[
  {
    "path": ".devcontainer/.gitignore",
    "content": "!devcontainer.json"
  },
  {
    "path": ".devcontainer/Dockerfile",
    "content": "FROM mcr.microsoft.com/devcontainers/python:3.10\n\nRUN python -m pip install --no-cache-dir --upgrade pip poetry \\\n    && pip install --no-cache-dir --extra-index-url https://download.pytorch.org/whl/cu118 'torch>=2.0.0' \\\n    && pip install --no-cache-dir --find-links https://data.pyg.org/whl/torch-2.0.0+cu118.html torch_geometric pyg_lib torch_scatter torch_sparse\n\nENV POETRY_VIRTUALENVS_CREATE=false\nCOPY pyproject.toml poetry.lock* /tmp/poetry/\nRUN poetry -C /tmp/poetry --no-cache install --no-root --no-directory"
  },
  {
    "path": ".devcontainer/devcontainer.json",
    "content": "// For format details, see https://aka.ms/devcontainer.json. For config options, see the\n// README at: https://github.com/devcontainers/templates/tree/main/src/python\n{\n\t\"name\": \"py-tgb\",\n\t\"build\": {\n\t\t\"dockerfile\": \"Dockerfile\",\n\t\t\"context\": \"..\"\n\t},\n\t\"customizations\": {\n\t\t\"vscode\": {\n\t\t\t\"extensions\": [\n\t\t\t\t\"editorconfig.editorconfig\",\n\t\t\t\t\"github.vscode-pull-request-github\",\n\t\t\t\t\"ms-azuretools.vscode-docker\",\n\t\t\t\t\"ms-python.python\",\n\t\t\t\t\"ms-python.vscode-pylance\",\n\t\t\t\t\"ms-python.pylint\",\n\t\t\t\t\"ms-python.isort\",\n\t\t\t\t\"ms-python.flake8\",\n\t\t\t\t\"ms-python.black-formatter\",\n\t\t\t\t\"ms-vsliveshare.vsliveshare\",\n\t\t\t\t\"ryanluker.vscode-coverage-gutters\",\n\t\t\t\t\"bungcip.better-toml\",\n\t\t\t\t\"GitHub.copilot\",\n\t\t\t\t\"redhat.vscode-yaml\"\n\t\t\t],\n\t\t\t\"settings\": {\n\t\t\t\t\"python.defaultInterpreterPath\": \"/usr/local/bin/python\",\n\t\t\t\t\"black-formatter.path\": [\n\t\t\t\t\t\"/usr/local/py-utils/bin/black\"\n\t\t\t\t],\n\t\t\t\t\"pylint.path\": [\n\t\t\t\t\t\"/usr/local/py-utils/bin/pylint\"\n\t\t\t\t],\n\t\t\t\t\"flake8.path\": [\n\t\t\t\t\t\"/usr/local/py-utils/bin/flake8\"\n\t\t\t\t],\n\t\t\t\t\"isort.path\": [\n\t\t\t\t\t\"/usr/local/py-utils/bin/isort\"\n\t\t\t\t]\n\t\t\t}\n\t\t}\n\t},\n\t\"features\": {\n\t\t\"ghcr.io/devcontainers-contrib/features/act:1\": {},\n\t\t\"ghcr.io/stuartleeks/dev-container-features/shell-history:0\": {},\n\t\t\"ghcr.io/devcontainers/features/common-utils:2\": {}\n\t},\n\t\"postCreateCommand\": \"poetry --no-cache install --only main\"\n}\n"
  },
  {
    "path": ".github/workflows/mkdocs.yaml",
    "content": "name: mkdocs\non:\n  push:\n    # branches:\n    #   - master \n    #   - main\n    tags:\n      - \"v*.*.*\"\npermissions:\n  contents: write\njobs:\n  deploy:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v3\n      - uses: actions/setup-python@v4\n        with:\n          python-version: 3.x\n      - uses: actions/cache@v3\n        with:\n          key: mkdocs-material-${{ github.ref }} \n          path: .cache\n          restore-keys: |\n            mkdocs-material-\n      - run: pip install mkdocs-material mkdocstrings-python mkdocs-jupyter\n      - run: mkdocs gh-deploy --force\n"
  },
  {
    "path": ".github/workflows/pypi.yaml",
    "content": "# https://github.com/JRubics/poetry-publish\n\nname: Publish to PyPI\non:\n  push:\n    tags:\n      - \"v*.*.*\"\n\njobs:\n  publish:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v3\n      - uses: actions/setup-python@v4\n      - name: Build and publish to pypi\n        uses: JRubics/poetry-publish@v1.17\n        with:\n          pypi_token: ${{ secrets.PYPI_API_TOKEN }}"
  },
  {
    "path": ".gitignore",
    "content": "!requirements*.txt\nget_croissant.py\n#dataset\nstats_figures/\nfigs/\n*.xz\n*.dict\n*.tab\n*.npz\n*.xz\n*.parquet\n*.gz\n*.tar\n*.pdf\n*.csv\n*.zip\n*.json\n*.npy\n*.pt\n*.out\n*.pkl\n*.txt\n*.attr\n*.edge\n.DS_Store\nstore_files/\n# Byte-compiled / optimized / DLL files\n__pycache__/\nraw/\nbooks/\nelectronics/\nsoftware/\n*.py[cod]\n*$py.class\nsaved_models/\ndump/\nsaved_results/\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n__pycache__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\ncc_env.sh\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2023 TGB Team\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE."
  },
  {
    "path": "README.md",
    "content": "<!-- # TGB -->\n![TGB logo](imgs/logo.png)\n\n**Temporal Graph Benchmark for Machine Learning on Temporal Graphs** (NeurIPS 2023 Datasets and Benchmarks Track)\n<h3>\n  <a href=\"https://proceedings.neurips.cc/paper_files/paper/2023/hash/066b98e63313162f6562b35962671288-Abstract-Datasets_and_Benchmarks.html\"><img src=\"https://img.shields.io/badge/Paper-link-important\"></a>\n\t<a href=\"https://arxiv.org/abs/2307.01026\"><img src=\"https://img.shields.io/badge/arXiv-pdf-yellowgreen\"></a>\n\t<a href=\"https://pypi.org/project/py-tgb/\"><img src=\"https://img.shields.io/pypi/v/py-tgb.svg?color=brightgreen\"></a>\n\t<a href=\"https://tgb.complexdatalab.com/\"><img src=\"https://img.shields.io/badge/website-blue\"></a>\n\t<a href=\"https://docs.tgb.complexdatalab.com/\"><img src=\"https://img.shields.io/badge/docs-orange\"></a>\n</h3> \n\n**TGB 2.0: A Benchmark for Learning on Temporal Knowledge Graphs and Heterogeneous Graphs** (NeurIPS 2024 Datasets and Benchmarks Track)\n<h3>\n  <a href=\"https://openreview.net/forum?id=EADRzNJFn1#discussion\"><img src=\"https://img.shields.io/badge/Paper-link-important\"></a>\n  <a href=\"https://arxiv.org/abs/2406.09639v1\"><img src=\"https://img.shields.io/badge/arXiv-pdf-yellowgreen\"></a>\n  <a href=\"https://pypi.org/project/py-tgb/\"><img src=\"https://img.shields.io/pypi/v/py-tgb.svg?color=brightgreen\"></a>\n\t<a href=\"https://tgb.complexdatalab.com/\"><img src=\"https://img.shields.io/badge/website-blue\"></a>\n</h3> \n\n\nOverview of the Temporal Graph Benchmark (TGB) pipeline:\n- TGB includes large-scale and realistic datasets from 10 different domains with both dynamic link prediction and node property prediction tasks.\n- TGB automatically downloads datasets and processes them into `numpy`, `PyTorch` and `PyG compatible TemporalData` formats. \n- Novel TG models can be easily evaluated on TGB datasets via reproducible and realistic evaluation protocols. \n- TGB provides public and online leaderboards to track recent developments in temporal graph learning domain.\n- Now TGB supports temporal homogeneous graphs, temporal knowledge graphs and temporal heterogenenous graph datasets.\n\n![TGB dataloading and evaluation pipeline](imgs/pipeline.png)\n\n**To submit to [TGB leaderboard](https://tgb.complexdatalab.com/), please fill in this [google form](https://forms.gle/SEsXvN1QHo9tSFwx9)**\n\n**See all version differences and update notes [here](https://tgb.complexdatalab.com/docs/update/)**\n\n### Announcements\n\n**Excited to announce TGB 2.0, has been presented at NeurIPS 2024 Datasets and Benchmarks Track**\n\nSee our [camera ready version](https://openreview.net/forum?id=EADRzNJFn1#discussion) and [arXiv version](https://arxiv.org/abs/2307.01026) for details. Please [install locally](https://tgb.complexdatalab.com/docs/home/) first. We welcome your feedback and suggestions. \n\n\n**Excited to announce TGX, a companion package for analyzing temporal graphs in WSDM 2024 Demo Track**\n\nTGX supports all TGB datasets and provides numerous temporal graph visualization plots and statistics out of the box. See our paper: [Temporal Graph Analysis with TGX](https://arxiv.org/abs/2402.03651) and [TGX website](https://complexdata-mila.github.io/TGX/).\n\n<!-- **Excited to announce that TGB has been accepted to NeurIPS 2023 Datasets and Benchmarks Track**\n\nThanks to everyone for your help in improving TGB! we will continue to improve TGB based on your feedback and suggestions.  -->\n\n**Please update to version `2.2.0`**\n\n#### version `2.2.0`\nAdding license for TGB software (for dataset license please check TGB website). \nPrinting messages now will not automatically set to stdout, use `TGB_VERBOSE=True` in your shell to set the print to be verbose.\nDefault option is to automatically download the datasets (rather than command line input as before).  \n\n#### version `2.1.0`\nIncludes supplementary datasets `tgbl-lastfm` `tgbl-enron` `tgbl-uci` `tgbl-subreddit` for research purposes.\nFor more details, see the release notes\n\n#### version `2.0.0`\n\nIncludes all new datasets from TGB 2.0 including temporal knowledge graphs and temporal heterogeneous graphs. \n\n<!-- \n#### version `0.9.2`\n\nUpdate the fix for `tgbl-flight` where now the unix timestamps are provided directly in the dataset. If you had issues with `tgbl-flight`, please remove `TGB/tgb/datasets/tgbl_flight`and redownload the dataset for a clean install --> -->\n\n\n<!-- \n#### version `0.9.1`\n\nFixed an issue for `tgbl-flight` where the timestamp conversion is incorrect due to time zone differences. If you had issues with `tgbl-flight` before, please update your package. \n\n\n#### version `0.9.0`\n\nAdded the large `tgbn-token` dataset with 72 million edges to the `nodeproppred` dataset. \n\nFixed errors in `tgbl-coin` and `tgbl-flight` where a small set of edges are not sorted chronologically. Please update your dataset version for them to version 2 (will be prompted in terminal). -->\n\n\n### Pip Install\n\nYou can install TGB via [pip](https://pypi.org/project/py-tgb/). **Requires python >= 3.9**\n```\npip install py-tgb\n```\n\n### Links and Datasets\n\nThe project website can be found [here](https://tgb.complexdatalab.com/).\n\nThe API documentations can be found [here](https://shenyanghuang.github.io/TGB/).\n\nall dataset download links can be found at [info.py](https://github.com/shenyangHuang/TGB/blob/main/tgb/utils/info.py)\n\nTGB dataloader will also automatically download the dataset as well as the negative samples for the link property prediction datasets.\n\nif website is unaccessible, please use [this link](https://tgb-website.pages.dev/) instead.\n\n\n### Running Example Methods\n\n- For the dynamic link property prediction task, see the [`examples/linkproppred`](https://github.com/shenyangHuang/TGB/tree/main/examples/linkproppred) folder for example scripts to run TGN, DyRep and EdgeBank on TGB datasets.\n- For the dynamic node property prediction task, see the [`examples/nodeproppred`](https://github.com/shenyangHuang/TGB/tree/main/examples/nodeproppred) folder for example scripts to run TGN, DyRep and EdgeBank on TGB datasets.\n- For all other baselines, please see the [TGB_Baselines](https://github.com/fpour/TGB_Baselines) repo.\n\n### Acknowledgments\nWe thank the [OGB](https://ogb.stanford.edu/) team for their support throughout this project and sharing their website code for the construction of [TGB website](https://tgb.complexdatalab.com/).\n\n\n### Software License\n\nThe code from this repo is licensed under the MIT License (see LICENSE)\n\n\n### Citation\n\nIf code or data from this repo is useful for your project, please consider citing our TGB and TGB 2.0 paper:\n```\n@article{huang2023temporal,\n  title={Temporal graph benchmark for machine learning on temporal graphs},\n  author={Huang, Shenyang and Poursafaei, Farimah and Danovitch, Jacob and Fey, Matthias and Hu, Weihua and Rossi, Emanuele and Leskovec, Jure and Bronstein, Michael and Rabusseau, Guillaume and Rabbany, Reihaneh},\n  journal={Advances in Neural Information Processing Systems},\n  year={2023}\n}\n```\n\n```\n@article{huang2024tgb2,\n  title={TGB 2.0: A Benchmark for Learning on Temporal Knowledge Graphs and Heterogeneous Graphs},\n  author={Gastinger, Julia and Huang, Shenyang and Galkin, Mikhail and Loghmani, Erfan and Parviz, Ali and Poursafaei, Farimah and Danovitch, Jacob and Rossi, Emanuele and Koutis, Ioannis and Stuckenschmidt, Heiner and      Rabbany, Reihaneh and Rabusseau, Guillaume},\n  journal={Advances in Neural Information Processing Systems},\n  year={2024}\n}\n```\n\n<!-- \n\n### Install dependency\nOur implementation works with python >= 3.9 and can be installed as follows\n\n1. set up virtual environment (conda should work as well)\n```\npython -m venv ~/tgb_env/\nsource ~/tgb_env/bin/activate\n```\n\n2. install external packages\n```\npip install pandas==1.5.3\npip install matplotlib==3.7.1\npip install clint==0.5.1\n```\n\ninstall Pytorch and PyG dependencies (needed to run the examples)\n```\npip install torch==2.0.0 --index-url https://download.pytorch.org/whl/cu117\npip install torch_geometric==2.3.0\npip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.0.0+cu117.html\n```\n\n3. install local dependencies under root directory `/TGB`\n```\npip install -e .\n```\n\n\n### Instruction for tracking new documentation and running mkdocs locally\n\n1. first run the mkdocs server locally in your terminal \n```\nmkdocs serve\n```\n\n2. go to the local hosted web address similar to\n```\n[14:18:13] Browser connected: http://127.0.0.1:8000/\n```\n\nExample: to track documentation of a new hi.py file in tgb/edgeregression/hi.py\n\n\n3. create docs/api/tgb.hi.md and add the following\n```\n# `tgb.edgeregression`\n\n::: tgb.edgeregression.hi\n```\n\n4. edit mkdocs.yml \n```\nnav:\n  - Overview: index.md\n  - About: about.md\n  - API:\n\tother *.md files \n\t- tgb.edgeregression: api/tgb.hi.md\n```\n\n### Creating new branch ###\n```\ngit fetch origin\n\ngit checkout -b test origin/test\n```\n\n### dependencies for mkdocs (documentation)\n```\npip install mkdocs\npip install mkdocs-material\npip install mkdocstrings-python\npip install mkdocs-jupyter\npip install notebook\n```\n\n\n### full dependency list\nOur implementation works with python >= 3.9 and has the following dependencies\n```\npytorch == 2.0.0\ntorch-geometric == 2.3.0\ntorch-scatter==2.1.1\ntorch-sparse==0.6.17\ntorch-spline-conv==1.2.2\npandas==1.5.3\nclint==0.5.1\n``` -->\n"
  },
  {
    "path": "docs/about.md",
    "content": "# Temporal Graph Benchmark (TGB)\r\n![TGB logo](assets/logo.png)\r\n\r\n## Overview\r\n\r\nThe TGB repo provides an automated ML pipeline for learning on a diverse set of temporal graph datasets:\r\n\r\n- automatic download of datasets from url\r\n\r\n- processing the raw files into ML ready format\r\n\r\n- support datasets in `numpy`, `Pytorch` and `PyG TemporalData` formats\r\n\r\n- evaluation code for each dataset \r\n\r\n"
  },
  {
    "path": "docs/api/tgb.linkproppred.md",
    "content": "# `tgb.linkproppred`\r\n\r\n::: tgb.linkproppred.dataset\r\n::: tgb.linkproppred.dataset_pyg\r\n::: tgb.linkproppred.evaluate\r\n::: tgb.linkproppred.negative_sampler\r\n::: tgb.linkproppred.negative_generator\r\n::: tgb.linkproppred.tkg_negative_generator\r\n::: tgb.linkproppred.tkg_negative_sampler\r\n::: tgb.linkproppred.thg_negative_generator\r\n::: tgb.linkproppred.thg_negative_sampler\r\n"
  },
  {
    "path": "docs/api/tgb.nodeproppred.md",
    "content": "# `tgb.nodeproppred`\r\n\r\n::: tgb.nodeproppred.dataset\r\n::: tgb.nodeproppred.dataset_pyg\r\n::: tgb.nodeproppred.evaluate\r\n\r\n"
  },
  {
    "path": "docs/api/tgb.utils.md",
    "content": "# `tgb.utils`\r\n\r\n::: tgb.utils.pre_process\r\n::: tgb.utils.utils\r\n::: tgb.utils.info\r\n::: tgb.utils.stats"
  },
  {
    "path": "docs/index.md",
    "content": "# Welcome to Temporal Graph Benchmark\n![TGB logo](assets/logo.png)\n\n\n\n### Pip Install\n\nYou can install TGB via [pip](https://pypi.org/project/py-tgb/)\n```\npip install py-tgb\n```\n\n### Links and Datasets\n\nThe project website can be found [here](https://tgb.complexdatalab.com/).\n\nThe API documentations can be found [here](https://shenyanghuang.github.io/TGB/).\n\nall dataset download links can be found at [info.py](https://github.com/shenyangHuang/TGB/blob/main/tgb/utils/info.py)\n\nTGB dataloader will also automatically download the dataset as well as the negative samples for the link property prediction datasets.\n\n\n### Install dependency\nOur implementation works with python >= 3.9 and can be installed as follows\n\n1. set up virtual environment (conda should work as well)\n```\npython -m venv ~/tgb_env/\nsource ~/tgb_env/bin/activate\n```\n\n2. install external packages\n```\npip install pandas==1.5.3\npip install matplotlib==3.7.1\npip install clint==0.5.1\n```\n\ninstall Pytorch and PyG dependencies (needed to run the examples)\n```\npip install torch==2.0.0 --index-url https://download.pytorch.org/whl/cu117\npip install torch_geometric==2.3.0\npip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.0.0+cu117.html\n```\n\n3. install local dependencies under root directory `/TGB`\n```\npip install -e .\n```\n\n\n### Instruction for tracking new documentation and running mkdocs locally\n\n1. first run the mkdocs server locally in your terminal \n```\nmkdocs serve\n```\n\n2. go to the local hosted web address similar to\n```\n[14:18:13] Browser connected: http://127.0.0.1:8000/\n```\n\nExample: to track documentation of a new hi.py file in tgb/edgeregression/hi.py\n\n\n3. create docs/api/tgb.hi.md and add the following\n```\n# `tgb.edgeregression`\n\n::: tgb.edgeregression.hi\n```\n\n4. edit mkdocs.yml \n```\nnav:\n  - Overview: index.md\n  - About: about.md\n  - API:\n\tother *.md files \n\t- tgb.edgeregression: api/tgb.hi.md\n```\n\n### Creating new branch ###\n```\ngit fetch origin\n\ngit checkout -b test origin/test\n```\n\n### dependencies for mkdocs (documentation)\n```\npip install mkdocs\npip install mkdocs-material\npip install mkdocstrings-python\npip install mkdocs-jupyter\npip install notebook\n```\n\n\n### full dependency list\nOur implementation works with python >= 3.9 and has the following dependencies\n```\npytorch == 2.0.0\ntorch-geometric == 2.3.0\ntorch-scatter==2.1.1\ntorch-sparse==0.6.17\ntorch-spline-conv==1.2.2\npandas==1.5.3\nclint==0.5.1\n```\n\n\n\n<!-- ## Code blocks\n\n`pip install tgb` -->\n\n\n\n<!-- \n\n### Plain codeblock\n\nA plain codeblock:\n\n```\nSome code here\ndef myfunction()\n// some comment\n```\n\n#### Code for a specific language\n\nSome more code with the `py` at the start:\n\n``` py\nimport tensorflow as tf\ndef whatever()\n```\n\n#### With a title\n\n``` py title=\"bubble_sort.py\"\ndef bubble_sort(items):\n    for i in range(len(items)):\n        for j in range(len(items) - 1 - i):\n            if items[j] > items[j + 1]:\n                items[j], items[j + 1] = items[j + 1], items[j]\n```\n\n#### With line numbers\n\n``` py linenums=\"1\"\ndef bubble_sort(items):\n    for i in range(len(items)):\n        for j in range(len(items) - 1 - i):\n            if items[j] > items[j + 1]:\n                items[j], items[j + 1] = items[j + 1], items[j]\n```\n\n#### Highlighting lines\n\n``` py hl_lines=\"2 3\"\ndef bubble_sort(items):\n    for i in range(len(items)):\n        for j in range(len(items) - 1 - i):\n            if items[j] > items[j + 1]:\n                items[j], items[j + 1] = items[j + 1], items[j]\n```\n\n## Icons and Emojs\n\n:smile: \n\n:fontawesome-regular-face-laugh-wink:\n\n:fontawesome-brands-twitter:{ .twitter }\n\n:octicons-heart-fill-24:{ .heart } -->"
  },
  {
    "path": "docs/tutorials/Edge_data_numpy.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"d5e3f5a2\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Access edge data as numpy arrays\\n\",\n    \"\\n\",\n    \"This tutorial will show you how to access various datasets and their corresponding edgelists in `tgb`\\n\",\n    \"\\n\",\n    \"You can directly retrieve the edge data as `numpy` arrays, `PyG` and `Pytorch` dependencies are not necessary\\n\",\n    \"\\n\",\n    \"The logic is implemented in `dataset.py` under `tgb/linkproppred/` and `tgb/nodeproppred/` folders respectively\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"id\": \"23f00c08\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from tgb.linkproppred.dataset import LinkPropPredDataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"60e52b7b\",\n   \"metadata\": {},\n   \"source\": [\n    \"specifying the name of the dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"id\": \"48888070\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"name = \\\"tgbl-wiki\\\" \"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"3511804a\",\n   \"metadata\": {},\n   \"source\": [\n    \"### process and loading the dataset\\n\",\n    \"\\n\",\n    \"if the dataset has been processed, it will be loaded from disc for fast access\\n\",\n    \"\\n\",\n    \"if the dataset has not been downloaded, it will be processed automatically\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"id\": \"8486fa82\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Will you download the dataset(s) now? (y/N)\\n\",\n      \"y\\n\",\n      \"\\u001b[93mDownload started, this might take a while . . . \\u001b[0m\\n\",\n      \"Dataset title: tgbl-wiki\\n\",\n      \"\\u001b[92mDownload completed \\u001b[0m\\n\",\n      \"Dataset directory is  /mnt/f/code/TGB/tgb/datasets/tgbl_wiki\\n\",\n      \"file not processed, generating processed file\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"tgb.linkproppred.dataset.LinkPropPredDataset\"\n      ]\n     },\n     \"execution_count\": 3,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"dataset = LinkPropPredDataset(name=name, root=\\\"datasets\\\", preprocess=True)\\n\",\n    \"type(dataset)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"47c949b4\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Accessing the edge data\\n\",\n    \"\\n\",\n    \"the edge data can be easily accessed via the property of the method as `numpy` arrays \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"id\": \"9e4e7421\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"dict\"\n      ]\n     },\n     \"execution_count\": 4,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"data = dataset.full_data  #a dictionary stores all the edge data\\n\",\n    \"type(data) \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"id\": \"c6ec9ac0\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"numpy.ndarray\"\n      ]\n     },\n     \"execution_count\": 5,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"type(data['sources'])\\n\",\n    \"type(data['destinations'])\\n\",\n    \"type(data['timestamps'])\\n\",\n    \"type(data['edge_feat'])\\n\",\n    \"type(data['w'])\\n\",\n    \"type(data['edge_label']) #just all one array as all edges in the dataset are positive edges\\n\",\n    \"type(data['edge_idxs']) #just index of the edges increment by 1 for each edge\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"bb1bbfd6\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Accessing the train, test, val split\\n\",\n    \"\\n\",\n    \"the masks for training, validation, and test split can be accessed directly from the `dataset` as well\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"id\": \"8cd3507c\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"numpy.ndarray\"\n      ]\n     },\n     \"execution_count\": 6,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_mask = dataset.train_mask\\n\",\n    \"val_mask = dataset.val_mask\\n\",\n    \"test_mask = dataset.test_mask\\n\",\n    \"\\n\",\n    \"type(train_mask)\\n\",\n    \"type(val_mask)\\n\",\n    \"type(test_mask)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"cf5eff06\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.9.16\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "docs/tutorials/Edge_data_pyg.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"d5e3f5a2\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Access edge data in Pytorch Geometric\\n\",\n    \"\\n\",\n    \"This tutorial will show you how to access various datasets and their corresponding edgelists in `tgb`\\n\",\n    \"\\n\",\n    \"The logic for PyG data is stored in `dataset_pyg.py` in `tgb/linkproppred` and `tgb/nodeproppred` folders\\n\",\n    \"\\n\",\n    \"This tutorial requires `Pytorch` and `PyG`, refer to `README.md` for installation instructions\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"id\": \"23f00c08\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"60e52b7b\",\n   \"metadata\": {},\n   \"source\": [\n    \"specifying the name of the dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"id\": \"48888070\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"name = \\\"tgbl-wiki\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"3511804a\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Process and load the dataset\\n\",\n    \"\\n\",\n    \"if the dataset has been processed, it will be loaded from disc for fast access\\n\",\n    \"\\n\",\n    \"if the dataset has not been downloaded, it will be processed automatically\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"id\": \"8486fa82\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"file found, skipping download\\n\",\n      \"Dataset directory is  /mnt/f/code/TGB/tgb/datasets/tgbl_wiki\\n\",\n      \"loading processed file\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"tgb.linkproppred.dataset_pyg.PyGLinkPropPredDataset\"\n      ]\n     },\n     \"execution_count\": 3,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"dataset = PyGLinkPropPredDataset(name=name, root=\\\"datasets\\\")\\n\",\n    \"type(dataset)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"47c949b4\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Access edge data from TemporalData object \\n\",\n    \"\\n\",\n    \"You can retrieve `torch_geometric.data.temporal.TemporalData` directly from `PyGLinkPropPredDataset`\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"id\": \"9e4e7421\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"torch_geometric.data.temporal.TemporalData\"\n      ]\n     },\n     \"execution_count\": 4,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"data = dataset.get_TemporalData()\\n\",\n    \"type(data)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"id\": \"c6ec9ac0\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"torch.Tensor\"\n      ]\n     },\n     \"execution_count\": 5,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"type(data.src)\\n\",\n    \"type(data.dst)\\n\",\n    \"type(data.t)\\n\",\n    \"type(data.msg)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"52fd601f\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Directly access edge data as Pytorch tensors\\n\",\n    \"\\n\",\n    \"the edge data can be easily accessed via the property of the method, these are converted into pytorch tensors (from `PyGLinkPropPredDataset`)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"id\": \"56fb3347\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"torch.Tensor\"\n      ]\n     },\n     \"execution_count\": 6,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"type(dataset.src)  #same as src from above\\n\",\n    \"type(dataset.dst)  #same as dst\\n\",\n    \"type(dataset.ts)  #same as t\\n\",\n    \"type(dataset.edge_feat) #same as msg\\n\",\n    \"type(dataset.edge_label) #same as label used in tgn\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"bb1bbfd6\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Accessing the train, test, val split\\n\",\n    \"\\n\",\n    \"the masks for training, validation, and test split can be accessed directly from the `dataset` as well\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"id\": \"8cd3507c\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"torch.Tensor\"\n      ]\n     },\n     \"execution_count\": 7,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"train_mask = dataset.train_mask\\n\",\n    \"val_mask = dataset.val_mask\\n\",\n    \"test_mask = dataset.test_mask\\n\",\n    \"\\n\",\n    \"type(train_mask)\\n\",\n    \"type(val_mask)\\n\",\n    \"type(test_mask)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"9d6ed432\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.9.16\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "docs/tutorials/Node_label_tutorial.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"d5e3f5a2\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Access node labels for Dynamic Node Property Prediction\\n\",\n    \"\\n\",\n    \"This tutorial will show you how to access node labels and edge data for the node property prediction datasets in `tgb`.\\n\",\n    \"\\n\",\n    \"The source code is stored in `dataset_pyg.py` in `tgb/nodeproppred` folder\\n\",\n    \"\\n\",\n    \"This tutorial requires `Pytorch` and `PyG`, refer to `README.md` for installation instructions\\n\",\n    \"\\n\",\n    \"This tutorial uses `PyG TemporalData` object, however it is possible to use `numpy` arrays as well.\\n\",\n    \"\\n\",\n    \"see examples in `examples/nodeproppred` folder for more details.\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"id\": \"23f00c08\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset\\n\",\n    \"from torch_geometric.loader import TemporalDataLoader\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"60e52b7b\",\n   \"metadata\": {},\n   \"source\": [\n    \"specifying the name of the dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"id\": \"48888070\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"name = \\\"tgbn-genre\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"3511804a\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Process and load the dataset\\n\",\n    \"\\n\",\n    \"if the dataset has been processed, it will be loaded from disc for fast access\\n\",\n    \"\\n\",\n    \"if the dataset has not been downloaded, it will be processed automatically\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"id\": \"8486fa82\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"file found, skipping download\\n\",\n      \"Dataset directory is  /mnt/f/code/TGB/tgb/datasets/tgbn_genre\\n\",\n      \"loading processed file\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"tgb.nodeproppred.dataset_pyg.PyGNodePropPredDataset\"\n      ]\n     },\n     \"execution_count\": 3,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"dataset = PyGNodePropPredDataset(name=name, root=\\\"datasets\\\")\\n\",\n    \"type(dataset)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"31338262\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Train, Validation and Test splits with dataloaders\\n\",\n    \"\\n\",\n    \"splitting the edges into train, val, test sets and construct dataloader for each\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"id\": \"27b4f6a1\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_mask = dataset.train_mask\\n\",\n    \"val_mask = dataset.val_mask\\n\",\n    \"test_mask = dataset.test_mask\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"data = dataset.get_TemporalData()\\n\",\n    \"\\n\",\n    \"train_data = data[train_mask]\\n\",\n    \"val_data = data[val_mask]\\n\",\n    \"test_data = data[test_mask]\\n\",\n    \"\\n\",\n    \"batch_size = 200\\n\",\n    \"train_loader = TemporalDataLoader(train_data, batch_size=batch_size)\\n\",\n    \"val_loader = TemporalDataLoader(val_data, batch_size=batch_size)\\n\",\n    \"test_loader = TemporalDataLoader(test_data, batch_size=batch_size)\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"47c949b4\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Access node label data \\n\",\n    \"\\n\",\n    \"In `tgb`, the node label data are queried based on the nearest edge observed so far and retrieves the node label data for the corresponding day. \\n\",\n    \"\\n\",\n    \"Note that this is because the node labels often have different timestamps from the edges thus should be processed at the correct time in the edge stream.\\n\",\n    \"\\n\",\n    \"In the example below, we show how to iterate through the edges and retrieve the node labels of the corresponding time. \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"id\": \"9e4e7421\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"#query the timestamps for the first node labels\\n\",\n    \"label_t = dataset.get_label_time()\\n\",\n    \"\\n\",\n    \"for batch in train_loader:\\n\",\n    \"    #access the edges in this batch\\n\",\n    \"    src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg\\n\",\n    \"    query_t = batch.t[-1]\\n\",\n    \"    # check if this batch moves to the next day\\n\",\n    \"    if query_t > label_t:\\n\",\n    \"        # find the node labels from the past day\\n\",\n    \"        label_tuple = dataset.get_node_label(query_t)\\n\",\n    \"        # node labels are structured as a tuple with (timestamps, source node, label) format, label is a vector\\n\",\n    \"        label_ts, label_srcs, labels = (\\n\",\n    \"            label_tuple[0],\\n\",\n    \"            label_tuple[1],\\n\",\n    \"            label_tuple[2],\\n\",\n    \"        )\\n\",\n    \"        label_t = dataset.get_label_time()\\n\",\n    \"\\n\",\n    \"        #insert your code for backproping with node labels here\\n\",\n    \"            \"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.9.16\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "examples/linkproppred/tgbl-coin/dyrep.py",
    "content": "\"\"\"\nDyRep\n    This has been implemented with intuitions from the following sources:\n    - https://github.com/twitter-research/tgn\n    - https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py\n\n    Spec.:\n        - Memory Updater: RNN\n        - Embedding Module: ID\n        - Message Function: ATTN\n\ncommand for an example run:\n    python examples/linkproppred/tgbl-coin/dyrep.py --data \"tgbl-coin\" --num_run 1 --seed 1\n\"\"\"\nimport math\nimport timeit\nimport os\nimport os.path as osp\nfrom pathlib import Path\nimport numpy as np\nimport torch\nfrom torch_geometric.loader import TemporalDataLoader\n\n# internal imports\nfrom tgb.utils.utils import get_args, set_random_seed, save_results\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom modules.decoder import LinkPredictor\nfrom modules.emb_module import GraphAttentionEmbedding\nfrom modules.msg_func import IdentityMessage\nfrom modules.msg_agg import LastAggregator\nfrom modules.neighbor_loader import LastNeighborLoader\nfrom modules.memory_module import DyRepMemory\nfrom modules.early_stopping import  EarlyStopMonitor\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\n\n\n\n# ==========\n# ========== Define helper function...\n# ==========\n\ndef train():\n    r\"\"\"\n    Training procedure for DyRep model\n    This function uses some objects that are globally defined in the current scrips \n\n    Parameters:\n        None\n    Returns:\n        None\n            \n    \"\"\"\n    model['memory'].train()\n    model['gnn'].train()\n    model['link_pred'].train()\n\n    model['memory'].reset_state()  # Start with a fresh memory.\n    neighbor_loader.reset_state()  # Start with an empty graph.\n\n    total_loss = 0\n    for batch in train_loader:\n        batch = batch.to(device)\n        optimizer.zero_grad()\n\n        src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg\n\n        # Sample negative destination nodes.\n        neg_dst = torch.randint(\n            min_dst_idx,\n            max_dst_idx + 1,\n            (src.size(0),),\n            dtype=torch.long,\n            device=device,\n        )\n\n        n_id = torch.cat([src, pos_dst, neg_dst]).unique()\n        n_id, edge_index, e_id = neighbor_loader(n_id)\n        assoc[n_id] = torch.arange(n_id.size(0), device=device)\n\n        # Get updated memory of all nodes involved in the computation.\n        z, last_update = model['memory'](n_id)\n\n        pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])\n        neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])\n\n        loss = criterion(pos_out, torch.ones_like(pos_out))\n        loss += criterion(neg_out, torch.zeros_like(neg_out))\n\n        # update the memory with ground-truth\n        z = model['gnn'](\n            z,\n            last_update,\n            edge_index,\n            data.t[e_id].to(device),\n            data.msg[e_id].to(device),\n        )\n        model['memory'].update_state(src, pos_dst, t, msg, z, assoc)\n\n        # update neighbor loader\n        neighbor_loader.insert(src, pos_dst)\n\n        loss.backward()\n        optimizer.step()\n        model['memory'].detach()\n        total_loss += float(loss.detach()) * batch.num_events\n\n    return total_loss / train_data.num_events\n\n\n@torch.no_grad()\ndef test(loader, neg_sampler, split_mode):\n    r\"\"\"\n    Evaluated the dynamic link prediction\n    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges\n\n    Parameters:\n        loader: an object containing positive attributes of the positive edges of the evaluation set\n        neg_sampler: an object that gives the negative edges corresponding to each positive edge\n        split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives\n    Returns:\n        perf_metric: the result of the performance evaluation\n    \"\"\"\n    model['memory'].eval()\n    model['gnn'].eval()\n    model['link_pred'].eval()\n\n    perf_list = []\n\n    for pos_batch in loader:\n        pos_src, pos_dst, pos_t, pos_msg = (\n            pos_batch.src,\n            pos_batch.dst,\n            pos_batch.t,\n            pos_batch.msg,\n        )\n\n        neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)\n\n        for idx, neg_batch in enumerate(neg_batch_list):\n            src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)\n            dst = torch.tensor(\n                np.concatenate(\n                    ([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),\n                    axis=0,\n                ),\n                device=device,\n            )\n\n            n_id = torch.cat([src, dst]).unique()\n            n_id, edge_index, e_id = neighbor_loader(n_id)\n            assoc[n_id] = torch.arange(n_id.size(0), device=device)\n\n            # Get updated memory of all nodes involved in the computation.\n            z, last_update = model['memory'](n_id)\n\n            y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])\n            # compute MRR\n            input_dict = {\n                \"y_pred_pos\": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),\n                \"y_pred_neg\": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),\n                \"eval_metric\": [metric],\n            }\n            perf_list.append(evaluator.eval(input_dict)[metric])\n        \n        # update the memory with positive edges\n        n_id = torch.cat([pos_src, pos_dst]).unique()\n        n_id, edge_index, e_id = neighbor_loader(n_id)\n        assoc[n_id] = torch.arange(n_id.size(0), device=device)\n        z, last_update = model['memory'](n_id)\n        z = model['gnn'](\n            z,\n            last_update,\n            edge_index,\n            data.t[e_id].to(device),\n            data.msg[e_id].to(device),\n        )\n        model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg, z, assoc)\n\n        # update the neighbor loader\n        neighbor_loader.insert(pos_src, pos_dst)\n\n    perf_metric = float(torch.tensor(perf_list).mean())\n\n    return perf_metric\n\n# ==========\n# ==========\n# ==========\n\n# Start...\nstart_overall = timeit.default_timer()\n\n# ========== set parameters...\nargs, _ = get_args()\nprint(\"INFO: Arguments:\", args)\n\nDATA = \"tgbl-coin\"\nLR = args.lr\nBATCH_SIZE = args.bs\nK_VALUE = args.k_value  \nNUM_EPOCH = args.num_epoch\nSEED = args.seed\nMEM_DIM = args.mem_dim\nTIME_DIM = args.time_dim\nEMB_DIM = args.emb_dim\nTOLERANCE = args.tolerance\nPATIENCE = args.patience\nNUM_RUNS = args.num_run\nNUM_NEIGHBORS = 10\n\nMODEL_NAME = 'DyRep'\nUSE_SRC_EMB_IN_MSG = False\nUSE_DST_EMB_IN_MSG = True\n# ==========\n\n# set the device\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n# data loading\ndataset = PyGLinkPropPredDataset(name=DATA, root=\"datasets\")\ntrain_mask = dataset.train_mask\nval_mask = dataset.val_mask\ntest_mask = dataset.test_mask\ndata = dataset.get_TemporalData()\ndata = data.to(device)\nmetric = dataset.eval_metric\n\ntrain_data = data[train_mask]\nval_data = data[val_mask]\ntest_data = data[test_mask]\n\ntrain_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)\nval_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)\ntest_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)\n\n# Ensure to only sample actual destination nodes as negatives.\nmin_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\n\n# neighborhood sampler\nneighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)\n\n# define the model end-to-end\n# 1) memory\nmemory = DyRepMemory(\n    data.num_nodes,\n    data.msg.size(-1),\n    MEM_DIM,\n    TIME_DIM,\n    message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),\n    aggregator_module=LastAggregator(),\n    memory_updater_type='rnn',\n    use_src_emb_in_msg=USE_SRC_EMB_IN_MSG,\n    use_dst_emb_in_msg=USE_DST_EMB_IN_MSG\n).to(device)\n\n# 2) GNN\ngnn = GraphAttentionEmbedding(\n    in_channels=MEM_DIM,\n    out_channels=EMB_DIM,\n    msg_dim=data.msg.size(-1),\n    time_enc=memory.time_enc,\n).to(device)\n\n# 3) link predictor\nlink_pred = LinkPredictor(in_channels=EMB_DIM).to(device)\n\nmodel = {'memory': memory,\n         'gnn': gnn,\n         'link_pred': link_pred}\n\n# define an optimizer\noptimizer = torch.optim.Adam(\n    set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),\n    lr=LR,\n)\ncriterion = torch.nn.BCEWithLogitsLoss()\n\n# Helper vector to map global node indices to local ones.\nassoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)\n\n\nprint(\"==========================================================\")\nprint(f\"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============\")\nprint(\"==========================================================\")\n\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'\n\nfor run_idx in range(NUM_RUNS):\n    print('-------------------------------------------------------------------------------')\n    print(f\"INFO: >>>>> Run: {run_idx} <<<<<\")\n    start_run = timeit.default_timer()\n\n    # set the seed for deterministic results...\n    torch.manual_seed(run_idx + SEED)\n    set_random_seed(run_idx + SEED)\n\n    # define an early stopper\n    save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'\n    save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'\n    early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id, \n                                    tolerance=TOLERANCE, patience=PATIENCE)\n\n    # ==================================================== Train & Validation\n    # loading the validation negative samples\n    dataset.load_val_ns()\n\n    val_perf_list = []\n    start_train_val = timeit.default_timer()\n    for epoch in range(1, NUM_EPOCH + 1):\n        # training\n        start_epoch_train = timeit.default_timer()\n        loss = train()\n        print(\n            f\"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {timeit.default_timer() - start_epoch_train: .4f}\"\n        )\n\n        # validation\n        start_val = timeit.default_timer()\n        perf_metric_val = test(val_loader, neg_sampler, split_mode=\"val\")\n        print(f\"\\tValidation {metric}: {perf_metric_val: .4f}\")\n        print(f\"\\tValidation: Elapsed time (s): {timeit.default_timer() - start_val: .4f}\")\n        val_perf_list.append(perf_metric_val)\n\n        # check for early stopping\n        if early_stopper.step_check(perf_metric_val, model):\n            break\n\n    train_val_time = timeit.default_timer() - start_train_val\n    print(f\"Train & Validation Total Elapsed Time (s): {train_val_time: .4f}\")\n\n    # ==================================================== Test\n    # first, load the best model\n    early_stopper.load_checkpoint(model)\n\n    # loading the test negative samples\n    dataset.load_test_ns()\n\n    # final testing\n    start_test = timeit.default_timer()\n    perf_metric_test = test(test_loader, neg_sampler, split_mode=\"test\")\n\n    print(f\"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< \")\n    print(f\"\\tTest: {metric}: {perf_metric_test: .4f}\")\n    test_time = timeit.default_timer() - start_test\n    print(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\n\n    save_results({'model': MODEL_NAME,\n                  'data': DATA,\n                  'run': run_idx,\n                  'seed': SEED,\n                  f'val {metric}': val_perf_list,\n                  f'test {metric}': perf_metric_test,\n                  'test_time': test_time,\n                  'tot_train_val_time': train_val_time\n                  }, \n    results_filename)\n\n    print(f\"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<\")\n    print('-------------------------------------------------------------------------------')\n\nprint(f\"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}\")\nprint(\"==============================================================\")\n"
  },
  {
    "path": "examples/linkproppred/tgbl-coin/edgebank.py",
    "content": "\"\"\"\nDynamic Link Prediction with EdgeBank\nNOTE: This implementation works only based on `numpy`\n\nReference: \n    - https://github.com/fpour/DGB/tree/main\n\n\n\"\"\"\n\nimport timeit\nimport numpy as np\nfrom sklearn.metrics import average_precision_score, roc_auc_score\nfrom torch_geometric.loader import TemporalDataLoader\nfrom tqdm import tqdm\nimport math\nimport os\nimport os.path as osp\nfrom pathlib import Path\nimport sys\nimport argparse\n\n# internal imports\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom modules.edgebank_predictor import EdgeBankPredictor\nfrom tgb.utils.utils import set_random_seed\nfrom tgb.linkproppred.dataset import LinkPropPredDataset\nfrom tgb.utils.utils import save_results\n\n# ==================\n# ==================\n# ==================\n\ndef test(data, test_mask, neg_sampler, split_mode):\n    r\"\"\"\n    Evaluated the dynamic link prediction\n    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges\n\n    Parameters:\n        data: a dataset object\n        test_mask: required masks to load the test set edges\n        neg_sampler: an object that gives the negative edges corresponding to each positive edge\n        split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives\n    Returns:\n        perf_metric: the result of the performance evaluation\n    \"\"\"\n    num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)\n    perf_list = []\n    for batch_idx in tqdm(range(num_batches)):\n        start_idx = batch_idx * BATCH_SIZE\n        end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))\n        pos_src, pos_dst, pos_t = (\n            data['sources'][test_mask][start_idx: end_idx],\n            data['destinations'][test_mask][start_idx: end_idx],\n            data['timestamps'][test_mask][start_idx: end_idx],\n        )\n        neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)\n        \n        for idx, neg_batch in enumerate(neg_batch_list):\n            query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])\n            query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])\n\n            y_pred = edgebank.predict_link(query_src, query_dst)\n            # compute MRR\n            input_dict = {\n                \"y_pred_pos\": np.array([y_pred[0]]),\n                \"y_pred_neg\": np.array(y_pred[1:]),\n                \"eval_metric\": [metric],\n            }\n            perf_list.append(evaluator.eval(input_dict)[metric])\n            \n        # update edgebank memory after each positive batch\n        edgebank.update_memory(pos_src, pos_dst, pos_t)\n\n    perf_metrics = float(np.mean(perf_list))\n\n    return perf_metrics\n\ndef get_args():\n    parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')\n    parser.add_argument('-d', '--data', type=str, help='Dataset name', default='tgbl-coin')\n    parser.add_argument('--bs', type=int, help='Batch size', default=200)\n    parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)\n    parser.add_argument('--seed', type=int, help='Random seed', default=1)\n    parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])\n    parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)\n\n    try:\n        args = parser.parse_args()\n    except:\n        parser.print_help()\n        sys.exit(0)\n    return args, sys.argv \n\n# ==================\n# ==================\n# ==================\n\nstart_overall = timeit.default_timer()\n\n# set hyperparameters\nargs, _ = get_args()\n\nSEED = args.seed  # set the random seed for consistency\nset_random_seed(SEED)\nMEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`\nBATCH_SIZE = args.bs\nK_VALUE = args.k_value\nTIME_WINDOW_RATIO = args.time_window_ratio\nDATA = \"tgbl-coin\"\n\nMODEL_NAME = 'EdgeBank'\n\n# data loading with `numpy`\ndataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\ndata = dataset.full_data  \nmetric = dataset.eval_metric\n\n# get masks\ntrain_mask = dataset.train_mask\nval_mask = dataset.val_mask\ntest_mask = dataset.test_mask\n\n#data for memory in edgebank\nhist_src = np.concatenate([data['sources'][train_mask]])\nhist_dst = np.concatenate([data['destinations'][train_mask]])\nhist_ts = np.concatenate([data['timestamps'][train_mask]])\n\n\n# #! check if edges are sorted\n# sorted = np.all(np.diff(data['timestamps']) >= 0)\n# print (\" INFO: Edges are sorted: \", sorted)\n\n# Set EdgeBank with memory updater\nedgebank = EdgeBankPredictor(\n        hist_src,\n        hist_dst,\n        hist_ts,\n        memory_mode=MEMORY_MODE,\n        time_window_ratio=TIME_WINDOW_RATIO)\n\nprint(\"==========================================================\")\nprint(f\"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============\")\nprint(\"==========================================================\")\n\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'\n\n# ==================================================== Test\n# loading the validation negative samples\ndataset.load_val_ns()\n\n# testing ...\nstart_val = timeit.default_timer()\nperf_metric_test = test(data, val_mask, neg_sampler, split_mode='val')\nend_val = timeit.default_timer()\n\nprint(f\"INFO: val: Evaluation Setting: >>> ONE-VS-MANY <<< \")\nprint(f\"\\tval: {metric}: {perf_metric_test: .4f}\")\ntest_time = timeit.default_timer() - start_val\nprint(f\"\\tval: Elapsed Time (s): {test_time: .4f}\")\n\n\n\n\n# ==================================================== Test\n# loading the test negative samples\ndataset.load_test_ns()\n\n# testing ...\nstart_test = timeit.default_timer()\nperf_metric_test = test(data, test_mask, neg_sampler, split_mode='test')\nend_test = timeit.default_timer()\n\nprint(f\"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< \")\nprint(f\"\\tTest: {metric}: {perf_metric_test: .4f}\")\ntest_time = timeit.default_timer() - start_test\nprint(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\n\nsave_results({'model': MODEL_NAME,\n              'memory_mode': MEMORY_MODE,\n              'data': DATA,\n              'run': 1,\n              'seed': SEED,\n              metric: perf_metric_test,\n              'test_time': test_time,\n              'tot_train_val_time': 'NA'\n              }, \n    results_filename)\n"
  },
  {
    "path": "examples/linkproppred/tgbl-coin/tgn.py",
    "content": "\"\"\"\r\nDynamic Link Prediction with a TGN model with Early Stopping\r\nReference: \r\n    - https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py\r\n\r\ncommand for an example run:\r\n    python examples/linkproppred/tgbl-coin/tgn.py --data \"tgbl-coin\" --num_run 1 --seed 1\r\n\"\"\"\r\n\r\nimport math\r\nimport timeit\r\n\r\nimport os\r\nimport os.path as osp\r\nfrom pathlib import Path\r\nimport numpy as np\r\n\r\nimport torch\r\nfrom sklearn.metrics import average_precision_score, roc_auc_score\r\nfrom torch.nn import Linear\r\n\r\nfrom torch_geometric.datasets import JODIEDataset\r\nfrom torch_geometric.loader import TemporalDataLoader\r\n\r\nfrom torch_geometric.nn import TransformerConv\r\n\r\n# internal imports\r\nfrom tgb.utils.utils import get_args, set_random_seed, save_results\r\nfrom tgb.linkproppred.evaluate import Evaluator\r\nfrom modules.decoder import LinkPredictor\r\nfrom modules.emb_module import GraphAttentionEmbedding\r\nfrom modules.msg_func import IdentityMessage\r\nfrom modules.msg_agg import LastAggregator\r\nfrom modules.neighbor_loader import LastNeighborLoader\r\nfrom modules.memory_module import TGNMemory\r\nfrom modules.early_stopping import  EarlyStopMonitor\r\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\r\n\r\n\r\n# ==========\r\n# ========== Define helper function...\r\n# ==========\r\n\r\ndef train():\r\n    r\"\"\"\r\n    Training procedure for TGN model\r\n    This function uses some objects that are globally defined in the current scrips \r\n\r\n    Parameters:\r\n        None\r\n    Returns:\r\n        None\r\n            \r\n    \"\"\"\r\n\r\n    model['memory'].train()\r\n    model['gnn'].train()\r\n    model['link_pred'].train()\r\n\r\n    model['memory'].reset_state()  # Start with a fresh memory.\r\n    neighbor_loader.reset_state()  # Start with an empty graph.\r\n\r\n    total_loss = 0\r\n    for batch in train_loader:\r\n        batch = batch.to(device)\r\n        optimizer.zero_grad()\r\n\r\n        src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg\r\n\r\n        # Sample negative destination nodes.\r\n        neg_dst = torch.randint(\r\n            min_dst_idx,\r\n            max_dst_idx + 1,\r\n            (src.size(0),),\r\n            dtype=torch.long,\r\n            device=device,\r\n        )\r\n\r\n        n_id = torch.cat([src, pos_dst, neg_dst]).unique()\r\n        n_id, edge_index, e_id = neighbor_loader(n_id)\r\n        assoc[n_id] = torch.arange(n_id.size(0), device=device)\r\n\r\n        # Get updated memory of all nodes involved in the computation.\r\n        z, last_update = model['memory'](n_id)\r\n        z = model['gnn'](\r\n            z,\r\n            last_update,\r\n            edge_index,\r\n            data.t[e_id].to(device),\r\n            data.msg[e_id].to(device),\r\n        )\r\n\r\n        pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])\r\n        neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])\r\n\r\n        loss = criterion(pos_out, torch.ones_like(pos_out))\r\n        loss += criterion(neg_out, torch.zeros_like(neg_out))\r\n\r\n        # Update memory and neighbor loader with ground-truth state.\r\n        model['memory'].update_state(src, pos_dst, t, msg)\r\n        neighbor_loader.insert(src, pos_dst)\r\n\r\n        loss.backward()\r\n        optimizer.step()\r\n        model['memory'].detach()\r\n        total_loss += float(loss.detach()) * batch.num_events\r\n\r\n    return total_loss / train_data.num_events\r\n\r\n\r\n@torch.no_grad()\r\ndef test(loader, neg_sampler, split_mode):\r\n    r\"\"\"\r\n    Evaluated the dynamic link prediction\r\n    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges\r\n\r\n    Parameters:\r\n        loader: an object containing positive attributes of the positive edges of the evaluation set\r\n        neg_sampler: an object that gives the negative edges corresponding to each positive edge\r\n        split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives\r\n    Returns:\r\n        perf_metric: the result of the performance evaluation\r\n    \"\"\"\r\n    model['memory'].eval()\r\n    model['gnn'].eval()\r\n    model['link_pred'].eval()\r\n\r\n    perf_list = []\r\n\r\n    for pos_batch in loader:\r\n        pos_src, pos_dst, pos_t, pos_msg = (\r\n            pos_batch.src,\r\n            pos_batch.dst,\r\n            pos_batch.t,\r\n            pos_batch.msg,\r\n        )\r\n\r\n        neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)\r\n\r\n        for idx, neg_batch in enumerate(neg_batch_list):\r\n            src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)\r\n            dst = torch.tensor(\r\n                np.concatenate(\r\n                    ([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),\r\n                    axis=0,\r\n                ),\r\n                device=device,\r\n            )\r\n\r\n            n_id = torch.cat([src, dst]).unique()\r\n            n_id, edge_index, e_id = neighbor_loader(n_id)\r\n            assoc[n_id] = torch.arange(n_id.size(0), device=device)\r\n\r\n            # Get updated memory of all nodes involved in the computation.\r\n            z, last_update = model['memory'](n_id)\r\n            z = model['gnn'](\r\n                z,\r\n                last_update,\r\n                edge_index,\r\n                data.t[e_id].to(device),\r\n                data.msg[e_id].to(device),\r\n            )\r\n\r\n            y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])\r\n\r\n            # compute MRR\r\n            input_dict = {\r\n                \"y_pred_pos\": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),\r\n                \"y_pred_neg\": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),\r\n                \"eval_metric\": [metric],\r\n            }\r\n            perf_list.append(evaluator.eval(input_dict)[metric])\r\n\r\n        # Update memory and neighbor loader with ground-truth state.\r\n        model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg)\r\n        neighbor_loader.insert(pos_src, pos_dst)\r\n\r\n    perf_metrics = float(torch.tensor(perf_list).mean())\r\n\r\n    return perf_metrics\r\n\r\n# ==========\r\n# ==========\r\n# ==========\r\n\r\n\r\n# Start...\r\nstart_overall = timeit.default_timer()\r\nDATA = \"tgbl-coin\"\r\n\r\n\r\n# ========== set parameters...\r\nargs, _ = get_args()\r\nargs.data = DATA\r\nprint(\"INFO: Arguments:\", args)\r\n\r\nLR = args.lr\r\nBATCH_SIZE = args.bs\r\nK_VALUE = args.k_value  \r\nNUM_EPOCH = args.num_epoch\r\nSEED = args.seed\r\nMEM_DIM = args.mem_dim\r\nTIME_DIM = args.time_dim\r\nEMB_DIM = args.emb_dim\r\nTOLERANCE = args.tolerance\r\nPATIENCE = args.patience\r\nNUM_RUNS = args.num_run\r\nNUM_NEIGHBORS = 10\r\n\r\n\r\nMODEL_NAME = 'TGN'\r\n# ==========\r\n\r\n# set the device\r\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\r\n\r\n# data loading\r\ndataset = PyGLinkPropPredDataset(name=DATA, root=\"datasets\")\r\ntrain_mask = dataset.train_mask\r\nval_mask = dataset.val_mask\r\ntest_mask = dataset.test_mask\r\ndata = dataset.get_TemporalData()\r\ndata = data.to(device)\r\nmetric = dataset.eval_metric\r\n\r\ntrain_data = data[train_mask]\r\nval_data = data[val_mask]\r\ntest_data = data[test_mask]\r\n\r\ntrain_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)\r\nval_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)\r\ntest_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)\r\n\r\n# Ensure to only sample actual destination nodes as negatives.\r\nmin_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\r\n\r\n\r\nprint(\"==========================================================\")\r\nprint(f\"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============\")\r\nprint(\"==========================================================\")\r\n\r\nevaluator = Evaluator(name=DATA)\r\nneg_sampler = dataset.negative_sampler\r\n\r\n# for saving the results...\r\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\r\nif not osp.exists(results_path):\r\n    os.mkdir(results_path)\r\n    print('INFO: Create directory {}'.format(results_path))\r\nPath(results_path).mkdir(parents=True, exist_ok=True)\r\nresults_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'\r\n\r\nfor run_idx in range(NUM_RUNS):\r\n    print('-------------------------------------------------------------------------------')\r\n    print(f\"INFO: >>>>> Run: {run_idx} <<<<<\")\r\n    start_run = timeit.default_timer()\r\n\r\n    # set the seed for deterministic results...\r\n    torch.manual_seed(run_idx + SEED)\r\n    set_random_seed(run_idx + SEED)\r\n\r\n    # neighborhood sampler\r\n    neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)\r\n\r\n    # define the model end-to-end\r\n    memory = TGNMemory(\r\n        data.num_nodes,\r\n        data.msg.size(-1),\r\n        MEM_DIM,\r\n        TIME_DIM,\r\n        message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),\r\n        aggregator_module=LastAggregator(),\r\n    ).to(device)\r\n\r\n    gnn = GraphAttentionEmbedding(\r\n        in_channels=MEM_DIM,\r\n        out_channels=EMB_DIM,\r\n        msg_dim=data.msg.size(-1),\r\n        time_enc=memory.time_enc,\r\n    ).to(device)\r\n\r\n    link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)\r\n\r\n    model = {'memory': memory,\r\n            'gnn': gnn,\r\n            'link_pred': link_pred}\r\n\r\n    optimizer = torch.optim.Adam(\r\n        set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),\r\n        lr=LR,\r\n    )\r\n    criterion = torch.nn.BCEWithLogitsLoss()\r\n\r\n    # Helper vector to map global node indices to local ones.\r\n    assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)\r\n\r\n\r\n    # define an early stopper\r\n    save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'\r\n    save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'\r\n    early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id, \r\n                                    tolerance=TOLERANCE, patience=PATIENCE)\r\n\r\n    # ==================================================== Train & Validation\r\n    # loading the validation negative samples\r\n    dataset.load_val_ns()\r\n\r\n    val_perf_list = []\r\n    start_train_val = timeit.default_timer()\r\n    for epoch in range(1, NUM_EPOCH + 1):\r\n        # training\r\n        start_epoch_train = timeit.default_timer()\r\n        loss = train()\r\n        print(\r\n            f\"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {timeit.default_timer() - start_epoch_train: .4f}\"\r\n        )\r\n\r\n        # validation\r\n        start_val = timeit.default_timer()\r\n        perf_metric_val = test(val_loader, neg_sampler, split_mode=\"val\")\r\n        print(f\"\\tValidation {metric}: {perf_metric_val: .4f}\")\r\n        print(f\"\\tValidation: Elapsed time (s): {timeit.default_timer() - start_val: .4f}\")\r\n        val_perf_list.append(perf_metric_val)\r\n\r\n        # check for early stopping\r\n        if early_stopper.step_check(perf_metric_val, model):\r\n            break\r\n\r\n    train_val_time = timeit.default_timer() - start_train_val\r\n    print(f\"Train & Validation: Elapsed Time (s): {train_val_time: .4f}\")\r\n\r\n    # ==================================================== Test\r\n    # first, load the best model\r\n    early_stopper.load_checkpoint(model)\r\n\r\n    # loading the test negative samples\r\n    dataset.load_test_ns()\r\n\r\n    # final testing\r\n    start_test = timeit.default_timer()\r\n    perf_metric_test = test(test_loader, neg_sampler, split_mode=\"test\")\r\n\r\n    print(f\"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< \")\r\n    print(f\"\\tTest: {metric}: {perf_metric_test: .4f}\")\r\n    test_time = timeit.default_timer() - start_test\r\n    print(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\r\n\r\n    save_results({'model': MODEL_NAME,\r\n                  'data': DATA,\r\n                  'run': run_idx,\r\n                  'seed': SEED,\r\n                  f'val {metric}': val_perf_list,\r\n                  f'test {metric}': perf_metric_test,\r\n                  'test_time': test_time,\r\n                  'tot_train_val_time': train_val_time\r\n                  }, \r\n    results_filename)\r\n\r\n    print(f\"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<\")\r\n    print('-------------------------------------------------------------------------------')\r\n\r\nprint(f\"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}\")\r\nprint(\"==============================================================\")\r\n"
  },
  {
    "path": "examples/linkproppred/tgbl-comment/dyrep.py",
    "content": "\"\"\"\nDyRep\n    This has been implemented with intuitions from the following sources:\n    - https://github.com/twitter-research/tgn\n    - https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py\n\n    Spec.:\n        - Memory Updater: RNN\n        - Embedding Module: ID\n        - Message Function: ATTN\n\"\"\"\nimport math\nimport timeit\nimport os\nimport os.path as osp\nfrom pathlib import Path\nimport numpy as np\nimport torch\nfrom sklearn.metrics import average_precision_score, roc_auc_score\nfrom torch.nn import Linear\nfrom torch_geometric.datasets import JODIEDataset\nfrom torch_geometric.loader import TemporalDataLoader\nfrom torch_geometric.nn import TransformerConv\n\n# internal imports\nfrom tgb.utils.utils import get_args, set_random_seed, save_results\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom modules.decoder import LinkPredictor\nfrom modules.emb_module import GraphAttentionEmbedding\nfrom modules.msg_func import IdentityMessage\nfrom modules.msg_agg import LastAggregator\nfrom modules.neighbor_loader import LastNeighborLoader\nfrom modules.memory_module import DyRepMemory\nfrom modules.early_stopping import  EarlyStopMonitor\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\n\n\n\n# ==========\n# ========== Define helper function...\n# ==========\n\ndef train():\n    model['memory'].train()\n    model['gnn'].train()\n    model['link_pred'].train()\n\n    model['memory'].reset_state()  # Start with a fresh memory.\n    neighbor_loader.reset_state()  # Start with an empty graph.\n\n    total_loss = 0\n    for batch in train_loader:\n        batch = batch.to(device)\n        optimizer.zero_grad()\n\n        src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg\n\n        # Sample negative destination nodes.\n        neg_dst = torch.randint(\n            min_dst_idx,\n            max_dst_idx + 1,\n            (src.size(0),),\n            dtype=torch.long,\n            device=device,\n        )\n\n        n_id = torch.cat([src, pos_dst, neg_dst]).unique()\n        n_id, edge_index, e_id = neighbor_loader(n_id)\n        assoc[n_id] = torch.arange(n_id.size(0), device=device)\n\n        # Get updated memory of all nodes involved in the computation.\n        z, last_update = model['memory'](n_id)\n\n        pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])\n        neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])\n\n        loss = criterion(pos_out, torch.ones_like(pos_out))\n        loss += criterion(neg_out, torch.zeros_like(neg_out))\n\n        # update the memory with ground-truth\n        z = model['gnn'](\n            z,\n            last_update,\n            edge_index,\n            data.t[e_id].to(device),\n            data.msg[e_id].to(device),\n        )\n        model['memory'].update_state(src, pos_dst, t, msg, z, assoc)\n\n        # update neighbor loader\n        neighbor_loader.insert(src, pos_dst)\n\n        loss.backward()\n        optimizer.step()\n        model['memory'].detach()\n        total_loss += float(loss.detach()) * batch.num_events\n\n    return total_loss / train_data.num_events\n\n\n@torch.no_grad()\ndef test_one_vs_many(loader, neg_sampler, split_mode):\n    \"\"\"\n    Evaluated the dynamic link prediction\n    \"\"\"\n    model['memory'].eval()\n    model['gnn'].eval()\n    model['link_pred'].eval()\n\n    perf_list = []\n\n    for pos_batch in loader:\n        pos_src, pos_dst, pos_t, pos_msg = (\n            pos_batch.src,\n            pos_batch.dst,\n            pos_batch.t,\n            pos_batch.msg,\n        )\n\n        neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)\n\n        for idx, neg_batch in enumerate(neg_batch_list):\n            src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)\n            dst = torch.tensor(\n                np.concatenate(\n                    ([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),\n                    axis=0,\n                ),\n                device=device,\n            )\n\n            n_id = torch.cat([src, dst]).unique()\n            n_id, edge_index, e_id = neighbor_loader(n_id)\n            assoc[n_id] = torch.arange(n_id.size(0), device=device)\n\n            # Get updated memory of all nodes involved in the computation.\n            z, last_update = model['memory'](n_id)\n\n            y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])\n            # compute MRR\n            input_dict = {\n                \"y_pred_pos\": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),\n                \"y_pred_neg\": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),\n                \"eval_metric\": [metric],\n            }\n            perf_list.append(evaluator.eval(input_dict)[metric])\n        \n        # update the memory with positive edges\n        n_id = torch.cat([pos_src, pos_dst]).unique()\n        n_id, edge_index, e_id = neighbor_loader(n_id)\n        assoc[n_id] = torch.arange(n_id.size(0), device=device)\n        z, last_update = model['memory'](n_id)\n        z = model['gnn'](\n            z,\n            last_update,\n            edge_index,\n            data.t[e_id].to(device),\n            data.msg[e_id].to(device),\n        )\n        model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg, z, assoc)\n\n        # update the neighbor loader\n        neighbor_loader.insert(pos_src, pos_dst)\n\n    perf_metric = float(torch.tensor(perf_list).mean())\n\n    return perf_metric\n\n# ==========\n# ==========\n# ==========\n\n# Start...\nstart_overall = timeit.default_timer()\n\n# ========== set parameters...\nargs, _ = get_args()\nprint(\"INFO: Arguments:\", args)\n\nDATA = \"tgbl-comment\"\nLR = args.lr\nBATCH_SIZE = args.bs\nK_VALUE = args.k_value  \nNUM_EPOCH = args.num_epoch\nSEED = args.seed\nMEM_DIM = args.mem_dim\nTIME_DIM = args.time_dim\nEMB_DIM = args.emb_dim\nTOLERANCE = args.tolerance\nPATIENCE = args.patience\nNUM_RUNS = args.num_run\nNUM_NEIGHBORS = 10\n\nMODEL_NAME = 'DyRep'\nUSE_SRC_EMB_IN_MSG = False\nUSE_DST_EMB_IN_MSG = True\n# ==========\n\n# set the device\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n# data loading\ndataset = PyGLinkPropPredDataset(name=DATA, root=\"datasets\")\ntrain_mask = dataset.train_mask\nval_mask = dataset.val_mask\ntest_mask = dataset.test_mask\ndata = dataset.get_TemporalData()\ndata = data.to(device)\nmetric = dataset.eval_metric\n\ntrain_data = data[train_mask]\nval_data = data[val_mask]\ntest_data = data[test_mask]\n\ntrain_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)\nval_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)\ntest_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)\n\n# Ensure to only sample actual destination nodes as negatives.\nmin_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\n\n# neighborhood sampler\nneighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)\n\n# define the model end-to-end\nmemory = DyRepMemory(\n    data.num_nodes,\n    data.msg.size(-1),\n    MEM_DIM,\n    TIME_DIM,\n    message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),\n    aggregator_module=LastAggregator(),\n    memory_updater_type='rnn',\n    use_src_emb_in_msg=USE_SRC_EMB_IN_MSG,\n    use_dst_emb_in_msg=USE_DST_EMB_IN_MSG\n).to(device)\n\ngnn = GraphAttentionEmbedding(\n    in_channels=MEM_DIM,\n    out_channels=EMB_DIM,\n    msg_dim=data.msg.size(-1),\n    time_enc=memory.time_enc,\n).to(device)\n\nlink_pred = LinkPredictor(in_channels=EMB_DIM).to(device)\n\nmodel = {'memory': memory,\n         'gnn': gnn,\n         'link_pred': link_pred}\n\noptimizer = torch.optim.Adam(\n    set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),\n    lr=LR,\n)\ncriterion = torch.nn.BCEWithLogitsLoss()\n\n# Helper vector to map global node indices to local ones.\nassoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)\n\n\nprint(\"==========================================================\")\nprint(f\"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============\")\nprint(\"==========================================================\")\n\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'\n\nfor run_idx in range(NUM_RUNS):\n    print('-------------------------------------------------------------------------------')\n    print(f\"INFO: >>>>> Run: {run_idx} <<<<<\")\n    start_run = timeit.default_timer()\n\n    # set the seed for deterministic results...\n    torch.manual_seed(run_idx + SEED)\n    set_random_seed(run_idx + SEED)\n\n    # define an early stopper\n    save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'\n    save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'\n    early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id, \n                                    tolerance=TOLERANCE, patience=PATIENCE)\n\n    # ==================================================== Train & Validation\n    # loading the validation negative samples\n    dataset.load_val_ns()\n\n    val_perf_list = []\n    start_train_val = timeit.default_timer()\n    for epoch in range(1, NUM_EPOCH + 1):\n        # training\n        start_epoch_train = timeit.default_timer()\n        loss = train()\n        print(\n            f\"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {timeit.default_timer() - start_epoch_train: .4f}\"\n        )\n\n        # validation\n        start_val = timeit.default_timer()\n        perf_metric_val = test_one_vs_many(val_loader, neg_sampler, split_mode=\"val\")\n        print(f\"\\tValidation {metric}: {perf_metric_val: .4f}\")\n        print(f\"\\tValidation: Elapsed time (s): {timeit.default_timer() - start_val: .4f}\")\n        val_perf_list.append(perf_metric_val)\n\n        # check for early stopping\n        if early_stopper.step_check(perf_metric_val, model):\n            break\n\n    train_val_time = timeit.default_timer() - start_train_val\n    print(f\"Train & Validation: Elapsed Time (s): {train_val_time: .4f}\")\n\n    # ==================================================== Test\n    # first, load the best model\n    early_stopper.load_checkpoint(model)\n\n    # loading the test negative samples\n    dataset.load_test_ns()\n\n    # final testing\n    start_test = timeit.default_timer()\n    perf_metric_test = test_one_vs_many(test_loader, neg_sampler, split_mode=\"test\")\n\n    print(f\"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< \")\n    print(f\"\\tTest: {metric}: {perf_metric_test: .4f}\")\n    test_time = timeit.default_timer() - start_test\n    print(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\n\n    save_results({'model': MODEL_NAME,\n                  'data': DATA,\n                  'run': run_idx,\n                  'seed': SEED,\n                  f'val {metric}': val_perf_list,\n                  f'test {metric}': perf_metric_test,\n                  'test_time': test_time,\n                  'tot_train_val_time': train_val_time\n                  }, \n    results_filename)\n\n    print(f\"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<\")\n    print('-------------------------------------------------------------------------------')\n\nprint(f\"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}\")\nprint(\"==============================================================\")\n"
  },
  {
    "path": "examples/linkproppred/tgbl-comment/edgebank.py",
    "content": "\"\"\"\nDynamic Link Prediction with EdgeBank\nNOTE: This implementation works only based on `numpy`\n\nReference: \n    - https://github.com/fpour/DGB/tree/main\n\n\n\"\"\"\n\nimport timeit\nimport numpy as np\nfrom sklearn.metrics import average_precision_score, roc_auc_score\nfrom torch_geometric.loader import TemporalDataLoader\nfrom tqdm import tqdm\nimport math\nimport os\nimport os.path as osp\nfrom pathlib import Path\nimport sys\nimport argparse\n\n# internal imports\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom modules.edgebank_predictor import EdgeBankPredictor\nfrom tgb.utils.utils import set_random_seed\nfrom tgb.linkproppred.dataset import LinkPropPredDataset\nfrom tgb.utils.utils import save_results\n\n# ==================\n# ==================\n# ==================\n\ndef test(data, test_mask, neg_sampler, split_mode):\n    r\"\"\"\n    Evaluated the dynamic link prediction\n    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges\n\n    Parameters:\n        data: a dataset object\n        test_mask: required masks to load the test set edges\n        neg_sampler: an object that gives the negative edges corresponding to each positive edge\n        split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives\n    Returns:\n        perf_metric: the result of the performance evaluation\n    \"\"\"\n    num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)\n    perf_list = []\n    for batch_idx in tqdm(range(num_batches)):\n        start_idx = batch_idx * BATCH_SIZE\n        end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))\n        pos_src, pos_dst, pos_t = (\n            data['sources'][test_mask][start_idx: end_idx],\n            data['destinations'][test_mask][start_idx: end_idx],\n            data['timestamps'][test_mask][start_idx: end_idx],\n        )\n        neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)\n        \n        for idx, neg_batch in enumerate(neg_batch_list):\n            query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])\n            query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])\n\n            y_pred = edgebank.predict_link(query_src, query_dst)\n            # compute MRR\n            input_dict = {\n                \"y_pred_pos\": np.array([y_pred[0]]),\n                \"y_pred_neg\": np.array(y_pred[1:]),\n                \"eval_metric\": [metric],\n            }\n            perf_list.append(evaluator.eval(input_dict)[metric])\n            \n        # update edgebank memory after each positive batch\n        edgebank.update_memory(pos_src, pos_dst, pos_t)\n\n    perf_metrics = float(np.mean(perf_list))\n\n    return perf_metrics\n\ndef get_args():\n    parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')\n    parser.add_argument('-d', '--data', type=str, help='Dataset name', default='tgbl-comment')\n    parser.add_argument('--bs', type=int, help='Batch size', default=200)\n    parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)\n    parser.add_argument('--seed', type=int, help='Random seed', default=1)\n    parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])\n    parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)\n\n    try:\n        args = parser.parse_args()\n    except:\n        parser.print_help()\n        sys.exit(0)\n    return args, sys.argv \n\n# ==================\n# ==================\n# ==================\n\nstart_overall = timeit.default_timer()\n\n# set hyperparameters\nargs, _ = get_args()\n\nSEED = args.seed  # set the random seed for consistency\nset_random_seed(SEED)\nMEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`\nBATCH_SIZE = args.bs\nK_VALUE = args.k_value\nTIME_WINDOW_RATIO = args.time_window_ratio\nDATA = \"tgbl-comment\"\n\nMODEL_NAME = 'EdgeBank'\n\n# data loading with `numpy`\ndataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\ndata = dataset.full_data  \nmetric = dataset.eval_metric\n\n# get masks\ntrain_mask = dataset.train_mask\nval_mask = dataset.val_mask\ntest_mask = dataset.test_mask\n\n#data for memory in edgebank\nhist_src = np.concatenate([data['sources'][train_mask]])\nhist_dst = np.concatenate([data['destinations'][train_mask]])\nhist_ts = np.concatenate([data['timestamps'][train_mask]])\n\n# Set EdgeBank with memory updater\nedgebank = EdgeBankPredictor(\n        hist_src,\n        hist_dst,\n        hist_ts,\n        memory_mode=MEMORY_MODE,\n        time_window_ratio=TIME_WINDOW_RATIO)\n\nprint(\"==========================================================\")\nprint(f\"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============\")\nprint(\"==========================================================\")\n\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'\n\n# ==================================================== Test\n# loading the validation negative samples\ndataset.load_val_ns()\n\n# testing ...\nstart_val = timeit.default_timer()\nperf_metric_test = test(data, val_mask, neg_sampler, split_mode='val')\nend_val = timeit.default_timer()\n\nprint(f\"INFO: val: Evaluation Setting: >>> ONE-VS-MANY <<< \")\nprint(f\"\\tval: {metric}: {perf_metric_test: .4f}\")\ntest_time = timeit.default_timer() - start_val\nprint(f\"\\tval: Elapsed Time (s): {test_time: .4f}\")\n\n\n\n\n# ==================================================== Test\n# loading the test negative samples\ndataset.load_test_ns()\n\n# testing ...\nstart_test = timeit.default_timer()\nperf_metric_test = test(data, test_mask, neg_sampler, split_mode='test')\nend_test = timeit.default_timer()\n\nprint(f\"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< \")\nprint(f\"\\tTest: {metric}: {perf_metric_test: .4f}\")\ntest_time = timeit.default_timer() - start_test\nprint(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\n\nsave_results({'model': MODEL_NAME,\n              'memory_mode': MEMORY_MODE,\n              'data': DATA,\n              'run': 1,\n              'seed': SEED,\n              metric: perf_metric_test,\n              'test_time': test_time,\n              'tot_train_val_time': 'NA'\n              }, \n    results_filename)\n"
  },
  {
    "path": "examples/linkproppred/tgbl-comment/tgn.py",
    "content": "\"\"\"\nDynamic Link Prediction with a TGN model with Early Stopping\nReference: \n    - https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py\n\ncommand for an example run:\n    python examples/linkproppred/tgbl-comment/tgn.py --data \"tgbl-comment\" --num_run 1 --seed 1\n\"\"\"\n\nimport math\nimport timeit\n\nimport os\nimport os.path as osp\nfrom pathlib import Path\nimport numpy as np\n\nimport torch\nfrom sklearn.metrics import average_precision_score, roc_auc_score\nfrom torch.nn import Linear\n\nfrom torch_geometric.datasets import JODIEDataset\nfrom torch_geometric.loader import TemporalDataLoader\n\nfrom torch_geometric.nn import TransformerConv\n\n# internal imports\nfrom tgb.utils.utils import get_args, set_random_seed, save_results\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom modules.decoder import LinkPredictor\nfrom modules.emb_module import GraphAttentionEmbedding\nfrom modules.msg_func import IdentityMessage\nfrom modules.msg_agg import LastAggregator\nfrom modules.neighbor_loader import LastNeighborLoader\nfrom modules.memory_module import TGNMemory\nfrom modules.early_stopping import  EarlyStopMonitor\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\n\n\n# ==========\n# ========== Define helper function...\n# ==========\n\ndef train():\n    r\"\"\"\n    Training procedure for TGN model\n    This function uses some objects that are globally defined in the current scrips \n\n    Parameters:\n        None\n    Returns:\n        None\n            \n    \"\"\"\n\n    model['memory'].train()\n    model['gnn'].train()\n    model['link_pred'].train()\n\n    model['memory'].reset_state()  # Start with a fresh memory.\n    neighbor_loader.reset_state()  # Start with an empty graph.\n\n    total_loss = 0\n    for batch in train_loader:\n        batch = batch.to(device)\n        optimizer.zero_grad()\n\n        src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg\n\n        # Sample negative destination nodes.\n        neg_dst = torch.randint(\n            min_dst_idx,\n            max_dst_idx + 1,\n            (src.size(0),),\n            dtype=torch.long,\n            device=device,\n        )\n\n        n_id = torch.cat([src, pos_dst, neg_dst]).unique()\n        n_id, edge_index, e_id = neighbor_loader(n_id)\n        assoc[n_id] = torch.arange(n_id.size(0), device=device)\n\n        # Get updated memory of all nodes involved in the computation.\n        z, last_update = model['memory'](n_id)\n        z = model['gnn'](\n            z,\n            last_update,\n            edge_index,\n            data.t[e_id].to(device),\n            data.msg[e_id].to(device),\n        )\n\n        pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])\n        neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])\n\n        loss = criterion(pos_out, torch.ones_like(pos_out))\n        loss += criterion(neg_out, torch.zeros_like(neg_out))\n\n        # Update memory and neighbor loader with ground-truth state.\n        model['memory'].update_state(src, pos_dst, t, msg)\n        neighbor_loader.insert(src, pos_dst)\n\n        loss.backward()\n        optimizer.step()\n        model['memory'].detach()\n        total_loss += float(loss.detach()) * batch.num_events\n\n    return total_loss / train_data.num_events\n\n\n@torch.no_grad()\ndef test(loader, neg_sampler, split_mode):\n    r\"\"\"\n    Evaluated the dynamic link prediction\n    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges\n\n    Parameters:\n        loader: an object containing positive attributes of the positive edges of the evaluation set\n        neg_sampler: an object that gives the negative edges corresponding to each positive edge\n        split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives\n    Returns:\n        perf_metric: the result of the performance evaluation\n    \"\"\"\n    model['memory'].eval()\n    model['gnn'].eval()\n    model['link_pred'].eval()\n\n    perf_list = []\n\n    for pos_batch in loader:\n        pos_src, pos_dst, pos_t, pos_msg = (\n            pos_batch.src,\n            pos_batch.dst,\n            pos_batch.t,\n            pos_batch.msg,\n        )\n\n        neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)\n\n        for idx, neg_batch in enumerate(neg_batch_list):\n            src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)\n            dst = torch.tensor(\n                np.concatenate(\n                    ([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),\n                    axis=0,\n                ),\n                device=device,\n            )\n\n            n_id = torch.cat([src, dst]).unique()\n            n_id, edge_index, e_id = neighbor_loader(n_id)\n            assoc[n_id] = torch.arange(n_id.size(0), device=device)\n\n            # Get updated memory of all nodes involved in the computation.\n            z, last_update = model['memory'](n_id)\n            z = model['gnn'](\n                z,\n                last_update,\n                edge_index,\n                data.t[e_id].to(device),\n                data.msg[e_id].to(device),\n            )\n\n            y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])\n\n            # compute MRR\n            input_dict = {\n                \"y_pred_pos\": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),\n                \"y_pred_neg\": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),\n                \"eval_metric\": [metric],\n            }\n            perf_list.append(evaluator.eval(input_dict)[metric])\n\n        # Update memory and neighbor loader with ground-truth state.\n        model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg)\n        neighbor_loader.insert(pos_src, pos_dst)\n\n    perf_metrics = float(torch.tensor(perf_list).mean())\n\n    return perf_metrics\n\n# ==========\n# ==========\n# ==========\n\n\n# Start...\nstart_overall = timeit.default_timer()\nDATA = \"tgbl-comment\"\n\n# ========== set parameters...\nargs, _ = get_args()\nargs.data = DATA\nprint(\"INFO: Arguments:\", args)\n\nLR = args.lr\nBATCH_SIZE = args.bs\nK_VALUE = args.k_value  \nNUM_EPOCH = args.num_epoch\nSEED = args.seed\nMEM_DIM = args.mem_dim\nTIME_DIM = args.time_dim\nEMB_DIM = args.emb_dim\nTOLERANCE = args.tolerance\nPATIENCE = args.patience\nNUM_RUNS = args.num_run\nNUM_NEIGHBORS = 10\n\n\nMODEL_NAME = 'TGN'\n# ==========\n\n# set the device\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n# data loading\ndataset = PyGLinkPropPredDataset(name=DATA, root=\"datasets\")\ntrain_mask = dataset.train_mask\nval_mask = dataset.val_mask\ntest_mask = dataset.test_mask\ndata = dataset.get_TemporalData()\ndata = data.to(device)\nmetric = dataset.eval_metric\n\ntrain_data = data[train_mask]\nval_data = data[val_mask]\ntest_data = data[test_mask]\n\ntrain_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)\nval_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)\ntest_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)\n\n# Ensure to only sample actual destination nodes as negatives.\nmin_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\n\nprint(\"==========================================================\")\nprint(f\"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============\")\nprint(\"==========================================================\")\n\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'\n\nfor run_idx in range(NUM_RUNS):\n    print('-------------------------------------------------------------------------------')\n    print(f\"INFO: >>>>> Run: {run_idx} <<<<<\")\n    start_run = timeit.default_timer()\n\n    # set the seed for deterministic results...\n    torch.manual_seed(run_idx + SEED)\n    set_random_seed(run_idx + SEED)\n\n    # neighborhood sampler\n    neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)\n\n    # define the model end-to-end\n    memory = TGNMemory(\n        data.num_nodes,\n        data.msg.size(-1),\n        MEM_DIM,\n        TIME_DIM,\n        message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),\n        aggregator_module=LastAggregator(),\n    ).to(device)\n\n    gnn = GraphAttentionEmbedding(\n        in_channels=MEM_DIM,\n        out_channels=EMB_DIM,\n        msg_dim=data.msg.size(-1),\n        time_enc=memory.time_enc,\n    ).to(device)\n\n    link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)\n\n    model = {'memory': memory,\n            'gnn': gnn,\n            'link_pred': link_pred}\n\n    optimizer = torch.optim.Adam(\n        set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),\n        lr=LR,\n    )\n    criterion = torch.nn.BCEWithLogitsLoss()\n\n    # Helper vector to map global node indices to local ones.\n    assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)\n\n\n    # define an early stopper\n    save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'\n    save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'\n    early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id, \n                                    tolerance=TOLERANCE, patience=PATIENCE)\n\n    # ==================================================== Train & Validation\n    # loading the validation negative samples\n    dataset.load_val_ns()\n\n    val_perf_list = []\n    start_train_val = timeit.default_timer()\n    for epoch in range(1, NUM_EPOCH + 1):\n        # training\n        start_epoch_train = timeit.default_timer()\n        loss = train()\n        print(\n            f\"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {timeit.default_timer() - start_epoch_train: .4f}\"\n        )\n\n        # validation\n        start_val = timeit.default_timer()\n        perf_metric_val = test(val_loader, neg_sampler, split_mode=\"val\")\n        print(f\"\\tValidation {metric}: {perf_metric_val: .4f}\")\n        print(f\"\\tValidation: Elapsed time (s): {timeit.default_timer() - start_val: .4f}\")\n        val_perf_list.append(perf_metric_val)\n\n        # check for early stopping\n        if early_stopper.step_check(perf_metric_val, model):\n            break\n\n    train_val_time = timeit.default_timer() - start_train_val\n    print(f\"Train & Validation: Elapsed Time (s): {train_val_time: .4f}\")\n\n    # ==================================================== Test\n    # first, load the best model\n    early_stopper.load_checkpoint(model)\n\n    # loading the test negative samples\n    dataset.load_test_ns()\n\n    # final testing\n    start_test = timeit.default_timer()\n    perf_metric_test = test(test_loader, neg_sampler, split_mode=\"test\")\n\n    print(f\"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< \")\n    print(f\"\\tTest: {metric}: {perf_metric_test: .4f}\")\n    test_time = timeit.default_timer() - start_test\n    print(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\n\n    save_results({'model': MODEL_NAME,\n                  'data': DATA,\n                  'run': run_idx,\n                  'seed': SEED,\n                  f'val {metric}': val_perf_list,\n                  f'test {metric}': perf_metric_test,\n                  'test_time': test_time,\n                  'tot_train_val_time': train_val_time\n                  }, \n    results_filename)\n\n    print(f\"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<\")\n    print('-------------------------------------------------------------------------------')\n\nprint(f\"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}\")\nprint(\"==============================================================\")\n"
  },
  {
    "path": "examples/linkproppred/tgbl-enron/edgebank.py",
    "content": "\"\"\"\nDynamic Link Prediction with EdgeBank\nNOTE: This implementation works only based on `numpy`\n\nReference: \n    - https://github.com/fpour/DGB/tree/main\n\n\n\"\"\"\n\nimport timeit\nimport numpy as np\nfrom tqdm import tqdm\nimport math\nimport os\nimport os.path as osp\nfrom pathlib import Path\nimport sys\nimport argparse\n\n# internal imports\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom modules.edgebank_predictor import EdgeBankPredictor\nfrom tgb.utils.utils import set_random_seed\nfrom tgb.linkproppred.dataset import LinkPropPredDataset\nfrom tgb.utils.utils import save_results\n\n# ==================\n# ==================\n# ==================\n\ndef test(data, test_mask, neg_sampler, split_mode):\n    r\"\"\"\n    Evaluated the dynamic link prediction\n    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges\n\n    Parameters:\n        data: a dataset object\n        test_mask: required masks to load the test set edges\n        neg_sampler: an object that gives the negative edges corresponding to each positive edge\n        split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives\n    Returns:\n        perf_metric: the result of the performance evaluaiton\n    \"\"\"\n    num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)\n    perf_list = []\n    for batch_idx in tqdm(range(num_batches)):\n        start_idx = batch_idx * BATCH_SIZE\n        end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))\n        pos_src, pos_dst, pos_t = (\n            data['sources'][test_mask][start_idx: end_idx],\n            data['destinations'][test_mask][start_idx: end_idx],\n            data['timestamps'][test_mask][start_idx: end_idx],\n        )\n        neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)\n        \n        for idx, neg_batch in enumerate(neg_batch_list):\n            query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])\n            query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])\n\n            y_pred = edgebank.predict_link(query_src, query_dst)\n            # compute MRR\n            input_dict = {\n                \"y_pred_pos\": np.array([y_pred[0]]),\n                \"y_pred_neg\": np.array(y_pred[1:]),\n                \"eval_metric\": [metric],\n            }\n            perf_list.append(evaluator.eval(input_dict)[metric])\n            \n        # update edgebank memory after each positive batch\n        edgebank.update_memory(pos_src, pos_dst, pos_t)\n\n    perf_metrics = float(np.mean(perf_list))\n\n    return perf_metrics\n\ndef get_args():\n    parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')\n    parser.add_argument('-d', '--data', type=str, help='Dataset name', default='tgbl-wiki')\n    parser.add_argument('--bs', type=int, help='Batch size', default=200)\n    parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)\n    parser.add_argument('--seed', type=int, help='Random seed', default=1)\n    parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])\n    parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)\n\n    try:\n        args = parser.parse_args()\n    except:\n        parser.print_help()\n        sys.exit(0)\n    return args, sys.argv \n\n# ==================\n# ==================\n# ==================\n\nstart_overall = timeit.default_timer()\n\n# set hyperparameters\nargs, _ = get_args()\n\nSEED = args.seed  # set the random seed for consistency\nset_random_seed(SEED)\nMEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`\nBATCH_SIZE = args.bs\nK_VALUE = args.k_value\nTIME_WINDOW_RATIO = args.time_window_ratio\nDATA = \"tgbl-enron\"\n\nMODEL_NAME = 'EdgeBank'\n\n# data loading with `numpy`\ndataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\ndata = dataset.full_data  \nmetric = dataset.eval_metric\n\n\n# get masks\ntrain_mask = dataset.train_mask\nval_mask = dataset.val_mask\ntest_mask = dataset.test_mask\n\n#data for memory in edgebank\nhist_src = np.concatenate([data['sources'][train_mask]])\nhist_dst = np.concatenate([data['destinations'][train_mask]])\nhist_ts = np.concatenate([data['timestamps'][train_mask]])\n\n# Set EdgeBank with memory updater\nedgebank = EdgeBankPredictor(\n        hist_src,\n        hist_dst,\n        hist_ts,\n        memory_mode=MEMORY_MODE,\n        time_window_ratio=TIME_WINDOW_RATIO)\n\nprint(\"==========================================================\")\nprint(f\"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============\")\nprint(\"==========================================================\")\n\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'\n\n# ==================================================== Test\n# loading the validation negative samples\ndataset.load_val_ns()\n\n# testing ...\nstart_val = timeit.default_timer()\nperf_metric_test = test(data, val_mask, neg_sampler, split_mode='val')\nend_val = timeit.default_timer()\n\nprint(f\"INFO: val: Evaluation Setting: >>> ONE-VS-MANY <<< \")\nprint(f\"\\tval: {metric}: {perf_metric_test: .4f}\")\ntest_time = timeit.default_timer() - start_val\nprint(f\"\\tval: Elapsed Time (s): {test_time: .4f}\")\n\n\n\n\n# ==================================================== Test\n# loading the test negative samples\ndataset.load_test_ns()\n\n# testing ...\nstart_test = timeit.default_timer()\nperf_metric_test = test(data, test_mask, neg_sampler, split_mode='test')\nend_test = timeit.default_timer()\n\nprint(f\"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< \")\nprint(f\"\\tTest: {metric}: {perf_metric_test: .4f}\")\ntest_time = timeit.default_timer() - start_test\nprint(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\n\nsave_results({'model': MODEL_NAME,\n              'memory_mode': MEMORY_MODE,\n              'data': DATA,\n              'run': 1,\n              'seed': SEED,\n              metric: perf_metric_test,\n              'test_time': test_time,\n              'tot_train_val_time': 'NA'\n              }, \n    results_filename)\n"
  },
  {
    "path": "examples/linkproppred/tgbl-flight/dyrep.py",
    "content": "\"\"\"\nDyRep\n    This has been implemented with intuitions from the following sources:\n    - https://github.com/twitter-research/tgn\n    - https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py\n\n    Spec.:\n        - Memory Updater: RNN\n        - Embedding Module: ID\n        - Message Function: ATTN\n\ncommand for an example run:\n    python examples/linkproppred/tgbl-flight/dyrep.py --data \"tgbl-flight\" --num_run 1 --seed 1\n\"\"\"\nimport math\nimport timeit\nimport os\nimport os.path as osp\nfrom pathlib import Path\nimport numpy as np\nimport torch\nfrom torch_geometric.loader import TemporalDataLoader\n\n# internal imports\nfrom tgb.utils.utils import get_args, set_random_seed, save_results\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom modules.decoder import LinkPredictor\nfrom modules.emb_module import GraphAttentionEmbedding\nfrom modules.msg_func import IdentityMessage\nfrom modules.msg_agg import LastAggregator\nfrom modules.neighbor_loader import LastNeighborLoader\nfrom modules.memory_module import DyRepMemory\nfrom modules.early_stopping import  EarlyStopMonitor\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\n\n\n\n# ==========\n# ========== Define helper function...\n# ==========\n\ndef train():\n    r\"\"\"\n    Training procedure for DyRep model\n    This function uses some objects that are globally defined in the current scrips \n\n    Parameters:\n        None\n    Returns:\n        None\n            \n    \"\"\"\n    model['memory'].train()\n    model['gnn'].train()\n    model['link_pred'].train()\n\n    model['memory'].reset_state()  # Start with a fresh memory.\n    neighbor_loader.reset_state()  # Start with an empty graph.\n\n    total_loss = 0\n    for batch in train_loader:\n        batch = batch.to(device)\n        optimizer.zero_grad()\n\n        src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg\n\n        # Sample negative destination nodes.\n        neg_dst = torch.randint(\n            min_dst_idx,\n            max_dst_idx + 1,\n            (src.size(0),),\n            dtype=torch.long,\n            device=device,\n        )\n\n        n_id = torch.cat([src, pos_dst, neg_dst]).unique()\n        n_id, edge_index, e_id = neighbor_loader(n_id)\n        assoc[n_id] = torch.arange(n_id.size(0), device=device)\n\n        # Get updated memory of all nodes involved in the computation.\n        z, last_update = model['memory'](n_id)\n\n        pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])\n        neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])\n\n        loss = criterion(pos_out, torch.ones_like(pos_out))\n        loss += criterion(neg_out, torch.zeros_like(neg_out))\n\n        # update the memory with ground-truth\n        z = model['gnn'](\n            z,\n            last_update,\n            edge_index,\n            data.t[e_id].to(device),\n            data.msg[e_id].to(device),\n        )\n        model['memory'].update_state(src, pos_dst, t, msg, z, assoc)\n\n        # update neighbor loader\n        neighbor_loader.insert(src, pos_dst)\n\n        loss.backward()\n        optimizer.step()\n        model['memory'].detach()\n        total_loss += float(loss.detach()) * batch.num_events\n\n    return total_loss / train_data.num_events\n\n\n@torch.no_grad()\ndef test(loader, neg_sampler, split_mode):\n    r\"\"\"\n    Evaluated the dynamic link prediction\n    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges\n\n    Parameters:\n        loader: an object containing positive attributes of the positive edges of the evaluation set\n        neg_sampler: an object that gives the negative edges corresponding to each positive edge\n        split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives\n    Returns:\n        perf_metric: the result of the performance evaluation\n    \"\"\"\n    model['memory'].eval()\n    model['gnn'].eval()\n    model['link_pred'].eval()\n\n    perf_list = []\n\n    for pos_batch in loader:\n        pos_src, pos_dst, pos_t, pos_msg = (\n            pos_batch.src,\n            pos_batch.dst,\n            pos_batch.t,\n            pos_batch.msg,\n        )\n\n        neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)\n\n        for idx, neg_batch in enumerate(neg_batch_list):\n            src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)\n            dst = torch.tensor(\n                np.concatenate(\n                    ([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),\n                    axis=0,\n                ),\n                device=device,\n            )\n\n            n_id = torch.cat([src, dst]).unique()\n            n_id, edge_index, e_id = neighbor_loader(n_id)\n            assoc[n_id] = torch.arange(n_id.size(0), device=device)\n\n            # Get updated memory of all nodes involved in the computation.\n            z, last_update = model['memory'](n_id)\n\n            y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])\n            # compute MRR\n            input_dict = {\n                \"y_pred_pos\": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),\n                \"y_pred_neg\": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),\n                \"eval_metric\": [metric],\n            }\n            perf_list.append(evaluator.eval(input_dict)[metric])\n        \n        # update the memory with positive edges\n        n_id = torch.cat([pos_src, pos_dst]).unique()\n        n_id, edge_index, e_id = neighbor_loader(n_id)\n        assoc[n_id] = torch.arange(n_id.size(0), device=device)\n        z, last_update = model['memory'](n_id)\n        z = model['gnn'](\n            z,\n            last_update,\n            edge_index,\n            data.t[e_id].to(device),\n            data.msg[e_id].to(device),\n        )\n        model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg, z, assoc)\n\n        # update the neighbor loader\n        neighbor_loader.insert(pos_src, pos_dst)\n\n    perf_metric = float(torch.tensor(perf_list).mean())\n\n    return perf_metric\n\n# ==========\n# ==========\n# ==========\n\n# Start...\nstart_overall = timeit.default_timer()\n\n# ========== set parameters...\nargs, _ = get_args()\nprint(\"INFO: Arguments:\", args)\n\nDATA = \"tgbl-flight\"\nLR = args.lr\nBATCH_SIZE = args.bs\nK_VALUE = args.k_value  \nNUM_EPOCH = args.num_epoch\nSEED = args.seed\nMEM_DIM = args.mem_dim\nTIME_DIM = args.time_dim\nEMB_DIM = args.emb_dim\nTOLERANCE = args.tolerance\nPATIENCE = args.patience\nNUM_RUNS = args.num_run\nNUM_NEIGHBORS = 10\n\nMODEL_NAME = 'DyRep'\nUSE_SRC_EMB_IN_MSG = False\nUSE_DST_EMB_IN_MSG = True\n# ==========\n\n# set the device\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n# data loading\ndataset = PyGLinkPropPredDataset(name=DATA, root=\"datasets\")\ntrain_mask = dataset.train_mask\nval_mask = dataset.val_mask\ntest_mask = dataset.test_mask\ndata = dataset.get_TemporalData()\ndata = data.to(device)\nmetric = dataset.eval_metric\n\ntrain_data = data[train_mask]\nval_data = data[val_mask]\ntest_data = data[test_mask]\n\ntrain_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)\nval_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)\ntest_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)\n\n# Ensure to only sample actual destination nodes as negatives.\nmin_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\n\n# neighborhood sampler\nneighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)\n\n# define the model end-to-end\n# 1) memory\nmemory = DyRepMemory(\n    data.num_nodes,\n    data.msg.size(-1),\n    MEM_DIM,\n    TIME_DIM,\n    message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),\n    aggregator_module=LastAggregator(),\n    memory_updater_type='rnn',\n    use_src_emb_in_msg=USE_SRC_EMB_IN_MSG,\n    use_dst_emb_in_msg=USE_DST_EMB_IN_MSG\n).to(device)\n\n# 2) GNN\ngnn = GraphAttentionEmbedding(\n    in_channels=MEM_DIM,\n    out_channels=EMB_DIM,\n    msg_dim=data.msg.size(-1),\n    time_enc=memory.time_enc,\n).to(device)\n\n# 3) link predictor\nlink_pred = LinkPredictor(in_channels=EMB_DIM).to(device)\n\nmodel = {'memory': memory,\n         'gnn': gnn,\n         'link_pred': link_pred}\n\n# define an optimizer\noptimizer = torch.optim.Adam(\n    set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),\n    lr=LR,\n)\ncriterion = torch.nn.BCEWithLogitsLoss()\n\n# Helper vector to map global node indices to local ones.\nassoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)\n\n\nprint(\"==========================================================\")\nprint(f\"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============\")\nprint(\"==========================================================\")\n\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'\n\nfor run_idx in range(NUM_RUNS):\n    print('-------------------------------------------------------------------------------')\n    print(f\"INFO: >>>>> Run: {run_idx} <<<<<\")\n    start_run = timeit.default_timer()\n\n    # set the seed for deterministic results...\n    torch.manual_seed(run_idx + SEED)\n    set_random_seed(run_idx + SEED)\n\n    # define an early stopper\n    save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'\n    save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'\n    early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id, \n                                    tolerance=TOLERANCE, patience=PATIENCE)\n\n    # ==================================================== Train & Validation\n    # loading the validation negative samples\n    dataset.load_val_ns()\n\n    val_perf_list = []\n    train_times_l, val_times_l = [], []\n    free_mem_l, total_mem_l, used_mem_l = [], [], []\n    start_train_val = timeit.default_timer()\n    for epoch in range(1, NUM_EPOCH + 1):\n        # training\n        start_epoch_train = timeit.default_timer()\n        loss = train()\n        end_epoch_train = timeit.default_timer()\n        print(\n            f\"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {end_epoch_train - start_epoch_train: .4f}\"\n        )\n        # checking GPU memory usage\n        free_mem, used_mem, total_mem = 0, 0, 0\n        if torch.cuda.is_available():\n            print(\"DEBUG: device: {}\".format(torch.cuda.get_device_name(0)))\n            free_mem, total_mem = torch.cuda.mem_get_info()\n            used_mem = total_mem - free_mem\n            print(\"------------Epoch {}: GPU memory usage-----------\".format(epoch))\n            print(\"Free memory: {}\".format(free_mem))\n            print(\"Total available memory: {}\".format(total_mem))\n            print(\"Used memory: {}\".format(used_mem))\n            print(\"--------------------------------------------\")\n            \n        train_times_l.append(end_epoch_train - start_epoch_train)\n        free_mem_l.append(float((free_mem*1.0)/2**30))  # in GB\n        used_mem_l.append(float((used_mem*1.0)/2**30))  # in GB\n        total_mem_l.append(float((total_mem*1.0)/2**30))  # in GB\n\n        # validation\n        start_val = timeit.default_timer()\n        perf_metric_val = test(val_loader, neg_sampler, split_mode=\"val\")\n        end_val = timeit.default_timer()\n        print(f\"\\tValidation {metric}: {perf_metric_val: .4f}\")\n        print(f\"\\tValidation: Elapsed time (s): {end_val - start_val: .4f}\")\n        val_perf_list.append(perf_metric_val)\n        val_times_l.append(end_val - start_val)\n\n        # check for early stopping\n        if early_stopper.step_check(perf_metric_val, model):\n            break\n\n    train_val_time = timeit.default_timer() - start_train_val\n    print(f\"Train & Validation Total Elapsed Time (s): {train_val_time: .4f}\")\n\n    # ==================================================== Test\n    # first, load the best model\n    early_stopper.load_checkpoint(model)\n\n    # loading the test negative samples\n    dataset.load_test_ns()\n\n    # final testing\n    start_test = timeit.default_timer()\n    perf_metric_test = test(test_loader, neg_sampler, split_mode=\"test\")\n\n    print(f\"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< \")\n    print(f\"\\tTest: {metric}: {perf_metric_test: .4f}\")\n    test_time = timeit.default_timer() - start_test\n    print(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\n\n    save_results({'data': DATA,\n                  'model': MODEL_NAME,\n                  'run': run_idx,\n                  'seed': SEED,\n                  'train_times': train_times_l,\n                  'free_mem': free_mem_l,\n                  'total_mem': total_mem_l,\n                  'used_mem': used_mem_l,\n                  'max_used_mem': max(used_mem_l),\n                  'val_times': val_times_l,\n                  f'val {metric}': val_perf_list,\n                  f'test {metric}': perf_metric_test,\n                  'test_time': test_time,\n                  'train_val_total_time': np.sum(np.array(train_times_l)) + np.sum(np.array(val_times_l)),\n                  }, \n    results_filename)\n\n    print(f\"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<\")\n    print('-------------------------------------------------------------------------------')\n\nprint(f\"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}\")\nprint(\"==============================================================\")\n"
  },
  {
    "path": "examples/linkproppred/tgbl-flight/edgebank.py",
    "content": "\"\"\"\nDynamic Link Prediction with EdgeBank\nNOTE: This implementation works only based on `numpy`\n\nReference: \n    - https://github.com/fpour/DGB/tree/main\n\n\n\"\"\"\n\nimport timeit\nimport numpy as np\nfrom sklearn.metrics import average_precision_score, roc_auc_score\nfrom torch_geometric.loader import TemporalDataLoader\nfrom tqdm import tqdm\nimport math\nimport os\nimport os.path as osp\nfrom pathlib import Path\nimport sys\nimport argparse\n\n# internal imports\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom modules.edgebank_predictor import EdgeBankPredictor\nfrom tgb.utils.utils import set_random_seed\nfrom tgb.linkproppred.dataset import LinkPropPredDataset\nfrom tgb.utils.utils import save_results\n\n# ==================\n# ==================\n# ==================\n\ndef test(data, test_mask, neg_sampler, split_mode):\n    r\"\"\"\n    Evaluated the dynamic link prediction\n    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges\n\n    Parameters:\n        data: a dataset object\n        test_mask: required masks to load the test set edges\n        neg_sampler: an object that gives the negative edges corresponding to each positive edge\n        split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives\n    Returns:\n        perf_metric: the result of the performance evaluation\n    \"\"\"\n    num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)\n    perf_list = []\n    for batch_idx in tqdm(range(num_batches)):\n        start_idx = batch_idx * BATCH_SIZE\n        end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))\n        pos_src, pos_dst, pos_t = (\n            data['sources'][test_mask][start_idx: end_idx],\n            data['destinations'][test_mask][start_idx: end_idx],\n            data['timestamps'][test_mask][start_idx: end_idx],\n        )\n        neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)\n        \n        for idx, neg_batch in enumerate(neg_batch_list):\n            query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])\n            query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])\n\n            y_pred = edgebank.predict_link(query_src, query_dst)\n            # compute MRR\n            input_dict = {\n                \"y_pred_pos\": np.array([y_pred[0]]),\n                \"y_pred_neg\": np.array(y_pred[1:]),\n                \"eval_metric\": [metric],\n            }\n            perf_list.append(evaluator.eval(input_dict)[metric])\n            \n        # update edgebank memory after each positive batch\n        edgebank.update_memory(pos_src, pos_dst, pos_t)\n\n    perf_metrics = float(np.mean(perf_list))\n\n    return perf_metrics\n\ndef get_args():\n    parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')\n    parser.add_argument('-d', '--data', type=str, help='Dataset name', default='tgbl-flight')\n    parser.add_argument('--bs', type=int, help='Batch size', default=200)\n    parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)\n    parser.add_argument('--seed', type=int, help='Random seed', default=1)\n    parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])\n    parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)\n\n    try:\n        args = parser.parse_args()\n    except:\n        parser.print_help()\n        sys.exit(0)\n    return args, sys.argv \n\n# ==================\n# ==================\n# ==================\n\nstart_overall = timeit.default_timer()\n\n# set hyperparameters\nargs, _ = get_args()\n\nSEED = args.seed  # set the random seed for consistency\nset_random_seed(SEED)\nMEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`\nBATCH_SIZE = args.bs\nK_VALUE = args.k_value\nTIME_WINDOW_RATIO = args.time_window_ratio\nDATA = \"tgbl-flight\"\n\nMODEL_NAME = 'EdgeBank'\n\n# data loading with `numpy`\ndataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\ndata = dataset.full_data  \nmetric = dataset.eval_metric\n\n# get masks\ntrain_mask = dataset.train_mask\nval_mask = dataset.val_mask\ntest_mask = dataset.test_mask\n\n\n#data for memory in edgebank\nhist_src = np.concatenate([data['sources'][train_mask]])\nhist_dst = np.concatenate([data['destinations'][train_mask]])\nhist_ts = np.concatenate([data['timestamps'][train_mask]])\n\n# Set EdgeBank with memory updater\nedgebank = EdgeBankPredictor(\n        hist_src,\n        hist_dst,\n        hist_ts,\n        memory_mode=MEMORY_MODE,\n        time_window_ratio=TIME_WINDOW_RATIO)\n\nprint(\"==========================================================\")\nprint(f\"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============\")\nprint(\"==========================================================\")\n\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'\n\n# ==================================================== Test\n# loading the validation negative samples\ndataset.load_val_ns()\n\n\n# testing ...\nstart_val = timeit.default_timer()\nperf_metric_test = test(data, val_mask, neg_sampler, split_mode='val')\nend_val = timeit.default_timer()\n\nprint(f\"INFO: val: Evaluation Setting: >>> ONE-VS-MANY <<< \")\nprint(f\"\\tval: {metric}: {perf_metric_test: .4f}\")\ntest_time = timeit.default_timer() - start_val\nprint(f\"\\tval: Elapsed Time (s): {test_time: .4f}\")\n\n\n\n\n# ==================================================== Test\n# loading the test negative samples\ndataset.load_test_ns()\n\n# testing ...\nstart_test = timeit.default_timer()\nperf_metric_test = test(data, test_mask, neg_sampler, split_mode='test')\nend_test = timeit.default_timer()\n\nprint(f\"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< \")\nprint(f\"\\tTest: {metric}: {perf_metric_test: .4f}\")\ntest_time = timeit.default_timer() - start_test\nprint(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\n\nsave_results({'model': MODEL_NAME,\n              'memory_mode': MEMORY_MODE,\n              'data': DATA,\n              'run': 1,\n              'seed': SEED,\n              metric: perf_metric_test,\n              'test_time': test_time,\n              'tot_train_val_time': 'NA'\n              }, \n    results_filename)\n"
  },
  {
    "path": "examples/linkproppred/tgbl-flight/tgn.py",
    "content": "\"\"\"\nDynamic Link Prediction with a TGN model with Early Stopping\nReference: \n    - https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py\n\ncommand for an example run:\n    python examples/linkproppred/tgbl-flight/tgn.py --data \"tgbl-flight\" --num_run 1 --seed 1\n\"\"\"\n\nimport math\nimport timeit\n\nimport os\nimport os.path as osp\nfrom pathlib import Path\nimport numpy as np\n\nimport torch\nfrom sklearn.metrics import average_precision_score, roc_auc_score\nfrom torch.nn import Linear\n\nfrom torch_geometric.datasets import JODIEDataset\nfrom torch_geometric.loader import TemporalDataLoader\n\nfrom torch_geometric.nn import TransformerConv\n\n# internal imports\nfrom tgb.utils.utils import get_args, set_random_seed, save_results\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom modules.decoder import LinkPredictor\nfrom modules.emb_module import GraphAttentionEmbedding\nfrom modules.msg_func import IdentityMessage\nfrom modules.msg_agg import LastAggregator\nfrom modules.neighbor_loader import LastNeighborLoader\nfrom modules.memory_module import TGNMemory\nfrom modules.early_stopping import  EarlyStopMonitor\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\n\n\n# ==========\n# ========== Define helper function...\n# ==========\n\ndef train():\n    r\"\"\"\n    Training procedure for TGN model\n    This function uses some objects that are globally defined in the current scrips \n\n    Parameters:\n        None\n    Returns:\n        None\n            \n    \"\"\"\n\n    model['memory'].train()\n    model['gnn'].train()\n    model['link_pred'].train()\n\n    model['memory'].reset_state()  # Start with a fresh memory.\n    neighbor_loader.reset_state()  # Start with an empty graph.\n\n    total_loss = 0\n    for batch in train_loader:\n        batch = batch.to(device)\n        optimizer.zero_grad()\n\n        src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg\n\n        # Sample negative destination nodes.\n        neg_dst = torch.randint(\n            min_dst_idx,\n            max_dst_idx + 1,\n            (src.size(0),),\n            dtype=torch.long,\n            device=device,\n        )\n\n        n_id = torch.cat([src, pos_dst, neg_dst]).unique()\n        n_id, edge_index, e_id = neighbor_loader(n_id)\n        assoc[n_id] = torch.arange(n_id.size(0), device=device)\n\n        # Get updated memory of all nodes involved in the computation.\n        z, last_update = model['memory'](n_id)\n        z = model['gnn'](\n            z,\n            last_update,\n            edge_index,\n            data.t[e_id].to(device),\n            data.msg[e_id].to(device),\n        )\n\n        pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])\n        neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])\n\n        loss = criterion(pos_out, torch.ones_like(pos_out))\n        loss += criterion(neg_out, torch.zeros_like(neg_out))\n\n        # Update memory and neighbor loader with ground-truth state.\n        model['memory'].update_state(src, pos_dst, t, msg)\n        neighbor_loader.insert(src, pos_dst)\n\n        loss.backward()\n        optimizer.step()\n        model['memory'].detach()\n        total_loss += float(loss.detach()) * batch.num_events\n\n    return total_loss / train_data.num_events\n\n\n@torch.no_grad()\ndef test(loader, neg_sampler, split_mode):\n    r\"\"\"\n    Evaluated the dynamic link prediction\n    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges\n\n    Parameters:\n        loader: an object containing positive attributes of the positive edges of the evaluation set\n        neg_sampler: an object that gives the negative edges corresponding to each positive edge\n        split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives\n    Returns:\n        perf_metric: the result of the performance evaluation\n    \"\"\"\n    model['memory'].eval()\n    model['gnn'].eval()\n    model['link_pred'].eval()\n\n    perf_list = []\n\n    for pos_batch in loader:\n        pos_src, pos_dst, pos_t, pos_msg = (\n            pos_batch.src,\n            pos_batch.dst,\n            pos_batch.t,\n            pos_batch.msg,\n        )\n\n        neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)\n\n        for idx, neg_batch in enumerate(neg_batch_list):\n            src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)\n            dst = torch.tensor(\n                np.concatenate(\n                    ([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),\n                    axis=0,\n                ),\n                device=device,\n            )\n\n            n_id = torch.cat([src, dst]).unique()\n            n_id, edge_index, e_id = neighbor_loader(n_id)\n            assoc[n_id] = torch.arange(n_id.size(0), device=device)\n\n            # Get updated memory of all nodes involved in the computation.\n            z, last_update = model['memory'](n_id)\n            z = model['gnn'](\n                z,\n                last_update,\n                edge_index,\n                data.t[e_id].to(device),\n                data.msg[e_id].to(device),\n            )\n\n            y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])\n\n            # compute MRR\n            input_dict = {\n                \"y_pred_pos\": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),\n                \"y_pred_neg\": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),\n                \"eval_metric\": [metric],\n            }\n            perf_list.append(evaluator.eval(input_dict)[metric])\n\n        # Update memory and neighbor loader with ground-truth state.\n        model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg)\n        neighbor_loader.insert(pos_src, pos_dst)\n\n    perf_metrics = float(torch.tensor(perf_list).mean())\n\n    return perf_metrics\n\n# ==========\n# ==========\n# ==========\n\n\n# Start...\nstart_overall = timeit.default_timer()\nDATA = \"tgbl-flight\"\n\n# ========== set parameters...\nargs, _ = get_args()\nargs.data = DATA\nprint(\"INFO: Arguments:\", args)\n\nLR = args.lr\nBATCH_SIZE = args.bs\nK_VALUE = args.k_value  \nNUM_EPOCH = args.num_epoch\nSEED = args.seed\nMEM_DIM = args.mem_dim\nTIME_DIM = args.time_dim\nEMB_DIM = args.emb_dim\nTOLERANCE = args.tolerance\nPATIENCE = args.patience\nNUM_RUNS = args.num_run\nNUM_NEIGHBORS = 10\n\n\nMODEL_NAME = 'TGN'\n# ==========\n\n# set the device\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n# data loading\ndataset = PyGLinkPropPredDataset(name=DATA, root=\"datasets\")\ntrain_mask = dataset.train_mask\nval_mask = dataset.val_mask\ntest_mask = dataset.test_mask\ndata = dataset.get_TemporalData()\ndata = data.to(device)\nmetric = dataset.eval_metric\n\ntrain_data = data[train_mask]\nval_data = data[val_mask]\ntest_data = data[test_mask]\n\ntrain_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)\nval_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)\ntest_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)\n\n# Ensure to only sample actual destination nodes as negatives.\nmin_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\n\n\nprint(\"==========================================================\")\nprint(f\"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============\")\nprint(\"==========================================================\")\n\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'\n\nfor run_idx in range(NUM_RUNS):\n    print('-------------------------------------------------------------------------------')\n    print(f\"INFO: >>>>> Run: {run_idx} <<<<<\")\n    start_run = timeit.default_timer()\n\n    # set the seed for deterministic results...\n    torch.manual_seed(run_idx + SEED)\n    set_random_seed(run_idx + SEED)\n\n    # neighborhood sampler\n    neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)\n\n    # define the model end-to-end\n    memory = TGNMemory(\n        data.num_nodes,\n        data.msg.size(-1),\n        MEM_DIM,\n        TIME_DIM,\n        message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),\n        aggregator_module=LastAggregator(),\n    ).to(device)\n\n    gnn = GraphAttentionEmbedding(\n        in_channels=MEM_DIM,\n        out_channels=EMB_DIM,\n        msg_dim=data.msg.size(-1),\n        time_enc=memory.time_enc,\n    ).to(device)\n\n    link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)\n\n    model = {'memory': memory,\n            'gnn': gnn,\n            'link_pred': link_pred}\n\n    optimizer = torch.optim.Adam(\n        set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),\n        lr=LR,\n    )\n    criterion = torch.nn.BCEWithLogitsLoss()\n\n    # Helper vector to map global node indices to local ones.\n    assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)\n\n    # define an early stopper\n    save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'\n    save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'\n    early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id, \n                                    tolerance=TOLERANCE, patience=PATIENCE)\n\n    # ==================================================== Train & Validation\n    # loading the validation negative samples\n    dataset.load_val_ns()\n\n    val_perf_list = []\n    train_times_l, val_times_l = [], []\n    free_mem_l, total_mem_l, used_mem_l = [], [], []\n    start_train_val = timeit.default_timer()\n    for epoch in range(1, NUM_EPOCH + 1):\n        # training\n        start_epoch_train = timeit.default_timer()\n        loss = train()\n        end_epoch_train = timeit.default_timer()\n        print(\n            f\"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {end_epoch_train - start_epoch_train: .4f}\"\n        )\n        # checking GPU memory usage\n        free_mem, used_mem, total_mem = 0, 0, 0\n        if torch.cuda.is_available():\n            print(\"DEBUG: device: {}\".format(torch.cuda.get_device_name(0)))\n            free_mem, total_mem = torch.cuda.mem_get_info()\n            used_mem = total_mem - free_mem\n            print(\"------------Epoch {}: GPU memory usage-----------\".format(epoch))\n            print(\"Free memory: {}\".format(free_mem))\n            print(\"Total available memory: {}\".format(total_mem))\n            print(\"Used memory: {}\".format(used_mem))\n            print(\"--------------------------------------------\")\n        \n        train_times_l.append(end_epoch_train - start_epoch_train)\n        free_mem_l.append(float((free_mem*1.0)/2**30))  # in GB\n        used_mem_l.append(float((used_mem*1.0)/2**30))  # in GB\n        total_mem_l.append(float((total_mem*1.0)/2**30))  # in GB\n\n        # validation\n        start_val = timeit.default_timer()\n        perf_metric_val = test(val_loader, neg_sampler, split_mode=\"val\")\n        end_val = timeit.default_timer()\n        print(f\"\\tValidation {metric}: {perf_metric_val: .4f}\")\n        print(f\"\\tValidation: Elapsed time (s): {end_val - start_val: .4f}\")\n        val_perf_list.append(perf_metric_val)\n        val_times_l.append(end_val - start_val)\n\n        # check for early stopping\n        if early_stopper.step_check(perf_metric_val, model):\n            break\n\n    train_val_time = timeit.default_timer() - start_train_val\n    print(f\"Train & Validation: Elapsed Time (s): {train_val_time: .4f}\")\n\n    # ==================================================== Test\n    # first, load the best model\n    early_stopper.load_checkpoint(model)\n\n    # loading the test negative samples\n    dataset.load_test_ns()\n\n    # final testing\n    start_test = timeit.default_timer()\n    perf_metric_test = test(test_loader, neg_sampler, split_mode=\"test\")\n\n    print(f\"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< \")\n    print(f\"\\tTest: {metric}: {perf_metric_test: .4f}\")\n    test_time = timeit.default_timer() - start_test\n    print(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\n\n    save_results({'data': DATA,\n                  'model': MODEL_NAME,\n                  'run': run_idx,\n                  'seed': SEED,\n                  'train_times': train_times_l,\n                  'free_mem': free_mem_l,\n                  'total_mem': total_mem_l,\n                  'used_mem': used_mem_l,\n                  'max_used_mem': max(used_mem_l),\n                  'val_times': val_times_l,\n                  f'val {metric}': val_perf_list,\n                  f'test {metric}': perf_metric_test,\n                  'test_time': test_time,\n                  'train_val_total_time': np.sum(np.array(train_times_l)) + np.sum(np.array(val_times_l)),\n                  }, \n    results_filename)\n\n    print(f\"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<\")\n    print('-------------------------------------------------------------------------------')\n\nprint(f\"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}\")\nprint(\"==============================================================\")\n"
  },
  {
    "path": "examples/linkproppred/tgbl-lastfm/edgebank.py",
    "content": "\"\"\"\nDynamic Link Prediction with EdgeBank\nNOTE: This implementation works only based on `numpy`\n\nReference: \n    - https://github.com/fpour/DGB/tree/main\n\n\n\"\"\"\n\nimport timeit\nimport numpy as np\nfrom tqdm import tqdm\nimport math\nimport os\nimport os.path as osp\nfrom pathlib import Path\nimport sys\nimport argparse\n\n# internal imports\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom modules.edgebank_predictor import EdgeBankPredictor\nfrom tgb.utils.utils import set_random_seed\nfrom tgb.linkproppred.dataset import LinkPropPredDataset\nfrom tgb.utils.utils import save_results\n\n# ==================\n# ==================\n# ==================\n\ndef test(data, test_mask, neg_sampler, split_mode):\n    r\"\"\"\n    Evaluated the dynamic link prediction\n    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges\n\n    Parameters:\n        data: a dataset object\n        test_mask: required masks to load the test set edges\n        neg_sampler: an object that gives the negative edges corresponding to each positive edge\n        split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives\n    Returns:\n        perf_metric: the result of the performance evaluaiton\n    \"\"\"\n    num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)\n    perf_list = []\n    for batch_idx in tqdm(range(num_batches)):\n        start_idx = batch_idx * BATCH_SIZE\n        end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))\n        pos_src, pos_dst, pos_t = (\n            data['sources'][test_mask][start_idx: end_idx],\n            data['destinations'][test_mask][start_idx: end_idx],\n            data['timestamps'][test_mask][start_idx: end_idx],\n        )\n        neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)\n        \n        for idx, neg_batch in enumerate(neg_batch_list):\n            query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])\n            query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])\n\n            y_pred = edgebank.predict_link(query_src, query_dst)\n            # compute MRR\n            input_dict = {\n                \"y_pred_pos\": np.array([y_pred[0]]),\n                \"y_pred_neg\": np.array(y_pred[1:]),\n                \"eval_metric\": [metric],\n            }\n            perf_list.append(evaluator.eval(input_dict)[metric])\n            \n        # update edgebank memory after each positive batch\n        edgebank.update_memory(pos_src, pos_dst, pos_t)\n\n    perf_metrics = float(np.mean(perf_list))\n\n    return perf_metrics\n\ndef get_args():\n    parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')\n    parser.add_argument('-d', '--data', type=str, help='Dataset name', default='tgbl-wiki')\n    parser.add_argument('--bs', type=int, help='Batch size', default=200)\n    parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)\n    parser.add_argument('--seed', type=int, help='Random seed', default=1)\n    parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])\n    parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)\n\n    try:\n        args = parser.parse_args()\n    except:\n        parser.print_help()\n        sys.exit(0)\n    return args, sys.argv \n\n# ==================\n# ==================\n# ==================\n\nstart_overall = timeit.default_timer()\n\n# set hyperparameters\nargs, _ = get_args()\n\nSEED = args.seed  # set the random seed for consistency\nset_random_seed(SEED)\nMEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`\nBATCH_SIZE = args.bs\nK_VALUE = args.k_value\nTIME_WINDOW_RATIO = args.time_window_ratio\nDATA = \"tgbl-lastfm\"\n\nMODEL_NAME = 'EdgeBank'\n\n# data loading with `numpy`\ndataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\ndata = dataset.full_data  \nmetric = dataset.eval_metric\n\n\n# get masks\ntrain_mask = dataset.train_mask\nval_mask = dataset.val_mask\ntest_mask = dataset.test_mask\n\n#data for memory in edgebank\nhist_src = np.concatenate([data['sources'][train_mask]])\nhist_dst = np.concatenate([data['destinations'][train_mask]])\nhist_ts = np.concatenate([data['timestamps'][train_mask]])\n\n# Set EdgeBank with memory updater\nedgebank = EdgeBankPredictor(\n        hist_src,\n        hist_dst,\n        hist_ts,\n        memory_mode=MEMORY_MODE,\n        time_window_ratio=TIME_WINDOW_RATIO)\n\nprint(\"==========================================================\")\nprint(f\"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============\")\nprint(\"==========================================================\")\n\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'\n\n# ==================================================== Test\n# loading the validation negative samples\ndataset.load_val_ns()\n\n# testing ...\nstart_val = timeit.default_timer()\nperf_metric_test = test(data, val_mask, neg_sampler, split_mode='val')\nend_val = timeit.default_timer()\n\nprint(f\"INFO: val: Evaluation Setting: >>> ONE-VS-MANY <<< \")\nprint(f\"\\tval: {metric}: {perf_metric_test: .4f}\")\ntest_time = timeit.default_timer() - start_val\nprint(f\"\\tval: Elapsed Time (s): {test_time: .4f}\")\n\n\n\n\n# ==================================================== Test\n# loading the test negative samples\ndataset.load_test_ns()\n\n# testing ...\nstart_test = timeit.default_timer()\nperf_metric_test = test(data, test_mask, neg_sampler, split_mode='test')\nend_test = timeit.default_timer()\n\nprint(f\"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< \")\nprint(f\"\\tTest: {metric}: {perf_metric_test: .4f}\")\ntest_time = timeit.default_timer() - start_test\nprint(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\n\nsave_results({'model': MODEL_NAME,\n              'memory_mode': MEMORY_MODE,\n              'data': DATA,\n              'run': 1,\n              'seed': SEED,\n              metric: perf_metric_test,\n              'test_time': test_time,\n              'tot_train_val_time': 'NA'\n              }, \n    results_filename)\n"
  },
  {
    "path": "examples/linkproppred/tgbl-lastfm/tgn.py",
    "content": "\"\"\"\nDynamic Link Prediction with a TGN model with Early Stopping\nReference: \n    - https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py\n\ncommand for an example run:\n    python examples/linkproppred/tgbl-lastfm/tgn.py --data \"tgbl-lastfm\" --num_run 1 --seed 1\n\"\"\"\n\nimport math\nimport timeit\n\nimport os\nimport os.path as osp\nfrom pathlib import Path\nimport numpy as np\n\nimport torch\nfrom torch.nn import Linear\n\nfrom torch_geometric.loader import TemporalDataLoader\nfrom torch_geometric.nn import TransformerConv\n\n# internal imports\nfrom tgb.utils.utils import get_args, set_random_seed, save_results\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom modules.decoder import LinkPredictor\nfrom modules.emb_module import GraphAttentionEmbedding\nfrom modules.msg_func import IdentityMessage\nfrom modules.msg_agg import LastAggregator\nfrom modules.neighbor_loader import LastNeighborLoader\nfrom modules.memory_module import TGNMemory\nfrom modules.early_stopping import  EarlyStopMonitor\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\n\n\n# ==========\n# ========== Define helper function...\n# ==========\n\ndef train():\n    r\"\"\"\n    Training procedure for TGN model\n    This function uses some objects that are globally defined in the current scrips \n\n    Parameters:\n        None\n    Returns:\n        None\n            \n    \"\"\"\n\n    model['memory'].train()\n    model['gnn'].train()\n    model['link_pred'].train()\n\n    model['memory'].reset_state()  # Start with a fresh memory.\n    neighbor_loader.reset_state()  # Start with an empty graph.\n\n    total_loss = 0\n    for batch in train_loader:\n        batch = batch.to(device)\n        optimizer.zero_grad()\n\n        src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg\n\n        # Sample negative destination nodes.\n        neg_dst = torch.randint(\n            min_dst_idx,\n            max_dst_idx + 1,\n            (src.size(0),),\n            dtype=torch.long,\n            device=device,\n        )\n\n        n_id = torch.cat([src, pos_dst, neg_dst]).unique()\n        n_id, edge_index, e_id = neighbor_loader(n_id)\n        assoc[n_id] = torch.arange(n_id.size(0), device=device)\n\n        # Get updated memory of all nodes involved in the computation.\n        z, last_update = model['memory'](n_id)\n        z = model['gnn'](\n            z,\n            last_update,\n            edge_index,\n            data.t[e_id].to(device),\n            data.msg[e_id].to(device),\n        )\n\n        pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])\n        neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])\n\n        loss = criterion(pos_out, torch.ones_like(pos_out))\n        loss += criterion(neg_out, torch.zeros_like(neg_out))\n\n        # Update memory and neighbor loader with ground-truth state.\n        model['memory'].update_state(src, pos_dst, t, msg)\n        neighbor_loader.insert(src, pos_dst)\n\n        loss.backward()\n        optimizer.step()\n        model['memory'].detach()\n        total_loss += float(loss.detach()) * batch.num_events\n\n    return total_loss / train_data.num_events\n\n\n@torch.no_grad()\ndef test(loader, neg_sampler, split_mode):\n    r\"\"\"\n    Evaluated the dynamic link prediction\n    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges\n\n    Parameters:\n        loader: an object containing positive attributes of the positive edges of the evaluation set\n        neg_sampler: an object that gives the negative edges corresponding to each positive edge\n        split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives\n    Returns:\n        perf_metric: the result of the performance evaluaiton\n    \"\"\"\n    model['memory'].eval()\n    model['gnn'].eval()\n    model['link_pred'].eval()\n\n    perf_list = []\n\n    for pos_batch in loader:\n        pos_src, pos_dst, pos_t, pos_msg = (\n            pos_batch.src,\n            pos_batch.dst,\n            pos_batch.t,\n            pos_batch.msg,\n        )\n\n        neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)\n\n        for idx, neg_batch in enumerate(neg_batch_list):\n            src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)\n            dst = torch.tensor(\n                np.concatenate(\n                    ([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),\n                    axis=0,\n                ),\n                device=device,\n            )\n\n            n_id = torch.cat([src, dst]).unique()\n            n_id, edge_index, e_id = neighbor_loader(n_id)\n            assoc[n_id] = torch.arange(n_id.size(0), device=device)\n\n            # Get updated memory of all nodes involved in the computation.\n            z, last_update = model['memory'](n_id)\n            z = model['gnn'](\n                z,\n                last_update,\n                edge_index,\n                data.t[e_id].to(device),\n                data.msg[e_id].to(device),\n            )\n\n            y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])\n\n            # compute MRR\n            input_dict = {\n                \"y_pred_pos\": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),\n                \"y_pred_neg\": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),\n                \"eval_metric\": [metric],\n            }\n            perf_list.append(evaluator.eval(input_dict)[metric])\n\n        # Update memory and neighbor loader with ground-truth state.\n        model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg)\n        neighbor_loader.insert(pos_src, pos_dst)\n\n    perf_metrics = float(torch.tensor(perf_list).mean())\n\n    return perf_metrics\n\n# ==========\n# ==========\n# ==========\n\n\n# Start...\nstart_overall = timeit.default_timer()\nDATA = \"tgbl-lastfm\"\n\n\n# ========== set parameters...\nargs, _ = get_args()\nargs.data = DATA\nprint(\"INFO: Arguments:\", args)\n\nLR = args.lr\nBATCH_SIZE = args.bs\nK_VALUE = args.k_value  \nNUM_EPOCH = args.num_epoch\nSEED = args.seed\nMEM_DIM = args.mem_dim\nTIME_DIM = args.time_dim\nEMB_DIM = args.emb_dim\nTOLERANCE = args.tolerance\nPATIENCE = args.patience\nNUM_RUNS = args.num_run\nNUM_NEIGHBORS = 10\n\n\nMODEL_NAME = 'TGN'\n# ==========\n\n# set the device\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\n# data loading\ndataset = PyGLinkPropPredDataset(name=DATA, root=\"datasets\")\ntrain_mask = dataset.train_mask\nval_mask = dataset.val_mask\ntest_mask = dataset.test_mask\ndata = dataset.get_TemporalData()\ndata = data.to(device)\nmetric = dataset.eval_metric\n\ntrain_data = data[train_mask]\nval_data = data[val_mask]\ntest_data = data[test_mask]\n\ntrain_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)\nval_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)\ntest_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)\n\n# Ensure to only sample actual destination nodes as negatives.\nmin_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\n\nprint(\"==========================================================\")\nprint(f\"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============\")\nprint(\"==========================================================\")\n\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'\n\nfor run_idx in range(NUM_RUNS):\n    print('-------------------------------------------------------------------------------')\n    print(f\"INFO: >>>>> Run: {run_idx} <<<<<\")\n    start_run = timeit.default_timer()\n\n    # set the seed for deterministic results...\n    torch.manual_seed(run_idx + SEED)\n    set_random_seed(run_idx + SEED)\n\n    # define an early stopper\n    save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'\n    save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'\n    early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id, \n                                    tolerance=TOLERANCE, patience=PATIENCE)\n    \n    # neighhorhood sampler\n    neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)\n\n    # define the model end-to-end\n    memory = TGNMemory(\n        data.num_nodes,\n        data.msg.size(-1),\n        MEM_DIM,\n        TIME_DIM,\n        message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),\n        aggregator_module=LastAggregator(),\n    ).to(device)\n\n    gnn = GraphAttentionEmbedding(\n        in_channels=MEM_DIM,\n        out_channels=EMB_DIM,\n        msg_dim=data.msg.size(-1),\n        time_enc=memory.time_enc,\n    ).to(device)\n\n    link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)\n\n    model = {'memory': memory,\n            'gnn': gnn,\n            'link_pred': link_pred}\n\n    optimizer = torch.optim.Adam(\n        set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),\n        lr=LR,\n    )\n    criterion = torch.nn.BCEWithLogitsLoss()\n\n    # Helper vector to map global node indices to local ones.\n    assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)\n\n    # ==================================================== Train & Validation\n    # loading the validation negative samples\n    dataset.load_val_ns()\n\n    val_perf_list = []\n    start_train_val = timeit.default_timer()\n    for epoch in range(1, NUM_EPOCH + 1):\n        # training\n        start_epoch_train = timeit.default_timer()\n        loss = train()\n        print(\n            f\"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {timeit.default_timer() - start_epoch_train: .4f}\"\n        )\n\n        # validation\n        start_val = timeit.default_timer()\n        perf_metric_val = test(val_loader, neg_sampler, split_mode=\"val\")\n        print(f\"\\tValidation {metric}: {perf_metric_val: .4f}\")\n        print(f\"\\tValidation: Elapsed time (s): {timeit.default_timer() - start_val: .4f}\")\n        val_perf_list.append(perf_metric_val)\n\n        # check for early stopping\n        if early_stopper.step_check(perf_metric_val, model):\n            break\n\n    train_val_time = timeit.default_timer() - start_train_val\n    print(f\"Train & Validation: Elapsed Time (s): {train_val_time: .4f}\")\n\n    # ==================================================== Test\n    # first, load the best model\n    early_stopper.load_checkpoint(model)\n\n    # loading the test negative samples\n    dataset.load_test_ns()\n\n    # final testing\n    start_test = timeit.default_timer()\n    perf_metric_test = test(test_loader, neg_sampler, split_mode=\"test\")\n\n    print(f\"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< \")\n    print(f\"\\tTest: {metric}: {perf_metric_test: .4f}\")\n    test_time = timeit.default_timer() - start_test\n    print(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\n\n    save_results({'model': MODEL_NAME,\n                  'data': DATA,\n                  'run': run_idx,\n                  'seed': SEED,\n                  f'val {metric}': val_perf_list,\n                  f'test {metric}': perf_metric_test,\n                  'test_time': test_time,\n                  'tot_train_val_time': train_val_time\n                  }, \n    results_filename)\n\n    print(f\"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<\")\n    print('-------------------------------------------------------------------------------')\n\nprint(f\"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}\")\nprint(\"==============================================================\")\n"
  },
  {
    "path": "examples/linkproppred/tgbl-review/dyrep.py",
    "content": "\"\"\"\nDyRep\n    This has been implemented with intuitions from the following sources:\n    - https://github.com/twitter-research/tgn\n    - https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py\n\n    Spec.:\n        - Memory Updater: RNN\n        - Embedding Module: ID\n        - Message Function: ATTN\n\ncommand for an example run:\n    python examples/linkproppred/tgbl-review/dyrep.py --data \"tgbl-review\" --num_run 1 --seed 1\n\"\"\"\nimport math\nimport timeit\nimport os\nimport os.path as osp\nfrom pathlib import Path\nimport numpy as np\nimport torch\nfrom torch_geometric.loader import TemporalDataLoader\n\n# internal imports\nfrom tgb.utils.utils import get_args, set_random_seed, save_results\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom modules.decoder import LinkPredictor\nfrom modules.emb_module import GraphAttentionEmbedding\nfrom modules.msg_func import IdentityMessage\nfrom modules.msg_agg import LastAggregator\nfrom modules.neighbor_loader import LastNeighborLoader\nfrom modules.memory_module import DyRepMemory\nfrom modules.early_stopping import  EarlyStopMonitor\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\n\n\n\n# ==========\n# ========== Define helper function...\n# ==========\n\ndef train():\n    r\"\"\"\n    Training procedure for DyRep model\n    This function uses some objects that are globally defined in the current scrips \n\n    Parameters:\n        None\n    Returns:\n        None\n            \n    \"\"\"\n    model['memory'].train()\n    model['gnn'].train()\n    model['link_pred'].train()\n\n    model['memory'].reset_state()  # Start with a fresh memory.\n    neighbor_loader.reset_state()  # Start with an empty graph.\n\n    total_loss = 0\n    for batch in train_loader:\n        batch = batch.to(device)\n        optimizer.zero_grad()\n\n        src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg\n\n        # Sample negative destination nodes.\n        neg_dst = torch.randint(\n            min_dst_idx,\n            max_dst_idx + 1,\n            (src.size(0),),\n            dtype=torch.long,\n            device=device,\n        )\n\n        n_id = torch.cat([src, pos_dst, neg_dst]).unique()\n        n_id, edge_index, e_id = neighbor_loader(n_id)\n        assoc[n_id] = torch.arange(n_id.size(0), device=device)\n\n        # Get updated memory of all nodes involved in the computation.\n        z, last_update = model['memory'](n_id)\n\n        pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])\n        neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])\n\n        loss = criterion(pos_out, torch.ones_like(pos_out))\n        loss += criterion(neg_out, torch.zeros_like(neg_out))\n\n        # update the memory with ground-truth\n        z = model['gnn'](\n            z,\n            last_update,\n            edge_index,\n            data.t[e_id].to(device),\n            data.msg[e_id].to(device),\n        )\n        model['memory'].update_state(src, pos_dst, t, msg, z, assoc)\n\n        # update neighbor loader\n        neighbor_loader.insert(src, pos_dst)\n\n        loss.backward()\n        optimizer.step()\n        model['memory'].detach()\n        total_loss += float(loss.detach()) * batch.num_events\n\n    return total_loss / train_data.num_events\n\n\n@torch.no_grad()\ndef test(loader, neg_sampler, split_mode):\n    r\"\"\"\n    Evaluated the dynamic link prediction\n    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges\n\n    Parameters:\n        loader: an object containing positive attributes of the positive edges of the evaluation set\n        neg_sampler: an object that gives the negative edges corresponding to each positive edge\n        split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives\n    Returns:\n        perf_metric: the result of the performance evaluation\n    \"\"\"\n    model['memory'].eval()\n    model['gnn'].eval()\n    model['link_pred'].eval()\n\n    perf_list = []\n\n    for pos_batch in loader:\n        pos_src, pos_dst, pos_t, pos_msg = (\n            pos_batch.src,\n            pos_batch.dst,\n            pos_batch.t,\n            pos_batch.msg,\n        )\n\n        neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)\n\n        for idx, neg_batch in enumerate(neg_batch_list):\n            src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)\n            dst = torch.tensor(\n                np.concatenate(\n                    ([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),\n                    axis=0,\n                ),\n                device=device,\n            )\n\n            n_id = torch.cat([src, dst]).unique()\n            n_id, edge_index, e_id = neighbor_loader(n_id)\n            assoc[n_id] = torch.arange(n_id.size(0), device=device)\n\n            # Get updated memory of all nodes involved in the computation.\n            z, last_update = model['memory'](n_id)\n\n            y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])\n            # compute MRR\n            input_dict = {\n                \"y_pred_pos\": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),\n                \"y_pred_neg\": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),\n                \"eval_metric\": [metric],\n            }\n            perf_list.append(evaluator.eval(input_dict)[metric])\n        \n        # update the memory with positive edges\n        n_id = torch.cat([pos_src, pos_dst]).unique()\n        n_id, edge_index, e_id = neighbor_loader(n_id)\n        assoc[n_id] = torch.arange(n_id.size(0), device=device)\n        z, last_update = model['memory'](n_id)\n        z = model['gnn'](\n            z,\n            last_update,\n            edge_index,\n            data.t[e_id].to(device),\n            data.msg[e_id].to(device),\n        )\n        model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg, z, assoc)\n\n        # update the neighbor loader\n        neighbor_loader.insert(pos_src, pos_dst)\n\n    perf_metric = float(torch.tensor(perf_list).mean())\n\n    return perf_metric\n\n# ==========\n# ==========\n# ==========\n\n# Start...\nstart_overall = timeit.default_timer()\n\n# ========== set parameters...\nargs, _ = get_args()\nprint(\"INFO: Arguments:\", args)\n\nDATA = \"tgbl-review\"\nLR = args.lr\nBATCH_SIZE = args.bs\nK_VALUE = args.k_value  \nNUM_EPOCH = args.num_epoch\nSEED = args.seed\nMEM_DIM = args.mem_dim\nTIME_DIM = args.time_dim\nEMB_DIM = args.emb_dim\nTOLERANCE = args.tolerance\nPATIENCE = args.patience\nNUM_RUNS = args.num_run\nNUM_NEIGHBORS = 10\n\nMODEL_NAME = 'DyRep'\nUSE_SRC_EMB_IN_MSG = False\nUSE_DST_EMB_IN_MSG = True\n# ==========\n\n# set the device\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n# data loading\ndataset = PyGLinkPropPredDataset(name=DATA, root=\"datasets\")\ntrain_mask = dataset.train_mask\nval_mask = dataset.val_mask\ntest_mask = dataset.test_mask\ndata = dataset.get_TemporalData()\ndata = data.to(device)\nmetric = dataset.eval_metric\n\ntrain_data = data[train_mask]\nval_data = data[val_mask]\ntest_data = data[test_mask]\n\ntrain_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)\nval_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)\ntest_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)\n\n# Ensure to only sample actual destination nodes as negatives.\nmin_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\n\n# neighborhood sampler\nneighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)\n\n# define the model end-to-end\n# 1) memory\nmemory = DyRepMemory(\n    data.num_nodes,\n    data.msg.size(-1),\n    MEM_DIM,\n    TIME_DIM,\n    message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),\n    aggregator_module=LastAggregator(),\n    memory_updater_type='rnn',\n    use_src_emb_in_msg=USE_SRC_EMB_IN_MSG,\n    use_dst_emb_in_msg=USE_DST_EMB_IN_MSG\n).to(device)\n\n# 2) GNN\ngnn = GraphAttentionEmbedding(\n    in_channels=MEM_DIM,\n    out_channels=EMB_DIM,\n    msg_dim=data.msg.size(-1),\n    time_enc=memory.time_enc,\n).to(device)\n\n# 3) link predictor\nlink_pred = LinkPredictor(in_channels=EMB_DIM).to(device)\n\nmodel = {'memory': memory,\n         'gnn': gnn,\n         'link_pred': link_pred}\n\n# define an optimizer\noptimizer = torch.optim.Adam(\n    set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),\n    lr=LR,\n)\ncriterion = torch.nn.BCEWithLogitsLoss()\n\n# Helper vector to map global node indices to local ones.\nassoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)\n\n\nprint(\"==========================================================\")\nprint(f\"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============\")\nprint(\"==========================================================\")\n\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'\n\nfor run_idx in range(NUM_RUNS):\n    print('-------------------------------------------------------------------------------')\n    print(f\"INFO: >>>>> Run: {run_idx} <<<<<\")\n    start_run = timeit.default_timer()\n\n    # set the seed for deterministic results...\n    torch.manual_seed(run_idx + SEED)\n    set_random_seed(run_idx + SEED)\n\n    # define an early stopper\n    save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'\n    save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'\n    early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id, \n                                    tolerance=TOLERANCE, patience=PATIENCE)\n\n    # ==================================================== Train & Validation\n    # loading the validation negative samples\n    dataset.load_val_ns()\n\n    val_perf_list = []\n    start_train_val = timeit.default_timer()\n    for epoch in range(1, NUM_EPOCH + 1):\n        # training\n        start_epoch_train = timeit.default_timer()\n        loss = train()\n        print(\n            f\"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {timeit.default_timer() - start_epoch_train: .4f}\"\n        )\n\n        # validation\n        start_val = timeit.default_timer()\n        perf_metric_val = test(val_loader, neg_sampler, split_mode=\"val\")\n        print(f\"\\tValidation {metric}: {perf_metric_val: .4f}\")\n        print(f\"\\tValidation: Elapsed time (s): {timeit.default_timer() - start_val: .4f}\")\n        val_perf_list.append(perf_metric_val)\n\n        # check for early stopping\n        if early_stopper.step_check(perf_metric_val, model):\n            break\n\n    train_val_time = timeit.default_timer() - start_train_val\n    print(f\"Train & Validation Total Elapsed Time (s): {train_val_time: .4f}\")\n\n    # ==================================================== Test\n    # first, load the best model\n    early_stopper.load_checkpoint(model)\n\n    # loading the test negative samples\n    dataset.load_test_ns()\n\n    # final testing\n    start_test = timeit.default_timer()\n    perf_metric_test = test(test_loader, neg_sampler, split_mode=\"test\")\n\n    print(f\"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< \")\n    print(f\"\\tTest: {metric}: {perf_metric_test: .4f}\")\n    test_time = timeit.default_timer() - start_test\n    print(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\n\n    save_results({'model': MODEL_NAME,\n                  'data': DATA,\n                  'run': run_idx,\n                  'seed': SEED,\n                  f'val {metric}': val_perf_list,\n                  f'test {metric}': perf_metric_test,\n                  'test_time': test_time,\n                  'tot_train_val_time': train_val_time\n                  }, \n    results_filename)\n\n    print(f\"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<\")\n    print('-------------------------------------------------------------------------------')\n\nprint(f\"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}\")\nprint(\"==============================================================\")\n"
  },
  {
    "path": "examples/linkproppred/tgbl-review/edgebank.py",
    "content": "\"\"\"\nDynamic Link Prediction with EdgeBank\nNOTE: This implementation works only based on `numpy`\n\nReference: \n    - https://github.com/fpour/DGB/tree/main\n\n\n\"\"\"\n\nimport timeit\nimport numpy as np\nfrom tqdm import tqdm\nimport math\nimport os\nimport os.path as osp\nfrom pathlib import Path\nimport sys\nimport argparse\n\n# internal imports\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom modules.edgebank_predictor import EdgeBankPredictor\nfrom tgb.utils.utils import set_random_seed\nfrom tgb.linkproppred.dataset import LinkPropPredDataset\nfrom tgb.utils.utils import save_results\n\n# ==================\n# ==================\n# ==================\n\ndef test(data, test_mask, neg_sampler, split_mode):\n    r\"\"\"\n    Evaluated the dynamic link prediction\n    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges\n\n    Parameters:\n        data: a dataset object\n        test_mask: required masks to load the test set edges\n        neg_sampler: an object that gives the negative edges corresponding to each positive edge\n        split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives\n    Returns:\n        perf_metric: the result of the performance evaluation\n    \"\"\"\n    num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)\n    perf_list = []\n    for batch_idx in tqdm(range(num_batches)):\n        start_idx = batch_idx * BATCH_SIZE\n        end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))\n        pos_src, pos_dst, pos_t = (\n            data['sources'][test_mask][start_idx: end_idx],\n            data['destinations'][test_mask][start_idx: end_idx],\n            data['timestamps'][test_mask][start_idx: end_idx],\n        )\n        neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)\n        \n        for idx, neg_batch in enumerate(neg_batch_list):\n            query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])\n            query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])\n\n            y_pred = edgebank.predict_link(query_src, query_dst)\n            # compute MRR\n            input_dict = {\n                \"y_pred_pos\": np.array([y_pred[0]]),\n                \"y_pred_neg\": np.array(y_pred[1:]),\n                \"eval_metric\": [metric],\n            }\n            perf_list.append(evaluator.eval(input_dict)[metric])\n            \n        # update edgebank memory after each positive batch\n        edgebank.update_memory(pos_src, pos_dst, pos_t)\n\n    perf_metrics = float(np.mean(perf_list))\n\n    return perf_metrics\n\ndef get_args():\n    parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')\n    parser.add_argument('-d', '--data', type=str, help='Dataset name', default='tgbl-review')\n    parser.add_argument('--bs', type=int, help='Batch size', default=200)\n    parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)\n    parser.add_argument('--seed', type=int, help='Random seed', default=1)\n    parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])\n    parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)\n\n    try:\n        args = parser.parse_args()\n    except:\n        parser.print_help()\n        sys.exit(0)\n    return args, sys.argv \n\n# ==================\n# ==================\n# ==================\n\nstart_overall = timeit.default_timer()\n\n# set hyperparameters\nargs, _ = get_args()\n\nSEED = args.seed  # set the random seed for consistency\nset_random_seed(SEED)\nMEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`\nBATCH_SIZE = args.bs\nK_VALUE = args.k_value\nTIME_WINDOW_RATIO = args.time_window_ratio\nDATA = \"tgbl-review\"\n\nMODEL_NAME = 'EdgeBank'\n\n\n\n# data loading with `numpy`\ndataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\ndata = dataset.full_data  \nmetric = dataset.eval_metric\n\n\n# get masks\ntrain_mask = dataset.train_mask\nval_mask = dataset.val_mask\ntest_mask = dataset.test_mask\n\n#data for memory in edgebank\nhist_src = np.concatenate([data['sources'][train_mask]])\nhist_dst = np.concatenate([data['destinations'][train_mask]])\nhist_ts = np.concatenate([data['timestamps'][train_mask]])\n\n# Set EdgeBank with memory updater\nedgebank = EdgeBankPredictor(\n        hist_src,\n        hist_dst,\n        hist_ts,\n        memory_mode=MEMORY_MODE,\n        time_window_ratio=TIME_WINDOW_RATIO)\n\nprint(\"==========================================================\")\nprint(f\"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============\")\nprint(\"==========================================================\")\n\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'\n\n# ==================================================== Test\n# loading the validation negative samples\ndataset.load_val_ns()\n\n# testing ...\nstart_val = timeit.default_timer()\nperf_metric_test = test(data, val_mask, neg_sampler, split_mode='val')\nend_val = timeit.default_timer()\n\nprint(f\"INFO: val: Evaluation Setting: >>> ONE-VS-MANY <<< \")\nprint(f\"\\tval: {metric}: {perf_metric_test: .4f}\")\ntest_time = timeit.default_timer() - start_val\nprint(f\"\\tval: Elapsed Time (s): {test_time: .4f}\")\n\n\n\n\n# ==================================================== Test\n# loading the test negative samples\ndataset.load_test_ns()\n\n# testing ...\nstart_test = timeit.default_timer()\nperf_metric_test = test(data, test_mask, neg_sampler, split_mode='test')\nend_test = timeit.default_timer()\n\nprint(f\"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< \")\nprint(f\"\\tTest: {metric}: {perf_metric_test: .4f}\")\ntest_time = timeit.default_timer() - start_test\nprint(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\n\nsave_results({'model': MODEL_NAME,\n              'memory_mode': MEMORY_MODE,\n              'data': DATA,\n              'run': 1,\n              'seed': SEED,\n              metric: perf_metric_test,\n              'test_time': test_time,\n              'tot_train_val_time': 'NA'\n              }, \n    results_filename)\n"
  },
  {
    "path": "examples/linkproppred/tgbl-review/tgn.py",
    "content": "\"\"\"\nDynamic Link Prediction with a TGN model with Early Stopping\nReference: \n    - https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py\n\ncommand for an example run:\n    python examples/linkproppred/tgbl-review/tgn.py --data \"tgbl-review\" --num_run 1 --seed 1\n\"\"\"\n\nimport math\nimport timeit\n\nimport os\nimport os.path as osp\nfrom pathlib import Path\nimport numpy as np\n\nimport torch\nfrom sklearn.metrics import average_precision_score, roc_auc_score\nfrom torch.nn import Linear\n\nfrom torch_geometric.datasets import JODIEDataset\nfrom torch_geometric.loader import TemporalDataLoader\n\nfrom torch_geometric.nn import TransformerConv\n\n# internal imports\nfrom tgb.utils.utils import get_args, set_random_seed, save_results\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom modules.decoder import LinkPredictor\nfrom modules.emb_module import GraphAttentionEmbedding\nfrom modules.msg_func import IdentityMessage\nfrom modules.msg_agg import LastAggregator\nfrom modules.neighbor_loader import LastNeighborLoader\nfrom modules.memory_module import TGNMemory\nfrom modules.early_stopping import  EarlyStopMonitor\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\n\n\n# ==========\n# ========== Define helper function...\n# ==========\n\ndef train():\n    r\"\"\"\n    Training procedure for TGN model\n    This function uses some objects that are globally defined in the current scrips \n\n    Parameters:\n        None\n    Returns:\n        None\n            \n    \"\"\"\n\n    model['memory'].train()\n    model['gnn'].train()\n    model['link_pred'].train()\n\n    model['memory'].reset_state()  # Start with a fresh memory.\n    neighbor_loader.reset_state()  # Start with an empty graph.\n\n    total_loss = 0\n    for batch in train_loader:\n        batch = batch.to(device)\n        optimizer.zero_grad()\n\n        src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg\n\n        # Sample negative destination nodes.\n        neg_dst = torch.randint(\n            min_dst_idx,\n            max_dst_idx + 1,\n            (src.size(0),),\n            dtype=torch.long,\n            device=device,\n        )\n\n        n_id = torch.cat([src, pos_dst, neg_dst]).unique()\n        n_id, edge_index, e_id = neighbor_loader(n_id)\n        assoc[n_id] = torch.arange(n_id.size(0), device=device)\n\n        # Get updated memory of all nodes involved in the computation.\n        z, last_update = model['memory'](n_id)\n        z = model['gnn'](\n            z,\n            last_update,\n            edge_index,\n            data.t[e_id].to(device),\n            data.msg[e_id].to(device),\n        )\n\n        pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])\n        neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])\n\n        loss = criterion(pos_out, torch.ones_like(pos_out))\n        loss += criterion(neg_out, torch.zeros_like(neg_out))\n\n        # Update memory and neighbor loader with ground-truth state.\n        model['memory'].update_state(src, pos_dst, t, msg)\n        neighbor_loader.insert(src, pos_dst)\n\n        loss.backward()\n        optimizer.step()\n        model['memory'].detach()\n        total_loss += float(loss.detach()) * batch.num_events\n\n    return total_loss / train_data.num_events\n\n\n@torch.no_grad()\ndef test(loader, neg_sampler, split_mode):\n    r\"\"\"\n    Evaluated the dynamic link prediction\n    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges\n\n    Parameters:\n        loader: an object containing positive attributes of the positive edges of the evaluation set\n        neg_sampler: an object that gives the negative edges corresponding to each positive edge\n        split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives\n    Returns:\n        perf_metric: the result of the performance evaluation\n    \"\"\"\n    model['memory'].eval()\n    model['gnn'].eval()\n    model['link_pred'].eval()\n\n    perf_list = []\n\n    for pos_batch in loader:\n        pos_src, pos_dst, pos_t, pos_msg = (\n            pos_batch.src,\n            pos_batch.dst,\n            pos_batch.t,\n            pos_batch.msg,\n        )\n\n        neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)\n\n        for idx, neg_batch in enumerate(neg_batch_list):\n            src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)\n            dst = torch.tensor(\n                np.concatenate(\n                    ([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),\n                    axis=0,\n                ),\n                device=device,\n            )\n\n            n_id = torch.cat([src, dst]).unique()\n            n_id, edge_index, e_id = neighbor_loader(n_id)\n            assoc[n_id] = torch.arange(n_id.size(0), device=device)\n\n            # Get updated memory of all nodes involved in the computation.\n            z, last_update = model['memory'](n_id)\n            z = model['gnn'](\n                z,\n                last_update,\n                edge_index,\n                data.t[e_id].to(device),\n                data.msg[e_id].to(device),\n            )\n\n            y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])\n\n            # compute MRR\n            input_dict = {\n                \"y_pred_pos\": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),\n                \"y_pred_neg\": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),\n                \"eval_metric\": [metric],\n            }\n            perf_list.append(evaluator.eval(input_dict)[metric])\n\n        # Update memory and neighbor loader with ground-truth state.\n        model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg)\n        neighbor_loader.insert(pos_src, pos_dst)\n\n    perf_metrics = float(torch.tensor(perf_list).mean())\n\n    return perf_metrics\n\n# ==========\n# ==========\n# ==========\n\n\n# Start...\nstart_overall = timeit.default_timer()\nDATA = \"tgbl-review\"\n\n\n# ========== set parameters...\nargs, _ = get_args()\nargs.data = DATA\nprint(\"INFO: Arguments:\", args)\n\nLR = args.lr\nBATCH_SIZE = args.bs\nK_VALUE = args.k_value  \nNUM_EPOCH = args.num_epoch\nSEED = args.seed\nMEM_DIM = args.mem_dim\nTIME_DIM = args.time_dim\nEMB_DIM = args.emb_dim\nTOLERANCE = args.tolerance\nPATIENCE = args.patience\nNUM_RUNS = args.num_run\nNUM_NEIGHBORS = 10\n\n\nMODEL_NAME = 'TGN'\n# ==========\n\n# set the device\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n# data loading\ndataset = PyGLinkPropPredDataset(name=DATA, root=\"datasets\")\ntrain_mask = dataset.train_mask\nval_mask = dataset.val_mask\ntest_mask = dataset.test_mask\ndata = dataset.get_TemporalData()\ndata = data.to(device)\nmetric = dataset.eval_metric\n\ntrain_data = data[train_mask]\nval_data = data[val_mask]\ntest_data = data[test_mask]\n\ntrain_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)\nval_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)\ntest_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)\n\n# Ensure to only sample actual destination nodes as negatives.\nmin_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\n\nprint(\"==========================================================\")\nprint(f\"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============\")\nprint(\"==========================================================\")\n\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'\n\nfor run_idx in range(NUM_RUNS):\n    print('-------------------------------------------------------------------------------')\n    print(f\"INFO: >>>>> Run: {run_idx} <<<<<\")\n    start_run = timeit.default_timer()\n\n    # set the seed for deterministic results...\n    torch.manual_seed(run_idx + SEED)\n    set_random_seed(run_idx + SEED)\n\n\n    # neighborhood sampler\n    neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)\n\n    # define the model end-to-end\n    memory = TGNMemory(\n        data.num_nodes,\n        data.msg.size(-1),\n        MEM_DIM,\n        TIME_DIM,\n        message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),\n        aggregator_module=LastAggregator(),\n    ).to(device)\n\n    gnn = GraphAttentionEmbedding(\n        in_channels=MEM_DIM,\n        out_channels=EMB_DIM,\n        msg_dim=data.msg.size(-1),\n        time_enc=memory.time_enc,\n    ).to(device)\n\n    link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)\n\n    model = {'memory': memory,\n            'gnn': gnn,\n            'link_pred': link_pred}\n\n    optimizer = torch.optim.Adam(\n        set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),\n        lr=LR,\n    )\n    criterion = torch.nn.BCEWithLogitsLoss()\n\n    # Helper vector to map global node indices to local ones.\n    assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)\n\n\n    # define an early stopper\n    save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'\n    save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'\n    early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id, \n                                    tolerance=TOLERANCE, patience=PATIENCE)\n\n    # ==================================================== Train & Validation\n    # loading the validation negative samples\n    dataset.load_val_ns()\n\n    val_perf_list = []\n    start_train_val = timeit.default_timer()\n    for epoch in range(1, NUM_EPOCH + 1):\n        # training\n        start_epoch_train = timeit.default_timer()\n        loss = train()\n        print(\n            f\"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {timeit.default_timer() - start_epoch_train: .4f}\"\n        )\n\n        # validation\n        start_val = timeit.default_timer()\n        perf_metric_val = test(val_loader, neg_sampler, split_mode=\"val\")\n        print(f\"\\tValidation {metric}: {perf_metric_val: .4f}\")\n        print(f\"\\tValidation: Elapsed time (s): {timeit.default_timer() - start_val: .4f}\")\n        val_perf_list.append(perf_metric_val)\n\n        # check for early stopping\n        if early_stopper.step_check(perf_metric_val, model):\n            break\n\n    train_val_time = timeit.default_timer() - start_train_val\n    print(f\"Train & Validation: Elapsed Time (s): {train_val_time: .4f}\")\n\n    # ==================================================== Test\n    # first, load the best model\n    early_stopper.load_checkpoint(model)\n\n    # loading the test negative samples\n    dataset.load_test_ns()\n\n    # final testing\n    start_test = timeit.default_timer()\n    perf_metric_test = test(test_loader, neg_sampler, split_mode=\"test\")\n\n    print(f\"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< \")\n    print(f\"\\tTest: {metric}: {perf_metric_test: .4f}\")\n    test_time = timeit.default_timer() - start_test\n    print(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\n\n    save_results({'model': MODEL_NAME,\n                  'data': DATA,\n                  'run': run_idx,\n                  'seed': SEED,\n                  f'val {metric}': val_perf_list,\n                  f'test {metric}': perf_metric_test,\n                  'test_time': test_time,\n                  'tot_train_val_time': train_val_time\n                  }, \n    results_filename)\n\n    print(f\"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<\")\n    print('-------------------------------------------------------------------------------')\n\nprint(f\"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}\")\nprint(\"==============================================================\")\n"
  },
  {
    "path": "examples/linkproppred/tgbl-subreddit/edgebank.py",
    "content": "\"\"\"\nDynamic Link Prediction with EdgeBank\nNOTE: This implementation works only based on `numpy`\n\nReference: \n    - https://github.com/fpour/DGB/tree/main\n\n\n\"\"\"\n\nimport timeit\nimport numpy as np\nfrom tqdm import tqdm\nimport math\nimport os\nimport os.path as osp\nfrom pathlib import Path\nimport sys\nimport argparse\n\n# internal imports\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom modules.edgebank_predictor import EdgeBankPredictor\nfrom tgb.utils.utils import set_random_seed\nfrom tgb.linkproppred.dataset import LinkPropPredDataset\nfrom tgb.utils.utils import save_results\n\n# ==================\n# ==================\n# ==================\n\ndef test(data, test_mask, neg_sampler, split_mode):\n    r\"\"\"\n    Evaluated the dynamic link prediction\n    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges\n\n    Parameters:\n        data: a dataset object\n        test_mask: required masks to load the test set edges\n        neg_sampler: an object that gives the negative edges corresponding to each positive edge\n        split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives\n    Returns:\n        perf_metric: the result of the performance evaluaiton\n    \"\"\"\n    num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)\n    perf_list = []\n    for batch_idx in tqdm(range(num_batches)):\n        start_idx = batch_idx * BATCH_SIZE\n        end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))\n        pos_src, pos_dst, pos_t = (\n            data['sources'][test_mask][start_idx: end_idx],\n            data['destinations'][test_mask][start_idx: end_idx],\n            data['timestamps'][test_mask][start_idx: end_idx],\n        )\n        neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)\n        \n        for idx, neg_batch in enumerate(neg_batch_list):\n            query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])\n            query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])\n\n            y_pred = edgebank.predict_link(query_src, query_dst)\n            # compute MRR\n            input_dict = {\n                \"y_pred_pos\": np.array([y_pred[0]]),\n                \"y_pred_neg\": np.array(y_pred[1:]),\n                \"eval_metric\": [metric],\n            }\n            perf_list.append(evaluator.eval(input_dict)[metric])\n            \n        # update edgebank memory after each positive batch\n        edgebank.update_memory(pos_src, pos_dst, pos_t)\n\n    perf_metrics = float(np.mean(perf_list))\n\n    return perf_metrics\n\ndef get_args():\n    parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')\n    parser.add_argument('-d', '--data', type=str, help='Dataset name', default='tgbl-wiki')\n    parser.add_argument('--bs', type=int, help='Batch size', default=200)\n    parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)\n    parser.add_argument('--seed', type=int, help='Random seed', default=1)\n    parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])\n    parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)\n\n    try:\n        args = parser.parse_args()\n    except:\n        parser.print_help()\n        sys.exit(0)\n    return args, sys.argv \n\n# ==================\n# ==================\n# ==================\n\nstart_overall = timeit.default_timer()\n\n# set hyperparameters\nargs, _ = get_args()\n\nSEED = args.seed  # set the random seed for consistency\nset_random_seed(SEED)\nMEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`\nBATCH_SIZE = args.bs\nK_VALUE = args.k_value\nTIME_WINDOW_RATIO = args.time_window_ratio\nDATA = \"tgbl-subreddit\"\n\nMODEL_NAME = 'EdgeBank'\n\n# data loading with `numpy`\ndataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\ndata = dataset.full_data  \nmetric = dataset.eval_metric\n\n\n# get masks\ntrain_mask = dataset.train_mask\nval_mask = dataset.val_mask\ntest_mask = dataset.test_mask\n\n#data for memory in edgebank\nhist_src = np.concatenate([data['sources'][train_mask]])\nhist_dst = np.concatenate([data['destinations'][train_mask]])\nhist_ts = np.concatenate([data['timestamps'][train_mask]])\n\n# Set EdgeBank with memory updater\nedgebank = EdgeBankPredictor(\n        hist_src,\n        hist_dst,\n        hist_ts,\n        memory_mode=MEMORY_MODE,\n        time_window_ratio=TIME_WINDOW_RATIO)\n\nprint(\"==========================================================\")\nprint(f\"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============\")\nprint(\"==========================================================\")\n\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'\n\n# ==================================================== Test\n# loading the validation negative samples\ndataset.load_val_ns()\n\n# testing ...\nstart_val = timeit.default_timer()\nperf_metric_test = test(data, val_mask, neg_sampler, split_mode='val')\nend_val = timeit.default_timer()\n\nprint(f\"INFO: val: Evaluation Setting: >>> ONE-VS-MANY <<< \")\nprint(f\"\\tval: {metric}: {perf_metric_test: .4f}\")\ntest_time = timeit.default_timer() - start_val\nprint(f\"\\tval: Elapsed Time (s): {test_time: .4f}\")\n\n\n\n\n# ==================================================== Test\n# loading the test negative samples\ndataset.load_test_ns()\n\n# testing ...\nstart_test = timeit.default_timer()\nperf_metric_test = test(data, test_mask, neg_sampler, split_mode='test')\nend_test = timeit.default_timer()\n\nprint(f\"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< \")\nprint(f\"\\tTest: {metric}: {perf_metric_test: .4f}\")\ntest_time = timeit.default_timer() - start_test\nprint(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\n\nsave_results({'model': MODEL_NAME,\n              'memory_mode': MEMORY_MODE,\n              'data': DATA,\n              'run': 1,\n              'seed': SEED,\n              metric: perf_metric_test,\n              'test_time': test_time,\n              'tot_train_val_time': 'NA'\n              }, \n    results_filename)\n"
  },
  {
    "path": "examples/linkproppred/tgbl-subreddit/tgn.py",
    "content": "\"\"\"\nDynamic Link Prediction with a TGN model with Early Stopping\nReference: \n    - https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py\n\ncommand for an example run:\n    python examples/linkproppred/tgbl-subreddit/tgn.py --data \"tgbl-subreddit\" --num_run 1 --seed 1\n\"\"\"\n\nimport math\nimport timeit\n\nimport os\nimport os.path as osp\nfrom pathlib import Path\nimport numpy as np\n\nimport torch\nfrom torch.nn import Linear\n\nfrom torch_geometric.datasets import JODIEDataset\nfrom torch_geometric.loader import TemporalDataLoader\nfrom torch_geometric.nn import TransformerConv\n\n# internal imports\nfrom tgb.utils.utils import get_args, set_random_seed, save_results\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom modules.decoder import LinkPredictor\nfrom modules.emb_module import GraphAttentionEmbedding\nfrom modules.msg_func import IdentityMessage\nfrom modules.msg_agg import LastAggregator\nfrom modules.neighbor_loader import LastNeighborLoader\nfrom modules.memory_module import TGNMemory\nfrom modules.early_stopping import  EarlyStopMonitor\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\n\n\n# ==========\n# ========== Define helper function...\n# ==========\n\ndef train():\n    r\"\"\"\n    Training procedure for TGN model\n    This function uses some objects that are globally defined in the current scrips \n\n    Parameters:\n        None\n    Returns:\n        None\n            \n    \"\"\"\n\n    model['memory'].train()\n    model['gnn'].train()\n    model['link_pred'].train()\n\n    model['memory'].reset_state()  # Start with a fresh memory.\n    neighbor_loader.reset_state()  # Start with an empty graph.\n\n    total_loss = 0\n    for batch in train_loader:\n        batch = batch.to(device)\n        optimizer.zero_grad()\n\n        src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg\n\n        # Sample negative destination nodes.\n        neg_dst = torch.randint(\n            min_dst_idx,\n            max_dst_idx + 1,\n            (src.size(0),),\n            dtype=torch.long,\n            device=device,\n        )\n\n        n_id = torch.cat([src, pos_dst, neg_dst]).unique()\n        n_id, edge_index, e_id = neighbor_loader(n_id)\n        assoc[n_id] = torch.arange(n_id.size(0), device=device)\n\n        # Get updated memory of all nodes involved in the computation.\n        z, last_update = model['memory'](n_id)\n        z = model['gnn'](\n            z,\n            last_update,\n            edge_index,\n            data.t[e_id].to(device),\n            data.msg[e_id].to(device),\n        )\n\n        pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])\n        neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])\n\n        loss = criterion(pos_out, torch.ones_like(pos_out))\n        loss += criterion(neg_out, torch.zeros_like(neg_out))\n\n        # Update memory and neighbor loader with ground-truth state.\n        model['memory'].update_state(src, pos_dst, t, msg)\n        neighbor_loader.insert(src, pos_dst)\n\n        loss.backward()\n        optimizer.step()\n        model['memory'].detach()\n        total_loss += float(loss.detach()) * batch.num_events\n\n    return total_loss / train_data.num_events\n\n\n@torch.no_grad()\ndef test(loader, neg_sampler, split_mode):\n    r\"\"\"\n    Evaluated the dynamic link prediction\n    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges\n\n    Parameters:\n        loader: an object containing positive attributes of the positive edges of the evaluation set\n        neg_sampler: an object that gives the negative edges corresponding to each positive edge\n        split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives\n    Returns:\n        perf_metric: the result of the performance evaluaiton\n    \"\"\"\n    model['memory'].eval()\n    model['gnn'].eval()\n    model['link_pred'].eval()\n\n    perf_list = []\n\n    for pos_batch in loader:\n        pos_src, pos_dst, pos_t, pos_msg = (\n            pos_batch.src,\n            pos_batch.dst,\n            pos_batch.t,\n            pos_batch.msg,\n        )\n\n        neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)\n\n        for idx, neg_batch in enumerate(neg_batch_list):\n            src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)\n            dst = torch.tensor(\n                np.concatenate(\n                    ([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),\n                    axis=0,\n                ),\n                device=device,\n            )\n\n            n_id = torch.cat([src, dst]).unique()\n            n_id, edge_index, e_id = neighbor_loader(n_id)\n            assoc[n_id] = torch.arange(n_id.size(0), device=device)\n\n            # Get updated memory of all nodes involved in the computation.\n            z, last_update = model['memory'](n_id)\n            z = model['gnn'](\n                z,\n                last_update,\n                edge_index,\n                data.t[e_id].to(device),\n                data.msg[e_id].to(device),\n            )\n\n            y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])\n\n            # compute MRR\n            input_dict = {\n                \"y_pred_pos\": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),\n                \"y_pred_neg\": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),\n                \"eval_metric\": [metric],\n            }\n            perf_list.append(evaluator.eval(input_dict)[metric])\n\n        # Update memory and neighbor loader with ground-truth state.\n        model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg)\n        neighbor_loader.insert(pos_src, pos_dst)\n\n    perf_metrics = float(torch.tensor(perf_list).mean())\n\n    return perf_metrics\n\n# ==========\n# ==========\n# ==========\n\n\n# Start...\nstart_overall = timeit.default_timer()\nDATA = \"tgbl-subreddit\"\n\n# ========== set parameters...\nargs, _ = get_args()\nargs.data = DATA\nprint(\"INFO: Arguments:\", args)\n\nLR = args.lr\nBATCH_SIZE = args.bs\nK_VALUE = args.k_value  \nNUM_EPOCH = args.num_epoch\nSEED = args.seed\nMEM_DIM = args.mem_dim\nTIME_DIM = args.time_dim\nEMB_DIM = args.emb_dim\nTOLERANCE = args.tolerance\nPATIENCE = args.patience\nNUM_RUNS = args.num_run\nNUM_NEIGHBORS = 10\n\n\nMODEL_NAME = 'TGN'\n# ==========\n\n# set the device\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\n# data loading\ndataset = PyGLinkPropPredDataset(name=DATA, root=\"datasets\")\ntrain_mask = dataset.train_mask\nval_mask = dataset.val_mask\ntest_mask = dataset.test_mask\ndata = dataset.get_TemporalData()\ndata = data.to(device)\nmetric = dataset.eval_metric\n\ntrain_data = data[train_mask]\nval_data = data[val_mask]\ntest_data = data[test_mask]\n\ntrain_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)\nval_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)\ntest_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)\n\n# Ensure to only sample actual destination nodes as negatives.\nmin_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\n\nprint(\"==========================================================\")\nprint(f\"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============\")\nprint(\"==========================================================\")\n\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'\n\nfor run_idx in range(NUM_RUNS):\n    print('-------------------------------------------------------------------------------')\n    print(f\"INFO: >>>>> Run: {run_idx} <<<<<\")\n    start_run = timeit.default_timer()\n\n    # set the seed for deterministic results...\n    torch.manual_seed(run_idx + SEED)\n    set_random_seed(run_idx + SEED)\n\n    # neighhorhood sampler\n    neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)\n\n    # define the model end-to-end\n    memory = TGNMemory(\n        data.num_nodes,\n        data.msg.size(-1),\n        MEM_DIM,\n        TIME_DIM,\n        message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),\n        aggregator_module=LastAggregator(),\n    ).to(device)\n\n    gnn = GraphAttentionEmbedding(\n        in_channels=MEM_DIM,\n        out_channels=EMB_DIM,\n        msg_dim=data.msg.size(-1),\n        time_enc=memory.time_enc,\n    ).to(device)\n\n    link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)\n\n    model = {'memory': memory,\n            'gnn': gnn,\n            'link_pred': link_pred}\n\n    optimizer = torch.optim.Adam(\n        set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),\n        lr=LR,\n    )\n    criterion = torch.nn.BCEWithLogitsLoss()\n\n    # Helper vector to map global node indices to local ones.\n    assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)\n\n\n    # define an early stopper\n    save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'\n    save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'\n    early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id, \n                                    tolerance=TOLERANCE, patience=PATIENCE)\n\n    # ==================================================== Train & Validation\n    # loading the validation negative samples\n    dataset.load_val_ns()\n\n    val_perf_list = []\n    start_train_val = timeit.default_timer()\n    for epoch in range(1, NUM_EPOCH + 1):\n        # training\n        start_epoch_train = timeit.default_timer()\n        loss = train()\n        print(\n            f\"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {timeit.default_timer() - start_epoch_train: .4f}\"\n        )\n\n        # validation\n        start_val = timeit.default_timer()\n        perf_metric_val = test(val_loader, neg_sampler, split_mode=\"val\")\n        print(f\"\\tValidation {metric}: {perf_metric_val: .4f}\")\n        print(f\"\\tValidation: Elapsed time (s): {timeit.default_timer() - start_val: .4f}\")\n        val_perf_list.append(perf_metric_val)\n\n        # check for early stopping\n        if early_stopper.step_check(perf_metric_val, model):\n            break\n\n    train_val_time = timeit.default_timer() - start_train_val\n    print(f\"Train & Validation: Elapsed Time (s): {train_val_time: .4f}\")\n\n    # ==================================================== Test\n    # first, load the best model\n    early_stopper.load_checkpoint(model)\n\n    # loading the test negative samples\n    dataset.load_test_ns()\n\n    # final testing\n    start_test = timeit.default_timer()\n    perf_metric_test = test(test_loader, neg_sampler, split_mode=\"test\")\n\n    print(f\"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< \")\n    print(f\"\\tTest: {metric}: {perf_metric_test: .4f}\")\n    test_time = timeit.default_timer() - start_test\n    print(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\n\n    save_results({'model': MODEL_NAME,\n                  'data': DATA,\n                  'run': run_idx,\n                  'seed': SEED,\n                  f'val {metric}': val_perf_list,\n                  f'test {metric}': perf_metric_test,\n                  'test_time': test_time,\n                  'tot_train_val_time': train_val_time\n                  }, \n    results_filename)\n\n    print(f\"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<\")\n    print('-------------------------------------------------------------------------------')\n\nprint(f\"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}\")\nprint(\"==============================================================\")\n"
  },
  {
    "path": "examples/linkproppred/tgbl-uci/edgebank.py",
    "content": "\"\"\"\nDynamic Link Prediction with EdgeBank\nNOTE: This implementation works only based on `numpy`\n\nReference: \n    - https://github.com/fpour/DGB/tree/main\n\n\n\"\"\"\n\nimport timeit\nimport numpy as np\nfrom tqdm import tqdm\nimport math\nimport os\nimport os.path as osp\nfrom pathlib import Path\nimport sys\nimport argparse\n\n# internal imports\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom modules.edgebank_predictor import EdgeBankPredictor\nfrom tgb.utils.utils import set_random_seed\nfrom tgb.linkproppred.dataset import LinkPropPredDataset\nfrom tgb.utils.utils import save_results\n\n# ==================\n# ==================\n# ==================\n\ndef test(data, test_mask, neg_sampler, split_mode):\n    r\"\"\"\n    Evaluated the dynamic link prediction\n    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges\n\n    Parameters:\n        data: a dataset object\n        test_mask: required masks to load the test set edges\n        neg_sampler: an object that gives the negative edges corresponding to each positive edge\n        split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives\n    Returns:\n        perf_metric: the result of the performance evaluaiton\n    \"\"\"\n    num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)\n    perf_list = []\n    for batch_idx in tqdm(range(num_batches)):\n        start_idx = batch_idx * BATCH_SIZE\n        end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))\n        pos_src, pos_dst, pos_t = (\n            data['sources'][test_mask][start_idx: end_idx],\n            data['destinations'][test_mask][start_idx: end_idx],\n            data['timestamps'][test_mask][start_idx: end_idx],\n        )\n        neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)\n        \n        for idx, neg_batch in enumerate(neg_batch_list):\n            query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])\n            query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])\n\n            y_pred = edgebank.predict_link(query_src, query_dst)\n            # compute MRR\n            input_dict = {\n                \"y_pred_pos\": np.array([y_pred[0]]),\n                \"y_pred_neg\": np.array(y_pred[1:]),\n                \"eval_metric\": [metric],\n            }\n            perf_list.append(evaluator.eval(input_dict)[metric])\n            \n        # update edgebank memory after each positive batch\n        edgebank.update_memory(pos_src, pos_dst, pos_t)\n\n    perf_metrics = float(np.mean(perf_list))\n\n    return perf_metrics\n\ndef get_args():\n    parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')\n    parser.add_argument('-d', '--data', type=str, help='Dataset name', default='tgbl-wiki')\n    parser.add_argument('--bs', type=int, help='Batch size', default=200)\n    parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)\n    parser.add_argument('--seed', type=int, help='Random seed', default=1)\n    parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])\n    parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)\n\n    try:\n        args = parser.parse_args()\n    except:\n        parser.print_help()\n        sys.exit(0)\n    return args, sys.argv \n\n# ==================\n# ==================\n# ==================\n\nstart_overall = timeit.default_timer()\n\n# set hyperparameters\nargs, _ = get_args()\n\nSEED = args.seed  # set the random seed for consistency\nset_random_seed(SEED)\nMEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`\nBATCH_SIZE = args.bs\nK_VALUE = args.k_value\nTIME_WINDOW_RATIO = args.time_window_ratio\nDATA = \"tgbl-uci\"\n\nMODEL_NAME = 'EdgeBank'\n\n# data loading with `numpy`\ndataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\ndata = dataset.full_data  \nmetric = dataset.eval_metric\n\n\n# get masks\ntrain_mask = dataset.train_mask\nval_mask = dataset.val_mask\ntest_mask = dataset.test_mask\n\n#data for memory in edgebank\nhist_src = np.concatenate([data['sources'][train_mask]])\nhist_dst = np.concatenate([data['destinations'][train_mask]])\nhist_ts = np.concatenate([data['timestamps'][train_mask]])\n\n# Set EdgeBank with memory updater\nedgebank = EdgeBankPredictor(\n        hist_src,\n        hist_dst,\n        hist_ts,\n        memory_mode=MEMORY_MODE,\n        time_window_ratio=TIME_WINDOW_RATIO)\n\nprint(\"==========================================================\")\nprint(f\"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============\")\nprint(\"==========================================================\")\n\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'\n\n# ==================================================== Test\n# loading the validation negative samples\ndataset.load_val_ns()\n\n# testing ...\nstart_val = timeit.default_timer()\nperf_metric_test = test(data, val_mask, neg_sampler, split_mode='val')\nend_val = timeit.default_timer()\n\nprint(f\"INFO: val: Evaluation Setting: >>> ONE-VS-MANY <<< \")\nprint(f\"\\tval: {metric}: {perf_metric_test: .4f}\")\ntest_time = timeit.default_timer() - start_val\nprint(f\"\\tval: Elapsed Time (s): {test_time: .4f}\")\n\n\n\n\n# ==================================================== Test\n# loading the test negative samples\ndataset.load_test_ns()\n\n# testing ...\nstart_test = timeit.default_timer()\nperf_metric_test = test(data, test_mask, neg_sampler, split_mode='test')\nend_test = timeit.default_timer()\n\nprint(f\"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< \")\nprint(f\"\\tTest: {metric}: {perf_metric_test: .4f}\")\ntest_time = timeit.default_timer() - start_test\nprint(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\n\nsave_results({'model': MODEL_NAME,\n              'memory_mode': MEMORY_MODE,\n              'data': DATA,\n              'run': 1,\n              'seed': SEED,\n              metric: perf_metric_test,\n              'test_time': test_time,\n              'tot_train_val_time': 'NA'\n              }, \n    results_filename)\n"
  },
  {
    "path": "examples/linkproppred/tgbl-wiki/dyrep.py",
    "content": "\"\"\"\nDyRep\n    This has been implemented with intuitions from the following sources:\n    - https://github.com/twitter-research/tgn\n    - https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py\n\n    Spec.:\n        - Memory Updater: RNN\n        - Embedding Module: ID\n        - Message Function: ATTN\n\ncommand for an example run:\n    python examples/linkproppred/tgbl-wiki/dyrep.py --data \"tgbl-wiki\" --num_run 1 --seed 1\n\"\"\"\nimport math\nimport timeit\nimport os\nimport os.path as osp\nfrom pathlib import Path\nimport numpy as np\nimport torch\nfrom torch_geometric.loader import TemporalDataLoader\n\n# internal imports\nfrom tgb.utils.utils import get_args, set_random_seed, save_results\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom modules.decoder import LinkPredictor\nfrom modules.emb_module import GraphAttentionEmbedding\nfrom modules.msg_func import IdentityMessage\nfrom modules.msg_agg import LastAggregator\nfrom modules.neighbor_loader import LastNeighborLoader\nfrom modules.memory_module import DyRepMemory\nfrom modules.early_stopping import  EarlyStopMonitor\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\n\n\n\n# ==========\n# ========== Define helper function...\n# ==========\n\ndef train():\n    r\"\"\"\n    Training procedure for DyRep model\n    This function uses some objects that are globally defined in the current scrips \n\n    Parameters:\n        None\n    Returns:\n        None\n            \n    \"\"\"\n    model['memory'].train()\n    model['gnn'].train()\n    model['link_pred'].train()\n\n    model['memory'].reset_state()  # Start with a fresh memory.\n    neighbor_loader.reset_state()  # Start with an empty graph.\n\n    total_loss = 0\n    for batch in train_loader:\n        batch = batch.to(device)\n        optimizer.zero_grad()\n\n        src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg\n\n        # Sample negative destination nodes.\n        neg_dst = torch.randint(\n            min_dst_idx,\n            max_dst_idx + 1,\n            (src.size(0),),\n            dtype=torch.long,\n            device=device,\n        )\n\n        n_id = torch.cat([src, pos_dst, neg_dst]).unique()\n        n_id, edge_index, e_id = neighbor_loader(n_id)\n        assoc[n_id] = torch.arange(n_id.size(0), device=device)\n\n        # Get updated memory of all nodes involved in the computation.\n        z, last_update = model['memory'](n_id)\n\n        pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])\n        neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])\n\n        loss = criterion(pos_out, torch.ones_like(pos_out))\n        loss += criterion(neg_out, torch.zeros_like(neg_out))\n\n        # update the memory with ground-truth\n        z = model['gnn'](\n            z,\n            last_update,\n            edge_index,\n            data.t[e_id].to(device),\n            data.msg[e_id].to(device),\n        )\n        model['memory'].update_state(src, pos_dst, t, msg, z, assoc)\n\n        # update neighbor loader\n        neighbor_loader.insert(src, pos_dst)\n\n        loss.backward()\n        optimizer.step()\n        model['memory'].detach()\n        total_loss += float(loss.detach()) * batch.num_events\n\n    return total_loss / train_data.num_events\n\n\n@torch.no_grad()\ndef test(loader, neg_sampler, split_mode):\n    r\"\"\"\n    Evaluated the dynamic link prediction\n    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges\n\n    Parameters:\n        loader: an object containing positive attributes of the positive edges of the evaluation set\n        neg_sampler: an object that gives the negative edges corresponding to each positive edge\n        split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives\n    Returns:\n        perf_metric: the result of the performance evaluation\n    \"\"\"\n    model['memory'].eval()\n    model['gnn'].eval()\n    model['link_pred'].eval()\n\n    perf_list = []\n\n    for pos_batch in loader:\n        pos_src, pos_dst, pos_t, pos_msg = (\n            pos_batch.src,\n            pos_batch.dst,\n            pos_batch.t,\n            pos_batch.msg,\n        )\n\n        neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)\n\n        for idx, neg_batch in enumerate(neg_batch_list):\n            src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)\n            dst = torch.tensor(\n                np.concatenate(\n                    ([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),\n                    axis=0,\n                ),\n                device=device,\n            )\n\n            n_id = torch.cat([src, dst]).unique()\n            n_id, edge_index, e_id = neighbor_loader(n_id)\n            assoc[n_id] = torch.arange(n_id.size(0), device=device)\n\n            # Get updated memory of all nodes involved in the computation.\n            z, last_update = model['memory'](n_id)\n\n            y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])\n            # compute MRR\n            input_dict = {\n                \"y_pred_pos\": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),\n                \"y_pred_neg\": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),\n                \"eval_metric\": [metric],\n            }\n            perf_list.append(evaluator.eval(input_dict)[metric])\n        \n        # update the memory with positive edges\n        n_id = torch.cat([pos_src, pos_dst]).unique()\n        n_id, edge_index, e_id = neighbor_loader(n_id)\n        assoc[n_id] = torch.arange(n_id.size(0), device=device)\n        z, last_update = model['memory'](n_id)\n        z = model['gnn'](\n            z,\n            last_update,\n            edge_index,\n            data.t[e_id].to(device),\n            data.msg[e_id].to(device),\n        )\n        model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg, z, assoc)\n\n        # update the neighbor loader\n        neighbor_loader.insert(pos_src, pos_dst)\n\n    perf_metric = float(torch.tensor(perf_list).mean())\n\n    return perf_metric\n\n# ==========\n# ==========\n# ==========\n\n# Start...\nstart_overall = timeit.default_timer()\n\n# ========== set parameters...\nargs, _ = get_args()\nprint(\"INFO: Arguments:\", args)\n\nDATA = \"tgbl-wiki\"\nLR = args.lr\nBATCH_SIZE = args.bs\nK_VALUE = args.k_value  \nNUM_EPOCH = args.num_epoch\nSEED = args.seed\nMEM_DIM = args.mem_dim\nTIME_DIM = args.time_dim\nEMB_DIM = args.emb_dim\nTOLERANCE = args.tolerance\nPATIENCE = args.patience\nNUM_RUNS = args.num_run\nNUM_NEIGHBORS = 10\n\nMODEL_NAME = 'DyRep'\nUSE_SRC_EMB_IN_MSG = False\nUSE_DST_EMB_IN_MSG = True\n# ==========\n\n# set the device\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n# data loading\ndataset = PyGLinkPropPredDataset(name=DATA, root=\"datasets\")\ntrain_mask = dataset.train_mask\nval_mask = dataset.val_mask\ntest_mask = dataset.test_mask\ndata = dataset.get_TemporalData()\ndata = data.to(device)\nmetric = dataset.eval_metric\n\ntrain_data = data[train_mask]\nval_data = data[val_mask]\ntest_data = data[test_mask]\n\ntrain_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)\nval_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)\ntest_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)\n\n# Ensure to only sample actual destination nodes as negatives.\nmin_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\n\n# neighborhood sampler\nneighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)\n\n# define the model end-to-end\n# 1) memory\nmemory = DyRepMemory(\n    data.num_nodes,\n    data.msg.size(-1),\n    MEM_DIM,\n    TIME_DIM,\n    message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),\n    aggregator_module=LastAggregator(),\n    memory_updater_type='rnn',\n    use_src_emb_in_msg=USE_SRC_EMB_IN_MSG,\n    use_dst_emb_in_msg=USE_DST_EMB_IN_MSG\n).to(device)\n\n# 2) GNN\ngnn = GraphAttentionEmbedding(\n    in_channels=MEM_DIM,\n    out_channels=EMB_DIM,\n    msg_dim=data.msg.size(-1),\n    time_enc=memory.time_enc,\n).to(device)\n\n# 3) link predictor\nlink_pred = LinkPredictor(in_channels=EMB_DIM).to(device)\n\nmodel = {'memory': memory,\n         'gnn': gnn,\n         'link_pred': link_pred}\n\n# define an optimizer\noptimizer = torch.optim.Adam(\n    set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),\n    lr=LR,\n)\ncriterion = torch.nn.BCEWithLogitsLoss()\n\n# Helper vector to map global node indices to local ones.\nassoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)\n\n\nprint(\"==========================================================\")\nprint(f\"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============\")\nprint(\"==========================================================\")\n\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'\n\nfor run_idx in range(NUM_RUNS):\n    print('-------------------------------------------------------------------------------')\n    print(f\"INFO: >>>>> Run: {run_idx} <<<<<\")\n    start_run = timeit.default_timer()\n\n    # set the seed for deterministic results...\n    torch.manual_seed(run_idx + SEED)\n    set_random_seed(run_idx + SEED)\n\n    # define an early stopper\n    save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'\n    save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'\n    early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id, \n                                    tolerance=TOLERANCE, patience=PATIENCE)\n\n    # ==================================================== Train & Validation\n    # loading the validation negative samples\n    dataset.load_val_ns()\n\n    val_perf_list = []\n    start_train_val = timeit.default_timer()\n    for epoch in range(1, NUM_EPOCH + 1):\n        # training\n        start_epoch_train = timeit.default_timer()\n        loss = train()\n        print(\n            f\"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {timeit.default_timer() - start_epoch_train: .4f}\"\n        )\n\n        # validation\n        start_val = timeit.default_timer()\n        perf_metric_val = test(val_loader, neg_sampler, split_mode=\"val\")\n        print(f\"\\tValidation {metric}: {perf_metric_val: .4f}\")\n        print(f\"\\tValidation: Elapsed time (s): {timeit.default_timer() - start_val: .4f}\")\n        val_perf_list.append(perf_metric_val)\n\n        # check for early stopping\n        if early_stopper.step_check(perf_metric_val, model):\n            break\n\n    train_val_time = timeit.default_timer() - start_train_val\n    print(f\"Train & Validation Total Elapsed Time (s): {train_val_time: .4f}\")\n\n    # ==================================================== Test\n    # first, load the best model\n    early_stopper.load_checkpoint(model)\n\n    # loading the test negative samples\n    dataset.load_test_ns()\n\n    # final testing\n    start_test = timeit.default_timer()\n    perf_metric_test = test(test_loader, neg_sampler, split_mode=\"test\")\n\n    print(f\"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< \")\n    print(f\"\\tTest: {metric}: {perf_metric_test: .4f}\")\n    test_time = timeit.default_timer() - start_test\n    print(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\n\n    save_results({'model': MODEL_NAME,\n                  'data': DATA,\n                  'run': run_idx,\n                  'seed': SEED,\n                  f'val {metric}': val_perf_list,\n                  f'test {metric}': perf_metric_test,\n                  'test_time': test_time,\n                  'tot_train_val_time': train_val_time\n                  }, \n    results_filename)\n\n    print(f\"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<\")\n    print('-------------------------------------------------------------------------------')\n\nprint(f\"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}\")\nprint(\"==============================================================\")\n"
  },
  {
    "path": "examples/linkproppred/tgbl-wiki/edgebank.py",
    "content": "\"\"\"\nDynamic Link Prediction with EdgeBank\nNOTE: This implementation works only based on `numpy`\n\nReference: \n    - https://github.com/fpour/DGB/tree/main\n\n\n\"\"\"\n\nimport timeit\nimport numpy as np\nfrom tqdm import tqdm\nimport math\nimport os\nimport os.path as osp\nfrom pathlib import Path\nimport sys\nimport argparse\n\n# internal imports\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom modules.edgebank_predictor import EdgeBankPredictor\nfrom tgb.utils.utils import set_random_seed\nfrom tgb.linkproppred.dataset import LinkPropPredDataset\nfrom tgb.utils.utils import save_results\n\n# ==================\n# ==================\n# ==================\n\ndef test(data, test_mask, neg_sampler, split_mode):\n    r\"\"\"\n    Evaluated the dynamic link prediction\n    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges\n\n    Parameters:\n        data: a dataset object\n        test_mask: required masks to load the test set edges\n        neg_sampler: an object that gives the negative edges corresponding to each positive edge\n        split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives\n    Returns:\n        perf_metric: the result of the performance evaluation\n    \"\"\"\n    num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)\n    perf_list = []\n    for batch_idx in tqdm(range(num_batches)):\n        start_idx = batch_idx * BATCH_SIZE\n        end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))\n        pos_src, pos_dst, pos_t = (\n            data['sources'][test_mask][start_idx: end_idx],\n            data['destinations'][test_mask][start_idx: end_idx],\n            data['timestamps'][test_mask][start_idx: end_idx],\n        )\n        neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)\n        \n        for idx, neg_batch in enumerate(neg_batch_list):\n            query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])\n            query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])\n\n            y_pred = edgebank.predict_link(query_src, query_dst)\n            # compute MRR\n            input_dict = {\n                \"y_pred_pos\": np.array([y_pred[0]]),\n                \"y_pred_neg\": np.array(y_pred[1:]),\n                \"eval_metric\": [metric],\n            }\n            perf_list.append(evaluator.eval(input_dict)[metric])\n            \n        # update edgebank memory after each positive batch\n        edgebank.update_memory(pos_src, pos_dst, pos_t)\n\n    perf_metrics = float(np.mean(perf_list))\n\n    return perf_metrics\n\ndef get_args():\n    parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')\n    parser.add_argument('-d', '--data', type=str, help='Dataset name', default='tgbl-wiki')\n    parser.add_argument('--bs', type=int, help='Batch size', default=200)\n    parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)\n    parser.add_argument('--seed', type=int, help='Random seed', default=1)\n    parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])\n    parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)\n\n    try:\n        args = parser.parse_args()\n    except:\n        parser.print_help()\n        sys.exit(0)\n    return args, sys.argv \n\n# ==================\n# ==================\n# ==================\n\nstart_overall = timeit.default_timer()\n\n# set hyperparameters\nargs, _ = get_args()\n\nSEED = args.seed  # set the random seed for consistency\nset_random_seed(SEED)\nMEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`\nBATCH_SIZE = args.bs\nK_VALUE = args.k_value\nTIME_WINDOW_RATIO = args.time_window_ratio\nDATA = \"tgbl-wiki\"\n\nMODEL_NAME = 'EdgeBank'\n\n# data loading with `numpy`\ndataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\ndata = dataset.full_data  \nmetric = dataset.eval_metric\n\n\n# get masks\ntrain_mask = dataset.train_mask\nval_mask = dataset.val_mask\ntest_mask = dataset.test_mask\n\n#data for memory in edgebank\nhist_src = np.concatenate([data['sources'][train_mask]])\nhist_dst = np.concatenate([data['destinations'][train_mask]])\nhist_ts = np.concatenate([data['timestamps'][train_mask]])\n\n# Set EdgeBank with memory updater\nedgebank = EdgeBankPredictor(\n        hist_src,\n        hist_dst,\n        hist_ts,\n        memory_mode=MEMORY_MODE,\n        time_window_ratio=TIME_WINDOW_RATIO)\n\nprint(\"==========================================================\")\nprint(f\"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============\")\nprint(\"==========================================================\")\n\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'\n\n# ==================================================== Test\n# loading the validation negative samples\ndataset.load_val_ns()\n\n# testing ...\nstart_val = timeit.default_timer()\nperf_metric_test = test(data, val_mask, neg_sampler, split_mode='val')\nend_val = timeit.default_timer()\n\nprint(f\"INFO: val: Evaluation Setting: >>> ONE-VS-MANY <<< \")\nprint(f\"\\tval: {metric}: {perf_metric_test: .4f}\")\ntest_time = timeit.default_timer() - start_val\nprint(f\"\\tval: Elapsed Time (s): {test_time: .4f}\")\n\n\n\n\n# ==================================================== Test\n# loading the test negative samples\ndataset.load_test_ns()\n\n# testing ...\nstart_test = timeit.default_timer()\nperf_metric_test = test(data, test_mask, neg_sampler, split_mode='test')\nend_test = timeit.default_timer()\n\nprint(f\"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< \")\nprint(f\"\\tTest: {metric}: {perf_metric_test: .4f}\")\ntest_time = timeit.default_timer() - start_test\nprint(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\n\nsave_results({'model': MODEL_NAME,\n              'memory_mode': MEMORY_MODE,\n              'data': DATA,\n              'run': 1,\n              'seed': SEED,\n              metric: perf_metric_test,\n              'test_time': test_time,\n              'tot_train_val_time': 'NA'\n              }, \n    results_filename)\n"
  },
  {
    "path": "examples/linkproppred/tgbl-wiki/tgn.py",
    "content": "\"\"\"\nDynamic Link Prediction with a TGN model with Early Stopping\nReference: \n    - https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py\n\ncommand for an example run:\n    python examples/linkproppred/tgbl-wiki/tgn.py --data \"tgbl-wiki\" --num_run 1 --seed 1\n\"\"\"\n\nimport math\nimport timeit\n\nimport os\nimport os.path as osp\nfrom pathlib import Path\nimport numpy as np\n\nimport torch\nfrom sklearn.metrics import average_precision_score, roc_auc_score\nfrom torch.nn import Linear\n\nfrom torch_geometric.datasets import JODIEDataset\nfrom torch_geometric.loader import TemporalDataLoader\nfrom torch_geometric.nn import TransformerConv\n\n# internal imports\nfrom tgb.utils.utils import get_args, set_random_seed, save_results\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom modules.decoder import LinkPredictor\nfrom modules.emb_module import GraphAttentionEmbedding\nfrom modules.msg_func import IdentityMessage\nfrom modules.msg_agg import LastAggregator\nfrom modules.neighbor_loader import LastNeighborLoader\nfrom modules.memory_module import TGNMemory\nfrom modules.early_stopping import  EarlyStopMonitor\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\n\n\n# ==========\n# ========== Define helper function...\n# ==========\n\ndef train():\n    r\"\"\"\n    Training procedure for TGN model\n    This function uses some objects that are globally defined in the current scrips \n\n    Parameters:\n        None\n    Returns:\n        None\n            \n    \"\"\"\n\n    model['memory'].train()\n    model['gnn'].train()\n    model['link_pred'].train()\n\n    model['memory'].reset_state()  # Start with a fresh memory.\n    neighbor_loader.reset_state()  # Start with an empty graph.\n\n    total_loss = 0\n    for batch in train_loader:\n        batch = batch.to(device)\n        optimizer.zero_grad()\n\n        src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg\n\n        # Sample negative destination nodes.\n        neg_dst = torch.randint(\n            min_dst_idx,\n            max_dst_idx + 1,\n            (src.size(0),),\n            dtype=torch.long,\n            device=device,\n        )\n\n        n_id = torch.cat([src, pos_dst, neg_dst]).unique()\n        n_id, edge_index, e_id = neighbor_loader(n_id)\n        assoc[n_id] = torch.arange(n_id.size(0), device=device)\n\n        # Get updated memory of all nodes involved in the computation.\n        z, last_update = model['memory'](n_id)\n        z = model['gnn'](\n            z,\n            last_update,\n            edge_index,\n            data.t[e_id].to(device),\n            data.msg[e_id].to(device),\n        )\n\n        pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])\n        neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])\n\n        loss = criterion(pos_out, torch.ones_like(pos_out))\n        loss += criterion(neg_out, torch.zeros_like(neg_out))\n\n        # Update memory and neighbor loader with ground-truth state.\n        model['memory'].update_state(src, pos_dst, t, msg)\n        neighbor_loader.insert(src, pos_dst)\n\n        loss.backward()\n        optimizer.step()\n        model['memory'].detach()\n        total_loss += float(loss.detach()) * batch.num_events\n\n    return total_loss / train_data.num_events\n\n\n@torch.no_grad()\ndef test(loader, neg_sampler, split_mode):\n    r\"\"\"\n    Evaluated the dynamic link prediction\n    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges\n\n    Parameters:\n        loader: an object containing positive attributes of the positive edges of the evaluation set\n        neg_sampler: an object that gives the negative edges corresponding to each positive edge\n        split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives\n    Returns:\n        perf_metric: the result of the performance evaluation\n    \"\"\"\n    model['memory'].eval()\n    model['gnn'].eval()\n    model['link_pred'].eval()\n\n    perf_list = []\n\n    for pos_batch in loader:\n        pos_src, pos_dst, pos_t, pos_msg = (\n            pos_batch.src,\n            pos_batch.dst,\n            pos_batch.t,\n            pos_batch.msg,\n        )\n\n        neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)\n\n        for idx, neg_batch in enumerate(neg_batch_list):\n            src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)\n            dst = torch.tensor(\n                np.concatenate(\n                    ([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),\n                    axis=0,\n                ),\n                device=device,\n            )\n\n            n_id = torch.cat([src, dst]).unique()\n            n_id, edge_index, e_id = neighbor_loader(n_id)\n            assoc[n_id] = torch.arange(n_id.size(0), device=device)\n\n            # Get updated memory of all nodes involved in the computation.\n            z, last_update = model['memory'](n_id)\n            z = model['gnn'](\n                z,\n                last_update,\n                edge_index,\n                data.t[e_id].to(device),\n                data.msg[e_id].to(device),\n            )\n\n            y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])\n\n            # compute MRR\n            input_dict = {\n                \"y_pred_pos\": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),\n                \"y_pred_neg\": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),\n                \"eval_metric\": [metric],\n            }\n            perf_list.append(evaluator.eval(input_dict)[metric])\n\n        # Update memory and neighbor loader with ground-truth state.\n        model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg)\n        neighbor_loader.insert(pos_src, pos_dst)\n\n    perf_metrics = float(torch.tensor(perf_list).mean())\n\n    return perf_metrics\n\n# ==========\n# ==========\n# ==========\n\n\n# Start...\nstart_overall = timeit.default_timer()\nDATA = \"tgbl-wiki\"\n\n# ========== set parameters...\nargs, _ = get_args()\nargs.data = DATA\nprint(\"INFO: Arguments:\", args)\n\nLR = args.lr\nBATCH_SIZE = args.bs\nK_VALUE = args.k_value  \nNUM_EPOCH = args.num_epoch\nSEED = args.seed\nMEM_DIM = args.mem_dim\nTIME_DIM = args.time_dim\nEMB_DIM = args.emb_dim\nTOLERANCE = args.tolerance\nPATIENCE = args.patience\nNUM_RUNS = args.num_run\nNUM_NEIGHBORS = 10\n\n\nMODEL_NAME = 'TGN'\n# ==========\n\n# set the device\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n# data loading\ndataset = PyGLinkPropPredDataset(name=DATA, root=\"datasets\")\ntrain_mask = dataset.train_mask\nval_mask = dataset.val_mask\ntest_mask = dataset.test_mask\ndata = dataset.get_TemporalData()\ndata = data.to(device)\nmetric = dataset.eval_metric\n\ntrain_data = data[train_mask]\nval_data = data[val_mask]\ntest_data = data[test_mask]\n\ntrain_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)\nval_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)\ntest_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)\n\n# Ensure to only sample actual destination nodes as negatives.\nmin_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\n\nprint(\"==========================================================\")\nprint(f\"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============\")\nprint(\"==========================================================\")\n\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'\n\nfor run_idx in range(NUM_RUNS):\n    print('-------------------------------------------------------------------------------')\n    print(f\"INFO: >>>>> Run: {run_idx} <<<<<\")\n    start_run = timeit.default_timer()\n\n    # set the seed for deterministic results...\n    torch.manual_seed(run_idx + SEED)\n    set_random_seed(run_idx + SEED)\n\n\n    # neighborhood sampler\n    neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)\n\n    # define the model end-to-end\n    memory = TGNMemory(\n        data.num_nodes,\n        data.msg.size(-1),\n        MEM_DIM,\n        TIME_DIM,\n        message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),\n        aggregator_module=LastAggregator(),\n    ).to(device)\n\n    gnn = GraphAttentionEmbedding(\n        in_channels=MEM_DIM,\n        out_channels=EMB_DIM,\n        msg_dim=data.msg.size(-1),\n        time_enc=memory.time_enc,\n    ).to(device)\n\n    link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)\n\n    model = {'memory': memory,\n            'gnn': gnn,\n            'link_pred': link_pred}\n\n    optimizer = torch.optim.Adam(\n        set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),\n        lr=LR,\n    )\n    criterion = torch.nn.BCEWithLogitsLoss()\n\n    # Helper vector to map global node indices to local ones.\n    assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)\n\n    # define an early stopper\n    save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'\n    save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'\n    early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id, \n                                    tolerance=TOLERANCE, patience=PATIENCE)\n\n    # ==================================================== Train & Validation\n    # loading the validation negative samples\n    dataset.load_val_ns()\n\n    val_perf_list = []\n    start_train_val = timeit.default_timer()\n    for epoch in range(1, NUM_EPOCH + 1):\n        # training\n        start_epoch_train = timeit.default_timer()\n        loss = train()\n        print(\n            f\"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {timeit.default_timer() - start_epoch_train: .4f}\"\n        )\n\n        # validation\n        start_val = timeit.default_timer()\n        perf_metric_val = test(val_loader, neg_sampler, split_mode=\"val\")\n        print(f\"\\tValidation {metric}: {perf_metric_val: .4f}\")\n        print(f\"\\tValidation: Elapsed time (s): {timeit.default_timer() - start_val: .4f}\")\n        val_perf_list.append(perf_metric_val)\n\n        # check for early stopping\n        if early_stopper.step_check(perf_metric_val, model):\n            break\n\n    train_val_time = timeit.default_timer() - start_train_val\n    print(f\"Train & Validation: Elapsed Time (s): {train_val_time: .4f}\")\n\n    # ==================================================== Test\n    # first, load the best model\n    early_stopper.load_checkpoint(model)\n\n    # loading the test negative samples\n    dataset.load_test_ns()\n\n    # final testing\n    start_test = timeit.default_timer()\n    perf_metric_test = test(test_loader, neg_sampler, split_mode=\"test\")\n\n    print(f\"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< \")\n    print(f\"\\tTest: {metric}: {perf_metric_test: .4f}\")\n    test_time = timeit.default_timer() - start_test\n    print(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\n\n    save_results({'model': MODEL_NAME,\n                  'data': DATA,\n                  'run': run_idx,\n                  'seed': SEED,\n                  f'val {metric}': val_perf_list,\n                  f'test {metric}': perf_metric_test,\n                  'test_time': test_time,\n                  'tot_train_val_time': train_val_time\n                  }, \n    results_filename)\n\n    print(f\"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<\")\n    print('-------------------------------------------------------------------------------')\n\nprint(f\"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}\")\nprint(\"==============================================================\")\n"
  },
  {
    "path": "examples/linkproppred/thgl-forum/edgebank.py",
    "content": "\"\"\"\nDynamic Link Prediction with EdgeBank\nNOTE: This implementation works only based on `numpy`\n\nReference: \n    - https://github.com/fpour/DGB/tree/main\n\n\n\"\"\"\n\nimport timeit\nimport numpy as np\nfrom sklearn.metrics import average_precision_score, roc_auc_score\nfrom torch_geometric.loader import TemporalDataLoader\nfrom tqdm import tqdm\nimport math\nimport os\nimport os.path as osp\nfrom pathlib import Path\nimport sys\nimport argparse\n\n# internal imports\ntgb_modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\nsys.path.append(tgb_modules_path)\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom modules.edgebank_predictor import EdgeBankPredictor\nfrom tgb.utils.utils import set_random_seed\nfrom tgb.linkproppred.dataset import LinkPropPredDataset\nfrom tgb.utils.utils import save_results\n\n# ==================\n# ==================\n# ==================\n\ndef test(data, test_mask, neg_sampler, split_mode):\n    r\"\"\"\n    Evaluated the dynamic link prediction\n    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges\n\n    Parameters:\n        data: a dataset object\n        test_mask: required masks to load the test set edges\n        neg_sampler: an object that gives the negative edges corresponding to each positive edge\n        split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives\n    Returns:\n        perf_metric: the result of the performance evaluation\n    \"\"\"\n    num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)\n    perf_list = []\n    hits_list = []\n    for batch_idx in tqdm(range(num_batches)):\n        start_idx = batch_idx * BATCH_SIZE\n        end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))\n        pos_src, pos_dst, pos_t, pos_edge = (\n            data['sources'][test_mask][start_idx: end_idx],\n            data['destinations'][test_mask][start_idx: end_idx],\n            data['timestamps'][test_mask][start_idx: end_idx],\n            data['edge_type'][test_mask][start_idx: end_idx],\n        )\n        neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, pos_edge, split_mode=split_mode)\n        \n        for idx, neg_batch in enumerate(neg_batch_list):\n            query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])\n            query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])\n\n            y_pred = edgebank.predict_link(query_src, query_dst)\n            # compute MRR\n            input_dict = {\n                \"y_pred_pos\": np.array([y_pred[0]]),\n                \"y_pred_neg\": np.array(y_pred[1:]),\n                \"eval_metric\": [metric],\n            }\n            results = evaluator.eval(input_dict)\n            perf_list.append(results[metric])\n            hits_list.append(results['hits@10'])\n            \n        # update edgebank memory after each positive batch\n        edgebank.update_memory(pos_src, pos_dst, pos_t)\n\n    perf_metrics = float(np.mean(perf_list))\n    perf_hits = float(np.mean(hits_list))\n\n    return perf_metrics, perf_hits\n\ndef get_args():\n    parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')\n    parser.add_argument('-d', '--data', type=str, help='Dataset name', default='thgl-forum')\n    parser.add_argument('--bs', type=int, help='Batch size', default=200)\n    parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)\n    parser.add_argument('--seed', type=int, help='Random seed', default=1)\n    parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])\n    parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)\n\n    try:\n        args = parser.parse_args()\n    except:\n        parser.print_help()\n        sys.exit(0)\n    return args, sys.argv \n\n# ==================\n# ==================\n# ==================\n\nstart_overall = timeit.default_timer()\n\n# set hyperparameters\nargs, _ = get_args()\n\nSEED = args.seed  # set the random seed for consistency\nset_random_seed(SEED)\nMEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`\nBATCH_SIZE = args.bs\nK_VALUE = args.k_value\nTIME_WINDOW_RATIO = args.time_window_ratio\nDATA = args.data\nMODEL_NAME = 'EdgeBank'\n\n\n\n# data loading with `numpy`\ndataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\ndata = dataset.full_data  \nmetric = dataset.eval_metric\n\n\n# get masks\ntrain_mask = dataset.train_mask\nval_mask = dataset.val_mask\ntest_mask = dataset.test_mask\n\n#data for memory in edgebank\nhist_src = np.concatenate([data['sources'][train_mask]])\nhist_dst = np.concatenate([data['destinations'][train_mask]])\nhist_ts = np.concatenate([data['timestamps'][train_mask]])\n\n# Set EdgeBank with memory updater\nedgebank = EdgeBankPredictor(\n        hist_src,\n        hist_dst,\n        hist_ts,\n        memory_mode=MEMORY_MODE,\n        time_window_ratio=TIME_WINDOW_RATIO)\n\nprint(\"==========================================================\")\nprint(f\"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============\")\nprint(\"==========================================================\")\n\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'\n\n# ==================================================== Test\n# loading the validation negative samples\ndataset.load_val_ns()\n\n# testing ...\nstart_val = timeit.default_timer()\nperf_metric_val, perf_hits_val = test(data, val_mask, neg_sampler, split_mode='val')\nend_val = timeit.default_timer()\n\nprint(f\"INFO: val: Evaluation Setting: >>> ONE-VS--ALL <<< \")\nprint(f\"\\tval: {metric}: {perf_metric_val: .4f}\")\nval_time = timeit.default_timer() - start_val\nprint(f\"\\tval: Elapsed Time (s): {val_time: .4f}\")\n\n\n\n\n# ==================================================== Test\n# loading the test negative samples\ndataset.load_test_ns()\n\n# testing ...\nstart_test = timeit.default_timer()\nperf_metric_test, perf_hits_test = test(data, test_mask, neg_sampler, split_mode='test')\nend_test = timeit.default_timer()\n\nprint(f\"INFO: Test: Evaluation Setting: >>>  <<< \")\nprint(f\"\\tTest: {metric}: {perf_metric_test: .4f}\")\ntest_time = timeit.default_timer() - start_test\nprint(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\n\nsave_results({'model': MODEL_NAME,\n              'memory_mode': MEMORY_MODE,\n              'data': DATA,\n              'run': 1,\n              'seed': SEED,\n              metric: perf_metric_test,\n              'val_mrr': perf_metric_val,\n              'test_time': test_time,\n              'tot_train_val_time': test_time+val_time,\n              'hits10': perf_hits_test}, \n    results_filename)\n\n"
  },
  {
    "path": "examples/linkproppred/thgl-forum/recurrencybaseline.py",
    "content": "\"\"\" from paper: \"History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting\" \nJulia Gastinger, Christian Meilicke, Federico Errica, Timo Sztyler, Anett Schuelke, Heiner Stuckenschmidt (IJCAI 2024)\n\n@inproceedings{gastinger2024baselines,\n  title={History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting},\n  author={Gastinger, Julia and Meilicke, Christian and Errica, Federico and Sztyler, Timo and Schuelke, Anett and Stuckenschmidt, Heiner},\n  booktitle={33nd International Joint Conference on Artificial Intelligence (IJCAI 2024)},\n  year={2024},\n  organization={International Joint Conferences on Artificial Intelligence Organization}\n}\n\n\"\"\"\n\n## imports\nimport timeit\nimport argparse\nimport numpy as np\nfrom copy import copy\nfrom pathlib import Path\nimport ray\nimport sys\nimport os\nimport os.path as osp\nimport json\n#internal imports \nmodules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\nsys.path.append(modules_path)\nfrom modules.recurrencybaseline_predictor import baseline_predict, baseline_predict_remote\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom tgb.linkproppred.dataset import LinkPropPredDataset \nfrom tgb.utils.utils import set_random_seed,  save_results \nfrom modules.tkg_utils import create_basis_dict, group_by, reformat_ts\n\ndef predict(num_processes,  data_c_rel, all_data_c_rel, alpha, lmbda_psi,\n            perf_list_all, hits_list_all, window, neg_sampler, split_mode):\n    \"\"\" create predictions for each relation on test or valid set and compute mrr\n    :return  perf_list_all: list of mrrs for each test query\n    :return hits_list_all: list of hits for each test query\n    \"\"\"\n\n    first_ts = data_c_rel[0][3]\n    ## use this if you wanna use ray:\n    num_queries = len(data_c_rel) // num_processes\n    if num_queries < num_processes: # if we do not have enough queries for all the processes\n        num_processes_tmp = 1\n        num_queries = len(data_c_rel)\n    else:\n        num_processes_tmp = num_processes  \n    if num_processes > 1:\n        object_references =[]                   \n        \n        for i in range(num_processes_tmp):\n            num_test_queries = len(data_c_rel) - (i + 1) * num_queries\n            if num_test_queries >= num_queries:\n                test_queries_idx =[i * num_queries, (i + 1) * num_queries]\n            else:\n                test_queries_idx = [i * num_queries, len(test_data)]\n\n            valid_data_b = data_c_rel[test_queries_idx[0]:test_queries_idx[1]]\n\n            ob = baseline_predict_remote.remote(num_queries, valid_data_b, all_data_c_rel, window, \n                                basis_dict, \n                                num_nodes, num_rels, lmbda_psi, \n                                alpha, evaluator,first_ts, neg_sampler, split_mode)\n            object_references.append(ob)\n\n        output = ray.get(object_references)\n\n        # updates the scores and logging dict for each process\n        for proc_loop in range(num_processes_tmp):\n            perf_list_all.extend(output[proc_loop][0])\n            hits_list_all.extend(output[proc_loop][1])\n\n    else:\n        perf_list, hits_list = baseline_predict(len(data_c_rel), data_c_rel, all_data_c_rel, \n                            window, basis_dict, \n                            num_nodes, num_rels, lmbda_psi, \n                            alpha, evaluator, first_ts, neg_sampler, split_mode)                  \n        perf_list_all.extend(perf_list)\n        hits_list_all.extend(hits_list)\n    \n    return perf_list_all, hits_list_all\n\n## test\ndef test(best_config, all_relations,test_data_prel, all_data_prel, neg_sampler, num_processes, window, split_mode='test'):  \n    \"\"\" create predictions by loopoing through all relations on test or valid set and compute mrr\n    :return  perf_list_all: list of mrrs for each test query\n    :return hits_list_all: list of hits for each test query\n    \"\"\"       \n    perf_list_all = []\n    hits_list_all =[]\n    \n    csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'\n\n    ## loop through relations and apply baselines\n    \n    for rel in all_relations:\n        start =  timeit.default_timer()\n        if rel in test_data_prel.keys():\n            lmbda_psi = best_config[str(rel)]['lmbda_psi'][0]\n            alpha = best_config[str(rel)]['alpha'][0]\n\n            # test data for this relation\n            test_data_c_rel = test_data_prel[rel]\n            timesteps_test = list(set(test_data_c_rel[:,3]))\n            timesteps_test.sort()\n            all_data_c_rel = all_data_prel[rel]         \n            perf_list_rel = []\n            hits_list_rel = []\n            perf_list_rel, hits_list_rel = predict(num_processes, test_data_c_rel,\n                                                all_data_c_rel, alpha, lmbda_psi,perf_list_rel, hits_list_rel, \n                                                window, neg_sampler, split_mode)\n            perf_list_all.extend(perf_list_rel)\n            hits_list_all.extend(hits_list_rel)\n        else:\n            perf_list_rel =[]\n         \n        end =  timeit.default_timer()\n        total_time = round(end - start, 6)  \n        print(\"Relation {} finished in {} seconds.\".format(rel, total_time))\n        \n\n        with open(csv_file, 'a') as f:\n            f.write(\"{},{}\\n\".format(rel, perf_list_rel))\n\n    return perf_list_all, hits_list_all\n\ndef read_dict_compute_mrr(split_mode='test'):\n    \"\"\" read the results per relation  from a precreated file and compute mrr\n    :return  mrr_per_rel: dictionary of mrrs for each relation\n    :return all_mrrs: list of mrrs for all relations\n    \"\"\"\n    csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'\n    # Initialize an empty dictionary to store the data\n    results_per_rel_dict = {}\n    mrr_per_rel = {}\n    all_mrrs = []\n    # Open the file for reading\n    with open(csv_file, 'r') as f:\n        # Read each line in the file\n        for line in f:\n            # Split the line at the comma\n            parts = line.strip().split(',')\n            # Extract the key (the first part)\n            key = int(parts[0])\n            # Extract the values (the rest of the parts), remove square brackets\n            values = [float(value.strip('[]')) for value in parts[1:]]\n            # Add the key-value pair to the dictionary\n            if key in results_per_rel_dict.keys():\n                print(f\"Key {key} already exists in the dictionary!!! might have duplicate entries in results csv\")\n            results_per_rel_dict[key] = values\n            all_mrrs.extend(values)\n            mrr_per_rel[key] = np.mean(values)\n\n    if len(list(results_per_rel_dict.keys())) != num_rels:\n        print(\"we do not have entries for each rel in the results csv file. only num enties: \", len(list(results_per_rel_dict.keys())))\n\n    print(\"Split mode: \"+split_mode +\" Mean MRR: \", np.mean(all_mrrs))\n    print(\"mrr per relation: \", mrr_per_rel)\n\n\n\n## train\ndef train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampler, num_processes, window):\n    \"\"\" optional, find best values for lambda and alpha by looping through all relations and testing an a fixed set of params\n    based on validation mrr\n    :return best_config: dictionary of best params for each relation\n    \"\"\"\n    best_config= {}\n    best_mrr = 0\n    for rel in rels: # loop through relations. for each relation, apply rules with selected params, compute valid mrr\n        start =  timeit.default_timer()\n        rel_key = int(rel)            \n\n        best_config[str(rel_key)] = {}\n        best_config[str(rel_key)]['not_trained'] = 'True'    \n        best_config[str(rel_key)]['lmbda_psi'] = [default_lmbda_psi,0] #default\n        best_config[str(rel_key)]['other_lmbda_mrrs'] = list(np.zeros(len(params_dict['lmbda_psi'])))\n        best_config[str(rel_key)]['alpha'] = [default_alpha,0]  #default    \n        best_config[str(rel_key)]['other_alpha_mrrs'] = list(np.zeros(len(params_dict['alpha'])))\n        \n        if rel in val_data_prel.keys():      \n            # valid data for this relation  \n            val_data_c_rel = val_data_prel[rel]\n            timesteps_valid = list(set(val_data_c_rel[:,3]))\n            timesteps_valid.sort()\n            trainval_data_c_rel = trainval_data_prel[rel]\n\n            ######  1) select lambda ###############        \n            lmbdas_psi = params_dict['lmbda_psi']        \n\n            alpha = 1\n            best_lmbda_psi = 0.1\n            best_mrr_psi = 0\n            lmbda_mrrs = []\n\n            best_config[str(rel_key)]['num_app_valid'] = copy(len(val_data_c_rel))\n            best_config[str(rel_key)]['num_app_train_valid'] = copy(len(trainval_data_c_rel))         \n            best_config[str(rel_key)]['not_trained'] = 'False'       \n            \n            for lmbda_psi in lmbdas_psi:   \n                perf_list_r = []\n                hits_list_r = []\n                perf_list_r, hits_list_r = predict(num_processes, val_data_c_rel, \n                                                    trainval_data_c_rel, alpha, lmbda_psi,perf_list_r, hits_list_r, \n                                                    window, neg_sampler, split_mode='val')\n                # compute mrr\n                mrr = np.mean(perf_list_r)\n                # # is new mrr better than previous best? if yes: store lmbda\n                if mrr > best_mrr_psi:\n                    best_mrr_psi = float(mrr)\n                    best_lmbda_psi = lmbda_psi\n\n\n                lmbda_mrrs.append(float(mrr))\n            best_config[str(rel_key)]['lmbda_psi'] = [best_lmbda_psi, best_mrr_psi]\n            best_config[str(rel_key)]['other_lmbda_mrrs'] = lmbda_mrrs\n            best_mrr = best_mrr_psi\n            ##### 2) select alpha ###############\n            best_config[str(rel_key)]['not_trained'] = 'False'    \n            alphas = params_dict['alpha'] \n            lmbda_psi = best_config[str(rel_key)]['lmbda_psi'][0] # use the best lmbda psi\n\n            alpha_mrrs = []\n            # perf_list_all = []\n            best_mrr_alpha = 0\n            best_alpha=0.99\n            for alpha in alphas:\n                perf_list_r = []\n                hits_list_r = []\n\n                perf_list_r, hits_list_r = predict(num_processes, val_data_c_rel, \n                                                    trainval_data_c_rel, alpha, lmbda_psi,perf_list_r, hits_list_r, \n                                                    window, neg_sampler, split_mode='val')\n                # compute mrr\n                mrr_alpha = np.mean(perf_list_r)\n\n                # is new mrr better than previous best? if yes: store alpha\n                if mrr_alpha > best_mrr_alpha:\n                    best_mrr_alpha = float(mrr_alpha)\n                    best_alpha = alpha\n                    best_mrr = best_mrr_alpha\n                alpha_mrrs.append(float(mrr_alpha))\n\n            best_config[str(rel_key)]['alpha'] = [best_alpha, best_mrr_alpha]\n            best_config[str(rel_key)]['other_alpha_mrrs'] = alpha_mrrs\n\n        end =  timeit.default_timer()\n        total_time = round(end - start, 6)  \n        print(\"Relation {} finished in {} seconds.\".format(rel, total_time))\n    return best_config\n\n\n\n## args\ndef get_args(): \n    \"\"\"parse all arguments for the script\"\"\"\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--dataset\", \"-d\", default=\"thgl-forum\", type=str) \n    parser.add_argument(\"--window\", \"-w\", default=0, type=int) # set to e.g. 200 if only the most recent 200 timesteps should be considered. set to -2 if multistep\n    parser.add_argument(\"--num_processes\", \"-p\", default=1, type=int)\n    parser.add_argument(\"--lmbda\", \"-l\",  default=0.1, type=float) # fix lambda. used if trainflag == false\n    parser.add_argument(\"--alpha\", \"-alpha\",  default=0.99, type=float) # fix alpha. used if trainflag == false\n    parser.add_argument(\"--train_flag\", \"-tr\",  default='False') # do we need training, ie selection of lambda and alpha\n    parser.add_argument(\"--load_flag\", \"-lo\",  default='False') # if train_flag set to True: do you want to load best_config?\n    parser.add_argument(\"--save_config\", \"-c\",  default='True') # do we need to save the selection of lambda and alpha in config file?\n    parser.add_argument('--seed', type=int, help='Random seed', default=1) # not needed\n    parsed = vars(parser.parse_args())\n    return parsed\n\nstart_o =  timeit.default_timer()\n\nparsed = get_args()\nif parsed['num_processes']>1:\n    ray.init(num_cpus=parsed[\"num_processes\"], num_gpus=0)\nMODEL_NAME = 'RecurrencyBaseline'\nSEED = parsed['seed']  # set the random seed for consistency\nset_random_seed(SEED)\nperrel_results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_models'\nif not osp.exists(perrel_results_path):\n    os.mkdir(perrel_results_path)\n    print('INFO: Create directory {}'.format(perrel_results_path))\nPath(perrel_results_path).mkdir(parents=True, exist_ok=True)\n\n## load dataset and prepare it accordingly\nname = parsed[\"dataset\"]\ndataset = LinkPropPredDataset(name=name, root=\"datasets\", preprocess=True)\nDATA = name\n\nrelations = dataset.edge_type\nnum_rels = dataset.num_rels\nrels = np.arange(0,num_rels)\nsubjects = dataset.full_data[\"sources\"]\nobjects= dataset.full_data[\"destinations\"]\nnum_nodes = dataset.num_nodes \ntimestamps_orig = dataset.full_data[\"timestamps\"]\ntimestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1\n\nprint(\"split train valid test data\")\nall_quads = np.stack((subjects, relations, objects, timestamps, timestamps_orig), axis=1)\ntrain_data = all_quads[dataset.train_mask]\nval_data = all_quads[dataset.val_mask]\ntest_data = all_quads[dataset.test_mask]\n\nmetric = dataset.eval_metric\nevaluator = Evaluator(name=name)\nneg_sampler = dataset.negative_sampler\n\ntrain_val_data = np.concatenate([train_data, val_data])\nall_data = np.concatenate([train_data, val_data, test_data])\n\n# create dicts with key: relation id, values: triples for that relation id\nprint(\"grouping data by relation\")\ntest_data_prel = group_by(test_data, 1)\nall_data_prel = group_by(all_data, 1)\nval_data_prel = group_by(val_data, 1)\ntrainval_data_prel = group_by(train_val_data, 1)\n\n#load the ns samples \n# if parsed['train_flag']:\nprint(\"loading negative samples\")\ndataset.load_val_ns()\ndataset.load_test_ns()\n\n# parameter options\nif parsed['train_flag'] == 'True':\n    params_dict = {}\n    params_dict['lmbda_psi'] = [0, 0.0001, 0.0005, 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.5, 0.9, 1.0001] \n    params_dict['alpha'] = [0, 0.00001, 0.0001, 0.001, 0.01, 0.1, 0.5, 0.9, 0.99, 0.999, 0.9999, 0.99999, 1]\n    default_lmbda_psi = params_dict['lmbda_psi'][-1]\n    default_alpha = params_dict['alpha'][-2]\n\n## load rules\nprint(\"creating rules\")\nbasis_dict = create_basis_dict(train_val_data)\nprint(\"done with creating rules\")\n## init\n# rb_predictor = RecurrencyBaselinePredictor(rels)\n## train to find best lambda and alpha\nstart_train =  timeit.default_timer()\nif parsed['train_flag'] ==  'True':\n    if parsed['load_flag'] == 'True':\n        with open('best_config.json', 'r') as infile:\n            best_config = json.load(infile)\n    else:\n        print('start training')\n        best_config = train(params_dict,  rels, val_data_prel, trainval_data_prel, neg_sampler, parsed['num_processes'], \n            parsed['window'])\n        if parsed['save_config'] == 'True':\n            import json\n            with open('best_config.json', 'w') as outfile:\n                json.dump(best_config, outfile)\n\nelse: # use preset lmbda and alpha; same for all relations\n    best_config = {} \n    for rel in rels:\n        best_config[str(rel)] = {}\n        best_config[str(rel)]['lmbda_psi'] = [parsed['lmbda']]\n        best_config[str(rel)]['alpha'] = [parsed['alpha']]\n    \nend_train =  timeit.default_timer()\n\n# compute validation mrr\nprint(\"Computing validation MRR\")\nperf_list_all_val, hits_list_all_val = test(best_config,rels, val_data_prel, \n                                                trainval_data_prel, neg_sampler, parsed['num_processes'], \n                                            parsed['window'], split_mode='val')\nval_mrr = float(np.mean(perf_list_all_val))\n\n# compute test mrr\nprint(\"Computing test MRR\")\nstart_test =  timeit.default_timer()\nperf_list_all, hits_list_all = test(best_config,rels, test_data_prel, \n                                                 all_data_prel, neg_sampler, parsed['num_processes'], \n                                                parsed['window'])\n\nend_o =  timeit.default_timer()\ntrain_time_o = round(end_train- start_train, 6)  \ntest_time_o = round(end_o- start_test, 6)  \ntotal_time_o = round(end_o- start_o, 6)  \nprint(\"Running Training to find best configs finished in {} seconds.\".format(train_time_o))\nprint(\"Running testing with best configs finished in {} seconds.\".format(test_time_o))\nprint(\"Running all steps finished in {} seconds.\".format(total_time_o))\n\nprint(f\"The test MRR is {np.mean(perf_list_all)}\")\nprint(f\"The valid MRR is {val_mrr}\")\nprint(f\"The Hits@10 is {np.mean(hits_list_all)}\")\nprint(f\"We have {len(perf_list_all)} predictions\")\nprint(f\"The test set has len {len(test_data)} \")\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_NONE_{DATA}_results.json'\nmetric = dataset.eval_metric\nsave_results({'model': MODEL_NAME,\n              'train_flag': parsed['train_flag'],\n              'data': DATA,\n              'run': 1,\n              'seed': SEED,\n              metric: float(np.mean(perf_list_all)),\n              'val_mrr': val_mrr,\n              'hits10': float(np.mean(hits_list_all)),\n              'test_time': test_time_o,\n              'tot_train_val_time': total_time_o\n              }, \n    results_filename)\n\nif parsed['num_processes']>1:\n    ray.shutdown()\n\n\n    \n"
  },
  {
    "path": "examples/linkproppred/thgl-forum/sthn.py",
    "content": "import timeit\nimport numpy as np\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\nfrom tgb.linkproppred.evaluate import Evaluator\n\nimport argparse\nfrom modules.sthn import set_seed, pre_compute_subgraphs, get_inputs_for_ind, check_data_leakage\nimport torch\nimport pandas as pd\nimport itertools\nfrom tqdm import tqdm\nimport math\nimport os\nimport os.path as osp\nfrom pathlib import Path\nfrom tgb.utils.utils import set_random_seed, save_results\n\n\n# Start...\nstart_overall = timeit.default_timer()\n\nDATA = \"thgl-forum\"\n\nMODEL_NAME = 'STHN'\n\n# data loading\ndataset = PyGLinkPropPredDataset(name=DATA, root=\"datasets\")\ntrain_mask = dataset.train_mask\nval_mask = dataset.val_mask\ntest_mask = dataset.test_mask\ndata = dataset.get_TemporalData()\nmetric = dataset.eval_metric\n\nprint (\"there are {} nodes and {} edges\".format(dataset.num_nodes, dataset.num_edges))\nprint (\"there are {} relation types\".format(dataset.num_rels))\n\n\ntimestamp = data.t\nhead = data.src\ntail = data.dst\nedge_type = data.edge_type\nneg_sampler = dataset.negative_sampler\n\nprint(data)\nprint(timestamp)\nprint(head)\nprint(tail)\n\ntrain_data = data[train_mask]\nval_data = data[val_mask]\ntest_data = data[test_mask]\n\n\nmetric = dataset.eval_metric\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'\n\n\n####################################################################\n####################################################################\n####################################################################\n\n\ndef print_model_info(model):\n    print(model)\n    parameters = filter(lambda p: p.requires_grad, model.parameters())\n    parameters = sum([np.prod(p.size()) for p in parameters])\n    print('Trainable Parameters: %d' % parameters)\n\ndef get_args():\n    parser=argparse.ArgumentParser()\n    parser.add_argument('--data', type=str, default='movie')\n    parser.add_argument('--device', type=int, default=0)\n    parser.add_argument('--batch_size', type=int, default=600)\n    parser.add_argument('--epochs', type=int, default=2)\n    parser.add_argument('--max_edges', type=int, default=50)\n    parser.add_argument('--num_edgeType', type=int, default=0, help='num of edgeType')\n    parser.add_argument('--lr', type=float, default=0.0005)\n    parser.add_argument('--weight_decay', type=float, default=1e-4)\n    parser.add_argument('--predict_class', action='store_true')\n    \n    # model\n    parser.add_argument('--window_size', type=int, default=5)\n    parser.add_argument('--dropout', type=float, default=0.1)\n    parser.add_argument('--model', type=str, default='sthn') \n    parser.add_argument('--neg_samples', type=int, default=1)\n    parser.add_argument('--extra_neg_samples', type=int, default=5)\n    parser.add_argument('--num_neighbors', type=int, default=50)\n    parser.add_argument('--channel_expansion_factor', type=int, default=2)\n    parser.add_argument('--sampled_num_hops', type=int, default=1)\n    parser.add_argument('--time_dims', type=int, default=100)\n    parser.add_argument('--hidden_dims', type=int, default=100)\n    parser.add_argument('--num_layers', type=int, default=1)\n    parser.add_argument('--check_data_leakage', action='store_true')\n    \n    parser.add_argument('--ignore_node_feats', action='store_true')\n    parser.add_argument('--node_feats_as_edge_feats', action='store_true')\n    parser.add_argument('--ignore_edge_feats', action='store_true')\n    parser.add_argument('--use_onehot_node_feats', action='store_true')\n    parser.add_argument('--use_type_feats', action='store_true')\n\n    parser.add_argument('--use_graph_structure', action='store_true')\n    parser.add_argument('--structure_time_gap', type=int, default=2000)\n    parser.add_argument('--structure_hops', type=int, default=1) \n\n    parser.add_argument('--use_node_cls', action='store_true')\n    parser.add_argument('--use_cached_subgraph', action='store_true')\n    \n    parser.add_argument('--seed', type=int, help='Random seed', default=1)\n    parser.add_argument('--num_run', type=int, help='Number of iteration runs', default=5)\n    return parser.parse_args()\n\n\ndef load_model(args):\n    # get model\n    edge_predictor_configs = {\n        'dim_in_time': args.time_dims,\n        'dim_in_node': args.node_feat_dims,\n        'predict_class': 1 if not args.predict_class else args.num_edgeType+1,\n    }\n    if args.model == 'sthn':\n        if args.predict_class:\n            from modules.sthn import Multiclass_Interface as STHN_Interface\n        else:\n            from modules.sthn import STHN_Interface\n        from modules.sthn import link_pred_train\n\n        mixer_configs = {\n            'per_graph_size'  : args.max_edges,\n            'time_channels'   : args.time_dims, \n            'input_channels'  : args.edge_feat_dims, \n            'hidden_channels' : args.hidden_dims, \n            'out_channels'    : args.hidden_dims,\n            'num_layers'      : args.num_layers,\n            'dropout'         : args.dropout,\n            'channel_expansion_factor': args.channel_expansion_factor,\n            'window_size'     : args.window_size,\n            'use_single_layer' : False\n        }  \n        \n    else:\n        NotImplementedError()\n\n    model = STHN_Interface(mixer_configs, edge_predictor_configs)\n    for k, v in model.named_parameters():\n        print(k, v.requires_grad)\n\n    print_model_info(model)\n\n    return model, args, link_pred_train\n\ndef load_graph(data):\n    df = pd.DataFrame({\n        'idx': np.arange(len(data.t)),\n        'src': data.src,\n        'dst': data.dst,\n        'time': data.t,\n        'label': data.edge_type,\n    })\n\n    num_nodes = max(int(df['src'].max()), int(df['dst'].max())) + 1 \n\n    ext_full_indptr = np.zeros(num_nodes + 1, dtype=np.int32)\n    ext_full_indices = [[] for _ in range(num_nodes)]\n    ext_full_ts = [[] for _ in range(num_nodes)]\n    ext_full_eid = [[] for _ in range(num_nodes)]\n\n    for idx, row in tqdm(df.iterrows(), total=len(df)):\n        src = int(row['src'])\n        dst = int(row['dst'])\n        \n        ext_full_indices[src].append(dst)\n        ext_full_ts[src].append(row['time'])\n        ext_full_eid[src].append(idx)\n        \n    for i in tqdm(range(num_nodes)):\n        ext_full_indptr[i + 1] = ext_full_indptr[i] + len(ext_full_indices[i])\n\n    ext_full_indices = np.array(list(itertools.chain(*ext_full_indices)))\n    ext_full_ts = np.array(list(itertools.chain(*ext_full_ts)))\n    ext_full_eid = np.array(list(itertools.chain(*ext_full_eid)))\n\n    print('Sorting...')\n\n    def tsort(i, indptr, indices, t, eid):\n        beg = indptr[i]\n        end = indptr[i + 1]\n        sidx = np.argsort(t[beg:end])\n        indices[beg:end] = indices[beg:end][sidx]\n        t[beg:end] = t[beg:end][sidx]\n        eid[beg:end] = eid[beg:end][sidx] \n\n    for i in tqdm(range(ext_full_indptr.shape[0] - 1)):\n        tsort(i, ext_full_indptr, ext_full_indices, ext_full_ts, ext_full_eid)\n\n    print('saving...')\n\n    np.savez('/tmp/ext_full.npz', indptr=ext_full_indptr,\n            indices=ext_full_indices, ts=ext_full_ts, eid=ext_full_eid)\n    g = np.load('/tmp/ext_full.npz')\n    return g, df\n\ndef load_all_data(args):\n\n    # load graph\n    g, df = load_graph(data)\n\n    args.train_mask = train_mask.numpy()\n    args.val_mask   = val_mask.numpy()\n    args.test_mask = test_mask.numpy()\n    args.num_edges = len(df)\n\n    print('Train %d, Valid %d, Test %d'%(sum(args.train_mask), \n                                         sum(args.val_mask),\n                                         sum(test_mask)))\n    \n    args.num_nodes = max(int(df['src'].max()), int(df['dst'].max())) + 1\n    args.num_edges = len(df)\n\n    print('Num nodes %d, num edges %d'%(args.num_nodes, args.num_edges))\n\n    # load feats \n    node_feats, edge_feats = dataset.node_feat, dataset.edge_feat\n    node_feat_dims = 0 if node_feats is None else node_feats.shape[1]\n    edge_feat_dims = 0 if edge_feats is None else edge_feats.shape[1]\n\n    # feature pre-processing\n    if args.use_onehot_node_feats:\n        print('>>> Use one-hot node features')\n        node_feats = torch.eye(args.num_nodes)\n        node_feat_dims = node_feats.size(1)\n\n    if args.ignore_node_feats:\n        print('>>> Ignore node features')\n        node_feats = None\n        node_feat_dims = 0\n\n    if args.use_type_feats:\n        edge_type = df.label.values\n        print(edge_type)\n        print(edge_type.sum())\n        args.num_edgeType = len(set(edge_type.tolist()))\n        edge_feats = torch.nn.functional.one_hot(torch.from_numpy(edge_type), \n                                                 num_classes=args.num_edgeType)\n        edge_feat_dims = edge_feats.size(1)\n        \n    print('Node feature dim %d, edge feature dim %d'%(node_feat_dims, edge_feat_dims))\n    \n    # double check (if data leakage then cannot continue the code)\n    if args.check_data_leakage:\n        check_data_leakage(args, g, df)\n\n    args.node_feat_dims = node_feat_dims\n    args.edge_feat_dims = edge_feat_dims\n    \n    if node_feats != None:\n        node_feats = node_feats.to(args.device)\n    if edge_feats != None:\n        edge_feats = edge_feats.to(args.device)\n    \n    return node_feats, edge_feats, g, df, args\n\n####################################################################\n####################################################################\n####################################################################\n\n@torch.no_grad()\ndef test(data, test_mask, model, neg_sampler, split_mode):\n    r\"\"\"\n    Evaluated the dynamic link prediction\n    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges\n\n    Parameters:\n        data: a dataset object\n        test_mask: required masks to load the test set edges\n        neg_sampler: an object that gives the negative edges corresponding to each positive edge\n        split_mode: specifies whether it is the 'val' or 'test' set to correctly load the negatives\n    Returns:\n        perf_metric: the result of the performance evaluation\n    \"\"\"\n    test_subgraphs  = pre_compute_subgraphs(args, g, df, mode='test' if split_mode == 'test' else 'valid', negative_sampler=neg_sampler, split_mode=split_mode)\n    perf_list = []\n    \n    if split_mode == 'test':\n        cur_df = df[args.test_mask]\n    elif split_mode == 'val':\n        cur_df = df[args.val_mask]\n    neg_samples = 20\n    cached_neg_samples = 20\n\n    test_loader = cur_df.groupby(cur_df.index // args.batch_size)\n    pbar = tqdm(total=len(test_loader))\n    pbar.set_description('%s mode with negative samples %d ...'%(split_mode, neg_samples))        \n    \n    ###################################################\n    # compute + training + fetch all scores\n    cur_inds = 0\n\n    for ind in range(len(test_loader)):\n        ###################################################\n        inputs, subgraph_node_feats, cur_inds = get_inputs_for_ind(test_subgraphs, 'test' if split_mode == 'test' else 'tgb-val', cached_neg_samples, neg_samples, node_feats, edge_feats, cur_df, cur_inds, ind, args)\n        \n        loss, pred, edge_label = model(inputs, neg_samples, subgraph_node_feats)\n        # print(ind, [l for l in inputs], pred.shape)\n\n        input_dict = {\n            \"y_pred_pos\": np.array([pred.cpu()[0]]),\n            \"y_pred_neg\": np.array(pred.cpu()[1:]),\n            \"eval_metric\": [metric],\n        }\n        perf_list.append(evaluator.eval(input_dict)[metric])\n\n    perf_metrics_mean = float(np.mean(perf_list))\n    perf_metrics_std = float(np.std(perf_list))\n\n    return perf_metrics_mean, perf_metrics_std, perf_list\n\n\nargs = get_args()\n\nargs.use_graph_structure = True\nargs.use_onehot_node_feats = False\nargs.ignore_node_feats = False # we only use graph structure\nargs.use_type_feats = True # type encoding\nargs.use_cached_subgraph = True\n\nprint(args)\n\nargs.device = f\"cuda:{args.device}\" if torch.cuda.is_available() else \"cpu\"\nargs.device = torch.device(args.device)\nSEED = args.seed\nBATCH_SIZE = args.batch_size\nNUM_RUNS = args.num_run\nset_seed(SEED)\n\n\n###################################################\n# load feats + graph\nnode_feats, edge_feats, g, df, args = load_all_data(args)\n\n###################################################\n# get model \nmodel, args, link_pred_train = load_model(args)\n\n###################################################\n\nprint(\"==========================================================\")\nprint(f\"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============\")\nprint(\"==========================================================\")\n\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'\n\n\n\nfor run_idx in range(NUM_RUNS):\n    print('-------------------------------------------------------------------------------')\n    print(f\"INFO: >>>>> Run: {run_idx} <<<<<\")\n    start_run = timeit.default_timer()\n\n    # set the seed for deterministic results...\n    torch.manual_seed(run_idx + SEED)\n    set_random_seed(run_idx + SEED)\n\n    # define an early stopper\n    save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'\n    save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'\n    # early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id, \n    #                                 tolerance=TOLERANCE, patience=PATIENCE)\n\n    # ==================================================== Train & Validation\n    # loading the validation negative samples\n\n    # Link prediction\n    start_val = timeit.default_timer()\n    print('Train link prediction task from scratch ...')\n    model = link_pred_train(model.to(args.device), args, g, df, node_feats, edge_feats)\n\n    dataset.load_val_ns()\n\n    # Validation ...\n    \n    perf_metrics_val_mean, perf_metrics_val_std, perf_list_val = test(data.to(args.device), test_mask, model.to(args.device), neg_sampler, split_mode='val')\n    end_val = timeit.default_timer()\n\n    print(f\"INFO: val: Evaluation Setting: >>> ONE-VS-MANY <<< \")\n    print(f\"\\tval: {metric}: {perf_metrics_val_mean: .4f} ± {perf_metrics_val_std: .4f}\")\n    val_time = timeit.default_timer() - start_val\n    print(f\"\\tval: Elapsed Time (s): {val_time: .4f}\")\n\n\n    dataset.load_test_ns()\n\n    # testing ...\n    start_test = timeit.default_timer()\n    perf_metrics_test_mean, perf_metrics_test_std, perf_list_test = test(data.to(args.device), test_mask, model.to(args.device), neg_sampler, split_mode='test')\n    end_test = timeit.default_timer()\n\n    print(f\"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< \")\n    print(f\"\\tTest: {metric}: {perf_metrics_test_mean: .4f} ± {perf_metrics_test_std: .4f}\")\n    test_time = timeit.default_timer() - start_test\n    print(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\n\n    save_results({'model': MODEL_NAME,\n                  'data': DATA,\n                  'run': run_idx,\n                  'seed': SEED,\n                  f'val {metric}': f'{perf_metrics_val_mean: .4f} ± {perf_metrics_val_std: .4f}' ,\n                  f'test {metric}': f'{perf_metrics_test_mean: .4f} ± {perf_metrics_test_std: .4f}' ,\n                  'test_time': test_time,\n                  'tot_train_val_time': val_time\n                  }, \n    results_filename)\n\n    print(f\"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<\")\n    print('-------------------------------------------------------------------------------')\n\nprint(f\"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}\")\nprint(\"==============================================================\")\n\n# save_results({'model': MODEL_NAME,\n#             'data': DATA,\n#             'run': 1,\n#             'seed': SEED,\n#             metric: perf_metric_test,\n#             'test_time': test_time,\n#             'tot_train_val_time': 'NA'\n#             }, \n#     results_filename)"
  },
  {
    "path": "examples/linkproppred/thgl-forum/tgn.py",
    "content": "import numpy as np\r\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\r\nfrom tgb.linkproppred.evaluate import Evaluator\r\nfrom torch_geometric.loader import TemporalDataLoader\r\nfrom tqdm import tqdm\r\nimport timeit\r\n\r\n\r\nimport math\r\nimport timeit\r\n\r\nimport os\r\nimport os.path as osp\r\nfrom pathlib import Path\r\nimport numpy as np\r\n\r\nimport torch\r\nfrom sklearn.metrics import average_precision_score, roc_auc_score\r\nfrom torch.nn import Linear\r\n\r\nfrom torch_geometric.datasets import JODIEDataset\r\nfrom torch_geometric.loader import TemporalDataLoader\r\nfrom torch_geometric.nn import TransformerConv\r\n\r\n# internal imports\r\nfrom tgb.utils.utils import get_args, set_random_seed, save_results\r\nfrom tgb.linkproppred.evaluate import Evaluator\r\nfrom modules.decoder import LinkPredictor\r\nfrom modules.emb_module import GraphAttentionEmbedding\r\nfrom modules.msg_func import IdentityMessage\r\nfrom modules.msg_agg import LastAggregator\r\nfrom modules.neighbor_loader import LastNeighborLoader\r\nfrom modules.memory_module import TGNMemory\r\nfrom modules.early_stopping import  EarlyStopMonitor\r\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\r\n\r\n\r\n# ==========\r\n# ========== Define helper function...\r\n# ==========\r\n\r\ndef train():\r\n    r\"\"\"\r\n    Training procedure for TGN model\r\n    This function uses some objects that are globally defined in the current scrips \r\n\r\n    Parameters:\r\n        None\r\n    Returns:\r\n        None\r\n            \r\n    \"\"\"\r\n\r\n    model['memory'].train()\r\n    model['gnn'].train()\r\n    model['link_pred'].train()\r\n\r\n    model['memory'].reset_state()  # Start with a fresh memory.\r\n    neighbor_loader.reset_state()  # Start with an empty graph.\r\n\r\n    total_loss = 0\r\n    for batch in train_loader:\r\n        batch = batch.to(device)\r\n        optimizer.zero_grad()\r\n\r\n        src, pos_dst, t, msg, rel = batch.src, batch.dst, batch.t, batch.msg, batch.edge_type\r\n\r\n        # Sample negative destination nodes.\r\n        neg_dst = torch.randint(\r\n            min_dst_idx,\r\n            max_dst_idx + 1,\r\n            (src.size(0),),\r\n            dtype=torch.long,\r\n            device=device,\r\n        )\r\n\r\n        n_id = torch.cat([src, pos_dst, neg_dst]).unique()\r\n        n_id, edge_index, e_id = neighbor_loader(n_id)\r\n        assoc[n_id] = torch.arange(n_id.size(0), device=device)\r\n\r\n        # Get updated memory of all nodes involved in the computation.\r\n        z, last_update = model['memory'](n_id)\r\n        z = model['gnn'](\r\n            z,\r\n            last_update,\r\n            edge_index,\r\n            data.t[e_id].to(device),\r\n            data.msg[e_id].to(device),\r\n        )\r\n\r\n        pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])\r\n        neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])\r\n\r\n        loss = criterion(pos_out, torch.ones_like(pos_out))\r\n        loss += criterion(neg_out, torch.zeros_like(neg_out))\r\n\r\n        # Update memory and neighbor loader with ground-truth state.\r\n        model['memory'].update_state(src, pos_dst, t, msg)\r\n        neighbor_loader.insert(src, pos_dst)\r\n\r\n        loss.backward()\r\n        optimizer.step()\r\n        model['memory'].detach()\r\n        total_loss += float(loss.detach()) * batch.num_events\r\n\r\n    return total_loss / train_data.num_events\r\n\r\n\r\n@torch.no_grad()\r\ndef test(loader, neg_sampler, split_mode):\r\n    r\"\"\"\r\n    Evaluated the dynamic link prediction\r\n    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges\r\n\r\n    Parameters:\r\n        loader: an object containing positive attributes of the positive edges of the evaluation set\r\n        neg_sampler: an object that gives the negative edges corresponding to each positive edge\r\n        split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives\r\n    Returns:\r\n        perf_metric: the result of the performance evaluaiton\r\n    \"\"\"\r\n    model['memory'].eval()\r\n    model['gnn'].eval()\r\n    model['link_pred'].eval()\r\n\r\n    perf_list = []\r\n\r\n    for pos_batch in loader:\r\n        pos_src, pos_dst, pos_t, pos_msg, pos_rel = (\r\n            pos_batch.src,\r\n            pos_batch.dst,\r\n            pos_batch.t,\r\n            pos_batch.msg,\r\n            pos_batch.edge_type\r\n        )\r\n\r\n        neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, pos_rel, split_mode=split_mode)\r\n\r\n        # pos_msg_new = torch.cat([pos_msg,pos_rel.unsqueeze(dim=1)], dim=1)   \r\n\r\n\r\n        for idx, neg_batch in enumerate(neg_batch_list):\r\n            src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)\r\n            dst = torch.tensor(\r\n                np.concatenate(\r\n                    ([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),\r\n                    axis=0,\r\n                ),\r\n                device=device,\r\n            )\r\n\r\n            n_id = torch.cat([src, dst]).unique()\r\n            n_id, edge_index, e_id = neighbor_loader(n_id)\r\n            assoc[n_id] = torch.arange(n_id.size(0), device=device)\r\n\r\n            # Get updated memory of all nodes involved in the computation.\r\n            z, last_update = model['memory'](n_id)\r\n            z = model['gnn'](\r\n                z,\r\n                last_update,\r\n                edge_index,\r\n                data.t[e_id].to(device),\r\n                data.msg[e_id].to(device),\r\n            )\r\n\r\n            y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])\r\n\r\n            # compute MRR\r\n            input_dict = {\r\n                \"y_pred_pos\": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),\r\n                \"y_pred_neg\": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),\r\n                \"eval_metric\": [metric],\r\n            }\r\n            perf_list.append(evaluator.eval(input_dict)[metric])\r\n\r\n        # Update memory and neighbor loader with ground-truth state.\r\n        model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg)\r\n        neighbor_loader.insert(pos_src, pos_dst)\r\n\r\n    perf_metrics = float(torch.tensor(perf_list).mean())\r\n\r\n    return perf_metrics\r\n\r\n# ==========\r\n# ==========\r\n# ==========\r\n\r\n\r\n# ==========\r\n# ==========\r\n# ==========\r\n\r\n\r\n# Start...\r\nstart_overall = timeit.default_timer()\r\n\r\nDATA = \"thgl-forum\"\r\n\r\n# ========== set parameters...\r\nargs, _ = get_args()\r\nargs.data = DATA\r\nprint(\"INFO: Arguments:\", args)\r\n\r\nLR = args.lr\r\nBATCH_SIZE = args.bs\r\nK_VALUE = args.k_value  \r\nNUM_EPOCH = args.num_epoch\r\nSEED = args.seed\r\nMEM_DIM = args.mem_dim\r\nTIME_DIM = args.time_dim\r\nEMB_DIM = args.emb_dim\r\nTOLERANCE = args.tolerance\r\nPATIENCE = args.patience\r\nNUM_RUNS = args.num_run\r\nNUM_NEIGHBORS = 10\r\nUSE_EDGE_TYPE = True\r\nUSE_NODE_TYPE = True\r\n\r\n\r\n\r\nMODEL_NAME = 'TGN'\r\n# ==========\r\n\r\n# set the device\r\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\r\n\r\n\r\n# data loading\r\ndataset = PyGLinkPropPredDataset(name=DATA, root=\"datasets\")\r\ntrain_mask = dataset.train_mask\r\nval_mask = dataset.val_mask\r\ntest_mask = dataset.test_mask\r\ndata = dataset.get_TemporalData()\r\ndata = data.to(device)\r\nmetric = dataset.eval_metric\r\n\r\nprint (\"there are {} nodes and {} edges\".format(dataset.num_nodes, dataset.num_edges))\r\nprint (\"there are {} relation types\".format(dataset.num_rels))\r\n\r\n\r\ntimestamp = data.t\r\nhead = data.src\r\ntail = data.dst\r\nedge_type = data.edge_type #relation\r\nedge_type_dim = len(torch.unique(edge_type))\r\n\r\nembed_edge_type = torch.nn.Embedding(edge_type_dim, 64).to(device)\r\nwith torch.no_grad():\r\n    edge_type_embeddings = embed_edge_type(edge_type)\r\n\r\n\r\nif USE_EDGE_TYPE:\r\n    data.msg = torch.cat([data.msg, edge_type_embeddings], dim=1)\r\n\r\n#! node type is a property of the dataset not the temporal data as temporal data has one entry per edge\r\nnode_type = dataset.node_type #node type\r\nneg_sampler = dataset.negative_sampler\r\n\r\ndata.__setattr__(\"node_type\", node_type)\r\n\r\nprint (\"shape of edge type is\", edge_type.shape)\r\nprint (\"shape of node type is\", node_type.shape)\r\n\r\ntrain_data = data[train_mask]\r\nval_data = data[val_mask]\r\ntest_data = data[test_mask]\r\nprint (\"finished loading PyG data\")\r\n\r\ntrain_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)\r\nval_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)\r\ntest_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)\r\n\r\nstart_time = timeit.default_timer()\r\n\r\n# Ensure to only sample actual destination nodes as negatives.\r\nmin_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\r\n\r\n\r\nprint(\"==========================================================\")\r\nprint(f\"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============\")\r\nprint(\"==========================================================\")\r\n\r\nevaluator = Evaluator(name=DATA)\r\nneg_sampler = dataset.negative_sampler\r\n\r\n# for saving the results...\r\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\r\nif not osp.exists(results_path):\r\n    os.mkdir(results_path)\r\n    print('INFO: Create directory {}'.format(results_path))\r\nPath(results_path).mkdir(parents=True, exist_ok=True)\r\nresults_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'\r\n\r\nfor run_idx in range(NUM_RUNS):\r\n    print('-------------------------------------------------------------------------------')\r\n    print(f\"INFO: >>>>> Run: {run_idx} <<<<<\")\r\n    start_run = timeit.default_timer()\r\n\r\n    # set the seed for deterministic results...\r\n    torch.manual_seed(run_idx + SEED)\r\n    set_random_seed(run_idx + SEED)\r\n\r\n    # neighhorhood sampler\r\n    neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)\r\n\r\n    # define the model end-to-end\r\n    memory = TGNMemory(\r\n        data.num_nodes,\r\n        data.msg.size(-1),\r\n        MEM_DIM,\r\n        TIME_DIM,\r\n        message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),\r\n        aggregator_module=LastAggregator(),\r\n    ).to(device)\r\n\r\n    gnn = GraphAttentionEmbedding(\r\n        in_channels=MEM_DIM,\r\n        out_channels=EMB_DIM,\r\n        msg_dim=data.msg.size(-1),\r\n        time_enc=memory.time_enc,\r\n    ).to(device)\r\n\r\n    link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)\r\n\r\n    model = {'memory': memory,\r\n            'gnn': gnn,\r\n            'link_pred': link_pred}\r\n\r\n    optimizer = torch.optim.Adam(\r\n        set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),\r\n        lr=LR,\r\n    )\r\n    criterion = torch.nn.BCEWithLogitsLoss()\r\n\r\n    # Helper vector to map global node indices to local ones.\r\n    assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)\r\n\r\n    # define an early stopper\r\n    save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'\r\n    save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'\r\n    early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id, \r\n                                    tolerance=TOLERANCE, patience=PATIENCE)\r\n\r\n    # ==================================================== Train & Validation\r\n    # loading the validation negative samples\r\n    dataset.load_val_ns()\r\n\r\n    val_perf_list = []\r\n    start_train_val = timeit.default_timer()\r\n    for epoch in range(1, NUM_EPOCH + 1):\r\n        # training\r\n        start_epoch_train = timeit.default_timer()\r\n        loss = train()\r\n        print(\r\n            f\"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {timeit.default_timer() - start_epoch_train: .4f}\"\r\n        )\r\n\r\n        # validation\r\n        start_val = timeit.default_timer()\r\n        perf_metric_val = test(val_loader, neg_sampler, split_mode=\"val\")\r\n        print(f\"\\tValidation {metric}: {perf_metric_val: .4f}\")\r\n        print(f\"\\tValidation: Elapsed time (s): {timeit.default_timer() - start_val: .4f}\")\r\n        val_perf_list.append(perf_metric_val)\r\n\r\n        # check for early stopping\r\n        if early_stopper.step_check(perf_metric_val, model):\r\n            break\r\n\r\n    train_val_time = timeit.default_timer() - start_train_val\r\n    print(f\"Train & Validation: Elapsed Time (s): {train_val_time: .4f}\")\r\n\r\n    # ==================================================== Test\r\n    # first, load the best model\r\n    early_stopper.load_checkpoint(model)\r\n\r\n    # loading the test negative samples\r\n    dataset.load_test_ns()\r\n\r\n    # final testing\r\n    start_test = timeit.default_timer()\r\n    perf_metric_test = test(test_loader, neg_sampler, split_mode=\"test\")\r\n\r\n    print(f\"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< \")\r\n    print(f\"\\tTest: {metric}: {perf_metric_test: .4f}\")\r\n    test_time = timeit.default_timer() - start_test\r\n    print(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\r\n\r\n    save_results({'model': MODEL_NAME,\r\n                  'data': DATA,\r\n                  'run': run_idx,\r\n                  'seed': SEED,\r\n                  f'val {metric}': val_perf_list,\r\n                  f'test {metric}': perf_metric_test,\r\n                  'test_time': test_time,\r\n                  'tot_train_val_time': train_val_time\r\n                  }, \r\n    results_filename)\r\n\r\n    print(f\"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<\")\r\n    print('-------------------------------------------------------------------------------')\r\n\r\nprint(f\"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}\")\r\nprint(\"==============================================================\")\r\n\r\n# #* load numpy arrays instead\r\n# from tgb.linkproppred.dataset import LinkPropPredDataset\r\n\r\n# # data loading\r\n# dataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\r\n# data = dataset.full_data  \r\n# metric = dataset.eval_metric\r\n# sources = dataset.full_data['sources']\r\n# print (\"finished loading numpy arrays\")\r\n"
  },
  {
    "path": "examples/linkproppred/thgl-github/edgebank.py",
    "content": "\"\"\"\nDynamic Link Prediction with EdgeBank\nNOTE: This implementation works only based on `numpy`\n\nReference: \n    - https://github.com/fpour/DGB/tree/main\n\n\n\"\"\"\n\nimport timeit\nimport numpy as np\nfrom sklearn.metrics import average_precision_score, roc_auc_score\nfrom torch_geometric.loader import TemporalDataLoader\nfrom tqdm import tqdm\nimport math\nimport os\nimport os.path as osp\nfrom pathlib import Path\nimport sys\nimport argparse\n\n# internal imports\ntgb_modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\nsys.path.append(tgb_modules_path)\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom modules.edgebank_predictor import EdgeBankPredictor\nfrom tgb.utils.utils import set_random_seed\nfrom tgb.linkproppred.dataset import LinkPropPredDataset\nfrom tgb.utils.utils import save_results\n\n# ==================\n# ==================\n# ==================\n\ndef test(data, test_mask, neg_sampler, split_mode):\n    r\"\"\"\n    Evaluated the dynamic link prediction\n    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges\n\n    Parameters:\n        data: a dataset object\n        test_mask: required masks to load the test set edges\n        neg_sampler: an object that gives the negative edges corresponding to each positive edge\n        split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives\n    Returns:\n        perf_metric: the result of the performance evaluation\n    \"\"\"\n    num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)\n    perf_list = []\n    hits_list = []\n    for batch_idx in tqdm(range(num_batches)):\n        start_idx = batch_idx * BATCH_SIZE\n        end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))\n        pos_src, pos_dst, pos_t, pos_edge = (\n            data['sources'][test_mask][start_idx: end_idx],\n            data['destinations'][test_mask][start_idx: end_idx],\n            data['timestamps'][test_mask][start_idx: end_idx],\n            data['edge_type'][test_mask][start_idx: end_idx],\n        )\n        neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, pos_edge, split_mode=split_mode)\n        \n        for idx, neg_batch in enumerate(neg_batch_list):\n            query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])\n            query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])\n\n            y_pred = edgebank.predict_link(query_src, query_dst)\n            # compute MRR\n            input_dict = {\n                \"y_pred_pos\": np.array([y_pred[0]]),\n                \"y_pred_neg\": np.array(y_pred[1:]),\n                \"eval_metric\": [metric],\n            }\n            results = evaluator.eval(input_dict)\n            perf_list.append(results[metric])\n            hits_list.append(results['hits@10'])\n            \n        # update edgebank memory after each positive batch\n        edgebank.update_memory(pos_src, pos_dst, pos_t)\n\n    perf_metrics = float(np.mean(perf_list))\n    perf_hits = float(np.mean(hits_list))\n\n    return perf_metrics, perf_hits\n\ndef get_args():\n    parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')\n    parser.add_argument('-d', '--data', type=str, help='Dataset name', default='thgl-github')\n    parser.add_argument('--bs', type=int, help='Batch size', default=200)\n    parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)\n    parser.add_argument('--seed', type=int, help='Random seed', default=1)\n    parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])\n    parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)\n\n    try:\n        args = parser.parse_args()\n    except:\n        parser.print_help()\n        sys.exit(0)\n    return args, sys.argv \n\n# ==================\n# ==================\n# ==================\n\nstart_overall = timeit.default_timer()\n\n# set hyperparameters\nargs, _ = get_args()\n\nSEED = args.seed  # set the random seed for consistency\nset_random_seed(SEED)\nMEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`\nBATCH_SIZE = args.bs\nK_VALUE = args.k_value\nTIME_WINDOW_RATIO = args.time_window_ratio\nDATA = args.data\nMODEL_NAME = 'EdgeBank'\n\n\n\n# data loading with `numpy`\ndataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\ndata = dataset.full_data  \nmetric = dataset.eval_metric\n\n\n# get masks\ntrain_mask = dataset.train_mask\nval_mask = dataset.val_mask\ntest_mask = dataset.test_mask\n\n#data for memory in edgebank\nhist_src = np.concatenate([data['sources'][train_mask]])\nhist_dst = np.concatenate([data['destinations'][train_mask]])\nhist_ts = np.concatenate([data['timestamps'][train_mask]])\n\n# Set EdgeBank with memory updater\nedgebank = EdgeBankPredictor(\n        hist_src,\n        hist_dst,\n        hist_ts,\n        memory_mode=MEMORY_MODE,\n        time_window_ratio=TIME_WINDOW_RATIO)\n\nprint(\"==========================================================\")\nprint(f\"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============\")\nprint(\"==========================================================\")\n\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'\n\n# ==================================================== Test\n# loading the validation negative samples\ndataset.load_val_ns()\n\n# testing ...\nstart_val = timeit.default_timer()\nperf_metric_val, perf_hits_val = test(data, val_mask, neg_sampler, split_mode='val')\nend_val = timeit.default_timer()\n\nprint(f\"INFO: val: Evaluation Setting: >>> ONE-VS--ALL <<< \")\nprint(f\"\\tval: {metric}: {perf_metric_val: .4f}\")\nval_time = timeit.default_timer() - start_val\nprint(f\"\\tval: Elapsed Time (s): {val_time: .4f}\")\n\n\n\n\n# ==================================================== Test\n# loading the test negative samples\ndataset.load_test_ns()\n\n# testing ...\nstart_test = timeit.default_timer()\nperf_metric_test, perf_hits_test = test(data, test_mask, neg_sampler, split_mode='test')\nend_test = timeit.default_timer()\n\nprint(f\"INFO: Test: Evaluation Setting: >>>  <<< \")\nprint(f\"\\tTest: {metric}: {perf_metric_test: .4f}\")\ntest_time = timeit.default_timer() - start_test\nprint(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\n\nsave_results({'model': MODEL_NAME,\n              'memory_mode': MEMORY_MODE,\n              'data': DATA,\n              'run': 1,\n              'seed': SEED,\n              metric: perf_metric_test,\n              'val_mrr': perf_metric_val,\n              'test_time': test_time,\n              'tot_train_val_time': test_time+val_time,\n              'hits10': perf_hits_test}, \n    results_filename)\n\n"
  },
  {
    "path": "examples/linkproppred/thgl-github/recurrencybaseline.py",
    "content": "\"\"\" from paper: \"History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting\" \nJulia Gastinger, Christian Meilicke, Federico Errica, Timo Sztyler, Anett Schuelke, Heiner Stuckenschmidt (IJCAI 2024)\n\n@inproceedings{gastinger2024baselines,\n  title={History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting},\n  author={Gastinger, Julia and Meilicke, Christian and Errica, Federico and Sztyler, Timo and Schuelke, Anett and Stuckenschmidt, Heiner},\n  booktitle={33nd International Joint Conference on Artificial Intelligence (IJCAI 2024)},\n  year={2024},\n  organization={International Joint Conferences on Artificial Intelligence Organization}\n}\n\npython recurrencybaseline.py --seed 1 --num_processes 1 -tr False\n\"\"\"\n\n## imports\nimport timeit\nimport argparse\nimport numpy as np\nfrom copy import copy\nfrom pathlib import Path\nimport ray\nimport sys\nimport os\nimport os.path as osp\nimport json\n#internal imports \nmodules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\nsys.path.append(modules_path)\nfrom modules.recurrencybaseline_predictor import baseline_predict, baseline_predict_remote\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom tgb.linkproppred.dataset import LinkPropPredDataset \nfrom tgb.utils.utils import set_random_seed,  save_results \nfrom modules.tkg_utils import create_basis_dict, group_by, reformat_ts\n\ndef predict(num_processes,  data_c_rel, all_data_c_rel, alpha, lmbda_psi,\n            perf_list_all, hits_list_all, window, neg_sampler, split_mode):\n    \"\"\" create predictions for each relation on test or valid set and compute mrr\n    :return  perf_list_all: list of mrrs for each test query\n    :return hits_list_all: list of hits for each test query\n    \"\"\"\n\n    first_ts = data_c_rel[0][3]\n    ## use this if you wanna use ray:\n    num_queries = len(data_c_rel) // num_processes\n    if num_queries < num_processes: # if we do not have enough queries for all the processes\n        num_processes_tmp = 1\n        num_queries = len(data_c_rel)\n    else:\n        num_processes_tmp = num_processes  \n    if num_processes > 1:\n        object_references =[]                   \n        \n        for i in range(num_processes_tmp):\n            num_test_queries = len(data_c_rel) - (i + 1) * num_queries\n            if num_test_queries >= num_queries:\n                test_queries_idx =[i * num_queries, (i + 1) * num_queries]\n            else:\n                test_queries_idx = [i * num_queries, len(test_data)]\n\n            valid_data_b = data_c_rel[test_queries_idx[0]:test_queries_idx[1]]\n\n            ob = baseline_predict_remote.remote(num_queries, valid_data_b, all_data_c_rel, window, \n                                basis_dict, \n                                num_nodes, num_rels, lmbda_psi, \n                                alpha, evaluator,first_ts, neg_sampler, split_mode)\n            object_references.append(ob)\n\n        output = ray.get(object_references)\n\n        # updates the scores and logging dict for each process\n        for proc_loop in range(num_processes_tmp):\n            perf_list_all.extend(output[proc_loop][0])\n            hits_list_all.extend(output[proc_loop][1])\n\n    else:\n        perf_list, hits_list = baseline_predict(len(data_c_rel), data_c_rel, all_data_c_rel, \n                            window, basis_dict, \n                            num_nodes, num_rels, lmbda_psi, \n                            alpha, evaluator, first_ts, neg_sampler, split_mode)                  \n        perf_list_all.extend(perf_list)\n        hits_list_all.extend(hits_list)\n    \n    return perf_list_all, hits_list_all\n\n## test\ndef test(best_config, all_relations,test_data_prel, all_data_prel, neg_sampler, num_processes, window, split_mode='test'):  \n    \"\"\" create predictions by loopoing through all relations on test or valid set and compute mrr\n    :return  perf_list_all: list of mrrs for each test query\n    :return hits_list_all: list of hits for each test query\n    \"\"\"       \n    perf_list_all = []\n    hits_list_all =[]\n    \n    csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'\n\n    ## loop through relations and apply baselines\n    \n    for rel in all_relations:\n        start =  timeit.default_timer()\n        if rel in test_data_prel.keys():\n            lmbda_psi = best_config[str(rel)]['lmbda_psi'][0]\n            alpha = best_config[str(rel)]['alpha'][0]\n\n            # test data for this relation\n            test_data_c_rel = test_data_prel[rel]\n            timesteps_test = list(set(test_data_c_rel[:,3]))\n            timesteps_test.sort()\n            all_data_c_rel = all_data_prel[rel]         \n            perf_list_rel = []\n            hits_list_rel = []\n            perf_list_rel, hits_list_rel = predict(num_processes, test_data_c_rel,\n                                                all_data_c_rel, alpha, lmbda_psi,perf_list_rel, hits_list_rel, \n                                                window, neg_sampler, split_mode)\n            perf_list_all.extend(perf_list_rel)\n            hits_list_all.extend(hits_list_rel)\n        else:\n            perf_list_rel =[]\n         \n        end =  timeit.default_timer()\n        total_time = round(end - start, 6)  \n        print(\"Relation {} finished in {} seconds.\".format(rel, total_time))\n        \n\n        with open(csv_file, 'a') as f:\n            f.write(\"{},{}\\n\".format(rel, perf_list_rel))\n\n    return perf_list_all, hits_list_all\n\ndef read_dict_compute_mrr(split_mode='test'):\n    \"\"\" read the results per relation  from a precreated file and compute mrr\n    :return  mrr_per_rel: dictionary of mrrs for each relation\n    :return all_mrrs: list of mrrs for all relations\n    \"\"\"\n    csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'\n    # Initialize an empty dictionary to store the data\n    results_per_rel_dict = {}\n    mrr_per_rel = {}\n    all_mrrs = []\n    # Open the file for reading\n    with open(csv_file, 'r') as f:\n        # Read each line in the file\n        for line in f:\n            # Split the line at the comma\n            parts = line.strip().split(',')\n            # Extract the key (the first part)\n            key = int(parts[0])\n            # Extract the values (the rest of the parts), remove square brackets\n            values = [float(value.strip('[]')) for value in parts[1:]]\n            # Add the key-value pair to the dictionary\n            if key in results_per_rel_dict.keys():\n                print(f\"Key {key} already exists in the dictionary!!! might have duplicate entries in results csv\")\n            results_per_rel_dict[key] = values\n            all_mrrs.extend(values)\n            mrr_per_rel[key] = np.mean(values)\n\n    if len(list(results_per_rel_dict.keys())) != num_rels:\n        print(\"we do not have entries for each rel in the results csv file. only num enties: \", len(list(results_per_rel_dict.keys())))\n\n    print(\"Split mode: \"+split_mode +\" Mean MRR: \", np.mean(all_mrrs))\n    print(\"mrr per relation: \", mrr_per_rel)\n\n\n\n## train\ndef train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampler, num_processes, window):\n    \"\"\" optional, find best values for lambda and alpha by looping through all relations and testing an a fixed set of params\n    based on validation mrr\n    :return best_config: dictionary of best params for each relation\n    \"\"\"\n    best_config= {}\n    best_mrr = 0\n    for rel in rels: # loop through relations. for each relation, apply rules with selected params, compute valid mrr\n        start =  timeit.default_timer()\n        rel_key = int(rel)            \n\n        best_config[str(rel_key)] = {}\n        best_config[str(rel_key)]['not_trained'] = 'True'    \n        best_config[str(rel_key)]['lmbda_psi'] = [default_lmbda_psi,0] #default\n        best_config[str(rel_key)]['other_lmbda_mrrs'] = list(np.zeros(len(params_dict['lmbda_psi'])))\n        best_config[str(rel_key)]['alpha'] = [default_alpha,0]  #default    \n        best_config[str(rel_key)]['other_alpha_mrrs'] = list(np.zeros(len(params_dict['alpha'])))\n        \n        if rel in val_data_prel.keys():      \n            # valid data for this relation  \n            val_data_c_rel = val_data_prel[rel]\n            timesteps_valid = list(set(val_data_c_rel[:,3]))\n            timesteps_valid.sort()\n            trainval_data_c_rel = trainval_data_prel[rel]\n\n            ######  1) select lambda ###############        \n            lmbdas_psi = params_dict['lmbda_psi']        \n\n            alpha = 1\n            best_lmbda_psi = 0.1\n            best_mrr_psi = 0\n            lmbda_mrrs = []\n\n            best_config[str(rel_key)]['num_app_valid'] = copy(len(val_data_c_rel))\n            best_config[str(rel_key)]['num_app_train_valid'] = copy(len(trainval_data_c_rel))         \n            best_config[str(rel_key)]['not_trained'] = 'False'       \n            \n            for lmbda_psi in lmbdas_psi:   \n                perf_list_r = []\n                hits_list_r = []\n                perf_list_r, hits_list_r = predict(num_processes, val_data_c_rel, \n                                                    trainval_data_c_rel, alpha, lmbda_psi,perf_list_r, hits_list_r, \n                                                    window, neg_sampler, split_mode='val')\n                # compute mrr\n                mrr = np.mean(perf_list_r)\n                # # is new mrr better than previous best? if yes: store lmbda\n                if mrr > best_mrr_psi:\n                    best_mrr_psi = float(mrr)\n                    best_lmbda_psi = lmbda_psi\n\n\n                lmbda_mrrs.append(float(mrr))\n            best_config[str(rel_key)]['lmbda_psi'] = [best_lmbda_psi, best_mrr_psi]\n            best_config[str(rel_key)]['other_lmbda_mrrs'] = lmbda_mrrs\n            best_mrr = best_mrr_psi\n            ##### 2) select alpha ###############\n            best_config[str(rel_key)]['not_trained'] = 'False'    \n            alphas = params_dict['alpha'] \n            lmbda_psi = best_config[str(rel_key)]['lmbda_psi'][0] # use the best lmbda psi\n\n            alpha_mrrs = []\n            # perf_list_all = []\n            best_mrr_alpha = 0\n            best_alpha=0.99\n            for alpha in alphas:\n                perf_list_r = []\n                hits_list_r = []\n\n                perf_list_r, hits_list_r = predict(num_processes, val_data_c_rel, \n                                                    trainval_data_c_rel, alpha, lmbda_psi,perf_list_r, hits_list_r, \n                                                    window, neg_sampler, split_mode='val')\n                # compute mrr\n                mrr_alpha = np.mean(perf_list_r)\n\n                # is new mrr better than previous best? if yes: store alpha\n                if mrr_alpha > best_mrr_alpha:\n                    best_mrr_alpha = float(mrr_alpha)\n                    best_alpha = alpha\n                    best_mrr = best_mrr_alpha\n                alpha_mrrs.append(float(mrr_alpha))\n\n            best_config[str(rel_key)]['alpha'] = [best_alpha, best_mrr_alpha]\n            best_config[str(rel_key)]['other_alpha_mrrs'] = alpha_mrrs\n\n        end =  timeit.default_timer()\n        total_time = round(end - start, 6)  \n        print(\"Relation {} finished in {} seconds.\".format(rel, total_time))\n    return best_config\n\n\n\n## args\ndef get_args(): \n    \"\"\"parse all arguments for the script\"\"\"\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--dataset\", \"-d\", default=\"thgl-github\", type=str) \n    parser.add_argument(\"--window\", \"-w\", default=0, type=int) # set to e.g. 200 if only the most recent 200 timesteps should be considered. set to -2 if multistep\n    parser.add_argument(\"--num_processes\", \"-p\", default=1, type=int)\n    parser.add_argument(\"--lmbda\", \"-l\",  default=0.1, type=float) # fix lambda. used if trainflag == false\n    parser.add_argument(\"--alpha\", \"-alpha\",  default=0.99, type=float) # fix alpha. used if trainflag == false\n    parser.add_argument(\"--train_flag\", \"-tr\",  default='False') # do we need training, ie selection of lambda and alpha\n    parser.add_argument(\"--load_flag\", \"-lo\",  default='False') # if train_flag set to True: do you want to load best_config?\n    parser.add_argument(\"--save_config\", \"-c\",  default='True') # do we need to save the selection of lambda and alpha in config file?\n    parser.add_argument('--seed', type=int, help='Random seed', default=1) # not needed\n    parsed = vars(parser.parse_args())\n    return parsed\n\nstart_o =  timeit.default_timer()\n\nparsed = get_args()\nif parsed['num_processes']>1:\n    ray.init(num_cpus=parsed[\"num_processes\"], num_gpus=0)\nMODEL_NAME = 'RecurrencyBaseline'\nSEED = parsed['seed']  # set the random seed for consistency\nset_random_seed(SEED)\nperrel_results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_models'\nif not osp.exists(perrel_results_path):\n    os.mkdir(perrel_results_path)\n    print('INFO: Create directory {}'.format(perrel_results_path))\nPath(perrel_results_path).mkdir(parents=True, exist_ok=True)\n\n## load dataset and prepare it accordingly\nname = parsed[\"dataset\"]\ndataset = LinkPropPredDataset(name=name, root=\"datasets\", preprocess=True)\nDATA = name\n\nrelations = dataset.edge_type\nnum_rels = dataset.num_rels\nrels = np.arange(0,num_rels)\nsubjects = dataset.full_data[\"sources\"]\nobjects= dataset.full_data[\"destinations\"]\nnum_nodes = dataset.num_nodes \ntimestamps_orig = dataset.full_data[\"timestamps\"]\ntimestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1\n\nprint(\"split train valid test data\")\nall_quads = np.stack((subjects, relations, objects, timestamps, timestamps_orig), axis=1)\ntrain_data = all_quads[dataset.train_mask]\nval_data = all_quads[dataset.val_mask]\ntest_data = all_quads[dataset.test_mask]\n\nmetric = dataset.eval_metric\nevaluator = Evaluator(name=name)\nneg_sampler = dataset.negative_sampler\n\ntrain_val_data = np.concatenate([train_data, val_data])\nall_data = np.concatenate([train_data, val_data, test_data])\n\n# create dicts with key: relation id, values: triples for that relation id\nprint(\"grouping data by relation\")\ntest_data_prel = group_by(test_data, 1)\nall_data_prel = group_by(all_data, 1)\nval_data_prel = group_by(val_data, 1)\ntrainval_data_prel = group_by(train_val_data, 1)\n\n#load the ns samples \n# if parsed['train_flag']:\nprint(\"loading negative samples\")\ndataset.load_val_ns()\ndataset.load_test_ns()\n\n# parameter options\nif parsed['train_flag'] == 'True':\n    params_dict = {}\n    params_dict['lmbda_psi'] = [0, 0.0001, 0.0005, 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.5, 0.9, 1.0001] \n    params_dict['alpha'] = [0, 0.00001, 0.0001, 0.001, 0.01, 0.1, 0.5, 0.9, 0.99, 0.999, 0.9999, 0.99999, 1]\n    default_lmbda_psi = params_dict['lmbda_psi'][-1]\n    default_alpha = params_dict['alpha'][-2]\n\n## load rules\nprint(\"creating rules\")\nbasis_dict = create_basis_dict(train_val_data)\nprint(\"done with creating rules\")\n## init\n# rb_predictor = RecurrencyBaselinePredictor(rels)\n## train to find best lambda and alpha\nstart_train =  timeit.default_timer()\nif parsed['train_flag'] ==  'True':\n    if parsed['load_flag'] == 'True':\n        with open('best_config.json', 'r') as infile:\n            best_config = json.load(infile)\n    else:\n        print('start training')\n        best_config = train(params_dict,  rels, val_data_prel, trainval_data_prel, neg_sampler, parsed['num_processes'], \n            parsed['window'])\n        if parsed['save_config'] == 'True':\n            import json\n            with open('best_config.json', 'w') as outfile:\n                json.dump(best_config, outfile)\n\nelse: # use preset lmbda and alpha; same for all relations\n    best_config = {} \n    for rel in rels:\n        best_config[str(rel)] = {}\n        best_config[str(rel)]['lmbda_psi'] = [parsed['lmbda']]\n        best_config[str(rel)]['alpha'] = [parsed['alpha']]\n    \nend_train =  timeit.default_timer()\n\n# compute validation mrr\nprint(\"Computing validation MRR\")\nperf_list_all_val, hits_list_all_val = test(best_config,rels, val_data_prel, \n                                                trainval_data_prel, neg_sampler, parsed['num_processes'], \n                                            parsed['window'], split_mode='val')\nval_mrr = float(np.mean(perf_list_all_val))\n\n# compute test mrr\nprint(\"Computing test MRR\")\nstart_test =  timeit.default_timer()\nperf_list_all, hits_list_all = test(best_config,rels, test_data_prel, \n                                                 all_data_prel, neg_sampler, parsed['num_processes'], \n                                                parsed['window'])\n\nend_o =  timeit.default_timer()\ntrain_time_o = round(end_train- start_train, 6)  \ntest_time_o = round(end_o- start_test, 6)  \ntotal_time_o = round(end_o- start_o, 6)  \nprint(\"Running Training to find best configs finished in {} seconds.\".format(train_time_o))\nprint(\"Running testing with best configs finished in {} seconds.\".format(test_time_o))\nprint(\"Running all steps finished in {} seconds.\".format(total_time_o))\n\nprint(f\"The test MRR is {np.mean(perf_list_all)}\")\nprint(f\"The valid MRR is {val_mrr}\")\nprint(f\"The Hits@10 is {np.mean(hits_list_all)}\")\nprint(f\"We have {len(perf_list_all)} predictions\")\nprint(f\"The test set has len {len(test_data)} \")\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_NONE_{DATA}_results.json'\nmetric = dataset.eval_metric\nsave_results({'model': MODEL_NAME,\n              'train_flag': parsed['train_flag'],\n              'data': DATA,\n              'run': 1,\n              'seed': SEED,\n              metric: float(np.mean(perf_list_all)),\n              'val_mrr': val_mrr,\n              'hits10': float(np.mean(hits_list_all)),\n              'test_time': test_time_o,\n              'tot_train_val_time': total_time_o\n              }, \n    results_filename)\n\nif parsed['num_processes']>1:\n    ray.shutdown()\n\n\n    \n"
  },
  {
    "path": "examples/linkproppred/thgl-github/run_seeds.sh",
    "content": "python -u tgn.py --seed 1 --mem_dim 16 --time_dim 16 --emb_dim 16 --num_epoch 5 --num_run 1 | tee tgn_s1_new_output.txt\n\npython -u tgn.py --seed 2 --mem_dim 16 --time_dim 16 --emb_dim 16 --num_epoch 5 --num_run 1 | tee tgn_s2_new_output.txt\n\npython -u tgn.py --seed 3 --mem_dim 16 --time_dim 16 --emb_dim 16 --num_epoch 5 --num_run 1 | tee tgn_s3_new_output.txt\n\npython -u tgn.py --seed 4 --mem_dim 16 --time_dim 16 --emb_dim 16 --num_epoch 5 --num_run 1 | tee tgn_s4_new_output.txt\n\npython -u tgn.py --seed 5 --mem_dim 16 --time_dim 16 --emb_dim 16 --num_epoch 5 --num_run 1 | tee tgn_s5_new_output.txt"
  },
  {
    "path": "examples/linkproppred/thgl-github/sthn.py",
    "content": "import timeit\nimport numpy as np\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\nfrom tgb.linkproppred.evaluate import Evaluator\n\nimport argparse\nfrom modules.sthn import set_seed, pre_compute_subgraphs, get_inputs_for_ind, check_data_leakage\nimport torch\nimport pandas as pd\nimport itertools\nfrom tqdm import tqdm\nimport math\nimport os\nimport os.path as osp\nfrom pathlib import Path\nfrom tgb.utils.utils import set_random_seed, save_results\n\n\n# Start...\nstart_overall = timeit.default_timer()\n\nDATA = \"thgl-github\"\n\nMODEL_NAME = 'STHN'\n\n# data loading\ndataset = PyGLinkPropPredDataset(name=DATA, root=\"datasets\")\ntrain_mask = dataset.train_mask\nval_mask = dataset.val_mask\ntest_mask = dataset.test_mask\ndata = dataset.get_TemporalData()\nmetric = dataset.eval_metric\n\nprint (\"there are {} nodes and {} edges\".format(dataset.num_nodes, dataset.num_edges))\nprint (\"there are {} relation types\".format(dataset.num_rels))\n\n\ntimestamp = data.t\nhead = data.src\ntail = data.dst\nedge_type = data.edge_type\nneg_sampler = dataset.negative_sampler\n\nprint(data)\nprint(timestamp)\nprint(head)\nprint(tail)\n\ntrain_data = data[train_mask]\nval_data = data[val_mask]\ntest_data = data[test_mask]\n\n\nmetric = dataset.eval_metric\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'\n\n\n####################################################################\n####################################################################\n####################################################################\n\n\ndef print_model_info(model):\n    print(model)\n    parameters = filter(lambda p: p.requires_grad, model.parameters())\n    parameters = sum([np.prod(p.size()) for p in parameters])\n    print('Trainable Parameters: %d' % parameters)\n\ndef get_args():\n    parser=argparse.ArgumentParser()\n    parser.add_argument('--data', type=str, default='movie')\n    parser.add_argument('--device', type=int, default=0)\n    parser.add_argument('--batch_size', type=int, default=600)\n    parser.add_argument('--epochs', type=int, default=2)\n    parser.add_argument('--max_edges', type=int, default=50)\n    parser.add_argument('--num_edgeType', type=int, default=0, help='num of edgeType')\n    parser.add_argument('--lr', type=float, default=0.0005)\n    parser.add_argument('--weight_decay', type=float, default=1e-4)\n    parser.add_argument('--predict_class', action='store_true')\n    \n    # model\n    parser.add_argument('--window_size', type=int, default=5)\n    parser.add_argument('--dropout', type=float, default=0.1)\n    parser.add_argument('--model', type=str, default='sthn') \n    parser.add_argument('--neg_samples', type=int, default=1)\n    parser.add_argument('--extra_neg_samples', type=int, default=5)\n    parser.add_argument('--num_neighbors', type=int, default=50)\n    parser.add_argument('--channel_expansion_factor', type=int, default=2)\n    parser.add_argument('--sampled_num_hops', type=int, default=1)\n    parser.add_argument('--time_dims', type=int, default=100)\n    parser.add_argument('--hidden_dims', type=int, default=100)\n    parser.add_argument('--num_layers', type=int, default=1)\n    parser.add_argument('--check_data_leakage', action='store_true')\n    \n    parser.add_argument('--ignore_node_feats', action='store_true')\n    parser.add_argument('--node_feats_as_edge_feats', action='store_true')\n    parser.add_argument('--ignore_edge_feats', action='store_true')\n    parser.add_argument('--use_onehot_node_feats', action='store_true')\n    parser.add_argument('--use_type_feats', action='store_true')\n\n    parser.add_argument('--use_graph_structure', action='store_true')\n    parser.add_argument('--structure_time_gap', type=int, default=2000)\n    parser.add_argument('--structure_hops', type=int, default=1) \n\n    parser.add_argument('--use_node_cls', action='store_true')\n    parser.add_argument('--use_cached_subgraph', action='store_true')\n    \n    parser.add_argument('--seed', type=int, help='Random seed', default=1)\n    parser.add_argument('--num_run', type=int, help='Number of iteration runs', default=5)\n    return parser.parse_args()\n\n\ndef load_model(args):\n    # get model\n    edge_predictor_configs = {\n        'dim_in_time': args.time_dims,\n        'dim_in_node': args.node_feat_dims,\n        'predict_class': 1 if not args.predict_class else args.num_edgeType+1,\n    }\n    if args.model == 'sthn':\n        if args.predict_class:\n            from modules.sthn import Multiclass_Interface as STHN_Interface\n        else:\n            from modules.sthn import STHN_Interface\n        from modules.sthn import link_pred_train\n\n        mixer_configs = {\n            'per_graph_size'  : args.max_edges,\n            'time_channels'   : args.time_dims, \n            'input_channels'  : args.edge_feat_dims, \n            'hidden_channels' : args.hidden_dims, \n            'out_channels'    : args.hidden_dims,\n            'num_layers'      : args.num_layers,\n            'dropout'         : args.dropout,\n            'channel_expansion_factor': args.channel_expansion_factor,\n            'window_size'     : args.window_size,\n            'use_single_layer' : False\n        }  \n        \n    else:\n        NotImplementedError()\n\n    model = STHN_Interface(mixer_configs, edge_predictor_configs)\n    for k, v in model.named_parameters():\n        print(k, v.requires_grad)\n\n    print_model_info(model)\n\n    return model, args, link_pred_train\n\ndef load_graph(data):\n    df = pd.DataFrame({\n        'idx': np.arange(len(data.t)),\n        'src': data.src,\n        'dst': data.dst,\n        'time': data.t,\n        'label': data.edge_type,\n    })\n\n    num_nodes = max(int(df['src'].max()), int(df['dst'].max())) + 1 \n\n    ext_full_indptr = np.zeros(num_nodes + 1, dtype=np.int32)\n    ext_full_indices = [[] for _ in range(num_nodes)]\n    ext_full_ts = [[] for _ in range(num_nodes)]\n    ext_full_eid = [[] for _ in range(num_nodes)]\n\n    for idx, row in tqdm(df.iterrows(), total=len(df)):\n        src = int(row['src'])\n        dst = int(row['dst'])\n        \n        ext_full_indices[src].append(dst)\n        ext_full_ts[src].append(row['time'])\n        ext_full_eid[src].append(idx)\n        \n    for i in tqdm(range(num_nodes)):\n        ext_full_indptr[i + 1] = ext_full_indptr[i] + len(ext_full_indices[i])\n\n    ext_full_indices = np.array(list(itertools.chain(*ext_full_indices)))\n    ext_full_ts = np.array(list(itertools.chain(*ext_full_ts)))\n    ext_full_eid = np.array(list(itertools.chain(*ext_full_eid)))\n\n    print('Sorting...')\n\n    def tsort(i, indptr, indices, t, eid):\n        beg = indptr[i]\n        end = indptr[i + 1]\n        sidx = np.argsort(t[beg:end])\n        indices[beg:end] = indices[beg:end][sidx]\n        t[beg:end] = t[beg:end][sidx]\n        eid[beg:end] = eid[beg:end][sidx] \n\n    for i in tqdm(range(ext_full_indptr.shape[0] - 1)):\n        tsort(i, ext_full_indptr, ext_full_indices, ext_full_ts, ext_full_eid)\n\n    print('saving...')\n\n    np.savez('/tmp/ext_full.npz', indptr=ext_full_indptr,\n            indices=ext_full_indices, ts=ext_full_ts, eid=ext_full_eid)\n    g = np.load('/tmp/ext_full.npz')\n    return g, df\n\ndef load_all_data(args):\n\n    # load graph\n    g, df = load_graph(data)\n\n    args.train_mask = train_mask.numpy()\n    args.val_mask   = val_mask.numpy()\n    args.test_mask = test_mask.numpy()\n    args.num_edges = len(df)\n\n    print('Train %d, Valid %d, Test %d'%(sum(args.train_mask), \n                                         sum(args.val_mask),\n                                         sum(test_mask)))\n    \n    args.num_nodes = max(int(df['src'].max()), int(df['dst'].max())) + 1\n    args.num_edges = len(df)\n\n    print('Num nodes %d, num edges %d'%(args.num_nodes, args.num_edges))\n\n    # load feats \n    node_feats, edge_feats = dataset.node_feat, dataset.edge_feat\n    node_feat_dims = 0 if node_feats is None else node_feats.shape[1]\n    edge_feat_dims = 0 if edge_feats is None else edge_feats.shape[1]\n\n    # feature pre-processing\n    if args.use_onehot_node_feats:\n        print('>>> Use one-hot node features')\n        node_feats = torch.eye(args.num_nodes)\n        node_feat_dims = node_feats.size(1)\n\n    if args.ignore_node_feats:\n        print('>>> Ignore node features')\n        node_feats = None\n        node_feat_dims = 0\n\n    if args.use_type_feats:\n        edge_type = df.label.values\n        print(edge_type)\n        print(edge_type.sum())\n        args.num_edgeType = len(set(edge_type.tolist()))\n        edge_feats = torch.nn.functional.one_hot(torch.from_numpy(edge_type), \n                                                 num_classes=args.num_edgeType)\n        edge_feat_dims = edge_feats.size(1)\n        \n    print('Node feature dim %d, edge feature dim %d'%(node_feat_dims, edge_feat_dims))\n    \n    # double check (if data leakage then cannot continue the code)\n    if args.check_data_leakage:\n        check_data_leakage(args, g, df)\n\n    args.node_feat_dims = node_feat_dims\n    args.edge_feat_dims = edge_feat_dims\n    \n    if node_feats != None:\n        node_feats = node_feats.to(args.device)\n    if edge_feats != None:\n        edge_feats = edge_feats.to(args.device)\n    \n    return node_feats, edge_feats, g, df, args\n\n####################################################################\n####################################################################\n####################################################################\n\n@torch.no_grad()\ndef test(data, test_mask, model, neg_sampler, split_mode):\n    r\"\"\"\n    Evaluated the dynamic link prediction\n    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges\n\n    Parameters:\n        data: a dataset object\n        test_mask: required masks to load the test set edges\n        neg_sampler: an object that gives the negative edges corresponding to each positive edge\n        split_mode: specifies whether it is the 'val' or 'test' set to correctly load the negatives\n    Returns:\n        perf_metric: the result of the performance evaluation\n    \"\"\"\n    test_subgraphs  = pre_compute_subgraphs(args, g, df, mode='test' if split_mode == 'test' else 'valid', negative_sampler=neg_sampler, split_mode=split_mode)\n    perf_list = []\n    \n    if split_mode == 'test':\n        cur_df = df[args.test_mask]\n    elif split_mode == 'val':\n        cur_df = df[args.val_mask]\n    neg_samples = 20\n    cached_neg_samples = 20\n\n    test_loader = cur_df.groupby(cur_df.index // args.batch_size)\n    pbar = tqdm(total=len(test_loader))\n    pbar.set_description('%s mode with negative samples %d ...'%(split_mode, neg_samples))        \n    \n    ###################################################\n    # compute + training + fetch all scores\n    cur_inds = 0\n\n    for ind in range(len(test_loader)):\n        ###################################################\n        inputs, subgraph_node_feats, cur_inds = get_inputs_for_ind(test_subgraphs, 'test' if split_mode == 'test' else 'tgb-val', cached_neg_samples, neg_samples, node_feats, edge_feats, cur_df, cur_inds, ind, args)\n        \n        loss, pred, edge_label = model(inputs, neg_samples, subgraph_node_feats)\n        # print(ind, [l for l in inputs], pred.shape)\n\n        input_dict = {\n            \"y_pred_pos\": np.array([pred.cpu()[0]]),\n            \"y_pred_neg\": np.array(pred.cpu()[1:]),\n            \"eval_metric\": [metric],\n        }\n        perf_list.append(evaluator.eval(input_dict)[metric])\n\n    perf_metrics_mean = float(np.mean(perf_list))\n    perf_metrics_std = float(np.std(perf_list))\n\n    return perf_metrics_mean, perf_metrics_std, perf_list\n\n\nargs = get_args()\n\nargs.use_graph_structure = True\nargs.use_onehot_node_feats = False\nargs.ignore_node_feats = False # we only use graph structure\nargs.use_type_feats = True # type encoding\nargs.use_cached_subgraph = True\n\nprint(args)\n\nargs.device = f\"cuda:{args.device}\" if torch.cuda.is_available() else \"cpu\"\nargs.device = torch.device(args.device)\nSEED = args.seed\nBATCH_SIZE = args.batch_size\nNUM_RUNS = args.num_run\nset_seed(SEED)\n\n\n###################################################\n# load feats + graph\nnode_feats, edge_feats, g, df, args = load_all_data(args)\n\n###################################################\n# get model \nmodel, args, link_pred_train = load_model(args)\n\n###################################################\n\nprint(\"==========================================================\")\nprint(f\"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============\")\nprint(\"==========================================================\")\n\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'\n\n\n\nfor run_idx in range(NUM_RUNS):\n    print('-------------------------------------------------------------------------------')\n    print(f\"INFO: >>>>> Run: {run_idx} <<<<<\")\n    start_run = timeit.default_timer()\n\n    # set the seed for deterministic results...\n    torch.manual_seed(run_idx + SEED)\n    set_random_seed(run_idx + SEED)\n\n    # define an early stopper\n    save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'\n    save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'\n    # early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id, \n    #                                 tolerance=TOLERANCE, patience=PATIENCE)\n\n    # ==================================================== Train & Validation\n    # loading the validation negative samples\n\n    # Link prediction\n    start_val = timeit.default_timer()\n    print('Train link prediction task from scratch ...')\n    model = link_pred_train(model.to(args.device), args, g, df, node_feats, edge_feats)\n\n    dataset.load_val_ns()\n\n    # Validation ...\n    \n    perf_metrics_val_mean, perf_metrics_val_std, perf_list_val = test(data.to(args.device), test_mask, model.to(args.device), neg_sampler, split_mode='val')\n    end_val = timeit.default_timer()\n\n    print(f\"INFO: val: Evaluation Setting: >>> ONE-VS-MANY <<< \")\n    print(f\"\\tval: {metric}: {perf_metrics_val_mean: .4f} ± {perf_metrics_val_std: .4f}\")\n    val_time = timeit.default_timer() - start_val\n    print(f\"\\tval: Elapsed Time (s): {val_time: .4f}\")\n\n\n    dataset.load_test_ns()\n\n    # testing ...\n    start_test = timeit.default_timer()\n    perf_metrics_test_mean, perf_metrics_test_std, perf_list_test = test(data.to(args.device), test_mask, model.to(args.device), neg_sampler, split_mode='test')\n    end_test = timeit.default_timer()\n\n    print(f\"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< \")\n    print(f\"\\tTest: {metric}: {perf_metrics_test_mean: .4f} ± {perf_metrics_test_std: .4f}\")\n    test_time = timeit.default_timer() - start_test\n    print(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\n\n    save_results({'model': MODEL_NAME,\n                  'data': DATA,\n                  'run': run_idx,\n                  'seed': SEED,\n                  f'val {metric}': f'{perf_metrics_val_mean: .4f} ± {perf_metrics_val_std: .4f}' ,\n                  f'test {metric}': f'{perf_metrics_test_mean: .4f} ± {perf_metrics_test_std: .4f}' ,\n                  'test_time': test_time,\n                  'tot_train_val_time': val_time\n                  }, \n    results_filename)\n\n    print(f\"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<\")\n    print('-------------------------------------------------------------------------------')\n\nprint(f\"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}\")\nprint(\"==============================================================\")\n\n# save_results({'model': MODEL_NAME,\n#             'data': DATA,\n#             'run': 1,\n#             'seed': SEED,\n#             metric: perf_metric_test,\n#             'test_time': test_time,\n#             'tot_train_val_time': 'NA'\n#             }, \n#     results_filename)"
  },
  {
    "path": "examples/linkproppred/thgl-github/tgn.py",
    "content": "\"\"\"\r\npython -u tgn.py --seed 1 --mem_dim 16 --time_dim 16 --emb_dim 16 --num_epoch 5 | tee tgn_s1_github_output.txt\r\n\"\"\"\r\n\r\nimport numpy as np\r\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\r\nfrom tgb.linkproppred.evaluate import Evaluator\r\nfrom torch_geometric.loader import TemporalDataLoader\r\nfrom tqdm import tqdm\r\nimport timeit\r\n\r\n\r\nimport math\r\nimport timeit\r\n\r\nimport os\r\nimport os.path as osp\r\nfrom pathlib import Path\r\nimport numpy as np\r\n\r\nimport torch\r\nfrom sklearn.metrics import average_precision_score, roc_auc_score\r\nfrom torch.nn import Linear\r\n\r\nfrom torch_geometric.datasets import JODIEDataset\r\nfrom torch_geometric.loader import TemporalDataLoader\r\nfrom torch_geometric.nn import TransformerConv\r\n\r\n# internal imports\r\nfrom tgb.utils.utils import get_args, set_random_seed, save_results\r\nfrom tgb.linkproppred.evaluate import Evaluator\r\nfrom modules.decoder import LinkPredictor\r\nfrom modules.emb_module import GraphAttentionEmbedding\r\nfrom modules.msg_func import IdentityMessage\r\nfrom modules.msg_agg import LastAggregator\r\nfrom modules.neighbor_loader import LastNeighborLoader\r\nfrom modules.memory_module import TGNMemory\r\nfrom modules.early_stopping import  EarlyStopMonitor\r\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\r\n\r\n\r\n# ==========\r\n# ========== Define helper function...\r\n# ==========\r\n\r\ndef train():\r\n    r\"\"\"\r\n    Training procedure for TGN model\r\n    This function uses some objects that are globally defined in the current scrips \r\n\r\n    Parameters:\r\n        None\r\n    Returns:\r\n        None\r\n            \r\n    \"\"\"\r\n\r\n    model['memory'].train()\r\n    model['gnn'].train()\r\n    model['link_pred'].train()\r\n\r\n    model['memory'].reset_state()  # Start with a fresh memory.\r\n    neighbor_loader.reset_state()  # Start with an empty graph.\r\n\r\n    total_loss = 0\r\n    for batch in train_loader:\r\n        batch = batch.to(device)\r\n        optimizer.zero_grad()\r\n\r\n        src, pos_dst, t, msg, rel = batch.src, batch.dst, batch.t, batch.msg, batch.edge_type\r\n\r\n        # Sample negative destination nodes.\r\n        neg_dst = torch.randint(\r\n            min_dst_idx,\r\n            max_dst_idx + 1,\r\n            (src.size(0),),\r\n            dtype=torch.long,\r\n            device=device,\r\n        )\r\n\r\n        n_id = torch.cat([src, pos_dst, neg_dst]).unique()\r\n        n_id, edge_index, e_id = neighbor_loader(n_id)\r\n        assoc[n_id] = torch.arange(n_id.size(0), device=device)\r\n\r\n        # Get updated memory of all nodes involved in the computation.\r\n        z, last_update = model['memory'](n_id)\r\n        z = model['gnn'](\r\n            z,\r\n            last_update,\r\n            edge_index,\r\n            data.t[e_id].to(device),\r\n            data.msg[e_id].to(device),\r\n        )\r\n\r\n        pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])\r\n        neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])\r\n\r\n        loss = criterion(pos_out, torch.ones_like(pos_out))\r\n        loss += criterion(neg_out, torch.zeros_like(neg_out))\r\n\r\n        # Update memory and neighbor loader with ground-truth state.\r\n        model['memory'].update_state(src, pos_dst, t, msg)\r\n        neighbor_loader.insert(src, pos_dst)\r\n\r\n        loss.backward()\r\n        optimizer.step()\r\n        model['memory'].detach()\r\n        total_loss += float(loss.detach()) * batch.num_events\r\n\r\n    return total_loss / train_data.num_events\r\n\r\n\r\n@torch.no_grad()\r\ndef test(loader, neg_sampler, split_mode):\r\n    r\"\"\"\r\n    Evaluated the dynamic link prediction\r\n    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges\r\n\r\n    Parameters:\r\n        loader: an object containing positive attributes of the positive edges of the evaluation set\r\n        neg_sampler: an object that gives the negative edges corresponding to each positive edge\r\n        split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives\r\n    Returns:\r\n        perf_metric: the result of the performance evaluaiton\r\n    \"\"\"\r\n    model['memory'].eval()\r\n    model['gnn'].eval()\r\n    model['link_pred'].eval()\r\n\r\n    perf_list = []\r\n\r\n    for pos_batch in loader:\r\n        pos_src, pos_dst, pos_t, pos_msg, pos_rel = (\r\n            pos_batch.src,\r\n            pos_batch.dst,\r\n            pos_batch.t,\r\n            pos_batch.msg,\r\n            pos_batch.edge_type\r\n        )\r\n\r\n        neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, pos_rel, split_mode=split_mode)\r\n\r\n        # pos_msg_new = torch.cat([pos_msg,pos_rel.unsqueeze(dim=1)], dim=1)   \r\n\r\n\r\n        for idx, neg_batch in enumerate(neg_batch_list):\r\n            src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)\r\n            dst = torch.tensor(\r\n                np.concatenate(\r\n                    ([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),\r\n                    axis=0,\r\n                ),\r\n                device=device,\r\n            )\r\n\r\n            n_id = torch.cat([src, dst]).unique()\r\n            n_id, edge_index, e_id = neighbor_loader(n_id)\r\n            assoc[n_id] = torch.arange(n_id.size(0), device=device)\r\n\r\n            # Get updated memory of all nodes involved in the computation.\r\n            z, last_update = model['memory'](n_id)\r\n            z = model['gnn'](\r\n                z,\r\n                last_update,\r\n                edge_index,\r\n                data.t[e_id].to(device),\r\n                data.msg[e_id].to(device),\r\n            )\r\n\r\n            y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])\r\n\r\n            # compute MRR\r\n            input_dict = {\r\n                \"y_pred_pos\": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),\r\n                \"y_pred_neg\": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),\r\n                \"eval_metric\": [metric],\r\n            }\r\n            perf_list.append(evaluator.eval(input_dict)[metric])\r\n\r\n        # Update memory and neighbor loader with ground-truth state.\r\n        model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg)\r\n        neighbor_loader.insert(pos_src, pos_dst)\r\n\r\n    perf_metrics = float(torch.tensor(perf_list).mean())\r\n\r\n    return perf_metrics\r\n\r\n# ==========\r\n# ==========\r\n# ==========\r\n\r\n\r\n# ==========\r\n# ==========\r\n# ==========\r\n\r\n\r\n# Start...\r\nstart_overall = timeit.default_timer()\r\nDATA = \"thgl-github\"\r\n\r\n\r\n# ========== set parameters...\r\nargs, _ = get_args()\r\nargs.data = DATA\r\nprint(\"INFO: Arguments:\", args)\r\n\r\nLR = args.lr\r\nBATCH_SIZE = args.bs\r\nK_VALUE = args.k_value  \r\nNUM_EPOCH = args.num_epoch\r\nSEED = args.seed\r\nMEM_DIM = 16 #args.mem_dim\r\nTIME_DIM = 16 #args.time_dim\r\nEMB_DIM = 16 #args.emb_dim\r\nTOLERANCE = args.tolerance\r\nPATIENCE = args.patience\r\nNUM_RUNS = 1 #args.num_run\r\nNUM_NEIGHBORS = 10\r\nUSE_EDGE_TYPE = True\r\nUSE_NODE_TYPE = True\r\n\r\n\r\n\r\nMODEL_NAME = 'TGN'\r\n# ==========\r\n\r\n# set the device\r\n# device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\r\ndevice = \"cpu\"\r\n\r\n\r\ntorch.manual_seed(SEED)\r\nset_random_seed(SEED)\r\n\r\n# data loading\r\ndataset = PyGLinkPropPredDataset(name=DATA, root=\"datasets\")\r\ntrain_mask = dataset.train_mask\r\nval_mask = dataset.val_mask\r\ntest_mask = dataset.test_mask\r\ndata = dataset.get_TemporalData()\r\ndata = data.to(device)\r\nmetric = dataset.eval_metric\r\n\r\nprint (\"there are {} nodes and {} edges\".format(dataset.num_nodes, dataset.num_edges))\r\nprint (\"there are {} relation types\".format(dataset.num_rels))\r\n\r\n\r\ntimestamp = data.t\r\nhead = data.src\r\ntail = data.dst\r\nedge_type = data.edge_type #relation\r\nedge_type_dim = len(torch.unique(edge_type))\r\n\r\nembed_edge_type = torch.nn.Embedding(edge_type_dim, EMB_DIM).to(device)\r\nwith torch.no_grad():\r\n    edge_type_embeddings = embed_edge_type(edge_type)\r\n\r\n\r\nif USE_EDGE_TYPE:\r\n    data.msg = torch.cat([data.msg, edge_type_embeddings], dim=1)\r\n\r\n#! node type is a property of the dataset not the temporal data as temporal data has one entry per edge\r\nnode_type = dataset.node_type #node type\r\nneg_sampler = dataset.negative_sampler\r\n\r\ndata.__setattr__(\"node_type\", node_type)\r\n\r\nprint (\"shape of edge type is\", edge_type.shape)\r\nprint (\"shape of node type is\", node_type.shape)\r\n\r\ntrain_data = data[train_mask]\r\nval_data = data[val_mask]\r\ntest_data = data[test_mask]\r\nprint (\"finished loading PyG data\")\r\n\r\ntrain_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)\r\nval_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)\r\ntest_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)\r\n\r\nstart_time = timeit.default_timer()\r\n\r\n# Ensure to only sample actual destination nodes as negatives.\r\nmin_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\r\n\r\n\r\nprint(\"==========================================================\")\r\nprint(f\"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============\")\r\nprint(\"==========================================================\")\r\n\r\nevaluator = Evaluator(name=DATA)\r\nneg_sampler = dataset.negative_sampler\r\n\r\nbest_test = 0\r\nbest_val = 0\r\nbest_epoch = 0\r\n\r\n# for saving the results...\r\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\r\nif not osp.exists(results_path):\r\n    os.mkdir(results_path)\r\n    print('INFO: Create directory {}'.format(results_path))\r\nPath(results_path).mkdir(parents=True, exist_ok=True)\r\nresults_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'\r\n\r\nfor run_idx in range(NUM_RUNS):\r\n    print('-------------------------------------------------------------------------------')\r\n    print(f\"INFO: >>>>> Run: {run_idx} <<<<<\")\r\n    start_run = timeit.default_timer()\r\n\r\n    # set the seed for deterministic results...\r\n    torch.manual_seed(run_idx + SEED)\r\n    set_random_seed(run_idx + SEED)\r\n\r\n    # neighhorhood sampler\r\n    neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)\r\n\r\n    # define the model end-to-end\r\n    memory = TGNMemory(\r\n        data.num_nodes,\r\n        data.msg.size(-1),\r\n        MEM_DIM,\r\n        TIME_DIM,\r\n        message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),\r\n        aggregator_module=LastAggregator(),\r\n    ).to(device)\r\n\r\n    gnn = GraphAttentionEmbedding(\r\n        in_channels=MEM_DIM,\r\n        out_channels=EMB_DIM,\r\n        msg_dim=data.msg.size(-1),\r\n        time_enc=memory.time_enc,\r\n    ).to(device)\r\n\r\n    link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)\r\n\r\n    model = {'memory': memory,\r\n            'gnn': gnn,\r\n            'link_pred': link_pred}\r\n\r\n    optimizer = torch.optim.Adam(\r\n        set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),\r\n        lr=LR,\r\n    )\r\n    criterion = torch.nn.BCEWithLogitsLoss()\r\n\r\n    # Helper vector to map global node indices to local ones.\r\n    assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)\r\n\r\n    # # define an early stopper\r\n    # save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'\r\n    # save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'\r\n    # early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id, \r\n    #                                 tolerance=TOLERANCE, patience=PATIENCE)\r\n\r\n    # ==================================================== Train & Validation\r\n    # loading the validation negative samples\r\n    dataset.load_val_ns()\r\n\r\n    val_perf_list = []\r\n    start_train_val = timeit.default_timer()\r\n    for epoch in range(1, NUM_EPOCH + 1):\r\n        # training\r\n        start_epoch_train = timeit.default_timer()\r\n        loss = train()\r\n        print(\r\n            f\"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {timeit.default_timer() - start_epoch_train: .4f}\"\r\n        )\r\n\r\n        # validation\r\n        start_val = timeit.default_timer()\r\n        perf_metric_val = test(val_loader, neg_sampler, split_mode=\"val\")\r\n        print(f\"\\tValidation {metric}: {perf_metric_val: .4f}\")\r\n        print(f\"\\tValidation: Elapsed time (s): {timeit.default_timer() - start_val: .4f}\")\r\n        val_perf_list.append(perf_metric_val)\r\n\r\n        # # check for early stopping\r\n        # if early_stopper.step_check(perf_metric_val, model):\r\n        #     break\r\n\r\n        train_val_time = timeit.default_timer() - start_train_val\r\n        print(f\"Train & Validation: Elapsed Time (s): {train_val_time: .4f}\")\r\n\r\n        # # ==================================================== Test\r\n        # # first, load the best model\r\n        # early_stopper.load_checkpoint(model)\r\n\r\n        # loading the test negative samples\r\n        dataset.load_test_ns()\r\n\r\n        # final testing\r\n        start_test = timeit.default_timer()\r\n        perf_metric_test = test(test_loader, neg_sampler, split_mode=\"test\")\r\n\r\n        print(f\"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< \")\r\n        print(f\"\\tTest: {metric}: {perf_metric_test: .4f}\")\r\n        test_time = timeit.default_timer() - start_test\r\n        print(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\r\n\r\n        # save_results({'model': MODEL_NAME,\r\n        #             'data': DATA,\r\n        #             'run': run_idx,\r\n        #             'seed': SEED,\r\n        #             f'val {metric}': val_perf_list,\r\n        #             f'test {metric}': perf_metric_test,\r\n        #             'test_time': test_time,\r\n        #             'tot_train_val_time': train_val_time\r\n        #             }, \r\n        # results_filename)\r\n        if (perf_metric_val > best_val):\r\n            best_val = perf_metric_val\r\n            best_epoch = epoch\r\n            best_test = perf_metric_test\r\n\r\n        print(f\"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<\")\r\n        print('-------------------------------------------------------------------------------')\r\n\r\n    print (\"INFO: Best Epoch: \", best_epoch)\r\n    print (\"INFO: Best Validation Performance: \", best_val)\r\n    print (\"INFO: Best Test Performance: \", best_test)\r\n\r\n    print(f\"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}\")\r\n    print(\"==============================================================\")\r\n"
  },
  {
    "path": "examples/linkproppred/thgl-myket/edgebank.py",
    "content": "\"\"\"\nDynamic Link Prediction with EdgeBank\nNOTE: This implementation works only based on `numpy`\n\nReference: \n    - https://github.com/fpour/DGB/tree/main\n\n\n\"\"\"\n\nimport timeit\nimport numpy as np\nfrom sklearn.metrics import average_precision_score, roc_auc_score\nfrom torch_geometric.loader import TemporalDataLoader\nfrom tqdm import tqdm\nimport math\nimport os\nimport os.path as osp\nfrom pathlib import Path\nimport sys\nimport argparse\n\n# internal imports\ntgb_modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\nsys.path.append(tgb_modules_path)\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom modules.edgebank_predictor import EdgeBankPredictor\nfrom tgb.utils.utils import set_random_seed\nfrom tgb.linkproppred.dataset import LinkPropPredDataset\nfrom tgb.utils.utils import save_results\n\n# ==================\n# ==================\n# ==================\n\ndef test(data, test_mask, neg_sampler, split_mode):\n    r\"\"\"\n    Evaluated the dynamic link prediction\n    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges\n\n    Parameters:\n        data: a dataset object\n        test_mask: required masks to load the test set edges\n        neg_sampler: an object that gives the negative edges corresponding to each positive edge\n        split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives\n    Returns:\n        perf_metric: the result of the performance evaluation\n    \"\"\"\n    num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)\n    perf_list = []\n    hits_list = []\n    for batch_idx in tqdm(range(num_batches)):\n        start_idx = batch_idx * BATCH_SIZE\n        end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))\n        pos_src, pos_dst, pos_t, pos_edge = (\n            data['sources'][test_mask][start_idx: end_idx],\n            data['destinations'][test_mask][start_idx: end_idx],\n            data['timestamps'][test_mask][start_idx: end_idx],\n            data['edge_type'][test_mask][start_idx: end_idx],\n        )\n        neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, pos_edge, split_mode=split_mode)\n        \n        for idx, neg_batch in enumerate(neg_batch_list):\n            query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])\n            query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])\n\n            y_pred = edgebank.predict_link(query_src, query_dst)\n            # compute MRR\n            input_dict = {\n                \"y_pred_pos\": np.array([y_pred[0]]),\n                \"y_pred_neg\": np.array(y_pred[1:]),\n                \"eval_metric\": [metric],\n            }\n            results = evaluator.eval(input_dict)\n            perf_list.append(results[metric])\n            hits_list.append(results['hits@10'])\n            \n        # update edgebank memory after each positive batch\n        edgebank.update_memory(pos_src, pos_dst, pos_t)\n\n    perf_metrics = float(np.mean(perf_list))\n    perf_hits = float(np.mean(hits_list))\n\n    return perf_metrics, perf_hits\n\ndef get_args():\n    parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')\n    parser.add_argument('-d', '--data', type=str, help='Dataset name', default='thgl-myket')\n    parser.add_argument('--bs', type=int, help='Batch size', default=200)\n    parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)\n    parser.add_argument('--seed', type=int, help='Random seed', default=1)\n    parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])\n    parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)\n\n    try:\n        args = parser.parse_args()\n    except:\n        parser.print_help()\n        sys.exit(0)\n    return args, sys.argv \n\n# ==================\n# ==================\n# ==================\n\nstart_overall = timeit.default_timer()\n\n# set hyperparameters\nargs, _ = get_args()\n\nSEED = args.seed  # set the random seed for consistency\nset_random_seed(SEED)\nMEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`\nBATCH_SIZE = args.bs\nK_VALUE = args.k_value\nTIME_WINDOW_RATIO = args.time_window_ratio\nDATA = args.data\nMODEL_NAME = 'EdgeBank'\n\n\n\n# data loading with `numpy`\ndataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\ndata = dataset.full_data  \nmetric = dataset.eval_metric\n\n\n# get masks\ntrain_mask = dataset.train_mask\nval_mask = dataset.val_mask\ntest_mask = dataset.test_mask\n\n#data for memory in edgebank\nhist_src = np.concatenate([data['sources'][train_mask]])\nhist_dst = np.concatenate([data['destinations'][train_mask]])\nhist_ts = np.concatenate([data['timestamps'][train_mask]])\n\n# Set EdgeBank with memory updater\nedgebank = EdgeBankPredictor(\n        hist_src,\n        hist_dst,\n        hist_ts,\n        memory_mode=MEMORY_MODE,\n        time_window_ratio=TIME_WINDOW_RATIO)\n\nprint(\"==========================================================\")\nprint(f\"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============\")\nprint(\"==========================================================\")\n\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'\n\n# ==================================================== Test\n# loading the validation negative samples\ndataset.load_val_ns()\n\n# testing ...\nstart_val = timeit.default_timer()\nperf_metric_val, perf_hits_val = test(data, val_mask, neg_sampler, split_mode='val')\nend_val = timeit.default_timer()\n\nprint(f\"INFO: val: Evaluation Setting: >>> ONE-VS--ALL <<< \")\nprint(f\"\\tval: {metric}: {perf_metric_val: .4f}\")\nval_time = timeit.default_timer() - start_val\nprint(f\"\\tval: Elapsed Time (s): {val_time: .4f}\")\n\n\n\n\n# ==================================================== Test\n# loading the test negative samples\ndataset.load_test_ns()\n\n# testing ...\nstart_test = timeit.default_timer()\nperf_metric_test, perf_hits_test = test(data, test_mask, neg_sampler, split_mode='test')\nend_test = timeit.default_timer()\n\nprint(f\"INFO: Test: Evaluation Setting: >>>  <<< \")\nprint(f\"\\tTest: {metric}: {perf_metric_test: .4f}\")\ntest_time = timeit.default_timer() - start_test\nprint(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\n\nsave_results({'model': MODEL_NAME,\n              'memory_mode': MEMORY_MODE,\n              'data': DATA,\n              'run': 1,\n              'seed': SEED,\n              metric: perf_metric_test,\n              'val_mrr': perf_metric_val,\n              'test_time': test_time,\n              'tot_train_val_time': test_time+val_time,\n              'hits10': perf_hits_test}, \n    results_filename)\n\n"
  },
  {
    "path": "examples/linkproppred/thgl-myket/recurrencybaseline.py",
    "content": "\"\"\" from paper: \"History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting\" \nJulia Gastinger, Christian Meilicke, Federico Errica, Timo Sztyler, Anett Schuelke, Heiner Stuckenschmidt (IJCAI 2024)\n\n@inproceedings{gastinger2024baselines,\n  title={History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting},\n  author={Gastinger, Julia and Meilicke, Christian and Errica, Federico and Sztyler, Timo and Schuelke, Anett and Stuckenschmidt, Heiner},\n  booktitle={33nd International Joint Conference on Artificial Intelligence (IJCAI 2024)},\n  year={2024},\n  organization={International Joint Conferences on Artificial Intelligence Organization}\n}\n\n\"\"\"\n\n## imports\nimport timeit\nimport argparse\nimport numpy as np\nfrom copy import copy\nfrom pathlib import Path\nimport ray\nimport sys\nimport os\nimport os.path as osp\nimport json\n#internal imports \nmodules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\nsys.path.append(modules_path)\nfrom modules.recurrencybaseline_predictor import baseline_predict, baseline_predict_remote\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom tgb.linkproppred.dataset import LinkPropPredDataset \nfrom tgb.utils.utils import set_random_seed,  save_results \nfrom modules.tkg_utils import create_basis_dict, group_by, reformat_ts\n\ndef predict(num_processes,  data_c_rel, all_data_c_rel, alpha, lmbda_psi,\n            perf_list_all, hits_list_all, window, neg_sampler, split_mode):\n    \"\"\" create predictions for each relation on test or valid set and compute mrr\n    :return  perf_list_all: list of mrrs for each test query\n    :return hits_list_all: list of hits for each test query\n    \"\"\"\n\n    first_ts = data_c_rel[0][3]\n    ## use this if you wanna use ray:\n    num_queries = len(data_c_rel) // num_processes\n    if num_queries < num_processes: # if we do not have enough queries for all the processes\n        num_processes_tmp = 1\n        num_queries = len(data_c_rel)\n    else:\n        num_processes_tmp = num_processes  \n    if num_processes > 1:\n        object_references =[]                   \n        \n        for i in range(num_processes_tmp):\n            num_test_queries = len(data_c_rel) - (i + 1) * num_queries\n            if num_test_queries >= num_queries:\n                test_queries_idx =[i * num_queries, (i + 1) * num_queries]\n            else:\n                test_queries_idx = [i * num_queries, len(test_data)]\n\n            valid_data_b = data_c_rel[test_queries_idx[0]:test_queries_idx[1]]\n\n            ob = baseline_predict_remote.remote(num_queries, valid_data_b, all_data_c_rel, window, \n                                basis_dict, \n                                num_nodes, num_rels, lmbda_psi, \n                                alpha, evaluator,first_ts, neg_sampler, split_mode)\n            object_references.append(ob)\n\n        output = ray.get(object_references)\n\n        # updates the scores and logging dict for each process\n        for proc_loop in range(num_processes_tmp):\n            perf_list_all.extend(output[proc_loop][0])\n            hits_list_all.extend(output[proc_loop][1])\n\n    else:\n        perf_list, hits_list = baseline_predict(len(data_c_rel), data_c_rel, all_data_c_rel, \n                            window, basis_dict, \n                            num_nodes, num_rels, lmbda_psi, \n                            alpha, evaluator, first_ts, neg_sampler, split_mode)                  \n        perf_list_all.extend(perf_list)\n        hits_list_all.extend(hits_list)\n    \n    return perf_list_all, hits_list_all\n\n## test\ndef test(best_config, all_relations,test_data_prel, all_data_prel, neg_sampler, num_processes, window, split_mode='test'):  \n    \"\"\" create predictions by loopoing through all relations on test or valid set and compute mrr\n    :return  perf_list_all: list of mrrs for each test query\n    :return hits_list_all: list of hits for each test query\n    \"\"\"       \n    perf_list_all = []\n    hits_list_all =[]\n    \n    csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'\n\n    ## loop through relations and apply baselines\n    \n    for rel in all_relations:\n        start =  timeit.default_timer()\n        if rel in test_data_prel.keys():\n            lmbda_psi = best_config[str(rel)]['lmbda_psi'][0]\n            alpha = best_config[str(rel)]['alpha'][0]\n\n            # test data for this relation\n            test_data_c_rel = test_data_prel[rel]\n            timesteps_test = list(set(test_data_c_rel[:,3]))\n            timesteps_test.sort()\n            all_data_c_rel = all_data_prel[rel]         \n            perf_list_rel = []\n            hits_list_rel = []\n            perf_list_rel, hits_list_rel = predict(num_processes, test_data_c_rel,\n                                                all_data_c_rel, alpha, lmbda_psi,perf_list_rel, hits_list_rel, \n                                                window, neg_sampler, split_mode)\n            perf_list_all.extend(perf_list_rel)\n            hits_list_all.extend(hits_list_rel)\n        else:\n            perf_list_rel =[]\n         \n        end =  timeit.default_timer()\n        total_time = round(end - start, 6)  \n        print(\"Relation {} finished in {} seconds.\".format(rel, total_time))\n        \n\n        with open(csv_file, 'a') as f:\n            f.write(\"{},{}\\n\".format(rel, perf_list_rel))\n\n    return perf_list_all, hits_list_all\n\ndef read_dict_compute_mrr(split_mode='test'):\n    \"\"\" read the results per relation  from a precreated file and compute mrr\n    :return  mrr_per_rel: dictionary of mrrs for each relation\n    :return all_mrrs: list of mrrs for all relations\n    \"\"\"\n    csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'\n    # Initialize an empty dictionary to store the data\n    results_per_rel_dict = {}\n    mrr_per_rel = {}\n    all_mrrs = []\n    # Open the file for reading\n    with open(csv_file, 'r') as f:\n        # Read each line in the file\n        for line in f:\n            # Split the line at the comma\n            parts = line.strip().split(',')\n            # Extract the key (the first part)\n            key = int(parts[0])\n            # Extract the values (the rest of the parts), remove square brackets\n            values = [float(value.strip('[]')) for value in parts[1:]]\n            # Add the key-value pair to the dictionary\n            if key in results_per_rel_dict.keys():\n                print(f\"Key {key} already exists in the dictionary!!! might have duplicate entries in results csv\")\n            results_per_rel_dict[key] = values\n            all_mrrs.extend(values)\n            mrr_per_rel[key] = np.mean(values)\n\n    if len(list(results_per_rel_dict.keys())) != num_rels:\n        print(\"we do not have entries for each rel in the results csv file. only num enties: \", len(list(results_per_rel_dict.keys())))\n\n    print(\"Split mode: \"+split_mode +\" Mean MRR: \", np.mean(all_mrrs))\n    print(\"mrr per relation: \", mrr_per_rel)\n\n\n\n## train\ndef train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampler, num_processes, window):\n    \"\"\" optional, find best values for lambda and alpha by looping through all relations and testing an a fixed set of params\n    based on validation mrr\n    :return best_config: dictionary of best params for each relation\n    \"\"\"\n    best_config= {}\n    best_mrr = 0\n    for rel in rels: # loop through relations. for each relation, apply rules with selected params, compute valid mrr\n        start =  timeit.default_timer()\n        rel_key = int(rel)            \n\n        best_config[str(rel_key)] = {}\n        best_config[str(rel_key)]['not_trained'] = 'True'    \n        best_config[str(rel_key)]['lmbda_psi'] = [default_lmbda_psi,0] #default\n        best_config[str(rel_key)]['other_lmbda_mrrs'] = list(np.zeros(len(params_dict['lmbda_psi'])))\n        best_config[str(rel_key)]['alpha'] = [default_alpha,0]  #default    \n        best_config[str(rel_key)]['other_alpha_mrrs'] = list(np.zeros(len(params_dict['alpha'])))\n        \n        if rel in val_data_prel.keys():      \n            # valid data for this relation  \n            val_data_c_rel = val_data_prel[rel]\n            timesteps_valid = list(set(val_data_c_rel[:,3]))\n            timesteps_valid.sort()\n            trainval_data_c_rel = trainval_data_prel[rel]\n\n            ######  1) select lambda ###############        \n            lmbdas_psi = params_dict['lmbda_psi']        \n\n            alpha = 1\n            best_lmbda_psi = 0.1\n            best_mrr_psi = 0\n            lmbda_mrrs = []\n\n            best_config[str(rel_key)]['num_app_valid'] = copy(len(val_data_c_rel))\n            best_config[str(rel_key)]['num_app_train_valid'] = copy(len(trainval_data_c_rel))         \n            best_config[str(rel_key)]['not_trained'] = 'False'       \n            \n            for lmbda_psi in lmbdas_psi:   \n                perf_list_r = []\n                hits_list_r = []\n                perf_list_r, hits_list_r = predict(num_processes, val_data_c_rel, \n                                                    trainval_data_c_rel, alpha, lmbda_psi,perf_list_r, hits_list_r, \n                                                    window, neg_sampler, split_mode='val')\n                # compute mrr\n                mrr = np.mean(perf_list_r)\n                # # is new mrr better than previous best? if yes: store lmbda\n                if mrr > best_mrr_psi:\n                    best_mrr_psi = float(mrr)\n                    best_lmbda_psi = lmbda_psi\n\n\n                lmbda_mrrs.append(float(mrr))\n            best_config[str(rel_key)]['lmbda_psi'] = [best_lmbda_psi, best_mrr_psi]\n            best_config[str(rel_key)]['other_lmbda_mrrs'] = lmbda_mrrs\n            best_mrr = best_mrr_psi\n            ##### 2) select alpha ###############\n            best_config[str(rel_key)]['not_trained'] = 'False'    \n            alphas = params_dict['alpha'] \n            lmbda_psi = best_config[str(rel_key)]['lmbda_psi'][0] # use the best lmbda psi\n\n            alpha_mrrs = []\n            # perf_list_all = []\n            best_mrr_alpha = 0\n            best_alpha=0.99\n            for alpha in alphas:\n                perf_list_r = []\n                hits_list_r = []\n\n                perf_list_r, hits_list_r = predict(num_processes, val_data_c_rel, \n                                                    trainval_data_c_rel, alpha, lmbda_psi,perf_list_r, hits_list_r, \n                                                    window, neg_sampler, split_mode='val')\n                # compute mrr\n                mrr_alpha = np.mean(perf_list_r)\n\n                # is new mrr better than previous best? if yes: store alpha\n                if mrr_alpha > best_mrr_alpha:\n                    best_mrr_alpha = float(mrr_alpha)\n                    best_alpha = alpha\n                    best_mrr = best_mrr_alpha\n                alpha_mrrs.append(float(mrr_alpha))\n\n            best_config[str(rel_key)]['alpha'] = [best_alpha, best_mrr_alpha]\n            best_config[str(rel_key)]['other_alpha_mrrs'] = alpha_mrrs\n\n        end =  timeit.default_timer()\n        total_time = round(end - start, 6)  \n        print(\"Relation {} finished in {} seconds.\".format(rel, total_time))\n    return best_config\n\n\n\n## args\ndef get_args(): \n    \"\"\"parse all arguments for the script\"\"\"\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--dataset\", \"-d\", default=\"thgl-myket\", type=str) \n    parser.add_argument(\"--window\", \"-w\", default=0, type=int) # set to e.g. 200 if only the most recent 200 timesteps should be considered. set to -2 if multistep\n    parser.add_argument(\"--num_processes\", \"-p\", default=1, type=int)\n    parser.add_argument(\"--lmbda\", \"-l\",  default=0.1, type=float) # fix lambda. used if trainflag == false\n    parser.add_argument(\"--alpha\", \"-alpha\",  default=0.99, type=float) # fix alpha. used if trainflag == false\n    parser.add_argument(\"--train_flag\", \"-tr\",  default='False') # do we need training, ie selection of lambda and alpha\n    parser.add_argument(\"--load_flag\", \"-lo\",  default='False') # if train_flag set to True: do you want to load best_config?\n    parser.add_argument(\"--save_config\", \"-c\",  default='True') # do we need to save the selection of lambda and alpha in config file?\n    parser.add_argument('--seed', type=int, help='Random seed', default=1) # not needed\n    parsed = vars(parser.parse_args())\n    return parsed\n\nstart_o =  timeit.default_timer()\n\nparsed = get_args()\nif parsed['num_processes']>1:\n    ray.init(num_cpus=parsed[\"num_processes\"], num_gpus=0)\nMODEL_NAME = 'RecurrencyBaseline'\nSEED = parsed['seed']  # set the random seed for consistency\nset_random_seed(SEED)\nperrel_results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_models'\nif not osp.exists(perrel_results_path):\n    os.mkdir(perrel_results_path)\n    print('INFO: Create directory {}'.format(perrel_results_path))\nPath(perrel_results_path).mkdir(parents=True, exist_ok=True)\n\n## load dataset and prepare it accordingly\nname = parsed[\"dataset\"]\ndataset = LinkPropPredDataset(name=name, root=\"datasets\", preprocess=True)\nDATA = name\n\nrelations = dataset.edge_type\nnum_rels = dataset.num_rels\nrels = np.arange(0,num_rels)\nsubjects = dataset.full_data[\"sources\"]\nobjects= dataset.full_data[\"destinations\"]\nnum_nodes = dataset.num_nodes \ntimestamps_orig = dataset.full_data[\"timestamps\"]\ntimestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1\n\nprint(\"split train valid test data\")\nall_quads = np.stack((subjects, relations, objects, timestamps, timestamps_orig), axis=1)\ntrain_data = all_quads[dataset.train_mask]\nval_data = all_quads[dataset.val_mask]\ntest_data = all_quads[dataset.test_mask]\n\nmetric = dataset.eval_metric\nevaluator = Evaluator(name=name)\nneg_sampler = dataset.negative_sampler\n\ntrain_val_data = np.concatenate([train_data, val_data])\nall_data = np.concatenate([train_data, val_data, test_data])\n\n# create dicts with key: relation id, values: triples for that relation id\nprint(\"grouping data by relation\")\ntest_data_prel = group_by(test_data, 1)\nall_data_prel = group_by(all_data, 1)\nval_data_prel = group_by(val_data, 1)\ntrainval_data_prel = group_by(train_val_data, 1)\n\n#load the ns samples \n# if parsed['train_flag']:\nprint(\"loading negative samples\")\ndataset.load_val_ns()\ndataset.load_test_ns()\n\n# parameter options\nif parsed['train_flag'] == 'True':\n    params_dict = {}\n    params_dict['lmbda_psi'] = [0, 0.0001, 0.0005, 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.5, 0.9, 1.0001] \n    params_dict['alpha'] = [0, 0.00001, 0.0001, 0.001, 0.01, 0.1, 0.5, 0.9, 0.99, 0.999, 0.9999, 0.99999, 1]\n    default_lmbda_psi = params_dict['lmbda_psi'][-1]\n    default_alpha = params_dict['alpha'][-2]\n\n## load rules\nprint(\"creating rules\")\nbasis_dict = create_basis_dict(train_val_data)\nprint(\"done with creating rules\")\n## init\n# rb_predictor = RecurrencyBaselinePredictor(rels)\n## train to find best lambda and alpha\nstart_train =  timeit.default_timer()\nif parsed['train_flag'] ==  'True':\n    if parsed['load_flag'] == 'True':\n        with open('best_config.json', 'r') as infile:\n            best_config = json.load(infile)\n    else:\n        print('start training')\n        best_config = train(params_dict,  rels, val_data_prel, trainval_data_prel, neg_sampler, parsed['num_processes'], \n            parsed['window'])\n        if parsed['save_config'] == 'True':\n            import json\n            with open('best_config.json', 'w') as outfile:\n                json.dump(best_config, outfile)\n\nelse: # use preset lmbda and alpha; same for all relations\n    best_config = {} \n    for rel in rels:\n        best_config[str(rel)] = {}\n        best_config[str(rel)]['lmbda_psi'] = [parsed['lmbda']]\n        best_config[str(rel)]['alpha'] = [parsed['alpha']]\n    \nend_train =  timeit.default_timer()\n\n# compute validation mrr\nprint(\"Computing validation MRR\")\nperf_list_all_val, hits_list_all_val = test(best_config,rels, val_data_prel, \n                                                trainval_data_prel, neg_sampler, parsed['num_processes'], \n                                            parsed['window'], split_mode='val')\nval_mrr = float(np.mean(perf_list_all_val))\n\n# compute test mrr\nprint(\"Computing test MRR\")\nstart_test =  timeit.default_timer()\nperf_list_all, hits_list_all = test(best_config,rels, test_data_prel, \n                                                 all_data_prel, neg_sampler, parsed['num_processes'], \n                                                parsed['window'])\n\nend_o =  timeit.default_timer()\ntrain_time_o = round(end_train- start_train, 6)  \ntest_time_o = round(end_o- start_test, 6)  \ntotal_time_o = round(end_o- start_o, 6)  \nprint(\"Running Training to find best configs finished in {} seconds.\".format(train_time_o))\nprint(\"Running testing with best configs finished in {} seconds.\".format(test_time_o))\nprint(\"Running all steps finished in {} seconds.\".format(total_time_o))\n\nprint(f\"The test MRR is {np.mean(perf_list_all)}\")\nprint(f\"The valid MRR is {val_mrr}\")\nprint(f\"The Hits@10 is {np.mean(hits_list_all)}\")\nprint(f\"We have {len(perf_list_all)} predictions\")\nprint(f\"The test set has len {len(test_data)} \")\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_NONE_{DATA}_results.json'\nmetric = dataset.eval_metric\nsave_results({'model': MODEL_NAME,\n              'train_flag': parsed['train_flag'],\n              'data': DATA,\n              'run': 1,\n              'seed': SEED,\n              metric: float(np.mean(perf_list_all)),\n              'val_mrr': val_mrr,\n              'hits10': float(np.mean(hits_list_all)),\n              'test_time': test_time_o,\n              'tot_train_val_time': total_time_o\n              }, \n    results_filename)\n\nif parsed['num_processes']>1:\n    ray.shutdown()\n\n\n    \n"
  },
  {
    "path": "examples/linkproppred/thgl-myket/sthn.py",
    "content": "import timeit\nimport numpy as np\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\nfrom tgb.linkproppred.evaluate import Evaluator\n\nimport argparse\nfrom modules.sthn import set_seed, pre_compute_subgraphs, get_inputs_for_ind, check_data_leakage\nimport torch\nimport pandas as pd\nimport itertools\nfrom tqdm import tqdm\nimport math\nimport os\nimport os.path as osp\nfrom pathlib import Path\nfrom tgb.utils.utils import set_random_seed, save_results\n\n\n# Start...\nstart_overall = timeit.default_timer()\n\nDATA = \"thgl-myket\"\n\nMODEL_NAME = 'STHN'\n\n# data loading\ndataset = PyGLinkPropPredDataset(name=DATA, root=\"datasets\")\ntrain_mask = dataset.train_mask\nval_mask = dataset.val_mask\ntest_mask = dataset.test_mask\ndata = dataset.get_TemporalData()\nmetric = dataset.eval_metric\n\nprint (\"there are {} nodes and {} edges\".format(dataset.num_nodes, dataset.num_edges))\nprint (\"there are {} relation types\".format(dataset.num_rels))\n\n\ntimestamp = data.t\nhead = data.src\ntail = data.dst\nedge_type = data.edge_type\nneg_sampler = dataset.negative_sampler\n\nprint(data)\nprint(timestamp)\nprint(head)\nprint(tail)\n\ntrain_data = data[train_mask]\nval_data = data[val_mask]\ntest_data = data[test_mask]\n\n\nmetric = dataset.eval_metric\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'\n\n\n####################################################################\n####################################################################\n####################################################################\n\n\ndef print_model_info(model):\n    print(model)\n    parameters = filter(lambda p: p.requires_grad, model.parameters())\n    parameters = sum([np.prod(p.size()) for p in parameters])\n    print('Trainable Parameters: %d' % parameters)\n\ndef get_args():\n    parser=argparse.ArgumentParser()\n    parser.add_argument('--data', type=str, default='movie')\n    parser.add_argument('--device', type=int, default=0)\n    parser.add_argument('--batch_size', type=int, default=600)\n    parser.add_argument('--epochs', type=int, default=2)\n    parser.add_argument('--max_edges', type=int, default=50)\n    parser.add_argument('--num_edgeType', type=int, default=0, help='num of edgeType')\n    parser.add_argument('--lr', type=float, default=0.0005)\n    parser.add_argument('--weight_decay', type=float, default=1e-4)\n    parser.add_argument('--predict_class', action='store_true')\n    \n    # model\n    parser.add_argument('--window_size', type=int, default=5)\n    parser.add_argument('--dropout', type=float, default=0.1)\n    parser.add_argument('--model', type=str, default='sthn') \n    parser.add_argument('--neg_samples', type=int, default=1)\n    parser.add_argument('--extra_neg_samples', type=int, default=5)\n    parser.add_argument('--num_neighbors', type=int, default=50)\n    parser.add_argument('--channel_expansion_factor', type=int, default=2)\n    parser.add_argument('--sampled_num_hops', type=int, default=1)\n    parser.add_argument('--time_dims', type=int, default=100)\n    parser.add_argument('--hidden_dims', type=int, default=100)\n    parser.add_argument('--num_layers', type=int, default=1)\n    parser.add_argument('--check_data_leakage', action='store_true')\n    \n    parser.add_argument('--ignore_node_feats', action='store_true')\n    parser.add_argument('--node_feats_as_edge_feats', action='store_true')\n    parser.add_argument('--ignore_edge_feats', action='store_true')\n    parser.add_argument('--use_onehot_node_feats', action='store_true')\n    parser.add_argument('--use_type_feats', action='store_true')\n\n    parser.add_argument('--use_graph_structure', action='store_true')\n    parser.add_argument('--structure_time_gap', type=int, default=2000)\n    parser.add_argument('--structure_hops', type=int, default=1) \n\n    parser.add_argument('--use_node_cls', action='store_true')\n    parser.add_argument('--use_cached_subgraph', action='store_true')\n    \n    parser.add_argument('--seed', type=int, help='Random seed', default=1)\n    parser.add_argument('--num_run', type=int, help='Number of iteration runs', default=5)\n    return parser.parse_args()\n\n\ndef load_model(args):\n    # get model\n    edge_predictor_configs = {\n        'dim_in_time': args.time_dims,\n        'dim_in_node': args.node_feat_dims,\n        'predict_class': 1 if not args.predict_class else args.num_edgeType+1,\n    }\n    if args.model == 'sthn':\n        if args.predict_class:\n            from modules.sthn import Multiclass_Interface as STHN_Interface\n        else:\n            from modules.sthn import STHN_Interface\n        from modules.sthn import link_pred_train\n\n        mixer_configs = {\n            'per_graph_size'  : args.max_edges,\n            'time_channels'   : args.time_dims, \n            'input_channels'  : args.edge_feat_dims, \n            'hidden_channels' : args.hidden_dims, \n            'out_channels'    : args.hidden_dims,\n            'num_layers'      : args.num_layers,\n            'dropout'         : args.dropout,\n            'channel_expansion_factor': args.channel_expansion_factor,\n            'window_size'     : args.window_size,\n            'use_single_layer' : False\n        }  \n        \n    else:\n        NotImplementedError()\n\n    model = STHN_Interface(mixer_configs, edge_predictor_configs)\n    for k, v in model.named_parameters():\n        print(k, v.requires_grad)\n\n    print_model_info(model)\n\n    return model, args, link_pred_train\n\ndef load_graph(data):\n    df = pd.DataFrame({\n        'idx': np.arange(len(data.t)),\n        'src': data.src,\n        'dst': data.dst,\n        'time': data.t,\n        'label': data.edge_type,\n    })\n\n    num_nodes = max(int(df['src'].max()), int(df['dst'].max())) + 1 \n\n    ext_full_indptr = np.zeros(num_nodes + 1, dtype=np.int32)\n    ext_full_indices = [[] for _ in range(num_nodes)]\n    ext_full_ts = [[] for _ in range(num_nodes)]\n    ext_full_eid = [[] for _ in range(num_nodes)]\n\n    for idx, row in tqdm(df.iterrows(), total=len(df)):\n        src = int(row['src'])\n        dst = int(row['dst'])\n        \n        ext_full_indices[src].append(dst)\n        ext_full_ts[src].append(row['time'])\n        ext_full_eid[src].append(idx)\n        \n    for i in tqdm(range(num_nodes)):\n        ext_full_indptr[i + 1] = ext_full_indptr[i] + len(ext_full_indices[i])\n\n    ext_full_indices = np.array(list(itertools.chain(*ext_full_indices)))\n    ext_full_ts = np.array(list(itertools.chain(*ext_full_ts)))\n    ext_full_eid = np.array(list(itertools.chain(*ext_full_eid)))\n\n    print('Sorting...')\n\n    def tsort(i, indptr, indices, t, eid):\n        beg = indptr[i]\n        end = indptr[i + 1]\n        sidx = np.argsort(t[beg:end])\n        indices[beg:end] = indices[beg:end][sidx]\n        t[beg:end] = t[beg:end][sidx]\n        eid[beg:end] = eid[beg:end][sidx] \n\n    for i in tqdm(range(ext_full_indptr.shape[0] - 1)):\n        tsort(i, ext_full_indptr, ext_full_indices, ext_full_ts, ext_full_eid)\n\n    print('saving...')\n\n    np.savez('/tmp/ext_full.npz', indptr=ext_full_indptr,\n            indices=ext_full_indices, ts=ext_full_ts, eid=ext_full_eid)\n    g = np.load('/tmp/ext_full.npz')\n    return g, df\n\ndef load_all_data(args):\n\n    # load graph\n    g, df = load_graph(data)\n\n    args.train_mask = train_mask.numpy()\n    args.val_mask   = val_mask.numpy()\n    args.test_mask = test_mask.numpy()\n    args.num_edges = len(df)\n\n    print('Train %d, Valid %d, Test %d'%(sum(args.train_mask), \n                                         sum(args.val_mask),\n                                         sum(test_mask)))\n    \n    args.num_nodes = max(int(df['src'].max()), int(df['dst'].max())) + 1\n    args.num_edges = len(df)\n\n    print('Num nodes %d, num edges %d'%(args.num_nodes, args.num_edges))\n\n    # load feats \n    node_feats, edge_feats = dataset.node_feat, dataset.edge_feat\n    node_feat_dims = 0 if node_feats is None else node_feats.shape[1]\n    edge_feat_dims = 0 if edge_feats is None else edge_feats.shape[1]\n\n    # feature pre-processing\n    if args.use_onehot_node_feats:\n        print('>>> Use one-hot node features')\n        node_feats = torch.eye(args.num_nodes)\n        node_feat_dims = node_feats.size(1)\n\n    if args.ignore_node_feats:\n        print('>>> Ignore node features')\n        node_feats = None\n        node_feat_dims = 0\n\n    if args.use_type_feats:\n        edge_type = df.label.values\n        print(edge_type)\n        print(edge_type.sum())\n        args.num_edgeType = len(set(edge_type.tolist()))\n        edge_feats = torch.nn.functional.one_hot(torch.from_numpy(edge_type), \n                                                 num_classes=args.num_edgeType)\n        edge_feat_dims = edge_feats.size(1)\n        \n    print('Node feature dim %d, edge feature dim %d'%(node_feat_dims, edge_feat_dims))\n    \n    # double check (if data leakage then cannot continue the code)\n    if args.check_data_leakage:\n        check_data_leakage(args, g, df)\n\n    args.node_feat_dims = node_feat_dims\n    args.edge_feat_dims = edge_feat_dims\n    \n    if node_feats != None:\n        node_feats = node_feats.to(args.device)\n    if edge_feats != None:\n        edge_feats = edge_feats.to(args.device)\n    \n    return node_feats, edge_feats, g, df, args\n\n####################################################################\n####################################################################\n####################################################################\n\n@torch.no_grad()\ndef test(data, test_mask, model, neg_sampler, split_mode):\n    r\"\"\"\n    Evaluated the dynamic link prediction\n    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges\n\n    Parameters:\n        data: a dataset object\n        test_mask: required masks to load the test set edges\n        neg_sampler: an object that gives the negative edges corresponding to each positive edge\n        split_mode: specifies whether it is the 'val' or 'test' set to correctly load the negatives\n    Returns:\n        perf_metric: the result of the performance evaluation\n    \"\"\"\n    test_subgraphs  = pre_compute_subgraphs(args, g, df, mode='test' if split_mode == 'test' else 'valid', negative_sampler=neg_sampler, split_mode=split_mode)\n    perf_list = []\n    \n    if split_mode == 'test':\n        cur_df = df[args.test_mask]\n    elif split_mode == 'val':\n        cur_df = df[args.val_mask]\n    neg_samples = 20\n    cached_neg_samples = 20\n\n    test_loader = cur_df.groupby(cur_df.index // args.batch_size)\n    pbar = tqdm(total=len(test_loader))\n    pbar.set_description('%s mode with negative samples %d ...'%(split_mode, neg_samples))        \n    \n    ###################################################\n    # compute + training + fetch all scores\n    cur_inds = 0\n\n    for ind in range(len(test_loader)):\n        ###################################################\n        inputs, subgraph_node_feats, cur_inds = get_inputs_for_ind(test_subgraphs, 'test' if split_mode == 'test' else 'tgb-val', cached_neg_samples, neg_samples, node_feats, edge_feats, cur_df, cur_inds, ind, args)\n        \n        loss, pred, edge_label = model(inputs, neg_samples, subgraph_node_feats)\n        # print(ind, [l for l in inputs], pred.shape)\n\n        input_dict = {\n            \"y_pred_pos\": np.array([pred.cpu()[0]]),\n            \"y_pred_neg\": np.array(pred.cpu()[1:]),\n            \"eval_metric\": [metric],\n        }\n        perf_list.append(evaluator.eval(input_dict)[metric])\n\n    perf_metrics_mean = float(np.mean(perf_list))\n    perf_metrics_std = float(np.std(perf_list))\n\n    return perf_metrics_mean, perf_metrics_std, perf_list\n\n\nargs = get_args()\n\nargs.use_graph_structure = True\nargs.use_onehot_node_feats = False\nargs.ignore_node_feats = False # we only use graph structure\nargs.use_type_feats = True # type encoding\nargs.use_cached_subgraph = True\n\nprint(args)\n\nargs.device = f\"cuda:{args.device}\" if torch.cuda.is_available() else \"cpu\"\nargs.device = torch.device(args.device)\nSEED = args.seed\nBATCH_SIZE = args.batch_size\nNUM_RUNS = args.num_run\nset_seed(SEED)\n\n\n###################################################\n# load feats + graph\nnode_feats, edge_feats, g, df, args = load_all_data(args)\n\n###################################################\n# get model \nmodel, args, link_pred_train = load_model(args)\n\n###################################################\n\nprint(\"==========================================================\")\nprint(f\"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============\")\nprint(\"==========================================================\")\n\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'\n\n\n\nfor run_idx in range(NUM_RUNS):\n    print('-------------------------------------------------------------------------------')\n    print(f\"INFO: >>>>> Run: {run_idx} <<<<<\")\n    start_run = timeit.default_timer()\n\n    # set the seed for deterministic results...\n    torch.manual_seed(run_idx + SEED)\n    set_random_seed(run_idx + SEED)\n\n    # define an early stopper\n    save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'\n    save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'\n    # early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id, \n    #                                 tolerance=TOLERANCE, patience=PATIENCE)\n\n    # ==================================================== Train & Validation\n    # loading the validation negative samples\n\n    # Link prediction\n    start_val = timeit.default_timer()\n    print('Train link prediction task from scratch ...')\n    model = link_pred_train(model.to(args.device), args, g, df, node_feats, edge_feats)\n\n    dataset.load_val_ns()\n\n    # Validation ...\n    \n    perf_metrics_val_mean, perf_metrics_val_std, perf_list_val = test(data.to(args.device), test_mask, model.to(args.device), neg_sampler, split_mode='val')\n    end_val = timeit.default_timer()\n\n    print(f\"INFO: val: Evaluation Setting: >>> ONE-VS-MANY <<< \")\n    print(f\"\\tval: {metric}: {perf_metrics_val_mean: .4f} ± {perf_metrics_val_std: .4f}\")\n    val_time = timeit.default_timer() - start_val\n    print(f\"\\tval: Elapsed Time (s): {val_time: .4f}\")\n\n\n    dataset.load_test_ns()\n\n    # testing ...\n    start_test = timeit.default_timer()\n    perf_metrics_test_mean, perf_metrics_test_std, perf_list_test = test(data.to(args.device), test_mask, model.to(args.device), neg_sampler, split_mode='test')\n    end_test = timeit.default_timer()\n\n    print(f\"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< \")\n    print(f\"\\tTest: {metric}: {perf_metrics_test_mean: .4f} ± {perf_metrics_test_std: .4f}\")\n    test_time = timeit.default_timer() - start_test\n    print(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\n\n    save_results({'model': MODEL_NAME,\n                  'data': DATA,\n                  'run': run_idx,\n                  'seed': SEED,\n                  f'val {metric}': f'{perf_metrics_val_mean: .4f} ± {perf_metrics_val_std: .4f}' ,\n                  f'test {metric}': f'{perf_metrics_test_mean: .4f} ± {perf_metrics_test_std: .4f}' ,\n                  'test_time': test_time,\n                  'tot_train_val_time': val_time\n                  }, \n    results_filename)\n\n    print(f\"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<\")\n    print('-------------------------------------------------------------------------------')\n\nprint(f\"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}\")\nprint(\"==============================================================\")\n\n# save_results({'model': MODEL_NAME,\n#             'data': DATA,\n#             'run': 1,\n#             'seed': SEED,\n#             metric: perf_metric_test,\n#             'test_time': test_time,\n#             'tot_train_val_time': 'NA'\n#             }, \n#     results_filename)"
  },
  {
    "path": "examples/linkproppred/thgl-myket/tgn.py",
    "content": "import numpy as np\r\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\r\nfrom tgb.linkproppred.evaluate import Evaluator\r\nfrom torch_geometric.loader import TemporalDataLoader\r\nfrom tqdm import tqdm\r\nimport timeit\r\n\r\n\r\nimport math\r\nimport timeit\r\n\r\nimport os\r\nimport os.path as osp\r\nfrom pathlib import Path\r\nimport numpy as np\r\n\r\nimport torch\r\nfrom sklearn.metrics import average_precision_score, roc_auc_score\r\nfrom torch.nn import Linear\r\n\r\nfrom torch_geometric.datasets import JODIEDataset\r\nfrom torch_geometric.loader import TemporalDataLoader\r\nfrom torch_geometric.nn import TransformerConv\r\n\r\n# internal imports\r\nfrom tgb.utils.utils import get_args, set_random_seed, save_results\r\nfrom tgb.linkproppred.evaluate import Evaluator\r\nfrom modules.decoder import LinkPredictor\r\nfrom modules.emb_module import GraphAttentionEmbedding\r\nfrom modules.msg_func import IdentityMessage\r\nfrom modules.msg_agg import LastAggregator\r\nfrom modules.neighbor_loader import LastNeighborLoader\r\nfrom modules.memory_module import TGNMemory\r\nfrom modules.early_stopping import  EarlyStopMonitor\r\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\r\n\r\n\r\n# ==========\r\n# ========== Define helper function...\r\n# ==========\r\n\r\ndef train():\r\n    r\"\"\"\r\n    Training procedure for TGN model\r\n    This function uses some objects that are globally defined in the current scrips \r\n\r\n    Parameters:\r\n        None\r\n    Returns:\r\n        None\r\n            \r\n    \"\"\"\r\n\r\n    model['memory'].train()\r\n    model['gnn'].train()\r\n    model['link_pred'].train()\r\n\r\n    model['memory'].reset_state()  # Start with a fresh memory.\r\n    neighbor_loader.reset_state()  # Start with an empty graph.\r\n\r\n    total_loss = 0\r\n    for batch in train_loader:\r\n        batch = batch.to(device)\r\n        optimizer.zero_grad()\r\n\r\n        src, pos_dst, t, msg, rel = batch.src, batch.dst, batch.t, batch.msg, batch.edge_type\r\n\r\n        # Sample negative destination nodes.\r\n        neg_dst = torch.randint(\r\n            min_dst_idx,\r\n            max_dst_idx + 1,\r\n            (src.size(0),),\r\n            dtype=torch.long,\r\n            device=device,\r\n        )\r\n\r\n        n_id = torch.cat([src, pos_dst, neg_dst]).unique()\r\n        n_id, edge_index, e_id = neighbor_loader(n_id)\r\n        assoc[n_id] = torch.arange(n_id.size(0), device=device)\r\n\r\n        # Get updated memory of all nodes involved in the computation.\r\n        z, last_update = model['memory'](n_id)\r\n        z = model['gnn'](\r\n            z,\r\n            last_update,\r\n            edge_index,\r\n            data.t[e_id].to(device),\r\n            data.msg[e_id].to(device),\r\n        )\r\n\r\n        pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])\r\n        neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])\r\n\r\n        loss = criterion(pos_out, torch.ones_like(pos_out))\r\n        loss += criterion(neg_out, torch.zeros_like(neg_out))\r\n\r\n        # Update memory and neighbor loader with ground-truth state.\r\n        model['memory'].update_state(src, pos_dst, t, msg)\r\n        neighbor_loader.insert(src, pos_dst)\r\n\r\n        loss.backward()\r\n        optimizer.step()\r\n        model['memory'].detach()\r\n        total_loss += float(loss.detach()) * batch.num_events\r\n\r\n    return total_loss / train_data.num_events\r\n\r\n\r\n@torch.no_grad()\r\ndef test(loader, neg_sampler, split_mode):\r\n    r\"\"\"\r\n    Evaluated the dynamic link prediction\r\n    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges\r\n\r\n    Parameters:\r\n        loader: an object containing positive attributes of the positive edges of the evaluation set\r\n        neg_sampler: an object that gives the negative edges corresponding to each positive edge\r\n        split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives\r\n    Returns:\r\n        perf_metric: the result of the performance evaluaiton\r\n    \"\"\"\r\n    model['memory'].eval()\r\n    model['gnn'].eval()\r\n    model['link_pred'].eval()\r\n\r\n    perf_list = []\r\n\r\n    for pos_batch in loader:\r\n        pos_src, pos_dst, pos_t, pos_msg, pos_rel = (\r\n            pos_batch.src,\r\n            pos_batch.dst,\r\n            pos_batch.t,\r\n            pos_batch.msg,\r\n            pos_batch.edge_type\r\n        )\r\n\r\n        neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, pos_rel, split_mode=split_mode)\r\n\r\n        # pos_msg_new = torch.cat([pos_msg,pos_rel.unsqueeze(dim=1)], dim=1)   \r\n\r\n\r\n        for idx, neg_batch in enumerate(neg_batch_list):\r\n            src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)\r\n            dst = torch.tensor(\r\n                np.concatenate(\r\n                    ([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),\r\n                    axis=0,\r\n                ),\r\n                device=device,\r\n            )\r\n\r\n            n_id = torch.cat([src, dst]).unique()\r\n            n_id, edge_index, e_id = neighbor_loader(n_id)\r\n            assoc[n_id] = torch.arange(n_id.size(0), device=device)\r\n\r\n            # Get updated memory of all nodes involved in the computation.\r\n            z, last_update = model['memory'](n_id)\r\n            z = model['gnn'](\r\n                z,\r\n                last_update,\r\n                edge_index,\r\n                data.t[e_id].to(device),\r\n                data.msg[e_id].to(device),\r\n            )\r\n\r\n            y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])\r\n\r\n            # compute MRR\r\n            input_dict = {\r\n                \"y_pred_pos\": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),\r\n                \"y_pred_neg\": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),\r\n                \"eval_metric\": [metric],\r\n            }\r\n            perf_list.append(evaluator.eval(input_dict)[metric])\r\n\r\n        # Update memory and neighbor loader with ground-truth state.\r\n        model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg)\r\n        neighbor_loader.insert(pos_src, pos_dst)\r\n\r\n    perf_metrics = float(torch.tensor(perf_list).mean())\r\n\r\n    return perf_metrics\r\n\r\n# ==========\r\n# ==========\r\n# ==========\r\n\r\n\r\n# ==========\r\n# ==========\r\n# ==========\r\n\r\n\r\n# Start...\r\nstart_overall = timeit.default_timer()\r\nDATA = \"thgl-myket\"\r\n\r\n\r\n# ========== set parameters...\r\nargs, _ = get_args()\r\nargs.data = DATA\r\nprint(\"INFO: Arguments:\", args)\r\n\r\nLR = args.lr\r\nBATCH_SIZE = args.bs\r\nK_VALUE = args.k_value  \r\nNUM_EPOCH = args.num_epoch\r\nSEED = args.seed\r\nMEM_DIM = args.mem_dim\r\nTIME_DIM = args.time_dim\r\nEMB_DIM = args.emb_dim\r\nTOLERANCE = args.tolerance\r\nPATIENCE = args.patience\r\nNUM_RUNS = args.num_run\r\nNUM_NEIGHBORS = 10\r\nUSE_EDGE_TYPE = True\r\nUSE_NODE_TYPE = True\r\n\r\n\r\n\r\nMODEL_NAME = 'TGN'\r\n# ==========\r\n\r\n# set the device\r\ndevice = torch.device(\"cuda:2\" if torch.cuda.is_available() else \"cpu\")\r\nprint(device)\r\n\r\n\r\n# data loading\r\ndataset = PyGLinkPropPredDataset(name=DATA, root=\"datasets\")\r\ntrain_mask = dataset.train_mask\r\nval_mask = dataset.val_mask\r\ntest_mask = dataset.test_mask\r\ndata = dataset.get_TemporalData()\r\ndata = data.to(device)\r\nmetric = dataset.eval_metric\r\n\r\nprint (\"there are {} nodes and {} edges\".format(dataset.num_nodes, dataset.num_edges))\r\nprint (\"there are {} relation types\".format(dataset.num_rels))\r\n\r\n\r\ntimestamp = data.t\r\nhead = data.src\r\ntail = data.dst\r\nedge_type = data.edge_type #relation\r\nedge_type_dim = len(torch.unique(edge_type))\r\n\r\nembed_edge_type = torch.nn.Embedding(edge_type_dim, 128).to(device)\r\nwith torch.no_grad():\r\n    edge_type_embeddings = embed_edge_type(edge_type)\r\n\r\n\r\nif USE_EDGE_TYPE:\r\n    data.msg = torch.cat([data.msg, edge_type_embeddings], dim=1)\r\n\r\n#! node type is a property of the dataset not the temporal data as temporal data has one entry per edge\r\nnode_type = dataset.node_type #node type\r\nneg_sampler = dataset.negative_sampler\r\n\r\ndata.__setattr__(\"node_type\", node_type)\r\n\r\nprint (\"shape of edge type is\", edge_type.shape)\r\nprint (\"shape of node type is\", node_type.shape)\r\n\r\ntrain_data = data[train_mask]\r\nval_data = data[val_mask]\r\ntest_data = data[test_mask]\r\nprint (\"finished loading PyG data\")\r\n\r\ntrain_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)\r\nval_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)\r\ntest_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)\r\n\r\nstart_time = timeit.default_timer()\r\n\r\n# Ensure to only sample actual destination nodes as negatives.\r\nmin_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\r\n\r\n\r\nprint(\"==========================================================\")\r\nprint(f\"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============\")\r\nprint(\"==========================================================\")\r\n\r\nevaluator = Evaluator(name=DATA)\r\nneg_sampler = dataset.negative_sampler\r\n\r\n# for saving the results...\r\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\r\nif not osp.exists(results_path):\r\n    os.mkdir(results_path)\r\n    print('INFO: Create directory {}'.format(results_path))\r\nPath(results_path).mkdir(parents=True, exist_ok=True)\r\nresults_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'\r\n\r\nfor run_idx in range(NUM_RUNS):\r\n    print('-------------------------------------------------------------------------------')\r\n    print(f\"INFO: >>>>> Run: {run_idx} <<<<<\")\r\n    start_run = timeit.default_timer()\r\n\r\n    # set the seed for deterministic results...\r\n    torch.manual_seed(run_idx + SEED)\r\n    set_random_seed(run_idx + SEED)\r\n\r\n    # neighhorhood sampler\r\n    neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)\r\n\r\n    # define the model end-to-end\r\n    memory = TGNMemory(\r\n        data.num_nodes,\r\n        data.msg.size(-1),\r\n        MEM_DIM,\r\n        TIME_DIM,\r\n        message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),\r\n        aggregator_module=LastAggregator(),\r\n    ).to(device)\r\n\r\n    gnn = GraphAttentionEmbedding(\r\n        in_channels=MEM_DIM,\r\n        out_channels=EMB_DIM,\r\n        msg_dim=data.msg.size(-1),\r\n        time_enc=memory.time_enc,\r\n    ).to(device)\r\n\r\n    link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)\r\n\r\n    model = {'memory': memory,\r\n            'gnn': gnn,\r\n            'link_pred': link_pred}\r\n\r\n    optimizer = torch.optim.Adam(\r\n        set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),\r\n        lr=LR,\r\n    )\r\n    criterion = torch.nn.BCEWithLogitsLoss()\r\n\r\n    # Helper vector to map global node indices to local ones.\r\n    assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)\r\n\r\n    # define an early stopper\r\n    save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'\r\n    save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'\r\n    early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id, \r\n                                    tolerance=TOLERANCE, patience=PATIENCE)\r\n\r\n    # ==================================================== Train & Validation\r\n    # loading the validation negative samples\r\n    dataset.load_val_ns()\r\n\r\n    val_perf_list = []\r\n    start_train_val = timeit.default_timer()\r\n    for epoch in range(1, NUM_EPOCH + 1):\r\n        # training\r\n        start_epoch_train = timeit.default_timer()\r\n        loss = train()\r\n        print(\r\n            f\"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {timeit.default_timer() - start_epoch_train: .4f}\"\r\n        )\r\n\r\n        # validation\r\n        start_val = timeit.default_timer()\r\n        perf_metric_val = test(val_loader, neg_sampler, split_mode=\"val\")\r\n        print(f\"\\tValidation {metric}: {perf_metric_val: .4f}\")\r\n        print(f\"\\tValidation: Elapsed time (s): {timeit.default_timer() - start_val: .4f}\")\r\n        val_perf_list.append(perf_metric_val)\r\n\r\n        # check for early stopping\r\n        if early_stopper.step_check(perf_metric_val, model):\r\n            break\r\n\r\n    train_val_time = timeit.default_timer() - start_train_val\r\n    print(f\"Train & Validation: Elapsed Time (s): {train_val_time: .4f}\")\r\n\r\n    # ==================================================== Test\r\n    # first, load the best model\r\n    early_stopper.load_checkpoint(model)\r\n\r\n    # loading the test negative samples\r\n    dataset.load_test_ns()\r\n\r\n    # final testing\r\n    start_test = timeit.default_timer()\r\n    perf_metric_test = test(test_loader, neg_sampler, split_mode=\"test\")\r\n\r\n    print(f\"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< \")\r\n    print(f\"\\tTest: {metric}: {perf_metric_test: .4f}\")\r\n    test_time = timeit.default_timer() - start_test\r\n    print(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\r\n\r\n    save_results({'model': MODEL_NAME,\r\n                  'data': DATA,\r\n                  'run': run_idx,\r\n                  'seed': SEED,\r\n                  f'val {metric}': val_perf_list,\r\n                  f'test {metric}': perf_metric_test,\r\n                  'test_time': test_time,\r\n                  'tot_train_val_time': train_val_time\r\n                  }, \r\n    results_filename)\r\n\r\n    print(f\"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<\")\r\n    print('-------------------------------------------------------------------------------')\r\n\r\nprint(f\"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}\")\r\nprint(\"==============================================================\")\r\n\r\n# #* load numpy arrays instead\r\n# from tgb.linkproppred.dataset import LinkPropPredDataset\r\n\r\n# # data loading\r\n# dataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\r\n# data = dataset.full_data  \r\n# metric = dataset.eval_metric\r\n# sources = dataset.full_data['sources']\r\n# print (\"finished loading numpy arrays\")\r\n"
  },
  {
    "path": "examples/linkproppred/thgl-software/STHN_README.md",
    "content": "STHN method adopted from: https://github.com/celi52/STHN/tree/main\n\nTo run:\n\n1. Install requirements. The two new additional requirements for STHN are `pybind11` and `torchmetrics==0.11.0`\n2. Compile the sampler\n```bash\npython sthn_sampler_setup.py build_ext --inplace\n```\n\n3. Run the example code\n\n```bash\npython sthn.py\n```\n\nIf the code runs correctly the output would end with\n\n```\nINFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< \n        Test: mrr: X\n        Test: Elapsed Time (s):  Y\n```"
  },
  {
    "path": "examples/linkproppred/thgl-software/edgebank.py",
    "content": "\"\"\"\nDynamic Link Prediction with EdgeBank\nNOTE: This implementation works only based on `numpy`\n\nReference: \n    - https://github.com/fpour/DGB/tree/main\n\n\n\"\"\"\n\nimport timeit\nimport numpy as np\nfrom sklearn.metrics import average_precision_score, roc_auc_score\nfrom torch_geometric.loader import TemporalDataLoader\nfrom tqdm import tqdm\nimport math\nimport os\nimport os.path as osp\nfrom pathlib import Path\nimport sys\nimport argparse\n\n# internal imports\ntgb_modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\nsys.path.append(tgb_modules_path)\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom modules.edgebank_predictor import EdgeBankPredictor\nfrom tgb.utils.utils import set_random_seed\nfrom tgb.linkproppred.dataset import LinkPropPredDataset\nfrom tgb.utils.utils import save_results\n\n# ==================\n# ==================\n# ==================\n\ndef test(data, test_mask, neg_sampler, split_mode):\n    r\"\"\"\n    Evaluated the dynamic link prediction\n    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges\n\n    Parameters:\n        data: a dataset object\n        test_mask: required masks to load the test set edges\n        neg_sampler: an object that gives the negative edges corresponding to each positive edge\n        split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives\n    Returns:\n        perf_metric: the result of the performance evaluation\n    \"\"\"\n    num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)\n    perf_list = []\n    hits_list = []\n    for batch_idx in tqdm(range(num_batches)):\n        start_idx = batch_idx * BATCH_SIZE\n        end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))\n        pos_src, pos_dst, pos_t, pos_edge = (\n            data['sources'][test_mask][start_idx: end_idx],\n            data['destinations'][test_mask][start_idx: end_idx],\n            data['timestamps'][test_mask][start_idx: end_idx],\n            data['edge_type'][test_mask][start_idx: end_idx],\n        )\n        neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, pos_edge, split_mode=split_mode)\n        \n        for idx, neg_batch in enumerate(neg_batch_list):\n            query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])\n            query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])\n\n            y_pred = edgebank.predict_link(query_src, query_dst)\n            # compute MRR\n            input_dict = {\n                \"y_pred_pos\": np.array([y_pred[0]]),\n                \"y_pred_neg\": np.array(y_pred[1:]),\n                \"eval_metric\": [metric],\n            }\n            results = evaluator.eval(input_dict)\n            perf_list.append(results[metric])\n            hits_list.append(results['hits@10'])\n            \n        # update edgebank memory after each positive batch\n        edgebank.update_memory(pos_src, pos_dst, pos_t)\n\n    perf_metrics = float(np.mean(perf_list))\n    perf_hits = float(np.mean(hits_list))\n\n    return perf_metrics, perf_hits\n\ndef get_args():\n    parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')\n    parser.add_argument('-d', '--data', type=str, help='Dataset name', default='thgl-software')\n    parser.add_argument('--bs', type=int, help='Batch size', default=200)\n    parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)\n    parser.add_argument('--seed', type=int, help='Random seed', default=1)\n    parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])\n    parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)\n\n    try:\n        args = parser.parse_args()\n    except:\n        parser.print_help()\n        sys.exit(0)\n    return args, sys.argv \n\n# ==================\n# ==================\n# ==================\n\nstart_overall = timeit.default_timer()\n\n# set hyperparameters\nargs, _ = get_args()\n\nSEED = args.seed  # set the random seed for consistency\nset_random_seed(SEED)\nMEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`\nBATCH_SIZE = args.bs\nK_VALUE = args.k_value\nTIME_WINDOW_RATIO = args.time_window_ratio\nDATA = args.data\nMODEL_NAME = 'EdgeBank'\n\n\n\n# data loading with `numpy`\ndataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\ndata = dataset.full_data  \nmetric = dataset.eval_metric\n\n\n# get masks\ntrain_mask = dataset.train_mask\nval_mask = dataset.val_mask\ntest_mask = dataset.test_mask\n\n#data for memory in edgebank\nhist_src = np.concatenate([data['sources'][train_mask]])\nhist_dst = np.concatenate([data['destinations'][train_mask]])\nhist_ts = np.concatenate([data['timestamps'][train_mask]])\n\n# Set EdgeBank with memory updater\nedgebank = EdgeBankPredictor(\n        hist_src,\n        hist_dst,\n        hist_ts,\n        memory_mode=MEMORY_MODE,\n        time_window_ratio=TIME_WINDOW_RATIO)\n\nprint(\"==========================================================\")\nprint(f\"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============\")\nprint(\"==========================================================\")\n\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'\n\n# ==================================================== Test\n# loading the validation negative samples\ndataset.load_val_ns()\n\n# testing ...\nstart_val = timeit.default_timer()\nperf_metric_val, perf_hits_val = test(data, val_mask, neg_sampler, split_mode='val')\nend_val = timeit.default_timer()\n\nprint(f\"INFO: val: Evaluation Setting: >>> ONE-VS--ALL <<< \")\nprint(f\"\\tval: {metric}: {perf_metric_val: .4f}\")\nval_time = timeit.default_timer() - start_val\nprint(f\"\\tval: Elapsed Time (s): {val_time: .4f}\")\n\n\n\n\n# ==================================================== Test\n# loading the test negative samples\ndataset.load_test_ns()\n\n# testing ...\nstart_test = timeit.default_timer()\nperf_metric_test, perf_hits_test = test(data, test_mask, neg_sampler, split_mode='test')\nend_test = timeit.default_timer()\n\nprint(f\"INFO: Test: Evaluation Setting: >>>  <<< \")\nprint(f\"\\tTest: {metric}: {perf_metric_test: .4f}\")\ntest_time = timeit.default_timer() - start_test\nprint(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\n\nsave_results({'model': MODEL_NAME,\n              'memory_mode': MEMORY_MODE,\n              'data': DATA,\n              'run': 1,\n              'seed': SEED,\n              metric: perf_metric_test,\n              'val_mrr': perf_metric_val,\n              'test_time': test_time,\n              'tot_train_val_time': test_time+val_time,\n              'hits10': perf_hits_test}, \n    results_filename)\n\n"
  },
  {
    "path": "examples/linkproppred/thgl-software/recurrencybaseline.py",
    "content": "\"\"\" from paper: \"History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting\" \nJulia Gastinger, Christian Meilicke, Federico Errica, Timo Sztyler, Anett Schuelke, Heiner Stuckenschmidt (IJCAI 2024)\n\n@inproceedings{gastinger2024baselines,\n  title={History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting},\n  author={Gastinger, Julia and Meilicke, Christian and Errica, Federico and Sztyler, Timo and Schuelke, Anett and Stuckenschmidt, Heiner},\n  booktitle={33nd International Joint Conference on Artificial Intelligence (IJCAI 2024)},\n  year={2024},\n  organization={International Joint Conferences on Artificial Intelligence Organization}\n}\n\n\"\"\"\n\n## imports\nimport timeit\nimport argparse\nimport numpy as np\nfrom copy import copy\nfrom pathlib import Path\nimport ray\nimport sys\nimport os\nimport os.path as osp\nimport json\n#internal imports \nmodules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\nsys.path.append(modules_path)\nfrom modules.recurrencybaseline_predictor import baseline_predict, baseline_predict_remote\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom tgb.linkproppred.dataset import LinkPropPredDataset \nfrom tgb.utils.utils import set_random_seed,  save_results \nfrom modules.tkg_utils import create_basis_dict, group_by, reformat_ts\n\ndef predict(num_processes,  data_c_rel, all_data_c_rel, alpha, lmbda_psi,\n            perf_list_all, hits_list_all, window, neg_sampler, split_mode):\n    \"\"\" create predictions for each relation on test or valid set and compute mrr\n    :return  perf_list_all: list of mrrs for each test query\n    :return hits_list_all: list of hits for each test query\n    \"\"\"\n\n    first_ts = data_c_rel[0][3]\n    ## use this if you wanna use ray:\n    num_queries = len(data_c_rel) // num_processes\n    if num_queries < num_processes: # if we do not have enough queries for all the processes\n        num_processes_tmp = 1\n        num_queries = len(data_c_rel)\n    else:\n        num_processes_tmp = num_processes  \n    if num_processes > 1:\n        object_references =[]                   \n        \n        for i in range(num_processes_tmp):\n            num_test_queries = len(data_c_rel) - (i + 1) * num_queries\n            if num_test_queries >= num_queries:\n                test_queries_idx =[i * num_queries, (i + 1) * num_queries]\n            else:\n                test_queries_idx = [i * num_queries, len(test_data)]\n\n            valid_data_b = data_c_rel[test_queries_idx[0]:test_queries_idx[1]]\n\n            ob = baseline_predict_remote.remote(num_queries, valid_data_b, all_data_c_rel, window, \n                                basis_dict, \n                                num_nodes, num_rels, lmbda_psi, \n                                alpha, evaluator,first_ts, neg_sampler, split_mode)\n            object_references.append(ob)\n\n        output = ray.get(object_references)\n\n        # updates the scores and logging dict for each process\n        for proc_loop in range(num_processes_tmp):\n            perf_list_all.extend(output[proc_loop][0])\n            hits_list_all.extend(output[proc_loop][1])\n\n    else:\n        perf_list, hits_list = baseline_predict(len(data_c_rel), data_c_rel, all_data_c_rel, \n                            window, basis_dict, \n                            num_nodes, num_rels, lmbda_psi, \n                            alpha, evaluator, first_ts, neg_sampler, split_mode)                  \n        perf_list_all.extend(perf_list)\n        hits_list_all.extend(hits_list)\n    \n    return perf_list_all, hits_list_all\n\n## test\ndef test(best_config, all_relations,test_data_prel, all_data_prel, neg_sampler, num_processes, window, split_mode='test'):  \n    \"\"\" create predictions by loopoing through all relations on test or valid set and compute mrr\n    :return  perf_list_all: list of mrrs for each test query\n    :return hits_list_all: list of hits for each test query\n    \"\"\"       \n    perf_list_all = []\n    hits_list_all =[]\n    \n    csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'\n\n    ## loop through relations and apply baselines\n    \n    for rel in all_relations:\n        start =  timeit.default_timer()\n        if rel in test_data_prel.keys():\n            lmbda_psi = best_config[str(rel)]['lmbda_psi'][0]\n            alpha = best_config[str(rel)]['alpha'][0]\n\n            # test data for this relation\n            test_data_c_rel = test_data_prel[rel]\n            timesteps_test = list(set(test_data_c_rel[:,3]))\n            timesteps_test.sort()\n            all_data_c_rel = all_data_prel[rel]         \n            perf_list_rel = []\n            hits_list_rel = []\n            perf_list_rel, hits_list_rel = predict(num_processes, test_data_c_rel,\n                                                all_data_c_rel, alpha, lmbda_psi,perf_list_rel, hits_list_rel, \n                                                window, neg_sampler, split_mode)\n            perf_list_all.extend(perf_list_rel)\n            hits_list_all.extend(hits_list_rel)\n        else:\n            perf_list_rel =[]\n         \n        end =  timeit.default_timer()\n        total_time = round(end - start, 6)  \n        print(\"Relation {} finished in {} seconds.\".format(rel, total_time))\n        \n\n        with open(csv_file, 'a') as f:\n            f.write(\"{},{}\\n\".format(rel, perf_list_rel))\n\n    return perf_list_all, hits_list_all\n\ndef read_dict_compute_mrr(split_mode='test'):\n    \"\"\" read the results per relation  from a precreated file and compute mrr\n    :return  mrr_per_rel: dictionary of mrrs for each relation\n    :return all_mrrs: list of mrrs for all relations\n    \"\"\"\n    csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'\n    # Initialize an empty dictionary to store the data\n    results_per_rel_dict = {}\n    mrr_per_rel = {}\n    all_mrrs = []\n    # Open the file for reading\n    with open(csv_file, 'r') as f:\n        # Read each line in the file\n        for line in f:\n            # Split the line at the comma\n            parts = line.strip().split(',')\n            # Extract the key (the first part)\n            key = int(parts[0])\n            # Extract the values (the rest of the parts), remove square brackets\n            values = [float(value.strip('[]')) for value in parts[1:]]\n            # Add the key-value pair to the dictionary\n            if key in results_per_rel_dict.keys():\n                print(f\"Key {key} already exists in the dictionary!!! might have duplicate entries in results csv\")\n            results_per_rel_dict[key] = values\n            all_mrrs.extend(values)\n            mrr_per_rel[key] = np.mean(values)\n\n    if len(list(results_per_rel_dict.keys())) != num_rels:\n        print(\"we do not have entries for each rel in the results csv file. only num enties: \", len(list(results_per_rel_dict.keys())))\n\n    print(\"Split mode: \"+split_mode +\" Mean MRR: \", np.mean(all_mrrs))\n    print(\"mrr per relation: \", mrr_per_rel)\n\n\n\n## train\ndef train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampler, num_processes, window):\n    \"\"\" optional, find best values for lambda and alpha by looping through all relations and testing an a fixed set of params\n    based on validation mrr\n    :return best_config: dictionary of best params for each relation\n    \"\"\"\n    best_config= {}\n    best_mrr = 0\n    for rel in rels: # loop through relations. for each relation, apply rules with selected params, compute valid mrr\n        start =  timeit.default_timer()\n        rel_key = int(rel)            \n\n        best_config[str(rel_key)] = {}\n        best_config[str(rel_key)]['not_trained'] = 'True'    \n        best_config[str(rel_key)]['lmbda_psi'] = [default_lmbda_psi,0] #default\n        best_config[str(rel_key)]['other_lmbda_mrrs'] = list(np.zeros(len(params_dict['lmbda_psi'])))\n        best_config[str(rel_key)]['alpha'] = [default_alpha,0]  #default    \n        best_config[str(rel_key)]['other_alpha_mrrs'] = list(np.zeros(len(params_dict['alpha'])))\n        \n        if rel in val_data_prel.keys():      \n            # valid data for this relation  \n            val_data_c_rel = val_data_prel[rel]\n            timesteps_valid = list(set(val_data_c_rel[:,3]))\n            timesteps_valid.sort()\n            trainval_data_c_rel = trainval_data_prel[rel]\n\n            ######  1) select lambda ###############        \n            lmbdas_psi = params_dict['lmbda_psi']        \n\n            alpha = 1\n            best_lmbda_psi = 0.1\n            best_mrr_psi = 0\n            lmbda_mrrs = []\n\n            best_config[str(rel_key)]['num_app_valid'] = copy(len(val_data_c_rel))\n            best_config[str(rel_key)]['num_app_train_valid'] = copy(len(trainval_data_c_rel))         \n            best_config[str(rel_key)]['not_trained'] = 'False'       \n            \n            for lmbda_psi in lmbdas_psi:   \n                perf_list_r = []\n                hits_list_r = []\n                perf_list_r, hits_list_r = predict(num_processes, val_data_c_rel, \n                                                    trainval_data_c_rel, alpha, lmbda_psi,perf_list_r, hits_list_r, \n                                                    window, neg_sampler, split_mode='val')\n                # compute mrr\n                mrr = np.mean(perf_list_r)\n                # # is new mrr better than previous best? if yes: store lmbda\n                if mrr > best_mrr_psi:\n                    best_mrr_psi = float(mrr)\n                    best_lmbda_psi = lmbda_psi\n\n\n                lmbda_mrrs.append(float(mrr))\n            best_config[str(rel_key)]['lmbda_psi'] = [best_lmbda_psi, best_mrr_psi]\n            best_config[str(rel_key)]['other_lmbda_mrrs'] = lmbda_mrrs\n            best_mrr = best_mrr_psi\n            ##### 2) select alpha ###############\n            best_config[str(rel_key)]['not_trained'] = 'False'    \n            alphas = params_dict['alpha'] \n            lmbda_psi = best_config[str(rel_key)]['lmbda_psi'][0] # use the best lmbda psi\n\n            alpha_mrrs = []\n            # perf_list_all = []\n            best_mrr_alpha = 0\n            best_alpha=0.99\n            for alpha in alphas:\n                perf_list_r = []\n                hits_list_r = []\n\n                perf_list_r, hits_list_r = predict(num_processes, val_data_c_rel, \n                                                    trainval_data_c_rel, alpha, lmbda_psi,perf_list_r, hits_list_r, \n                                                    window, neg_sampler, split_mode='val')\n                # compute mrr\n                mrr_alpha = np.mean(perf_list_r)\n\n                # is new mrr better than previous best? if yes: store alpha\n                if mrr_alpha > best_mrr_alpha:\n                    best_mrr_alpha = float(mrr_alpha)\n                    best_alpha = alpha\n                    best_mrr = best_mrr_alpha\n                alpha_mrrs.append(float(mrr_alpha))\n\n            best_config[str(rel_key)]['alpha'] = [best_alpha, best_mrr_alpha]\n            best_config[str(rel_key)]['other_alpha_mrrs'] = alpha_mrrs\n\n        end =  timeit.default_timer()\n        total_time = round(end - start, 6)  \n        print(\"Relation {} finished in {} seconds.\".format(rel, total_time))\n    return best_config\n\n\n\n## args\ndef get_args(): \n    \"\"\"parse all arguments for the script\"\"\"\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--dataset\", \"-d\", default=\"thgl-software\", type=str) \n    parser.add_argument(\"--window\", \"-w\", default=0, type=int) # set to e.g. 200 if only the most recent 200 timesteps should be considered. set to -2 if multistep\n    parser.add_argument(\"--num_processes\", \"-p\", default=1, type=int)\n    parser.add_argument(\"--lmbda\", \"-l\",  default=0.1, type=float) # fix lambda. used if trainflag == false\n    parser.add_argument(\"--alpha\", \"-alpha\",  default=0.99, type=float) # fix alpha. used if trainflag == false\n    parser.add_argument(\"--train_flag\", \"-tr\",  default='False') # do we need training, ie selection of lambda and alpha\n    parser.add_argument(\"--load_flag\", \"-lo\",  default='False') # if train_flag set to True: do you want to load best_config?\n    parser.add_argument(\"--save_config\", \"-c\",  default='True') # do we need to save the selection of lambda and alpha in config file?\n    parser.add_argument('--seed', type=int, help='Random seed', default=1) # not needed\n    parsed = vars(parser.parse_args())\n    return parsed\n\nstart_o =  timeit.default_timer()\n\nparsed = get_args()\nif parsed['num_processes']>1:\n    ray.init(num_cpus=parsed[\"num_processes\"], num_gpus=0)\nMODEL_NAME = 'RecurrencyBaseline'\nSEED = parsed['seed']  # set the random seed for consistency\nset_random_seed(SEED)\nperrel_results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_models'\nif not osp.exists(perrel_results_path):\n    os.mkdir(perrel_results_path)\n    print('INFO: Create directory {}'.format(perrel_results_path))\nPath(perrel_results_path).mkdir(parents=True, exist_ok=True)\n\n## load dataset and prepare it accordingly\nname = parsed[\"dataset\"]\ndataset = LinkPropPredDataset(name=name, root=\"datasets\", preprocess=True)\nDATA = name\n\nrelations = dataset.edge_type\nnum_rels = dataset.num_rels\nrels = np.arange(0,num_rels)\nsubjects = dataset.full_data[\"sources\"]\nobjects= dataset.full_data[\"destinations\"]\nnum_nodes = dataset.num_nodes \ntimestamps_orig = dataset.full_data[\"timestamps\"]\ntimestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1\n\nprint(\"split train valid test data\")\nall_quads = np.stack((subjects, relations, objects, timestamps, timestamps_orig), axis=1)\ntrain_data = all_quads[dataset.train_mask]\nval_data = all_quads[dataset.val_mask]\ntest_data = all_quads[dataset.test_mask]\n\nmetric = dataset.eval_metric\nevaluator = Evaluator(name=name)\nneg_sampler = dataset.negative_sampler\n\ntrain_val_data = np.concatenate([train_data, val_data])\nall_data = np.concatenate([train_data, val_data, test_data])\n\n# create dicts with key: relation id, values: triples for that relation id\nprint(\"grouping data by relation\")\ntest_data_prel = group_by(test_data, 1)\nall_data_prel = group_by(all_data, 1)\nval_data_prel = group_by(val_data, 1)\ntrainval_data_prel = group_by(train_val_data, 1)\n\n#load the ns samples \n# if parsed['train_flag']:\nprint(\"loading negative samples\")\ndataset.load_val_ns()\ndataset.load_test_ns()\n\n# parameter options\nif parsed['train_flag'] == 'True':\n    params_dict = {}\n    params_dict['lmbda_psi'] = [0, 0.0001, 0.0005, 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.5, 0.9, 1.0001] \n    params_dict['alpha'] = [0, 0.00001, 0.0001, 0.001, 0.01, 0.1, 0.5, 0.9, 0.99, 0.999, 0.9999, 0.99999, 1]\n    default_lmbda_psi = params_dict['lmbda_psi'][-1]\n    default_alpha = params_dict['alpha'][-2]\n\n## load rules\nprint(\"creating rules\")\nbasis_dict = create_basis_dict(train_val_data)\nprint(\"done with creating rules\")\n## init\n# rb_predictor = RecurrencyBaselinePredictor(rels)\n## train to find best lambda and alpha\nstart_train =  timeit.default_timer()\nif parsed['train_flag'] ==  'True':\n    if parsed['load_flag'] == 'True':\n        with open('best_config.json', 'r') as infile:\n            best_config = json.load(infile)\n    else:\n        print('start training')\n        best_config = train(params_dict,  rels, val_data_prel, trainval_data_prel, neg_sampler, parsed['num_processes'], \n            parsed['window'])\n        if parsed['save_config'] == 'True':\n            import json\n            with open('best_config.json', 'w') as outfile:\n                json.dump(best_config, outfile)\n\nelse: # use preset lmbda and alpha; same for all relations\n    best_config = {} \n    for rel in rels:\n        best_config[str(rel)] = {}\n        best_config[str(rel)]['lmbda_psi'] = [parsed['lmbda']]\n        best_config[str(rel)]['alpha'] = [parsed['alpha']]\n    \nend_train =  timeit.default_timer()\n\n# compute validation mrr\nprint(\"Computing validation MRR\")\nperf_list_all_val, hits_list_all_val = test(best_config,rels, val_data_prel, \n                                                trainval_data_prel, neg_sampler, parsed['num_processes'], \n                                            parsed['window'], split_mode='val')\nval_mrr = float(np.mean(perf_list_all_val))\n\n# compute test mrr\nprint(\"Computing test MRR\")\nstart_test =  timeit.default_timer()\nperf_list_all, hits_list_all = test(best_config,rels, test_data_prel, \n                                                 all_data_prel, neg_sampler, parsed['num_processes'], \n                                                parsed['window'])\n\nend_o =  timeit.default_timer()\ntrain_time_o = round(end_train- start_train, 6)  \ntest_time_o = round(end_o- start_test, 6)  \ntotal_time_o = round(end_o- start_o, 6)  \nprint(\"Running Training to find best configs finished in {} seconds.\".format(train_time_o))\nprint(\"Running testing with best configs finished in {} seconds.\".format(test_time_o))\nprint(\"Running all steps finished in {} seconds.\".format(total_time_o))\n\nprint(f\"The test MRR is {np.mean(perf_list_all)}\")\nprint(f\"The valid MRR is {val_mrr}\")\nprint(f\"The Hits@10 is {np.mean(hits_list_all)}\")\nprint(f\"We have {len(perf_list_all)} predictions\")\nprint(f\"The test set has len {len(test_data)} \")\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_NONE_{DATA}_results.json'\nmetric = dataset.eval_metric\nsave_results({'model': MODEL_NAME,\n              'train_flag': parsed['train_flag'],\n              'data': DATA,\n              'run': 1,\n              'seed': SEED,\n              metric: float(np.mean(perf_list_all)),\n              'val_mrr': val_mrr,\n              'hits10': float(np.mean(hits_list_all)),\n              'test_time': test_time_o,\n              'tot_train_val_time': total_time_o\n              }, \n    results_filename)\n\nif parsed['num_processes']>1:\n    ray.shutdown()\n\n\n    \n"
  },
  {
    "path": "examples/linkproppred/thgl-software/sthn.py",
    "content": "import timeit\nimport numpy as np\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\nfrom tgb.linkproppred.evaluate import Evaluator\n\nimport argparse\nfrom modules.sthn import set_seed, pre_compute_subgraphs, get_inputs_for_ind, check_data_leakage\nimport torch\nimport pandas as pd\nimport itertools\nfrom tqdm import tqdm\nimport math\nimport os\nimport os.path as osp\nfrom pathlib import Path\nfrom tgb.utils.utils import set_random_seed, save_results\n\n\n# Start...\nstart_overall = timeit.default_timer()\n\nDATA = \"thgl-software\"\n\nMODEL_NAME = 'STHN'\n\n# data loading\ndataset = PyGLinkPropPredDataset(name=DATA, root=\"datasets\")\ntrain_mask = dataset.train_mask\nval_mask = dataset.val_mask\ntest_mask = dataset.test_mask\ndata = dataset.get_TemporalData()\nmetric = dataset.eval_metric\n\nprint (\"there are {} nodes and {} edges\".format(dataset.num_nodes, dataset.num_edges))\nprint (\"there are {} relation types\".format(dataset.num_rels))\n\n\ntimestamp = data.t\nhead = data.src\ntail = data.dst\nedge_type = data.edge_type\nneg_sampler = dataset.negative_sampler\n\nprint(data)\nprint(timestamp)\nprint(head)\nprint(tail)\n\ntrain_data = data[train_mask]\nval_data = data[val_mask]\ntest_data = data[test_mask]\n\n\nmetric = dataset.eval_metric\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'\n\n\n####################################################################\n####################################################################\n####################################################################\n\n\ndef print_model_info(model):\n    print(model)\n    parameters = filter(lambda p: p.requires_grad, model.parameters())\n    parameters = sum([np.prod(p.size()) for p in parameters])\n    print('Trainable Parameters: %d' % parameters)\n\ndef get_args():\n    parser=argparse.ArgumentParser()\n    parser.add_argument('--data', type=str, default='movie')\n    parser.add_argument('--device', type=int, default=0)\n    parser.add_argument('--batch_size', type=int, default=600)\n    parser.add_argument('--epochs', type=int, default=2)\n    parser.add_argument('--max_edges', type=int, default=50)\n    parser.add_argument('--num_edgeType', type=int, default=0, help='num of edgeType')\n    parser.add_argument('--lr', type=float, default=0.0005)\n    parser.add_argument('--weight_decay', type=float, default=1e-4)\n    parser.add_argument('--predict_class', action='store_true')\n    \n    # model\n    parser.add_argument('--window_size', type=int, default=5)\n    parser.add_argument('--dropout', type=float, default=0.1)\n    parser.add_argument('--model', type=str, default='sthn') \n    parser.add_argument('--neg_samples', type=int, default=1)\n    parser.add_argument('--extra_neg_samples', type=int, default=5)\n    parser.add_argument('--num_neighbors', type=int, default=50)\n    parser.add_argument('--channel_expansion_factor', type=int, default=2)\n    parser.add_argument('--sampled_num_hops', type=int, default=1)\n    parser.add_argument('--time_dims', type=int, default=100)\n    parser.add_argument('--hidden_dims', type=int, default=100)\n    parser.add_argument('--num_layers', type=int, default=1)\n    parser.add_argument('--check_data_leakage', action='store_true')\n    \n    parser.add_argument('--ignore_node_feats', action='store_true')\n    parser.add_argument('--node_feats_as_edge_feats', action='store_true')\n    parser.add_argument('--ignore_edge_feats', action='store_true')\n    parser.add_argument('--use_onehot_node_feats', action='store_true')\n    parser.add_argument('--use_type_feats', action='store_true')\n\n    parser.add_argument('--use_graph_structure', action='store_true')\n    parser.add_argument('--structure_time_gap', type=int, default=2000)\n    parser.add_argument('--structure_hops', type=int, default=1) \n\n    parser.add_argument('--use_node_cls', action='store_true')\n    parser.add_argument('--use_cached_subgraph', action='store_true')\n    \n    parser.add_argument('--seed', type=int, help='Random seed', default=1)\n    parser.add_argument('--num_run', type=int, help='Number of iteration runs', default=5)\n    return parser.parse_args()\n\n\ndef load_model(args):\n    # get model\n    edge_predictor_configs = {\n        'dim_in_time': args.time_dims,\n        'dim_in_node': args.node_feat_dims,\n        'predict_class': 1 if not args.predict_class else args.num_edgeType+1,\n    }\n    if args.model == 'sthn':\n        if args.predict_class:\n            from modules.sthn import Multiclass_Interface as STHN_Interface\n        else:\n            from modules.sthn import STHN_Interface\n        from modules.sthn import link_pred_train\n\n        mixer_configs = {\n            'per_graph_size'  : args.max_edges,\n            'time_channels'   : args.time_dims, \n            'input_channels'  : args.edge_feat_dims, \n            'hidden_channels' : args.hidden_dims, \n            'out_channels'    : args.hidden_dims,\n            'num_layers'      : args.num_layers,\n            'dropout'         : args.dropout,\n            'channel_expansion_factor': args.channel_expansion_factor,\n            'window_size'     : args.window_size,\n            'use_single_layer' : False\n        }  \n        \n    else:\n        NotImplementedError()\n\n    model = STHN_Interface(mixer_configs, edge_predictor_configs)\n    for k, v in model.named_parameters():\n        print(k, v.requires_grad)\n\n    print_model_info(model)\n\n    return model, args, link_pred_train\n\ndef load_graph(data):\n    df = pd.DataFrame({\n        'idx': np.arange(len(data.t)),\n        'src': data.src,\n        'dst': data.dst,\n        'time': data.t,\n        'label': data.edge_type,\n    })\n\n    num_nodes = max(int(df['src'].max()), int(df['dst'].max())) + 1 \n\n    ext_full_indptr = np.zeros(num_nodes + 1, dtype=np.int32)\n    ext_full_indices = [[] for _ in range(num_nodes)]\n    ext_full_ts = [[] for _ in range(num_nodes)]\n    ext_full_eid = [[] for _ in range(num_nodes)]\n\n    for idx, row in tqdm(df.iterrows(), total=len(df)):\n        src = int(row['src'])\n        dst = int(row['dst'])\n        \n        ext_full_indices[src].append(dst)\n        ext_full_ts[src].append(row['time'])\n        ext_full_eid[src].append(idx)\n        \n    for i in tqdm(range(num_nodes)):\n        ext_full_indptr[i + 1] = ext_full_indptr[i] + len(ext_full_indices[i])\n\n    ext_full_indices = np.array(list(itertools.chain(*ext_full_indices)))\n    ext_full_ts = np.array(list(itertools.chain(*ext_full_ts)))\n    ext_full_eid = np.array(list(itertools.chain(*ext_full_eid)))\n\n    print('Sorting...')\n\n    def tsort(i, indptr, indices, t, eid):\n        beg = indptr[i]\n        end = indptr[i + 1]\n        sidx = np.argsort(t[beg:end])\n        indices[beg:end] = indices[beg:end][sidx]\n        t[beg:end] = t[beg:end][sidx]\n        eid[beg:end] = eid[beg:end][sidx] \n\n    for i in tqdm(range(ext_full_indptr.shape[0] - 1)):\n        tsort(i, ext_full_indptr, ext_full_indices, ext_full_ts, ext_full_eid)\n\n    print('saving...')\n\n    np.savez('/tmp/ext_full.npz', indptr=ext_full_indptr,\n            indices=ext_full_indices, ts=ext_full_ts, eid=ext_full_eid)\n    g = np.load('/tmp/ext_full.npz')\n    return g, df\n\ndef load_all_data(args):\n\n    # load graph\n    g, df = load_graph(data)\n\n    args.train_mask = train_mask.numpy()\n    args.val_mask   = val_mask.numpy()\n    args.test_mask = test_mask.numpy()\n    args.num_edges = len(df)\n\n    print('Train %d, Valid %d, Test %d'%(sum(args.train_mask), \n                                         sum(args.val_mask),\n                                         sum(test_mask)))\n    \n    args.num_nodes = max(int(df['src'].max()), int(df['dst'].max())) + 1\n    args.num_edges = len(df)\n\n    print('Num nodes %d, num edges %d'%(args.num_nodes, args.num_edges))\n\n    # load feats \n    node_feats, edge_feats = dataset.node_feat, dataset.edge_feat\n    node_feat_dims = 0 if node_feats is None else node_feats.shape[1]\n    edge_feat_dims = 0 if edge_feats is None else edge_feats.shape[1]\n\n    # feature pre-processing\n    if args.use_onehot_node_feats:\n        print('>>> Use one-hot node features')\n        node_feats = torch.eye(args.num_nodes)\n        node_feat_dims = node_feats.size(1)\n\n    if args.ignore_node_feats:\n        print('>>> Ignore node features')\n        node_feats = None\n        node_feat_dims = 0\n\n    if args.use_type_feats:\n        edge_type = df.label.values\n        print(edge_type)\n        print(edge_type.sum())\n        args.num_edgeType = len(set(edge_type.tolist()))\n        edge_feats = torch.nn.functional.one_hot(torch.from_numpy(edge_type), \n                                                 num_classes=args.num_edgeType)\n        edge_feat_dims = edge_feats.size(1)\n        \n    print('Node feature dim %d, edge feature dim %d'%(node_feat_dims, edge_feat_dims))\n    \n    # double check (if data leakage then cannot continue the code)\n    if args.check_data_leakage:\n        check_data_leakage(args, g, df)\n\n    args.node_feat_dims = node_feat_dims\n    args.edge_feat_dims = edge_feat_dims\n    \n    if node_feats != None:\n        node_feats = node_feats.to(args.device)\n    if edge_feats != None:\n        edge_feats = edge_feats.to(args.device)\n    \n    return node_feats, edge_feats, g, df, args\n\n####################################################################\n####################################################################\n####################################################################\n\n@torch.no_grad()\ndef test(data, test_mask, model, neg_sampler, split_mode):\n    r\"\"\"\n    Evaluated the dynamic link prediction\n    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges\n\n    Parameters:\n        data: a dataset object\n        test_mask: required masks to load the test set edges\n        neg_sampler: an object that gives the negative edges corresponding to each positive edge\n        split_mode: specifies whether it is the 'val' or 'test' set to correctly load the negatives\n    Returns:\n        perf_metric: the result of the performance evaluation\n    \"\"\"\n    test_subgraphs  = pre_compute_subgraphs(args, g, df, mode='test' if split_mode == 'test' else 'valid', negative_sampler=neg_sampler, split_mode=split_mode)\n    perf_list = []\n    \n    if split_mode == 'test':\n        cur_df = df[args.test_mask]\n    elif split_mode == 'val':\n        cur_df = df[args.val_mask]\n    neg_samples = 20\n    cached_neg_samples = 20\n\n    test_loader = cur_df.groupby(cur_df.index // args.batch_size)\n    pbar = tqdm(total=len(test_loader))\n    pbar.set_description('%s mode with negative samples %d ...'%(split_mode, neg_samples))        \n    \n    ###################################################\n    # compute + training + fetch all scores\n    cur_inds = 0\n\n    for ind in range(len(test_loader)):\n        ###################################################\n        inputs, subgraph_node_feats, cur_inds = get_inputs_for_ind(test_subgraphs, 'test' if split_mode == 'test' else 'tgb-val', cached_neg_samples, neg_samples, node_feats, edge_feats, cur_df, cur_inds, ind, args)\n        \n        loss, pred, edge_label = model(inputs, neg_samples, subgraph_node_feats)\n        # print(ind, [l for l in inputs], pred.shape)\n\n        input_dict = {\n            \"y_pred_pos\": np.array([pred.cpu()[0]]),\n            \"y_pred_neg\": np.array(pred.cpu()[1:]),\n            \"eval_metric\": [metric],\n        }\n        perf_list.append(evaluator.eval(input_dict)[metric])\n\n    perf_metrics_mean = float(np.mean(perf_list))\n    perf_metrics_std = float(np.std(perf_list))\n\n    return perf_metrics_mean, perf_metrics_std, perf_list\n\n\nargs = get_args()\n\nargs.use_graph_structure = True\nargs.use_onehot_node_feats = False\nargs.ignore_node_feats = False # we only use graph structure\nargs.use_type_feats = True # type encoding\nargs.use_cached_subgraph = True\n\nprint(args)\n\nargs.device = f\"cuda:{args.device}\" if torch.cuda.is_available() else \"cpu\"\nargs.device = torch.device(args.device)\nSEED = args.seed\nBATCH_SIZE = args.batch_size\nNUM_RUNS = args.num_run\nset_seed(SEED)\n\n\n###################################################\n# load feats + graph\nnode_feats, edge_feats, g, df, args = load_all_data(args)\n\n###################################################\n# get model \nmodel, args, link_pred_train = load_model(args)\n\n###################################################\n\nprint(\"==========================================================\")\nprint(f\"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============\")\nprint(\"==========================================================\")\n\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'\n\n\n\nfor run_idx in range(NUM_RUNS):\n    print('-------------------------------------------------------------------------------')\n    print(f\"INFO: >>>>> Run: {run_idx} <<<<<\")\n    start_run = timeit.default_timer()\n\n    # set the seed for deterministic results...\n    torch.manual_seed(run_idx + SEED)\n    set_random_seed(run_idx + SEED)\n\n    # define an early stopper\n    save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'\n    save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'\n    # early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id, \n    #                                 tolerance=TOLERANCE, patience=PATIENCE)\n\n    # ==================================================== Train & Validation\n    # loading the validation negative samples\n\n    # Link prediction\n    start_val = timeit.default_timer()\n    print('Train link prediction task from scratch ...')\n    model = link_pred_train(model.to(args.device), args, g, df, node_feats, edge_feats)\n\n    dataset.load_val_ns()\n\n    # Validation ...\n    \n    perf_metrics_val_mean, perf_metrics_val_std, perf_list_val = test(data.to(args.device), test_mask, model.to(args.device), neg_sampler, split_mode='val')\n    end_val = timeit.default_timer()\n\n    print(f\"INFO: val: Evaluation Setting: >>> ONE-VS-MANY <<< \")\n    print(f\"\\tval: {metric}: {perf_metrics_val_mean: .4f} ± {perf_metrics_val_std: .4f}\")\n    val_time = timeit.default_timer() - start_val\n    print(f\"\\tval: Elapsed Time (s): {val_time: .4f}\")\n\n\n    dataset.load_test_ns()\n\n    # testing ...\n    start_test = timeit.default_timer()\n    perf_metrics_test_mean, perf_metrics_test_std, perf_list_test = test(data.to(args.device), test_mask, model.to(args.device), neg_sampler, split_mode='test')\n    end_test = timeit.default_timer()\n\n    print(f\"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< \")\n    print(f\"\\tTest: {metric}: {perf_metrics_test_mean: .4f} ± {perf_metrics_test_std: .4f}\")\n    test_time = timeit.default_timer() - start_test\n    print(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\n\n    save_results({'model': MODEL_NAME,\n                  'data': DATA,\n                  'run': run_idx,\n                  'seed': SEED,\n                  f'val {metric}': f'{perf_metrics_val_mean: .4f} ± {perf_metrics_val_std: .4f}' ,\n                  f'test {metric}': f'{perf_metrics_test_mean: .4f} ± {perf_metrics_test_std: .4f}' ,\n                  'test_time': test_time,\n                  'tot_train_val_time': val_time\n                  }, \n    results_filename)\n\n    print(f\"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<\")\n    print('-------------------------------------------------------------------------------')\n\nprint(f\"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}\")\nprint(\"==============================================================\")\n\n# save_results({'model': MODEL_NAME,\n#             'data': DATA,\n#             'run': 1,\n#             'seed': SEED,\n#             metric: perf_metric_test,\n#             'test_time': test_time,\n#             'tot_train_val_time': 'NA'\n#             }, \n#     results_filename)"
  },
  {
    "path": "examples/linkproppred/thgl-software/tgn.py",
    "content": "import numpy as np\r\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\r\nfrom tgb.linkproppred.evaluate import Evaluator\r\nfrom torch_geometric.loader import TemporalDataLoader\r\nfrom tqdm import tqdm\r\nimport timeit\r\n\r\n\r\nimport math\r\nimport timeit\r\n\r\nimport os\r\nimport os.path as osp\r\nfrom pathlib import Path\r\nimport numpy as np\r\n\r\nimport torch\r\nfrom sklearn.metrics import average_precision_score, roc_auc_score\r\nfrom torch.nn import Linear\r\n\r\nfrom torch_geometric.datasets import JODIEDataset\r\nfrom torch_geometric.loader import TemporalDataLoader\r\nfrom torch_geometric.nn import TransformerConv\r\n\r\n# internal imports\r\nfrom tgb.utils.utils import get_args, set_random_seed, save_results\r\nfrom tgb.linkproppred.evaluate import Evaluator\r\nfrom modules.decoder import LinkPredictor\r\nfrom modules.emb_module import GraphAttentionEmbedding\r\nfrom modules.msg_func import IdentityMessage\r\nfrom modules.msg_agg import LastAggregator\r\nfrom modules.neighbor_loader import LastNeighborLoader\r\nfrom modules.memory_module import TGNMemory\r\nfrom modules.early_stopping import  EarlyStopMonitor\r\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\r\n\r\n\r\n# ==========\r\n# ========== Define helper function...\r\n# ==========\r\n\r\ndef train():\r\n    r\"\"\"\r\n    Training procedure for TGN model\r\n    This function uses some objects that are globally defined in the current scrips \r\n\r\n    Parameters:\r\n        None\r\n    Returns:\r\n        None\r\n            \r\n    \"\"\"\r\n\r\n    model['memory'].train()\r\n    model['gnn'].train()\r\n    model['link_pred'].train()\r\n\r\n    model['memory'].reset_state()  # Start with a fresh memory.\r\n    neighbor_loader.reset_state()  # Start with an empty graph.\r\n\r\n    total_loss = 0\r\n    for batch in train_loader:\r\n        batch = batch.to(device)\r\n        optimizer.zero_grad()\r\n\r\n        src, pos_dst, t, msg, rel = batch.src, batch.dst, batch.t, batch.msg, batch.edge_type\r\n\r\n        # Sample negative destination nodes.\r\n        neg_dst = torch.randint(\r\n            min_dst_idx,\r\n            max_dst_idx + 1,\r\n            (src.size(0),),\r\n            dtype=torch.long,\r\n            device=device,\r\n        )\r\n\r\n        n_id = torch.cat([src, pos_dst, neg_dst]).unique()\r\n        n_id, edge_index, e_id = neighbor_loader(n_id)\r\n        assoc[n_id] = torch.arange(n_id.size(0), device=device)\r\n\r\n        # Get updated memory of all nodes involved in the computation.\r\n        z, last_update = model['memory'](n_id)\r\n        z = model['gnn'](\r\n            z,\r\n            last_update,\r\n            edge_index,\r\n            data.t[e_id].to(device),\r\n            data.msg[e_id].to(device),\r\n        )\r\n\r\n        pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])\r\n        neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])\r\n\r\n        loss = criterion(pos_out, torch.ones_like(pos_out))\r\n        loss += criterion(neg_out, torch.zeros_like(neg_out))\r\n\r\n        # Update memory and neighbor loader with ground-truth state.\r\n        model['memory'].update_state(src, pos_dst, t, msg)\r\n        neighbor_loader.insert(src, pos_dst)\r\n\r\n        loss.backward()\r\n        optimizer.step()\r\n        model['memory'].detach()\r\n        total_loss += float(loss.detach()) * batch.num_events\r\n\r\n    return total_loss / train_data.num_events\r\n\r\n\r\n@torch.no_grad()\r\ndef test(loader, neg_sampler, split_mode):\r\n    r\"\"\"\r\n    Evaluated the dynamic link prediction\r\n    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges\r\n\r\n    Parameters:\r\n        loader: an object containing positive attributes of the positive edges of the evaluation set\r\n        neg_sampler: an object that gives the negative edges corresponding to each positive edge\r\n        split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives\r\n    Returns:\r\n        perf_metric: the result of the performance evaluaiton\r\n    \"\"\"\r\n    model['memory'].eval()\r\n    model['gnn'].eval()\r\n    model['link_pred'].eval()\r\n\r\n    perf_list = []\r\n\r\n    for pos_batch in loader:\r\n        pos_src, pos_dst, pos_t, pos_msg, pos_rel = (\r\n            pos_batch.src,\r\n            pos_batch.dst,\r\n            pos_batch.t,\r\n            pos_batch.msg,\r\n            pos_batch.edge_type\r\n        )\r\n\r\n        neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, pos_rel, split_mode=split_mode)\r\n\r\n        # pos_msg_new = torch.cat([pos_msg,pos_rel.unsqueeze(dim=1)], dim=1)   \r\n\r\n\r\n        for idx, neg_batch in enumerate(neg_batch_list):\r\n            src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)\r\n            dst = torch.tensor(\r\n                np.concatenate(\r\n                    ([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),\r\n                    axis=0,\r\n                ),\r\n                device=device,\r\n            )\r\n\r\n            n_id = torch.cat([src, dst]).unique()\r\n            n_id, edge_index, e_id = neighbor_loader(n_id)\r\n            assoc[n_id] = torch.arange(n_id.size(0), device=device)\r\n\r\n            # Get updated memory of all nodes involved in the computation.\r\n            z, last_update = model['memory'](n_id)\r\n            z = model['gnn'](\r\n                z,\r\n                last_update,\r\n                edge_index,\r\n                data.t[e_id].to(device),\r\n                data.msg[e_id].to(device),\r\n            )\r\n\r\n            y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])\r\n\r\n            # compute MRR\r\n            input_dict = {\r\n                \"y_pred_pos\": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),\r\n                \"y_pred_neg\": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),\r\n                \"eval_metric\": [metric],\r\n            }\r\n            perf_list.append(evaluator.eval(input_dict)[metric])\r\n\r\n        # Update memory and neighbor loader with ground-truth state.\r\n        model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg)\r\n        neighbor_loader.insert(pos_src, pos_dst)\r\n\r\n    perf_metrics = float(torch.tensor(perf_list).mean())\r\n\r\n    return perf_metrics\r\n\r\n# ==========\r\n# ==========\r\n# ==========\r\n\r\n\r\n# ==========\r\n# ==========\r\n# ==========\r\n\r\n\r\n# Start...\r\nstart_overall = timeit.default_timer()\r\nDATA = \"thgl-software\"\r\n\r\n\r\n# ========== set parameters...\r\nargs, _ = get_args()\r\nargs.data = DATA\r\nprint(\"INFO: Arguments:\", args)\r\n\r\nLR = args.lr\r\nBATCH_SIZE = args.bs\r\nK_VALUE = args.k_value  \r\nNUM_EPOCH = args.num_epoch\r\nSEED = args.seed\r\nMEM_DIM = args.mem_dim\r\nTIME_DIM = args.time_dim\r\nEMB_DIM = args.emb_dim\r\nTOLERANCE = args.tolerance\r\nPATIENCE = args.patience\r\nNUM_RUNS = args.num_run\r\nNUM_NEIGHBORS = 10\r\nUSE_EDGE_TYPE = True\r\nUSE_NODE_TYPE = True\r\n\r\n\r\n\r\nMODEL_NAME = 'TGN'\r\n# ==========\r\n\r\n# set the device\r\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\r\n\r\n\r\n# data loading\r\ndataset = PyGLinkPropPredDataset(name=DATA, root=\"datasets\")\r\ntrain_mask = dataset.train_mask\r\nval_mask = dataset.val_mask\r\ntest_mask = dataset.test_mask\r\ndata = dataset.get_TemporalData()\r\ndata = data.to(device)\r\nmetric = dataset.eval_metric\r\n\r\nprint (\"there are {} nodes and {} edges\".format(dataset.num_nodes, dataset.num_edges))\r\nprint (\"there are {} relation types\".format(dataset.num_rels))\r\n\r\n\r\ntimestamp = data.t\r\nhead = data.src\r\ntail = data.dst\r\nedge_type = data.edge_type #relation\r\nedge_type_dim = len(torch.unique(edge_type))\r\n\r\nembed_edge_type = torch.nn.Embedding(edge_type_dim, 128).to(device)\r\nwith torch.no_grad():\r\n    edge_type_embeddings = embed_edge_type(edge_type)\r\n\r\n\r\nif USE_EDGE_TYPE:\r\n    data.msg = torch.cat([data.msg, edge_type_embeddings], dim=1)\r\n\r\n#! node type is a property of the dataset not the temporal data as temporal data has one entry per edge\r\nnode_type = dataset.node_type #node type\r\nneg_sampler = dataset.negative_sampler\r\n\r\ndata.__setattr__(\"node_type\", node_type)\r\n\r\nprint (\"shape of edge type is\", edge_type.shape)\r\nprint (\"shape of node type is\", node_type.shape)\r\n\r\ntrain_data = data[train_mask]\r\nval_data = data[val_mask]\r\ntest_data = data[test_mask]\r\nprint (\"finished loading PyG data\")\r\n\r\ntrain_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)\r\nval_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)\r\ntest_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)\r\n\r\nstart_time = timeit.default_timer()\r\n\r\n# Ensure to only sample actual destination nodes as negatives.\r\nmin_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\r\n\r\n\r\nprint(\"==========================================================\")\r\nprint(f\"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============\")\r\nprint(\"==========================================================\")\r\n\r\nevaluator = Evaluator(name=DATA)\r\nneg_sampler = dataset.negative_sampler\r\n\r\n# for saving the results...\r\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\r\nif not osp.exists(results_path):\r\n    os.mkdir(results_path)\r\n    print('INFO: Create directory {}'.format(results_path))\r\nPath(results_path).mkdir(parents=True, exist_ok=True)\r\nresults_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'\r\n\r\nfor run_idx in range(NUM_RUNS):\r\n    print('-------------------------------------------------------------------------------')\r\n    print(f\"INFO: >>>>> Run: {run_idx} <<<<<\")\r\n    start_run = timeit.default_timer()\r\n\r\n    # set the seed for deterministic results...\r\n    torch.manual_seed(run_idx + SEED)\r\n    set_random_seed(run_idx + SEED)\r\n\r\n    # neighhorhood sampler\r\n    neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)\r\n\r\n    # define the model end-to-end\r\n    memory = TGNMemory(\r\n        data.num_nodes,\r\n        data.msg.size(-1),\r\n        MEM_DIM,\r\n        TIME_DIM,\r\n        message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),\r\n        aggregator_module=LastAggregator(),\r\n    ).to(device)\r\n\r\n    gnn = GraphAttentionEmbedding(\r\n        in_channels=MEM_DIM,\r\n        out_channels=EMB_DIM,\r\n        msg_dim=data.msg.size(-1),\r\n        time_enc=memory.time_enc,\r\n    ).to(device)\r\n\r\n    link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)\r\n\r\n    model = {'memory': memory,\r\n            'gnn': gnn,\r\n            'link_pred': link_pred}\r\n\r\n    optimizer = torch.optim.Adam(\r\n        set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),\r\n        lr=LR,\r\n    )\r\n    criterion = torch.nn.BCEWithLogitsLoss()\r\n\r\n    # Helper vector to map global node indices to local ones.\r\n    assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)\r\n\r\n    # define an early stopper\r\n    save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'\r\n    save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'\r\n    early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id, \r\n                                    tolerance=TOLERANCE, patience=PATIENCE)\r\n\r\n    # ==================================================== Train & Validation\r\n    # loading the validation negative samples\r\n    dataset.load_val_ns()\r\n\r\n    val_perf_list = []\r\n    start_train_val = timeit.default_timer()\r\n    for epoch in range(1, NUM_EPOCH + 1):\r\n        # training\r\n        start_epoch_train = timeit.default_timer()\r\n        loss = train()\r\n        print(\r\n            f\"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {timeit.default_timer() - start_epoch_train: .4f}\"\r\n        )\r\n\r\n        # validation\r\n        start_val = timeit.default_timer()\r\n        perf_metric_val = test(val_loader, neg_sampler, split_mode=\"val\")\r\n        print(f\"\\tValidation {metric}: {perf_metric_val: .4f}\")\r\n        print(f\"\\tValidation: Elapsed time (s): {timeit.default_timer() - start_val: .4f}\")\r\n        val_perf_list.append(perf_metric_val)\r\n\r\n        # check for early stopping\r\n        if early_stopper.step_check(perf_metric_val, model):\r\n            break\r\n\r\n    train_val_time = timeit.default_timer() - start_train_val\r\n    print(f\"Train & Validation: Elapsed Time (s): {train_val_time: .4f}\")\r\n\r\n    # ==================================================== Test\r\n    # first, load the best model\r\n    early_stopper.load_checkpoint(model)\r\n\r\n    # loading the test negative samples\r\n    dataset.load_test_ns()\r\n\r\n    # final testing\r\n    start_test = timeit.default_timer()\r\n    perf_metric_test = test(test_loader, neg_sampler, split_mode=\"test\")\r\n\r\n    print(f\"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< \")\r\n    print(f\"\\tTest: {metric}: {perf_metric_test: .4f}\")\r\n    test_time = timeit.default_timer() - start_test\r\n    print(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\r\n\r\n    save_results({'model': MODEL_NAME,\r\n                  'data': DATA,\r\n                  'run': run_idx,\r\n                  'seed': SEED,\r\n                  f'val {metric}': val_perf_list,\r\n                  f'test {metric}': perf_metric_test,\r\n                  'test_time': test_time,\r\n                  'tot_train_val_time': train_val_time\r\n                  }, \r\n    results_filename)\r\n\r\n    print(f\"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<\")\r\n    print('-------------------------------------------------------------------------------')\r\n\r\nprint(f\"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}\")\r\nprint(\"==============================================================\")\r\n\r\n# #* load numpy arrays instead\r\n# from tgb.linkproppred.dataset import LinkPropPredDataset\r\n\r\n# # data loading\r\n# dataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\r\n# data = dataset.full_data  \r\n# metric = dataset.eval_metric\r\n# sources = dataset.full_data['sources']\r\n# print (\"finished loading numpy arrays\")\r\n"
  },
  {
    "path": "examples/linkproppred/tkgl-icews/cen.py",
    "content": "\"\"\"\nComplex Evolutional Pattern Learning for Temporal Knowledge Graph Reasoning\nReference:\n- https://github.com/Lee-zix/CEN\nZixuan Li, Saiping Guan, Xiaolong Jin, Weihua Peng, Yajuan Lyu , Yong Zhu, Long Bai, Wei Li, Jiafeng Guo, Xueqi Cheng. \nComplex Evolutional Pattern Learning for Temporal Knowledge Graph Reasoning. ACL 2022.\n\"\"\"\nimport timeit\nimport os\nimport sys\nimport os.path as osp\nfrom pathlib import Path\nimport numpy as np\nimport torch\nimport random\nfrom tqdm import tqdm\nimport json\n# internal imports\nmodules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\nsys.path.append(modules_path)\nfrom modules.rrgcn import RecurrentRGCNCEN\nfrom tgb.utils.utils import set_random_seed, split_by_time,  save_results\nfrom modules.tkg_utils import get_args_cen, reformat_ts\nfrom modules.tkg_utils_dgl import build_sub_graph\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom tgb.linkproppred.dataset import LinkPropPredDataset \n\ndef test(model, history_len, history_list, test_list, num_rels, num_nodes, use_cuda, model_name, mode, split_mode):\n    \"\"\"\n    Test the model\n    :param model: model used to test\n    :param history_list:    all input history snap shot list, not include output label train list or valid list\n    :param test_list:   test triple snap shot list\n    :param num_rels:    number of relations\n    :param num_nodes:   number of nodes\n    :param use_cuda:\n    :param model_name:\n    :param mode:\n    :param split_mode: 'test' or 'val' to state which negative samples to load\n    :return mrr\n    \"\"\"\n    print(\"Testing for mode: \", split_mode)\n    if split_mode == 'test':\n        timesteps_to_eval = test_timestamps_orig\n    else:\n        timesteps_to_eval = val_timestamps_orig\n\n    idx = 0\n\n    if mode == \"test\":\n        # test mode: load parameter form file \n        if use_cuda:\n            checkpoint = torch.load(model_name, map_location=torch.device(args.gpu))\n        else:\n            checkpoint = torch.load(model_name, map_location=torch.device('cpu'))        \n        # use best stat checkpoint:\n        print(\"Load Model name: {}. Using best epoch : {}\".format(model_name, checkpoint['epoch']))  \n        print(\"\\n\"+\"-\"*10+\"start testing\"+\"-\"*10+\"\\n\")\n        model.load_state_dict(checkpoint['state_dict'])\n\n    model.eval()\n\n    input_list = [snap for snap in history_list[-history_len:]] \n    perf_list_all = []\n    hits_list_all = []\n    perf_per_rel = {}\n    for rel in range(num_rels):\n        perf_per_rel[rel] = []\n\n    for time_idx, test_snap in enumerate(tqdm(test_list)):\n        history_glist = [build_sub_graph(num_nodes, num_rels, g, use_cuda, args.gpu) for g in input_list]    \n        test_triples_input = torch.LongTensor(test_snap).cuda() if use_cuda else torch.LongTensor(test_snap)\n        timesteps_batch =timesteps_to_eval[time_idx]*np.ones(len(test_triples_input[:,0]))\n\n        neg_samples_batch = neg_sampler.query_batch(test_triples_input[:,0], test_triples_input[:,2], \n                                timesteps_batch, edge_type=test_triples_input[:,1], split_mode=split_mode)\n        pos_samples_batch = test_triples_input[:,2]\n\n        _, perf_list, hits_list = model.predict(history_glist, test_triples_input, use_cuda, neg_samples_batch, pos_samples_batch, \n                                    evaluator, METRIC) \n\n        perf_list_all.extend(perf_list)\n        hits_list_all.extend(hits_list)\n        if split_mode == \"test\":\n            if args.log_per_rel:\n                for score, rel in zip(perf_list, test_triples_input[:,1].tolist()):\n                    perf_per_rel[rel].append(score)\n        # reconstruct history graph list\n        input_list.pop(0)\n        input_list.append(test_snap)\n        idx += 1\n    if split_mode == \"test\":\n        if args.log_per_rel:   \n            for rel in range(num_rels):\n                if len(perf_per_rel[rel]) > 0:\n                    perf_per_rel[rel] = float(np.mean(perf_per_rel[rel]))\n                else:\n                    perf_per_rel.pop(rel)\n    mrr = np.mean(perf_list_all) \n    hits10 = np.mean(hits_list_all)\n    return mrr, perf_per_rel, hits10\n\n\n\ndef run_experiment(args, trainvalidtest_id=0, n_hidden=None, n_layers=None, dropout=None, n_bases=None):\n    '''\n    Run experiment for CEN model\n    :param args: arguments for the model\n    :param trainvalidtest_id: -1: pretrainig, 0: curriculum training (to find best test history len), 1: test on valid set, 2: test on test set\n    :param n_hidden: number of hidden units\n    :param n_layers: number of layers\n    :param dropout: dropout rate\n    :param n_bases: number of bases\n    return: mrr, perf_per_rel: mean reciprocal rank and performance per relation\n    '''\n    # 1) load configuration for grid search the best configuration\n    if n_hidden:\n        args.n_hidden = n_hidden\n    if n_layers:\n        args.n_layers = n_layers\n    if dropout:\n        args.dropout = dropout\n    if n_bases:\n        args.n_bases = n_bases\n    test_history_len = args.test_history_len\n    # 2) set save model path\n    save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'\n    if not osp.exists(save_model_dir):\n        os.mkdir(save_model_dir)\n        print('INFO: Create directory {}'.format(save_model_dir))\n    test_model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{args.test_history_len}'\n    test_state_file = save_model_dir+test_model_name\n    perf_per_rel ={}\n    use_cuda = args.gpu >= 0 and torch.cuda.is_available()\n    # create stat\n\n    model = RecurrentRGCNCEN(args.decoder,\n                            args.encoder,\n                            num_nodes,\n                            num_rels,\n                            args.n_hidden,\n                            args.opn,\n                            sequence_len=args.train_history_len,\n                            num_bases=args.n_bases,\n                            num_basis=args.n_basis,\n                            num_hidden_layers=args.n_layers,\n                            dropout=args.dropout,\n                            self_loop=args.self_loop,\n                            skip_connect=args.skip_connect,\n                            layer_norm=args.layer_norm,\n                            input_dropout=args.input_dropout,\n                            hidden_dropout=args.hidden_dropout,\n                            feat_dropout=args.feat_dropout,\n                            entity_prediction=args.entity_prediction,\n                            relation_prediction=args.relation_prediction,\n                            use_cuda=use_cuda,\n                            gpu = args.gpu)\n    if use_cuda:\n        torch.cuda.set_device(args.gpu)\n        model.cuda()\n\n    # optimizer\n    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)\n    \n    if trainvalidtest_id == 1:  # normal test on validation set  Note that mode=test\n        if os.path.exists(test_state_file):\n            mrr, _, hits10 = test(model, args.test_history_len, train_list, valid_list, num_rels, num_nodes, use_cuda, \n                        test_state_file, \"test\", split_mode=\"val\")      \n        else:\n            print('Cannot do testing because model does not exist: ', test_state_file)\n            mrr = 0\n            hits10 = 0\n    elif trainvalidtest_id == 2: # normal test on test set\n        if os.path.exists(test_state_file):\n            mrr, perf_per_rel, hits10 = test(model, args.test_history_len, train_list+valid_list, test_list, num_rels, num_nodes, use_cuda, \n                        test_state_file, \"test\", split_mode=\"test\")\n        else:\n            print('Cannot do testing because model does not exist: ', test_state_file)\n            mrr = 0\n            hits10 = 0\n    elif trainvalidtest_id == -1:\n        print(\"-------------start pre training model with history length {}----------\\n\".format(args.start_history_len))\n        model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{args.start_history_len}'\n        model_state_file = save_model_dir + model_name\n        print(\"Sanity Check: stat name : {}\".format(model_state_file))\n        print(\"Sanity Check: Is cuda available ? {}\".format(torch.cuda.is_available()))\n            \n        best_mrr = 0\n        best_epoch = 0\n        best_hits10= 0\n\n        ## training loop\n        for epoch in range(args.n_epochs):\n            model.train()\n            losses = []\n\n            idx = [_ for _ in range(len(train_list))]\n            random.shuffle(idx)\n            for train_sample_num in idx:\n                if train_sample_num == 0 or train_sample_num == 1: continue\n                if train_sample_num - args.start_history_len<0:\n                    input_list = train_list[0: train_sample_num]\n                    output = train_list[1:train_sample_num+1]\n                else:\n                    input_list = train_list[train_sample_num-args.start_history_len: train_sample_num]\n                    output = train_list[train_sample_num-args.start_history_len+1:train_sample_num+1]\n\n                # generate history graph\n                history_glist = [build_sub_graph(num_nodes, num_rels, snap, use_cuda, args.gpu) for snap in input_list]\n                output = [torch.from_numpy(_).long().cuda() for _ in output] if use_cuda else [torch.from_numpy(_).long() for _ in output]\n\n                loss= model.get_loss(history_glist, output[-1], None, use_cuda)\n                losses.append(loss.item())\n\n                loss.backward()\n                torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm)  # clip gradients\n                optimizer.step()\n                optimizer.zero_grad()\n\n            print(\"His {:04d}, Epoch {:04d} | Ave Loss: {:.4f} | Best MRR {:.4f} | Model {} \"\n                .format(args.start_history_len, epoch, np.mean(losses), best_mrr, model_name))\n            \n            #! checking GPU usage\n            free_mem, total_mem = torch.cuda.mem_get_info()\n            print (\"--------------GPU memory usage-----------\")\n            print (\"there are \", free_mem, \" free memory\")\n            print (\"there are \", total_mem, \" total available memory\")\n            print (\"there are \", total_mem - free_mem, \" used memory\")\n            print (\"--------------GPU memory usage-----------\")\n\n            # validation        \n            if epoch % args.evaluate_every == 0:\n                mrr, _, hits10 = test(model, args.start_history_len, train_list, valid_list, num_rels, num_nodes, use_cuda, \n                                model_state_file, mode=\"train\", split_mode= \"val\")\n\n                if mrr< best_mrr:\n                    if epoch >= args.n_epochs or epoch - best_epoch > 5:\n                        break\n                else:\n                    best_mrr = mrr\n                    best_epoch = epoch\n                    best_hits10 = hits10\n                    torch.save({'state_dict': model.state_dict(), 'epoch': epoch}, model_state_file)\n\n        \n    elif trainvalidtest_id == 0: #curriculum training\n        model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{args.start_history_len}'\n        init_state_file = save_model_dir + model_name\n        init_checkpoint = torch.load(init_state_file, map_location=torch.device(args.gpu))\n        # use best stat checkpoint:\n        print(\"Load Previous Model name: {}. Using best epoch : {}\".format(init_state_file, init_checkpoint['epoch']))  \n        print(\"\\n\"+\"-\"*10+\"Load model with history length {}\".format(args.start_history_len)+\"-\"*10+\"\\n\")\n        model.load_state_dict(init_checkpoint['state_dict'])\n        test_history_len = args.start_history_len\n\n        mrr, _, hits10 = test(model, \n                    args.start_history_len,\n                    train_list,\n                    valid_list, \n                    num_rels, \n                    num_nodes, \n                    use_cuda, \n                    init_state_file,  \n                    mode=\"test\", split_mode= \"val\") \n        best_mrr_list = [mrr.item()]         \n        best_hits_list = [hits10.item()]                                          \n        # start knowledge distillation\n        ks_idx = 0\n        for history_len in range(args.start_history_len+1, args.train_history_len+1, 1):\n            # current model\n            print(\"best mrr list :\", best_mrr_list)\n            # lr = 0.1*args.lr - 0.002*args.lr*ks_idx\n            optimizer = torch.optim.Adam(model.parameters(), lr=0.1*args.lr, weight_decay=0.00001)\n            model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{history_len}'\n            model_state_file = save_model_dir + model_name\n\n            print(\"Sanity Check: stat name : {}\".format(model_state_file))\n\n            # load model with the least history length\n            prev_model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{history_len-1}'\n            prev_state_file = save_model_dir + prev_model_name\n            checkpoint = torch.load(prev_state_file, map_location=torch.device(args.gpu)) \n            model.load_state_dict(checkpoint['state_dict']) \n            print(\"\\n\"+\"-\"*10+\"start knowledge distillation for history length at \"+ str(history_len)+\"-\"*10+\"\\n\")\n \n            best_mrr = 0\n            best_hits10 = 0\n            best_epoch = 0\n            for epoch in range(args.n_epochs):\n                model.train()\n                losses = []\n\n                idx = [_ for _ in range(len(train_list))]\n                random.shuffle(idx)\n                for train_sample_num in idx:\n                    if train_sample_num == 0 or train_sample_num == 1: continue\n                    if train_sample_num - history_len<0:\n                        input_list = train_list[0: train_sample_num]\n                        output = train_list[1:train_sample_num+1]\n                    else:\n                        input_list = train_list[train_sample_num - history_len: train_sample_num]\n                        output = train_list[train_sample_num-history_len+1:train_sample_num+1]\n\n                    # generate history graph\n                    history_glist = [build_sub_graph(num_nodes, num_rels, snap, use_cuda, args.gpu) for snap in input_list]\n                    output = [torch.from_numpy(_).long().cuda() for _ in output] if use_cuda else [torch.from_numpy(_).long() for _ in output]\n\n                    loss= model.get_loss(history_glist, output[-1], None, use_cuda)\n                    # print(loss)\n                    losses.append(loss.item())\n\n                    loss.backward()\n                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm)  # clip gradients\n                    optimizer.step()\n                    optimizer.zero_grad()\n\n                print(\"His {:04d}, Epoch {:04d} | Ave Loss: {:.4f} |Best MRR {:.4f} | Model {} \"\n                    .format(history_len, epoch, np.mean(losses), best_mrr, model_name))\n\n                #! checking GPU usage\n                free_mem, total_mem = torch.cuda.mem_get_info()\n                print (\"--------------GPU memory usage-----------\")\n                print (\"there are \", free_mem, \" free memory\")\n                print (\"there are \", total_mem, \" total available memory\")\n                print (\"there are \", total_mem - free_mem, \" used memory\")\n                print (\"--------------GPU memory usage-----------\")\n\n                # validation\n                if epoch % args.evaluate_every == 0:\n                    mrr, _, hits10 = test(model, history_len, train_list, valid_list, num_rels, num_nodes, use_cuda, \n                                model_state_file, mode=\"train\", split_mode= \"val\")\n                    \n                    if mrr< best_mrr:\n                        if epoch >= args.n_epochs or epoch-best_epoch>2:\n                            break\n                    else:\n                        best_mrr = mrr\n                        best_epoch = epoch\n                        best_hits10 = hits10\n                        torch.save({'state_dict': model.state_dict(), 'epoch': epoch}, model_state_file)  \n            mrr, _, hits10 = test(model, history_len, train_list, valid_list, num_rels, num_nodes, use_cuda, \n                        model_state_file, mode=\"test\", split_mode= \"val\")\n            ks_idx += 1\n            if mrr.item() < max(best_mrr_list):\n                test_history_len = history_len-1\n                print(\"early stopping, best history length: \", test_history_len)\n                break\n            else:\n                best_mrr_list.append(mrr.item())\n                best_hits_list.append(hits10.item())\n        \n    return mrr, test_history_len, perf_per_rel, hits10\n\n\n\n# ==================\n# ==================\n# ==================\n\nstart_overall = timeit.default_timer()\n\n# set hyperparameters\nargs, _ = get_args_cen()\nargs.dataset = 'tkgl-icews'\n\nSEED = args.seed  # set the random seed for consistency\nset_random_seed(SEED)\n\nDATA=args.dataset\nMODEL_NAME = 'CEN'\n\nprint(\"logging mrrs per relation: \", args.log_per_rel)\nprint(\"do test and valid? do only test no validation?: \", args.validtest, args.test_only)\n\n# load data\ndataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\n\nnum_rels = dataset.num_rels\nnum_nodes = dataset.num_nodes \nsubjects = dataset.full_data[\"sources\"]\nobjects= dataset.full_data[\"destinations\"]\nrelations = dataset.edge_type\n\ntimestamps_orig = dataset.full_data[\"timestamps\"]\ntimestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1\nall_quads = np.stack((subjects, relations, objects, timestamps), axis=1)\n\ntrain_data = all_quads[dataset.train_mask]\nval_data = all_quads[dataset.val_mask]\ntest_data = all_quads[dataset.test_mask]\n\nval_timestamps_orig = list(set(timestamps_orig[dataset.val_mask])) # needed for getting the negative samples\nval_timestamps_orig.sort()\ntest_timestamps_orig = list(set(timestamps_orig[dataset.test_mask])) # needed for getting the negative samples\ntest_timestamps_orig.sort()\n\ntrain_list = split_by_time(train_data)\nvalid_list = split_by_time(val_data)\ntest_list = split_by_time(test_data)\n\n# evaluation metric\nMETRIC = dataset.eval_metric\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n#load the ns samples \ndataset.load_val_ns()\ndataset.load_test_ns()\n\nif args.grid_search:\n    print(\"TODO: implement hyperparameter grid search\")\n# single run\nelse:\n    \n    start_train = timeit.default_timer()\n    if args.validtest:\n        print('directly start testing')\n        if args.test_history_len_2 != args.test_history_len:\n            args.test_history_len = args.test_history_len_2 # hyperparameter value as given in original paper \n    else:\n        print('running pretrain and train')\n        # pretrain\n        mrr, _, _, hits10 = run_experiment(args, trainvalidtest_id=-1)\n        # train\n        mrr, args.test_history_len, _, hits10 = run_experiment(args, trainvalidtest_id=0) # overwrite test_history_len with \n        # the best history len (for valid mrr)       \n        \n    if args.test_only == False:\n        print(\"running test (on val and test dataset) with test_history_len of: \", args.test_history_len)\n        # test on val set\n        val_mrr, _, _, val_hits10 = run_experiment(args, trainvalidtest_id=1)\n    else:\n        val_mrr = 0\n        val_hits10 = 0\n\n    # test on test set\n    start_test = timeit.default_timer()\n    test_mrr, _, perf_per_rel, test_hits10 = run_experiment(args, trainvalidtest_id=2)\n\ntest_time = timeit.default_timer() - start_test\nall_time = timeit.default_timer() - start_train\nprint(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\nprint(f\"\\Train and Test: Elapsed Time (s): {all_time: .4f}\")\n\nprint(f\"\\tTest: {METRIC}: {test_mrr: .4f}\")\nprint(f\"\\tValid: {METRIC}: {val_mrr: .4f}\")\n\n# saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'\n\nsave_results({'model': MODEL_NAME,\n              'data': DATA,\n              'run': args.run_nr,\n              'seed': SEED,\n              'test_history_len': args.test_history_len,\n              f'val {METRIC}': float(val_mrr),\n              f'test {METRIC}': float(test_mrr),\n              'test_time': test_time,\n              'tot_train_val_time': all_time,\n              'test_hits10': float(test_hits10)\n              }, \n    results_filename)\n\nif args.log_per_rel:\n    results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results_per_rel.json'\n    with open(results_filename, 'w') as json_file:\n        json.dump(perf_per_rel, json_file)\n\n\nsys.exit()"
  },
  {
    "path": "examples/linkproppred/tkgl-icews/edgebank.py",
    "content": "\"\"\"\nDynamic Link Prediction with EdgeBank\nNOTE: This implementation works only based on `numpy`\n\nReference: \n    - https://github.com/fpour/DGB/tree/main\n\n\n\"\"\"\n\nimport timeit\nimport numpy as np\nfrom sklearn.metrics import average_precision_score, roc_auc_score\nfrom torch_geometric.loader import TemporalDataLoader\nfrom tqdm import tqdm\nimport math\nimport os\nimport os.path as osp\nfrom pathlib import Path\nimport sys\nimport argparse\n\n# internal imports\ntgb_modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\nsys.path.append(tgb_modules_path)\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom modules.edgebank_predictor import EdgeBankPredictor\nfrom tgb.utils.utils import set_random_seed\nfrom tgb.linkproppred.dataset import LinkPropPredDataset\nfrom tgb.utils.utils import save_results\n\n# ==================\n# ==================\n# ==================\n\ndef test(data, test_mask, neg_sampler, split_mode):\n    r\"\"\"\n    Evaluated the dynamic link prediction\n    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges\n\n    Parameters:\n        data: a dataset object\n        test_mask: required masks to load the test set edges\n        neg_sampler: an object that gives the negative edges corresponding to each positive edge\n        split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives\n    Returns:\n        perf_metric: the result of the performance evaluation\n    \"\"\"\n    num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)\n    perf_list = []\n    hits_list = []\n    for batch_idx in tqdm(range(num_batches)):\n        start_idx = batch_idx * BATCH_SIZE\n        end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))\n        pos_src, pos_dst, pos_t, pos_edge = (\n            data['sources'][test_mask][start_idx: end_idx],\n            data['destinations'][test_mask][start_idx: end_idx],\n            data['timestamps'][test_mask][start_idx: end_idx],\n            data['edge_type'][test_mask][start_idx: end_idx],\n        )\n        neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, pos_edge, split_mode=split_mode)\n        \n        for idx, neg_batch in enumerate(neg_batch_list):\n            query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])\n            query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])\n\n            y_pred = edgebank.predict_link(query_src, query_dst)\n            # compute MRR\n            input_dict = {\n                \"y_pred_pos\": np.array([y_pred[0]]),\n                \"y_pred_neg\": np.array(y_pred[1:]),\n                \"eval_metric\": [metric],\n            }\n            results = evaluator.eval(input_dict)\n            perf_list.append(results[metric])\n            hits_list.append(results['hits@10'])\n            \n        # update edgebank memory after each positive batch\n        edgebank.update_memory(pos_src, pos_dst, pos_t)\n\n    perf_metrics = float(np.mean(perf_list))\n    perf_hits = float(np.mean(hits_list))\n\n    return perf_metrics, perf_hits\n\ndef get_args():\n    parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')\n    parser.add_argument('-d', '--data', type=str, help='Dataset name', default='tkgl-icews')\n    parser.add_argument('--bs', type=int, help='Batch size', default=200)\n    parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)\n    parser.add_argument('--seed', type=int, help='Random seed', default=1)\n    parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])\n    parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)\n\n    try:\n        args = parser.parse_args()\n    except:\n        parser.print_help()\n        sys.exit(0)\n    return args, sys.argv \n\n# ==================\n# ==================\n# ==================\n\nstart_overall = timeit.default_timer()\n\n# set hyperparameters\nargs, _ = get_args()\n\nSEED = args.seed  # set the random seed for consistency\nset_random_seed(SEED)\nMEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`\nBATCH_SIZE = args.bs\nK_VALUE = args.k_value\nTIME_WINDOW_RATIO = args.time_window_ratio\nDATA = args.data\nMODEL_NAME = 'EdgeBank'\n\n\n\n# data loading with `numpy`\ndataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\ndata = dataset.full_data  \nmetric = dataset.eval_metric\n\n\n# get masks\ntrain_mask = dataset.train_mask\nval_mask = dataset.val_mask\ntest_mask = dataset.test_mask\n\n#data for memory in edgebank\nhist_src = np.concatenate([data['sources'][train_mask]])\nhist_dst = np.concatenate([data['destinations'][train_mask]])\nhist_ts = np.concatenate([data['timestamps'][train_mask]])\n\n# Set EdgeBank with memory updater\nedgebank = EdgeBankPredictor(\n        hist_src,\n        hist_dst,\n        hist_ts,\n        memory_mode=MEMORY_MODE,\n        time_window_ratio=TIME_WINDOW_RATIO)\n\nprint(\"==========================================================\")\nprint(f\"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============\")\nprint(\"==========================================================\")\n\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'\n\n# ==================================================== Test\n# loading the validation negative samples\ndataset.load_val_ns()\n\n# testing ...\nstart_val = timeit.default_timer()\nperf_metric_val, perf_hits_val = test(data, val_mask, neg_sampler, split_mode='val')\nend_val = timeit.default_timer()\n\nprint(f\"INFO: val: Evaluation Setting: >>> ONE-VS--ALL <<< \")\nprint(f\"\\tval: {metric}: {perf_metric_val: .4f}\")\nval_time = timeit.default_timer() - start_val\nprint(f\"\\tval: Elapsed Time (s): {val_time: .4f}\")\n\n\n\n\n# ==================================================== Test\n# loading the test negative samples\ndataset.load_test_ns()\n\n# testing ...\nstart_test = timeit.default_timer()\nperf_metric_test, perf_hits_test = test(data, test_mask, neg_sampler, split_mode='test')\nend_test = timeit.default_timer()\n\nprint(f\"INFO: Test: Evaluation Setting: >>>  <<< \")\nprint(f\"\\tTest: {metric}: {perf_metric_test: .4f}\")\ntest_time = timeit.default_timer() - start_test\nprint(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\n\nsave_results({'model': MODEL_NAME,\n              'memory_mode': MEMORY_MODE,\n              'data': DATA,\n              'run': 1,\n              'seed': SEED,\n              metric: perf_metric_test,\n              'val_mrr': perf_metric_val,\n              'test_time': test_time,\n              'tot_train_val_time': test_time+val_time,\n              'hits10': perf_hits_test}, \n    results_filename)\n\n"
  },
  {
    "path": "examples/linkproppred/tkgl-icews/recurrencybaseline.py",
    "content": "\"\"\" from paper: \"History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting\" \nJulia Gastinger, Christian Meilicke, Federico Errica, Timo Sztyler, Anett Schuelke, Heiner Stuckenschmidt (IJCAI 2024)\n\n@inproceedings{gastinger2024baselines,\n  title={History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting},\n  author={Gastinger, Julia and Meilicke, Christian and Errica, Federico and Sztyler, Timo and Schuelke, Anett and Stuckenschmidt, Heiner},\n  booktitle={33nd International Joint Conference on Artificial Intelligence (IJCAI 2024)},\n  year={2024},\n  organization={International Joint Conferences on Artificial Intelligence Organization}\n}\n\n\"\"\"\n\n## imports\nimport timeit\nimport argparse\nimport numpy as np\nfrom copy import copy\nfrom pathlib import Path\nimport ray\nimport sys\nimport os\nimport os.path as osp\nimport json\n#internal imports \nmodules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\nsys.path.append(modules_path)\nfrom modules.recurrencybaseline_predictor import baseline_predict, baseline_predict_remote\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom tgb.linkproppred.dataset import LinkPropPredDataset \nfrom tgb.utils.utils import set_random_seed,  save_results \nfrom modules.tkg_utils import create_basis_dict, group_by, reformat_ts\n\ndef predict(num_processes,  data_c_rel, all_data_c_rel, alpha, lmbda_psi,\n            perf_list_all, hits_list_all, window, neg_sampler, split_mode):\n    \"\"\" create predictions for each relation on test or valid set and compute mrr\n    :return  perf_list_all: list of mrrs for each test query\n    :return hits_list_all: list of hits for each test query\n    \"\"\"\n\n    first_ts = data_c_rel[0][3]\n    ## use this if you wanna use ray:\n    num_queries = len(data_c_rel) // num_processes\n    if num_queries < num_processes: # if we do not have enough queries for all the processes\n        num_processes_tmp = 1\n        num_queries = len(data_c_rel)\n    else:\n        num_processes_tmp = num_processes  \n    if num_processes > 1:\n        object_references =[]                   \n        \n        for i in range(num_processes_tmp):\n            num_test_queries = len(data_c_rel) - (i + 1) * num_queries\n            if num_test_queries >= num_queries:\n                test_queries_idx =[i * num_queries, (i + 1) * num_queries]\n            else:\n                test_queries_idx = [i * num_queries, len(test_data)]\n\n            valid_data_b = data_c_rel[test_queries_idx[0]:test_queries_idx[1]]\n\n            ob = baseline_predict_remote.remote(num_queries, valid_data_b, all_data_c_rel, window, \n                                basis_dict, \n                                num_nodes, num_rels, lmbda_psi, \n                                alpha, evaluator,first_ts, neg_sampler, split_mode)\n            object_references.append(ob)\n\n        output = ray.get(object_references)\n\n        # updates the scores and logging dict for each process\n        for proc_loop in range(num_processes_tmp):\n            perf_list_all.extend(output[proc_loop][0])\n            hits_list_all.extend(output[proc_loop][1])\n\n    else:\n        perf_list, hits_list = baseline_predict(len(data_c_rel), data_c_rel, all_data_c_rel, \n                            window, basis_dict, \n                            num_nodes, num_rels, lmbda_psi, \n                            alpha, evaluator, first_ts, neg_sampler, split_mode)                  \n        perf_list_all.extend(perf_list)\n        hits_list_all.extend(hits_list)\n    \n    return perf_list_all, hits_list_all\n\n## test\ndef test(best_config, all_relations,test_data_prel, all_data_prel, neg_sampler, num_processes, window, split_mode='test'):  \n    \"\"\" create predictions by loopoing through all relations on test or valid set and compute mrr\n    :return  perf_list_all: list of mrrs for each test query\n    :return hits_list_all: list of hits for each test query\n    \"\"\"       \n    perf_list_all = []\n    hits_list_all =[]\n    \n    csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'\n\n    ## loop through relations and apply baselines\n    \n    for rel in all_relations:\n        start =  timeit.default_timer()\n        if rel in test_data_prel.keys():\n            lmbda_psi = best_config[str(rel)]['lmbda_psi'][0]\n            alpha = best_config[str(rel)]['alpha'][0]\n\n            # test data for this relation\n            test_data_c_rel = test_data_prel[rel]\n            timesteps_test = list(set(test_data_c_rel[:,3]))\n            timesteps_test.sort()\n            all_data_c_rel = all_data_prel[rel]         \n            perf_list_rel = []\n            hits_list_rel = []\n            perf_list_rel, hits_list_rel = predict(num_processes, test_data_c_rel,\n                                                all_data_c_rel, alpha, lmbda_psi,perf_list_rel, hits_list_rel, \n                                                window, neg_sampler, split_mode)\n            perf_list_all.extend(perf_list_rel)\n            hits_list_all.extend(hits_list_rel)\n        else:\n            perf_list_rel =[]\n         \n        end =  timeit.default_timer()\n        total_time = round(end - start, 6)  \n        print(\"Relation {} finished in {} seconds.\".format(rel, total_time))\n        \n\n        with open(csv_file, 'a') as f:\n            f.write(\"{},{}\\n\".format(rel, perf_list_rel))\n\n    return perf_list_all, hits_list_all\n\ndef read_dict_compute_mrr(split_mode='test'):\n    \"\"\" read the results per relation  from a precreated file and compute mrr\n    :return  mrr_per_rel: dictionary of mrrs for each relation\n    :return all_mrrs: list of mrrs for all relations\n    \"\"\"\n    csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'\n    # Initialize an empty dictionary to store the data\n    results_per_rel_dict = {}\n    mrr_per_rel = {}\n    all_mrrs = []\n    # Open the file for reading\n    with open(csv_file, 'r') as f:\n        # Read each line in the file\n        for line in f:\n            # Split the line at the comma\n            parts = line.strip().split(',')\n            # Extract the key (the first part)\n            key = int(parts[0])\n            # Extract the values (the rest of the parts), remove square brackets\n            values = [float(value.strip('[]')) for value in parts[1:]]\n            # Add the key-value pair to the dictionary\n            if key in results_per_rel_dict.keys():\n                print(f\"Key {key} already exists in the dictionary!!! might have duplicate entries in results csv\")\n            results_per_rel_dict[key] = values\n            all_mrrs.extend(values)\n            mrr_per_rel[key] = np.mean(values)\n\n    if len(list(results_per_rel_dict.keys())) != num_rels:\n        print(\"we do not have entries for each rel in the results csv file. only num enties: \", len(list(results_per_rel_dict.keys())))\n\n    print(\"Split mode: \"+split_mode +\" Mean MRR: \", np.mean(all_mrrs))\n    print(\"mrr per relation: \", mrr_per_rel)\n\n\n\n## train\ndef train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampler, num_processes, window):\n    \"\"\" optional, find best values for lambda and alpha by looping through all relations and testing an a fixed set of params\n    based on validation mrr\n    :return best_config: dictionary of best params for each relation\n    \"\"\"\n    best_config= {}\n    best_mrr = 0\n    for rel in rels: # loop through relations. for each relation, apply rules with selected params, compute valid mrr\n        start =  timeit.default_timer()\n        rel_key = int(rel)            \n\n        best_config[str(rel_key)] = {}\n        best_config[str(rel_key)]['not_trained'] = 'True'    \n        best_config[str(rel_key)]['lmbda_psi'] = [default_lmbda_psi,0] #default\n        best_config[str(rel_key)]['other_lmbda_mrrs'] = list(np.zeros(len(params_dict['lmbda_psi'])))\n        best_config[str(rel_key)]['alpha'] = [default_alpha,0]  #default    \n        best_config[str(rel_key)]['other_alpha_mrrs'] = list(np.zeros(len(params_dict['alpha'])))\n        \n        if rel in val_data_prel.keys():      \n            # valid data for this relation  \n            val_data_c_rel = val_data_prel[rel]\n            timesteps_valid = list(set(val_data_c_rel[:,3]))\n            timesteps_valid.sort()\n            trainval_data_c_rel = trainval_data_prel[rel]\n\n            ######  1) select lambda ###############        \n            lmbdas_psi = params_dict['lmbda_psi']        \n\n            alpha = 1\n            best_lmbda_psi = 0.1\n            best_mrr_psi = 0\n            lmbda_mrrs = []\n\n            best_config[str(rel_key)]['num_app_valid'] = copy(len(val_data_c_rel))\n            best_config[str(rel_key)]['num_app_train_valid'] = copy(len(trainval_data_c_rel))         \n            best_config[str(rel_key)]['not_trained'] = 'False'       \n            \n            for lmbda_psi in lmbdas_psi:   \n                perf_list_r = []\n                hits_list_r = []\n                perf_list_r, hits_list_r = predict(num_processes, val_data_c_rel, \n                                                    trainval_data_c_rel, alpha, lmbda_psi,perf_list_r, hits_list_r, \n                                                    window, neg_sampler, split_mode='val')\n                # compute mrr\n                mrr = np.mean(perf_list_r)\n                # # is new mrr better than previous best? if yes: store lmbda\n                if mrr > best_mrr_psi:\n                    best_mrr_psi = float(mrr)\n                    best_lmbda_psi = lmbda_psi\n\n\n                lmbda_mrrs.append(float(mrr))\n            best_config[str(rel_key)]['lmbda_psi'] = [best_lmbda_psi, best_mrr_psi]\n            best_config[str(rel_key)]['other_lmbda_mrrs'] = lmbda_mrrs\n            best_mrr = best_mrr_psi\n            ##### 2) select alpha ###############\n            best_config[str(rel_key)]['not_trained'] = 'False'    \n            alphas = params_dict['alpha'] \n            lmbda_psi = best_config[str(rel_key)]['lmbda_psi'][0] # use the best lmbda psi\n\n            alpha_mrrs = []\n            # perf_list_all = []\n            best_mrr_alpha = 0\n            best_alpha=0.99\n            for alpha in alphas:\n                perf_list_r = []\n                hits_list_r = []\n\n                perf_list_r, hits_list_r = predict(num_processes, val_data_c_rel, \n                                                    trainval_data_c_rel, alpha, lmbda_psi,perf_list_r, hits_list_r, \n                                                    window, neg_sampler, split_mode='val')\n                # compute mrr\n                mrr_alpha = np.mean(perf_list_r)\n\n                # is new mrr better than previous best? if yes: store alpha\n                if mrr_alpha > best_mrr_alpha:\n                    best_mrr_alpha = float(mrr_alpha)\n                    best_alpha = alpha\n                    best_mrr = best_mrr_alpha\n                alpha_mrrs.append(float(mrr_alpha))\n\n            best_config[str(rel_key)]['alpha'] = [best_alpha, best_mrr_alpha]\n            best_config[str(rel_key)]['other_alpha_mrrs'] = alpha_mrrs\n\n        end =  timeit.default_timer()\n        total_time = round(end - start, 6)  \n        print(\"Relation {} finished in {} seconds.\".format(rel, total_time))\n    return best_config\n\n\n\n## args\ndef get_args(): \n    \"\"\"parse all arguments for the script\"\"\"\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--dataset\", \"-d\", default=\"tkgl-icews\", type=str) \n    parser.add_argument(\"--window\", \"-w\", default=0, type=int) # set to e.g. 200 if only the most recent 200 timesteps should be considered. set to -2 if multistep\n    parser.add_argument(\"--num_processes\", \"-p\", default=1, type=int)\n    parser.add_argument(\"--lmbda\", \"-l\",  default=0.1, type=float) # fix lambda. used if trainflag == false\n    parser.add_argument(\"--alpha\", \"-alpha\",  default=0.99, type=float) # fix alpha. used if trainflag == false\n    parser.add_argument(\"--train_flag\", \"-tr\",  default='False') # do we need training, ie selection of lambda and alpha\n    parser.add_argument(\"--load_flag\", \"-lo\",  default='False') # if train_flag set to True: do you want to load best_config?\n    parser.add_argument(\"--save_config\", \"-c\",  default='True') # do we need to save the selection of lambda and alpha in config file?\n    parser.add_argument('--seed', type=int, help='Random seed', default=1) # not needed\n    parsed = vars(parser.parse_args())\n    return parsed\n\nstart_o =  timeit.default_timer()\n\nparsed = get_args()\nif parsed['num_processes']>1:\n    ray.init(num_cpus=parsed[\"num_processes\"], num_gpus=0)\nMODEL_NAME = 'RecurrencyBaseline'\nSEED = parsed['seed']  # set the random seed for consistency\nset_random_seed(SEED)\nperrel_results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_models'\nif not osp.exists(perrel_results_path):\n    os.mkdir(perrel_results_path)\n    print('INFO: Create directory {}'.format(perrel_results_path))\nPath(perrel_results_path).mkdir(parents=True, exist_ok=True)\n\n## load dataset and prepare it accordingly\nname = parsed[\"dataset\"]\ndataset = LinkPropPredDataset(name=name, root=\"datasets\", preprocess=True)\nDATA = name\n\nrelations = dataset.edge_type\nnum_rels = dataset.num_rels\nrels = np.arange(0,num_rels)\nsubjects = dataset.full_data[\"sources\"]\nobjects= dataset.full_data[\"destinations\"]\nnum_nodes = dataset.num_nodes \ntimestamps_orig = dataset.full_data[\"timestamps\"]\ntimestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1\n\nprint(\"split train valid test data\")\nall_quads = np.stack((subjects, relations, objects, timestamps, timestamps_orig), axis=1)\ntrain_data = all_quads[dataset.train_mask]\nval_data = all_quads[dataset.val_mask]\ntest_data = all_quads[dataset.test_mask]\n\nmetric = dataset.eval_metric\nevaluator = Evaluator(name=name)\nneg_sampler = dataset.negative_sampler\n\ntrain_val_data = np.concatenate([train_data, val_data])\nall_data = np.concatenate([train_data, val_data, test_data])\n\n# create dicts with key: relation id, values: triples for that relation id\nprint(\"grouping data by relation\")\ntest_data_prel = group_by(test_data, 1)\nall_data_prel = group_by(all_data, 1)\nval_data_prel = group_by(val_data, 1)\ntrainval_data_prel = group_by(train_val_data, 1)\n\n#load the ns samples \n# if parsed['train_flag']:\nprint(\"loading negative samples\")\ndataset.load_val_ns()\ndataset.load_test_ns()\n\n# parameter options\nif parsed['train_flag'] == 'True':\n    params_dict = {}\n    params_dict['lmbda_psi'] = [0, 0.0001, 0.0005, 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.5, 0.9, 1.0001] \n    params_dict['alpha'] = [0, 0.00001, 0.0001, 0.001, 0.01, 0.1, 0.5, 0.9, 0.99, 0.999, 0.9999, 0.99999, 1]\n    default_lmbda_psi = params_dict['lmbda_psi'][-1]\n    default_alpha = params_dict['alpha'][-2]\n\n## load rules\nprint(\"creating rules\")\nbasis_dict = create_basis_dict(train_val_data)\nprint(\"done with creating rules\")\n## init\n# rb_predictor = RecurrencyBaselinePredictor(rels)\n## train to find best lambda and alpha\nstart_train =  timeit.default_timer()\nif parsed['train_flag'] ==  'True':\n    if parsed['load_flag'] == 'True':\n        with open('best_config.json', 'r') as infile:\n            best_config = json.load(infile)\n    else:\n        print('start training')\n        best_config = train(params_dict,  rels, val_data_prel, trainval_data_prel, neg_sampler, parsed['num_processes'], \n            parsed['window'])\n        if parsed['save_config'] == 'True':\n            import json\n            with open('best_config.json', 'w') as outfile:\n                json.dump(best_config, outfile)\n\nelse: # use preset lmbda and alpha; same for all relations\n    best_config = {} \n    for rel in rels:\n        best_config[str(rel)] = {}\n        best_config[str(rel)]['lmbda_psi'] = [parsed['lmbda']]\n        best_config[str(rel)]['alpha'] = [parsed['alpha']]\n    \nend_train =  timeit.default_timer()\n\n# compute validation mrr\nprint(\"Computing validation MRR\")\nperf_list_all_val, hits_list_all_val = test(best_config,rels, val_data_prel, \n                                                trainval_data_prel, neg_sampler, parsed['num_processes'], \n                                            parsed['window'], split_mode='val')\nval_mrr = float(np.mean(perf_list_all_val))\n\n# compute test mrr\nprint(\"Computing test MRR\")\nstart_test =  timeit.default_timer()\nperf_list_all, hits_list_all = test(best_config,rels, test_data_prel, \n                                                 all_data_prel, neg_sampler, parsed['num_processes'], \n                                                parsed['window'])\n\nend_o =  timeit.default_timer()\ntrain_time_o = round(end_train- start_train, 6)  \ntest_time_o = round(end_o- start_test, 6)  \ntotal_time_o = round(end_o- start_o, 6)  \nprint(\"Running Training to find best configs finished in {} seconds.\".format(train_time_o))\nprint(\"Running testing with best configs finished in {} seconds.\".format(test_time_o))\nprint(\"Running all steps finished in {} seconds.\".format(total_time_o))\n\nprint(f\"The test MRR is {np.mean(perf_list_all)}\")\nprint(f\"The valid MRR is {val_mrr}\")\nprint(f\"The Hits@10 is {np.mean(hits_list_all)}\")\nprint(f\"We have {len(perf_list_all)} predictions\")\nprint(f\"The test set has len {len(test_data)} \")\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_NONE_{DATA}_results.json'\nmetric = dataset.eval_metric\nsave_results({'model': MODEL_NAME,\n              'train_flag': parsed['train_flag'],\n              'data': DATA,\n              'run': 1,\n              'seed': SEED,\n              metric: float(np.mean(perf_list_all)),\n              'val_mrr': val_mrr,\n              'hits10': float(np.mean(hits_list_all)),\n              'test_time': test_time_o,\n              'tot_train_val_time': total_time_o\n              }, \n    results_filename)\n\nif parsed['num_processes']>1:\n    ray.shutdown()\n\n\n    \n"
  },
  {
    "path": "examples/linkproppred/tkgl-icews/regcn.py",
    "content": "\"\"\"\nTemporal Knowledge Graph Reasoning Based on Evolutional Representation Learning\nReference:\n- https://github.com/Lee-zix/RE-GCN\nZixuan Li, Xiaolong Jin, Wei Li, Saiping Guan, Jiafeng Guo, Huawei Shen, Yuanzhuo Wang and Xueqi Cheng. Temporal \nKnowledge Graph Reasoning Based on Evolutional Representation Learning. SIGIR 2021.\n\"\"\"\nimport sys\nimport timeit\nimport os\nimport sys\nimport os.path as osp\nfrom pathlib import Path\nimport numpy as np\nimport torch\nimport random\nfrom tqdm import tqdm\n# internal imports\nmodules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\nsys.path.append(modules_path)\nfrom modules.rrgcn import RecurrentRGCNREGCN\nfrom tgb.utils.utils import set_random_seed, split_by_time, save_results\nfrom modules.tkg_utils import get_args_regcn, reformat_ts\nfrom modules.tkg_utils_dgl import build_sub_graph\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom tgb.linkproppred.dataset import LinkPropPredDataset \nimport json\n\ndef test(model, history_list, test_list, num_rels, num_nodes, use_cuda, model_name, static_graph, mode, split_mode):\n    \"\"\"\n    Test the model on either test or validation set\n    :param model: model used to test\n    :param history_list:    all input history snap shot list, not include output label train list or valid list\n    :param test_list:   test triple snap shot list\n    :param num_rels:    number of relations\n    :param num_nodes:   number of nodes\n    :param use_cuda:\n    :param model_name:\n    :param mode:\n    :param split_mode: 'test' or 'val' to state which negative samples to load\n    :return mrr\n    \"\"\"\n    print(\"Testing for mode: \", split_mode)\n    if split_mode == 'test':\n        timesteps_to_eval = test_timestamps_orig\n    else:\n        timesteps_to_eval = val_timestamps_orig\n\n    idx = 0\n\n    if mode == \"test\":\n        # test mode: load parameter form file \n        if use_cuda:\n            checkpoint = torch.load(model_name, map_location=torch.device(args.gpu))\n        else:\n            checkpoint = torch.load(model_name, map_location=torch.device('cpu'))        \n        # use best stat checkpoint:\n        print(\"Load Model name: {}. Using best epoch : {}\".format(model_name, checkpoint['epoch']))  \n        print(\"\\n\"+\"-\"*10+\"start testing\"+\"-\"*10+\"\\n\")\n        model.load_state_dict(checkpoint['state_dict'])\n\n    model.eval()\n    input_list = [snap for snap in history_list[-args.test_history_len:]] \n    perf_list_all = []\n    hits_list_all = []\n    perf_per_rel = {}\n    for rel in range(num_rels):\n        perf_per_rel[rel] = []\n\n    for time_idx, test_snap in enumerate(tqdm(test_list)):\n        history_glist = [build_sub_graph(num_nodes, num_rels, g, use_cuda, args.gpu) for g in input_list]    \n        test_triples_input = torch.LongTensor(test_snap).cuda() if use_cuda else torch.LongTensor(test_snap)\n        timesteps_batch =timesteps_to_eval[time_idx]*np.ones(len(test_triples_input[:,0]))\n\n        neg_samples_batch = neg_sampler.query_batch(test_triples_input[:,0], test_triples_input[:,2], \n                                timesteps_batch, edge_type=test_triples_input[:,1], split_mode=split_mode)\n        pos_samples_batch = test_triples_input[:,2]\n\n        _, perf_list, hits_list = model.predict(history_glist, num_rels, static_graph, test_triples_input, use_cuda, neg_samples_batch, pos_samples_batch, \n                                    evaluator, METRIC)  \n\n        perf_list_all.extend(perf_list)\n        hits_list_all.extend(hits_list)\n        if split_mode == \"test\":\n            if args.log_per_rel:\n                for score, rel in zip(perf_list, test_triples_input[:,1].tolist()):\n                    perf_per_rel[rel].append(score)\n        # reconstruct history graph list\n        input_list.pop(0)\n        input_list.append(test_snap)\n        idx += 1\n\n    if split_mode == \"test\":\n        if args.log_per_rel:   \n            for rel in range(num_rels):\n                if len(perf_per_rel[rel]) > 0:\n                    perf_per_rel[rel] = float(np.mean(perf_per_rel[rel]))\n                else:\n                    perf_per_rel.pop(rel)\n    mrr = np.mean(perf_list_all) \n    hits10 = np.mean(hits_list_all) \n    return mrr, perf_per_rel, hits10\n\n\n\ndef run_experiment(args, n_hidden=None, n_layers=None, dropout=None, n_bases=None):\n    \"\"\"\n    Run the experiment with the given configuration\n    :param args: arguments\n    :param n_hidden: hidden dimension\n    :param n_layers: number of layers\n    :param dropout: dropout rate\n    :param n_bases: number of bases\n    :return: mrr, perf_per_rel  (mean reciprocal rank, performance per relation)\n    \"\"\"\n    # load configuration for grid search the best configuration\n    if n_hidden:\n        args.n_hidden = n_hidden\n    if n_layers:\n        args.n_layers = n_layers\n    if dropout:\n        args.dropout = dropout\n    if n_bases:\n        args.n_bases = n_bases\n    mrr = 0\n    hits10=0\n    # 2) set save model path\n    save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'\n    if not osp.exists(save_model_dir):\n        os.mkdir(save_model_dir)\n        print('INFO: Create directory {}'.format(save_model_dir))\n    model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}'\n    model_state_file = save_model_dir+model_name\n    perf_per_rel = {}\n    use_cuda = args.gpu >= 0 and torch.cuda.is_available()\n\n    num_static_rels, num_words, static_triples, static_graph = 0, 0, [], None\n\n    # create stat\n    model = RecurrentRGCNREGCN(args.decoder,\n                          args.encoder,\n                        num_nodes,\n                        int(num_rels/2),\n                        num_static_rels, # DIFFERENT\n                        num_words, # DIFFERENT\n                        args.n_hidden,\n                        args.opn,\n                        sequence_len=args.train_history_len,\n                        num_bases=args.n_bases,\n                        num_basis=args.n_basis,\n                        num_hidden_layers=args.n_layers,\n                        dropout=args.dropout,\n                        self_loop=args.self_loop,\n                        skip_connect=args.skip_connect,\n                        layer_norm=args.layer_norm,\n                        input_dropout=args.input_dropout,\n                        hidden_dropout=args.hidden_dropout,\n                        feat_dropout=args.feat_dropout,\n                        aggregation=args.aggregation, # DIFFERENT\n                        weight=args.weight, # DIFFERENT\n                        discount=args.discount, # DIFFERENT\n                        angle=args.angle, # DIFFERENT\n                        use_static=args.add_static_graph, # DIFFERENT\n                        entity_prediction=args.entity_prediction, \n                        relation_prediction=args.relation_prediction,\n                        use_cuda=use_cuda,\n                        gpu = args.gpu,\n                        analysis=args.run_analysis) # DIFFERENT\n\n    if use_cuda:\n        torch.cuda.set_device(args.gpu)\n        model.cuda()\n\n    # optimizer\n    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)\n\n    if args.test and os.path.exists(model_state_file):\n        mrr, perf_per_rel, hits10 = test(model, \n                    train_list+valid_list, \n                    test_list, \n                    num_rels, \n                    num_nodes, \n                    use_cuda, \n                    model_state_file, \n                    static_graph, \n                    \"test\", \n                    \"test\")\n        return mrr, perf_per_rel, hits10\n    elif args.test and not os.path.exists(model_state_file):\n        print(\"--------------{} not exist, Change mode to train and generate stat for testing----------------\\n\".format(model_state_file))\n        return 0, 0\n    else:\n        print(\"----------------------------------------start training----------------------------------------\\n\")\n        best_mrr = 0\n        best_hits = 0\n        for epoch in range(args.n_epochs):\n\n            model.train()\n            losses = []\n            idx = [_ for _ in range(len(train_list))]\n            random.shuffle(idx)\n\n            for train_sample_num in tqdm(idx):\n                if train_sample_num == 0: continue\n                output = train_list[train_sample_num:train_sample_num+1]\n                if train_sample_num - args.train_history_len<0:\n                    input_list = train_list[0: train_sample_num]\n                else:\n                    input_list = train_list[train_sample_num - args.train_history_len:\n                                        train_sample_num]\n\n                # generate history graph\n                history_glist = [build_sub_graph(num_nodes, num_rels, snap, use_cuda, args.gpu) for snap in input_list]\n                output = [torch.from_numpy(_).long().cuda() for _ in output] if use_cuda else [torch.from_numpy(_).long() for _ in output]\n                loss_e, loss_r, loss_static = model.get_loss(history_glist, output[0], static_graph, use_cuda)\n                loss = args.task_weight*loss_e + (1-args.task_weight)*loss_r + loss_static\n\n                losses.append(loss.item())\n\n                loss.backward()\n                torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm)  # clip gradients\n                optimizer.step()\n                optimizer.zero_grad()\n\n            print(\"Epoch {:04d} | Ave Loss: {:.4f} | Best MRR {:.4f} | Model {} \"\n                  .format(epoch, np.mean(losses),  best_mrr, model_name))\n            \n            #! checking GPU usage\n            free_mem, total_mem = torch.cuda.mem_get_info()\n            print (\"--------------GPU memory usage-----------\")\n            print (\"there are \", free_mem, \" free memory\")\n            print (\"there are \", total_mem, \" total available memory\")\n            print (\"there are \", total_mem - free_mem, \" used memory\")\n            print (\"--------------GPU memory usage-----------\")\n\n            # validation\n            if epoch and epoch % args.evaluate_every == 0:\n                mrr,perf_per_rel, hits10 = test(model, train_list, \n                            valid_list, \n                            num_rels, \n                            num_nodes, \n                            use_cuda, \n                            model_state_file, \n                            static_graph, \n                            mode=\"train\", split_mode='val')\n            \n                if mrr < best_mrr:\n                    if epoch >= args.n_epochs:\n                        break\n                else:\n                    best_mrr = mrr\n                    best_hits = hits10\n                    torch.save({'state_dict': model.state_dict(), 'epoch': epoch}, model_state_file)\n\n        return best_mrr, perf_per_rel, hits10\n# ==================\n# ==================\n# ==================\n\nstart_overall = timeit.default_timer()\n\n# set hyperparameters\nargs, _ = get_args_regcn()\nargs.dataset = 'tkgl-icews'\n\nSEED = args.seed  # set the random seed for consistency\nset_random_seed(SEED)\n\nDATA=args.dataset\nMODEL_NAME = 'REGCN'\n\nprint(\"logging mrrs per relation: \", args.log_per_rel)\n\n# load data\ndataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\n\nnum_rels = dataset.num_rels\nnum_nodes = dataset.num_nodes \nsubjects = dataset.full_data[\"sources\"]\nobjects= dataset.full_data[\"destinations\"]\nrelations = dataset.edge_type\n\ntimestamps_orig = dataset.full_data[\"timestamps\"]\ntimestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1\nall_quads = np.stack((subjects, relations, objects, timestamps), axis=1)\n\ntrain_data = all_quads[dataset.train_mask]\nval_data = all_quads[dataset.val_mask]\ntest_data = all_quads[dataset.test_mask]\n\nval_timestamps_orig = list(set(timestamps_orig[dataset.val_mask])) # needed for getting the negative samples\nval_timestamps_orig.sort()\ntest_timestamps_orig = list(set(timestamps_orig[dataset.test_mask])) # needed for getting the negative samples\ntest_timestamps_orig.sort()\n\ntrain_list = split_by_time(train_data)\nvalid_list = split_by_time(val_data)\ntest_list = split_by_time(test_data)\n\n# evaluation metric\nMETRIC = dataset.eval_metric\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n#load the ns samples \ndataset.load_val_ns()\ndataset.load_test_ns()\n\n## run training and testing\nval_mrr, test_mrr = 0, 0\ntest_hits10 = 0\nif args.grid_search:\n    print(\"hyperparameter grid search not implemented. Exiting.\")\n# single run\nelse:\n    start_train = timeit.default_timer()\n    if args.test == False: #if they are true: directly test on a previously trained and stored model\n        print('start training')\n        val_mrr, perf_per_rel, val_hits10 = run_experiment(args) # do training\n    start_test = timeit.default_timer()\n    args.test = True\n    print('start testing')\n    test_mrr, perf_per_rel, test_hits10 = run_experiment(args) # do testing\n\n\ntest_time = timeit.default_timer() - start_test\nall_time = timeit.default_timer() - start_train\nprint(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\nprint(f\"\\Train and Test: Elapsed Time (s): {all_time: .4f}\")\n\nprint(f\"\\tTest: {METRIC}: {test_mrr: .4f}\")\nprint(f\"\\tValid: {METRIC}: {val_mrr: .4f}\")\n\n# saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'\n\nsave_results({'model': MODEL_NAME,\n              'data': DATA,\n              'run': args.run_nr,\n              'seed': SEED,\n              f'val {METRIC}': float(val_mrr),\n              f'test {METRIC}': float(test_mrr),\n              'test_time': test_time,\n              'tot_train_val_time': all_time,\n              'test_hits10': float(test_hits10)\n              }, \n    results_filename)\n\nif args.log_per_rel:\n    results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results_per_rel.json'\n    with open(results_filename, 'w') as json_file:\n        json.dump(perf_per_rel, json_file)\n\nsys.exit()"
  },
  {
    "path": "examples/linkproppred/tkgl-icews/timetraveler.py",
    "content": "\"\"\"\nTimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting\nReference:\n- https://github.com/JHL-HUST/TITer\nHaohai Sun, Jialun Zhong, Yunpu Ma, Zhen Han, Kun He.\nTimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting EMNLP 2021\n\"\"\"\nimport sys\nimport timeit\n\nimport torch\nfrom torch.utils.data import Dataset,DataLoader\nimport logging\n\nimport numpy as np\nimport pickle\nfrom tqdm import tqdm\nimport os.path as osp\nfrom pathlib import Path\nimport os\n\nmodules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\nsys.path.append(modules_path)\nfrom modules.timetraveler_agent import Agent\nfrom modules.timetraveler_environment import Env\nfrom modules.timetraveler_dirichlet import Dirichlet, MLE_Dirchlet\nfrom modules.timetraveler_episode import Episode\nfrom modules.timetraveler_policygradient import PG\nfrom tgb.linkproppred.dataset import LinkPropPredDataset \nfrom tgb.linkproppred.evaluate import Evaluator\nfrom modules.timetraveler_trainertester import Trainer, Tester, getRelEntCooccurrence\nfrom tgb.utils.utils import set_random_seed,save_results \nfrom modules.tkg_utils import  get_args_timetraveler, reformat_ts, get_model_config_timetraveler\n\nclass QuadruplesDataset(Dataset):\n    \"\"\" this is an internal way how Timetraveler represents the data\n    \"\"\"\n    def __init__(self, examples):\n        \"\"\"\n        examples: a list of quadruples.\n        num_r: number of relations\n        \"\"\"\n        self.quadruples = examples.copy()\n\n\n    def __len__(self):\n        return len(self.quadruples)\n\n    def __getitem__(self, item):\n        return self.quadruples[item][0], \\\n               self.quadruples[item][1], \\\n               self.quadruples[item][2], \\\n               self.quadruples[item][3], \\\n               self.quadruples[item][4]\n    \ndef set_logger(save_path):\n    \"\"\"Write logs to checkpoint and console\"\"\"\n    if args.do_train:\n        log_file = os.path.join(save_path, 'train.log')\n    else:\n        log_file = os.path.join(save_path, 'test.log')\n\n    logging.basicConfig(\n        format='%(asctime)s %(levelname)-8s %(message)s',\n        level=logging.INFO,\n        datefmt='%Y-%m-%d %H:%M:%S',\n        filename=log_file,\n        filemode='w'\n    )\n    console = logging.StreamHandler()\n    console.setLevel(logging.INFO)\n    formatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s')\n    console.setFormatter(formatter)\n    logging.getLogger('').addHandler(console)\n\ndef preprocess_data(args, config, timestamps, save_path, all_quads):\n    \"\"\"\n    Preprocess the data and save the state-action space (pickle dump)\n    \"\"\"\n    # parser = argparse.ArgumentParser(description='Data preprocess', usage='preprocess_data.py [<args>] [-h | --help]')\n    # parser.add_argument('--data_dir', default='data/ICEWS14', type=str, help='Path to data.')\n\n    env = Env(all_quads, config)\n    state_actions_space = {}\n\n    with tqdm(total=len(all_quads)) as bar:\n        for (head, rel, tail, t, _) in all_quads:\n            if (head, t, True) not in state_actions_space.keys():\n                state_actions_space[(head, t, True)] = env.get_state_actions_space_complete(head, t, True, args.store_actions_num)\n                state_actions_space[(head, t, False)] = env.get_state_actions_space_complete(head, t, False, args.store_actions_num)\n            if (tail, t, True) not in state_actions_space.keys():\n                state_actions_space[(tail, t, True)] = env.get_state_actions_space_complete(tail, t, True, args.store_actions_num)\n                state_actions_space[(tail, t, False)] = env.get_state_actions_space_complete(tail, t, False, args.store_actions_num)\n            bar.update(1)\n    pickle.dump(state_actions_space, open(os.path.join(save_path,  args.state_actions_path), 'wb'))\n\ndef log_metrics(mode, step, metrics):\n    \"\"\"Print the evaluation logs\"\"\"\n    for metric in metrics:\n        logging.info('%s %s at epoch %d: %f' % (mode, metric, step, metrics[metric]))\n\ndef main(args):\n    \"\"\"\n    Main function to train and test the TimeTraveler model\"\"\"\n\n    start_overall = timeit.default_timer()\n    #######################Set Logger#################################\n    \n    save_path = f'{os.path.dirname(os.path.abspath(__file__))}/saved_models/'\n    if not os.path.exists(save_path):\n        os.makedirs(save_path)\n\n    if args.cuda and torch.cuda.is_available():\n        args.cuda = True\n    else:\n        args.cuda = False\n    set_logger(save_path)\n\n    #######################Create DataLoader#################################\n    # set hyperparameters\n    args.dataset = 'tkgl-icews'\n\n    SEED = args.seed  # set the random seed for consistency\n    set_random_seed(SEED)\n\n    DATA=args.dataset\n    MODEL_NAME = 'TIMETRAVELER'\n\n    # load data\n    dataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\n\n    num_rels = dataset.num_rels\n    num_nodes = dataset.num_nodes \n    subjects = dataset.full_data[\"sources\"]\n    objects= dataset.full_data[\"destinations\"]\n    relations = dataset.edge_type\n\n    timestamps_orig = dataset.full_data[\"timestamps\"]\n    timestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1\n    all_quads = np.stack((subjects, relations, objects, timestamps,timestamps_orig), axis=1)\n\n    train_data = all_quads[dataset.train_mask]\n    train_entities = np.unique(np.concatenate((train_data[:, 0], train_data[:, 2])))\n    RelEntCooccurrence = getRelEntCooccurrence(train_data, num_rels)\n    train_data =QuadruplesDataset(train_data)\n    val_data = QuadruplesDataset(all_quads[dataset.val_mask])\n    test_data = QuadruplesDataset(all_quads[dataset.test_mask])\n\n    METRIC = dataset.eval_metric\n    evaluator = Evaluator(name=DATA)\n    neg_sampler = dataset.negative_sampler\n    #load the ns samples \n    dataset.load_val_ns()\n    dataset.load_test_ns()\n\n    train_dataloader = DataLoader(\n        train_data,\n        batch_size=args.batch_size,\n        shuffle=True,\n        num_workers=args.num_workers,\n    )\n\n    valid_dataloader = DataLoader(\n        val_data,\n        batch_size=args.test_batch_size,\n        shuffle=False,\n        num_workers=args.num_workers,\n    )\n\n    test_dataloader = DataLoader(\n        test_data,\n        batch_size=args.test_batch_size,\n        shuffle=False,\n        num_workers=args.num_workers,\n    )\n\n    ######################Creat the agent and the environment###########################\n    config = get_model_config_timetraveler(args, num_nodes, num_rels)\n    logging.info(config)\n    logging.info(args)\n\n    # creat the agent\n    agent = Agent(config)\n\n\n    # creat the environment\n    state_actions_path = os.path.join(save_path, args.state_actions_path)\n\n\n    ######################preprocessing###########################\n    if not os.path.exists(state_actions_path):\n        if args.preprocess:\n            print(\"preprocessing data...\")\n            preprocess_data(args, config, timestamps, save_path, list(all_quads))\n            state_action_space = pickle.load(open(os.path.join(save_path, args.state_actions_path), 'rb'))\n        else:\n            state_action_space = None\n    else:\n        print(\"load preprocessed data...\")\n        state_action_space = pickle.load(open(os.path.join(save_path, args.state_actions_path), 'rb'))\n\n\n    env = Env(list(all_quads), config, state_action_space)\n    # Create episode controller\n    episode = Episode(env, agent, config)\n    if args.cuda:\n        episode = episode.cuda()\n    pg = PG(config)  # Policy Gradient\n    optimizer = torch.optim.Adam(episode.parameters(), lr=args.lr, weight_decay=0.00001)\n\n    ######################Reward Shaping: MLE DIRICHLET alphas###########################\n    if args.reward_shaping: \n        try:\n            print(\"load alphas from pickle file\")\n            alphas = pickle.load(open(os.path.join(save_path, args.alphas_pkl), 'rb'))\n        except:\n            print('running MLE dirichlet now')\n            mle_d = MLE_Dirchlet(all_quads, num_rels, args.k, args.time_span,\n                         args.tol, args.method, args.maxiter)\n            pickle.dump(mle_d.alphas, open(os.path.join(save_path, args.alphas_pkl), 'wb'))\n\n            print('dumped alphas')\n            alphas = mle_d.alphas\n        distributions = Dirichlet(alphas, args.k)\n    else:\n        distributions = None\n\n    ######################Training and Testing###########################\n\n    trainer = Trainer(episode, pg, optimizer, args, distributions)\n    tester = Tester(episode, args, train_entities, RelEntCooccurrence, dataset.metric)\n    test_metrics ={}\n    val_metrics = {}\n    test_metrics[METRIC] = None\n    val_metrics[METRIC] = None\n\n    if args.do_train:\n        start_train =timeit.default_timer()\n        logging.info('Start Training......')\n        for i in range(args.max_epochs):\n            loss, reward = trainer.train_epoch(train_dataloader, len(train_data))\n            logging.info('Epoch {}/{} Loss: {}, reward: {}'.format(i, args.max_epochs, loss, reward))\n\n            #! checking GPU usage\n            free_mem, total_mem = torch.cuda.mem_get_info()\n            print (\"--------------GPU memory usage-----------\")\n            print (\"there are \", free_mem, \" free memory\")\n            print (\"there are \", total_mem, \" total available memory\")\n            print (\"there are \", total_mem - free_mem, \" used memory\")\n            print (\"--------------GPU memory usage-----------\")\n            \n            if i % args.save_epoch == 0 and i != 0:\n                trainer.save_model(save_path, 'checkpoint_{}.pth'.format(i))\n                logging.info('Save Model in {}'.format(save_path))\n\n            if i % args.valid_epoch == 0 and i != 0:\n                logging.info('Start Val......')\n                val_metrics = tester.test(valid_dataloader,\n                                      len(val_data), num_nodes, neg_sampler, evaluator, split_mode='val')\n                for mode in val_metrics.keys():\n                    logging.info('{} at epoch {}: {}'.format(mode, i, val_metrics[mode]))\n\n        trainer.save_model(save_path)\n        logging.info('Save Model in {}'.format(save_path))\n    else:\n          # # Load the model parameters\n        if os.path.isfile(save_path):\n            params = torch.load(save_path)\n            episode.load_state_dict(params['model_state_dict'])\n            optimizer.load_state_dict(params['optimizer_state_dict'])\n            logging.info('Load pretrain model: {}'.format(save_path))\n    if args.do_test:\n        logging.info('Start Testing......')\n        start_test = timeit.default_timer()\n        test_metrics = tester.test(test_dataloader,\n                              len(test_data), num_nodes, neg_sampler, evaluator, split_mode='test')\n        for mode in test_metrics.keys():\n            logging.info('Test {} : {}'.format(mode, test_metrics[mode]))\n\n        # saving the results...\n        results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\n        if not osp.exists(results_path):\n            os.mkdir(results_path)\n            print('INFO: Create directory {}'.format(results_path))\n        Path(results_path).mkdir(parents=True, exist_ok=True)\n        results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'\n        test_time = timeit.default_timer() - start_test\n        all_time = timeit.default_timer() - start_train \n        all_time_preprocess = timeit.default_timer() - start_overall \n\n        save_results({'model': MODEL_NAME,\n                    'data': DATA,\n                    'seed': SEED,\n                    f'val {METRIC}': float(val_metrics[METRIC]),\n                    f'test {METRIC}': float(test_metrics[METRIC]),\n                    'test_time': test_time,\n                    'tot_train_val_time': all_time,\n                    'tot_preprocess_train_val_time': all_time_preprocess\n                    }, \n            results_filename)     \n\nif __name__ == '__main__':\n    args = get_args_timetraveler()\n    main(args)"
  },
  {
    "path": "examples/linkproppred/tkgl-icews/tkgl-icews_example.py",
    "content": "import numpy as np\r\nimport timeit\r\nfrom tqdm import tqdm\r\nimport sys\r\nimport os\r\nimport os.path as osp\r\n\r\n#internal imports \r\nmodules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\r\nsys.path.append(modules_path)\r\n\r\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\r\nfrom torch_geometric.loader import TemporalDataLoader\r\nfrom tgb.linkproppred.evaluate import Evaluator\r\n\r\nDATA = \"tkgl-icews\"\r\n\r\n# data loading\r\ndataset = PyGLinkPropPredDataset(name=DATA, root=\"datasets\")\r\ntrain_mask = dataset.train_mask\r\nval_mask = dataset.val_mask\r\ntest_mask = dataset.test_mask\r\ndata = dataset.get_TemporalData()\r\nmetric = dataset.eval_metric\r\n\r\nprint (\"there are {} nodes and {} edges\".format(dataset.num_nodes, dataset.num_edges))\r\nprint (\"there are {} relation types\".format(dataset.num_rels))\r\n\r\n\r\ntimestamp = data.t\r\nhead = data.src\r\ntail = data.dst\r\nedge_type = data.edge_type #relation\r\nneg_sampler = dataset.negative_sampler\r\n\r\ntrain_data = data[train_mask]\r\nval_data = data[val_mask]\r\ntest_data = data[test_mask]\r\n\r\n\r\nmetric = dataset.eval_metric\r\nevaluator = Evaluator(name=DATA)\r\nneg_sampler = dataset.negative_sampler\r\n\r\nBATCH_SIZE = 200\r\nval_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)\r\ntest_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)\r\n\r\nstart_time = timeit.default_timer()\r\n#load the ns samples first\r\ndataset.load_val_ns()\r\nfor batch in tqdm(val_loader):\r\n    src, pos_dst, t, msg, rel = batch.src, batch.dst, batch.t, batch.msg, batch.edge_type\r\n    neg_batch_list = neg_sampler.query_batch(src.detach().cpu().numpy(), pos_dst.detach().cpu().numpy(), t.detach().cpu().numpy(), rel.detach().cpu().numpy(), split_mode='val')\r\nprint (\"loading ns samples from validation\", timeit.default_timer() - start_time)\r\n# for i, (src, dst, t, rel) in enumerate(zip(val_data.src, val_data.dst, val_data.t, val_data.edge_type)):\r\n#     #must use np array to query\r\n#     neg_batch_list = neg_sampler.query_batch(np.array([src]), np.array([dst]), np.array([t]), edge_type=np.array([rel]), split_mode='val')\r\n\r\nstart_time = timeit.default_timer()\r\ndataset.load_test_ns()\r\nfor batch in test_loader:\r\n    src, pos_dst, t, msg, rel = batch.src, batch.dst, batch.t, batch.msg, batch.edge_type\r\n    neg_batch_list = neg_sampler.query_batch(src.detach().cpu().numpy(), pos_dst.detach().cpu().numpy(), t.detach().cpu().numpy(), rel.detach().cpu().numpy(), split_mode='test')\r\nprint (\"loading ns samples from test\", timeit.default_timer() - start_time)\r\n# for i, (src, dst, t, rel) in enumerate(zip(test_data.src, test_data.dst, test_data.t, test_data.edge_type)):\r\n#     #must use np array to query\r\n#     neg_batch_list = neg_sampler.query_batch(np.array([src]), np.array([dst]), np.array([t]), edge_type=np.array([rel]), split_mode='test')\r\nprint (\"retrieved all negative samples\")\r\n\r\n\r\n# #* load numpy arrays instead\r\n# from tgb.linkproppred.dataset import LinkPropPredDataset\r\n\r\n# # data loading\r\n# dataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\r\n# data = dataset.full_data  \r\n# metric = dataset.eval_metric\r\n# sources = dataset.full_data['sources']\r\n# print (sources.dtype)\r\n\r\n"
  },
  {
    "path": "examples/linkproppred/tkgl-icews/tlogic.py",
    "content": "\"\"\"\nhttps://github.com/liu-yushan/TLogic/tree/main/mycode\nTLogic: Temporal Logical Rules for Explainable Link Forecasting on Temporal Knowledge Graphs.\nYushan Liu, Yunpu Ma, Marcel Hildebrandt, Mitchell Joblin, Volker Tresp\n\"\"\"\n\n# imports\nimport sys\nimport os\nimport os.path as osp\nfrom pathlib import Path\nmodules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\nsys.path.append(modules_path)\nimport timeit\nimport argparse\nimport numpy as np\nimport json\nfrom joblib import Parallel, delayed\nimport itertools\n\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom tgb.linkproppred.dataset import LinkPropPredDataset \nfrom modules.tlogic_learn_modules import Temporal_Walk, Rule_Learner, store_edges\nimport modules.tlogic_apply_modules as ra\nfrom tgb.utils.utils import set_random_seed,  save_results\nfrom modules.tkg_utils import reformat_ts, get_inv_relation_id, create_scores_array\n\ndef learn_rules(i, num_relations):\n    \"\"\"\n    Learn rules (multiprocessing possible).\n\n    Parameters:\n        i (int): process number\n        num_relations (int): minimum number of relations for each process\n\n    Returns:\n        rl.rules_dict (dict): rules dictionary\n    \"\"\"\n\n    # if seed:\n    #     np.random.seed(seed)\n\n    num_rest_relations = len(all_relations) - (i + 1) * num_relations\n    if num_rest_relations >= num_relations:\n        relations_idx = range(i * num_relations, (i + 1) * num_relations)\n    else:\n        relations_idx = range(i * num_relations, len(all_relations))\n\n    num_rules = [0]\n    for k in relations_idx:\n        rel = all_relations[k]\n        for length in rule_lengths:\n            it_start =  timeit.default_timer()\n            for _ in range(num_walks):\n                walk_successful, walk = temporal_walk.sample_walk(length + 1, rel)\n                if walk_successful:\n                    rl.create_rule(walk)\n            it_end =  timeit.default_timer()\n            it_time = round(it_end - it_start, 6)\n            num_rules.append(sum([len(v) for k, v in rl.rules_dict.items()]) // 2)\n            num_new_rules = num_rules[-1] - num_rules[-2]\n            print(\n                \"Process {0}: relation {1}/{2}, length {3}: {4} sec, {5} rules\".format(\n                    i,\n                    k - relations_idx[0] + 1,\n                    len(relations_idx),\n                    length,\n                    it_time,\n                    num_new_rules,\n                )\n            )\n\n    return rl.rules_dict\n\ndef apply_rules(i, num_queries, rules_dict, neg_sampler, data, window, learn_edges, all_quads, args, split_mode, \n                log_per_rel=False, num_rels=0):\n    \"\"\"\n    Apply rules (multiprocessing possible).\n\n    Parameters:\n        i (int): process number\n        num_queries (int): minimum number of queries for each process\n\n    Returns:\n        hits_list (list): hits list (hits@10 per sample)\n        perf_list (list): performance list (mrr per sample)\n    \"\"\"\n    perf_per_rel = {}\n    for rel in range(num_rels):\n            perf_per_rel[rel] = []\n    print(\"Start process\", i, \"...\")\n    all_candidates = [dict() for _ in range(len(args))]\n    no_cands_counter = 0\n\n    num_rest_queries = len(data) - (i + 1) * num_queries\n    if num_rest_queries >= num_queries:\n        test_queries_idx = range(i * num_queries, (i + 1) * num_queries)\n    else:\n        test_queries_idx = range(i * num_queries, len(data))\n\n    cur_ts = data[test_queries_idx[0]][3]\n    edges = ra.get_window_edges(all_quads[:,0:4], cur_ts, learn_edges, window)\n\n    it_start =  timeit.default_timer()\n    hits_list = [0] * len(test_queries_idx)\n    perf_list = [0] * len(test_queries_idx)\n    for index, j in enumerate(test_queries_idx):\n        neg_sample_el =  neg_sampler.query_batch(np.expand_dims(np.array(data[j,0]), axis=0), \n                                                np.expand_dims(np.array(data[j,2]), axis=0), \n                                                np.expand_dims(np.array(data[j,4]), axis=0), \n                                                np.expand_dims(np.array(data[j,1]), axis=0), \n                                                split_mode=split_mode)[0]        \n        \n        # neg_samples_batch[j]\n        pos_sample_el =  data[j,2]\n        test_query = data[j]\n        assert pos_sample_el == test_query[2]\n        cands_dict = [dict() for _ in range(len(args))]\n\n        if test_query[3] != cur_ts:\n            cur_ts = test_query[3]\n            edges = ra.get_window_edges(all_quads[:,0:4], cur_ts, learn_edges, window)\n\n        if test_query[1] in rules_dict:\n            dicts_idx = list(range(len(args)))\n            for rule in rules_dict[test_query[1]]:\n                walk_edges = ra.match_body_relations(rule, edges, test_query[0])\n\n                if 0 not in [len(x) for x in walk_edges]:\n                    rule_walks = ra.get_walks(rule, walk_edges)\n                    if rule[\"var_constraints\"]:\n                        rule_walks = ra.check_var_constraints(\n                            rule[\"var_constraints\"], rule_walks\n                        )\n\n                    if not rule_walks.empty:\n                        cands_dict = ra.get_candidates(\n                            rule,\n                            rule_walks,\n                            cur_ts,\n                            cands_dict,\n                            score_func,\n                            args,\n                            dicts_idx,\n                        )\n                        for s in dicts_idx:\n                            cands_dict[s] = {\n                                x: sorted(cands_dict[s][x], reverse=True)\n                                for x in cands_dict[s].keys()\n                            }\n                            cands_dict[s] = dict(\n                                sorted(\n                                    cands_dict[s].items(),\n                                    key=lambda item: item[1],\n                                    reverse=True,\n                                )\n                            )\n                            top_k_scores = [v for _, v in cands_dict[s].items()][:top_k]\n                            unique_scores = list(\n                                scores for scores, _ in itertools.groupby(top_k_scores)\n                            )\n                            if len(unique_scores) >= top_k:\n                                dicts_idx.remove(s)\n                        if not dicts_idx:\n                            break\n\n            if cands_dict[0]:\n                for s in range(len(args)):\n                    # Calculate noisy-or scores\n                    scores = list(\n                        map(\n                            lambda x: 1 - np.product(1 - np.array(x)),\n                            cands_dict[s].values(),\n                        )\n                    )\n                    cands_scores = dict(zip(cands_dict[s].keys(), scores))\n                    noisy_or_cands = dict(\n                        sorted(cands_scores.items(), key=lambda x: x[1], reverse=True)\n                    )\n                    all_candidates[s][j] = noisy_or_cands\n            else:  # No candidates found by applying rules\n                no_cands_counter += 1\n                for s in range(len(args)):\n                    all_candidates[s][j] = dict()\n\n        else:  # No rules exist for this relation\n            no_cands_counter += 1\n            for s in range(len(args)):\n                all_candidates[s][j] = dict()\n\n        if not (j - test_queries_idx[0] + 1) % 100:\n            it_end =  timeit.default_timer()\n            it_time = round(it_end - it_start, 6)\n            print(\n                \"Process {0}: test samples finished: {1}/{2}, {3} sec\".format(\n                    i, j - test_queries_idx[0] + 1, len(test_queries_idx), it_time\n                )\n            )\n            it_start =  timeit.default_timer()\n\n        predictions = create_scores_array(all_candidates[s][j], num_nodes)  \n        predictions_of_interest_pos = np.array(predictions[pos_sample_el])\n        predictions_of_interest_neg = predictions[neg_sample_el]\n        input_dict = {\n            \"y_pred_pos\": predictions_of_interest_pos,\n            \"y_pred_neg\": predictions_of_interest_neg,\n            \"eval_metric\": ['mrr'], \n        }\n\n        predictions = evaluator.eval(input_dict)\n        perf_list[index] = predictions['mrr']\n        hits_list[index] = predictions['hits@10']\n        if split_mode == \"test\":\n            if log_per_rel:\n                perf_per_rel[test_query[1]].append(perf_list[index]) #test_query[1] is the relation index\n\n    if split_mode == \"test\":\n        if log_per_rel:   \n            for rel in range(num_rels):\n                if len(perf_per_rel[rel]) > 0:\n                    perf_per_rel[rel] = float(np.mean(perf_per_rel[rel]))\n                else:\n                    perf_per_rel.pop(rel)       \n               \n\n    return perf_list, hits_list, perf_per_rel\n\n\n## args\ndef get_args(): \n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--dataset\", \"-d\", default=\"tkgl-icews\", type=str) \n    parser.add_argument(\"--rule_lengths\", \"-l\", default=\"1\", type=int, nargs=\"+\")\n    parser.add_argument(\"--num_walks\", \"-n\", default=\"100\", type=int)\n    parser.add_argument(\"--transition_distr\", default=\"exp\", type=str)\n    parser.add_argument(\"--window\", \"-w\", default=0, type=int) # set to e.g. 200 if only the most recent 200 timesteps should be considered. set to -2 if multistep\n    parser.add_argument(\"--top_k\", default=20, type=int)\n    parser.add_argument(\"--num_processes\", \"-p\", default=1, type=int)\n    parser.add_argument(\"--alpha\", \"-alpha\",  default=0.99, type=float) # fix alpha. used if trainflag == false\n    # parser.add_argument(\"--train_flag\", \"-tr\",  default=True) # do we need training, ie selection of lambda and alpha\n    parser.add_argument(\"--save_config\", \"-c\",  default=True) # do we need to save the selection of lambda and alpha in config file?\n    parser.add_argument('--seed', type=int, help='Random seed', default=1)\n    parser.add_argument('--run_nr', type=int, help='Run Number', default=1)\n    parser.add_argument('--learn_rules_flag', type=bool, help='Do we want to learn the rules', default=True)\n    parser.add_argument('--rule_filename', type=str, help='if rules not learned: where are they stored', default='0_r[3]_n100_exp_s1_rules.json')\n    parser.add_argument('--log_per_rel', type=bool, help='Do we want to log mrr per relation', default=False)\n    parser.add_argument('--compute_valid_mrr', type=bool, help='Do we want to compute mrr for valid set', default=True)\n    parsed = vars(parser.parse_args())\n    return parsed\n\nstart_o =  timeit.default_timer()\n\n## get args\nparsed = get_args()\ndataset = parsed[\"dataset\"]\nrule_lengths = parsed[\"rule_lengths\"]\nrule_lengths = [rule_lengths] if (type(rule_lengths) == int) else rule_lengths\nnum_walks = parsed[\"num_walks\"]\ntransition_distr = parsed[\"transition_distr\"]\nnum_processes = parsed[\"num_processes\"]\nwindow = parsed[\"window\"]\ntop_k = parsed[\"top_k\"]\nlog_per_rel = parsed['log_per_rel']\n\nMODEL_NAME = 'TLogic'\nSEED = parsed['seed']  # set the random seed for consistency\nset_random_seed(SEED)\n\n## load dataset and prepare it accordingly\nname = parsed[\"dataset\"]\ncompute_valid_mrr = parsed[\"compute_valid_mrr\"]\ndataset = LinkPropPredDataset(name=name, root=\"datasets\", preprocess=True)\nDATA = name\n\nrelations = dataset.edge_type\nnum_rels = dataset.num_rels\n\nsubjects = dataset.full_data[\"sources\"]\nobjects= dataset.full_data[\"destinations\"]\nnum_nodes = dataset.num_nodes \ntimestamps_orig = dataset.full_data[\"timestamps\"]\ntimestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1\n\nall_quads = np.stack((subjects, relations, objects, timestamps, timestamps_orig), axis=1)\ntrain_data = all_quads[dataset.train_mask,0:4] # we do not need the original timestamps\nval_data = all_quads[dataset.val_mask,0:5]\ntest_data = all_quads[dataset.test_mask,0:5]\nall_data = all_quads[:,0:4]\n\nmetric = dataset.eval_metric\nevaluator = Evaluator(name=name)\nneg_sampler = dataset.negative_sampler\n\ninv_relation_id = get_inv_relation_id(num_rels)\n\n#load the ns samples \n\ndataset.load_val_ns()\ndataset.load_test_ns()\noutput_dir =  f'{osp.dirname(osp.abspath(__file__))}/saved_models/'\nlearn_rules_flag = parsed['learn_rules_flag']\n## 1. learn rules\nstart_train =  timeit.default_timer()\nif learn_rules_flag:\n    print(\"start learning rules\")\n    # edges (dict): edges for each relation\n    # inv_relation_id (dict): mapping of relation to inverse relation\n    \n    temporal_walk = Temporal_Walk(train_data, inv_relation_id, transition_distr)\n    rl = Rule_Learner(edges=temporal_walk.edges, id2relation=None, inv_relation_id=inv_relation_id,  \n                        output_dir=output_dir)\n    all_relations = sorted(temporal_walk.edges)  # Learn for all relations\n\n    start =  timeit.default_timer()\n    num_relations = len(all_relations) // num_processes\n    output = Parallel(n_jobs=num_processes)(\n        delayed(learn_rules)(i, num_relations) for i in range(num_processes)\n    )\n    end =  timeit.default_timer()\n\n    all_rules = output[0]\n    for i in range(1, num_processes):\n        all_rules.update(output[i])\n\n    total_time = round(end - start, 6)\n    print(\"Learning finished in {} seconds.\".format(total_time))\n\n    rl.rules_dict = all_rules\n    rl.sort_rules_dict()\n\n    rule_filename = rl.save_rules(0, rule_lengths, num_walks, transition_distr, SEED)\n    # rl.save_rules_verbalized(0, rule_lengths, num_walks, transition_distr, seed)\n    # rules_statistics(rl.rules_dict)\nelse:\n    rule_filename = parsed['rule_filename']\n    print(\"Loading rules from file {}\".format(parsed['rule_filename']))\n\nend_train =  timeit.default_timer()\n\n## 2. Apply rules\n\nrules_dict = json.load(open(output_dir + rule_filename))\nrules_dict = {int(k): v for k, v in rules_dict.items()}\n\nrules_dict = ra.filter_rules(\n    rules_dict, min_conf=0.01, min_body_supp=2, rule_lengths=rule_lengths\n) # filter rules for minimum confidence, body support and rule length\n\nlearn_edges = store_edges(train_data)\nscore_func = ra.score_12\n# It is possible to specify a list of list of arguments for tuning\nargs = [[0.1, 0.5]]\n\n# compute valid mrr\nstart_valid =  timeit.default_timer()\nif compute_valid_mrr:\n    print('Computing valid MRR')\n\n    num_queries = len(val_data) // num_processes\n\n    output = Parallel(n_jobs=num_processes)(\n        delayed(apply_rules)(i, num_queries,rules_dict, neg_sampler, val_data, window, learn_edges, \n                            all_quads, args, split_mode='val') for i in range(num_processes))\n    end =  timeit.default_timer()\n\n    perf_list_val = []\n    hits_list_val = []\n\n    for i in range(num_processes):\n        perf_list_val.extend(output[i][0])\n        hits_list_val.extend(output[i][1])\nelse:\n    perf_list_val = [0]\n    hits_list_val = [0]\n    \n\nend_valid =  timeit.default_timer()\n\n# compute test mrr\nif log_per_rel ==True:\n    num_processes = 1 #otherwise logging per rel does not work for our implementation\nstart_test =  timeit.default_timer()\nprint('Computing test MRR')\nstart =  timeit.default_timer()\nnum_queries = len(test_data) // num_processes\n\noutput = Parallel(n_jobs=num_processes)(\n    delayed(apply_rules)(i, num_queries,rules_dict, neg_sampler, test_data, window, learn_edges, \n                         all_quads, args, split_mode='test', log_per_rel=log_per_rel, num_rels=num_rels) for i in range(num_processes))\nend =  timeit.default_timer()\n\nperf_list_all = []\nhits_list_all = []\n\n\nfor i in range(num_processes):\n    perf_list_all.extend(output[i][0])\n    hits_list_all.extend(output[i][1])\nif log_per_rel == True:\n    perf_per_rel = output[0][2]\n\n\ntotal_time = round(end - start, 6)\ntotal_valid_time = round(end_valid - start_valid, 6)\nprint(\"Application finished in {} seconds.\".format(total_time))\n\nprint(f\"The valid MRR is {np.mean(perf_list_val)}\")\nprint(f\"The MRR is {np.mean(perf_list_all)}\")\nprint(f\"The Hits@10 is {np.mean(hits_list_all)}\")\nprint(f\"We have {len(perf_list_all)} predictions\")\nprint(f\"The test set has len {len(test_data)} \")\n\nend_o =  timeit.default_timer()\ntrain_time_o = round(end_train- start_train, 6)  \ntest_time_o = round(end_o- start_test, 6)  \ntotal_time_o = round(end_o- start_o, 6)  \nprint(\"Running Training to find best configs finished in {} seconds.\".format(train_time_o))\nprint(\"Running testing with best configs finished in {} seconds.\".format(test_time_o))\nprint(\"Running all steps finished in {} seconds.\".format(total_time_o))\n\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\n\nif log_per_rel == True:\n    results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results_per_rel.json'\n    with open(results_filename, 'w') as json_file:\n        json.dump(perf_per_rel, json_file)\n\nresults_filename = f'{results_path}/{MODEL_NAME}_NONE_{DATA}_results.json'\nmetric = dataset.eval_metric\nsave_results({'model': MODEL_NAME,\n              'train_flag': None,\n              'rule_len': rule_lengths,\n              'window': window,\n              'data': DATA,\n              'run': 1,\n              'seed': SEED,\n              metric: float(np.mean(perf_list_all)),\n              'hits10': float(np.mean(hits_list_all)),\n              'val_mrr': float(np.mean(perf_list_val)),\n              'test_time': test_time_o,\n              'tot_train_val_time': total_time_o,\n              'valid_time': total_valid_time\n              }, \n    results_filename)\n"
  },
  {
    "path": "examples/linkproppred/tkgl-polecat/cen.py",
    "content": "\"\"\"\nComplex Evolutional Pattern Learning for Temporal Knowledge Graph Reasoning\nReference:\n- https://github.com/Lee-zix/CEN\nZixuan Li, Saiping Guan, Xiaolong Jin, Weihua Peng, Yajuan Lyu , Yong Zhu, Long Bai, Wei Li, Jiafeng Guo, Xueqi Cheng. \nComplex Evolutional Pattern Learning for Temporal Knowledge Graph Reasoning. ACL 2022.\n\"\"\"\nimport timeit\nimport os\nimport sys\nimport os.path as osp\nfrom pathlib import Path\nimport numpy as np\nimport torch\nimport random\nfrom tqdm import tqdm\nimport json\n# internal imports\nmodules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\nsys.path.append(modules_path)\nfrom modules.rrgcn import RecurrentRGCNCEN\nfrom tgb.utils.utils import set_random_seed, split_by_time,  save_results\nfrom modules.tkg_utils import get_args_cen, reformat_ts\nfrom modules.tkg_utils_dgl import build_sub_graph\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom tgb.linkproppred.dataset import LinkPropPredDataset \n\ndef test(model, history_len, history_list, test_list, num_rels, num_nodes, use_cuda, model_name, mode, split_mode):\n    \"\"\"\n    Test the model\n    :param model: model used to test\n    :param history_list:    all input history snap shot list, not include output label train list or valid list\n    :param test_list:   test triple snap shot list\n    :param num_rels:    number of relations\n    :param num_nodes:   number of nodes\n    :param use_cuda:\n    :param model_name:\n    :param mode:\n    :param split_mode: 'test' or 'val' to state which negative samples to load\n    :return mrr\n    \"\"\"\n    print(\"Testing for mode: \", split_mode)\n    if split_mode == 'test':\n        timesteps_to_eval = test_timestamps_orig\n    else:\n        timesteps_to_eval = val_timestamps_orig\n\n    idx = 0\n\n    if mode == \"test\":\n        # test mode: load parameter form file \n        if use_cuda:\n            checkpoint = torch.load(model_name, map_location=torch.device(args.gpu))\n        else:\n            checkpoint = torch.load(model_name, map_location=torch.device('cpu'))        \n        # use best stat checkpoint:\n        print(\"Load Model name: {}. Using best epoch : {}\".format(model_name, checkpoint['epoch']))  \n        print(\"\\n\"+\"-\"*10+\"start testing\"+\"-\"*10+\"\\n\")\n        model.load_state_dict(checkpoint['state_dict'])\n\n    model.eval()\n\n    input_list = [snap for snap in history_list[-history_len:]] \n    perf_list_all = []\n    hits_list_all = []\n    perf_per_rel = {}\n    for rel in range(num_rels):\n        perf_per_rel[rel] = []\n\n    for time_idx, test_snap in enumerate(tqdm(test_list)):\n        history_glist = [build_sub_graph(num_nodes, num_rels, g, use_cuda, args.gpu) for g in input_list]    \n        test_triples_input = torch.LongTensor(test_snap).cuda() if use_cuda else torch.LongTensor(test_snap)\n        timesteps_batch =timesteps_to_eval[time_idx]*np.ones(len(test_triples_input[:,0]))\n\n        neg_samples_batch = neg_sampler.query_batch(test_triples_input[:,0], test_triples_input[:,2], \n                                timesteps_batch, edge_type=test_triples_input[:,1], split_mode=split_mode)\n        pos_samples_batch = test_triples_input[:,2]\n\n        _, perf_list, hits_list = model.predict(history_glist, test_triples_input, use_cuda, neg_samples_batch, pos_samples_batch, \n                                    evaluator, METRIC) \n\n        perf_list_all.extend(perf_list)\n        hits_list_all.extend(hits_list)\n        if split_mode == \"test\":\n            if args.log_per_rel:\n                for score, rel in zip(perf_list, test_triples_input[:,1].tolist()):\n                    perf_per_rel[rel].append(score)\n        # reconstruct history graph list\n        input_list.pop(0)\n        input_list.append(test_snap)\n        idx += 1\n    if split_mode == \"test\":\n        if args.log_per_rel:   \n            for rel in range(num_rels):\n                if len(perf_per_rel[rel]) > 0:\n                    perf_per_rel[rel] = float(np.mean(perf_per_rel[rel]))\n                else:\n                    perf_per_rel.pop(rel)\n    mrr = np.mean(perf_list_all) \n    hits10 = np.mean(hits_list_all)\n    return mrr, perf_per_rel, hits10\n\n\n\ndef run_experiment(args, trainvalidtest_id=0, n_hidden=None, n_layers=None, dropout=None, n_bases=None):\n    '''\n    Run experiment for CEN model\n    :param args: arguments for the model\n    :param trainvalidtest_id: -1: pretrainig, 0: curriculum training (to find best test history len), 1: test on valid set, 2: test on test set\n    :param n_hidden: number of hidden units\n    :param n_layers: number of layers\n    :param dropout: dropout rate\n    :param n_bases: number of bases\n    return: mrr, perf_per_rel: mean reciprocal rank and performance per relation\n    '''\n    # 1) load configuration for grid search the best configuration\n    if n_hidden:\n        args.n_hidden = n_hidden\n    if n_layers:\n        args.n_layers = n_layers\n    if dropout:\n        args.dropout = dropout\n    if n_bases:\n        args.n_bases = n_bases\n    test_history_len = args.test_history_len\n    # 2) set save model path\n    save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'\n    if not osp.exists(save_model_dir):\n        os.mkdir(save_model_dir)\n        print('INFO: Create directory {}'.format(save_model_dir))\n    test_model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{args.test_history_len}'\n    test_state_file = save_model_dir+test_model_name\n    perf_per_rel ={}\n    use_cuda = args.gpu >= 0 and torch.cuda.is_available()\n    # create stat\n\n    model = RecurrentRGCNCEN(args.decoder,\n                            args.encoder,\n                            num_nodes,\n                            num_rels,\n                            args.n_hidden,\n                            args.opn,\n                            sequence_len=args.train_history_len,\n                            num_bases=args.n_bases,\n                            num_basis=args.n_basis,\n                            num_hidden_layers=args.n_layers,\n                            dropout=args.dropout,\n                            self_loop=args.self_loop,\n                            skip_connect=args.skip_connect,\n                            layer_norm=args.layer_norm,\n                            input_dropout=args.input_dropout,\n                            hidden_dropout=args.hidden_dropout,\n                            feat_dropout=args.feat_dropout,\n                            entity_prediction=args.entity_prediction,\n                            relation_prediction=args.relation_prediction,\n                            use_cuda=use_cuda,\n                            gpu = args.gpu)\n    if use_cuda:\n        torch.cuda.set_device(args.gpu)\n        model.cuda()\n\n    # optimizer\n    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)\n    \n    if trainvalidtest_id == 1:  # normal test on validation set  Note that mode=test\n        if os.path.exists(test_state_file):\n            mrr, _, hits10 = test(model, args.test_history_len, train_list, valid_list, num_rels, num_nodes, use_cuda, \n                        test_state_file, \"test\", split_mode=\"val\")      \n        else:\n            print('Cannot do testing because model does not exist: ', test_state_file)\n            mrr = 0\n            hits10 = 0\n    elif trainvalidtest_id == 2: # normal test on test set\n        if os.path.exists(test_state_file):\n            mrr, perf_per_rel, hits10 = test(model, args.test_history_len, train_list+valid_list, test_list, num_rels, num_nodes, use_cuda, \n                        test_state_file, \"test\", split_mode=\"test\")\n        else:\n            print('Cannot do testing because model does not exist: ', test_state_file)\n            mrr = 0\n            hits10 = 0\n    elif trainvalidtest_id == -1:\n        print(\"-------------start pre training model with history length {}----------\\n\".format(args.start_history_len))\n        model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{args.start_history_len}'\n        model_state_file = save_model_dir + model_name\n        print(\"Sanity Check: stat name : {}\".format(model_state_file))\n        print(\"Sanity Check: Is cuda available ? {}\".format(torch.cuda.is_available()))\n            \n        best_mrr = 0\n        best_epoch = 0\n        best_hits10= 0\n\n        ## training loop\n        for epoch in range(args.n_epochs):\n            model.train()\n            losses = []\n\n            idx = [_ for _ in range(len(train_list))]\n            random.shuffle(idx)\n            for train_sample_num in idx:\n                if train_sample_num == 0 or train_sample_num == 1: continue\n                if train_sample_num - args.start_history_len<0:\n                    input_list = train_list[0: train_sample_num]\n                    output = train_list[1:train_sample_num+1]\n                else:\n                    input_list = train_list[train_sample_num-args.start_history_len: train_sample_num]\n                    output = train_list[train_sample_num-args.start_history_len+1:train_sample_num+1]\n\n                # generate history graph\n                history_glist = [build_sub_graph(num_nodes, num_rels, snap, use_cuda, args.gpu) for snap in input_list]\n                output = [torch.from_numpy(_).long().cuda() for _ in output] if use_cuda else [torch.from_numpy(_).long() for _ in output]\n\n                loss= model.get_loss(history_glist, output[-1], None, use_cuda)\n                losses.append(loss.item())\n\n                loss.backward()\n                torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm)  # clip gradients\n                optimizer.step()\n                optimizer.zero_grad()\n\n            print(\"His {:04d}, Epoch {:04d} | Ave Loss: {:.4f} | Best MRR {:.4f} | Model {} \"\n                .format(args.start_history_len, epoch, np.mean(losses), best_mrr, model_name))\n            \n            #! checking GPU usage\n            free_mem, total_mem = torch.cuda.mem_get_info()\n            print (\"--------------GPU memory usage-----------\")\n            print (\"there are \", free_mem, \" free memory\")\n            print (\"there are \", total_mem, \" total available memory\")\n            print (\"there are \", total_mem - free_mem, \" used memory\")\n            print (\"--------------GPU memory usage-----------\")\n\n            # validation        \n            if epoch % args.evaluate_every == 0:\n                mrr, _, hits10 = test(model, args.start_history_len, train_list, valid_list, num_rels, num_nodes, use_cuda, \n                                model_state_file, mode=\"train\", split_mode= \"val\")\n\n                if mrr< best_mrr:\n                    if epoch >= args.n_epochs or epoch - best_epoch > 5:\n                        break\n                else:\n                    best_mrr = mrr\n                    best_epoch = epoch\n                    best_hits10 = hits10\n                    torch.save({'state_dict': model.state_dict(), 'epoch': epoch}, model_state_file)\n\n        \n    elif trainvalidtest_id == 0: #curriculum training\n        model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{args.start_history_len}'\n        init_state_file = save_model_dir + model_name\n        init_checkpoint = torch.load(init_state_file, map_location=torch.device(args.gpu))\n        # use best stat checkpoint:\n        print(\"Load Previous Model name: {}. Using best epoch : {}\".format(init_state_file, init_checkpoint['epoch']))  \n        print(\"\\n\"+\"-\"*10+\"Load model with history length {}\".format(args.start_history_len)+\"-\"*10+\"\\n\")\n        model.load_state_dict(init_checkpoint['state_dict'])\n        test_history_len = args.start_history_len\n\n        mrr, _, hits10 = test(model, \n                    args.start_history_len,\n                    train_list,\n                    valid_list, \n                    num_rels, \n                    num_nodes, \n                    use_cuda, \n                    init_state_file,  \n                    mode=\"test\", split_mode= \"val\") \n        best_mrr_list = [mrr.item()]         \n        best_hits_list = [hits10.item()]                                          \n        # start knowledge distillation\n        ks_idx = 0\n        for history_len in range(args.start_history_len+1, args.train_history_len+1, 1):\n            # current model\n            print(\"best mrr list :\", best_mrr_list)\n            # lr = 0.1*args.lr - 0.002*args.lr*ks_idx\n            optimizer = torch.optim.Adam(model.parameters(), lr=0.1*args.lr, weight_decay=0.00001)\n            model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{history_len}'\n            model_state_file = save_model_dir + model_name\n\n            print(\"Sanity Check: stat name : {}\".format(model_state_file))\n\n            # load model with the least history length\n            prev_model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{history_len-1}'\n            prev_state_file = save_model_dir + prev_model_name\n            checkpoint = torch.load(prev_state_file, map_location=torch.device(args.gpu)) \n            model.load_state_dict(checkpoint['state_dict']) \n            print(\"\\n\"+\"-\"*10+\"start knowledge distillation for history length at \"+ str(history_len)+\"-\"*10+\"\\n\")\n \n            best_mrr = 0\n            best_hits10 = 0\n            best_epoch = 0\n            for epoch in range(args.n_epochs):\n                model.train()\n                losses = []\n\n                idx = [_ for _ in range(len(train_list))]\n                random.shuffle(idx)\n                for train_sample_num in idx:\n                    if train_sample_num == 0 or train_sample_num == 1: continue\n                    if train_sample_num - history_len<0:\n                        input_list = train_list[0: train_sample_num]\n                        output = train_list[1:train_sample_num+1]\n                    else:\n                        input_list = train_list[train_sample_num - history_len: train_sample_num]\n                        output = train_list[train_sample_num-history_len+1:train_sample_num+1]\n\n                    # generate history graph\n                    history_glist = [build_sub_graph(num_nodes, num_rels, snap, use_cuda, args.gpu) for snap in input_list]\n                    output = [torch.from_numpy(_).long().cuda() for _ in output] if use_cuda else [torch.from_numpy(_).long() for _ in output]\n\n                    loss= model.get_loss(history_glist, output[-1], None, use_cuda)\n                    # print(loss)\n                    losses.append(loss.item())\n\n                    loss.backward()\n                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm)  # clip gradients\n                    optimizer.step()\n                    optimizer.zero_grad()\n\n                print(\"His {:04d}, Epoch {:04d} | Ave Loss: {:.4f} |Best MRR {:.4f} | Model {} \"\n                    .format(history_len, epoch, np.mean(losses), best_mrr, model_name))\n\n                #! checking GPU usage\n                free_mem, total_mem = torch.cuda.mem_get_info()\n                print (\"--------------GPU memory usage-----------\")\n                print (\"there are \", free_mem, \" free memory\")\n                print (\"there are \", total_mem, \" total available memory\")\n                print (\"there are \", total_mem - free_mem, \" used memory\")\n                print (\"--------------GPU memory usage-----------\")\n\n                # validation\n                if epoch % args.evaluate_every == 0:\n                    mrr, _, hits10 = test(model, history_len, train_list, valid_list, num_rels, num_nodes, use_cuda, \n                                model_state_file, mode=\"train\", split_mode= \"val\")\n                    \n                    if mrr< best_mrr:\n                        if epoch >= args.n_epochs or epoch-best_epoch>2:\n                            break\n                    else:\n                        best_mrr = mrr\n                        best_epoch = epoch\n                        best_hits10 = hits10\n                        torch.save({'state_dict': model.state_dict(), 'epoch': epoch}, model_state_file)  \n            mrr, _, hits10 = test(model, history_len, train_list, valid_list, num_rels, num_nodes, use_cuda, \n                        model_state_file, mode=\"test\", split_mode= \"val\")\n            ks_idx += 1\n            if mrr.item() < max(best_mrr_list):\n                test_history_len = history_len-1\n                print(\"early stopping, best history length: \", test_history_len)\n                break\n            else:\n                best_mrr_list.append(mrr.item())\n                best_hits_list.append(hits10.item())\n        \n    return mrr, test_history_len, perf_per_rel, hits10\n\n\n\n# ==================\n# ==================\n# ==================\n\nstart_overall = timeit.default_timer()\n\n# set hyperparameters\nargs, _ = get_args_cen()\nargs.dataset = 'tkgl-polecat'\n\nSEED = args.seed  # set the random seed for consistency\nset_random_seed(SEED)\n\nDATA=args.dataset\nMODEL_NAME = 'CEN'\n\nprint(\"logging mrrs per relation: \", args.log_per_rel)\nprint(\"do test and valid? do only test no validation?: \", args.validtest, args.test_only)\n\n# load data\ndataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\n\nnum_rels = dataset.num_rels\nnum_nodes = dataset.num_nodes \nsubjects = dataset.full_data[\"sources\"]\nobjects= dataset.full_data[\"destinations\"]\nrelations = dataset.edge_type\n\ntimestamps_orig = dataset.full_data[\"timestamps\"]\ntimestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1\nall_quads = np.stack((subjects, relations, objects, timestamps), axis=1)\n\ntrain_data = all_quads[dataset.train_mask]\nval_data = all_quads[dataset.val_mask]\ntest_data = all_quads[dataset.test_mask]\n\nval_timestamps_orig = list(set(timestamps_orig[dataset.val_mask])) # needed for getting the negative samples\nval_timestamps_orig.sort()\ntest_timestamps_orig = list(set(timestamps_orig[dataset.test_mask])) # needed for getting the negative samples\ntest_timestamps_orig.sort()\n\ntrain_list = split_by_time(train_data)\nvalid_list = split_by_time(val_data)\ntest_list = split_by_time(test_data)\n\n# evaluation metric\nMETRIC = dataset.eval_metric\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n#load the ns samples \ndataset.load_val_ns()\ndataset.load_test_ns()\n\nif args.grid_search:\n    print(\"TODO: implement hyperparameter grid search\")\n# single run\nelse:\n    \n    start_train = timeit.default_timer()\n    if args.validtest:\n        print('directly start testing')\n        if args.test_history_len_2 != args.test_history_len:\n            args.test_history_len = args.test_history_len_2 # hyperparameter value as given in original paper \n    else:\n        print('running pretrain and train')\n        # pretrain\n        mrr, _, _, hits10 = run_experiment(args, trainvalidtest_id=-1)\n        # train\n        mrr, args.test_history_len, _, hits10 = run_experiment(args, trainvalidtest_id=0) # overwrite test_history_len with \n        # the best history len (for valid mrr)       \n        \n    if args.test_only == False:\n        print(\"running test (on val and test dataset) with test_history_len of: \", args.test_history_len)\n        # test on val set\n        val_mrr, _, _, val_hits10 = run_experiment(args, trainvalidtest_id=1)\n    else:\n        val_mrr = 0\n        val_hits10 = 0\n\n    # test on test set\n    start_test = timeit.default_timer()\n    test_mrr, _, perf_per_rel, test_hits10 = run_experiment(args, trainvalidtest_id=2)\n\ntest_time = timeit.default_timer() - start_test\nall_time = timeit.default_timer() - start_train\nprint(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\nprint(f\"\\Train and Test: Elapsed Time (s): {all_time: .4f}\")\n\nprint(f\"\\tTest: {METRIC}: {test_mrr: .4f}\")\nprint(f\"\\tValid: {METRIC}: {val_mrr: .4f}\")\n\n# saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'\n\nsave_results({'model': MODEL_NAME,\n              'data': DATA,\n              'run': args.run_nr,\n              'seed': SEED,\n              'test_history_len': args.test_history_len,\n              f'val {METRIC}': float(val_mrr),\n              f'test {METRIC}': float(test_mrr),\n              'test_time': test_time,\n              'tot_train_val_time': all_time,\n              'test_hits10': float(test_hits10)\n              }, \n    results_filename)\n\nif args.log_per_rel:\n    results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results_per_rel.json'\n    with open(results_filename, 'w') as json_file:\n        json.dump(perf_per_rel, json_file)\n\n\nsys.exit()"
  },
  {
    "path": "examples/linkproppred/tkgl-polecat/edgebank.py",
    "content": "\"\"\"\nDynamic Link Prediction with EdgeBank\nNOTE: This implementation works only based on `numpy`\n\nReference: \n    - https://github.com/fpour/DGB/tree/main\n\n\n\"\"\"\n\nimport timeit\nimport numpy as np\nfrom sklearn.metrics import average_precision_score, roc_auc_score\nfrom torch_geometric.loader import TemporalDataLoader\nfrom tqdm import tqdm\nimport math\nimport os\nimport os.path as osp\nfrom pathlib import Path\nimport sys\nimport argparse\n\n# internal imports\ntgb_modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\nsys.path.append(tgb_modules_path)\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom modules.edgebank_predictor import EdgeBankPredictor\nfrom tgb.utils.utils import set_random_seed\nfrom tgb.linkproppred.dataset import LinkPropPredDataset\nfrom tgb.utils.utils import save_results\n\n# ==================\n# ==================\n# ==================\n\ndef test(data, test_mask, neg_sampler, split_mode):\n    r\"\"\"\n    Evaluated the dynamic link prediction\n    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges\n\n    Parameters:\n        data: a dataset object\n        test_mask: required masks to load the test set edges\n        neg_sampler: an object that gives the negative edges corresponding to each positive edge\n        split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives\n    Returns:\n        perf_metric: the result of the performance evaluation\n    \"\"\"\n    num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)\n    perf_list = []\n    hits_list = []\n    for batch_idx in tqdm(range(num_batches)):\n        start_idx = batch_idx * BATCH_SIZE\n        end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))\n        pos_src, pos_dst, pos_t, pos_edge = (\n            data['sources'][test_mask][start_idx: end_idx],\n            data['destinations'][test_mask][start_idx: end_idx],\n            data['timestamps'][test_mask][start_idx: end_idx],\n            data['edge_type'][test_mask][start_idx: end_idx],\n        )\n        neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, pos_edge, split_mode=split_mode)\n        \n        for idx, neg_batch in enumerate(neg_batch_list):\n            query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])\n            query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])\n\n            y_pred = edgebank.predict_link(query_src, query_dst)\n            # compute MRR\n            input_dict = {\n                \"y_pred_pos\": np.array([y_pred[0]]),\n                \"y_pred_neg\": np.array(y_pred[1:]),\n                \"eval_metric\": [metric],\n            }\n            results = evaluator.eval(input_dict)\n            perf_list.append(results[metric])\n            hits_list.append(results['hits@10'])\n            \n        # update edgebank memory after each positive batch\n        edgebank.update_memory(pos_src, pos_dst, pos_t)\n\n    perf_metrics = float(np.mean(perf_list))\n    perf_hits = float(np.mean(hits_list))\n\n    return perf_metrics, perf_hits\n\ndef get_args():\n    parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')\n    parser.add_argument('-d', '--data', type=str, help='Dataset name', default='tkgl-polecat')\n    parser.add_argument('--bs', type=int, help='Batch size', default=200)\n    parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)\n    parser.add_argument('--seed', type=int, help='Random seed', default=1)\n    parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])\n    parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)\n\n    try:\n        args = parser.parse_args()\n    except:\n        parser.print_help()\n        sys.exit(0)\n    return args, sys.argv \n\n# ==================\n# ==================\n# ==================\n\nstart_overall = timeit.default_timer()\n\n# set hyperparameters\nargs, _ = get_args()\n\nSEED = args.seed  # set the random seed for consistency\nset_random_seed(SEED)\nMEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`\nBATCH_SIZE = args.bs\nK_VALUE = args.k_value\nTIME_WINDOW_RATIO = args.time_window_ratio\nDATA = args.data\nMODEL_NAME = 'EdgeBank'\n\n\n\n# data loading with `numpy`\ndataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\ndata = dataset.full_data  \nmetric = dataset.eval_metric\n\n\n# get masks\ntrain_mask = dataset.train_mask\nval_mask = dataset.val_mask\ntest_mask = dataset.test_mask\n\n#data for memory in edgebank\nhist_src = np.concatenate([data['sources'][train_mask]])\nhist_dst = np.concatenate([data['destinations'][train_mask]])\nhist_ts = np.concatenate([data['timestamps'][train_mask]])\n\n# Set EdgeBank with memory updater\nedgebank = EdgeBankPredictor(\n        hist_src,\n        hist_dst,\n        hist_ts,\n        memory_mode=MEMORY_MODE,\n        time_window_ratio=TIME_WINDOW_RATIO)\n\nprint(\"==========================================================\")\nprint(f\"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============\")\nprint(\"==========================================================\")\n\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'\n\n# ==================================================== Test\n# loading the validation negative samples\ndataset.load_val_ns()\n\n# testing ...\nstart_val = timeit.default_timer()\nperf_metric_val, perf_hits_val = test(data, val_mask, neg_sampler, split_mode='val')\nend_val = timeit.default_timer()\n\nprint(f\"INFO: val: Evaluation Setting: >>> ONE-VS--ALL <<< \")\nprint(f\"\\tval: {metric}: {perf_metric_val: .4f}\")\nval_time = timeit.default_timer() - start_val\nprint(f\"\\tval: Elapsed Time (s): {val_time: .4f}\")\n\n\n\n\n# ==================================================== Test\n# loading the test negative samples\ndataset.load_test_ns()\n\n# testing ...\nstart_test = timeit.default_timer()\nperf_metric_test, perf_hits_test = test(data, test_mask, neg_sampler, split_mode='test')\nend_test = timeit.default_timer()\n\nprint(f\"INFO: Test: Evaluation Setting: >>>  <<< \")\nprint(f\"\\tTest: {metric}: {perf_metric_test: .4f}\")\ntest_time = timeit.default_timer() - start_test\nprint(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\n\nsave_results({'model': MODEL_NAME,\n              'memory_mode': MEMORY_MODE,\n              'data': DATA,\n              'run': 1,\n              'seed': SEED,\n              metric: perf_metric_test,\n              'val_mrr': perf_metric_val,\n              'test_time': test_time,\n              'tot_train_val_time': test_time+val_time,\n              'hits10': perf_hits_test}, \n    results_filename)\n\n"
  },
  {
    "path": "examples/linkproppred/tkgl-polecat/example.py",
    "content": "from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\r\n\r\n\r\nDATA = \"tkgl-polecat\"\r\n\r\n# data loading\r\ndataset = PyGLinkPropPredDataset(name=DATA, root=\"datasets\")\r\ntrain_mask = dataset.train_mask\r\nval_mask = dataset.val_mask\r\ntest_mask = dataset.test_mask\r\ndata = dataset.get_TemporalData()\r\nmetric = dataset.eval_metric\r\n\r\n\r\ntimestamp = data.t\r\nhead = data.src\r\ntail = data.dst\r\nedge_type = data.edge_type #relation\r\n\r\nprint (edge_type)\r\n\r\ntrain_data = data[train_mask]\r\nval_data = data[val_mask]\r\ntest_data = data[test_mask]"
  },
  {
    "path": "examples/linkproppred/tkgl-polecat/recurrencybaseline.py",
    "content": "\"\"\" from paper: \"History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting\" \nJulia Gastinger, Christian Meilicke, Federico Errica, Timo Sztyler, Anett Schuelke, Heiner Stuckenschmidt (IJCAI 2024)\n\n@inproceedings{gastinger2024baselines,\n  title={History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting},\n  author={Gastinger, Julia and Meilicke, Christian and Errica, Federico and Sztyler, Timo and Schuelke, Anett and Stuckenschmidt, Heiner},\n  booktitle={33nd International Joint Conference on Artificial Intelligence (IJCAI 2024)},\n  year={2024},\n  organization={International Joint Conferences on Artificial Intelligence Organization}\n}\n\n\"\"\"\n\n## imports\nimport timeit\nimport argparse\nimport numpy as np\nfrom copy import copy\nfrom pathlib import Path\nimport ray\nimport sys\nimport os\nimport os.path as osp\nimport json\n#internal imports \nmodules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\nsys.path.append(modules_path)\nfrom modules.recurrencybaseline_predictor import baseline_predict, baseline_predict_remote\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom tgb.linkproppred.dataset import LinkPropPredDataset \nfrom tgb.utils.utils import set_random_seed,  save_results \nfrom modules.tkg_utils import create_basis_dict, group_by, reformat_ts\n\ndef predict(num_processes,  data_c_rel, all_data_c_rel, alpha, lmbda_psi,\n            perf_list_all, hits_list_all, window, neg_sampler, split_mode):\n    \"\"\" create predictions for each relation on test or valid set and compute mrr\n    :return  perf_list_all: list of mrrs for each test query\n    :return hits_list_all: list of hits for each test query\n    \"\"\"\n\n    first_ts = data_c_rel[0][3]\n    ## use this if you wanna use ray:\n    num_queries = len(data_c_rel) // num_processes\n    if num_queries < num_processes: # if we do not have enough queries for all the processes\n        num_processes_tmp = 1\n        num_queries = len(data_c_rel)\n    else:\n        num_processes_tmp = num_processes  \n    if num_processes > 1:\n        object_references =[]                   \n        \n        for i in range(num_processes_tmp):\n            num_test_queries = len(data_c_rel) - (i + 1) * num_queries\n            if num_test_queries >= num_queries:\n                test_queries_idx =[i * num_queries, (i + 1) * num_queries]\n            else:\n                test_queries_idx = [i * num_queries, len(test_data)]\n\n            valid_data_b = data_c_rel[test_queries_idx[0]:test_queries_idx[1]]\n\n            ob = baseline_predict_remote.remote(num_queries, valid_data_b, all_data_c_rel, window, \n                                basis_dict, \n                                num_nodes, num_rels, lmbda_psi, \n                                alpha, evaluator,first_ts, neg_sampler, split_mode)\n            object_references.append(ob)\n\n        output = ray.get(object_references)\n\n        # updates the scores and logging dict for each process\n        for proc_loop in range(num_processes_tmp):\n            perf_list_all.extend(output[proc_loop][0])\n            hits_list_all.extend(output[proc_loop][1])\n\n    else:\n        perf_list, hits_list = baseline_predict(len(data_c_rel), data_c_rel, all_data_c_rel, \n                            window, basis_dict, \n                            num_nodes, num_rels, lmbda_psi, \n                            alpha, evaluator, first_ts, neg_sampler, split_mode)                  \n        perf_list_all.extend(perf_list)\n        hits_list_all.extend(hits_list)\n    \n    return perf_list_all, hits_list_all\n\n## test\ndef test(best_config, all_relations,test_data_prel, all_data_prel, neg_sampler, num_processes, window, split_mode='test'):  \n    \"\"\" create predictions by loopoing through all relations on test or valid set and compute mrr\n    :return  perf_list_all: list of mrrs for each test query\n    :return hits_list_all: list of hits for each test query\n    \"\"\"       \n    perf_list_all = []\n    hits_list_all =[]\n    \n    csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'\n\n    ## loop through relations and apply baselines\n    \n    for rel in all_relations:\n        start =  timeit.default_timer()\n        if rel in test_data_prel.keys():\n            lmbda_psi = best_config[str(rel)]['lmbda_psi'][0]\n            alpha = best_config[str(rel)]['alpha'][0]\n\n            # test data for this relation\n            test_data_c_rel = test_data_prel[rel]\n            timesteps_test = list(set(test_data_c_rel[:,3]))\n            timesteps_test.sort()\n            all_data_c_rel = all_data_prel[rel]         \n            perf_list_rel = []\n            hits_list_rel = []\n            perf_list_rel, hits_list_rel = predict(num_processes, test_data_c_rel,\n                                                all_data_c_rel, alpha, lmbda_psi,perf_list_rel, hits_list_rel, \n                                                window, neg_sampler, split_mode)\n            perf_list_all.extend(perf_list_rel)\n            hits_list_all.extend(hits_list_rel)\n        else:\n            perf_list_rel =[]\n         \n        end =  timeit.default_timer()\n        total_time = round(end - start, 6)  \n        print(\"Relation {} finished in {} seconds.\".format(rel, total_time))\n        \n\n        with open(csv_file, 'a') as f:\n            f.write(\"{},{}\\n\".format(rel, perf_list_rel))\n\n    return perf_list_all, hits_list_all\n\ndef read_dict_compute_mrr(split_mode='test'):\n    \"\"\" read the results per relation  from a precreated file and compute mrr\n    :return  mrr_per_rel: dictionary of mrrs for each relation\n    :return all_mrrs: list of mrrs for all relations\n    \"\"\"\n    csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'\n    # Initialize an empty dictionary to store the data\n    results_per_rel_dict = {}\n    mrr_per_rel = {}\n    all_mrrs = []\n    # Open the file for reading\n    with open(csv_file, 'r') as f:\n        # Read each line in the file\n        for line in f:\n            # Split the line at the comma\n            parts = line.strip().split(',')\n            # Extract the key (the first part)\n            key = int(parts[0])\n            # Extract the values (the rest of the parts), remove square brackets\n            values = [float(value.strip('[]')) for value in parts[1:]]\n            # Add the key-value pair to the dictionary\n            if key in results_per_rel_dict.keys():\n                print(f\"Key {key} already exists in the dictionary!!! might have duplicate entries in results csv\")\n            results_per_rel_dict[key] = values\n            all_mrrs.extend(values)\n            mrr_per_rel[key] = np.mean(values)\n\n    if len(list(results_per_rel_dict.keys())) != num_rels:\n        print(\"we do not have entries for each rel in the results csv file. only num enties: \", len(list(results_per_rel_dict.keys())))\n\n    print(\"Split mode: \"+split_mode +\" Mean MRR: \", np.mean(all_mrrs))\n    print(\"mrr per relation: \", mrr_per_rel)\n\n\n\n## train\ndef train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampler, num_processes, window):\n    \"\"\" optional, find best values for lambda and alpha by looping through all relations and testing an a fixed set of params\n    based on validation mrr\n    :return best_config: dictionary of best params for each relation\n    \"\"\"\n    best_config= {}\n    best_mrr = 0\n    for rel in rels: # loop through relations. for each relation, apply rules with selected params, compute valid mrr\n        start =  timeit.default_timer()\n        rel_key = int(rel)            \n\n        best_config[str(rel_key)] = {}\n        best_config[str(rel_key)]['not_trained'] = 'True'    \n        best_config[str(rel_key)]['lmbda_psi'] = [default_lmbda_psi,0] #default\n        best_config[str(rel_key)]['other_lmbda_mrrs'] = list(np.zeros(len(params_dict['lmbda_psi'])))\n        best_config[str(rel_key)]['alpha'] = [default_alpha,0]  #default    \n        best_config[str(rel_key)]['other_alpha_mrrs'] = list(np.zeros(len(params_dict['alpha'])))\n        \n        if rel in val_data_prel.keys():      \n            # valid data for this relation  \n            val_data_c_rel = val_data_prel[rel]\n            timesteps_valid = list(set(val_data_c_rel[:,3]))\n            timesteps_valid.sort()\n            trainval_data_c_rel = trainval_data_prel[rel]\n\n            ######  1) select lambda ###############        \n            lmbdas_psi = params_dict['lmbda_psi']        \n\n            alpha = 1\n            best_lmbda_psi = 0.1\n            best_mrr_psi = 0\n            lmbda_mrrs = []\n\n            best_config[str(rel_key)]['num_app_valid'] = copy(len(val_data_c_rel))\n            best_config[str(rel_key)]['num_app_train_valid'] = copy(len(trainval_data_c_rel))         \n            best_config[str(rel_key)]['not_trained'] = 'False'       \n            \n            for lmbda_psi in lmbdas_psi:   \n                perf_list_r = []\n                hits_list_r = []\n                perf_list_r, hits_list_r = predict(num_processes, val_data_c_rel, \n                                                    trainval_data_c_rel, alpha, lmbda_psi,perf_list_r, hits_list_r, \n                                                    window, neg_sampler, split_mode='val')\n                # compute mrr\n                mrr = np.mean(perf_list_r)\n                # # is new mrr better than previous best? if yes: store lmbda\n                if mrr > best_mrr_psi:\n                    best_mrr_psi = float(mrr)\n                    best_lmbda_psi = lmbda_psi\n\n\n                lmbda_mrrs.append(float(mrr))\n            best_config[str(rel_key)]['lmbda_psi'] = [best_lmbda_psi, best_mrr_psi]\n            best_config[str(rel_key)]['other_lmbda_mrrs'] = lmbda_mrrs\n            best_mrr = best_mrr_psi\n            ##### 2) select alpha ###############\n            best_config[str(rel_key)]['not_trained'] = 'False'    \n            alphas = params_dict['alpha'] \n            lmbda_psi = best_config[str(rel_key)]['lmbda_psi'][0] # use the best lmbda psi\n\n            alpha_mrrs = []\n            # perf_list_all = []\n            best_mrr_alpha = 0\n            best_alpha=0.99\n            for alpha in alphas:\n                perf_list_r = []\n                hits_list_r = []\n\n                perf_list_r, hits_list_r = predict(num_processes, val_data_c_rel, \n                                                    trainval_data_c_rel, alpha, lmbda_psi,perf_list_r, hits_list_r, \n                                                    window, neg_sampler, split_mode='val')\n                # compute mrr\n                mrr_alpha = np.mean(perf_list_r)\n\n                # is new mrr better than previous best? if yes: store alpha\n                if mrr_alpha > best_mrr_alpha:\n                    best_mrr_alpha = float(mrr_alpha)\n                    best_alpha = alpha\n                    best_mrr = best_mrr_alpha\n                alpha_mrrs.append(float(mrr_alpha))\n\n            best_config[str(rel_key)]['alpha'] = [best_alpha, best_mrr_alpha]\n            best_config[str(rel_key)]['other_alpha_mrrs'] = alpha_mrrs\n\n        end =  timeit.default_timer()\n        total_time = round(end - start, 6)  \n        print(\"Relation {} finished in {} seconds.\".format(rel, total_time))\n    return best_config\n\n\n\n## args\ndef get_args(): \n    \"\"\"parse all arguments for the script\"\"\"\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--dataset\", \"-d\", default=\"tkgl-polecat\", type=str) \n    parser.add_argument(\"--window\", \"-w\", default=0, type=int) # set to e.g. 200 if only the most recent 200 timesteps should be considered. set to -2 if multistep\n    parser.add_argument(\"--num_processes\", \"-p\", default=1, type=int)\n    parser.add_argument(\"--lmbda\", \"-l\",  default=0.1, type=float) # fix lambda. used if trainflag == false\n    parser.add_argument(\"--alpha\", \"-alpha\",  default=0.99, type=float) # fix alpha. used if trainflag == false\n    parser.add_argument(\"--train_flag\", \"-tr\",  default='False') # do we need training, ie selection of lambda and alpha\n    parser.add_argument(\"--load_flag\", \"-lo\",  default='False') # if train_flag set to True: do you want to load best_config?\n    parser.add_argument(\"--save_config\", \"-c\",  default='True') # do we need to save the selection of lambda and alpha in config file?\n    parser.add_argument('--seed', type=int, help='Random seed', default=1) # not needed\n    parsed = vars(parser.parse_args())\n    return parsed\n\nstart_o =  timeit.default_timer()\n\nparsed = get_args()\nif parsed['num_processes']>1:\n    ray.init(num_cpus=parsed[\"num_processes\"], num_gpus=0)\nMODEL_NAME = 'RecurrencyBaseline'\nSEED = parsed['seed']  # set the random seed for consistency\nset_random_seed(SEED)\nperrel_results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_models'\nif not osp.exists(perrel_results_path):\n    os.mkdir(perrel_results_path)\n    print('INFO: Create directory {}'.format(perrel_results_path))\nPath(perrel_results_path).mkdir(parents=True, exist_ok=True)\n\n## load dataset and prepare it accordingly\nname = parsed[\"dataset\"]\ndataset = LinkPropPredDataset(name=name, root=\"datasets\", preprocess=True)\nDATA = name\n\nrelations = dataset.edge_type\nnum_rels = dataset.num_rels\nrels = np.arange(0,num_rels)\nsubjects = dataset.full_data[\"sources\"]\nobjects= dataset.full_data[\"destinations\"]\nnum_nodes = dataset.num_nodes \ntimestamps_orig = dataset.full_data[\"timestamps\"]\ntimestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1\n\nprint(\"split train valid test data\")\nall_quads = np.stack((subjects, relations, objects, timestamps, timestamps_orig), axis=1)\ntrain_data = all_quads[dataset.train_mask]\nval_data = all_quads[dataset.val_mask]\ntest_data = all_quads[dataset.test_mask]\n\nmetric = dataset.eval_metric\nevaluator = Evaluator(name=name)\nneg_sampler = dataset.negative_sampler\n\ntrain_val_data = np.concatenate([train_data, val_data])\nall_data = np.concatenate([train_data, val_data, test_data])\n\n# create dicts with key: relation id, values: triples for that relation id\nprint(\"grouping data by relation\")\ntest_data_prel = group_by(test_data, 1)\nall_data_prel = group_by(all_data, 1)\nval_data_prel = group_by(val_data, 1)\ntrainval_data_prel = group_by(train_val_data, 1)\n\n#load the ns samples \n# if parsed['train_flag']:\nprint(\"loading negative samples\")\ndataset.load_val_ns()\ndataset.load_test_ns()\n\n# parameter options\nif parsed['train_flag'] == 'True':\n    params_dict = {}\n    params_dict['lmbda_psi'] = [0, 0.0001, 0.0005, 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.5, 0.9, 1.0001] \n    params_dict['alpha'] = [0, 0.00001, 0.0001, 0.001, 0.01, 0.1, 0.5, 0.9, 0.99, 0.999, 0.9999, 0.99999, 1]\n    default_lmbda_psi = params_dict['lmbda_psi'][-1]\n    default_alpha = params_dict['alpha'][-2]\n\n## load rules\nprint(\"creating rules\")\nbasis_dict = create_basis_dict(train_val_data)\nprint(\"done with creating rules\")\n## init\n# rb_predictor = RecurrencyBaselinePredictor(rels)\n## train to find best lambda and alpha\nstart_train =  timeit.default_timer()\nif parsed['train_flag'] ==  'True':\n    if parsed['load_flag'] == 'True':\n        with open('best_config.json', 'r') as infile:\n            best_config = json.load(infile)\n    else:\n        print('start training')\n        best_config = train(params_dict,  rels, val_data_prel, trainval_data_prel, neg_sampler, parsed['num_processes'], \n            parsed['window'])\n        if parsed['save_config'] == 'True':\n            import json\n            with open('best_config.json', 'w') as outfile:\n                json.dump(best_config, outfile)\n\nelse: # use preset lmbda and alpha; same for all relations\n    best_config = {} \n    for rel in rels:\n        best_config[str(rel)] = {}\n        best_config[str(rel)]['lmbda_psi'] = [parsed['lmbda']]\n        best_config[str(rel)]['alpha'] = [parsed['alpha']]\n    \nend_train =  timeit.default_timer()\n\n# compute validation mrr\nprint(\"Computing validation MRR\")\nperf_list_all_val, hits_list_all_val = test(best_config,rels, val_data_prel, \n                                                trainval_data_prel, neg_sampler, parsed['num_processes'], \n                                            parsed['window'], split_mode='val')\nval_mrr = float(np.mean(perf_list_all_val))\n\n# compute test mrr\nprint(\"Computing test MRR\")\nstart_test =  timeit.default_timer()\nperf_list_all, hits_list_all = test(best_config,rels, test_data_prel, \n                                                 all_data_prel, neg_sampler, parsed['num_processes'], \n                                                parsed['window'])\n\nend_o =  timeit.default_timer()\ntrain_time_o = round(end_train- start_train, 6)  \ntest_time_o = round(end_o- start_test, 6)  \ntotal_time_o = round(end_o- start_o, 6)  \nprint(\"Running Training to find best configs finished in {} seconds.\".format(train_time_o))\nprint(\"Running testing with best configs finished in {} seconds.\".format(test_time_o))\nprint(\"Running all steps finished in {} seconds.\".format(total_time_o))\n\nprint(f\"The test MRR is {np.mean(perf_list_all)}\")\nprint(f\"The valid MRR is {val_mrr}\")\nprint(f\"The Hits@10 is {np.mean(hits_list_all)}\")\nprint(f\"We have {len(perf_list_all)} predictions\")\nprint(f\"The test set has len {len(test_data)} \")\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_NONE_{DATA}_results.json'\nmetric = dataset.eval_metric\nsave_results({'model': MODEL_NAME,\n              'train_flag': parsed['train_flag'],\n              'data': DATA,\n              'run': 1,\n              'seed': SEED,\n              metric: float(np.mean(perf_list_all)),\n              'val_mrr': val_mrr,\n              'hits10': float(np.mean(hits_list_all)),\n              'test_time': test_time_o,\n              'tot_train_val_time': total_time_o\n              }, \n    results_filename)\n\nif parsed['num_processes']>1:\n    ray.shutdown()\n\n\n    \n"
  },
  {
    "path": "examples/linkproppred/tkgl-polecat/regcn.py",
    "content": "\"\"\"\nTemporal Knowledge Graph Reasoning Based on Evolutional Representation Learning\nReference:\n- https://github.com/Lee-zix/RE-GCN\nZixuan Li, Xiaolong Jin, Wei Li, Saiping Guan, Jiafeng Guo, Huawei Shen, Yuanzhuo Wang and Xueqi Cheng. Temporal \nKnowledge Graph Reasoning Based on Evolutional Representation Learning. SIGIR 2021.\n\"\"\"\nimport sys\nimport timeit\nimport os\nimport sys\nimport os.path as osp\nfrom pathlib import Path\nimport numpy as np\nimport torch\nimport random\nfrom tqdm import tqdm\n# internal imports\nmodules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\nsys.path.append(modules_path)\nfrom modules.rrgcn import RecurrentRGCNREGCN\nfrom tgb.utils.utils import set_random_seed, split_by_time, save_results\nfrom modules.tkg_utils import get_args_regcn, reformat_ts\nfrom modules.tkg_utils_dgl import build_sub_graph\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom tgb.linkproppred.dataset import LinkPropPredDataset \nimport json\n\ndef test(model, history_list, test_list, num_rels, num_nodes, use_cuda, model_name, static_graph, mode, split_mode):\n    \"\"\"\n    Test the model on either test or validation set\n    :param model: model used to test\n    :param history_list:    all input history snap shot list, not include output label train list or valid list\n    :param test_list:   test triple snap shot list\n    :param num_rels:    number of relations\n    :param num_nodes:   number of nodes\n    :param use_cuda:\n    :param model_name:\n    :param mode:\n    :param split_mode: 'test' or 'val' to state which negative samples to load\n    :return mrr\n    \"\"\"\n    print(\"Testing for mode: \", split_mode)\n    if split_mode == 'test':\n        timesteps_to_eval = test_timestamps_orig\n    else:\n        timesteps_to_eval = val_timestamps_orig\n\n    idx = 0\n\n    if mode == \"test\":\n        # test mode: load parameter form file \n        if use_cuda:\n            checkpoint = torch.load(model_name, map_location=torch.device(args.gpu))\n        else:\n            checkpoint = torch.load(model_name, map_location=torch.device('cpu'))        \n        # use best stat checkpoint:\n        print(\"Load Model name: {}. Using best epoch : {}\".format(model_name, checkpoint['epoch']))  \n        print(\"\\n\"+\"-\"*10+\"start testing\"+\"-\"*10+\"\\n\")\n        model.load_state_dict(checkpoint['state_dict'])\n\n    model.eval()\n    input_list = [snap for snap in history_list[-args.test_history_len:]] \n    perf_list_all = []\n    hits_list_all = []\n    perf_per_rel = {}\n    for rel in range(num_rels):\n        perf_per_rel[rel] = []\n\n    for time_idx, test_snap in enumerate(tqdm(test_list)):\n        history_glist = [build_sub_graph(num_nodes, num_rels, g, use_cuda, args.gpu) for g in input_list]    \n        test_triples_input = torch.LongTensor(test_snap).cuda() if use_cuda else torch.LongTensor(test_snap)\n        timesteps_batch =timesteps_to_eval[time_idx]*np.ones(len(test_triples_input[:,0]))\n\n        neg_samples_batch = neg_sampler.query_batch(test_triples_input[:,0], test_triples_input[:,2], \n                                timesteps_batch, edge_type=test_triples_input[:,1], split_mode=split_mode)\n        pos_samples_batch = test_triples_input[:,2]\n\n        _, perf_list, hits_list = model.predict(history_glist, num_rels, static_graph, test_triples_input, use_cuda, neg_samples_batch, pos_samples_batch, \n                                    evaluator, METRIC)  \n\n        perf_list_all.extend(perf_list)\n        hits_list_all.extend(hits_list)\n        if split_mode == \"test\":\n            if args.log_per_rel:\n                for score, rel in zip(perf_list, test_triples_input[:,1].tolist()):\n                    perf_per_rel[rel].append(score)\n        # reconstruct history graph list\n        input_list.pop(0)\n        input_list.append(test_snap)\n        idx += 1\n\n    if split_mode == \"test\":\n        if args.log_per_rel:   \n            for rel in range(num_rels):\n                if len(perf_per_rel[rel]) > 0:\n                    perf_per_rel[rel] = float(np.mean(perf_per_rel[rel]))\n                else:\n                    perf_per_rel.pop(rel)\n    mrr = np.mean(perf_list_all) \n    hits10 = np.mean(hits_list_all) \n    return mrr, perf_per_rel, hits10\n\n\n\ndef run_experiment(args, n_hidden=None, n_layers=None, dropout=None, n_bases=None):\n    \"\"\"\n    Run the experiment with the given configuration\n    :param args: arguments\n    :param n_hidden: hidden dimension\n    :param n_layers: number of layers\n    :param dropout: dropout rate\n    :param n_bases: number of bases\n    :return: mrr, perf_per_rel  (mean reciprocal rank, performance per relation)\n    \"\"\"\n    # load configuration for grid search the best configuration\n    if n_hidden:\n        args.n_hidden = n_hidden\n    if n_layers:\n        args.n_layers = n_layers\n    if dropout:\n        args.dropout = dropout\n    if n_bases:\n        args.n_bases = n_bases\n    mrr = 0\n    hits10=0\n    # 2) set save model path\n    save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'\n    if not osp.exists(save_model_dir):\n        os.mkdir(save_model_dir)\n        print('INFO: Create directory {}'.format(save_model_dir))\n    model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}'\n    model_state_file = save_model_dir+model_name\n    perf_per_rel = {}\n    use_cuda = args.gpu >= 0 and torch.cuda.is_available()\n\n    num_static_rels, num_words, static_triples, static_graph = 0, 0, [], None\n\n    # create stat\n    model = RecurrentRGCNREGCN(args.decoder,\n                          args.encoder,\n                        num_nodes,\n                        int(num_rels/2),\n                        num_static_rels, # DIFFERENT\n                        num_words, # DIFFERENT\n                        args.n_hidden,\n                        args.opn,\n                        sequence_len=args.train_history_len,\n                        num_bases=args.n_bases,\n                        num_basis=args.n_basis,\n                        num_hidden_layers=args.n_layers,\n                        dropout=args.dropout,\n                        self_loop=args.self_loop,\n                        skip_connect=args.skip_connect,\n                        layer_norm=args.layer_norm,\n                        input_dropout=args.input_dropout,\n                        hidden_dropout=args.hidden_dropout,\n                        feat_dropout=args.feat_dropout,\n                        aggregation=args.aggregation, # DIFFERENT\n                        weight=args.weight, # DIFFERENT\n                        discount=args.discount, # DIFFERENT\n                        angle=args.angle, # DIFFERENT\n                        use_static=args.add_static_graph, # DIFFERENT\n                        entity_prediction=args.entity_prediction, \n                        relation_prediction=args.relation_prediction,\n                        use_cuda=use_cuda,\n                        gpu = args.gpu,\n                        analysis=args.run_analysis) # DIFFERENT\n\n    if use_cuda:\n        torch.cuda.set_device(args.gpu)\n        model.cuda()\n\n    # optimizer\n    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)\n\n    if args.test and os.path.exists(model_state_file):\n        mrr, perf_per_rel, hits10 = test(model, \n                    train_list+valid_list, \n                    test_list, \n                    num_rels, \n                    num_nodes, \n                    use_cuda, \n                    model_state_file, \n                    static_graph, \n                    \"test\", \n                    \"test\")\n        return mrr, perf_per_rel, hits10\n    elif args.test and not os.path.exists(model_state_file):\n        print(\"--------------{} not exist, Change mode to train and generate stat for testing----------------\\n\".format(model_state_file))\n        return 0, 0\n    else:\n        print(\"----------------------------------------start training----------------------------------------\\n\")\n        best_mrr = 0\n        best_hits = 0\n        for epoch in range(args.n_epochs):\n\n            model.train()\n            losses = []\n            idx = [_ for _ in range(len(train_list))]\n            random.shuffle(idx)\n\n            for train_sample_num in tqdm(idx):\n                if train_sample_num == 0: continue\n                output = train_list[train_sample_num:train_sample_num+1]\n                if train_sample_num - args.train_history_len<0:\n                    input_list = train_list[0: train_sample_num]\n                else:\n                    input_list = train_list[train_sample_num - args.train_history_len:\n                                        train_sample_num]\n\n                # generate history graph\n                history_glist = [build_sub_graph(num_nodes, num_rels, snap, use_cuda, args.gpu) for snap in input_list]\n                output = [torch.from_numpy(_).long().cuda() for _ in output] if use_cuda else [torch.from_numpy(_).long() for _ in output]\n                loss_e, loss_r, loss_static = model.get_loss(history_glist, output[0], static_graph, use_cuda)\n                loss = args.task_weight*loss_e + (1-args.task_weight)*loss_r + loss_static\n\n                losses.append(loss.item())\n\n                loss.backward()\n                torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm)  # clip gradients\n                optimizer.step()\n                optimizer.zero_grad()\n\n            print(\"Epoch {:04d} | Ave Loss: {:.4f} | Best MRR {:.4f} | Model {} \"\n                  .format(epoch, np.mean(losses),  best_mrr, model_name))\n            \n            #! checking GPU usage\n            free_mem, total_mem = torch.cuda.mem_get_info()\n            print (\"--------------GPU memory usage-----------\")\n            print (\"there are \", free_mem, \" free memory\")\n            print (\"there are \", total_mem, \" total available memory\")\n            print (\"there are \", total_mem - free_mem, \" used memory\")\n            print (\"--------------GPU memory usage-----------\")\n\n            # validation\n            if epoch and epoch % args.evaluate_every == 0:\n                mrr,perf_per_rel, hits10 = test(model, train_list, \n                            valid_list, \n                            num_rels, \n                            num_nodes, \n                            use_cuda, \n                            model_state_file, \n                            static_graph, \n                            mode=\"train\", split_mode='val')\n            \n                if mrr < best_mrr:\n                    if epoch >= args.n_epochs:\n                        break\n                else:\n                    best_mrr = mrr\n                    best_hits = hits10\n                    torch.save({'state_dict': model.state_dict(), 'epoch': epoch}, model_state_file)\n\n        return best_mrr, perf_per_rel, hits10\n# ==================\n# ==================\n# ==================\n\nstart_overall = timeit.default_timer()\n\n# set hyperparameters\nargs, _ = get_args_regcn()\nargs.dataset = 'tkgl-polecat'\n\nSEED = args.seed  # set the random seed for consistency\nset_random_seed(SEED)\n\nDATA=args.dataset\nMODEL_NAME = 'REGCN'\n\nprint(\"logging mrrs per relation: \", args.log_per_rel)\n\n# load data\ndataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\n\nnum_rels = dataset.num_rels\nnum_nodes = dataset.num_nodes \nsubjects = dataset.full_data[\"sources\"]\nobjects= dataset.full_data[\"destinations\"]\nrelations = dataset.edge_type\n\ntimestamps_orig = dataset.full_data[\"timestamps\"]\ntimestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1\nall_quads = np.stack((subjects, relations, objects, timestamps), axis=1)\n\ntrain_data = all_quads[dataset.train_mask]\nval_data = all_quads[dataset.val_mask]\ntest_data = all_quads[dataset.test_mask]\n\nval_timestamps_orig = list(set(timestamps_orig[dataset.val_mask])) # needed for getting the negative samples\nval_timestamps_orig.sort()\ntest_timestamps_orig = list(set(timestamps_orig[dataset.test_mask])) # needed for getting the negative samples\ntest_timestamps_orig.sort()\n\ntrain_list = split_by_time(train_data)\nvalid_list = split_by_time(val_data)\ntest_list = split_by_time(test_data)\n\n# evaluation metric\nMETRIC = dataset.eval_metric\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n#load the ns samples \ndataset.load_val_ns()\ndataset.load_test_ns()\n\n## run training and testing\nval_mrr, test_mrr = 0, 0\ntest_hits10 = 0\nif args.grid_search:\n    print(\"hyperparameter grid search not implemented. Exiting.\")\n# single run\nelse:\n    start_train = timeit.default_timer()\n    if args.test == False: #if they are true: directly test on a previously trained and stored model\n        print('start training')\n        val_mrr, perf_per_rel, val_hits10 = run_experiment(args) # do training\n    start_test = timeit.default_timer()\n    args.test = True\n    print('start testing')\n    test_mrr, perf_per_rel, test_hits10 = run_experiment(args) # do testing\n\n\ntest_time = timeit.default_timer() - start_test\nall_time = timeit.default_timer() - start_train\nprint(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\nprint(f\"\\Train and Test: Elapsed Time (s): {all_time: .4f}\")\n\nprint(f\"\\tTest: {METRIC}: {test_mrr: .4f}\")\nprint(f\"\\tValid: {METRIC}: {val_mrr: .4f}\")\n\n# saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'\n\nsave_results({'model': MODEL_NAME,\n              'data': DATA,\n              'run': args.run_nr,\n              'seed': SEED,\n              f'val {METRIC}': float(val_mrr),\n              f'test {METRIC}': float(test_mrr),\n              'test_time': test_time,\n              'tot_train_val_time': all_time,\n              'test_hits10': float(test_hits10)\n              }, \n    results_filename)\n\nif args.log_per_rel:\n    results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results_per_rel.json'\n    with open(results_filename, 'w') as json_file:\n        json.dump(perf_per_rel, json_file)\n\nsys.exit()"
  },
  {
    "path": "examples/linkproppred/tkgl-polecat/timetraveler.py",
    "content": "\"\"\"\nTimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting\nReference:\n- https://github.com/JHL-HUST/TITer\nHaohai Sun, Jialun Zhong, Yunpu Ma, Zhen Han, Kun He.\nTimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting EMNLP 2021\n\"\"\"\nimport sys\nimport timeit\n\nimport torch\nfrom torch.utils.data import Dataset,DataLoader\nimport logging\n\nimport numpy as np\nimport pickle\nfrom tqdm import tqdm\nimport os.path as osp\nfrom pathlib import Path\nimport os\n\nmodules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\nsys.path.append(modules_path)\nfrom modules.timetraveler_agent import Agent\nfrom modules.timetraveler_environment import Env\nfrom modules.timetraveler_dirichlet import Dirichlet, MLE_Dirchlet\nfrom modules.timetraveler_episode import Episode\nfrom modules.timetraveler_policygradient import PG\nfrom tgb.linkproppred.dataset import LinkPropPredDataset \nfrom tgb.linkproppred.evaluate import Evaluator\nfrom modules.timetraveler_trainertester import Trainer, Tester, getRelEntCooccurrence\nfrom tgb.utils.utils import set_random_seed,save_results \nfrom modules.tkg_utils import  get_args_timetraveler, reformat_ts, get_model_config_timetraveler\n\nclass QuadruplesDataset(Dataset):\n    \"\"\" this is an internal way how Timetraveler represents the data\n    \"\"\"\n    def __init__(self, examples):\n        \"\"\"\n        examples: a list of quadruples.\n        num_r: number of relations\n        \"\"\"\n        self.quadruples = examples.copy()\n\n\n    def __len__(self):\n        return len(self.quadruples)\n\n    def __getitem__(self, item):\n        return self.quadruples[item][0], \\\n               self.quadruples[item][1], \\\n               self.quadruples[item][2], \\\n               self.quadruples[item][3], \\\n               self.quadruples[item][4]\n    \ndef set_logger(save_path):\n    \"\"\"Write logs to checkpoint and console\"\"\"\n    if args.do_train:\n        log_file = os.path.join(save_path, 'train.log')\n    else:\n        log_file = os.path.join(save_path, 'test.log')\n\n    logging.basicConfig(\n        format='%(asctime)s %(levelname)-8s %(message)s',\n        level=logging.INFO,\n        datefmt='%Y-%m-%d %H:%M:%S',\n        filename=log_file,\n        filemode='w'\n    )\n    console = logging.StreamHandler()\n    console.setLevel(logging.INFO)\n    formatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s')\n    console.setFormatter(formatter)\n    logging.getLogger('').addHandler(console)\n\ndef preprocess_data(args, config, timestamps, save_path, all_quads):\n    \"\"\"\n    Preprocess the data and save the state-action space (pickle dump)\n    \"\"\"\n    # parser = argparse.ArgumentParser(description='Data preprocess', usage='preprocess_data.py [<args>] [-h | --help]')\n    # parser.add_argument('--data_dir', default='data/ICEWS14', type=str, help='Path to data.')\n\n    env = Env(all_quads, config)\n    state_actions_space = {}\n\n    with tqdm(total=len(all_quads)) as bar:\n        for (head, rel, tail, t, _) in all_quads:\n            if (head, t, True) not in state_actions_space.keys():\n                state_actions_space[(head, t, True)] = env.get_state_actions_space_complete(head, t, True, args.store_actions_num)\n                state_actions_space[(head, t, False)] = env.get_state_actions_space_complete(head, t, False, args.store_actions_num)\n            if (tail, t, True) not in state_actions_space.keys():\n                state_actions_space[(tail, t, True)] = env.get_state_actions_space_complete(tail, t, True, args.store_actions_num)\n                state_actions_space[(tail, t, False)] = env.get_state_actions_space_complete(tail, t, False, args.store_actions_num)\n            bar.update(1)\n    pickle.dump(state_actions_space, open(os.path.join(save_path,  args.state_actions_path), 'wb'))\n\ndef log_metrics(mode, step, metrics):\n    \"\"\"Print the evaluation logs\"\"\"\n    for metric in metrics:\n        logging.info('%s %s at epoch %d: %f' % (mode, metric, step, metrics[metric]))\n\ndef main(args):\n    \"\"\"\n    Main function to train and test the TimeTraveler model\"\"\"\n\n    start_overall = timeit.default_timer()\n    #######################Set Logger#################################\n    \n    save_path = f'{os.path.dirname(os.path.abspath(__file__))}/saved_models/'\n    if not os.path.exists(save_path):\n        os.makedirs(save_path)\n\n    if args.cuda and torch.cuda.is_available():\n        args.cuda = True\n    else:\n        args.cuda = False\n    set_logger(save_path)\n\n    #######################Create DataLoader#################################\n    # set hyperparameters\n    args.dataset = 'tkgl-yago'\n\n    SEED = args.seed  # set the random seed for consistency\n    set_random_seed(SEED)\n\n    DATA=args.dataset\n    MODEL_NAME = 'TIMETRAVELER'\n\n    # load data\n    dataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\n\n    num_rels = dataset.num_rels\n    num_nodes = dataset.num_nodes \n    subjects = dataset.full_data[\"sources\"]\n    objects= dataset.full_data[\"destinations\"]\n    relations = dataset.edge_type\n\n    timestamps_orig = dataset.full_data[\"timestamps\"]\n    timestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1\n    all_quads = np.stack((subjects, relations, objects, timestamps,timestamps_orig), axis=1)\n\n    train_data = all_quads[dataset.train_mask]\n    train_entities = np.unique(np.concatenate((train_data[:, 0], train_data[:, 2])))\n    RelEntCooccurrence = getRelEntCooccurrence(train_data, num_rels)\n    train_data =QuadruplesDataset(train_data)\n    val_data = QuadruplesDataset(all_quads[dataset.val_mask])\n    test_data = QuadruplesDataset(all_quads[dataset.test_mask])\n\n    METRIC = dataset.eval_metric\n    evaluator = Evaluator(name=DATA)\n    neg_sampler = dataset.negative_sampler\n    #load the ns samples \n    dataset.load_val_ns()\n    dataset.load_test_ns()\n\n    train_dataloader = DataLoader(\n        train_data,\n        batch_size=args.batch_size,\n        shuffle=True,\n        num_workers=args.num_workers,\n    )\n\n    valid_dataloader = DataLoader(\n        val_data,\n        batch_size=args.test_batch_size,\n        shuffle=False,\n        num_workers=args.num_workers,\n    )\n\n    test_dataloader = DataLoader(\n        test_data,\n        batch_size=args.test_batch_size,\n        shuffle=False,\n        num_workers=args.num_workers,\n    )\n\n    ######################Creat the agent and the environment###########################\n    config = get_model_config_timetraveler(args, num_nodes, num_rels)\n    logging.info(config)\n    logging.info(args)\n\n    # creat the agent\n    agent = Agent(config)\n\n\n    # creat the environment\n    state_actions_path = os.path.join(save_path, args.state_actions_path)\n\n\n    ######################preprocessing###########################\n    if not os.path.exists(state_actions_path):\n        if args.preprocess:\n            print(\"preprocessing data...\")\n            preprocess_data(args, config, timestamps, save_path, list(all_quads))\n            state_action_space = pickle.load(open(os.path.join(save_path, args.state_actions_path), 'rb'))\n        else:\n            state_action_space = None\n    else:\n        print(\"load preprocessed data...\")\n        state_action_space = pickle.load(open(os.path.join(save_path, args.state_actions_path), 'rb'))\n\n\n    env = Env(list(all_quads), config, state_action_space)\n    # Create episode controller\n    episode = Episode(env, agent, config)\n    if args.cuda:\n        episode = episode.cuda()\n    pg = PG(config)  # Policy Gradient\n    optimizer = torch.optim.Adam(episode.parameters(), lr=args.lr, weight_decay=0.00001)\n\n    ######################Reward Shaping: MLE DIRICHLET alphas###########################\n    if args.reward_shaping: \n        try:\n            print(\"load alphas from pickle file\")\n            alphas = pickle.load(open(os.path.join(save_path, args.alphas_pkl), 'rb'))\n        except:\n            print('running MLE dirichlet now')\n            mle_d = MLE_Dirchlet(all_quads, num_rels, args.k, args.time_span,\n                         args.tol, args.method, args.maxiter)\n            pickle.dump(mle_d.alphas, open(os.path.join(save_path, args.alphas_pkl), 'wb'))\n\n            print('dumped alphas')\n            alphas = mle_d.alphas\n        distributions = Dirichlet(alphas, args.k)\n    else:\n        distributions = None\n\n    ######################Training and Testing###########################\n\n    trainer = Trainer(episode, pg, optimizer, args, distributions)\n    tester = Tester(episode, args, train_entities, RelEntCooccurrence, dataset.metric)\n    test_metrics ={}\n    val_metrics = {}\n    test_metrics[METRIC] = None\n    val_metrics[METRIC] = None\n\n    if args.do_train:\n        start_train =timeit.default_timer()\n        logging.info('Start Training......')\n        for i in range(args.max_epochs):\n            loss, reward = trainer.train_epoch(train_dataloader, len(train_data))\n            logging.info('Epoch {}/{} Loss: {}, reward: {}'.format(i, args.max_epochs, loss, reward))\n\n            #! checking GPU usage\n            free_mem, total_mem = torch.cuda.mem_get_info()\n            print (\"--------------GPU memory usage-----------\")\n            print (\"there are \", free_mem, \" free memory\")\n            print (\"there are \", total_mem, \" total available memory\")\n            print (\"there are \", total_mem - free_mem, \" used memory\")\n            print (\"--------------GPU memory usage-----------\")\n            \n            if i % args.save_epoch == 0 and i != 0:\n                trainer.save_model(save_path, 'checkpoint_{}.pth'.format(i))\n                logging.info('Save Model in {}'.format(save_path))\n\n            if i % args.valid_epoch == 0 and i != 0:\n                logging.info('Start Val......')\n                val_metrics = tester.test(valid_dataloader,\n                                      len(val_data), num_nodes, neg_sampler, evaluator, split_mode='val')\n                for mode in val_metrics.keys():\n                    logging.info('{} at epoch {}: {}'.format(mode, i, val_metrics[mode]))\n\n        trainer.save_model(save_path)\n        logging.info('Save Model in {}'.format(save_path))\n    else:\n          # # Load the model parameters\n        if os.path.isfile(save_path):\n            params = torch.load(save_path)\n            episode.load_state_dict(params['model_state_dict'])\n            optimizer.load_state_dict(params['optimizer_state_dict'])\n            logging.info('Load pretrain model: {}'.format(save_path))\n    if args.do_test:\n        logging.info('Start Testing......')\n        start_test = timeit.default_timer()\n        test_metrics = tester.test(test_dataloader,\n                              len(test_data), num_nodes, neg_sampler, evaluator, split_mode='test')\n        for mode in test_metrics.keys():\n            logging.info('Test {} : {}'.format(mode, test_metrics[mode]))\n\n        # saving the results...\n        results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\n        if not osp.exists(results_path):\n            os.mkdir(results_path)\n            print('INFO: Create directory {}'.format(results_path))\n        Path(results_path).mkdir(parents=True, exist_ok=True)\n        results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'\n        test_time = timeit.default_timer() - start_test\n        all_time = timeit.default_timer() - start_train \n        all_time_preprocess = timeit.default_timer() - start_overall \n\n        save_results({'model': MODEL_NAME,\n                    'data': DATA,\n                    'seed': SEED,\n                    f'val {METRIC}': float(val_metrics[METRIC]),\n                    f'test {METRIC}': float(test_metrics[METRIC]),\n                    'test_time': test_time,\n                    'tot_train_val_time': all_time,\n                    'tot_preprocess_train_val_time': all_time_preprocess\n                    }, \n            results_filename)     \n\nif __name__ == '__main__':\n    args = get_args_timetraveler()\n    main(args)"
  },
  {
    "path": "examples/linkproppred/tkgl-polecat/tkgl-polecat_example.py",
    "content": "import sys\r\nsys.path.insert(0,'/../../../')\r\nimport numpy as np\r\nimport timeit\r\nfrom tqdm import tqdm\r\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\r\nfrom torch_geometric.loader import TemporalDataLoader\r\nfrom tgb.linkproppred.evaluate import Evaluator\r\n\r\nDATA = \"tkgl-polecat\"\r\n\r\n# data loading\r\ndataset = PyGLinkPropPredDataset(name=DATA, root=\"datasets\")\r\ntrain_mask = dataset.train_mask\r\nval_mask = dataset.val_mask\r\ntest_mask = dataset.test_mask\r\ndata = dataset.get_TemporalData()\r\nmetric = dataset.eval_metric\r\n\r\nprint (\"there are {} nodes and {} edges\".format(dataset.num_nodes, dataset.num_edges))\r\nprint (\"there are {} relation types\".format(dataset.num_rels))\r\n\r\n\r\ntimestamp = data.t\r\nhead = data.src\r\ntail = data.dst\r\nedge_type = data.edge_type #relation\r\nneg_sampler = dataset.negative_sampler\r\n\r\ntrain_data = data[train_mask]\r\nval_data = data[val_mask]\r\ntest_data = data[test_mask]\r\n\r\n\r\nmetric = dataset.eval_metric\r\nevaluator = Evaluator(name=DATA)\r\nneg_sampler = dataset.negative_sampler\r\n\r\nBATCH_SIZE = 200\r\nval_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)\r\ntest_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)\r\n\r\nstart_time = timeit.default_timer()\r\n#load the ns samples first\r\ndataset.load_val_ns()\r\nfor batch in tqdm(val_loader):\r\n    src, pos_dst, t, msg, rel = batch.src, batch.dst, batch.t, batch.msg, batch.edge_type\r\n    neg_batch_list = neg_sampler.query_batch(src.detach().cpu().numpy(), pos_dst.detach().cpu().numpy(), t.detach().cpu().numpy(), rel.detach().cpu().numpy(), split_mode='val')\r\nprint (\"loading ns samples from validation\", timeit.default_timer() - start_time)\r\n# for i, (src, dst, t, rel) in enumerate(zip(val_data.src, val_data.dst, val_data.t, val_data.edge_type)):\r\n#     #must use np array to query\r\n#     neg_batch_list = neg_sampler.query_batch(np.array([src]), np.array([dst]), np.array([t]), edge_type=np.array([rel]), split_mode='val')\r\n\r\nstart_time = timeit.default_timer()\r\ndataset.load_test_ns()\r\nfor batch in test_loader:\r\n    src, pos_dst, t, msg, rel = batch.src, batch.dst, batch.t, batch.msg, batch.edge_type\r\n    neg_batch_list = neg_sampler.query_batch(src.detach().cpu().numpy(), pos_dst.detach().cpu().numpy(), t.detach().cpu().numpy(), rel.detach().cpu().numpy(), split_mode='test')\r\nprint (\"loading ns samples from test\", timeit.default_timer() - start_time)\r\n# for i, (src, dst, t, rel) in enumerate(zip(test_data.src, test_data.dst, test_data.t, test_data.edge_type)):\r\n#     #must use np array to query\r\n#     neg_batch_list = neg_sampler.query_batch(np.array([src]), np.array([dst]), np.array([t]), edge_type=np.array([rel]), split_mode='test')\r\nprint (\"retrieved all negative samples\")\r\n\r\n\r\n# #* load numpy arrays instead\r\n# from tgb.linkproppred.dataset import LinkPropPredDataset\r\n\r\n# # data loading\r\n# dataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\r\n# data = dataset.full_data  \r\n# metric = dataset.eval_metric\r\n# sources = dataset.full_data['sources']\r\n# print (sources.dtype)\r\n\r\n"
  },
  {
    "path": "examples/linkproppred/tkgl-polecat/tlogic.py",
    "content": "\"\"\"\nhttps://github.com/liu-yushan/TLogic/tree/main/mycode\nTLogic: Temporal Logical Rules for Explainable Link Forecasting on Temporal Knowledge Graphs.\nYushan Liu, Yunpu Ma, Marcel Hildebrandt, Mitchell Joblin, Volker Tresp\n\"\"\"\n\n# imports\nimport sys\nimport os\nimport os.path as osp\nfrom pathlib import Path\nmodules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\nsys.path.append(modules_path)\nimport timeit\nimport argparse\nimport numpy as np\nimport json\nfrom joblib import Parallel, delayed\nimport itertools\n\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom tgb.linkproppred.dataset import LinkPropPredDataset \nfrom modules.tlogic_learn_modules import Temporal_Walk, Rule_Learner, store_edges\nimport modules.tlogic_apply_modules as ra\nfrom tgb.utils.utils import set_random_seed,  save_results\nfrom modules.tkg_utils import reformat_ts, get_inv_relation_id, create_scores_array\n\ndef learn_rules(i, num_relations):\n    \"\"\"\n    Learn rules (multiprocessing possible).\n\n    Parameters:\n        i (int): process number\n        num_relations (int): minimum number of relations for each process\n\n    Returns:\n        rl.rules_dict (dict): rules dictionary\n    \"\"\"\n\n    # if seed:\n    #     np.random.seed(seed)\n\n    num_rest_relations = len(all_relations) - (i + 1) * num_relations\n    if num_rest_relations >= num_relations:\n        relations_idx = range(i * num_relations, (i + 1) * num_relations)\n    else:\n        relations_idx = range(i * num_relations, len(all_relations))\n\n    num_rules = [0]\n    for k in relations_idx:\n        rel = all_relations[k]\n        for length in rule_lengths:\n            it_start =  timeit.default_timer()\n            for _ in range(num_walks):\n                walk_successful, walk = temporal_walk.sample_walk(length + 1, rel)\n                if walk_successful:\n                    rl.create_rule(walk)\n            it_end =  timeit.default_timer()\n            it_time = round(it_end - it_start, 6)\n            num_rules.append(sum([len(v) for k, v in rl.rules_dict.items()]) // 2)\n            num_new_rules = num_rules[-1] - num_rules[-2]\n            print(\n                \"Process {0}: relation {1}/{2}, length {3}: {4} sec, {5} rules\".format(\n                    i,\n                    k - relations_idx[0] + 1,\n                    len(relations_idx),\n                    length,\n                    it_time,\n                    num_new_rules,\n                )\n            )\n\n    return rl.rules_dict\n\ndef apply_rules(i, num_queries, rules_dict, neg_sampler, data, window, learn_edges, all_quads, args, split_mode, \n                log_per_rel=False, num_rels=0):\n    \"\"\"\n    Apply rules (multiprocessing possible).\n\n    Parameters:\n        i (int): process number\n        num_queries (int): minimum number of queries for each process\n\n    Returns:\n        hits_list (list): hits list (hits@10 per sample)\n        perf_list (list): performance list (mrr per sample)\n    \"\"\"\n    perf_per_rel = {}\n    for rel in range(num_rels):\n            perf_per_rel[rel] = []\n    print(\"Start process\", i, \"...\")\n    all_candidates = [dict() for _ in range(len(args))]\n    no_cands_counter = 0\n\n    num_rest_queries = len(data) - (i + 1) * num_queries\n    if num_rest_queries >= num_queries:\n        test_queries_idx = range(i * num_queries, (i + 1) * num_queries)\n    else:\n        test_queries_idx = range(i * num_queries, len(data))\n\n    cur_ts = data[test_queries_idx[0]][3]\n    edges = ra.get_window_edges(all_quads[:,0:4], cur_ts, learn_edges, window)\n\n    it_start =  timeit.default_timer()\n    hits_list = [0] * len(test_queries_idx)\n    perf_list = [0] * len(test_queries_idx)\n    for index, j in enumerate(test_queries_idx):\n        neg_sample_el =  neg_sampler.query_batch(np.expand_dims(np.array(data[j,0]), axis=0), \n                                                np.expand_dims(np.array(data[j,2]), axis=0), \n                                                np.expand_dims(np.array(data[j,4]), axis=0), \n                                                np.expand_dims(np.array(data[j,1]), axis=0), \n                                                split_mode=split_mode)[0]        \n        \n        # neg_samples_batch[j]\n        pos_sample_el =  data[j,2]\n        test_query = data[j]\n        assert pos_sample_el == test_query[2]\n        cands_dict = [dict() for _ in range(len(args))]\n\n        if test_query[3] != cur_ts:\n            cur_ts = test_query[3]\n            edges = ra.get_window_edges(all_quads[:,0:4], cur_ts, learn_edges, window)\n\n        if test_query[1] in rules_dict:\n            dicts_idx = list(range(len(args)))\n            for rule in rules_dict[test_query[1]]:\n                walk_edges = ra.match_body_relations(rule, edges, test_query[0])\n\n                if 0 not in [len(x) for x in walk_edges]:\n                    rule_walks = ra.get_walks(rule, walk_edges)\n                    if rule[\"var_constraints\"]:\n                        rule_walks = ra.check_var_constraints(\n                            rule[\"var_constraints\"], rule_walks\n                        )\n\n                    if not rule_walks.empty:\n                        cands_dict = ra.get_candidates(\n                            rule,\n                            rule_walks,\n                            cur_ts,\n                            cands_dict,\n                            score_func,\n                            args,\n                            dicts_idx,\n                        )\n                        for s in dicts_idx:\n                            cands_dict[s] = {\n                                x: sorted(cands_dict[s][x], reverse=True)\n                                for x in cands_dict[s].keys()\n                            }\n                            cands_dict[s] = dict(\n                                sorted(\n                                    cands_dict[s].items(),\n                                    key=lambda item: item[1],\n                                    reverse=True,\n                                )\n                            )\n                            top_k_scores = [v for _, v in cands_dict[s].items()][:top_k]\n                            unique_scores = list(\n                                scores for scores, _ in itertools.groupby(top_k_scores)\n                            )\n                            if len(unique_scores) >= top_k:\n                                dicts_idx.remove(s)\n                        if not dicts_idx:\n                            break\n\n            if cands_dict[0]:\n                for s in range(len(args)):\n                    # Calculate noisy-or scores\n                    scores = list(\n                        map(\n                            lambda x: 1 - np.product(1 - np.array(x)),\n                            cands_dict[s].values(),\n                        )\n                    )\n                    cands_scores = dict(zip(cands_dict[s].keys(), scores))\n                    noisy_or_cands = dict(\n                        sorted(cands_scores.items(), key=lambda x: x[1], reverse=True)\n                    )\n                    all_candidates[s][j] = noisy_or_cands\n            else:  # No candidates found by applying rules\n                no_cands_counter += 1\n                for s in range(len(args)):\n                    all_candidates[s][j] = dict()\n\n        else:  # No rules exist for this relation\n            no_cands_counter += 1\n            for s in range(len(args)):\n                all_candidates[s][j] = dict()\n\n        if not (j - test_queries_idx[0] + 1) % 100:\n            it_end =  timeit.default_timer()\n            it_time = round(it_end - it_start, 6)\n            print(\n                \"Process {0}: test samples finished: {1}/{2}, {3} sec\".format(\n                    i, j - test_queries_idx[0] + 1, len(test_queries_idx), it_time\n                )\n            )\n            it_start =  timeit.default_timer()\n\n        predictions = create_scores_array(all_candidates[s][j], num_nodes)  \n        predictions_of_interest_pos = np.array(predictions[pos_sample_el])\n        predictions_of_interest_neg = predictions[neg_sample_el]\n        input_dict = {\n            \"y_pred_pos\": predictions_of_interest_pos,\n            \"y_pred_neg\": predictions_of_interest_neg,\n            \"eval_metric\": ['mrr'], \n        }\n\n        predictions = evaluator.eval(input_dict)\n        perf_list[index] = predictions['mrr']\n        hits_list[index] = predictions['hits@10']\n        if split_mode == \"test\":\n            if log_per_rel:\n                perf_per_rel[test_query[1]].append(perf_list[index]) #test_query[1] is the relation index\n\n    if split_mode == \"test\":\n        if log_per_rel:   \n            for rel in range(num_rels):\n                if len(perf_per_rel[rel]) > 0:\n                    perf_per_rel[rel] = float(np.mean(perf_per_rel[rel]))\n                else:\n                    perf_per_rel.pop(rel)       \n               \n\n    return perf_list, hits_list, perf_per_rel\n\n\n## args\ndef get_args(): \n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--dataset\", \"-d\", default=\"tkgl-polecat\", type=str) \n    parser.add_argument(\"--rule_lengths\", \"-l\", default=\"1\", type=int, nargs=\"+\")\n    parser.add_argument(\"--num_walks\", \"-n\", default=\"100\", type=int)\n    parser.add_argument(\"--transition_distr\", default=\"exp\", type=str)\n    parser.add_argument(\"--window\", \"-w\", default=0, type=int) # set to e.g. 200 if only the most recent 200 timesteps should be considered. set to -2 if multistep\n    parser.add_argument(\"--top_k\", default=20, type=int)\n    parser.add_argument(\"--num_processes\", \"-p\", default=1, type=int)\n    parser.add_argument(\"--alpha\", \"-alpha\",  default=0.99, type=float) # fix alpha. used if trainflag == false\n    # parser.add_argument(\"--train_flag\", \"-tr\",  default=True) # do we need training, ie selection of lambda and alpha\n    parser.add_argument(\"--save_config\", \"-c\",  default=True) # do we need to save the selection of lambda and alpha in config file?\n    parser.add_argument('--seed', type=int, help='Random seed', default=1)\n    parser.add_argument('--run_nr', type=int, help='Run Number', default=1)\n    parser.add_argument('--learn_rules_flag', type=bool, help='Do we want to learn the rules', default=True)\n    parser.add_argument('--rule_filename', type=str, help='if rules not learned: where are they stored', default='0_r[3]_n100_exp_s1_rules.json')\n    parser.add_argument('--log_per_rel', type=bool, help='Do we want to log mrr per relation', default=False)\n    parser.add_argument('--compute_valid_mrr', type=bool, help='Do we want to compute mrr for valid set', default=True)\n    parsed = vars(parser.parse_args())\n    return parsed\n\nstart_o =  timeit.default_timer()\n\n## get args\nparsed = get_args()\ndataset = parsed[\"dataset\"]\nrule_lengths = parsed[\"rule_lengths\"]\nrule_lengths = [rule_lengths] if (type(rule_lengths) == int) else rule_lengths\nnum_walks = parsed[\"num_walks\"]\ntransition_distr = parsed[\"transition_distr\"]\nnum_processes = parsed[\"num_processes\"]\nwindow = parsed[\"window\"]\ntop_k = parsed[\"top_k\"]\nlog_per_rel = parsed['log_per_rel']\n\nMODEL_NAME = 'TLogic'\nSEED = parsed['seed']  # set the random seed for consistency\nset_random_seed(SEED)\n\n## load dataset and prepare it accordingly\nname = parsed[\"dataset\"]\ncompute_valid_mrr = parsed[\"compute_valid_mrr\"]\ndataset = LinkPropPredDataset(name=name, root=\"datasets\", preprocess=True)\nDATA = name\n\nrelations = dataset.edge_type\nnum_rels = dataset.num_rels\n\nsubjects = dataset.full_data[\"sources\"]\nobjects= dataset.full_data[\"destinations\"]\nnum_nodes = dataset.num_nodes \ntimestamps_orig = dataset.full_data[\"timestamps\"]\ntimestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1\n\nall_quads = np.stack((subjects, relations, objects, timestamps, timestamps_orig), axis=1)\ntrain_data = all_quads[dataset.train_mask,0:4] # we do not need the original timestamps\nval_data = all_quads[dataset.val_mask,0:5]\ntest_data = all_quads[dataset.test_mask,0:5]\nall_data = all_quads[:,0:4]\n\nmetric = dataset.eval_metric\nevaluator = Evaluator(name=name)\nneg_sampler = dataset.negative_sampler\n\ninv_relation_id = get_inv_relation_id(num_rels)\n\n#load the ns samples \n\ndataset.load_val_ns()\ndataset.load_test_ns()\noutput_dir =  f'{osp.dirname(osp.abspath(__file__))}/saved_models/'\nlearn_rules_flag = parsed['learn_rules_flag']\n## 1. learn rules\nstart_train =  timeit.default_timer()\nif learn_rules_flag:\n    print(\"start learning rules\")\n    # edges (dict): edges for each relation\n    # inv_relation_id (dict): mapping of relation to inverse relation\n    \n    temporal_walk = Temporal_Walk(train_data, inv_relation_id, transition_distr)\n    rl = Rule_Learner(edges=temporal_walk.edges, id2relation=None, inv_relation_id=inv_relation_id,  \n                        output_dir=output_dir)\n    all_relations = sorted(temporal_walk.edges)  # Learn for all relations\n\n    start =  timeit.default_timer()\n    num_relations = len(all_relations) // num_processes\n    output = Parallel(n_jobs=num_processes)(\n        delayed(learn_rules)(i, num_relations) for i in range(num_processes)\n    )\n    end =  timeit.default_timer()\n\n    all_rules = output[0]\n    for i in range(1, num_processes):\n        all_rules.update(output[i])\n\n    total_time = round(end - start, 6)\n    print(\"Learning finished in {} seconds.\".format(total_time))\n\n    rl.rules_dict = all_rules\n    rl.sort_rules_dict()\n\n    rule_filename = rl.save_rules(0, rule_lengths, num_walks, transition_distr, SEED)\n    # rl.save_rules_verbalized(0, rule_lengths, num_walks, transition_distr, seed)\n    # rules_statistics(rl.rules_dict)\nelse:\n    rule_filename = parsed['rule_filename']\n    print(\"Loading rules from file {}\".format(parsed['rule_filename']))\n\nend_train =  timeit.default_timer()\n\n## 2. Apply rules\n\nrules_dict = json.load(open(output_dir + rule_filename))\nrules_dict = {int(k): v for k, v in rules_dict.items()}\n\nrules_dict = ra.filter_rules(\n    rules_dict, min_conf=0.01, min_body_supp=2, rule_lengths=rule_lengths\n) # filter rules for minimum confidence, body support and rule length\n\nlearn_edges = store_edges(train_data)\nscore_func = ra.score_12\n# It is possible to specify a list of list of arguments for tuning\nargs = [[0.1, 0.5]]\n\n# compute valid mrr\nstart_valid =  timeit.default_timer()\nif compute_valid_mrr:\n    print('Computing valid MRR')\n\n    num_queries = len(val_data) // num_processes\n\n    output = Parallel(n_jobs=num_processes)(\n        delayed(apply_rules)(i, num_queries,rules_dict, neg_sampler, val_data, window, learn_edges, \n                            all_quads, args, split_mode='val') for i in range(num_processes))\n    end =  timeit.default_timer()\n\n    perf_list_val = []\n    hits_list_val = []\n\n    for i in range(num_processes):\n        perf_list_val.extend(output[i][0])\n        hits_list_val.extend(output[i][1])\nelse:\n    perf_list_val = [0]\n    hits_list_val = [0]\n    \n\nend_valid =  timeit.default_timer()\n\n# compute test mrr\nif log_per_rel ==True:\n    num_processes = 1 #otherwise logging per rel does not work for our implementation\nstart_test =  timeit.default_timer()\nprint('Computing test MRR')\nstart =  timeit.default_timer()\nnum_queries = len(test_data) // num_processes\n\noutput = Parallel(n_jobs=num_processes)(\n    delayed(apply_rules)(i, num_queries,rules_dict, neg_sampler, test_data, window, learn_edges, \n                         all_quads, args, split_mode='test', log_per_rel=log_per_rel, num_rels=num_rels) for i in range(num_processes))\nend =  timeit.default_timer()\n\nperf_list_all = []\nhits_list_all = []\n\n\nfor i in range(num_processes):\n    perf_list_all.extend(output[i][0])\n    hits_list_all.extend(output[i][1])\nif log_per_rel == True:\n    perf_per_rel = output[0][2]\n\n\ntotal_time = round(end - start, 6)\ntotal_valid_time = round(end_valid - start_valid, 6)\nprint(\"Application finished in {} seconds.\".format(total_time))\n\nprint(f\"The valid MRR is {np.mean(perf_list_val)}\")\nprint(f\"The MRR is {np.mean(perf_list_all)}\")\nprint(f\"The Hits@10 is {np.mean(hits_list_all)}\")\nprint(f\"We have {len(perf_list_all)} predictions\")\nprint(f\"The test set has len {len(test_data)} \")\n\nend_o =  timeit.default_timer()\ntrain_time_o = round(end_train- start_train, 6)  \ntest_time_o = round(end_o- start_test, 6)  \ntotal_time_o = round(end_o- start_o, 6)  \nprint(\"Running Training to find best configs finished in {} seconds.\".format(train_time_o))\nprint(\"Running testing with best configs finished in {} seconds.\".format(test_time_o))\nprint(\"Running all steps finished in {} seconds.\".format(total_time_o))\n\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\n\nif log_per_rel == True:\n    results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results_per_rel.json'\n    with open(results_filename, 'w') as json_file:\n        json.dump(perf_per_rel, json_file)\n\nresults_filename = f'{results_path}/{MODEL_NAME}_NONE_{DATA}_results.json'\nmetric = dataset.eval_metric\nsave_results({'model': MODEL_NAME,\n              'train_flag': None,\n              'rule_len': rule_lengths,\n              'window': window,\n              'data': DATA,\n              'run': 1,\n              'seed': SEED,\n              metric: float(np.mean(perf_list_all)),\n              'hits10': float(np.mean(hits_list_all)),\n              'val_mrr': float(np.mean(perf_list_val)),\n              'test_time': test_time_o,\n              'tot_train_val_time': total_time_o,\n              'valid_time': total_valid_time\n              }, \n    results_filename)\n"
  },
  {
    "path": "examples/linkproppred/tkgl-smallpedia/cen.py",
    "content": "\"\"\"\nComplex Evolutional Pattern Learning for Temporal Knowledge Graph Reasoning\nReference:\n- https://github.com/Lee-zix/CEN\nZixuan Li, Saiping Guan, Xiaolong Jin, Weihua Peng, Yajuan Lyu , Yong Zhu, Long Bai, Wei Li, Jiafeng Guo, Xueqi Cheng. \nComplex Evolutional Pattern Learning for Temporal Knowledge Graph Reasoning. ACL 2022.\n\"\"\"\nimport timeit\nimport os\nimport sys\nimport os.path as osp\nfrom pathlib import Path\nimport numpy as np\nimport torch\nimport random\nfrom tqdm import tqdm\nimport json\n# internal imports\nmodules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\nsys.path.append(modules_path)\nfrom modules.rrgcn import RecurrentRGCNCEN\nfrom tgb.utils.utils import set_random_seed, split_by_time,  save_results\nfrom modules.tkg_utils import get_args_cen, reformat_ts\nfrom modules.tkg_utils_dgl import build_sub_graph\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom tgb.linkproppred.dataset import LinkPropPredDataset \n\ndef test(model, history_len, history_list, test_list, num_rels, num_nodes, use_cuda, model_name, mode, split_mode):\n    \"\"\"\n    Test the model\n    :param model: model used to test\n    :param history_list:    all input history snap shot list, not include output label train list or valid list\n    :param test_list:   test triple snap shot list\n    :param num_rels:    number of relations\n    :param num_nodes:   number of nodes\n    :param use_cuda:\n    :param model_name:\n    :param mode:\n    :param split_mode: 'test' or 'val' to state which negative samples to load\n    :return mrr\n    \"\"\"\n    print(\"Testing for mode: \", split_mode)\n    if split_mode == 'test':\n        timesteps_to_eval = test_timestamps_orig\n    else:\n        timesteps_to_eval = val_timestamps_orig\n\n    idx = 0\n\n    if mode == \"test\":\n        # test mode: load parameter form file \n        if use_cuda:\n            checkpoint = torch.load(model_name, map_location=torch.device(args.gpu))\n        else:\n            checkpoint = torch.load(model_name, map_location=torch.device('cpu'))        \n        # use best stat checkpoint:\n        print(\"Load Model name: {}. Using best epoch : {}\".format(model_name, checkpoint['epoch']))  \n        print(\"\\n\"+\"-\"*10+\"start testing\"+\"-\"*10+\"\\n\")\n        model.load_state_dict(checkpoint['state_dict'])\n\n    model.eval()\n\n    input_list = [snap for snap in history_list[-history_len:]] \n    perf_list_all = []\n    hits_list_all = []\n    perf_per_rel = {}\n    for rel in range(num_rels):\n        perf_per_rel[rel] = []\n\n    for time_idx, test_snap in enumerate(tqdm(test_list)):\n        history_glist = [build_sub_graph(num_nodes, num_rels, g, use_cuda, args.gpu) for g in input_list]    \n        test_triples_input = torch.LongTensor(test_snap).cuda() if use_cuda else torch.LongTensor(test_snap)\n        timesteps_batch =timesteps_to_eval[time_idx]*np.ones(len(test_triples_input[:,0]))\n\n        neg_samples_batch = neg_sampler.query_batch(test_triples_input[:,0], test_triples_input[:,2], \n                                timesteps_batch, edge_type=test_triples_input[:,1], split_mode=split_mode)\n        pos_samples_batch = test_triples_input[:,2]\n\n        _, perf_list, hits_list = model.predict(history_glist, test_triples_input, use_cuda, neg_samples_batch, pos_samples_batch, \n                                    evaluator, METRIC) \n\n        perf_list_all.extend(perf_list)\n        hits_list_all.extend(hits_list)\n        if split_mode == \"test\":\n            if args.log_per_rel:\n                for score, rel in zip(perf_list, test_triples_input[:,1].tolist()):\n                    perf_per_rel[rel].append(score)\n        # reconstruct history graph list\n        input_list.pop(0)\n        input_list.append(test_snap)\n        idx += 1\n    if split_mode == \"test\":\n        if args.log_per_rel:   \n            for rel in range(num_rels):\n                if len(perf_per_rel[rel]) > 0:\n                    perf_per_rel[rel] = float(np.mean(perf_per_rel[rel]))\n                else:\n                    perf_per_rel.pop(rel)\n    mrr = np.mean(perf_list_all) \n    hits10 = np.mean(hits_list_all)\n    return mrr, perf_per_rel, hits10\n\n\n\ndef run_experiment(args, trainvalidtest_id=0, n_hidden=None, n_layers=None, dropout=None, n_bases=None):\n    '''\n    Run experiment for CEN model\n    :param args: arguments for the model\n    :param trainvalidtest_id: -1: pretrainig, 0: curriculum training (to find best test history len), 1: test on valid set, 2: test on test set\n    :param n_hidden: number of hidden units\n    :param n_layers: number of layers\n    :param dropout: dropout rate\n    :param n_bases: number of bases\n    return: mrr, perf_per_rel: mean reciprocal rank and performance per relation\n    '''\n    # 1) load configuration for grid search the best configuration\n    if n_hidden:\n        args.n_hidden = n_hidden\n    if n_layers:\n        args.n_layers = n_layers\n    if dropout:\n        args.dropout = dropout\n    if n_bases:\n        args.n_bases = n_bases\n    test_history_len = args.test_history_len\n    # 2) set save model path\n    save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'\n    if not osp.exists(save_model_dir):\n        os.mkdir(save_model_dir)\n        print('INFO: Create directory {}'.format(save_model_dir))\n    test_model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{args.test_history_len}'\n    test_state_file = save_model_dir+test_model_name\n    perf_per_rel ={}\n    use_cuda = args.gpu >= 0 and torch.cuda.is_available()\n    # create stat\n\n    model = RecurrentRGCNCEN(args.decoder,\n                            args.encoder,\n                            num_nodes,\n                            num_rels,\n                            args.n_hidden,\n                            args.opn,\n                            sequence_len=args.train_history_len,\n                            num_bases=args.n_bases,\n                            num_basis=args.n_basis,\n                            num_hidden_layers=args.n_layers,\n                            dropout=args.dropout,\n                            self_loop=args.self_loop,\n                            skip_connect=args.skip_connect,\n                            layer_norm=args.layer_norm,\n                            input_dropout=args.input_dropout,\n                            hidden_dropout=args.hidden_dropout,\n                            feat_dropout=args.feat_dropout,\n                            entity_prediction=args.entity_prediction,\n                            relation_prediction=args.relation_prediction,\n                            use_cuda=use_cuda,\n                            gpu = args.gpu)\n    if use_cuda:\n        torch.cuda.set_device(args.gpu)\n        model.cuda()\n\n    # optimizer\n    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)\n    \n    if trainvalidtest_id == 1:  # normal test on validation set  Note that mode=test\n        if os.path.exists(test_state_file):\n            mrr, _, hits10 = test(model, args.test_history_len, train_list, valid_list, num_rels, num_nodes, use_cuda, \n                        test_state_file, \"test\", split_mode=\"val\")      \n        else:\n            print('Cannot do testing because model does not exist: ', test_state_file)\n            mrr = 0\n            hits10 = 0\n    elif trainvalidtest_id == 2: # normal test on test set\n        if os.path.exists(test_state_file):\n            mrr, perf_per_rel, hits10 = test(model, args.test_history_len, train_list+valid_list, test_list, num_rels, num_nodes, use_cuda, \n                        test_state_file, \"test\", split_mode=\"test\")\n        else:\n            print('Cannot do testing because model does not exist: ', test_state_file)\n            mrr = 0\n            hits10 = 0\n    elif trainvalidtest_id == -1:\n        print(\"-------------start pre training model with history length {}----------\\n\".format(args.start_history_len))\n        model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{args.start_history_len}'\n        model_state_file = save_model_dir + model_name\n        print(\"Sanity Check: stat name : {}\".format(model_state_file))\n        print(\"Sanity Check: Is cuda available ? {}\".format(torch.cuda.is_available()))\n            \n        best_mrr = 0\n        best_epoch = 0\n        best_hits10= 0\n\n        ## training loop\n        for epoch in range(args.n_epochs):\n            model.train()\n            losses = []\n\n            idx = [_ for _ in range(len(train_list))]\n            random.shuffle(idx)\n            for train_sample_num in idx:\n                if train_sample_num == 0 or train_sample_num == 1: continue\n                if train_sample_num - args.start_history_len<0:\n                    input_list = train_list[0: train_sample_num]\n                    output = train_list[1:train_sample_num+1]\n                else:\n                    input_list = train_list[train_sample_num-args.start_history_len: train_sample_num]\n                    output = train_list[train_sample_num-args.start_history_len+1:train_sample_num+1]\n\n                # generate history graph\n                history_glist = [build_sub_graph(num_nodes, num_rels, snap, use_cuda, args.gpu) for snap in input_list]\n                output = [torch.from_numpy(_).long().cuda() for _ in output] if use_cuda else [torch.from_numpy(_).long() for _ in output]\n\n                loss= model.get_loss(history_glist, output[-1], None, use_cuda)\n                losses.append(loss.item())\n\n                loss.backward()\n                torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm)  # clip gradients\n                optimizer.step()\n                optimizer.zero_grad()\n\n            print(\"His {:04d}, Epoch {:04d} | Ave Loss: {:.4f} | Best MRR {:.4f} | Model {} \"\n                .format(args.start_history_len, epoch, np.mean(losses), best_mrr, model_name))\n            \n            #! checking GPU usage\n            free_mem, total_mem = torch.cuda.mem_get_info()\n            print (\"--------------GPU memory usage-----------\")\n            print (\"there are \", free_mem, \" free memory\")\n            print (\"there are \", total_mem, \" total available memory\")\n            print (\"there are \", total_mem - free_mem, \" used memory\")\n            print (\"--------------GPU memory usage-----------\")\n\n            # validation        \n            if epoch % args.evaluate_every == 0:\n                mrr, _, hits10 = test(model, args.start_history_len, train_list, valid_list, num_rels, num_nodes, use_cuda, \n                                model_state_file, mode=\"train\", split_mode= \"val\")\n\n                if mrr< best_mrr:\n                    if epoch >= args.n_epochs or epoch - best_epoch > 5:\n                        break\n                else:\n                    best_mrr = mrr\n                    best_epoch = epoch\n                    best_hits10 = hits10\n                    torch.save({'state_dict': model.state_dict(), 'epoch': epoch}, model_state_file)\n\n        \n    elif trainvalidtest_id == 0: #curriculum training\n        model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{args.start_history_len}'\n        init_state_file = save_model_dir + model_name\n        init_checkpoint = torch.load(init_state_file, map_location=torch.device(args.gpu))\n        # use best stat checkpoint:\n        print(\"Load Previous Model name: {}. Using best epoch : {}\".format(init_state_file, init_checkpoint['epoch']))  \n        print(\"\\n\"+\"-\"*10+\"Load model with history length {}\".format(args.start_history_len)+\"-\"*10+\"\\n\")\n        model.load_state_dict(init_checkpoint['state_dict'])\n        test_history_len = args.start_history_len\n\n        mrr, _, hits10 = test(model, \n                    args.start_history_len,\n                    train_list,\n                    valid_list, \n                    num_rels, \n                    num_nodes, \n                    use_cuda, \n                    init_state_file,  \n                    mode=\"test\", split_mode= \"val\") \n        best_mrr_list = [mrr.item()]         \n        best_hits_list = [hits10.item()]                                          \n        # start knowledge distillation\n        ks_idx = 0\n        for history_len in range(args.start_history_len+1, args.train_history_len+1, 1):\n            # current model\n            print(\"best mrr list :\", best_mrr_list)\n            # lr = 0.1*args.lr - 0.002*args.lr*ks_idx\n            optimizer = torch.optim.Adam(model.parameters(), lr=0.1*args.lr, weight_decay=0.00001)\n            model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{history_len}'\n            model_state_file = save_model_dir + model_name\n\n            print(\"Sanity Check: stat name : {}\".format(model_state_file))\n\n            # load model with the least history length\n            prev_model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{history_len-1}'\n            prev_state_file = save_model_dir + prev_model_name\n            checkpoint = torch.load(prev_state_file, map_location=torch.device(args.gpu)) \n            model.load_state_dict(checkpoint['state_dict']) \n            print(\"\\n\"+\"-\"*10+\"start knowledge distillation for history length at \"+ str(history_len)+\"-\"*10+\"\\n\")\n \n            best_mrr = 0\n            best_hits10 = 0\n            best_epoch = 0\n            for epoch in range(args.n_epochs):\n                model.train()\n                losses = []\n\n                idx = [_ for _ in range(len(train_list))]\n                random.shuffle(idx)\n                for train_sample_num in idx:\n                    if train_sample_num == 0 or train_sample_num == 1: continue\n                    if train_sample_num - history_len<0:\n                        input_list = train_list[0: train_sample_num]\n                        output = train_list[1:train_sample_num+1]\n                    else:\n                        input_list = train_list[train_sample_num - history_len: train_sample_num]\n                        output = train_list[train_sample_num-history_len+1:train_sample_num+1]\n\n                    # generate history graph\n                    history_glist = [build_sub_graph(num_nodes, num_rels, snap, use_cuda, args.gpu) for snap in input_list]\n                    output = [torch.from_numpy(_).long().cuda() for _ in output] if use_cuda else [torch.from_numpy(_).long() for _ in output]\n\n                    loss= model.get_loss(history_glist, output[-1], None, use_cuda)\n                    # print(loss)\n                    losses.append(loss.item())\n\n                    loss.backward()\n                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm)  # clip gradients\n                    optimizer.step()\n                    optimizer.zero_grad()\n\n                print(\"His {:04d}, Epoch {:04d} | Ave Loss: {:.4f} |Best MRR {:.4f} | Model {} \"\n                    .format(history_len, epoch, np.mean(losses), best_mrr, model_name))\n\n                #! checking GPU usage\n                free_mem, total_mem = torch.cuda.mem_get_info()\n                print (\"--------------GPU memory usage-----------\")\n                print (\"there are \", free_mem, \" free memory\")\n                print (\"there are \", total_mem, \" total available memory\")\n                print (\"there are \", total_mem - free_mem, \" used memory\")\n                print (\"--------------GPU memory usage-----------\")\n\n                # validation\n                if epoch % args.evaluate_every == 0:\n                    mrr, _, hits10 = test(model, history_len, train_list, valid_list, num_rels, num_nodes, use_cuda, \n                                model_state_file, mode=\"train\", split_mode= \"val\")\n                    \n                    if mrr< best_mrr:\n                        if epoch >= args.n_epochs or epoch-best_epoch>2:\n                            break\n                    else:\n                        best_mrr = mrr\n                        best_epoch = epoch\n                        best_hits10 = hits10\n                        torch.save({'state_dict': model.state_dict(), 'epoch': epoch}, model_state_file)  \n            mrr, _, hits10 = test(model, history_len, train_list, valid_list, num_rels, num_nodes, use_cuda, \n                        model_state_file, mode=\"test\", split_mode= \"val\")\n            ks_idx += 1\n            if mrr.item() < max(best_mrr_list):\n                test_history_len = history_len-1\n                print(\"early stopping, best history length: \", test_history_len)\n                break\n            else:\n                best_mrr_list.append(mrr.item())\n                best_hits_list.append(hits10.item())\n        \n    return mrr, test_history_len, perf_per_rel, hits10\n\n\n\n# ==================\n# ==================\n# ==================\n\nstart_overall = timeit.default_timer()\n\n# set hyperparameters\nargs, _ = get_args_cen()\nargs.dataset = 'tkgl-smallpedia'\n\nSEED = args.seed  # set the random seed for consistency\nset_random_seed(SEED)\n\nDATA=args.dataset\nMODEL_NAME = 'CEN'\n\nprint(\"logging mrrs per relation: \", args.log_per_rel)\nprint(\"do test and valid? do only test no validation?: \", args.validtest, args.test_only)\n\n# load data\ndataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\n\nnum_rels = dataset.num_rels\nnum_nodes = dataset.num_nodes \nsubjects = dataset.full_data[\"sources\"]\nobjects= dataset.full_data[\"destinations\"]\nrelations = dataset.edge_type\n\ntimestamps_orig = dataset.full_data[\"timestamps\"]\ntimestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1\nall_quads = np.stack((subjects, relations, objects, timestamps), axis=1)\n\ntrain_data = all_quads[dataset.train_mask]\nval_data = all_quads[dataset.val_mask]\ntest_data = all_quads[dataset.test_mask]\n\nval_timestamps_orig = list(set(timestamps_orig[dataset.val_mask])) # needed for getting the negative samples\nval_timestamps_orig.sort()\ntest_timestamps_orig = list(set(timestamps_orig[dataset.test_mask])) # needed for getting the negative samples\ntest_timestamps_orig.sort()\n\ntrain_list = split_by_time(train_data)\nvalid_list = split_by_time(val_data)\ntest_list = split_by_time(test_data)\n\n# evaluation metric\nMETRIC = dataset.eval_metric\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n#load the ns samples \ndataset.load_val_ns()\ndataset.load_test_ns()\n\nif args.grid_search:\n    print(\"TODO: implement hyperparameter grid search\")\n# single run\nelse:\n    \n    start_train = timeit.default_timer()\n    if args.validtest:\n        print('directly start testing')\n        if args.test_history_len_2 != args.test_history_len:\n            args.test_history_len = args.test_history_len_2 # hyperparameter value as given in original paper \n    else:\n        print('running pretrain and train')\n        # pretrain\n        mrr, _, _, hits10 = run_experiment(args, trainvalidtest_id=-1)\n        # train\n        mrr, args.test_history_len, _, hits10 = run_experiment(args, trainvalidtest_id=0) # overwrite test_history_len with \n        # the best history len (for valid mrr)       \n        \n    if args.test_only == False:\n        print(\"running test (on val and test dataset) with test_history_len of: \", args.test_history_len)\n        # test on val set\n        val_mrr, _, _, val_hits10 = run_experiment(args, trainvalidtest_id=1)\n    else:\n        val_mrr = 0\n        val_hits10 = 0\n\n    # test on test set\n    start_test = timeit.default_timer()\n    test_mrr, _, perf_per_rel, test_hits10 = run_experiment(args, trainvalidtest_id=2)\n\ntest_time = timeit.default_timer() - start_test\nall_time = timeit.default_timer() - start_train\nprint(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\nprint(f\"\\Train and Test: Elapsed Time (s): {all_time: .4f}\")\n\nprint(f\"\\tTest: {METRIC}: {test_mrr: .4f}\")\nprint(f\"\\tValid: {METRIC}: {val_mrr: .4f}\")\n\n# saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'\n\nsave_results({'model': MODEL_NAME,\n              'data': DATA,\n              'run': args.run_nr,\n              'seed': SEED,\n              'test_history_len': args.test_history_len,\n              f'val {METRIC}': float(val_mrr),\n              f'test {METRIC}': float(test_mrr),\n              'test_time': test_time,\n              'tot_train_val_time': all_time,\n              'test_hits10': float(test_hits10)\n              }, \n    results_filename)\n\nif args.log_per_rel:\n    results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results_per_rel.json'\n    with open(results_filename, 'w') as json_file:\n        json.dump(perf_per_rel, json_file)\n\n\nsys.exit()"
  },
  {
    "path": "examples/linkproppred/tkgl-smallpedia/edgebank.py",
    "content": "\"\"\"\nDynamic Link Prediction with EdgeBank\nNOTE: This implementation works only based on `numpy`\n\nReference: \n    - https://github.com/fpour/DGB/tree/main\n\n\n\"\"\"\n\nimport timeit\nimport numpy as np\nfrom sklearn.metrics import average_precision_score, roc_auc_score\nfrom torch_geometric.loader import TemporalDataLoader\nfrom tqdm import tqdm\nimport math\nimport os\nimport os.path as osp\nfrom pathlib import Path\nimport sys\nimport argparse\n\n# internal imports\ntgb_modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\nsys.path.append(tgb_modules_path)\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom modules.edgebank_predictor import EdgeBankPredictor\nfrom tgb.utils.utils import set_random_seed\nfrom tgb.linkproppred.dataset import LinkPropPredDataset\nfrom tgb.utils.utils import save_results\n\n# ==================\n# ==================\n# ==================\n\ndef test(data, test_mask, neg_sampler, split_mode):\n    r\"\"\"\n    Evaluated the dynamic link prediction\n    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges\n\n    Parameters:\n        data: a dataset object\n        test_mask: required masks to load the test set edges\n        neg_sampler: an object that gives the negative edges corresponding to each positive edge\n        split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives\n    Returns:\n        perf_metric: the result of the performance evaluation\n    \"\"\"\n    num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)\n    perf_list = []\n    hits_list = []\n    for batch_idx in tqdm(range(num_batches)):\n        start_idx = batch_idx * BATCH_SIZE\n        end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))\n        pos_src, pos_dst, pos_t, pos_edge = (\n            data['sources'][test_mask][start_idx: end_idx],\n            data['destinations'][test_mask][start_idx: end_idx],\n            data['timestamps'][test_mask][start_idx: end_idx],\n            data['edge_type'][test_mask][start_idx: end_idx],\n        )\n        neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, pos_edge, split_mode=split_mode)\n        \n        for idx, neg_batch in enumerate(neg_batch_list):\n            query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])\n            query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])\n\n            y_pred = edgebank.predict_link(query_src, query_dst)\n            # compute MRR\n            input_dict = {\n                \"y_pred_pos\": np.array([y_pred[0]]),\n                \"y_pred_neg\": np.array(y_pred[1:]),\n                \"eval_metric\": [metric],\n            }\n            results = evaluator.eval(input_dict)\n            perf_list.append(results[metric])\n            hits_list.append(results['hits@10'])\n            \n        # update edgebank memory after each positive batch\n        edgebank.update_memory(pos_src, pos_dst, pos_t)\n\n    perf_metrics = float(np.mean(perf_list))\n    perf_hits = float(np.mean(hits_list))\n\n    return perf_metrics, perf_hits\n\ndef get_args():\n    parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')\n    parser.add_argument('-d', '--data', type=str, help='Dataset name', default='tkgl-smallpedia')\n    parser.add_argument('--bs', type=int, help='Batch size', default=200)\n    parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)\n    parser.add_argument('--seed', type=int, help='Random seed', default=1)\n    parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])\n    parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)\n\n    try:\n        args = parser.parse_args()\n    except:\n        parser.print_help()\n        sys.exit(0)\n    return args, sys.argv \n\n# ==================\n# ==================\n# ==================\n\nstart_overall = timeit.default_timer()\n\n# set hyperparameters\nargs, _ = get_args()\n\nSEED = args.seed  # set the random seed for consistency\nset_random_seed(SEED)\nMEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`\nBATCH_SIZE = args.bs\nK_VALUE = args.k_value\nTIME_WINDOW_RATIO = args.time_window_ratio\nDATA = args.data\nMODEL_NAME = 'EdgeBank'\n\n\n\n# data loading with `numpy`\ndataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\ndata = dataset.full_data  \nmetric = dataset.eval_metric\n\n\n# get masks\ntrain_mask = dataset.train_mask\nval_mask = dataset.val_mask\ntest_mask = dataset.test_mask\n\n#data for memory in edgebank\nhist_src = np.concatenate([data['sources'][train_mask]])\nhist_dst = np.concatenate([data['destinations'][train_mask]])\nhist_ts = np.concatenate([data['timestamps'][train_mask]])\n\n# Set EdgeBank with memory updater\nedgebank = EdgeBankPredictor(\n        hist_src,\n        hist_dst,\n        hist_ts,\n        memory_mode=MEMORY_MODE,\n        time_window_ratio=TIME_WINDOW_RATIO)\n\nprint(\"==========================================================\")\nprint(f\"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============\")\nprint(\"==========================================================\")\n\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'\n\n# ==================================================== Test\n# loading the validation negative samples\ndataset.load_val_ns()\n\n# testing ...\nstart_val = timeit.default_timer()\nperf_metric_val, perf_hits_val = test(data, val_mask, neg_sampler, split_mode='val')\nend_val = timeit.default_timer()\n\nprint(f\"INFO: val: Evaluation Setting: >>> ONE-VS--ALL <<< \")\nprint(f\"\\tval: {metric}: {perf_metric_val: .4f}\")\nval_time = timeit.default_timer() - start_val\nprint(f\"\\tval: Elapsed Time (s): {val_time: .4f}\")\n\n\n\n\n# ==================================================== Test\n# loading the test negative samples\ndataset.load_test_ns()\n\n# testing ...\nstart_test = timeit.default_timer()\nperf_metric_test, perf_hits_test = test(data, test_mask, neg_sampler, split_mode='test')\nend_test = timeit.default_timer()\n\nprint(f\"INFO: Test: Evaluation Setting: >>>  <<< \")\nprint(f\"\\tTest: {metric}: {perf_metric_test: .4f}\")\ntest_time = timeit.default_timer() - start_test\nprint(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\n\nsave_results({'model': MODEL_NAME,\n              'memory_mode': MEMORY_MODE,\n              'data': DATA,\n              'run': 1,\n              'seed': SEED,\n              metric: perf_metric_test,\n              'val_mrr': perf_metric_val,\n              'test_time': test_time,\n              'tot_train_val_time': test_time+val_time,\n              'hits10': perf_hits_test}, \n    results_filename)\n\n"
  },
  {
    "path": "examples/linkproppred/tkgl-smallpedia/recurrencybaseline.py",
    "content": "\"\"\" from paper: \"History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting\" \nJulia Gastinger, Christian Meilicke, Federico Errica, Timo Sztyler, Anett Schuelke, Heiner Stuckenschmidt (IJCAI 2024)\n\n@inproceedings{gastinger2024baselines,\n  title={History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting},\n  author={Gastinger, Julia and Meilicke, Christian and Errica, Federico and Sztyler, Timo and Schuelke, Anett and Stuckenschmidt, Heiner},\n  booktitle={33nd International Joint Conference on Artificial Intelligence (IJCAI 2024)},\n  year={2024},\n  organization={International Joint Conferences on Artificial Intelligence Organization}\n}\n\n\"\"\"\n\n## imports\nimport timeit\nimport argparse\nimport numpy as np\nfrom copy import copy\nfrom pathlib import Path\nimport ray\nimport sys\nimport os\nimport os.path as osp\nimport json\n#internal imports \nmodules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\nsys.path.append(modules_path)\nfrom modules.recurrencybaseline_predictor import baseline_predict, baseline_predict_remote\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom tgb.linkproppred.dataset import LinkPropPredDataset \nfrom tgb.utils.utils import set_random_seed,  save_results \nfrom modules.tkg_utils import create_basis_dict, group_by, reformat_ts\n\ndef predict(num_processes,  data_c_rel, all_data_c_rel, alpha, lmbda_psi,\n            perf_list_all, hits_list_all, window, neg_sampler, split_mode):\n    \"\"\" create predictions for each relation on test or valid set and compute mrr\n    :return  perf_list_all: list of mrrs for each test query\n    :return hits_list_all: list of hits for each test query\n    \"\"\"\n\n    first_ts = data_c_rel[0][3]\n    ## use this if you wanna use ray:\n    num_queries = len(data_c_rel) // num_processes\n    if num_queries < num_processes: # if we do not have enough queries for all the processes\n        num_processes_tmp = 1\n        num_queries = len(data_c_rel)\n    else:\n        num_processes_tmp = num_processes  \n    if num_processes > 1:\n        object_references =[]                   \n        \n        for i in range(num_processes_tmp):\n            num_test_queries = len(data_c_rel) - (i + 1) * num_queries\n            if num_test_queries >= num_queries:\n                test_queries_idx =[i * num_queries, (i + 1) * num_queries]\n            else:\n                test_queries_idx = [i * num_queries, len(test_data)]\n\n            valid_data_b = data_c_rel[test_queries_idx[0]:test_queries_idx[1]]\n\n            ob = baseline_predict_remote.remote(num_queries, valid_data_b, all_data_c_rel, window, \n                                basis_dict, \n                                num_nodes, num_rels, lmbda_psi, \n                                alpha, evaluator,first_ts, neg_sampler, split_mode)\n            object_references.append(ob)\n\n        output = ray.get(object_references)\n\n        # updates the scores and logging dict for each process\n        for proc_loop in range(num_processes_tmp):\n            perf_list_all.extend(output[proc_loop][0])\n            hits_list_all.extend(output[proc_loop][1])\n\n    else:\n        perf_list, hits_list = baseline_predict(len(data_c_rel), data_c_rel, all_data_c_rel, \n                            window, basis_dict, \n                            num_nodes, num_rels, lmbda_psi, \n                            alpha, evaluator, first_ts, neg_sampler, split_mode)                  \n        perf_list_all.extend(perf_list)\n        hits_list_all.extend(hits_list)\n    \n    return perf_list_all, hits_list_all\n\n## test\ndef test(best_config, all_relations,test_data_prel, all_data_prel, neg_sampler, num_processes, window, split_mode='test'):  \n    \"\"\" create predictions by loopoing through all relations on test or valid set and compute mrr\n    :return  perf_list_all: list of mrrs for each test query\n    :return hits_list_all: list of hits for each test query\n    \"\"\"       \n    perf_list_all = []\n    hits_list_all =[]\n    \n    csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'\n\n    ## loop through relations and apply baselines\n    \n    for rel in all_relations:\n        start =  timeit.default_timer()\n        if rel in test_data_prel.keys():\n            lmbda_psi = best_config[str(rel)]['lmbda_psi'][0]\n            alpha = best_config[str(rel)]['alpha'][0]\n\n            # test data for this relation\n            test_data_c_rel = test_data_prel[rel]\n            timesteps_test = list(set(test_data_c_rel[:,3]))\n            timesteps_test.sort()\n            all_data_c_rel = all_data_prel[rel]         \n            perf_list_rel = []\n            hits_list_rel = []\n            perf_list_rel, hits_list_rel = predict(num_processes, test_data_c_rel,\n                                                all_data_c_rel, alpha, lmbda_psi,perf_list_rel, hits_list_rel, \n                                                window, neg_sampler, split_mode)\n            perf_list_all.extend(perf_list_rel)\n            hits_list_all.extend(hits_list_rel)\n        else:\n            perf_list_rel =[]\n         \n        end =  timeit.default_timer()\n        total_time = round(end - start, 6)  \n        print(\"Relation {} finished in {} seconds.\".format(rel, total_time))\n        \n\n        with open(csv_file, 'a') as f:\n            f.write(\"{},{}\\n\".format(rel, perf_list_rel))\n\n    return perf_list_all, hits_list_all\n\ndef read_dict_compute_mrr(split_mode='test'):\n    \"\"\" read the results per relation  from a precreated file and compute mrr\n    :return  mrr_per_rel: dictionary of mrrs for each relation\n    :return all_mrrs: list of mrrs for all relations\n    \"\"\"\n    csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'\n    # Initialize an empty dictionary to store the data\n    results_per_rel_dict = {}\n    mrr_per_rel = {}\n    all_mrrs = []\n    # Open the file for reading\n    with open(csv_file, 'r') as f:\n        # Read each line in the file\n        for line in f:\n            # Split the line at the comma\n            parts = line.strip().split(',')\n            # Extract the key (the first part)\n            key = int(parts[0])\n            # Extract the values (the rest of the parts), remove square brackets\n            values = [float(value.strip('[]')) for value in parts[1:]]\n            # Add the key-value pair to the dictionary\n            if key in results_per_rel_dict.keys():\n                print(f\"Key {key} already exists in the dictionary!!! might have duplicate entries in results csv\")\n            results_per_rel_dict[key] = values\n            all_mrrs.extend(values)\n            mrr_per_rel[key] = np.mean(values)\n\n    if len(list(results_per_rel_dict.keys())) != num_rels:\n        print(\"we do not have entries for each rel in the results csv file. only num enties: \", len(list(results_per_rel_dict.keys())))\n\n    print(\"Split mode: \"+split_mode +\" Mean MRR: \", np.mean(all_mrrs))\n    print(\"mrr per relation: \", mrr_per_rel)\n\n\n\n## train\ndef train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampler, num_processes, window):\n    \"\"\" optional, find best values for lambda and alpha by looping through all relations and testing an a fixed set of params\n    based on validation mrr\n    :return best_config: dictionary of best params for each relation\n    \"\"\"\n    best_config= {}\n    best_mrr = 0\n    for rel in rels: # loop through relations. for each relation, apply rules with selected params, compute valid mrr\n        start =  timeit.default_timer()\n        rel_key = int(rel)            \n\n        best_config[str(rel_key)] = {}\n        best_config[str(rel_key)]['not_trained'] = 'True'    \n        best_config[str(rel_key)]['lmbda_psi'] = [default_lmbda_psi,0] #default\n        best_config[str(rel_key)]['other_lmbda_mrrs'] = list(np.zeros(len(params_dict['lmbda_psi'])))\n        best_config[str(rel_key)]['alpha'] = [default_alpha,0]  #default    \n        best_config[str(rel_key)]['other_alpha_mrrs'] = list(np.zeros(len(params_dict['alpha'])))\n        \n        if rel in val_data_prel.keys():      \n            # valid data for this relation  \n            val_data_c_rel = val_data_prel[rel]\n            timesteps_valid = list(set(val_data_c_rel[:,3]))\n            timesteps_valid.sort()\n            trainval_data_c_rel = trainval_data_prel[rel]\n\n            ######  1) select lambda ###############        \n            lmbdas_psi = params_dict['lmbda_psi']        \n\n            alpha = 1\n            best_lmbda_psi = 0.1\n            best_mrr_psi = 0\n            lmbda_mrrs = []\n\n            best_config[str(rel_key)]['num_app_valid'] = copy(len(val_data_c_rel))\n            best_config[str(rel_key)]['num_app_train_valid'] = copy(len(trainval_data_c_rel))         \n            best_config[str(rel_key)]['not_trained'] = 'False'       \n            \n            for lmbda_psi in lmbdas_psi:   \n                perf_list_r = []\n                hits_list_r = []\n                perf_list_r, hits_list_r = predict(num_processes, val_data_c_rel, \n                                                    trainval_data_c_rel, alpha, lmbda_psi,perf_list_r, hits_list_r, \n                                                    window, neg_sampler, split_mode='val')\n                # compute mrr\n                mrr = np.mean(perf_list_r)\n                # # is new mrr better than previous best? if yes: store lmbda\n                if mrr > best_mrr_psi:\n                    best_mrr_psi = float(mrr)\n                    best_lmbda_psi = lmbda_psi\n\n\n                lmbda_mrrs.append(float(mrr))\n            best_config[str(rel_key)]['lmbda_psi'] = [best_lmbda_psi, best_mrr_psi]\n            best_config[str(rel_key)]['other_lmbda_mrrs'] = lmbda_mrrs\n            best_mrr = best_mrr_psi\n            ##### 2) select alpha ###############\n            best_config[str(rel_key)]['not_trained'] = 'False'    \n            alphas = params_dict['alpha'] \n            lmbda_psi = best_config[str(rel_key)]['lmbda_psi'][0] # use the best lmbda psi\n\n            alpha_mrrs = []\n            # perf_list_all = []\n            best_mrr_alpha = 0\n            best_alpha=0.99\n            for alpha in alphas:\n                perf_list_r = []\n                hits_list_r = []\n\n                perf_list_r, hits_list_r = predict(num_processes, val_data_c_rel, \n                                                    trainval_data_c_rel, alpha, lmbda_psi,perf_list_r, hits_list_r, \n                                                    window, neg_sampler, split_mode='val')\n                # compute mrr\n                mrr_alpha = np.mean(perf_list_r)\n\n                # is new mrr better than previous best? if yes: store alpha\n                if mrr_alpha > best_mrr_alpha:\n                    best_mrr_alpha = float(mrr_alpha)\n                    best_alpha = alpha\n                    best_mrr = best_mrr_alpha\n                alpha_mrrs.append(float(mrr_alpha))\n\n            best_config[str(rel_key)]['alpha'] = [best_alpha, best_mrr_alpha]\n            best_config[str(rel_key)]['other_alpha_mrrs'] = alpha_mrrs\n\n        end =  timeit.default_timer()\n        total_time = round(end - start, 6)  \n        print(\"Relation {} finished in {} seconds.\".format(rel, total_time))\n    return best_config\n\n\n\n## args\ndef get_args(): \n    \"\"\"parse all arguments for the script\"\"\"\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--dataset\", \"-d\", default=\"tkgl-smallpedia\", type=str) \n    parser.add_argument(\"--window\", \"-w\", default=0, type=int) # set to e.g. 200 if only the most recent 200 timesteps should be considered. set to -2 if multistep\n    parser.add_argument(\"--num_processes\", \"-p\", default=1, type=int)\n    parser.add_argument(\"--lmbda\", \"-l\",  default=0.1, type=float) # fix lambda. used if trainflag == false\n    parser.add_argument(\"--alpha\", \"-alpha\",  default=0.99, type=float) # fix alpha. used if trainflag == false\n    parser.add_argument(\"--train_flag\", \"-tr\",  default='False') # do we need training, ie selection of lambda and alpha\n    parser.add_argument(\"--load_flag\", \"-lo\",  default='False') # if train_flag set to True: do you want to load best_config?\n    parser.add_argument(\"--save_config\", \"-c\",  default='True') # do we need to save the selection of lambda and alpha in config file?\n    parser.add_argument('--seed', type=int, help='Random seed', default=1) # not needed\n    parsed = vars(parser.parse_args())\n    return parsed\n\nstart_o =  timeit.default_timer()\n\nparsed = get_args()\nif parsed['num_processes']>1:\n    ray.init(num_cpus=parsed[\"num_processes\"], num_gpus=0)\nMODEL_NAME = 'RecurrencyBaseline'\nSEED = parsed['seed']  # set the random seed for consistency\nset_random_seed(SEED)\nperrel_results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_models'\nif not osp.exists(perrel_results_path):\n    os.mkdir(perrel_results_path)\n    print('INFO: Create directory {}'.format(perrel_results_path))\nPath(perrel_results_path).mkdir(parents=True, exist_ok=True)\n\n## load dataset and prepare it accordingly\nname = parsed[\"dataset\"]\ndataset = LinkPropPredDataset(name=name, root=\"datasets\", preprocess=True)\nDATA = name\n\nrelations = dataset.edge_type\nnum_rels = dataset.num_rels\nrels = np.arange(0,num_rels)\nsubjects = dataset.full_data[\"sources\"]\nobjects= dataset.full_data[\"destinations\"]\nnum_nodes = dataset.num_nodes \ntimestamps_orig = dataset.full_data[\"timestamps\"]\ntimestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1\n\nprint(\"split train valid test data\")\nall_quads = np.stack((subjects, relations, objects, timestamps, timestamps_orig), axis=1)\ntrain_data = all_quads[dataset.train_mask]\nval_data = all_quads[dataset.val_mask]\ntest_data = all_quads[dataset.test_mask]\n\nmetric = dataset.eval_metric\nevaluator = Evaluator(name=name)\nneg_sampler = dataset.negative_sampler\n\ntrain_val_data = np.concatenate([train_data, val_data])\nall_data = np.concatenate([train_data, val_data, test_data])\n\n# create dicts with key: relation id, values: triples for that relation id\nprint(\"grouping data by relation\")\ntest_data_prel = group_by(test_data, 1)\nall_data_prel = group_by(all_data, 1)\nval_data_prel = group_by(val_data, 1)\ntrainval_data_prel = group_by(train_val_data, 1)\n\n#load the ns samples \n# if parsed['train_flag']:\nprint(\"loading negative samples\")\ndataset.load_val_ns()\ndataset.load_test_ns()\n\n# parameter options\nif parsed['train_flag'] == 'True':\n    params_dict = {}\n    params_dict['lmbda_psi'] = [0, 0.0001, 0.0005, 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.5, 0.9, 1.0001] \n    params_dict['alpha'] = [0, 0.00001, 0.0001, 0.001, 0.01, 0.1, 0.5, 0.9, 0.99, 0.999, 0.9999, 0.99999, 1]\n    default_lmbda_psi = params_dict['lmbda_psi'][-1]\n    default_alpha = params_dict['alpha'][-2]\n\n## load rules\nprint(\"creating rules\")\nbasis_dict = create_basis_dict(train_val_data)\nprint(\"done with creating rules\")\n## init\n# rb_predictor = RecurrencyBaselinePredictor(rels)\n## train to find best lambda and alpha\nstart_train =  timeit.default_timer()\nif parsed['train_flag'] ==  'True':\n    if parsed['load_flag'] == 'True':\n        with open('best_config.json', 'r') as infile:\n            best_config = json.load(infile)\n    else:\n        print('start training')\n        best_config = train(params_dict,  rels, val_data_prel, trainval_data_prel, neg_sampler, parsed['num_processes'], \n            parsed['window'])\n        if parsed['save_config'] == 'True':\n            import json\n            with open('best_config.json', 'w') as outfile:\n                json.dump(best_config, outfile)\n\nelse: # use preset lmbda and alpha; same for all relations\n    best_config = {} \n    for rel in rels:\n        best_config[str(rel)] = {}\n        best_config[str(rel)]['lmbda_psi'] = [parsed['lmbda']]\n        best_config[str(rel)]['alpha'] = [parsed['alpha']]\n    \nend_train =  timeit.default_timer()\n\n# compute validation mrr\nprint(\"Computing validation MRR\")\nperf_list_all_val, hits_list_all_val = test(best_config,rels, val_data_prel, \n                                                trainval_data_prel, neg_sampler, parsed['num_processes'], \n                                            parsed['window'], split_mode='val')\nval_mrr = float(np.mean(perf_list_all_val))\n\n# compute test mrr\nprint(\"Computing test MRR\")\nstart_test =  timeit.default_timer()\nperf_list_all, hits_list_all = test(best_config,rels, test_data_prel, \n                                                 all_data_prel, neg_sampler, parsed['num_processes'], \n                                                parsed['window'])\n\nend_o =  timeit.default_timer()\ntrain_time_o = round(end_train- start_train, 6)  \ntest_time_o = round(end_o- start_test, 6)  \ntotal_time_o = round(end_o- start_o, 6)  \nprint(\"Running Training to find best configs finished in {} seconds.\".format(train_time_o))\nprint(\"Running testing with best configs finished in {} seconds.\".format(test_time_o))\nprint(\"Running all steps finished in {} seconds.\".format(total_time_o))\n\nprint(f\"The test MRR is {np.mean(perf_list_all)}\")\nprint(f\"The valid MRR is {val_mrr}\")\nprint(f\"The Hits@10 is {np.mean(hits_list_all)}\")\nprint(f\"We have {len(perf_list_all)} predictions\")\nprint(f\"The test set has len {len(test_data)} \")\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_NONE_{DATA}_results.json'\nmetric = dataset.eval_metric\nsave_results({'model': MODEL_NAME,\n              'train_flag': parsed['train_flag'],\n              'data': DATA,\n              'run': 1,\n              'seed': SEED,\n              metric: float(np.mean(perf_list_all)),\n              'val_mrr': val_mrr,\n              'hits10': float(np.mean(hits_list_all)),\n              'test_time': test_time_o,\n              'tot_train_val_time': total_time_o\n              }, \n    results_filename)\n\nif parsed['num_processes']>1:\n    ray.shutdown()\n\n\n    \n"
  },
  {
    "path": "examples/linkproppred/tkgl-smallpedia/regcn.py",
    "content": "\"\"\"\nTemporal Knowledge Graph Reasoning Based on Evolutional Representation Learning\nReference:\n- https://github.com/Lee-zix/RE-GCN\nZixuan Li, Xiaolong Jin, Wei Li, Saiping Guan, Jiafeng Guo, Huawei Shen, Yuanzhuo Wang and Xueqi Cheng. Temporal \nKnowledge Graph Reasoning Based on Evolutional Representation Learning. SIGIR 2021.\n\"\"\"\nimport sys\nimport timeit\nimport os\nimport sys\nimport os.path as osp\nfrom pathlib import Path\nimport numpy as np\nimport torch\nimport random\nfrom tqdm import tqdm\n# internal imports\nmodules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\nsys.path.append(modules_path)\nfrom modules.rrgcn import RecurrentRGCNREGCN\nfrom tgb.utils.utils import set_random_seed, split_by_time, save_results\nfrom modules.tkg_utils import get_args_regcn, reformat_ts\nfrom modules.tkg_utils_dgl import build_sub_graph\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom tgb.linkproppred.dataset import LinkPropPredDataset \nimport json\n\ndef test(model, history_list, test_list, num_rels, num_nodes, use_cuda, model_name, static_graph, mode, split_mode):\n    \"\"\"\n    Test the model on either test or validation set\n    :param model: model used to test\n    :param history_list:    all input history snap shot list, not include output label train list or valid list\n    :param test_list:   test triple snap shot list\n    :param num_rels:    number of relations\n    :param num_nodes:   number of nodes\n    :param use_cuda:\n    :param model_name:\n    :param mode:\n    :param split_mode: 'test' or 'val' to state which negative samples to load\n    :return mrr\n    \"\"\"\n    print(\"Testing for mode: \", split_mode)\n    if split_mode == 'test':\n        timesteps_to_eval = test_timestamps_orig\n    else:\n        timesteps_to_eval = val_timestamps_orig\n\n    idx = 0\n\n    if mode == \"test\":\n        # test mode: load parameter form file \n        if use_cuda:\n            checkpoint = torch.load(model_name, map_location=torch.device(args.gpu))\n        else:\n            checkpoint = torch.load(model_name, map_location=torch.device('cpu'))        \n        # use best stat checkpoint:\n        print(\"Load Model name: {}. Using best epoch : {}\".format(model_name, checkpoint['epoch']))  \n        print(\"\\n\"+\"-\"*10+\"start testing\"+\"-\"*10+\"\\n\")\n        model.load_state_dict(checkpoint['state_dict'])\n\n    model.eval()\n    input_list = [snap for snap in history_list[-args.test_history_len:]] \n    perf_list_all = []\n    hits_list_all = []\n    perf_per_rel = {}\n    for rel in range(num_rels):\n        perf_per_rel[rel] = []\n\n    for time_idx, test_snap in enumerate(tqdm(test_list)):\n        history_glist = [build_sub_graph(num_nodes, num_rels, g, use_cuda, args.gpu) for g in input_list]    \n        test_triples_input = torch.LongTensor(test_snap).cuda() if use_cuda else torch.LongTensor(test_snap)\n        timesteps_batch =timesteps_to_eval[time_idx]*np.ones(len(test_triples_input[:,0]))\n\n        neg_samples_batch = neg_sampler.query_batch(test_triples_input[:,0], test_triples_input[:,2], \n                                timesteps_batch, edge_type=test_triples_input[:,1], split_mode=split_mode)\n        pos_samples_batch = test_triples_input[:,2]\n\n        _, perf_list, hits_list = model.predict(history_glist, num_rels, static_graph, test_triples_input, use_cuda, neg_samples_batch, pos_samples_batch, \n                                    evaluator, METRIC)  \n\n        perf_list_all.extend(perf_list)\n        hits_list_all.extend(hits_list)\n        if split_mode == \"test\":\n            if args.log_per_rel:\n                for score, rel in zip(perf_list, test_triples_input[:,1].tolist()):\n                    perf_per_rel[rel].append(score)\n        # reconstruct history graph list\n        input_list.pop(0)\n        input_list.append(test_snap)\n        idx += 1\n\n    if split_mode == \"test\":\n        if args.log_per_rel:   \n            for rel in range(num_rels):\n                if len(perf_per_rel[rel]) > 0:\n                    perf_per_rel[rel] = float(np.mean(perf_per_rel[rel]))\n                else:\n                    perf_per_rel.pop(rel)\n    mrr = np.mean(perf_list_all) \n    hits10 = np.mean(hits_list_all) \n    return mrr, perf_per_rel, hits10\n\n\n\ndef run_experiment(args, n_hidden=None, n_layers=None, dropout=None, n_bases=None):\n    \"\"\"\n    Run the experiment with the given configuration\n    :param args: arguments\n    :param n_hidden: hidden dimension\n    :param n_layers: number of layers\n    :param dropout: dropout rate\n    :param n_bases: number of bases\n    :return: mrr, perf_per_rel  (mean reciprocal rank, performance per relation)\n    \"\"\"\n    # load configuration for grid search the best configuration\n    if n_hidden:\n        args.n_hidden = n_hidden\n    if n_layers:\n        args.n_layers = n_layers\n    if dropout:\n        args.dropout = dropout\n    if n_bases:\n        args.n_bases = n_bases\n    mrr = 0\n    hits10=0\n    # 2) set save model path\n    save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'\n    if not osp.exists(save_model_dir):\n        os.mkdir(save_model_dir)\n        print('INFO: Create directory {}'.format(save_model_dir))\n    model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}'\n    model_state_file = save_model_dir+model_name\n    perf_per_rel = {}\n    use_cuda = args.gpu >= 0 and torch.cuda.is_available()\n\n    num_static_rels, num_words, static_triples, static_graph = 0, 0, [], None\n\n    # create stat\n    model = RecurrentRGCNREGCN(args.decoder,\n                          args.encoder,\n                        num_nodes,\n                        int(num_rels/2),\n                        num_static_rels, # DIFFERENT\n                        num_words, # DIFFERENT\n                        args.n_hidden,\n                        args.opn,\n                        sequence_len=args.train_history_len,\n                        num_bases=args.n_bases,\n                        num_basis=args.n_basis,\n                        num_hidden_layers=args.n_layers,\n                        dropout=args.dropout,\n                        self_loop=args.self_loop,\n                        skip_connect=args.skip_connect,\n                        layer_norm=args.layer_norm,\n                        input_dropout=args.input_dropout,\n                        hidden_dropout=args.hidden_dropout,\n                        feat_dropout=args.feat_dropout,\n                        aggregation=args.aggregation, # DIFFERENT\n                        weight=args.weight, # DIFFERENT\n                        discount=args.discount, # DIFFERENT\n                        angle=args.angle, # DIFFERENT\n                        use_static=args.add_static_graph, # DIFFERENT\n                        entity_prediction=args.entity_prediction, \n                        relation_prediction=args.relation_prediction,\n                        use_cuda=use_cuda,\n                        gpu = args.gpu,\n                        analysis=args.run_analysis) # DIFFERENT\n\n    if use_cuda:\n        torch.cuda.set_device(args.gpu)\n        model.cuda()\n\n    # optimizer\n    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)\n\n    if args.test and os.path.exists(model_state_file):\n        mrr, perf_per_rel, hits10 = test(model, \n                    train_list+valid_list, \n                    test_list, \n                    num_rels, \n                    num_nodes, \n                    use_cuda, \n                    model_state_file, \n                    static_graph, \n                    \"test\", \n                    \"test\")\n        return mrr, perf_per_rel, hits10\n    elif args.test and not os.path.exists(model_state_file):\n        print(\"--------------{} not exist, Change mode to train and generate stat for testing----------------\\n\".format(model_state_file))\n        return 0, 0\n    else:\n        print(\"----------------------------------------start training----------------------------------------\\n\")\n        best_mrr = 0\n        best_hits = 0\n        for epoch in range(args.n_epochs):\n\n            model.train()\n            losses = []\n            idx = [_ for _ in range(len(train_list))]\n            random.shuffle(idx)\n\n            for train_sample_num in tqdm(idx):\n                if train_sample_num == 0: continue\n                output = train_list[train_sample_num:train_sample_num+1]\n                if train_sample_num - args.train_history_len<0:\n                    input_list = train_list[0: train_sample_num]\n                else:\n                    input_list = train_list[train_sample_num - args.train_history_len:\n                                        train_sample_num]\n\n                # generate history graph\n                history_glist = [build_sub_graph(num_nodes, num_rels, snap, use_cuda, args.gpu) for snap in input_list]\n                output = [torch.from_numpy(_).long().cuda() for _ in output] if use_cuda else [torch.from_numpy(_).long() for _ in output]\n                loss_e, loss_r, loss_static = model.get_loss(history_glist, output[0], static_graph, use_cuda)\n                loss = args.task_weight*loss_e + (1-args.task_weight)*loss_r + loss_static\n\n                losses.append(loss.item())\n\n                loss.backward()\n                torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm)  # clip gradients\n                optimizer.step()\n                optimizer.zero_grad()\n\n            print(\"Epoch {:04d} | Ave Loss: {:.4f} | Best MRR {:.4f} | Model {} \"\n                  .format(epoch, np.mean(losses),  best_mrr, model_name))\n            \n            #! checking GPU usage\n            free_mem, total_mem = torch.cuda.mem_get_info()\n            print (\"--------------GPU memory usage-----------\")\n            print (\"there are \", free_mem, \" free memory\")\n            print (\"there are \", total_mem, \" total available memory\")\n            print (\"there are \", total_mem - free_mem, \" used memory\")\n            print (\"--------------GPU memory usage-----------\")\n\n            # validation\n            if epoch and epoch % args.evaluate_every == 0:\n                mrr,perf_per_rel, hits10 = test(model, train_list, \n                            valid_list, \n                            num_rels, \n                            num_nodes, \n                            use_cuda, \n                            model_state_file, \n                            static_graph, \n                            mode=\"train\", split_mode='val')\n            \n                if mrr < best_mrr:\n                    if epoch >= args.n_epochs:\n                        break\n                else:\n                    best_mrr = mrr\n                    best_hits = hits10\n                    torch.save({'state_dict': model.state_dict(), 'epoch': epoch}, model_state_file)\n\n        return best_mrr, perf_per_rel, hits10\n# ==================\n# ==================\n# ==================\n\nstart_overall = timeit.default_timer()\n\n# set hyperparameters\nargs, _ = get_args_regcn()\nargs.dataset = 'tkgl-smallpedia'\n\nSEED = args.seed  # set the random seed for consistency\nset_random_seed(SEED)\n\nDATA=args.dataset\nMODEL_NAME = 'REGCN'\n\nprint(\"logging mrrs per relation: \", args.log_per_rel)\n\n# load data\ndataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\n\nnum_rels = dataset.num_rels\nnum_nodes = dataset.num_nodes \nsubjects = dataset.full_data[\"sources\"]\nobjects= dataset.full_data[\"destinations\"]\nrelations = dataset.edge_type\n\ntimestamps_orig = dataset.full_data[\"timestamps\"]\ntimestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1\nall_quads = np.stack((subjects, relations, objects, timestamps), axis=1)\n\ntrain_data = all_quads[dataset.train_mask]\nval_data = all_quads[dataset.val_mask]\ntest_data = all_quads[dataset.test_mask]\n\nval_timestamps_orig = list(set(timestamps_orig[dataset.val_mask])) # needed for getting the negative samples\nval_timestamps_orig.sort()\ntest_timestamps_orig = list(set(timestamps_orig[dataset.test_mask])) # needed for getting the negative samples\ntest_timestamps_orig.sort()\n\ntrain_list = split_by_time(train_data)\nvalid_list = split_by_time(val_data)\ntest_list = split_by_time(test_data)\n\n# evaluation metric\nMETRIC = dataset.eval_metric\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n#load the ns samples \ndataset.load_val_ns()\ndataset.load_test_ns()\n\n## run training and testing\nval_mrr, test_mrr = 0, 0\ntest_hits10 = 0\nif args.grid_search:\n    print(\"hyperparameter grid search not implemented. Exiting.\")\n# single run\nelse:\n    start_train = timeit.default_timer()\n    if args.test == False: #if they are true: directly test on a previously trained and stored model\n        print('start training')\n        val_mrr, perf_per_rel, val_hits10 = run_experiment(args) # do training\n    start_test = timeit.default_timer()\n    args.test = True\n    print('start testing')\n    test_mrr, perf_per_rel, test_hits10 = run_experiment(args) # do testing\n\n\ntest_time = timeit.default_timer() - start_test\nall_time = timeit.default_timer() - start_train\nprint(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\nprint(f\"\\Train and Test: Elapsed Time (s): {all_time: .4f}\")\n\nprint(f\"\\tTest: {METRIC}: {test_mrr: .4f}\")\nprint(f\"\\tValid: {METRIC}: {val_mrr: .4f}\")\n\n# saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'\n\nsave_results({'model': MODEL_NAME,\n              'data': DATA,\n              'run': args.run_nr,\n              'seed': SEED,\n              f'val {METRIC}': float(val_mrr),\n              f'test {METRIC}': float(test_mrr),\n              'test_time': test_time,\n              'tot_train_val_time': all_time,\n              'test_hits10': float(test_hits10)\n              }, \n    results_filename)\n\nif args.log_per_rel:\n    results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results_per_rel.json'\n    with open(results_filename, 'w') as json_file:\n        json.dump(perf_per_rel, json_file)\n\nsys.exit()"
  },
  {
    "path": "examples/linkproppred/tkgl-smallpedia/timetraveler.py",
    "content": "\"\"\"\nTimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting\nReference:\n- https://github.com/JHL-HUST/TITer\nHaohai Sun, Jialun Zhong, Yunpu Ma, Zhen Han, Kun He.\nTimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting EMNLP 2021\n\"\"\"\nimport sys\nimport timeit\n\nimport torch\nfrom torch.utils.data import Dataset,DataLoader\nimport logging\n\nimport numpy as np\nimport pickle\nfrom tqdm import tqdm\nimport os.path as osp\nfrom pathlib import Path\nimport os\n\nmodules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\nsys.path.append(modules_path)\nfrom modules.timetraveler_agent import Agent\nfrom modules.timetraveler_environment import Env\nfrom modules.timetraveler_dirichlet import Dirichlet, MLE_Dirchlet\nfrom modules.timetraveler_episode import Episode\nfrom modules.timetraveler_policygradient import PG\nfrom tgb.linkproppred.dataset import LinkPropPredDataset \nfrom tgb.linkproppred.evaluate import Evaluator\nfrom modules.timetraveler_trainertester import Trainer, Tester, getRelEntCooccurrence\nfrom tgb.utils.utils import set_random_seed,save_results \nfrom modules.tkg_utils import  get_args_timetraveler, reformat_ts, get_model_config_timetraveler\n\nclass QuadruplesDataset(Dataset):\n    \"\"\" this is an internal way how Timetraveler represents the data\n    \"\"\"\n    def __init__(self, examples):\n        \"\"\"\n        examples: a list of quadruples.\n        num_r: number of relations\n        \"\"\"\n        self.quadruples = examples.copy()\n\n\n    def __len__(self):\n        return len(self.quadruples)\n\n    def __getitem__(self, item):\n        return self.quadruples[item][0], \\\n               self.quadruples[item][1], \\\n               self.quadruples[item][2], \\\n               self.quadruples[item][3], \\\n               self.quadruples[item][4]\n    \ndef set_logger(save_path):\n    \"\"\"Write logs to checkpoint and console\"\"\"\n    if args.do_train:\n        log_file = os.path.join(save_path, 'train.log')\n    else:\n        log_file = os.path.join(save_path, 'test.log')\n\n    logging.basicConfig(\n        format='%(asctime)s %(levelname)-8s %(message)s',\n        level=logging.INFO,\n        datefmt='%Y-%m-%d %H:%M:%S',\n        filename=log_file,\n        filemode='w'\n    )\n    console = logging.StreamHandler()\n    console.setLevel(logging.INFO)\n    formatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s')\n    console.setFormatter(formatter)\n    logging.getLogger('').addHandler(console)\n\ndef preprocess_data(args, config, timestamps, save_path, all_quads):\n    \"\"\"\n    Preprocess the data and save the state-action space (pickle dump)\n    \"\"\"\n    # parser = argparse.ArgumentParser(description='Data preprocess', usage='preprocess_data.py [<args>] [-h | --help]')\n    # parser.add_argument('--data_dir', default='data/ICEWS14', type=str, help='Path to data.')\n\n    env = Env(all_quads, config)\n    state_actions_space = {}\n\n    with tqdm(total=len(all_quads)) as bar:\n        for (head, rel, tail, t, _) in all_quads:\n            if (head, t, True) not in state_actions_space.keys():\n                state_actions_space[(head, t, True)] = env.get_state_actions_space_complete(head, t, True, args.store_actions_num)\n                state_actions_space[(head, t, False)] = env.get_state_actions_space_complete(head, t, False, args.store_actions_num)\n            if (tail, t, True) not in state_actions_space.keys():\n                state_actions_space[(tail, t, True)] = env.get_state_actions_space_complete(tail, t, True, args.store_actions_num)\n                state_actions_space[(tail, t, False)] = env.get_state_actions_space_complete(tail, t, False, args.store_actions_num)\n            bar.update(1)\n    pickle.dump(state_actions_space, open(os.path.join(save_path,  args.state_actions_path), 'wb'))\n\ndef log_metrics(mode, step, metrics):\n    \"\"\"Print the evaluation logs\"\"\"\n    for metric in metrics:\n        logging.info('%s %s at epoch %d: %f' % (mode, metric, step, metrics[metric]))\n\ndef main(args):\n    \"\"\"\n    Main function to train and test the TimeTraveler model\"\"\"\n\n    start_overall = timeit.default_timer()\n    #######################Set Logger#################################\n    \n    save_path = f'{os.path.dirname(os.path.abspath(__file__))}/saved_models/'\n    if not os.path.exists(save_path):\n        os.makedirs(save_path)\n\n    if args.cuda and torch.cuda.is_available():\n        args.cuda = True\n    else:\n        args.cuda = False\n    set_logger(save_path)\n\n    #######################Create DataLoader#################################\n    # set hyperparameters\n    args.dataset = 'tkgl-smallpedia'\n\n    SEED = args.seed  # set the random seed for consistency\n    set_random_seed(SEED)\n\n    DATA=args.dataset\n    MODEL_NAME = 'TIMETRAVELER'\n\n    # load data\n    dataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\n\n    num_rels = dataset.num_rels\n    num_nodes = dataset.num_nodes \n    subjects = dataset.full_data[\"sources\"]\n    objects= dataset.full_data[\"destinations\"]\n    relations = dataset.edge_type\n\n    timestamps_orig = dataset.full_data[\"timestamps\"]\n    timestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1\n    all_quads = np.stack((subjects, relations, objects, timestamps,timestamps_orig), axis=1)\n\n    train_data = all_quads[dataset.train_mask]\n    train_entities = np.unique(np.concatenate((train_data[:, 0], train_data[:, 2])))\n    RelEntCooccurrence = getRelEntCooccurrence(train_data, num_rels)\n    train_data =QuadruplesDataset(train_data)\n    val_data = QuadruplesDataset(all_quads[dataset.val_mask])\n    test_data = QuadruplesDataset(all_quads[dataset.test_mask])\n\n    METRIC = dataset.eval_metric\n    evaluator = Evaluator(name=DATA)\n    neg_sampler = dataset.negative_sampler\n    #load the ns samples \n    dataset.load_val_ns()\n    dataset.load_test_ns()\n\n    train_dataloader = DataLoader(\n        train_data,\n        batch_size=args.batch_size,\n        shuffle=True,\n        num_workers=args.num_workers,\n    )\n\n    valid_dataloader = DataLoader(\n        val_data,\n        batch_size=args.test_batch_size,\n        shuffle=False,\n        num_workers=args.num_workers,\n    )\n\n    test_dataloader = DataLoader(\n        test_data,\n        batch_size=args.test_batch_size,\n        shuffle=False,\n        num_workers=args.num_workers,\n    )\n\n    ######################Creat the agent and the environment###########################\n    config = get_model_config_timetraveler(args, num_nodes, num_rels)\n    logging.info(config)\n    logging.info(args)\n\n    # creat the agent\n    agent = Agent(config)\n\n\n    # creat the environment\n    state_actions_path = os.path.join(save_path, args.state_actions_path)\n\n\n    ######################preprocessing###########################\n    if not os.path.exists(state_actions_path):\n        if args.preprocess:\n            print(\"preprocessing data...\")\n            preprocess_data(args, config, timestamps, save_path, list(all_quads))\n            state_action_space = pickle.load(open(os.path.join(save_path, args.state_actions_path), 'rb'))\n        else:\n            state_action_space = None\n    else:\n        print(\"load preprocessed data...\")\n        state_action_space = pickle.load(open(os.path.join(save_path, args.state_actions_path), 'rb'))\n\n\n    env = Env(list(all_quads), config, state_action_space)\n    # Create episode controller\n    episode = Episode(env, agent, config)\n    if args.cuda:\n        episode = episode.cuda()\n    pg = PG(config)  # Policy Gradient\n    optimizer = torch.optim.Adam(episode.parameters(), lr=args.lr, weight_decay=0.00001)\n\n    ######################Reward Shaping: MLE DIRICHLET alphas###########################\n    if args.reward_shaping: \n        try:\n            print(\"load alphas from pickle file\")\n            alphas = pickle.load(open(os.path.join(save_path, args.alphas_pkl), 'rb'))\n        except:\n            print('running MLE dirichlet now')\n            mle_d = MLE_Dirchlet(all_quads, num_rels, args.k, args.time_span,\n                         args.tol, args.method, args.maxiter)\n            pickle.dump(mle_d.alphas, open(os.path.join(save_path, args.alphas_pkl), 'wb'))\n\n            print('dumped alphas')\n            alphas = mle_d.alphas\n        distributions = Dirichlet(alphas, args.k)\n    else:\n        distributions = None\n\n    ######################Training and Testing###########################\n\n    trainer = Trainer(episode, pg, optimizer, args, distributions)\n    tester = Tester(episode, args, train_entities, RelEntCooccurrence, dataset.metric)\n    test_metrics ={}\n    val_metrics = {}\n    test_metrics[METRIC] = None\n    val_metrics[METRIC] = None\n\n    if args.do_train:\n        start_train =timeit.default_timer()\n        logging.info('Start Training......')\n        for i in range(args.max_epochs):\n            loss, reward = trainer.train_epoch(train_dataloader, len(train_data))\n            logging.info('Epoch {}/{} Loss: {}, reward: {}'.format(i, args.max_epochs, loss, reward))\n\n            #! checking GPU usage\n            free_mem, total_mem = torch.cuda.mem_get_info()\n            print (\"--------------GPU memory usage-----------\")\n            print (\"there are \", free_mem, \" free memory\")\n            print (\"there are \", total_mem, \" total available memory\")\n            print (\"there are \", total_mem - free_mem, \" used memory\")\n            print (\"--------------GPU memory usage-----------\")\n            \n            if i % args.save_epoch == 0 and i != 0:\n                trainer.save_model(save_path, 'checkpoint_{}.pth'.format(i))\n                logging.info('Save Model in {}'.format(save_path))\n\n            if i % args.valid_epoch == 0 and i != 0:\n                logging.info('Start Val......')\n                val_metrics = tester.test(valid_dataloader,\n                                      len(val_data), num_nodes, neg_sampler, evaluator, split_mode='val')\n                for mode in val_metrics.keys():\n                    logging.info('{} at epoch {}: {}'.format(mode, i, val_metrics[mode]))\n\n        trainer.save_model(save_path)\n        logging.info('Save Model in {}'.format(save_path))\n    else:\n          # # Load the model parameters\n        if os.path.isfile(save_path):\n            params = torch.load(save_path)\n            episode.load_state_dict(params['model_state_dict'])\n            optimizer.load_state_dict(params['optimizer_state_dict'])\n            logging.info('Load pretrain model: {}'.format(save_path))\n    if args.do_test:\n        logging.info('Start Testing......')\n        start_test = timeit.default_timer()\n        test_metrics = tester.test(test_dataloader,\n                              len(test_data), num_nodes, neg_sampler, evaluator, split_mode='test')\n        for mode in test_metrics.keys():\n            logging.info('Test {} : {}'.format(mode, test_metrics[mode]))\n\n        # saving the results...\n        results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\n        if not osp.exists(results_path):\n            os.mkdir(results_path)\n            print('INFO: Create directory {}'.format(results_path))\n        Path(results_path).mkdir(parents=True, exist_ok=True)\n        results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'\n        test_time = timeit.default_timer() - start_test\n        all_time = timeit.default_timer() - start_train \n        all_time_preprocess = timeit.default_timer() - start_overall \n\n        save_results({'model': MODEL_NAME,\n                    'data': DATA,\n                    'seed': SEED,\n                    f'val {METRIC}': float(val_metrics[METRIC]),\n                    f'test {METRIC}': float(test_metrics[METRIC]),\n                    'test_time': test_time,\n                    'tot_train_val_time': all_time,\n                    'tot_preprocess_train_val_time': all_time_preprocess\n                    }, \n            results_filename)     \n\nif __name__ == '__main__':\n    args = get_args_timetraveler()\n    main(args)"
  },
  {
    "path": "examples/linkproppred/tkgl-smallpedia/tkgl-smallpedia_example.py",
    "content": "import numpy as np\r\nimport timeit\r\nfrom tqdm import tqdm\r\nimport sys\r\nimport os.path as osp\r\nimport os\r\nfrom pathlib import Path\r\n# internal imports\r\nmodules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\r\nsys.path.append(modules_path)\r\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\r\nfrom torch_geometric.loader import TemporalDataLoader\r\nfrom tgb.linkproppred.evaluate import Evaluator\r\n\r\nDATA = \"tkgl-smallpedia\"\r\n\r\n# data loading\r\ndataset = PyGLinkPropPredDataset(name=DATA, root=\"datasets\")\r\ntrain_mask = dataset.train_mask\r\nval_mask = dataset.val_mask\r\ntest_mask = dataset.test_mask\r\ndata = dataset.get_TemporalData()\r\nmetric = dataset.eval_metric\r\n\r\nprint (\"there are {} nodes and {} edges\".format(dataset.num_nodes, dataset.num_edges))\r\nprint (\"there are {} relation types\".format(dataset.num_rels))\r\n\r\n#! must run in this order\r\nstatic_data = dataset.static_data\r\nstatic_head = static_data[\"head\"]\r\nstatic_tail = static_data[\"tail\"]\r\nstatic_edge_type = static_data[\"edge_type\"]\r\nprint ('static edges processed')\r\nprint (\"static data has \", static_head.shape[0], \" edges\")\r\n\r\n\r\ntimestamp = data.t\r\nhead = data.src\r\ntail = data.dst\r\nedge_type = data.edge_type #relation\r\n\r\n\r\n\r\nneg_sampler = dataset.negative_sampler\r\n\r\ntrain_data = data[train_mask]\r\nval_data = data[val_mask]\r\ntest_data = data[test_mask]\r\n\r\n\r\nmetric = dataset.eval_metric\r\nevaluator = Evaluator(name=DATA)\r\nneg_sampler = dataset.negative_sampler\r\n\r\nBATCH_SIZE = 200\r\nval_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)\r\ntest_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)\r\n\r\nstart_time = timeit.default_timer()\r\n#load the ns samples first\r\ndataset.load_val_ns()\r\nfor batch in tqdm(val_loader):\r\n    src, pos_dst, t, msg, rel = batch.src, batch.dst, batch.t, batch.msg, batch.edge_type\r\n    neg_batch_list = neg_sampler.query_batch(src.detach().cpu().numpy(), pos_dst.detach().cpu().numpy(), t.detach().cpu().numpy(), rel.detach().cpu().numpy(), split_mode='val')\r\nprint (\"loading ns samples from validation\", timeit.default_timer() - start_time)\r\n\r\nstart_time = timeit.default_timer()\r\ndataset.load_test_ns()\r\nfor batch in test_loader:\r\n    src, pos_dst, t, msg, rel = batch.src, batch.dst, batch.t, batch.msg, batch.edge_type\r\n    neg_batch_list = neg_sampler.query_batch(src.detach().cpu().numpy(), pos_dst.detach().cpu().numpy(), t.detach().cpu().numpy(), rel.detach().cpu().numpy(), split_mode='test')\r\nprint (\"loading ns samples from test\", timeit.default_timer() - start_time)\r\nprint (\"retrieved all negative samples\")\r\n\r\n\r\n# #* load numpy arrays instead\r\n# from tgb.linkproppred.dataset import LinkPropPredDataset\r\n\r\n# # data loading\r\n# dataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\r\n# data = dataset.full_data  \r\n# metric = dataset.eval_metric\r\n# sources = dataset.full_data['sources']\r\n# print (sources.dtype)\r\n\r\n"
  },
  {
    "path": "examples/linkproppred/tkgl-smallpedia/tlogic.py",
    "content": "\"\"\"\nhttps://github.com/liu-yushan/TLogic/tree/main/mycode\nTLogic: Temporal Logical Rules for Explainable Link Forecasting on Temporal Knowledge Graphs.\nYushan Liu, Yunpu Ma, Marcel Hildebrandt, Mitchell Joblin, Volker Tresp\n\"\"\"\n\n# imports\nimport sys\nimport os\nimport os.path as osp\nfrom pathlib import Path\nmodules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\nsys.path.append(modules_path)\nimport timeit\nimport argparse\nimport numpy as np\nimport json\nfrom joblib import Parallel, delayed\nimport itertools\n\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom tgb.linkproppred.dataset import LinkPropPredDataset \nfrom modules.tlogic_learn_modules import Temporal_Walk, Rule_Learner, store_edges\nimport modules.tlogic_apply_modules as ra\nfrom tgb.utils.utils import set_random_seed,  save_results\nfrom modules.tkg_utils import reformat_ts, get_inv_relation_id, create_scores_array\n\ndef learn_rules(i, num_relations):\n    \"\"\"\n    Learn rules (multiprocessing possible).\n\n    Parameters:\n        i (int): process number\n        num_relations (int): minimum number of relations for each process\n\n    Returns:\n        rl.rules_dict (dict): rules dictionary\n    \"\"\"\n\n    # if seed:\n    #     np.random.seed(seed)\n\n    num_rest_relations = len(all_relations) - (i + 1) * num_relations\n    if num_rest_relations >= num_relations:\n        relations_idx = range(i * num_relations, (i + 1) * num_relations)\n    else:\n        relations_idx = range(i * num_relations, len(all_relations))\n\n    num_rules = [0]\n    for k in relations_idx:\n        rel = all_relations[k]\n        for length in rule_lengths:\n            it_start =  timeit.default_timer()\n            for _ in range(num_walks):\n                walk_successful, walk = temporal_walk.sample_walk(length + 1, rel)\n                if walk_successful:\n                    rl.create_rule(walk)\n            it_end =  timeit.default_timer()\n            it_time = round(it_end - it_start, 6)\n            num_rules.append(sum([len(v) for k, v in rl.rules_dict.items()]) // 2)\n            num_new_rules = num_rules[-1] - num_rules[-2]\n            print(\n                \"Process {0}: relation {1}/{2}, length {3}: {4} sec, {5} rules\".format(\n                    i,\n                    k - relations_idx[0] + 1,\n                    len(relations_idx),\n                    length,\n                    it_time,\n                    num_new_rules,\n                )\n            )\n\n    return rl.rules_dict\n\ndef apply_rules(i, num_queries, rules_dict, neg_sampler, data, window, learn_edges, all_quads, args, split_mode, \n                log_per_rel=False, num_rels=0):\n    \"\"\"\n    Apply rules (multiprocessing possible).\n\n    Parameters:\n        i (int): process number\n        num_queries (int): minimum number of queries for each process\n\n    Returns:\n        hits_list (list): hits list (hits@10 per sample)\n        perf_list (list): performance list (mrr per sample)\n    \"\"\"\n    perf_per_rel = {}\n    for rel in range(num_rels):\n            perf_per_rel[rel] = []\n    print(\"Start process\", i, \"...\")\n    all_candidates = [dict() for _ in range(len(args))]\n    no_cands_counter = 0\n\n    num_rest_queries = len(data) - (i + 1) * num_queries\n    if num_rest_queries >= num_queries:\n        test_queries_idx = range(i * num_queries, (i + 1) * num_queries)\n    else:\n        test_queries_idx = range(i * num_queries, len(data))\n\n    cur_ts = data[test_queries_idx[0]][3]\n    edges = ra.get_window_edges(all_quads[:,0:4], cur_ts, learn_edges, window)\n\n    it_start =  timeit.default_timer()\n    hits_list = [0] * len(test_queries_idx)\n    perf_list = [0] * len(test_queries_idx)\n    for index, j in enumerate(test_queries_idx):\n        neg_sample_el =  neg_sampler.query_batch(np.expand_dims(np.array(data[j,0]), axis=0), \n                                                np.expand_dims(np.array(data[j,2]), axis=0), \n                                                np.expand_dims(np.array(data[j,4]), axis=0), \n                                                np.expand_dims(np.array(data[j,1]), axis=0), \n                                                split_mode=split_mode)[0]        \n        \n        # neg_samples_batch[j]\n        pos_sample_el =  data[j,2]\n        test_query = data[j]\n        assert pos_sample_el == test_query[2]\n        cands_dict = [dict() for _ in range(len(args))]\n\n        if test_query[3] != cur_ts:\n            cur_ts = test_query[3]\n            edges = ra.get_window_edges(all_quads[:,0:4], cur_ts, learn_edges, window)\n\n        if test_query[1] in rules_dict:\n            dicts_idx = list(range(len(args)))\n            for rule in rules_dict[test_query[1]]:\n                walk_edges = ra.match_body_relations(rule, edges, test_query[0])\n\n                if 0 not in [len(x) for x in walk_edges]:\n                    rule_walks = ra.get_walks(rule, walk_edges)\n                    if rule[\"var_constraints\"]:\n                        rule_walks = ra.check_var_constraints(\n                            rule[\"var_constraints\"], rule_walks\n                        )\n\n                    if not rule_walks.empty:\n                        cands_dict = ra.get_candidates(\n                            rule,\n                            rule_walks,\n                            cur_ts,\n                            cands_dict,\n                            score_func,\n                            args,\n                            dicts_idx,\n                        )\n                        for s in dicts_idx:\n                            cands_dict[s] = {\n                                x: sorted(cands_dict[s][x], reverse=True)\n                                for x in cands_dict[s].keys()\n                            }\n                            cands_dict[s] = dict(\n                                sorted(\n                                    cands_dict[s].items(),\n                                    key=lambda item: item[1],\n                                    reverse=True,\n                                )\n                            )\n                            top_k_scores = [v for _, v in cands_dict[s].items()][:top_k]\n                            unique_scores = list(\n                                scores for scores, _ in itertools.groupby(top_k_scores)\n                            )\n                            if len(unique_scores) >= top_k:\n                                dicts_idx.remove(s)\n                        if not dicts_idx:\n                            break\n\n            if cands_dict[0]:\n                for s in range(len(args)):\n                    # Calculate noisy-or scores\n                    scores = list(\n                        map(\n                            lambda x: 1 - np.product(1 - np.array(x)),\n                            cands_dict[s].values(),\n                        )\n                    )\n                    cands_scores = dict(zip(cands_dict[s].keys(), scores))\n                    noisy_or_cands = dict(\n                        sorted(cands_scores.items(), key=lambda x: x[1], reverse=True)\n                    )\n                    all_candidates[s][j] = noisy_or_cands\n            else:  # No candidates found by applying rules\n                no_cands_counter += 1\n                for s in range(len(args)):\n                    all_candidates[s][j] = dict()\n\n        else:  # No rules exist for this relation\n            no_cands_counter += 1\n            for s in range(len(args)):\n                all_candidates[s][j] = dict()\n\n        if not (j - test_queries_idx[0] + 1) % 100:\n            it_end =  timeit.default_timer()\n            it_time = round(it_end - it_start, 6)\n            print(\n                \"Process {0}: test samples finished: {1}/{2}, {3} sec\".format(\n                    i, j - test_queries_idx[0] + 1, len(test_queries_idx), it_time\n                )\n            )\n            it_start =  timeit.default_timer()\n\n        predictions = create_scores_array(all_candidates[s][j], num_nodes)  \n        predictions_of_interest_pos = np.array(predictions[pos_sample_el])\n        predictions_of_interest_neg = predictions[neg_sample_el]\n        input_dict = {\n            \"y_pred_pos\": predictions_of_interest_pos,\n            \"y_pred_neg\": predictions_of_interest_neg,\n            \"eval_metric\": ['mrr'], \n        }\n\n        predictions = evaluator.eval(input_dict)\n        perf_list[index] = predictions['mrr']\n        hits_list[index] = predictions['hits@10']\n        if split_mode == \"test\":\n            if log_per_rel:\n                perf_per_rel[test_query[1]].append(perf_list[index]) #test_query[1] is the relation index\n\n    if split_mode == \"test\":\n        if log_per_rel:   \n            for rel in range(num_rels):\n                if len(perf_per_rel[rel]) > 0:\n                    perf_per_rel[rel] = float(np.mean(perf_per_rel[rel]))\n                else:\n                    perf_per_rel.pop(rel)       \n               \n\n    return perf_list, hits_list, perf_per_rel\n\n\n## args\ndef get_args(): \n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--dataset\", \"-d\", default=\"tkgl-smallpedia\", type=str) \n    parser.add_argument(\"--rule_lengths\", \"-l\", default=\"1\", type=int, nargs=\"+\")\n    parser.add_argument(\"--num_walks\", \"-n\", default=\"100\", type=int)\n    parser.add_argument(\"--transition_distr\", default=\"exp\", type=str)\n    parser.add_argument(\"--window\", \"-w\", default=0, type=int) # set to e.g. 200 if only the most recent 200 timesteps should be considered. set to -2 if multistep\n    parser.add_argument(\"--top_k\", default=20, type=int)\n    parser.add_argument(\"--num_processes\", \"-p\", default=1, type=int)\n    parser.add_argument(\"--alpha\", \"-alpha\",  default=0.99, type=float) # fix alpha. used if trainflag == false\n    # parser.add_argument(\"--train_flag\", \"-tr\",  default=True) # do we need training, ie selection of lambda and alpha\n    parser.add_argument(\"--save_config\", \"-c\",  default=True) # do we need to save the selection of lambda and alpha in config file?\n    parser.add_argument('--seed', type=int, help='Random seed', default=1)\n    parser.add_argument('--run_nr', type=int, help='Run Number', default=1)\n    parser.add_argument('--learn_rules_flag', type=bool, help='Do we want to learn the rules', default=True)\n    parser.add_argument('--rule_filename', type=str, help='if rules not learned: where are they stored', default='0_r[3]_n100_exp_s1_rules.json')\n    parser.add_argument('--log_per_rel', type=bool, help='Do we want to log mrr per relation', default=False)\n    parser.add_argument('--compute_valid_mrr', type=bool, help='Do we want to compute mrr for valid set', default=True)\n    parsed = vars(parser.parse_args())\n    return parsed\n\nstart_o =  timeit.default_timer()\n\n## get args\nparsed = get_args()\ndataset = parsed[\"dataset\"]\nrule_lengths = parsed[\"rule_lengths\"]\nrule_lengths = [rule_lengths] if (type(rule_lengths) == int) else rule_lengths\nprint('rule_lengths', rule_lengths)\nnum_walks = parsed[\"num_walks\"]\ntransition_distr = parsed[\"transition_distr\"]\nnum_processes = parsed[\"num_processes\"]\nwindow = parsed[\"window\"]\ntop_k = parsed[\"top_k\"]\nlog_per_rel = parsed['log_per_rel']\n\nMODEL_NAME = 'TLogic'\nSEED = parsed['seed']  # set the random seed for consistency\nset_random_seed(SEED)\n\n## load dataset and prepare it accordingly\nname = parsed[\"dataset\"]\ncompute_valid_mrr = parsed[\"compute_valid_mrr\"]\ndataset = LinkPropPredDataset(name=name, root=\"datasets\", preprocess=True)\nDATA = name\n\nrelations = dataset.edge_type\nnum_rels = dataset.num_rels\n\nsubjects = dataset.full_data[\"sources\"]\nobjects= dataset.full_data[\"destinations\"]\nnum_nodes = dataset.num_nodes \ntimestamps_orig = dataset.full_data[\"timestamps\"]\ntimestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1\n\nall_quads = np.stack((subjects, relations, objects, timestamps, timestamps_orig), axis=1)\ntrain_data = all_quads[dataset.train_mask,0:4] # we do not need the original timestamps\nval_data = all_quads[dataset.val_mask,0:5]\ntest_data = all_quads[dataset.test_mask,0:5]\nall_data = all_quads[:,0:4]\n\nmetric = dataset.eval_metric\nevaluator = Evaluator(name=name)\nneg_sampler = dataset.negative_sampler\n\ninv_relation_id = get_inv_relation_id(num_rels)\n\n#load the ns samples \n\ndataset.load_val_ns()\ndataset.load_test_ns()\noutput_dir =  f'{osp.dirname(osp.abspath(__file__))}/saved_models/'\nlearn_rules_flag = parsed['learn_rules_flag']\n## 1. learn rules\nstart_train =  timeit.default_timer()\nif learn_rules_flag:\n    print(\"start learning rules\")\n    # edges (dict): edges for each relation\n    # inv_relation_id (dict): mapping of relation to inverse relation\n    \n    temporal_walk = Temporal_Walk(train_data, inv_relation_id, transition_distr)\n    rl = Rule_Learner(edges=temporal_walk.edges, id2relation=None, inv_relation_id=inv_relation_id,  \n                        output_dir=output_dir)\n    all_relations = sorted(temporal_walk.edges)  # Learn for all relations\n\n    start =  timeit.default_timer()\n    num_relations = len(all_relations) // num_processes\n    output = Parallel(n_jobs=num_processes)(\n        delayed(learn_rules)(i, num_relations) for i in range(num_processes)\n    )\n    end =  timeit.default_timer()\n\n    all_rules = output[0]\n    for i in range(1, num_processes):\n        all_rules.update(output[i])\n\n    total_time = round(end - start, 6)\n    print(\"Learning finished in {} seconds.\".format(total_time))\n\n    rl.rules_dict = all_rules\n    rl.sort_rules_dict()\n\n    rule_filename = rl.save_rules(0, rule_lengths, num_walks, transition_distr, SEED)\n    # rl.save_rules_verbalized(0, rule_lengths, num_walks, transition_distr, seed)\n    # rules_statistics(rl.rules_dict)\nelse:\n    rule_filename = parsed['rule_filename']\n    print(\"Loading rules from file {}\".format(parsed['rule_filename']))\n\nend_train =  timeit.default_timer()\n\n## 2. Apply rules\n\nrules_dict = json.load(open(output_dir + rule_filename))\nrules_dict = {int(k): v for k, v in rules_dict.items()}\n\nrules_dict = ra.filter_rules(\n    rules_dict, min_conf=0.01, min_body_supp=2, rule_lengths=rule_lengths\n) # filter rules for minimum confidence, body support and rule length\n\nlearn_edges = store_edges(train_data)\nscore_func = ra.score_12\n# It is possible to specify a list of list of arguments for tuning\nargs = [[0.1, 0.5]]\n\n# compute valid mrr\nstart_valid =  timeit.default_timer()\nif compute_valid_mrr:\n    print('Computing valid MRR')\n\n    num_queries = len(val_data) // num_processes\n\n    output = Parallel(n_jobs=num_processes)(\n        delayed(apply_rules)(i, num_queries,rules_dict, neg_sampler, val_data, window, learn_edges, \n                            all_quads, args, split_mode='val') for i in range(num_processes))\n    end =  timeit.default_timer()\n\n    perf_list_val = []\n    hits_list_val = []\n\n    for i in range(num_processes):\n        perf_list_val.extend(output[i][0])\n        hits_list_val.extend(output[i][1])\nelse:\n    perf_list_val = [0]\n    hits_list_val = [0]\n    \n\nend_valid =  timeit.default_timer()\n\n# compute test mrr\nif log_per_rel ==True:\n    num_processes = 1 #otherwise logging per rel does not work for our implementation\nstart_test =  timeit.default_timer()\nprint('Computing test MRR')\nstart =  timeit.default_timer()\nnum_queries = len(test_data) // num_processes\n\noutput = Parallel(n_jobs=num_processes)(\n    delayed(apply_rules)(i, num_queries,rules_dict, neg_sampler, test_data, window, learn_edges, \n                         all_quads, args, split_mode='test', log_per_rel=log_per_rel, num_rels=num_rels) for i in range(num_processes))\nend =  timeit.default_timer()\n\nperf_list_all = []\nhits_list_all = []\n\n\nfor i in range(num_processes):\n    perf_list_all.extend(output[i][0])\n    hits_list_all.extend(output[i][1])\nif log_per_rel == True:\n    perf_per_rel = output[0][2]\n\n\ntotal_time = round(end - start, 6)\ntotal_valid_time = round(end_valid - start_valid, 6)\nprint(\"Application finished in {} seconds.\".format(total_time))\n\nprint(f\"The valid MRR is {np.mean(perf_list_val)}\")\nprint(f\"The MRR is {np.mean(perf_list_all)}\")\nprint(f\"The Hits@10 is {np.mean(hits_list_all)}\")\nprint(f\"We have {len(perf_list_all)} predictions\")\nprint(f\"The test set has len {len(test_data)} \")\n\nend_o =  timeit.default_timer()\ntrain_time_o = round(end_train- start_train, 6)  \ntest_time_o = round(end_o- start_test, 6)  \ntotal_time_o = round(end_o- start_o, 6)  \nprint(\"Running Training to find best configs finished in {} seconds.\".format(train_time_o))\nprint(\"Running testing with best configs finished in {} seconds.\".format(test_time_o))\nprint(\"Running all steps finished in {} seconds.\".format(total_time_o))\n\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\n\nif log_per_rel == True:\n    results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results_per_rel.json'\n    with open(results_filename, 'w') as json_file:\n        json.dump(perf_per_rel, json_file)\n\nresults_filename = f'{results_path}/{MODEL_NAME}_NONE_{DATA}_results.json'\nmetric = dataset.eval_metric\nsave_results({'model': MODEL_NAME,\n              'train_flag': None,\n              'rule_len': rule_lengths,\n              'window': window,\n              'data': DATA,\n              'run': 1,\n              'seed': SEED,\n              metric: float(np.mean(perf_list_all)),\n              'hits10': float(np.mean(hits_list_all)),\n              'val_mrr': float(np.mean(perf_list_val)),\n              'test_time': test_time_o,\n              'tot_train_val_time': total_time_o,\n              'valid_time': total_valid_time\n              }, \n    results_filename)\n"
  },
  {
    "path": "examples/linkproppred/tkgl-wikidata/edgebank.py",
    "content": "\"\"\"\nDynamic Link Prediction with EdgeBank\nNOTE: This implementation works only based on `numpy`\n\nReference: \n    - https://github.com/fpour/DGB/tree/main\n\n\n\"\"\"\n\nimport timeit\nimport numpy as np\nfrom sklearn.metrics import average_precision_score, roc_auc_score\nfrom torch_geometric.loader import TemporalDataLoader\nfrom tqdm import tqdm\nimport math\nimport os\nimport os.path as osp\nfrom pathlib import Path\nimport sys\nimport argparse\n\n# internal imports\ntgb_modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\nsys.path.append(tgb_modules_path)\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom modules.edgebank_predictor import EdgeBankPredictor\nfrom tgb.utils.utils import set_random_seed\nfrom tgb.linkproppred.dataset import LinkPropPredDataset\nfrom tgb.utils.utils import save_results\n\n# ==================\n# ==================\n# ==================\n\ndef test(data, test_mask, neg_sampler, split_mode):\n    r\"\"\"\n    Evaluated the dynamic link prediction\n    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges\n\n    Parameters:\n        data: a dataset object\n        test_mask: required masks to load the test set edges\n        neg_sampler: an object that gives the negative edges corresponding to each positive edge\n        split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives\n    Returns:\n        perf_metric: the result of the performance evaluation\n    \"\"\"\n    num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)\n    perf_list = []\n    hits_list = []\n    for batch_idx in tqdm(range(num_batches)):\n        start_idx = batch_idx * BATCH_SIZE\n        end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))\n        pos_src, pos_dst, pos_t, pos_edge = (\n            data['sources'][test_mask][start_idx: end_idx],\n            data['destinations'][test_mask][start_idx: end_idx],\n            data['timestamps'][test_mask][start_idx: end_idx],\n            data['edge_type'][test_mask][start_idx: end_idx],\n        )\n        neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, pos_edge, split_mode=split_mode)\n        \n        for idx, neg_batch in enumerate(neg_batch_list):\n            query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])\n            query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])\n\n            y_pred = edgebank.predict_link(query_src, query_dst)\n            # compute MRR\n            input_dict = {\n                \"y_pred_pos\": np.array([y_pred[0]]),\n                \"y_pred_neg\": np.array(y_pred[1:]),\n                \"eval_metric\": [metric],\n            }\n            results = evaluator.eval(input_dict)\n            perf_list.append(results[metric])\n            hits_list.append(results['hits@10'])\n            \n        # update edgebank memory after each positive batch\n        edgebank.update_memory(pos_src, pos_dst, pos_t)\n\n    perf_metrics = float(np.mean(perf_list))\n    perf_hits = float(np.mean(hits_list))\n\n    return perf_metrics, perf_hits\n\ndef get_args():\n    parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')\n    parser.add_argument('-d', '--data', type=str, help='Dataset name', default='tkgl-wikidata')\n    parser.add_argument('--bs', type=int, help='Batch size', default=200)\n    parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)\n    parser.add_argument('--seed', type=int, help='Random seed', default=1)\n    parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])\n    parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)\n\n    try:\n        args = parser.parse_args()\n    except:\n        parser.print_help()\n        sys.exit(0)\n    return args, sys.argv \n\n# ==================\n# ==================\n# ==================\n\nstart_overall = timeit.default_timer()\n\n# set hyperparameters\nargs, _ = get_args()\n\nSEED = args.seed  # set the random seed for consistency\nset_random_seed(SEED)\nMEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`\nBATCH_SIZE = args.bs\nK_VALUE = args.k_value\nTIME_WINDOW_RATIO = args.time_window_ratio\nDATA = args.data\nMODEL_NAME = 'EdgeBank'\n\n\n\n# data loading with `numpy`\ndataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\ndata = dataset.full_data  \nmetric = dataset.eval_metric\n\n\n# get masks\ntrain_mask = dataset.train_mask\nval_mask = dataset.val_mask\ntest_mask = dataset.test_mask\n\n#data for memory in edgebank\nhist_src = np.concatenate([data['sources'][train_mask]])\nhist_dst = np.concatenate([data['destinations'][train_mask]])\nhist_ts = np.concatenate([data['timestamps'][train_mask]])\n\n# Set EdgeBank with memory updater\nedgebank = EdgeBankPredictor(\n        hist_src,\n        hist_dst,\n        hist_ts,\n        memory_mode=MEMORY_MODE,\n        time_window_ratio=TIME_WINDOW_RATIO)\n\nprint(\"==========================================================\")\nprint(f\"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============\")\nprint(\"==========================================================\")\n\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'\n\n# ==================================================== Test\n# loading the validation negative samples\ndataset.load_val_ns()\n\n# testing ...\nstart_val = timeit.default_timer()\nperf_metric_val, perf_hits_val = test(data, val_mask, neg_sampler, split_mode='val')\nend_val = timeit.default_timer()\n\nprint(f\"INFO: val: Evaluation Setting: >>> ONE-VS--ALL <<< \")\nprint(f\"\\tval: {metric}: {perf_metric_val: .4f}\")\nval_time = timeit.default_timer() - start_val\nprint(f\"\\tval: Elapsed Time (s): {val_time: .4f}\")\n\n\n\n\n# ==================================================== Test\n# loading the test negative samples\ndataset.load_test_ns()\n\n# testing ...\nstart_test = timeit.default_timer()\nperf_metric_test, perf_hits_test = test(data, test_mask, neg_sampler, split_mode='test')\nend_test = timeit.default_timer()\n\nprint(f\"INFO: Test: Evaluation Setting: >>>  <<< \")\nprint(f\"\\tTest: {metric}: {perf_metric_test: .4f}\")\ntest_time = timeit.default_timer() - start_test\nprint(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\n\nsave_results({'model': MODEL_NAME,\n              'memory_mode': MEMORY_MODE,\n              'data': DATA,\n              'run': 1,\n              'seed': SEED,\n              metric: perf_metric_test,\n              'val_mrr': perf_metric_val,\n              'test_time': test_time,\n              'tot_train_val_time': test_time+val_time,\n              'hits10': perf_hits_test}, \n    results_filename)\n\n"
  },
  {
    "path": "examples/linkproppred/tkgl-wikidata/recurrencybaseline.py",
    "content": "\"\"\" from paper: \"History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting\" \nJulia Gastinger, Christian Meilicke, Federico Errica, Timo Sztyler, Anett Schuelke, Heiner Stuckenschmidt (IJCAI 2024)\n\n@inproceedings{gastinger2024baselines,\n  title={History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting},\n  author={Gastinger, Julia and Meilicke, Christian and Errica, Federico and Sztyler, Timo and Schuelke, Anett and Stuckenschmidt, Heiner},\n  booktitle={33nd International Joint Conference on Artificial Intelligence (IJCAI 2024)},\n  year={2024},\n  organization={International Joint Conferences on Artificial Intelligence Organization}\n}\n\n\"\"\"\n\n## imports\nimport timeit\nimport argparse\nimport numpy as np\nfrom copy import copy\nfrom pathlib import Path\nimport ray\nimport sys\nimport os\nimport os.path as osp\nimport json\n#internal imports \nmodules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\nsys.path.append(modules_path)\nfrom modules.recurrencybaseline_predictor import baseline_predict, baseline_predict_remote\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom tgb.linkproppred.dataset import LinkPropPredDataset \nfrom tgb.utils.utils import set_random_seed,  save_results \nfrom modules.tkg_utils import create_basis_dict, group_by, reformat_ts\n\ndef predict(num_processes,  data_c_rel, all_data_c_rel, alpha, lmbda_psi,\n            perf_list_all, hits_list_all, window, neg_sampler, split_mode):\n    \"\"\" create predictions for each relation on test or valid set and compute mrr\n    :return  perf_list_all: list of mrrs for each test query\n    :return hits_list_all: list of hits for each test query\n    \"\"\"\n\n    first_ts = data_c_rel[0][3]\n    ## use this if you wanna use ray:\n    num_queries = len(data_c_rel) // num_processes\n    if num_queries < num_processes: # if we do not have enough queries for all the processes\n        num_processes_tmp = 1\n        num_queries = len(data_c_rel)\n    else:\n        num_processes_tmp = num_processes  \n    if num_processes > 1:\n        object_references =[]                   \n        \n        for i in range(num_processes_tmp):\n            num_test_queries = len(data_c_rel) - (i + 1) * num_queries\n            if num_test_queries >= num_queries:\n                test_queries_idx =[i * num_queries, (i + 1) * num_queries]\n            else:\n                test_queries_idx = [i * num_queries, len(test_data)]\n\n            valid_data_b = data_c_rel[test_queries_idx[0]:test_queries_idx[1]]\n\n            ob = baseline_predict_remote.remote(num_queries, valid_data_b, all_data_c_rel, window, \n                                basis_dict, \n                                num_nodes, num_rels, lmbda_psi, \n                                alpha, evaluator,first_ts, neg_sampler, split_mode)\n            object_references.append(ob)\n\n        output = ray.get(object_references)\n\n        # updates the scores and logging dict for each process\n        for proc_loop in range(num_processes_tmp):\n            perf_list_all.extend(output[proc_loop][0])\n            hits_list_all.extend(output[proc_loop][1])\n\n    else:\n        perf_list, hits_list = baseline_predict(len(data_c_rel), data_c_rel, all_data_c_rel, \n                            window, basis_dict, \n                            num_nodes, num_rels, lmbda_psi, \n                            alpha, evaluator, first_ts, neg_sampler, split_mode)                  \n        perf_list_all.extend(perf_list)\n        hits_list_all.extend(hits_list)\n    \n    return perf_list_all, hits_list_all\n\n## test\ndef test(best_config, all_relations,test_data_prel, all_data_prel, neg_sampler, num_processes, window, split_mode='test'):  \n    \"\"\" create predictions by loopoing through all relations on test or valid set and compute mrr\n    :return  perf_list_all: list of mrrs for each test query\n    :return hits_list_all: list of hits for each test query\n    \"\"\"       \n    perf_list_all = []\n    hits_list_all =[]\n    \n    csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'\n\n    ## loop through relations and apply baselines\n    \n    for rel in all_relations:\n        start =  timeit.default_timer()\n        if rel in test_data_prel.keys():\n            lmbda_psi = best_config[str(rel)]['lmbda_psi'][0]\n            alpha = best_config[str(rel)]['alpha'][0]\n\n            # test data for this relation\n            test_data_c_rel = test_data_prel[rel]\n            timesteps_test = list(set(test_data_c_rel[:,3]))\n            timesteps_test.sort()\n            all_data_c_rel = all_data_prel[rel]         \n            perf_list_rel = []\n            hits_list_rel = []\n            perf_list_rel, hits_list_rel = predict(num_processes, test_data_c_rel,\n                                                all_data_c_rel, alpha, lmbda_psi,perf_list_rel, hits_list_rel, \n                                                window, neg_sampler, split_mode)\n            perf_list_all.extend(perf_list_rel)\n            hits_list_all.extend(hits_list_rel)\n        else:\n            perf_list_rel =[]\n         \n        end =  timeit.default_timer()\n        total_time = round(end - start, 6)  \n        print(\"Relation {} finished in {} seconds.\".format(rel, total_time))\n        \n\n        with open(csv_file, 'a') as f:\n            f.write(\"{},{}\\n\".format(rel, perf_list_rel))\n\n    return perf_list_all, hits_list_all\n\ndef read_dict_compute_mrr(split_mode='test'):\n    \"\"\" read the results per relation  from a precreated file and compute mrr\n    :return  mrr_per_rel: dictionary of mrrs for each relation\n    :return all_mrrs: list of mrrs for all relations\n    \"\"\"\n    csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'\n    # Initialize an empty dictionary to store the data\n    results_per_rel_dict = {}\n    mrr_per_rel = {}\n    all_mrrs = []\n    # Open the file for reading\n    with open(csv_file, 'r') as f:\n        # Read each line in the file\n        for line in f:\n            # Split the line at the comma\n            parts = line.strip().split(',')\n            # Extract the key (the first part)\n            key = int(parts[0])\n            # Extract the values (the rest of the parts), remove square brackets\n            values = [float(value.strip('[]')) for value in parts[1:]]\n            # Add the key-value pair to the dictionary\n            if key in results_per_rel_dict.keys():\n                print(f\"Key {key} already exists in the dictionary!!! might have duplicate entries in results csv\")\n            results_per_rel_dict[key] = values\n            all_mrrs.extend(values)\n            mrr_per_rel[key] = np.mean(values)\n\n    if len(list(results_per_rel_dict.keys())) != num_rels:\n        print(\"we do not have entries for each rel in the results csv file. only num enties: \", len(list(results_per_rel_dict.keys())))\n\n    print(\"Split mode: \"+split_mode +\" Mean MRR: \", np.mean(all_mrrs))\n    print(\"mrr per relation: \", mrr_per_rel)\n\n\n\n## train\ndef train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampler, num_processes, window):\n    \"\"\" optional, find best values for lambda and alpha by looping through all relations and testing an a fixed set of params\n    based on validation mrr\n    :return best_config: dictionary of best params for each relation\n    \"\"\"\n    best_config= {}\n    best_mrr = 0\n    for rel in rels: # loop through relations. for each relation, apply rules with selected params, compute valid mrr\n        start =  timeit.default_timer()\n        rel_key = int(rel)            \n\n        best_config[str(rel_key)] = {}\n        best_config[str(rel_key)]['not_trained'] = 'True'    \n        best_config[str(rel_key)]['lmbda_psi'] = [default_lmbda_psi,0] #default\n        best_config[str(rel_key)]['other_lmbda_mrrs'] = list(np.zeros(len(params_dict['lmbda_psi'])))\n        best_config[str(rel_key)]['alpha'] = [default_alpha,0]  #default    \n        best_config[str(rel_key)]['other_alpha_mrrs'] = list(np.zeros(len(params_dict['alpha'])))\n        \n        if rel in val_data_prel.keys():      \n            # valid data for this relation  \n            val_data_c_rel = val_data_prel[rel]\n            timesteps_valid = list(set(val_data_c_rel[:,3]))\n            timesteps_valid.sort()\n            trainval_data_c_rel = trainval_data_prel[rel]\n\n            ######  1) select lambda ###############        \n            lmbdas_psi = params_dict['lmbda_psi']        \n\n            alpha = 1\n            best_lmbda_psi = 0.1\n            best_mrr_psi = 0\n            lmbda_mrrs = []\n\n            best_config[str(rel_key)]['num_app_valid'] = copy(len(val_data_c_rel))\n            best_config[str(rel_key)]['num_app_train_valid'] = copy(len(trainval_data_c_rel))         \n            best_config[str(rel_key)]['not_trained'] = 'False'       \n            \n            for lmbda_psi in lmbdas_psi:   \n                perf_list_r = []\n                hits_list_r = []\n                perf_list_r, hits_list_r = predict(num_processes, val_data_c_rel, \n                                                    trainval_data_c_rel, alpha, lmbda_psi,perf_list_r, hits_list_r, \n                                                    window, neg_sampler, split_mode='val')\n                # compute mrr\n                mrr = np.mean(perf_list_r)\n                # # is new mrr better than previous best? if yes: store lmbda\n                if mrr > best_mrr_psi:\n                    best_mrr_psi = float(mrr)\n                    best_lmbda_psi = lmbda_psi\n\n\n                lmbda_mrrs.append(float(mrr))\n            best_config[str(rel_key)]['lmbda_psi'] = [best_lmbda_psi, best_mrr_psi]\n            best_config[str(rel_key)]['other_lmbda_mrrs'] = lmbda_mrrs\n            best_mrr = best_mrr_psi\n            ##### 2) select alpha ###############\n            best_config[str(rel_key)]['not_trained'] = 'False'    \n            alphas = params_dict['alpha'] \n            lmbda_psi = best_config[str(rel_key)]['lmbda_psi'][0] # use the best lmbda psi\n\n            alpha_mrrs = []\n            # perf_list_all = []\n            best_mrr_alpha = 0\n            best_alpha=0.99\n            for alpha in alphas:\n                perf_list_r = []\n                hits_list_r = []\n\n                perf_list_r, hits_list_r = predict(num_processes, val_data_c_rel, \n                                                    trainval_data_c_rel, alpha, lmbda_psi,perf_list_r, hits_list_r, \n                                                    window, neg_sampler, split_mode='val')\n                # compute mrr\n                mrr_alpha = np.mean(perf_list_r)\n\n                # is new mrr better than previous best? if yes: store alpha\n                if mrr_alpha > best_mrr_alpha:\n                    best_mrr_alpha = float(mrr_alpha)\n                    best_alpha = alpha\n                    best_mrr = best_mrr_alpha\n                alpha_mrrs.append(float(mrr_alpha))\n\n            best_config[str(rel_key)]['alpha'] = [best_alpha, best_mrr_alpha]\n            best_config[str(rel_key)]['other_alpha_mrrs'] = alpha_mrrs\n\n        end =  timeit.default_timer()\n        total_time = round(end - start, 6)  \n        print(\"Relation {} finished in {} seconds.\".format(rel, total_time))\n    return best_config\n\n\n\n## args\ndef get_args(): \n    \"\"\"parse all arguments for the script\"\"\"\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--dataset\", \"-d\", default=\"tkgl-wikidata\", type=str) \n    parser.add_argument(\"--window\", \"-w\", default=0, type=int) # set to e.g. 200 if only the most recent 200 timesteps should be considered. set to -2 if multistep\n    parser.add_argument(\"--num_processes\", \"-p\", default=1, type=int)\n    parser.add_argument(\"--lmbda\", \"-l\",  default=0.1, type=float) # fix lambda. used if trainflag == false\n    parser.add_argument(\"--alpha\", \"-alpha\",  default=0.99, type=float) # fix alpha. used if trainflag == false\n    parser.add_argument(\"--train_flag\", \"-tr\",  default='False') # do we need training, ie selection of lambda and alpha\n    parser.add_argument(\"--load_flag\", \"-lo\",  default='False') # if train_flag set to True: do you want to load best_config?\n    parser.add_argument(\"--save_config\", \"-c\",  default='True') # do we need to save the selection of lambda and alpha in config file?\n    parser.add_argument('--seed', type=int, help='Random seed', default=1) # not needed\n    parsed = vars(parser.parse_args())\n    return parsed\n\nstart_o =  timeit.default_timer()\n\nparsed = get_args()\nif parsed['num_processes']>1:\n    ray.init(num_cpus=parsed[\"num_processes\"], num_gpus=0)\nMODEL_NAME = 'RecurrencyBaseline'\nSEED = parsed['seed']  # set the random seed for consistency\nset_random_seed(SEED)\nperrel_results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_models'\nif not osp.exists(perrel_results_path):\n    os.mkdir(perrel_results_path)\n    print('INFO: Create directory {}'.format(perrel_results_path))\nPath(perrel_results_path).mkdir(parents=True, exist_ok=True)\n\n## load dataset and prepare it accordingly\nname = parsed[\"dataset\"]\ndataset = LinkPropPredDataset(name=name, root=\"datasets\", preprocess=True)\nDATA = name\n\nrelations = dataset.edge_type\nnum_rels = dataset.num_rels\nrels = np.arange(0,num_rels)\nsubjects = dataset.full_data[\"sources\"]\nobjects= dataset.full_data[\"destinations\"]\nnum_nodes = dataset.num_nodes \ntimestamps_orig = dataset.full_data[\"timestamps\"]\ntimestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1\n\nprint(\"split train valid test data\")\nall_quads = np.stack((subjects, relations, objects, timestamps, timestamps_orig), axis=1)\ntrain_data = all_quads[dataset.train_mask]\nval_data = all_quads[dataset.val_mask]\ntest_data = all_quads[dataset.test_mask]\n\nmetric = dataset.eval_metric\nevaluator = Evaluator(name=name)\nneg_sampler = dataset.negative_sampler\n\ntrain_val_data = np.concatenate([train_data, val_data])\nall_data = np.concatenate([train_data, val_data, test_data])\n\n# create dicts with key: relation id, values: triples for that relation id\nprint(\"grouping data by relation\")\ntest_data_prel = group_by(test_data, 1)\nall_data_prel = group_by(all_data, 1)\nval_data_prel = group_by(val_data, 1)\ntrainval_data_prel = group_by(train_val_data, 1)\n\n#load the ns samples \n# if parsed['train_flag']:\nprint(\"loading negative samples\")\ndataset.load_val_ns()\ndataset.load_test_ns()\n\n# parameter options\nif parsed['train_flag'] == 'True':\n    params_dict = {}\n    params_dict['lmbda_psi'] = [0, 0.0001, 0.0005, 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.5, 0.9, 1.0001] \n    params_dict['alpha'] = [0, 0.00001, 0.0001, 0.001, 0.01, 0.1, 0.5, 0.9, 0.99, 0.999, 0.9999, 0.99999, 1]\n    default_lmbda_psi = params_dict['lmbda_psi'][-1]\n    default_alpha = params_dict['alpha'][-2]\n\n## load rules\nprint(\"creating rules\")\nbasis_dict = create_basis_dict(train_val_data)\nprint(\"done with creating rules\")\n## init\n# rb_predictor = RecurrencyBaselinePredictor(rels)\n## train to find best lambda and alpha\nstart_train =  timeit.default_timer()\nif parsed['train_flag'] ==  'True':\n    if parsed['load_flag'] == 'True':\n        with open('best_config.json', 'r') as infile:\n            best_config = json.load(infile)\n    else:\n        print('start training')\n        best_config = train(params_dict,  rels, val_data_prel, trainval_data_prel, neg_sampler, parsed['num_processes'], \n            parsed['window'])\n        if parsed['save_config'] == 'True':\n            import json\n            with open('best_config.json', 'w') as outfile:\n                json.dump(best_config, outfile)\n\nelse: # use preset lmbda and alpha; same for all relations\n    best_config = {} \n    for rel in rels:\n        best_config[str(rel)] = {}\n        best_config[str(rel)]['lmbda_psi'] = [parsed['lmbda']]\n        best_config[str(rel)]['alpha'] = [parsed['alpha']]\n    \nend_train =  timeit.default_timer()\n\n# compute validation mrr\nprint(\"Computing validation MRR\")\nperf_list_all_val, hits_list_all_val = test(best_config,rels, val_data_prel, \n                                                trainval_data_prel, neg_sampler, parsed['num_processes'], \n                                            parsed['window'], split_mode='val')\nval_mrr = float(np.mean(perf_list_all_val))\n\n# compute test mrr\nprint(\"Computing test MRR\")\nstart_test =  timeit.default_timer()\nperf_list_all, hits_list_all = test(best_config,rels, test_data_prel, \n                                                 all_data_prel, neg_sampler, parsed['num_processes'], \n                                                parsed['window'])\n\nend_o =  timeit.default_timer()\ntrain_time_o = round(end_train- start_train, 6)  \ntest_time_o = round(end_o- start_test, 6)  \ntotal_time_o = round(end_o- start_o, 6)  \nprint(\"Running Training to find best configs finished in {} seconds.\".format(train_time_o))\nprint(\"Running testing with best configs finished in {} seconds.\".format(test_time_o))\nprint(\"Running all steps finished in {} seconds.\".format(total_time_o))\n\nprint(f\"The test MRR is {np.mean(perf_list_all)}\")\nprint(f\"The valid MRR is {val_mrr}\")\nprint(f\"The Hits@10 is {np.mean(hits_list_all)}\")\nprint(f\"We have {len(perf_list_all)} predictions\")\nprint(f\"The test set has len {len(test_data)} \")\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_NONE_{DATA}_results.json'\nmetric = dataset.eval_metric\nsave_results({'model': MODEL_NAME,\n              'train_flag': parsed['train_flag'],\n              'data': DATA,\n              'run': 1,\n              'seed': SEED,\n              metric: float(np.mean(perf_list_all)),\n              'val_mrr': val_mrr,\n              'hits10': float(np.mean(hits_list_all)),\n              'test_time': test_time_o,\n              'tot_train_val_time': total_time_o\n              }, \n    results_filename)\n\nif parsed['num_processes']>1:\n    ray.shutdown()\n\n\n    \n"
  },
  {
    "path": "examples/linkproppred/tkgl-wikidata/regcn.py",
    "content": "\"\"\"\nTemporal Knowledge Graph Reasoning Based on Evolutional Representation Learning\nReference:\n- https://github.com/Lee-zix/RE-GCN\nZixuan Li, Xiaolong Jin, Wei Li, Saiping Guan, Jiafeng Guo, Huawei Shen, Yuanzhuo Wang and Xueqi Cheng. Temporal \nKnowledge Graph Reasoning Based on Evolutional Representation Learning. SIGIR 2021.\n\"\"\"\nimport sys\nimport timeit\nimport os\nimport sys\nimport os.path as osp\nfrom pathlib import Path\nimport numpy as np\nimport torch\nimport random\nfrom tqdm import tqdm\n# internal imports\nmodules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\nsys.path.append(modules_path)\nfrom modules.rrgcn import RecurrentRGCNREGCN\nfrom tgb.utils.utils import set_random_seed, split_by_time, save_results\nfrom modules.tkg_utils import get_args_regcn, reformat_ts\nfrom modules.tkg_utils_dgl import build_sub_graph\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom tgb.linkproppred.dataset import LinkPropPredDataset \nimport json\n\ndef test(model, history_list, test_list, num_rels, num_nodes, use_cuda, model_name, static_graph, mode, split_mode):\n    \"\"\"\n    Test the model on either test or validation set\n    :param model: model used to test\n    :param history_list:    all input history snap shot list, not include output label train list or valid list\n    :param test_list:   test triple snap shot list\n    :param num_rels:    number of relations\n    :param num_nodes:   number of nodes\n    :param use_cuda:\n    :param model_name:\n    :param mode:\n    :param split_mode: 'test' or 'val' to state which negative samples to load\n    :return mrr\n    \"\"\"\n    print(\"Testing for mode: \", split_mode)\n    if split_mode == 'test':\n        timesteps_to_eval = test_timestamps_orig\n    else:\n        timesteps_to_eval = val_timestamps_orig\n\n    idx = 0\n\n    if mode == \"test\":\n        # test mode: load parameter form file \n        if use_cuda:\n            checkpoint = torch.load(model_name, map_location=torch.device(args.gpu))\n        else:\n            checkpoint = torch.load(model_name, map_location=torch.device('cpu'))        \n        # use best stat checkpoint:\n        print(\"Load Model name: {}. Using best epoch : {}\".format(model_name, checkpoint['epoch']))  \n        print(\"\\n\"+\"-\"*10+\"start testing\"+\"-\"*10+\"\\n\")\n        model.load_state_dict(checkpoint['state_dict'])\n\n    model.eval()\n    input_list = [snap for snap in history_list[-args.test_history_len:]] \n    perf_list_all = []\n    hits_list_all = []\n    perf_per_rel = {}\n    for rel in range(num_rels):\n        perf_per_rel[rel] = []\n\n    for time_idx, test_snap in enumerate(tqdm(test_list)):\n        history_glist = [build_sub_graph(num_nodes, num_rels, g, use_cuda, args.gpu) for g in input_list]    \n        test_triples_input = torch.LongTensor(test_snap).cuda() if use_cuda else torch.LongTensor(test_snap)\n        timesteps_batch =timesteps_to_eval[time_idx]*np.ones(len(test_triples_input[:,0]))\n\n        neg_samples_batch = neg_sampler.query_batch(test_triples_input[:,0], test_triples_input[:,2], \n                                timesteps_batch, edge_type=test_triples_input[:,1], split_mode=split_mode)\n        pos_samples_batch = test_triples_input[:,2]\n\n        _, perf_list, hits_list = model.predict(history_glist, num_rels, static_graph, test_triples_input, use_cuda, neg_samples_batch, pos_samples_batch, \n                                    evaluator, METRIC)  \n\n        perf_list_all.extend(perf_list)\n        hits_list_all.extend(hits_list)\n        if split_mode == \"test\":\n            if args.log_per_rel:\n                for score, rel in zip(perf_list, test_triples_input[:,1].tolist()):\n                    perf_per_rel[rel].append(score)\n        # reconstruct history graph list\n        input_list.pop(0)\n        input_list.append(test_snap)\n        idx += 1\n\n    if split_mode == \"test\":\n        if args.log_per_rel:   \n            for rel in range(num_rels):\n                if len(perf_per_rel[rel]) > 0:\n                    perf_per_rel[rel] = float(np.mean(perf_per_rel[rel]))\n                else:\n                    perf_per_rel.pop(rel)\n    mrr = np.mean(perf_list_all) \n    hits10 = np.mean(hits_list_all) \n    return mrr, perf_per_rel, hits10\n\n\n\ndef run_experiment(args, n_hidden=None, n_layers=None, dropout=None, n_bases=None):\n    \"\"\"\n    Run the experiment with the given configuration\n    :param args: arguments\n    :param n_hidden: hidden dimension\n    :param n_layers: number of layers\n    :param dropout: dropout rate\n    :param n_bases: number of bases\n    :return: mrr, perf_per_rel  (mean reciprocal rank, performance per relation)\n    \"\"\"\n    # load configuration for grid search the best configuration\n    if n_hidden:\n        args.n_hidden = n_hidden\n    if n_layers:\n        args.n_layers = n_layers\n    if dropout:\n        args.dropout = dropout\n    if n_bases:\n        args.n_bases = n_bases\n    mrr = 0\n    hits10=0\n    # 2) set save model path\n    save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'\n    if not osp.exists(save_model_dir):\n        os.mkdir(save_model_dir)\n        print('INFO: Create directory {}'.format(save_model_dir))\n    model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}'\n    model_state_file = save_model_dir+model_name\n    perf_per_rel = {}\n    use_cuda = args.gpu >= 0 and torch.cuda.is_available()\n\n    num_static_rels, num_words, static_triples, static_graph = 0, 0, [], None\n\n    # create stat\n    model = RecurrentRGCNREGCN(args.decoder,\n                          args.encoder,\n                        num_nodes,\n                        int(num_rels/2),\n                        num_static_rels, # DIFFERENT\n                        num_words, # DIFFERENT\n                        args.n_hidden,\n                        args.opn,\n                        sequence_len=args.train_history_len,\n                        num_bases=args.n_bases,\n                        num_basis=args.n_basis,\n                        num_hidden_layers=args.n_layers,\n                        dropout=args.dropout,\n                        self_loop=args.self_loop,\n                        skip_connect=args.skip_connect,\n                        layer_norm=args.layer_norm,\n                        input_dropout=args.input_dropout,\n                        hidden_dropout=args.hidden_dropout,\n                        feat_dropout=args.feat_dropout,\n                        aggregation=args.aggregation, # DIFFERENT\n                        weight=args.weight, # DIFFERENT\n                        discount=args.discount, # DIFFERENT\n                        angle=args.angle, # DIFFERENT\n                        use_static=args.add_static_graph, # DIFFERENT\n                        entity_prediction=args.entity_prediction, \n                        relation_prediction=args.relation_prediction,\n                        use_cuda=use_cuda,\n                        gpu = args.gpu,\n                        analysis=args.run_analysis) # DIFFERENT\n\n    if use_cuda:\n        torch.cuda.set_device(args.gpu)\n        model.cuda()\n\n    # optimizer\n    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)\n\n    if args.test and os.path.exists(model_state_file):\n        mrr, perf_per_rel, hits10 = test(model, \n                    train_list+valid_list, \n                    test_list, \n                    num_rels, \n                    num_nodes, \n                    use_cuda, \n                    model_state_file, \n                    static_graph, \n                    \"test\", \n                    \"test\")\n        return mrr, perf_per_rel, hits10\n    elif args.test and not os.path.exists(model_state_file):\n        print(\"--------------{} not exist, Change mode to train and generate stat for testing----------------\\n\".format(model_state_file))\n        return 0, 0\n    else:\n        print(\"----------------------------------------start training----------------------------------------\\n\")\n        best_mrr = 0\n        best_hits = 0\n        for epoch in range(args.n_epochs):\n\n            model.train()\n            losses = []\n            idx = [_ for _ in range(len(train_list))]\n            random.shuffle(idx)\n\n            for train_sample_num in tqdm(idx):\n                if train_sample_num == 0: continue\n                output = train_list[train_sample_num:train_sample_num+1]\n                if train_sample_num - args.train_history_len<0:\n                    input_list = train_list[0: train_sample_num]\n                else:\n                    input_list = train_list[train_sample_num - args.train_history_len:\n                                        train_sample_num]\n\n                # generate history graph\n                history_glist = [build_sub_graph(num_nodes, num_rels, snap, use_cuda, args.gpu) for snap in input_list]\n                output = [torch.from_numpy(_).long().cuda() for _ in output] if use_cuda else [torch.from_numpy(_).long() for _ in output]\n                loss_e, loss_r, loss_static = model.get_loss(history_glist, output[0], static_graph, use_cuda)\n                loss = args.task_weight*loss_e + (1-args.task_weight)*loss_r + loss_static\n\n                losses.append(loss.item())\n\n                loss.backward()\n                torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm)  # clip gradients\n                optimizer.step()\n                optimizer.zero_grad()\n\n            print(\"Epoch {:04d} | Ave Loss: {:.4f} | Best MRR {:.4f} | Model {} \"\n                  .format(epoch, np.mean(losses),  best_mrr, model_name))\n            \n            #! checking GPU usage\n            free_mem, total_mem = torch.cuda.mem_get_info()\n            print (\"--------------GPU memory usage-----------\")\n            print (\"there are \", free_mem, \" free memory\")\n            print (\"there are \", total_mem, \" total available memory\")\n            print (\"there are \", total_mem - free_mem, \" used memory\")\n            print (\"--------------GPU memory usage-----------\")\n\n            # validation\n            if epoch and epoch % args.evaluate_every == 0:\n                mrr,perf_per_rel, hits10 = test(model, train_list, \n                            valid_list, \n                            num_rels, \n                            num_nodes, \n                            use_cuda, \n                            model_state_file, \n                            static_graph, \n                            mode=\"train\", split_mode='val')\n            \n                if mrr < best_mrr:\n                    if epoch >= args.n_epochs:\n                        break\n                else:\n                    best_mrr = mrr\n                    best_hits = hits10\n                    torch.save({'state_dict': model.state_dict(), 'epoch': epoch}, model_state_file)\n\n        return best_mrr, perf_per_rel, hits10\n# ==================\n# ==================\n# ==================\n\nstart_overall = timeit.default_timer()\n\n# set hyperparameters\nargs, _ = get_args_regcn()\nargs.dataset = 'tkgl-wikidata'\n\nSEED = args.seed  # set the random seed for consistency\nset_random_seed(SEED)\n\nDATA=args.dataset\nMODEL_NAME = 'REGCN'\n\nprint(\"logging mrrs per relation: \", args.log_per_rel)\n\n# load data\ndataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\n\nnum_rels = dataset.num_rels\nnum_nodes = dataset.num_nodes \nsubjects = dataset.full_data[\"sources\"]\nobjects= dataset.full_data[\"destinations\"]\nrelations = dataset.edge_type\n\ntimestamps_orig = dataset.full_data[\"timestamps\"]\ntimestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1\nall_quads = np.stack((subjects, relations, objects, timestamps), axis=1)\n\ntrain_data = all_quads[dataset.train_mask]\nval_data = all_quads[dataset.val_mask]\ntest_data = all_quads[dataset.test_mask]\n\nval_timestamps_orig = list(set(timestamps_orig[dataset.val_mask])) # needed for getting the negative samples\nval_timestamps_orig.sort()\ntest_timestamps_orig = list(set(timestamps_orig[dataset.test_mask])) # needed for getting the negative samples\ntest_timestamps_orig.sort()\n\ntrain_list = split_by_time(train_data)\nvalid_list = split_by_time(val_data)\ntest_list = split_by_time(test_data)\n\n# evaluation metric\nMETRIC = dataset.eval_metric\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n#load the ns samples \ndataset.load_val_ns()\ndataset.load_test_ns()\n\n## run training and testing\nval_mrr, test_mrr = 0, 0\ntest_hits10 = 0\nif args.grid_search:\n    print(\"hyperparameter grid search not implemented. Exiting.\")\n# single run\nelse:\n    start_train = timeit.default_timer()\n    if args.test == False: #if they are true: directly test on a previously trained and stored model\n        print('start training')\n        val_mrr, perf_per_rel, val_hits10 = run_experiment(args) # do training\n    start_test = timeit.default_timer()\n    args.test = True\n    print('start testing')\n    test_mrr, perf_per_rel, test_hits10 = run_experiment(args) # do testing\n\n\ntest_time = timeit.default_timer() - start_test\nall_time = timeit.default_timer() - start_train\nprint(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\nprint(f\"\\Train and Test: Elapsed Time (s): {all_time: .4f}\")\n\nprint(f\"\\tTest: {METRIC}: {test_mrr: .4f}\")\nprint(f\"\\tValid: {METRIC}: {val_mrr: .4f}\")\n\n# saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'\n\nsave_results({'model': MODEL_NAME,\n              'data': DATA,\n              'run': args.run_nr,\n              'seed': SEED,\n              f'val {METRIC}': float(val_mrr),\n              f'test {METRIC}': float(test_mrr),\n              'test_time': test_time,\n              'tot_train_val_time': all_time,\n              'test_hits10': float(test_hits10)\n              }, \n    results_filename)\n\nif args.log_per_rel:\n    results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results_per_rel.json'\n    with open(results_filename, 'w') as json_file:\n        json.dump(perf_per_rel, json_file)\n\nsys.exit()"
  },
  {
    "path": "examples/linkproppred/tkgl-wikidata/tkgl-wikidata_example.py",
    "content": "import numpy as np\r\nimport timeit\r\nfrom tqdm import tqdm\r\nimport os.path as osp\r\nimport sys\r\nimport os\r\nmodules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\r\nsys.path.append(modules_path)\r\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\r\nfrom torch_geometric.loader import TemporalDataLoader\r\nfrom tgb.linkproppred.evaluate import Evaluator\r\n\r\nDATA = \"tkgl-wikidata\"\r\n\r\n# data loading\r\ndataset = PyGLinkPropPredDataset(name=DATA, root=\"datasets\")\r\ntrain_mask = dataset.train_mask\r\nval_mask = dataset.val_mask\r\ntest_mask = dataset.test_mask\r\ndata = dataset.get_TemporalData()\r\nmetric = dataset.eval_metric\r\n\r\nprint (\"there are {} nodes and {} edges\".format(dataset.num_nodes, dataset.num_edges))\r\nprint (\"there are {} relation types\".format(dataset.num_rels))\r\n\r\n\r\ntimestamp = data.t\r\nhead = data.src\r\ntail = data.dst\r\nedge_type = data.edge_type #relation\r\nneg_sampler = dataset.negative_sampler\r\n\r\ntrain_data = data[train_mask]\r\nval_data = data[val_mask]\r\ntest_data = data[test_mask]\r\n\r\n\r\nmetric = dataset.eval_metric\r\nevaluator = Evaluator(name=DATA)\r\nneg_sampler = dataset.negative_sampler\r\n\r\nBATCH_SIZE = 1 ## 200\r\nval_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)\r\ntest_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)\r\n\r\nstart_time = timeit.default_timer()\r\n#load the ns samples first\r\ndataset.load_val_ns()\r\nfor batch in tqdm(val_loader):\r\n    src, pos_dst, t, msg, rel = batch.src, batch.dst, batch.t, batch.msg, batch.edge_type\r\n    neg_batch_list = neg_sampler.query_batch(src.detach().cpu().numpy(), pos_dst.detach().cpu().numpy(), t.detach().cpu().numpy(), rel.detach().cpu().numpy(), split_mode='val')\r\n    \r\n    if len(neg_batch_list[0]) > 1500:\r\n        print(rel, len(neg_batch_list[0]))\r\nprint (\"loading ns samples from validation\", timeit.default_timer() - start_time)\r\n\r\nstart_time = timeit.default_timer()\r\ndataset.load_test_ns()\r\nfor batch in test_loader:\r\n    src, pos_dst, t, msg, rel = batch.src, batch.dst, batch.t, batch.msg, batch.edge_type\r\n    neg_batch_list = neg_sampler.query_batch(src.detach().cpu().numpy(), pos_dst.detach().cpu().numpy(), t.detach().cpu().numpy(), rel.detach().cpu().numpy(), split_mode='test')\r\nprint (\"loading ns samples from test\", timeit.default_timer() - start_time)\r\nprint (\"retrieved all negative samples\")\r\n\r\n\r\n# #* load numpy arrays instead\r\n# from tgb.linkproppred.dataset import LinkPropPredDataset\r\n\r\n# # data loading\r\n# dataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\r\n# data = dataset.full_data  \r\n# metric = dataset.eval_metric\r\n# sources = dataset.full_data['sources']\r\n# print (sources.dtype)\r\n\r\n"
  },
  {
    "path": "examples/linkproppred/tkgl-wikidata/tlogic.py",
    "content": "\"\"\"\nhttps://github.com/liu-yushan/TLogic/tree/main/mycode\nTLogic: Temporal Logical Rules for Explainable Link Forecasting on Temporal Knowledge Graphs.\nYushan Liu, Yunpu Ma, Marcel Hildebrandt, Mitchell Joblin, Volker Tresp\n\"\"\"\n\n# imports\nimport sys\nimport os\nimport os.path as osp\nfrom pathlib import Path\nmodules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\nsys.path.append(modules_path)\nimport timeit\nimport argparse\nimport numpy as np\nimport json\nfrom joblib import Parallel, delayed\nimport itertools\n\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom tgb.linkproppred.dataset import LinkPropPredDataset \nfrom modules.tlogic_learn_modules import Temporal_Walk, Rule_Learner, store_edges\nimport modules.tlogic_apply_modules as ra\nfrom tgb.utils.utils import set_random_seed,  save_results\nfrom modules.tkg_utils import reformat_ts, get_inv_relation_id, create_scores_array\n\ndef learn_rules(i, num_relations):\n    \"\"\"\n    Learn rules (multiprocessing possible).\n\n    Parameters:\n        i (int): process number\n        num_relations (int): minimum number of relations for each process\n\n    Returns:\n        rl.rules_dict (dict): rules dictionary\n    \"\"\"\n\n    # if seed:\n    #     np.random.seed(seed)\n\n    num_rest_relations = len(all_relations) - (i + 1) * num_relations\n    if num_rest_relations >= num_relations:\n        relations_idx = range(i * num_relations, (i + 1) * num_relations)\n    else:\n        relations_idx = range(i * num_relations, len(all_relations))\n\n    num_rules = [0]\n    for k in relations_idx:\n        rel = all_relations[k]\n        for length in rule_lengths:\n            it_start =  timeit.default_timer()\n            for _ in range(num_walks):\n                walk_successful, walk = temporal_walk.sample_walk(length + 1, rel)\n                if walk_successful:\n                    rl.create_rule(walk)\n            it_end =  timeit.default_timer()\n            it_time = round(it_end - it_start, 6)\n            num_rules.append(sum([len(v) for k, v in rl.rules_dict.items()]) // 2)\n            num_new_rules = num_rules[-1] - num_rules[-2]\n            print(\n                \"Process {0}: relation {1}/{2}, length {3}: {4} sec, {5} rules\".format(\n                    i,\n                    k - relations_idx[0] + 1,\n                    len(relations_idx),\n                    length,\n                    it_time,\n                    num_new_rules,\n                )\n            )\n\n    return rl.rules_dict\n\ndef apply_rules(i, num_queries, rules_dict, neg_sampler, data, window, learn_edges, all_quads, args, split_mode, \n                log_per_rel=False, num_rels=0):\n    \"\"\"\n    Apply rules (multiprocessing possible).\n\n    Parameters:\n        i (int): process number\n        num_queries (int): minimum number of queries for each process\n\n    Returns:\n        hits_list (list): hits list (hits@10 per sample)\n        perf_list (list): performance list (mrr per sample)\n    \"\"\"\n    perf_per_rel = {}\n    for rel in range(num_rels):\n            perf_per_rel[rel] = []\n    print(\"Start process\", i, \"...\")\n    all_candidates = [dict() for _ in range(len(args))]\n    no_cands_counter = 0\n\n    num_rest_queries = len(data) - (i + 1) * num_queries\n    if num_rest_queries >= num_queries:\n        test_queries_idx = range(i * num_queries, (i + 1) * num_queries)\n    else:\n        test_queries_idx = range(i * num_queries, len(data))\n\n    cur_ts = data[test_queries_idx[0]][3]\n    edges = ra.get_window_edges(all_quads[:,0:4], cur_ts, learn_edges, window)\n\n    it_start =  timeit.default_timer()\n    hits_list = [0] * len(test_queries_idx)\n    perf_list = [0] * len(test_queries_idx)\n    for index, j in enumerate(test_queries_idx):\n        neg_sample_el =  neg_sampler.query_batch(np.expand_dims(np.array(data[j,0]), axis=0), \n                                                np.expand_dims(np.array(data[j,2]), axis=0), \n                                                np.expand_dims(np.array(data[j,4]), axis=0), \n                                                np.expand_dims(np.array(data[j,1]), axis=0), \n                                                split_mode=split_mode)[0]        \n        \n        # neg_samples_batch[j]\n        pos_sample_el =  data[j,2]\n        test_query = data[j]\n        assert pos_sample_el == test_query[2]\n        cands_dict = [dict() for _ in range(len(args))]\n\n        if test_query[3] != cur_ts:\n            cur_ts = test_query[3]\n            edges = ra.get_window_edges(all_quads[:,0:4], cur_ts, learn_edges, window)\n\n        if test_query[1] in rules_dict:\n            dicts_idx = list(range(len(args)))\n            for rule in rules_dict[test_query[1]]:\n                walk_edges = ra.match_body_relations(rule, edges, test_query[0])\n\n                if 0 not in [len(x) for x in walk_edges]:\n                    rule_walks = ra.get_walks(rule, walk_edges)\n                    if rule[\"var_constraints\"]:\n                        rule_walks = ra.check_var_constraints(\n                            rule[\"var_constraints\"], rule_walks\n                        )\n\n                    if not rule_walks.empty:\n                        cands_dict = ra.get_candidates(\n                            rule,\n                            rule_walks,\n                            cur_ts,\n                            cands_dict,\n                            score_func,\n                            args,\n                            dicts_idx,\n                        )\n                        for s in dicts_idx:\n                            cands_dict[s] = {\n                                x: sorted(cands_dict[s][x], reverse=True)\n                                for x in cands_dict[s].keys()\n                            }\n                            cands_dict[s] = dict(\n                                sorted(\n                                    cands_dict[s].items(),\n                                    key=lambda item: item[1],\n                                    reverse=True,\n                                )\n                            )\n                            top_k_scores = [v for _, v in cands_dict[s].items()][:top_k]\n                            unique_scores = list(\n                                scores for scores, _ in itertools.groupby(top_k_scores)\n                            )\n                            if len(unique_scores) >= top_k:\n                                dicts_idx.remove(s)\n                        if not dicts_idx:\n                            break\n\n            if cands_dict[0]:\n                for s in range(len(args)):\n                    # Calculate noisy-or scores\n                    scores = list(\n                        map(\n                            lambda x: 1 - np.product(1 - np.array(x)),\n                            cands_dict[s].values(),\n                        )\n                    )\n                    cands_scores = dict(zip(cands_dict[s].keys(), scores))\n                    noisy_or_cands = dict(\n                        sorted(cands_scores.items(), key=lambda x: x[1], reverse=True)\n                    )\n                    all_candidates[s][j] = noisy_or_cands\n            else:  # No candidates found by applying rules\n                no_cands_counter += 1\n                for s in range(len(args)):\n                    all_candidates[s][j] = dict()\n\n        else:  # No rules exist for this relation\n            no_cands_counter += 1\n            for s in range(len(args)):\n                all_candidates[s][j] = dict()\n\n        if not (j - test_queries_idx[0] + 1) % 100:\n            it_end =  timeit.default_timer()\n            it_time = round(it_end - it_start, 6)\n            print(\n                \"Process {0}: test samples finished: {1}/{2}, {3} sec\".format(\n                    i, j - test_queries_idx[0] + 1, len(test_queries_idx), it_time\n                )\n            )\n            it_start =  timeit.default_timer()\n\n        predictions = create_scores_array(all_candidates[s][j], num_nodes)  \n        predictions_of_interest_pos = np.array(predictions[pos_sample_el])\n        predictions_of_interest_neg = predictions[neg_sample_el]\n        input_dict = {\n            \"y_pred_pos\": predictions_of_interest_pos,\n            \"y_pred_neg\": predictions_of_interest_neg,\n            \"eval_metric\": ['mrr'], \n        }\n\n        predictions = evaluator.eval(input_dict)\n        perf_list[index] = predictions['mrr']\n        hits_list[index] = predictions['hits@10']\n        if split_mode == \"test\":\n            if log_per_rel:\n                perf_per_rel[test_query[1]].append(perf_list[index]) #test_query[1] is the relation index\n\n    if split_mode == \"test\":\n        if log_per_rel:   \n            for rel in range(num_rels):\n                if len(perf_per_rel[rel]) > 0:\n                    perf_per_rel[rel] = float(np.mean(perf_per_rel[rel]))\n                else:\n                    perf_per_rel.pop(rel)       \n               \n\n    return perf_list, hits_list, perf_per_rel\n\n\n## args\ndef get_args(): \n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--dataset\", \"-d\", default=\"tkgl-wikidata\", type=str) \n    parser.add_argument(\"--rule_lengths\", \"-l\", default=\"1\", type=int, nargs=\"+\")\n    parser.add_argument(\"--num_walks\", \"-n\", default=\"100\", type=int)\n    parser.add_argument(\"--transition_distr\", default=\"exp\", type=str)\n    parser.add_argument(\"--window\", \"-w\", default=0, type=int) # set to e.g. 200 if only the most recent 200 timesteps should be considered. set to -2 if multistep\n    parser.add_argument(\"--top_k\", default=20, type=int)\n    parser.add_argument(\"--num_processes\", \"-p\", default=1, type=int)\n    parser.add_argument(\"--alpha\", \"-alpha\",  default=0.99, type=float) # fix alpha. used if trainflag == false\n    # parser.add_argument(\"--train_flag\", \"-tr\",  default=True) # do we need training, ie selection of lambda and alpha\n    parser.add_argument(\"--save_config\", \"-c\",  default=True) # do we need to save the selection of lambda and alpha in config file?\n    parser.add_argument('--seed', type=int, help='Random seed', default=1)\n    parser.add_argument('--run_nr', type=int, help='Run Number', default=1)\n    parser.add_argument('--learn_rules_flag', type=bool, help='Do we want to learn the rules', default=True)\n    parser.add_argument('--rule_filename', type=str, help='if rules not learned: where are they stored', default='0_r[3]_n100_exp_s1_rules.json')\n    parser.add_argument('--log_per_rel', type=bool, help='Do we want to log mrr per relation', default=False)\n    parser.add_argument('--compute_valid_mrr', type=bool, help='Do we want to compute mrr for valid set', default=True)\n    parsed = vars(parser.parse_args())\n    return parsed\n\nstart_o =  timeit.default_timer()\n\n## get args\nparsed = get_args()\ndataset = parsed[\"dataset\"]\nrule_lengths = parsed[\"rule_lengths\"]\nrule_lengths = [rule_lengths] if (type(rule_lengths) == int) else rule_lengths\nnum_walks = parsed[\"num_walks\"]\ntransition_distr = parsed[\"transition_distr\"]\nnum_processes = parsed[\"num_processes\"]\nwindow = parsed[\"window\"]\ntop_k = parsed[\"top_k\"]\nlog_per_rel = parsed['log_per_rel']\n\nMODEL_NAME = 'TLogic'\nSEED = parsed['seed']  # set the random seed for consistency\nset_random_seed(SEED)\n\n## load dataset and prepare it accordingly\nname = parsed[\"dataset\"]\ncompute_valid_mrr = parsed[\"compute_valid_mrr\"]\ndataset = LinkPropPredDataset(name=name, root=\"datasets\", preprocess=True)\nDATA = name\n\nrelations = dataset.edge_type\nnum_rels = dataset.num_rels\n\nsubjects = dataset.full_data[\"sources\"]\nobjects= dataset.full_data[\"destinations\"]\nnum_nodes = dataset.num_nodes \ntimestamps_orig = dataset.full_data[\"timestamps\"]\ntimestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1\n\nall_quads = np.stack((subjects, relations, objects, timestamps, timestamps_orig), axis=1)\ntrain_data = all_quads[dataset.train_mask,0:4] # we do not need the original timestamps\nval_data = all_quads[dataset.val_mask,0:5]\ntest_data = all_quads[dataset.test_mask,0:5]\nall_data = all_quads[:,0:4]\n\nmetric = dataset.eval_metric\nevaluator = Evaluator(name=name)\nneg_sampler = dataset.negative_sampler\n\ninv_relation_id = get_inv_relation_id(num_rels)\n\n#load the ns samples \n\ndataset.load_val_ns()\ndataset.load_test_ns()\noutput_dir =  f'{osp.dirname(osp.abspath(__file__))}/saved_models/'\nlearn_rules_flag = parsed['learn_rules_flag']\n## 1. learn rules\nstart_train =  timeit.default_timer()\nif learn_rules_flag:\n    print(\"start learning rules\")\n    # edges (dict): edges for each relation\n    # inv_relation_id (dict): mapping of relation to inverse relation\n    \n    temporal_walk = Temporal_Walk(train_data, inv_relation_id, transition_distr)\n    rl = Rule_Learner(edges=temporal_walk.edges, id2relation=None, inv_relation_id=inv_relation_id,  \n                        output_dir=output_dir)\n    all_relations = sorted(temporal_walk.edges)  # Learn for all relations\n\n    start =  timeit.default_timer()\n    num_relations = len(all_relations) // num_processes\n    output = Parallel(n_jobs=num_processes)(\n        delayed(learn_rules)(i, num_relations) for i in range(num_processes)\n    )\n    end =  timeit.default_timer()\n\n    all_rules = output[0]\n    for i in range(1, num_processes):\n        all_rules.update(output[i])\n\n    total_time = round(end - start, 6)\n    print(\"Learning finished in {} seconds.\".format(total_time))\n\n    rl.rules_dict = all_rules\n    rl.sort_rules_dict()\n\n    rule_filename = rl.save_rules(0, rule_lengths, num_walks, transition_distr, SEED)\n    # rl.save_rules_verbalized(0, rule_lengths, num_walks, transition_distr, seed)\n    # rules_statistics(rl.rules_dict)\nelse:\n    rule_filename = parsed['rule_filename']\n    print(\"Loading rules from file {}\".format(parsed['rule_filename']))\n\nend_train =  timeit.default_timer()\n\n## 2. Apply rules\n\nrules_dict = json.load(open(output_dir + rule_filename))\nrules_dict = {int(k): v for k, v in rules_dict.items()}\n\nrules_dict = ra.filter_rules(\n    rules_dict, min_conf=0.01, min_body_supp=2, rule_lengths=rule_lengths\n) # filter rules for minimum confidence, body support and rule length\n\nlearn_edges = store_edges(train_data)\nscore_func = ra.score_12\n# It is possible to specify a list of list of arguments for tuning\nargs = [[0.1, 0.5]]\n\n# compute valid mrr\nstart_valid =  timeit.default_timer()\nif compute_valid_mrr:\n    print('Computing valid MRR')\n\n    num_queries = len(val_data) // num_processes\n\n    output = Parallel(n_jobs=num_processes)(\n        delayed(apply_rules)(i, num_queries,rules_dict, neg_sampler, val_data, window, learn_edges, \n                            all_quads, args, split_mode='val') for i in range(num_processes))\n    end =  timeit.default_timer()\n\n    perf_list_val = []\n    hits_list_val = []\n\n    for i in range(num_processes):\n        perf_list_val.extend(output[i][0])\n        hits_list_val.extend(output[i][1])\nelse:\n    perf_list_val = [0]\n    hits_list_val = [0]\n    \n\nend_valid =  timeit.default_timer()\n\n# compute test mrr\nif log_per_rel ==True:\n    num_processes = 1 #otherwise logging per rel does not work for our implementation\nstart_test =  timeit.default_timer()\nprint('Computing test MRR')\nstart =  timeit.default_timer()\nnum_queries = len(test_data) // num_processes\n\noutput = Parallel(n_jobs=num_processes)(\n    delayed(apply_rules)(i, num_queries,rules_dict, neg_sampler, test_data, window, learn_edges, \n                         all_quads, args, split_mode='test', log_per_rel=log_per_rel, num_rels=num_rels) for i in range(num_processes))\nend =  timeit.default_timer()\n\nperf_list_all = []\nhits_list_all = []\n\n\nfor i in range(num_processes):\n    perf_list_all.extend(output[i][0])\n    hits_list_all.extend(output[i][1])\nif log_per_rel == True:\n    perf_per_rel = output[0][2]\n\n\ntotal_time = round(end - start, 6)\ntotal_valid_time = round(end_valid - start_valid, 6)\nprint(\"Application finished in {} seconds.\".format(total_time))\n\nprint(f\"The valid MRR is {np.mean(perf_list_val)}\")\nprint(f\"The MRR is {np.mean(perf_list_all)}\")\nprint(f\"The Hits@10 is {np.mean(hits_list_all)}\")\nprint(f\"We have {len(perf_list_all)} predictions\")\nprint(f\"The test set has len {len(test_data)} \")\n\nend_o =  timeit.default_timer()\ntrain_time_o = round(end_train- start_train, 6)  \ntest_time_o = round(end_o- start_test, 6)  \ntotal_time_o = round(end_o- start_o, 6)  \nprint(\"Running Training to find best configs finished in {} seconds.\".format(train_time_o))\nprint(\"Running testing with best configs finished in {} seconds.\".format(test_time_o))\nprint(\"Running all steps finished in {} seconds.\".format(total_time_o))\n\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\n\nif log_per_rel == True:\n    results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results_per_rel.json'\n    with open(results_filename, 'w') as json_file:\n        json.dump(perf_per_rel, json_file)\n\nresults_filename = f'{results_path}/{MODEL_NAME}_NONE_{DATA}_results.json'\nmetric = dataset.eval_metric\nsave_results({'model': MODEL_NAME,\n              'train_flag': None,\n              'rule_len': rule_lengths,\n              'window': window,\n              'data': DATA,\n              'run': 1,\n              'seed': SEED,\n              metric: float(np.mean(perf_list_all)),\n              'hits10': float(np.mean(hits_list_all)),\n              'val_mrr': float(np.mean(perf_list_val)),\n              'test_time': test_time_o,\n              'tot_train_val_time': total_time_o,\n              'valid_time': total_valid_time\n              }, \n    results_filename)\n"
  },
  {
    "path": "examples/linkproppred/tkgl-yago/cen.py",
    "content": "\"\"\"\nComplex Evolutional Pattern Learning for Temporal Knowledge Graph Reasoning\nReference:\n- https://github.com/Lee-zix/CEN\nZixuan Li, Saiping Guan, Xiaolong Jin, Weihua Peng, Yajuan Lyu , Yong Zhu, Long Bai, Wei Li, Jiafeng Guo, Xueqi Cheng. \nComplex Evolutional Pattern Learning for Temporal Knowledge Graph Reasoning. ACL 2022.\n\"\"\"\nimport timeit\nimport os\nimport sys\nimport os.path as osp\nfrom pathlib import Path\nimport numpy as np\nimport torch\nimport random\nfrom tqdm import tqdm\nimport json\n# internal imports\nmodules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\nsys.path.append(modules_path)\nfrom modules.rrgcn import RecurrentRGCNCEN\nfrom tgb.utils.utils import set_random_seed, split_by_time,  save_results\nfrom modules.tkg_utils import get_args_cen, reformat_ts\nfrom modules.tkg_utils_dgl import build_sub_graph\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom tgb.linkproppred.dataset import LinkPropPredDataset \n\ndef test(model, history_len, history_list, test_list, num_rels, num_nodes, use_cuda, model_name, mode, split_mode):\n    \"\"\"\n    Test the model\n    :param model: model used to test\n    :param history_list:    all input history snap shot list, not include output label train list or valid list\n    :param test_list:   test triple snap shot list\n    :param num_rels:    number of relations\n    :param num_nodes:   number of nodes\n    :param use_cuda:\n    :param model_name:\n    :param mode:\n    :param split_mode: 'test' or 'val' to state which negative samples to load\n    :return mrr\n    \"\"\"\n    print(\"Testing for mode: \", split_mode)\n    if split_mode == 'test':\n        timesteps_to_eval = test_timestamps_orig\n    else:\n        timesteps_to_eval = val_timestamps_orig\n\n    idx = 0\n\n    if mode == \"test\":\n        # test mode: load parameter form file \n        if use_cuda:\n            checkpoint = torch.load(model_name, map_location=torch.device(args.gpu))\n        else:\n            checkpoint = torch.load(model_name, map_location=torch.device('cpu'))        \n        # use best stat checkpoint:\n        print(\"Load Model name: {}. Using best epoch : {}\".format(model_name, checkpoint['epoch']))  \n        print(\"\\n\"+\"-\"*10+\"start testing\"+\"-\"*10+\"\\n\")\n        model.load_state_dict(checkpoint['state_dict'])\n\n    model.eval()\n\n    input_list = [snap for snap in history_list[-history_len:]] \n    perf_list_all = []\n    hits_list_all = []\n    perf_per_rel = {}\n    for rel in range(num_rels):\n        perf_per_rel[rel] = []\n\n    for time_idx, test_snap in enumerate(tqdm(test_list)):\n        history_glist = [build_sub_graph(num_nodes, num_rels, g, use_cuda, args.gpu) for g in input_list]    \n        test_triples_input = torch.LongTensor(test_snap).cuda() if use_cuda else torch.LongTensor(test_snap)\n        timesteps_batch =timesteps_to_eval[time_idx]*np.ones(len(test_triples_input[:,0]))\n\n        neg_samples_batch = neg_sampler.query_batch(test_triples_input[:,0], test_triples_input[:,2], \n                                timesteps_batch, edge_type=test_triples_input[:,1], split_mode=split_mode)\n        pos_samples_batch = test_triples_input[:,2]\n\n        _, perf_list, hits_list = model.predict(history_glist, test_triples_input, use_cuda, neg_samples_batch, pos_samples_batch, \n                                    evaluator, METRIC) \n\n        perf_list_all.extend(perf_list)\n        hits_list_all.extend(hits_list)\n        if split_mode == \"test\":\n            if args.log_per_rel:\n                for score, rel in zip(perf_list, test_triples_input[:,1].tolist()):\n                    perf_per_rel[rel].append(score)\n        # reconstruct history graph list\n        input_list.pop(0)\n        input_list.append(test_snap)\n        idx += 1\n    if split_mode == \"test\":\n        if args.log_per_rel:   \n            for rel in range(num_rels):\n                if len(perf_per_rel[rel]) > 0:\n                    perf_per_rel[rel] = float(np.mean(perf_per_rel[rel]))\n                else:\n                    perf_per_rel.pop(rel)\n    mrr = np.mean(perf_list_all) \n    hits10 = np.mean(hits_list_all)\n    return mrr, perf_per_rel, hits10\n\n\n\ndef run_experiment(args, trainvalidtest_id=0, n_hidden=None, n_layers=None, dropout=None, n_bases=None):\n    '''\n    Run experiment for CEN model\n    :param args: arguments for the model\n    :param trainvalidtest_id: -1: pretrainig, 0: curriculum training (to find best test history len), 1: test on valid set, 2: test on test set\n    :param n_hidden: number of hidden units\n    :param n_layers: number of layers\n    :param dropout: dropout rate\n    :param n_bases: number of bases\n    return: mrr, perf_per_rel: mean reciprocal rank and performance per relation\n    '''\n    # 1) load configuration for grid search the best configuration\n    if n_hidden:\n        args.n_hidden = n_hidden\n    if n_layers:\n        args.n_layers = n_layers\n    if dropout:\n        args.dropout = dropout\n    if n_bases:\n        args.n_bases = n_bases\n    test_history_len = args.test_history_len\n    # 2) set save model path\n    save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'\n    if not osp.exists(save_model_dir):\n        os.mkdir(save_model_dir)\n        print('INFO: Create directory {}'.format(save_model_dir))\n    test_model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{args.test_history_len}'\n    test_state_file = save_model_dir+test_model_name\n    perf_per_rel ={}\n    use_cuda = args.gpu >= 0 and torch.cuda.is_available()\n    # create stat\n\n    model = RecurrentRGCNCEN(args.decoder,\n                            args.encoder,\n                            num_nodes,\n                            num_rels,\n                            args.n_hidden,\n                            args.opn,\n                            sequence_len=args.train_history_len,\n                            num_bases=args.n_bases,\n                            num_basis=args.n_basis,\n                            num_hidden_layers=args.n_layers,\n                            dropout=args.dropout,\n                            self_loop=args.self_loop,\n                            skip_connect=args.skip_connect,\n                            layer_norm=args.layer_norm,\n                            input_dropout=args.input_dropout,\n                            hidden_dropout=args.hidden_dropout,\n                            feat_dropout=args.feat_dropout,\n                            entity_prediction=args.entity_prediction,\n                            relation_prediction=args.relation_prediction,\n                            use_cuda=use_cuda,\n                            gpu = args.gpu)\n    if use_cuda:\n        torch.cuda.set_device(args.gpu)\n        model.cuda()\n\n    # optimizer\n    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)\n    \n    if trainvalidtest_id == 1:  # normal test on validation set  Note that mode=test\n        if os.path.exists(test_state_file):\n            mrr, _, hits10 = test(model, args.test_history_len, train_list, valid_list, num_rels, num_nodes, use_cuda, \n                        test_state_file, \"test\", split_mode=\"val\")      \n        else:\n            print('Cannot do testing because model does not exist: ', test_state_file)\n            mrr = 0\n            hits10 = 0\n    elif trainvalidtest_id == 2: # normal test on test set\n        if os.path.exists(test_state_file):\n            mrr, perf_per_rel, hits10 = test(model, args.test_history_len, train_list+valid_list, test_list, num_rels, num_nodes, use_cuda, \n                        test_state_file, \"test\", split_mode=\"test\")\n        else:\n            print('Cannot do testing because model does not exist: ', test_state_file)\n            mrr = 0\n            hits10 = 0\n    elif trainvalidtest_id == -1:\n        print(\"-------------start pre training model with history length {}----------\\n\".format(args.start_history_len))\n        model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{args.start_history_len}'\n        model_state_file = save_model_dir + model_name\n        print(\"Sanity Check: stat name : {}\".format(model_state_file))\n        print(\"Sanity Check: Is cuda available ? {}\".format(torch.cuda.is_available()))\n            \n        best_mrr = 0\n        best_epoch = 0\n        best_hits10= 0\n\n        ## training loop\n        for epoch in range(args.n_epochs):\n            model.train()\n            losses = []\n\n            idx = [_ for _ in range(len(train_list))]\n            random.shuffle(idx)\n            for train_sample_num in idx:\n                if train_sample_num == 0 or train_sample_num == 1: continue\n                if train_sample_num - args.start_history_len<0:\n                    input_list = train_list[0: train_sample_num]\n                    output = train_list[1:train_sample_num+1]\n                else:\n                    input_list = train_list[train_sample_num-args.start_history_len: train_sample_num]\n                    output = train_list[train_sample_num-args.start_history_len+1:train_sample_num+1]\n\n                # generate history graph\n                history_glist = [build_sub_graph(num_nodes, num_rels, snap, use_cuda, args.gpu) for snap in input_list]\n                output = [torch.from_numpy(_).long().cuda() for _ in output] if use_cuda else [torch.from_numpy(_).long() for _ in output]\n\n                loss= model.get_loss(history_glist, output[-1], None, use_cuda)\n                losses.append(loss.item())\n\n                loss.backward()\n                torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm)  # clip gradients\n                optimizer.step()\n                optimizer.zero_grad()\n\n            print(\"His {:04d}, Epoch {:04d} | Ave Loss: {:.4f} | Best MRR {:.4f} | Model {} \"\n                .format(args.start_history_len, epoch, np.mean(losses), best_mrr, model_name))\n            \n            #! checking GPU usage\n            free_mem, total_mem = torch.cuda.mem_get_info()\n            print (\"--------------GPU memory usage-----------\")\n            print (\"there are \", free_mem, \" free memory\")\n            print (\"there are \", total_mem, \" total available memory\")\n            print (\"there are \", total_mem - free_mem, \" used memory\")\n            print (\"--------------GPU memory usage-----------\")\n\n            # validation        \n            if epoch % args.evaluate_every == 0:\n                mrr, _, hits10 = test(model, args.start_history_len, train_list, valid_list, num_rels, num_nodes, use_cuda, \n                                model_state_file, mode=\"train\", split_mode= \"val\")\n\n                if mrr< best_mrr:\n                    if epoch >= args.n_epochs or epoch - best_epoch > 5:\n                        break\n                else:\n                    best_mrr = mrr\n                    best_epoch = epoch\n                    best_hits10 = hits10\n                    torch.save({'state_dict': model.state_dict(), 'epoch': epoch}, model_state_file)\n\n        \n    elif trainvalidtest_id == 0: #curriculum training\n        model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{args.start_history_len}'\n        init_state_file = save_model_dir + model_name\n        init_checkpoint = torch.load(init_state_file, map_location=torch.device(args.gpu))\n        # use best stat checkpoint:\n        print(\"Load Previous Model name: {}. Using best epoch : {}\".format(init_state_file, init_checkpoint['epoch']))  \n        print(\"\\n\"+\"-\"*10+\"Load model with history length {}\".format(args.start_history_len)+\"-\"*10+\"\\n\")\n        model.load_state_dict(init_checkpoint['state_dict'])\n        test_history_len = args.start_history_len\n\n        mrr, _, hits10 = test(model, \n                    args.start_history_len,\n                    train_list,\n                    valid_list, \n                    num_rels, \n                    num_nodes, \n                    use_cuda, \n                    init_state_file,  \n                    mode=\"test\", split_mode= \"val\") \n        best_mrr_list = [mrr.item()]         \n        best_hits_list = [hits10.item()]                                          \n        # start knowledge distillation\n        ks_idx = 0\n        for history_len in range(args.start_history_len+1, args.train_history_len+1, 1):\n            # current model\n            print(\"best mrr list :\", best_mrr_list)\n            # lr = 0.1*args.lr - 0.002*args.lr*ks_idx\n            optimizer = torch.optim.Adam(model.parameters(), lr=0.1*args.lr, weight_decay=0.00001)\n            model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{history_len}'\n            model_state_file = save_model_dir + model_name\n\n            print(\"Sanity Check: stat name : {}\".format(model_state_file))\n\n            # load model with the least history length\n            prev_model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{history_len-1}'\n            prev_state_file = save_model_dir + prev_model_name\n            checkpoint = torch.load(prev_state_file, map_location=torch.device(args.gpu)) \n            model.load_state_dict(checkpoint['state_dict']) \n            print(\"\\n\"+\"-\"*10+\"start knowledge distillation for history length at \"+ str(history_len)+\"-\"*10+\"\\n\")\n \n            best_mrr = 0\n            best_hits10 = 0\n            best_epoch = 0\n            for epoch in range(args.n_epochs):\n                model.train()\n                losses = []\n\n                idx = [_ for _ in range(len(train_list))]\n                random.shuffle(idx)\n                for train_sample_num in idx:\n                    if train_sample_num == 0 or train_sample_num == 1: continue\n                    if train_sample_num - history_len<0:\n                        input_list = train_list[0: train_sample_num]\n                        output = train_list[1:train_sample_num+1]\n                    else:\n                        input_list = train_list[train_sample_num - history_len: train_sample_num]\n                        output = train_list[train_sample_num-history_len+1:train_sample_num+1]\n\n                    # generate history graph\n                    history_glist = [build_sub_graph(num_nodes, num_rels, snap, use_cuda, args.gpu) for snap in input_list]\n                    output = [torch.from_numpy(_).long().cuda() for _ in output] if use_cuda else [torch.from_numpy(_).long() for _ in output]\n\n                    loss= model.get_loss(history_glist, output[-1], None, use_cuda)\n                    # print(loss)\n                    losses.append(loss.item())\n\n                    loss.backward()\n                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm)  # clip gradients\n                    optimizer.step()\n                    optimizer.zero_grad()\n\n                print(\"His {:04d}, Epoch {:04d} | Ave Loss: {:.4f} |Best MRR {:.4f} | Model {} \"\n                    .format(history_len, epoch, np.mean(losses), best_mrr, model_name))\n\n                #! checking GPU usage\n                free_mem, total_mem = torch.cuda.mem_get_info()\n                print (\"--------------GPU memory usage-----------\")\n                print (\"there are \", free_mem, \" free memory\")\n                print (\"there are \", total_mem, \" total available memory\")\n                print (\"there are \", total_mem - free_mem, \" used memory\")\n                print (\"--------------GPU memory usage-----------\")\n\n                # validation\n                if epoch % args.evaluate_every == 0:\n                    mrr, _, hits10 = test(model, history_len, train_list, valid_list, num_rels, num_nodes, use_cuda, \n                                model_state_file, mode=\"train\", split_mode= \"val\")\n                    \n                    if mrr< best_mrr:\n                        if epoch >= args.n_epochs or epoch-best_epoch>2:\n                            break\n                    else:\n                        best_mrr = mrr\n                        best_epoch = epoch\n                        best_hits10 = hits10\n                        torch.save({'state_dict': model.state_dict(), 'epoch': epoch}, model_state_file)  \n            mrr, _, hits10 = test(model, history_len, train_list, valid_list, num_rels, num_nodes, use_cuda, \n                        model_state_file, mode=\"test\", split_mode= \"val\")\n            ks_idx += 1\n            if mrr.item() < max(best_mrr_list):\n                test_history_len = history_len-1\n                print(\"early stopping, best history length: \", test_history_len)\n                break\n            else:\n                best_mrr_list.append(mrr.item())\n                best_hits_list.append(hits10.item())\n        \n    return mrr, test_history_len, perf_per_rel, hits10\n\n\n\n# ==================\n# ==================\n# ==================\n\nstart_overall = timeit.default_timer()\n\n# set hyperparameters\nargs, _ = get_args_cen()\nargs.dataset = 'tkgl-yago'\n\nSEED = args.seed  # set the random seed for consistency\nset_random_seed(SEED)\n\nDATA=args.dataset\nMODEL_NAME = 'CEN'\n\nprint(\"logging mrrs per relation: \", args.log_per_rel)\nprint(\"do test and valid? do only test no validation?: \", args.validtest, args.test_only)\n\n# load data\ndataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\n\nnum_rels = dataset.num_rels\nnum_nodes = dataset.num_nodes \nsubjects = dataset.full_data[\"sources\"]\nobjects= dataset.full_data[\"destinations\"]\nrelations = dataset.edge_type\n\ntimestamps_orig = dataset.full_data[\"timestamps\"]\ntimestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1\nall_quads = np.stack((subjects, relations, objects, timestamps), axis=1)\n\ntrain_data = all_quads[dataset.train_mask]\nval_data = all_quads[dataset.val_mask]\ntest_data = all_quads[dataset.test_mask]\n\nval_timestamps_orig = list(set(timestamps_orig[dataset.val_mask])) # needed for getting the negative samples\nval_timestamps_orig.sort()\ntest_timestamps_orig = list(set(timestamps_orig[dataset.test_mask])) # needed for getting the negative samples\ntest_timestamps_orig.sort()\n\ntrain_list = split_by_time(train_data)\nvalid_list = split_by_time(val_data)\ntest_list = split_by_time(test_data)\n\n# evaluation metric\nMETRIC = dataset.eval_metric\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n#load the ns samples \ndataset.load_val_ns()\ndataset.load_test_ns()\n\nif args.grid_search:\n    print(\"TODO: implement hyperparameter grid search\")\n# single run\nelse:\n    \n    start_train = timeit.default_timer()\n    if args.validtest:\n        print('directly start testing')\n        if args.test_history_len_2 != args.test_history_len:\n            args.test_history_len = args.test_history_len_2 # hyperparameter value as given in original paper \n    else:\n        print('running pretrain and train')\n        # pretrain\n        mrr, _, _, hits10 = run_experiment(args, trainvalidtest_id=-1)\n        # train\n        mrr, args.test_history_len, _, hits10 = run_experiment(args, trainvalidtest_id=0) # overwrite test_history_len with \n        # the best history len (for valid mrr)       \n        \n    if args.test_only == False:\n        print(\"running test (on val and test dataset) with test_history_len of: \", args.test_history_len)\n        # test on val set\n        val_mrr, _, _, val_hits10 = run_experiment(args, trainvalidtest_id=1)\n    else:\n        val_mrr = 0\n        val_hits10 = 0\n\n    # test on test set\n    start_test = timeit.default_timer()\n    test_mrr, _, perf_per_rel, test_hits10 = run_experiment(args, trainvalidtest_id=2)\n\ntest_time = timeit.default_timer() - start_test\nall_time = timeit.default_timer() - start_train\nprint(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\nprint(f\"\\Train and Test: Elapsed Time (s): {all_time: .4f}\")\n\nprint(f\"\\tTest: {METRIC}: {test_mrr: .4f}\")\nprint(f\"\\tValid: {METRIC}: {val_mrr: .4f}\")\n\n# saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'\n\nsave_results({'model': MODEL_NAME,\n              'data': DATA,\n              'run': args.run_nr,\n              'seed': SEED,\n              'test_history_len': args.test_history_len,\n              f'val {METRIC}': float(val_mrr),\n              f'test {METRIC}': float(test_mrr),\n              'test_time': test_time,\n              'tot_train_val_time': all_time,\n              'test_hits10': float(test_hits10)\n              }, \n    results_filename)\n\nif args.log_per_rel:\n    results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results_per_rel.json'\n    with open(results_filename, 'w') as json_file:\n        json.dump(perf_per_rel, json_file)\n\n\nsys.exit()"
  },
  {
    "path": "examples/linkproppred/tkgl-yago/edgebank.py",
    "content": "\"\"\"\nDynamic Link Prediction with EdgeBank\nNOTE: This implementation works only based on `numpy`\n\nReference: \n    - https://github.com/fpour/DGB/tree/main\n\n\n\"\"\"\n\nimport timeit\nimport numpy as np\nfrom sklearn.metrics import average_precision_score, roc_auc_score\nfrom torch_geometric.loader import TemporalDataLoader\nfrom tqdm import tqdm\nimport math\nimport os\nimport os.path as osp\nfrom pathlib import Path\nimport sys\nimport argparse\n\n# internal imports\ntgb_modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\nsys.path.append(tgb_modules_path)\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom modules.edgebank_predictor import EdgeBankPredictor\nfrom tgb.utils.utils import set_random_seed\nfrom tgb.linkproppred.dataset import LinkPropPredDataset\nfrom tgb.utils.utils import save_results\n\n# ==================\n# ==================\n# ==================\n\ndef test(data, test_mask, neg_sampler, split_mode):\n    r\"\"\"\n    Evaluated the dynamic link prediction\n    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges\n\n    Parameters:\n        data: a dataset object\n        test_mask: required masks to load the test set edges\n        neg_sampler: an object that gives the negative edges corresponding to each positive edge\n        split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives\n    Returns:\n        perf_metric: the result of the performance evaluation\n    \"\"\"\n    num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)\n    perf_list = []\n    hits_list = []\n    for batch_idx in tqdm(range(num_batches)):\n        start_idx = batch_idx * BATCH_SIZE\n        end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))\n        pos_src, pos_dst, pos_t, pos_edge = (\n            data['sources'][test_mask][start_idx: end_idx],\n            data['destinations'][test_mask][start_idx: end_idx],\n            data['timestamps'][test_mask][start_idx: end_idx],\n            data['edge_type'][test_mask][start_idx: end_idx],\n        )\n        neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, pos_edge, split_mode=split_mode)\n        \n        for idx, neg_batch in enumerate(neg_batch_list):\n            query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])\n            query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])\n\n            y_pred = edgebank.predict_link(query_src, query_dst)\n            # compute MRR\n            input_dict = {\n                \"y_pred_pos\": np.array([y_pred[0]]),\n                \"y_pred_neg\": np.array(y_pred[1:]),\n                \"eval_metric\": [metric],\n            }\n            results = evaluator.eval(input_dict)\n            perf_list.append(results[metric])\n            hits_list.append(results['hits@10'])\n            \n        # update edgebank memory after each positive batch\n        edgebank.update_memory(pos_src, pos_dst, pos_t)\n\n    perf_metrics = float(np.mean(perf_list))\n    perf_hits = float(np.mean(hits_list))\n\n    return perf_metrics, perf_hits\n\ndef get_args():\n    parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')\n    parser.add_argument('-d', '--data', type=str, help='Dataset name', default='tkgl-yago')\n    parser.add_argument('--bs', type=int, help='Batch size', default=200)\n    parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)\n    parser.add_argument('--seed', type=int, help='Random seed', default=1)\n    parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])\n    parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)\n\n    try:\n        args = parser.parse_args()\n    except:\n        parser.print_help()\n        sys.exit(0)\n    return args, sys.argv \n\n# ==================\n# ==================\n# ==================\n\nstart_overall = timeit.default_timer()\n\n# set hyperparameters\nargs, _ = get_args()\n\nSEED = args.seed  # set the random seed for consistency\nset_random_seed(SEED)\nMEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`\nBATCH_SIZE = args.bs\nK_VALUE = args.k_value\nTIME_WINDOW_RATIO = args.time_window_ratio\nDATA = args.data\nMODEL_NAME = 'EdgeBank'\n\n\n\n# data loading with `numpy`\ndataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\ndata = dataset.full_data  \nmetric = dataset.eval_metric\n\n\n# get masks\ntrain_mask = dataset.train_mask\nval_mask = dataset.val_mask\ntest_mask = dataset.test_mask\n\n#data for memory in edgebank\nhist_src = np.concatenate([data['sources'][train_mask]])\nhist_dst = np.concatenate([data['destinations'][train_mask]])\nhist_ts = np.concatenate([data['timestamps'][train_mask]])\n\n# Set EdgeBank with memory updater\nedgebank = EdgeBankPredictor(\n        hist_src,\n        hist_dst,\n        hist_ts,\n        memory_mode=MEMORY_MODE,\n        time_window_ratio=TIME_WINDOW_RATIO)\n\nprint(\"==========================================================\")\nprint(f\"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============\")\nprint(\"==========================================================\")\n\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'\n\n# ==================================================== Test\n# loading the validation negative samples\ndataset.load_val_ns()\n\n# testing ...\nstart_val = timeit.default_timer()\nperf_metric_val, perf_hits_val = test(data, val_mask, neg_sampler, split_mode='val')\nend_val = timeit.default_timer()\n\nprint(f\"INFO: val: Evaluation Setting: >>> ONE-VS--ALL <<< \")\nprint(f\"\\tval: {metric}: {perf_metric_val: .4f}\")\nval_time = timeit.default_timer() - start_val\nprint(f\"\\tval: Elapsed Time (s): {val_time: .4f}\")\n\n\n\n\n# ==================================================== Test\n# loading the test negative samples\ndataset.load_test_ns()\n\n# testing ...\nstart_test = timeit.default_timer()\nperf_metric_test, perf_hits_test = test(data, test_mask, neg_sampler, split_mode='test')\nend_test = timeit.default_timer()\n\nprint(f\"INFO: Test: Evaluation Setting: >>>  <<< \")\nprint(f\"\\tTest: {metric}: {perf_metric_test: .4f}\")\ntest_time = timeit.default_timer() - start_test\nprint(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\n\nsave_results({'model': MODEL_NAME,\n              'memory_mode': MEMORY_MODE,\n              'data': DATA,\n              'run': 1,\n              'seed': SEED,\n              metric: perf_metric_test,\n              'val_mrr': perf_metric_val,\n              'test_time': test_time,\n              'tot_train_val_time': test_time+val_time,\n              'hits10': perf_hits_test}, \n    results_filename)\n\n"
  },
  {
    "path": "examples/linkproppred/tkgl-yago/recurrencybaseline.py",
    "content": "\"\"\" from paper: \"History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting\" \nJulia Gastinger, Christian Meilicke, Federico Errica, Timo Sztyler, Anett Schuelke, Heiner Stuckenschmidt (IJCAI 2024)\n\n@inproceedings{gastinger2024baselines,\n  title={History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting},\n  author={Gastinger, Julia and Meilicke, Christian and Errica, Federico and Sztyler, Timo and Schuelke, Anett and Stuckenschmidt, Heiner},\n  booktitle={33nd International Joint Conference on Artificial Intelligence (IJCAI 2024)},\n  year={2024},\n  organization={International Joint Conferences on Artificial Intelligence Organization}\n}\n\n\"\"\"\n\n## imports\nimport timeit\nimport argparse\nimport numpy as np\nfrom copy import copy\nfrom pathlib import Path\nimport ray\nimport sys\nimport os\nimport os.path as osp\nimport json\n#internal imports \nmodules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\nsys.path.append(modules_path)\nfrom modules.recurrencybaseline_predictor import baseline_predict, baseline_predict_remote\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom tgb.linkproppred.dataset import LinkPropPredDataset \nfrom tgb.utils.utils import set_random_seed,  save_results \nfrom modules.tkg_utils import create_basis_dict, group_by, reformat_ts\n\ndef predict(num_processes,  data_c_rel, all_data_c_rel, alpha, lmbda_psi,\n            perf_list_all, hits_list_all, window, neg_sampler, split_mode):\n    \"\"\" create predictions for each relation on test or valid set and compute mrr\n    :return  perf_list_all: list of mrrs for each test query\n    :return hits_list_all: list of hits for each test query\n    \"\"\"\n\n    first_ts = data_c_rel[0][3]\n    ## use this if you wanna use ray:\n    num_queries = len(data_c_rel) // num_processes\n    if num_queries < num_processes: # if we do not have enough queries for all the processes\n        num_processes_tmp = 1\n        num_queries = len(data_c_rel)\n    else:\n        num_processes_tmp = num_processes  \n    if num_processes > 1:\n        object_references =[]                   \n        \n        for i in range(num_processes_tmp):\n            num_test_queries = len(data_c_rel) - (i + 1) * num_queries\n            if num_test_queries >= num_queries:\n                test_queries_idx =[i * num_queries, (i + 1) * num_queries]\n            else:\n                test_queries_idx = [i * num_queries, len(test_data)]\n\n            valid_data_b = data_c_rel[test_queries_idx[0]:test_queries_idx[1]]\n\n            ob = baseline_predict_remote.remote(num_queries, valid_data_b, all_data_c_rel, window, \n                                basis_dict, \n                                num_nodes, num_rels, lmbda_psi, \n                                alpha, evaluator,first_ts, neg_sampler, split_mode)\n            object_references.append(ob)\n\n        output = ray.get(object_references)\n\n        # updates the scores and logging dict for each process\n        for proc_loop in range(num_processes_tmp):\n            perf_list_all.extend(output[proc_loop][0])\n            hits_list_all.extend(output[proc_loop][1])\n\n    else:\n        perf_list, hits_list = baseline_predict(len(data_c_rel), data_c_rel, all_data_c_rel, \n                            window, basis_dict, \n                            num_nodes, num_rels, lmbda_psi, \n                            alpha, evaluator, first_ts, neg_sampler, split_mode)                  \n        perf_list_all.extend(perf_list)\n        hits_list_all.extend(hits_list)\n    \n    return perf_list_all, hits_list_all\n\n## test\ndef test(best_config, all_relations,test_data_prel, all_data_prel, neg_sampler, num_processes, window, split_mode='test'):  \n    \"\"\" create predictions by loopoing through all relations on test or valid set and compute mrr\n    :return  perf_list_all: list of mrrs for each test query\n    :return hits_list_all: list of hits for each test query\n    \"\"\"       \n    perf_list_all = []\n    hits_list_all =[]\n    \n    csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'\n\n    ## loop through relations and apply baselines\n    \n    for rel in all_relations:\n        start =  timeit.default_timer()\n        if rel in test_data_prel.keys():\n            lmbda_psi = best_config[str(rel)]['lmbda_psi'][0]\n            alpha = best_config[str(rel)]['alpha'][0]\n\n            # test data for this relation\n            test_data_c_rel = test_data_prel[rel]\n            timesteps_test = list(set(test_data_c_rel[:,3]))\n            timesteps_test.sort()\n            all_data_c_rel = all_data_prel[rel]         \n            perf_list_rel = []\n            hits_list_rel = []\n            perf_list_rel, hits_list_rel = predict(num_processes, test_data_c_rel,\n                                                all_data_c_rel, alpha, lmbda_psi,perf_list_rel, hits_list_rel, \n                                                window, neg_sampler, split_mode)\n            perf_list_all.extend(perf_list_rel)\n            hits_list_all.extend(hits_list_rel)\n        else:\n            perf_list_rel =[]\n         \n        end =  timeit.default_timer()\n        total_time = round(end - start, 6)  \n        print(\"Relation {} finished in {} seconds.\".format(rel, total_time))\n        \n\n        with open(csv_file, 'a') as f:\n            f.write(\"{},{}\\n\".format(rel, perf_list_rel))\n\n    return perf_list_all, hits_list_all\n\ndef read_dict_compute_mrr(split_mode='test'):\n    \"\"\" read the results per relation  from a precreated file and compute mrr\n    :return  mrr_per_rel: dictionary of mrrs for each relation\n    :return all_mrrs: list of mrrs for all relations\n    \"\"\"\n    csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'\n    # Initialize an empty dictionary to store the data\n    results_per_rel_dict = {}\n    mrr_per_rel = {}\n    all_mrrs = []\n    # Open the file for reading\n    with open(csv_file, 'r') as f:\n        # Read each line in the file\n        for line in f:\n            # Split the line at the comma\n            parts = line.strip().split(',')\n            # Extract the key (the first part)\n            key = int(parts[0])\n            # Extract the values (the rest of the parts), remove square brackets\n            values = [float(value.strip('[]')) for value in parts[1:]]\n            # Add the key-value pair to the dictionary\n            if key in results_per_rel_dict.keys():\n                print(f\"Key {key} already exists in the dictionary!!! might have duplicate entries in results csv\")\n            results_per_rel_dict[key] = values\n            all_mrrs.extend(values)\n            mrr_per_rel[key] = np.mean(values)\n\n    if len(list(results_per_rel_dict.keys())) != num_rels:\n        print(\"we do not have entries for each rel in the results csv file. only num enties: \", len(list(results_per_rel_dict.keys())))\n\n    print(\"Split mode: \"+split_mode +\" Mean MRR: \", np.mean(all_mrrs))\n    print(\"mrr per relation: \", mrr_per_rel)\n\n\n\n## train\ndef train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampler, num_processes, window):\n    \"\"\" optional, find best values for lambda and alpha by looping through all relations and testing an a fixed set of params\n    based on validation mrr\n    :return best_config: dictionary of best params for each relation\n    \"\"\"\n    best_config= {}\n    best_mrr = 0\n    for rel in rels: # loop through relations. for each relation, apply rules with selected params, compute valid mrr\n        start =  timeit.default_timer()\n        rel_key = int(rel)            \n\n        best_config[str(rel_key)] = {}\n        best_config[str(rel_key)]['not_trained'] = 'True'    \n        best_config[str(rel_key)]['lmbda_psi'] = [default_lmbda_psi,0] #default\n        best_config[str(rel_key)]['other_lmbda_mrrs'] = list(np.zeros(len(params_dict['lmbda_psi'])))\n        best_config[str(rel_key)]['alpha'] = [default_alpha,0]  #default    \n        best_config[str(rel_key)]['other_alpha_mrrs'] = list(np.zeros(len(params_dict['alpha'])))\n        \n        if rel in val_data_prel.keys():      \n            # valid data for this relation  \n            val_data_c_rel = val_data_prel[rel]\n            timesteps_valid = list(set(val_data_c_rel[:,3]))\n            timesteps_valid.sort()\n            trainval_data_c_rel = trainval_data_prel[rel]\n\n            ######  1) select lambda ###############        \n            lmbdas_psi = params_dict['lmbda_psi']        \n\n            alpha = 1\n            best_lmbda_psi = 0.1\n            best_mrr_psi = 0\n            lmbda_mrrs = []\n\n            best_config[str(rel_key)]['num_app_valid'] = copy(len(val_data_c_rel))\n            best_config[str(rel_key)]['num_app_train_valid'] = copy(len(trainval_data_c_rel))         \n            best_config[str(rel_key)]['not_trained'] = 'False'       \n            \n            for lmbda_psi in lmbdas_psi:   \n                perf_list_r = []\n                hits_list_r = []\n                perf_list_r, hits_list_r = predict(num_processes, val_data_c_rel, \n                                                    trainval_data_c_rel, alpha, lmbda_psi,perf_list_r, hits_list_r, \n                                                    window, neg_sampler, split_mode='val')\n                # compute mrr\n                mrr = np.mean(perf_list_r)\n                # # is new mrr better than previous best? if yes: store lmbda\n                if mrr > best_mrr_psi:\n                    best_mrr_psi = float(mrr)\n                    best_lmbda_psi = lmbda_psi\n\n\n                lmbda_mrrs.append(float(mrr))\n            best_config[str(rel_key)]['lmbda_psi'] = [best_lmbda_psi, best_mrr_psi]\n            best_config[str(rel_key)]['other_lmbda_mrrs'] = lmbda_mrrs\n            best_mrr = best_mrr_psi\n            ##### 2) select alpha ###############\n            best_config[str(rel_key)]['not_trained'] = 'False'    \n            alphas = params_dict['alpha'] \n            lmbda_psi = best_config[str(rel_key)]['lmbda_psi'][0] # use the best lmbda psi\n\n            alpha_mrrs = []\n            # perf_list_all = []\n            best_mrr_alpha = 0\n            best_alpha=0.99\n            for alpha in alphas:\n                perf_list_r = []\n                hits_list_r = []\n\n                perf_list_r, hits_list_r = predict(num_processes, val_data_c_rel, \n                                                    trainval_data_c_rel, alpha, lmbda_psi,perf_list_r, hits_list_r, \n                                                    window, neg_sampler, split_mode='val')\n                # compute mrr\n                mrr_alpha = np.mean(perf_list_r)\n\n                # is new mrr better than previous best? if yes: store alpha\n                if mrr_alpha > best_mrr_alpha:\n                    best_mrr_alpha = float(mrr_alpha)\n                    best_alpha = alpha\n                    best_mrr = best_mrr_alpha\n                alpha_mrrs.append(float(mrr_alpha))\n\n            best_config[str(rel_key)]['alpha'] = [best_alpha, best_mrr_alpha]\n            best_config[str(rel_key)]['other_alpha_mrrs'] = alpha_mrrs\n\n        end =  timeit.default_timer()\n        total_time = round(end - start, 6)  \n        print(\"Relation {} finished in {} seconds.\".format(rel, total_time))\n    return best_config\n\n\n\n## args\ndef get_args(): \n    \"\"\"parse all arguments for the script\"\"\"\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--dataset\", \"-d\", default=\"tkgl-yago\", type=str) \n    parser.add_argument(\"--window\", \"-w\", default=0, type=int) # set to e.g. 200 if only the most recent 200 timesteps should be considered. set to -2 if multistep\n    parser.add_argument(\"--num_processes\", \"-p\", default=1, type=int)\n    parser.add_argument(\"--lmbda\", \"-l\",  default=0.1, type=float) # fix lambda. used if trainflag == false\n    parser.add_argument(\"--alpha\", \"-alpha\",  default=0.99, type=float) # fix alpha. used if trainflag == false\n    parser.add_argument(\"--train_flag\", \"-tr\",  default='False') # do we need training, ie selection of lambda and alpha\n    parser.add_argument(\"--load_flag\", \"-lo\",  default='False') # if train_flag set to True: do you want to load best_config?\n    parser.add_argument(\"--save_config\", \"-c\",  default='True') # do we need to save the selection of lambda and alpha in config file?\n    parser.add_argument('--seed', type=int, help='Random seed', default=1) # not needed\n    parsed = vars(parser.parse_args())\n    return parsed\n\nstart_o =  timeit.default_timer()\n\nparsed = get_args()\nif parsed['num_processes']>1:\n    ray.init(num_cpus=parsed[\"num_processes\"], num_gpus=0)\nMODEL_NAME = 'RecurrencyBaseline'\nSEED = parsed['seed']  # set the random seed for consistency\nset_random_seed(SEED)\nperrel_results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_models'\nif not osp.exists(perrel_results_path):\n    os.mkdir(perrel_results_path)\n    print('INFO: Create directory {}'.format(perrel_results_path))\nPath(perrel_results_path).mkdir(parents=True, exist_ok=True)\n\n## load dataset and prepare it accordingly\nname = parsed[\"dataset\"]\ndataset = LinkPropPredDataset(name=name, root=\"datasets\", preprocess=True)\nDATA = name\n\nrelations = dataset.edge_type\nnum_rels = dataset.num_rels\nrels = np.arange(0,num_rels)\nsubjects = dataset.full_data[\"sources\"]\nobjects= dataset.full_data[\"destinations\"]\nnum_nodes = dataset.num_nodes \ntimestamps_orig = dataset.full_data[\"timestamps\"]\ntimestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1\n\nprint(\"split train valid test data\")\nall_quads = np.stack((subjects, relations, objects, timestamps, timestamps_orig), axis=1)\ntrain_data = all_quads[dataset.train_mask]\nval_data = all_quads[dataset.val_mask]\ntest_data = all_quads[dataset.test_mask]\n\nmetric = dataset.eval_metric\nevaluator = Evaluator(name=name)\nneg_sampler = dataset.negative_sampler\n\ntrain_val_data = np.concatenate([train_data, val_data])\nall_data = np.concatenate([train_data, val_data, test_data])\n\n# create dicts with key: relation id, values: triples for that relation id\nprint(\"grouping data by relation\")\ntest_data_prel = group_by(test_data, 1)\nall_data_prel = group_by(all_data, 1)\nval_data_prel = group_by(val_data, 1)\ntrainval_data_prel = group_by(train_val_data, 1)\n\n#load the ns samples \n# if parsed['train_flag']:\nprint(\"loading negative samples\")\ndataset.load_val_ns()\ndataset.load_test_ns()\n\n# parameter options\nif parsed['train_flag'] == 'True':\n    params_dict = {}\n    params_dict['lmbda_psi'] = [0, 0.0001, 0.0005, 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.5, 0.9, 1.0001] \n    params_dict['alpha'] = [0, 0.00001, 0.0001, 0.001, 0.01, 0.1, 0.5, 0.9, 0.99, 0.999, 0.9999, 0.99999, 1]\n    default_lmbda_psi = params_dict['lmbda_psi'][-1]\n    default_alpha = params_dict['alpha'][-2]\n\n## load rules\nprint(\"creating rules\")\nbasis_dict = create_basis_dict(train_val_data)\nprint(\"done with creating rules\")\n## init\n# rb_predictor = RecurrencyBaselinePredictor(rels)\n## train to find best lambda and alpha\nstart_train =  timeit.default_timer()\nif parsed['train_flag'] ==  'True':\n    if parsed['load_flag'] == 'True':\n        with open('best_config.json', 'r') as infile:\n            best_config = json.load(infile)\n    else:\n        print('start training')\n        best_config = train(params_dict,  rels, val_data_prel, trainval_data_prel, neg_sampler, parsed['num_processes'], \n            parsed['window'])\n        if parsed['save_config'] == 'True':\n            import json\n            with open('best_config.json', 'w') as outfile:\n                json.dump(best_config, outfile)\n\nelse: # use preset lmbda and alpha; same for all relations\n    best_config = {} \n    for rel in rels:\n        best_config[str(rel)] = {}\n        best_config[str(rel)]['lmbda_psi'] = [parsed['lmbda']]\n        best_config[str(rel)]['alpha'] = [parsed['alpha']]\n    \nend_train =  timeit.default_timer()\n\n# compute validation mrr\nprint(\"Computing validation MRR\")\nperf_list_all_val, hits_list_all_val = test(best_config,rels, val_data_prel, \n                                                trainval_data_prel, neg_sampler, parsed['num_processes'], \n                                            parsed['window'], split_mode='val')\nval_mrr = float(np.mean(perf_list_all_val))\n\n# compute test mrr\nprint(\"Computing test MRR\")\nstart_test =  timeit.default_timer()\nperf_list_all, hits_list_all = test(best_config,rels, test_data_prel, \n                                                 all_data_prel, neg_sampler, parsed['num_processes'], \n                                                parsed['window'])\n\nend_o =  timeit.default_timer()\ntrain_time_o = round(end_train- start_train, 6)  \ntest_time_o = round(end_o- start_test, 6)  \ntotal_time_o = round(end_o- start_o, 6)  \nprint(\"Running Training to find best configs finished in {} seconds.\".format(train_time_o))\nprint(\"Running testing with best configs finished in {} seconds.\".format(test_time_o))\nprint(\"Running all steps finished in {} seconds.\".format(total_time_o))\n\nprint(f\"The test MRR is {np.mean(perf_list_all)}\")\nprint(f\"The valid MRR is {val_mrr}\")\nprint(f\"The Hits@10 is {np.mean(hits_list_all)}\")\nprint(f\"We have {len(perf_list_all)} predictions\")\nprint(f\"The test set has len {len(test_data)} \")\n\n# for saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_NONE_{DATA}_results.json'\nmetric = dataset.eval_metric\nsave_results({'model': MODEL_NAME,\n              'train_flag': parsed['train_flag'],\n              'data': DATA,\n              'run': 1,\n              'seed': SEED,\n              metric: float(np.mean(perf_list_all)),\n              'val_mrr': val_mrr,\n              'hits10': float(np.mean(hits_list_all)),\n              'test_time': test_time_o,\n              'tot_train_val_time': total_time_o\n              }, \n    results_filename)\n\nif parsed['num_processes']>1:\n    ray.shutdown()\n\n\n    \n"
  },
  {
    "path": "examples/linkproppred/tkgl-yago/regcn.py",
    "content": "\"\"\"\nTemporal Knowledge Graph Reasoning Based on Evolutional Representation Learning\nReference:\n- https://github.com/Lee-zix/RE-GCN\nZixuan Li, Xiaolong Jin, Wei Li, Saiping Guan, Jiafeng Guo, Huawei Shen, Yuanzhuo Wang and Xueqi Cheng. Temporal \nKnowledge Graph Reasoning Based on Evolutional Representation Learning. SIGIR 2021.\n\"\"\"\nimport sys\nimport timeit\nimport os\nimport sys\nimport os.path as osp\nfrom pathlib import Path\nimport numpy as np\nimport torch\nimport random\nfrom tqdm import tqdm\n# internal imports\nmodules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\nsys.path.append(modules_path)\nfrom modules.rrgcn import RecurrentRGCNREGCN\nfrom tgb.utils.utils import set_random_seed, split_by_time, save_results\nfrom modules.tkg_utils import get_args_regcn, reformat_ts\nfrom modules.tkg_utils_dgl import build_sub_graph\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom tgb.linkproppred.dataset import LinkPropPredDataset \nimport json\n\ndef test(model, history_list, test_list, num_rels, num_nodes, use_cuda, model_name, static_graph, mode, split_mode):\n    \"\"\"\n    Test the model on either test or validation set\n    :param model: model used to test\n    :param history_list:    all input history snap shot list, not include output label train list or valid list\n    :param test_list:   test triple snap shot list\n    :param num_rels:    number of relations\n    :param num_nodes:   number of nodes\n    :param use_cuda:\n    :param model_name:\n    :param mode:\n    :param split_mode: 'test' or 'val' to state which negative samples to load\n    :return mrr\n    \"\"\"\n    print(\"Testing for mode: \", split_mode)\n    if split_mode == 'test':\n        timesteps_to_eval = test_timestamps_orig\n    else:\n        timesteps_to_eval = val_timestamps_orig\n\n    idx = 0\n\n    if mode == \"test\":\n        # test mode: load parameter form file \n        if use_cuda:\n            checkpoint = torch.load(model_name, map_location=torch.device(args.gpu))\n        else:\n            checkpoint = torch.load(model_name, map_location=torch.device('cpu'))        \n        # use best stat checkpoint:\n        print(\"Load Model name: {}. Using best epoch : {}\".format(model_name, checkpoint['epoch']))  \n        print(\"\\n\"+\"-\"*10+\"start testing\"+\"-\"*10+\"\\n\")\n        model.load_state_dict(checkpoint['state_dict'])\n\n    model.eval()\n    input_list = [snap for snap in history_list[-args.test_history_len:]] \n    perf_list_all = []\n    hits_list_all = []\n    perf_per_rel = {}\n    for rel in range(num_rels):\n        perf_per_rel[rel] = []\n\n    for time_idx, test_snap in enumerate(tqdm(test_list)):\n        history_glist = [build_sub_graph(num_nodes, num_rels, g, use_cuda, args.gpu) for g in input_list]    \n        test_triples_input = torch.LongTensor(test_snap).cuda() if use_cuda else torch.LongTensor(test_snap)\n        timesteps_batch =timesteps_to_eval[time_idx]*np.ones(len(test_triples_input[:,0]))\n\n        neg_samples_batch = neg_sampler.query_batch(test_triples_input[:,0], test_triples_input[:,2], \n                                timesteps_batch, edge_type=test_triples_input[:,1], split_mode=split_mode)\n        pos_samples_batch = test_triples_input[:,2]\n\n        _, perf_list, hits_list = model.predict(history_glist, num_rels, static_graph, test_triples_input, use_cuda, neg_samples_batch, pos_samples_batch, \n                                    evaluator, METRIC)  \n\n        perf_list_all.extend(perf_list)\n        hits_list_all.extend(hits_list)\n        if split_mode == \"test\":\n            if args.log_per_rel:\n                for score, rel in zip(perf_list, test_triples_input[:,1].tolist()):\n                    perf_per_rel[rel].append(score)\n        # reconstruct history graph list\n        input_list.pop(0)\n        input_list.append(test_snap)\n        idx += 1\n\n    if split_mode == \"test\":\n        if args.log_per_rel:   \n            for rel in range(num_rels):\n                if len(perf_per_rel[rel]) > 0:\n                    perf_per_rel[rel] = float(np.mean(perf_per_rel[rel]))\n                else:\n                    perf_per_rel.pop(rel)\n    mrr = np.mean(perf_list_all) \n    hits10 = np.mean(hits_list_all) \n    return mrr, perf_per_rel, hits10\n\n\n\ndef run_experiment(args, n_hidden=None, n_layers=None, dropout=None, n_bases=None):\n    \"\"\"\n    Run the experiment with the given configuration\n    :param args: arguments\n    :param n_hidden: hidden dimension\n    :param n_layers: number of layers\n    :param dropout: dropout rate\n    :param n_bases: number of bases\n    :return: mrr, perf_per_rel  (mean reciprocal rank, performance per relation)\n    \"\"\"\n    # load configuration for grid search the best configuration\n    if n_hidden:\n        args.n_hidden = n_hidden\n    if n_layers:\n        args.n_layers = n_layers\n    if dropout:\n        args.dropout = dropout\n    if n_bases:\n        args.n_bases = n_bases\n    mrr = 0\n    hits10=0\n    # 2) set save model path\n    save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'\n    if not osp.exists(save_model_dir):\n        os.mkdir(save_model_dir)\n        print('INFO: Create directory {}'.format(save_model_dir))\n    model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}'\n    model_state_file = save_model_dir+model_name\n    perf_per_rel = {}\n    use_cuda = args.gpu >= 0 and torch.cuda.is_available()\n\n    num_static_rels, num_words, static_triples, static_graph = 0, 0, [], None\n\n    # create stat\n    model = RecurrentRGCNREGCN(args.decoder,\n                          args.encoder,\n                        num_nodes,\n                        int(num_rels/2),\n                        num_static_rels, # DIFFERENT\n                        num_words, # DIFFERENT\n                        args.n_hidden,\n                        args.opn,\n                        sequence_len=args.train_history_len,\n                        num_bases=args.n_bases,\n                        num_basis=args.n_basis,\n                        num_hidden_layers=args.n_layers,\n                        dropout=args.dropout,\n                        self_loop=args.self_loop,\n                        skip_connect=args.skip_connect,\n                        layer_norm=args.layer_norm,\n                        input_dropout=args.input_dropout,\n                        hidden_dropout=args.hidden_dropout,\n                        feat_dropout=args.feat_dropout,\n                        aggregation=args.aggregation, # DIFFERENT\n                        weight=args.weight, # DIFFERENT\n                        discount=args.discount, # DIFFERENT\n                        angle=args.angle, # DIFFERENT\n                        use_static=args.add_static_graph, # DIFFERENT\n                        entity_prediction=args.entity_prediction, \n                        relation_prediction=args.relation_prediction,\n                        use_cuda=use_cuda,\n                        gpu = args.gpu,\n                        analysis=args.run_analysis) # DIFFERENT\n\n    if use_cuda:\n        torch.cuda.set_device(args.gpu)\n        model.cuda()\n\n    # optimizer\n    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)\n\n    if args.test and os.path.exists(model_state_file):\n        mrr, perf_per_rel, hits10 = test(model, \n                    train_list+valid_list, \n                    test_list, \n                    num_rels, \n                    num_nodes, \n                    use_cuda, \n                    model_state_file, \n                    static_graph, \n                    \"test\", \n                    \"test\")\n        return mrr, perf_per_rel, hits10\n    elif args.test and not os.path.exists(model_state_file):\n        print(\"--------------{} not exist, Change mode to train and generate stat for testing----------------\\n\".format(model_state_file))\n        return 0, 0\n    else:\n        print(\"----------------------------------------start training----------------------------------------\\n\")\n        best_mrr = 0\n        best_hits = 0\n        for epoch in range(args.n_epochs):\n\n            model.train()\n            losses = []\n            idx = [_ for _ in range(len(train_list))]\n            random.shuffle(idx)\n\n            for train_sample_num in tqdm(idx):\n                if train_sample_num == 0: continue\n                output = train_list[train_sample_num:train_sample_num+1]\n                if train_sample_num - args.train_history_len<0:\n                    input_list = train_list[0: train_sample_num]\n                else:\n                    input_list = train_list[train_sample_num - args.train_history_len:\n                                        train_sample_num]\n\n                # generate history graph\n                history_glist = [build_sub_graph(num_nodes, num_rels, snap, use_cuda, args.gpu) for snap in input_list]\n                output = [torch.from_numpy(_).long().cuda() for _ in output] if use_cuda else [torch.from_numpy(_).long() for _ in output]\n                loss_e, loss_r, loss_static = model.get_loss(history_glist, output[0], static_graph, use_cuda)\n                loss = args.task_weight*loss_e + (1-args.task_weight)*loss_r + loss_static\n\n                losses.append(loss.item())\n\n                loss.backward()\n                torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm)  # clip gradients\n                optimizer.step()\n                optimizer.zero_grad()\n\n            print(\"Epoch {:04d} | Ave Loss: {:.4f} | Best MRR {:.4f} | Model {} \"\n                  .format(epoch, np.mean(losses),  best_mrr, model_name))\n            \n            #! checking GPU usage\n            free_mem, total_mem = torch.cuda.mem_get_info()\n            print (\"--------------GPU memory usage-----------\")\n            print (\"there are \", free_mem, \" free memory\")\n            print (\"there are \", total_mem, \" total available memory\")\n            print (\"there are \", total_mem - free_mem, \" used memory\")\n            print (\"--------------GPU memory usage-----------\")\n\n            # validation\n            if epoch and epoch % args.evaluate_every == 0:\n                mrr,perf_per_rel, hits10 = test(model, train_list, \n                            valid_list, \n                            num_rels, \n                            num_nodes, \n                            use_cuda, \n                            model_state_file, \n                            static_graph, \n                            mode=\"train\", split_mode='val')\n            \n                if mrr < best_mrr:\n                    if epoch >= args.n_epochs:\n                        break\n                else:\n                    best_mrr = mrr\n                    best_hits = hits10\n                    torch.save({'state_dict': model.state_dict(), 'epoch': epoch}, model_state_file)\n\n        return best_mrr, perf_per_rel, hits10\n# ==================\n# ==================\n# ==================\n\nstart_overall = timeit.default_timer()\n\n# set hyperparameters\nargs, _ = get_args_regcn()\nargs.dataset = 'tkgl-yago'\n\nSEED = args.seed  # set the random seed for consistency\nset_random_seed(SEED)\n\nDATA=args.dataset\nMODEL_NAME = 'REGCN'\n\nprint(\"logging mrrs per relation: \", args.log_per_rel)\n\n# load data\ndataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\n\nnum_rels = dataset.num_rels\nnum_nodes = dataset.num_nodes \nsubjects = dataset.full_data[\"sources\"]\nobjects= dataset.full_data[\"destinations\"]\nrelations = dataset.edge_type\n\ntimestamps_orig = dataset.full_data[\"timestamps\"]\ntimestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1\nall_quads = np.stack((subjects, relations, objects, timestamps), axis=1)\n\ntrain_data = all_quads[dataset.train_mask]\nval_data = all_quads[dataset.val_mask]\ntest_data = all_quads[dataset.test_mask]\n\nval_timestamps_orig = list(set(timestamps_orig[dataset.val_mask])) # needed for getting the negative samples\nval_timestamps_orig.sort()\ntest_timestamps_orig = list(set(timestamps_orig[dataset.test_mask])) # needed for getting the negative samples\ntest_timestamps_orig.sort()\n\ntrain_list = split_by_time(train_data)\nvalid_list = split_by_time(val_data)\ntest_list = split_by_time(test_data)\n\n# evaluation metric\nMETRIC = dataset.eval_metric\nevaluator = Evaluator(name=DATA)\nneg_sampler = dataset.negative_sampler\n#load the ns samples \ndataset.load_val_ns()\ndataset.load_test_ns()\n\n## run training and testing\nval_mrr, test_mrr = 0, 0\ntest_hits10 = 0\nif args.grid_search:\n    print(\"hyperparameter grid search not implemented. Exiting.\")\n# single run\nelse:\n    start_train = timeit.default_timer()\n    if args.test == False: #if they are true: directly test on a previously trained and stored model\n        print('start training')\n        val_mrr, perf_per_rel, val_hits10 = run_experiment(args) # do training\n    start_test = timeit.default_timer()\n    args.test = True\n    print('start testing')\n    test_mrr, perf_per_rel, test_hits10 = run_experiment(args) # do testing\n\n\ntest_time = timeit.default_timer() - start_test\nall_time = timeit.default_timer() - start_train\nprint(f\"\\tTest: Elapsed Time (s): {test_time: .4f}\")\nprint(f\"\\Train and Test: Elapsed Time (s): {all_time: .4f}\")\n\nprint(f\"\\tTest: {METRIC}: {test_mrr: .4f}\")\nprint(f\"\\tValid: {METRIC}: {val_mrr: .4f}\")\n\n# saving the results...\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\nresults_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'\n\nsave_results({'model': MODEL_NAME,\n              'data': DATA,\n              'run': args.run_nr,\n              'seed': SEED,\n              f'val {METRIC}': float(val_mrr),\n              f'test {METRIC}': float(test_mrr),\n              'test_time': test_time,\n              'tot_train_val_time': all_time,\n              'test_hits10': float(test_hits10)\n              }, \n    results_filename)\n\nif args.log_per_rel:\n    results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results_per_rel.json'\n    with open(results_filename, 'w') as json_file:\n        json.dump(perf_per_rel, json_file)\n\nsys.exit()"
  },
  {
    "path": "examples/linkproppred/tkgl-yago/timetraveler.py",
    "content": "\"\"\"\nTimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting\nReference:\n- https://github.com/JHL-HUST/TITer\nHaohai Sun, Jialun Zhong, Yunpu Ma, Zhen Han, Kun He.\nTimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting EMNLP 2021\n\"\"\"\nimport sys\nimport timeit\n\nimport torch\nfrom torch.utils.data import Dataset,DataLoader\nimport logging\n\nimport numpy as np\nimport pickle\nfrom tqdm import tqdm\nimport os.path as osp\nfrom pathlib import Path\nimport os\n\nmodules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\nsys.path.append(modules_path)\nfrom modules.timetraveler_agent import Agent\nfrom modules.timetraveler_environment import Env\nfrom modules.timetraveler_dirichlet import Dirichlet, MLE_Dirchlet\nfrom modules.timetraveler_episode import Episode\nfrom modules.timetraveler_policygradient import PG\nfrom tgb.linkproppred.dataset import LinkPropPredDataset \nfrom tgb.linkproppred.evaluate import Evaluator\nfrom modules.timetraveler_trainertester import Trainer, Tester, getRelEntCooccurrence\nfrom tgb.utils.utils import set_random_seed,save_results \nfrom modules.tkg_utils import  get_args_timetraveler, reformat_ts, get_model_config_timetraveler\n\nclass QuadruplesDataset(Dataset):\n    \"\"\" this is an internal way how Timetraveler represents the data\n    \"\"\"\n    def __init__(self, examples):\n        \"\"\"\n        examples: a list of quadruples.\n        num_r: number of relations\n        \"\"\"\n        self.quadruples = examples.copy()\n\n\n    def __len__(self):\n        return len(self.quadruples)\n\n    def __getitem__(self, item):\n        return self.quadruples[item][0], \\\n               self.quadruples[item][1], \\\n               self.quadruples[item][2], \\\n               self.quadruples[item][3], \\\n               self.quadruples[item][4]\n    \ndef set_logger(save_path):\n    \"\"\"Write logs to checkpoint and console\"\"\"\n    if args.do_train:\n        log_file = os.path.join(save_path, 'train.log')\n    else:\n        log_file = os.path.join(save_path, 'test.log')\n\n    logging.basicConfig(\n        format='%(asctime)s %(levelname)-8s %(message)s',\n        level=logging.INFO,\n        datefmt='%Y-%m-%d %H:%M:%S',\n        filename=log_file,\n        filemode='w'\n    )\n    console = logging.StreamHandler()\n    console.setLevel(logging.INFO)\n    formatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s')\n    console.setFormatter(formatter)\n    logging.getLogger('').addHandler(console)\n\ndef preprocess_data(args, config, timestamps, save_path, all_quads):\n    \"\"\"\n    Preprocess the data and save the state-action space (pickle dump)\n    \"\"\"\n    # parser = argparse.ArgumentParser(description='Data preprocess', usage='preprocess_data.py [<args>] [-h | --help]')\n    # parser.add_argument('--data_dir', default='data/ICEWS14', type=str, help='Path to data.')\n\n    env = Env(all_quads, config)\n    state_actions_space = {}\n\n    with tqdm(total=len(all_quads)) as bar:\n        for (head, rel, tail, t, _) in all_quads:\n            if (head, t, True) not in state_actions_space.keys():\n                state_actions_space[(head, t, True)] = env.get_state_actions_space_complete(head, t, True, args.store_actions_num)\n                state_actions_space[(head, t, False)] = env.get_state_actions_space_complete(head, t, False, args.store_actions_num)\n            if (tail, t, True) not in state_actions_space.keys():\n                state_actions_space[(tail, t, True)] = env.get_state_actions_space_complete(tail, t, True, args.store_actions_num)\n                state_actions_space[(tail, t, False)] = env.get_state_actions_space_complete(tail, t, False, args.store_actions_num)\n            bar.update(1)\n    pickle.dump(state_actions_space, open(os.path.join(save_path,  args.state_actions_path), 'wb'))\n\ndef log_metrics(mode, step, metrics):\n    \"\"\"Print the evaluation logs\"\"\"\n    for metric in metrics:\n        logging.info('%s %s at epoch %d: %f' % (mode, metric, step, metrics[metric]))\n\ndef main(args):\n    \"\"\"\n    Main function to train and test the TimeTraveler model\"\"\"\n\n    start_overall = timeit.default_timer()\n    #######################Set Logger#################################\n    \n    save_path = f'{os.path.dirname(os.path.abspath(__file__))}/saved_models/'\n    if not os.path.exists(save_path):\n        os.makedirs(save_path)\n\n    if args.cuda and torch.cuda.is_available():\n        args.cuda = True\n    else:\n        args.cuda = False\n    set_logger(save_path)\n\n    #######################Create DataLoader#################################\n    # set hyperparameters\n    args.dataset = 'tkgl-yago'\n\n    SEED = args.seed  # set the random seed for consistency\n    set_random_seed(SEED)\n\n    DATA=args.dataset\n    MODEL_NAME = 'TIMETRAVELER'\n\n    # load data\n    dataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\n\n    num_rels = dataset.num_rels\n    num_nodes = dataset.num_nodes \n    subjects = dataset.full_data[\"sources\"]\n    objects= dataset.full_data[\"destinations\"]\n    relations = dataset.edge_type\n\n    timestamps_orig = dataset.full_data[\"timestamps\"]\n    timestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1\n    all_quads = np.stack((subjects, relations, objects, timestamps,timestamps_orig), axis=1)\n\n    train_data = all_quads[dataset.train_mask]\n    train_entities = np.unique(np.concatenate((train_data[:, 0], train_data[:, 2])))\n    RelEntCooccurrence = getRelEntCooccurrence(train_data, num_rels)\n    train_data =QuadruplesDataset(train_data)\n    val_data = QuadruplesDataset(all_quads[dataset.val_mask])\n    test_data = QuadruplesDataset(all_quads[dataset.test_mask])\n\n    METRIC = dataset.eval_metric\n    evaluator = Evaluator(name=DATA)\n    neg_sampler = dataset.negative_sampler\n    #load the ns samples \n    dataset.load_val_ns()\n    dataset.load_test_ns()\n\n    train_dataloader = DataLoader(\n        train_data,\n        batch_size=args.batch_size,\n        shuffle=True,\n        num_workers=args.num_workers,\n    )\n\n    valid_dataloader = DataLoader(\n        val_data,\n        batch_size=args.test_batch_size,\n        shuffle=False,\n        num_workers=args.num_workers,\n    )\n\n    test_dataloader = DataLoader(\n        test_data,\n        batch_size=args.test_batch_size,\n        shuffle=False,\n        num_workers=args.num_workers,\n    )\n\n    ######################Creat the agent and the environment###########################\n    config = get_model_config_timetraveler(args, num_nodes, num_rels)\n    logging.info(config)\n    logging.info(args)\n\n    # creat the agent\n    agent = Agent(config)\n\n\n    # creat the environment\n    state_actions_path = os.path.join(save_path, args.state_actions_path)\n\n\n    ######################preprocessing###########################\n    if not os.path.exists(state_actions_path):\n        if args.preprocess:\n            print(\"preprocessing data...\")\n            preprocess_data(args, config, timestamps, save_path, list(all_quads))\n            state_action_space = pickle.load(open(os.path.join(save_path, args.state_actions_path), 'rb'))\n        else:\n            state_action_space = None\n    else:\n        print(\"load preprocessed data...\")\n        state_action_space = pickle.load(open(os.path.join(save_path, args.state_actions_path), 'rb'))\n\n\n    env = Env(list(all_quads), config, state_action_space)\n    # Create episode controller\n    episode = Episode(env, agent, config)\n    if args.cuda:\n        episode = episode.cuda()\n    pg = PG(config)  # Policy Gradient\n    optimizer = torch.optim.Adam(episode.parameters(), lr=args.lr, weight_decay=0.00001)\n\n    ######################Reward Shaping: MLE DIRICHLET alphas###########################\n    if args.reward_shaping: \n        try:\n            print(\"load alphas from pickle file\")\n            alphas = pickle.load(open(os.path.join(save_path, args.alphas_pkl), 'rb'))\n        except:\n            print('running MLE dirichlet now')\n            mle_d = MLE_Dirchlet(all_quads, num_rels, args.k, args.time_span,\n                         args.tol, args.method, args.maxiter)\n            pickle.dump(mle_d.alphas, open(os.path.join(save_path, args.alphas_pkl), 'wb'))\n\n            print('dumped alphas')\n            alphas = mle_d.alphas\n        distributions = Dirichlet(alphas, args.k)\n    else:\n        distributions = None\n\n    ######################Training and Testing###########################\n\n    trainer = Trainer(episode, pg, optimizer, args, distributions)\n    tester = Tester(episode, args, train_entities, RelEntCooccurrence, dataset.metric)\n    test_metrics ={}\n    val_metrics = {}\n    test_metrics[METRIC] = None\n    val_metrics[METRIC] = None\n\n    if args.do_train:\n        start_train =timeit.default_timer()\n        logging.info('Start Training......')\n        for i in range(args.max_epochs):\n            loss, reward = trainer.train_epoch(train_dataloader, len(train_data))\n            logging.info('Epoch {}/{} Loss: {}, reward: {}'.format(i, args.max_epochs, loss, reward))\n\n            #! checking GPU usage\n            free_mem, total_mem = torch.cuda.mem_get_info()\n            print (\"--------------GPU memory usage-----------\")\n            print (\"there are \", free_mem, \" free memory\")\n            print (\"there are \", total_mem, \" total available memory\")\n            print (\"there are \", total_mem - free_mem, \" used memory\")\n            print (\"--------------GPU memory usage-----------\")\n            \n            if i % args.save_epoch == 0 and i != 0:\n                trainer.save_model(save_path, 'checkpoint_{}.pth'.format(i))\n                logging.info('Save Model in {}'.format(save_path))\n\n            if i % args.valid_epoch == 0 and i != 0:\n                logging.info('Start Val......')\n                val_metrics = tester.test(valid_dataloader,\n                                      len(val_data), num_nodes, neg_sampler, evaluator, split_mode='val')\n                for mode in val_metrics.keys():\n                    logging.info('{} at epoch {}: {}'.format(mode, i, val_metrics[mode]))\n\n        trainer.save_model(save_path)\n        logging.info('Save Model in {}'.format(save_path))\n    else:\n          # # Load the model parameters\n        if os.path.isfile(save_path):\n            params = torch.load(save_path)\n            episode.load_state_dict(params['model_state_dict'])\n            optimizer.load_state_dict(params['optimizer_state_dict'])\n            logging.info('Load pretrain model: {}'.format(save_path))\n    if args.do_test:\n        logging.info('Start Testing......')\n        start_test = timeit.default_timer()\n        test_metrics = tester.test(test_dataloader,\n                              len(test_data), num_nodes, neg_sampler, evaluator, split_mode='test')\n        for mode in test_metrics.keys():\n            logging.info('Test {} : {}'.format(mode, test_metrics[mode]))\n\n        # saving the results...\n        results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\n        if not osp.exists(results_path):\n            os.mkdir(results_path)\n            print('INFO: Create directory {}'.format(results_path))\n        Path(results_path).mkdir(parents=True, exist_ok=True)\n        results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'\n        test_time = timeit.default_timer() - start_test\n        all_time = timeit.default_timer() - start_train \n        all_time_preprocess = timeit.default_timer() - start_overall \n\n        save_results({'model': MODEL_NAME,\n                    'data': DATA,\n                    'seed': SEED,\n                    f'val {METRIC}': float(val_metrics[METRIC]),\n                    f'test {METRIC}': float(test_metrics[METRIC]),\n                    'test_time': test_time,\n                    'tot_train_val_time': all_time,\n                    'tot_preprocess_train_val_time': all_time_preprocess\n                    }, \n            results_filename)     \n\nif __name__ == '__main__':\n    args = get_args_timetraveler()\n    main(args)"
  },
  {
    "path": "examples/linkproppred/tkgl-yago/tkgl-yago_example.py",
    "content": "import numpy as np\r\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\r\nfrom tgb.linkproppred.evaluate import Evaluator\r\n\r\nDATA = \"tkgl-yago\"\r\n\r\n# data loading\r\ndataset = PyGLinkPropPredDataset(name=DATA, root=\"datasets\")\r\ntrain_mask = dataset.train_mask\r\nval_mask = dataset.val_mask\r\ntest_mask = dataset.test_mask\r\ndata = dataset.get_TemporalData()\r\nmetric = dataset.eval_metric\r\n\r\nprint (\"there are {} nodes and {} edges\".format(dataset.num_nodes, dataset.num_edges))\r\nprint (\"there are {} relation types\".format(dataset.num_rels))\r\n\r\n\r\ntimestamp = data.t\r\nhead = data.src\r\ntail = data.dst\r\nedge_type = data.edge_type #relation\r\nneg_sampler = dataset.negative_sampler\r\n\r\ntrain_data = data[train_mask]\r\nval_data = data[val_mask]\r\ntest_data = data[test_mask]\r\n\r\n\r\nmetric = dataset.eval_metric\r\nevaluator = Evaluator(name=DATA)\r\nneg_sampler = dataset.negative_sampler\r\n\r\n\r\n#load the ns samples first\r\ndataset.load_val_ns()\r\nfor i, (src, dst, t, rel) in enumerate(zip(val_data.src, val_data.dst, val_data.t, val_data.edge_type)):\r\n    #must use np array to query\r\n    neg_batch_list = neg_sampler.query_batch(np.array([src]), np.array([dst]), np.array([t]), edge_type=np.array([rel]), split_mode='val')\r\n\r\nprint (\"retrieved all negative samples\")\r\n\r\n\r\n# #* load numpy arrays instead\r\n# from tgb.linkproppred.dataset import LinkPropPredDataset\r\n\r\n# # data loading\r\n# dataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\r\n# data = dataset.full_data  \r\n# metric = dataset.eval_metric\r\n# sources = dataset.full_data['sources']\r\n# print (sources.dtype)\r\n\r\n"
  },
  {
    "path": "examples/linkproppred/tkgl-yago/tlogic.py",
    "content": "\"\"\"\nhttps://github.com/liu-yushan/TLogic/tree/main/mycode\nTLogic: Temporal Logical Rules for Explainable Link Forecasting on Temporal Knowledge Graphs.\nYushan Liu, Yunpu Ma, Marcel Hildebrandt, Mitchell Joblin, Volker Tresp\n\"\"\"\n\n# imports\nimport sys\nimport os\nimport os.path as osp\nfrom pathlib import Path\nmodules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\nsys.path.append(modules_path)\nimport timeit\nimport argparse\nimport numpy as np\nimport json\nfrom joblib import Parallel, delayed\nimport itertools\n\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom tgb.linkproppred.dataset import LinkPropPredDataset \nfrom modules.tlogic_learn_modules import Temporal_Walk, Rule_Learner, store_edges\nimport modules.tlogic_apply_modules as ra\nfrom tgb.utils.utils import set_random_seed,  save_results\nfrom modules.tkg_utils import reformat_ts, get_inv_relation_id, create_scores_array\n\ndef learn_rules(i, num_relations):\n    \"\"\"\n    Learn rules (multiprocessing possible).\n\n    Parameters:\n        i (int): process number\n        num_relations (int): minimum number of relations for each process\n\n    Returns:\n        rl.rules_dict (dict): rules dictionary\n    \"\"\"\n\n    # if seed:\n    #     np.random.seed(seed)\n\n    num_rest_relations = len(all_relations) - (i + 1) * num_relations\n    if num_rest_relations >= num_relations:\n        relations_idx = range(i * num_relations, (i + 1) * num_relations)\n    else:\n        relations_idx = range(i * num_relations, len(all_relations))\n\n    num_rules = [0]\n    for k in relations_idx:\n        rel = all_relations[k]\n        for length in rule_lengths:\n            it_start =  timeit.default_timer()\n            for _ in range(num_walks):\n                walk_successful, walk = temporal_walk.sample_walk(length + 1, rel)\n                if walk_successful:\n                    rl.create_rule(walk)\n            it_end =  timeit.default_timer()\n            it_time = round(it_end - it_start, 6)\n            num_rules.append(sum([len(v) for k, v in rl.rules_dict.items()]) // 2)\n            num_new_rules = num_rules[-1] - num_rules[-2]\n            print(\n                \"Process {0}: relation {1}/{2}, length {3}: {4} sec, {5} rules\".format(\n                    i,\n                    k - relations_idx[0] + 1,\n                    len(relations_idx),\n                    length,\n                    it_time,\n                    num_new_rules,\n                )\n            )\n\n    return rl.rules_dict\n\ndef apply_rules(i, num_queries, rules_dict, neg_sampler, data, window, learn_edges, all_quads, args, split_mode, \n                log_per_rel=False, num_rels=0):\n    \"\"\"\n    Apply rules (multiprocessing possible).\n\n    Parameters:\n        i (int): process number\n        num_queries (int): minimum number of queries for each process\n\n    Returns:\n        hits_list (list): hits list (hits@10 per sample)\n        perf_list (list): performance list (mrr per sample)\n    \"\"\"\n    perf_per_rel = {}\n    for rel in range(num_rels):\n            perf_per_rel[rel] = []\n    print(\"Start process\", i, \"...\")\n    all_candidates = [dict() for _ in range(len(args))]\n    no_cands_counter = 0\n\n    num_rest_queries = len(data) - (i + 1) * num_queries\n    if num_rest_queries >= num_queries:\n        test_queries_idx = range(i * num_queries, (i + 1) * num_queries)\n    else:\n        test_queries_idx = range(i * num_queries, len(data))\n\n    cur_ts = data[test_queries_idx[0]][3]\n    edges = ra.get_window_edges(all_quads[:,0:4], cur_ts, learn_edges, window)\n\n    it_start =  timeit.default_timer()\n    hits_list = [0] * len(test_queries_idx)\n    perf_list = [0] * len(test_queries_idx)\n    for index, j in enumerate(test_queries_idx):\n        neg_sample_el =  neg_sampler.query_batch(np.expand_dims(np.array(data[j,0]), axis=0), \n                                                np.expand_dims(np.array(data[j,2]), axis=0), \n                                                np.expand_dims(np.array(data[j,4]), axis=0), \n                                                np.expand_dims(np.array(data[j,1]), axis=0), \n                                                split_mode=split_mode)[0]        \n        \n        # neg_samples_batch[j]\n        pos_sample_el =  data[j,2]\n        test_query = data[j]\n        assert pos_sample_el == test_query[2]\n        cands_dict = [dict() for _ in range(len(args))]\n\n        if test_query[3] != cur_ts:\n            cur_ts = test_query[3]\n            edges = ra.get_window_edges(all_quads[:,0:4], cur_ts, learn_edges, window)\n\n        if test_query[1] in rules_dict:\n            dicts_idx = list(range(len(args)))\n            for rule in rules_dict[test_query[1]]:\n                walk_edges = ra.match_body_relations(rule, edges, test_query[0])\n\n                if 0 not in [len(x) for x in walk_edges]:\n                    rule_walks = ra.get_walks(rule, walk_edges)\n                    if rule[\"var_constraints\"]:\n                        rule_walks = ra.check_var_constraints(\n                            rule[\"var_constraints\"], rule_walks\n                        )\n\n                    if not rule_walks.empty:\n                        cands_dict = ra.get_candidates(\n                            rule,\n                            rule_walks,\n                            cur_ts,\n                            cands_dict,\n                            score_func,\n                            args,\n                            dicts_idx,\n                        )\n                        for s in dicts_idx:\n                            cands_dict[s] = {\n                                x: sorted(cands_dict[s][x], reverse=True)\n                                for x in cands_dict[s].keys()\n                            }\n                            cands_dict[s] = dict(\n                                sorted(\n                                    cands_dict[s].items(),\n                                    key=lambda item: item[1],\n                                    reverse=True,\n                                )\n                            )\n                            top_k_scores = [v for _, v in cands_dict[s].items()][:top_k]\n                            unique_scores = list(\n                                scores for scores, _ in itertools.groupby(top_k_scores)\n                            )\n                            if len(unique_scores) >= top_k:\n                                dicts_idx.remove(s)\n                        if not dicts_idx:\n                            break\n\n            if cands_dict[0]:\n                for s in range(len(args)):\n                    # Calculate noisy-or scores\n                    scores = list(\n                        map(\n                            lambda x: 1 - np.product(1 - np.array(x)),\n                            cands_dict[s].values(),\n                        )\n                    )\n                    cands_scores = dict(zip(cands_dict[s].keys(), scores))\n                    noisy_or_cands = dict(\n                        sorted(cands_scores.items(), key=lambda x: x[1], reverse=True)\n                    )\n                    all_candidates[s][j] = noisy_or_cands\n            else:  # No candidates found by applying rules\n                no_cands_counter += 1\n                for s in range(len(args)):\n                    all_candidates[s][j] = dict()\n\n        else:  # No rules exist for this relation\n            no_cands_counter += 1\n            for s in range(len(args)):\n                all_candidates[s][j] = dict()\n\n        if not (j - test_queries_idx[0] + 1) % 100:\n            it_end =  timeit.default_timer()\n            it_time = round(it_end - it_start, 6)\n            print(\n                \"Process {0}: test samples finished: {1}/{2}, {3} sec\".format(\n                    i, j - test_queries_idx[0] + 1, len(test_queries_idx), it_time\n                )\n            )\n            it_start =  timeit.default_timer()\n\n        predictions = create_scores_array(all_candidates[s][j], num_nodes)  \n        predictions_of_interest_pos = np.array(predictions[pos_sample_el])\n        predictions_of_interest_neg = predictions[neg_sample_el]\n        input_dict = {\n            \"y_pred_pos\": predictions_of_interest_pos,\n            \"y_pred_neg\": predictions_of_interest_neg,\n            \"eval_metric\": ['mrr'], \n        }\n\n        predictions = evaluator.eval(input_dict)\n        perf_list[index] = predictions['mrr']\n        hits_list[index] = predictions['hits@10']\n        if split_mode == \"test\":\n            if log_per_rel:\n                perf_per_rel[test_query[1]].append(perf_list[index]) #test_query[1] is the relation index\n\n    if split_mode == \"test\":\n        if log_per_rel:   \n            for rel in range(num_rels):\n                if len(perf_per_rel[rel]) > 0:\n                    perf_per_rel[rel] = float(np.mean(perf_per_rel[rel]))\n                else:\n                    perf_per_rel.pop(rel)       \n               \n\n    return perf_list, hits_list, perf_per_rel\n\n\n## args\ndef get_args(): \n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--dataset\", \"-d\", default=\"tkgl-yago\", type=str) \n    parser.add_argument(\"--rule_lengths\", \"-l\", default=\"1\", type=int, nargs=\"+\")\n    parser.add_argument(\"--num_walks\", \"-n\", default=\"100\", type=int)\n    parser.add_argument(\"--transition_distr\", default=\"exp\", type=str)\n    parser.add_argument(\"--window\", \"-w\", default=0, type=int) # set to e.g. 200 if only the most recent 200 timesteps should be considered. set to -2 if multistep\n    parser.add_argument(\"--top_k\", default=20, type=int)\n    parser.add_argument(\"--num_processes\", \"-p\", default=1, type=int)\n    parser.add_argument(\"--alpha\", \"-alpha\",  default=0.99, type=float) # fix alpha. used if trainflag == false\n    # parser.add_argument(\"--train_flag\", \"-tr\",  default=True) # do we need training, ie selection of lambda and alpha\n    parser.add_argument(\"--save_config\", \"-c\",  default=True) # do we need to save the selection of lambda and alpha in config file?\n    parser.add_argument('--seed', type=int, help='Random seed', default=1)\n    parser.add_argument('--run_nr', type=int, help='Run Number', default=1)\n    parser.add_argument('--learn_rules_flag', type=bool, help='Do we want to learn the rules', default=True)\n    parser.add_argument('--rule_filename', type=str, help='if rules not learned: where are they stored', default='0_r[3]_n100_exp_s1_rules.json')\n    parser.add_argument('--log_per_rel', type=bool, help='Do we want to log mrr per relation', default=False)\n    parser.add_argument('--compute_valid_mrr', type=bool, help='Do we want to compute mrr for valid set', default=True)\n    parsed = vars(parser.parse_args())\n    return parsed\n\nstart_o =  timeit.default_timer()\n\n## get args\nparsed = get_args()\ndataset = parsed[\"dataset\"]\nrule_lengths = parsed[\"rule_lengths\"]\nrule_lengths = [rule_lengths] if (type(rule_lengths) == int) else rule_lengths\nnum_walks = parsed[\"num_walks\"]\ntransition_distr = parsed[\"transition_distr\"]\nnum_processes = parsed[\"num_processes\"]\nwindow = parsed[\"window\"]\ntop_k = parsed[\"top_k\"]\nlog_per_rel = parsed['log_per_rel']\n\nMODEL_NAME = 'TLogic'\nSEED = parsed['seed']  # set the random seed for consistency\nset_random_seed(SEED)\n\n## load dataset and prepare it accordingly\nname = parsed[\"dataset\"]\ncompute_valid_mrr = parsed[\"compute_valid_mrr\"]\ndataset = LinkPropPredDataset(name=name, root=\"datasets\", preprocess=True)\nDATA = name\n\nrelations = dataset.edge_type\nnum_rels = dataset.num_rels\n\nsubjects = dataset.full_data[\"sources\"]\nobjects= dataset.full_data[\"destinations\"]\nnum_nodes = dataset.num_nodes \ntimestamps_orig = dataset.full_data[\"timestamps\"]\ntimestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1\n\nall_quads = np.stack((subjects, relations, objects, timestamps, timestamps_orig), axis=1)\ntrain_data = all_quads[dataset.train_mask,0:4] # we do not need the original timestamps\nval_data = all_quads[dataset.val_mask,0:5]\ntest_data = all_quads[dataset.test_mask,0:5]\nall_data = all_quads[:,0:4]\n\nmetric = dataset.eval_metric\nevaluator = Evaluator(name=name)\nneg_sampler = dataset.negative_sampler\n\ninv_relation_id = get_inv_relation_id(num_rels)\n\n#load the ns samples \n\ndataset.load_val_ns()\ndataset.load_test_ns()\noutput_dir =  f'{osp.dirname(osp.abspath(__file__))}/saved_models/'\nlearn_rules_flag = parsed['learn_rules_flag']\n## 1. learn rules\nstart_train =  timeit.default_timer()\nif learn_rules_flag:\n    print(\"start learning rules\")\n    # edges (dict): edges for each relation\n    # inv_relation_id (dict): mapping of relation to inverse relation\n    \n    temporal_walk = Temporal_Walk(train_data, inv_relation_id, transition_distr)\n    rl = Rule_Learner(edges=temporal_walk.edges, id2relation=None, inv_relation_id=inv_relation_id,  \n                        output_dir=output_dir)\n    all_relations = sorted(temporal_walk.edges)  # Learn for all relations\n\n    start =  timeit.default_timer()\n    num_relations = len(all_relations) // num_processes\n    output = Parallel(n_jobs=num_processes)(\n        delayed(learn_rules)(i, num_relations) for i in range(num_processes)\n    )\n    end =  timeit.default_timer()\n\n    all_rules = output[0]\n    for i in range(1, num_processes):\n        all_rules.update(output[i])\n\n    total_time = round(end - start, 6)\n    print(\"Learning finished in {} seconds.\".format(total_time))\n\n    rl.rules_dict = all_rules\n    rl.sort_rules_dict()\n\n    rule_filename = rl.save_rules(0, rule_lengths, num_walks, transition_distr, SEED)\n    # rl.save_rules_verbalized(0, rule_lengths, num_walks, transition_distr, seed)\n    # rules_statistics(rl.rules_dict)\nelse:\n    rule_filename = parsed['rule_filename']\n    print(\"Loading rules from file {}\".format(parsed['rule_filename']))\n\nend_train =  timeit.default_timer()\n\n## 2. Apply rules\n\nrules_dict = json.load(open(output_dir + rule_filename))\nrules_dict = {int(k): v for k, v in rules_dict.items()}\n\nrules_dict = ra.filter_rules(\n    rules_dict, min_conf=0.01, min_body_supp=2, rule_lengths=rule_lengths\n) # filter rules for minimum confidence, body support and rule length\n\nlearn_edges = store_edges(train_data)\nscore_func = ra.score_12\n# It is possible to specify a list of list of arguments for tuning\nargs = [[0.1, 0.5]]\n\n# compute valid mrr\nstart_valid =  timeit.default_timer()\nif compute_valid_mrr:\n    print('Computing valid MRR')\n\n    num_queries = len(val_data) // num_processes\n\n    output = Parallel(n_jobs=num_processes)(\n        delayed(apply_rules)(i, num_queries,rules_dict, neg_sampler, val_data, window, learn_edges, \n                            all_quads, args, split_mode='val') for i in range(num_processes))\n    end =  timeit.default_timer()\n\n    perf_list_val = []\n    hits_list_val = []\n\n    for i in range(num_processes):\n        perf_list_val.extend(output[i][0])\n        hits_list_val.extend(output[i][1])\nelse:\n    perf_list_val = [0]\n    hits_list_val = [0]\n    \n\nend_valid =  timeit.default_timer()\n\n# compute test mrr\nif log_per_rel ==True:\n    num_processes = 1 #otherwise logging per rel does not work for our implementation\nstart_test =  timeit.default_timer()\nprint('Computing test MRR')\nstart =  timeit.default_timer()\nnum_queries = len(test_data) // num_processes\n\noutput = Parallel(n_jobs=num_processes)(\n    delayed(apply_rules)(i, num_queries,rules_dict, neg_sampler, test_data, window, learn_edges, \n                         all_quads, args, split_mode='test', log_per_rel=log_per_rel, num_rels=num_rels) for i in range(num_processes))\nend =  timeit.default_timer()\n\nperf_list_all = []\nhits_list_all = []\n\n\nfor i in range(num_processes):\n    perf_list_all.extend(output[i][0])\n    hits_list_all.extend(output[i][1])\nif log_per_rel == True:\n    perf_per_rel = output[0][2]\n\n\ntotal_time = round(end - start, 6)\ntotal_valid_time = round(end_valid - start_valid, 6)\nprint(\"Application finished in {} seconds.\".format(total_time))\n\nprint(f\"The valid MRR is {np.mean(perf_list_val)}\")\nprint(f\"The MRR is {np.mean(perf_list_all)}\")\nprint(f\"The Hits@10 is {np.mean(hits_list_all)}\")\nprint(f\"We have {len(perf_list_all)} predictions\")\nprint(f\"The test set has len {len(test_data)} \")\n\nend_o =  timeit.default_timer()\ntrain_time_o = round(end_train- start_train, 6)  \ntest_time_o = round(end_o- start_test, 6)  \ntotal_time_o = round(end_o- start_o, 6)  \nprint(\"Running Training to find best configs finished in {} seconds.\".format(train_time_o))\nprint(\"Running testing with best configs finished in {} seconds.\".format(test_time_o))\nprint(\"Running all steps finished in {} seconds.\".format(total_time_o))\n\nresults_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'\nif not osp.exists(results_path):\n    os.mkdir(results_path)\n    print('INFO: Create directory {}'.format(results_path))\nPath(results_path).mkdir(parents=True, exist_ok=True)\n\nif log_per_rel == True:\n    results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results_per_rel.json'\n    with open(results_filename, 'w') as json_file:\n        json.dump(perf_per_rel, json_file)\n\nresults_filename = f'{results_path}/{MODEL_NAME}_NONE_{DATA}_results.json'\nmetric = dataset.eval_metric\nsave_results({'model': MODEL_NAME,\n              'train_flag': None,\n              'rule_len': rule_lengths,\n              'window': window,\n              'data': DATA,\n              'run': 1,\n              'seed': SEED,\n              metric: float(np.mean(perf_list_all)),\n              'hits10': float(np.mean(hits_list_all)),\n              'val_mrr': float(np.mean(perf_list_val)),\n              'test_time': test_time_o,\n              'tot_train_val_time': total_time_o,\n              'valid_time': total_valid_time\n              }, \n    results_filename)\n"
  },
  {
    "path": "examples/nodeproppred/tgbn-genre/dyrep.py",
    "content": "\"\"\"\nDyRep\n    This has been implemented with intuitions from the following sources:\n    - https://github.com/twitter-research/tgn\n    - https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py\n\n    Spec.:\n        - Memory Updater: RNN\n        - Embedding Module: ID\n        - Message Function: ATTN\n\"\"\"\nimport timeit\nfrom tqdm import tqdm\nimport torch\nfrom torch_geometric.loader import TemporalDataLoader\n\n# internal imports\nfrom tgb.utils.utils import get_args, set_random_seed\nfrom tgb.nodeproppred.evaluate import Evaluator\nfrom modules.decoder import NodePredictor\nfrom modules.emb_module import GraphAttentionEmbedding\nfrom modules.msg_func import IdentityMessage\nfrom modules.msg_agg import LastAggregator\nfrom modules.neighbor_loader import LastNeighborLoader\nfrom modules.memory_module import DyRepMemory\nfrom tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset\n\n\ndef process_edges(src, dst, t, msg):\n    if src.nelement() > 0:\n        model['memory'].update_state(src, dst, t, msg)\n        neighbor_loader.insert(src, dst)\n\n\n# ==========\n# ========== Define helper function...\n# ==========\n\ndef train():\n    model['memory'].train()\n    model['gnn'].train()\n    model['node_pred'].train()\n\n    model['memory'].reset_state()  # Start with a fresh memory.\n    neighbor_loader.reset_state()  # Start with an empty graph.\n\n    total_loss = 0\n    label_t = dataset.get_label_time()  # check when does the first label start\n    num_label_ts = 0\n    total_score = 0\n    for batch in train_loader:\n        batch = batch.to(device)\n        optimizer.zero_grad()\n\n        src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg\n\n        query_t = batch.t[-1]\n        # check if this batch moves to the next day\n        if query_t > label_t:\n            # find the node labels from the past day\n            label_tuple = dataset.get_node_label(query_t)\n            label_ts, label_srcs, labels = (\n                label_tuple[0],\n                label_tuple[1],\n                label_tuple[2],\n            )\n            label_t = dataset.get_label_time()\n            label_srcs = label_srcs.to(device)\n\n            # Process all edges that are still in the past day\n            previous_day_mask = batch.t < label_t\n            process_edges(\n                src[previous_day_mask],\n                dst[previous_day_mask],\n                t[previous_day_mask],\n                msg[previous_day_mask],\n            )\n            # Reset edges to be the edges from tomorrow so they can be used later\n            src, dst, t, msg = (\n                src[~previous_day_mask],\n                dst[~previous_day_mask],\n                t[~previous_day_mask],\n                msg[~previous_day_mask],\n            )\n\n            \"\"\"\n            modified for node property prediction\n            1. sample neighbors from the neighbor loader for all nodes to be predicted\n            2. extract memory from the sampled neighbors and the nodes\n            3. run gnn with the extracted memory embeddings and the corresponding time and message\n            \"\"\"\n            n_id = label_srcs\n            n_id_neighbors, mem_edge_index, e_id = neighbor_loader(n_id)\n            assoc[n_id_neighbors] = torch.arange(n_id_neighbors.size(0), device=device)\n\n            z, last_update = model['memory'](n_id_neighbors)\n            z = model['gnn'](\n                z,\n                last_update,\n                mem_edge_index,\n                data.t[e_id].to(device),\n                data.msg[e_id].to(device),\n            )\n            z = z[assoc[n_id]]\n\n            # loss and metric computation\n            pred = model['node_pred'](z)\n\n            loss = criterion(pred, labels.to(device))\n            np_pred = pred.cpu().detach().numpy()\n            np_true = labels.cpu().detach().numpy()\n\n            input_dict = {\n                \"y_true\": np_true,\n                \"y_pred\": np_pred,\n                \"eval_metric\": [metric],\n            }\n            result_dict = evaluator.eval(input_dict)\n            score = result_dict[metric]\n            total_score += score\n            num_label_ts += 1\n\n            loss.backward()\n            optimizer.step()\n            total_loss += float(loss.detach())\n\n        # Update memory and neighbor loader with ground-truth state.\n        process_edges(src, dst, t, msg)\n        model['memory'].detach()\n\n    metric_dict = {\n        \"ce\": total_loss / num_label_ts,\n    }\n    metric_dict[metric] = total_score / num_label_ts\n    return metric_dict\n\n\n@torch.no_grad()\ndef test(loader):\n    model['memory'].eval()\n    model['gnn'].eval()\n    model['node_pred'].eval()\n\n    total_score = 0\n    label_t = dataset.get_label_time()  # check when does the first label start\n    num_label_ts = 0\n\n    for batch in tqdm(loader):\n        batch = batch.to(device)\n        src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg\n\n        query_t = batch.t[-1]\n        if query_t > label_t:\n            label_tuple = dataset.get_node_label(query_t)\n            if label_tuple is None:\n                break\n            label_ts, label_srcs, labels = (\n                label_tuple[0],\n                label_tuple[1],\n                label_tuple[2],\n            )\n            label_t = dataset.get_label_time()\n            label_srcs = label_srcs.to(device)\n\n            # Process all edges that are still in the past day\n            previous_day_mask = batch.t < label_t\n            process_edges(\n                src[previous_day_mask],\n                dst[previous_day_mask],\n                t[previous_day_mask],\n                msg[previous_day_mask],\n            )\n            # Reset edges to be the edges from tomorrow so they can be used later\n            src, dst, t, msg = (\n                src[~previous_day_mask],\n                dst[~previous_day_mask],\n                t[~previous_day_mask],\n                msg[~previous_day_mask],\n            )\n\n            \"\"\"\n            modified for node property prediction\n            1. sample neighbors from the neighbor loader for all nodes to be predicted\n            2. extract memory from the sampled neighbors and the nodes\n            3. run gnn with the extracted memory embeddings and the corresponding time and message\n            \"\"\"\n            n_id = label_srcs\n            n_id_neighbors, mem_edge_index, e_id = neighbor_loader(n_id)\n            assoc[n_id_neighbors] = torch.arange(n_id_neighbors.size(0), device=device)\n\n            z, last_update = model['memory'](n_id_neighbors)\n            z = model['gnn'](\n                z,\n                last_update,\n                mem_edge_index,\n                data.t[e_id].to(device),\n                data.msg[e_id].to(device),\n            )\n            z = z[assoc[n_id]]\n\n            # loss and metric computation\n            pred = model['node_pred'](z)\n            np_pred = pred.cpu().detach().numpy()\n            np_true = labels.cpu().detach().numpy()\n\n            input_dict = {\n                \"y_true\": np_true,\n                \"y_pred\": np_pred,\n                \"eval_metric\": [metric],\n            }\n            result_dict = evaluator.eval(input_dict)\n            score = result_dict[metric]\n            total_score += score\n            num_label_ts += 1\n\n        process_edges(src, dst, t, msg)\n\n    metric_dict = {}\n    metric_dict[metric] = total_score / num_label_ts\n    return metric_dict\n\n# ==========\n# ==========\n# ==========\n\n# Start...\nstart_overall = timeit.default_timer()\n\n# ========== set parameters...\nargs, _ = get_args()\nprint(\"INFO: Arguments:\", args)\n\nDATA = \"tgbn-genre\"\nLR = args.lr\nBATCH_SIZE = args.bs\nK_VALUE = args.k_value  \nNUM_EPOCH = args.num_epoch\nSEED = args.seed\nMEM_DIM = args.mem_dim\nTIME_DIM = args.time_dim\nEMB_DIM = args.emb_dim\nTOLERANCE = args.tolerance\nPATIENCE = args.patience\nNUM_RUNS = args.num_run\nNUM_NEIGHBORS = 10\n\n\n# setting random seed\ntorch.manual_seed(SEED)\nset_random_seed(SEED)\n\nMODEL_NAME = 'DyRep'\nUSE_SRC_EMB_IN_MSG = False\nUSE_DST_EMB_IN_MSG = True\nevaluator = Evaluator(name=DATA)\n# ==========\n\n# set the device\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n# data loading\ndataset = PyGNodePropPredDataset(name=DATA, root=\"datasets\")\ntrain_mask = dataset.train_mask\nval_mask = dataset.val_mask\ntest_mask = dataset.test_mask\ndata = dataset.get_TemporalData()\ndata = data.to(device)\nmetric = dataset.eval_metric\n\ntrain_data = data[train_mask]\nval_data = data[val_mask]\ntest_data = data[test_mask]\nnum_classes = dataset.num_classes\n\n\ntrain_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)\nval_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)\ntest_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)\n\n# Ensure to only sample actual destination nodes as negatives.\nmin_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\n\n# neighborhood sampler\nneighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)\n\n# define the model end-to-end\nmemory = DyRepMemory(\n    data.num_nodes,\n    data.msg.size(-1),\n    MEM_DIM,\n    TIME_DIM,\n    message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),\n    aggregator_module=LastAggregator(),\n    memory_updater_type='rnn',\n    use_src_emb_in_msg=USE_SRC_EMB_IN_MSG,\n    use_dst_emb_in_msg=USE_DST_EMB_IN_MSG\n).to(device)\n\ngnn = GraphAttentionEmbedding(\n    in_channels=MEM_DIM,\n    out_channels=EMB_DIM,\n    msg_dim=data.msg.size(-1),\n    time_enc=memory.time_enc,\n).to(device)\n\nnode_pred = NodePredictor(in_dim=EMB_DIM, out_dim=num_classes).to(device)\n\nmodel = {'memory': memory,\n         'gnn': gnn,\n         'node_pred': node_pred}\n\noptimizer = torch.optim.Adam(\n    set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['node_pred'].parameters()),\n    lr=LR,\n)\n\ncriterion = torch.nn.CrossEntropyLoss()\n# Helper vector to map global node indices to local ones.\nassoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)\n\ntrain_curve = []\nval_curve = []\ntest_curve = []\nmax_val_score = 0  #find the best test score based on validation score\nbest_test_idx = 0\nfor epoch in range(1, NUM_EPOCH + 1):\n    start_time = timeit.default_timer()\n    train_dict = train()\n    print(\"------------------------------------\")\n    print(f\"training Epoch: {epoch:02d}\")\n    print(train_dict)\n    train_curve.append(train_dict[metric])\n    print(\"Training takes--- %s seconds ---\" % (timeit.default_timer() - start_time))\n    \n    start_time = timeit.default_timer()\n    val_dict = test(val_loader)\n    print(val_dict)\n    val_curve.append(val_dict[metric])\n    if (val_dict[metric] > max_val_score):\n        max_val_score = val_dict[metric]\n        best_test_idx = epoch - 1\n    print(\"Validation takes--- %s seconds ---\" % (timeit.default_timer() - start_time))\n\n    start_time = timeit.default_timer()\n    test_dict = test(test_loader)\n    print(test_dict)\n    test_curve.append(test_dict[metric])\n    print(\"Test takes--- %s seconds ---\" % (timeit.default_timer() - start_time))\n    print(\"------------------------------------\")\n    dataset.reset_label_time()\n\nmax_test_score = test_curve[best_test_idx]\nprint(\"------------------------------------\")\nprint(\"------------------------------------\")\nprint (\"best val score: \", max_val_score)\nprint (\"best validation epoch   : \", best_test_idx + 1)\nprint (\"best test score: \", max_test_score)\n\n"
  },
  {
    "path": "examples/nodeproppred/tgbn-genre/moving_average.py",
    "content": "\"\"\"\nimplement persistant forecast as baseline for the node prop pred task\nsimply predict last seen label for the node\n\"\"\"\n\nimport timeit\nimport numpy as np\nfrom torch_geometric.loader import TemporalDataLoader\n\n# local imports\nfrom tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset\nfrom modules.heuristics import MovingAverage\nfrom tgb.nodeproppred.evaluate import Evaluator\n\nwindow = 6\ndevice = \"cpu\"\nname = \"tgbn-genre\"\ndataset = PyGNodePropPredDataset(name=name, root=\"datasets\")\nnum_classes = dataset.num_classes\ndata = dataset.get_TemporalData()\ndata = data.to(device)\n\neval_metric = dataset.eval_metric\nforecaster = MovingAverage(num_classes, window=window)\nevaluator = Evaluator(name=name)\n\n\n# Ensure to only sample actual destination nodes as negatives.\nmin_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\ntrain_data, val_data, test_data = data.train_val_test_split(\n    val_ratio=0.15, test_ratio=0.15\n)\n\nbatch_size = 200\n\ntrain_loader = TemporalDataLoader(train_data, batch_size=batch_size)\nval_loader = TemporalDataLoader(val_data, batch_size=batch_size)\ntest_loader = TemporalDataLoader(test_data, batch_size=batch_size)\n\n\ndef test_n_upate(loader):\n    label_t = dataset.get_label_time()  # check when does the first label start\n    num_label_ts = 0\n    total_score = 0\n\n    for batch in loader:\n        batch = batch.to(device)\n        src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg\n\n        query_t = batch.t[-1]\n        if query_t > label_t:\n            label_tuple = dataset.get_node_label(query_t)\n            if label_tuple is None:\n                break\n            label_ts, label_srcs, labels = (\n                label_tuple[0],\n                label_tuple[1],\n                label_tuple[2],\n            )\n            label_ts = label_ts.numpy()\n            label_srcs = label_srcs.numpy()\n            labels = labels.numpy()\n            label_t = dataset.get_label_time()\n\n            preds = []\n\n            for i in range(0, label_srcs.shape[0]):\n                node_id = label_srcs[i]\n                pred_vec = forecaster.query_dict(node_id)\n                preds.append(pred_vec)\n                forecaster.update_dict(node_id, labels[i])\n\n            np_pred = np.stack(preds, axis=0)\n            np_true = labels\n\n            input_dict = {\n                \"y_true\": np_true,\n                \"y_pred\": np_pred,\n                \"eval_metric\": [eval_metric],\n            }\n            result_dict = evaluator.eval(input_dict)\n            score = result_dict[eval_metric]\n\n            total_score += score\n            num_label_ts += 1\n\n    metric_dict = {}\n    metric_dict[eval_metric] = total_score / num_label_ts\n    return metric_dict\n\n\n\"\"\"\ntrain, val and test for one epoch only\n\"\"\"\n\nstart_time = timeit.default_timer()\nmetric_dict = test_n_upate(train_loader)\nprint(metric_dict)\nprint(\n    \"Persistant forecast on Training takes--- %s seconds ---\"\n    % (timeit.default_timer() - start_time)\n)\n\nstart_time = timeit.default_timer()\nval_dict = test_n_upate(val_loader)\nprint(val_dict)\nprint(\n    \"Persistant forecast on validation takes--- %s seconds ---\"\n    % (timeit.default_timer() - start_time)\n)\n\n\nstart_time = timeit.default_timer()\ntest_dict = test_n_upate(test_loader)\nprint(test_dict)\nprint(\n    \"Persistant forecast on Test takes--- %s seconds ---\" % (timeit.default_timer() - start_time)\n)\ndataset.reset_label_time()\n"
  },
  {
    "path": "examples/nodeproppred/tgbn-genre/persistant_forecast.py",
    "content": "\"\"\"\nimplement persistant forecast as baseline for the node prop pred task\nsimply predict last seen label for the node\n\"\"\"\n\nimport timeit\nimport numpy as np\nfrom torch_geometric.loader import TemporalDataLoader\nfrom tqdm import tqdm\nimport torch\n\n# local imports\nfrom tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset\nfrom modules.heuristics import PersistantForecaster\nfrom tgb.nodeproppred.evaluate import Evaluator\n\n\ndevice = \"cpu\"\nname = \"tgbn-genre\"\ndataset = PyGNodePropPredDataset(name=name, root=\"datasets\")\nnum_classes = dataset.num_classes\ndata = dataset.get_TemporalData()\ndata = data.to(device)\n\nall_nodes = torch.cat((data.src, data.dst), 0)\nall_nodes = all_nodes.unique()\nprint (all_nodes.shape[0])\n\neval_metric = dataset.eval_metric\nforecaster = PersistantForecaster(num_classes)\nevaluator = Evaluator(name=name)\n\n\n# Ensure to only sample actual destination nodes as negatives.\nmin_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\ntrain_data, val_data, test_data = data.train_val_test_split(\n    val_ratio=0.15, test_ratio=0.15\n)\n\nbatch_size = 200\n\ntrain_loader = TemporalDataLoader(train_data, batch_size=batch_size)\nval_loader = TemporalDataLoader(val_data, batch_size=batch_size)\ntest_loader = TemporalDataLoader(test_data, batch_size=batch_size)\n\n\n\"\"\"\ncontinue debug here\n\"\"\"\n\n\ndef test_n_upate(loader):\n    label_t = dataset.get_label_time()  # check when does the first label start\n    num_label_ts = 0\n    total_score = 0\n\n    for batch in tqdm(loader):\n        batch = batch.to(device)\n        src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg\n\n        query_t = batch.t[-1]\n        if query_t > label_t:\n            label_tuple = dataset.get_node_label(query_t)\n            if label_tuple is None:\n                break\n            label_ts, label_srcs, labels = (\n                label_tuple[0],\n                label_tuple[1],\n                label_tuple[2],\n            )\n            label_ts = label_ts.numpy()\n            label_srcs = label_srcs.numpy()\n            labels = labels.numpy()\n            label_t = dataset.get_label_time()\n\n            preds = []\n\n            for i in range(0, label_srcs.shape[0]):\n                node_id = label_srcs[i]\n                pred_vec = forecaster.query_dict(node_id)\n                preds.append(pred_vec)\n                forecaster.update_dict(node_id, labels[i])\n\n            np_pred = np.stack(preds, axis=0)\n            np_true = labels\n\n            input_dict = {\n                \"y_true\": np_true,\n                \"y_pred\": np_pred,\n                \"eval_metric\": [eval_metric],\n            }\n            result_dict = evaluator.eval(input_dict)\n            score = result_dict[eval_metric]\n            total_score += score\n            num_label_ts += 1\n\n    metric_dict = {}\n    metric_dict[eval_metric] = total_score / num_label_ts\n    return metric_dict\n\n\n\"\"\"\ntrain, val and test for one epoch only\n\"\"\"\n\nstart_time = timeit.default_timer()\nmetric_dict = test_n_upate(train_loader)\nprint(metric_dict)\nprint(\n    \"Persistant forecast on Training takes--- %s seconds ---\"\n    % (timeit.default_timer() - start_time)\n)\n\nstart_time = timeit.default_timer()\nval_dict = test_n_upate(val_loader)\nprint(val_dict)\nprint(\n    \"Persistant forecast on validation takes--- %s seconds ---\"\n    % (timeit.default_timer() - start_time)\n)\n\n\nstart_time = timeit.default_timer()\ntest_dict = test_n_upate(test_loader)\nprint(test_dict)\nprint(\n    \"Persistant forecast on Test takes--- %s seconds ---\" % (timeit.default_timer() - start_time)\n)\ndataset.reset_label_time()\n"
  },
  {
    "path": "examples/nodeproppred/tgbn-genre/tgn.py",
    "content": "from tqdm import tqdm\r\nimport torch\r\nimport timeit\r\nimport argparse\r\nimport matplotlib.pyplot as plt\r\n\r\nfrom torch_geometric.loader import TemporalDataLoader\r\nfrom torch_geometric.nn import TGNMemory\r\nfrom torch_geometric.nn.models.tgn import (\r\n    IdentityMessage,\r\n    LastAggregator,\r\n    LastNeighborLoader,\r\n)\r\n\r\nfrom modules.decoder import NodePredictor\r\nfrom modules.emb_module import GraphAttentionEmbedding\r\nfrom tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset\r\nfrom tgb.nodeproppred.evaluate import Evaluator\r\nfrom tgb.utils.utils import set_random_seed\r\n\r\nparser = argparse.ArgumentParser(description='parsing command line arguments as hyperparameters')\r\nparser.add_argument('-s', '--seed', type=int, default=1,\r\n                    help='random seed to use')\r\nparser.parse_args()\r\nargs = parser.parse_args()\r\n# setting random seed\r\nseed = int(args.seed) #1,2,3,4,5\r\nprint (\"setting random seed to be\", seed)\r\ntorch.manual_seed(seed)\r\nset_random_seed(seed)\r\n\r\n# hyperparameters\r\nlr = 0.0001\r\nepochs = 50\r\n\r\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\r\nname = \"tgbn-genre\"\r\ndataset = PyGNodePropPredDataset(name=name, root=\"datasets\")\r\ntrain_mask = dataset.train_mask\r\nval_mask = dataset.val_mask\r\ntest_mask = dataset.test_mask\r\n\r\neval_metric = dataset.eval_metric\r\nnum_classes = dataset.num_classes\r\ndata = dataset.get_TemporalData()\r\ndata = data.to(device)\r\n\r\nevaluator = Evaluator(name=name)\r\n\r\n\r\ntrain_data = data[train_mask]\r\nval_data = data[val_mask]\r\ntest_data = data[test_mask]\r\n\r\n# Ensure to only sample actual destination nodes as negatives.\r\nmin_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\r\n\r\nbatch_size = 200\r\n\r\ntrain_loader = TemporalDataLoader(train_data, batch_size=batch_size)\r\nval_loader = TemporalDataLoader(val_data, batch_size=batch_size)\r\ntest_loader = TemporalDataLoader(test_data, batch_size=batch_size)\r\n\r\nneighbor_loader = LastNeighborLoader(data.num_nodes, size=10, device=device)\r\n\r\nmemory_dim = time_dim = embedding_dim = 100\r\n\r\nmemory = TGNMemory(\r\n    data.num_nodes,\r\n    data.msg.size(-1),\r\n    memory_dim,\r\n    time_dim,\r\n    message_module=IdentityMessage(data.msg.size(-1), memory_dim, time_dim),\r\n    aggregator_module=LastAggregator(),\r\n).to(device)\r\n\r\ngnn = (\r\n    GraphAttentionEmbedding(\r\n        in_channels=memory_dim,\r\n        out_channels=embedding_dim,\r\n        msg_dim=data.msg.size(-1),\r\n        time_enc=memory.time_enc,\r\n    )\r\n    .to(device)\r\n    .float()\r\n)\r\n\r\nnode_pred = NodePredictor(in_dim=embedding_dim, out_dim=num_classes).to(device)\r\n\r\noptimizer = torch.optim.Adam(\r\n    set(memory.parameters()) | set(gnn.parameters()) | set(node_pred.parameters()),\r\n    lr=lr,\r\n)\r\n\r\ncriterion = torch.nn.CrossEntropyLoss()\r\n# Helper vector to map global node indices to local ones.\r\nassoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)\r\n\r\n\r\ndef plot_curve(scores, out_name):\r\n    plt.plot(scores, color=\"#e34a33\")\r\n    plt.ylabel(\"score\")\r\n    plt.savefig(out_name + \".pdf\")\r\n    plt.close()\r\n\r\n\r\ndef process_edges(src, dst, t, msg):\r\n    if src.nelement() > 0:\r\n        # msg = msg.to(torch.float32)\r\n        memory.update_state(src, dst, t, msg)\r\n        neighbor_loader.insert(src, dst)\r\n\r\n\r\ndef train():\r\n    memory.train()\r\n    gnn.train()\r\n    node_pred.train()\r\n\r\n    memory.reset_state()  # Start with a fresh memory.\r\n    neighbor_loader.reset_state()  # Start with an empty graph.\r\n\r\n    total_loss = 0\r\n    label_t = dataset.get_label_time()  # check when does the first label start\r\n    total_score = 0\r\n    num_label_ts = 0\r\n\r\n    for batch in tqdm(train_loader):\r\n        batch = batch.to(device)\r\n        optimizer.zero_grad()\r\n        src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg\r\n\r\n        query_t = batch.t[-1]\r\n        # check if this batch moves to the next day\r\n        if query_t > label_t:\r\n            # find the node labels from the past day\r\n            label_tuple = dataset.get_node_label(query_t)\r\n            label_ts, label_srcs, labels = (\r\n                label_tuple[0],\r\n                label_tuple[1],\r\n                label_tuple[2],\r\n            )\r\n            label_t = dataset.get_label_time()\r\n            label_srcs = label_srcs.to(device)\r\n\r\n            # Process all edges that are still in the past day\r\n            previous_day_mask = batch.t < label_t\r\n            process_edges(\r\n                src[previous_day_mask],\r\n                dst[previous_day_mask],\r\n                t[previous_day_mask],\r\n                msg[previous_day_mask],\r\n            )\r\n            # Reset edges to be the edges from tomorrow so they can be used later\r\n            src, dst, t, msg = (\r\n                src[~previous_day_mask],\r\n                dst[~previous_day_mask],\r\n                t[~previous_day_mask],\r\n                msg[~previous_day_mask],\r\n            )\r\n\r\n            \"\"\"\r\n            modified for node property prediction\r\n            1. sample neighbors from the neighbor loader for all nodes to be predicted\r\n            2. extract memory from the sampled neighbors and the nodes\r\n            3. run gnn with the extracted memory embeddings and the corresponding time and message\r\n            \"\"\"\r\n            n_id = label_srcs\r\n            n_id_neighbors, mem_edge_index, e_id = neighbor_loader(n_id)\r\n            assoc[n_id_neighbors] = torch.arange(n_id_neighbors.size(0), device=device)\r\n\r\n            z, last_update = memory(n_id_neighbors)\r\n\r\n            z = gnn(\r\n                z,\r\n                last_update,\r\n                mem_edge_index,\r\n                data.t[e_id].to(device),\r\n                data.msg[e_id].to(device),\r\n            )\r\n            z = z[assoc[n_id]]\r\n\r\n            # loss and metric computation\r\n            pred = node_pred(z)\r\n\r\n            loss = criterion(pred, labels.to(device))\r\n            np_pred = pred.cpu().detach().numpy()\r\n            np_true = labels.cpu().detach().numpy()\r\n\r\n            input_dict = {\r\n                \"y_true\": np_true,\r\n                \"y_pred\": np_pred,\r\n                \"eval_metric\": [eval_metric],\r\n            }\r\n            result_dict = evaluator.eval(input_dict)\r\n            score = result_dict[eval_metric]\r\n            total_score += score\r\n            num_label_ts += 1\r\n\r\n            loss.backward()\r\n            optimizer.step()\r\n            total_loss += float(loss.detach())\r\n\r\n        # Update memory and neighbor loader with ground-truth state.\r\n        process_edges(src, dst, t, msg)\r\n        memory.detach()\r\n\r\n    metric_dict = {\r\n        \"ce\": total_loss / num_label_ts,\r\n    }\r\n    metric_dict[eval_metric] = total_score / num_label_ts\r\n    return metric_dict\r\n\r\n\r\n@torch.no_grad()\r\ndef test(loader):\r\n    memory.eval()\r\n    gnn.eval()\r\n    node_pred.eval()\r\n\r\n    label_t = dataset.get_label_time()  # check when does the first label start\r\n    num_label_ts = 0\r\n    total_score = 0\r\n\r\n    for batch in tqdm(loader):\r\n        batch = batch.to(device)\r\n        src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg\r\n\r\n        query_t = batch.t[-1]\r\n        if query_t > label_t:\r\n            label_tuple = dataset.get_node_label(query_t)\r\n            if label_tuple is None:\r\n                break\r\n            label_ts, label_srcs, labels = (\r\n                label_tuple[0],\r\n                label_tuple[1],\r\n                label_tuple[2],\r\n            )\r\n            label_t = dataset.get_label_time()\r\n            label_srcs = label_srcs.to(device)\r\n\r\n            # Process all edges that are still in the past day\r\n            previous_day_mask = batch.t < label_t\r\n            process_edges(\r\n                src[previous_day_mask],\r\n                dst[previous_day_mask],\r\n                t[previous_day_mask],\r\n                msg[previous_day_mask],\r\n            )\r\n            # Reset edges to be the edges from tomorrow so they can be used later\r\n            src, dst, t, msg = (\r\n                src[~previous_day_mask],\r\n                dst[~previous_day_mask],\r\n                t[~previous_day_mask],\r\n                msg[~previous_day_mask],\r\n            )\r\n\r\n            \"\"\"\r\n            modified for node property prediction\r\n            1. sample neighbors from the neighbor loader for all nodes to be predicted\r\n            2. extract memory from the sampled neighbors and the nodes\r\n            3. run gnn with the extracted memory embeddings and the corresponding time and message\r\n            \"\"\"\r\n            n_id = label_srcs\r\n            n_id_neighbors, mem_edge_index, e_id = neighbor_loader(n_id)\r\n            assoc[n_id_neighbors] = torch.arange(n_id_neighbors.size(0), device=device)\r\n\r\n            z, last_update = memory(n_id_neighbors)\r\n            z = gnn(\r\n                z,\r\n                last_update,\r\n                mem_edge_index,\r\n                data.t[e_id].to(device),\r\n                data.msg[e_id].to(device),\r\n            )\r\n            z = z[assoc[n_id]]\r\n\r\n            # loss and metric computation\r\n            pred = node_pred(z)\r\n            np_pred = pred.cpu().detach().numpy()\r\n            np_true = labels.cpu().detach().numpy()\r\n\r\n            input_dict = {\r\n                \"y_true\": np_true,\r\n                \"y_pred\": np_pred,\r\n                \"eval_metric\": [eval_metric],\r\n            }\r\n            result_dict = evaluator.eval(input_dict)\r\n            score = result_dict[eval_metric]\r\n            total_score += score\r\n            num_label_ts += 1\r\n\r\n        process_edges(src, dst, t, msg)\r\n\r\n    metric_dict = {}\r\n    metric_dict[eval_metric] = total_score / num_label_ts\r\n    return metric_dict\r\n\r\n\r\ntrain_curve = []\r\nval_curve = []\r\ntest_curve = []\r\nmax_val_score = 0  #find the best test score based on validation score\r\nbest_test_idx = 0\r\nfor epoch in range(1, epochs + 1):\r\n    start_time = timeit.default_timer()\r\n    train_dict = train()\r\n    print(\"------------------------------------\")\r\n    print(f\"training Epoch: {epoch:02d}\")\r\n    print(train_dict)\r\n    train_curve.append(train_dict[eval_metric])\r\n    print(\"Training takes--- %s seconds ---\" % (timeit.default_timer() - start_time))\r\n    \r\n    start_time = timeit.default_timer()\r\n    val_dict = test(val_loader)\r\n    print(val_dict)\r\n    val_curve.append(val_dict[eval_metric])\r\n    if (val_dict[eval_metric] > max_val_score):\r\n        max_val_score = val_dict[eval_metric]\r\n        best_test_idx = epoch - 1\r\n    print(\"Validation takes--- %s seconds ---\" % (timeit.default_timer() - start_time))\r\n\r\n    start_time = timeit.default_timer()\r\n    test_dict = test(test_loader)\r\n    print(test_dict)\r\n    test_curve.append(test_dict[eval_metric])\r\n    print(\"Test takes--- %s seconds ---\" % (timeit.default_timer() - start_time))\r\n    print(\"------------------------------------\")\r\n    dataset.reset_label_time()\r\n\r\n\r\n# # code for plotting\r\n# plot_curve(train_curve, \"train_curve\")\r\n# plot_curve(val_curve, \"val_curve\")\r\n# plot_curve(test_curve, \"test_curve\")\r\n\r\nmax_test_score = test_curve[best_test_idx]\r\nprint(\"------------------------------------\")\r\nprint(\"------------------------------------\")\r\nprint (\"best val score: \", max_val_score)\r\nprint (\"best validation epoch   : \", best_test_idx + 1)\r\nprint (\"best test score: \", max_test_score)\r\n"
  },
  {
    "path": "examples/nodeproppred/tgbn-reddit/dyrep.py",
    "content": "\"\"\"\nDyRep\n    This has been implemented with intuitions from the following sources:\n    - https://github.com/twitter-research/tgn\n    - https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py\n\n    Spec.:\n        - Memory Updater: RNN\n        - Embedding Module: ID\n        - Message Function: ATTN\n\"\"\"\nimport timeit\nfrom tqdm import tqdm\nimport torch\nfrom torch_geometric.loader import TemporalDataLoader\nimport numpy as np\n\n# internal imports\nfrom tgb.utils.utils import get_args, set_random_seed\nfrom tgb.nodeproppred.evaluate import Evaluator\nfrom modules.decoder import NodePredictor\nfrom modules.emb_module import GraphAttentionEmbedding\nfrom modules.msg_func import IdentityMessage\nfrom modules.msg_agg import LastAggregator\nfrom modules.neighbor_loader import LastNeighborLoader\nfrom modules.memory_module import DyRepMemory\nfrom modules.early_stopping import  EarlyStopMonitor\nfrom tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset\n\n\ndef process_edges(src, dst, t, msg):\n    if src.nelement() > 0:\n        model['memory'].update_state(src, dst, t, msg)\n        neighbor_loader.insert(src, dst)\n\n\n# ==========\n# ========== Define helper function...\n# ==========\n\ndef train():\n    model['memory'].train()\n    model['gnn'].train()\n    model['node_pred'].train()\n\n    model['memory'].reset_state()  # Start with a fresh memory.\n    neighbor_loader.reset_state()  # Start with an empty graph.\n\n    total_loss = 0\n    label_t = dataset.get_label_time()  # check when does the first label start\n    num_label_ts = 0\n    total_score = 0\n    for batch in train_loader:\n        batch = batch.to(device)\n        optimizer.zero_grad()\n\n        src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg\n\n        query_t = batch.t[-1]\n        # check if this batch moves to the next day\n        if query_t > label_t:\n            # find the node labels from the past day\n            label_tuple = dataset.get_node_label(query_t)\n            label_ts, label_srcs, labels = (\n                label_tuple[0],\n                label_tuple[1],\n                label_tuple[2],\n            )\n            label_t = dataset.get_label_time()\n            label_srcs = label_srcs.to(device)\n\n            # Process all edges that are still in the past day\n            previous_day_mask = batch.t < label_t\n            process_edges(\n                src[previous_day_mask],\n                dst[previous_day_mask],\n                t[previous_day_mask],\n                msg[previous_day_mask],\n            )\n            # Reset edges to be the edges from tomorrow so they can be used later\n            src, dst, t, msg = (\n                src[~previous_day_mask],\n                dst[~previous_day_mask],\n                t[~previous_day_mask],\n                msg[~previous_day_mask],\n            )\n\n            \"\"\"\n            modified for node property prediction\n            1. sample neighbors from the neighbor loader for all nodes to be predicted\n            2. extract memory from the sampled neighbors and the nodes\n            3. run gnn with the extracted memory embeddings and the corresponding time and message\n            \"\"\"\n            n_id = label_srcs\n            n_id_neighbors, mem_edge_index, e_id = neighbor_loader(n_id)\n            assoc[n_id_neighbors] = torch.arange(n_id_neighbors.size(0), device=device)\n\n            z, last_update = model['memory'](n_id_neighbors)\n            z = model['gnn'](\n                z,\n                last_update,\n                mem_edge_index,\n                data.t[e_id].to(device),\n                data.msg[e_id].to(device),\n            )\n            z = z[assoc[n_id]]\n\n            # loss and metric computation\n            pred = model['node_pred'](z)\n\n            loss = criterion(pred, labels.to(device))\n            np_pred = pred.cpu().detach().numpy()\n            np_true = labels.cpu().detach().numpy()\n\n            input_dict = {\n                \"y_true\": np_true,\n                \"y_pred\": np_pred,\n                \"eval_metric\": [metric],\n            }\n            result_dict = evaluator.eval(input_dict)\n            score = result_dict[metric]\n            total_score += score\n            num_label_ts += 1\n\n            loss.backward()\n            optimizer.step()\n            total_loss += float(loss.detach())\n\n        # Update memory and neighbor loader with ground-truth state.\n        process_edges(src, dst, t, msg)\n        model['memory'].detach()\n\n    metric_dict = {\n        \"ce\": total_loss / num_label_ts,\n    }\n    metric_dict[metric] = total_score / num_label_ts\n    return metric_dict\n\n\n@torch.no_grad()\ndef test(loader):\n    model['memory'].eval()\n    model['gnn'].eval()\n    model['node_pred'].eval()\n\n    total_score = 0\n    label_t = dataset.get_label_time()  # check when does the first label start\n    num_label_ts = 0\n\n    for batch in tqdm(loader):\n        batch = batch.to(device)\n        src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg\n\n        query_t = batch.t[-1]\n        if query_t > label_t:\n            label_tuple = dataset.get_node_label(query_t)\n            if label_tuple is None:\n                break\n            label_ts, label_srcs, labels = (\n                label_tuple[0],\n                label_tuple[1],\n                label_tuple[2],\n            )\n            label_t = dataset.get_label_time()\n            label_srcs = label_srcs.to(device)\n\n            # Process all edges that are still in the past day\n            previous_day_mask = batch.t < label_t\n            process_edges(\n                src[previous_day_mask],\n                dst[previous_day_mask],\n                t[previous_day_mask],\n                msg[previous_day_mask],\n            )\n            # Reset edges to be the edges from tomorrow so they can be used later\n            src, dst, t, msg = (\n                src[~previous_day_mask],\n                dst[~previous_day_mask],\n                t[~previous_day_mask],\n                msg[~previous_day_mask],\n            )\n\n            \"\"\"\n            modified for node property prediction\n            1. sample neighbors from the neighbor loader for all nodes to be predicted\n            2. extract memory from the sampled neighbors and the nodes\n            3. run gnn with the extracted memory embeddings and the corresponding time and message\n            \"\"\"\n            n_id = label_srcs\n            n_id_neighbors, mem_edge_index, e_id = neighbor_loader(n_id)\n            assoc[n_id_neighbors] = torch.arange(n_id_neighbors.size(0), device=device)\n\n            z, last_update = model['memory'](n_id_neighbors)\n            z = model['gnn'](\n                z,\n                last_update,\n                mem_edge_index,\n                data.t[e_id].to(device),\n                data.msg[e_id].to(device),\n            )\n            z = z[assoc[n_id]]\n\n            # loss and metric computation\n            pred = model['node_pred'](z)\n            np_pred = pred.cpu().detach().numpy()\n            np_true = labels.cpu().detach().numpy()\n\n            input_dict = {\n                \"y_true\": np_true,\n                \"y_pred\": np_pred,\n                \"eval_metric\": [metric],\n            }\n            result_dict = evaluator.eval(input_dict)\n            score = result_dict[metric]\n            total_score += score\n            num_label_ts += 1\n\n        process_edges(src, dst, t, msg)\n\n    metric_dict = {}\n    metric_dict[metric] = total_score / num_label_ts\n    return metric_dict\n\n# ==========\n# ==========\n# ==========\n\n# Start...\nstart_overall = timeit.default_timer()\n\n# ========== set parameters...\nargs, _ = get_args()\nprint(\"INFO: Arguments:\", args)\n\nDATA = \"tgbn-reddit\"\nLR = args.lr\nBATCH_SIZE = args.bs\nK_VALUE = args.k_value  \nNUM_EPOCH = args.num_epoch\nSEED = args.seed\nMEM_DIM = args.mem_dim\nTIME_DIM = args.time_dim\nEMB_DIM = args.emb_dim\nTOLERANCE = args.tolerance\nPATIENCE = args.patience\nNUM_RUNS = args.num_run\nNUM_NEIGHBORS = 10\n\n\n# setting random seed\ntorch.manual_seed(SEED)\nset_random_seed(SEED)\n\nMODEL_NAME = 'DyRep'\nUSE_SRC_EMB_IN_MSG = False\nUSE_DST_EMB_IN_MSG = True\nevaluator = Evaluator(name=DATA)\n# ==========\n\n# set the device\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n# data loading\ndataset = PyGNodePropPredDataset(name=DATA, root=\"datasets\")\ntrain_mask = dataset.train_mask\nval_mask = dataset.val_mask\ntest_mask = dataset.test_mask\ndata = dataset.get_TemporalData()\ndata = data.to(device)\nmetric = dataset.eval_metric\n\ntrain_data = data[train_mask]\nval_data = data[val_mask]\ntest_data = data[test_mask]\nnum_classes = dataset.num_classes\n\n\ntrain_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)\nval_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)\ntest_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)\n\n# Ensure to only sample actual destination nodes as negatives.\nmin_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\n\n# neighborhood sampler\nneighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)\n\n# define the model end-to-end\nmemory = DyRepMemory(\n    data.num_nodes,\n    data.msg.size(-1),\n    MEM_DIM,\n    TIME_DIM,\n    message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),\n    aggregator_module=LastAggregator(),\n    memory_updater_type='rnn',\n    use_src_emb_in_msg=USE_SRC_EMB_IN_MSG,\n    use_dst_emb_in_msg=USE_DST_EMB_IN_MSG\n).to(device)\n\ngnn = GraphAttentionEmbedding(\n    in_channels=MEM_DIM,\n    out_channels=EMB_DIM,\n    msg_dim=data.msg.size(-1),\n    time_enc=memory.time_enc,\n).to(device)\n\nnode_pred = NodePredictor(in_dim=EMB_DIM, out_dim=num_classes).to(device)\n\nmodel = {'memory': memory,\n         'gnn': gnn,\n         'node_pred': node_pred}\n\noptimizer = torch.optim.Adam(\n    set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['node_pred'].parameters()),\n    lr=LR,\n)\n\ncriterion = torch.nn.CrossEntropyLoss()\n# Helper vector to map global node indices to local ones.\nassoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)\n\ntrain_curve = []\nval_curve = []\ntest_curve = []\nmax_val_score = 0  #find the best test score based on validation score\nbest_test_idx = 0\nfor epoch in range(1, NUM_EPOCH + 1):\n    start_time = timeit.default_timer()\n    train_dict = train()\n    print(\"------------------------------------\")\n    print(f\"training Epoch: {epoch:02d}\")\n    print(train_dict)\n    train_curve.append(train_dict[metric])\n    print(\"Training takes--- %s seconds ---\" % (timeit.default_timer() - start_time))\n    \n    start_time = timeit.default_timer()\n    val_dict = test(val_loader)\n    print(val_dict)\n    val_curve.append(val_dict[metric])\n    if (val_dict[metric] > max_val_score):\n        max_val_score = val_dict[metric]\n        best_test_idx = epoch - 1\n    print(\"Validation takes--- %s seconds ---\" % (timeit.default_timer() - start_time))\n\n    start_time = timeit.default_timer()\n    test_dict = test(test_loader)\n    print(test_dict)\n    test_curve.append(test_dict[metric])\n    print(\"Test takes--- %s seconds ---\" % (timeit.default_timer() - start_time))\n    print(\"------------------------------------\")\n    dataset.reset_label_time()\n\nmax_test_score = test_curve[best_test_idx]\nprint(\"------------------------------------\")\nprint(\"------------------------------------\")\nprint (\"best val score: \", max_val_score)\nprint (\"best validation epoch   : \", best_test_idx + 1)\nprint (\"best test score: \", max_test_score)\n\n"
  },
  {
    "path": "examples/nodeproppred/tgbn-reddit/moving_average.py",
    "content": "\"\"\"\nimplement persistant forecast as baseline for the node prop pred task\nsimply predict last seen label for the node\n\"\"\"\n\nimport timeit\nimport numpy as np\nfrom torch_geometric.loader import TemporalDataLoader\n\n# local imports\nfrom tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset\nfrom modules.heuristics import MovingAverage\nfrom tgb.nodeproppred.evaluate import Evaluator\n\nwindow = 7\ndevice = \"cpu\"\nname = \"tgbn-reddit\"\ndataset = PyGNodePropPredDataset(name=name, root=\"datasets\")\nnum_classes = dataset.num_classes\ndata = dataset.get_TemporalData()\ndata = data.to(device)\n\neval_metric = dataset.eval_metric\nforecaster = MovingAverage(num_classes, window=window)\nevaluator = Evaluator(name=name)\n\n\n# Ensure to only sample actual destination nodes as negatives.\nmin_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\ntrain_data, val_data, test_data = data.train_val_test_split(\n    val_ratio=0.15, test_ratio=0.15\n)\n\nbatch_size = 200\n\ntrain_loader = TemporalDataLoader(train_data, batch_size=batch_size)\nval_loader = TemporalDataLoader(val_data, batch_size=batch_size)\ntest_loader = TemporalDataLoader(test_data, batch_size=batch_size)\n\n\ndef test_n_upate(loader):\n    label_t = dataset.get_label_time()  # check when does the first label start\n    num_label_ts = 0\n    total_score = 0\n\n    for batch in loader:\n        batch = batch.to(device)\n        src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg\n\n        query_t = batch.t[-1]\n        if query_t > label_t:\n            label_tuple = dataset.get_node_label(query_t)\n            if label_tuple is None:\n                break\n            label_ts, label_srcs, labels = (\n                label_tuple[0],\n                label_tuple[1],\n                label_tuple[2],\n            )\n            label_ts = label_ts.numpy()\n            label_srcs = label_srcs.numpy()\n            labels = labels.numpy()\n            label_t = dataset.get_label_time()\n\n            preds = []\n\n            for i in range(0, label_srcs.shape[0]):\n                node_id = label_srcs[i]\n                pred_vec = forecaster.query_dict(node_id)\n                preds.append(pred_vec)\n                forecaster.update_dict(node_id, labels[i])\n\n            np_pred = np.stack(preds, axis=0)\n            np_true = labels\n\n            input_dict = {\n                \"y_true\": np_true,\n                \"y_pred\": np_pred,\n                \"eval_metric\": [eval_metric],\n            }\n            result_dict = evaluator.eval(input_dict)\n            score = result_dict[eval_metric]\n\n            total_score += score\n            num_label_ts += 1\n\n    metric_dict = {}\n    metric_dict[eval_metric] = total_score / num_label_ts\n    return metric_dict\n\n\n\"\"\"\ntrain, val and test for one epoch only\n\"\"\"\n\nstart_time = timeit.default_timer()\nmetric_dict = test_n_upate(train_loader)\nprint(metric_dict)\nprint(\n    \"Persistant forecast on Training takes--- %s seconds ---\"\n    % (timeit.default_timer() - start_time)\n)\n\nstart_time = timeit.default_timer()\nval_dict = test_n_upate(val_loader)\nprint(val_dict)\nprint(\n    \"Persistant forecast on validation takes--- %s seconds ---\"\n    % (timeit.default_timer() - start_time)\n)\n\n\nstart_time = timeit.default_timer()\ntest_dict = test_n_upate(test_loader)\nprint(test_dict)\nprint(\n    \"Persistant forecast on Test takes--- %s seconds ---\" % (timeit.default_timer() - start_time)\n)\ndataset.reset_label_time()"
  },
  {
    "path": "examples/nodeproppred/tgbn-reddit/persistant_forecast.py",
    "content": "\"\"\"\nimplement persistant forecast as baseline for the node prop pred task\nsimply predict last seen label for the node\n\"\"\"\n\nimport timeit\nimport numpy as np\nfrom torch_geometric.loader import TemporalDataLoader\nfrom tqdm import tqdm\nimport torch\n\n# local imports\nfrom tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset\nfrom modules.heuristics import PersistantForecaster\nfrom tgb.nodeproppred.evaluate import Evaluator\n\n\ndevice = \"cpu\"\nname = \"tgbn-reddit\"\ndataset = PyGNodePropPredDataset(name=name, root=\"datasets\")\nnum_classes = dataset.num_classes\ndata = dataset.get_TemporalData()\ndata = data.to(device)\n\nall_nodes = torch.cat((data.src, data.dst), 0)\nall_nodes = all_nodes.unique()\nprint (all_nodes.shape[0])\n\neval_metric = dataset.eval_metric\nforecaster = PersistantForecaster(num_classes)\nevaluator = Evaluator(name=name)\n\n\n# Ensure to only sample actual destination nodes as negatives.\nmin_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\ntrain_data, val_data, test_data = data.train_val_test_split(\n    val_ratio=0.15, test_ratio=0.15\n)\n\nbatch_size = 200\n\ntrain_loader = TemporalDataLoader(train_data, batch_size=batch_size)\nval_loader = TemporalDataLoader(val_data, batch_size=batch_size)\ntest_loader = TemporalDataLoader(test_data, batch_size=batch_size)\n\n\n\"\"\"\ncontinue debug here\n\"\"\"\n\n\ndef test_n_upate(loader):\n    label_t = dataset.get_label_time()  # check when does the first label start\n    num_label_ts = 0\n    total_score = 0\n\n    for batch in tqdm(loader):\n        batch = batch.to(device)\n        src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg\n\n        query_t = batch.t[-1]\n        if query_t > label_t:\n            label_tuple = dataset.get_node_label(query_t)\n            if label_tuple is None:\n                break\n            label_ts, label_srcs, labels = (\n                label_tuple[0],\n                label_tuple[1],\n                label_tuple[2],\n            )\n            label_ts = label_ts.numpy()\n            label_srcs = label_srcs.numpy()\n            labels = labels.numpy()\n            label_t = dataset.get_label_time()\n\n            preds = []\n\n            for i in range(0, label_srcs.shape[0]):\n                node_id = label_srcs[i]\n                pred_vec = forecaster.query_dict(node_id)\n                preds.append(pred_vec)\n                forecaster.update_dict(node_id, labels[i])\n\n            np_pred = np.stack(preds, axis=0)\n            np_true = labels\n\n            input_dict = {\n                \"y_true\": np_true,\n                \"y_pred\": np_pred,\n                \"eval_metric\": [eval_metric],\n            }\n            result_dict = evaluator.eval(input_dict)\n            score = result_dict[eval_metric]\n            total_score += score\n            num_label_ts += 1\n\n    metric_dict = {}\n    metric_dict[eval_metric] = total_score / num_label_ts\n    return metric_dict\n\n\n\"\"\"\ntrain, val and test for one epoch only\n\"\"\"\n\nstart_time = timeit.default_timer()\nmetric_dict = test_n_upate(train_loader)\nprint(metric_dict)\nprint(\n    \"Persistant forecast on Training takes--- %s seconds ---\"\n    % (timeit.default_timer() - start_time)\n)\n\nstart_time = timeit.default_timer()\nval_dict = test_n_upate(val_loader)\nprint(val_dict)\nprint(\n    \"Persistant forecast on validation takes--- %s seconds ---\"\n    % (timeit.default_timer() - start_time)\n)\n\n\nstart_time = timeit.default_timer()\ntest_dict = test_n_upate(test_loader)\nprint(test_dict)\nprint(\n    \"Persistant forecast on Test takes--- %s seconds ---\" % (timeit.default_timer() - start_time)\n)\ndataset.reset_label_time()"
  },
  {
    "path": "examples/nodeproppred/tgbn-reddit/tgn.py",
    "content": "import timeit\r\nimport argparse\r\nfrom tqdm import tqdm\r\nimport torch\r\nimport matplotlib.pyplot as plt\r\n\r\nfrom torch_geometric.loader import TemporalDataLoader\r\nfrom torch_geometric.nn import TGNMemory\r\nfrom torch_geometric.nn.models.tgn import (\r\n    IdentityMessage,\r\n    LastAggregator,\r\n    LastNeighborLoader,\r\n)\r\n\r\nfrom modules.decoder import NodePredictor\r\nfrom modules.emb_module import GraphAttentionEmbedding\r\nfrom tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset\r\nfrom tgb.nodeproppred.evaluate import Evaluator\r\nfrom tgb.utils.utils import set_random_seed\r\nfrom tgb.utils.stats import plot_curve\r\n\r\nparser = argparse.ArgumentParser(description='parsing command line arguments as hyperparameters')\r\nparser.add_argument('-s', '--seed', type=int, default=1,\r\n                    help='random seed to use')\r\nparser.parse_args()\r\nargs = parser.parse_args()\r\n# setting random seed\r\nseed = int(args.seed) #1,2,3,4,5\r\nprint (\"setting random seed to be\", seed)\r\ntorch.manual_seed(seed)\r\nset_random_seed(seed)\r\n\r\n# hyperparameters\r\nlr = 0.0001\r\nepochs = 50\r\n\r\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\r\nname = \"tgbn-reddit\"\r\ndataset = PyGNodePropPredDataset(name=name, root=\"datasets\")\r\ntrain_mask = dataset.train_mask\r\nval_mask = dataset.val_mask\r\ntest_mask = dataset.test_mask\r\n\r\neval_metric = dataset.eval_metric\r\nnum_classes = dataset.num_classes\r\ndata = dataset.get_TemporalData()\r\ndata = data.to(device)\r\n\r\nevaluator = Evaluator(name=name)\r\n\r\n\r\ntrain_data = data[train_mask]\r\nval_data = data[val_mask]\r\ntest_data = data[test_mask]\r\n\r\n# Ensure to only sample actual destination nodes as negatives.\r\nmin_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\r\n\r\nbatch_size = 200\r\n\r\ntrain_loader = TemporalDataLoader(train_data, batch_size=batch_size)\r\nval_loader = TemporalDataLoader(val_data, batch_size=batch_size)\r\ntest_loader = TemporalDataLoader(test_data, batch_size=batch_size)\r\n\r\nneighbor_loader = LastNeighborLoader(data.num_nodes, size=10, device=device)\r\n\r\nmemory_dim = time_dim = embedding_dim = 100\r\n\r\nmemory = TGNMemory(\r\n    data.num_nodes,\r\n    data.msg.size(-1),\r\n    memory_dim,\r\n    time_dim,\r\n    message_module=IdentityMessage(data.msg.size(-1), memory_dim, time_dim),\r\n    aggregator_module=LastAggregator(),\r\n).to(device)\r\n\r\ngnn = (\r\n    GraphAttentionEmbedding(\r\n        in_channels=memory_dim,\r\n        out_channels=embedding_dim,\r\n        msg_dim=data.msg.size(-1),\r\n        time_enc=memory.time_enc,\r\n    )\r\n    .to(device)\r\n    .float()\r\n)\r\n\r\nnode_pred = NodePredictor(in_dim=embedding_dim, out_dim=num_classes).to(device)\r\n\r\noptimizer = torch.optim.Adam(\r\n    set(memory.parameters()) | set(gnn.parameters()) | set(node_pred.parameters()),\r\n    lr=lr,\r\n)\r\n\r\ncriterion = torch.nn.CrossEntropyLoss()\r\n# Helper vector to map global node indices to local ones.\r\nassoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)\r\n\r\n\r\ndef plot_curve(scores, out_name):\r\n    plt.plot(scores, color=\"#e34a33\")\r\n    plt.ylabel(\"score\")\r\n    plt.savefig(out_name + \".pdf\")\r\n    plt.close()\r\n\r\n\r\ndef process_edges(src, dst, t, msg):\r\n    if src.nelement() > 0:\r\n        # msg = msg.to(torch.float32)\r\n        memory.update_state(src, dst, t, msg)\r\n        neighbor_loader.insert(src, dst)\r\n\r\n\r\ndef train():\r\n    memory.train()\r\n    gnn.train()\r\n    node_pred.train()\r\n\r\n    memory.reset_state()  # Start with a fresh memory.\r\n    neighbor_loader.reset_state()  # Start with an empty graph.\r\n\r\n    total_loss = 0\r\n    label_t = dataset.get_label_time()  # check when does the first label start\r\n    total_score = 0\r\n    num_label_ts = 0\r\n\r\n    for batch in tqdm(train_loader):\r\n        batch = batch.to(device)\r\n        optimizer.zero_grad()\r\n        src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg\r\n\r\n        query_t = batch.t[-1]\r\n        # check if this batch moves to the next day\r\n        if query_t > label_t:\r\n            # find the node labels from the past day\r\n            label_tuple = dataset.get_node_label(query_t)\r\n            label_ts, label_srcs, labels = (\r\n                label_tuple[0],\r\n                label_tuple[1],\r\n                label_tuple[2],\r\n            )\r\n            label_t = dataset.get_label_time()\r\n            label_srcs = label_srcs.to(device)\r\n\r\n            # Process all edges that are still in the past day\r\n            previous_day_mask = batch.t < label_t\r\n            process_edges(\r\n                src[previous_day_mask],\r\n                dst[previous_day_mask],\r\n                t[previous_day_mask],\r\n                msg[previous_day_mask],\r\n            )\r\n            # Reset edges to be the edges from tomorrow so they can be used later\r\n            src, dst, t, msg = (\r\n                src[~previous_day_mask],\r\n                dst[~previous_day_mask],\r\n                t[~previous_day_mask],\r\n                msg[~previous_day_mask],\r\n            )\r\n\r\n            \"\"\"\r\n            modified for node property prediction\r\n            1. sample neighbors from the neighbor loader for all nodes to be predicted\r\n            2. extract memory from the sampled neighbors and the nodes\r\n            3. run gnn with the extracted memory embeddings and the corresponding time and message\r\n            \"\"\"\r\n            n_id = label_srcs\r\n            n_id_neighbors, mem_edge_index, e_id = neighbor_loader(n_id)\r\n            assoc[n_id_neighbors] = torch.arange(n_id_neighbors.size(0), device=device)\r\n\r\n            z, last_update = memory(n_id_neighbors)\r\n\r\n            z = gnn(\r\n                z,\r\n                last_update,\r\n                mem_edge_index,\r\n                data.t[e_id].to(device),\r\n                data.msg[e_id].to(device),\r\n            )\r\n            z = z[assoc[n_id]]\r\n\r\n            # loss and metric computation\r\n            pred = node_pred(z)\r\n\r\n            loss = criterion(pred, labels.to(device))\r\n            np_pred = pred.cpu().detach().numpy()\r\n            np_true = labels.cpu().detach().numpy()\r\n\r\n            input_dict = {\r\n                \"y_true\": np_true,\r\n                \"y_pred\": np_pred,\r\n                \"eval_metric\": [eval_metric],\r\n            }\r\n            result_dict = evaluator.eval(input_dict)\r\n            score = result_dict[eval_metric]\r\n            total_score += score\r\n            num_label_ts += 1\r\n            loss.backward()\r\n            optimizer.step()\r\n            total_loss += float(loss.detach())\r\n\r\n        # Update memory and neighbor loader with ground-truth state.\r\n        process_edges(src, dst, t, msg)\r\n        memory.detach()\r\n\r\n    metric_dict = {\r\n        \"ce\": total_loss / num_label_ts,\r\n    }\r\n    metric_dict[eval_metric] = total_score / num_label_ts\r\n    return metric_dict\r\n\r\n\r\n@torch.no_grad()\r\ndef test(loader):\r\n    memory.eval()\r\n    gnn.eval()\r\n    node_pred.eval()\r\n\r\n    label_t = dataset.get_label_time()  # check when does the first label start\r\n    num_label_ts = 0\r\n    total_score = 0\r\n\r\n    for batch in tqdm(loader):\r\n        batch = batch.to(device)\r\n        src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg\r\n\r\n        query_t = batch.t[-1]\r\n        if query_t > label_t:\r\n            label_tuple = dataset.get_node_label(query_t)\r\n            if label_tuple is None:\r\n                break\r\n            label_ts, label_srcs, labels = (\r\n                label_tuple[0],\r\n                label_tuple[1],\r\n                label_tuple[2],\r\n            )\r\n            label_t = dataset.get_label_time()\r\n            label_srcs = label_srcs.to(device)\r\n\r\n            # Process all edges that are still in the past day\r\n            previous_day_mask = batch.t < label_t\r\n            process_edges(\r\n                src[previous_day_mask],\r\n                dst[previous_day_mask],\r\n                t[previous_day_mask],\r\n                msg[previous_day_mask],\r\n            )\r\n            # Reset edges to be the edges from tomorrow so they can be used later\r\n            src, dst, t, msg = (\r\n                src[~previous_day_mask],\r\n                dst[~previous_day_mask],\r\n                t[~previous_day_mask],\r\n                msg[~previous_day_mask],\r\n            )\r\n\r\n            \"\"\"\r\n            modified for node property prediction\r\n            1. sample neighbors from the neighbor loader for all nodes to be predicted\r\n            2. extract memory from the sampled neighbors and the nodes\r\n            3. run gnn with the extracted memory embeddings and the corresponding time and message\r\n            \"\"\"\r\n            n_id = label_srcs\r\n            n_id_neighbors, mem_edge_index, e_id = neighbor_loader(n_id)\r\n            assoc[n_id_neighbors] = torch.arange(n_id_neighbors.size(0), device=device)\r\n\r\n            z, last_update = memory(n_id_neighbors)\r\n            z = gnn(\r\n                z,\r\n                last_update,\r\n                mem_edge_index,\r\n                data.t[e_id].to(device),\r\n                data.msg[e_id].to(device),\r\n            )\r\n            z = z[assoc[n_id]]\r\n\r\n            # loss and metric computation\r\n            pred = node_pred(z)\r\n            np_pred = pred.cpu().detach().numpy()\r\n            np_true = labels.cpu().detach().numpy()\r\n\r\n            input_dict = {\r\n                \"y_true\": np_true,\r\n                \"y_pred\": np_pred,\r\n                \"eval_metric\": [eval_metric],\r\n            }\r\n            result_dict = evaluator.eval(input_dict)\r\n            score = result_dict[eval_metric]\r\n            total_score += score\r\n            num_label_ts += 1\r\n\r\n        process_edges(src, dst, t, msg)\r\n\r\n    metric_dict = {}\r\n    metric_dict[eval_metric] = total_score / num_label_ts\r\n    return metric_dict\r\n\r\n\r\ntrain_curve = []\r\nval_curve = []\r\ntest_curve = []\r\nmax_val_score = 0  #find the best test score based on validation score\r\nbest_test_idx = 0\r\nfor epoch in range(1, epochs + 1):\r\n    start_time = timeit.default_timer()\r\n    train_dict = train()\r\n    print(\"------------------------------------\")\r\n    print(f\"training Epoch: {epoch:02d}\")\r\n    print(train_dict)\r\n    train_curve.append(train_dict[eval_metric])\r\n    print(\"Training takes--- %s seconds ---\" % (timeit.default_timer() - start_time))\r\n    \r\n    start_time = timeit.default_timer()\r\n    val_dict = test(val_loader)\r\n    print(val_dict)\r\n    val_curve.append(val_dict[eval_metric])\r\n    if (val_dict[eval_metric] > max_val_score):\r\n        max_val_score = val_dict[eval_metric]\r\n        best_test_idx = epoch - 1\r\n    print(\"Validation takes--- %s seconds ---\" % (timeit.default_timer() - start_time))\r\n\r\n    start_time = timeit.default_timer()\r\n    test_dict = test(test_loader)\r\n    print(test_dict)\r\n    test_curve.append(test_dict[eval_metric])\r\n    print(\"Test takes--- %s seconds ---\" % (timeit.default_timer() - start_time))\r\n    print(\"------------------------------------\")\r\n    dataset.reset_label_time()\r\n\r\n\r\n# code for plotting\r\nplot_curve(train_curve, \"train_curve\")\r\nplot_curve(val_curve, \"val_curve\")\r\nplot_curve(test_curve, \"test_curve\")\r\n\r\nmax_test_score = test_curve[best_test_idx]\r\nprint(\"------------------------------------\")\r\nprint(\"------------------------------------\")\r\nprint (\"best val score: \", max_val_score)\r\nprint (\"best validation epoch   : \", best_test_idx + 1)\r\nprint (\"best test score: \", max_test_score)\r\n"
  },
  {
    "path": "examples/nodeproppred/tgbn-token/dyrep.py",
    "content": "\"\"\"\nDyRep\n    This has been implemented with intuitions from the following sources:\n    - https://github.com/twitter-research/tgn\n    - https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py\n\n    Spec.:\n        - Memory Updater: RNN\n        - Embedding Module: ID\n        - Message Function: ATTN\n\"\"\"\nimport timeit\nfrom tqdm import tqdm\nimport torch\nfrom torch_geometric.loader import TemporalDataLoader\nimport numpy as np\n\n# internal imports\nfrom tgb.utils.utils import get_args, set_random_seed\nfrom tgb.nodeproppred.evaluate import Evaluator\nfrom modules.decoder import NodePredictor\nfrom modules.emb_module import GraphAttentionEmbedding\nfrom modules.msg_func import IdentityMessage\nfrom modules.msg_agg import LastAggregator\nfrom modules.neighbor_loader import LastNeighborLoader\nfrom modules.memory_module import DyRepMemory\nfrom modules.early_stopping import  EarlyStopMonitor\nfrom tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset\n\n\ndef process_edges(src, dst, t, msg):\n    if src.nelement() > 0:\n        model['memory'].update_state(src, dst, t, msg)\n        neighbor_loader.insert(src, dst)\n\n\n# ==========\n# ========== Define helper function...\n# ==========\n\ndef train():\n    model['memory'].train()\n    model['gnn'].train()\n    model['node_pred'].train()\n\n    model['memory'].reset_state()  # Start with a fresh memory.\n    neighbor_loader.reset_state()  # Start with an empty graph.\n\n    total_loss = 0\n    label_t = dataset.get_label_time()  # check when does the first label start\n    num_label_ts = 0\n    total_score = 0\n    for batch in train_loader:\n        batch = batch.to(device)\n        optimizer.zero_grad()\n\n        src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg\n\n        query_t = batch.t[-1]\n        # check if this batch moves to the next day\n        if query_t > label_t:\n            # find the node labels from the past day\n            label_tuple = dataset.get_node_label(query_t)\n            label_ts, label_srcs, labels = (\n                label_tuple[0],\n                label_tuple[1],\n                label_tuple[2],\n            )\n            label_t = dataset.get_label_time()\n            label_srcs = label_srcs.to(device)\n\n            # Process all edges that are still in the past day\n            previous_day_mask = batch.t < label_t\n            process_edges(\n                src[previous_day_mask],\n                dst[previous_day_mask],\n                t[previous_day_mask],\n                msg[previous_day_mask],\n            )\n            # Reset edges to be the edges from tomorrow so they can be used later\n            src, dst, t, msg = (\n                src[~previous_day_mask],\n                dst[~previous_day_mask],\n                t[~previous_day_mask],\n                msg[~previous_day_mask],\n            )\n\n            \"\"\"\n            modified for node property prediction\n            1. sample neighbors from the neighbor loader for all nodes to be predicted\n            2. extract memory from the sampled neighbors and the nodes\n            3. run gnn with the extracted memory embeddings and the corresponding time and message\n            \"\"\"\n            n_id = label_srcs\n            n_id_neighbors, mem_edge_index, e_id = neighbor_loader(n_id)\n            assoc[n_id_neighbors] = torch.arange(n_id_neighbors.size(0), device=device)\n\n            z, last_update = model['memory'](n_id_neighbors)\n            z = model['gnn'](\n                z,\n                last_update,\n                mem_edge_index,\n                data.t[e_id].to(device),\n                data.msg[e_id].to(device),\n            )\n            z = z[assoc[n_id]]\n\n            # loss and metric computation\n            pred = model['node_pred'](z)\n\n            loss = criterion(pred, labels.to(device))\n            np_pred = pred.cpu().detach().numpy()\n            np_true = labels.cpu().detach().numpy()\n\n            input_dict = {\n                \"y_true\": np_true,\n                \"y_pred\": np_pred,\n                \"eval_metric\": [metric],\n            }\n            result_dict = evaluator.eval(input_dict)\n            score = result_dict[metric]\n            total_score += score\n            num_label_ts += 1\n\n            loss.backward()\n            optimizer.step()\n            total_loss += float(loss.detach())\n\n        # Update memory and neighbor loader with ground-truth state.\n        process_edges(src, dst, t, msg)\n        model['memory'].detach()\n\n    metric_dict = {\n        \"ce\": total_loss / num_label_ts,\n    }\n    metric_dict[metric] = total_score / num_label_ts\n    return metric_dict\n\n\n@torch.no_grad()\ndef test(loader):\n    model['memory'].eval()\n    model['gnn'].eval()\n    model['node_pred'].eval()\n\n    total_score = 0\n    label_t = dataset.get_label_time()  # check when does the first label start\n    num_label_ts = 0\n\n    for batch in tqdm(loader):\n        batch = batch.to(device)\n        src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg\n\n        query_t = batch.t[-1]\n        if query_t > label_t:\n            label_tuple = dataset.get_node_label(query_t)\n            if label_tuple is None:\n                break\n            label_ts, label_srcs, labels = (\n                label_tuple[0],\n                label_tuple[1],\n                label_tuple[2],\n            )\n            label_t = dataset.get_label_time()\n            label_srcs = label_srcs.to(device)\n\n            # Process all edges that are still in the past day\n            previous_day_mask = batch.t < label_t\n            process_edges(\n                src[previous_day_mask],\n                dst[previous_day_mask],\n                t[previous_day_mask],\n                msg[previous_day_mask],\n            )\n            # Reset edges to be the edges from tomorrow so they can be used later\n            src, dst, t, msg = (\n                src[~previous_day_mask],\n                dst[~previous_day_mask],\n                t[~previous_day_mask],\n                msg[~previous_day_mask],\n            )\n\n            \"\"\"\n            modified for node property prediction\n            1. sample neighbors from the neighbor loader for all nodes to be predicted\n            2. extract memory from the sampled neighbors and the nodes\n            3. run gnn with the extracted memory embeddings and the corresponding time and message\n            \"\"\"\n            n_id = label_srcs\n            n_id_neighbors, mem_edge_index, e_id = neighbor_loader(n_id)\n            assoc[n_id_neighbors] = torch.arange(n_id_neighbors.size(0), device=device)\n\n            z, last_update = model['memory'](n_id_neighbors)\n            z = model['gnn'](\n                z,\n                last_update,\n                mem_edge_index,\n                data.t[e_id].to(device),\n                data.msg[e_id].to(device),\n            )\n            z = z[assoc[n_id]]\n\n            # loss and metric computation\n            pred = model['node_pred'](z)\n            np_pred = pred.cpu().detach().numpy()\n            np_true = labels.cpu().detach().numpy()\n\n            input_dict = {\n                \"y_true\": np_true,\n                \"y_pred\": np_pred,\n                \"eval_metric\": [metric],\n            }\n            result_dict = evaluator.eval(input_dict)\n            score = result_dict[metric]\n            total_score += score\n            num_label_ts += 1\n\n        process_edges(src, dst, t, msg)\n\n    metric_dict = {}\n    metric_dict[metric] = total_score / num_label_ts\n    return metric_dict\n\n# ==========\n# ==========\n# ==========\n\n# Start...\nstart_overall = timeit.default_timer()\n\n# ========== set parameters...\nargs, _ = get_args()\nprint(\"INFO: Arguments:\", args)\n\nDATA = \"tgbn-token\"\nLR = args.lr\nBATCH_SIZE = args.bs\nK_VALUE = args.k_value  \nNUM_EPOCH = args.num_epoch\nSEED = args.seed\nMEM_DIM = args.mem_dim\nTIME_DIM = args.time_dim\nEMB_DIM = args.emb_dim\nTOLERANCE = args.tolerance\nPATIENCE = args.patience\nNUM_RUNS = args.num_run\nNUM_NEIGHBORS = 10\n\n\n# setting random seed\ntorch.manual_seed(SEED)\nset_random_seed(SEED)\n\nMODEL_NAME = 'DyRep'\nUSE_SRC_EMB_IN_MSG = False\nUSE_DST_EMB_IN_MSG = True\nevaluator = Evaluator(name=DATA)\n# ==========\n\n# set the device\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n# data loading\ndataset = PyGNodePropPredDataset(name=DATA, root=\"datasets\")\ntrain_mask = dataset.train_mask\nval_mask = dataset.val_mask\ntest_mask = dataset.test_mask\ndata = dataset.get_TemporalData()\ndata = data.to(device)\nmetric = dataset.eval_metric\n\ntrain_data = data[train_mask]\nval_data = data[val_mask]\ntest_data = data[test_mask]\nnum_classes = dataset.num_classes\n\n\ntrain_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)\nval_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)\ntest_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)\n\n# Ensure to only sample actual destination nodes as negatives.\nmin_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\n\n# neighborhood sampler\nneighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)\n\n# define the model end-to-end\nmemory = DyRepMemory(\n    data.num_nodes,\n    data.msg.size(-1),\n    MEM_DIM,\n    TIME_DIM,\n    message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),\n    aggregator_module=LastAggregator(),\n    memory_updater_type='rnn',\n    use_src_emb_in_msg=USE_SRC_EMB_IN_MSG,\n    use_dst_emb_in_msg=USE_DST_EMB_IN_MSG\n).to(device)\n\ngnn = GraphAttentionEmbedding(\n    in_channels=MEM_DIM,\n    out_channels=EMB_DIM,\n    msg_dim=data.msg.size(-1),\n    time_enc=memory.time_enc,\n).to(device)\n\nnode_pred = NodePredictor(in_dim=EMB_DIM, out_dim=num_classes).to(device)\n\nmodel = {'memory': memory,\n         'gnn': gnn,\n         'node_pred': node_pred}\n\noptimizer = torch.optim.Adam(\n    set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['node_pred'].parameters()),\n    lr=LR,\n)\n\ncriterion = torch.nn.CrossEntropyLoss()\n# Helper vector to map global node indices to local ones.\nassoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)\n\ntrain_curve = []\nval_curve = []\ntest_curve = []\nmax_val_score = 0  #find the best test score based on validation score\nbest_test_idx = 0\nfor epoch in range(1, NUM_EPOCH + 1):\n    start_time = timeit.default_timer()\n    train_dict = train()\n    print(\"------------------------------------\")\n    print(f\"training Epoch: {epoch:02d}\")\n    print(train_dict)\n    train_curve.append(train_dict[metric])\n    print(\"Training takes--- %s seconds ---\" % (timeit.default_timer() - start_time))\n    \n    start_time = timeit.default_timer()\n    val_dict = test(val_loader)\n    print(val_dict)\n    val_curve.append(val_dict[metric])\n    if (val_dict[metric] > max_val_score):\n        max_val_score = val_dict[metric]\n        best_test_idx = epoch - 1\n    print(\"Validation takes--- %s seconds ---\" % (timeit.default_timer() - start_time))\n\n    start_time = timeit.default_timer()\n    test_dict = test(test_loader)\n    print(test_dict)\n    test_curve.append(test_dict[metric])\n    print(\"Test takes--- %s seconds ---\" % (timeit.default_timer() - start_time))\n    print(\"------------------------------------\")\n    dataset.reset_label_time()\n\nmax_test_score = test_curve[best_test_idx]\nprint(\"------------------------------------\")\nprint(\"------------------------------------\")\nprint (\"best val score: \", max_val_score)\nprint (\"best validation epoch   : \", best_test_idx + 1)\nprint (\"best test score: \", max_test_score)\n\n"
  },
  {
    "path": "examples/nodeproppred/tgbn-token/moving_average.py",
    "content": "\"\"\"\nimplement persistant forecast as baseline for the node prop pred task\nsimply predict last seen label for the node\n\"\"\"\n\nimport timeit\nimport numpy as np\nfrom torch_geometric.loader import TemporalDataLoader\nfrom tqdm import tqdm\n\n# local imports\nfrom tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset\nfrom modules.heuristics import MovingAverage\nfrom tgb.nodeproppred.evaluate import Evaluator\n\nwindow = 7\ndevice = \"cpu\"\nname = \"tgbn-token\"\ndataset = PyGNodePropPredDataset(name=name, root=\"datasets\")\nnum_classes = dataset.num_classes\ndata = dataset.get_TemporalData()\ndata = data.to(device)\n\neval_metric = dataset.eval_metric\nforecaster = MovingAverage(num_classes, window=window)\nevaluator = Evaluator(name=name)\n\n\n# Ensure to only sample actual destination nodes as negatives.\nmin_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\ntrain_data, val_data, test_data = data.train_val_test_split(\n    val_ratio=0.15, test_ratio=0.15\n)\n\nbatch_size = 200\n\ntrain_loader = TemporalDataLoader(train_data, batch_size=batch_size)\nval_loader = TemporalDataLoader(val_data, batch_size=batch_size)\ntest_loader = TemporalDataLoader(test_data, batch_size=batch_size)\n\n\ndef test_n_upate(loader):\n    label_t = dataset.get_label_time()  # check when does the first label start\n    num_label_ts = 0\n    total_score = 0\n\n    for batch in tqdm(loader):\n        batch = batch.to(device)\n        src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg\n\n        query_t = batch.t[-1]\n        if query_t > label_t:\n            label_tuple = dataset.get_node_label(query_t)\n            if label_tuple is None:\n                break\n            label_ts, label_srcs, labels = (\n                label_tuple[0],\n                label_tuple[1],\n                label_tuple[2],\n            )\n            label_ts = label_ts.numpy()\n            label_srcs = label_srcs.numpy()\n            labels = labels.numpy()\n            label_t = dataset.get_label_time()\n\n            preds = []\n\n            for i in range(0, label_srcs.shape[0]):\n                node_id = label_srcs[i]\n                pred_vec = forecaster.query_dict(node_id)\n                preds.append(pred_vec)\n                forecaster.update_dict(node_id, labels[i])\n\n            np_pred = np.stack(preds, axis=0)\n            np_true = labels\n\n            input_dict = {\n                \"y_true\": np_true,\n                \"y_pred\": np_pred,\n                \"eval_metric\": [eval_metric],\n            }\n            result_dict = evaluator.eval(input_dict)\n            score = result_dict[eval_metric]\n\n            total_score += score\n            num_label_ts += 1\n\n    metric_dict = {}\n    metric_dict[eval_metric] = total_score / num_label_ts\n    return metric_dict\n\n\n\"\"\"\ntrain, val and test for one epoch only\n\"\"\"\n\nstart_time = timeit.default_timer()\nmetric_dict = test_n_upate(train_loader)\nprint(metric_dict)\nprint(\n    \"Persistant forecast on Training takes--- %s seconds ---\"\n    % (timeit.default_timer() - start_time)\n)\n\nstart_time = timeit.default_timer()\nval_dict = test_n_upate(val_loader)\nprint(val_dict)\nprint(\n    \"Persistant forecast on validation takes--- %s seconds ---\"\n    % (timeit.default_timer() - start_time)\n)\n\n\nstart_time = timeit.default_timer()\ntest_dict = test_n_upate(test_loader)\nprint(test_dict)\nprint(\n    \"Persistant forecast on Test takes--- %s seconds ---\" % (timeit.default_timer() - start_time)\n)\ndataset.reset_label_time()"
  },
  {
    "path": "examples/nodeproppred/tgbn-token/persistant_forecast.py",
    "content": "\"\"\"\nimplement persistant forecast as baseline for the node prop pred task\nsimply predict last seen label for the node\n\"\"\"\n\nimport timeit\nimport numpy as np\nfrom torch_geometric.loader import TemporalDataLoader\nfrom tqdm import tqdm\nimport torch\n\n# local imports\nfrom tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset\nfrom modules.heuristics import PersistantForecaster\nfrom tgb.nodeproppred.evaluate import Evaluator\n\n\ndevice = \"cpu\"\nname = \"tgbn-token\"\ndataset = PyGNodePropPredDataset(name=name, root=\"datasets\")\nnum_classes = dataset.num_classes\ndata = dataset.get_TemporalData()\ndata = data.to(device)\n\nall_nodes = torch.cat((data.src, data.dst), 0)\nall_nodes = all_nodes.unique()\nprint (all_nodes.shape[0])\n\neval_metric = dataset.eval_metric\nforecaster = PersistantForecaster(num_classes)\nevaluator = Evaluator(name=name)\n\n\n# Ensure to only sample actual destination nodes as negatives.\nmin_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\ntrain_data, val_data, test_data = data.train_val_test_split(\n    val_ratio=0.15, test_ratio=0.15\n)\n\nbatch_size = 200\n\ntrain_loader = TemporalDataLoader(train_data, batch_size=batch_size)\nval_loader = TemporalDataLoader(val_data, batch_size=batch_size)\ntest_loader = TemporalDataLoader(test_data, batch_size=batch_size)\n\n\n\"\"\"\ncontinue debug here\n\"\"\"\n\n\ndef test_n_upate(loader):\n    label_t = dataset.get_label_time()  # check when does the first label start\n    num_label_ts = 0\n    total_score = 0\n\n    for batch in tqdm(loader):\n        batch = batch.to(device)\n        src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg\n\n        query_t = batch.t[-1]\n        if query_t > label_t:\n            label_tuple = dataset.get_node_label(query_t)\n            if label_tuple is None:\n                break\n            label_ts, label_srcs, labels = (\n                label_tuple[0],\n                label_tuple[1],\n                label_tuple[2],\n            )\n            label_ts = label_ts.numpy()\n            label_srcs = label_srcs.numpy()\n            labels = labels.numpy()\n            label_t = dataset.get_label_time()\n\n            preds = []\n\n            for i in range(0, label_srcs.shape[0]):\n                node_id = label_srcs[i]\n                pred_vec = forecaster.query_dict(node_id)\n                preds.append(pred_vec)\n                forecaster.update_dict(node_id, labels[i])\n\n            np_pred = np.stack(preds, axis=0)\n            np_true = labels\n\n            input_dict = {\n                \"y_true\": np_true,\n                \"y_pred\": np_pred,\n                \"eval_metric\": [eval_metric],\n            }\n            result_dict = evaluator.eval(input_dict)\n            score = result_dict[eval_metric]\n            total_score += score\n            num_label_ts += 1\n\n    metric_dict = {}\n    metric_dict[eval_metric] = total_score / num_label_ts\n    return metric_dict\n\n\n\"\"\"\ntrain, val and test for one epoch only\n\"\"\"\n\nstart_time = timeit.default_timer()\nmetric_dict = test_n_upate(train_loader)\nprint(metric_dict)\nprint(\n    \"Persistant forecast on Training takes--- %s seconds ---\"\n    % (timeit.default_timer() - start_time)\n)\n\nstart_time = timeit.default_timer()\nval_dict = test_n_upate(val_loader)\nprint(val_dict)\nprint(\n    \"Persistant forecast on validation takes--- %s seconds ---\"\n    % (timeit.default_timer() - start_time)\n)\n\n\nstart_time = timeit.default_timer()\ntest_dict = test_n_upate(test_loader)\nprint(test_dict)\nprint(\n    \"Persistant forecast on Test takes--- %s seconds ---\" % (timeit.default_timer() - start_time)\n)\ndataset.reset_label_time()"
  },
  {
    "path": "examples/nodeproppred/tgbn-token/tgn.py",
    "content": "import timeit\r\nimport argparse\r\nfrom tqdm import tqdm\r\nimport torch\r\nimport matplotlib.pyplot as plt\r\nimport numpy as np\r\n\r\nfrom torch_geometric.loader import TemporalDataLoader\r\nfrom torch_geometric.nn import TGNMemory\r\nfrom torch_geometric.nn.models.tgn import (\r\n    IdentityMessage,\r\n    LastAggregator,\r\n    LastNeighborLoader,\r\n)\r\n\r\nfrom modules.decoder import NodePredictor\r\nfrom modules.emb_module import GraphAttentionEmbedding\r\nfrom tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset\r\nfrom tgb.nodeproppred.evaluate import Evaluator\r\nfrom tgb.utils.utils import set_random_seed\r\nfrom tgb.utils.stats import plot_curve\r\n\r\nparser = argparse.ArgumentParser(description='parsing command line arguments as hyperparameters')\r\nparser.add_argument('-s', '--seed', type=int, default=1,\r\n                    help='random seed to use')\r\nparser.parse_args()\r\nargs = parser.parse_args()\r\n# setting random seed\r\nseed = int(args.seed) #1,2,3,4,5\r\nprint (\"setting random seed to be\", seed)\r\ntorch.manual_seed(seed)\r\nset_random_seed(seed)\r\n\r\n# hyperparameters\r\nlr = 0.0001\r\nepochs = 50\r\n\r\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\r\nname = \"tgbn-token\"\r\ndataset = PyGNodePropPredDataset(name=name, root=\"datasets\")\r\ntrain_mask = dataset.train_mask\r\nval_mask = dataset.val_mask\r\ntest_mask = dataset.test_mask\r\n\r\neval_metric = dataset.eval_metric\r\nnum_classes = dataset.num_classes\r\ndata = dataset.get_TemporalData()\r\ndata = data.to(device)\r\n\r\nevaluator = Evaluator(name=name)\r\n\r\n\r\ntrain_data = data[train_mask]\r\nval_data = data[val_mask]\r\ntest_data = data[test_mask]\r\n\r\n# Ensure to only sample actual destination nodes as negatives.\r\nmin_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\r\n\r\nbatch_size = 200\r\n\r\ntrain_loader = TemporalDataLoader(train_data, batch_size=batch_size)\r\nval_loader = TemporalDataLoader(val_data, batch_size=batch_size)\r\ntest_loader = TemporalDataLoader(test_data, batch_size=batch_size)\r\n\r\nneighbor_loader = LastNeighborLoader(data.num_nodes, size=10, device=device)\r\n\r\nmemory_dim = time_dim = embedding_dim = 100\r\n\r\nmemory = TGNMemory(\r\n    data.num_nodes,\r\n    data.msg.size(-1),\r\n    memory_dim,\r\n    time_dim,\r\n    message_module=IdentityMessage(data.msg.size(-1), memory_dim, time_dim),\r\n    aggregator_module=LastAggregator(),\r\n).to(device)\r\n\r\ngnn = (\r\n    GraphAttentionEmbedding(\r\n        in_channels=memory_dim,\r\n        out_channels=embedding_dim,\r\n        msg_dim=data.msg.size(-1),\r\n        time_enc=memory.time_enc,\r\n    )\r\n    .to(device)\r\n    .float()\r\n)\r\n\r\nnode_pred = NodePredictor(in_dim=embedding_dim, out_dim=num_classes).to(device)\r\n\r\noptimizer = torch.optim.Adam(\r\n    set(memory.parameters()) | set(gnn.parameters()) | set(node_pred.parameters()),\r\n    lr=lr,\r\n)\r\n\r\ncriterion = torch.nn.CrossEntropyLoss()\r\n# Helper vector to map global node indices to local ones.\r\nassoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)\r\n\r\n\r\ndef plot_curve(scores, out_name):\r\n    plt.plot(scores, color=\"#e34a33\")\r\n    plt.ylabel(\"score\")\r\n    plt.savefig(out_name + \".pdf\")\r\n    plt.close()\r\n\r\n\r\ndef process_edges(src, dst, t, msg):\r\n    if src.nelement() > 0:\r\n        # msg = msg.to(torch.float32)\r\n        memory.update_state(src, dst, t, msg)\r\n        neighbor_loader.insert(src, dst)\r\n\r\n\r\ndef train():\r\n    memory.train()\r\n    gnn.train()\r\n    node_pred.train()\r\n\r\n    memory.reset_state()  # Start with a fresh memory.\r\n    neighbor_loader.reset_state()  # Start with an empty graph.\r\n\r\n    total_loss = 0\r\n    label_t = dataset.get_label_time()  # check when does the first label start\r\n    total_score = 0\r\n    num_label_ts = 0\r\n\r\n    for batch in tqdm(train_loader):\r\n        batch = batch.to(device)\r\n        optimizer.zero_grad()\r\n        src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg\r\n\r\n        query_t = batch.t[-1]\r\n        # check if this batch moves to the next day\r\n        if query_t > label_t:\r\n            # find the node labels from the past day\r\n            label_tuple = dataset.get_node_label(query_t)\r\n            label_ts, label_srcs, labels = (\r\n                label_tuple[0],\r\n                label_tuple[1],\r\n                label_tuple[2],\r\n            )\r\n            label_t = dataset.get_label_time()\r\n            label_srcs = label_srcs.to(device)\r\n\r\n            # Process all edges that are still in the past day\r\n            previous_day_mask = batch.t < label_t\r\n            process_edges(\r\n                src[previous_day_mask],\r\n                dst[previous_day_mask],\r\n                t[previous_day_mask],\r\n                msg[previous_day_mask],\r\n            )\r\n            # Reset edges to be the edges from tomorrow so they can be used later\r\n            src, dst, t, msg = (\r\n                src[~previous_day_mask],\r\n                dst[~previous_day_mask],\r\n                t[~previous_day_mask],\r\n                msg[~previous_day_mask],\r\n            )\r\n\r\n            \"\"\"\r\n            modified for node property prediction\r\n            1. sample neighbors from the neighbor loader for all nodes to be predicted\r\n            2. extract memory from the sampled neighbors and the nodes\r\n            3. run gnn with the extracted memory embeddings and the corresponding time and message\r\n            \"\"\"\r\n            n_id = label_srcs\r\n            n_id_neighbors, mem_edge_index, e_id = neighbor_loader(n_id)\r\n            assoc[n_id_neighbors] = torch.arange(n_id_neighbors.size(0), device=device)\r\n\r\n            z, last_update = memory(n_id_neighbors)\r\n\r\n            z = gnn(\r\n                z,\r\n                last_update,\r\n                mem_edge_index,\r\n                data.t[e_id].to(device),\r\n                data.msg[e_id].to(device),\r\n            )\r\n            z = z[assoc[n_id]]\r\n\r\n            # loss and metric computation\r\n            pred = node_pred(z)\r\n\r\n            loss = criterion(pred, labels.to(device))\r\n            np_pred = pred.cpu().detach().numpy()\r\n            np_true = labels.cpu().detach().numpy()\r\n            input_dict = {\r\n                \"y_true\": np_true,\r\n                \"y_pred\": np_pred,\r\n                \"eval_metric\": [eval_metric],\r\n            }\r\n            result_dict = evaluator.eval(input_dict)\r\n            score = result_dict[eval_metric]\r\n            total_score += score\r\n            num_label_ts += 1\r\n            loss.backward()\r\n            optimizer.step()\r\n            total_loss += float(loss.detach())\r\n\r\n        # Update memory and neighbor loader with ground-truth state.\r\n        process_edges(src, dst, t, msg)\r\n        memory.detach()\r\n\r\n    metric_dict = {\r\n        \"ce\": total_loss / num_label_ts,\r\n    }\r\n    metric_dict[eval_metric] = total_score / num_label_ts\r\n    return metric_dict\r\n\r\n\r\n@torch.no_grad()\r\ndef test(loader):\r\n    memory.eval()\r\n    gnn.eval()\r\n    node_pred.eval()\r\n\r\n    label_t = dataset.get_label_time()  # check when does the first label start\r\n    num_label_ts = 0\r\n    total_score = 0\r\n\r\n    for batch in tqdm(loader):\r\n        batch = batch.to(device)\r\n        src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg\r\n\r\n        query_t = batch.t[-1]\r\n        if query_t > label_t:\r\n            label_tuple = dataset.get_node_label(query_t)\r\n            if label_tuple is None:\r\n                break\r\n            label_ts, label_srcs, labels = (\r\n                label_tuple[0],\r\n                label_tuple[1],\r\n                label_tuple[2],\r\n            )\r\n            label_t = dataset.get_label_time()\r\n            label_srcs = label_srcs.to(device)\r\n\r\n            # Process all edges that are still in the past day\r\n            previous_day_mask = batch.t < label_t\r\n            process_edges(\r\n                src[previous_day_mask],\r\n                dst[previous_day_mask],\r\n                t[previous_day_mask],\r\n                msg[previous_day_mask],\r\n            )\r\n            # Reset edges to be the edges from tomorrow so they can be used later\r\n            src, dst, t, msg = (\r\n                src[~previous_day_mask],\r\n                dst[~previous_day_mask],\r\n                t[~previous_day_mask],\r\n                msg[~previous_day_mask],\r\n            )\r\n\r\n            \"\"\"\r\n            modified for node property prediction\r\n            1. sample neighbors from the neighbor loader for all nodes to be predicted\r\n            2. extract memory from the sampled neighbors and the nodes\r\n            3. run gnn with the extracted memory embeddings and the corresponding time and message\r\n            \"\"\"\r\n            n_id = label_srcs\r\n            n_id_neighbors, mem_edge_index, e_id = neighbor_loader(n_id)\r\n            assoc[n_id_neighbors] = torch.arange(n_id_neighbors.size(0), device=device)\r\n\r\n            z, last_update = memory(n_id_neighbors)\r\n            z = gnn(\r\n                z,\r\n                last_update,\r\n                mem_edge_index,\r\n                data.t[e_id].to(device),\r\n                data.msg[e_id].to(device),\r\n            )\r\n            z = z[assoc[n_id]]\r\n\r\n            # loss and metric computation\r\n            pred = node_pred(z)\r\n            np_pred = pred.cpu().detach().numpy()\r\n            np_true = labels.cpu().detach().numpy()\r\n\r\n            input_dict = {\r\n                \"y_true\": np_true,\r\n                \"y_pred\": np_pred,\r\n                \"eval_metric\": [eval_metric],\r\n            }\r\n            result_dict = evaluator.eval(input_dict)\r\n            score = result_dict[eval_metric]\r\n            total_score += score\r\n            num_label_ts += 1\r\n\r\n        process_edges(src, dst, t, msg)\r\n\r\n    metric_dict = {}\r\n    metric_dict[eval_metric] = total_score / num_label_ts\r\n    return metric_dict\r\n\r\n\r\ntrain_curve = []\r\nval_curve = []\r\ntest_curve = []\r\nmax_val_score = 0  #find the best test score based on validation score\r\nbest_test_idx = 0\r\nfor epoch in range(1, epochs + 1):\r\n    start_time = timeit.default_timer()\r\n    train_dict = train()\r\n    print(\"------------------------------------\")\r\n    print(f\"training Epoch: {epoch:02d}\")\r\n    print(train_dict)\r\n    train_curve.append(train_dict[eval_metric])\r\n    print(\"Training takes--- %s seconds ---\" % (timeit.default_timer() - start_time))\r\n    \r\n    start_time = timeit.default_timer()\r\n    val_dict = test(val_loader)\r\n    print(val_dict)\r\n    val_curve.append(val_dict[eval_metric])\r\n    if (val_dict[eval_metric] > max_val_score):\r\n        max_val_score = val_dict[eval_metric]\r\n        best_test_idx = epoch - 1\r\n    print(\"Validation takes--- %s seconds ---\" % (timeit.default_timer() - start_time))\r\n\r\n    start_time = timeit.default_timer()\r\n    test_dict = test(test_loader)\r\n    print(test_dict)\r\n    test_curve.append(test_dict[eval_metric])\r\n    print(\"Test takes--- %s seconds ---\" % (timeit.default_timer() - start_time))\r\n    print(\"------------------------------------\")\r\n    dataset.reset_label_time()\r\n\r\n\r\n# code for plotting\r\nplot_curve(train_curve, \"train_curve\")\r\nplot_curve(val_curve, \"val_curve\")\r\nplot_curve(test_curve, \"test_curve\")\r\n\r\nmax_test_score = test_curve[best_test_idx]\r\nprint(\"------------------------------------\")\r\nprint(\"------------------------------------\")\r\nprint (\"best val score: \", max_val_score)\r\nprint (\"best validation epoch   : \", best_test_idx + 1)\r\nprint (\"best test score: \", max_test_score)\r\n"
  },
  {
    "path": "examples/nodeproppred/tgbn-trade/count_new_nodes.py",
    "content": "\nimport timeit\nimport numpy as np\nfrom tqdm import tqdm\nimport math\nimport os\nimport os.path as osp\nfrom pathlib import Path\nimport sys\nimport argparse\n\n# internal imports\nfrom modules.nodebank import NodeBank\nfrom tgb.linkproppred.evaluate import Evaluator\nfrom modules.edgebank_predictor import EdgeBankPredictor\nfrom tgb.utils.utils import set_random_seed\nfrom tgb.nodeproppred.dataset import NodePropPredDataset\n\n# ==================\n# ==================\n# ==================\n\ndef count_nodes(data, test_mask, nodebank):\n    r\"\"\"\n    Evaluated the dynamic link prediction\n    Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges\n\n    Parameters:\n        data: a dataset object\n        test_mask: required masks to load the test set edges\n        neg_sampler: an object that gives the negative edges corresponding to each positive edge\n        split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives\n    Returns:\n        perf_metric: the result of the performance evaluation\n    \"\"\"\n    node_dict_new = {}\n    node_dict = {}\n    num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)\n    for batch_idx in tqdm(range(num_batches)):\n        start_idx = batch_idx * BATCH_SIZE\n        end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))\n        pos_src, pos_dst, pos_t = (\n            data['sources'][test_mask][start_idx: end_idx],\n            data['destinations'][test_mask][start_idx: end_idx],\n            data['timestamps'][test_mask][start_idx: end_idx],\n        )\n\n        for node in pos_src:\n            if (not nodebank.query_node(node)):\n                if (node not in node_dict_new):\n                    node_dict_new[node] = 1\n\n            if (node not in node_dict):\n                node_dict[node] = 1\n        \n        for node in pos_dst:\n            if (not nodebank.query_node(node)):\n                if (node not in node_dict_new):\n                    node_dict_new[node] = 1\n\n            if (node not in node_dict):\n                node_dict[node] = 1\n\n    return len(node_dict_new), len(node_dict)\n\n\n\n\n\n\ndef get_args():\n    parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')\n    parser.add_argument('-d', '--data', type=str, help='Dataset name', default='tgbl-wiki')\n    parser.add_argument('--bs', type=int, help='Batch size', default=200)\n    parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)\n    parser.add_argument('--seed', type=int, help='Random seed', default=1)\n    parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])\n    parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)\n\n    try:\n        args = parser.parse_args()\n    except:\n        parser.print_help()\n        sys.exit(0)\n    return args, sys.argv \n\n# ==================\n# ==================\n# ==================\n\nstart_overall = timeit.default_timer()\n\n# set hyperparameters\nargs, _ = get_args()\n\nSEED = args.seed  # set the random seed for consistency\nset_random_seed(SEED)\nMEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`\nBATCH_SIZE = 10000\nK_VALUE = args.k_value\nTIME_WINDOW_RATIO = args.time_window_ratio\nDATA = \"tgbn-token\" #\"tgbl-wiki\"\n\nMODEL_NAME = 'EdgeBank'\n\n# data loading with `numpy`\ndataset = NodePropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\ndata = dataset.full_data  \nmetric = dataset.eval_metric\n\n# get masks\ntrain_mask = dataset.train_mask\nval_mask = dataset.val_mask\ntest_mask = dataset.test_mask\n\ntrain_src = data['sources'][train_mask]\ntrain_dst = data['destinations'][train_mask]\n\n#data for memory in edgebank\nhist_src = np.concatenate([data['sources'][train_mask]])\nhist_dst = np.concatenate([data['destinations'][train_mask]])\nhist_ts = np.concatenate([data['timestamps'][train_mask]])\n\n# Set EdgeBank with memory updater\nedgebank = EdgeBankPredictor(\n        hist_src,\n        hist_dst,\n        hist_ts,\n        memory_mode=MEMORY_MODE,\n        time_window_ratio=TIME_WINDOW_RATIO)\n\nprint(\"==========================================================\")\nprint(f\"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============\")\nprint(\"==========================================================\")\n\nevaluator = Evaluator(name=DATA)\n\n\nnodebank = NodeBank(train_src, train_dst)\n\nnew_val_num, val_total = count_nodes(data, val_mask, nodebank)\nprint ()\nprint (\"-------------------------------------------------------\")\nprint (\"there are \", new_val_num, \" new nodes in the validation set\")\nprint (\"there are \", val_total, \" total nodes in the validation set\")\nprint (\" the percentage of new nodes in the validation set is \", (new_val_num/val_total))\n\n\nnew_test_num, test_total = count_nodes(data, test_mask, nodebank)\nprint ()\nprint (\"-------------------------------------------------------\")\nprint (\"there are \", new_test_num, \" new nodes in the test set\")\nprint (\"there are \", test_total, \" total nodes in the test set\") \nprint (\" the percentage of new nodes in the test set is \", (new_test_num/test_total))\n"
  },
  {
    "path": "examples/nodeproppred/tgbn-trade/dyrep.py",
    "content": "\"\"\"\nDyRep\n    This has been implemented with intuitions from the following sources:\n    - https://github.com/twitter-research/tgn\n    - https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py\n\n    Spec.:\n        - Memory Updater: RNN\n        - Embedding Module: ID\n        - Message Function: ATTN\n\"\"\"\nimport timeit\nfrom tqdm import tqdm\nimport torch\nfrom torch_geometric.loader import TemporalDataLoader\n\n# internal imports\nfrom tgb.utils.utils import get_args, set_random_seed\nfrom tgb.nodeproppred.evaluate import Evaluator\nfrom modules.decoder import NodePredictor\nfrom modules.emb_module import GraphAttentionEmbedding\nfrom modules.msg_func import IdentityMessage\nfrom modules.msg_agg import LastAggregator\nfrom modules.neighbor_loader import LastNeighborLoader\nfrom modules.memory_module import DyRepMemory\nfrom tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset\n\n\ndef process_edges(src, dst, t, msg):\n    if src.nelement() > 0:\n        model['memory'].update_state(src, dst, t, msg)\n        neighbor_loader.insert(src, dst)\n\n\n# ==========\n# ========== Define helper function...\n# ==========\n\ndef train():\n    model['memory'].train()\n    model['gnn'].train()\n    model['node_pred'].train()\n\n    model['memory'].reset_state()  # Start with a fresh memory.\n    neighbor_loader.reset_state()  # Start with an empty graph.\n\n    total_loss = 0\n    label_t = dataset.get_label_time()  # check when does the first label start\n    num_label_ts = 0\n    total_score = 0\n    for batch in train_loader:\n        batch = batch.to(device)\n        optimizer.zero_grad()\n\n        src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg\n\n        query_t = batch.t[-1]\n        # check if this batch moves to the next day\n        if query_t > label_t:\n            # find the node labels from the past day\n            label_tuple = dataset.get_node_label(query_t)\n            label_ts, label_srcs, labels = (\n                label_tuple[0],\n                label_tuple[1],\n                label_tuple[2],\n            )\n            label_t = dataset.get_label_time()\n            label_srcs = label_srcs.to(device)\n\n            # Process all edges that are still in the past day\n            previous_day_mask = batch.t < label_t\n            process_edges(\n                src[previous_day_mask],\n                dst[previous_day_mask],\n                t[previous_day_mask],\n                msg[previous_day_mask],\n            )\n            # Reset edges to be the edges from tomorrow so they can be used later\n            src, dst, t, msg = (\n                src[~previous_day_mask],\n                dst[~previous_day_mask],\n                t[~previous_day_mask],\n                msg[~previous_day_mask],\n            )\n\n            \"\"\"\n            modified for node property prediction\n            1. sample neighbors from the neighbor loader for all nodes to be predicted\n            2. extract memory from the sampled neighbors and the nodes\n            3. run gnn with the extracted memory embeddings and the corresponding time and message\n            \"\"\"\n            n_id = label_srcs\n            n_id_neighbors, mem_edge_index, e_id = neighbor_loader(n_id)\n            assoc[n_id_neighbors] = torch.arange(n_id_neighbors.size(0), device=device)\n\n            z, last_update = model['memory'](n_id_neighbors)\n            z = model['gnn'](\n                z,\n                last_update,\n                mem_edge_index,\n                data.t[e_id].to(device),\n                data.msg[e_id].to(device),\n            )\n            z = z[assoc[n_id]]\n\n            # loss and metric computation\n            pred = model['node_pred'](z)\n\n            loss = criterion(pred, labels.to(device))\n            np_pred = pred.cpu().detach().numpy()\n            np_true = labels.cpu().detach().numpy()\n\n            input_dict = {\n                \"y_true\": np_true,\n                \"y_pred\": np_pred,\n                \"eval_metric\": [metric],\n            }\n            result_dict = evaluator.eval(input_dict)\n            score = result_dict[metric]\n            total_score += score\n            num_label_ts += 1\n\n            loss.backward()\n            optimizer.step()\n            total_loss += float(loss.detach())\n\n        # Update memory and neighbor loader with ground-truth state.\n        process_edges(src, dst, t, msg)\n        model['memory'].detach()\n\n    metric_dict = {\n        \"ce\": total_loss / num_label_ts,\n    }\n    metric_dict[metric] = total_score / num_label_ts\n    return metric_dict\n\n\n@torch.no_grad()\ndef test(loader):\n    model['memory'].eval()\n    model['gnn'].eval()\n    model['node_pred'].eval()\n\n    total_score = 0\n    label_t = dataset.get_label_time()  # check when does the first label start\n    num_label_ts = 0\n\n    for batch in loader:\n        batch = batch.to(device)\n        src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg\n\n        query_t = batch.t[-1]\n        if query_t > label_t:\n            label_tuple = dataset.get_node_label(query_t)\n            if label_tuple is None:\n                break\n            label_ts, label_srcs, labels = (\n                label_tuple[0],\n                label_tuple[1],\n                label_tuple[2],\n            )\n            label_t = dataset.get_label_time()\n            label_srcs = label_srcs.to(device)\n\n            # Process all edges that are still in the past day\n            previous_day_mask = batch.t < label_t\n            process_edges(\n                src[previous_day_mask],\n                dst[previous_day_mask],\n                t[previous_day_mask],\n                msg[previous_day_mask],\n            )\n            # Reset edges to be the edges from tomorrow so they can be used later\n            src, dst, t, msg = (\n                src[~previous_day_mask],\n                dst[~previous_day_mask],\n                t[~previous_day_mask],\n                msg[~previous_day_mask],\n            )\n\n            \"\"\"\n            modified for node property prediction\n            1. sample neighbors from the neighbor loader for all nodes to be predicted\n            2. extract memory from the sampled neighbors and the nodes\n            3. run gnn with the extracted memory embeddings and the corresponding time and message\n            \"\"\"\n            n_id = label_srcs\n            n_id_neighbors, mem_edge_index, e_id = neighbor_loader(n_id)\n            assoc[n_id_neighbors] = torch.arange(n_id_neighbors.size(0), device=device)\n\n            z, last_update = model['memory'](n_id_neighbors)\n            z = model['gnn'](\n                z,\n                last_update,\n                mem_edge_index,\n                data.t[e_id].to(device),\n                data.msg[e_id].to(device),\n            )\n            z = z[assoc[n_id]]\n\n            # loss and metric computation\n            pred = model['node_pred'](z)\n            np_pred = pred.cpu().detach().numpy()\n            np_true = labels.cpu().detach().numpy()\n\n            input_dict = {\n                \"y_true\": np_true,\n                \"y_pred\": np_pred,\n                \"eval_metric\": [metric],\n            }\n            result_dict = evaluator.eval(input_dict)\n            score = result_dict[metric]\n            total_score += score\n            num_label_ts += 1\n\n        process_edges(src, dst, t, msg)\n\n    metric_dict = {}\n    metric_dict[metric] = total_score / num_label_ts\n    return metric_dict\n\n# ==========\n# ==========\n# ==========\n\n# Start...\nstart_overall = timeit.default_timer()\n\n# ========== set parameters...\nargs, _ = get_args()\nprint(\"INFO: Arguments:\", args)\n\nDATA = \"tgbn-trade\"\nLR = args.lr\nBATCH_SIZE = args.bs\nK_VALUE = args.k_value  \nNUM_EPOCH = args.num_epoch\nSEED = args.seed\nMEM_DIM = args.mem_dim\nTIME_DIM = args.time_dim\nEMB_DIM = args.emb_dim\nTOLERANCE = args.tolerance\nPATIENCE = args.patience\nNUM_RUNS = args.num_run\nNUM_NEIGHBORS = 10\n\n\n# setting random seed\ntorch.manual_seed(SEED)\nset_random_seed(SEED)\n\nMODEL_NAME = 'DyRep'\nUSE_SRC_EMB_IN_MSG = False\nUSE_DST_EMB_IN_MSG = True\nevaluator = Evaluator(name=DATA)\n# ==========\n\n# set the device\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n# data loading\ndataset = PyGNodePropPredDataset(name=DATA, root=\"datasets\")\ntrain_mask = dataset.train_mask\nval_mask = dataset.val_mask\ntest_mask = dataset.test_mask\ndata = dataset.get_TemporalData()\ndata = data.to(device)\nmetric = dataset.eval_metric\n\ntrain_data = data[train_mask]\nval_data = data[val_mask]\ntest_data = data[test_mask]\nnum_classes = dataset.num_classes\n\n\ntrain_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)\nval_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)\ntest_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)\n\n# Ensure to only sample actual destination nodes as negatives.\nmin_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\n\n# neighborhood sampler\nneighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)\n\n# define the model end-to-end\nmemory = DyRepMemory(\n    data.num_nodes,\n    data.msg.size(-1),\n    MEM_DIM,\n    TIME_DIM,\n    message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),\n    aggregator_module=LastAggregator(),\n    memory_updater_type='rnn',\n    use_src_emb_in_msg=USE_SRC_EMB_IN_MSG,\n    use_dst_emb_in_msg=USE_DST_EMB_IN_MSG\n).to(device)\n\ngnn = GraphAttentionEmbedding(\n    in_channels=MEM_DIM,\n    out_channels=EMB_DIM,\n    msg_dim=data.msg.size(-1),\n    time_enc=memory.time_enc,\n).to(device)\n\nnode_pred = NodePredictor(in_dim=EMB_DIM, out_dim=num_classes).to(device)\n\nmodel = {'memory': memory,\n         'gnn': gnn,\n         'node_pred': node_pred}\n\noptimizer = torch.optim.Adam(\n    set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['node_pred'].parameters()),\n    lr=LR,\n)\n\ncriterion = torch.nn.CrossEntropyLoss()\n# Helper vector to map global node indices to local ones.\nassoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)\n\ntrain_curve = []\nval_curve = []\ntest_curve = []\nmax_val_score = 0  #find the best test score based on validation score\nbest_test_idx = 0\nfor epoch in range(1, NUM_EPOCH + 1):\n    start_time = timeit.default_timer()\n    train_dict = train()\n    print(\"------------------------------------\")\n    print(f\"training Epoch: {epoch:02d}\")\n    print(train_dict)\n    train_curve.append(train_dict[metric])\n    print(\"Training takes--- %s seconds ---\" % (timeit.default_timer() - start_time))\n    \n    start_time = timeit.default_timer()\n    val_dict = test(val_loader)\n    print(val_dict)\n    val_curve.append(val_dict[metric])\n    if (val_dict[metric] > max_val_score):\n        max_val_score = val_dict[metric]\n        best_test_idx = epoch - 1\n    print(\"Validation takes--- %s seconds ---\" % (timeit.default_timer() - start_time))\n\n    start_time = timeit.default_timer()\n    test_dict = test(test_loader)\n    print(test_dict)\n    test_curve.append(test_dict[metric])\n    print(\"Test takes--- %s seconds ---\" % (timeit.default_timer() - start_time))\n    print(\"------------------------------------\")\n    dataset.reset_label_time()\n\nmax_test_score = test_curve[best_test_idx]\nprint(\"------------------------------------\")\nprint(\"------------------------------------\")\nprint (\"best val score: \", max_val_score)\nprint (\"best validation epoch   : \", best_test_idx + 1)\nprint (\"best test score: \", max_test_score)\n\n"
  },
  {
    "path": "examples/nodeproppred/tgbn-trade/moving_average.py",
    "content": "\"\"\"\nimplement persistant forecast as baseline for the node prop pred task\nsimply predict last seen label for the node\n\"\"\"\n\nimport timeit\nimport numpy as np\nfrom torch_geometric.loader import TemporalDataLoader\n\n# local imports\nfrom tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset\nfrom modules.heuristics import MovingAverage\nfrom tgb.nodeproppred.evaluate import Evaluator\n\n\ndevice = \"cpu\"\n\nwindow = 7\nname = \"tgbn-trade\"\ndataset = PyGNodePropPredDataset(name=name, root=\"datasets\")\nnum_classes = dataset.num_classes\ndata = dataset.get_TemporalData()\ndata = data.to(device)\n\neval_metric = dataset.eval_metric\nforecaster = MovingAverage(num_classes, window=window)\nevaluator = Evaluator(name=name)\n\n\n# Ensure to only sample actual destination nodes as negatives.\nmin_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\ntrain_data, val_data, test_data = data.train_val_test_split(\n    val_ratio=0.15, test_ratio=0.15\n)\n\nbatch_size = 200\n\ntrain_loader = TemporalDataLoader(train_data, batch_size=batch_size)\nval_loader = TemporalDataLoader(val_data, batch_size=batch_size)\ntest_loader = TemporalDataLoader(test_data, batch_size=batch_size)\n\n\ndef test_n_upate(loader):\n    label_t = dataset.get_label_time()  # check when does the first label start\n    num_label_ts = 0\n    total_score = 0\n\n    for batch in loader:\n        batch = batch.to(device)\n        src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg\n\n        query_t = batch.t[-1]\n        if query_t > label_t:\n            label_tuple = dataset.get_node_label(query_t)\n            if label_tuple is None:\n                break\n            label_ts, label_srcs, labels = (\n                label_tuple[0],\n                label_tuple[1],\n                label_tuple[2],\n            )\n            label_ts = label_ts.numpy()\n            label_srcs = label_srcs.numpy()\n            labels = labels.numpy()\n            label_t = dataset.get_label_time()\n\n            preds = []\n\n            for i in range(0, label_srcs.shape[0]):\n                node_id = label_srcs[i]\n                pred_vec = forecaster.query_dict(node_id)\n                preds.append(pred_vec)\n                forecaster.update_dict(node_id, labels[i])\n\n            np_pred = np.stack(preds, axis=0)\n            np_true = labels\n\n            input_dict = {\n                \"y_true\": np_true,\n                \"y_pred\": np_pred,\n                \"eval_metric\": [eval_metric],\n            }\n            result_dict = evaluator.eval(input_dict)\n            score = result_dict[eval_metric]\n\n            total_score += score\n            num_label_ts += 1\n\n    metric_dict = {}\n    metric_dict[eval_metric] = total_score / num_label_ts\n    return metric_dict\n\n\n\"\"\"\ntrain, val and test for one epoch only\n\"\"\"\nstart_time = timeit.default_timer()\nmetric_dict = test_n_upate(train_loader)\nprint(metric_dict)\nprint(\n    \"Persistant forecast on Training takes--- %s seconds ---\"\n    % (timeit.default_timer() - start_time)\n)\n\nstart_time = timeit.default_timer()\nval_dict = test_n_upate(val_loader)\nprint(val_dict)\nprint(\n    \"Persistant forecast on validation takes--- %s seconds ---\"\n    % (timeit.default_timer() - start_time)\n)\n\n\nstart_time = timeit.default_timer()\ntest_dict = test_n_upate(test_loader)\nprint(test_dict)\nprint(\n    \"Persistant forecast on Test takes--- %s seconds ---\" % (timeit.default_timer() - start_time)\n)\ndataset.reset_label_time()"
  },
  {
    "path": "examples/nodeproppred/tgbn-trade/persistant_forecast.py",
    "content": "\"\"\"\r\nimplement persistant forecast as baseline for the node prop pred task\r\nsimply predict last seen label for the node\r\n\"\"\"\r\n\r\nimport timeit\r\nimport numpy as np\r\nfrom torch_geometric.loader import TemporalDataLoader\r\n\r\n# local imports\r\nfrom tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset\r\nfrom modules.heuristics import PersistantForecaster\r\nfrom tgb.nodeproppred.evaluate import Evaluator\r\n\r\n\r\ndevice = \"cpu\"\r\n\r\nname = \"tgbn-trade\"\r\ndataset = PyGNodePropPredDataset(name=name, root=\"datasets\")\r\nnum_classes = dataset.num_classes\r\ndata = dataset.get_TemporalData()\r\ndata = data.to(device)\r\n\r\neval_metric = dataset.eval_metric\r\nforecaster = PersistantForecaster(num_classes)\r\nevaluator = Evaluator(name=name)\r\n\r\n\r\n# Ensure to only sample actual destination nodes as negatives.\r\nmin_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\r\ntrain_data, val_data, test_data = data.train_val_test_split(\r\n    val_ratio=0.15, test_ratio=0.15\r\n)\r\n\r\nbatch_size = 200\r\n\r\ntrain_loader = TemporalDataLoader(train_data, batch_size=batch_size)\r\nval_loader = TemporalDataLoader(val_data, batch_size=batch_size)\r\ntest_loader = TemporalDataLoader(test_data, batch_size=batch_size)\r\n\r\n\r\ndef test_n_upate(loader):\r\n    label_t = dataset.get_label_time()  # check when does the first label start\r\n    num_label_ts = 0\r\n    total_score = 0\r\n\r\n    for batch in loader:\r\n        batch = batch.to(device)\r\n        src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg\r\n\r\n        query_t = batch.t[-1]\r\n        if query_t > label_t:\r\n            label_tuple = dataset.get_node_label(query_t)\r\n            if label_tuple is None:\r\n                break\r\n            label_ts, label_srcs, labels = (\r\n                label_tuple[0],\r\n                label_tuple[1],\r\n                label_tuple[2],\r\n            )\r\n            label_ts = label_ts.numpy()\r\n            label_srcs = label_srcs.numpy()\r\n            labels = labels.numpy()\r\n            label_t = dataset.get_label_time()\r\n\r\n            preds = []\r\n\r\n            for i in range(0, label_srcs.shape[0]):\r\n                node_id = label_srcs[i]\r\n                pred_vec = forecaster.query_dict(node_id)\r\n                preds.append(pred_vec)\r\n                forecaster.update_dict(node_id, labels[i])\r\n\r\n            np_pred = np.stack(preds, axis=0)\r\n            np_true = labels\r\n\r\n            input_dict = {\r\n                \"y_true\": np_true,\r\n                \"y_pred\": np_pred,\r\n                \"eval_metric\": [eval_metric],\r\n            }\r\n            result_dict = evaluator.eval(input_dict)\r\n            score = result_dict[eval_metric]\r\n            total_score += score\r\n            num_label_ts += 1\r\n\r\n    metric_dict = {}\r\n    metric_dict[eval_metric] = total_score / num_label_ts \r\n    return metric_dict\r\n\r\n\r\n\"\"\"\r\ntrain, val and test for one epoch only\r\n\"\"\"\r\n\r\nstart_time = timeit.default_timer()\r\nmetric_dict = test_n_upate(train_loader)\r\nprint(metric_dict)\r\nprint(\r\n    \"Persistant forecast on Training takes--- %s seconds ---\"\r\n    % (timeit.default_timer() - start_time)\r\n)\r\n\r\nstart_time = timeit.default_timer()\r\nval_dict = test_n_upate(val_loader)\r\nprint(val_dict)\r\nprint(\r\n    \"Persistant forecast on validation takes--- %s seconds ---\"\r\n    % (timeit.default_timer() - start_time)\r\n)\r\n\r\n\r\nstart_time = timeit.default_timer()\r\ntest_dict = test_n_upate(test_loader)\r\nprint(test_dict)\r\nprint(\r\n    \"Persistant forecast on Test takes--- %s seconds ---\" % (timeit.default_timer() - start_time)\r\n)\r\ndataset.reset_label_time()"
  },
  {
    "path": "examples/nodeproppred/tgbn-trade/tgn.py",
    "content": "import timeit\r\nimport argparse\r\nfrom tqdm import tqdm\r\nimport torch\r\nimport matplotlib.pyplot as plt\r\n\r\nfrom torch_geometric.loader import TemporalDataLoader\r\nfrom torch_geometric.nn import TGNMemory\r\nfrom torch_geometric.nn.models.tgn import (\r\n    IdentityMessage,\r\n    LastAggregator,\r\n    LastNeighborLoader,\r\n)\r\n\r\nfrom modules.decoder import NodePredictor\r\nfrom modules.emb_module import GraphAttentionEmbedding\r\nfrom tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset\r\nfrom tgb.nodeproppred.evaluate import Evaluator\r\nfrom tgb.utils.utils import set_random_seed\r\nfrom tgb.utils.stats import plot_curve\r\n\r\nparser = argparse.ArgumentParser(description='parsing command line arguments as hyperparameters')\r\nparser.add_argument('-s', '--seed', type=int, default=1,\r\n                    help='random seed to use')\r\nparser.parse_args()\r\nargs = parser.parse_args()\r\n# setting random seed\r\nseed = int(args.seed) #1,2,3,4,5\r\ntorch.manual_seed(seed)\r\nset_random_seed(seed)\r\n\r\n# hyperparameters\r\nlr = 0.0001\r\nepochs = 50\r\n\r\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\r\nname = \"tgbn-trade\"\r\ndataset = PyGNodePropPredDataset(name=name, root=\"datasets\")\r\ntrain_mask = dataset.train_mask\r\nval_mask = dataset.val_mask\r\ntest_mask = dataset.test_mask\r\n\r\neval_metric = dataset.eval_metric\r\nnum_classes = dataset.num_classes\r\ndata = dataset.get_TemporalData()\r\ndata = data.to(device)\r\n\r\nevaluator = Evaluator(name=name)\r\n\r\n\r\ntrain_data = data[train_mask]\r\nval_data = data[val_mask]\r\ntest_data = data[test_mask]\r\n\r\n# Ensure to only sample actual destination nodes as negatives.\r\nmin_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\r\n\r\nbatch_size = 200\r\n\r\ntrain_loader = TemporalDataLoader(train_data, batch_size=batch_size)\r\nval_loader = TemporalDataLoader(val_data, batch_size=batch_size)\r\ntest_loader = TemporalDataLoader(test_data, batch_size=batch_size)\r\n\r\nneighbor_loader = LastNeighborLoader(data.num_nodes, size=10, device=device)\r\n\r\nmemory_dim = time_dim = embedding_dim = 100\r\n\r\nmemory = TGNMemory(\r\n    data.num_nodes,\r\n    data.msg.size(-1),\r\n    memory_dim,\r\n    time_dim,\r\n    message_module=IdentityMessage(data.msg.size(-1), memory_dim, time_dim),\r\n    aggregator_module=LastAggregator(),\r\n).to(device)\r\n\r\ngnn = (\r\n    GraphAttentionEmbedding(\r\n        in_channels=memory_dim,\r\n        out_channels=embedding_dim,\r\n        msg_dim=data.msg.size(-1),\r\n        time_enc=memory.time_enc,\r\n    )\r\n    .to(device)\r\n    .float()\r\n)\r\n\r\nnode_pred = NodePredictor(in_dim=embedding_dim, out_dim=num_classes).to(device)\r\n\r\noptimizer = torch.optim.Adam(\r\n    set(memory.parameters()) | set(gnn.parameters()) | set(node_pred.parameters()),\r\n    lr=lr,\r\n)\r\n\r\ncriterion = torch.nn.CrossEntropyLoss()\r\n# Helper vector to map global node indices to local ones.\r\nassoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)\r\n\r\n\r\ndef plot_curve(scores, out_name):\r\n    plt.plot(scores, color=\"#e34a33\")\r\n    plt.ylabel(\"score\")\r\n    plt.savefig(out_name + \".pdf\")\r\n    plt.close()\r\n\r\n\r\ndef process_edges(src, dst, t, msg):\r\n    if src.nelement() > 0:\r\n        # msg = msg.to(torch.float32)\r\n        memory.update_state(src, dst, t, msg)\r\n        neighbor_loader.insert(src, dst)\r\n\r\n\r\ndef train():\r\n    memory.train()\r\n    gnn.train()\r\n    node_pred.train()\r\n\r\n    memory.reset_state()  # Start with a fresh memory.\r\n    neighbor_loader.reset_state()  # Start with an empty graph.\r\n\r\n    total_loss = 0\r\n    label_t = dataset.get_label_time()  # check when does the first label start\r\n    num_label_ts = 0\r\n    total_score = 0\r\n\r\n    for batch in train_loader:\r\n        batch = batch.to(device)\r\n        optimizer.zero_grad()\r\n        src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg\r\n\r\n        query_t = batch.t[-1]\r\n        # check if this batch moves to the next day\r\n        if query_t > label_t:\r\n            # find the node labels from the past day\r\n            label_tuple = dataset.get_node_label(query_t)\r\n            label_ts, label_srcs, labels = (\r\n                label_tuple[0],\r\n                label_tuple[1],\r\n                label_tuple[2],\r\n            )\r\n            label_t = dataset.get_label_time()\r\n            label_srcs = label_srcs.to(device)\r\n\r\n            # Process all edges that are still in the past day\r\n            previous_day_mask = batch.t < label_t\r\n            process_edges(\r\n                src[previous_day_mask],\r\n                dst[previous_day_mask],\r\n                t[previous_day_mask],\r\n                msg[previous_day_mask],\r\n            )\r\n            # Reset edges to be the edges from tomorrow so they can be used later\r\n            src, dst, t, msg = (\r\n                src[~previous_day_mask],\r\n                dst[~previous_day_mask],\r\n                t[~previous_day_mask],\r\n                msg[~previous_day_mask],\r\n            )\r\n\r\n            \"\"\"\r\n            modified for node property prediction\r\n            1. sample neighbors from the neighbor loader for all nodes to be predicted\r\n            2. extract memory from the sampled neighbors and the nodes\r\n            3. run gnn with the extracted memory embeddings and the corresponding time and message\r\n            \"\"\"\r\n            n_id = label_srcs\r\n            n_id_neighbors, mem_edge_index, e_id = neighbor_loader(n_id)\r\n            assoc[n_id_neighbors] = torch.arange(n_id_neighbors.size(0), device=device)\r\n\r\n            z, last_update = memory(n_id_neighbors)\r\n\r\n            z = gnn(\r\n                z,\r\n                last_update,\r\n                mem_edge_index,\r\n                data.t[e_id].to(device),\r\n                data.msg[e_id].to(device),\r\n            )\r\n            z = z[assoc[n_id]]\r\n\r\n            # loss and metric computation\r\n            pred = node_pred(z)\r\n\r\n            loss = criterion(pred, labels.to(device))\r\n            np_pred = pred.cpu().detach().numpy()\r\n            np_true = labels.cpu().detach().numpy()\r\n\r\n            input_dict = {\r\n                \"y_true\": np_true,\r\n                \"y_pred\": np_pred,\r\n                \"eval_metric\": [eval_metric],\r\n            }\r\n            result_dict = evaluator.eval(input_dict)\r\n            score = result_dict[eval_metric]\r\n            total_score += score\r\n            num_label_ts += 1\r\n\r\n            loss.backward()\r\n            optimizer.step()\r\n            total_loss += float(loss.detach())\r\n\r\n        # Update memory and neighbor loader with ground-truth state.\r\n        process_edges(src, dst, t, msg)\r\n        memory.detach()\r\n\r\n    metric_dict = {\r\n        \"ce\": total_loss / num_label_ts,\r\n    }\r\n    metric_dict[eval_metric] = total_score / num_label_ts\r\n    return metric_dict\r\n\r\n\r\n@torch.no_grad()\r\ndef test(loader):\r\n    memory.eval()\r\n    gnn.eval()\r\n    node_pred.eval()\r\n    total_score = 0\r\n    label_t = dataset.get_label_time()  # check when does the first label start\r\n    num_label_ts = 0\r\n\r\n    for batch in loader:\r\n        batch = batch.to(device)\r\n        src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg\r\n\r\n        query_t = batch.t[-1]\r\n        if query_t > label_t:\r\n            label_tuple = dataset.get_node_label(query_t)\r\n            if label_tuple is None:\r\n                break\r\n            label_ts, label_srcs, labels = (\r\n                label_tuple[0],\r\n                label_tuple[1],\r\n                label_tuple[2],\r\n            )\r\n            label_t = dataset.get_label_time()\r\n            label_srcs = label_srcs.to(device)\r\n\r\n            # Process all edges that are still in the past day\r\n            previous_day_mask = batch.t < label_t\r\n            process_edges(\r\n                src[previous_day_mask],\r\n                dst[previous_day_mask],\r\n                t[previous_day_mask],\r\n                msg[previous_day_mask],\r\n            )\r\n            # Reset edges to be the edges from tomorrow so they can be used later\r\n            src, dst, t, msg = (\r\n                src[~previous_day_mask],\r\n                dst[~previous_day_mask],\r\n                t[~previous_day_mask],\r\n                msg[~previous_day_mask],\r\n            )\r\n\r\n            \"\"\"\r\n            modified for node property prediction\r\n            1. sample neighbors from the neighbor loader for all nodes to be predicted\r\n            2. extract memory from the sampled neighbors and the nodes\r\n            3. run gnn with the extracted memory embeddings and the corresponding time and message\r\n            \"\"\"\r\n            n_id = label_srcs\r\n            n_id_neighbors, mem_edge_index, e_id = neighbor_loader(n_id)\r\n            assoc[n_id_neighbors] = torch.arange(n_id_neighbors.size(0), device=device)\r\n\r\n            z, last_update = memory(n_id_neighbors)\r\n            z = gnn(\r\n                z,\r\n                last_update,\r\n                mem_edge_index,\r\n                data.t[e_id].to(device),\r\n                data.msg[e_id].to(device),\r\n            )\r\n            z = z[assoc[n_id]]\r\n\r\n            # loss and metric computation\r\n            pred = node_pred(z)\r\n            np_pred = pred.cpu().detach().numpy()\r\n            np_true = labels.cpu().detach().numpy()\r\n\r\n            input_dict = {\r\n                \"y_true\": np_true,\r\n                \"y_pred\": np_pred,\r\n                \"eval_metric\": [eval_metric],\r\n            }\r\n            result_dict = evaluator.eval(input_dict)\r\n            score = result_dict[eval_metric]\r\n            total_score += score\r\n            num_label_ts += 1\r\n\r\n        process_edges(src, dst, t, msg)\r\n\r\n    metric_dict = {}\r\n    metric_dict[eval_metric] = total_score / num_label_ts\r\n    return metric_dict\r\n\r\n\r\ntrain_curve = []\r\nval_curve = []\r\ntest_curve = []\r\nmax_val_score = 0  #find the best test score based on validation score\r\nbest_test_idx = 0\r\nfor epoch in range(1, epochs + 1):\r\n    start_time = timeit.default_timer()\r\n    train_dict = train()\r\n    print(\"------------------------------------\")\r\n    print(f\"training Epoch: {epoch:02d}\")\r\n    print(train_dict)\r\n    train_curve.append(train_dict[eval_metric])\r\n    print(\"Training takes--- %s seconds ---\" % (timeit.default_timer() - start_time))\r\n    \r\n    start_time = timeit.default_timer()\r\n    val_dict = test(val_loader)\r\n    print(val_dict)\r\n    val_curve.append(val_dict[eval_metric])\r\n    if (val_dict[eval_metric] > max_val_score):\r\n        max_val_score = val_dict[eval_metric]\r\n        best_test_idx = epoch - 1\r\n    print(\"Validation takes--- %s seconds ---\" % (timeit.default_timer() - start_time))\r\n\r\n    start_time = timeit.default_timer()\r\n    test_dict = test(test_loader)\r\n    print(test_dict)\r\n    test_curve.append(test_dict[eval_metric])\r\n    print(\"Test takes--- %s seconds ---\" % (timeit.default_timer() - start_time))\r\n    print(\"------------------------------------\")\r\n    dataset.reset_label_time()\r\n\r\n\r\n# code for plotting\r\nplot_curve(train_curve, \"train_curve\")\r\nplot_curve(val_curve, \"val_curve\")\r\nplot_curve(test_curve, \"test_curve\")\r\n\r\nmax_test_score = test_curve[best_test_idx]\r\nprint(\"------------------------------------\")\r\nprint(\"------------------------------------\")\r\nprint (\"best val score: \", max_val_score)\r\nprint (\"best validation epoch   : \", best_test_idx + 1)\r\nprint (\"best test score: \", max_test_score)\r\n"
  },
  {
    "path": "mkdocs.yml",
    "content": "site_name: Temporal Graph Benchmark\n\nnav:\n  - Overview: index.md\n  - About: about.md\n  - API:\n    - tgb.linkproppred: api/tgb.linkproppred.md\n    - tgb.nodeproppred: api/tgb.nodeproppred.md\n    - tgb.utils: api/tgb.utils.md\n  - Tutorials:\n    - Access Edge Data in PyG: tutorials/Edge_data_pyg.ipynb\n    - Access Edge Data in Numpy: tutorials/Edge_data_numpy.ipynb\n\ntheme:\n  logo: assets/logo.png\n  name: material\n  features:\n    - navigation.tabs\n    - navigation.sections\n    - toc.integrate\n    - navigation.top\n    - search.suggest\n    - search.highlight\n    - content.tabs.link\n    - content.code.annotation\n    - content.code.copy\n  language: en\n  palette:\n    - scheme: default\n      toggle:\n        icon: material/toggle-switch-off-outline \n        name: Switch to dark mode\n      primary: purple \n      accent: orange\n    - scheme: slate \n      toggle:\n        icon: material/toggle-switch\n        name: Switch to light mode    \n      primary: orange\n      accent: lime\n\nextra:\n  social:\n    - icon: fontawesome/brands/github-alt\n      link: https://github.com/shenyangHuang/TGB\n    - icon: fontawesome/solid/envelope\n      link: shenyang.huang@mail.mcgill.ca\n    - icon: fontawesome/brands/twitter\n      link: https://twitter.com/shenyangHuang\n    - icon: fontawesome/brands/linkedin\n      link: https://www.linkedin.com/in/shenyang-huang/\n\n\nmarkdown_extensions:\n  - pymdownx.highlight:\n      anchor_linenums: true\n  - pymdownx.inlinehilite\n  - pymdownx.snippets\n  - admonition\n  - pymdownx.arithmatex:\n      generic: true\n  - footnotes\n  - pymdownx.details\n  - pymdownx.superfences\n  - pymdownx.mark\n  - attr_list\n  - pymdownx.emoji:\n      emoji_index: !!python/name:materialx.emoji.twemoji\n      emoji_generator: !!python/name:materialx.emoji.to_svg\n\n\nplugins:\n  - search\n\n  - mkdocstrings:\n      watch:\n        - tgb/\n      handlers:\n        python:\n          setup_commands:\n            - import sys\n            - sys.path.append(\"docs\")\n            - sys.path.append(\"tgb\")\n          selection:\n            new_path_syntax: true\n          rendering:\n            show_root_heading: false\n            heading_level: 3\n            show_root_full_path: false\n\n  - mkdocs-jupyter:\n      execute: false\n"
  },
  {
    "path": "modules/decoder.py",
    "content": "\"\"\"\nDecoder modules for dynamic link prediction\n\n\"\"\"\n\nimport torch\nfrom torch.nn import Linear\nimport torch.nn.functional as F\nfrom torch.nn.parameter import Parameter\nimport math\n\nclass LinkPredictor(torch.nn.Module):\n    \"\"\"\n    Reference:\n    - https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py\n    \"\"\"\n\n    def __init__(self, in_channels):\n        super().__init__()\n        self.lin_src = Linear(in_channels, in_channels)\n        self.lin_dst = Linear(in_channels, in_channels)\n        self.lin_final = Linear(in_channels, 1)\n\n    def forward(self, z_src, z_dst):\n        h = self.lin_src(z_src) + self.lin_dst(z_dst)\n        h = h.relu()\n        return self.lin_final(h).sigmoid()\n\n\nclass NodePredictor(torch.nn.Module):\n    def __init__(self, in_dim, out_dim):\n        super().__init__()\n        self.lin_node = Linear(in_dim, in_dim)\n        self.out = Linear(in_dim, out_dim)\n\n    def forward(self, node_embed):\n        h = self.lin_node(node_embed)\n        h = h.relu()\n        h = self.out(h)\n        # h = F.log_softmax(h, dim=-1)\n        return h\n\n\n### for TKG:\nclass ConvTransE(torch.nn.Module):\n    \"\"\"\n    https://github.com/Lee-zix/CEN/blob/main/src/decoder.py\n    \"\"\"\n    def __init__(self, num_entities, embedding_dim, input_dropout=0, hidden_dropout=0, \n    feature_map_dropout=0, channels=50, kernel_size=3, sequence_len = 1, use_bias=True, model_name='REGCN'):\n\n        super(ConvTransE, self).__init__()\n        self.model_name = model_name #'REGCN' or 'CEN'\n        self.inp_drop = torch.nn.Dropout(input_dropout)\n        self.hidden_drop = torch.nn.Dropout(hidden_dropout)\n        self.feature_map_drop = torch.nn.Dropout(feature_map_dropout)\n        self.embedding_dim = embedding_dim\n\n        # self.sequence_len = sequence_len\n\n        self.conv_list = torch.nn.ModuleList()\n        self.bn0_list = torch.nn.ModuleList()\n        self.bn1_list = torch.nn.ModuleList()\n        self.bn2_list = torch.nn.ModuleList()\n        for _ in range(sequence_len):\n            self.conv_list.append(torch.nn.Conv1d(2, channels, kernel_size, stride=1, \n            padding=int(math.floor(kernel_size / 2)))  ) # kernel size is odd, then padding = math.floor(kernel_size/2))\n            self.bn0_list.append(torch.nn.BatchNorm1d(2))\n            self.bn1_list.append( torch.nn.BatchNorm1d(channels))\n            self.bn2_list.append(torch.nn.BatchNorm1d(embedding_dim)) \n\n\n        self.fc = torch.nn.Linear(embedding_dim * channels, embedding_dim)\n\n    def forward(self, embedding, emb_rel, triplets, partial_embeding=None, samples_of_interest_emb=None):\n        \"\"\" forward for ConvsTransE decoder that computes scores for given triples of question\n        return: score_list: list of scores for each triple in the batch\n        \"\"\"\n        score_list = []\n        batch_size = len(triplets)\n        if self.model_name == 'CEN': #CEN\n            for idx in range(len(embedding)): # leng of test_graph\n                if samples_of_interest_emb != None:\n                    x= self.forward_inner(embedding[idx], emb_rel, triplets, idx, partial_embeding, samples_of_interest_emb[idx])     \n                else:\n                    x= self.forward_inner(embedding[idx], emb_rel, triplets, idx, partial_embeding, samples_of_interest_emb)\n                score_list.append(x)\n            return score_list\n        else: #RE-GCN\n            scores = self.forward_inner(embedding, emb_rel, triplets, 0, partial_embeding, samples_of_interest_emb)\n            return scores \n\n\n\n    def forward_inner(self, embedding, emb_rel, triplets, idx=0, partial_embeding=None, samples_of_interest_emb=None):\n        \"\"\" forward for ConvsTransE decoder that computes scores for given triples of question for each graph in the history of test graphs\n        return: x: list of scores for each triple in the batch\n        \"\"\"\n        batch_size = len(triplets)\n        e1_embedded_all = F.tanh(embedding)\n        e1_embedded = e1_embedded_all[triplets[:, 0]].unsqueeze(1)\n        rel_embedded = emb_rel[triplets[:, 1]].unsqueeze(1)\n        stacked_inputs = torch.cat([e1_embedded, rel_embedded], 1)\n        stacked_inputs = self.bn0_list[idx](stacked_inputs)\n        x = self.inp_drop(stacked_inputs)\n        x = self.conv_list[idx](x)\n        x = self.bn1_list[idx](x)\n        x = F.relu(x)\n        x = self.feature_map_drop(x)\n        x = x.view(batch_size, -1)\n        x = self.fc(x)\n        x = self.hidden_drop(x)\n        if batch_size > 1:\n            x = self.bn2_list[idx](x)\n        x = F.relu(x)\n        if partial_embeding !=None:\n            x = torch.mm(x, partial_embeding.transpose(1, 0))\n        elif samples_of_interest_emb !=None: # added tgb team: predict only for nodes of interest\n            x = torch.mm(x, F.tanh(samples_of_interest_emb).transpose(1, 0)) \n        else: #predict for all nodes\n            x = torch.mm(x, e1_embedded_all.transpose(1, 0))\n\n        return x"
  },
  {
    "path": "modules/early_stopping.py",
    "content": "\"\"\"\nAn Early Stopping Module\n\"\"\"\nimport os\nfrom pathlib import Path\nimport torch\nimport torch.nn as nn\nimport numpy as np\n\nclass EarlyStopMonitor(object):\n    \n    def __init__(self, save_model_dir: str, save_model_id: str, \n                tolerance: float=1e-10, patience: int=5,\n                higher_better: bool=True):\n        r\"\"\"\n        Early Stopping Monitor\n        :param: save_model_path: strc, where to save the model\n        :param: save_model_id: str, an id to save the model with\n        :param: tolerance: float, the amount of tolerance of the early stopper\n        :param: patience: int, how many round to wait\n        :param: higher_better: whether higher_value of the a metric is better\n        \"\"\"\n        self.tolerance = tolerance\n        self.patience = patience\n        self.higher_better = higher_better\n        self.counter = 0\n        self.best_sofar = None\n        self.best_epoch = 0\n        self.epoch_idx = 1\n\n        self.save_model_dir = save_model_dir\n        if not os.path.exists(self.save_model_dir):\n            os.mkdir(self.save_model_dir)\n            print('INFO: Create directory {}'.format(save_model_dir))\n        Path(self.save_model_dir).mkdir(parents=True, exist_ok=True)\n        self.save_model_id = save_model_id\n\n    def get_best_model_path(self):\n        r\"\"\"\n        return the path of the best model\n        \"\"\"\n        return self.save_model_dir + '/{}.pth'.format(self.save_model_id)\n    \n    def step_check(self, curr_metric: float, models_dict: dict):\n        r\"\"\"\n        execute the early stop strategy\n        :param: metric: a metric to evaluate the early stopping on\n        :param: models_dict: a dictionary containing all models to be saved\n        \"\"\"\n        if not self.higher_better:\n            curr_metric *= -1\n        \n        if (self.best_sofar is None) or ((curr_metric - self.best_sofar) / np.abs(self.best_sofar) > self.tolerance):\n            # first iteration or observing an improvement\n            self.best_sofar = curr_metric\n            print(\"INFO: save a checkpoint...\")\n            self.save_checkpoint(models_dict)\n            self.counter = 0\n            self.best_epoch = self.epoch_idx\n        else:\n            # no improvement observed\n            self.counter += 1\n        \n        self.epoch_idx += 1\n        \n        return self.counter >= self.patience\n    \n    def save_checkpoint(self, models_dict: dict):\n        r\"\"\"\n        save models as a checkpoint\n        :param: models_dict: a dictionary containing all models to be saved \n        \"\"\"\n        model_path = self.get_best_model_path()\n        print(\"INFO: save the model to {}\".format(model_path))\n        model_names = list(models_dict.keys())\n        model_components = list(models_dict.values())\n        torch.save({model_names[i]: model_components[i].state_dict() for i in range(len(model_names))}, \n                    model_path)\n\n    def load_checkpoint(self, models_dict: dict):\n        r\"\"\"\n        save models from the checkpoint\n        :param: models_dict: a dictionary containing all models\n        \"\"\"\n        model_path = self.get_best_model_path()\n        print(\"INFO: load the model of epoch {} from {}\".format(self.best_epoch, model_path))\n        checkpoint = torch.load(model_path)\n        for model_name, model in models_dict.items():\n            model.load_state_dict(checkpoint[model_name])\n        \n\n\n        "
  },
  {
    "path": "modules/edgebank_predictor.py",
    "content": "\"\"\"\r\nEdgeBank is a simple strong baseline for dynamic link prediction\r\nit predicts the existence of edges based on their history of occurrence\r\n\r\nReference:\r\n    - https://github.com/fpour/DGB/tree/main\r\n\"\"\"\r\n\r\n\r\nimport numpy as np\r\nimport warnings\r\n\r\nclass EdgeBankPredictor(object):\r\n    def __init__(\r\n        self,\r\n        src: np.ndarray,\r\n        dst: np.ndarray,\r\n        ts: np.ndarray,\r\n        memory_mode: str = 'unlimited',  # could be `unlimited` or `fixed_time_window`\r\n        time_window_ratio: float = 0.15,\r\n        pos_prob: float = 1.0,\r\n    ):\r\n        r\"\"\"\r\n        intialize edgebank and specify the memory mode\r\n        Parameters:\r\n            src: source node id of the edges for initialization \r\n            dst: destination node id of the edges for initialization\r\n            ts: timestamp of the edges for initialization\r\n            memory_mode: 'unlimited' or 'fixed_time_window'\r\n            time_window_ratio: the ratio of the time window length to the total time length         \r\n            pos_prob: the probability of the link existence for the edges in memory   \r\n        \"\"\"\r\n        assert memory_mode in ['unlimited', 'fixed_time_window'], \"Invalide memory mode for EdgeBank!\"\r\n        self.memory_mode = memory_mode\r\n        if self.memory_mode == 'fixed_time_window':\r\n            self.time_window_ratio = time_window_ratio\r\n            #determine the time window size based on ratio from the given src, dst, and ts for initialization\r\n            duration = ts.max() - ts.min()\r\n            self.prev_t = ts.min() + duration * (1-time_window_ratio) #the time windows starts from the last ratio% of time\r\n            self.cur_t = ts.max()\r\n            self.duration = self.cur_t - self.prev_t\r\n        else:\r\n            self.time_window_ratio = -1\r\n            self.prev_t = -1 \r\n            self.cur_t = -1 \r\n            self.duration = -1\r\n\r\n        self.memory = {} #{(u,v):1}\r\n        self.pos_prob = pos_prob\r\n        self.update_memory(src, dst, ts)\r\n\r\n    def update_memory(self, \r\n                       src: np.ndarray, \r\n                       dst: np.ndarray, \r\n                       ts: np.ndarray):\r\n        r\"\"\"\r\n        generate the current and correct state of the memory with the observed edges so far\r\n        note that historical edges may include training, validation, and already observed test edges\r\n        Parameters:\r\n            src: source node id of the edges\r\n            dst: destination node id of the edges\r\n            ts: timestamp of the edges\r\n        \"\"\"\r\n        if self.memory_mode == 'unlimited':\r\n            self._update_unlimited_memory(src, dst)  #ignores time\r\n        elif self.memory_mode == 'fixed_time_window':\r\n            self._update_time_window_memory(src, dst, ts)\r\n        else:\r\n            raise ValueError(\"Invalide memory mode!\")\r\n        \r\n    @property\r\n    def start_time(self) -> int:\r\n        \"\"\"\r\n        return the start of time window for edgebank `fixed_time_window` only\r\n        Returns:\r\n            start of time window\r\n        \"\"\"\r\n        if (self.memory_mode == \"unlimited\"):\r\n            warnings.warn(\"start_time is not defined for unlimited memory mode, returns -1\")\r\n        return self.prev_t\r\n    \r\n    @property\r\n    def end_time(self) -> int:\r\n        \"\"\"\r\n        return the end of time window for edgebank `fixed_time_window` only\r\n        Returns:\r\n            end of time window\r\n        \"\"\"\r\n        if (self.memory_mode == \"unlimited\"):\r\n            warnings.warn(\"end_time is not defined for unlimited memory mode, returns -1\")\r\n        return self.cur_t\r\n    \r\n    def _update_unlimited_memory(self, \r\n                                 update_src: np.ndarray, \r\n                                 update_dst: np.ndarray):\r\n        r\"\"\"\r\n        update self.memory with newly arrived src and dst\r\n        Parameters:\r\n            src: source node id of the edges\r\n            dst: destination node id of the edges\r\n        \"\"\"\r\n        for src, dst in zip(update_src, update_dst):\r\n            if (src, dst) not in self.memory:\r\n                self.memory[(src, dst)] = 1\r\n\r\n    def _update_time_window_memory(self, \r\n                                   update_src: np.ndarray, \r\n                                   update_dst: np.ndarray, \r\n                                   update_ts: np.ndarray) -> None:\r\n        r\"\"\"\r\n        move the time window forward until end of dst timestamp here\r\n        also need to remove earlier edges from memory which is not in the time window\r\n        Parameters:\r\n            update_src: source node id of the edges\r\n            update_dst: destination node id of the edges\r\n            update_ts: timestamp of the edges\r\n        \"\"\"\r\n\r\n        #* initialize the memory if it is empty\r\n        if (len(self.memory) == 0):\r\n            for src, dst, ts in zip(update_src, update_dst, update_ts):\r\n                self.memory[(src, dst)] = ts\r\n            return None\r\n        \r\n        #* update the memory if it is not empty\r\n        if (update_ts.max() > self.cur_t):\r\n            self.cur_t = update_ts.max()\r\n            self.prev_t = self.cur_t - self.duration\r\n\r\n        #* add new edges to the time window\r\n        for src, dst, ts in zip(update_src, update_dst, update_ts):\r\n            self.memory[(src, dst)] = ts\r\n\r\n        \r\n    def predict_link(self, \r\n                    query_src: np.ndarray, \r\n                    query_dst: np.ndarray) -> np.ndarray:\r\n        r\"\"\"\r\n        predict the probability from query src,dst pair given the current memory, \r\n        all edges not in memory will return 0.0 while all observed edges in memory will return self.pos_prob\r\n        Parameters:\r\n            query_src: source node id of the query edges\r\n            query_dst: destination node id of the query edges\r\n        Returns: \r\n            pred: the prediction for all query edges\r\n        \"\"\"\r\n        pred = np.zeros(len(query_src))\r\n        idx = 0\r\n        for src, dst in zip(query_src, query_dst):\r\n            if (src, dst) in self.memory:\r\n                if (self.memory_mode == 'fixed_time_window'):\r\n                    if (self.memory[(src,dst)] >= self.prev_t):\r\n                        pred[idx] = self.pos_prob\r\n                else:\r\n                    pred[idx] = self.pos_prob\r\n            idx += 1\r\n        \r\n        return pred\r\n    \r\n    "
  },
  {
    "path": "modules/emb_module.py",
    "content": "\"\"\"\nGNN-based modules used in the architecture of MP-TG models\n\n\"\"\"\n\nimport math\nfrom torch_geometric.nn import TransformerConv\nimport torch\n\n\nclass GraphAttentionEmbedding(torch.nn.Module):\n    \"\"\"\n    Reference:\n    - https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py\n    \"\"\"\n\n    def __init__(self, in_channels, out_channels, msg_dim, time_enc):\n        super().__init__()\n        self.time_enc = time_enc\n        edge_dim = msg_dim + time_enc.out_channels\n        self.conv = TransformerConv(\n            in_channels, out_channels // 2, heads=2, dropout=0.1, edge_dim=edge_dim\n        )\n\n    def forward(self, x, last_update, edge_index, t, msg):\n        rel_t = last_update[edge_index[0]] - t\n        rel_t_enc = self.time_enc(rel_t.to(x.dtype))\n        edge_attr = torch.cat([rel_t_enc, msg], dim=-1)\n        return self.conv(x, edge_index, edge_attr)\n\n\nclass TimeEmbedding(torch.nn.Module):\n    def __init__(self, in_channels, out_channels):\n        super().__init__()\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n\n        class NormalLinear(torch.nn.Linear):\n            # From TGN code: From JODIE code\n            def reset_parameters(self):\n                stdv = 1.0 / math.sqrt(self.weight.size(1))\n                self.weight.data.normal_(0, stdv)\n                if self.bias is not None:\n                    self.bias.data.normal_(0, stdv)\n\n        self.embedding_layer = NormalLinear(1, self.out_channels)\n\n    def forward(self, x, last_update, t):\n        rel_t = last_update - t\n        embeddings = x * (1 + self.embedding_layer(rel_t.to(x.dtype).unsqueeze(1)))\n\n        return embeddings\n"
  },
  {
    "path": "modules/heuristics.py",
    "content": "import numpy as np\r\n\r\n\r\nclass PersistantForecaster:\r\n    def __init__(self, num_class):\r\n        self.dict = {}\r\n        self.num_class = num_class\r\n\r\n    def update_dict(self, node_id, label):\r\n        self.dict[node_id] = label\r\n\r\n    def query_dict(self, node_id):\r\n        r\"\"\"\r\n        Parameters:\r\n            node_id: the node to query\r\n        Returns:\r\n            returns the last seen label of the node if it exists, if not return zero vector\r\n        \"\"\"\r\n        if node_id in self.dict:\r\n            return self.dict[node_id]\r\n        else:\r\n            return np.zeros(self.num_class)\r\n\r\n\r\nclass MovingAverage:\r\n    def __init__(self, num_class, window=7):\r\n        self.dict = {}\r\n        self.num_class = num_class\r\n        self.window = window\r\n\r\n    def update_dict(self, node_id, label):\r\n        if node_id in self.dict:\r\n            total = self.dict[node_id] * (self.window - 1) + label\r\n            self.dict[node_id] = total / self.window\r\n        else:\r\n            self.dict[node_id] = label\r\n\r\n    def query_dict(self, node_id):\r\n        r\"\"\"\r\n        Parameters:\r\n            node_id: the node to query\r\n        Returns:\r\n            returns the last seen label of the node if it exists, if not return zero vector\r\n        \"\"\"\r\n        if node_id in self.dict:\r\n            return self.dict[node_id]\r\n        else:\r\n            return np.zeros(self.num_class)\r\n"
  },
  {
    "path": "modules/memory_module.py",
    "content": "\"\"\"\nMemory Module\n\nReference:\n    - https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/models/tgn.html\n\"\"\"\n\n\nimport copy\nfrom typing import Callable, Dict, Tuple\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import GRUCell, RNNCell, Linear\n\nfrom torch_geometric.nn.inits import zeros\nfrom torch_geometric.utils import scatter\n\nfrom modules.time_enc import TimeEncoder\n\n\nTGNMessageStoreType = Dict[int, Tuple[Tensor, Tensor, Tensor, Tensor]]\n\n\nclass TGNMemory(torch.nn.Module):\n    r\"\"\"The Temporal Graph Network (TGN) memory model from the\n    `\"Temporal Graph Networks for Deep Learning on Dynamic Graphs\"\n    <https://arxiv.org/abs/2006.10637>`_ paper.\n\n    .. note::\n\n        For an example of using TGN, see `examples/tgn.py\n        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/\n        tgn.py>`_.\n\n    Args:\n        num_nodes (int): The number of nodes to save memories for.\n        raw_msg_dim (int): The raw message dimensionality.\n        memory_dim (int): The hidden memory dimensionality.\n        time_dim (int): The time encoding dimensionality.\n        message_module (torch.nn.Module): The message function which\n            combines source and destination node memory embeddings, the raw\n            message and the time encoding.\n        aggregator_module (torch.nn.Module): The message aggregator function\n            which aggregates messages to the same destination into a single\n            representation.\n    \"\"\"\n\n    def __init__(\n        self,\n        num_nodes: int,\n        raw_msg_dim: int,\n        memory_dim: int,\n        time_dim: int,\n        message_module: Callable,\n        aggregator_module: Callable,\n        memory_updater_cell: str = \"gru\",\n    ):\n        super().__init__()\n\n        self.num_nodes = num_nodes\n        self.raw_msg_dim = raw_msg_dim\n        self.memory_dim = memory_dim\n        self.time_dim = time_dim\n\n        self.msg_s_module = message_module\n        self.msg_d_module = copy.deepcopy(message_module)\n        self.aggr_module = aggregator_module\n        self.time_enc = TimeEncoder(time_dim)\n        # self.gru = GRUCell(message_module.out_channels, memory_dim)\n        if memory_updater_cell == \"gru\":  # for TGN\n            self.memory_updater = GRUCell(message_module.out_channels, memory_dim)\n        elif memory_updater_cell == \"rnn\":  # for JODIE & DyRep\n            self.memory_updater = RNNCell(message_module.out_channels, memory_dim)\n        else:\n            raise ValueError(\n                \"Undefined memory updater!!! Memory updater can be either 'gru' or 'rnn'.\"\n            )\n\n        self.register_buffer(\"memory\", torch.empty(num_nodes, memory_dim))\n        last_update = torch.empty(self.num_nodes, dtype=torch.long)\n        self.register_buffer(\"last_update\", last_update)\n        self.register_buffer(\"_assoc\", torch.empty(num_nodes, dtype=torch.long))\n\n        self.msg_s_store = {}\n        self.msg_d_store = {}\n\n        self.reset_parameters()\n\n    @property\n    def device(self) -> torch.device:\n        return self.time_enc.lin.weight.device\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        if hasattr(self.msg_s_module, \"reset_parameters\"):\n            self.msg_s_module.reset_parameters()\n        if hasattr(self.msg_d_module, \"reset_parameters\"):\n            self.msg_d_module.reset_parameters()\n        if hasattr(self.aggr_module, \"reset_parameters\"):\n            self.aggr_module.reset_parameters()\n        self.time_enc.reset_parameters()\n        self.memory_updater.reset_parameters()\n        self.reset_state()\n\n    def reset_state(self):\n        \"\"\"Resets the memory to its initial state.\"\"\"\n        zeros(self.memory)\n        zeros(self.last_update)\n        self._reset_message_store()\n\n    def detach(self):\n        \"\"\"Detaches the memory from gradient computation.\"\"\"\n        self.memory.detach_()\n\n    def forward(self, n_id: Tensor) -> Tuple[Tensor, Tensor]:\n        \"\"\"Returns, for all nodes :obj:`n_id`, their current memory and their\n        last updated timestamp.\"\"\"\n        if self.training:\n            memory, last_update = self._get_updated_memory(n_id)\n        else:\n            memory, last_update = self.memory[n_id], self.last_update[n_id]\n\n        return memory, last_update\n\n    def update_state(self, src: Tensor, dst: Tensor, t: Tensor, raw_msg: Tensor):\n        \"\"\"Updates the memory with newly encountered interactions\n        :obj:`(src, dst, t, raw_msg)`.\"\"\"\n        n_id = torch.cat([src, dst]).unique()\n\n        if self.training:\n            self._update_memory(n_id)\n            self._update_msg_store(src, dst, t, raw_msg, self.msg_s_store)\n            self._update_msg_store(dst, src, t, raw_msg, self.msg_d_store)\n        else:\n            self._update_msg_store(src, dst, t, raw_msg, self.msg_s_store)\n            self._update_msg_store(dst, src, t, raw_msg, self.msg_d_store)\n            self._update_memory(n_id)\n\n    def _reset_message_store(self):\n        i = self.memory.new_empty((0,), device=self.device, dtype=torch.long)\n        msg = self.memory.new_empty((0, self.raw_msg_dim), device=self.device)\n        # Message store format: (src, dst, t, msg)\n        self.msg_s_store = {j: (i, i, i, msg) for j in range(self.num_nodes)}\n        self.msg_d_store = {j: (i, i, i, msg) for j in range(self.num_nodes)}\n\n    def _update_memory(self, n_id: Tensor):\n        memory, last_update = self._get_updated_memory(n_id)\n        self.memory[n_id] = memory\n        self.last_update[n_id] = last_update\n\n    def _get_updated_memory(self, n_id: Tensor) -> Tuple[Tensor, Tensor]:\n        self._assoc[n_id] = torch.arange(n_id.size(0), device=n_id.device)\n\n        # Compute messages (src -> dst).\n        msg_s, t_s, src_s, dst_s = self._compute_msg(\n            n_id, self.msg_s_store, self.msg_s_module\n        )\n\n        # Compute messages (dst -> src).\n        msg_d, t_d, src_d, dst_d = self._compute_msg(\n            n_id, self.msg_d_store, self.msg_d_module\n        )\n\n        # Aggregate messages.\n        idx = torch.cat([src_s, src_d], dim=0)\n        msg = torch.cat([msg_s, msg_d], dim=0)\n        t = torch.cat([t_s, t_d], dim=0)\n        aggr = self.aggr_module(msg, self._assoc[idx], t, n_id.size(0))\n\n        # Get local copy of updated memory.\n        memory = self.memory_updater(aggr, self.memory[n_id])\n\n        # Get local copy of updated `last_update`.\n        dim_size = self.last_update.size(0)\n        last_update = scatter(t, idx, 0, dim_size, reduce=\"max\")[n_id]\n\n        return memory, last_update\n\n    def _update_msg_store(\n        self,\n        src: Tensor,\n        dst: Tensor,\n        t: Tensor,\n        raw_msg: Tensor,\n        msg_store: TGNMessageStoreType,\n    ):\n        n_id, perm = src.sort()\n        n_id, count = n_id.unique_consecutive(return_counts=True)\n        for i, idx in zip(n_id.tolist(), perm.split(count.tolist())):\n            msg_store[i] = (src[idx], dst[idx], t[idx], raw_msg[idx])\n\n    def _compute_msg(\n        self, n_id: Tensor, msg_store: TGNMessageStoreType, msg_module: Callable\n    ):\n        data = [msg_store[i] for i in n_id.tolist()]\n        src, dst, t, raw_msg = list(zip(*data))\n        src = torch.cat(src, dim=0)\n        dst = torch.cat(dst, dim=0)\n        t = torch.cat(t, dim=0)\n        raw_msg = torch.cat(raw_msg, dim=0)\n        t_rel = t - self.last_update[src]\n        t_enc = self.time_enc(t_rel.to(raw_msg.dtype))\n\n        msg = msg_module(self.memory[src], self.memory[dst], raw_msg, t_enc)\n\n        return msg, t, src, dst\n\n    def train(self, mode: bool = True):\n        \"\"\"Sets the module in training mode.\"\"\"\n        if self.training and not mode:\n            # Flush message store to memory in case we just entered eval mode.\n            self._update_memory(torch.arange(self.num_nodes, device=self.memory.device))\n            self._reset_message_store()\n        super().train(mode)\n\n\nclass DyRepMemory(torch.nn.Module):\n    r\"\"\"\n    Based on intuitions from TGN Memory...\n    Differences with the original TGN Memory:\n        - can use source or destination embeddings in message generation\n        - can use a RNN or GRU module as the memory updater\n\n    Args:\n        num_nodes (int): The number of nodes to save memories for.\n        raw_msg_dim (int): The raw message dimensionality.\n        memory_dim (int): The hidden memory dimensionality.\n        time_dim (int): The time encoding dimensionality.\n        message_module (torch.nn.Module): The message function which\n            combines source and destination node memory embeddings, the raw\n            message and the time encoding.\n        aggregator_module (torch.nn.Module): The message aggregator function\n            which aggregates messages to the same destination into a single\n            representation.\n        memory_updater_type (str): specifies whether the memory updater is GRU or RNN\n        use_src_emb_in_msg (bool): whether to use the source embeddings \n            in generation of messages\n        use_dst_emb_in_msg (bool): whether to use the destination embeddings \n            in generation of messages\n    \"\"\"\n    def __init__(self, num_nodes: int, raw_msg_dim: int, memory_dim: int,\n                 time_dim: int, message_module: Callable,\n                 aggregator_module: Callable, memory_updater_type: str,\n                 use_src_emb_in_msg: bool = False, use_dst_emb_in_msg: bool = False):\n        super().__init__()\n\n        self.num_nodes = num_nodes\n        self.raw_msg_dim = raw_msg_dim\n        self.memory_dim = memory_dim\n        self.time_dim = time_dim\n\n        self.msg_s_module = message_module\n        self.msg_d_module = copy.deepcopy(message_module)\n        self.aggr_module = aggregator_module\n        self.time_enc = TimeEncoder(time_dim)\n\n        assert memory_updater_type in ['gru', 'rnn'], \"Memor updater can be either `rnn` or `gru`.\"\n        if memory_updater_type == 'gru':  # for TGN\n            self.memory_updater = GRUCell(message_module.out_channels, memory_dim)\n        elif memory_updater_type == 'rnn':  # for JODIE & DyRep\n            self.memory_updater = RNNCell(message_module.out_channels, memory_dim)\n        else:\n            raise ValueError(\"Undefined memory updater!!! Memory updater can be either 'gru' or 'rnn'.\")\n        \n        self.use_src_emb_in_msg = use_src_emb_in_msg\n        self.use_dst_emb_in_msg = use_dst_emb_in_msg\n\n        self.register_buffer('memory', torch.empty(num_nodes, memory_dim))\n        last_update = torch.empty(self.num_nodes, dtype=torch.long)\n        self.register_buffer('last_update', last_update)\n        self.register_buffer('_assoc', torch.empty(num_nodes,\n                                                   dtype=torch.long))\n\n        self.msg_s_store = {}\n        self.msg_d_store = {}\n\n        self.reset_parameters()\n\n    @property\n    def device(self) -> torch.device:\n        return self.time_enc.lin.weight.device\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        if hasattr(self.msg_s_module, 'reset_parameters'):\n            self.msg_s_module.reset_parameters()\n        if hasattr(self.msg_d_module, 'reset_parameters'):\n            self.msg_d_module.reset_parameters()\n        if hasattr(self.aggr_module, 'reset_parameters'):\n            self.aggr_module.reset_parameters()\n        self.time_enc.reset_parameters()\n        self.memory_updater.reset_parameters()\n        self.reset_state()\n\n    def reset_state(self):\n        \"\"\"Resets the memory to its initial state.\"\"\"\n        zeros(self.memory)\n        zeros(self.last_update)\n        self._reset_message_store()\n\n    def detach(self):\n        \"\"\"Detaches the memory from gradient computation.\"\"\"\n        self.memory.detach_()\n\n    def forward(self, n_id: Tensor) -> Tuple[Tensor, Tensor]:\n        \"\"\"Returns, for all nodes :obj:`n_id`, their current memory and their\n        last updated timestamp.\"\"\"\n        if self.training:\n            memory, last_update = self._get_updated_memory(n_id)\n        else:\n            memory, last_update = self.memory[n_id], self.last_update[n_id]\n\n        return memory, last_update\n\n    def update_state(self, src: Tensor, dst: Tensor, t: Tensor, raw_msg: Tensor, \n                     embeddings: Tensor = None, assoc: Tensor = None):\n        \"\"\"Updates the memory with newly encountered interactions\n        :obj:`(src, dst, t, raw_msg)`.\"\"\"\n        n_id = torch.cat([src, dst]).unique()\n        \n        if self.training:\n            self._update_memory(n_id, embeddings, assoc)\n            self._update_msg_store(src, dst, t, raw_msg, self.msg_s_store)\n            self._update_msg_store(dst, src, t, raw_msg, self.msg_d_store)\n        else:\n            self._update_msg_store(src, dst, t, raw_msg, self.msg_s_store)\n            self._update_msg_store(dst, src, t, raw_msg, self.msg_d_store)\n            self._update_memory(n_id, embeddings, assoc)\n\n    def _reset_message_store(self):\n        i = self.memory.new_empty((0, ), device=self.device, dtype=torch.long)\n        msg = self.memory.new_empty((0, self.raw_msg_dim), device=self.device)\n        # Message store format: (src, dst, t, msg)\n        self.msg_s_store = {j: (i, i, i, msg) for j in range(self.num_nodes)}\n        self.msg_d_store = {j: (i, i, i, msg) for j in range(self.num_nodes)}\n\n    def _update_memory(self, n_id: Tensor, embeddings: Tensor = None, assoc: Tensor = None):\n        memory, last_update = self._get_updated_memory(n_id, embeddings, assoc)\n        self.memory[n_id] = memory\n        self.last_update[n_id] = last_update\n\n    def _get_updated_memory(self, n_id: Tensor, embeddings: Tensor = None, assoc: Tensor = None) -> Tuple[Tensor, Tensor]:\n        self._assoc[n_id] = torch.arange(n_id.size(0), device=n_id.device)\n\n        # Compute messages (src -> dst).\n        msg_s, t_s, src_s, dst_s = self._compute_msg(n_id, self.msg_s_store,\n                                                     self.msg_s_module, embeddings, assoc)                                          \n\n        # Compute messages (dst -> src).\n        msg_d, t_d, src_d, dst_d = self._compute_msg(n_id, self.msg_d_store,\n                                                     self.msg_d_module, embeddings, assoc)\n\n        # Aggregate messages.\n        idx = torch.cat([src_s, src_d], dim=0)\n        msg = torch.cat([msg_s, msg_d], dim=0)\n        t = torch.cat([t_s, t_d], dim=0)\n        aggr = self.aggr_module(msg, self._assoc[idx], t, n_id.size(0))\n\n        # Get local copy of updated memory.\n        memory = self.memory_updater(aggr, self.memory[n_id])\n\n        # Get local copy of updated `last_update`.\n        dim_size = self.last_update.size(0)\n        last_update = scatter(t, idx, 0, dim_size, reduce='max')[n_id]\n\n        return memory, last_update\n\n    def _update_msg_store(self, src: Tensor, dst: Tensor, t: Tensor,\n                          raw_msg: Tensor, msg_store: TGNMessageStoreType):\n        n_id, perm = src.sort()\n        n_id, count = n_id.unique_consecutive(return_counts=True)\n        for i, idx in zip(n_id.tolist(), perm.split(count.tolist())):\n            msg_store[i] = (src[idx], dst[idx], t[idx], raw_msg[idx])\n\n    def _compute_msg(self, n_id: Tensor, msg_store: TGNMessageStoreType, msg_module: Callable, \n                     embeddings: Tensor = None, assoc: Tensor = None):\n        data = [msg_store[i] for i in n_id.tolist()]\n        src, dst, t, raw_msg = list(zip(*data))\n        src = torch.cat(src, dim=0)\n        dst = torch.cat(dst, dim=0)\n        t = torch.cat(t, dim=0)\n        raw_msg = torch.cat(raw_msg, dim=0)\n        t_rel = t - self.last_update[src]\n        t_enc = self.time_enc(t_rel.to(raw_msg.dtype))\n\n        # source nodes: retrieve embeddings\n        source_memory = self.memory[src]\n        if self.use_src_emb_in_msg and embeddings != None:\n            if src.size(0) > 0:\n                curr_src, curr_src_idx = [], []\n                for s_idx, s in enumerate(src):\n                    if s in n_id:\n                        curr_src.append(s.item())\n                        curr_src_idx.append(s_idx)\n\n                source_memory[curr_src_idx] = embeddings[assoc[curr_src]]\n\n        # destination nodes: retrieve embeddings\n        destination_memory = self.memory[dst]\n        if self.use_dst_emb_in_msg and embeddings != None:\n            if dst.size(0) > 0:\n                curr_dst, curr_dst_idx = [], []\n                for d_idx, d in enumerate(dst):\n                    if d in n_id:\n                        curr_dst.append(d.item())\n                        curr_dst_idx.append(d_idx)\n                destination_memory[curr_dst_idx] = embeddings[assoc[curr_dst]]\n            \n        msg = msg_module(source_memory, destination_memory, raw_msg, t_enc)\n\n        return msg, t, src, dst\n\n    def train(self, mode: bool = True):\n        \"\"\"Sets the module in training mode.\"\"\"\n        if self.training and not mode:\n            # Flush message store to memory in case we just entered eval mode.\n            self._update_memory(\n                torch.arange(self.num_nodes, device=self.memory.device))\n            self._reset_message_store()\n        super().train(mode)"
  },
  {
    "path": "modules/msg_agg.py",
    "content": "\"\"\"\nMessage Aggregator Module\n\nReference:\n    - https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/models/tgn.html\n\"\"\"\n\n\nimport torch\nfrom torch import Tensor\nfrom torch_geometric.utils import scatter\nfrom torch_scatter import scatter_max\n\n\nclass LastAggregator(torch.nn.Module):\n    def forward(self, msg: Tensor, index: Tensor, t: Tensor, dim_size: int):\n        _, argmax = scatter_max(t, index, dim=0, dim_size=dim_size)\n        out = msg.new_zeros((dim_size, msg.size(-1)))\n        mask = argmax < msg.size(0)  # Filter items with at least one entry.\n        out[mask] = msg[argmax[mask]]\n        return out\n\n\nclass MeanAggregator(torch.nn.Module):\n    def forward(self, msg: Tensor, index: Tensor, t: Tensor, dim_size: int):\n        return scatter(msg, index, dim=0, dim_size=dim_size, reduce=\"mean\")\n"
  },
  {
    "path": "modules/msg_func.py",
    "content": "\"\"\"\nMessage Function Module\n\nReference:\n    - https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/models/tgn.html\n\"\"\"\n\nimport torch\nfrom torch import Tensor\n\n\nclass IdentityMessage(torch.nn.Module):\n    def __init__(self, raw_msg_dim: int, memory_dim: int, time_dim: int):\n        super().__init__()\n        self.out_channels = raw_msg_dim + 2 * memory_dim + time_dim\n\n    def forward(self, z_src: Tensor, z_dst: Tensor, raw_msg: Tensor, t_enc: Tensor):\n        return torch.cat([z_src, z_dst, raw_msg, t_enc], dim=-1)\n"
  },
  {
    "path": "modules/neighbor_loader.py",
    "content": "\"\"\"\nNeighbor Loader\n\nReference:\n    - https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/models/tgn.html\n\"\"\"\n\nimport copy\nfrom typing import Callable, Dict, Tuple\n\nimport torch\nfrom torch import Tensor\n\n\nclass LastNeighborLoader:\n    def __init__(self, num_nodes: int, size: int, device=None):\n        self.size = size\n\n        self.neighbors = torch.empty((num_nodes, size), dtype=torch.long, device=device)\n        self.e_id = torch.empty((num_nodes, size), dtype=torch.long, device=device)\n        self._assoc = torch.empty(num_nodes, dtype=torch.long, device=device)\n\n        self.reset_state()\n\n    def __call__(self, n_id: Tensor) -> Tuple[Tensor, Tensor, Tensor]:\n        neighbors = self.neighbors[n_id]\n        nodes = n_id.view(-1, 1).repeat(1, self.size)\n        e_id = self.e_id[n_id]\n\n        # Filter invalid neighbors (identified by `e_id < 0`).\n        mask = e_id >= 0\n        neighbors, nodes, e_id = neighbors[mask], nodes[mask], e_id[mask]\n\n        # Relabel node indices.\n        n_id = torch.cat([n_id, neighbors]).unique()\n        self._assoc[n_id] = torch.arange(n_id.size(0), device=n_id.device)\n        neighbors, nodes = self._assoc[neighbors], self._assoc[nodes]\n\n        return n_id, torch.stack([neighbors, nodes]), e_id\n\n    def insert(self, src: Tensor, dst: Tensor):\n        # Inserts newly encountered interactions into an ever-growing\n        # (undirected) temporal graph.\n\n        # Collect central nodes, their neighbors and the current event ids.\n        neighbors = torch.cat([src, dst], dim=0)\n        nodes = torch.cat([dst, src], dim=0)\n        e_id = torch.arange(\n            self.cur_e_id, self.cur_e_id + src.size(0), device=src.device\n        ).repeat(2)\n        self.cur_e_id += src.numel()\n\n        # Convert newly encountered interaction ids so that they point to\n        # locations of a \"dense\" format of shape [num_nodes, size].\n        nodes, perm = nodes.sort()\n        neighbors, e_id = neighbors[perm], e_id[perm]\n\n        n_id = nodes.unique()\n        self._assoc[n_id] = torch.arange(n_id.numel(), device=n_id.device)\n\n        dense_id = torch.arange(nodes.size(0), device=nodes.device) % self.size\n        dense_id += self._assoc[nodes].mul_(self.size)\n\n        dense_e_id = e_id.new_full((n_id.numel() * self.size,), -1)\n        dense_e_id[dense_id] = e_id\n        dense_e_id = dense_e_id.view(-1, self.size)\n\n        dense_neighbors = e_id.new_empty(n_id.numel() * self.size)\n        dense_neighbors[dense_id] = neighbors\n        dense_neighbors = dense_neighbors.view(-1, self.size)\n\n        # Collect new and old interactions...\n        e_id = torch.cat([self.e_id[n_id, : self.size], dense_e_id], dim=-1)\n        neighbors = torch.cat(\n            [self.neighbors[n_id, : self.size], dense_neighbors], dim=-1\n        )\n\n        # And sort them based on `e_id`.\n        e_id, perm = e_id.topk(self.size, dim=-1)\n        self.e_id[n_id] = e_id\n        self.neighbors[n_id] = torch.gather(neighbors, 1, perm)\n\n    def reset_state(self):\n        self.cur_e_id = 0\n        self.e_id.fill_(-1)\n"
  },
  {
    "path": "modules/nodebank.py",
    "content": "import numpy as np\r\n\r\n\r\nclass NodeBank(object):\r\n    def __init__(\r\n        self,\r\n        src: np.ndarray,\r\n        dst: np.ndarray,\r\n    ):\r\n        r\"\"\"\r\n        maintains a dictionary of all nodes seen so far (specified by the input src and dst)\r\n        Parameters:\r\n            src: source node id of the edges\r\n            dst: destination node id of the edges\r\n            ts: timestamp of the edges\r\n        \"\"\"\r\n        self.nodebank = {}\r\n        self.update_memory(src, dst)\r\n\r\n\r\n    def update_memory(self, \r\n                      update_src: np.ndarray, \r\n                      update_dst: np.ndarray) -> None:\r\n        r\"\"\"\r\n        update self.memory with newly arrived src and dst\r\n        Parameters:\r\n            src: source node id of the edges\r\n            dst: destination node id of the edges\r\n        \"\"\"\r\n        for src, dst in zip(update_src, update_dst):\r\n            if src not in self.nodebank:\r\n                self.nodebank[src] = 1\r\n            if dst not in self.nodebank:\r\n                self.nodebank[dst] = 1\r\n\r\n\r\n    def query_node(self, node: int) -> bool:\r\n        r\"\"\"\r\n        query if node is in the memory\r\n        Parameters:\r\n            node: node id to query\r\n        Returns:\r\n            True if node is in the memory, False otherwise\r\n        \"\"\"\r\n        return node in self.nodebank\r\n"
  },
  {
    "path": "modules/recurrencybaseline_predictor.py",
    "content": "\"\"\"\n from paper: \"History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting\" \nJulia Gastinger, Christian Meilicke, Federico Errica, Timo Sztyler, Anett Schuelke, Heiner Stuckenschmidt (IJCAI 2024)\n\n@inproceedings{gastinger2024baselines,\n  title={History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting},\n  author={Gastinger, Julia and Meilicke, Christian and Errica, Federico and Sztyler, Timo and Schuelke, Anett and Stuckenschmidt, Heiner},\n  booktitle={33nd International Joint Conference on Artificial Intelligence (IJCAI 2024)},\n  year={2024},\n  organization={International Joint Conferences on Artificial Intelligence Organization}\n}\n\n\"\"\"\nimport numpy as np\nfrom collections import Counter\nimport ray\nfrom modules.tkg_utils import create_scores_array\n\n@ray.remote\ndef baseline_predict_remote(num_queries, test_data, all_data, window, basis_dict, num_nodes, \n                num_rels, lmbda_psi, alpha, evaluator,first_test_ts, neg_sampler, split_mode='test'):\n    \"\"\"\n    Apply baselines psi and xi (multiprocessing possible). See baseline_predict for more details.\"\"\"\n\n    return baseline_predict(num_queries, test_data, all_data, window, basis_dict, num_nodes, \n                num_rels, lmbda_psi, alpha,  evaluator,first_test_ts, neg_sampler, split_mode)\n\n\ndef baseline_predict(num_queries, test_data, all_data, window, basis_dict, num_nodes, \n                num_rels, lmbda_psi, alpha, evaluator,first_test_ts, neg_sampler, split_mode='test'):\n    \"\"\"\n    Apply baselines psi and xi and compute scores and mrr per test or valid query (multiprocessing possible).\n\n    Parameters:\n        num_queries (int): minimum number of queries for each process\n        test_data (np.array): test quadruples (only used in single-step prediction, depending on window specified);\n            including inverse quadruples for subject prediction\n        all_data (np.array): train valid and test quadruples (test only used in single-step prediction, depending \n            on window specified); including inverse quadruples  for subject prediction\n        window: int, specifying which values from the past can be used for prediction. 0: all edges before the test \n        query timestamp are included. -2: multistep. all edges from train and validation set used. as long as they are \n        < first_test_query_ts. Int n > 0, all edges within n timestamps before the test query timestamp are included.\n        basis_dict (dict): keys: rel_ids; specifies the predefined rules for each relation. \n            in our case: head rel = tail rel, confidence =1 for all rels in train/valid set\n        score_func_psi (method): method to use for computing time decay for psi\n        num_nodes (int): number of nodes in the dataset\n        num_rels (int): number of relations in the dataset\n        lambda_psi (float): parameter for time decay function for baselinepsi. 0: no decay, >1 very steep decay\n        alpha (float): parameter, weight to combine the scores from psi and xi. alpha*scores_psi + (1-alpha)*scores_xi\n        evaluator (method): method to compute mrr and hits\n        first_test_ts (int): timestamp of the first test query\n        neg_sampler (NegSampler): negative sampler\n        split_mode (str): 'test' or 'valid'\n    Returns:\n        performance_list and hits_list (one entry per query)\n    \"\"\"\n    num_this_queries = len(test_data)\n    cur_ts = test_data[0][3]\n    first_test_query_ts = first_test_ts #test_data[0][3]\n    edges, all_data_ts = get_window_edges(all_data, cur_ts, window, first_test_query_ts) # get for the current \n                            # timestep all previous quadruples per relation that fullfill time constraints\n\n    rel_obj_dist_cur_ts = update_distributions(edges, num_rels)\n    if len(all_data_ts) >0:\n        sum_delta_t = update_delta_t(np.min(all_data_ts[:,3]), np.max(all_data_ts[:,3]), cur_ts, lmbda_psi)\n\n    predictions_xi=np.zeros(num_nodes) \n    predictions_psi=np.zeros(num_nodes)\n    # if num_queries != len(test_queries_idx):\n        # print('num_queries not equal to len(test_queries_idx)')\n        \n    hits_list = [0] * num_this_queries #len(test_queries_idx)\n    perf_list = [0] * num_this_queries #* len(test_queries_idx)\n    for j in range(num_this_queries):   \n        neg_sample_el =  neg_sampler.query_batch(np.expand_dims(np.array(test_data[j,0]), axis=0), \n                                                np.expand_dims(np.array(test_data[j,2]), axis=0), \n                                                np.expand_dims(np.array(test_data[j,4]), axis=0), \n                                                np.expand_dims(np.array(test_data[j,1]), axis=0), \n                                                split_mode)[0]\n        pos_sample_el =  test_data[j,2]      \n        test_query = test_data[j]\n        assert(pos_sample_el == test_query[2])\n        cands_dict = dict() \n        cands_dict_psi = dict() \n        # 1) update timestep and known triples\n        if test_query[3] != cur_ts: # if we have a new timestep\n            cur_ts = test_query[3]\n            edges, all_data_ts = get_window_edges(all_data, cur_ts, window, first_test_query_ts) # get for the current \n            # timestep all previous quadruples per relation that fullfill time constraints\n            # update the object and rel-object distritbutions to take into account what timesteps to use\n            if window > -1: #otherwise: multistep, we do not need to update\n                rel_obj_dist_cur_ts = update_distributions( edges, num_rels)\n\n            if len(all_data_ts) >0:\n                if window > -1: #otherwise: multistep, we do not need to update\n                    sum_delta_t = update_delta_t(np.min(all_data_ts[:,3]), np.max(all_data_ts[:,3]), cur_ts, lmbda_psi)\n                        \n        #### BASELINE  PSI\n        # 2) apply rules for relation of interest, if we have any\n        if str(test_query[1]) in basis_dict: # do we have rules for the given relation?                \n            walk_edges = match_body_relations(basis_dict[str(test_query[1])][0], edges, test_query[0]) \n                                # Find quadruples that match the rule (starting from the test query subject)\n                                # Find edges whose subject match the query subject and the relation matches\n                                # the relation in the rule body. np array with [[sub, obj, ts]]\n            if 0 not in [len(x) for x in walk_edges]: # if we found at least one potential rule                        \n                cands_dict_psi = get_candidates_psi(walk_edges[0][:,1:3], cur_ts, cands_dict, lmbda_psi, sum_delta_t)\n                if len(cands_dict_psi)>0:                \n                    # predictions_psi = create_scores_tensor(cands_dict_psi, num_nodes)\n                    predictions_psi = create_scores_array(cands_dict_psi, num_nodes)\n\n        #### BASELINE XI      \n        predictions_xi = create_scores_array(rel_obj_dist_cur_ts[test_query[1]], num_nodes)\n        # predictions_xi = create_scores_tensor(rel_obj_dist_cur_ts[test_query[1]], num_nodes)\n\n        #### Combine Both\n        predictions_all = 1000*alpha*predictions_psi + 1000*(1-alpha)*predictions_xi           \n        # predictions_of_interest_pos = predimctions_all[pos_sample_el].unsqueeze(0)\n        predictions_of_interest_pos = np.array(predictions_all[pos_sample_el])\n        predictions_of_interest_neg = predictions_all[neg_sample_el]\n        input_dict = {\n            \"y_pred_pos\": predictions_of_interest_pos,\n            \"y_pred_neg\": predictions_of_interest_neg,\n            \"eval_metric\": ['mrr'], \n        }\n\n        predictions = evaluator.eval(input_dict)\n        perf_list[j] = float(predictions['mrr'])\n        hits_list[j] = float(predictions['hits@10'])\n        \n\n    return perf_list, hits_list\n\n\ndef match_body_relations(rule, edges, test_query_sub):\n    \"\"\"\n    for rules of length 1\n    Find quadruples that match the rule (starting from the test query subject)\n    Find edges whose subject match the query subject and the relation matches\n    the relation in the rule body. \n    Memory-efficient implementation.\n\n    modified from Tlogic rule_application.py https://github.com/liu-yushan/TLogic/blob/main/mycode/rule_application.py\n    shortened because we only have rules of length one \n\n    Parameters:\n        rule (dict): rule from rules_dict\n        edges (dict): edges for rule application\n        test_query_sub (int): test query subject\n    Returns:\n        walk_edges (list of np.ndarrays): edges that could constitute rule walks\n    \"\"\"\n\n    rels = rule[\"body_rels\"]\n    # Match query subject and first body relation\n    try:\n        rel_edges = edges[rels[0]]\n        mask = rel_edges[:, 0] == test_query_sub\n        new_edges = rel_edges[mask]\n        walk_edges = [np.hstack((new_edges[:, 0:1], new_edges[:, 2:4]))]  # [sub, obj, ts]\n\n    except KeyError:\n        walk_edges = [[]]\n    return walk_edges #subject object timestamp\n\ndef score_delta(cands_ts, test_query_ts, lmbda):\n    \"\"\" deta function to score a given candidate based on its distance to current timestep and based on param lambda\n    Parameters:\n        cands_ts (int): timestep of candidate(s)\n        test_query_ts (int): timestep of current test quadruple\n        lmbda (float): param to specify how steep decay is\n    Returns:\n        score (float): score for a given candicate\n    \"\"\"\n    score = pow(2, lmbda * (cands_ts - test_query_ts))\n    return score\n\ndef get_window_edges(all_data, test_query_ts, window=-2, first_test_query_ts=0): \n    \"\"\"\n    modified from Tlogic rule_application.py https://github.com/liu-yushan/TLogic/blob/main/mycode/rule_application.py\n    introduce window -2 \n\n    Get the edges in the data (for rule application) that occur in the specified time window.\n    If window is 0, all edges before the test query timestamp are included.\n    If window is -2, all edges from train and validation set are used. as long as they are < first_test_query_ts\n    If window is an integer n > 0, all edges within n timestamps before the test query\n    timestamp are included.\n\n    Parameters:\n        all_data (np.ndarray): complete dataset (train/valid/test)\n        test_query_ts (np.ndarray): test query timestamp\n        window (int): time window used for rule application\n        first_test_query_ts (int): smallest timestamp from test set (eval_paper_authors)\n\n    Returns:\n        window_edges (dict): edges in the window for rule application\n    \"\"\"\n\n    if window > 0:\n        mask = (all_data[:, 3] < test_query_ts) * (\n            all_data[:, 3] >= test_query_ts - window \n        )\n        window_edges = quads_per_rel(all_data[mask]) # quadruples per relation that fullfill the time constraints \n    elif window == 0:\n        mask = all_data[:, 3] < test_query_ts #!!! \n        window_edges = quads_per_rel(all_data[mask]) \n    elif window == -2: #modified eval_paper_authors: added this option\n        mask = all_data[:, 3] < first_test_query_ts # all edges at timestep smaller then the test queries. meaning all from train and valid set\n        window_edges = quads_per_rel(all_data[mask])  \n    elif window == -200: #modified eval_paper_authors: added this option\n        abswindow = 200\n        mask = (all_data[:, 3] < first_test_query_ts) * (\n            all_data[:, 3] >= first_test_query_ts - abswindow  # all edges at timestep smaller than the test queries - 200\n        )\n        window_edges = quads_per_rel(all_data[mask])\n    all_data_ts = all_data[mask]\n    return window_edges, all_data_ts\n\n\ndef quads_per_rel(quads):\n    \"\"\"\n    modified from Tlogic rule_application.py https://github.com/liu-yushan/TLogic/blob/main/mycode/rule_application.py\n    Store all edges for each relation.\n\n    Parameters:\n        quads (np.ndarray): indices of quadruples\n\n    Returns:\n        edges (dict): edges for each relation\n    \"\"\"\n\n    edges = dict()\n    relations = list(set(quads[:, 1]))\n    for rel in relations:\n        edges[rel] = quads[quads[:, 1] == rel]\n    return edges\n\ndef get_candidates_psi(rule_walks, test_query_ts, cands_dict,lmbda, sum_delta_t):\n    \"\"\"\n    Get answer candidates from the walks that follow the rule.\n    Add the confidence of the rule that leads to these candidates.\n    originally from TLogic https://github.com/liu-yushan/TLogic/blob/main/mycode/apply.py but heavily modified\n\n    Parameters:\n        rule_walks (np.array): rule walks np array with [[sub, obj]]\n        test_query_ts (int): test query timestamp\n        cands_dict (dict): candidates along with the confidences of the rules that generated these candidates\n        score_func (function): function for calculating the candidate score\n        lmbda (float): parameter to describe decay of the scoring function\n        sum_delta_t: to be used in denominator of scoring fct\n    Returns:\n        cands_dict (dict): keys: candidates, values: score for the candidates  \"\"\"\n\n    cands = set(rule_walks[:,0]) \n\n    for cand in cands:\n        cands_walks = rule_walks[rule_walks[:,0] == cand] \n        score = score_psi(cands_walks, test_query_ts, lmbda, sum_delta_t).astype(np.float64)\n        cands_dict[cand] = score\n\n    return cands_dict\n\ndef update_delta_t(min_ts, max_ts, cur_ts, lmbda):\n    \"\"\" compute denominator for scoring function psi_delta\n    Patameters:\n        min_ts (int): minimum available timestep\n        max_ts (int): maximum available timestep\n        cur_ts (int): current timestep\n        lmbda (float): time decay parameter\n    Returns:\n        delta_all (float): sum(delta_t for all available timesteps between min_ts and max_ts)\n    \"\"\"\n    timesteps = np.arange(min_ts, max_ts)\n    now = np.ones(len(timesteps))*cur_ts\n    delta_all = score_delta(timesteps, now, lmbda)\n    delta_all = np.sum(delta_all)\n    return delta_all\n\ndef score_psi(cands_walks, test_query_ts, lmbda, sum_delta_t):\n    \"\"\"\n    Calculate candidate score depending on the time difference.\n\n    Parameters:\n        cands_walks (np.array): rule walks np array with [[sub, obj]]\n        test_query_ts (int): test query timestamp\n        lmbda (float): rate of exponential distribution\n\n    Returns:\n        score (float): candidate score\n    \"\"\"\n\n    all_cands_ts = cands_walks[:,1] #cands_walks[\"timestamp_0\"].reset_index()[\"timestamp_0\"]\n    ts_series = np.ones(len(all_cands_ts))*test_query_ts \n    scores =  score_delta(all_cands_ts, ts_series, lmbda) # Score depending on time difference\n    if sum_delta_t == 0:\n        print(scores, \"sum_delta_t is zero\")\n        print(all_cands_ts)\n        score = np.sum(scores)\n        # print(score)\n    else:            \n        score = np.sum(scores)/sum_delta_t\n\n    return score   \n\ndef update_distributions(ts_edges,num_rels):\n    \"\"\" update the distributions with more recent infos, if there is a more recent timestep available, depending on window parameter\n    take into account scaling factor\n    \"\"\"\n    rel_obj_dist_cur_ts= calculate_obj_distribution(ts_edges, num_rels) #, lmbda, cur_ts)\n    return  rel_obj_dist_cur_ts\n\ndef calculate_obj_distribution(edges, num_rels):\n    \"\"\"\n    Calculate the overall object distribution and the object distribution for each relation in the data.\n\n    Parameters:\n        edges (dict): edges from the data on which the rules should be learned\n\n    Returns:\n        rel_obj_dist (dict): object distribution for each relation\n    \"\"\"\n    rel_obj_dist_scaled = dict()\n    for rel in range(num_rels):\n        rel_obj_dist_scaled[rel] = {}\n    \n    for rel in edges:\n        objects = edges[rel][:, 2]\n        dist = Counter(objects)\n        for obj in dist:\n            dist[obj] /= len(objects)\n        rel_obj_dist_scaled[rel] = {k: v for k, v in dist.items()}\n\n    return rel_obj_dist_scaled\n\ndef update_delta_t(min_ts, max_ts, cur_ts, lmbda):\n    \"\"\" compute denominator for scoring function psi_delta\n    Patameters:\n        min_ts (int): minimum available timestep\n        max_ts (int): maximum available timestep\n        cur_ts (int): current timestep\n        lmbda (float): time decay parameter\n    Returns:\n        delta_all (float): sum(delta_t for all available timesteps between min_ts and max_ts)\n    \"\"\"\n    timesteps = np.arange(min_ts, max_ts)\n    now = np.ones(len(timesteps))*cur_ts\n    delta_all = score_delta(timesteps, now, lmbda)\n    delta_all = np.sum(delta_all)\n    return delta_all"
  },
  {
    "path": "modules/rgcn_layers.py",
    "content": "\"\"\"\nhttps://github.com/Lee-zix/CEN/blob/main/rgcn/layers.py\n\"\"\"\n\nimport dgl.function as fn\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass RGCNLayer(nn.Module):\n    def __init__(self, in_feat, out_feat, bias=None, activation=None,\n                 self_loop=False, skip_connect=False, dropout=0.0, layer_norm=False):\n        \"\"\" init of the RGCN layer class\n        from https://github.com/Lee-zix/CEN/blob/main/rgcn/layers.py\n        \"\"\"\n        super(RGCNLayer, self).__init__()\n        self.bias = bias\n        self.activation = activation\n        self.self_loop = self_loop\n        self.skip_connect = skip_connect\n        self.layer_norm = layer_norm\n\n        if self.bias:\n            self.bias = nn.Parameter(torch.Tensor(out_feat))\n            nn.init.xavier_uniform_(self.bias,\n                                    gain=nn.init.calculate_gain('relu'))\n\n        # weight for self loop\n        if self.self_loop:\n            self.loop_weight = nn.Parameter(torch.Tensor(in_feat, out_feat))\n            nn.init.xavier_uniform_(self.loop_weight, gain=nn.init.calculate_gain('relu'))\n            # self.loop_weight = nn.Parameter(torch.eye(out_feat), requires_grad=False)\n\n        if self.skip_connect:\n            self.skip_connect_weight = nn.Parameter(torch.Tensor(out_feat, out_feat))   # 和self-loop不一样，是跨层的计算\n            nn.init.xavier_uniform_(self.skip_connect_weight,\n                                    gain=nn.init.calculate_gain('relu'))\n\n            self.skip_connect_bias = nn.Parameter(torch.Tensor(out_feat))\n            nn.init.zeros_(self.skip_connect_bias)  # 初始化设置为0\n\n        if dropout:\n            self.dropout = nn.Dropout(dropout)\n        else:\n            self.dropout = None\n\n        if self.layer_norm:\n            self.normalization_layer = nn.LayerNorm(out_feat, elementwise_affine=False)\n\n    # define how propagation is done in subclass\n    def propagate(self, g):\n        raise NotImplementedError\n\n    def forward(self, g, prev_h=[]):\n        if self.self_loop:\n            #print(self.loop_weight)\n            loop_message = torch.mm(g.ndata['h'], self.loop_weight)\n            if self.dropout is not None:\n                loop_message = self.dropout(loop_message)\n        # self.skip_connect_weight.register_hook(lambda g: print(\"grad of skip connect weight: {}\".format(g)))\n        if len(prev_h) != 0 and self.skip_connect:\n            skip_weight = F.sigmoid(torch.mm(prev_h, self.skip_connect_weight) + self.skip_connect_bias)     # 使用sigmoid，让值在0~1\n            # print(\"skip_ weight\")\n            # print(skip_weight)\n            # print(\"skip connect weight\")\n            # print(self.skip_connect_weight)\n            # print(torch.mm(prev_h, self.skip_connect_weight))\n\n        self.propagate(g)  # 这里是在计算从周围节点传来的信息\n\n        # apply bias and activation\n        node_repr = g.ndata['h']\n        if self.bias:\n            node_repr = node_repr + self.bias\n        # print(len(prev_h))\n        if len(prev_h) != 0 and self.skip_connect:   # 两次计算loop_message的方式不一样，前者激活后再加权\n            previous_node_repr = (1 - skip_weight) * prev_h\n            if self.activation:\n                node_repr = self.activation(node_repr)\n            if self.self_loop:\n                if self.activation:\n                    loop_message = skip_weight * self.activation(loop_message)\n                else:\n                    loop_message = skip_weight * loop_message\n                node_repr = node_repr + loop_message\n            node_repr = node_repr + previous_node_repr\n        else:\n            if self.self_loop:\n                node_repr = node_repr + loop_message\n            if self.layer_norm:\n                node_repr = self.normalization_layer(node_repr)\n            if self.activation:\n                node_repr = self.activation(node_repr)\n            # print(\"node_repr\")\n            # print(node_repr)\n        g.ndata['h'] = node_repr\n        return node_repr\n\n\nclass RGCNBasisLayer(RGCNLayer):\n    def __init__(self, in_feat, out_feat, num_rels, num_bases=-1, bias=None,\n                 activation=None, is_input_layer=False):\n        super(RGCNBasisLayer, self).__init__(in_feat, out_feat, bias, activation)\n        self.in_feat = in_feat\n        self.out_feat = out_feat\n        self.num_rels = num_rels\n        self.num_bases = num_bases\n        self.is_input_layer = is_input_layer\n        if self.num_bases <= 0 or self.num_bases > self.num_rels:\n            self.num_bases = self.num_rels\n\n        # add basis weights\n        self.weight = nn.Parameter(torch.Tensor(self.num_bases, self.in_feat,\n                                                self.out_feat))\n        if self.num_bases < self.num_rels:\n            # linear combination coefficients\n            self.w_comp = nn.Parameter(torch.Tensor(self.num_rels,\n                                                    self.num_bases))\n        nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))\n        if self.num_bases < self.num_rels:\n            nn.init.xavier_uniform_(self.w_comp,\n                                    gain=nn.init.calculate_gain('relu'))\n\n    def propagate(self, g):\n        if self.num_bases < self.num_rels:\n            # generate all weights from bases\n            weight = self.weight.view(self.num_bases,\n                                      self.in_feat * self.out_feat)\n            weight = torch.matmul(self.w_comp, weight).view(\n                self.num_rels, self.in_feat, self.out_feat)\n        else:\n            weight = self.weight\n\n        if self.is_input_layer:\n            def msg_func(edges):\n                # for input layer, matrix multiply can be converted to be\n                # an embedding lookup using source node id\n                embed = weight.view(-1, self.out_feat)\n                index = edges.data['type'] * self.in_feat + edges.src['id']\n                return {'msg': embed.index_select(0, index)}\n        else:\n            def msg_func(edges):\n                w = weight.index_select(0, edges.data['type'])\n                msg = torch.bmm(edges.src['h'].unsqueeze(1), w).squeeze()\n                return {'msg': msg}\n\n        def apply_func(nodes):\n            return {'h': nodes.data['h'] * nodes.data['norm']}\n\n        g.update_all(msg_func, fn.sum(msg='msg', out='h'), apply_func)\n\n\nclass RGCNBlockLayer(RGCNLayer):\n    def __init__(self, in_feat, out_feat, num_rels, num_bases, bias=None,\n                 activation=None, self_loop=False, dropout=0.0, skip_connect=False, layer_norm=False):\n        super(RGCNBlockLayer, self).__init__(in_feat, out_feat, bias,\n                                             activation, self_loop=self_loop, skip_connect=skip_connect,\n                                             dropout=dropout)\n        self.num_rels = num_rels\n        self.num_bases = num_bases\n\n        assert self.num_bases > 0\n\n        self.out_feat = out_feat\n        self.submat_in = in_feat // self.num_bases\n        self.submat_out = out_feat // self.num_bases\n\n        # assuming in_feat and out_feat are both divisible by num_bases\n        self.weight = nn.Parameter(torch.Tensor(\n            self.num_rels, self.num_bases * self.submat_in * self.submat_out))\n        nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))\n\n    def msg_func(self, edges):\n        weight = self.weight.index_select(0, edges.data['type']).view(\n                    -1, self.submat_in, self.submat_out)    # [edge_num, submat_in, submat_out]\n        node = edges.src['h'].view(-1, 1, self.submat_in)   # [edge_num * num_bases, 1, submat_in]->\n        msg = torch.bmm(node, weight).view(-1, self.out_feat)   # [edge_num, out_feat]\n        return {'msg': msg}\n\n    def propagate(self, g):\n        g.update_all(self.msg_func, fn.sum(msg='msg', out='h'), self.apply_func)\n        # g.updata_all ({'msg': msg} , fn.sum(msg='msg', out='h'), {'h': nodes.data['h'] * nodes.data[''norm]})\n\n    def apply_func(self, nodes):\n        return {'h': nodes.data['h'] * nodes.data['norm']}\n\n\nclass UnionRGCNLayer(nn.Module):\n    def __init__(self, in_feat, out_feat, num_rels, num_bases=-1,  bias=None,\n                 activation=None, self_loop=False, dropout=0.0, skip_connect=False, rel_emb=None):\n        super(UnionRGCNLayer, self).__init__()\n\n        self.in_feat = in_feat\n        self.out_feat = out_feat\n        self.bias = bias\n        self.activation = activation\n        self.self_loop = self_loop\n        self.num_rels = num_rels\n        self.skip_connect = skip_connect\n        self.emb_rel = rel_emb\n        self.ob = None\n        self.sub = None\n\n        # WL\n        self.weight_neighbor = nn.Parameter(torch.Tensor(self.in_feat, self.out_feat))\n        nn.init.xavier_uniform_(self.weight_neighbor, gain=nn.init.calculate_gain('relu'))\n\n        if self.self_loop:\n            self.loop_weight = nn.Parameter(torch.Tensor(in_feat, out_feat))\n            nn.init.xavier_uniform_(self.loop_weight, gain=nn.init.calculate_gain('relu'))\n            self.evolve_loop_weight = nn.Parameter(torch.Tensor(in_feat, out_feat))\n            nn.init.xavier_uniform_(self.evolve_loop_weight, gain=nn.init.calculate_gain('relu'))\n\n        if self.skip_connect:\n            self.skip_connect_weight = nn.Parameter(torch.Tensor(out_feat, out_feat))   # 和self-loop不一样，是跨层的计算\n            nn.init.xavier_uniform_(self.skip_connect_weight,gain=nn.init.calculate_gain('relu'))\n            self.skip_connect_bias = nn.Parameter(torch.Tensor(out_feat))\n            nn.init.zeros_(self.skip_connect_bias)  # 初始化设置为0\n\n        if dropout:\n            self.dropout = nn.Dropout(dropout)\n        else:\n            self.dropout = None\n\n    def propagate(self, g):\n        g.update_all(lambda x: self.msg_func(x), fn.sum(msg='msg', out='h'), self.apply_func)\n\n    def forward(self, g, prev_h):\n        # self.sub = sub\n        # self.ob = ob\n        if self.self_loop:\n            #loop_message = torch.mm(g.ndata['h'], self.loop_weight)\n            # masked_index = torch.masked_select(torch.arange(0, g.number_of_nodes(), dtype=torch.long), (g.in_degrees(range(g.number_of_nodes())) > 0))\n            masked_index = torch.masked_select(\n                torch.arange(0, g.number_of_nodes(), dtype=torch.long).cuda(),\n                (g.in_degrees(range(g.number_of_nodes())) > 0))\n            loop_message = torch.mm(g.ndata['h'], self.evolve_loop_weight)\n            loop_message[masked_index, :] = torch.mm(g.ndata['h'], self.loop_weight)[masked_index, :]\n        if len(prev_h) != 0 and self.skip_connect:\n            skip_weight = F.sigmoid(torch.mm(prev_h, self.skip_connect_weight) + self.skip_connect_bias)     # 使用sigmoid，让值在0~1\n\n        # calculate the neighbor message with weight_neighbor\n        self.propagate(g)\n        node_repr = g.ndata['h']\n\n        # print(len(prev_h))\n        if len(prev_h) != 0 and self.skip_connect:  # 两次计算loop_message的方式不一样，前者激活后再加权\n            if self.self_loop:\n                node_repr = node_repr + loop_message\n            node_repr = skip_weight * node_repr + (1 - skip_weight) * prev_h\n        else:\n            if self.self_loop:\n                node_repr = node_repr + loop_message\n\n        if self.activation:\n            node_repr = self.activation(node_repr)\n        if self.dropout is not None:\n            node_repr = self.dropout(node_repr)\n        g.ndata['h'] = node_repr\n        return node_repr\n\n    def msg_func(self, edges):\n        # if reverse:\n        #     relation = self.rel_emb.index_select(0, edges.data['type_o']).view(-1, self.out_feat)\n        # else:\n        #     relation = self.rel_emb.index_select(0, edges.data['type_s']).view(-1, self.out_feat)\n        relation = self.emb_rel.index_select(0, edges.data['type']).view(-1, self.out_feat)\n        edge_type = edges.data['type']\n        edge_num = edge_type.shape[0]\n        node = edges.src['h'].view(-1, self.out_feat)\n        # node = torch.cat([torch.matmul(node[:edge_num // 2, :], self.sub),\n        #                  torch.matmul(node[edge_num // 2:, :], self.ob)])\n        # node = torch.matmul(node, self.sub)\n\n        # after add inverse edges, we only use message pass when h as tail entity\n        # 这里计算的是每个节点发出的消息，节点发出消息时其作为头实体\n        # msg = torch.cat((node, relation), dim=1)\n        msg = node + relation\n        # calculate the neighbor message with weight_neighbor\n        msg = torch.mm(msg, self.weight_neighbor)\n        return {'msg': msg}\n\n    def apply_func(self, nodes):\n        return {'h': nodes.data['h'] * nodes.data['norm']}"
  },
  {
    "path": "modules/rgcn_model.py",
    "content": "\"\"\"\nhttps://github.com/nec-research/CEN/blob/main/src/model.py\n\"\"\"\n\nimport torch.nn as nn\n\n\nclass BaseRGCN(nn.Module):\n    def __init__(self, num_nodes, h_dim, out_dim, num_rels, num_bases=-1, num_basis=-1,\n                 num_hidden_layers=1, dropout=0, self_loop=False, skip_connect=False, encoder_name=\"\", opn=\"sub\", \n                 rel_emb=None, use_cuda=False, analysis=False):\n        super(BaseRGCN, self).__init__()\n        self.num_nodes = num_nodes\n        self.h_dim = h_dim\n        self.out_dim = out_dim\n        self.num_rels = num_rels\n        self.num_bases = num_bases\n        self.num_basis = num_basis\n        self.num_hidden_layers = num_hidden_layers\n        self.dropout = dropout\n        self.skip_connect = skip_connect\n        self.self_loop = self_loop\n        self.encoder_name = encoder_name\n        self.use_cuda = use_cuda\n        self.run_analysis = analysis\n        self.skip_connect = skip_connect\n        print(\"use layer :{}\".format(encoder_name))\n        self.rel_emb = rel_emb\n        self.opn = opn\n        # create rgcn layers\n        self.build_model()\n        # create initial features\n        self.features = self.create_features()\n        \n    def build_model(self):\n        self.layers = nn.ModuleList()\n        # i2h\n        i2h = self.build_input_layer()\n        if i2h is not None:\n            self.layers.append(i2h)\n        # h2h\n        for idx in range(self.num_hidden_layers):\n\n            h2h = self.build_hidden_layer(idx)\n            self.layers.append(h2h)\n        # h2o\n        h2o = self.build_output_layer()\n        if h2o is not None:\n            self.layers.append(h2o)\n\n    # initialize feature for each node\n    def create_features(self):\n        return None\n\n    def build_input_layer(self):\n        return None\n\n    def build_hidden_layer(self, idx):\n        raise NotImplementedError\n\n    def build_output_layer(self):\n        return None\n\n    def forward(self, g):\n        if self.features is not None:\n            g.ndata['id'] = self.features\n        print(\"h before GCN message passing\")\n        print(g.ndata['h'])\n        print(\"h behind GCN message passing\")\n        for layer in self.layers:\n            layer(g)\n        print(g.ndata['h'])\n        return g.ndata.pop('h')"
  },
  {
    "path": "modules/rrgcn.py",
    "content": "\"\"\"\nhttps://github.com/Lee-zix/CEN/blob/main/src/rrgcn.py\n\"\"\"\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport math\n\nfrom modules.rgcn_layers import UnionRGCNLayer, RGCNBlockLayer\nfrom modules.rgcn_model import BaseRGCN\nfrom modules.decoder import ConvTransE\nimport numpy as np\nclass RGCNCell(BaseRGCN):\n    def build_hidden_layer(self, idx):\n        act = F.rrelu\n        if idx:\n            self.num_basis = 0\n        print(\"activate function: {}\".format(act))\n        if self.skip_connect:\n            sc = False if idx == 0 else True\n        else:\n            sc = False\n        if self.encoder_name == \"uvrgcn\":\n            return UnionRGCNLayer(self.h_dim, self.h_dim, self.num_rels, self.num_bases,\n                             activation=act, dropout=self.dropout, self_loop=self.self_loop, skip_connect=sc, \n                             rel_emb=self.rel_emb)\n        else:\n            raise NotImplementedError\n\n    def forward(self, g, init_ent_emb):\n        if self.encoder_name == \"uvrgcn\":\n            node_id = g.ndata['id'].squeeze()\n            g.ndata['h'] = init_ent_emb[node_id]\n            for i, layer in enumerate(self.layers):\n                layer(g, [])\n            return g.ndata.pop('h')\n        else:\n            if self.features is not None:\n                print(\"----------------Feature is not None, Attention ------------\")\n                g.ndata['id'] = self.features\n            node_id = g.ndata['id'].squeeze()\n            g.ndata['h'] = init_ent_emb[node_id]\n            if self.skip_connect:\n                prev_h = []\n                for layer in self.layers:\n                    prev_h = layer(g, prev_h)\n            else:\n                for layer in self.layers:\n                    layer(g, [])\n            return g.ndata.pop('h')\n\nclass RecurrentRGCNCEN(nn.Module):\n    def __init__(self, decoder_name, encoder_name, num_ents, num_rels, h_dim, opn, sequence_len, num_bases=-1, num_basis=-1,\n                 num_hidden_layers=1, dropout=0, self_loop=False, skip_connect=False, layer_norm=False, input_dropout=0,\n                 hidden_dropout=0, feat_dropout=0, entity_prediction=False, relation_prediction=False, use_cuda=False,\n                 gpu = 0):\n        super(RecurrentRGCNCEN, self).__init__()\n\n        self.decoder_name = decoder_name\n        self.encoder_name = encoder_name\n        self.num_rels = num_rels\n        self.num_ents = num_ents\n        self.opn = opn\n        self.sequence_len = sequence_len\n        self.h_dim = h_dim\n        self.layer_norm = layer_norm\n        self.h = None\n        self.relation_prediction = relation_prediction\n        self.entity_prediction = entity_prediction\n        self.gpu = gpu\n\n        self.emb_rel = torch.nn.Parameter(torch.Tensor(self.num_rels * 2, self.h_dim), requires_grad=True).float() #TODO: correct number?\n        torch.nn.init.xavier_normal_(self.emb_rel)\n\n        self.dynamic_emb = torch.nn.Parameter(torch.Tensor(num_ents, h_dim), requires_grad=True).float()\n        torch.nn.init.normal_(self.dynamic_emb)\n\n        self.loss_e = torch.nn.CrossEntropyLoss()\n\n\n        self.rgcn = RGCNCell(num_ents,\n                             h_dim,\n                             h_dim,\n                             num_rels * 2,\n                             num_bases,\n                             num_basis,\n                             num_hidden_layers,\n                             dropout,\n                             self_loop,\n                             skip_connect,\n                             encoder_name,\n                             self.opn,\n                             self.emb_rel,\n                             use_cuda)\n\n        self.time_gate_weight = nn.Parameter(torch.Tensor(h_dim, h_dim))    \n        nn.init.xavier_uniform_(self.time_gate_weight, gain=nn.init.calculate_gain('relu'))\n        self.time_gate_bias = nn.Parameter(torch.Tensor(h_dim))\n        nn.init.zeros_(self.time_gate_bias)    \n        \n      \n        if decoder_name == \"convtranse\":\n            self.decoder_ob = ConvTransE(num_ents, h_dim, input_dropout, hidden_dropout, feat_dropout, \n                                         sequence_len=self.sequence_len, model_name='CEN')\n        else:\n            raise NotImplementedError \n\n\n    def forward(self, g_list, use_cuda):\n        evolve_embs = []\n        self.h = F.normalize(self.dynamic_emb) if self.layer_norm else self.dynamic_emb\n        for i, g in enumerate(g_list):\n            g = g.to(self.gpu)\n            current_h = self.rgcn.forward(g, self.h)\n            current_h = F.normalize(current_h) if self.layer_norm else current_h\n            time_weight = F.sigmoid(torch.mm(self.h, self.time_gate_weight) + self.time_gate_bias)\n            self.h = time_weight * current_h + (1-time_weight) * self.h\n            self.h = F.normalize(self.h)\n            evolve_embs.append(self.h)\n        return evolve_embs, self.emb_rel\n\n    def predict(self, test_graph, test_triplets, use_cuda, neg_samples_batch=None, pos_samples_batch=None, \n                evaluator=None, metric=None):\n        with torch.no_grad():\n            scores = torch.zeros(len(test_triplets), self.num_ents).cuda()\n            evolve_embeddings = []\n            for idx in range(len(test_graph)):\n                evolve_embs, r_emb = self.forward(test_graph[idx:], use_cuda)\n                evolve_embeddings.append(evolve_embs[-1])\n            evolve_embeddings.reverse()\n\n            if neg_samples_batch != None: # added by tgb team\n                perf_list = []\n                hits_list = []\n                for query_id, query in enumerate(neg_samples_batch): # for each sample separately\n                    pos = pos_samples_batch[query_id]\n                    neg = torch.tensor(query).to(pos.device)\n                    all =torch.cat((pos.unsqueeze(0), neg), dim=0)\n                    score_list = self.decoder_ob.forward(evolve_embeddings, r_emb, test_triplets[query_id].unsqueeze(0),\n                            samples_of_interest_emb= [evolve_embeddings[i][all] for i in range(len(evolve_embeddings))])\n                    score_list = [_.unsqueeze(2) for _ in score_list]\n                    scores_b = torch.cat(score_list, dim=2)\n                    scores_b = torch.softmax(scores_b, dim=1)\n                    scores_b = torch.sum(scores_b, dim=-1)\n                    # compute MRR\n                    input_dict = {\n                        \"y_pred_pos\": np.array([scores_b[0,0].cpu()]),\n                        \"y_pred_neg\": np.array(scores_b[0,1:].cpu()),\n                        \"eval_metric\": [metric],\n                    }\n                    prediction_perf = evaluator.eval(input_dict)\n                    perf_list.append(prediction_perf[metric])\n                    hits_list.append(prediction_perf['hits@10'])\n\n            else:\n                score_list = self.decoder_ob.forward(evolve_embeddings, r_emb, test_triplets, mode=\"test\")\n\n                score_list = [_.unsqueeze(2) for _ in score_list]\n                scores = torch.cat(score_list, dim=2)\n                scores = torch.softmax(scores, dim=1)\n                scores = torch.sum(scores, dim=-1)\n\n            return scores, perf_list, hits_list\n\n    def get_ft_loss(self, glist, triple_list,  use_cuda):\n        #\"\"\"\n        #:param glist:\n        #:param triplets:\n        #:param use_cuda:\n        #:return:\n        #\"\"\"\n        glist = [g.to(self.gpu) for g in glist]\n        loss_ent = torch.zeros(1).cuda().to(self.gpu) if use_cuda else torch.zeros(1)\n\n        # for step, triples in enumerate(triple_list):\n        evolve_embeddings = []\n        for idx in range(len(glist)):\n            evolve_embs, r_emb = self.forward(glist[idx:], use_cuda)\n            evolve_embeddings.append(evolve_embs[-1])\n        evolve_embeddings.reverse()\n        scores_ob = self.decoder_ob.forward(evolve_embeddings, r_emb, triple_list[-1])#.view(-1, self.num_ents)\n        for idx in range(len(glist)):\n            loss_ent += self.loss_e(scores_ob[idx], triple_list[-1][:, 2])\n        return loss_ent\n\n    def get_loss(self, glist, triples, prev_model, use_cuda):\n        \"\"\"\n        :param glist:\n        :param triplets:\n        :param use_cuda:\n        :return:\n        \"\"\"\n        loss_ent = torch.zeros(1).cuda().to(self.gpu) if use_cuda else torch.zeros(1)\n\n        evolve_embeddings = []\n        for idx in range(len(glist)):\n            evolve_embs, r_emb = self.forward(glist[idx:], use_cuda)\n            evolve_embeddings.append(evolve_embs[-1])\n        evolve_embeddings.reverse()\n        if self.entity_prediction:\n            scores_ob = self.decoder_ob.forward(evolve_embeddings, r_emb, triples)#.view(-1, self.num_ents)\n            for idx in range(len(glist)):\n                loss_ent += self.loss_e(scores_ob[idx], triples[:, 2])\n        return loss_ent\n    \n\n\nclass RecurrentRGCNREGCN(nn.Module):\n    def __init__(self, decoder_name, encoder_name, num_ents, num_rels, num_static_rels, num_words, h_dim, opn, sequence_len, num_bases=-1, num_basis=-1,\n                 num_hidden_layers=1, dropout=0, self_loop=False, skip_connect=False, layer_norm=False, input_dropout=0,\n                 hidden_dropout=0, feat_dropout=0, aggregation='cat', weight=1, discount=0, angle=0, use_static=False,\n                 entity_prediction=False, relation_prediction=False, use_cuda=False,\n                 gpu = 0, analysis=False):\n        super(RecurrentRGCNREGCN, self).__init__()\n\n        self.decoder_name = decoder_name\n        self.encoder_name = encoder_name\n        self.num_rels = num_rels\n        self.num_ents = num_ents\n        self.opn = opn\n        self.num_words = num_words\n        self.num_static_rels = num_static_rels\n        self.sequence_len = sequence_len\n        self.h_dim = h_dim\n        self.layer_norm = layer_norm\n        self.h = None\n        self.run_analysis = analysis\n        self.aggregation = aggregation\n        self.relation_evolve = False\n        self.weight = weight\n        self.discount = discount\n        self.use_static = use_static\n        self.angle = angle\n        self.relation_prediction = relation_prediction\n        self.entity_prediction = entity_prediction\n        self.emb_rel = None\n        self.gpu = gpu\n\n        self.w1 = torch.nn.Parameter(torch.Tensor(self.h_dim, self.h_dim), requires_grad=True).float()\n        torch.nn.init.xavier_normal_(self.w1)\n\n        self.w2 = torch.nn.Parameter(torch.Tensor(self.h_dim, self.h_dim), requires_grad=True).float()\n        torch.nn.init.xavier_normal_(self.w2)\n\n        self.emb_rel = torch.nn.Parameter(torch.Tensor(self.num_rels * 2, self.h_dim), requires_grad=True).float()\n        torch.nn.init.xavier_normal_(self.emb_rel)\n\n        self.dynamic_emb = torch.nn.Parameter(torch.Tensor(num_ents, h_dim), requires_grad=True).float()\n        torch.nn.init.normal_(self.dynamic_emb)\n\n\n        if self.use_static:\n            self.words_emb = torch.nn.Parameter(torch.Tensor(self.num_words, h_dim), requires_grad=True).float()\n            torch.nn.init.xavier_normal_(self.words_emb)\n            self.statci_rgcn_layer = RGCNBlockLayer(self.h_dim, self.h_dim, self.num_static_rels*2, num_bases,\n                                                    activation=F.rrelu, dropout=dropout, self_loop=False, skip_connect=False)\n            self.static_loss = torch.nn.MSELoss()\n\n        self.loss_r = torch.nn.CrossEntropyLoss()\n        self.loss_e = torch.nn.CrossEntropyLoss()\n\n        self.rgcn = RGCNCell(num_ents,\n                             h_dim,\n                             h_dim,\n                             num_rels * 2,\n                             num_bases,\n                             num_basis,\n                             num_hidden_layers,\n                             dropout,\n                             self_loop,\n                             skip_connect,\n                             encoder_name,\n                             self.opn,\n                             self.emb_rel,\n                             use_cuda,\n                             analysis)\n\n        self.time_gate_weight = nn.Parameter(torch.Tensor(h_dim, h_dim))    \n        nn.init.xavier_uniform_(self.time_gate_weight, gain=nn.init.calculate_gain('relu'))\n        self.time_gate_bias = nn.Parameter(torch.Tensor(h_dim))\n        nn.init.zeros_(self.time_gate_bias)                                 \n\n        # GRU cell for relation evolving\n        self.relation_cell_1 = nn.GRUCell(self.h_dim*2, self.h_dim)\n\n        # decoder\n        if decoder_name == \"convtranse\":\n            self.decoder_ob = ConvTransE(num_ents, h_dim, input_dropout, hidden_dropout, feat_dropout)\n            # self.rdecoder = ConvTransR(num_rels, h_dim, input_dropout, hidden_dropout, feat_dropout)\n        else:\n            raise NotImplementedError \n\n    def forward(self, g_list, static_graph, use_cuda):\n        gate_list = []\n        degree_list = []\n        # a = True\n        if self.use_static:\n            static_graph = static_graph.to(self.gpu)\n            static_graph.ndata['h'] = torch.cat((self.dynamic_emb, self.words_emb), dim=0)  # 演化得到的表示，和wordemb满足静态图约束\n            self.statci_rgcn_layer(static_graph, [])\n            static_emb = static_graph.ndata.pop('h')[:self.num_ents, :]\n            static_emb = F.normalize(static_emb) if self.layer_norm else static_emb\n            self.h = static_emb\n            a = torch.isnan(F.normalize(static_emb)).any() or torch.isinf(static_emb).any()\n            if a ==True:\n                print(\"static_emb is nan\")\n        else:\n            self.h = F.normalize(self.dynamic_emb) if self.layer_norm else self.dynamic_emb[:, :]\n            static_emb = None\n        history_embs = []\n\n        for i, g in enumerate(g_list):\n            g = g.to(self.gpu)\n            temp_e = self.h[g.r_to_e]\n            x_input = torch.zeros(self.num_rels * 2, self.h_dim).float().cuda() if use_cuda else torch.zeros(self.num_rels * 2, self.h_dim).float()\n            for span, r_idx in zip(g.r_len, g.uniq_r):\n                x = temp_e[span[0]:span[1],:]\n                x_mean = torch.mean(x, dim=0, keepdim=True)\n                x_input[r_idx] = x_mean\n            if i == 0:\n                x_input = torch.cat((self.emb_rel, x_input), dim=1)\n                self.h_0 = self.relation_cell_1(x_input, self.emb_rel)    # 第1层输入\n                self.h_0 = F.normalize(self.h_0) if self.layer_norm else self.h_0\n            else:\n                x_input = torch.cat((self.emb_rel, x_input), dim=1)\n                self.h_0 = self.relation_cell_1(x_input, self.h_0)  # 第2层输出==下一时刻第一层输入\n                self.h_0 = F.normalize(self.h_0) if self.layer_norm else self.h_0\n            current_h = self.rgcn.forward(g, self.h) #, [self.h_0, self.h_0])\n            current_h = F.normalize(current_h) if self.layer_norm else current_h\n            time_weight = F.sigmoid(torch.mm(self.h, self.time_gate_weight) + self.time_gate_bias)\n            self.h = time_weight * current_h + (1-time_weight) * self.h\n            history_embs.append(self.h)\n        return history_embs, static_emb, self.h_0, gate_list, degree_list\n\n\n    def predict(self, test_graph, num_rels, static_graph, test_triplets, use_cuda, neg_samples_batch=None,\n                pos_samples_batch=None, evaluator=None, metric=None):\n        perf_list = [None]*len(neg_samples_batch)\n        hits_list = [None]*len(neg_samples_batch)\n        with torch.no_grad():\n            # inverse_test_triplets = test_triplets[:, [2, 1, 0]]\n            # inverse_test_triplets[:, 1] = inverse_test_triplets[:, 1] + num_rels  # 将逆关系换成逆关系的id\n            all_triples =test_triplets # torch.cat((test_triplets, inverse_test_triplets))\n            \n            evolve_embs, _, r_emb, _, _ = self.forward(test_graph, static_graph, use_cuda)\n            embedding = F.normalize(evolve_embs[-1]) if self.layer_norm else evolve_embs[-1]\n            if neg_samples_batch != None: # added by tgb team\n                perf_list = []\n                hits_list = []\n                for query_id, query in enumerate(neg_samples_batch): # for each sample separately\n                    pos = pos_samples_batch[query_id]\n                    neg = torch.tensor(query).to(pos.device)\n                    all =torch.cat((pos.unsqueeze(0), neg), dim=0)\n                    score = self.decoder_ob.forward(embedding, r_emb, test_triplets[query_id].unsqueeze(0),\n                            samples_of_interest_emb=embedding[all] )\n                    # compute MRR\n                    input_dict = {\n                        \"y_pred_pos\": np.array([score[0,0].cpu()]),\n                        \"y_pred_neg\": np.array(score[0,1:].cpu()),\n                        \"eval_metric\": [metric],\n                    }\n                    prediction_perf = evaluator.eval(input_dict)\n                    perf_list.append(prediction_perf[metric])\n                    hits_list.append(prediction_perf['hits@10'])\n            else:\n                score = self.decoder_ob.forward(embedding, r_emb, all_triples, mode=\"test\")\n            # score_rel = self.rdecoder.forward(embedding, r_emb, all_triples, mode=\"test\")\n            return score, perf_list, hits_list\n        \n    def get_mask_nonzero(self, static_embedding):\n        \"\"\" Each element of this resulting tensor will be True if the sum of the corresponding row in \n        static_emb is not zero, and False otherwise\n        \"\"\"\n        mask = torch.sum(static_embedding, dim=1) != 0\n        return mask\n\n    def get_loss(self, glist, triples, static_graph, use_cuda):\n        \"\"\"\n        :param glist:\n        :param triplets:\n        :param static_graph: \n        :param use_cuda:\n        :return:\n        \"\"\"\n        loss_ent = torch.zeros(1).cuda().to(self.gpu) if use_cuda else torch.zeros(1)\n        loss_rel = torch.zeros(1).cuda().to(self.gpu) if use_cuda else torch.zeros(1)\n        loss_static = torch.zeros(1).cuda().to(self.gpu) if use_cuda else torch.zeros(1)\n\n        # inverse_triples = triples[:, [2, 1, 0]]\n        # inverse_triples[:, 1] = inverse_triples[:, 1] + self.num_rels\n        all_triples = triples #torch.cat([triples, inverse_triples])\n        all_triples = all_triples.to(self.gpu)\n\n        evolve_embs, static_emb, r_emb, _, _ = self.forward(glist, static_graph, use_cuda)\n        pre_emb = F.normalize(evolve_embs[-1]) if self.layer_norm else evolve_embs[-1]\n\n        if self.entity_prediction:\n            scores_ob = self.decoder_ob.forward(pre_emb, r_emb, all_triples).view(-1, self.num_ents)\n            loss_ent += self.loss_e(scores_ob, all_triples[:, 2])\n     \n\n        if self.use_static:\n            if self.discount == 1:\n                for time_step, evolve_emb in enumerate(evolve_embs):\n                    step = (self.angle * math.pi / 180) * (time_step + 1)\n                    if self.layer_norm:\n                        a= torch.isnan(F.normalize(evolve_emb)).any() or torch.isinf(evolve_emb).any()\n                        if a ==True:\n                            print(\"evolve_emb is nan\")\n                        sim_matrix = torch.sum(static_emb * F.normalize(evolve_emb), dim=1)\n                        a = torch.isnan(sim_matrix).any() or torch.isinf(sim_matrix).any()\n                        if a ==True:\n                            print(\"sim_matrix is nan\")\n                    else:\n                        sim_matrix = torch.sum(static_emb * evolve_emb, dim=1)\n                        c = torch.norm(static_emb, p=2, dim=1) * torch.norm(evolve_emb, p=2, dim=1)\n                        non_zero_mask = c != 0\n\n                        # Initialize b_sim_matrix with zeros (or another appropriate value)\n                        sim_matrix = torch.zeros_like(sim_matrix)\n\n                        # Perform division only where c is not zero\n                        sim_matrix[non_zero_mask] = sim_matrix[non_zero_mask] / c[non_zero_mask]\n                        # sim_matrix = sim_matrix / c\n                    mask = (math.cos(step) - sim_matrix) > 0\n                    # mask = self.get_mask_nonzero(static_emb) #modified! to only consider non-zero rows\n                    loss_static += self.weight * torch.sum(torch.masked_select(math.cos(step) - sim_matrix, mask))\n            elif self.discount == 0:\n                for time_step, evolve_emb in enumerate(evolve_embs):\n                    step = (self.angle * math.pi / 180)\n                    if self.layer_norm:\n                        sim_matrix = torch.sum(static_emb * F.normalize(evolve_emb), dim=1)\n                    else:\n                        sim_matrix = torch.sum(static_emb * evolve_emb, dim=1)\n                        c = torch.norm(static_emb, p=2, dim=1) * torch.norm(evolve_emb, p=2, dim=1)\n                        sim_matrix = sim_matrix / c\n                    mask = (math.cos(step) - sim_matrix) > 0\n                    loss_static += self.weight * torch.sum(torch.masked_select(math.cos(step) - sim_matrix, mask))\n\n        return loss_ent, loss_rel, loss_static"
  },
  {
    "path": "modules/sampler_core.cpp",
    "content": "#include <iostream>\n#include <string>\n#include <cstdlib>\n#include <random>\n#include <omp.h>\n#include <math.h>\n#include <pybind11/pybind11.h>\n#include <pybind11/numpy.h>\n#include <pybind11/stl.h>\n\nnamespace py = pybind11;\n\ntypedef int NodeIDType;\ntypedef int EdgeIDType;\ntypedef float TimeStampType;\n\nclass TemporalGraphBlock\n{\npublic:\n    std::vector<NodeIDType> row;\n    std::vector<NodeIDType> col;\n    std::vector<EdgeIDType> eid;\n    std::vector<TimeStampType> ts;\n    std::vector<TimeStampType> dts;\n    std::vector<NodeIDType> nodes;\n    NodeIDType dim_in, dim_out;\n    double ptr_time = 0;\n    double search_time = 0;\n    double sample_time = 0;\n    double tot_time = 0;\n    double coo_time = 0;\n\n    TemporalGraphBlock() {}\n    TemporalGraphBlock(std::vector<NodeIDType> &_row, std::vector<NodeIDType> &_col,\n                       std::vector<EdgeIDType> &_eid, std::vector<TimeStampType> &_ts,\n                       std::vector<TimeStampType> &_dts, std::vector<NodeIDType> &_nodes,\n                       NodeIDType _dim_in, NodeIDType _dim_out) : row(_row), col(_col), eid(_eid), ts(_ts), dts(_dts),\n                                                                  nodes(_nodes), dim_in(_dim_in), dim_out(_dim_out) {}\n};\n\nclass ParallelSampler\n{\npublic:\n    std::vector<EdgeIDType> indptr;\n    std::vector<EdgeIDType> indices;\n    std::vector<EdgeIDType> eid;\n    std::vector<TimeStampType> ts;\n    NodeIDType num_nodes;\n    EdgeIDType num_edges;\n    int num_thread_per_worker;\n    int num_workers;\n    int num_threads;\n    int num_layers;\n    std::vector<int> num_neighbors;\n    bool recent;\n    bool prop_time;\n    int num_history;\n    TimeStampType window_duration;\n    std::vector<std::vector<std::vector<EdgeIDType>::size_type>> ts_ptr;\n    omp_lock_t *ts_ptr_lock;\n    std::vector<TemporalGraphBlock> ret;\n\n    ParallelSampler(std::vector<EdgeIDType> &_indptr, std::vector<EdgeIDType> &_indices,\n                    std::vector<EdgeIDType> &_eid, std::vector<TimeStampType> &_ts,\n                    int _num_thread_per_worker, int _num_workers, int _num_layers,\n                    std::vector<int> &_num_neighbors, bool _recent, bool _prop_time,\n                    int _num_history, TimeStampType _window_duration) : indptr(_indptr), indices(_indices), eid(_eid), ts(_ts), prop_time(_prop_time),\n                                                                        num_thread_per_worker(_num_thread_per_worker), num_workers(_num_workers),\n                                                                        num_layers(_num_layers), num_neighbors(_num_neighbors), recent(_recent),\n                                                                        num_history(_num_history), window_duration(_window_duration)\n    {\n        omp_set_num_threads(num_thread_per_worker * num_workers);\n        num_threads = num_thread_per_worker * num_workers;\n        num_nodes = indptr.size() - 1;\n        num_edges = indices.size();\n        ts_ptr_lock = (omp_lock_t *)malloc(num_nodes * sizeof(omp_lock_t));\n        for (int i = 0; i < num_nodes; i++)\n            omp_init_lock(&ts_ptr_lock[i]);\n        ts_ptr.resize(num_history + 1);\n        for (auto it = ts_ptr.begin(); it != ts_ptr.end(); it++)\n        {\n            it->resize(indptr.size() - 1);\n#pragma omp parallel for\n            for (auto itt = indptr.begin(); itt < indptr.end() - 1; itt++)\n                (*it)[itt - indptr.begin()] = *itt;\n        }\n    }\n\n    void reset()\n    {\n        for (auto it = ts_ptr.begin(); it != ts_ptr.end(); it++)\n        {\n            it->resize(indptr.size() - 1);\n#pragma omp parallel for\n            for (auto itt = indptr.begin(); itt < indptr.end() - 1; itt++)\n                (*it)[itt - indptr.begin()] = *itt;\n        }\n    }\n\n    void update_ts_ptr(int slc, std::vector<NodeIDType> &root_nodes,\n                       std::vector<TimeStampType> &root_ts, float offset)\n    {\n#pragma omp parallel for schedule(static, int(ceil(static_cast <float>(root_nodes.size()) / num_threads)))\n        for (std::vector<NodeIDType>::size_type i = 0; i < root_nodes.size(); i++)\n        {\n            NodeIDType n = root_nodes[i];\n            omp_set_lock(&(ts_ptr_lock[n]));\n            for (std::vector<EdgeIDType>::size_type j = ts_ptr[slc][n]; j < indptr[n + 1]; j++)\n            {\n                // std::cout << \"comparing \" << ts[j] << \" with \" << root_ts[i] << std::endl;\n                if (ts[j] > (root_ts[i] + offset - 1e-7f))\n                {\n                    if (j != ts_ptr[slc][n])\n                        ts_ptr[slc][n] = j - 1;\n                    break;\n                }\n                if (j == indptr[n + 1] - 1)\n                {\n                    ts_ptr[slc][n] = j;\n                }\n            }\n            omp_unset_lock(&(ts_ptr_lock[n]));\n        }\n    }\n\n    inline void add_neighbor(std::vector<NodeIDType> *_row, std::vector<NodeIDType> *_col,\n                             std::vector<EdgeIDType> *_eid, std::vector<TimeStampType> *_ts,\n                             std::vector<TimeStampType> *_dts, std::vector<NodeIDType> *_nodes,\n                             EdgeIDType &k, TimeStampType &src_ts, int &row_id)\n    {\n        _row->push_back(row_id);\n        _col->push_back(_nodes->size());\n        _eid->push_back(eid[k]);\n        if (prop_time)\n            _ts->push_back(src_ts);\n        else\n            _ts->push_back(ts[k]);\n        _dts->push_back(src_ts - ts[k]);\n        _nodes->push_back(indices[k]);\n        // _row.push_back(0);\n        // _col.push_back(0);\n        // _eid.push_back(0);\n        // if (prop_time)\n        //     _ts.push_back(src_ts);\n        // else\n        //     _ts.push_back(10000);\n        // _nodes.push_back(100);\n    }\n\n    inline void combine_coo(TemporalGraphBlock &_ret, std::vector<NodeIDType> **_row,\n                            std::vector<NodeIDType> **_col,\n                            std::vector<EdgeIDType> **_eid,\n                            std::vector<TimeStampType> **_ts,\n                            std::vector<TimeStampType> **_dts,\n                            std::vector<NodeIDType> **_nodes,\n                            std::vector<int> &_out_nodes)\n    {\n        std::vector<EdgeIDType> cum_row, cum_col;\n        cum_row.push_back(0);\n        cum_col.push_back(0);\n        for (int tid = 0; tid < num_threads; tid++)\n        {\n            // std::cout<<tid<<\" here \"<<_out_nodes[tid]<<std::endl;\n            cum_row.push_back(cum_row.back() + _out_nodes[tid]);\n            cum_col.push_back(cum_col.back() + _col[tid]->size());\n        }\n        int num_root_nodes = _ret.nodes.size();\n        _ret.row.resize(cum_col.back());\n        _ret.col.resize(cum_col.back());\n        _ret.eid.resize(cum_col.back());\n        _ret.ts.resize(cum_col.back() + num_root_nodes);\n        _ret.dts.resize(cum_col.back() + num_root_nodes);\n        _ret.nodes.resize(cum_col.back() + num_root_nodes);\n#pragma omp parallel for schedule(static, 1)\n        for (int tid = 0; tid < num_threads; tid++)\n        {\n            std::transform(_row[tid]->begin(), _row[tid]->end(), _row[tid]->begin(),\n                           [&](auto &v)\n                           { return v + cum_row[tid]; });\n            std::transform(_col[tid]->begin(), _col[tid]->end(), _col[tid]->begin(),\n                           [&](auto &v)\n                           { return v + cum_col[tid] + num_root_nodes; });\n            std::copy(_row[tid]->begin(), _row[tid]->end(), _ret.row.begin() + cum_col[tid]);\n            std::copy(_col[tid]->begin(), _col[tid]->end(), _ret.col.begin() + cum_col[tid]);\n            std::copy(_eid[tid]->begin(), _eid[tid]->end(), _ret.eid.begin() + cum_col[tid]);\n            std::copy(_ts[tid]->begin(), _ts[tid]->end(), _ret.ts.begin() + cum_col[tid] + num_root_nodes);\n            std::copy(_dts[tid]->begin(), _dts[tid]->end(), _ret.dts.begin() + cum_col[tid] + num_root_nodes);\n            std::copy(_nodes[tid]->begin(), _nodes[tid]->end(), _ret.nodes.begin() + cum_col[tid] + num_root_nodes);\n            delete _row[tid];\n            delete _col[tid];\n            delete _eid[tid];\n            delete _ts[tid];\n            delete _dts[tid];\n            delete _nodes[tid];\n        }\n        _ret.dim_in = _ret.nodes.size();\n        _ret.dim_out = cum_row.back();\n    }\n\n    void sample_layer(std::vector<NodeIDType> &_root_nodes, std::vector<TimeStampType> &_root_ts,\n                      int neighs, bool use_ptr, bool from_root)\n    {\n        double t_s = omp_get_wtime();\n        std::vector<NodeIDType> *root_nodes;\n        std::vector<TimeStampType> *root_ts;\n        if (from_root)\n        {\n            root_nodes = &_root_nodes;\n            root_ts = &_root_ts;\n        }\n        double t_ptr_s = omp_get_wtime();\n        if (use_ptr)\n            update_ts_ptr(num_history, *root_nodes, *root_ts, 0);\n        ret[0].ptr_time += omp_get_wtime() - t_ptr_s;\n        for (int i = 0; i < num_history; i++)\n        {\n            if (!from_root)\n            {\n                root_nodes = &(ret[ret.size() - 1 - i - num_history].nodes);\n                root_ts = &(ret[ret.size() - 1 - i - num_history].ts);\n            }\n            TimeStampType offset = -i * window_duration;\n            t_ptr_s = omp_get_wtime();\n            if ((use_ptr) && (std::abs(window_duration) > 1e-7f))\n                update_ts_ptr(num_history - 1 - i, *root_nodes, *root_ts, offset - window_duration);\n            ret[0].ptr_time += omp_get_wtime() - t_ptr_s;\n            std::vector<NodeIDType> *_row[num_threads];\n            std::vector<NodeIDType> *_col[num_threads];\n            std::vector<EdgeIDType> *_eid[num_threads];\n            std::vector<TimeStampType> *_ts[num_threads];\n            std::vector<TimeStampType> *_dts[num_threads];\n            std::vector<NodeIDType> *_nodes[num_threads];\n            std::vector<int> _out_node(num_threads, 0);\n            int reserve_capacity = int(ceil((*root_nodes).size() / num_threads)) * neighs;\n#pragma omp parallel\n            {\n                int tid = omp_get_thread_num();\n                unsigned int loc_seed = tid;\n                _row[tid] = new std::vector<NodeIDType>;\n                _col[tid] = new std::vector<NodeIDType>;\n                _eid[tid] = new std::vector<EdgeIDType>;\n                _ts[tid] = new std::vector<TimeStampType>;\n                _dts[tid] = new std::vector<TimeStampType>;\n                _nodes[tid] = new std::vector<NodeIDType>;\n                _row[tid]->reserve(reserve_capacity);\n                _col[tid]->reserve(reserve_capacity);\n                _eid[tid]->reserve(reserve_capacity);\n                _ts[tid]->reserve(reserve_capacity);\n                _dts[tid]->reserve(reserve_capacity);\n                _nodes[tid]->reserve(reserve_capacity);\n// #pragma omp critical\n//                     std::cout<<tid<<\" sampling: \"<<root_nodes->size()<<\" \"<<int(ceil((*root_nodes).size() / num_threads))<<std::endl;\n#pragma omp for schedule(static, int(ceil(static_cast <float>((*root_nodes).size()) / num_threads)))\n                for (std::vector<NodeIDType>::size_type j = 0; j < (*root_nodes).size(); j++)\n                {\n                    NodeIDType n = (*root_nodes)[j];\n                    // if (tid == 16)\n                    //     std::cout << _out_node[tid] << \" \" <<j << \" \" << n << std::endl;\n                    TimeStampType nts = (*root_ts)[j];\n                    EdgeIDType s_search, e_search;\n                    if (use_ptr)\n                    {\n                        s_search = ts_ptr[num_history - 1 - i][n];\n                        e_search = ts_ptr[num_history - i][n];\n                    }\n                    else\n                    {\n                        // search for start and end pointer\n                        double t_search_s = omp_get_wtime();\n                        if (num_history == 1)\n                        {\n                            // TGAT style\n                            s_search = indptr[n];\n                            auto e_it = std::upper_bound(ts.begin() + indptr[n],\n                                                         ts.begin() + indptr[n + 1], nts);\n                            e_search = std::max(int(e_it - ts.begin()) - 1, s_search);\n                        }\n                        else\n                        {\n                            // DySAT style\n                            auto s_it = std::upper_bound(ts.begin() + indptr[n],\n                                                         ts.begin() + indptr[n + 1],\n                                                         nts + offset - window_duration);\n                            s_search = std::max(int(s_it - ts.begin()) - 1, indptr[n]);\n                            auto e_it = std::upper_bound(ts.begin() + indptr[n],\n                                                         ts.begin() + indptr[n + 1], nts + offset);\n                            e_search = std::max(int(e_it - ts.begin()) - 1, s_search);\n                        }\n                        if (tid == 0)\n                            ret[0].search_time += omp_get_wtime() - t_search_s;\n                    }\n                    // std::cout << n << \" \" << s_search << \" \" << e_search << std::endl;\n                    double t_sample_s = omp_get_wtime();\n                    if ((recent) || (e_search - s_search < neighs))\n                    {\n                        // no sampling, pick recent neighbors\n                        for (EdgeIDType k = e_search; k > std::max(s_search, e_search - neighs); k--)\n                        {\n                            if (ts[k] < nts + offset - 1e-7f)\n                            {\n                                add_neighbor(_row[tid], _col[tid], _eid[tid], _ts[tid],\n                                             _dts[tid], _nodes[tid], k, nts, _out_node[tid]);\n                            }\n                        }\n                    }\n                    else\n                    {\n                        // random sampling within ptr\n                        for (int _i = 0; _i < neighs; _i++)\n                        {\n                            EdgeIDType picked = s_search + rand_r(&loc_seed) % (e_search - s_search + 1);\n                            if (ts[picked] < nts + offset - 1e-7f)\n                            {\n                                add_neighbor(_row[tid], _col[tid], _eid[tid], _ts[tid],\n                                             _dts[tid], _nodes[tid], picked, nts, _out_node[tid]);\n                            }\n                        }\n                    }\n                    _out_node[tid] += 1;\n                    if (tid == 0)\n                        ret[0].sample_time += omp_get_wtime() - t_sample_s;\n                }\n            }\n            double t_coo_s = omp_get_wtime();\n            ret[ret.size() - 1 - i].ts.insert(ret[ret.size() - 1 - i].ts.end(),\n                                              root_ts->begin(), root_ts->end());\n            ret[ret.size() - 1 - i].nodes.insert(ret[ret.size() - 1 - i].nodes.end(),\n                                                 root_nodes->begin(), root_nodes->end());\n            ret[ret.size() - 1 - i].dts.resize(root_nodes->size());\n            combine_coo(ret[ret.size() - 1 - i], _row, _col, _eid, _ts, _dts, _nodes, _out_node);\n            ret[0].coo_time += omp_get_wtime() - t_coo_s;\n        }\n        ret[0].tot_time += omp_get_wtime() - t_s;\n    }\n\n    void sample(std::vector<NodeIDType> &root_nodes, std::vector<TimeStampType> &root_ts)\n    {\n        // a weird bug, dgl library seems to modify the total number of threads\n        omp_set_num_threads(num_threads);\n        ret.resize(0);\n        bool first_layer = true;\n        bool use_ptr = false;\n        for (int i = 0; i < num_layers; i++)\n        {\n            ret.resize(ret.size() + num_history);\n            if ((first_layer) || ((prop_time) && num_history == 1) || (recent))\n            {\n                first_layer = false;\n                use_ptr = true;\n            }\n            else\n                use_ptr = false;\n            if (i == 0)\n                sample_layer(root_nodes, root_ts, num_neighbors[i], use_ptr, true);\n            else\n                sample_layer(root_nodes, root_ts, num_neighbors[i], use_ptr, false);\n        }\n    }\n};\n\ntemplate <typename T>\ninline py::array vec2npy(const std::vector<T> &vec)\n{\n    // need to let python garbage collector handle C++ vector memory\n    // see https://github.com/pybind/pybind11/issues/1042\n    auto v = new std::vector<T>(vec);\n    auto capsule = py::capsule(v, [](void *v)\n                               { delete reinterpret_cast<std::vector<T> *>(v); });\n    return py::array(v->size(), v->data(), capsule);\n    // return py::array(vec.size(), vec.data());\n}\n\nPYBIND11_MODULE(sampler_core, m)\n{\n    py::class_<TemporalGraphBlock>(m, \"TemporalGraphBlock\")\n        .def(py::init<std::vector<NodeIDType> &, std::vector<NodeIDType> &,\n                      std::vector<EdgeIDType> &, std::vector<TimeStampType> &,\n                      std::vector<TimeStampType> &, std::vector<NodeIDType> &,\n                      NodeIDType, NodeIDType>())\n        .def(\"row\", [](const TemporalGraphBlock &tgb)\n             { return vec2npy(tgb.row); })\n        .def(\"col\", [](const TemporalGraphBlock &tgb)\n             { return vec2npy(tgb.col); })\n        .def(\"eid\", [](const TemporalGraphBlock &tgb)\n             { return vec2npy(tgb.eid); })\n        .def(\"ts\", [](const TemporalGraphBlock &tgb)\n             { return vec2npy(tgb.ts); })\n        .def(\"dts\", [](const TemporalGraphBlock &tgb)\n             { return vec2npy(tgb.dts); })\n        .def(\"nodes\", [](const TemporalGraphBlock &tgb)\n             { return vec2npy(tgb.nodes); })\n        .def(\"dim_in\", [](const TemporalGraphBlock &tgb)\n             { return tgb.dim_in; })\n        .def(\"dim_out\", [](const TemporalGraphBlock &tgb)\n             { return tgb.dim_out; })\n        .def(\"tot_time\", [](const TemporalGraphBlock &tgb)\n             { return tgb.tot_time; })\n        .def(\"ptr_time\", [](const TemporalGraphBlock &tgb)\n             { return tgb.ptr_time; })\n        .def(\"search_time\", [](const TemporalGraphBlock &tgb)\n             { return tgb.search_time; })\n        .def(\"sample_time\", [](const TemporalGraphBlock &tgb)\n             { return tgb.sample_time; })\n        .def(\"coo_time\", [](const TemporalGraphBlock &tgb)\n             { return tgb.coo_time; });\n             \n    py::class_<ParallelSampler>(m, \"ParallelSampler\")\n        .def(py::init<std::vector<EdgeIDType> &, std::vector<EdgeIDType> &,\n                      std::vector<EdgeIDType> &, std::vector<TimeStampType> &,\n                      int, int, int, std::vector<int> &, bool, bool,\n                      int, TimeStampType>())\n        .def(\"sample\", &ParallelSampler::sample)\n        .def(\"reset\", &ParallelSampler::reset)\n        .def(\"get_ret\", [](const ParallelSampler &ps)\n             { return ps.ret; });\n}"
  },
  {
    "path": "modules/sthn.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom typing import Optional\nimport numpy as np\nfrom torch import Tensor\n\n\nfrom tqdm import tqdm\nfrom sampler_core import ParallelSampler\nimport torch_sparse\n\n\nimport time\nimport copy\nimport random\nfrom torch_sparse import SparseTensor\nfrom torchmetrics.classification import MulticlassAUROC, MulticlassAveragePrecision\nfrom torchmetrics.classification import BinaryAUROC, BinaryAveragePrecision\nfrom sklearn.preprocessing import MinMaxScaler\nimport os\nimport pickle\n\n\n\"\"\"\nSource: STHN: utils.py\nURL: https://github.com/celi52/STHN/blob/main/utils.py\n\"\"\"\n\n# utility function\ndef set_seed(seed):\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n\ndef row_norm(adj_t):\n    if isinstance(adj_t, torch_sparse.SparseTensor):\n        # adj_t = torch_sparse.fill_diag(adj, 1)\n        deg = torch_sparse.sum(adj_t, dim=1)\n        deg_inv = 1. / deg\n        deg_inv.masked_fill_(deg_inv == float('inf'), 0.)\n        adj_t = torch_sparse.mul(adj_t, deg_inv.view(-1, 1))\n        return adj_t\n\n\n\"\"\"\nSource: STHN: construct_subgraph.py\nURL: https://github.com/celi52/STHN/blob/main/construct_subgraph.py\n\nNotes: The NegLinkSampler is only used for STHN internal sampling and not for TGB\n\"\"\"\n\n\n##############################################################################\n##############################################################################\n##############################################################################\n\n\n# get sampler\nclass NegLinkSampler:\n    \"\"\"\n    From https://github.com/amazon-research/tgl/blob/main/sampler.py\n    \"\"\"\n    def __init__(self, num_nodes):\n        self.num_nodes = num_nodes\n\n    def sample(self, n):\n        return np.random.randint(self.num_nodes, size=n)\n    \ndef get_parallel_sampler(g, num_neighbors=10):\n    \"\"\"\n    Function wrapper of the C++ sampler (https://github.com/amazon-research/tgl/blob/main/sampler_core.cpp)\n    Sample the 1-hop most recent neighbors of each node\n    \"\"\"\n\n    configs = [\n        g['indptr'],       # indptr --> fixed: data info\n        g['indices'],      # indices --> fixed: data info\n        g['eid'],          # eid --> fixed: data info\n        g['ts'],           # ts --> fixed: data info\n        32, # num_thread_per_worker --> change this based on machine's setup\n        1,  # num_workers --> change this based on machine's setup\n        1,  # num_layers --> change this based on machine's setup\n        [num_neighbors],   # num_neighbors --> hyper-parameters. Reddit 10, WIKI 30\n        True,  # recent --> fixed: never touch\n        False, # prop_time --> never touch\n        1,     # num_history --> fixed: never touch\n        0      # window_duration --> fixed: never touch\n    ]\n    \n    sampler = ParallelSampler(*configs)       \n    neg_link_sampler = NegLinkSampler(g['indptr'].shape[0] - 1)\n    return sampler, neg_link_sampler\n    \n##############################################################################\n##############################################################################\n##############################################################################\n# sampling\n\ndef get_mini_batch(sampler, root_nodes, ts, num_hops): # neg_samples is not used\n    \"\"\"\n    Call function fetch_subgraph()\n    Return: Subgraph of each node. \n    \"\"\"\n    all_graphs = []\n    \n    for root_node, root_time in zip(root_nodes, ts):\n        all_graphs.append(fetch_subgraph(sampler, root_node, root_time, num_hops))\n\n    return all_graphs\n\ndef fetch_subgraph(sampler, root_node, root_time, num_hops):\n    \"\"\"\n    Sample a subgraph for each node or node pair\n    \"\"\"\n    all_row_col_times_nodes_eid = []\n\n    # suppose sampling for both a single node and a node pair (two side of a link)\n    if isinstance(root_node, list):\n        nodes, ts = [i for i in root_node], [root_time for i in root_node]\n    else:\n        nodes, ts = [root_node], [root_time]\n    \n    # fetch all nodes+edges\n    for _ in range(num_hops):\n        sampler.sample(nodes, ts)\n        ret = sampler.get_ret() # 1-hop recent neighbors\n        row, col, eid = ret[0].row(), ret[0].col(), ret[0].eid()\n        nodes, ts = ret[0].nodes(), ret[0].ts().astype(np.float32)\n             \n        row_col_times_nodes_eid = np.stack([ts[row], nodes[row], ts[col], nodes[col], eid]).T\n        all_row_col_times_nodes_eid.append(row_col_times_nodes_eid)\n    all_row_col_times_nodes_eid = np.concatenate(all_row_col_times_nodes_eid, axis=0)\n\n    # remove duplicate edges and sort according to the root node time (descending)\n    all_row_col_times_nodes_eid = np.unique(all_row_col_times_nodes_eid, axis=0)[::-1]\n    all_row_col_times_nodes = all_row_col_times_nodes_eid[:, :-1]\n    eid = all_row_col_times_nodes_eid[:, -1]\n\n    # remove duplicate (node+time) and sorted by time decending order\n    all_row_col_times_nodes = np.array_split(all_row_col_times_nodes, 2, axis=1)\n    times_nodes = np.concatenate(all_row_col_times_nodes, axis=0)\n    times_nodes = np.unique(times_nodes, axis=0)[::-1]\n    \n    # each (node, time) pair identifies a node\n    node_2_ind = dict()\n    for ind, (time, node) in enumerate(times_nodes):\n        node_2_ind[(time, node)] = ind\n\n    # translate the nodes into new index\n    row = np.zeros(len(eid), dtype=np.int32)\n    col = np.zeros(len(eid), dtype=np.int32)\n    for i, ((t1, n1), (t2, n2)) in enumerate(zip(*all_row_col_times_nodes)):\n        row[i] = node_2_ind[(t1, n1)]\n        col[i] = node_2_ind[(t2, n2)]\n        \n    # fetch get time + node information\n    eid = eid.astype(np.int32)\n    ts = times_nodes[:,0].astype(np.float32)\n    nodes = times_nodes[:,1].astype(np.int32)\n    dts = root_time - ts # make sure the root node time is 0\n    \n    return {\n        # edge info: sorted with descending row (src) node temporal order\n        'row': row, # src\n        'col': col, # dst\n        'eid': eid, \n        # node info\n        'nodes': nodes , # sorted by the ascending order of node's dts (root_node's dts = 0)\n        'dts': dts,\n        # graph info\n        'num_nodes': len(nodes),\n        'num_edges': len(eid),\n        # root info\n        'root_node': root_node,\n        'root_time': root_time,\n    }\n\n\ndef construct_mini_batch_giant_graph(all_graphs, max_num_edges):\n    \"\"\"\n    Take the subgraph computed by fetch_subgraph() and combine it into a giant graph\n    Return: the new indices of the graph\n    \"\"\"\n    \n    all_rows, all_cols, all_eids, all_nodes, all_dts = [], [], [], [], []\n    \n    cumsum_edges = 0\n    all_edge_indptr = [0]\n    \n    cumsum_nodes = 0\n    all_node_indptr = [0]\n    \n    all_root_nodes = []\n    all_root_times = []\n    for all_graph in all_graphs:\n        # record inds\n        num_nodes = all_graph['num_nodes']\n        num_edges = min(all_graph['num_edges'], max_num_edges)\n\n        # add graph information\n        all_rows.append(all_graph['row'][:num_edges] + cumsum_nodes)\n        all_cols.append(all_graph['col'][:num_edges] + cumsum_nodes)\n        all_eids.append(all_graph['eid'][:num_edges])\n        \n        all_nodes.append(all_graph['nodes'])\n        all_dts.append(all_graph['dts'])\n\n        # update cumsum\n        cumsum_nodes += num_nodes\n        all_node_indptr.append(cumsum_nodes)\n        \n        cumsum_edges += num_edges\n        all_edge_indptr.append(cumsum_edges)\n        \n        # add root nodes\n        all_root_nodes.append(all_graph['root_node'])\n        all_root_times.append(all_graph['root_time'])\n    # for each edges\n    all_rows = np.concatenate(all_rows).astype(np.int32)\n    all_cols = np.concatenate(all_cols).astype(np.int32)\n    all_eids = np.concatenate(all_eids).astype(np.int32)\n    all_edge_indptr = np.array(all_edge_indptr).astype(np.int32)\n    \n    # for each nodes\n    all_nodes = np.concatenate(all_nodes).astype(np.int32)\n    all_dts = np.concatenate(all_dts).astype(np.float32)\n    all_node_indptr = np.array(all_node_indptr).astype(np.int32)\n        \n    return {\n        # for edges\n        'row': all_rows, \n        'col': all_cols, \n        'eid': all_eids, \n        'edts': all_dts[all_cols] - all_dts[all_rows],\n        # number of subgraphs + 1\n        'all_node_indptr': all_node_indptr,\n        'all_edge_indptr': all_edge_indptr,\n        # for nodes\n        'nodes': all_nodes, \n        'dts': all_dts, \n        # general information\n        'all_num_nodes': cumsum_nodes,\n        'all_num_edges': cumsum_edges,\n        # root nodes\n        'root_nodes': np.array(all_root_nodes, dtype=np.int32), \n        'root_times': np.array(all_root_times, dtype=np.float32), \n    }\n\n##############################################################################\n##############################################################################\n##############################################################################\n\ndef print_subgraph_data(subgraph_data):\n    \"\"\"\n    Used to double check see if the sampled graph is as expected\n    \"\"\"\n    for key, vals in subgraph_data.items():\n        if isinstance(vals, np.ndarray):\n            print(key, vals.shape)\n        else:\n            print(key, vals)\n\n\n\"\"\"\nSource: STHN data_process_utils.py\nURL: https://github.com/celi52/STHN/blob/main/data_process_utils.py\n\nNote:\nCurrently only using pre_compute_subgraphs because use_cached_subgraph is True\nget_subgraph_sampler needs to be modified if use_cached_subgraph is False\n\nThe function get_all_inds is new to handle TGB evaluation\n\"\"\"\n\n\nclass SubgraphSampler:\n    def __init__(self, all_root_nodes, all_ts, sampler, args):\n        self.all_root_nodes = all_root_nodes\n        self.all_ts = all_ts\n        self.sampler = sampler\n        self.sampled_num_hops = args.sampled_num_hops\n\n    def mini_batch(self, ind, mini_batch_inds):\n        root_nodes = self.all_root_nodes[ind][mini_batch_inds]\n        ts = self.all_ts[ind][mini_batch_inds]\n        return get_mini_batch(self.sampler, root_nodes, ts, self.sampled_num_hops)\n\ndef get_subgraph_sampler(args, g, df, mode):\n    ###################################################\n    # get cached file_name\n    if mode == 'train':\n        extra_neg_samples = args.extra_neg_samples\n    else:\n        extra_neg_samples = 1\n\n    ###################################################\n    # for each node, sample its neighbors with the most recent neighbors (sorted) \n    print('Sample subgraphs ... for %s mode'%mode)\n    sampler, neg_link_sampler = get_parallel_sampler(g, args.num_neighbors)\n\n    ###################################################\n    # setup modes\n    if mode == 'train':\n        cur_df = df[args.train_mask]\n\n    elif mode == 'valid':\n        cur_df = df[args.val_mask]\n\n    elif mode == 'test':\n        cur_df = df[args.test_mask]\n\n    loader = cur_df.groupby(cur_df.index // args.batch_size)\n    print(cur_df.index, cur_df.index // args.batch_size)\n    pbar = tqdm(total=len(loader))\n    pbar.set_description('Pre-sampling: %s mode with negative sampleds %s ...'%(mode, extra_neg_samples))\n\n    all_root_nodes = []\n    all_ts = []\n    for _, rows in loader:\n\n        root_nodes = np.concatenate(\n            [rows.src.values, \n            rows.dst.values, \n            neg_link_sampler.sample(len(rows) * extra_neg_samples)]\n        ).astype(np.int32)\n        all_root_nodes.append(root_nodes)\n\n        # time-stamp for node = edge time-stamp\n        ts = np.tile(rows.time.values, extra_neg_samples + 2).astype(np.float32)\n        all_ts.append(ts)\n\n        pbar.update(1)\n    pbar.close()\n    return SubgraphSampler(all_root_nodes, all_ts, sampler, args)\n\n######################################################################################################\n######################################################################################################\n######################################################################################################\n# for small dataset, we can cache each graph\ndef pre_compute_subgraphs(args, g, df, mode, negative_sampler=None, split_mode='test', cache=False):\n    ###################################################\n    # get cached file_name\n    if mode == 'train':\n        extra_neg_samples = args.extra_neg_samples\n    else:\n        extra_neg_samples = 1\n\n    fn = os.path.join(os.getcwd(), 'DATA', args.data, \n                        '%s_neg_sample_neg%d_bs%d_hops%d_neighbors%d.pickle'%(mode, \n                                                                            extra_neg_samples, \n                                                                            args.batch_size, \n                                                                            args.sampled_num_hops, \n                                                                          args.num_neighbors))\n    ###################################################\n\n    # # try:\n    if os.path.exists(fn):\n        subgraph_elabel = pickle.load(open(fn, 'rb'))\n        # print('load ', fn)\n\n    else:\n        ##################################################\n        # for each node, sample its neighbors with the most recent neighbors (sorted) \n        print('Sample subgraphs ... for %s mode'%mode)\n        sampler, neg_link_sampler = get_parallel_sampler(g, args.num_neighbors)\n\n        ###################################################\n        # setup modes\n        if mode == 'train':\n            cur_df = df[args.train_mask]\n\n        elif mode == 'valid':\n            cur_df = df[args.val_mask]\n\n        elif mode == 'test':\n            cur_df = df[args.test_mask]\n\n        loader = cur_df.groupby(cur_df.index // args.batch_size)\n        pbar = tqdm(total=len(loader))\n        pbar.set_description('Pre-sampling: %s mode'%(mode,))\n\n        ###################################################\n        all_subgraphs = []\n        all_elabel = []\n        sampler.reset()\n        for _, rows in loader:\n            \n            if negative_sampler is not None:\n                neg_batch_list = negative_sampler.query_batch(\n                    rows.src.values,\n                    rows.dst.values,\n                    rows.time.values,\n                    rows.label.values,\n                    split_mode=split_mode\n                )\n                neg_batch_list = np.concatenate(neg_batch_list)\n                extra_neg_samples = neg_batch_list.shape[0] // len(rows)\n            else:\n                neg_batch_list = neg_link_sampler.sample(len(rows) * extra_neg_samples)\n\n            root_nodes = np.concatenate(\n                [rows.src.values, \n                    rows.dst.values, \n                    neg_batch_list]\n            ).astype(np.int32)\n\n            # time-stamp for node = edge time-stamp\n            ts = np.tile(rows.time.values, extra_neg_samples + 2).astype(np.float32)\n            all_elabel.append(rows.label.values)\n            all_subgraphs.append(get_mini_batch(sampler, root_nodes, ts, args.sampled_num_hops))\n            \n            pbar.update(1)\n        pbar.close()\n        subgraph_elabel = (all_subgraphs, all_elabel)\n\n        if cache:\n            try:\n                pickle.dump(subgraph_elabel, open(fn, 'wb'), protocol=pickle.HIGHEST_PROTOCOL)\n            except:\n                print('For some shit reason pickle cannot save ... but anyway ...')\n        \n        ##################################################\n        \n    return subgraph_elabel\n\n\ndef get_random_inds(num_subgraph, cached_neg_samples, neg_samples):\n    ###################################################\n    batch_size = num_subgraph // (2+cached_neg_samples)\n    pos_src_inds = np.arange(batch_size)\n    pos_dst_inds = np.arange(batch_size) + batch_size\n    neg_dst_inds = np.random.randint(low=2, high=2+cached_neg_samples, size=batch_size*neg_samples)\n    neg_dst_inds = batch_size * neg_dst_inds + np.arange(batch_size)\n    mini_batch_inds = np.concatenate([pos_src_inds, pos_dst_inds, neg_dst_inds]).astype(np.int32)\n    ###################################################\n\n    return mini_batch_inds\n\n\ndef get_all_inds(num_subgraph, neg_samples):\n    ###################################################\n    batch_size = num_subgraph // (2+neg_samples)\n    pos_src_inds = np.arange(batch_size)\n    pos_dst_inds = np.arange(batch_size) + batch_size\n    neg_dst_inds = batch_size * 2 + np.arange(batch_size * neg_samples)\n    mini_batch_inds = np.concatenate([pos_src_inds, pos_dst_inds, neg_dst_inds]).astype(np.int32)\n    ###################################################\n\n    return mini_batch_inds\n\n\ndef check_data_leakage(args, g, df):\n    \"\"\"\n    This is a function to double if the sampled graph has eid greater than the positive node pairs eid (if no then no data leakage)\n    \"\"\"\n    for mode in ['train', 'valid', 'test']:\n\n        if mode == 'train':\n            cur_df = df[:args.train_edge_end]\n        elif mode == 'valid':\n            cur_df = df[args.train_edge_end:args.val_edge_end]\n        elif mode == 'test':\n            cur_df = df[args.val_edge_end:]\n\n        loader = cur_df.groupby(cur_df.index // args.batch_size)\n        subgraphs = pre_compute_subgraphs(args, g, df, mode)\n\n        for i, (_, rows) in enumerate(loader):\n            root_nodes = np.concatenate([rows.src.values, rows.dst.values]).astype(np.int32)\n            eids = np.tile(rows.index.values, 2)\n            cur_subgraphs = subgraphs[i][:args.batch_size*2]\n\n            for eid, cur_subgraph in zip(eids, cur_subgraphs):\n                all_eids_in_subgraph = cur_subgraph['eid']\n                if len(all_eids_in_subgraph) == 0:\n                    continue\n                # all edges in the sampled graph has eid smaller than the target edge's eid, i.e,. sampled links never seen before\n                assert sum(all_eids_in_subgraph < eid) == len(all_eids_in_subgraph)\n                \n    print('Does not detect information leakage ...')\n\n\n\"\"\"\nSource: STHN link_pred_train_utils.py\nURL: https://github.com/celi52/STHN/blob/main/link_pred_train_utils.py\n\nNotes: I created a separate function for get_inputs_for_ind so that we can use it for TGB evaluation as well\n\"\"\"\n\ndef get_inputs_for_ind(subgraphs, mode, cached_neg_samples, neg_samples, node_feats, edge_feats, cur_df, cur_inds, ind, args):\n    subgraphs, elabel = subgraphs\n    scaler = MinMaxScaler()\n    if args.use_cached_subgraph == False and mode == 'train':\n        subgraph_data_list = subgraphs.all_root_nodes[ind]\n        mini_batch_inds = get_random_inds(len(subgraph_data_list), cached_neg_samples, neg_samples)\n        subgraph_data = subgraphs.mini_batch(ind, mini_batch_inds)\n    elif mode in ['test', 'tgb-val']:\n        assert cached_neg_samples == neg_samples\n        subgraph_data_list = subgraphs[ind]\n        mini_batch_inds = get_all_inds(len(subgraph_data_list), cached_neg_samples)\n        subgraph_data = [subgraph_data_list[i] for i in mini_batch_inds]      \n    else: # sthn valid\n        subgraph_data_list = subgraphs[ind]\n        mini_batch_inds = get_random_inds(len(subgraph_data_list), cached_neg_samples, neg_samples)\n        subgraph_data = [subgraph_data_list[i] for i in mini_batch_inds]\n    subgraph_data = construct_mini_batch_giant_graph(subgraph_data, args.max_edges)\n\n    # raw edge feats \n    subgraph_edge_feats = edge_feats[subgraph_data['eid']]\n    subgraph_edts = torch.from_numpy(subgraph_data['edts']).float()\n    if args.use_graph_structure and node_feats:\n        num_of_df_links = len(subgraph_data_list) //  (cached_neg_samples+2)   \n        # subgraph_node_feats = compute_sign_feats(node_feats, df, cur_inds, num_of_df_links, subgraph_data['root_nodes'], args)\n        # Erfan: change this part to use masked version\n        subgraph_node_feats = compute_sign_feats(node_feats, cur_df, cur_inds, num_of_df_links, subgraph_data['root_nodes'], args)\n        cur_inds += num_of_df_links\n    else:\n        subgraph_node_feats = None\n    # scale\n    scaler.fit(subgraph_edts.reshape(-1,1))\n    subgraph_edts = scaler.transform(subgraph_edts.reshape(-1,1)).ravel().astype(np.float32) * 1000\n    subgraph_edts = torch.from_numpy(subgraph_edts)\n    \n    # get mini-batch inds\n    all_inds, has_temporal_neighbors = [], []\n\n    # ignore an edge pair if (src_node, dst_node) does not have temporal neighbors\n    all_edge_indptr = subgraph_data['all_edge_indptr']\n    \n    for i in range(len(all_edge_indptr)-1):\n        num_edges = all_edge_indptr[i+1] - all_edge_indptr[i]\n        all_inds.extend([(args.max_edges * i + j) for j in range(num_edges)])\n        has_temporal_neighbors.append(num_edges>0)\n        \n    if not args.predict_class:\n        inputs = [\n            subgraph_edge_feats.to(args.device), \n            subgraph_edts.to(args.device), \n            len(has_temporal_neighbors), \n            torch.tensor(all_inds).long() \n        ]\n    else:\n        subgraph_edge_type = elabel[ind]\n        inputs = [\n            subgraph_edge_feats.to(args.device), \n            subgraph_edts.to(args.device), \n            len(has_temporal_neighbors), \n            torch.tensor(all_inds).long(),  \n            torch.from_numpy(subgraph_edge_type).to(args.device)\n        ]\n    return inputs, subgraph_node_feats, cur_inds\n\ndef run(model, optimizer, args, subgraphs, df, node_feats, edge_feats, MLAUROC, MLAUPRC, mode):\n    time_epoch = 0\n    ###################################################\n    # setup modes\n    cur_inds = 0\n    if mode == 'train':\n        model.train()\n        cur_df = df[args.train_mask]\n        neg_samples = args.neg_samples\n        cached_neg_samples = args.extra_neg_samples\n\n    elif mode == 'valid':\n        model.eval()\n        cur_df = df[args.val_mask]\n        neg_samples = 1\n        cached_neg_samples = 1\n\n    elif mode == 'test':\n        ## Erfan: remove this part use TGB evaluation\n        raise('Use TGB evaluation')\n        # model.eval()\n        # cur_df = df[args.test_mask]\n        # neg_samples = 1\n        # cached_neg_samples = 1\n        # cur_inds = args.val_edge_end\n\n    train_loader = cur_df.groupby(cur_df.index // args.batch_size)\n    pbar = tqdm(total=len(train_loader))\n    pbar.set_description('%s mode with negative samples %d ...'%(mode, neg_samples))        \n        \n    ###################################################\n    # compute + training + fetch all scores\n    loss_lst = []\n    MLAUROC.reset()\n    MLAUPRC.reset()\n    \n    for ind in range(len(train_loader)):\n        ###################################################\n        inputs, subgraph_node_feats, cur_inds = get_inputs_for_ind(subgraphs, mode, cached_neg_samples, neg_samples, node_feats, edge_feats, cur_df, cur_inds, ind, args)\n        \n        start_time = time.time()\n        loss, pred, edge_label = model(inputs, neg_samples, subgraph_node_feats)\n        if mode == 'train' and optimizer != None:\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n        time_epoch += (time.time() - start_time)\n        \n        batch_auroc = MLAUROC.update(pred, edge_label)\n        batch_auprc = MLAUPRC.update(pred, edge_label)\n        loss_lst.append(float(loss.detach()))\n\n        pbar.update(1)\n    pbar.close()    \n    total_auroc = MLAUROC.compute()\n    total_auprc = MLAUPRC.compute()\n    print('%s mode with time %.4f, AUROC %.4f, AUPRC %.4f, loss %.4f'%(mode, time_epoch, total_auroc, total_auprc, loss.item()))\n    return_loss = np.mean(loss_lst)\n    return total_auroc, total_auprc, return_loss, time_epoch\n\n\ndef link_pred_train(model, args, g, df, node_feats, edge_feats):\n    \n    optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)\n    ###################################################\n    # get cached data\n    if args.use_cached_subgraph:\n        train_subgraphs = pre_compute_subgraphs(args, g, df, mode='train')\n    else:\n        train_subgraphs = get_subgraph_sampler(args, g, df, mode='train')\n    \n    valid_subgraphs = pre_compute_subgraphs(args, g, df, mode='valid')\n    # test_subgraphs  = pre_compute_subgraphs(args, g, df, mode='test' )\n          \n    ###################################################\n    all_results = {\n        'train_ap': [],\n        'valid_ap': [],\n        # 'test_ap' : [],\n        'train_auc': [],\n        'valid_auc': [],\n        # 'test_auc' : [],\n        'train_loss': [],\n        'valid_loss': [],\n        # 'test_loss': [],\n    }\n\n    low_loss = 100000\n    user_train_total_time = 0\n    user_epoch_num = 0\n    if args.predict_class:\n        num_classes = args.num_edgeType+1\n        train_AUROC = MulticlassAUROC(num_classes, average=\"macro\", thresholds=None)\n        valid_AUROC = MulticlassAUROC(num_classes, average=\"macro\", thresholds=None)\n        train_AUPRC = MulticlassAveragePrecision(num_classes, average=\"macro\", thresholds=None)\n        valid_AUPRC = MulticlassAveragePrecision(num_classes, average=\"macro\", thresholds=None)\n    else:\n        train_AUROC = BinaryAUROC(thresholds=None)\n        valid_AUROC = BinaryAUROC(thresholds=None)\n        train_AUPRC = BinaryAveragePrecision(thresholds=None)\n        valid_AUPRC = BinaryAveragePrecision(thresholds=None)\n        \n    for epoch in range(args.epochs):\n        print('>>> Epoch ', epoch+1)\n        train_auc, train_ap, train_loss, time_train = run(model, optimizer, args, train_subgraphs, df, \n                                              node_feats, edge_feats, train_AUROC, train_AUPRC, mode='train')\n        with torch.no_grad():\n            # second variable (optimizer) is only required for training\n            valid_auc, valid_ap, valid_loss, time_valid = run(copy.deepcopy(model), None, args, valid_subgraphs, df, \n                                                  node_feats, edge_feats, valid_AUROC, valid_AUPRC, mode='valid')\n        #     # second variable (optimizer) is only required for training\n        #     test_auc,  test_ap,  test_loss, time_test = run(copy.deepcopy(model), None, args, test_subgraphs,  df, \n        #                                           node_feats, edge_feats, test_AUROC, test_AUPRC, mode='test')  \n\n        if valid_loss < low_loss:\n            best_auc_model = copy.deepcopy(model).cpu() \n            best_auc = valid_auc\n            low_loss = valid_loss\n            best_epoch = epoch\n\n        user_train_total_time += time_train + time_valid\n        user_epoch_num += 1\n        if epoch > best_epoch + 20:\n            break\n        \n        all_results['train_ap'].append(train_ap)\n        all_results['valid_ap'].append(valid_ap)\n        # all_results['test_ap'].append(test_ap)\n        \n        all_results['valid_auc'].append(valid_auc)\n        all_results['train_auc'].append(train_auc)\n        # all_results['test_auc'].append(test_auc)\n        \n        all_results['train_loss'].append(train_loss)\n        all_results['valid_loss'].append(valid_loss)\n        # all_results['test_loss'].append(test_loss)        \n        \n    print('best epoch %d, auc score %.4f'%(best_epoch, best_auc))     \n    return best_auc_model\n\n\ndef compute_sign_feats(node_feats, df, start_i, num_links, root_nodes, args):\n    num_duplicate = len(root_nodes) // num_links \n    num_nodes = args.num_nodes\n\n    root_inds = torch.arange(len(root_nodes)).view(num_duplicate, -1)\n    root_inds = [arr.flatten() for arr in root_inds.chunk(1, dim=1)]\n\n    output_feats = torch.zeros((len(root_nodes), node_feats.size(1))).to(args.device)\n    i = start_i\n\n    for _root_ind in root_inds:\n\n        if i == 0 or args.structure_hops == 0:\n            sign_feats = node_feats.clone()\n        else:\n            prev_i = max(0, i - args.structure_time_gap)\n            cur_df = df[prev_i: i] # get adj's row, col indices (as undirected)\n            src = torch.from_numpy(cur_df.src.values)\n            dst = torch.from_numpy(cur_df.dst.values)\n            edge_index = torch.stack([\n                torch.cat([src, dst]), \n                torch.cat([dst, src])\n            ])\n            edge_index, edge_cnt = torch.unique(edge_index, dim=1, return_counts=True) \n            mask = edge_index[0]!=edge_index[1] # ignore self-loops\n            adj = SparseTensor(\n                value = torch.ones_like(edge_cnt[mask]).float(),\n                row = edge_index[0][mask].long(),\n                col = edge_index[1][mask].long(),\n                sparse_sizes=(num_nodes, num_nodes)\n            )\n            adj_norm = row_norm(adj).to(args.device)\n            sign_feats = [node_feats]\n            for _ in range(args.structure_hops):\n                sign_feats.append(adj_norm@sign_feats[-1])\n            sign_feats = torch.sum(torch.stack(sign_feats), dim=0)\n\n        output_feats[_root_ind] = sign_feats[root_nodes[_root_ind]]\n\n        i += len(_root_ind) // num_duplicate\n\n    return output_feats\n\n\n\n################################################################################################\n################################################################################################\n################################################################################################\n\n\"\"\"\nSource: STHN torch_encodings\nURL: https://github.com/celi52/STHN/blob/main/torch_encodings.py\n\"\"\"\n\ndef get_emb(sin_inp):\n    \"\"\"\n    Gets a base embedding for one dimension with sin and cos intertwined\n    \"\"\"\n    emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1)\n    return torch.flatten(emb, -2, -1)\n\n\nclass PositionalEncoding1D(nn.Module):\n    def __init__(self, channels):\n        \"\"\"\n        :param channels: The last dimension of the tensor you want to apply pos emb to.\n        \"\"\"\n        super(PositionalEncoding1D, self).__init__()\n        self.org_channels = channels\n        channels = int(np.ceil(channels / 2) * 2)\n        self.channels = channels\n        inv_freq = 1.0 / (1000 ** (torch.arange(0, channels, 2).float() / channels))\n        self.register_buffer(\"inv_freq\", inv_freq)\n        self.cached_penc = None\n\n    def forward(self, tensor):\n        \"\"\"\n        :param tensor: A 3d tensor of size (batch_size, x, ch)\n        :return: Positional Encoding Matrix of size (batch_size, x, ch)\n        \"\"\"\n        if len(tensor.shape) != 3:\n            raise RuntimeError(\"The input tensor has to be 3d!\")\n\n        if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:\n            return self.cached_penc\n\n        self.cached_penc = None\n        batch_size, x, orig_ch = tensor.shape\n        pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type())\n        sin_inp_x = torch.einsum(\"i,j->ij\", pos_x, self.inv_freq)\n        emb_x = get_emb(sin_inp_x)\n        emb = torch.zeros((x, self.channels), device=tensor.device).type(tensor.type())\n        emb[:, : self.channels] = emb_x\n\n        self.cached_penc = emb[None, :, :orig_ch].repeat(batch_size, 1, 1)\n        return self.cached_penc\n\n\nclass PositionalEncodingPermute1D(nn.Module):\n    def __init__(self, channels):\n        \"\"\"\n        Accepts (batchsize, ch, x) instead of (batchsize, x, ch)\n        \"\"\"\n        super(PositionalEncodingPermute1D, self).__init__()\n        self.penc = PositionalEncoding1D(channels)\n\n    def forward(self, tensor):\n        tensor = tensor.permute(0, 2, 1)\n        enc = self.penc(tensor)\n        return enc.permute(0, 2, 1)\n\n    @property\n    def org_channels(self):\n        return self.penc.org_channels\n\n\nclass PositionalEncoding2D(nn.Module):\n    def __init__(self, channels):\n        \"\"\"\n        :param channels: The last dimension of the tensor you want to apply pos emb to.\n        \"\"\"\n        super(PositionalEncoding2D, self).__init__()\n        self.org_channels = channels\n        channels = int(np.ceil(channels / 4) * 2)\n        self.channels = channels\n        inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels))\n        self.register_buffer(\"inv_freq\", inv_freq)\n        self.cached_penc = None\n\n    def forward(self, tensor):\n        \"\"\"\n        :param tensor: A 4d tensor of size (batch_size, x, y, ch)\n        :return: Positional Encoding Matrix of size (batch_size, x, y, ch)\n        \"\"\"\n        if len(tensor.shape) != 4:\n            raise RuntimeError(\"The input tensor has to be 4d!\")\n\n        if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:\n            return self.cached_penc\n\n        self.cached_penc = None\n        batch_size, x, y, orig_ch = tensor.shape\n        pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type())\n        pos_y = torch.arange(y, device=tensor.device).type(self.inv_freq.type())\n        sin_inp_x = torch.einsum(\"i,j->ij\", pos_x, self.inv_freq)\n        sin_inp_y = torch.einsum(\"i,j->ij\", pos_y, self.inv_freq)\n        emb_x = get_emb(sin_inp_x).unsqueeze(1)\n        emb_y = get_emb(sin_inp_y)\n        emb = torch.zeros((x, y, self.channels * 2), device=tensor.device).type(\n            tensor.type()\n        )\n        emb[:, :, : self.channels] = emb_x\n        emb[:, :, self.channels : 2 * self.channels] = emb_y\n\n        self.cached_penc = emb[None, :, :, :orig_ch].repeat(tensor.shape[0], 1, 1, 1)\n        return self.cached_penc\n\n\nclass PositionalEncodingPermute2D(nn.Module):\n    def __init__(self, channels):\n        \"\"\"\n        Accepts (batchsize, ch, x, y) instead of (batchsize, x, y, ch)\n        \"\"\"\n        super(PositionalEncodingPermute2D, self).__init__()\n        self.penc = PositionalEncoding2D(channels)\n\n    def forward(self, tensor):\n        tensor = tensor.permute(0, 2, 3, 1)\n        enc = self.penc(tensor)\n        return enc.permute(0, 3, 1, 2)\n\n    @property\n    def org_channels(self):\n        return self.penc.org_channels\n\n\nclass PositionalEncoding3D(nn.Module):\n    def __init__(self, channels):\n        \"\"\"\n        :param channels: The last dimension of the tensor you want to apply pos emb to.\n        \"\"\"\n        super(PositionalEncoding3D, self).__init__()\n        self.org_channels = channels\n        channels = int(np.ceil(channels / 6) * 2)\n        if channels % 2:\n            channels += 1\n        self.channels = channels\n        inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels))\n        self.register_buffer(\"inv_freq\", inv_freq)\n        self.cached_penc = None\n\n    def forward(self, tensor):\n        \"\"\"\n        :param tensor: A 5d tensor of size (batch_size, x, y, z, ch)\n        :return: Positional Encoding Matrix of size (batch_size, x, y, z, ch)\n        \"\"\"\n        if len(tensor.shape) != 5:\n            raise RuntimeError(\"The input tensor has to be 5d!\")\n\n        if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:\n            return self.cached_penc\n\n        self.cached_penc = None\n        batch_size, x, y, z, orig_ch = tensor.shape\n        pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type())\n        pos_y = torch.arange(y, device=tensor.device).type(self.inv_freq.type())\n        pos_z = torch.arange(z, device=tensor.device).type(self.inv_freq.type())\n        sin_inp_x = torch.einsum(\"i,j->ij\", pos_x, self.inv_freq)\n        sin_inp_y = torch.einsum(\"i,j->ij\", pos_y, self.inv_freq)\n        sin_inp_z = torch.einsum(\"i,j->ij\", pos_z, self.inv_freq)\n        emb_x = get_emb(sin_inp_x).unsqueeze(1).unsqueeze(1)\n        emb_y = get_emb(sin_inp_y).unsqueeze(1)\n        emb_z = get_emb(sin_inp_z)\n        emb = torch.zeros((x, y, z, self.channels * 3), device=tensor.device).type(\n            tensor.type()\n        )\n        emb[:, :, :, : self.channels] = emb_x\n        emb[:, :, :, self.channels : 2 * self.channels] = emb_y\n        emb[:, :, :, 2 * self.channels :] = emb_z\n\n        self.cached_penc = emb[None, :, :, :, :orig_ch].repeat(batch_size, 1, 1, 1, 1)\n        return self.cached_penc\n\n\nclass PositionalEncodingPermute3D(nn.Module):\n    def __init__(self, channels):\n        \"\"\"\n        Accepts (batchsize, ch, x, y, z) instead of (batchsize, x, y, z, ch)\n        \"\"\"\n        super(PositionalEncodingPermute3D, self).__init__()\n        self.penc = PositionalEncoding3D(channels)\n\n    def forward(self, tensor):\n        tensor = tensor.permute(0, 2, 3, 4, 1)\n        enc = self.penc(tensor)\n        return enc.permute(0, 4, 1, 2, 3)\n\n    @property\n    def org_channels(self):\n        return self.penc.org_channels\n\n\nclass Summer(nn.Module):\n    def __init__(self, penc):\n        \"\"\"\n        :param model: The type of positional encoding to run the summer on.\n        \"\"\"\n        super(Summer, self).__init__()\n        self.penc = penc\n\n    def forward(self, tensor):\n        \"\"\"\n        :param tensor: A 3, 4 or 5d tensor that matches the model output size\n        :return: Positional Encoding Matrix summed to the original tensor\n        \"\"\"\n        penc = self.penc(tensor)\n        assert (\n            tensor.size() == penc.size()\n        ), \"The original tensor size {} and the positional encoding tensor size {} must match!\".format(\n            tensor.size(), penc.size()\n        )\n        return tensor + penc\n\n\n\"\"\"\nSource: STHN model.py\nURL: https://github.com/celi52/STHN/blob/main/model.py\n\"\"\"\n\n\n\"\"\"\nModule: Time-encoder\n\"\"\"\n\nclass TimeEncode(nn.Module):\n    \"\"\"\n    out = linear(time_scatter): 1-->time_dims\n    out = cos(out)\n    \"\"\"\n    def __init__(self, dim):\n        super(TimeEncode, self).__init__()\n        self.dim = dim\n        self.w = nn.Linear(1, dim)\n        self.reset_parameters()\n    \n    def reset_parameters(self, ):\n        self.w.weight = nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, self.dim, dtype=np.float32))).reshape(self.dim, -1))\n        self.w.bias = nn.Parameter(torch.zeros(self.dim))\n\n        self.w.weight.requires_grad = False\n        self.w.bias.requires_grad = False\n    \n    @torch.no_grad()\n    def forward(self, t):\n        output = torch.cos(self.w(t.reshape((-1, 1))))\n        return output\n\n\n\n################################################################################################\n################################################################################################\n################################################################################################\n\"\"\"\nModule: STHN\n\"\"\"\n\nclass FeedForward(nn.Module):\n    \"\"\"\n    2-layer MLP with GeLU (fancy version of ReLU) as activation\n    \"\"\"\n    def __init__(self, dims, expansion_factor, dropout=0, use_single_layer=False):\n        super().__init__()\n\n        self.dims = dims\n        self.use_single_layer = use_single_layer\n        \n        self.expansion_factor = expansion_factor\n        self.dropout = dropout\n\n        if use_single_layer:\n            self.linear_0 = nn.Linear(dims, dims)\n        else:\n            self.linear_0 = nn.Linear(dims, int(expansion_factor * dims))\n            self.linear_1 = nn.Linear(int(expansion_factor * dims), dims)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        self.linear_0.reset_parameters()\n        if self.use_single_layer==False:\n            self.linear_1.reset_parameters()\n\n    def forward(self, x):\n        x = self.linear_0(x)\n        x = F.gelu(x)\n        x = F.dropout(x, p=self.dropout, training=self.training)\n        \n        if self.use_single_layer==False:\n            x = self.linear_1(x)\n            x = F.dropout(x, p=self.dropout, training=self.training)\n        return x\n\n\nclass TransformerBlock(nn.Module):\n    \"\"\"\n    out = X.T + MLP_Layernorm(X.T)     # apply token mixing\n    out = out.T + MLP_Layernorm(out.T) # apply channel mixing\n    \"\"\"\n    def __init__(self, dims, \n                 channel_expansion_factor=4, \n                 dropout=0.2, \n                 module_spec=None, use_single_layer=False):\n        super().__init__()\n        \n        if module_spec == None:\n            self.module_spec = ['token', 'channel']\n        else:\n            self.module_spec = module_spec.split('+')\n\n        self.dims = dims\n        if 'token' in self.module_spec:\n            self.transformer_encoder = _MultiheadAttention(d_model=dims, \n                                                           n_heads=2,\n                                                           d_k=None,\n                                                           d_v=None,\n                                                           attn_dropout=dropout)\n        if 'channel' in self.module_spec:\n            self.channel_layernorm = nn.LayerNorm(dims)\n            self.channel_forward = FeedForward(dims, channel_expansion_factor, dropout, use_single_layer)\n        \n    def reset_parameters(self):\n        if 'token' in self.module_spec:\n            self.transformer_encoder.reset_parameters()\n        if 'channel' in self.module_spec:\n            self.channel_layernorm.reset_parameters()\n            self.channel_forward.reset_parameters()\n        \n    def token_mixer(self, x):\n        x = self.transformer_encoder(x, x, x)\n        return x\n    \n    def channel_mixer(self, x):\n        x = self.channel_layernorm(x)\n        x = self.channel_forward(x)\n        return x\n\n    def forward(self, x):\n        if 'token' in self.module_spec:\n            x = x + self.token_mixer(x)\n        if 'channel' in self.module_spec:\n            x = x + self.channel_mixer(x)\n        return x\n\n\nclass _MultiheadAttention(nn.Module):\n    def __init__(self, d_model, n_heads, d_k=None, d_v=None, res_attention=False, attn_dropout=0., proj_dropout=0., qkv_bias=True, lsa=False):\n        \"\"\"Multi Head Attention Layer\n        Input shape:\n            Q:       [batch_size (bs) x max_q_len x d_model]\n            K, V:    [batch_size (bs) x q_len x d_model]\n            mask:    [q_len x q_len]\n        \"\"\"\n        super().__init__()\n        d_k = d_model // n_heads if d_k is None else d_k\n        d_v = d_model // n_heads if d_v is None else d_v\n\n        self.n_heads, self.d_k, self.d_v = n_heads, d_k, d_v\n\n        self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)\n        self.W_K = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)\n        self.W_V = nn.Linear(d_model, d_v * n_heads, bias=qkv_bias)\n\n        # Scaled Dot-Product Attention (multiple heads)\n        self.res_attention = res_attention\n        self.sdp_attn = _ScaledDotProductAttention(d_model, n_heads, attn_dropout=attn_dropout, res_attention=self.res_attention, lsa=lsa)\n\n        # Poject output\n        self.to_out = nn.Sequential(nn.Linear(n_heads * d_v, d_model), nn.Dropout(attn_dropout))\n\n    def reset_parameters(self):\n        self.to_out[0].reset_parameters()\n        self.W_Q.reset_parameters()\n        self.W_K.reset_parameters()\n        self.W_V.reset_parameters()\n\n    def forward(self, Q:Tensor, K:Optional[Tensor]=None, V:Optional[Tensor]=None, prev:Optional[Tensor]=None,\n                key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None):\n\n        bs = Q.size(0)\n        if K is None: K = Q\n        if V is None: V = Q\n\n        # Linear (+ split in multiple heads)\n        q_s = self.W_Q(Q).view(bs, -1, self.n_heads, self.d_k).transpose(1,2)       # q_s    : [bs x n_heads x max_q_len x d_k]\n        k_s = self.W_K(K).view(bs, -1, self.n_heads, self.d_k).permute(0,2,3,1)     # k_s    : [bs x n_heads x d_k x q_len] - transpose(1,2) + transpose(2,3)\n        v_s = self.W_V(V).view(bs, -1, self.n_heads, self.d_v).transpose(1,2)       # v_s    : [bs x n_heads x q_len x d_v]\n\n        # Apply Scaled Dot-Product Attention (multiple heads)\n        output, attn_weights = self.sdp_attn(q_s, k_s, v_s, prev=prev, key_padding_mask=key_padding_mask, attn_mask=attn_mask)\n\n        # back to the original inputs dimensions\n        output = output.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.d_v) # output: [bs x q_len x n_heads * d_v]\n        output = self.to_out(output)\n\n        return output\n\n\nclass _ScaledDotProductAttention(nn.Module):\n    r\"\"\"Scaled Dot-Product Attention module (Attention is all you need by Vaswani et al., 2017) with optional residual attention from previous layer\n    (Realformer: Transformer likes residual attention by He et al, 2020) and locality self sttention (Vision Transformer for Small-Size Datasets\n    by Lee et al, 2021)\"\"\"\n\n    def __init__(self, d_model, n_heads, attn_dropout=0., res_attention=False, lsa=False):\n        super().__init__()\n        self.attn_dropout = nn.Dropout(attn_dropout)\n        self.res_attention = res_attention\n        head_dim = d_model // n_heads\n        self.scale = nn.Parameter(torch.tensor(head_dim ** -0.5), requires_grad=lsa)\n        self.lsa = lsa\n\n    def forward(self, q:Tensor, k:Tensor, v:Tensor, prev:Optional[Tensor]=None, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None):\n        '''\n        Input shape:\n            q               : [bs x n_heads x max_q_len x d_k]\n            k               : [bs x n_heads x d_k x seq_len]\n            v               : [bs x n_heads x seq_len x d_v]\n            prev            : [bs x n_heads x q_len x seq_len]\n            key_padding_mask: [bs x seq_len]\n            attn_mask       : [1 x seq_len x seq_len]\n        Output shape:\n            output:  [bs x n_heads x q_len x d_v]\n            attn   : [bs x n_heads x q_len x seq_len]\n            scores : [bs x n_heads x q_len x seq_len]\n        '''\n\n        # Scaled MatMul (q, k) - similarity scores for all pairs of positions in an input sequence\n        attn_scores = torch.matmul(q, k) * self.scale      # attn_scores : [bs x n_heads x max_q_len x q_len]\n\n        # Add pre-softmax attention scores from the previous layer (optional)\n        if prev is not None: attn_scores = attn_scores + prev\n\n        # Attention mask (optional)\n        if attn_mask is not None:                                     # attn_mask with shape [q_len x seq_len] - only used when q_len == seq_len\n            if attn_mask.dtype == torch.bool:\n                attn_scores.masked_fill_(attn_mask, -np.inf)\n            else:\n                attn_scores += attn_mask\n\n        # Key padding mask (optional)\n        if key_padding_mask is not None:                              # mask with shape [bs x q_len] (only when max_w_len == q_len)\n            attn_scores.masked_fill_(key_padding_mask.unsqueeze(1).unsqueeze(2), -np.inf)\n\n        # normalize the attention weights\n        attn_weights = F.softmax(attn_scores, dim=-1)                 # attn_weights   : [bs x n_heads x max_q_len x q_len]\n        attn_weights = self.attn_dropout(attn_weights)\n\n        # compute the new values given the attention weights\n        output = torch.matmul(attn_weights, v)                        # output: [bs x n_heads x max_q_len x d_v]\n\n        if self.res_attention: return output, attn_weights, attn_scores\n        else: return output, attn_weights\n\n\n    \nclass FeatEncode(nn.Module):\n    \"\"\"\n    Return [raw_edge_feat | TimeEncode(edge_time_stamp)]\n    \"\"\"\n    def __init__(self, time_dims, feat_dims, out_dims):\n        super().__init__()\n        \n        self.time_encoder = TimeEncode(time_dims)\n        self.feat_encoder = nn.Linear(time_dims + feat_dims, out_dims) \n        self.reset_parameters()\n\n    def reset_parameters(self):\n        self.time_encoder.reset_parameters()\n        self.feat_encoder.reset_parameters()\n        \n    def forward(self, edge_feats, edge_ts):\n        edge_time_feats = self.time_encoder(edge_ts)\n        x = torch.cat([edge_feats, edge_time_feats], dim=1)\n        return self.feat_encoder(x)\n\nclass Patch_Encoding(nn.Module):\n    \"\"\"\n    Input : [ batch_size, graph_size, edge_dims+time_dims]\n    Output: [ batch_size, graph_size, output_dims]\n    \"\"\"\n    def __init__(self, per_graph_size, time_channels,\n                 input_channels, hidden_channels, out_channels,\n                 num_layers, dropout,\n                 channel_expansion_factor,\n                 window_size,\n                 module_spec=None, \n                 use_single_layer=False\n                ):\n        super().__init__()\n        self.per_graph_size = per_graph_size\n        self.dropout = nn.Dropout(dropout)\n        self.num_layers = num_layers\n        \n        # input & output classifer\n        self.feat_encoder = FeatEncode(time_channels, input_channels, hidden_channels)\n        self.layernorm = nn.LayerNorm(hidden_channels)\n        self.mlp_head = nn.Linear(hidden_channels, out_channels)\n        \n        # inner layers\n        self.mixer_blocks = torch.nn.ModuleList()\n        for _ in range(num_layers):\n            self.mixer_blocks.append(\n                TransformerBlock(hidden_channels, \n                                 channel_expansion_factor, \n                                 dropout, \n                                 module_spec=None, \n                                 use_single_layer=use_single_layer)\n            )\n        # padding\n        self.stride = window_size\n        self.window_size = window_size\n        self.pad_projector = nn.Linear(window_size*hidden_channels, hidden_channels)\n        self.p_enc_1d_model_sum = Summer(PositionalEncoding1D(hidden_channels))\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        for layer in self.mixer_blocks:\n            layer.reset_parameters()\n        self.feat_encoder.reset_parameters()\n        self.layernorm.reset_parameters()\n        self.mlp_head.reset_parameters()\n\n    def forward(self, edge_feats, edge_ts, batch_size, inds):\n        # x : [ batch_size, graph_size, edge_dims+time_dims]\n        edge_time_feats = self.feat_encoder(edge_feats, edge_ts)\n        x = torch.zeros((batch_size * self.per_graph_size, \n                         edge_time_feats.size(1)), device=edge_feats.device)\n        x[inds] = x[inds] + edge_time_feats         \n        x = x. view(-1, self.per_graph_size//self.window_size, self.window_size*x.shape[-1])\n        x = self.pad_projector(x)\n        x = self.p_enc_1d_model_sum(x) \n        for i in range(self.num_layers):\n            # apply to channel + feat dim\n            x = self.mixer_blocks[i](x)    \n        x = self.layernorm(x)\n        x = torch.mean(x, dim=1)\n        x = self.mlp_head(x)\n        return x\n    \n################################################################################################\n################################################################################################\n################################################################################################\n\n\"\"\"\nEdge predictor\n\"\"\"\n\nclass EdgePredictor_per_node(torch.nn.Module):\n    \"\"\"\n    out = linear(src_node_feats) + linear(dst_node_feats)\n    out = ReLU(out)\n    \"\"\"\n    def __init__(self, dim_in_time, dim_in_node, predict_class):\n        super().__init__()\n\n        self.dim_in_time = dim_in_time\n        self.dim_in_node = dim_in_node\n\n        # dim_in_time + dim_in_node\n        self.src_fc = torch.nn.Linear(dim_in_time+dim_in_node, 100)\n        self.dst_fc = torch.nn.Linear(dim_in_time+dim_in_node, 100)\n    \n        self.out_fc = torch.nn.Linear(100, predict_class)\n        self.reset_parameters()\n        \n    def reset_parameters(self, ):\n        self.src_fc.reset_parameters()\n        self.dst_fc.reset_parameters()\n        self.out_fc.reset_parameters()\n\n    def forward(self, h, neg_samples=1):\n        num_edge = h.shape[0]//(neg_samples + 2)\n        h_src = self.src_fc(h[:num_edge])\n        h_pos_dst = self.dst_fc(h[num_edge:2 * num_edge])\n        h_neg_dst = self.dst_fc(h[2 * num_edge:])\n        \n        h_pos_edge = torch.nn.functional.relu(h_src + h_pos_dst)\n        h_neg_edge = torch.nn.functional.relu(h_src.tile(neg_samples, 1) + h_neg_dst)\n        \n        return self.out_fc(h_pos_edge), self.out_fc(h_neg_edge)\n    \n    \nclass STHN_Interface(nn.Module):\n    def __init__(self, mlp_mixer_configs, edge_predictor_configs):\n        super(STHN_Interface, self).__init__()\n\n        self.time_feats_dim = edge_predictor_configs['dim_in_time']\n        self.node_feats_dim = edge_predictor_configs['dim_in_node']\n\n        if self.time_feats_dim > 0:\n            self.base_model = Patch_Encoding(**mlp_mixer_configs)\n\n        self.edge_predictor = EdgePredictor_per_node(**edge_predictor_configs)        \n        self.creterion = nn.BCEWithLogitsLoss(reduction='none') \n        self.reset_parameters()            \n\n    def reset_parameters(self):\n        if self.time_feats_dim > 0:\n            self.base_model.reset_parameters()\n        self.edge_predictor.reset_parameters()\n        \n    def forward(self, model_inputs, neg_samples, node_feats):        \n        pred_pos, pred_neg = self.predict(model_inputs, neg_samples, node_feats)\n        all_pred = torch.cat((pred_pos, pred_neg), dim=0)\n        all_edge_label = torch.cat((torch.ones_like(pred_pos), \n                                    torch.zeros_like(pred_neg)), dim=0)\n        loss = self.creterion(all_pred, all_edge_label).mean()\n        return loss, all_pred, all_edge_label\n    \n    def predict(self, model_inputs, neg_samples, node_feats):\n        if self.time_feats_dim > 0 and self.node_feats_dim == 0:\n            x = self.base_model(*model_inputs)\n        elif self.time_feats_dim > 0 and self.node_feats_dim > 0:\n            x = self.base_model(*model_inputs)\n            x = torch.cat([x, node_feats], dim=1)\n        elif self.time_feats_dim == 0 and self.node_feats_dim > 0:\n            x = node_feats\n        else:\n            print('Either time_feats_dim or node_feats_dim must larger than 0!')\n        \n        pred_pos, pred_neg = self.edge_predictor(x, neg_samples=neg_samples)\n        return pred_pos, pred_neg\n\nclass Multiclass_Interface(nn.Module):\n    def __init__(self, mlp_mixer_configs, edge_predictor_configs):\n        super(Multiclass_Interface, self).__init__()\n\n        self.time_feats_dim = edge_predictor_configs['dim_in_time']\n        self.node_feats_dim = edge_predictor_configs['dim_in_node']\n\n        if self.time_feats_dim > 0:\n            self.base_model = Patch_Encoding(**mlp_mixer_configs)\n\n        self.edge_predictor = EdgePredictor_per_node(**edge_predictor_configs)        \n        self.creterion = nn.CrossEntropyLoss(reduction='none')\n        self.reset_parameters()            \n\n    def reset_parameters(self):\n        if self.time_feats_dim > 0:\n            self.base_model.reset_parameters()\n        self.edge_predictor.reset_parameters()\n        \n    def forward(self, model_inputs, neg_samples, node_feats):        \n        pos_edge_label = model_inputs[-1].view(-1,1)\n        model_inputs = model_inputs[:-1]\n        pred_pos, pred_neg = self.predict(model_inputs, neg_samples, node_feats)\n        \n        all_pred = torch.cat((pred_pos, pred_neg), dim=0)\n        all_edge_label = torch.squeeze(torch.cat((pos_edge_label, torch.zeros_like(pos_edge_label)), dim=0))\n        loss = self.creterion(all_pred, all_edge_label).mean()\n            \n        return loss, all_pred, all_edge_label\n    \n    def predict(self, model_inputs, neg_samples, node_feats):\n        if self.time_feats_dim > 0 and self.node_feats_dim == 0:\n            x = self.base_model(*model_inputs)\n        elif self.time_feats_dim > 0 and self.node_feats_dim > 0:\n            x = self.base_model(*model_inputs)\n            x = torch.cat([x, node_feats], dim=1)\n        elif self.time_feats_dim == 0 and self.node_feats_dim > 0:\n            x = node_feats\n        else:\n            print('Either time_feats_dim or node_feats_dim must larger than 0!')\n        \n        pred_pos, pred_neg = self.edge_predictor(x, neg_samples=neg_samples)\n        return pred_pos, pred_neg\n\n    "
  },
  {
    "path": "modules/sthn_sampler_setup.py",
    "content": "from glob import glob\nfrom setuptools import setup\nfrom pybind11.setup_helpers import Pybind11Extension\n\next_modules = [\n    Pybind11Extension(\"sampler_core\", \n                      ['sampler_core.cpp'],\n                      extra_compile_args = ['-fopenmp'],\n                      extra_link_args = ['-fopenmp'],),\n]\n\nsetup(\n    name = \"sampler_core\",\n    version = \"0.0.1\",\n    author = \"XXXX-2\",\n    author_email = \"XXXX-3\",\n    url = \"XXXX-4\",\n    description = \"Parallel Sampling for Temporal Graphs\",\n    ext_modules = ext_modules,\n)"
  },
  {
    "path": "modules/time_enc.py",
    "content": "\"\"\"\nTime Encoding Module\n\nReference:\n    - https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/models/tgn.html\n\"\"\"\n\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Linear\n\n\nclass TimeEncoder(torch.nn.Module):\n    def __init__(self, out_channels: int):\n        super().__init__()\n        self.out_channels = out_channels\n        self.lin = Linear(1, out_channels)\n\n    def reset_parameters(self):\n        self.lin.reset_parameters()\n\n    def forward(self, t: Tensor) -> Tensor:\n        return self.lin(t.view(-1, 1)).cos()\n"
  },
  {
    "path": "modules/timetraveler_agent.py",
    "content": "\"\"\"\nTimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting\nReference:\n- https://github.com/JHL-HUST/TITer/blob/master/model/agent.py\nHaohai Sun, Jialun Zhong, Yunpu Ma, Zhen Han, Kun He.\nTimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting EMNLP 2021\n\"\"\"\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\n\nclass HistoryEncoder(nn.Module):\n    def __init__(self, config):\n        super(HistoryEncoder, self).__init__()\n        self.config = config\n        self.lstm_cell = torch.nn.LSTMCell(input_size=config['action_dim'],\n                                           hidden_size=config['state_dim'])\n\n    def set_hiddenx(self, batch_size):\n        \"\"\"Set hidden layer parameters. Initialize to 0\"\"\"\n        if self.config['cuda']:\n            self.hx = torch.zeros(batch_size, self.config['state_dim'], device='cuda')\n            self.cx = torch.zeros(batch_size, self.config['state_dim'], device='cuda')\n        else:\n            self.hx = torch.zeros(batch_size, self.config['state_dim'])\n            self.cx = torch.zeros(batch_size, self.config['state_dim'])\n\n    def forward(self, prev_action, mask):\n        \"\"\"mask: True if NO_OP. ON_OP does not affect history coding results\"\"\"\n        self.hx_, self.cx_ = self.lstm_cell(prev_action, (self.hx, self.cx))\n        self.hx = torch.where(mask, self.hx, self.hx_)\n        self.cx = torch.where(mask, self.cx, self.cx_)\n        return self.hx\n\nclass PolicyMLP(nn.Module):\n    def __init__(self, config):\n        super(PolicyMLP, self).__init__()\n        self.mlp_l1= nn.Linear(config['mlp_input_dim'], config['mlp_hidden_dim'], bias=True)\n        self.mlp_l2 = nn.Linear(config['mlp_hidden_dim'], config['action_dim'], bias=True)\n\n    def forward(self, state_query):\n        hidden = torch.relu(self.mlp_l1(state_query))\n        output = self.mlp_l2(hidden).unsqueeze(1)\n        return output\n\nclass DynamicEmbedding(nn.Module):\n    def __init__(self, n_ent, dim_ent, dim_t):\n        super(DynamicEmbedding, self).__init__()\n        self.ent_embs = nn.Embedding(n_ent, dim_ent - dim_t)\n        self.w = torch.nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, dim_t))).float())\n        self.b = torch.nn.Parameter(torch.zeros(dim_t).float())\n\n    def forward(self, entities, dt):\n        dt = dt.unsqueeze(-1)\n        batch_size = dt.size(0)\n        seq_len = dt.size(1)\n\n        dt = dt.view(batch_size, seq_len, 1)\n        t = torch.cos(self.w.view(1, 1, -1) * dt + self.b.view(1, 1, -1))\n        t = t.squeeze(1)  # [batch_size, time_dim]\n\n        e = self.ent_embs(entities)\n        return torch.cat((e, t), -1)\n\nclass StaticEmbedding(nn.Module):\n    def __init__(self, n_ent, dim_ent):\n        super(StaticEmbedding, self).__init__()\n        self.ent_embs = nn.Embedding(n_ent, dim_ent)\n\n    def forward(self, entities, timestamps=None):\n        return self.ent_embs(entities)\n\nclass Agent(nn.Module):\n    def __init__(self, config):\n        super(Agent, self).__init__()\n        self.num_rel = config['num_rel'] * 2 + 2\n        self.config = config\n\n        # [0, num_rel) -> normal relations; num_rel -> stay in place，(num_rel, num_rel * 2] reversed relations.\n        self.NO_OP = self.num_rel  # Stay in place; No Operation\n        self.ePAD = config['num_ent']  # Padding entity\n        self.rPAD = config['num_rel'] * 2 + 1  # Padding relation\n        self.tPAD = 0  # Padding time\n\n        if self.config['entities_embeds_method'] == 'dynamic':\n            self.ent_embs = DynamicEmbedding(config['num_ent']+1, config['ent_dim'], config['time_dim'])\n        else:\n            self.ent_embs = StaticEmbedding(config['num_ent']+1, config['ent_dim'])\n\n        self.rel_embs = nn.Embedding(config['num_ent'], config['rel_dim'])\n\n        self.policy_step = HistoryEncoder(config)\n        self.policy_mlp = PolicyMLP(config)\n\n        self.score_weighted_fc = nn.Linear(\n            self.config['ent_dim'] * 2 + self.config['rel_dim'] * 2 + self.config['state_dim'],\n            1, bias=True)\n\n    def forward(self, prev_relation, current_entities, current_timestamps,\n                query_relation, query_entity, query_timestamps, action_space):\n        \"\"\"\n        Args:\n            prev_relation: [batch_size]\n            current_entities: [batch_size]\n            current_timestamps: [batch_size]\n            query_relation: embeddings of query relation，[batch_size, rel_dim]\n            query_entity: embeddings of query entity, [batch_size, ent_dim]\n            query_timestamps: [batch_size]\n            action_space: [batch_size, max_actions_num, 3] (relations, entities, timestamps)\n        \"\"\"\n        # embeddings\n        current_delta_time = query_timestamps - current_timestamps\n        current_embds = self.ent_embs(current_entities, current_delta_time)  # [batch_size, ent_dim] #dynamic embedding\n        prev_relation_embds = self.rel_embs(prev_relation)  # [batch_size, rel_dim]\n\n        # Pad Mask\n        pad_mask = torch.ones_like(action_space[:, :, 0]) * self.rPAD  # [batch_size, action_number]\n        pad_mask = torch.eq(action_space[:, :, 0], pad_mask)  # [batch_size, action_number]\n\n        # History Encode\n        NO_OP_mask = torch.eq(prev_relation, torch.ones_like(prev_relation) * self.NO_OP)  # [batch_size]\n        NO_OP_mask = NO_OP_mask.repeat(self.config['state_dim'], 1).transpose(1, 0)  # [batch_size, state_dim]\n        prev_action_embedding = torch.cat([prev_relation_embds, current_embds], dim=-1)  # [batch_size, rel_dim + ent_dim]\n        lstm_output = self.policy_step(prev_action_embedding, NO_OP_mask)  # [batch_size, state_dim] (5) Path encoding\n\n        # Neighbor/condidate_actions embeddings\n        action_num = action_space.size(1)\n        neighbors_delta_time = query_timestamps.unsqueeze(-1).repeat(1, action_num) - action_space[:, :, 2]\n        neighbors_entities = self.ent_embs(action_space[:, :, 1], neighbors_delta_time)  # [batch_size, action_num, ent_dim]\n        neighbors_relations = self.rel_embs(action_space[:, :, 0])  # [batch_size, action_num, rel_dim]\n\n        # agent state representation\n        agent_state = torch.cat([lstm_output, query_entity, query_relation], dim=-1)  # [batch_size, state_dim + ent_dim + rel_dim]\n        output = self.policy_mlp(agent_state)  # [batch_size, 1, action_dim] action_dim == rel_dim + ent_dim\n\n        # scoring\n        entitis_output = output[:, :, self.config['rel_dim']:]\n        relation_ouput = output[:, :, :self.config['rel_dim']]\n        relation_score = torch.sum(torch.mul(neighbors_relations, relation_ouput), dim=2)\n        entities_score = torch.sum(torch.mul(neighbors_entities, entitis_output), dim=2)  # [batch_size, action_number]\n\n        actions = torch.cat([neighbors_relations, neighbors_entities], dim=-1)  # [batch_size, action_number, action_dim]\n\n        agent_state_repeats = agent_state.unsqueeze(1).repeat(1, actions.shape[1], 1)\n        score_attention_input = torch.cat([actions, agent_state_repeats], dim=-1)\n        a = self.score_weighted_fc(score_attention_input)                                   # (8)\n        a = torch.sigmoid(a).squeeze()                      # [batch_size, action_number]   # (8)\n\n        scores = (1 - a) * relation_score + a * entities_score                              # (6) a= beta\n\n        # Padding mask\n        scores = scores.masked_fill(pad_mask, -1e10)  # [batch_size ,action_number]\n\n        action_prob = torch.softmax(scores, dim=1)\n        action_id = torch.multinomial(action_prob, 1)  # Randomly select an action. [batch_size, 1] # ACTION SELECTION\n\n        logits = torch.nn.functional.log_softmax(scores, dim=1)  # [batch_size, action_number]\n        one_hot = torch.zeros_like(logits).scatter(1, action_id, 1)\n        loss = - torch.sum(torch.mul(logits, one_hot), dim=1)\n        return loss, logits, action_id\n\n    def get_im_embedding(self, cooccurrence_entities):\n        \"\"\"Get the inductive mean representation of the co-occurrence relation.\n        cooccurrence_entities: a list that contains the trained entities with the co-occurrence relation.\n        return: torch.tensor, representation of the co-occurrence entities.\n        \"\"\"\n        entities = self.ent_embs.ent_embs.weight.data[cooccurrence_entities]\n        im = torch.mean(entities, dim=0)\n        return im\n\n    def update_entity_embedding(self, entity, ims, mu):\n        \"\"\"Update the entity representation with the co-occurrence relations in the last timestamp.\n        entity: int, the entity that needs to be updated.\n        ims: torch.tensor, [number of co-occurrence, -1], the IM representations of the co-occurrence relations\n        mu: update ratio, the hyperparam.\n        \"\"\"\n        self.source_entity = self.ent_embs.ent_embs.weight.data[entity]\n        self.ent_embs.ent_embs.weight.data[entity] = mu * self.source_entity + (1 - mu) * torch.mean(ims, dim=0)\n\n    def entities_embedding_shift(self, entity, im, mu):\n        \"\"\"Prediction shift.\"\"\"\n        self.source_entity = self.ent_embs.ent_embs.weight.data[entity]\n        self.ent_embs.ent_embs.weight.data[entity] = mu * self.source_entity + (1 - mu) * im\n\n    def back_entities_embedding(self, entity):\n        \"\"\"Go back after shift ends.\"\"\"\n        self.ent_embs.ent_embs.weight.data[entity] = self.source_entity"
  },
  {
    "path": "modules/timetraveler_dirichlet.py",
    "content": "\"\"\"\nTimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting\nReference:\n- https://github.com/JHL-HUST/TITer/blob/master/model/dirichlet.py\nHaohai Sun, Jialun Zhong, Yunpu Ma, Zhen Han, Kun He.\nTimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting EMNLP 2021\n\"\"\"\n\n\"\"\"Dirichlet.py\nMaximum likelihood estimation and likelihood ratio tests of Dirichlet\ndistribution models of data.\nMost of this package is a port of Thomas P. Minka's wonderful Fastfit MATLAB\ncode. Much thanks to him for that and his clear paper \"Estimating a Dirichlet\ndistribution\". See the following URL for more information:\n    http://research.microsoft.com/en-us/um/people/minka/\"\"\"\n\nimport sys\n\nimport numpy as np\nimport scipy as sp\nimport scipy.stats as stats\nfrom scipy.stats import dirichlet\nfrom tqdm import tqdm\nfrom numpy import (\n    arange,\n    array,\n    asanyarray,\n    asarray,\n    diag,\n    exp,\n    isscalar,\n    log,\n    ndarray,\n    ones,\n    vstack,\n    zeros,\n)\nfrom numpy.linalg import norm\nfrom scipy.special import gammaln, polygamma, psi\n\nMAXINT = sys.maxsize\n\n__all__ = [\n    \"loglikelihood\",\n    \"meanprecision\",\n    \"mle\",\n    \"pdf\",\n    \"test\",\n]\n\neuler = -1 * psi(1)  # Euler-Mascheroni constant\n\n\nclass NotConvergingError(Exception):\n    \"\"\"Error when a successive approximation method doesn't converge\n    \"\"\"\n    pass\n\n\ndef test(D1, D2, method=\"meanprecision\", maxiter=None):\n    \"\"\"Test for statistical difference between observed proportions.\n    Parameters\n    ----------\n    D1 : (N1, K) shape array\n    D2 : (N2, K) shape array\n        Input observations. ``N1`` and ``N2`` are the number of observations,\n        and ``K`` is the number of parameters for the Dirichlet distribution\n        (i.e. the number of levels or categorical possibilities).\n        Each cell is the proportion seen in that category for a particular\n        observation. Rows of the matrices must add up to 1.\n    method : string\n        One of ``'fixedpoint'`` and ``'meanprecision'``, designates method by\n        which to find MLE Dirichlet distribution. Default is\n        ``'meanprecision'``, which is faster.\n    maxiter : int\n        Maximum number of iterations to take calculations. Default is\n        ``sys.maxint``.\n    Returns\n    -------\n    D : float\n        Test statistic, which is ``-2 * log`` of likelihood ratios.\n    p : float\n        p-value of test.\n    a0 : (K,) shape array\n    a1 : (K,) shape array\n    a2 : (K,) shape array\n        MLE parameters for the Dirichlet distributions fit to\n        ``D1`` and ``D2`` together, ``D1``, and ``D2``, respectively.\"\"\"\n\n    N1, K1 = D1.shape\n    N2, K2 = D2.shape\n    if K1 != K2:\n        raise ValueError(\"D1 and D2 must have the same number of columns\")\n\n    D0 = vstack((D1, D2))\n    a0 = mle(D0, method=method, maxiter=maxiter)\n    a1 = mle(D1, method=method, maxiter=maxiter)\n    a2 = mle(D2, method=method, maxiter=maxiter)\n\n    D = 2 * (loglikelihood(D1, a1) + loglikelihood(D2, a2) - loglikelihood(D0, a0))\n    return (D, stats.chi2.sf(D, K1), a0, a1, a2)\n\n\ndef pdf(alphas):\n    \"\"\"Returns a Dirichlet PDF function\n    Parameters\n    ----------\n    alphas : (K,) shape array\n        The parameters for the distribution of shape ``(K,)``.\n    Returns\n    -------\n    function\n        The PDF function, takes an ``(N, K)`` shape input and gives an\n        ``(N,)`` output.\n    \"\"\"\n    alphap = alphas - 1\n    c = np.exp(gammaln(alphas.sum()) - gammaln(alphas).sum())\n\n    def dirichlet(xs):\n        \"\"\"Dirichlet PDF\n        Parameters\n        ----------\n        xs : (N, K) shape array\n            The ``(N, K)`` shape input matrix\n\n        Returns\n        -------\n        (N,) shape array\n            Point value for PDF\n        \"\"\"\n        return c * (xs ** alphap).prod(axis=1)\n\n    return dirichlet\n\n\ndef meanprecision(a):\n    \"\"\"Mean and precision of a Dirichlet distribution.\n    Parameters\n    ----------\n    a : (K,) shape array\n        Parameters of a Dirichlet distribution.\n    Returns\n    -------\n    mean : (K,) shape array\n        Means of the Dirichlet distribution. Values are in [0,1].\n    precision : float\n        Precision or concentration parameter of the Dirichlet distribution.\"\"\"\n\n    s = a.sum()\n    m = a / s\n    return (m, s)\n\n\ndef loglikelihood(D, a):\n    \"\"\"Compute log likelihood of Dirichlet distribution, i.e. log p(D|a).\n    Parameters\n    ----------\n    D : (N, K) shape array\n        ``N`` is the number of observations, ``K`` is the number of\n        parameters for the Dirichlet distribution.\n    a : (K,) shape array\n        Parameters for the Dirichlet distribution.\n    Returns\n    -------\n    logl : float\n        The log likelihood of the Dirichlet distribution\"\"\"\n    N, K = D.shape\n    logp = log(D).mean(axis=0)\n    return N * (gammaln(a.sum()) - gammaln(a).sum() + ((a - 1) * logp).sum())\n\n\ndef mle(D, tol=1e-7, method=\"meanprecision\", maxiter=None):\n    \"\"\"Iteratively computes maximum likelihood Dirichlet distribution\n    for an observed data set, i.e. a for which log p(D|a) is maximum.\n    Parameters\n    ----------\n    D : (N, K) shape array\n        ``N`` is the number of observations, ``K`` is the number of\n        parameters for the Dirichlet distribution.\n    tol : float\n        If Euclidean distance between successive parameter arrays is less than\n        ``tol``, calculation is taken to have converged.\n    method : string\n        One of ``'fixedpoint'`` and ``'meanprecision'``, designates method by\n        which to find MLE Dirichlet distribution. Default is\n        ``'meanprecision'``, which is faster.\n    maxiter : int\n        Maximum number of iterations to take calculations. Default is\n        ``sys.maxint``.\n    Returns\n    -------\n    a : (K,) shape array\n        Maximum likelihood parameters for Dirichlet distribution.\"\"\"\n\n    if method == \"meanprecision\":\n        return _meanprecision(D, tol=tol, maxiter=maxiter)\n    else:\n        return _fixedpoint(D, tol=tol, maxiter=maxiter)\n\n\ndef _fixedpoint(D, tol=1e-7, maxiter=None):\n    \"\"\"Simple fixed point iteration method for MLE of Dirichlet distribution\n    Parameters\n    ----------\n    D : (N, K) shape array\n        ``N`` is the number of observations, ``K`` is the number of\n        parameters for the Dirichlet distribution.\n    tol : float\n        If Euclidean distance between successive parameter arrays is less than\n        ``tol``, calculation is taken to have converged.\n    maxiter : int\n        Maximum number of iterations to take calculations. Default is\n        ``sys.maxint``.\n    Returns\n    -------\n    a : (K,) shape array\n        Fixed-point estimated parameters for Dirichlet distribution.\"\"\"\n    logp = log(D).mean(axis=0)\n    a0 = _init_a(D)\n\n    # Start updating\n    if maxiter is None:\n        maxiter = MAXINT\n    for i in range(maxiter):\n        a1 = _ipsi(psi(a0.sum()) + logp)\n        # Much faster convergence than with the more obvious condition\n        # `norm(a1-a0) < tol`\n        if abs(loglikelihood(D, a1) - loglikelihood(D, a0)) < tol:\n            return a1\n        a0 = a1\n    raise NotConvergingError(\n        \"Failed to converge after {} iterations, values are {}.\".format(maxiter, a1)\n    )\n\n\ndef _meanprecision(D, tol=1e-7, maxiter=None):\n    \"\"\"Mean/precision method for MLE of Dirichlet distribution\n    Uses alternating estimations of mean and precision.\n    Parameters\n    ----------\n    D : (N, K) shape array\n        ``N`` is the number of observations, ``K`` is the number of\n        parameters for the Dirichlet distribution.\n    tol : float\n        If Euclidean distance between successive parameter arrays is less than\n        ``tol``, calculation is taken to have converged.\n    maxiter : int\n        Maximum number of iterations to take calculations. Default is\n        ``sys.maxint``.\n    Returns\n    -------\n    a : (K,) shape array\n        Estimated parameters for Dirichlet distribution.\"\"\"\n    D = D + 1e-9\n    logp = log(D).mean(axis=0)\n    a0 = _init_a(D)\n    s0 = a0.sum()\n    if s0 < 0:\n        a0 = a0 / s0\n        s0 = 1\n    elif s0 == 0:\n        a0 = ones(a0.shape) / len(a0)\n        s0 = 1\n    m0 = a0 / s0\n\n    # Start updating\n    if maxiter is None:\n        maxiter = MAXINT\n    for i in range(maxiter):\n        a1 = _fit_s(D, a0, logp, tol=tol)\n        s1 = sum(a1)\n        a1 = _fit_m(D, a1, logp, tol=tol)\n        m = a1 / s1\n        # Much faster convergence than with the more obvious condition\n        # `norm(a1-a0) < tol`\n        if abs(loglikelihood(D, a1) - loglikelihood(D, a0)) < tol:\n            return a1\n        a0 = a1\n    return a1\n    # raise NotConvergingError(\n    #     f\"Failed to converge after {maxiter} iterations, \" f\"values are {a1}.\"\n    # )\n\n\ndef _fit_s(D, a0, logp, tol=1e-7, maxiter=1000):\n    \"\"\"Update parameters via MLE of precision with fixed mean\n    Parameters\n    ----------\n    D : (N, K) shape array\n        ``N`` is the number of observations, ``K`` is the number of\n        parameters for the Dirichlet distribution.\n    a0 : (K,) shape array\n        Current parameters for Dirichlet distribution\n    logp : (K,) shape array\n        Mean of log-transformed D across N observations\n    tol : float\n        If Euclidean distance between successive parameter arrays is less than\n        ``tol``, calculation is taken to have converged.\n    maxiter : int\n        Maximum number of iterations to take calculations. Default is 1000.\n    Returns\n    -------\n    (K,) shape array\n        Updated parameters for Dirichlet distribution.\"\"\"\n    s1 = a0.sum()\n    m = a0 / s1\n    mlogp = (m * logp).sum()\n    for i in range(maxiter):\n        s0 = s1\n        g = psi(s1) - (m * psi(s1 * m)).sum() + mlogp\n        h = _trigamma(s1) - ((m ** 2) * _trigamma(s1 * m)).sum()\n\n        if g + s1 * h < 0:\n            s1 = 1 / (1 / s0 + g / h / (s0 ** 2))\n        if s1 <= 0:\n            s1 = s0 * exp(-g / (s0 * h + g))  # Newton on log s\n        if s1 <= 0:\n            s1 = 1 / (1 / s0 + g / ((s0 ** 2) * h + 2 * s0 * g))  # Newton on 1/s\n        if s1 <= 0:\n            s1 = s0 - g / h  # Newton\n        if s1 <= 0:\n            raise NotConvergingError(f\"Unable to update s from {s0}\")\n\n        a = s1 * m\n        if abs(s1 - s0) < tol:\n            return a\n\n    return a\n    # raise NotConvergingError(f\"Failed to converge after {maxiter} iterations, \" f\"s is {s1}\")\n\n\ndef _fit_m(D, a0, logp, tol=1e-7, maxiter=1000):\n    \"\"\"Update parameters via MLE of mean with fixed precision s\n    Parameters\n    ----------\n    D : (N, K) shape array\n        ``N`` is the number of observations, ``K`` is the number of\n        parameters for the Dirichlet distribution.\n    a0 : (K,) shape array\n        Current parameters for Dirichlet distribution\n    logp : (K,) shape array\n        Mean of log-transformed D across N observations\n    tol : float\n        If Euclidean distance between successive parameter arrays is less than\n        ``tol``, calculation is taken to have converged.\n    maxiter : int\n        Maximum number of iterations to take calculations. Default is 1000.\n    Returns\n    -------\n    (K,) shape array\n        Updated parameters for Dirichlet distribution.\"\"\"\n    s = a0.sum()\n    for i in range(maxiter):\n        m = a0 / s\n        a1 = _ipsi(logp + (m * (psi(a0) - logp)).sum())\n        a1 = a1 / a1.sum() * s\n\n        if norm(a1 - a0) < tol:\n            return a1\n        a0 = a1\n    return a1\n    # raise NotConvergingError(f\"Failed to converge after {maxiter} iterations, \" f\"s is {s}\")\n\n\ndef _init_a(D):\n    \"\"\"Initial guess for Dirichlet alpha parameters given data D\n    Parameters\n    ----------\n    D : (N, K) shape array\n        ``N`` is the number of observations, ``K`` is the number of\n        parameters for the Dirichlet distribution.\n    Returns\n    -------\n    (K,) shape array\n        Crude guess for parameters of Dirichlet distribution.\"\"\"\n    E = D.mean(axis=0)\n    E2 = (D ** 2).mean(axis=0)\n    return ((E[0] - E2[0]) / ((E2[0] - E[0] ** 2) + 1e-9 ) * E)\n\n\ndef _ipsi(y, tol=1.48e-9, maxiter=10):\n    \"\"\"Inverse of psi (digamma) using Newton's method. For the purposes\n    of Dirichlet MLE, since the parameters a[i] must always\n    satisfy a > 0, we define ipsi :: R -> (0,inf).\n\n    Parameters\n    ----------\n    y : (K,) shape array\n        y-values of psi(x)\n    tol : float\n        If Euclidean distance between successive parameter arrays is less than\n        ``tol``, calculation is taken to have converged.\n    maxiter : int\n        Maximum number of iterations to take calculations. Default is 10.\n    Returns\n    -------\n    (K,) shape array\n        Approximate x for psi(x).\"\"\"\n    y = asanyarray(y, dtype=\"float\")\n    x0 = np.piecewise(\n        y,\n        [y >= -2.22, y < -2.22],\n        [(lambda x: exp(x) + 0.5), (lambda x: -1 / (x + euler))],\n    )\n    for i in range(maxiter):\n        x1 = x0 - (psi(x0) - y) / _trigamma(x0)\n        if norm(x1 - x0) < tol:\n            return x1\n        x0 = x1\n    return x1\n    # raise NotConvergingError(f\"Failed to converge after {maxiter} iterations, \" f\"value is {x1}\")\n\n\ndef _trigamma(x):\n    return polygamma(1, x)\n\n\nclass MLE_Dirchlet(object):\n    def __init__(self, trainQuads, num_r, k, timespan,\n                 tol=1e-7, method=\"meanprecision\", maxiter=10000):\n        \"\"\"\n        num_r:int,  number of relations.\n        k:int, statistics recent K historical snapshots.\n        timespan:int, 24 for ICEWS, 1 for WIKI and YAGO\n        tol : float, If Euclidean distance between successive parameter arrays is less than\n        ``tol``, calculation is taken to have converged.\n        method : string, One of ``'fixedpoint'`` and ``'meanprecision'``, designates method by\n        which to find MLE Dirichlet distribution. Default is ``'meanprecision'``, which is faster.\n        maxiter : int, Maximum number of iterations to take calculations. Default is ``sys.maxint``.\n        \"\"\"\n        self.num_r = num_r\n        self.k = k\n        self.timespan = timespan\n        self.tol = tol\n        self.method = method\n        self.maxiter = maxiter\n        self.entity_occ_times = self.get_entity_occ_times(trainQuads) # The number of occurrences of the entity at each time in the training set\n        self.relations_observed_data = self.get_relations_observed_data(trainQuads)\n        self.alphas = self.mle_dirchlet()\n\n    def get_entity_occ_times(self, trainQuads):\n        entity_occ_times = {}  # key -> entity, value -> dict [key: time, value: times]\n        for quad in trainQuads:\n            for entity in [quad[0], quad[2]]:\n                if entity in entity_occ_times.keys():\n                    if quad[3] in entity_occ_times[entity].keys():\n                        entity_occ_times[entity][quad[3]] += 1\n                    else:\n                        entity_occ_times[entity][quad[3]] = 1\n                else:\n                    entity_occ_times[entity] = {quad[3]: 1, }\n        return entity_occ_times\n\n    def get_relations_observed_data(self, trainQuads):\n        relations_observed_data = {}  # key: relation, value: list of observed data\n        for quad in trainQuads:\n            if quad[1] not in relations_observed_data.keys():\n                relations_observed_data[quad[1]] = []\n            observed = np.zeros([self.k+1])\n            occ_times = self.entity_occ_times[quad[2]]\n            for time in occ_times.keys():\n                if time >= quad[3]:\n                    continue\n                observed[(quad[3] - time) // self.timespan] = occ_times[time]\n            relations_observed_data[quad[1]].append(observed)\n\n            # reversed_r = quad[1] + 1 + self.num_r\n            # if reversed_r not in relations_observed_data.keys():\n            #     relations_observed_data[reversed_r] = []\n            # reversed_r_observed = np.zeros([self.k+1])\n            # occ_times = self.entity_occ_times[quad[0]]\n            # for time in occ_times.keys():\n            #     if time >= quad[3]:\n            #         continue\n            #     reversed_r_observed[(quad[3] - time) // self.timespan] = occ_times[time]\n            # relations_observed_data[reversed_r].append(reversed_r_observed)\n        return relations_observed_data\n\n    def mle_dirchlet(self):\n        alphas = {}  # key: relation, value: alpha array\n        with tqdm(total=len(self.relations_observed_data)) as bar:\n            for r, observed in self.relations_observed_data.items():\n                alphas[r] = mle(np.array(observed), tol=self.tol, method=self.method, maxiter=self.maxiter)\n                bar.update(1)\n        return alphas\n\n\nclass Dirichlet(object):\n    def __init__(self, alphas, k):\n        \"\"\"alphas: Get from MLE_Dirchlet\n        k: int, statistics recent K historical snapshots.\n        \"\"\"\n        self.k = k\n        self.distributions = {}\n        for rel, alpha in alphas.items():\n            self.distributions[rel] = dirichlet(alpha)\n\n    def __call__(self, rel, dt):\n        if dt >= self.k:\n            return 0.0\n        p_dt = self.distributions[rel].rvs(1)[0][dt]\n        return p_dt"
  },
  {
    "path": "modules/timetraveler_environment.py",
    "content": "\"\"\"\nTimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting\nReference:\n- https://github.com/JHL-HUST/TITer/blob/master/model/environment.py\nHaohai Sun, Jialun Zhong, Yunpu Ma, Zhen Han, Kun He.\nTimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting EMNLP 2021\n\"\"\"\n\nimport networkx as nx\nfrom collections import defaultdict\nimport numpy as np\nimport torch\n\nclass Env(object):\n    def __init__(self, examples, config, state_action_space=None):\n        \"\"\"Temporal Knowledge Graph Environment.\n        examples: quadruples (subject, relation, object, timestamps);\n        config: config dict;\n        state_action_space: Pre-processed action space;\n        \"\"\"\n        self.config = config\n        self.num_rel = config['num_rel']\n        self.graph, self.label2nodes = self.build_graph(examples)\n        # [0, num_rel) -> normal relations; num_rel -> stay in place，(num_rel, num_rel * 2] reversed relations.\n        self.NO_OP = self.num_rel  # Stay in place; No Operation\n        self.ePAD = config['num_ent']  # Padding entity\n        self.rPAD = config['num_rel'] * 2 # + 1  # Padding relation.\n        self.tPAD = 0  # Padding time\n        self.state_action_space = state_action_space  # Pre-processed action space\n        if state_action_space:\n            self.state_action_space_key = self.state_action_space.keys()\n\n    def build_graph(self, examples):\n        \"\"\"The graph node is represented as (entity, time), and the edges are directed and labeled relation.\n        return:\n            graph: nx.MultiDiGraph;\n            label2nodes: a dict [keys -> entities, value-> nodes in the graph (entity, time)]\n        \"\"\"\n        graph = nx.MultiDiGraph()\n        label2nodes = defaultdict(set)\n        examples.sort(key=lambda x: x[3], reverse=True)  # Reverse chronological order\n        for example in examples:\n            src = example[0]\n            rel = example[1]\n            dst = example[2]\n            time = example[3]\n\n            # Add the nodes and edges of the current quadruple\n            src_node = (src, time)\n            dst_node = (dst, time)\n            if src_node not in label2nodes[src]:\n                graph.add_node(src_node, label=src)\n            if dst_node not in label2nodes[dst]:\n                graph.add_node(dst_node, label=dst)\n\n            graph.add_edge(src_node, dst_node, relation=rel)\n            # graph.add_edge(dst_node, src_node, relation=rel+self.num_rel+1) #REMOVED by JULIA \n\n            label2nodes[src].add(src_node)\n            label2nodes[dst].add(dst_node)\n        return graph, label2nodes\n\n    def get_state_actions_space_complete(self, entity, time, current_=False, max_action_num=None):\n        \"\"\"Get the action space of the current state.\n        Args:\n            entity: The entity of the current state;\n            time: Maximum timestamp for candidate actions;\n            current_: Can the current time of the event be used;\n            max_action_num: Maximum number of events stored;\n        Return:\n            numpy array，shape: [number of events，3], (relation, dst, time)\n        \"\"\"\n        if self.state_action_space:\n            if (entity, time, current_) in self.state_action_space_key:\n                return self.state_action_space[(entity, time, current_)]\n        nodes = self.label2nodes[entity].copy()\n        if current_:\n            # Delete future events, you can see current events, before query time\n            nodes = list(filter((lambda x: x[1] <= time), nodes))\n        else:\n            # No future events, no current events\n            nodes = list(filter((lambda x: x[1] < time), nodes))\n        nodes.sort(key=lambda x: x[1], reverse=True)\n        actions_space = []\n        i = 0\n        for node in nodes:\n            for src, dst, rel in self.graph.out_edges(node, data=True):\n                actions_space.append((rel['relation'], dst[0], dst[1]))\n                i += 1\n                if max_action_num and i >= max_action_num:\n                    break\n            if max_action_num and i >= max_action_num:\n                break\n        return np.array(list(actions_space), dtype=np.dtype('int32'))\n\n    def next_actions(self, entites, times, query_times, max_action_num=200, first_step=False):\n        \"\"\"Get the current action space. There must be an action that stays at the current position in the action space.\n        Args:\n            entites: torch.tensor, shape: [batch_size], the entity where the agent is currently located;\n            times: torch.tensor, shape: [batch_size], the timestamp of the current entity;\n            query_times: torch.tensor, shape: [batch_size], the timestamp of query;\n            max_action_num: The size of the action space;\n            first_step: Is it the first step for the agent.\n        Return: torch.tensor, shape: [batch_size, max_action_num, 3], (relation, entity, time)\n        \"\"\"\n        if self.config['cuda']:\n            entites = entites.cpu()\n            times = times.cpu()\n            query_times = times.cpu()\n\n        entites = entites.numpy()\n        times = times.numpy()\n        query_times = query_times.numpy()\n\n        actions = self.get_padd_actions(entites, times, query_times, max_action_num, first_step)\n\n        if self.config['cuda']:\n            actions = torch.tensor(actions, dtype=torch.long, device='cuda')\n        else:\n            actions = torch.tensor(actions, dtype=torch.long)\n        return actions\n\n    def get_padd_actions(self, entites, times, query_times, max_action_num=200, first_step=False):\n        \"\"\"Construct the model input array.\n        If the optional actions are greater than the maximum number of actions, then sample,\n        otherwise all are selected, and the insufficient part is pad.\n        \"\"\"\n        actions = np.ones((entites.shape[0], max_action_num, 3), dtype=np.dtype('int32'))\n        actions[:, :, 0] *= self.rPAD\n        actions[:, :, 1] *= self.ePAD\n        actions[:, :, 2] *= self.tPAD\n        for i in range(entites.shape[0]):\n            # NO OPERATION\n            actions[i, 0, 0] = self.NO_OP\n            actions[i, 0, 1] = entites[i]\n            actions[i, 0, 2] = times[i]\n\n            if times[i] == query_times[i]:\n                action_array = self.get_state_actions_space_complete(entites[i], times[i], False)\n            else:\n                action_array = self.get_state_actions_space_complete(entites[i], times[i], True)\n\n            if action_array.shape[0] == 0:\n                continue\n\n            # Whether to keep the action NO_OPERATION\n            start_idx = 1\n            if first_step:\n                # The first step cannot stay in place\n                start_idx = 0\n\n            if action_array.shape[0] > (max_action_num - start_idx):\n                # Sample. Take the latest events.\n                actions[i, start_idx:, ] = action_array[:max_action_num-start_idx]\n            else:\n                actions[i, start_idx:action_array.shape[0]+start_idx, ] = action_array\n        return actions"
  },
  {
    "path": "modules/timetraveler_episode.py",
    "content": "\"\"\"\nTimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting\nReference:\n- https://github.com/JHL-HUST/TITer/blob/master/model/episode.py\nHaohai Sun, Jialun Zhong, Yunpu Ma, Zhen Han, Kun He.\nTimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting EMNLP 2021\n\"\"\"\n\nimport torch\nimport torch.nn as nn\n\nclass Episode(nn.Module):\n    def __init__(self, env, agent, config):\n        super(Episode, self).__init__()\n        self.config = config\n        self.env = env\n        self.agent = agent\n        self.path_length = config['path_length']\n        self.num_rel = config['num_rel']\n        self.max_action_num = config['max_action_num']\n\n    def forward(self, query_entities, query_timestamps, query_relations):\n        \"\"\"\n        Args:\n            query_entities: [batch_size]\n            query_timestamps: [batch_size]\n            query_relations: [batch_size]\n        Return:\n            all_loss: list\n            all_logits: list\n            all_actions_idx: list\n            current_entities: torch.tensor, [batch_size]\n            current_timestamps: torch.tensor, [batch_size]\n        \"\"\"\n        query_entities_embeds = self.agent.ent_embs(query_entities, torch.zeros_like(query_timestamps))\n        query_relations_embeds = self.agent.rel_embs(query_relations)\n\n        current_entites = query_entities\n        current_timestamps = query_timestamps\n        prev_relations = torch.ones_like(query_relations) * self.num_rel  # NO_OP\n\n        all_loss = []\n        all_logits = []\n        all_actions_idx = []\n\n        self.agent.policy_step.set_hiddenx(query_relations.shape[0])\n        for t in range(self.path_length):\n            if t == 0:\n                first_step = True\n            else:\n                first_step = False\n\n            action_space = self.env.next_actions(\n                current_entites,\n                current_timestamps,\n                query_timestamps,\n                self.max_action_num,\n                first_step\n            )\n\n            loss, logits, action_id = self.agent(\n                prev_relations,\n                current_entites,\n                current_timestamps,\n                query_relations_embeds,\n                query_entities_embeds,\n                query_timestamps,\n                action_space,\n            )\n\n            chosen_relation = torch.gather(action_space[:, :, 0], dim=1, index=action_id).reshape(action_space.shape[0])\n            chosen_entity = torch.gather(action_space[:, :, 1], dim=1, index=action_id).reshape(action_space.shape[0])\n            chosen_entity_timestamps = torch.gather(action_space[:, :, 2], dim=1, index=action_id).reshape(action_space.shape[0])\n\n            all_loss.append(loss)\n            all_logits.append(logits)\n            all_actions_idx.append(action_id)\n\n            current_entites = chosen_entity\n            current_timestamps = chosen_entity_timestamps\n            prev_relations = chosen_relation\n\n        return all_loss, all_logits, all_actions_idx, current_entites, current_timestamps\n\n    def beam_search(self, query_entities, query_timestamps, query_relations):\n        \"\"\"\n        Args:\n            query_entities: [batch_size]\n            query_timestamps: [batch_size]\n            query_relations: [batch_size]\n        Return:\n            current_entites: [batch_size, test_rollouts_num]\n            beam_prob: [batch_size, test_rollouts_num]\n        \"\"\"\n        batch_size = query_entities.shape[0]\n        query_entities_embeds = self.agent.ent_embs(query_entities, torch.zeros_like(query_timestamps))\n        query_relations_embeds = self.agent.rel_embs(query_relations)\n\n        self.agent.policy_step.set_hiddenx(batch_size)\n\n        # In the first step, if rollouts_num is greater than the maximum number of actions, select all actions\n        current_entites = query_entities\n        current_timestamps = query_timestamps\n        prev_relations = torch.ones_like(query_relations) * self.num_rel  # NO_OP\n        action_space = self.env.next_actions(current_entites, current_timestamps,\n                                             query_timestamps, self.max_action_num, True)\n        loss, logits, action_id = self.agent(\n            prev_relations,\n            current_entites,\n            current_timestamps,\n            query_relations_embeds,\n            query_entities_embeds,\n            query_timestamps,\n            action_space\n        )  # logits.shape: [batch_size, max_action_num]\n\n        action_space_size = action_space.shape[1]\n        if self.config['beam_size'] > action_space_size:\n            beam_size = action_space_size\n        else:\n            beam_size = self.config['beam_size']\n        beam_log_prob, top_k_action_id = torch.topk(logits, beam_size, dim=1)  # beam_log_prob.shape [batch_size, beam_size]\n        beam_log_prob = beam_log_prob.reshape(-1)  # [batch_size * beam_size]\n\n        current_entites = torch.gather(action_space[:, :, 1], dim=1, index=top_k_action_id).reshape(-1)  # [batch_size * beam_size]\n        current_timestamps = torch.gather(action_space[:, :, 2], dim=1, index=top_k_action_id).reshape(-1) # [batch_size * beam_size]\n        prev_relations = torch.gather(action_space[:, :, 0], dim=1, index=top_k_action_id).reshape(-1)  # [batch_size * beam_size]\n        self.agent.policy_step.hx = self.agent.policy_step.hx.repeat(1, 1, beam_size).reshape([batch_size * beam_size, -1])  # [batch_size * beam_size, state_dim]\n        self.agent.policy_step.cx = self.agent.policy_step.cx.repeat(1, 1, beam_size).reshape([batch_size * beam_size, -1])  # [batch_size * beam_size, state_dim]\n\n        beam_tmp = beam_log_prob.repeat([action_space_size, 1]).transpose(1, 0)  # [batch_size * beam_size, max_action_num]\n        for t in range(1, self.path_length):\n            query_timestamps_roll = query_timestamps.repeat(beam_size, 1).permute(1, 0).reshape(-1)\n            query_entities_embeds_roll = query_entities_embeds.repeat(1, 1, beam_size)\n            query_entities_embeds_roll = query_entities_embeds_roll.reshape([batch_size * beam_size, -1])  # [batch_size * beam_size, ent_dim]\n            query_relations_embeds_roll = query_relations_embeds.repeat(1, 1, beam_size)\n            query_relations_embeds_roll = query_relations_embeds_roll.reshape([batch_size * beam_size, -1])  # [batch_size * beam_size, rel_dim]\n\n            action_space = self.env.next_actions(current_entites, current_timestamps,\n                                                     query_timestamps_roll, self.max_action_num)\n\n            loss, logits, action_id = self.agent(\n                prev_relations,\n                current_entites,\n                current_timestamps,\n                query_relations_embeds_roll,\n                query_entities_embeds_roll,\n                query_timestamps_roll,\n                action_space\n            ) # logits.shape [bs * rollouts_num, max_action_num]\n\n            hx_tmp = self.agent.policy_step.hx.reshape(batch_size, beam_size, -1)\n            cx_tmp = self.agent.policy_step.cx.reshape(batch_size, beam_size, -1)\n\n            beam_tmp = beam_log_prob.repeat([action_space_size, 1]).transpose(1, 0) # [batch_size * beam_size, max_action_num]\n            beam_tmp += logits\n            beam_tmp = beam_tmp.reshape(batch_size, -1)  # [batch_size, beam_size * max_actions_num]\n\n            if action_space_size * beam_size >= self.config['beam_size']:\n                beam_size = self.config['beam_size']\n            else:\n                beam_size = action_space_size * beam_size\n\n            top_k_log_prob, top_k_action_id = torch.topk(beam_tmp, beam_size, dim=1)  # [batch_size, beam_size]\n            offset = top_k_action_id // action_space_size  # [batch_size, beam_size]\n            offset = offset.unsqueeze(-1).repeat(1, 1, self.config['state_dim'])  # [batch_size, beam_size]\n            self.agent.policy_step.hx = torch.gather(hx_tmp, dim=1, index=offset)\n            self.agent.policy_step.hx = self.agent.policy_step.hx.reshape([batch_size * beam_size, -1])\n            self.agent.policy_step.cx = torch.gather(cx_tmp, dim=1, index=offset)\n            self.agent.policy_step.cx = self.agent.policy_step.cx.reshape([batch_size * beam_size, -1])\n\n            current_entites = torch.gather(action_space[:, :, 1].reshape(batch_size, -1), dim=1, index=top_k_action_id).reshape(-1)\n            current_timestamps = torch.gather(action_space[:, :, 2].reshape(batch_size, -1), dim=1, index=top_k_action_id).reshape(-1)\n            prev_relations = torch.gather(action_space[:, :, 0].reshape(batch_size, -1), dim=1, index=top_k_action_id).reshape(-1)\n\n            beam_log_prob = top_k_log_prob.reshape(-1)  # [batch_size * beam_size]\n\n        return action_space[:, :, 1].reshape(batch_size, -1), beam_tmp"
  },
  {
    "path": "modules/timetraveler_policygradient.py",
    "content": "\"\"\"\nTimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting\nReference:\n- https://github.com/JHL-HUST/TITer/blob/master/model/policyGradient.py and\nhttps://github.com/JHL-HUST/TITer/blob/master/model/baseline.py\nHaohai Sun, Jialun Zhong, Yunpu Ma, Zhen Han, Kun He.\nTimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting EMNLP 2021\n\"\"\"\n\nimport torch\nimport numpy as np\nimport math\n\nclass ReactiveBaseline(object):\n    def __init__(self, config, update_rate):\n        self.update_rate = update_rate\n        self.value = torch.zeros(1)\n        if config['cuda']:\n            self.value = self.value.cuda()\n\n    def get_baseline_value(self):\n        return self.value\n\n    def update(self, target):\n        self.value = torch.add((1 - self.update_rate) * self.value, self.update_rate * target)\n\nclass PG(object):\n    def __init__(self, config):\n        self.config = config\n        self.positive_reward = 1.0\n        self.negative_reward = 0.0\n        self.baseline = ReactiveBaseline(config, config['lambda'])\n        self.now_epoch = 0\n\n    def get_reward(self, current_entites, answers):\n         positive = torch.ones_like(current_entites, dtype=torch.float32) * self.positive_reward\n         negative = torch.ones_like(current_entites, dtype=torch.float32) * self.negative_reward\n         reward = torch.where(current_entites == answers, positive, negative)\n         return reward\n\n    def calc_cum_discounted_reward(self, rewards):\n        running_add = torch.zeros([rewards.shape[0]])\n        cum_disc_reward = torch.zeros([rewards.shape[0], self.config['path_length']])\n        if self.config['cuda']:\n            running_add = running_add.cuda()\n            cum_disc_reward = cum_disc_reward.cuda()\n\n        cum_disc_reward[:, self.config['path_length'] - 1] = rewards\n        for t in reversed(range(self.config['path_length'])):\n            running_add = self.config['gamma'] * running_add + cum_disc_reward[:, t]\n            cum_disc_reward[:, t] = running_add\n        return cum_disc_reward\n\n    def entropy_reg_loss(self, all_logits):\n        all_logits = torch.stack(all_logits, dim=2)\n        entropy_loss = - torch.mean(torch.sum(torch.mul(torch.exp(all_logits), all_logits), dim=1))\n        return entropy_loss\n\n    def calc_reinforce_loss(self, all_loss, all_logits, cum_discounted_reward):\n        loss = torch.stack(all_loss, dim=1)\n        base_value = self.baseline.get_baseline_value()\n        final_reward = cum_discounted_reward - base_value\n\n        reward_mean = torch.mean(final_reward)\n        reward_std = torch.std(final_reward) + 1e-6\n        final_reward = torch.div(final_reward - reward_mean, reward_std)\n\n        loss = torch.mul(loss, final_reward)\n        entropy_loss = self.config['ita'] * math.pow(self.config['zita'], self.now_epoch) * self.entropy_reg_loss(all_logits)\n\n        total_loss = torch.mean(loss) - entropy_loss\n        return total_loss"
  },
  {
    "path": "modules/timetraveler_trainertester.py",
    "content": "import torch\nimport json\nimport os\nimport tqdm\nimport numpy as np\n\nclass Trainer(object):\n    def __init__(self, model, pg, optimizer, args, distribution=None):\n        self.model = model\n        self.pg = pg\n        self.optimizer = optimizer\n        self.args = args\n        self.distribution = distribution\n\n    def train_epoch(self, dataloader, ntriple):\n        self.model.train()\n        total_loss = 0.0\n        total_reward = 0.0\n        counter = 0\n        with tqdm.tqdm(total=ntriple, unit='ex') as bar:\n            bar.set_description('Train')\n            for src_batch, rel_batch, dst_batch, time_batch, time_orig_batch in dataloader:\n                if self.args.cuda:\n                    src_batch = src_batch.cuda()\n                    rel_batch = rel_batch.cuda()\n                    dst_batch = dst_batch.cuda()\n                    time_batch = time_batch.cuda()\n\n                all_loss, all_logits, _, current_entities, current_time = self.model(src_batch, time_batch, rel_batch)\n\n                reward = self.pg.get_reward(current_entities, dst_batch)\n                if self.args.reward_shaping:\n                    # reward shaping\n                    delta_time = time_batch - current_time\n                    p_dt = []\n\n                    for i in range(rel_batch.shape[0]):\n                        rel = rel_batch[i].item()\n                        dt = delta_time[i].item()\n                        p_dt.append(self.distribution(rel, dt // self.args.time_span))\n\n                    p_dt = torch.tensor(p_dt)\n                    if self.args.cuda:\n                        p_dt = p_dt.cuda()\n                    shaped_reward = (1 + p_dt) * reward\n                    cum_discounted_reward = self.pg.calc_cum_discounted_reward(shaped_reward)\n                else:\n                    cum_discounted_reward = self.pg.calc_cum_discounted_reward(reward)\n                reinfore_loss = self.pg.calc_reinforce_loss(all_loss, all_logits, cum_discounted_reward)\n                self.pg.baseline.update(torch.mean(cum_discounted_reward))\n                self.pg.now_epoch += 1\n\n                self.optimizer.zero_grad()\n                reinfore_loss.backward()\n                if self.args.clip_gradient:\n                    total_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip_gradient)\n                self.optimizer.step()\n\n                total_loss += reinfore_loss\n                total_reward += torch.mean(reward)\n                counter += 1\n                bar.update(self.args.batch_size)\n                bar.set_postfix(loss='%.4f' % reinfore_loss, reward='%.4f' % torch.mean(reward).item())\n        return total_loss / counter, total_reward / counter\n\n    def save_model(self, save_path, checkpoint_path='checkpoint.pth'):\n        \"\"\"Save the parameters of the model and the optimizer,\"\"\"\n        argparse_dict = vars(self.args)\n\n        with open(os.path.join(save_path, 'config.json'), 'w') as fjson:\n            json.dump(argparse_dict, fjson)\n\n        torch.save({\n            'model_state_dict': self.model.state_dict(),\n            'optimizer_state_dict': self.optimizer.state_dict()},\n            os.path.join(save_path, checkpoint_path)\n        )\n\nclass Tester(object):\n    def __init__(self, model, args, train_entities, RelEntCooccurrence, metric='mrr'):\n        self.model = model\n        self.args = args\n        self.train_entities = train_entities\n        self.RelEntCooccurrence = RelEntCooccurrence\n        self.metric = metric\n\n\n    def get_rank(self, score, answer, entities_space, num_ent):\n        \"\"\"Get the location of the answer, if the answer is not in the array,\n        the ranking will be the total number of entities.\n        Args:\n            score: list, entity score\n            answer: int, the ground truth entity\n            entities_space: corresponding entity with the score\n            num_ent: the total number of entities\n        Return: the rank of the ground truth.\n        \"\"\"\n        if answer not in entities_space:\n            rank = num_ent\n        else:\n            answer_prob = score[entities_space.index(answer)]\n            score.sort(reverse=True)\n            rank = score.index(answer_prob) + 1\n        return rank\n\n    def test(self, dataloader, ntriple, num_nodes, neg_sampler, evaluator, split_mode='test'):\n        \"\"\"Get time-aware filtered metrics(MRR, Hits@1/3/10).\n        Args:\n            ntriple: number of the test examples.\n            skip_dict: time-aware filter. Get from baseDataset\n            num_ent: number of the entities.\n        Return: a dict (key -> MRR/HITS@1/HITS@3/HITS@10, values -> float)\n        \"\"\"\n        self.model.eval()\n        logs = []\n        perf_list =[]\n        with torch.no_grad():            \n            with tqdm.tqdm(total=ntriple, unit='ex') as bar:\n                current_time = 0\n                cache_IM = {}  # key -> entity, values: list, IM representations of the co-o relations.\n                for src_batch, rel_batch, dst_batch, time_batch,time_orig_batch in dataloader:\n                    batch_size = dst_batch.size(0)\n\n                    if self.args.IM:\n                        src = src_batch[0].item()\n                        rel = rel_batch[0].item()\n                        dst = dst_batch[0].item()\n                        time = time_batch[0].item()\n\n                        # representation update\n                        if current_time != time:\n                            current_time = time\n                            for k, v in cache_IM.items():\n                                ims = torch.stack(v, dim=0)\n                                self.model.agent.update_entity_embedding(k, ims, self.args.mu)\n                            cache_IM = {}\n\n                        if src not in self.train_entities and rel in self.RelEntCooccurrence['subject'].keys():\n                            im = self.model.agent.get_im_embedding(list(self.RelEntCooccurrence['subject'][rel]))\n                            if src in cache_IM.keys():\n                                cache_IM[src].append(im)\n                            else:\n                                cache_IM[src] = [im]\n\n                            # prediction shift\n                            self.model.agent.entities_embedding_shift(src, im, self.args.mu)\n\n                    if self.args.cuda:\n                        src_batch = src_batch.cuda()\n                        rel_batch = rel_batch.cuda()\n                        dst_batch = dst_batch.cuda()\n                        time_batch = time_batch.cuda()\n\n                    current_entities, beam_prob = \\\n                        self.model.beam_search(src_batch, time_batch, rel_batch)\n\n                    if self.args.IM and src not in self.train_entities:\n                        # We do this\n                        # because events that happen at the same time in the future cannot see each other.\n                        self.model.agent.back_entities_embedding(src)\n\n                    if self.args.cuda:\n                        current_entities = current_entities.cpu()\n                        beam_prob = beam_prob.cpu()\n\n                    current_entities = current_entities.numpy()\n                    beam_prob = beam_prob.numpy()\n\n                    MRR = 0\n                    for i in range(batch_size):\n                        candidate_answers = current_entities[i]\n                        candidate_score = beam_prob[i]\n                        scores_eval_paper_authors = -10000000000.0*np.ones(num_nodes, dtype=np.float32)\n                        # sort by score from largest to smallest\n                        idx = np.argsort(-candidate_score)\n                        candidate_answers = candidate_answers[idx]\n                        candidate_score = candidate_score[idx]\n\n                        # remove duplicate entities\n                        candidate_answers, idx = np.unique(candidate_answers, return_index=True)\n                        candidate_answers = list(candidate_answers)\n                        candidate_score = list(candidate_score[idx])\n\n                        src = src_batch[i].item()\n                        rel = rel_batch[i].item()\n                        dst = dst_batch[i].item()\n                        time = time_batch[i].item()\n                        time_orig = time_orig_batch[i].item()\n\n                        if np.max(candidate_answers) >= num_nodes:\n                            if candidate_answers[-1] == num_nodes:\n                                logging_score_answers = candidate_answers[0:-1]\n                                logging_score = candidate_score[0:-1]\n                            else:\n                                print(\"Problem with the score ids\", np.max(candidate_answers))\n                        else:\n                            logging_score_answers = candidate_answers\n                            logging_score = candidate_score\n\n                        neg_samples_batch = neg_sampler.query_batch(np.expand_dims(np.array(src), axis=0),\n                                                np.expand_dims(np.array(dst), axis=0), \n                                                np.expand_dims(np.array(time_orig), axis=0), \n                                                edge_type=np.expand_dims(np.array(rel), axis=0), \n                                                split_mode=split_mode)\n                        pos_samples_batch = dst\n                        # get inductive inference performance.\n                        # Only count the results of the example containing new entities.\n                        if self.args.test_inductive and src in self.train_entities and dst in self.train_entities:\n                            continue\n\n                        # filter = skip_dict[(src, rel, time)]  # a set of ground truth entities\n                        # tmp_entities = candidate_answers.copy()\n                        # tmp_prob = candidate_score.copy()\n                        # # time-aware filter\n                        # for j in range(len(tmp_entities)):\n                        #     if tmp_entities[j] in filter and tmp_entities[j] != dst:\n                        #         candidate_answers.remove(tmp_entities[j])\n                        #         candidate_score.remove(tmp_prob[j])\n\n                        # ranking_raw = self.get_rank(candidate_score, dst, candidate_answers, num_ent)\n                        scores_eval_paper_authors[logging_score_answers] = logging_score\n                        # logs.append({\n                        #     'MRR': 1.0 / ranking_raw,\n                        #     'HITS@1': 1.0 if ranking_raw <= 1 else 0.0,\n                        #     'HITS@3': 1.0 if ranking_raw <= 3 else 0.0,\n                        #     'HITS@10': 1.0 if ranking_raw <= 10 else 0.0,\n                        # })\n                        neg_scores = scores_eval_paper_authors[neg_samples_batch]\n                        pos_scores = scores_eval_paper_authors[pos_samples_batch]\n                        input_dict = {\n                            \"y_pred_pos\": np.array([pos_scores]),\n                            \"y_pred_neg\": np.array(neg_scores),\n                            \"eval_metric\": [self.metric],\n                        }\n                        perf_list.append(evaluator.eval(input_dict)[self.metric])\n\n\n                    bar.update(batch_size)\n                    bar.set_postfix(MRR='{}'.format(perf_list[-1] / batch_size))\n        metrics = {}\n        metrics[self.metric] = np.mean(perf_list)\n        # for metric in logs[0].keys():\n        #     metrics[metric] = sum([log[metric] for log in logs]) / len(logs)\n        return metrics\n    \n\ndef getRelEntCooccurrence(quadruples, num_rels):\n    \"\"\"Used for Inductive-Mean. Get co-occurrence in the training set.\n    https://github.com/JHL-HUST/TITer/blob/master/dataset/baseDataset.py\n    from Timetraveler\n    return:\n        {'subject': a dict[key -> relation, values -> a set of co-occurrence subject entities],\n            'object': a dict[key -> relation, values -> a set of co-occurrence object entities],}\n    \"\"\"\n    relation_entities_s = {}\n    relation_entities_o = {}\n    for ex in quadruples:\n        s, r, o = ex[0], ex[1], ex[2]\n        reversed_r = r + num_rels + 1\n        if r not in relation_entities_s.keys():\n            relation_entities_s[r] = set()\n        relation_entities_s[r].add(s)\n        if r not in relation_entities_o.keys():\n            relation_entities_o[r] = set()\n        relation_entities_o[r].add(o)\n\n        if reversed_r not in relation_entities_s.keys():\n            relation_entities_s[reversed_r] = set()\n        relation_entities_s[reversed_r].add(o)\n        if reversed_r not in relation_entities_o.keys():\n            relation_entities_o[reversed_r] = set()\n        relation_entities_o[reversed_r].add(s)\n    return {'subject': relation_entities_s, 'object': relation_entities_o}\n"
  },
  {
    "path": "modules/tkg_utils.py",
    "content": "\nfrom itertools import groupby\nfrom operator import itemgetter\nfrom collections import defaultdict\nimport sys\nimport argparse\nimport numpy as np\n\ndef get_args_timetraveler(args=None):\n    \"\"\" Parse the arguments for \"timetraveler\" model\n    \"\"\"\n    parser = argparse.ArgumentParser(\n        description='Timetraveler',\n        usage='main.py [<args>] [-h | --help]'\n    )\n    parser.add_argument('--seed', type=int, help='Random seed', default=1)\n    parser.add_argument('--cuda', action='store_true', help='whether to use GPU or not.')\n    parser.add_argument('--do_train', default=True, action='store_true', help='whether to train.')\n    parser.add_argument('--do_test', default=True, action='store_true', help='whether to test.')\n\n    # Train Params\n    parser.add_argument('--batch_size', default=512, type=int, help='training batch size.')\n    parser.add_argument('--max_epochs', default=400, type=int, help='max training epochs.') #400\n    parser.add_argument('--num_workers', default=8, type=int, help='workers number used for dataloader.')\n    parser.add_argument('--valid_epoch', default=30, type=int, help='validation frequency.') # 30\n    parser.add_argument('--lr', default=0.001, type=float, help='learning rate.')\n    parser.add_argument('--save_epoch', default=30, type=int, help='model saving frequency.')\n    parser.add_argument('--clip_gradient', default=10.0, type=float, help='for gradient crop.')\n\n    # Test Params\n    parser.add_argument('--test_batch_size', default=1, type=int,\n                        help='test batch size, it needs to be set to 1 when using IM module.')\n    parser.add_argument('--beam_size', default=100, type=int, help='the beam number of the beam search.')\n    parser.add_argument('--test_inductive', action='store_true', help='whether to verify inductive inference performance.')\n    parser.add_argument('--IM', default=True, action='store_true', help='whether to use IM module.')\n    parser.add_argument('--mu', default=0.1, type=float, help='the hyperparameter of IM module.')\n\n    # Agent Params\n    parser.add_argument('--ent_dim', default=80, type=int, help='Embedding dimension of the entities')\n    parser.add_argument('--rel_dim', default=100, type=int, help='Embedding dimension of the relations')\n    parser.add_argument('--state_dim', default=100, type=int, help='dimension of the LSTM hidden state')\n    parser.add_argument('--hidden_dim', default=100, type=int, help='dimension of the MLP hidden layer')\n    parser.add_argument('--time_dim', default=20, type=int, help='Embedding dimension of the timestamps')\n    parser.add_argument('--entities_embeds_method', default='dynamic', type=str,\n                        help='representation method of the entities, dynamic or static')\n\n    # Environment Params\n    parser.add_argument('--state_actions_path', default='state_actions_space.pkl', type=str,\n                        help='the file stores preprocessed candidate action array.')\n\n    # Episode Params\n    parser.add_argument('--path_length', default=3, type=int, help='the agent search path length.')\n    parser.add_argument('--max_action_num', default=30, type=int, help='the max candidate actions number.')\n\n    # Policy Gradient Params\n    parser.add_argument('--Lambda', default=0.0, type=float, help='update rate of baseline.')\n    parser.add_argument('--Gamma', default=0.95, type=float, help='discount factor of Bellman Eq.')\n    parser.add_argument('--Ita', default=0.01, type=float, help='regular proportionality constant.')\n    parser.add_argument('--Zita', default=0.9, type=float, help='attenuation factor of entropy regular term.')\n\n    # reward shaping params\n    parser.add_argument('--reward_shaping', default=False, help='whether to use reward shaping.')\n    parser.add_argument('--time_span', default=1, type=int, help='24 for ICEWS, 1 for WIKI and YAGO')\n    parser.add_argument('--alphas_pkl', default='dirchlet_alphas.pkl', type=str,\n                        help='the file storing the alpha parameters of the Dirichlet distribution.')\n    parser.add_argument('--k', default=12000, type=int, help='statistics recent K historical snapshots.')\n    # configuration for preprocessor \n    parser.add_argument('--store_actions_num', default=0, type=int,\n                        help='maximum number of stored neighbors, 0 means store all.')\n    parser.add_argument('--preprocess', default=True,\n                        help=\"Do we want preprocessing for the actionspace\")\n    # configuration for dirichlet\n    parser.add_argument('--tol', default=1e-7, type=float)\n    parser.add_argument('--method', default='meanprecision', type=str)\n    parser.add_argument('--maxiter', default=100, type=int)\n    return parser.parse_args(args)\n\ndef get_model_config_timetraveler(args, num_ent, num_rel):\n    \"\"\" Get the model configuration for \"timetraveler\" model\"\"\"\n    config = {\n        'cuda': args.cuda,  # whether to use GPU or not.\n        'batch_size': args.batch_size,  # training batch size.\n        'num_ent': num_ent,  # number of entities\n        'num_rel': num_rel,  # number of relations\n        'ent_dim': args.ent_dim,  # Embedding dimension of the entities\n        'rel_dim': args.rel_dim,  # Embedding dimension of the relations\n        'time_dim': args.time_dim,  # Embedding dimension of the timestamps\n        'state_dim': args.state_dim,  # dimension of the LSTM hidden state\n        'action_dim': args.ent_dim + args.rel_dim,  # dimension of the actions\n        'mlp_input_dim': args.ent_dim + args.rel_dim + args.state_dim,  # dimension of the input of the MLP\n        'mlp_hidden_dim': args.hidden_dim,  # dimension of the MLP hidden layer\n        'path_length': args.path_length,  # agent search path length\n        'max_action_num': args.max_action_num,  # max candidate action number\n        'lambda': args.Lambda,  # update rate of baseline\n        'gamma': args.Gamma,  # discount factor of Bellman Eq.\n        'ita': args.Ita,  # regular proportionality constant\n        'zita': args.Zita,  # attenuation factor of entropy regular term\n        'beam_size': args.beam_size,  # beam size for beam search\n        'entities_embeds_method': args.entities_embeds_method,  # default: 'dynamic', otherwise static encoder will be used\n    }\n    return config\n\ndef get_args_cen():\n    \"\"\" Get the arguments for \"CEN\" model\"\"\"\n    parser = argparse.ArgumentParser(description='CEN')\n    parser.add_argument(\"--gpu\", type=int, default=0,\n                        help=\"gpu\")\n    parser.add_argument(\"--batch-size\", type=int, default=1,\n                        help=\"batch-size\")\n    parser.add_argument(\"-d\", \"--dataset\", type=str, default='tkgl-yago',\n                        help=\"dataset to use\")\n    parser.add_argument(\"--test\", type=int, default=0,\n                        help=\"1: formal test 2: continual test\")\n    parser.add_argument(\"--validtest\",  default=False,\n                        help=\"load stat from dir and directly valid and test\")\n    parser.add_argument(\"--test-only\", type=bool, default=False,\n                        help=\"do we want to compute valid mrr or only test\")\n    parser.add_argument(\"--run-statistic\", action='store_true', default=False,\n                        help=\"statistic the result\")\n\n    parser.add_argument(\"--relation-evaluation\", action='store_true', default=False,\n                        help=\"save model accordding to the relation evalution\")\n    parser.add_argument(\"--log-per-rel\", action='store_true', default=False,\n                        help=\"log mrr per relation in json\")\n\n    \n    # configuration for encoder RGCN stat\n    parser.add_argument(\"--weight\", type=float, default=1,\n                        help=\"weight of static constraint\")\n    parser.add_argument(\"--task-weight\", type=float, default=1,\n                        help=\"weight of entity prediction task\")\n    parser.add_argument(\"--kl-weight\", type=float, default=0.7,\n                        help=\"weight of entity prediction task\")\n   \n    parser.add_argument(\"--encoder\", type=str, default=\"uvrgcn\",\n                        help=\"method of encoder\")\n\n    parser.add_argument(\"--dropout\", type=float, default=0.2,\n                        help=\"dropout probability\")\n    parser.add_argument(\"--skip-connect\", action='store_true', default=False,\n                        help=\"whether to use skip connect in a RGCN Unit\")\n    parser.add_argument(\"--n-hidden\", type=int, default=200,\n                        help=\"number of hidden units\")\n    parser.add_argument(\"--opn\", type=str, default=\"sub\",\n                        help=\"opn of compgcn\")\n\n    parser.add_argument(\"--n-bases\", type=int, default=100,\n                        help=\"number of weight blocks for each relation\")\n    parser.add_argument(\"--n-basis\", type=int, default=100,\n                        help=\"number of basis vector for compgcn\")\n    parser.add_argument(\"--n-layers\", type=int, default=2,\n                        help=\"number of propagation rounds\")\n    parser.add_argument(\"--self-loop\", action='store_true', default=True,\n                        help=\"perform layer normalization in every layer of gcn \")\n    parser.add_argument(\"--layer-norm\", action='store_true', default=True,\n                        help=\"perform layer normalization in every layer of gcn \")\n    parser.add_argument(\"--relation-prediction\", action='store_true', default=False,\n                        help=\"add relation prediction loss\")\n    parser.add_argument(\"--entity-prediction\", action='store_true', default=True,\n                        help=\"add entity prediction loss\")\n\n\n    # configuration for stat training\n    parser.add_argument(\"--n-epochs\", type=int, default=30,\n                        help=\"number of minimum training epochs on each time step\")\n    parser.add_argument(\"--lr\", type=float, default=0.001,\n                        help=\"learning rate\")\n    parser.add_argument(\"--ft_epochs\", type=int, default=30,\n                        help=\"number of minimum fine-tuning epoch\")\n    parser.add_argument(\"--ft_lr\", type=float, default=0.001,\n                        help=\"learning rate\")\n    parser.add_argument(\"--norm_weight\", type=float, default=1,\n                        help=\"learning rate\")\n    parser.add_argument(\"--grad-norm\", type=float, default=1.0,\n                        help=\"norm to clip gradient to\")\n\n    # configuration for evaluating\n    parser.add_argument(\"--evaluate-every\", type=int, default=1,\n                        help=\"perform evaluation every n epochs\")\n\n    # configuration for decoder\n    parser.add_argument(\"--decoder\", type=str, default=\"convtranse\",\n                        help=\"method of decoder\")\n    parser.add_argument(\"--input-dropout\", type=float, default=0.2,\n                        help=\"input dropout for decoder \")\n    parser.add_argument(\"--hidden-dropout\", type=float, default=0.2,\n                        help=\"hidden dropout for decoder\")\n    parser.add_argument(\"--feat-dropout\", type=float, default=0.2,\n                        help=\"feat dropout for decoder\")\n\n    # configuration for sequences stat\n    parser.add_argument(\"--train-history-len\", type=int, default=3,\n                        help=\"history length\")\n    parser.add_argument(\"--test-history-len\", type=int, default=10,\n                        help=\"history length for test\")\n    parser.add_argument(\"--test-history-len-2\", type=int, default=2,\n                        help=\"history length for test\")\n    parser.add_argument(\"--start-history-len\", type=int, default=3,\n                    help=\"start history length\")\n    parser.add_argument(\"--dilate-len\", type=int, default=1,\n                        help=\"dilate history graph\")\n\n    # configuration for optimal parameters\n    parser.add_argument(\"--grid-search\", action='store_true', default=False,\n                        help=\"perform grid search for best configuration\")\n    parser.add_argument(\"-tune\", \"--tune\", type=str, default=\"n_hidden,n_layers,dropout,n_bases\",\n                        help=\"stat to use\")\n    parser.add_argument(\"--num-k\", type=int, default=500,\n                        help=\"number of triples generated\")\n    parser.add_argument('--seed', type=int, help='Random seed', default=1)\n    parser.add_argument('--run-nr', type=int, help='Run Number', default=1)\n\n\n\n    try:\n        args = parser.parse_args()\n    except:\n        parser.print_help()\n        sys.exit(0)\n    return args, sys.argv \n\ndef get_args_regcn():\n    \"\"\"Parses the arguments for REGCN model\"\"\"\n    parser = argparse.ArgumentParser(description='REGCN')\n\n    parser.add_argument(\"--gpu\", type=int, default=0,\n                        help=\"gpu\")\n    parser.add_argument(\"--batch-size\", type=int, default=1,\n                        help=\"batch-size\")\n    parser.add_argument(\"-d\", \"--dataset\", type=str, default='tkgl-yago',\n                        help=\"dataset to use\")\n    parser.add_argument(\"--test\", default=False,\n                        help=\"load stat from dir and directly test\")\n    parser.add_argument(\"--run-analysis\", action='store_true', default=False,\n                        help=\"print log info\")\n    parser.add_argument(\"--run-statistic\", action='store_true', default=False,\n                        help=\"statistic the result\")\n    parser.add_argument(\"--multi-step\", action='store_true', default=False,\n                        help=\"do multi-steps inference without ground truth\")\n    parser.add_argument(\"--topk\", type=int, default=10,\n                        help=\"choose top k entities as results when do multi-steps without ground truth\")\n    parser.add_argument(\"--add-static-graph\",  action='store_true', default=False,\n                        help=\"use the info of static graph\")\n    parser.add_argument(\"--add-rel-word\", action='store_true', default=False,\n                        help=\"use words in relaitons\")\n    parser.add_argument(\"--relation-evaluation\", action='store_true', default=False,\n                        help=\"save model accordding to the relation evalution\")\n\n    # configuration for encoder RGCN stat\n\n    parser.add_argument(\"--weight\", type=float, default=0.5,\n                        help=\"weight of static constraint\")\n    parser.add_argument(\"--task-weight\", type=float, default=0.7,\n                        help=\"weight of entity prediction task\")\n    parser.add_argument(\"--discount\", type=float, default=1.0,\n                        help=\"discount of weight of static constraint\")\n    parser.add_argument(\"--angle\", type=int, default=10,\n                        help=\"evolution speed\")\n\n    parser.add_argument(\"--encoder\", type=str, default=\"uvrgcn\",\n                        help=\"method of encoder\")\n    parser.add_argument(\"--aggregation\", type=str, default=\"none\",\n                        help=\"method of aggregation\")\n    parser.add_argument(\"--dropout\", type=float, default=0.2,\n                        help=\"dropout probability\")\n    parser.add_argument(\"--skip-connect\", action='store_true', default=False,\n                        help=\"whether to use skip connect in a RGCN Unit\")\n    parser.add_argument(\"--n-hidden\", type=int, default=200,\n                        help=\"number of hidden units\")\n    parser.add_argument(\"--opn\", type=str, default=\"sub\",\n                        help=\"opn of compgcn\")\n\n    parser.add_argument(\"--n-bases\", type=int, default=100,\n                        help=\"number of weight blocks for each relation\")\n    parser.add_argument(\"--n-basis\", type=int, default=100,\n                        help=\"number of basis vector for compgcn\")\n    parser.add_argument(\"--n-layers\", type=int, default=2,\n                        help=\"number of propagation rounds\")\n    parser.add_argument(\"--self-loop\", action='store_true', default=True,\n                        help=\"perform layer normalization in every layer of gcn \")\n    parser.add_argument(\"--layer-norm\", action='store_true', default=True,\n                        help=\"perform layer normalization in every layer of gcn \")\n    parser.add_argument(\"--relation-prediction\", action='store_true', default=False,\n                        help=\"add relation prediction loss\")\n    parser.add_argument(\"--entity-prediction\", action='store_true', default=True,\n                        help=\"add entity prediction loss\")\n    parser.add_argument(\"--split_by_relation\", action='store_true', default=False,\n                        help=\"do relation prediction\")\n\n    # configuration for stat training\n    parser.add_argument(\"--n-epochs\", type=int, default=10,\n                        help=\"number of minimum training epochs on each time step\") #100\n    parser.add_argument(\"--lr\", type=float, default=0.001,\n                        help=\"learning rate\")\n    parser.add_argument(\"--grad-norm\", type=float, default=1.0,\n                        help=\"norm to clip gradient to\")\n\n    # configuration for evaluating\n    parser.add_argument(\"--evaluate-every\", type=int, default=1,\n                        help=\"perform evaluation every n epochs\")\n    parser.add_argument(\"--log-per-rel\", action='store_true', default=False,\n                        help=\"log mrr per relation in json\")\n\n    # configuration for decoder\n    parser.add_argument(\"--decoder\", type=str, default=\"convtranse\",\n                        help=\"method of decoder\")\n    parser.add_argument(\"--input-dropout\", type=float, default=0.2,\n                        help=\"input dropout for decoder \")\n    parser.add_argument(\"--hidden-dropout\", type=float, default=0.2,\n                        help=\"hidden dropout for decoder\")\n    parser.add_argument(\"--feat-dropout\", type=float, default=0.2,\n                        help=\"feat dropout for decoder\")\n\n    # configuration for sequences stat\n    parser.add_argument(\"--train-history-len\", type=int, default=3,\n                        help=\"history length\")\n    parser.add_argument(\"--test-history-len\", type=int, default=3,\n                        help=\"history length for test\")\n    parser.add_argument(\"--dilate-len\", type=int, default=1,\n                        help=\"dilate history graph\")\n\n    # configuration for optimal parameters\n    parser.add_argument(\"--grid-search\", action='store_true', default=False,\n                        help=\"perform grid search for best configuration\")\n    parser.add_argument(\"-tune\", \"--tune\", type=str, default=\"n_hidden,n_layers,dropout,n_bases\",\n                        help=\"stat to use\")\n    parser.add_argument(\"--num-k\", type=int, default=500,\n                        help=\"number of triples generated\")\n    parser.add_argument('--seed', type=int, help='Random seed', default=1)\n    parser.add_argument('--run-nr', type=int, help='Run Number', default=1)\n    try:\n        args = parser.parse_args()\n    except:\n        parser.print_help()\n        sys.exit(0)\n    return args, sys.argv \n\n\ndef compute_min_distance(unique_sorted_timestamps):\n    \"\"\" compute the minimum distance between timestamps, where the timestamps are in a sorted list\n    \"\"\"\n    min_distance = np.inf\n    for i in range(1, len(unique_sorted_timestamps)):\n        min_distance = min(min_distance, unique_sorted_timestamps[i] - unique_sorted_timestamps[i-1])\n    return min_distance\n\ndef compute_maxminmean_distances(unique_sorted_timestamps):\n    \"\"\" compute the maximum, minimum and mean distances between timestamps, where the timestamps are in a sorted list\"\"\"\n    differences = []\n    \n    # Iterate over the list and compute the differences between successive elements\n    for i in range(len(unique_sorted_timestamps) - 1):\n        diff = unique_sorted_timestamps[i+1] - unique_sorted_timestamps[i]\n        differences.append(diff)\n    \n    # Calculate the mean of the differences\n    mean_diff = sum(differences) / len(differences)\n    \n    return np.max(differences), np.min(differences), np.mean(differences)\n\ndef group_by(data: np.array, key_idx: int) -> dict:\n    \"\"\"\n    group data in an np array to dict; where key is specified by key_idx. for example groups elements of array by relations\n    :param data: [np.array] data to be grouped\n    :param key_idx: [int] index for element of interest\n    returns data_dict: dict with key: values of element at index key_idx, values: all elements in data that have that value\n    \"\"\"\n    data_dict = {}\n    data_sorted = sorted(data, key=itemgetter(key_idx))\n    for key, group in groupby(data_sorted, key=itemgetter(key_idx)):\n        data_dict[key] = np.array(list(group))\n    return data_dict\n\ndef tkg_granularity_lookup(dataset_name, ts_distmean):\n    \"\"\" lookup the granularity of the dataset, and return the corresponding granularity\n    \"\"\"\n    if 'icews' in dataset_name or 'polecat' in dataset_name:\n        return 86400\n    elif 'wiki' in dataset_name or 'yago' in dataset_name:\n        return 31536000\n    else:\n        return ts_distmean\n\n    \n\ndef reformat_ts(timestamps, dataset_name='tkgl'):\n    \"\"\" reformat timestamps s.t. they start with 0, and have stepsize 1.\n    :param timestamps: np.array() with timestamps\n    returns: np.array(ts_new)\n    \"\"\"\n    all_ts = list(set(timestamps))\n    all_ts.sort()\n    ts_min = np.min(all_ts)\n    if 'tkgl' in dataset_name:\n        ts_distmax, ts_distmin, ts_distmean = compute_maxminmean_distances(all_ts)\n        if ts_distmean != ts_distmin:\n            ts_dist = tkg_granularity_lookup(dataset_name, ts_distmean)\n            if ts_dist - ts_distmean > 0.1*ts_distmean:\n                print('PROBLEM: the distances are somehwat off from the granularity of the dataset. using original mean distance')\n                ts_dist = ts_distmean\n        else:\n            ts_dist = ts_distmean\n    else:\n        ts_dist = compute_min_distance(all_ts) # all_ts[1] - all_ts[0]\n\n    ts_new = []\n    timestamps2 = timestamps - ts_min\n    ts_new = np.ceil(timestamps2/ts_dist).astype(int)\n\n    return np.array(ts_new)\n\ndef get_original_ts(reformatted_ts, ts_dist, min_ts):\n    \"\"\" get original timestamps from reformatted timestamps\n    :param reformatted_ts: np.array() with reformatted timestamps\n    returns: np.array(ts_new)\n    \"\"\"\n    reformatted_ts = list(set(reformatted_ts))\n    reformatted_ts.sort()\n    ts_new = []\n    for ts in reformatted_ts:\n        ts_new.append((ts * ts_dist)+min_ts)\n    return np.array(ts_new)\n\n\ndef create_basis_dict(data):\n    \"\"\"\n    Create basis dictionary for the recurrency baseline model with rules of confidence 1\n    data: concatenated train and vali data, INCLUDING INVERSE QUADRUPLES. we need it for the relation ids.\n    \"\"\"\n    rels = list(set(data[:,1]))\n    basis_dict = {}\n    for rel in rels:\n        basis_id_new = []\n        rule_dict = {}\n        rule_dict[\"head_rel\"] = int(rel)\n        rule_dict[\"body_rels\"] = [int(rel)] #same body and head relation -> what happened before happens again\n        rule_dict[\"conf\"] = 1 #same confidence for every rule\n        rule_new = rule_dict\n        basis_id_new.append(rule_new)\n        basis_dict[str(rel)] = basis_id_new\n    return basis_dict\n\n\ndef get_inv_relation_id(num_rels):\n    \"\"\"\n    Get inverse relation id.\n    parameters:\n        num_rels (int): number of relations\n    returns:\n        inv_relation_id (dict): mapping of relation to inverse relation\n    \"\"\"\n    inv_relation_id = dict()\n    for i in range(int(num_rels / 2)):\n        inv_relation_id[i] = i + int(num_rels / 2)\n    for i in range(int(num_rels / 2), num_rels):\n        inv_relation_id[i] = i % int(num_rels / 2)\n    return inv_relation_id\n\n\ndef create_scores_array(predictions_dict, num_nodes):\n    \"\"\" \n    Create an array of scores from a dictionary of predictions.\n    predictions_dict: a dictionary mapping indices to values\n    num_nodes: the size of the array\n    returns: an array of scores\n    \"\"\"\n    # predictions_dict is a dictionary mapping indices to values\n    # num_nodes is the size of the array\n\n    # Convert keys and values of the predictions_dict into NumPy arrays\n    keys_array = np.array(list(predictions_dict.keys()))\n    values_array = np.array(list(predictions_dict.values()))\n\n    # Create an array of zeros with the desired shape\n    predictions = np.zeros(num_nodes)\n\n    # Use advanced indexing to scatter values into predictions array\n    predictions[keys_array.astype(int)] = values_array.astype(float)\n    return predictions\n\n"
  },
  {
    "path": "modules/tkg_utils_dgl.py",
    "content": "\nimport dgl\nimport torch\nimport numpy as np\nfrom collections import defaultdict\n\ndef build_sub_graph(num_nodes, num_rels, triples, use_cuda, gpu, mode='dyn'):\n    \"\"\"\n    https://github.com/Lee-zix/CEN/blob/main/rgcn/utils.py\n    :param node_id: node id in the large graph\n    :param num_rels: number of relation\n    :param src: relabeled src id\n    :param rel: original rel id\n    :param dst: relabeled dst id\n    :param use_cuda:\n    :return:\n    \"\"\"\n    def comp_deg_norm(g):\n        in_deg = g.in_degrees(range(g.number_of_nodes())).float()\n        in_deg[torch.nonzero(in_deg == 0).view(-1)] = 1\n        norm = 1.0 / in_deg\n        return norm\n\n    src, rel, dst = triples.transpose()\n    if mode =='static':\n        src, dst = np.concatenate((src, dst)), np.concatenate((dst, src))\n        rel = np.concatenate((rel, rel + num_rels))\n    g = dgl.DGLGraph()\n    g.add_nodes(num_nodes)\n    #g.ndata['original_id'] = np.unique(np.concatenate((np.unique(triples[:,0]), np.unique(triples[:,2]))))\n    g.add_edges(src, dst)\n    norm = comp_deg_norm(g)\n    #node_id =torch.arange(0, g.num_nodes(), dtype=torch.long).view(-1, 1) #updated to deal with the fact that ot only the first k nodes of our graph have static infos\n    node_id = torch.arange(0, num_nodes, dtype=torch.long).view(-1, 1)\n    g.ndata.update({'id': node_id, 'norm': norm.view(-1, 1)})\n    g.apply_edges(lambda edges: {'norm': edges.dst['norm'] * edges.src['norm']})\n    g.edata['type'] = torch.LongTensor(rel)\n\n\n    uniq_r, r_len, r_to_e = r2e(triples, num_rels)\n    g.uniq_r = uniq_r\n    g.r_to_e = r_to_e\n    g.r_len = r_len\n\n    if use_cuda:\n        g = g.to(gpu)\n        g.r_to_e = torch.from_numpy(np.array(r_to_e))\n    return g\n\n\ndef r2e(triplets, num_rels):\n    \"\"\" get the mapping from relation to entities helper function for build_sub_graph()\n    returns: \n    uniq_r: set of unique relations\n    r_len: list of tuples, where each tuple is the start and end index of entities for a relation\n    e_idx: indices of entities\"\"\"\n    src, rel, dst = triplets.transpose()\n    # get all relations\n    uniq_r = np.unique(rel)\n    # uniq_r = np.concatenate((uniq_r, uniq_r+num_rels)) #we already have the inverse triples\n    # generate r2e\n    r_to_e = defaultdict(set)\n    for j, (src, rel, dst) in enumerate(triplets):\n        r_to_e[rel].add(src)\n        r_to_e[rel].add(dst)\n        r_to_e[rel+num_rels].add(src)\n        r_to_e[rel+num_rels].add(dst)\n    r_len = []\n    e_idx = []\n    idx = 0\n    for r in uniq_r:\n        r_len.append((idx,idx+len(r_to_e[r])))\n        e_idx.extend(list(r_to_e[r]))\n        idx += len(r_to_e[r])\n    return uniq_r, r_len, e_idx"
  },
  {
    "path": "modules/tlogic_apply_modules.py",
    "content": "\"\"\"\nhttps://github.com/liu-yushan/TLogic/blob/main/mycode/rule_application.py\nTLogic: Temporal Logical Rules for Explainable Link Forecasting on Temporal Knowledge Graphs.\nYushan Liu, Yunpu Ma, Marcel Hildebrandt, Mitchell Joblin, Volker Tresp\n\"\"\"\n\nimport json\nimport numpy as np\nimport pandas as pd\n\nfrom modules.tlogic_learn_modules import store_edges\n\n\ndef filter_rules(rules_dict, min_conf, min_body_supp, rule_lengths):\n    \"\"\"\n    Filter for rules with a minimum confidence, minimum body support, and\n    specified rule lengths.\n\n    Parameters.\n        rules_dict (dict): rules\n        min_conf (float): minimum confidence value\n        min_body_supp (int): minimum body support value\n        rule_lengths (list): rule lengths\n\n    Returns:\n        new_rules_dict (dict): filtered rules\n    \"\"\"\n\n    new_rules_dict = dict()\n    for k in rules_dict:\n        new_rules_dict[k] = []\n        for rule in rules_dict[k]:\n            cond = (\n                (rule[\"conf\"] >= min_conf)\n                and (rule[\"body_supp\"] >= min_body_supp)\n                and (len(rule[\"body_rels\"]) in rule_lengths)\n            )\n            if cond:\n                new_rules_dict[k].append(rule)\n\n    return new_rules_dict\n\n\ndef get_window_edges(all_data, test_query_ts, learn_edges, window=-1, first_test_query_ts=0): #modified eval_paper_authors: added first_test_query_ts for validation set usage\n    \"\"\"\n    Get the edges in the data (for rule application) that occur in the specified time window.\n    If window is 0, all edges before the test query timestamp are included.\n    If window is -1, the edges on which the rules are learned are used.\n    If window is -2, all edges from train and validation set are used. modified by eval_paper_authors.\n    If window is an integer n > 0, all edges within n timestamps before the test query\n    timestamp are included.\n    Note: modified according to Julia Gastinger, Timo Sztyler, Lokesh Sharma, Anett Schuelke, Heiner Stuckenschmidt. \n    Comparing Apples and Oranges? On the Evaluation of Methods for Temporal Knowledge Graphs. In ECML PKDD, 2023. \n    https://github.com/nec-research/TLogic/blob/374c7e34f5949f98b2eccc9628f98125a63763f1/mycode/rule_application.py\n    Parameters:\n        all_data (np.ndarray): complete dataset (train/valid/test)\n        test_query_ts (np.ndarray): test query timestamp\n        learn_edges (dict): edges on which the rules are learned\n        window (int): time window used for rule application\n        first_test_query_ts (int): smallest timestamp from test set (eval_paper_authors)\n\n    Returns:\n        window_edges (dict): edges in the window for rule application\n    \"\"\"\n\n    if window > 0:\n        mask = (all_data[:, 3] < test_query_ts) * (\n            all_data[:, 3] >= test_query_ts - window \n        )\n        window_edges = store_edges(all_data[mask])\n    elif window == 0:\n        mask = all_data[:, 3] < test_query_ts #!!! \n        window_edges = store_edges(all_data[mask]) \n    elif window == -1:\n        window_edges = learn_edges\n    elif window == -2: #modified eval_paper_authors: added this option\n        mask = all_data[:, 3] < first_test_query_ts # all edges at timestep smaller then the test queries. \n        # meaning all from train and valid set\n        window_edges = store_edges(all_data[mask])  \n    elif window == -200: #modified eval_paper_authors: added this option\n        abswindow = 200\n        mask = (all_data[:, 3] < first_test_query_ts) * (\n            all_data[:, 3] >= first_test_query_ts - abswindow  # all edges at timestep smaller than the test queries - 200\n        )\n        window_edges = store_edges(all_data[mask])\n    return window_edges\n\n\ndef match_body_relations(rule, edges, test_query_sub):\n    \"\"\"\n    Find edges that could constitute walks (starting from the test query subject)\n    that match the rule.\n    First, find edges whose subject match the query subject and the relation matches\n    the first relation in the rule body. Then, find edges whose subjects match the\n    current targets and the relation the next relation in the rule body.\n    Memory-efficient implementation.\n\n    Parameters:\n        rule (dict): rule from rules_dict\n        edges (dict): edges for rule application\n        test_query_sub (int): test query subject\n\n    Returns:\n        walk_edges (list of np.ndarrays): edges that could constitute rule walks\n    \"\"\"\n\n    rels = rule[\"body_rels\"]\n    # Match query subject and first body relation\n    try:\n        rel_edges = edges[rels[0]]\n        mask = rel_edges[:, 0] == test_query_sub\n        new_edges = rel_edges[mask]\n        walk_edges = [\n            np.hstack((new_edges[:, 0:1], new_edges[:, 2:4]))\n        ]  # [sub, obj, ts]\n        cur_targets = np.array(list(set(walk_edges[0][:, 1])))\n\n        for i in range(1, len(rels)):\n            # Match current targets and next body relation\n            try:\n                rel_edges = edges[rels[i]]\n                mask = np.any(rel_edges[:, 0] == cur_targets[:, None], axis=0)\n                new_edges = rel_edges[mask]\n                walk_edges.append(\n                    np.hstack((new_edges[:, 0:1], new_edges[:, 2:4]))\n                )  # [sub, obj, ts]\n                cur_targets = np.array(list(set(walk_edges[i][:, 1])))\n            except KeyError:\n                walk_edges.append([])\n                break\n    except KeyError:\n        walk_edges = [[]]\n\n    return walk_edges\n\n\ndef match_body_relations_complete(rule, edges, test_query_sub):\n    \"\"\"\n    Find edges that could constitute walks (starting from the test query subject)\n    that match the rule.\n    First, find edges whose subject match the query subject and the relation matches\n    the first relation in the rule body. Then, find edges whose subjects match the\n    current targets and the relation the next relation in the rule body.\n\n    Parameters:\n        rule (dict): rule from rules_dict\n        edges (dict): edges for rule application\n        test_query_sub (int): test query subject\n\n    Returns:\n        walk_edges (list of np.ndarrays): edges that could constitute rule walks\n    \"\"\"\n\n    rels = rule[\"body_rels\"]\n    # Match query subject and first body relation\n    try:\n        rel_edges = edges[rels[0]]\n        mask = rel_edges[:, 0] == test_query_sub\n        new_edges = rel_edges[mask]\n        walk_edges = [new_edges]\n        cur_targets = np.array(list(set(walk_edges[0][:, 2])))\n\n        for i in range(1, len(rels)):\n            # Match current targets and next body relation\n            try:\n                rel_edges = edges[rels[i]]\n                mask = np.any(rel_edges[:, 0] == cur_targets[:, None], axis=0)\n                new_edges = rel_edges[mask]\n                walk_edges.append(new_edges)\n                cur_targets = np.array(list(set(walk_edges[i][:, 2])))\n            except KeyError:\n                walk_edges.append([])\n                break\n    except KeyError:\n        walk_edges = [[]]\n\n    return walk_edges\n\n\ndef get_walks(rule, walk_edges):\n    \"\"\"\n    Get walks for a given rule. Take the time constraints into account.\n    Memory-efficient implementation.\n\n    Parameters:\n        rule (dict): rule from rules_dict\n        walk_edges (list of np.ndarrays): edges from match_body_relations\n\n    Returns:\n        rule_walks (pd.DataFrame): all walks matching the rule\n    \"\"\"\n\n    df_edges = []\n    #pd.Series(values).astype(uint16)\n    df = pd.DataFrame(\n        walk_edges[0],\n        columns=[\"entity_\" + str(0), \"entity_\" + str(1), \"timestamp_\" + str(0)]#,\n    #    dtype=np.uint16,\n    )  # Change type if necessary for better memory efficiency\n    if not rule[\"var_constraints\"]:\n        del df[\"entity_\" + str(0)]\n    df_edges.append(df)\n    df = df[0:0]  # Memory efficiency\n\n    for i in range(1, len(walk_edges)):\n        df = pd.DataFrame(\n            walk_edges[i],\n            columns=[\"entity_\" + str(i), \"entity_\" + str(i + 1), \"timestamp_\" + str(i)],\n            dtype=np.uint16,\n        )  # Change type if necessary\n        df_edges.append(df)\n        df = df[0:0]\n\n    rule_walks = df_edges[0]\n    df_edges[0] = df_edges[0][0:0]\n    for i in range(1, len(df_edges)):\n        rule_walks = pd.merge(rule_walks, df_edges[i], on=[\"entity_\" + str(i)])\n        rule_walks = rule_walks[\n            rule_walks[\"timestamp_\" + str(i - 1)] <= rule_walks[\"timestamp_\" + str(i)]\n        ]\n        if not rule[\"var_constraints\"]:\n            del rule_walks[\"entity_\" + str(i)]\n        df_edges[i] = df_edges[i][0:0]\n\n    for i in range(1, len(rule[\"body_rels\"])):\n        del rule_walks[\"timestamp_\" + str(i)]\n\n    return rule_walks\n\n\ndef get_walks_complete(rule, walk_edges):\n    \"\"\"\n    Get complete walks for a given rule. Take the time constraints into account.\n\n    Parameters:\n        rule (dict): rule from rules_dict\n        walk_edges (list of np.ndarrays): edges from match_body_relations\n\n    Returns:\n        rule_walks (pd.DataFrame): all walks matching the rule\n    \"\"\"\n\n    df_edges = []\n    df = pd.DataFrame(\n        walk_edges[0],\n        columns=[\n            \"entity_\" + str(0),\n            \"relation_\" + str(0),\n            \"entity_\" + str(1),\n            \"timestamp_\" + str(0),\n        ],\n        dtype=np.uint16,\n    )  # Change type if necessary for better memory efficiency\n    df_edges.append(df)\n\n    for i in range(1, len(walk_edges)):\n        df = pd.DataFrame(\n            walk_edges[i],\n            columns=[\n                \"entity_\" + str(i),\n                \"relation_\" + str(i),\n                \"entity_\" + str(i + 1),\n                \"timestamp_\" + str(i),\n            ],\n            dtype=np.uint16,\n        )  # Change type if necessary\n        df_edges.append(df)\n\n    rule_walks = df_edges[0]\n    for i in range(1, len(df_edges)):\n        rule_walks = pd.merge(rule_walks, df_edges[i], on=[\"entity_\" + str(i)])\n        rule_walks = rule_walks[\n            rule_walks[\"timestamp_\" + str(i - 1)] <= rule_walks[\"timestamp_\" + str(i)]\n        ]\n\n    return rule_walks\n\n\ndef check_var_constraints(var_constraints, rule_walks):\n    \"\"\"\n    Check variable constraints of the rule.\n\n    Parameters:\n        var_constraints (list): variable constraints from the rule\n        rule_walks (pd.DataFrame): all walks matching the rule\n\n    Returns:\n        rule_walks (pd.DataFrame): all walks matching the rule including the variable constraints\n    \"\"\"\n\n    for const in var_constraints:\n        for i in range(len(const) - 1):\n            rule_walks = rule_walks[\n                rule_walks[\"entity_\" + str(const[i])]\n                == rule_walks[\"entity_\" + str(const[i + 1])]\n            ]\n\n    return rule_walks\n\n\ndef get_candidates(\n    rule, rule_walks, test_query_ts, cands_dict, score_func, args, dicts_idx\n):\n    \"\"\"\n    Get from the walks that follow the rule the answer candidates.\n    Add the confidence of the rule that leads to these candidates.\n\n    Parameters:\n        rule (dict): rule from rules_dict\n        rule_walks (pd.DataFrame): rule walks (satisfying all constraints from the rule)\n        test_query_ts (int): test query timestamp\n        cands_dict (dict): candidates along with the confidences of the rules that generated these candidates\n        score_func (function): function for calculating the candidate score\n        args (list): arguments for the scoring function\n        dicts_idx (list): indices for candidate dictionaries\n\n    Returns:\n        cands_dict (dict): updated candidates\n    \"\"\"\n\n    max_entity = \"entity_\" + str(len(rule[\"body_rels\"]))\n    cands = set(rule_walks[max_entity])\n\n    for cand in cands:\n        cands_walks = rule_walks[rule_walks[max_entity] == cand]\n        for s in dicts_idx:\n            score = score_func(rule, cands_walks, test_query_ts, *args[s]).astype(\n                np.float32\n            )\n            try:\n                cands_dict[s][cand].append(score)\n            except KeyError:\n                cands_dict[s][cand] = [score]\n\n    return cands_dict\n\n\ndef save_candidates(\n    rules_file, dir_path, all_candidates, rule_lengths, window, score_func_str\n):\n    \"\"\"\n    Save the candidates.\n\n    Parameters:\n        rules_file (str): name of rules file\n        dir_path (str): path to output directory\n        all_candidates (dict): candidates for all test queries\n        rule_lengths (list): rule lengths\n        window (int): time window used for rule application\n        score_func_str (str): scoring function\n\n    Returns:\n        None\n    \"\"\"\n\n    all_candidates = {int(k): v for k, v in all_candidates.items()}\n    for k in all_candidates:\n        all_candidates[k] = {int(cand): v for cand, v in all_candidates[k].items()}\n    filename = \"{0}_cands_r{1}_w{2}_{3}.json\".format(\n        rules_file[:-11], rule_lengths, window, score_func_str\n    )\n    filename = filename.replace(\" \", \"\")\n    with open(dir_path + filename, \"w\", encoding=\"utf-8\") as fout:\n        json.dump(all_candidates, fout)\n\n\ndef verbalize_walk(walk, data):\n    \"\"\"\n    Verbalize walk from rule application.\n\n    Parameters:\n        walk (pandas.core.series.Series): walk that matches the rule body from get_walks\n        data (grapher.Grapher): graph data\n\n    Returns:\n        walk_str (str): verbalized walk\n    \"\"\"\n\n    l = len(walk) // 3\n    walk = walk.values.tolist()\n\n    walk_str = data.id2entity[walk[0]] + \"\\t\"\n    for j in range(l):\n        walk_str += data.id2relation[walk[3 * j + 1]] + \"\\t\"\n        walk_str += data.id2entity[walk[3 * j + 2]] + \"\\t\"\n        walk_str += data.id2ts[walk[3 * j + 3]] + \"\\t\"\n\n    return walk_str[:-1]\n\n\ndef score1(rule, c=0):\n    \"\"\"\n    Calculate candidate score depending on the rule's confidence.\n\n    Parameters:\n        rule (dict): rule from rules_dict\n        c (int): constant for smoothing\n\n    Returns:\n        score (float): candidate score\n    \"\"\"\n\n    score = rule[\"rule_supp\"] / (rule[\"body_supp\"] + c)\n\n    return score\n\n\ndef score2(cands_walks, test_query_ts, lmbda):\n    \"\"\"\n    Calculate candidate score depending on the time difference.\n\n    Parameters:\n        cands_walks (pd.DataFrame): walks leading to the candidate\n        test_query_ts (int): test query timestamp\n        lmbda (float): rate of exponential distribution\n\n    Returns:\n        score (float): candidate score\n    \"\"\"\n\n    max_cands_ts = max(cands_walks[\"timestamp_0\"])\n    score = np.exp(\n        lmbda * (max_cands_ts - test_query_ts)\n    )  # Score depending on time difference\n\n    return score\n\n\ndef score_12(rule, cands_walks, test_query_ts, lmbda, a):\n    \"\"\"\n    Combined score function.\n\n    Parameters:\n        rule (dict): rule from rules_dict\n        cands_walks (pd.DataFrame): walks leading to the candidate\n        test_query_ts (int): test query timestamp\n        lmbda (float): rate of exponential distribution\n        a (float): value between 0 and 1\n\n    Returns:\n        score (float): candidate score\n    \"\"\"\n\n    score = a * score1(rule) + (1 - a) * score2(cands_walks, test_query_ts, lmbda)\n\n    return score"
  },
  {
    "path": "modules/tlogic_learn_modules.py",
    "content": "\"\"\"\nhttps://github.com/liu-yushan/TLogic/blob/main/mycode/temporal_walk.py\nAND\nhttps://github.com/liu-yushan/TLogic/blob/main/mycode/rule_learning.py\nTLogic: Temporal Logical Rules for Explainable Link Forecasting on Temporal Knowledge Graphs.\nYushan Liu, Yunpu Ma, Marcel Hildebrandt, Mitchell Joblin, Volker Tresp\n\"\"\"\n\n\nimport os\nimport json\nimport itertools\nimport numpy as np\n\nclass Temporal_Walk(object):\n    def __init__(self, learn_data, inv_relation_id, transition_distr):\n        \"\"\"\n        Initialize temporal random walk object.\n\n        Parameters:\n            learn_data (np.ndarray): data on which the rules should be learned\n            inv_relation_id (dict): mapping of relation to inverse relation\n            transition_distr (str): transition distribution\n                                    \"unif\" - uniform distribution\n                                    \"exp\"  - exponential distribution\n\n        Returns:\n            None\n        \"\"\"\n\n        self.learn_data = learn_data\n        self.inv_relation_id = inv_relation_id\n        self.transition_distr = transition_distr\n        self.neighbors = store_neighbors(learn_data)\n        self.edges = store_edges(learn_data)\n\n    def sample_start_edge(self, rel_idx):\n        \"\"\"\n        Define start edge distribution.\n\n        Parameters:\n            rel_idx (int): relation index\n\n        Returns:\n            start_edge (np.ndarray): start edge\n        \"\"\"\n\n        rel_edges = self.edges[rel_idx]\n        start_edge = rel_edges[np.random.choice(len(rel_edges))]\n\n        return start_edge\n\n    def sample_next_edge(self, filtered_edges, cur_ts):\n        \"\"\"\n        Define next edge distribution.\n\n        Parameters:\n            filtered_edges (np.ndarray): filtered (according to time) edges\n            cur_ts (int): current timestamp\n\n        Returns:\n            next_edge (np.ndarray): next edge\n        \"\"\"\n\n        if self.transition_distr == \"unif\":\n            next_edge = filtered_edges[np.random.choice(len(filtered_edges))]\n        elif self.transition_distr == \"exp\":\n            tss = filtered_edges[:, 3]\n            prob = np.exp(tss - cur_ts)\n            try:\n                prob = prob / np.sum(prob)\n                next_edge = filtered_edges[\n                    np.random.choice(range(len(filtered_edges)), p=prob)\n                ]\n            except ValueError:  # All timestamps are far away\n                next_edge = filtered_edges[np.random.choice(len(filtered_edges))]\n\n        return next_edge\n\n    def transition_step(self, cur_node, cur_ts, prev_edge, start_node, step, L):\n        \"\"\"\n        Sample a neighboring edge given the current node and timestamp.\n        In the second step (step == 1), the next timestamp should be smaller than the current timestamp.\n        In the other steps, the next timestamp should be smaller than or equal to the current timestamp.\n        In the last step (step == L-1), the edge should connect to the source of the walk (cyclic walk).\n        It is not allowed to go back using the inverse edge.\n\n        Parameters:\n            cur_node (int): current node\n            cur_ts (int): current timestamp\n            prev_edge (np.ndarray): previous edge\n            start_node (int): start node\n            step (int): number of current step\n            L (int): length of random walk\n\n        Returns:\n            next_edge (np.ndarray): next edge\n        \"\"\"\n\n        next_edges = self.neighbors[cur_node]\n\n        if step == 1:  # The next timestamp should be smaller than the current timestamp\n            filtered_edges = next_edges[next_edges[:, 3] < cur_ts]\n        else:  # The next timestamp should be smaller than or equal to the current timestamp\n            filtered_edges = next_edges[next_edges[:, 3] <= cur_ts]\n            # Delete inverse edge\n            inv_edge = [\n                cur_node,\n                self.inv_relation_id[prev_edge[1]],\n                prev_edge[0],\n                cur_ts,\n            ]\n            row_idx = np.where(np.all(filtered_edges == inv_edge, axis=1))\n            filtered_edges = np.delete(filtered_edges, row_idx, axis=0)\n\n        if step == L - 1:  # Find an edge that connects to the source of the walk\n            filtered_edges = filtered_edges[filtered_edges[:, 2] == start_node]\n\n        if len(filtered_edges):\n            next_edge = self.sample_next_edge(filtered_edges, cur_ts)\n        else:\n            next_edge = []\n\n        return next_edge\n\n    def sample_walk(self, L, rel_idx):\n        \"\"\"\n        Try to sample a cyclic temporal random walk of length L (for a rule of length L-1).\n\n        Parameters:\n            L (int): length of random walk\n            rel_idx (int): relation index\n\n        Returns:\n            walk_successful (bool): if a cyclic temporal random walk has been successfully sampled\n            walk (dict): information about the walk (entities, relations, timestamps)\n        \"\"\"\n\n        walk_successful = True\n        walk = dict()\n        prev_edge = self.sample_start_edge(rel_idx)\n        start_node = prev_edge[0]\n        cur_node = prev_edge[2]\n        cur_ts = prev_edge[3]\n        walk[\"entities\"] = [start_node, cur_node]\n        walk[\"relations\"] = [prev_edge[1]]\n        walk[\"timestamps\"] = [cur_ts]\n\n        for step in range(1, L):\n            next_edge = self.transition_step(\n                cur_node, cur_ts, prev_edge, start_node, step, L\n            )\n            if len(next_edge):\n                cur_node = next_edge[2]\n                cur_ts = next_edge[3]\n                walk[\"relations\"].append(next_edge[1])\n                walk[\"entities\"].append(cur_node)\n                walk[\"timestamps\"].append(cur_ts)\n                prev_edge = next_edge\n            else:  # No valid neighbors (due to temporal or cyclic constraints)\n                walk_successful = False\n                break\n\n        return walk_successful, walk\n\n\ndef store_neighbors(quads):\n    \"\"\"\n    Store all neighbors (outgoing edges) for each node.\n\n    Parameters:\n        quads (np.ndarray): indices of quadruples\n\n    Returns:\n        neighbors (dict): neighbors for each node\n    \"\"\"\n\n    neighbors = dict()\n    nodes = list(set(quads[:, 0]))\n    for node in nodes:\n        neighbors[node] = quads[quads[:, 0] == node]\n\n    return neighbors\n\n\ndef store_edges(quads):\n    \"\"\"\n    Store all edges for each relation.\n\n    Parameters:\n        quads (np.ndarray): indices of quadruples\n\n    Returns:\n        edges (dict): edges for each relation\n    \"\"\"\n\n    edges = dict()\n    relations = list(set(quads[:, 1]))\n    for rel in relations:\n        edges[rel] = quads[quads[:, 1] == rel]\n\n    return edges\n\n\nclass Rule_Learner(object):\n    def __init__(self, edges, id2relation, inv_relation_id, output_dir):\n        \"\"\"\n        Initialize rule learner object.\n\n        Parameters:\n            edges (dict): edges for each relation\n            id2relation (dict): mapping of index to relation\n            inv_relation_id (dict): mapping of relation to inverse relation\n            output_dir (str): directory name where to store learned rules\n\n        Returns:\n            None\n        \"\"\"\n\n        self.edges = edges\n        self.id2relation = id2relation\n        self.inv_relation_id = inv_relation_id\n\n        self.found_rules = []\n        self.rules_dict = dict()\n        self.output_dir = output_dir\n        if not os.path.exists(self.output_dir):\n            os.makedirs(self.output_dir)\n\n    def create_rule(self, walk):\n        \"\"\"\n        Create a rule given a cyclic temporal random walk.\n        The rule contains information about head relation, body relations,\n        variable constraints, confidence, rule support, and body support.\n        A rule is a dictionary with the content\n        {\"head_rel\": int, \"body_rels\": list, \"var_constraints\": list,\n         \"conf\": float, \"rule_supp\": int, \"body_supp\": int}\n\n        Parameters:\n            walk (dict): cyclic temporal random walk\n                         {\"entities\": list, \"relations\": list, \"timestamps\": list}\n\n        Returns:\n            rule (dict): created rule\n        \"\"\"\n\n        rule = dict()\n        rule[\"head_rel\"] = int(walk[\"relations\"][0])\n        rule[\"body_rels\"] = [\n            self.inv_relation_id[x] for x in walk[\"relations\"][1:][::-1]\n        ]\n        rule[\"var_constraints\"] = self.define_var_constraints(\n            walk[\"entities\"][1:][::-1]\n        )\n\n        if rule not in self.found_rules:\n            self.found_rules.append(rule.copy())\n            (\n                rule[\"conf\"],\n                rule[\"rule_supp\"],\n                rule[\"body_supp\"],\n            ) = self.estimate_confidence(rule)\n\n            if rule[\"conf\"]:\n                self.update_rules_dict(rule)\n\n    def define_var_constraints(self, entities):\n        \"\"\"\n        Define variable constraints, i.e., state the indices of reoccurring entities in a walk.\n\n        Parameters:\n            entities (list): entities in the temporal walk\n\n        Returns:\n            var_constraints (list): list of indices for reoccurring entities\n        \"\"\"\n\n        var_constraints = []\n        for ent in set(entities):\n            all_idx = [idx for idx, x in enumerate(entities) if x == ent]\n            var_constraints.append(all_idx)\n        var_constraints = [x for x in var_constraints if len(x) > 1]\n\n        return sorted(var_constraints)\n\n    def estimate_confidence(self, rule, num_samples=500):\n        \"\"\"\n        Estimate the confidence of the rule by sampling bodies and checking the rule support.\n\n        Parameters:\n            rule (dict): rule\n                         {\"head_rel\": int, \"body_rels\": list, \"var_constraints\": list}\n            num_samples (int): number of samples\n\n        Returns:\n            confidence (float): confidence of the rule, rule_support/body_support\n            rule_support (int): rule support\n            body_support (int): body support\n        \"\"\"\n\n        all_bodies = []\n        for _ in range(num_samples):\n            sample_successful, body_ents_tss = self.sample_body(\n                rule[\"body_rels\"], rule[\"var_constraints\"]\n            )\n            if sample_successful:\n                all_bodies.append(body_ents_tss)\n\n        all_bodies.sort()\n        unique_bodies = list(x for x, _ in itertools.groupby(all_bodies))\n        body_support = len(unique_bodies)\n\n        confidence, rule_support = 0, 0\n        if body_support:\n            rule_support = self.calculate_rule_support(unique_bodies, rule[\"head_rel\"])\n            confidence = round(rule_support / body_support, 6)\n\n        return confidence, rule_support, body_support\n\n    def sample_body(self, body_rels, var_constraints):\n        \"\"\"\n        Sample a walk according to the rule body.\n        The sequence of timesteps should be non-decreasing.\n\n        Parameters:\n            body_rels (list): relations in the rule body\n            var_constraints (list): variable constraints for the entities\n\n        Returns:\n            sample_successful (bool): if a body has been successfully sampled\n            body_ents_tss (list): entities and timestamps (alternately entity and timestamp)\n                                  of the sampled body\n        \"\"\"\n\n        sample_successful = True\n        body_ents_tss = []\n        cur_rel = body_rels[0]\n        rel_edges = self.edges[cur_rel]\n        next_edge = rel_edges[np.random.choice(len(rel_edges))]\n        cur_ts = next_edge[3]\n        cur_node = next_edge[2]\n        body_ents_tss.append(next_edge[0])\n        body_ents_tss.append(cur_ts)\n        body_ents_tss.append(cur_node)\n\n        for cur_rel in body_rels[1:]:\n            next_edges = self.edges[cur_rel]\n            mask = (next_edges[:, 0] == cur_node) * (next_edges[:, 3] >= cur_ts)\n            filtered_edges = next_edges[mask]\n\n            if len(filtered_edges):\n                next_edge = filtered_edges[np.random.choice(len(filtered_edges))]\n                cur_ts = next_edge[3]\n                cur_node = next_edge[2]\n                body_ents_tss.append(cur_ts)\n                body_ents_tss.append(cur_node)\n            else:\n                sample_successful = False\n                break\n\n        if sample_successful and var_constraints:\n            # Check variable constraints\n            body_var_constraints = self.define_var_constraints(body_ents_tss[::2])\n            if body_var_constraints != var_constraints:\n                sample_successful = False\n\n        return sample_successful, body_ents_tss\n\n    def calculate_rule_support(self, unique_bodies, head_rel):\n        \"\"\"\n        Calculate the rule support. Check for each body if there is a timestamp\n        (larger than the timestamps in the rule body) for which the rule head holds.\n\n        Parameters:\n            unique_bodies (list): bodies from self.sample_body\n            head_rel (int): head relation\n\n        Returns:\n            rule_support (int): rule support\n        \"\"\"\n\n        rule_support = 0\n        head_rel_edges = self.edges[head_rel]\n        for body in unique_bodies:\n            mask = (\n                (head_rel_edges[:, 0] == body[0])\n                * (head_rel_edges[:, 2] == body[-1])\n                * (head_rel_edges[:, 3] > body[-2])\n            )\n\n            if True in mask:\n                rule_support += 1\n\n        return rule_support\n\n    def update_rules_dict(self, rule):\n        \"\"\"\n        Update the rules if a new rule has been found.\n\n        Parameters:\n            rule (dict): generated rule from self.create_rule\n\n        Returns:\n            None\n        \"\"\"\n\n        try:\n            self.rules_dict[rule[\"head_rel\"]].append(rule)\n        except KeyError:\n            self.rules_dict[rule[\"head_rel\"]] = [rule]\n\n    def sort_rules_dict(self):\n        \"\"\"\n        Sort the found rules for each head relation by decreasing confidence.\n\n        Parameters:\n            None\n\n        Returns:\n            None\n        \"\"\"\n\n        for rel in self.rules_dict:\n            self.rules_dict[rel] = sorted(\n                self.rules_dict[rel], key=lambda x: x[\"conf\"], reverse=True\n            )\n\n    def save_rules(self, dt, rule_lengths, num_walks, transition_distr, seed):\n        \"\"\"\n        Save all rules.\n\n        Parameters:\n            dt (str): time now\n            rule_lengths (list): rule lengths\n            num_walks (int): number of walks\n            transition_distr (str): transition distribution\n            seed (int): random seed\n\n        Returns:\n            None\n        \"\"\"\n\n        rules_dict = {int(k): v for k, v in self.rules_dict.items()}\n        filename = \"{0}_r{1}_n{2}_{3}_s{4}_rules.json\".format(\n            dt, rule_lengths, num_walks, transition_distr, seed\n        )\n        filename = filename.replace(\" \", \"\")\n        with open(self.output_dir + filename, \"w\", encoding=\"utf-8\") as fout:\n            json.dump(rules_dict, fout)\n\n        return filename\n\n    def save_rules_verbalized(\n        self, dt, rule_lengths, num_walks, transition_distr, seed\n    ):\n        \"\"\"\n        Save all rules in a human-readable format.\n\n        Parameters:\n            dt (str): time now\n            rule_lengths (list): rule lengths\n            num_walks (int): number of walks\n            transition_distr (str): transition distribution\n            seed (int): random seed\n\n        Returns:\n            None\n        \"\"\"\n\n        rules_str = \"\"\n        for rel in self.rules_dict:\n            for rule in self.rules_dict[rel]:\n                rules_str += verbalize_rule(rule, self.id2relation) + \"\\n\"\n\n        filename = \"{0}_r{1}_n{2}_{3}_s{4}_rules.txt\".format(\n            dt, rule_lengths, num_walks, transition_distr, seed\n        )\n        filename = filename.replace(\" \", \"\")\n        with open(self.output_dir + filename, \"w\", encoding=\"utf-8\") as fout:\n            fout.write(rules_str)\n\n\ndef verbalize_rule(rule, id2relation):\n    \"\"\"\n    Verbalize the rule to be in a human-readable format.\n\n    Parameters:\n        rule (dict): rule from Rule_Learner.create_rule\n        id2relation (dict): mapping of index to relation\n\n    Returns:\n        rule_str (str): human-readable rule\n    \"\"\"\n\n    if rule[\"var_constraints\"]:\n        var_constraints = rule[\"var_constraints\"]\n        constraints = [x for sublist in var_constraints for x in sublist]\n        for i in range(len(rule[\"body_rels\"]) + 1):\n            if i not in constraints:\n                var_constraints.append([i])\n        var_constraints = sorted(var_constraints)\n    else:\n        var_constraints = [[x] for x in range(len(rule[\"body_rels\"]) + 1)]\n\n    rule_str = \"{0:8.6f}  {1:4}  {2:4}  {3}(X0,X{4},T{5}) <- \"\n    obj_idx = [\n        idx\n        for idx in range(len(var_constraints))\n        if len(rule[\"body_rels\"]) in var_constraints[idx]\n    ][0]\n    rule_str = rule_str.format(\n        rule[\"conf\"],\n        rule[\"rule_supp\"],\n        rule[\"body_supp\"],\n        id2relation[rule[\"head_rel\"]],\n        obj_idx,\n        len(rule[\"body_rels\"]),\n    )\n\n    for i in range(len(rule[\"body_rels\"])):\n        sub_idx = [\n            idx for idx in range(len(var_constraints)) if i in var_constraints[idx]\n        ][0]\n        obj_idx = [\n            idx for idx in range(len(var_constraints)) if i + 1 in var_constraints[idx]\n        ][0]\n        rule_str += \"{0}(X{1},X{2},T{3}), \".format(\n            id2relation[rule[\"body_rels\"][i]], sub_idx, obj_idx, i\n        )\n\n    return rule_str[:-2]\n\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[tool.poetry]\nname = \"py-tgb\"\nversion = \"2.2.0\"\ndescription = \"Temporal Graph Benchmark project repo\"\nauthors = [\"shenyang Huang <shenyang.huang@mail.mcgill.ca>\", \"Julia Gastinger\", \"Farimah Poursafaei\", \"Emanuele Rossi <emanuele.rossi1909@gmail.com>\", \"Jacob Danovitch <jacob.danovitch@mila.quebec>\"]\nreadme = \"README.md\"\npackages = [{include = \"tgb\"}]\n\n[tool.poetry.dependencies]\npython = \"^3.9\"\ntorch-geometric = \"^2.3.0\"\ntqdm = \"^4.65.0\"\nnumpy = \"^2.0.2\"\nclint = \"^0.5.1\"\nrequests = \"^2.28.2\"\npandas = \">=2.2.3\"\nscikit-learn = \"^1.2.2\"\n\n[tool.poetry.group.dev.dependencies]\nmkdocs = \"^1.4.3\"\nmkdocs-material = \"^9.1.15\"\nmkdocstrings-python = \"^1.1.2\"\nmkdocs-jupyter = \"^0.24.1\"\npoetry = \"^1.5.1\"\n\n[build-system]\nrequires = [\"poetry-core\"]\nbuild-backend = \"poetry.core.masonry.api\"\n"
  },
  {
    "path": "run.sh",
    "content": "#!/bin/bash\n#SBATCH --partition=long  #unkillable #main #long\n#SBATCH --output=tgnlog_genre_s5.txt #tgn_lastfmgenre_s5.txt \n#SBATCH --error=tgnlog_genre_s5error.txt #tgn_lastfmgenre_s5_error.txt   \n#SBATCH --cpus-per-task=4                     # Ask for 4 CPUs\n#SBATCH --gres=gpu:rtx8000:1                  # Ask for 1 titan xp\n#SBATCH --mem=32G                             # Ask for 32 GB of RAM\n#SBATCH --time=48:00:00                       # The job will run for 1 day\n\nexport HOME=\"/home/mila/h/huangshe\"\nmodule load python/3.9\nsource $HOME/tgbenv/bin/activate\n\npwd\nCUDA_VISIBLE_DEVICES=0 python examples/nodeproppred/tgbn-genre/tgn.py --seed 5\n# CUDA_VISIBLE_DEVICES=0 python examples/nodeproppred/lastfmgenre/dyrep.py --seed 5\n# CUDA_VISIBLE_DEVICES=0 python examples/nodeproppred/un_trade/tgn.py -s 5\n# CUDA_VISIBLE_DEVICES=0 python examples/linkproppred/amazonreview/tgn.py -s 1\n"
  },
  {
    "path": "scripts/env.sh",
    "content": "module load python/3.9\nsource $HOME/tgbenv/bin/activate\n\n"
  },
  {
    "path": "scripts/mila.sh",
    "content": "salloc --partition=unkillable --cpus-per-task=4 --gres=gpu:1 --mem=32G\n"
  },
  {
    "path": "scripts/mila_install.sh",
    "content": "module load python/3.9\npython -m venv $HOME/tgbenv\nsource $HOME/tgbenv/bin/activate\npip3 install -r requirements.txt\npip3 install -e .\n\n\n"
  },
  {
    "path": "scripts/run.sh",
    "content": "#!/bin/bash\n#SBATCH --partition=long  #unkillable #main #long\n#SBATCH --output=dyrep_trade_s5.txt #tgn_lastfmgenre_s5.txt \n#SBATCH --error=dyrep_trade_s5error.txt #tgn_lastfmgenre_s5_error.txt   \n#SBATCH --cpus-per-task=4                     # Ask for 4 CPUs\n#SBATCH --gres=gpu:rtx8000:1                  # Ask for 1 titan xp\n#SBATCH --mem=32G                             # Ask for 32 GB of RAM\n#SBATCH --time=48:00:00                       # The job will run for 1 day\n\nexport HOME=\"/home/mila/h/huangshe\"\nmodule load python/3.9\nsource $HOME/tgbenv/bin/activate\n\npwd\nCUDA_VISIBLE_DEVICES=0 python examples/nodeproppred/un_trade/dyrep.py --seed 5\n# CUDA_VISIBLE_DEVICES=0 python examples/nodeproppred/lastfmgenre/dyrep.py --seed 5\n# CUDA_VISIBLE_DEVICES=0 python examples/nodeproppred/un_trade/tgn.py -s 5\n# CUDA_VISIBLE_DEVICES=0 python examples/linkproppred/amazonreview/tgn.py -s 1\n# CUDA_VISIBLE_DEVICES=0 python examples/nodeproppred/lastfmgenre/tgn.py -s 5\n"
  },
  {
    "path": "setup.py",
    "content": "from setuptools import setup, find_packages\r\n\r\nsetup(name=\"py-tgb\", version=\"2.2.0\", packages=find_packages())\r\n"
  },
  {
    "path": "tgb/datasets/ICEWS14/ent2word.py",
    "content": "# -*- coding: utf-8 -*-\n# @Time    : 2019/12/5 4:20 下午\n# @Author  : Lee_zix\n# @Email   : Lee_zix@163.com\n# @File    : ent2word.py.py\n# @Software: PyCharm\n\nimport os\n\ndef load_index(input_path):\n    index, rev_index = {}, {}\n    with open(input_path) as f:\n        for i, line in enumerate(f.readlines()):        # relaions.dict和entities.dict中的id都是按顺序排列的\n            rel, id = line.strip().split(\"\\t\")\n            index[rel] = id\n            rev_index[id] = rel\n    return index, rev_index\n\nentity2id, id2entity = load_index(os.path.join('entity2id.txt'))\nrelation2id, id2relation = load_index(os.path.join('relation2id.txt'))\n\ncount = 0\ncount1 = 0\nword_list = set()\nfor entity_str in entity2id.keys():\n    if \"(\" in entity_str and \")\" in entity_str:\n        count += 1\n        begin = entity_str.find('(')\n        end = entity_str.find(')')\n        w1 = entity_str[:begin].strip()\n        w2 = entity_str[begin+1: end]\n        if w2 not in entity2id.keys():\n            print(w2)\n            count1 += 1\n        word_list.add(w1)\n        word_list.add(w2)\n    else:\n        word_list.add(entity_str)\n\nnum_word = len(word_list)\n\nword2id = {word: id for id, word in enumerate(word_list)}\nid2word = {id: word for id, word in enumerate(word_list)}\n# print(word2id)\n# print(id2word)\n\nprint(\"words num: {}, enity_num: {}\".format(num_word, len(entity2id.keys())))\nprint(float(count)/len(entity2id.keys()))\nprint(float(count1)/float(count))\n\nwith open(\"word2id.txt\", \"w\") as f:\n    for word in word2id.keys():\n        f.write(word + \"\\t\" + str(word2id[word])+'\\n')\n\neid2wid = []\nfor id in range(len(id2entity.keys())):\n    entity_str = id2entity[str(id)]\n    if \"(\" in entity_str and \")\" in entity_str:\n        count += 1\n        begin = entity_str.find('(')\n        end = entity_str.find(')')\n        w1 = entity_str[:begin].strip()\n        w2 = entity_str[begin+1: end]\n        eid2wid.append([str(entity2id[entity_str]), \"0\", str(word2id[w1])])   # isA关系\n        eid2wid.append([str(entity2id[entity_str]), \"1\", str(word2id[w2])])     # 隶属关系\n    else:\n        eid2wid.append([str(entity2id[entity_str]), \"2\", str(word2id[entity_str])])\n\nwith open(\"e-w-graph.txt\", \"w\") as f:\n    for line in eid2wid:\n        f.write(\"\\t\".join(line)+'\\n')\n\n\n\n\n"
  },
  {
    "path": "tgb/datasets/ICEWS14/icews14.py",
    "content": "import csv\r\n\r\ndef load_index(input_path):\r\n    index, rev_index = {}, {}\r\n    with open(input_path) as f:\r\n        for i, line in enumerate(f.readlines()):        # relaions.dict和entities.dict中的id都是按顺序排列的\r\n            rel, id = line.strip().split(\"\\t\")\r\n            index[rel] = id\r\n            rev_index[id] = rel\r\n    return index, rev_index\r\n\r\n\r\ndef load_tab_list(input_path):\r\n    rows = []\r\n    with open(input_path) as f:\r\n        for i, line in enumerate(f.readlines()): \r\n            head,relation,tail,t, = line.strip().split(\"\\t\")\r\n            rows.append([t,head,tail,relation])\r\n    return rows\r\n\r\n        \r\ndef write2csv(rows, output_path):\r\n    with open(output_path, \"w\") as f:\r\n        writer = csv.writer(f)\r\n        writer.writerow([\"timestamp\", \"head\", \"tail\", \"relation_type\"])\r\n        writer.writerows(rows)\r\n\r\n\r\ndef main():\r\n    \"\"\"\r\n    concatenate and merge the edgelists into one \r\n    change tab to ,\r\n    \"\"\"\r\n    train_name = \"train.txt\"\r\n    train_rows = load_tab_list(train_name)\r\n\r\n    val_name = \"valid.txt\"\r\n    val_rows = load_tab_list(val_name)\r\n\r\n    test_name = \"test.txt\"\r\n    test_rows = load_tab_list(test_name)\r\n\r\n    all_rows = train_rows + val_rows + test_rows\r\n    output_path = \"icews14.csv\"\r\n    write2csv(all_rows, output_path)\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n\r\n"
  },
  {
    "path": "tgb/datasets/dataset_scripts/MAG/mag.py",
    "content": "import pandas as pd\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    df = pd.read_parquet(\"nodes.parquet/nodes.parquet\", engine=\"pyarrow\")\r\n    data_top = df.head()\r\n\r\n    print(data_top)\r\n"
  },
  {
    "path": "tgb/datasets/dataset_scripts/MAG/old/plot_stats.py",
    "content": "import networkx as nx\r\nimport matplotlib.pyplot as plt\r\n\r\n\r\ndef load_csv(fname: str):\r\n    \"\"\"\r\n    plot the number of citations in each year for the MAG dataset\r\n    \"\"\"\r\n    f = open(fname, \"r\")\r\n    lines = list(f.readlines())\r\n    f.close()\r\n\r\n    years = []\r\n    cites = []\r\n    for i in range(len(lines)):\r\n        if i == 0:\r\n            continue\r\n        line = lines[i]\r\n        line = line.split(\",\")\r\n        try:\r\n            year = int(line[0])\r\n        except:\r\n            continue\r\n        num_citations = int(line[1])\r\n        years.append(year)\r\n        cites.append(num_citations)\r\n\r\n    plt.plot(years, cites, color=\"#e34a33\")\r\n    plt.xlabel(\"Year\")\r\n    plt.ylabel(\"Paper Count\")\r\n    plt.savefig(\"paper_count.pdf\")\r\n    plt.close()\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    load_csv(\"paper_year.txt\")\r\n"
  },
  {
    "path": "tgb/datasets/dataset_scripts/dgraph.py",
    "content": "import dateutil.parser as dparser\r\nimport csv\r\nimport matplotlib.pyplot as plt\r\nimport numpy as np\r\nfrom tqdm import tqdm\r\nfrom os import listdir\r\nfrom datetime import datetime\r\n\r\n\r\n\"\"\"\r\n# Description of DGraphFin datafile.\r\n\r\n#! File **dgraphfin.npz** including below keys:  \r\n\r\n#* **x**: 17-dimensional node features.\r\n#* **y**: node label.  \r\n    There four classes. Below are the nodes counts of each class.     \r\n    0: 1210092    \r\n    1: 15509    \r\n    2: 1620851    \r\n    3: 854098    \r\n    Nodes of Class 1 are fraud users and nodes of 0 are normal users, and they the two classes to be predicted.    \r\n    Nodes of Class 2 and Class 3 are background users.    \r\n    \r\n#* **edge_index**: shape (4300999, 2).   \r\n    Each edge is in the form (id_a, id_b), where ids are the indices in x.        \r\n\r\n#* **edge_type**: 11 types of edges. \r\n    \r\n#* **edge_timestamp**: the desensitized timestamp of each edge.\r\n    \r\n#* **train_mask, valid_mask, test_mask**:  \r\n    Nodes of Class 0 and Class 1 are randomly splitted by 70/15/15.  \r\n\"\"\"\r\n\r\n\r\n\r\n\r\ndef main():\r\n    \r\n    #* load the raw data from numpy\r\n    with np.load('dgraphfin.npz') as data:\r\n        \r\n        x = data['x']\r\n        print (\"shape of the node feature vectors are\")\r\n        print (x.shape)\r\n        \r\n        y = data['y']\r\n        print (\"shape of the node labels are\")\r\n        print (y.shape)\r\n        \r\n        edge_index = data['edge_index']\r\n        print (\"shape of the edge index are\")\r\n        print (edge_index.shape)\r\n        \r\n        edge_type = data['edge_type']\r\n        print (\"shape of the edge type are\")\r\n        print (edge_type.shape)\r\n        \r\n        edge_timestamp = data['edge_timestamp']\r\n        print (\"shape of the edge timestamp are\")\r\n        print (edge_timestamp.shape)\r\n        \r\n        print (\"check if the timestamps are sorted\")\r\n        print(np.all(edge_timestamp[:-1] <= edge_timestamp[1:]))\r\n\r\n                \r\n    \r\n    \r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n"
  },
  {
    "path": "tgb/datasets/dataset_scripts/dgraph_Readme.md",
    "content": "# Description of DGraphFin datafile.\n\nFile **dgraphfin.npz** including below keys:  \n\n- **x**: 17-dimensional node features.\n- **y**: node label.  \n    There four classes. Below are the nodes counts of each class.     \n    0: 1210092    \n    1: 15509    \n    2: 1620851    \n    3: 854098    \n    Nodes of Class 1 are fraud users and nodes of 0 are normal users, and they the two classes to be predicted.    \n    Nodes of Class 2 and Class 3 are background users.    \n    \n- **edge_index**: shape (4300999, 2).   \n    Each edge is in the form (id_a, id_b), where ids are the indices in x.        \n\n- **edge_type**: 11 types of edges. \n    \n- **edge_timestamp**: the desensitized timestamp of each edge.\n    \n- **train_mask, valid_mask, test_mask**:  \n    Nodes of Class 0 and Class 1 are randomly splitted by 70/15/15.  \n\n    \n\n\n    "
  },
  {
    "path": "tgb/datasets/dataset_scripts/process_arxiv.py",
    "content": "import json\r\nimport networkx as nx \r\nimport numpy as np\r\nimport csv\r\nfrom datetime import date\r\n\r\n\r\ndef load_full_json(fname):\r\n    json_str = \"\"\r\n    ctr = 0\r\n    with open(fname, \"r\", encoding='utf-8') as f:\r\n\r\n        #TODO need to determine how many lines form a json object \r\n        for line in f:\r\n            data = json.loads(line)\r\n            print (data)\r\n            quit() #remove this when you write the code\r\n\r\n\r\ndef main():\r\n    fname = \"nodes.json\"\r\n    load_full_json(fname)\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n"
  },
  {
    "path": "tgb/datasets/dataset_scripts/process_github.py",
    "content": "import json\nfrom datetime import datetime\n\nrels = {\n    \"IC_Created_IC_I\": \"IC_AO_C_I\",\n    \"IC_Created_U_IC\": \"U_SO_C_IC\",\n    \"I_Opened_U_I\": \"U_SE_O_I\",\n    \"I_Opened_I_R\": \"I_AO_O_R\",\n    \"I_Closed_U_I\": \"U_SE_C_I\",\n    \"I_Closed_I_R\": \"I_AO_C_R\",\n    \"I_Reopened_U_I\": \"U_SE_RO_I\",\n    \"I_Reopened_I_R\": \"I_AO_RO_R\",\n    \"PR_Opened_U_PR\": \"U_SO_O_P\",\n    \"PR_Opened_PR_R\": \"P_AO_O_R\",\n    \"PR_Closed_U_PR\": \"U_SO_C_P\",\n    \"PR_Closed_PR_R\": \"P_AO_C_R\",\n    \"PR_Reopened_U_PR\": \"U_SO_R_P\",\n    \"PR_Reopened_PR_R\": \"P_AO_R_R\",\n    \"PRRC_Created_U_PRC\": \"U_SO_C_PRC\",\n    \"PRRC_Created_PRC_PR\": \"PRC_AO_C_P\",\n    \"Forked_R_R\": \"R_FO_R\",\n    \"AddMember_U_R\": \"U_CO_A_R\",\n}\n\nissue_comment_format = \"/issue_comment/{}\"\nissue_format = \"/issue/{}\"\nuser_format = \"/user/{}\"\nrepo_format = \"/repo/{}\"\npull_request_format = \"/pr/{}\"\npull_request_review_comment_format = \"/pr_review_comment/{}\"\n\n\ndef str_to_timestamp(time_str):\n    dt = datetime.strptime(time_str, \"%Y-%m-%dT%H:%M:%SZ\")\n    return int(dt.timestamp())\n\n\ndef parse_issue_comment_events(event):\n    if event[\"payload\"][\"action\"] == \"created\":\n        issue_comment_id = event[\"payload\"][\"comment\"][\"id\"]\n        issue_id = event[\"payload\"][\"issue\"][\"id\"]\n        user_id = event[\"actor\"][\"id\"]\n        created_at = str_to_timestamp(event[\"created_at\"])\n\n        ici_event = [\n            issue_comment_format.format(issue_comment_id),\n            rels[\"IC_Created_IC_I\"],\n            issue_format.format(issue_id),\n            created_at,\n        ]\n        uic_event = [\n            user_format.format(user_id),\n            rels[\"IC_Created_U_IC\"],\n            issue_comment_format.format(issue_comment_id),\n            created_at,\n        ]\n        return [ici_event, uic_event]\n    return []\n\n\ndef parse_issue_event(event):\n    issue_id = event[\"payload\"][\"issue\"][\"id\"]\n    user_id = event[\"actor\"][\"id\"]\n    repo_id = event[\"repo\"][\"id\"]\n    created_at = str_to_timestamp(event[\"created_at\"])\n    action_map = {\n        \"opened\": (\"I_Opened_U_I\", \"I_Opened_I_R\"),\n        \"closed\": (\"I_Closed_U_I\", \"I_Closed_I_R\"),\n        \"reopened\": (\"I_Reopened_U_I\", \"I_Reopened_I_R\"),\n    }\n    for action, event_rels in action_map.items():\n        if event[\"payload\"][\"action\"] == action:\n            ui_event = [\n                user_format.format(user_id),\n                rels[event_rels[0]],\n                issue_format.format(issue_id),\n                created_at,\n            ]\n\n            ir_event = [\n                issue_format.format(issue_id),\n                rels[event_rels[1]],\n                repo_format.format(repo_id),\n                created_at,\n            ]\n            return [ui_event, ir_event]\n    return []\n\n\ndef parse_pull_request_event(event):\n    pull_request_id = event[\"payload\"][\"pull_request\"][\"id\"]\n    user_id = event[\"actor\"][\"id\"]\n    repo_id = event[\"repo\"][\"id\"]\n    created_at = str_to_timestamp(event[\"created_at\"])\n    action_map = {\n        \"opened\": (\"PR_Opened_U_PR\", \"PR_Opened_PR_R\"),\n        \"closed\": (\"PR_Closed_U_PR\", \"PR_Closed_PR_R\"),\n        \"reopened\": (\"PR_Reopened_U_PR\", \"PR_Reopened_PR_R\"),\n    }\n    for action, event_rels in action_map.items():\n        if event[\"payload\"][\"action\"] == action:\n            upr_event = [\n                user_format.format(user_id),\n                rels[event_rels[0]],\n                pull_request_format.format(pull_request_id),\n                created_at,\n            ]\n\n            prr_event = [\n                pull_request_format.format(pull_request_id),\n                rels[event_rels[1]],\n                repo_format.format(repo_id),\n                created_at,\n            ]\n            return [upr_event, prr_event]\n    return []\n\n\ndef parse_pull_request_review_comment_event(event):\n    pull_request_review_comment_id = event[\"payload\"][\"comment\"][\"id\"]\n    pull_request_id = event[\"payload\"][\"pull_request\"][\"id\"]\n    user_id = event[\"actor\"][\"id\"]\n    created_at = str_to_timestamp(event[\"created_at\"])\n    if event[\"payload\"][\"action\"] == \"created\":\n        uprc_event = [\n            user_format.format(user_id),\n            rels[\"PRRC_Created_U_PRC\"],\n            pull_request_review_comment_format.format(pull_request_review_comment_id),\n            created_at,\n        ]\n\n        prcpr_event = [\n            pull_request_review_comment_format.format(pull_request_review_comment_id),\n            rels[\"PRRC_Created_PRC_PR\"],\n            pull_request_format.format(pull_request_id),\n            created_at,\n        ]\n        return [uprc_event, prcpr_event]\n    return []\n\n\ndef parse_fork_event(event):\n    forkee_repo_id = event[\"payload\"][\"forkee\"][\"id\"]\n    forked_repo_id = event[\"repo\"][\"id\"]\n    created_at = str_to_timestamp(event[\"created_at\"])\n    return [\n        [\n            repo_format.format(forkee_repo_id),\n            rels[\"Forked_R_R\"],\n            repo_format.format(forked_repo_id),\n            created_at,\n        ]\n    ]\n\n\ndef parse_member_event(event):\n    user_id = event[\"payload\"][\"member\"][\"id\"]\n    repo_id = event[\"repo\"][\"id\"]\n    created_at = str_to_timestamp(event[\"created_at\"])\n    return [\n        [\n            user_format.format(user_id),\n            rels[\"AddMember_U_R\"],\n            repo_format.format(repo_id),\n            created_at,\n        ]\n    ]\n\n\nevent_handler_dict = {\n    \"IssueCommentEvent\": parse_issue_comment_events,\n    \"IssuesEvent\": parse_issue_event,\n    \"PullRequestEvent\": parse_pull_request_event,\n    \"PullRequestReviewCommentEvent\": parse_pull_request_review_comment_event,\n    \"ForkEvent\": parse_fork_event,\n    \"MemberEvent\": parse_member_event,\n}\n\n\ndef parse_event(event):\n    event_type = event[\"type\"]\n    if event_type in event_handler_dict:\n        output_list = event_handler_dict[event_type](event)\n        # print(\"Got {} outputs for event type {}\".format(len(output_list), event_type))\n    else:\n        # print(\"Unknown event type: {}\".format(event_type))\n        output_list = []\n    return output_list\n\n\ndef parse_file(filename):\n    events = []\n    with open(filename) as f:\n        for i, line in enumerate(f):\n            event = json.loads(line)\n            parsed_events = parse_event(event)\n            events.append(parsed_events)\n    events = [event for sublist in events for event in sublist]\n    print(\"Parsed {} events\".format(len(events)))\n    return events\n\n\nfilename = \"2015-01-01-15.json\"\nparse_file(filename)\n"
  },
  {
    "path": "tgb/datasets/dataset_scripts/tgbl-coin.py",
    "content": "import csv\r\n\r\n\"\"\"\r\n#! analyze statistics from the dataset\r\n#* 1). # of unique nodes, 2). # of edges. 3). # of unique edges, 4). # of timestamps 5). min & max of edge weights, 6). recurrence of nodes\r\n\"\"\"\r\n\r\n\r\ndef analyze_csv(fname):\r\n    node_dict = {}\r\n    edge_dict = {}\r\n    num_edges = 0\r\n    num_time = 0\r\n    prev_t = \"none\"\r\n    min_w = 100000\r\n    max_w = 0\r\n\r\n    with open(fname, \"r\") as csv_file:\r\n        csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n        line_count = 0\r\n        for row in csv_reader:\r\n            if line_count == 0:\r\n                line_count += 1\r\n            else:\r\n                # t,u,v,w\r\n                t = row[0]\r\n                u = row[1]\r\n                v = row[2]\r\n                w = float(row[3].strip())\r\n\r\n                # min & max edge weights\r\n                if w > max_w:\r\n                    max_w = w\r\n\r\n                if w < min_w:\r\n                    min_w = w\r\n\r\n                # count unique time\r\n                if t != prev_t:\r\n                    num_time += 1\r\n                    prev_t = t\r\n\r\n                # unique nodes\r\n                if u not in node_dict:\r\n                    node_dict[u] = 1\r\n                else:\r\n                    node_dict[u] += 1\r\n\r\n                if v not in node_dict:\r\n                    node_dict[v] = 1\r\n                else:\r\n                    node_dict[v] += 1\r\n\r\n                # unique edges\r\n                num_edges += 1\r\n                if (u, v) not in edge_dict:\r\n                    edge_dict[(u, v)] = 1\r\n                else:\r\n                    edge_dict[(u, v)] += 1\r\n\r\n    print(\"----------------------high level statistics-------------------------\")\r\n    print(\"number of total edges are \", num_edges)\r\n    print(\"number of nodes are \", len(node_dict))\r\n    print(\"number of unique edges are \", len(edge_dict))\r\n    print(\"number of unique timestamps are \", num_time)\r\n    print(\"maximum edge weight is \", max_w)\r\n    print(\"minimum edge weight is \", min_w)\r\n\r\n    num_10 = 0\r\n    num_100 = 0\r\n    num_1000 = 0\r\n\r\n    for node in node_dict:\r\n        if node_dict[node] >= 10:\r\n            num_10 += 1\r\n        if node_dict[node] >= 100:\r\n            num_100 += 1\r\n        if node_dict[node] >= 1000:\r\n            num_1000 += 1\r\n    print(\"number of nodes with # edges >= 10 is \", num_10)\r\n    print(\"number of nodes with # edges >= 100 is \", num_100)\r\n    print(\"number of nodes with # edges >= 1000 is \", num_1000)\r\n    print(\"----------------------high level statistics-------------------------\")\r\n\r\n\r\n\"\"\"\r\nreturn a node dict only keeping nodes with > 10 edges\r\n\"\"\"\r\n\r\n\r\ndef extract_node_dict(fname, freq=10):\r\n    node_dict = {}\r\n    with open(fname, \"r\") as csv_file:\r\n        csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n        line_count = 0\r\n        for row in csv_reader:\r\n            if line_count == 0:\r\n                line_count += 1\r\n            else:\r\n                # t,u,v,w\r\n                t = row[0]\r\n                u = row[1]\r\n                v = row[2]\r\n                w = float(row[3].strip())\r\n                if u not in node_dict:\r\n                    node_dict[u] = 1\r\n                else:\r\n                    node_dict[u] += 1\r\n\r\n                if v not in node_dict:\r\n                    node_dict[v] = 1\r\n                else:\r\n                    node_dict[v] += 1\r\n\r\n    out_dict = {}\r\n    for node in node_dict:\r\n        if node_dict[node] >= freq:\r\n            out_dict[node] = node_dict[node]\r\n    return out_dict\r\n\r\n\r\n\"\"\"\r\nremove any edges do not contain either src or dst not in the node dict\r\n\"\"\"\r\n\r\n\r\ndef clean_edgelist(fname, outname, node_dict):\r\n    with open(outname, \"w\") as outf:\r\n        write = csv.writer(outf)\r\n        fields = [\"time\", \"src\", \"dst\", \"weight\"]\r\n        write.writerow(fields)\r\n        with open(fname, \"r\") as csv_file:\r\n            csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n            line_count = 0\r\n            for row in csv_reader:\r\n                if line_count == 0:\r\n                    line_count += 1\r\n                else:\r\n                    # t,u,v,w\r\n                    t = row[0]\r\n                    u = row[1]\r\n                    v = row[2]\r\n                    w = float(row[3].strip())\r\n                    if u in node_dict and v in node_dict:\r\n                        write.writerow([t, u, v, w])\r\n\r\n\r\n\r\ndef sort_edgelist(in_file, outname):\r\n    \"\"\"\r\n    sort the edges by timestamp\r\n    \"\"\"\r\n    row_dict = {} #{day: {row: row}}\r\n    line_idx = 0\r\n    with open(outname, \"w\") as outf:\r\n        write = csv.writer(outf)\r\n        fields = [\"day\", \"src\", \"dst\", \"callsign\", \"typecode\"]\r\n        write.writerow(fields)\r\n        with open(in_file, \"r\") as csv_file:\r\n            csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n            for row in csv_reader:\r\n                if line_idx == 0:  # header\r\n                    line_idx += 1\r\n                    continue\r\n                ts = int(row[0])\r\n                if ts not in row_dict:\r\n                    row_dict[ts] = {}\r\n                    row_dict[ts][line_idx] = row\r\n                else:\r\n                    row_dict[ts][line_idx] = row\r\n                line_idx += 1\r\n        \r\n        for ts in sorted(row_dict.keys()):\r\n            for idx in row_dict[ts].keys():\r\n                row = row_dict[ts][idx]\r\n                write.writerow(row)\r\n\r\n\r\n\r\n\r\ndef main():\r\n    \"\"\"\r\n    keeping subgraph of most active nodes\r\n    \"\"\"\r\n    # freq = 10\r\n    # fname = \"stablecoin_edgelist.csv\"\r\n    # node_dict = extract_node_dict(fname, freq=freq)\r\n\r\n    # outname = \"stablecoin_freq10.csv\"\r\n    # clean_edgelist(fname, outname, node_dict)\r\n\r\n    # fname = \"stablecoin_freq10.csv\"\r\n    # analyze_csv(fname)\r\n\r\n    \"\"\"\r\n    sort edgelist by time\r\n    \"\"\"\r\n    in_file = \"tgbl-coin_edgelist.csv\"\r\n    outname = \"tgbl-coin_edgelist_sorted.csv\"\r\n    sort_edgelist(in_file, outname)\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n"
  },
  {
    "path": "tgb/datasets/dataset_scripts/tgbl-coin_neg_generator.py",
    "content": "import time\r\n\r\n\r\nfrom tgb.linkproppred.negative_generator import NegativeEdgeGenerator\r\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\r\n\r\n\r\n\r\ndef main():\r\n    r\"\"\"\r\n    Generate negative edges for the validation or test phase\r\n    \"\"\"\r\n    print(\"*** Negative Sample Generation ***\")\r\n\r\n    # setting the required parameters\r\n    num_neg_e_per_pos = 20 #100\r\n    neg_sample_strategy = \"hist_rnd\" #\"rnd\"\r\n    rnd_seed = 42\r\n\r\n\r\n    name = \"tgbl-coin\"\r\n    dataset = PyGLinkPropPredDataset(name=name, root=\"datasets\")\r\n    train_mask = dataset.train_mask\r\n    val_mask = dataset.val_mask\r\n    test_mask = dataset.test_mask\r\n    data = dataset.get_TemporalData()\r\n\r\n    data_splits = {}\r\n    data_splits['train'] = data[train_mask]\r\n    data_splits['val'] = data[val_mask]\r\n    data_splits['test'] = data[test_mask]\r\n\r\n    # Ensure to only sample actual destination nodes as negatives.\r\n    min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\r\n\r\n    # After successfully loading the dataset...\r\n    if neg_sample_strategy == \"hist_rnd\":\r\n        historical_data = data_splits[\"train\"]\r\n    else:\r\n        historical_data = None\r\n\r\n    neg_sampler = NegativeEdgeGenerator(\r\n        dataset_name=name,\r\n        first_dst_id=min_dst_idx,\r\n        last_dst_id=max_dst_idx,\r\n        num_neg_e=num_neg_e_per_pos,\r\n        strategy=neg_sample_strategy,\r\n        rnd_seed=rnd_seed,\r\n        historical_data=historical_data,\r\n    )\r\n\r\n    # generate evaluation set\r\n    partial_path = \".\"\r\n    # generate validation negative edge set\r\n    start_time = time.time()\r\n    split_mode = \"val\"\r\n    print(\r\n        f\"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}\"\r\n    )\r\n    neg_sampler.generate_negative_samples(\r\n        data=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path\r\n    )\r\n    print(\r\n        f\"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}\"\r\n    )\r\n\r\n    # generate test negative edge set\r\n    start_time = time.time()\r\n    split_mode = \"test\"\r\n    print(\r\n        f\"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}\"\r\n    )\r\n    neg_sampler.generate_negative_samples(\r\n        data=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path\r\n    )\r\n    print(\r\n        f\"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}\"\r\n    )\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n"
  },
  {
    "path": "tgb/datasets/dataset_scripts/tgbl-comment.py",
    "content": "import csv\r\nfrom tqdm import tqdm\r\nfrom os import listdir\r\nfrom tgb.utils.stats import analyze_csv\r\n\r\n\r\ndef find_filenames(path_to_dir):\r\n    r\"\"\"\r\n    find all files in a folder\r\n    Parameters:\r\n        path_to_dir (str): path to the directory\r\n    \"\"\"\r\n    filenames = listdir(path_to_dir)\r\n    return filenames\r\n\r\n\r\ndef read_edgelist(fname, outfname, write_header=False):\r\n    \"\"\"\r\n    read a space separated edgelist\r\n    comment’s author, author of the parent (the post that the comment is replied to), comment’s creation time, comment’s edge id\r\n    u,v,t,edge_id\r\n    3746738\t1637382\t1551398391\t31534079835\r\n    Parameters:\r\n        fname (str): path to the edgelist\r\n        outfname (str): path to the output file\r\n    \"\"\"\r\n    edgelist = open(fname, \"r\")\r\n    lines = list(edgelist.readlines())\r\n    edgelist.close()\r\n\r\n    with open(outfname, \"a\") as outf:\r\n        write = csv.writer(outf)\r\n        if write_header:\r\n            fields = [\"ts\", \"src\", \"dst\", \"edge_id\"]\r\n            write.writerow(fields)\r\n        for line in lines:\r\n            line = line.split()\r\n            if len(line) < 4:\r\n                continue\r\n            src = line[0]\r\n            dst = line[1]\r\n            ts = line[2]\r\n            edge_id = line[3]\r\n            write.writerow([ts, src, dst, edge_id])\r\n\r\n\r\ndef read_nodeattr(fname, outfname, write_header=False):\r\n    \"\"\"\r\n    read a space separated edgelist\r\n    comment’s edge id, Reddit’s identifier of the comment, Reddit’s identifier of the parent (the post that the comment is replied to)\r\n    Reddit’s identifier of the submission that the comment is in, name of the subreddit that the comment is in, number of characters in the comment’s body\r\n    number of words in the comment’s body, score of the comment, a flag indicating if the comment has been edited\r\n\r\n\r\n    edge_id, subreddit, num_characters, num_words, score, 'edited_flag'\r\n    Parameters:\r\n        fname (str): path to the edgelist\r\n        outfname (str): path to the output file\r\n    \"\"\"\r\n    edgelist = open(fname, \"r\")\r\n    lines = list(edgelist.readlines())\r\n    edgelist.close()\r\n\r\n    with open(outfname, \"a\") as outf:\r\n        write = csv.writer(outf)\r\n        if write_header:\r\n            fields = [\r\n                \"edge_id\",\r\n                \"subreddit\",\r\n                \"num_characters\",\r\n                \"num_words\",\r\n                \"score\",\r\n                \"edited_flag\",\r\n            ]\r\n            write.writerow(fields)\r\n        for line in lines:\r\n            line = line.split()\r\n            if len(line) < 4:\r\n                continue\r\n            edge_id = line[0]\r\n            subreddit = line[4]\r\n            num_characters = line[5]\r\n            num_words = line[6]\r\n            score = line[7]\r\n            edited_flag = line[8].strip(\"/n\")\r\n            write.writerow(\r\n                [edge_id, subreddit, num_characters, num_words, score, edited_flag]\r\n            )\r\n\r\n\r\ndef combine_edgelist_edgefeat(edgefname, featfname, outname):\r\n    \"\"\"\r\n    combine edgelist and edge features\r\n    #! remove subreddit from feature\r\n    \"\"\"\r\n    total_lines = sum(1 for line in open(edgefname))\r\n    subreddit_ids = {}\r\n\r\n    missing_ts = 0\r\n    missing_src = 0\r\n    missing_dst = 0\r\n    line_idx = 0\r\n\r\n    with open(outname, \"w\") as outf:\r\n        write = csv.writer(outf)\r\n        fields = [\"ts\", \"src\", \"dst\", \"subreddit\", \"num_words\", \"score\"]\r\n        write.writerow(fields)\r\n        sub_id = 0\r\n        edgelist = open(edgefname, \"r\")\r\n        edgefeat = open(featfname, \"r\")\r\n        edgelist.readline()\r\n        edgefeat.readline()\r\n\r\n        while True:\r\n            #'ts', 'src', 'dst', 'edge_id'\r\n            edge_line = edgelist.readline()\r\n            edge_line = edge_line.split(\",\")\r\n            if len(edge_line) < 4:\r\n                break\r\n            edge_id = int(edge_line[3])\r\n            ts = int(edge_line[0])\r\n            src = int(edge_line[1])\r\n            dst = int(edge_line[2])\r\n\r\n            #'edge_id', 'subreddit', 'num_characters', 'num_words', 'score', 'edited_flag'\r\n            feat_line = edgefeat.readline()\r\n            feat_line = feat_line.split(\",\")\r\n            edge_id_feat = int(feat_line[0])\r\n            subreddit = feat_line[1]\r\n            if subreddit not in subreddit_ids:\r\n                subreddit_ids[subreddit] = sub_id\r\n                sub_id += 1\r\n            subreddit = subreddit_ids[subreddit]\r\n            num_characters = int(feat_line[2])\r\n            num_words = int(feat_line[3])\r\n            score = int(feat_line[4])\r\n            edited_flag = bool(feat_line[5])\r\n\r\n            #! check if ts, src, dst is -1\r\n            if ts == -1:\r\n                missing_ts += 1\r\n                continue\r\n            if src == -1:\r\n                missing_src += 1\r\n                continue\r\n            if dst == -1:\r\n                missing_dst += 1\r\n                continue\r\n\r\n            if edge_id != edge_id_feat:\r\n                print(\"edge_id != edge_id_feat\")\r\n                print(edge_id)\r\n                print(edge_id_feat)\r\n                break\r\n\r\n            # write.writerow([ts, src, dst, subreddit, num_words, score])\r\n            write.writerow([ts, src, dst, num_words, score])\r\n            line_idx += 1\r\n    print(\"processed\", line_idx, \"lines\")\r\n    # print (\"there are lines\", missing_ts, \" missing timestamps\")\r\n    # print (\"there are lines\", missing_src, \" missing src\")\r\n    # print (\"there are lines\", missing_dst, \" missing dst\")\r\n\r\n\r\ndef main():\r\n    # #! unzip all xz files by $ unxz *.xz\r\n\r\n    # f_dir = \"raw/raw_2008_2010/\" #\"raw/raw_2005_2010/\" #\"raw/raw_2013_2014/\"\r\n    # fnames = find_filenames(f_dir)\r\n    # outname = \"redditcomments_edgelist_2008_2010.csv\" #\"redditcomments_edgelist_2013_2014.csv\"\r\n    # idx = 0\r\n    # for fname in tqdm(fnames):\r\n    #     if (idx == 0):\r\n    #         read_edgelist(f_dir+fname, outname, write_header=True)\r\n    #     else:\r\n    #         read_edgelist(f_dir+fname, outname, write_header=False)\r\n    #     idx += 1\r\n\r\n    # # #! extract the node attributes\r\n    f_dir = \"raw/node_2008_2010/\"#\"raw/node_2005_2010/\"\r\n    fnames = find_filenames(f_dir)\r\n    outname = \"redditcomments_edgefeat_2008_2010.csv\"\r\n    idx = 0\r\n    for fname in tqdm(fnames):\r\n        if (idx == 0):\r\n            read_nodeattr(f_dir+fname, outname, write_header=True)\r\n        else:\r\n            read_nodeattr(f_dir+fname, outname, write_header=False)\r\n        idx += 1\r\n\r\n    #! combine edgelist and edge feat file check if the edge_id matches\r\n    # edgefname = \"redditcomments_edgelist_2005_2010.csv\"\r\n    # featfname = \"redditcomments_edgefeat_2005_2010.csv\"\r\n    # outname = \"redditcomments_edgelist.csv\"\r\n    # combine_edgelist_edgefeat(edgefname, featfname, outname)\r\n\r\n    # #! analyze the extracted csv\r\n    # fname = \"redditcomments_edgelist_2005_2010.csv\"\r\n    # analyze_csv(fname)\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n"
  },
  {
    "path": "tgb/datasets/dataset_scripts/tgbl-comment_neg_generator.py",
    "content": "import time\r\n\r\n\r\nfrom tgb.linkproppred.negative_generator import NegativeEdgeGenerator\r\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\r\n\r\n\r\n\r\ndef main():\r\n    r\"\"\"\r\n    Generate negative edges for the validation or test phase\r\n    \"\"\"\r\n    print(\"*** Negative Sample Generation ***\")\r\n\r\n    # setting the required parameters\r\n    num_neg_e_per_pos = 20 #100\r\n    neg_sample_strategy = \"hist_rnd\" #\"rnd\"\r\n    rnd_seed = 42\r\n\r\n\r\n    name = \"tgbl-comment\"\r\n    dataset = PyGLinkPropPredDataset(name=name, root=\"datasets\")\r\n    train_mask = dataset.train_mask\r\n    val_mask = dataset.val_mask\r\n    test_mask = dataset.test_mask\r\n    data = dataset.get_TemporalData()\r\n\r\n    data_splits = {}\r\n    data_splits['train'] = data[train_mask]\r\n    data_splits['val'] = data[val_mask]\r\n    data_splits['test'] = data[test_mask]\r\n\r\n    # Ensure to only sample actual destination nodes as negatives.\r\n    min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\r\n\r\n    # After successfully loading the dataset...\r\n    if neg_sample_strategy == \"hist_rnd\":\r\n        historical_data = data_splits[\"train\"]\r\n    else:\r\n        historical_data = None\r\n\r\n    neg_sampler = NegativeEdgeGenerator(\r\n        dataset_name=name,\r\n        first_dst_id=min_dst_idx,\r\n        last_dst_id=max_dst_idx,\r\n        num_neg_e=num_neg_e_per_pos,\r\n        strategy=neg_sample_strategy,\r\n        rnd_seed=rnd_seed,\r\n        historical_data=historical_data,\r\n    )\r\n\r\n    # generate evaluation set\r\n    partial_path = \".\"\r\n    # generate validation negative edge set\r\n    start_time = time.time()\r\n    split_mode = \"val\"\r\n    print(\r\n        f\"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}\"\r\n    )\r\n    neg_sampler.generate_negative_samples(\r\n        data=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path\r\n    )\r\n    print(\r\n        f\"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}\"\r\n    )\r\n\r\n    # generate test negative edge set\r\n    start_time = time.time()\r\n    split_mode = \"test\"\r\n    print(\r\n        f\"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}\"\r\n    )\r\n    neg_sampler.generate_negative_samples(\r\n        data=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path\r\n    )\r\n    print(\r\n        f\"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}\"\r\n    )\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n"
  },
  {
    "path": "tgb/datasets/dataset_scripts/tgbl-flight.py",
    "content": "import dateutil.parser as dparser\r\nimport csv\r\nimport matplotlib.pyplot as plt\r\nimport numpy as np\r\nfrom tqdm import tqdm\r\nfrom os import listdir\r\nfrom datetime import datetime\r\n\r\ndef find_csv_filenames(path_to_dir, suffix=\".csv\"):\r\n    r\"\"\"\r\n    find all csv files in a directory\r\n    Parameters:\r\n        path_to_dir (str): path to the directory\r\n                suffix (str): suffix of the file\r\n    \"\"\"\r\n    filenames = listdir(path_to_dir)\r\n    return [filename for filename in filenames if filename.endswith(suffix)]\r\n\r\n\r\ndef flight2edgelist(\r\n    fname,\r\n    outname,\r\n    node_dict=None,\r\n):\r\n    \"\"\"\r\n    process all rows into\r\n    Day, src, dst, callsign, number, icao24, registration, typecode\r\n    and save it as an edgelist file\r\n    \"\"\"\r\n    miss_node_lines = 0\r\n\r\n    skip_lines = 0\r\n    print(\"processing \", outname)\r\n    with open(outname, \"w\") as outf:\r\n        write = csv.writer(outf)\r\n        fields = [\r\n            \"day\",\r\n            \"src\",\r\n            \"dst\",\r\n            \"callsign\",\r\n            \"number\",\r\n            \"icao24\",\r\n            \"registration\",\r\n            \"typecode\",\r\n        ]\r\n        write.writerow(fields)\r\n\r\n        with open(fname, \"r\") as csv_file:\r\n            csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n            line_count = 0\r\n            # callsign,number,icao24,registration,typecode,origin,destination,firstseen,lastseen,day,latitude_1,longitude_1,altitude_1,latitude_2,longitude_2,altitude_2\r\n            for row in csv_reader:\r\n                if line_count == 0:\r\n                    line_count += 1\r\n                else:\r\n                    out = []\r\n                    callsign = row[0]\r\n                    number = row[1]\r\n                    icao24 = row[2]\r\n                    registration = row[3]\r\n                    typecode = row[4]\r\n                    src = row[5]\r\n                    if src == \"\":\r\n                        skip_lines += 1\r\n                        continue\r\n                    dst = row[6]\r\n                    if dst == \"\":\r\n                        skip_lines += 1\r\n                        continue\r\n\r\n                    if node_dict is not None:\r\n                        if src not in node_dict:\r\n                            miss_node_lines += 1\r\n                            continue\r\n                        if dst not in node_dict:\r\n                            miss_node_lines += 1\r\n                            continue\r\n                    day = row[9]\r\n                    day = day[0:10]\r\n\r\n                    out.append(day)\r\n                    out.append(src)\r\n                    out.append(dst)\r\n                    out.append(callsign)\r\n                    out.append(number)\r\n                    out.append(icao24)\r\n                    out.append(registration)\r\n                    out.append(typecode)\r\n                    write.writerow(out)\r\n                    line_count += 1\r\n        print(f\"Processed {line_count} lines.\")\r\n        print(f\"Skipped {skip_lines} lines.\")\r\n        print(f\"missing node {miss_node_lines} lines.\")\r\n    return line_count, skip_lines, miss_node_lines\r\n\r\n\r\ndef load_icao_airports(fname=\"airport_codes.csv\"):\r\n    airports_continent = {}\r\n    airports_country = {}\r\n\r\n    edgelist = open(fname, \"r\")\r\n    lines = list(edgelist.readlines())\r\n    edgelist.close()\r\n    # date u  v  w\r\n    # find how many timestamps there are\r\n\r\n    for i in range(0, len(lines)):\r\n        line = lines[i]\r\n        values = line.split(\",\")\r\n        icao = values[0]\r\n        continent = values[4]\r\n        country = values[5]\r\n        airports_continent[icao] = continent\r\n        airports_country[icao] = country\r\n    return airports_continent, airports_country\r\n\r\n\r\ndef merge_edgelist(input_names: str, in_dir: str, outname: str):\r\n    \"\"\"\r\n    merge a list of edgefiles into one file\r\n    \"\"\"\r\n    line_count = 0\r\n    total = 0\r\n    with open(outname, \"w\") as outf:\r\n        write = csv.writer(outf)\r\n        fields = [\"day\", \"src\", \"dst\", \"callsign\", \"typecode\"]\r\n        write.writerow(fields)\r\n        for csv_name in tqdm(input_names):\r\n            in_name = in_dir + csv_name\r\n            line_count = 0\r\n            with open(in_name, \"r\") as csv_file:\r\n                csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n                for row in csv_reader:\r\n                    if line_count == 0:  # header\r\n                        line_count += 1\r\n                    else:\r\n                        # Day, src, dst, callsign, number, icao24, registration, typecode\r\n                        day = row[0]\r\n                        src = row[1]\r\n                        dst = row[2]\r\n                        callsign = row[3]\r\n                        typecode = row[-1]\r\n                        out = [day, src, dst, callsign, typecode]\r\n                        write.writerow(out)\r\n                        total += 1\r\n\r\n\r\ndef clean_node_feat(in_file, outname):\r\n    with open(outname, \"w\") as outf:\r\n        write = csv.writer(outf)\r\n        fields = [\r\n            \"airport_code\",\r\n            \"type\",\r\n            \"continent\",\r\n            \"iso_region\",\r\n            \"longitude\",\r\n            \"latitude\",\r\n        ]\r\n        write.writerow(fields)\r\n        idx = 0\r\n        with open(in_file, \"r\") as csv_file:\r\n            csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n            for row in csv_reader:\r\n                if idx == 0:\r\n                    idx += 1\r\n                    continue\r\n                else:\r\n                    # ident,type,name,elevation_ft,continent,iso_country,iso_region,municipality,gps_code,iata_code,local_code,coordinates\r\n                    airport_code = row[0]\r\n                    type = row[1]\r\n                    continent = row[4]\r\n                    iso_region = row[6]\r\n                    longitude = float(row[-1].split(\",\")[0])\r\n                    latitude = float(row[-1].split(\",\")[1])\r\n                    out = [\r\n                        airport_code,\r\n                        type,\r\n                        continent,\r\n                        iso_region,\r\n                        longitude,\r\n                        latitude,\r\n                    ]\r\n                    idx += 1\r\n                    write.writerow(out)\r\n\r\n\r\n\r\ndef sort_edgelist(in_file, outname):\r\n    \"\"\"\r\n    sort the edges by day\r\n    \"\"\"\r\n    TIME_FORMAT = \"%Y-%m-%d\"\r\n    row_dict = {} #{day: {row: row}}\r\n    line_idx = 0\r\n    with open(outname, \"w\") as outf:\r\n        write = csv.writer(outf)\r\n        fields = [\"day\", \"src\", \"dst\", \"callsign\", \"typecode\"]\r\n        write.writerow(fields)\r\n        with open(in_file, \"r\") as csv_file:\r\n            csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n            for row in csv_reader:\r\n                if line_idx == 0:  # header\r\n                    line_idx += 1\r\n                    continue\r\n                day = row[0]\r\n                ts = datetime.strptime(day, TIME_FORMAT)\r\n                ts = ts.timestamp()\r\n                if ts not in row_dict:\r\n                    row_dict[ts] = {}\r\n                    row_dict[ts][line_idx] = row\r\n                else:\r\n                    row_dict[ts][line_idx] = row\r\n                line_idx += 1\r\n        \r\n        for ts in sorted(row_dict.keys()):\r\n            for idx in row_dict[ts].keys():\r\n                row = row_dict[ts][idx]\r\n                write.writerow(row)\r\n\r\n\r\ndef date2ts(date_str: str) -> float:\r\n    r\"\"\"\r\n    convert date string to timestamp\r\n    \"\"\"\r\n    TIME_FORMAT = \"%Y-%m-%d-%z\"\r\n    date_cur = datetime.strptime(date_str, TIME_FORMAT)\r\n    return float(date_cur.timestamp())\r\n\r\n\r\ndef main():\r\n    \"\"\"\r\n    instructions for recompiling the dataset from\r\n    https://zenodo.org/record/7323875#.ZD1-43ZKguX\r\n\r\n    1. download all datasets into a folder specified by in_dir (such as full_dataset)\r\n    2. run the following code to extract the needed information\r\n    \"\"\"\r\n\r\n    # _, airports_country = load_icao_airports(fname=\"airport_codes.csv\")\r\n\r\n    # in_dir = \"full_dataset/\"\r\n    # out_dir = \"edgelists/\"\r\n\r\n    # csv_name = \"flightlist_20190101_20190131.csv\"\r\n\r\n    # csv_names = find_csv_filenames(in_dir)\r\n    # processed_lines = 0\r\n    # skipped_lines = 0\r\n    # miss_node_lines = 0\r\n\r\n    # for csv_name in tqdm(csv_names):\r\n    #     fname = in_dir + csv_name\r\n    #     outname = out_dir + csv_name[11:-4] + \"edgelist\"+\".csv\"\r\n    #     line_count, skip_lines, miss_node = flight2edgelist(fname, outname, node_dict=airports_country)\r\n    #     processed_lines += line_count\r\n    #     skipped_lines += skip_lines\r\n    #     miss_node_lines += miss_node\r\n    # print(f'Processed {processed_lines} lines.')\r\n    # print(f'Skipped {skipped_lines} lines.')\r\n    # print(f'missing node {miss_node_lines} lines.')\r\n\r\n    \"\"\"\r\n    merge all edgelists into one file\r\n    \"\"\"\r\n    # in_dir = \"edgelists/\"\r\n    # outname = \"opensky_edgelist.csv\"\r\n    # csv_names = find_csv_filenames(in_dir)\r\n    # merge_edgelist(csv_names, in_dir, outname)\r\n\r\n    \"\"\"\r\n    clean the node features\r\n    \"\"\"\r\n    # in_file = \"edgelists/airport_codes.csv\"\r\n    # outname = \"airport_node_feat.csv\"\r\n    # clean_node_feat(in_file, outname)\r\n\r\n\r\n    \"\"\"\r\n    sort the edgelist by day\r\n    \"\"\"\r\n    # in_file = \"tgbl-flight_edgelist.csv\"\r\n    # outname = \"tgbl-flight_edgelist_sorted.csv\"\r\n    # sort_edgelist(in_file, outname)\r\n\r\n\r\n    \"\"\"\r\n    fixing time zone different for strip time\r\n    \"\"\"\r\n    tz_offset = \"-0500\"\r\n    ts = \"2021-11-29\" + \"-\" + tz_offset\r\n    print (date2ts(ts))\r\n\r\n\r\n    tz_offset = \"+0000\"\r\n    ts_utc = \"2021-11-29\" + \"-\" + tz_offset\r\n    print (date2ts(ts_utc))\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n"
  },
  {
    "path": "tgb/datasets/dataset_scripts/tgbl-flight_neg_generator.py",
    "content": "import time\r\n\r\n\r\nfrom tgb.linkproppred.negative_generator import NegativeEdgeGenerator\r\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\r\n\r\n\r\n\r\ndef main():\r\n    r\"\"\"\r\n    Generate negative edges for the validation or test phase\r\n    \"\"\"\r\n    print(\"*** Negative Sample Generation ***\")\r\n\r\n    # setting the required parameters\r\n    num_neg_e_per_pos = 20\r\n    neg_sample_strategy = \"hist_rnd\" #\"rnd\"\r\n    rnd_seed = 42\r\n\r\n\r\n    name = \"tgbl-flight\"\r\n    dataset = PyGLinkPropPredDataset(name=name, root=\"datasets\")\r\n    train_mask = dataset.train_mask\r\n    val_mask = dataset.val_mask\r\n    test_mask = dataset.test_mask\r\n    data = dataset.get_TemporalData()\r\n\r\n    data_splits = {}\r\n    data_splits['train'] = data[train_mask]\r\n    data_splits['val'] = data[val_mask]\r\n    data_splits['test'] = data[test_mask]\r\n\r\n    # Ensure to only sample actual destination nodes as negatives.\r\n    min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\r\n\r\n    # After successfully loading the dataset...\r\n    if neg_sample_strategy == \"hist_rnd\":\r\n        historical_data = data_splits[\"train\"]\r\n    else:\r\n        historical_data = None\r\n\r\n    neg_sampler = NegativeEdgeGenerator(\r\n        dataset_name=name,\r\n        first_dst_id=min_dst_idx,\r\n        last_dst_id=max_dst_idx,\r\n        num_neg_e=num_neg_e_per_pos,\r\n        strategy=neg_sample_strategy,\r\n        rnd_seed=rnd_seed,\r\n        historical_data=historical_data,\r\n    )\r\n\r\n    # generate evaluation set\r\n    partial_path = \".\"\r\n    # generate validation negative edge set\r\n    start_time = time.time()\r\n    split_mode = \"val\"\r\n    print(\r\n        f\"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}\"\r\n    )\r\n    neg_sampler.generate_negative_samples(\r\n        data=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path\r\n    )\r\n    print(\r\n        f\"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}\"\r\n    )\r\n\r\n    # generate test negative edge set\r\n    start_time = time.time()\r\n    split_mode = \"test\"\r\n    print(\r\n        f\"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}\"\r\n    )\r\n    neg_sampler.generate_negative_samples(\r\n        data=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path\r\n    )\r\n    print(\r\n        f\"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}\"\r\n    )\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n"
  },
  {
    "path": "tgb/datasets/dataset_scripts/tgbl-review.py",
    "content": "import pyarrow.dataset as ds\r\nimport csv\r\nimport numpy as np\r\nfrom tgb.utils.stats import analyze_csv\r\nimport pandas as pd\r\nfrom tqdm import tqdm\r\n\r\n\r\ndef collect_csv(dir_name=\"software\"):\r\n    dataset = ds.dataset(dir_name, format=\"csv\")\r\n    df = dataset.to_table().to_pandas()\r\n    df.to_csv(dir_name + \".csv\", index=True)\r\n\r\n\r\ndef reorder_column(fname: str, outname: str):\r\n    with open(outname, \"w\") as outf:\r\n        write = csv.writer(outf)\r\n        fields = [\"ts\", \"source\", \"target\", \"weight\"]\r\n        write.writerow(fields)\r\n        line_count = 0\r\n        with open(fname, \"r\") as csv_file:\r\n            csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n            for row in csv_reader:\r\n                if line_count == 0:  # header\r\n                    line_count += 1\r\n                else:\r\n                    # edgeid, SourceId,TargetId,Weight,Timestamp\r\n                    src = row[1]\r\n                    dst = row[2]\r\n                    w = row[3]\r\n                    ts = row[4]\r\n                    write.writerow([ts, src, dst, w])\r\n                    line_count += 1\r\n\r\n\r\ndef sort_edgelist(fname: str, outname: str):\r\n    with open(outname, \"w\") as outf:\r\n        write = csv.writer(outf)\r\n        fields = [\"ts\", \"source\", \"target\", \"weight\"]\r\n        write.writerow(fields)\r\n        line_count = 0\r\n        ts_list = []\r\n        line_list = []\r\n\r\n        with open(fname, \"r\") as csv_file:\r\n            csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n            for row in csv_reader:\r\n                if line_count == 0:  # header\r\n                    line_count += 1\r\n                else:\r\n                    ts = int(row[0])\r\n                    src = row[1]\r\n                    dst = row[2]\r\n                    w = row[3]\r\n                    ts_list.append(ts)\r\n                    line_list.append([ts, src, dst, w])\r\n                    # write.writerow([ts, src, dst, w])\r\n                    line_count += 1\r\n\r\n        ts_list = np.array(ts_list)\r\n        idx = np.argsort(ts_list)\r\n        idx = idx.tolist()\r\n\r\n        line_list_out = []\r\n        for i in idx:\r\n            line_list_out.append(line_list[i])\r\n        for line in line_list_out:\r\n            write.writerow(line)\r\n\r\n\r\ndef count_degree(fname: str):\r\n    node_counts = {}\r\n    line_count = 0\r\n    with open(fname, \"r\") as csv_file:\r\n        csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n        for row in csv_reader:\r\n            if line_count == 0:  # header\r\n                line_count += 1\r\n            else:\r\n                ts = int(row[0])\r\n                src = row[1]\r\n                dst = row[2]\r\n                w = row[3]\r\n                if src not in node_counts:\r\n                    node_counts[src] = 1\r\n                else:\r\n                    node_counts[src] += 1\r\n\r\n                if dst not in node_counts:\r\n                    node_counts[dst] = 1\r\n                else:\r\n                    node_counts[dst] += 1\r\n                line_count += 1\r\n    return node_counts\r\n\r\n\r\ndef reduce_edgelist(fname: str, outname: str, node10_id: dict):\r\n    with open(outname, \"w\") as outf:\r\n        write = csv.writer(outf)\r\n        fields = [\"ts\", \"source\", \"target\", \"weight\"]\r\n        write.writerow(fields)\r\n        line_count = 0\r\n\r\n        with open(fname, \"r\") as csv_file:\r\n            csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n            for row in csv_reader:\r\n                if line_count == 0:  # header\r\n                    line_count += 1\r\n                else:\r\n                    ts = int(row[0])\r\n                    src = row[1]\r\n                    dst = row[2]\r\n                    w = row[3]\r\n                    if (src in node10_id) and (dst in node10_id):\r\n                        write.writerow([ts, src, dst, w])\r\n                    line_count += 1\r\n\r\n\"\"\"\r\nfunction for review\r\n\"\"\"\r\ndef csv_process_review(\r\n    fname: str,\r\n    outname: str = \"review.csv\",\r\n) -> pd.DataFrame:\r\n    r\"\"\"\r\n    used for processing review dataset, helper function, not used in actual dataloading\r\n    input .csv file format should be: timestamp, node u, node v, attributes\r\n    Parameters:\r\n        fname: the path to the raw data\r\n    Returns:\r\n        df: a pandas dataframe containing the edgelist data\r\n        feat_l: a numpy array containing the node features\r\n        node_ids: a dictionary mapping node id to integer\r\n    \"\"\"\r\n    src_ids = {}\r\n    dst_ids = {}\r\n    src_ctr = 0\r\n    dst_ctr = 0\r\n\r\n    with open(fname, \"r\") as csv_file:\r\n        csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n        idx = 0\r\n        \"\"\"\r\n        ts,source,target,weight\r\n        929232000,137139,30122,5.0\r\n        930787200,129185,175070,2.0\r\n        931824000,246213,30122,2.0\r\n        \"\"\"\r\n        for row in tqdm(csv_reader):\r\n            if idx == 0:\r\n                idx += 1\r\n                continue\r\n            else:\r\n                ts = int(row[0])\r\n                src = row[1]\r\n                dst = row[2]\r\n                if src not in src_ids:\r\n                    src_ids[src] = src_ctr\r\n                    src_ctr += 1\r\n                if dst not in dst_ids:\r\n                    dst_ids[dst] = dst_ctr\r\n                    dst_ctr += 1\r\n                w = float(row[3])\r\n    \r\n    #! ensure that source and destination nodes are unique and non-overlapping\r\n    src_ctr += 1\r\n    dst_ids = {k:v+src_ctr for k,v in dst_ids.items()}\r\n\r\n    with open(outname, \"w\") as outf:\r\n        write = csv.writer(outf)\r\n        fields = [\"ts\",\"source\",\"target\",\"weight\"]\r\n        write.writerow(fields)\r\n\r\n        with open(fname, \"r\") as csv_file:\r\n            csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n            idx = 0\r\n            \"\"\"\r\n            ts,source,target,weight\r\n            929232000,137139,30122,5.0\r\n            930787200,129185,175070,2.0\r\n            931824000,246213,30122,2.0\r\n            \"\"\"\r\n            for row in tqdm(csv_reader):\r\n                if idx == 0:\r\n                    idx += 1\r\n                    continue\r\n                else:\r\n                    ts = int(row[0])\r\n                    src = src_ids[row[1]]\r\n                    dst = dst_ids[row[2]]\r\n                    w = float(row[3])\r\n                write.writerow([ts,src,dst,w])\r\n\r\n\r\ndef main():\r\n    # # collect csv\r\n    # # collect_csv(dir_name = \"software\")\r\n    # collect_csv(dir_name=\"books\")\r\n    # # collect_csv(dir_name = \"electronics\")\r\n\r\n    # # #* reorder column\r\n    # # fname = \"electronics.csv\"\r\n    # # outname = \"amazonreview_edgelist.csv\"\r\n    # # reorder_column(fname,\r\n    # #                outname)\r\n\r\n    # # #* sort edgelist\r\n    # # fname = \"amazonreview_edgelist.csv\"\r\n    # # outname = \"amazonreview_edgelist_sort.csv\"\r\n    # # sort_edgelist(fname,\r\n    # #               outname)\r\n\r\n    # fname = \"amazonreview_edgelist_reduce.csv\"\r\n    # analyze_csv(fname)\r\n\r\n    # # fname = \"amazonreview_edgelist.csv\"\r\n    # # node_counts = count_degree(fname)\r\n    # # node10_id = {}\r\n    # # for node in node_counts:\r\n    # #     if node_counts[node] > 10:\r\n    # #         node10_id[node] = node_counts[node]\r\n\r\n    # # outname = \"amazonreview_edgelist_reduce.csv\"\r\n    # # reduce_edgelist(fname,\r\n    # #                 outname,\r\n    # #                 node10_id)\r\n\r\n    csv_process_review(\"tgbl-review_edgelist_v2.csv\", \"review.csv\")\r\n\r\n    \r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n"
  },
  {
    "path": "tgb/datasets/dataset_scripts/tgbl-review_neg_generator.py",
    "content": "import time\r\n\r\n\r\nfrom tgb.linkproppred.negative_generator import NegativeEdgeGenerator\r\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\r\n\r\n\r\n\r\ndef main():\r\n    r\"\"\"\r\n    Generate negative edges for the validation or test phase\r\n    \"\"\"\r\n    print(\"*** Negative Sample Generation ***\")\r\n\r\n    # setting the required parameters\r\n    num_neg_e_per_pos = 100 #20 #100\r\n    neg_sample_strategy = \"hist_rnd\" #\"rnd\"\r\n    rnd_seed = 42\r\n\r\n\r\n    name = \"tgbl-review\"\r\n    dataset = PyGLinkPropPredDataset(name=name, root=\"datasets\")\r\n    train_mask = dataset.train_mask\r\n    val_mask = dataset.val_mask\r\n    test_mask = dataset.test_mask\r\n    data = dataset.get_TemporalData()\r\n\r\n    data_splits = {}\r\n    data_splits['train'] = data[train_mask]\r\n    data_splits['val'] = data[val_mask]\r\n    data_splits['test'] = data[test_mask]\r\n\r\n    min_src_idx = int(data.src.min())\r\n    print (f\"min_src_idx: {min_src_idx}\")\r\n\r\n    # Ensure to only sample actual destination nodes as negatives.\r\n    min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\r\n\r\n    # After successfully loading the dataset...\r\n    if neg_sample_strategy == \"hist_rnd\":\r\n        historical_data = data_splits[\"train\"]\r\n    else:\r\n        historical_data = None\r\n\r\n    neg_sampler = NegativeEdgeGenerator(\r\n        dataset_name=name,\r\n        first_dst_id=min_dst_idx,\r\n        last_dst_id=max_dst_idx,\r\n        num_neg_e=num_neg_e_per_pos,\r\n        strategy=neg_sample_strategy,\r\n        rnd_seed=rnd_seed,\r\n        historical_data=historical_data,\r\n    )\r\n\r\n    # generate evaluation set\r\n    partial_path = \".\"\r\n    # generate validation negative edge set\r\n    start_time = time.time()\r\n    split_mode = \"val\"\r\n    print(\r\n        f\"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}\"\r\n    )\r\n    neg_sampler.generate_negative_samples(\r\n        data=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path\r\n    )\r\n    print(\r\n        f\"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}\"\r\n    )\r\n\r\n    # generate test negative edge set\r\n    start_time = time.time()\r\n    split_mode = \"test\"\r\n    print(\r\n        f\"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}\"\r\n    )\r\n    neg_sampler.generate_negative_samples(\r\n        data=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path\r\n    )\r\n    print(\r\n        f\"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}\"\r\n    )\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n"
  },
  {
    "path": "tgb/datasets/dataset_scripts/tgbl-wiki_neg_generator.py",
    "content": "import timeit\r\n\r\n\r\nfrom tgb.linkproppred.negative_generator import NegativeEdgeGenerator\r\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\r\n\r\n\r\n\r\ndef main():\r\n    r\"\"\"\r\n    Generate negative edges for the validation or test phase\r\n    \"\"\"\r\n    print(\"*** Negative Sample Generation ***\")\r\n\r\n    # setting the required parameters\r\n    num_neg_e_per_pos = 12000 #11000 #10000 #20 #100\r\n    neg_sample_strategy = \"hist_rnd\" #\"rnd\"\r\n    rnd_seed = 42\r\n\r\n\r\n    name = \"tgbl-wiki\"\r\n    dataset = PyGLinkPropPredDataset(name=name, root=\"datasets\")\r\n    train_mask = dataset.train_mask\r\n    val_mask = dataset.val_mask\r\n    test_mask = dataset.test_mask\r\n    data = dataset.get_TemporalData()\r\n\r\n    data_splits = {}\r\n    data_splits['train'] = data[train_mask]\r\n    data_splits['val'] = data[val_mask]\r\n    data_splits['test'] = data[test_mask]\r\n\r\n    # Ensure to only sample actual destination nodes as negatives.\r\n    min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\r\n\r\n    # After successfully loading the dataset...\r\n    if neg_sample_strategy == \"hist_rnd\":\r\n        historical_data = data_splits[\"train\"]\r\n    else:\r\n        historical_data = None\r\n\r\n    neg_sampler = NegativeEdgeGenerator(\r\n        dataset_name=name,\r\n        first_dst_id=min_dst_idx,\r\n        last_dst_id=max_dst_idx,\r\n        num_neg_e=num_neg_e_per_pos,\r\n        strategy=neg_sample_strategy,\r\n        rnd_seed=rnd_seed,\r\n        historical_data=historical_data,\r\n    )\r\n\r\n    # generate evaluation set\r\n    partial_path = \"./\"\r\n    # generate validation negative edge set\r\n    start_time = time.time()\r\n    split_mode = \"val\"\r\n    print(\r\n        f\"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}\"\r\n    )\r\n    neg_sampler.generate_negative_samples(\r\n        data=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path\r\n    )\r\n    print(\r\n        f\"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}\"\r\n    )\r\n\r\n    # generate test negative edge set\r\n    start_time = timeit.default_timer()\r\n    split_mode = \"test\"\r\n    print(\r\n        f\"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}\"\r\n    )\r\n    neg_sampler.generate_negative_samples(\r\n        data=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path\r\n    )\r\n    print(\r\n        f\"INFO: End of negative samples generation. Elapsed Time (s): {timeit.default_timer()- start_time: .4f}\"\r\n\r\n    )\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n"
  },
  {
    "path": "tgb/datasets/dataset_scripts/tgbn-genre.py",
    "content": "import networkx as nx\r\nimport numpy as np\r\nimport matplotlib.pyplot as plt\r\nimport csv\r\nfrom typing import Optional, Dict, Any, Tuple\r\nimport datetime\r\nfrom datetime import date, timedelta\r\nfrom difflib import SequenceMatcher\r\n\r\n\r\n# similarity_dict = {('electronic', 'electronica'): 0.9523809523809523, ('electronic', 'electro'): 0.8235294117647058, ('alternative', 'alternative rock'): 0.8148148148148148, ('nu jazz', 'nu-jazz'): 0.8571428571428571,\r\n#                    ('funky', 'funk'): 0.8888888888888888, ('funky', 'funny'): 0.8, ('post rock', 'pop rock'): 0.8235294117647058, ('post rock', 'post-rock'): 0.8888888888888888,\r\n#                    ('instrumental', 'instrumental rock'): 0.8275862068965517, ('chill', 'chile'): 0.8, ('Drum and bass', 'Drum n Bass'): 0.8333333333333334, ('female vocalists', 'female vocalist'): 0.967741935483871,\r\n#                    ('female vocalists', 'male vocalists'): 0.9333333333333333, ('female vocalists', 'male vocalist'): 0.896551724137931, ('electro', 'electropop'): 0.8235294117647058, ('funk', 'fun'): 0.8571428571428571,\r\n#                    ('hip hop', 'trip hop'): 0.8, ('hip hop', 'hiphop'): 0.9230769230769231, ('trip-hop', 'trip hop'): 0.875, ('indie rock', 'indie folk'): 0.8, ('new age', 'new wave'): 0.8, ('new age', 'new rave'): 0.8,\r\n#                    ('synthpop', 'synth pop'): 0.9411764705882353, ('industrial', 'industrial rock'): 0.8, ('cover', 'covers'): 0.9090909090909091, ('post hardcore', 'post-hardcore'): 0.9230769230769231, ('mathcore', 'deathcore'): 0.8235294117647058,\r\n#                    ('deutsch', 'dutch'): 0.8333333333333334, ('swing', 'sting'): 0.8, ('female vocalist', 'male vocalists'): 0.896551724137931, ('female vocalist', 'male vocalist'): 0.9285714285714286, ('new wave', 'new rave'): 0.875,\r\n#                    ('male vocalists', 'male vocalist'): 0.9629629629629629, ('Progressive rock', 'Progressive'): 0.8148148148148148, ('Alt-country', 'alt country'): 0.8181818181818182, ('favorites', 'Favourites'): 0.8421052631578947,\r\n#                    ('favorites', 'favourite'): 0.8888888888888888, ('favorites', 'Favorite'): 0.8235294117647058, ('1970s', '1980s'): 0.8, ('1970s', '1990s'): 0.8, ('proto-punk', 'post-punk'): 0.8421052631578947,\r\n#                    ('folk rock', 'folk-rock'): 0.8888888888888888, ('1980s', '1990s'): 0.8, ('favorite songs', 'Favourite Songs'): 0.8275862068965517, ('melancholic', 'melancholy'): 0.8571428571428571,\r\n#                    ('Favourites', 'favourite'): 0.8421052631578947, ('Favourites', 'Favorite'): 0.8888888888888888, ('Favourites', 'Favourite Songs'): 0.8, ('favourite', 'Favorite'): 0.8235294117647058,\r\n#                    ('american', 'americana'): 0.9411764705882353, ('american', 'african'): 0.8, ('american', 'mexican'): 0.8, ('rock en español', 'Rock en Espanol'): 0.8, ('trance', 'psytrance'): 0.8,\r\n#                    ('power pop', 'powerpop'): 0.9411764705882353, ('psychill', 'psychobilly'): 0.8421052631578947, ('Progressive metal', 'progressive death metal'): 0.8, ('Progressive metal', 'progressive black metal'): 0.8,\r\n#                    ('progressive death metal', 'progressive black metal'): 0.8260869565217391, ('romantic', 'new romantic'): 0.8, ('hair metal', 'Dark metal'): 0.8, ('melodic metal', 'melodic black metal'): 0.8125,\r\n#                    ('funk metal', 'folk metal'): 0.8, ('death metal', 'math metal'): 0.8571428571428571, ('Technical Metal', 'Technical Death Metal'): 0.8333333333333334, ('speed metal', 'sid metal'): 0.8}\r\n\r\n#! map diferent spelling and similar ones to the same one, use space if possible\r\n# ? key = to replace, value = to keep\r\n\r\nsimilarity_dict = {\r\n    \"nu-jazz\": \"nu jazz\",\r\n    \"funky\": \"funk\",\r\n    \"post-rock\": \"post rock\",\r\n    \"Drum n Bass\": \"Drum and bass\",\r\n    \"female vocalists\": \"female vocalist\",\r\n    \"male vocalists\": \"male vocalist\",\r\n    \"hiphop\": \"hip hop\",\r\n    \"trip-hop\": \"trip hop\",\r\n    \"synthpop\": \"synth pop\",\r\n    \"covers\": \"cover\",\r\n    \"post-hardcore\": \"post hardcore\",\r\n    \"Favourites\": \"favorites\",\r\n    \"favourite\": \"favorites\",\r\n    \"Favorite\": \"favorites\",\r\n    \"folk-rock\": \"folk rock\",\r\n    \"favorite songs\": \"favorites\",\r\n    \"Favourite Songs\": \"favorites\",\r\n    \"americana\": \"american\",\r\n    \"Rock en Espanol\": \"rock en español\",\r\n    \"melancholy\": \"melancholic\",\r\n    \"powerpop\": \"power pop\",\r\n}\r\n\r\n\r\ndef filter_genre_edgelist(fname, genres_dict):\r\n    \"\"\"\r\n    rewrite the edgelist but only keeping the genres with high frequency, also uses similarity_dict\r\n    \"\"\"\r\n    edgelist = open(fname, \"r\")\r\n    lines = list(edgelist.readlines())\r\n    edgelist.close()\r\n\r\n    with open(\"lastfm_edgelist_clean.csv\", \"w\") as f:\r\n        write = csv.writer(f)\r\n        fields = [\"user_id\", \"timestamp\", \"tags\", \"weight\"]\r\n        write.writerow(fields)\r\n\r\n        for i in range(1, len(lines)):\r\n            vals = lines[i].split(\",\")\r\n            user_id = vals[1]\r\n            time = vals[2]\r\n            genre = vals[3].strip('\"').strip(\"['\")\r\n            w = vals[4][:-3]\r\n            if genre in genres_dict:\r\n                if genre in similarity_dict:\r\n                    genre = similarity_dict[genre]\r\n                write.writerow([user_id, time, genre, w])\r\n\r\n\r\ndef get_genre_list(fname):\r\n    \"\"\"\r\n    edge_id, user_id, timestamp, tags\r\n\r\n    0,user_000001,2006-08-13 14:59:59+00:00,\"['electronic', 0.5319148936170213]\"\r\n    0,user_000001,2006-08-13 14:59:59+00:00,\"['alternative', 0.46808510638297873]\"\r\n    1,user_000001,2006-08-13 15:36:22+00:00,\"['electronic', 0.6410256410256411]\"\r\n    1,user_000001,2006-08-13 15:36:22+00:00,\"['chillout', 0.358974358974359]\"\r\n    2,user_000001,2006-08-13 15:40:13+00:00,\"['math rock', 1.0]\"\r\n    3,user_000001,2006-08-15 13:41:18+00:00,\"['electronica', 1.0]\"\r\n    \"\"\"\r\n    edgelist = open(fname, \"r\")\r\n    lines = list(edgelist.readlines())\r\n    edgelist.close()\r\n    genre_dict = {}\r\n    for i in range(1, len(lines)):\r\n        vals = lines[i].split(\",\")\r\n        user_id = vals[1]\r\n        time = vals[2]\r\n        genre = vals[3].strip('\"').strip(\"['\")\r\n        # genre = vals[3]\r\n        w = float(vals[4][:-3])\r\n        if genre not in genre_dict:\r\n            genre_dict[genre] = 1\r\n        else:\r\n            genre_dict[genre] += 1\r\n\r\n    # TODO check the frequency of genres and threshold\r\n    genre_list_10 = []\r\n    genre_list_100 = []\r\n    genre_list_1000 = []\r\n    genre_list_2000 = []\r\n    for key, freq in genre_dict.items():\r\n        if freq > 10:\r\n            genre_list_10.append([key])\r\n        if freq > 100:\r\n            genre_list_100.append([key])\r\n        if freq > 1000:\r\n            genre_list_1000.append([key])\r\n        if freq > 2000:\r\n            genre_list_2000.append([key])\r\n    print(\"number of genres with frequency > 10: \" + str(len(genre_list_10)))\r\n    print(\"number of genres with frequency > 100: \" + str(len(genre_list_100)))\r\n    print(\"number of genres with frequency > 1000: \" + str(len(genre_list_1000)))\r\n    print(\"number of genres with frequency > 2000: \" + str(len(genre_list_2000)))\r\n    fields = [\"genre\"]\r\n\r\n    with open(\"genre_list_1000.csv\", \"w\") as f:\r\n        write = csv.writer(f)\r\n        write.writerow(fields)\r\n        write.writerows(genre_list_1000)\r\n\r\n\r\ndef find_unique_genres(fname: str, threshold: float = 0.8):\r\n    \"\"\"\r\n    identify fuzzy strings which are actually the same genre, differences can be spacing, typo etc.\r\n    \"\"\"\r\n    # load all genre names into a list\r\n    edgelist = open(fname, \"r\")\r\n    lines = list(edgelist.readlines())\r\n    edgelist.close()\r\n\r\n    genres = []\r\n    sim_genres = {}\r\n    for i in range(1, len(lines)):\r\n        line = lines[i]\r\n        genre = line.strip(\"\\n\")\r\n        genres.append(genre)\r\n\r\n    for i in range(len(genres)):\r\n        for j in range(i + 1, len(genres)):\r\n            text = genres[i]\r\n            search_key = genres[j]\r\n            sim = SequenceMatcher(None, text, search_key)\r\n            sim = sim.ratio()\r\n            if sim >= threshold:\r\n                sim_genres[(text, search_key)] = sim\r\n\r\n    print(\"there are \" + str(len(sim_genres)) + \" similar genres\")\r\n    print(sim_genres)\r\n\r\n\r\ndef load_genre_dict(\r\n    fname: str,\r\n) -> Dict[str, Any]:\r\n    \"\"\"\r\n    reading the list of genres from genre_list.csv\r\n    parameters:\r\n        fname: file name of the genre list\r\n    Returns:\r\n        genre_dict: a dictionary of genres\r\n    \"\"\"\r\n    genre_dict = {}\r\n    with open(fname, \"r\") as f:\r\n        reader = csv.reader(f)\r\n        for row in reader:\r\n            genre_dict[row[0]] = 1\r\n    return genre_dict\r\n\r\n\r\ndef generate_daily_node_labels(fname: str):\r\n    r\"\"\"\r\n    read a temporal edgelist\r\n    node label = fav genre in this day\r\n    generate the node label for each day for each user\r\n    Note: only genres from the genre_list are considered\r\n\r\n    user_000001,2006-08-13 14:59:59+00:00,\"['electronic', 0.5319148936170213]\"\r\n    user_000001,2006-08-13 14:59:59+00:00,\"['alternative', 0.46808510638297873]\"\r\n    user_000001,2006-08-13 15:36:22+00:00,\"['electronic', 0.6410256410256411]\"\r\n    user_000001,2006-08-13 15:36:22+00:00,\"['chillout', 0.358974358974359]\"\r\n    user_000001,2006-08-13 15:40:13+00:00,\"['math rock', 1.0]\"\r\n    user_000001,2006-08-15 13:41:18+00:00,\"['electronica', 1.0]\"\r\n    user_000001,2006-08-15 13:59:27+00:00,\"['acid jazz', 0.3546099290780142]\"\r\n    user_000001,2006-08-15 13:59:27+00:00,\"['nu jazz', 0.3333333333333333]\"\r\n    user_000001,2006-08-15 13:59:27+00:00,\"['chillout', 0.3120567375886525]\"\r\n    \"\"\"\r\n    edgelist = open(fname, \"r\")\r\n    lines = list(edgelist.readlines())\r\n    edgelist.close()\r\n\r\n    format = \"%Y-%m-%d %H:%M:%S\"\r\n    day_dict = {}  # store the weights of genres on this day\r\n    cur_day = -1\r\n\r\n    with open(\"daily_labels.csv\", \"w\") as outf:\r\n        write = csv.writer(outf)\r\n        fields = [\"user_id\", \"year\", \"month\", \"day\", \"genre\", \"weight\"]\r\n        write.writerow(fields)\r\n\r\n        # generate daily labels for users\r\n        for i in range(1, len(lines)):\r\n            vals = lines[i].split(\",\")\r\n            user_id = vals[0]\r\n            time = vals[1][:-7]\r\n            date_object = datetime.datetime.strptime(time, format)\r\n            if i == 1:\r\n                cur_day = date_object.day\r\n\r\n            genre = vals[2]\r\n            w = float(vals[3].strip())\r\n            if date_object.day != cur_day:\r\n                #! normalize the weights in the day_dict to sum 1\r\n                # * remove normalization for future aggregation\r\n                # total = sum(day_dict.values())\r\n                # day_dict = {k: v / total for k, v in day_dict.items()}\r\n\r\n                #! user,time,genre,weight  # genres = # of weights\r\n                out = [\r\n                    user_id,\r\n                    str(date_object.year),\r\n                    str(date_object.month),\r\n                    str(date_object.day),\r\n                ]\r\n                for genre, w in day_dict.items():\r\n                    write.writerow(out + [genre] + [w])\r\n\r\n                cur_day = date_object.day\r\n                day_dict = {}\r\n            else:\r\n                if genre not in day_dict:\r\n                    day_dict[genre] = w\r\n                else:\r\n                    day_dict[genre] += w\r\n\r\n\r\ndef generate_aggregate_labels(fname: str, days: int = 7):\r\n    \"\"\"\r\n    aggregate the genres over a number of days,  as specified by days\r\n    #! current generation includes edges from the day of the label, thus the label should be set to be beginning of the day\r\n    prediction should always be at the first second of the day\r\n    \"\"\"\r\n    edgelist = open(fname, \"r\")\r\n    lines = list(edgelist.readlines())\r\n    edgelist.close()\r\n    date_prev = 0\r\n\r\n    genre_dict = {}\r\n    user_prev = 0\r\n\r\n    # \"user_id\", \"year\", \"month\", \"day\", \"genre\", \"weight\"\r\n    with open(str(days) + \"days_labels.csv\", \"w\") as outf:\r\n        write = csv.writer(outf)\r\n        fields = [\"user_id\", \"year\", \"month\", \"day\", \"genre\", \"weight\"]\r\n        write.writerow(fields)\r\n\r\n        for i in range(1, len(lines)):\r\n            vals = lines[i].split(\",\")\r\n            user_id = vals[0]\r\n            year = int(vals[1])\r\n            month = int(vals[2])\r\n            day = int(vals[3])\r\n            genre = vals[4]\r\n            w = float(vals[5])\r\n            if i == 1:\r\n                date_prev = date(year, month, day)\r\n                user_prev = user_id\r\n\r\n            date_cur = date(year, month, day)\r\n\r\n            if user_id != user_prev:\r\n                date_prev = date(year, month, day)\r\n                user_prev = user_id\r\n\r\n            if (\r\n                date_cur - date_prev\r\n            ).days <= days:  #! this means that the date = [0,7] which includes the current day\r\n                if genre not in genre_dict:\r\n                    genre_dict[genre] = w\r\n                else:\r\n                    genre_dict[genre] += w\r\n            else:\r\n                # start a new week\r\n                # normalize the weight to sum 1\r\n                total = sum(genre_dict.values())\r\n                genre_dict = {k: v / total for k, v in genre_dict.items()}\r\n\r\n                out = [\r\n                    user_id,\r\n                    str(date_prev.year),\r\n                    str(date_prev.month),\r\n                    str(date_prev.day),\r\n                ]\r\n                for genre, w in genre_dict.items():\r\n                    write.writerow(out + [genre] + [w])\r\n                date_prev = date_prev + datetime.timedelta(days=1)\r\n                genre_dict = {}\r\n\r\n\r\ndef most_frequent(List):\r\n    \"\"\"\r\n    helper function to find the most frequent element in a list\r\n    the ties are broken by choosing the earlier element\r\n    \"\"\"\r\n    counter = 0\r\n    out = List[0]\r\n\r\n    for item in List:\r\n        curr_frequency = List.count(item)\r\n        if curr_frequency > counter:  # update on most frequent item is found\r\n            counter = curr_frequency\r\n            out = item\r\n    return out\r\n\r\n\r\ndef convert_ts_unix(fname: str, outname: str):\r\n    \"\"\"\r\n    convert all time from datetime to unix time\r\n    \"\"\"\r\n    TIME_FORMAT = \"%Y-%m-%d\"\r\n    with open(outname, \"w\") as outf:\r\n        write = csv.writer(outf)\r\n        fields = [\"ts\", \"user_id\", \"genre\", \"weight\"]\r\n        write.writerow(fields)\r\n\r\n        with open(fname, \"r\") as csv_file:\r\n            csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n            line_count = 0\r\n            # time,user_id,genre,weight\r\n            for row in csv_reader:\r\n                if line_count == 0:\r\n                    line_count += 1\r\n                else:\r\n                    ts = datetime.datetime.strptime(row[0], TIME_FORMAT)\r\n                    ts += timedelta(days=1)\r\n                    ts = int(ts.timestamp())\r\n                    user_id = row[1]\r\n                    genre = row[2]\r\n                    weight = float(row[3])\r\n                    write.writerow([ts, user_id, genre, weight])\r\n\r\n\r\ndef convert_ts_edgelist(fname: str, outname: str):\r\n    \"\"\"\r\n    convert all time from datetime to unix time\r\n    \"\"\"\r\n    TIME_FORMAT = \"%Y-%m-%d %H:%M:%S\"\r\n    with open(outname, \"w\") as outf:\r\n        write = csv.writer(outf)\r\n        fields = [\"ts\", \"user_id\", \"genre\", \"weight\"]\r\n        write.writerow(fields)\r\n\r\n        with open(fname, \"r\") as csv_file:\r\n            csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n            line_count = 0\r\n            # time,user_id,genre,weight\r\n            for row in csv_reader:\r\n                if line_count == 0:\r\n                    line_count += 1\r\n                else:\r\n                    ts = datetime.datetime.strptime(row[0], TIME_FORMAT)\r\n                    ts = int(ts.timestamp())\r\n                    user_id = row[1]\r\n                    genre = row[2]\r\n                    weight = float(row[3])\r\n                    write.writerow([ts, user_id, genre, weight])\r\n\r\n\r\ndef sort_node_labels(fname, outname):\r\n    r\"\"\"\r\n    sort the node labels by time\r\n    \"\"\"\r\n    edgelist = open(fname, \"r\")\r\n    lines = list(edgelist.readlines())\r\n    edgelist.close()\r\n\r\n    with open(outname, \"w\") as outf:\r\n        write = csv.writer(outf)\r\n        fields = [\"time\", \"user_id\", \"genre\", \"weight\"]\r\n        write.writerow(fields)\r\n        rows_dict = {}\r\n\r\n        for i in range(1, len(lines)):\r\n            vals = lines[i].split(\",\")\r\n            user_id = vals[0]\r\n            year = int(vals[1])\r\n            month = int(vals[2])\r\n            day = int(vals[3])\r\n            genre = vals[4]\r\n            w = float(vals[5])\r\n            date_cur = datetime(year, month, day)\r\n            time_ts = date_cur.strftime(\"%Y-%m-%d\")\r\n            if time_ts not in rows_dict:\r\n                rows_dict[time_ts] = [(user_id, genre, w)]\r\n            else:\r\n                rows_dict[time_ts].append((user_id, genre, w))\r\n\r\n        time_keys = list(rows_dict.keys())\r\n        time_keys.sort()\r\n\r\n        for ts in time_keys:\r\n            rows = rows_dict[ts]\r\n            for user_id, genre, w in rows:\r\n                write.writerow([ts, user_id, genre, w])\r\n\r\n\r\ndef sort_edgelist(fname, outname=\"sorted_lastfm_edgelist.csv\"):\r\n    r\"\"\"\r\n    sort the edgelist by time\r\n    \"\"\"\r\n    edgelist = open(fname, \"r\")\r\n    lines = list(edgelist.readlines())\r\n    edgelist.close()\r\n\r\n    with open(outname, \"w\") as outf:\r\n        write = csv.writer(outf)\r\n        fields = [\"time\", \"user_id\", \"genre\", \"weight\"]\r\n        write.writerow(fields)\r\n\r\n        rows_dict = {}\r\n        for idx in range(1, len(lines)):\r\n            vals = lines[idx].split(\",\")\r\n            user_id = vals[0]\r\n            time_ts = vals[1][:-7]\r\n            genre = vals[2]\r\n            w = float(vals[3].strip())\r\n            if time_ts not in rows_dict:\r\n                rows_dict[time_ts] = [(user_id, genre, w)]\r\n            else:\r\n                rows_dict[time_ts].append((user_id, genre, w))\r\n\r\n        time_keys = list(rows_dict.keys())\r\n        time_keys.sort()\r\n\r\n        for ts in time_keys:\r\n            rows = rows_dict[ts]\r\n            for user_id, genre, w in rows:\r\n                write.writerow([ts, user_id, genre, w])\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    #! generate the list of genres by frequency\r\n    # get_genre_list(\"/mnt/c/Users/sheny/Desktop/TGB/tgb/datasets/lastfmGenre/dataset.csv\")\r\n    # genre_dict = load_genre_dict(\"/mnt/c/Users/sheny/Desktop/TGB/tgb/datasets/lastfmGenre/genre_list.csv\")\r\n\r\n    #! find similar genres\r\n    # find_unique_genres(\"genre_list_1000.csv\",threshold= 0.8)\r\n\r\n    #! filter edgelist with genres to keep\r\n    # genres_dict = load_genre_dict(\"genre_list_1000.csv\")\r\n    # filter_genre_edgelist(\"dataset.csv\", genres_dict)\r\n\r\n    #! generate the daily node labels\r\n    # generate_daily_node_labels(\"lastfm_edgelist_clean.csv\")\r\n\r\n    # generate_daily_node_labels(\"/mnt/c/Users/sheny/Desktop/TGB/tgb/datasets/lastfmGenre/dataset.csv\")\r\n    # load_node_labels(\"/mnt/c/Users/sheny/Desktop/TGB/tgb/datasets/lastfmGenre/daily_labels.csv\")\r\n\r\n    # #! generate normalized weekly node labels\r\n    # generate_aggregate_labels(\"daily_labels.csv\", days=7)\\\r\n\r\n    # \"\"\"\r\n    # sort edgelist by time for lastfm dataset\r\n    # \"\"\"\r\n    # fname = \"../datasets/lastfmGenre/lastfm_edgelist_clean.csv\"\r\n    # outname = '../datasets/lastfmGenre/sorted_lastfm_edgelist.csv'\r\n    # sort_edgelist(fname,\r\n    #               outname = outname)\r\n\r\n    # \"\"\"\r\n    # sort node labels by time for lastfm dataset\r\n    # \"\"\"\r\n    # fname = \"../datasets/lastfmGenre/7days_labels.csv\"\r\n    # outname = '../datasets/lastfmGenre/sorted_7days_node_labels.csv'\r\n    # sort_node_labels(fname,\r\n    #                  outname)\r\n\r\n    # #! convert from date to ts\r\n    # convert_ts_unix(\"lastfmgenre_node_labels_datetime.csv\",\r\n    #                 \"lastfmgenre_node_labels.csv\")\r\n    convert_ts_edgelist(\"lastfmgenre_edgelist.csv\", \"lastfmgenre_edgelist_ts.csv\")\r\n"
  },
  {
    "path": "tgb/datasets/dataset_scripts/tgbn-reddit.py",
    "content": "import csv\r\nfrom tqdm import tqdm\r\nfrom os import listdir\r\nfrom tgb.utils.stats import analyze_csv\r\n\r\n\r\ndef find_filenames(path_to_dir):\r\n    r\"\"\"\r\n    find all files in a folder\r\n    Parameters:\r\n        path_to_dir (str): path to the directory\r\n    \"\"\"\r\n    filenames = listdir(path_to_dir)\r\n    return filenames\r\n\r\n\r\ndef combine_edgelist_edgefeat2subreddits(edgefname, featfname, outname):\r\n    \"\"\"\r\n    combine edgelist and edge features\r\n    'ts', 'src', 'subreddit', 'num_words', 'score'\r\n    \"\"\"\r\n    line_idx = 0\r\n\r\n    with open(outname, \"w\") as outf:\r\n        write = csv.writer(outf)\r\n        fields = [\"ts\", \"src\", \"subreddit\", \"num_words\", \"score\"]\r\n        write.writerow(fields)\r\n        sub_id = 0\r\n        edgelist = open(edgefname, \"r\")\r\n        edgefeat = open(featfname, \"r\")\r\n        edgelist.readline()\r\n        edgefeat.readline()\r\n\r\n        while True:\r\n            #'ts', 'src', 'dst', 'edge_id'\r\n            edge_line = edgelist.readline()\r\n            edge_line = edge_line.split(\",\")\r\n            if len(edge_line) < 4:\r\n                break\r\n            edge_id = int(edge_line[3])\r\n            ts = int(edge_line[0])\r\n            src = int(edge_line[1])\r\n\r\n            #'edge_id', 'subreddit', 'num_characters', 'num_words', 'score', 'edited_flag'\r\n            feat_line = edgefeat.readline()\r\n            feat_line = feat_line.split(\",\")\r\n            edge_id_feat = int(feat_line[0])\r\n            subreddit = feat_line[1]\r\n            num_words = int(feat_line[3])\r\n            score = int(feat_line[4])\r\n\r\n            if edge_id != edge_id_feat:\r\n                print(\"edge_id != edge_id_feat\")\r\n                print(edge_id)\r\n                print(edge_id_feat)\r\n                break\r\n\r\n            write.writerow([ts, src, subreddit, num_words, score])\r\n            line_idx += 1\r\n    print(\"processed\", line_idx, \"lines\")\r\n\r\n\r\ndef filter_subreddits(fname):\r\n    \"\"\"\r\n    check the frequency of subreddits\r\n    \"\"\"\r\n    subreddit_count = {}\r\n    node_count = {}\r\n    with open(fname, \"r\") as csv_file:\r\n        csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n        line_count = 0\r\n        # ts, src, subreddit, num_words, score\r\n        for row in csv_reader:\r\n            if line_count == 0:\r\n                line_count += 1\r\n            else:\r\n                ts = row[0]\r\n                src = row[1]\r\n                subreddit = row[2]\r\n                if subreddit not in subreddit_count:\r\n                    subreddit_count[subreddit] = 1\r\n                else:\r\n                    subreddit_count[subreddit] += 1\r\n                if src not in node_count:\r\n                    node_count[src] = 1\r\n                else:\r\n                    node_count[src] += 1\r\n    return subreddit_count, node_count\r\n\r\n\r\ndef clean_edgelist(fname, node_counts, outname, threshold=1000):\r\n    \"\"\"\r\n    helper function for filtering out low frequency nodes\r\n    \"\"\"\r\n    node_dict = {}\r\n\r\n    for node in node_counts:\r\n        if node_counts[node] >= threshold:\r\n            node_dict[node] = 1\r\n\r\n    with open(outname, \"w\") as outf:\r\n        write = csv.writer(outf)\r\n        fields = [\"ts\", \"user\", \"subreddit\", \"num_words\", \"score\"]\r\n        write.writerow(fields)\r\n\r\n        with open(fname, \"r\") as csv_file:\r\n            csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n            line_count = 0\r\n            # ts, src, subreddit, num_words, score\r\n            for row in csv_reader:\r\n                if line_count == 0:\r\n                    line_count += 1\r\n                else:\r\n                    ts = row[0]\r\n                    src = row[1]\r\n                    subreddit = row[2]\r\n                    num_words = int(row[3])\r\n                    score = int(row[4])\r\n                    if src in node_dict:\r\n                        write.writerow([ts, src, subreddit, num_words, score])\r\n\r\n\r\ndef clean_edgelist_reddits(fname, reddit_counts, outname, threshold=50):\r\n    \"\"\"\r\n    helper function for filtering out low frequency subreddits\r\n    \"\"\"\r\n    reddit_dict = {}\r\n\r\n    for reddit in reddit_counts:\r\n        if reddit_counts[reddit] >= threshold:\r\n            reddit_dict[reddit] = 1\r\n    print (\"there remains, \", len(reddit_dict), \" subreddits\")\r\n\r\n    with open(outname, \"w\") as outf:\r\n        write = csv.writer(outf)\r\n        fields = [\"ts\", \"user\", \"subreddit\", \"num_words\", \"score\"]\r\n        write.writerow(fields)\r\n        with open(fname, \"r\") as csv_file:\r\n            csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n            line_count = 0\r\n            # ts, src, subreddit, num_words, score\r\n            for row in csv_reader:\r\n                if line_count == 0:\r\n                    line_count += 1\r\n                else:\r\n                    ts = row[0]\r\n                    src = row[1]\r\n                    subreddit = row[2]\r\n                    num_words = int(row[3])\r\n                    score = int(row[4])\r\n                    if subreddit in reddit_dict:\r\n                        write.writerow([ts, src, subreddit, num_words, score])\r\n\r\n\r\ndef remove_missing_user(fname, outname):\r\n    \"\"\"\r\n    remove all lines that are missing the user\r\n    \"\"\"\r\n\r\n    with open(outname, \"w\") as outf:\r\n        write = csv.writer(outf)\r\n        fields = [\"ts\", \"user\", \"subreddit\", \"num_words\", \"score\"]\r\n        write.writerow(fields)\r\n        with open(fname, \"r\") as csv_file:\r\n            csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n            line_count = 0\r\n            # ts, src, subreddit, num_words, score\r\n            for row in csv_reader:\r\n                if line_count == 0:\r\n                    line_count += 1\r\n                else:\r\n                    ts = row[0]\r\n                    src = int(row[1])\r\n                    subreddit = row[2]\r\n                    num_words = int(row[3])\r\n                    score = int(row[4])\r\n                    if src != -1:\r\n                        write.writerow([ts, src, subreddit, num_words, score])\r\n\r\n\r\ndef generate_daily_node_labels(\r\n    fname: str,\r\n    outname: str,\r\n):\r\n    r\"\"\"\r\n    function for generating daily node labels then can be used for aggregation\r\n    \"\"\"\r\n\r\n    day_dict = {}  # store the weights of genres on this day\r\n    prev_t = -1\r\n    DAY_IN_SEC = 86400\r\n    # WEEK_IN_SEC = 604800\r\n    with open(outname, \"w\") as outf:\r\n        write = csv.writer(outf)\r\n        fields = [\"ts\", \"user\", \"subreddit\", \"weight\"]\r\n        write.writerow(fields)\r\n\r\n        with open(fname, \"r\") as csv_file:\r\n            csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n            line_count = 0\r\n            # ts, src, subreddit, num_words, score\r\n            for row in csv_reader:\r\n                if line_count == 0:\r\n                    line_count += 1\r\n                else:\r\n                    ts = int(row[0])\r\n                    user_id = row[1]\r\n                    subreddit = row[2]\r\n\r\n                    if line_count == 1:\r\n                        prev_t = ts\r\n\r\n                    if (prev_t + DAY_IN_SEC) < ts:\r\n                        #! user,time,genre,weight  # genres = # of weights\r\n                        out = [user_id, ts]\r\n                        for subreddit, w in day_dict.items():\r\n                            write.writerow(out + [subreddit] + [w])\r\n                        prev_t = ts\r\n                        day_dict = {}\r\n                    else:\r\n                        if subreddit not in day_dict:\r\n                            day_dict[subreddit] = 1\r\n                        else:\r\n                            day_dict[subreddit] += 1\r\n                    line_count += 1\r\n\r\n\r\n#! note that the edgelist are not sorted by users then by time, should keep multiple users when aggregating\r\ndef generate_aggregate_labels(fname: str, outname: str, days: int = 7):\r\n    \"\"\"\r\n    aggregate the genres over a number of days,  as specified by days\r\n    prediction should always be at the first second of the day\r\n    #! daily labels are always shifted by 1 day\r\n    \"\"\"\r\n\r\n    ts_prev = 0\r\n\r\n    DAY_IN_SEC = 86400\r\n    timespan = days * DAY_IN_SEC\r\n\r\n    user_dict = {}\r\n\r\n    # ts, src, subreddit, num_words, score\r\n    with open(outname, \"w\") as outf:\r\n        write = csv.writer(outf)\r\n        fields = [\"ts\", \"user\", \"subreddit\", \"weight\"]\r\n        write.writerow(fields)\r\n\r\n        with open(fname, \"r\") as csv_file:\r\n            csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n            line_count = 0\r\n            # ts, src, subreddit, num_words, score\r\n            for row in csv_reader:\r\n                if line_count == 0:\r\n                    line_count += 1\r\n                else:\r\n                    ts = int(row[0])\r\n                    user = row[1]\r\n                    subreddit = row[2]\r\n                    w = int(row[3])\r\n                    if line_count == 1:\r\n                        ts_prev = ts\r\n\r\n                    if (ts - ts_prev) > timespan:\r\n                        for user in user_dict:\r\n                            total = sum(user_dict[user].values())\r\n                            subreddit_dict = {\r\n                                k: v / total for k, v in user_dict[user].items()\r\n                            }\r\n                            for subreddit, w in subreddit_dict.items():\r\n                                write.writerow(\r\n                                    [ts_prev + DAY_IN_SEC, user, subreddit, w]\r\n                                )\r\n                        user_dict = {}\r\n                        ts_prev = ts_prev + DAY_IN_SEC  #! move label to the next day\r\n                    else:\r\n                        if user in user_dict:\r\n                            if subreddit in user_dict[user]:\r\n                                user_dict[user][subreddit] += w\r\n                            else:\r\n                                user_dict[user][subreddit] = w\r\n                        else:\r\n                            user_dict[user] = {}\r\n                            user_dict[user][subreddit] = w\r\n                    line_count += 1\r\n\r\n\r\ndef main():\r\n    # #? see redditcomments.py for the extraction from the raw files\r\n\r\n    #! combine edgelist and edge feat file check if the edge_id matches\r\n    # edgefname = \"redditcomments_edgelist_2008_2010.csv\"\r\n    # featfname = \"redditcomments_edgefeat_2008_2010.csv\"\r\n    # outname = \"subreddits_edgelist.csv\"\r\n    # combine_edgelist_edgefeat2subreddits(edgefname, featfname, outname)\r\n\r\n    #! remove all edges missing user\r\n    # fname = \"subreddits_edgelist.csv\"\r\n    # outname = \"subreddits_edgelist_filtered.csv\"\r\n    # remove_missing_user(fname,\r\n    #                     outname)\r\n\r\n    #! should clean subreddits first, frequency count of reddits\r\n    # fname = \"subreddits_edgelist.csv\"\r\n    # outname = \"subreddits_edgelist_filter.csv\"\r\n    # subreddit_count, node_count = filter_subreddits(fname)\r\n    # threshold = 1000 #200 #100\r\n    # clean_edgelist_reddits(fname, subreddit_count, outname, threshold=threshold)\r\n\r\n\r\n    #! filter out nodes with low frequency frequency count of nodes\r\n    # fname = \"subreddits_edgelist.csv\"\r\n    # outname = \"subreddits_edgelist_clean.csv\"\r\n    # subreddit_count, node_count = filter_subreddits(fname)\r\n    # threshold = 1000\r\n    # clean_edgelist(fname, node_count, outname, threshold=threshold)\r\n    # print (\"finish cleaning\")\r\n\r\n    #! generate aggregate labels, the label for each day is shifted by 1 day as it uses the edges from today\r\n    # fname = \"subreddits_edgelist.csv\"\r\n    # outname = \"subreddits_node_labels.csv\"\r\n    # generate_aggregate_labels(fname, outname, days=7)\r\n\r\n    #! analyze the extracted csv\r\n    fname = \"subreddits_edgelist.csv\"\r\n    analyze_csv(fname)\r\n\r\n\r\n    \r\n    # #! generate daily node labels\r\n    # outname = 'subreddits_daily_labels.csv'\r\n    # fname = \"subreddits_edgelist.csv\"\r\n    # generate_daily_node_labels(fname,outname)\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n"
  },
  {
    "path": "tgb/datasets/dataset_scripts/tgbn-token.py",
    "content": "import csv\r\nimport datetime\r\n\r\ndef count_node_freq(fname, filter_size=100):\r\n\r\n    node_dict = {}\r\n    with open(fname, \"r\") as csv_file:\r\n        csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n        ctr = 0\r\n        for row in csv_reader:\r\n            if ctr == 0:\r\n                ctr += 1\r\n                continue\r\n            else:\r\n                token_type = row[0]\r\n                src = row[1]\r\n                if (src not in node_dict):\r\n                    node_dict[src] = 1\r\n                else:\r\n                    node_dict[src] += 1\r\n                dst = row[2]\r\n                if (dst not in node_dict):\r\n                    node_dict[dst] = 1\r\n                else:\r\n                    node_dict[dst] += 1\r\n                ctr += 1\r\n\r\n    num_10 = 0\r\n    num_100 = 0\r\n    num_1000 = 0\r\n    num_2000 = 0\r\n    num_5000 = 0 \r\n\r\n    for node in node_dict:\r\n        if node_dict[node] >= 10:\r\n            num_10 += 1\r\n        if node_dict[node] >= 100:\r\n            num_100 += 1\r\n        if node_dict[node] >= 1000:\r\n            num_1000 += 1\r\n        if node_dict[node] >= 2000:\r\n            num_2000 += 1\r\n        if node_dict[node] >= 5000:\r\n            num_5000 += 1\r\n\r\n    print(\"number of nodes with # edges >= 10 is \", num_10)\r\n    print(\"number of nodes with # edges >= 100 is \", num_100)\r\n    print(\"number of nodes with # edges >= 1000 is \", num_1000)\r\n    print(\"number of nodes with # edges >= 2000 is \", num_2000)\r\n    print(\"number of nodes with # edges >= 5000 is \", num_5000)\r\n    print(\"----------------------high level statistics-------------------------\")\r\n\r\n\r\n    #! keep nodes with at least 100 edges\r\n    node_dict_filtered = {}\r\n    for node in node_dict:\r\n        if node_dict[node] >= filter_size:\r\n            node_dict_filtered[node] = node_dict[node]\r\n    return node_dict_filtered\r\n\r\n\r\n\r\n\r\n\r\n\r\ndef filter_edgelist(token_fname, edgefile, outname):\r\n    \"\"\"\r\n    preserve only the tokens in the token file\r\n    Parameters:\r\n        token_fname: the file of the token file\r\n        edgefile: the edgelist file name\r\n        outname: the output filtered edgelistname\r\n    \"\"\"\r\n    #* read tokens from the file\r\n    token_dict = {}\r\n    with open(token_fname, \"r\") as csv_file:\r\n        csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n        ctr = 0\r\n        for row in csv_reader:\r\n            if ctr == 0:\r\n                ctr += 1\r\n                continue\r\n            else:\r\n                token_type = row[0]\r\n                token_dict[token_type] = 1\r\n    \r\n    with open(edgefile, \"r\") as in_file:\r\n        with open(outname, \"w\") as out_file:\r\n            csv_reader = csv.reader(in_file, delimiter=\",\")\r\n            csv_writer = csv.writer(out_file, delimiter=\",\")\r\n            csv_writer.writerow([\"token_address\", \"from_address\", \"to_address\", \"value\", \"block_timestamp\"])\r\n            ctr = 0\r\n            for row in csv_reader:\r\n                if ctr == 0:\r\n                    ctr += 1\r\n                    continue\r\n                else:\r\n                    token_type = row[0]\r\n                    if token_type in token_dict:\r\n                        csv_writer.writerow(row)\r\n                    ctr += 1\r\n\r\n\r\ndef filter_by_node(node_dict, edgefile, outname):\r\n    with open(edgefile, \"r\") as in_file:\r\n        with open(outname, \"w\") as out_file:\r\n            csv_reader = csv.reader(in_file, delimiter=\",\")\r\n            csv_writer = csv.writer(out_file, delimiter=\",\")\r\n            csv_writer.writerow([\"token_address\", \"from_address\", \"to_address\", \"value\", \"block_timestamp\"])\r\n            ctr = 0\r\n            for row in csv_reader:\r\n                if ctr == 0:\r\n                    ctr += 1\r\n                    continue\r\n                else:\r\n                    token_type = row[0]\r\n                    src = row[1]\r\n                    dst = row[2]\r\n                    if (src in node_dict) or (dst in node_dict):\r\n                        csv_writer.writerow(row)\r\n                    ctr += 1\r\n\r\n\r\n\r\ndef store_node_list(node_dict, outname):\r\n    \"\"\"\r\n    Parameters:\r\n        outname: name of the output csv file\r\n    Output:\r\n        output csv file with node list\r\n    \"\"\"\r\n    with open(outname, \"w\") as csv_file:\r\n        csv_writer = csv.writer(csv_file, delimiter=\",\")\r\n        csv_writer.writerow([\"node_list\", \"frequency\"])\r\n        for key, value in node_dict.items():\r\n            csv_writer.writerow([key, value])\r\n\r\n\r\ndef load_node_dict(fname):\r\n    node_dict = {}\r\n    with open(fname, \"r\") as csv_file:\r\n        csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n        ctr = 0\r\n        for row in csv_reader:\r\n            if ctr == 0:\r\n                ctr += 1\r\n                continue\r\n            else:\r\n                node = row[0]\r\n                freq = int(row[1])\r\n                node_dict[node] = freq\r\n                ctr += 1\r\n    return node_dict\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\ndef store_token_address(token_dict, outname, topk=1000):\r\n    \"\"\"\r\n    Parameters:\r\n        outname: name of the output csv file\r\n    Output:\r\n        output csv file with topk token addresses\r\n    \"\"\"\r\n    sorted_tokens = {k: v for k, v in sorted(token_dict.items(), key=lambda item: item[1], reverse=True)}\r\n    ctr = 0\r\n    with open(outname, \"w\") as csv_file:\r\n        csv_writer = csv.writer(csv_file, delimiter=\",\")\r\n        csv_writer.writerow([\"token_address\", \"frequency\"])\r\n        for key, value in sorted_tokens.items():\r\n            if (ctr <= topk):\r\n                csv_writer.writerow([key, value])\r\n            else:\r\n                break\r\n            ctr += 1\r\n\r\ndef analyze_token_frequency(fname):\r\n    # ['token_address', 'from_address', 'to_address', 'value', 'block_timestamp']\r\n    token_dict = {}\r\n    node_dict = {}\r\n    time_dict = {}\r\n    max_w = 0\r\n    min_w = 100000\r\n    num_edges = 0\r\n\r\n    with open(fname, \"r\") as csv_file:\r\n        csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n        ctr = 0\r\n        for row in csv_reader:\r\n            if ctr == 0:\r\n                ctr += 1\r\n                continue\r\n            else:\r\n                token_type = row[0]\r\n                if (token_type not in token_dict):\r\n                    token_dict[token_type] = 1\r\n                else:\r\n                    token_dict[token_type] += 1\r\n                src = row[1]\r\n                if (src not in node_dict):\r\n                    node_dict[src] = 1\r\n                else:\r\n                    node_dict[src] += 1\r\n                dst = row[2]\r\n                if (dst not in node_dict):\r\n                    node_dict[dst] = 1\r\n                else:\r\n                    node_dict[dst] += 1\r\n\r\n                w = float(row[3])\r\n                if (w > max_w):\r\n                    max_w = w\r\n                elif (w < min_w):\r\n                    min_w = w\r\n                timestamp = row[4]\r\n                if (timestamp not in time_dict):\r\n                    time_dict[timestamp] = 1\r\n                ctr += 1\r\n                num_edges += 1\r\n\r\n    print ( \"number of edges are \", num_edges)\r\n    print (\" number of unique tokens are \", len(token_dict))\r\n    print (\" number of unique nodes are \", len(node_dict))\r\n    print (\" number of unique timestamps are \", len(time_dict))\r\n    print (\" max weight is \", max_w)\r\n    print (\" min weight is \", min_w)\r\n\r\n    # topk = 1000\r\n    # store_token_address(token_dict, \"token_list.csv\", topk=topk)\r\n\r\ndef to_bipartite(in_name, out_name, node_dict):\r\n    \"\"\"\r\n    load and convert a user-user graph into a user-token bipartite graph\r\n    \"\"\"\r\n    with open(in_name, \"r\") as csv_file:\r\n        csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n        with open(out_name, \"w\") as out_file:\r\n            csv_writer = csv.writer(out_file, delimiter=\",\")\r\n            csv_writer.writerow([\"timestamp\", \"user_address\", \"token_address\", \"value\", \"IsSender\"])\r\n            ctr = 0\r\n            for row in csv_reader:\r\n                if ctr == 0:\r\n                    ctr += 1\r\n                    continue\r\n                else:\r\n                    token_type = row[0]\r\n                    src = row[1]\r\n                    dst = row[2]\r\n                    w = float(row[3])\r\n                    timestamp = row[4]\r\n                    if (src in node_dict):\r\n                        csv_writer.writerow([timestamp, src, token_type, w, 1])\r\n                    if (dst in node_dict):\r\n                        csv_writer.writerow([timestamp, dst, token_type, w, 0])\r\n                    \r\n\r\ndef analyze_csv(fname):\r\n    node_dict = {}\r\n    edge_dict = {}\r\n    num_edges = 0\r\n    num_time = 0\r\n    time_dict = {}\r\n\r\n    with open(fname, \"r\") as csv_file:\r\n        csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n        line_count = 0\r\n        for row in csv_reader:\r\n            if line_count == 0:\r\n                line_count += 1\r\n            else:\r\n                # t,u,v,w\r\n                t = row[0]\r\n                u = row[1]\r\n                v = row[2]\r\n\r\n                # count unique time\r\n                if t not in time_dict:\r\n                    time_dict[t] = 1\r\n                    num_time += 1\r\n\r\n                # unique nodes\r\n                if u not in node_dict:\r\n                    node_dict[u] = 1\r\n                else:\r\n                    node_dict[u] += 1\r\n\r\n                if v not in node_dict:\r\n                    node_dict[v] = 1\r\n                else:\r\n                    node_dict[v] += 1\r\n\r\n                # unique edges\r\n                num_edges += 1\r\n                if (u, v) not in edge_dict:\r\n                    edge_dict[(u, v)] = 1\r\n                else:\r\n                    edge_dict[(u, v)] += 1\r\n\r\n    print(\"----------------------high level statistics-------------------------\")\r\n    print(\"number of total edges are \", num_edges)\r\n    print(\"number of nodes are \", len(node_dict))\r\n    print(\"number of unique edges are \", len(edge_dict))\r\n    print(\"number of unique timestamps are \", num_time)\r\n\r\n    num_10 = 0\r\n    num_100 = 0\r\n    num_1000 = 0\r\n\r\n    for node in node_dict:\r\n        if node_dict[node] >= 10:\r\n            num_10 += 1\r\n        if node_dict[node] >= 100:\r\n            num_100 += 1\r\n        if node_dict[node] >= 1000:\r\n            num_1000 += 1\r\n    print(\"number of nodes with # edges >= 10 is \", num_10)\r\n    print(\"number of nodes with # edges >= 100 is \", num_100)\r\n    print(\"number of nodes with # edges >= 1000 is \", num_1000)\r\n    print(\"----------------------high level statistics-------------------------\")        \r\n\r\n\r\n\r\ndef convert_2_sec(fname, outname):\r\n    \"\"\"\r\n    convert datetime object format = \"%Y-%m-%d %H:%M:%S\" to seconds\r\n    #2017-07-24 17:48:15+00:00\r\n    \"\"\"\r\n    format = \"%Y-%m-%d %H:%M:%S\"\r\n    with open(fname, \"r\") as csv_file:\r\n        csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n        with open(outname, \"w\") as out_file:\r\n            csv_writer = csv.writer(out_file, delimiter=\",\")\r\n            csv_writer.writerow([\"timestamp\", \"user_address\", \"token_address\", \"value\", \"IsSender\"])\r\n            ctr = 0\r\n            for row in csv_reader:\r\n                if ctr == 0:\r\n                    ctr += 1\r\n                    continue\r\n                else:\r\n                    timestamp = row[0][:19]\r\n                    date_object = datetime.datetime.strptime(timestamp, format)\r\n                    timestamp_sec = int(date_object.timestamp())\r\n                    src = row[1]\r\n                    dst = row[2]\r\n                    w = float(row[3])\r\n                    IsSender = int(row[4])\r\n                    if (w != 0):\r\n                        csv_writer.writerow([timestamp_sec, src, dst, w, IsSender])\r\n\r\n    \r\n\r\n\r\n\r\ndef print_csv(fname):\r\n    # ['token_address', 'from_address', 'to_address', 'value', 'block_timestamp']\r\n    with open(fname, \"r\") as csv_file:\r\n        csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n        ctr = 0\r\n        for row in csv_reader:\r\n            ctr += 1\r\n    print (\"there are \", ctr, \" rows in the csv file\")\r\n\r\n\r\n\r\ndef sort_edgelist_by_time(fname, outname):\r\n    row_dict = {}\r\n    with open(fname, \"r\") as csv_file:\r\n        csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n        with open(outname, \"w\") as out_file:\r\n            csv_writer = csv.writer(out_file, delimiter=\",\")\r\n            csv_writer.writerow([\"timestamp\", \"user_address\", \"token_address\", \"value\", \"IsSender\"])\r\n            ctr = 0\r\n            for row in csv_reader:\r\n                if ctr == 0:\r\n                    ctr += 1\r\n                    continue\r\n                else:\r\n                    timestamp =int(row[0])\r\n                    if (timestamp not in row_dict):\r\n                        row_dict[timestamp] = [row]\r\n                    else:\r\n                        row_dict[timestamp].append(row)\r\n            for i in sorted(row_dict.keys()):\r\n                rows = row_dict[i]\r\n                for row in rows:\r\n                    csv_writer.writerow(row)\r\n\r\n\r\n\r\n\r\n\r\n#! aggregate node labels\r\ndef generate_aggregate_labels(fname: str, outname: str, days: int = 7):\r\n    \"\"\"\r\n    aggregate the genres over a number of days,  as specified by days\r\n    prediction should always be at the first second of the day\r\n    #! daily labels are always shifted by 1 day\r\n    \"\"\"\r\n\r\n    ts_prev = 0\r\n\r\n    DAY_IN_SEC = 86400\r\n    timespan = days * DAY_IN_SEC\r\n\r\n    user_dict = {}\r\n\r\n    # ts, src, subreddit, num_words, score\r\n    with open(outname, \"w\") as outf:\r\n        write = csv.writer(outf)\r\n        fields = [\"ts\", \"user_address\", \"token_address\", \"weight\"] #[\"ts\", \"user\", \"subreddit\", \"weight\"]\r\n        write.writerow(fields)\r\n\r\n        with open(fname, \"r\") as csv_file:\r\n            csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n            line_count = 0\r\n            # ts, src, subreddit, num_words, score\r\n            for row in csv_reader:\r\n                if line_count == 0:\r\n                    line_count += 1\r\n                else:\r\n                    ts = float(row[0])\r\n                    ts = int(ts)\r\n                    user = row[1]\r\n                    item = row[2]\r\n                    w = float(row[3])\r\n                    if (w == 0):\r\n                        print (row)\r\n\r\n                    if line_count == 1:\r\n                        ts_prev = ts\r\n\r\n                    if (ts - ts_prev) > timespan:\r\n                        for user in user_dict:\r\n                            total = sum(user_dict[user].values())\r\n                            item_dict = {\r\n                                k: v / total for k, v in user_dict[user].items()\r\n                            }\r\n                            for item, w in item_dict.items():\r\n                                write.writerow(\r\n                                    [ts_prev + DAY_IN_SEC, user, item, w]\r\n                                )\r\n                        user_dict = {}\r\n                        ts_prev = ts_prev + DAY_IN_SEC  #! move label to the next day\r\n                    else:\r\n                        if user in user_dict:\r\n                            if item in user_dict[user]:\r\n                                user_dict[user][item] += w\r\n                            else:\r\n                                user_dict[user][item] = w\r\n                        else:\r\n                            user_dict[user] = {}\r\n                            user_dict[user][item] = w\r\n                    line_count += 1\r\n\r\n\r\n\r\n\r\ndef main():\r\n\r\n    \"\"\"\r\n    processing token types\r\n    \"\"\"\r\n    # fname = \"ERC20_token_network.csv\"\r\n    # #analyze_token_frequency(fname)\r\n\r\n    # token_file = \"token_list.csv\"\r\n    # outname = \"filtered_token_edgelist.csv\"\r\n\r\n    #! filter by token frequency\r\n    # filter_edgelist(token_file, fname, outname)\r\n    # #print_csv(fname)\r\n    # #analyze_csv(fname)\r\n\r\n    \"\"\"\r\n    processing node dict\r\n    \"\"\"\r\n    # fname = \"filtered_token_edgelist.csv\"\r\n    # #! filter by node frequency\r\n    # node_dict = count_node_freq(fname, filter_size=100)\r\n    # store_node_list(node_dict, \"node_list.csv\")\r\n    # #store_token_address(node_dict, \"node_list.csv\", topk=0)\r\n\r\n    # outname = \"tgbl-token-edgelist_100.csv\"\r\n    # filter_by_node(node_dict, fname, outname)\r\n    # analyze_token_frequency('tgbl-token-edgelist_100.csv')\r\n\r\n\r\n    #! converting user-user graph to user-token bipartite graph\r\n    # out_name = \"tgbl-token_edgelist.csv\"\r\n    # node_dict = load_node_dict(\"node_list.csv\")\r\n    # to_bipartite('tgbl-token-edgelist_100.csv', out_name, node_dict)\r\n    # analyze_csv(out_name)\r\n\r\n\r\n    #! convert datetime to seconds\r\n    #convert_2_sec(\"tgbl-token_edgelist_old.csv\", \"tgbn-token_edgelist.csv\")\r\n\r\n\r\n    #! sort the timestamps in the edgelist\r\n    # fname = \"tgbn-token_edgelist.csv\"\r\n    # outname = \"tgbn-token_edgelist_sorted.csv\"\r\n    # sort_edgelist_by_time(fname, outname)\r\n\r\n\r\n\r\n    #! generate node labels\r\n    edgefile = \"tgbn-token_edgelist.csv\"\r\n    outfile = \"tgbn-token_node_labels.csv\"\r\n    days = 7\r\n    generate_aggregate_labels(edgefile, outfile, days=days)\r\n\r\n\r\n\r\n\r\n\r\n    \r\n\r\n\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()"
  },
  {
    "path": "tgb/datasets/dataset_scripts/tgbn-trade.py",
    "content": "import dateutil.parser as dparser\r\nimport csv\r\nimport matplotlib.pyplot as plt\r\nimport numpy as np\r\nfrom tqdm import tqdm\r\nfrom os import listdir\r\n\r\n\r\ndef count_unique_countries(fname):\r\n    node_dict = {}\r\n\r\n    with open(fname, \"r\") as csv_file:\r\n        csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n        line_count = 0\r\n        # year,u,v,w\r\n        for row in csv_reader:\r\n            if line_count == 0:\r\n                line_count += 1\r\n            else:\r\n                year = int(row[0])\r\n                u = row[1]\r\n                v = row[2]\r\n                w = float(row[3])\r\n                if u not in node_dict:\r\n                    node_dict[u] = 1\r\n                if v not in node_dict:\r\n                    node_dict[v] = 1\r\n\r\n    print(\"there are {} unique countries\".format(len(node_dict)))\r\n\r\n\r\n\r\n#! incorrect, do not use\r\ndef normalize_edgelist(fname: str, outname: str):\r\n    \"\"\"\r\n    need to track id for nodes\r\n    normalize the edgelist by row for each year\r\n    \"\"\"\r\n    prev_t = 0\r\n    uid = 0\r\n    node_dict = {}\r\n    year_dict = {}\r\n\r\n    with open(outname, \"w\") as outf:\r\n        write = csv.writer(outf)\r\n        fields = [\"year\", \"nation\", \"trading nation\", \"weight\"]\r\n        write.writerow(fields)\r\n        with open(fname, \"r\") as csv_file:\r\n            csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n            line_count = 0\r\n            for row in csv_reader:\r\n                if line_count == 0:\r\n                    line_count += 1\r\n                else:\r\n                    ts = int(row[0])\r\n                    u = row[1]\r\n                    v = row[2]\r\n                    w = float(row[3])\r\n                    if line_count == 1:\r\n                        prev_t = ts\r\n                    if u not in node_dict:\r\n                        node_dict[u] = uid\r\n                        uid += 1\r\n                    if v not in node_dict:\r\n                        node_dict[v] = uid\r\n                        uid += 1\r\n                    if w == 0:\r\n                        line_count += 1\r\n                        continue\r\n\r\n                    if ts != prev_t:  # a new year now, write everything\r\n                        # normalize the counts\r\n                        for u in year_dict:\r\n                            if np.sum(year_dict[u]) == 0:\r\n                                continue\r\n                            year_dict[u] = year_dict[u] / np.sum(year_dict[u])\r\n                            invert_dict = {v: k for k, v in node_dict.items()}\r\n                            for v in range(len(year_dict[u])):\r\n                                if year_dict[u][v] > 0:\r\n                                    write.writerow(\r\n                                        [prev_t, u, invert_dict[v], year_dict[u][v]]\r\n                                    )\r\n                        year_dict = {}\r\n                        prev_t = ts\r\n                    else:\r\n                        if u not in year_dict:\r\n                            year_dict[u] = np.zeros(255)\r\n                            year_dict[u][node_dict[v]] = w\r\n                        else:\r\n                            year_dict[u][node_dict[v]] = w\r\n                    line_count += 1\r\n\r\n\r\ndef generate_aggregate_labels(fname: str, outname: str):\r\n    \"\"\"\r\n    aggregate the node label for next year\r\n    \"\"\"\r\n\r\n    ts_init = 1986\r\n\r\n    # ts, src, subreddit, num_words, score\r\n    with open(outname, \"w\") as outf:\r\n        write = csv.writer(outf)\r\n        fields = [\"year\", \"nation\", \"trading nation\", \"weight\"]\r\n        write.writerow(fields)\r\n\r\n        with open(fname, \"r\") as csv_file:\r\n            csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n            line_count = 0\r\n            # ts, src, subreddit, num_words, score\r\n            for row in csv_reader:\r\n                if line_count == 0:\r\n                    line_count += 1\r\n                else:\r\n                    ts = int(row[0])\r\n                    u = row[1]\r\n                    v = row[2]\r\n                    w = float(row[3])\r\n                    if (ts > ts_init):\r\n                        write.writerow([ts, u, v, w])\r\n                    line_count += 1\r\n\r\n\r\ndef check_sum_to_one(fname: str):\r\n    \"\"\"\r\n    just to check if weights sum to 1 in a year\r\n    \"\"\"\r\n    u_dict = {}\r\n    ts_prev = 0\r\n\r\n    with open(fname, \"r\") as csv_file:\r\n        csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n        line_count = 0\r\n        # ts, src, subreddit, num_words, score\r\n        for row in csv_reader:\r\n            if line_count == 0:\r\n                line_count += 1\r\n            else:\r\n                ts = int(row[0])\r\n                if (line_count == 1):\r\n                    ts_prev = ts\r\n                if (ts != ts_prev):\r\n                    ts_prev = ts\r\n                    for u in u_dict:\r\n                        print (u_dict[u])\r\n                    u_dict = {}\r\n                u = row[1]\r\n                v = row[2]\r\n                w = float(row[3])\r\n                if (u not in u_dict):\r\n                    u_dict[u] = w\r\n                else:\r\n                    u_dict[u] += w\r\n                line_count += 1\r\n\r\n\r\n\r\n\r\ndef main():\r\n    #! should have the normalized version on the edgelist\r\n\r\n    # #find the number of unique countries\r\n    # fname = \"un_trade_edgelist.csv\"\r\n    # count_unique_countries(fname)\r\n\r\n    #! normalize edgelist by row for each year\r\n    # fname = \"un_trade_edgelist.csv\"\r\n    # outname = \"un_trade_edgelist_normalized.csv\"\r\n    # normalize_edgelist(fname, outname)\r\n\r\n    #! find the node label for next year\r\n    # * the node labels are simply the edgelist in this case\r\n    # fname = \"un_trade_edgelist.csv\"\r\n    # outname = \"un_trade_node_labels.csv\"\r\n    # generate_aggregate_labels(fname, outname)\r\n\r\n\r\n    # #! check if all sums are correct\r\n    # fname = \"un_trade_node_labels.csv\"\r\n    # check_sum_to_one(fname)\r\n\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n"
  },
  {
    "path": "tgb/datasets/tgbl_enron/tgbl-enron_neg_generator.py",
    "content": "import timeit\r\nfrom tgb.linkproppred.negative_generator import NegativeEdgeGenerator\r\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\r\n\r\n\r\n\r\ndef main():\r\n    r\"\"\"\r\n    Generate negative edges for the validation or test phase\r\n    \"\"\"\r\n    print(\"*** Negative Sample Generation ***\")\r\n\r\n    # setting the required parameters\r\n    num_neg_e_per_pos = 1000 # more than half the nodes in the graph\r\n    neg_sample_strategy = \"hist_rnd\" #\"rnd\"\r\n    rnd_seed = 42\r\n\r\n\r\n    name = \"tgbl-enron\"\r\n    dataset = PyGLinkPropPredDataset(name=name, root=\"datasets\")\r\n    train_mask = dataset.train_mask\r\n    val_mask = dataset.val_mask\r\n    test_mask = dataset.test_mask\r\n    data = dataset.get_TemporalData()\r\n\r\n    data_splits = {}\r\n    data_splits['train'] = data[train_mask]\r\n    data_splits['val'] = data[val_mask]\r\n    data_splits['test'] = data[test_mask]\r\n\r\n    # Ensure to only sample actual destination nodes as negatives.\r\n    min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\r\n\r\n    # After successfully loading the dataset...\r\n    if neg_sample_strategy == \"hist_rnd\":\r\n        historical_data = data_splits[\"train\"]\r\n    else:\r\n        historical_data = None\r\n\r\n    neg_sampler = NegativeEdgeGenerator(\r\n        dataset_name=name,\r\n        first_dst_id=min_dst_idx,\r\n        last_dst_id=max_dst_idx,\r\n        num_neg_e=num_neg_e_per_pos,\r\n        strategy=neg_sample_strategy,\r\n        rnd_seed=rnd_seed,\r\n        historical_data=historical_data,\r\n    )\r\n\r\n    # generate evaluation set\r\n    partial_path = \"./\"\r\n    # generate validation negative edge set\r\n    start_time = timeit.default_timer()\r\n    split_mode = \"val\"\r\n    print(\r\n        f\"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}\"\r\n    )\r\n    neg_sampler.generate_negative_samples(\r\n        data=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path\r\n    )\r\n    print(\r\n        f\"INFO: End of negative samples generation. Elapsed Time (s): {timeit.default_timer() - start_time: .4f}\"\r\n    )\r\n\r\n    # generate test negative edge set\r\n    start_time = timeit.default_timer()\r\n    split_mode = \"test\"\r\n    print(\r\n        f\"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}\"\r\n    )\r\n    neg_sampler.generate_negative_samples(\r\n        data=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path\r\n    )\r\n    print(\r\n        f\"INFO: End of negative samples generation. Elapsed Time (s): {timeit.default_timer()- start_time: .4f}\"\r\n\r\n    )\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n"
  },
  {
    "path": "tgb/datasets/tgbl_enron/tgbl_enron.py",
    "content": "import csv\n\n\nwith open('ml_enron.csv', 'r', newline='\\n') as infile, open('tgbl-enron_edgelist.csv', 'w', newline='\\n') as outfile:\n    reader = csv.reader(infile)\n    writer = csv.writer(outfile)\n    for row in reader:\n        writer.writerow(row[1:])"
  },
  {
    "path": "tgb/datasets/tgbl_lastfm/tgbl-lastfm_neg_generator.py",
    "content": "import timeit\r\nfrom tgb.linkproppred.negative_generator import NegativeEdgeGenerator\r\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\r\n\r\n\r\n\r\ndef main():\r\n    r\"\"\"\r\n    Generate negative edges for the validation or test phase\r\n    \"\"\"\r\n    print(\"*** Negative Sample Generation ***\")\r\n\r\n    # setting the required parameters\r\n    num_neg_e_per_pos = 2000 #this is all nodes in the dataset\r\n    neg_sample_strategy = \"hist_rnd\" #\"rnd\"\r\n    rnd_seed = 42\r\n\r\n\r\n    name = \"tgbl-lastfm\"\r\n    dataset = PyGLinkPropPredDataset(name=name, root=\"datasets\")\r\n    train_mask = dataset.train_mask\r\n    val_mask = dataset.val_mask\r\n    test_mask = dataset.test_mask\r\n    data = dataset.get_TemporalData()\r\n\r\n    data_splits = {}\r\n    data_splits['train'] = data[train_mask]\r\n    data_splits['val'] = data[val_mask]\r\n    data_splits['test'] = data[test_mask]\r\n\r\n    # Ensure to only sample actual destination nodes as negatives.\r\n    min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\r\n\r\n    # After successfully loading the dataset...\r\n    if neg_sample_strategy == \"hist_rnd\":\r\n        historical_data = data_splits[\"train\"]\r\n    else:\r\n        historical_data = None\r\n\r\n    neg_sampler = NegativeEdgeGenerator(\r\n        dataset_name=name,\r\n        first_dst_id=min_dst_idx,\r\n        last_dst_id=max_dst_idx,\r\n        num_neg_e=num_neg_e_per_pos,\r\n        strategy=neg_sample_strategy,\r\n        rnd_seed=rnd_seed,\r\n        historical_data=historical_data,\r\n    )\r\n\r\n    # generate evaluation set\r\n    partial_path = \"./\"\r\n    # generate validation negative edge set\r\n    start_time = timeit.default_timer()\r\n    split_mode = \"val\"\r\n    print(\r\n        f\"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}\"\r\n    )\r\n    neg_sampler.generate_negative_samples(\r\n        data=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path\r\n    )\r\n    print(\r\n        f\"INFO: End of negative samples generation. Elapsed Time (s): {timeit.default_timer() - start_time: .4f}\"\r\n    )\r\n\r\n    # generate test negative edge set\r\n    start_time = timeit.default_timer()\r\n    split_mode = \"test\"\r\n    print(\r\n        f\"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}\"\r\n    )\r\n    neg_sampler.generate_negative_samples(\r\n        data=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path\r\n    )\r\n    print(\r\n        f\"INFO: End of negative samples generation. Elapsed Time (s): {timeit.default_timer()- start_time: .4f}\"\r\n\r\n    )\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n"
  },
  {
    "path": "tgb/datasets/tgbl_subreddit/tgbl-subreddit_neg_generator.py",
    "content": "import timeit\r\nfrom tgb.linkproppred.negative_generator import NegativeEdgeGenerator\r\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\r\n\r\n\r\n\r\ndef main():\r\n    r\"\"\"\r\n    Generate negative edges for the validation or test phase\r\n    \"\"\"\r\n    print(\"*** Negative Sample Generation ***\")\r\n\r\n    # setting the required parameters\r\n    num_neg_e_per_pos = 1000 \r\n    neg_sample_strategy = \"hist_rnd\" #\"rnd\"\r\n    rnd_seed = 42\r\n\r\n\r\n    name = \"tgbl-subreddit\"\r\n    dataset = PyGLinkPropPredDataset(name=name, root=\"datasets\")\r\n    train_mask = dataset.train_mask\r\n    val_mask = dataset.val_mask\r\n    test_mask = dataset.test_mask\r\n    data = dataset.get_TemporalData()\r\n\r\n    data_splits = {}\r\n    data_splits['train'] = data[train_mask]\r\n    data_splits['val'] = data[val_mask]\r\n    data_splits['test'] = data[test_mask]\r\n\r\n    # Ensure to only sample actual destination nodes as negatives.\r\n    min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\r\n\r\n    # After successfully loading the dataset...\r\n    if neg_sample_strategy == \"hist_rnd\":\r\n        historical_data = data_splits[\"train\"]\r\n    else:\r\n        historical_data = None\r\n\r\n    neg_sampler = NegativeEdgeGenerator(\r\n        dataset_name=name,\r\n        first_dst_id=min_dst_idx,\r\n        last_dst_id=max_dst_idx,\r\n        num_neg_e=num_neg_e_per_pos,\r\n        strategy=neg_sample_strategy,\r\n        rnd_seed=rnd_seed,\r\n        historical_data=historical_data,\r\n    )\r\n\r\n    # generate evaluation set\r\n    partial_path = \"./\"\r\n    # generate validation negative edge set\r\n    start_time = timeit.default_timer()\r\n    split_mode = \"val\"\r\n    print(\r\n        f\"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}\"\r\n    )\r\n    neg_sampler.generate_negative_samples(\r\n        data=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path\r\n    )\r\n    print(\r\n        f\"INFO: End of negative samples generation. Elapsed Time (s): {timeit.default_timer() - start_time: .4f}\"\r\n    )\r\n\r\n    # generate test negative edge set\r\n    start_time = timeit.default_timer()\r\n    split_mode = \"test\"\r\n    print(\r\n        f\"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}\"\r\n    )\r\n    neg_sampler.generate_negative_samples(\r\n        data=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path\r\n    )\r\n    print(\r\n        f\"INFO: End of negative samples generation. Elapsed Time (s): {timeit.default_timer()- start_time: .4f}\"\r\n\r\n    )\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n"
  },
  {
    "path": "tgb/datasets/tgbl_uci/tgbl-uci_neg_generator.py",
    "content": "import timeit\r\nfrom tgb.linkproppred.negative_generator import NegativeEdgeGenerator\r\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\r\n\r\n\r\n\r\ndef main():\r\n    r\"\"\"\r\n    Generate negative edges for the validation or test phase\r\n    \"\"\"\r\n    print(\"*** Negative Sample Generation ***\")\r\n\r\n    # setting the required parameters\r\n    num_neg_e_per_pos = 1000 # more than half the nodes in the graph\r\n    neg_sample_strategy = \"hist_rnd\" #\"rnd\"\r\n    rnd_seed = 42\r\n\r\n\r\n    name = \"tgbl-uci\"\r\n    dataset = PyGLinkPropPredDataset(name=name, root=\"datasets\")\r\n    train_mask = dataset.train_mask\r\n    val_mask = dataset.val_mask\r\n    test_mask = dataset.test_mask\r\n    data = dataset.get_TemporalData()\r\n\r\n    data_splits = {}\r\n    data_splits['train'] = data[train_mask]\r\n    data_splits['val'] = data[val_mask]\r\n    data_splits['test'] = data[test_mask]\r\n\r\n    # Ensure to only sample actual destination nodes as negatives.\r\n    min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\r\n\r\n    # After successfully loading the dataset...\r\n    if neg_sample_strategy == \"hist_rnd\":\r\n        historical_data = data_splits[\"train\"]\r\n    else:\r\n        historical_data = None\r\n\r\n    neg_sampler = NegativeEdgeGenerator(\r\n        dataset_name=name,\r\n        first_dst_id=min_dst_idx,\r\n        last_dst_id=max_dst_idx,\r\n        num_neg_e=num_neg_e_per_pos,\r\n        strategy=neg_sample_strategy,\r\n        rnd_seed=rnd_seed,\r\n        historical_data=historical_data,\r\n    )\r\n\r\n    # generate evaluation set\r\n    partial_path = \"./\"\r\n    # generate validation negative edge set\r\n    start_time = timeit.default_timer()\r\n    split_mode = \"val\"\r\n    print(\r\n        f\"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}\"\r\n    )\r\n    neg_sampler.generate_negative_samples(\r\n        data=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path\r\n    )\r\n    print(\r\n        f\"INFO: End of negative samples generation. Elapsed Time (s): {timeit.default_timer() - start_time: .4f}\"\r\n    )\r\n\r\n    # generate test negative edge set\r\n    start_time = timeit.default_timer()\r\n    split_mode = \"test\"\r\n    print(\r\n        f\"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}\"\r\n    )\r\n    neg_sampler.generate_negative_samples(\r\n        data=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path\r\n    )\r\n    print(\r\n        f\"INFO: End of negative samples generation. Elapsed Time (s): {timeit.default_timer()- start_time: .4f}\"\r\n\r\n    )\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n"
  },
  {
    "path": "tgb/datasets/tgbl_uci/tgbl_uci.py",
    "content": "import csv\n\n\nwith open('ml_uci.csv', 'r', newline='\\n') as infile, open('tgbl-uci_edgelist.csv', 'w', newline='\\n') as outfile:\n    reader = csv.reader(infile)\n    writer = csv.writer(outfile)\n    for row in reader:\n        writer.writerow(row[1:])"
  },
  {
    "path": "tgb/datasets/thgl_forum/merge_files.py",
    "content": "\r\nimport csv\r\nfrom tqdm import tqdm\r\nfrom os import listdir\r\nimport glob\r\n\r\ndef find_filenames(path_to_dir):\r\n    r\"\"\"\r\n    find all files in a folder\r\n    Parameters:\r\n        path_to_dir (str): path to the directory\r\n    \"\"\"\r\n    # filenames = glob.glob(path_to_dir)\r\n    filenames = listdir(path_to_dir)\r\n    return filenames\r\n\r\ndef read_edgelist(fname, outfname, write_header=False):\r\n    \"\"\"\r\n    read a space separated edgelist\r\n    comment’s author, author of the parent (the post that the comment is replied to), comment’s creation time, comment’s edge id\r\n    u,v,t,edge_id\r\n    3746738\t1637382\t1551398391\t31534079835\r\n    Parameters:\r\n        fname (str): path to the edgelist\r\n        outfname (str): path to the output file\r\n    \"\"\"\r\n    edgelist = open(fname, \"r\")\r\n    lines = list(edgelist.readlines())\r\n    edgelist.close()\r\n\r\n    with open(outfname, \"a\") as outf:\r\n        write = csv.writer(outf)\r\n        if write_header:\r\n            fields = [\"ts\", \"src\", \"dst\", \"edge_id\"]\r\n            write.writerow(fields)\r\n        for line in lines:\r\n            line = line.split()\r\n            if len(line) < 4:\r\n                continue\r\n            src = line[0]\r\n            dst = line[1]\r\n            ts = line[2]\r\n            edge_id = line[3]\r\n            write.writerow([ts, src, dst, edge_id])\r\n\r\n\r\ndef read_nodeattr(fname, outfname, write_header=False):\r\n    \"\"\"\r\n    read a space separated edgelist\r\n    comment’s edge id, Reddit’s identifier of the comment, Reddit’s identifier of the parent (the post that the comment is replied to)\r\n    Reddit’s identifier of the submission that the comment is in, name of the subreddit that the comment is in, number of characters in the comment’s body\r\n    number of words in the comment’s body, score of the comment, a flag indicating if the comment has been edited\r\n\r\n    - comment’s edge id\r\n    - Reddit’s identifier of the comment\r\n    - Reddit’s identifier of the parent (the post that the comment is replied to)\r\n    - Reddit’s identifier of the submission that the comment is in\r\n    - name of the subreddit that the comment is in\r\n    - number of characters in the comment’s body\r\n    - number of words in the comment’s body\r\n    - score of the comment\r\n    - a flag indicating if the comment has been edited\r\n\r\n\r\n\r\n    edge_id, subreddit, num_characters, num_words, score, 'edited_flag'\r\n    Parameters:\r\n        fname (str): path to the edgelist\r\n        outfname (str): path to the output file\r\n    \"\"\"\r\n    edgelist = open(fname, \"r\")\r\n    lines = list(edgelist.readlines())\r\n    edgelist.close()\r\n\r\n    max_words = 0\r\n    max_score = 0\r\n    min_score = 100000000\r\n\r\n    with open(outfname, \"a\") as outf:\r\n        write = csv.writer(outf)\r\n        if write_header:\r\n            fields = [\r\n                \"edge_id\",\r\n                \"reddit_id\",\r\n                \"reddit_parent_id\",\r\n                \"subreddit\",\r\n                \"num_characters\",\r\n                \"num_words\",\r\n                \"score\",\r\n                \"edited_flag\",\r\n            ]\r\n            write.writerow(fields)\r\n        for line in lines:\r\n            line = line.split()\r\n            if len(line) < 4:\r\n                continue\r\n            edge_id = line[0]\r\n            reddit_id = line[1]\r\n            reddit_parent_id = line[2]\r\n            subreddit = line[4]\r\n            num_characters = line[5]\r\n            num_words = line[6]\r\n            if (int(num_words) > max_words):\r\n                max_words = int(num_words)\r\n            score = line[7]\r\n            if (int(score) < min_score):\r\n                min_score = int(score)\r\n            edited_flag = line[8].strip(\"/n\")\r\n            write.writerow(\r\n                [edge_id, reddit_id, reddit_parent_id, subreddit, num_characters, num_words, score, edited_flag]\r\n            )\r\n    print(\"max # words\", max_words)\r\n    print(\"min score\", min_score)\r\n    print(\"max score\", max_score)\r\n\r\n\r\ndef combine_edgelist_edgefeat(edgefname, featfname, outname):\r\n    \"\"\"\r\n    combine edgelist and edge features\r\n    \"\"\"\r\n    total_lines = sum(1 for line in open(edgefname))\r\n    subreddit_ids = {}\r\n\r\n    missing_ts = 0\r\n    missing_src = 0\r\n    missing_dst = 0\r\n    line_idx = 0\r\n\r\n    with open(outname, \"w\") as outf:\r\n        write = csv.writer(outf)\r\n        #fields = [\"ts\", \"src\", \"dst\", \"subreddit\", \"num_words\", \"score\"]\r\n        fields = [\"ts\", \"src\", \"dst\", \"reddit_id\", \"reddit_parent_id\", \"subreddit\", \"num_words\", \"score\"]\r\n        write.writerow(fields)\r\n        sub_id = 0\r\n        edgelist = open(edgefname, \"r\")\r\n        edgefeat = open(featfname, \"r\")\r\n        edgelist.readline()\r\n        edgefeat.readline()\r\n\r\n        while True:\r\n            #'ts', 'src', 'dst', 'edge_id'\r\n            edge_line = edgelist.readline()\r\n            edge_line = edge_line.split(\",\")\r\n            if len(edge_line) < 4:\r\n                break\r\n            edge_id = int(edge_line[3])\r\n            ts = int(edge_line[0])\r\n            src = int(edge_line[1])\r\n            dst = int(edge_line[2])\r\n\r\n            # \"edge_id\", \"reddit_id\", \"reddit_parent_id\", \"subreddit\", \"num_characters\", \"num_words\", \"score\", \"edited_flag\",\r\n            feat_line = edgefeat.readline()\r\n            feat_line = feat_line.split(\",\")\r\n            edge_id_feat = int(feat_line[0])\r\n            reddit_id = feat_line[1]\r\n            reddit_parent_id = feat_line[2]\r\n            subreddit = feat_line[3]\r\n            num_characters = int(feat_line[4])\r\n            num_words = int(feat_line[5])\r\n            score = int(feat_line[6])\r\n            edited_flag = bool(feat_line[7])\r\n\r\n            #! check if ts, src, dst is -1\r\n            if ts == -1:\r\n                missing_ts += 1\r\n                continue\r\n            if src == -1:\r\n                missing_src += 1\r\n                continue\r\n            if dst == -1:\r\n                missing_dst += 1\r\n                continue\r\n\r\n            if edge_id != edge_id_feat:\r\n                print(\"edge_id != edge_id_feat\")\r\n                print(edge_id)\r\n                print(edge_id_feat)\r\n                break\r\n\r\n            # write.writerow([ts, src, dst, subreddit, num_words, score])\r\n            #write.writerow([ts, src, dst, num_words, score])\r\n            #? ts: int, src: int (user_id), dst: int (user_id), subreddit: str, reddit_id: str (comment_id), reddit_parent_id: str (post_id), num_words: int, score: int\r\n            write.writerow([ts, src, dst, subreddit, reddit_id, reddit_parent_id, num_words, score])\r\n            line_idx += 1\r\n    print(\"processed\", line_idx, \"lines\")\r\n    print (\"there are lines\", missing_ts, \" missing timestamps\")\r\n    print (\"there are lines\", missing_src, \" missing src\")\r\n    print (\"there are lines\", missing_dst, \" missing dst\")\r\n\r\n\r\n\r\ndef main():\r\n    # #! unzip all xz files by $ unxz *.xz\r\n\r\n    f_dir = \"edge_files/\" #\"raw/raw_2005_2010/\" #\"raw/raw_2013_2014/\"\r\n    fnames = find_filenames(f_dir)\r\n    edgefname = \"reddit_edgefile_2014_01.csv\" #\"redditcomments_edgelist_2013_2014.csv\"\r\n    idx = 0\r\n    for fname in tqdm(fnames):\r\n        print (\"processing, \", fname)\r\n        if (idx == 0):\r\n            read_edgelist(f_dir+fname, edgefname, write_header=True)\r\n        else:\r\n            read_edgelist(f_dir+fname, edgefname, write_header=False)\r\n        idx += 1\r\n\r\n    \r\n    # # #! extract the node attributes\r\n    f_dir = \"attribute_files/\" \r\n    fnames = find_filenames(f_dir)\r\n    featfname = \"reddit_attribute_2014_01.csv\"\r\n    idx = 0\r\n    for fname in tqdm(fnames):\r\n        print (\"processing, \", fname)\r\n        if (idx == 0):\r\n            read_nodeattr(f_dir+fname, featfname, write_header=True)\r\n        else:\r\n            read_nodeattr(f_dir+fname, featfname, write_header=False)\r\n        idx += 1\r\n\r\n    #! combine edgelist and edge feat file check if the edge_id matches\r\n    # edgefname = \"reddit_edgefile_2019_01_03.csv\"\r\n    # featfname = \"reddit_attribute_2019_01_03.csv\"\r\n    outname = \"reddit_edgelist.csv\"\r\n    combine_edgelist_edgefeat(edgefname, featfname, outname)\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n"
  },
  {
    "path": "tgb/datasets/thgl_forum/thgl-forum.py",
    "content": "\r\n\"\"\"\r\nData source: \r\n# https://surfdrive.surf.nl/files/index.php/s/M09RDerAMZrQy8q#editor\r\n\r\n# https://dl.acm.org/doi/abs/10.1145/3487553.3524699\r\n\r\nalso see here: https://arxiv.org/pdf/1803.03697\r\n\r\n#  Temporal Social Network Dataset of Reddit\r\n### Dataset to accompany “A large-scale temporal analysis of user lifespan durability on the Reddit social media platform” (WWW 2022).\r\n\r\n## Overview\r\n\r\nThis dataset consists of more than 6.7 billion Reddit comment interactions made from the beginning of Reddit in 2005 until the end of 2019. \r\n\r\n### Nodes\r\n\r\nNodes in the network represent users who posted at least one comment or one submission until the end of 2019 and have not deleted their accounts by the time of the data ingestion. \r\n\r\nEach user is assigned a unique identifier starting from 0, and -1 is the identifier of the node representing deleted users. The `nodes` file maintains the node-identifier assignment to each user’s username.\r\n\r\n### Edges\r\n\r\nFor each month of our data, we maintain two separate files, an edge file that consists of temporal edges data and an attribute file that consists of attributes of each interaction. All these files are in a tab-separated format. The compressed edge files and compressed attribute files are available in the `edges` and the `attributes` directory, respectively. The name of files indicates the timeframe they belong to.\r\n\r\nEach line in an edge file corresponds to a comment and includes:\r\n\r\n - comment’s author\r\n - author of the parent (the post that the comment is replied to)\r\n - comment’s creation time\r\n - comment’s edge id\r\n\r\nEach line in an attribute file corresponds to the line with the same line number in the corresponding edge file and includes:\r\n\r\n - comment’s edge id\r\n - Reddit’s identifier of the comment\r\n - Reddit’s identifier of the parent (the post that the comment is replied to)\r\n - Reddit’s identifier of the submission that the comment is in\r\n - name of the subreddit that the comment is in\r\n - number of characters in the comment’s body\r\n - number of words in the comment’s body\r\n - score of the comment\r\n - a flag indicating if the comment has been edited\r\n\r\n### Stats\r\n\r\nSize (compressed): 125GB\r\nSize (uncompressed): 652GB\r\nNumber of nodes: 62,402,844\r\nNumber of edges: 6,728,759,080\r\n\r\n### Notes\r\n\r\nReddit banned the subreddit `/r/Incels` in November of 2017, and its data is no longer available via the Reddit API. This has resulted in the loss of score data for 119,111 comments made in October and November of 2017 in this subreddit. The affected entries have a null value as their score. \r\n\r\n## Citation\r\n\r\nIf you want to reuse this dataset, you can reference it as follows:\r\n\r\nA. Nadiri and F.W. Takes, A large-scale temporal analysis of user lifespan durability on the Reddit social media platform, in Proceedings of the 28th ACM International Web Conference (TheWebConf) Workshops, 2022.\r\n\r\n## Online repository\r\n\r\nThe dataset is available for download at [**LINK**](https://surfdrive.surf.nl/files/index.php/s/M09RDerAMZrQy8q)\r\n\r\n## Acknowledgments\r\n\r\nThe dataset is constructed using data provided by [The Pushshift Reddit Dataset](https://ojs.aaai.org/index.php/ICWSM/article/view/7347)\r\n\r\n\r\n\"\"\"\r\n\r\n\r\n\r\n\"\"\"\r\nideas for temporal heterogenous graph in reddit data:\r\n\r\nnode types:\r\n1. user\r\n2. subreddit\r\n\r\nedge types\r\n1. user post in subreddit (top level)\r\n2. user replies to another user \r\n3. user replies in subreddit\r\n\r\n\r\n\r\n# node types:\r\n# 1. user \r\n# 2. subreddit\r\n# 3. comment\r\n\r\n\r\n# edge types\r\n# 1. user makes comment in subreddit (top level comment)\r\n# 2. user replies to comment in subreddit (comments that has a parent)\r\n# 2. comment is child of comment (comments that has a parent)\r\n# 3. comment belongs to subreddit\r\n\"\"\"\r\nimport csv\r\nfrom tgb.utils.utils import save_pkl, load_pkl\r\n\r\n\r\ndef load_csv_raw(fname):\r\n    \"\"\"\r\n    load the raw csv file and merge them into one\r\n    ts, src, dst, subreddit, reddit_id, reddit_parent_id, num_words, score\r\n    \"\"\"\r\n    out_dict = {}\r\n    num_lines = 0\r\n    max_words = 0\r\n    min_words = 10000000\r\n\r\n    max_score = 0\r\n    min_score = 1000000\r\n\r\n    \"\"\"\r\n    relation types:\r\n    0: user replies to user\r\n    1: user replies to subreddit\r\n\r\n    node types:\r\n    0: user\r\n    1: subreddit\r\n    \"\"\"\r\n\r\n    node_dict = {}\r\n    node_type_dict = {}\r\n    reddit_deg_dict = {}\r\n    node_deg_dict = {}\r\n    header = True\r\n    with open(fname, 'r') as f:\r\n        reader = csv.reader(f, delimiter =',')\r\n        #* ts, src, dst, subreddit, reddit_id, reddit_parent_id, num_words, score\r\n        #? 1388534400,32183137,51851117,AskReddit,t1_ceefsvy,t3_1u4kbf,32,1\r\n        for row in reader: \r\n            if header:\r\n                header = False\r\n                continue\r\n            ts = int(row[0])\r\n            src = row[1]\r\n            if (src not in node_dict):\r\n                node_dict[src] = len(node_dict)\r\n                node_type_dict[node_dict[src]] = 0\r\n            dst = row[2]\r\n            if (dst not in node_dict):\r\n                node_dict[dst] = len(node_dict)\r\n                node_type_dict[node_dict[dst]] = 0\r\n\r\n            if (src not in node_deg_dict):\r\n                node_deg_dict[src] = 1\r\n            else:\r\n                node_deg_dict[src] += 1\r\n            if (dst not in node_deg_dict):\r\n                node_deg_dict[dst] = 1\r\n            else:\r\n                node_deg_dict[dst] += 1\r\n\r\n            subreddit = row[3]\r\n            if (subreddit not in node_dict):\r\n                node_dict[subreddit] = len(node_dict)\r\n                node_type_dict[node_dict[subreddit]] = 1\r\n            \r\n            if (subreddit not in reddit_deg_dict):\r\n                reddit_deg_dict[subreddit] = 1\r\n            else:\r\n                reddit_deg_dict[subreddit] += 1\r\n            \r\n            num_words = int(row[6])\r\n            if (num_words > max_words):\r\n                max_words = num_words\r\n            if (num_words < min_words):\r\n                min_words = num_words\r\n            score = int(row[7])\r\n            if (score > max_score):\r\n                max_score = score\r\n            if (score < min_score):\r\n                min_score = score\r\n\r\n            if (ts in out_dict):\r\n                out_dict[ts][(node_dict[src], node_dict[dst], 0)] = (num_words, score)\r\n                out_dict[ts][(node_dict[src], node_dict[subreddit], 1)] = (num_words, score)\r\n            else:\r\n                out_dict[ts] = {}\r\n                out_dict[ts][(node_dict[src], node_dict[dst], 0)] = (num_words, score)\r\n                out_dict[ts][(node_dict[src], node_dict[subreddit], 1)] = (num_words, score)\r\n            num_lines += 1\r\n\r\n    print (\"max words: \", max_words)\r\n    print (\"min words: \", min_words)\r\n    print (\"max score: \", max_score)\r\n    print (\"min score: \", min_score)\r\n    return out_dict, num_lines, node_dict, node_type_dict, reddit_deg_dict, node_deg_dict\r\n\r\n\r\n\r\ndef load_csv_filtered_node(fname, low_deg_dict):\r\n    \"\"\"\r\n    load the raw csv file, remove edges with low degree nodes\r\n    ts, src, dst, subreddit, reddit_id, reddit_parent_id, num_words, score\r\n    \"\"\"\r\n    out_dict = {}\r\n    num_lines = 0\r\n    \"\"\"\r\n    relation types:\r\n    0: user replies to user\r\n    1: user replies to subreddit\r\n\r\n    node types:\r\n    0: user\r\n    1: subreddit\r\n    \"\"\"\r\n\r\n    node_dict = {}\r\n    node_type_dict = {}\r\n    header = True\r\n    with open(fname, 'r') as f:\r\n        reader = csv.reader(f, delimiter =',')\r\n        #* ts, src, dst, subreddit, reddit_id, reddit_parent_id, num_words, score\r\n        #? 1388534400,32183137,51851117,AskReddit,t1_ceefsvy,t3_1u4kbf,32,1\r\n        for row in reader: \r\n            if header:\r\n                header = False\r\n                continue\r\n            ts = int(row[0])\r\n            src = row[1]\r\n            dst = row[2]\r\n\r\n            #* filter low degree nodes\r\n            if (src in low_deg_dict or dst in low_deg_dict):\r\n                continue\r\n\r\n            if (src not in node_dict):\r\n                node_dict[src] = len(node_dict)\r\n                node_type_dict[node_dict[src]] = 0\r\n            if (dst not in node_dict):\r\n                node_dict[dst] = len(node_dict)\r\n                node_type_dict[node_dict[dst]] = 0\r\n\r\n            subreddit = row[3]\r\n            if (subreddit not in node_dict):\r\n                node_dict[subreddit] = len(node_dict)\r\n                node_type_dict[node_dict[subreddit]] = 1\r\n       \r\n            num_words = int(row[6])\r\n            score = int(row[7])\r\n\r\n            if (ts in out_dict):\r\n                out_dict[ts][(node_dict[src], node_dict[dst], 0)] = (num_words, score)\r\n                out_dict[ts][(node_dict[src], node_dict[subreddit], 1)] = (num_words, score)\r\n            else:\r\n                out_dict[ts] = {}\r\n                out_dict[ts][(node_dict[src], node_dict[dst], 0)] = (num_words, score)\r\n                out_dict[ts][(node_dict[src], node_dict[subreddit], 1)] = (num_words, score)\r\n            num_lines += 1\r\n    return out_dict, num_lines, node_dict, node_type_dict\r\n\r\n\r\n\r\n\r\n\r\ndef writeNodeType(node_type_dict, outname):\r\n    r\"\"\"\r\n    write the node type mapping to a file\r\n    \"\"\"\r\n    with open(outname, 'w') as f:\r\n        writer = csv.writer(f, delimiter =',')\r\n        writer.writerow(['node_id', 'type'])\r\n        for key in node_type_dict:\r\n            writer.writerow([key, node_type_dict[key]])\r\n\r\n\r\ndef write2csv(outname, out_dict):\r\n    with open(outname, 'w') as f:\r\n        writer = csv.writer(f, delimiter =',')\r\n        writer.writerow(['ts', 'src', 'dst', 'relation_type', 'num_words', 'score'])\r\n        ts_list = list(out_dict.keys())\r\n        ts_list.sort()\r\n\r\n        for ts in ts_list:\r\n            for edge in out_dict[ts]:\r\n                head = edge[0]\r\n                tail = edge[1]\r\n                relation_type = edge[2]\r\n                num_words, score = out_dict[ts][edge]\r\n                row = [ts, head, tail, relation_type, num_words, score]\r\n                writer.writerow(row)\r\n\r\n\r\ndef writeNodeIDMapping(node_dict, outname):\r\n    r\"\"\"\r\n    write the node id mapping to a file\r\n    \"\"\"\r\n    with open(outname, 'w') as f:\r\n        writer = csv.writer(f, delimiter =',')\r\n        writer.writerow(['node_name', 'node_id'])\r\n        for key in node_dict:\r\n            writer.writerow([key, node_dict[key]])\r\n\r\ndef node_deg_filter(node_deg_dict):\r\n    \"\"\"\r\n    filter out nodes with degree less than threshold\r\n    \"\"\"\r\n    deg10_nodes = 0\r\n    deg100_nodes = 0\r\n    deg1000_nodes = 0\r\n\r\n    for key in node_deg_dict:\r\n        if (node_deg_dict[key] < 10):\r\n            deg10_nodes += 1\r\n        if (node_deg_dict[key] < 100):\r\n            deg100_nodes += 1\r\n        if (node_deg_dict[key] < 1000):\r\n            deg1000_nodes += 1\r\n    print (\"nodes with degree less than 10: \", deg10_nodes)\r\n    print (\"nodes with degree less than 100: \", deg100_nodes)\r\n    print (\"nodes with degree less than 1000: \", deg1000_nodes)\r\n\r\ndef find_low_degree_nodes(node_deg_dict, threshold=10):\r\n    \"\"\"\r\n    find nodes with degree less than threshold\r\n    \"\"\"\r\n    low_degree_nodes = {}\r\n    for key in node_deg_dict:\r\n        if (node_deg_dict[key] < threshold):\r\n            low_degree_nodes[key] = 1\r\n    return low_degree_nodes\r\n\r\n\r\ndef main():\r\n    fname = \"reddit_edgelist.csv\"\r\n    _, _, _, _, _, node_deg_dict = load_csv_raw(fname)\r\n    # print (\"checking node degree\")\r\n    # node_deg_filter(node_deg_dict)\r\n    # print (\"checking reddit degree\")\r\n    # node_deg_filter(reddit_deg_dict)\r\n    # low_degree_nodes = find_low_degree_nodes(node_deg_dict, threshold=100)\r\n    # save_pkl(low_degree_nodes, 'low_degree_nodes.pkl')\r\n\r\n\r\n    low_degree_nodes = load_pkl('low_degree_nodes.pkl')\r\n    out_dict, num_lines, node_dict, node_type_dict = load_csv_filtered_node(fname, low_degree_nodes)\r\n    writeNodeType(node_type_dict, 'thgl-forum_nodetype.csv')\r\n    writeNodeIDMapping(node_dict, 'thgl-forum_nodeIDmapping.csv')\r\n    write2csv('thgl-forum_edgelist.csv', out_dict)\r\n\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()"
  },
  {
    "path": "tgb/datasets/thgl_forum/thgl_forum_ns_gen.py",
    "content": "import time\r\nfrom tgb.linkproppred.thg_negative_generator import THGNegativeEdgeGenerator\r\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\r\n\r\n\r\ndef main():\r\n    r\"\"\"\r\n    Generate negative edges for the validation or test phase\r\n    \"\"\"\r\n    print(\"*** Negative Sample Generation ***\")\r\n\r\n    # setting the required parameters\r\n    num_neg_e_per_pos = 100 #-1 \r\n    neg_sample_strategy = \"node-type-filtered\"\r\n    rnd_seed = 42\r\n\r\n\r\n    name = \"thgl-forum\"\r\n    dataset = PyGLinkPropPredDataset(name=name, root=\"datasets\")\r\n    train_mask = dataset.train_mask\r\n    val_mask = dataset.val_mask\r\n    test_mask = dataset.test_mask\r\n    data = dataset.get_TemporalData()\r\n\r\n\r\n\r\n    data_splits = {}\r\n    data_splits['train'] = data[train_mask]\r\n    data_splits['val'] = data[val_mask]\r\n    data_splits['test'] = data[test_mask]\r\n\r\n    min_node_idx = min(int(data.src.min()), int(data.dst.min()))\r\n    max_node_idx = max(int(data.src.max()), int(data.dst.max()))\r\n\r\n    neg_sampler = THGNegativeEdgeGenerator(\r\n        dataset_name=name,\r\n        first_node_id=min_node_idx,\r\n        last_node_id=max_node_idx,\r\n        node_type=dataset.node_type,\r\n        num_neg_e=num_neg_e_per_pos,\r\n        strategy=neg_sample_strategy,\r\n        rnd_seed=rnd_seed,\r\n        edge_data=data,\r\n    )\r\n\r\n   \r\n    # generate evaluation set\r\n    partial_path = \".\"\r\n    # generate validation negative edge set\r\n    start_time = time.time()\r\n    split_mode = \"val\"\r\n    print(\r\n        f\"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}\"\r\n    )\r\n    neg_sampler.generate_negative_samples(\r\n        pos_edges=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path\r\n    )\r\n    print(\r\n        f\"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}\"\r\n    )\r\n\r\n    # generate test negative edge set\r\n    start_time = time.time()\r\n    split_mode = \"test\"\r\n    print(\r\n        f\"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}\"\r\n    )\r\n    neg_sampler.generate_negative_samples(\r\n        pos_edges=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path\r\n    )\r\n    print(\r\n        f\"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}\"\r\n    )\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n"
  },
  {
    "path": "tgb/datasets/thgl_github/2024_01/github_extract.py",
    "content": "import json\r\nfrom datetime import datetime\r\nimport glob\r\nimport gzip\r\nimport csv\r\n\"\"\"\r\ngo to https://www.gharchive.org/\r\n\r\nwget https://data.gharchive.org/2024-01-{01..31}-{0..23}.json.gz\r\n\r\nCreates (src, edge_type, dst, time) edges from the GitHub archive JSON file.\r\nUsing the rules from https://arxiv.org/pdf/2007.01231 (page 11)\r\nThe parser creates 18 rules that are in the GITHUB-SE-1Y-Repo dataset. I wrote the meaning of the rules and sources and destination types here.\r\n\"\"\"\r\n\r\n\r\n\r\nrels = {\r\n    \"IC_Created_IC_I\": \"IC_AO_C_I\",\r\n    \"IC_Created_U_IC\": \"U_SO_C_IC\",\r\n    \"I_Opened_U_I\": \"U_SE_O_I\",\r\n    \"I_Opened_I_R\": \"I_AO_O_R\",\r\n    \"I_Closed_U_I\": \"U_SE_C_I\",\r\n    \"I_Closed_I_R\": \"I_AO_C_R\",\r\n    \"I_Reopened_U_I\": \"U_SE_RO_I\",\r\n    \"I_Reopened_I_R\": \"I_AO_RO_R\",\r\n    \"PR_Opened_U_PR\": \"U_SO_O_P\",\r\n    \"PR_Opened_PR_R\": \"P_AO_O_R\",\r\n    \"PR_Closed_U_PR\": \"U_SO_C_P\",\r\n    \"PR_Closed_PR_R\": \"P_AO_C_R\",\r\n    \"PR_Reopened_U_PR\": \"U_SO_R_P\",\r\n    \"PR_Reopened_PR_R\": \"P_AO_R_R\",\r\n    \"PRRC_Created_U_PRC\": \"U_SO_C_PRC\",\r\n    \"PRRC_Created_PRC_PR\": \"PRC_AO_C_P\",\r\n    \"Forked_R_R\": \"R_FO_R\",\r\n    \"AddMember_U_R\": \"U_CO_A_R\",\r\n}\r\n\r\nissue_comment_format = \"/issue_comment/{}\"\r\nissue_format = \"/issue/{}\"\r\nuser_format = \"/user/{}\"\r\nrepo_format = \"/repo/{}\"\r\npull_request_format = \"/pr/{}\"\r\npull_request_review_comment_format = \"/pr_review_comment/{}\"\r\n\r\n\r\ndef str_to_timestamp(time_str):\r\n    dt = datetime.strptime(time_str, \"%Y-%m-%dT%H:%M:%SZ\")\r\n    return int(dt.timestamp())\r\n\r\n\r\ndef parse_issue_comment_events(event):\r\n    try:\r\n        if \"action\" not in event[\"payload\"]:\r\n            return []\r\n        if event[\"payload\"][\"action\"] == \"created\":\r\n            issue_comment_id = event[\"payload\"][\"comment\"][\"id\"]\r\n            issue_id = event[\"payload\"][\"issue\"][\"id\"]\r\n            user_id = event[\"actor\"][\"id\"]\r\n            created_at = str_to_timestamp(event[\"created_at\"])\r\n\r\n            ici_event = [\r\n                issue_comment_format.format(issue_comment_id),\r\n                rels[\"IC_Created_IC_I\"],\r\n                issue_format.format(issue_id),\r\n                created_at,\r\n            ]\r\n            uic_event = [\r\n                user_format.format(user_id),\r\n                rels[\"IC_Created_U_IC\"],\r\n                issue_comment_format.format(issue_comment_id),\r\n                created_at,\r\n            ]\r\n            return [ici_event, uic_event]\r\n        return []\r\n    except:\r\n        return []\r\n\r\n\r\ndef parse_issue_event(event):\r\n    try:\r\n        issue_id = event[\"payload\"][\"issue\"][\"id\"]\r\n        user_id = event[\"actor\"][\"id\"]\r\n        repo_id = event[\"repo\"][\"id\"]\r\n        created_at = str_to_timestamp(event[\"created_at\"])\r\n        action_map = {\r\n            \"opened\": (\"I_Opened_U_I\", \"I_Opened_I_R\"),\r\n            \"closed\": (\"I_Closed_U_I\", \"I_Closed_I_R\"),\r\n            \"reopened\": (\"I_Reopened_U_I\", \"I_Reopened_I_R\"),\r\n        }\r\n        for action, event_rels in action_map.items():\r\n            if event[\"payload\"][\"action\"] == action:\r\n                ui_event = [\r\n                    user_format.format(user_id),\r\n                    rels[event_rels[0]],\r\n                    issue_format.format(issue_id),\r\n                    created_at,\r\n                ]\r\n\r\n                ir_event = [\r\n                    issue_format.format(issue_id),\r\n                    rels[event_rels[1]],\r\n                    repo_format.format(repo_id),\r\n                    created_at,\r\n                ]\r\n                return [ui_event, ir_event]\r\n        return []\r\n    except:\r\n        return []\r\n\r\n\r\ndef parse_pull_request_event(event):\r\n    try:\r\n        pull_request_id = event[\"payload\"][\"pull_request\"][\"id\"]\r\n        user_id = event[\"actor\"][\"id\"]\r\n        repo_id = event[\"repo\"][\"id\"]\r\n        created_at = str_to_timestamp(event[\"created_at\"])\r\n        action_map = {\r\n            \"opened\": (\"PR_Opened_U_PR\", \"PR_Opened_PR_R\"),\r\n            \"closed\": (\"PR_Closed_U_PR\", \"PR_Closed_PR_R\"),\r\n            \"reopened\": (\"PR_Reopened_U_PR\", \"PR_Reopened_PR_R\"),\r\n        }\r\n        for action, event_rels in action_map.items():\r\n            if event[\"payload\"][\"action\"] == action:\r\n                upr_event = [\r\n                    user_format.format(user_id),\r\n                    rels[event_rels[0]],\r\n                    pull_request_format.format(pull_request_id),\r\n                    created_at,\r\n                ]\r\n\r\n                prr_event = [\r\n                    pull_request_format.format(pull_request_id),\r\n                    rels[event_rels[1]],\r\n                    repo_format.format(repo_id),\r\n                    created_at,\r\n                ]\r\n                return [upr_event, prr_event]\r\n        return []\r\n    except:\r\n        return []\r\n\r\ndef parse_pull_request_review_comment_event(event):\r\n    try:\r\n        pull_request_review_comment_id = event[\"payload\"][\"comment\"][\"id\"]\r\n        pull_request_id = event[\"payload\"][\"pull_request\"][\"id\"]\r\n        user_id = event[\"actor\"][\"id\"]\r\n        created_at = str_to_timestamp(event[\"created_at\"])\r\n        if event[\"payload\"][\"action\"] == \"created\":\r\n            uprc_event = [\r\n                user_format.format(user_id),\r\n                rels[\"PRRC_Created_U_PRC\"],\r\n                pull_request_review_comment_format.format(pull_request_review_comment_id),\r\n                created_at,\r\n            ]\r\n\r\n            prcpr_event = [\r\n                pull_request_review_comment_format.format(pull_request_review_comment_id),\r\n                rels[\"PRRC_Created_PRC_PR\"],\r\n                pull_request_format.format(pull_request_id),\r\n                created_at,\r\n            ]\r\n            return [uprc_event, prcpr_event]\r\n        return []\r\n    except:\r\n        return []\r\n\r\n\r\ndef parse_fork_event(event):\r\n    try:\r\n        forkee_repo_id = event[\"payload\"][\"forkee\"][\"id\"]\r\n        forked_repo_id = event[\"repo\"][\"id\"]\r\n        created_at = str_to_timestamp(event[\"created_at\"])\r\n        return [\r\n            [\r\n                repo_format.format(forkee_repo_id),\r\n                rels[\"Forked_R_R\"],\r\n                repo_format.format(forked_repo_id),\r\n                created_at,\r\n            ]\r\n        ]\r\n    except:\r\n        return []\r\n\r\n\r\ndef parse_member_event(event):\r\n    try:\r\n        user_id = event[\"payload\"][\"member\"][\"id\"]\r\n        repo_id = event[\"repo\"][\"id\"]\r\n        created_at = str_to_timestamp(event[\"created_at\"])\r\n        return [\r\n            [\r\n                user_format.format(user_id),\r\n                rels[\"AddMember_U_R\"],\r\n                repo_format.format(repo_id),\r\n                created_at,\r\n            ]\r\n        ]\r\n    except:\r\n        return []\r\n\r\n\r\nevent_handler_dict = {\r\n    \"IssueCommentEvent\": parse_issue_comment_events,\r\n    \"IssuesEvent\": parse_issue_event,\r\n    \"PullRequestEvent\": parse_pull_request_event,\r\n    \"PullRequestReviewCommentEvent\": parse_pull_request_review_comment_event,\r\n    \"ForkEvent\": parse_fork_event,\r\n    \"MemberEvent\": parse_member_event,\r\n}\r\n\r\n\r\ndef parse_event(event):\r\n    event_type = event[\"type\"]\r\n    if event_type in event_handler_dict:\r\n        output_list = event_handler_dict[event_type](event)\r\n        # print(\"Got {} outputs for event type {}\".format(len(output_list), event_type))\r\n    else:\r\n        # print(\"Unknown event type: {}\".format(event_type))\r\n        output_list = []\r\n    return output_list\r\n\r\n\r\ndef parse_file(filename):\r\n    # events = []\r\n    output_dict = {}\r\n    num_edge = 1\r\n    #with open(filename) as f:\r\n    with gzip.open(filename, 'r') as f:\r\n        for i, line in enumerate(f):\r\n            djson = json.loads(line)\r\n            parsed_events = parse_event(djson)\r\n            if (len(parsed_events) > 0):\r\n                for edge in parsed_events:\r\n                    #? ['/user/41898282', 'U_SE_O_I', '/issue/2061196208', 1704085558]\r\n                    ts = int(edge[3])\r\n                    head = edge[0]\r\n                    rel = edge[1]\r\n                    tail = edge[2]\r\n                    if ts not in output_dict:\r\n                        output_dict[ts] = {}\r\n                        output_dict[ts][(head,tail,rel)] = 1\r\n                        num_edge += 1\r\n                    else:\r\n                        if (head,tail,rel) in output_dict[ts]:\r\n                            output_dict[ts][(head,tail,rel)] += 1\r\n                        else:\r\n                            output_dict[ts][(head,tail,rel)] = 1\r\n                            num_edge += 1\r\n    print(\"Parsed {} events\".format(num_edge))\r\n    return output_dict\r\n\r\ndef write2csv(outname, out_dict):\r\n    with open(outname, 'a') as f:\r\n        writer = csv.writer(f, delimiter =',')\r\n        # writer.writerow(['ts', 'head', 'tail', 'relation_type'])\r\n        ts_list = list(out_dict.keys())\r\n        ts_list.sort()\r\n\r\n        for ts in ts_list:\r\n            for edge in out_dict[ts]:\r\n                head = edge[0]\r\n                tail = edge[1]\r\n                relation_type = edge[2]\r\n                row = [ts, head, tail, relation_type]\r\n                writer.writerow(row)\r\n\r\n\r\n\r\ndef main():\r\n    total_edge_dict = {}\r\n    for file in glob.glob(\"*.json.gz\"):\r\n        print (\"processing,\", file)\r\n        edge_dict = parse_file(file)\r\n        # print ('check for edge overlap')\r\n        # print(edge_dict.keys() & total_edge_dict.keys())\r\n        # print (\"-------------------------\")\r\n        #! write to csv after each file is processed. \r\n        # total_edge_dict.update(edge_dict)\r\n        outname = \"github_01_2024.csv\"\r\n        write2csv(outname, edge_dict)\r\n        \r\n\r\n\r\n\r\n\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()"
  },
  {
    "path": "tgb/datasets/thgl_github/2024_02/github_extract.py",
    "content": "import json\r\nfrom datetime import datetime\r\nimport glob\r\nimport gzip\r\nimport csv\r\n\"\"\"\r\ngo to https://www.gharchive.org/\r\n\r\nwget https://data.gharchive.org/2024-01-{01..31}-{0..23}.json.gz\r\n\r\nCreates (src, edge_type, dst, time) edges from the GitHub archive JSON file.\r\nUsing the rules from https://arxiv.org/pdf/2007.01231 (page 11)\r\nThe parser creates 18 rules that are in the GITHUB-SE-1Y-Repo dataset. I wrote the meaning of the rules and sources and destination types here.\r\n\"\"\"\r\n\r\n\r\n\r\nrels = {\r\n    \"IC_Created_IC_I\": \"IC_AO_C_I\",\r\n    \"IC_Created_U_IC\": \"U_SO_C_IC\",\r\n    \"I_Opened_U_I\": \"U_SE_O_I\",\r\n    \"I_Opened_I_R\": \"I_AO_O_R\",\r\n    \"I_Closed_U_I\": \"U_SE_C_I\",\r\n    \"I_Closed_I_R\": \"I_AO_C_R\",\r\n    \"I_Reopened_U_I\": \"U_SE_RO_I\",\r\n    \"I_Reopened_I_R\": \"I_AO_RO_R\",\r\n    \"PR_Opened_U_PR\": \"U_SO_O_P\",\r\n    \"PR_Opened_PR_R\": \"P_AO_O_R\",\r\n    \"PR_Closed_U_PR\": \"U_SO_C_P\",\r\n    \"PR_Closed_PR_R\": \"P_AO_C_R\",\r\n    \"PR_Reopened_U_PR\": \"U_SO_R_P\",\r\n    \"PR_Reopened_PR_R\": \"P_AO_R_R\",\r\n    \"PRRC_Created_U_PRC\": \"U_SO_C_PRC\",\r\n    \"PRRC_Created_PRC_PR\": \"PRC_AO_C_P\",\r\n    \"Forked_R_R\": \"R_FO_R\",\r\n    \"AddMember_U_R\": \"U_CO_A_R\",\r\n}\r\n\r\nissue_comment_format = \"/issue_comment/{}\"\r\nissue_format = \"/issue/{}\"\r\nuser_format = \"/user/{}\"\r\nrepo_format = \"/repo/{}\"\r\npull_request_format = \"/pr/{}\"\r\npull_request_review_comment_format = \"/pr_review_comment/{}\"\r\n\r\n\r\ndef str_to_timestamp(time_str):\r\n    dt = datetime.strptime(time_str, \"%Y-%m-%dT%H:%M:%SZ\")\r\n    return int(dt.timestamp())\r\n\r\n\r\ndef parse_issue_comment_events(event):\r\n    try:\r\n        if \"action\" not in event[\"payload\"]:\r\n            return []\r\n        if event[\"payload\"][\"action\"] == \"created\":\r\n            issue_comment_id = event[\"payload\"][\"comment\"][\"id\"]\r\n            issue_id = event[\"payload\"][\"issue\"][\"id\"]\r\n            user_id = event[\"actor\"][\"id\"]\r\n            created_at = str_to_timestamp(event[\"created_at\"])\r\n\r\n            ici_event = [\r\n                issue_comment_format.format(issue_comment_id),\r\n                rels[\"IC_Created_IC_I\"],\r\n                issue_format.format(issue_id),\r\n                created_at,\r\n            ]\r\n            uic_event = [\r\n                user_format.format(user_id),\r\n                rels[\"IC_Created_U_IC\"],\r\n                issue_comment_format.format(issue_comment_id),\r\n                created_at,\r\n            ]\r\n            return [ici_event, uic_event]\r\n        return []\r\n    except:\r\n        return []\r\n\r\n\r\ndef parse_issue_event(event):\r\n    try:\r\n        issue_id = event[\"payload\"][\"issue\"][\"id\"]\r\n        user_id = event[\"actor\"][\"id\"]\r\n        repo_id = event[\"repo\"][\"id\"]\r\n        created_at = str_to_timestamp(event[\"created_at\"])\r\n        action_map = {\r\n            \"opened\": (\"I_Opened_U_I\", \"I_Opened_I_R\"),\r\n            \"closed\": (\"I_Closed_U_I\", \"I_Closed_I_R\"),\r\n            \"reopened\": (\"I_Reopened_U_I\", \"I_Reopened_I_R\"),\r\n        }\r\n        for action, event_rels in action_map.items():\r\n            if event[\"payload\"][\"action\"] == action:\r\n                ui_event = [\r\n                    user_format.format(user_id),\r\n                    rels[event_rels[0]],\r\n                    issue_format.format(issue_id),\r\n                    created_at,\r\n                ]\r\n\r\n                ir_event = [\r\n                    issue_format.format(issue_id),\r\n                    rels[event_rels[1]],\r\n                    repo_format.format(repo_id),\r\n                    created_at,\r\n                ]\r\n                return [ui_event, ir_event]\r\n        return []\r\n    except:\r\n        return []\r\n\r\n\r\ndef parse_pull_request_event(event):\r\n    try:\r\n        pull_request_id = event[\"payload\"][\"pull_request\"][\"id\"]\r\n        user_id = event[\"actor\"][\"id\"]\r\n        repo_id = event[\"repo\"][\"id\"]\r\n        created_at = str_to_timestamp(event[\"created_at\"])\r\n        action_map = {\r\n            \"opened\": (\"PR_Opened_U_PR\", \"PR_Opened_PR_R\"),\r\n            \"closed\": (\"PR_Closed_U_PR\", \"PR_Closed_PR_R\"),\r\n            \"reopened\": (\"PR_Reopened_U_PR\", \"PR_Reopened_PR_R\"),\r\n        }\r\n        for action, event_rels in action_map.items():\r\n            if event[\"payload\"][\"action\"] == action:\r\n                upr_event = [\r\n                    user_format.format(user_id),\r\n                    rels[event_rels[0]],\r\n                    pull_request_format.format(pull_request_id),\r\n                    created_at,\r\n                ]\r\n\r\n                prr_event = [\r\n                    pull_request_format.format(pull_request_id),\r\n                    rels[event_rels[1]],\r\n                    repo_format.format(repo_id),\r\n                    created_at,\r\n                ]\r\n                return [upr_event, prr_event]\r\n        return []\r\n    except:\r\n        return []\r\n\r\ndef parse_pull_request_review_comment_event(event):\r\n    try:\r\n        pull_request_review_comment_id = event[\"payload\"][\"comment\"][\"id\"]\r\n        pull_request_id = event[\"payload\"][\"pull_request\"][\"id\"]\r\n        user_id = event[\"actor\"][\"id\"]\r\n        created_at = str_to_timestamp(event[\"created_at\"])\r\n        if event[\"payload\"][\"action\"] == \"created\":\r\n            uprc_event = [\r\n                user_format.format(user_id),\r\n                rels[\"PRRC_Created_U_PRC\"],\r\n                pull_request_review_comment_format.format(pull_request_review_comment_id),\r\n                created_at,\r\n            ]\r\n\r\n            prcpr_event = [\r\n                pull_request_review_comment_format.format(pull_request_review_comment_id),\r\n                rels[\"PRRC_Created_PRC_PR\"],\r\n                pull_request_format.format(pull_request_id),\r\n                created_at,\r\n            ]\r\n            return [uprc_event, prcpr_event]\r\n        return []\r\n    except:\r\n        return []\r\n\r\n\r\ndef parse_fork_event(event):\r\n    try:\r\n        forkee_repo_id = event[\"payload\"][\"forkee\"][\"id\"]\r\n        forked_repo_id = event[\"repo\"][\"id\"]\r\n        created_at = str_to_timestamp(event[\"created_at\"])\r\n        return [\r\n            [\r\n                repo_format.format(forkee_repo_id),\r\n                rels[\"Forked_R_R\"],\r\n                repo_format.format(forked_repo_id),\r\n                created_at,\r\n            ]\r\n        ]\r\n    except:\r\n        return []\r\n\r\n\r\ndef parse_member_event(event):\r\n    try:\r\n        user_id = event[\"payload\"][\"member\"][\"id\"]\r\n        repo_id = event[\"repo\"][\"id\"]\r\n        created_at = str_to_timestamp(event[\"created_at\"])\r\n        return [\r\n            [\r\n                user_format.format(user_id),\r\n                rels[\"AddMember_U_R\"],\r\n                repo_format.format(repo_id),\r\n                created_at,\r\n            ]\r\n        ]\r\n    except:\r\n        return []\r\n\r\n\r\nevent_handler_dict = {\r\n    \"IssueCommentEvent\": parse_issue_comment_events,\r\n    \"IssuesEvent\": parse_issue_event,\r\n    \"PullRequestEvent\": parse_pull_request_event,\r\n    \"PullRequestReviewCommentEvent\": parse_pull_request_review_comment_event,\r\n    \"ForkEvent\": parse_fork_event,\r\n    \"MemberEvent\": parse_member_event,\r\n}\r\n\r\n\r\ndef parse_event(event):\r\n    event_type = event[\"type\"]\r\n    if event_type in event_handler_dict:\r\n        output_list = event_handler_dict[event_type](event)\r\n        # print(\"Got {} outputs for event type {}\".format(len(output_list), event_type))\r\n    else:\r\n        # print(\"Unknown event type: {}\".format(event_type))\r\n        output_list = []\r\n    return output_list\r\n\r\n\r\ndef parse_file(filename):\r\n    # events = []\r\n    output_dict = {}\r\n    num_edge = 1\r\n    #with open(filename) as f:\r\n    with gzip.open(filename, 'r') as f:\r\n        for i, line in enumerate(f):\r\n            djson = json.loads(line)\r\n            parsed_events = parse_event(djson)\r\n            if (len(parsed_events) > 0):\r\n                for edge in parsed_events:\r\n                    #? ['/user/41898282', 'U_SE_O_I', '/issue/2061196208', 1704085558]\r\n                    ts = int(edge[3])\r\n                    head = edge[0]\r\n                    rel = edge[1]\r\n                    tail = edge[2]\r\n                    if ts not in output_dict:\r\n                        output_dict[ts] = {}\r\n                        output_dict[ts][(head,tail,rel)] = 1\r\n                        num_edge += 1\r\n                    else:\r\n                        if (head,tail,rel) in output_dict[ts]:\r\n                            output_dict[ts][(head,tail,rel)] += 1\r\n                        else:\r\n                            output_dict[ts][(head,tail,rel)] = 1\r\n                            num_edge += 1\r\n    print(\"Parsed {} events\".format(num_edge))\r\n    return output_dict\r\n\r\ndef write2csv(outname, out_dict):\r\n    with open(outname, 'a') as f:\r\n        writer = csv.writer(f, delimiter =',')\r\n        # writer.writerow(['ts', 'head', 'tail', 'relation_type'])\r\n        ts_list = list(out_dict.keys())\r\n        ts_list.sort()\r\n\r\n        for ts in ts_list:\r\n            for edge in out_dict[ts]:\r\n                head = edge[0]\r\n                tail = edge[1]\r\n                relation_type = edge[2]\r\n                row = [ts, head, tail, relation_type]\r\n                writer.writerow(row)\r\n\r\n\r\n\r\ndef main():\r\n    total_edge_dict = {}\r\n    for file in glob.glob(\"*.json.gz\"):\r\n        print (\"processing,\", file)\r\n        edge_dict = parse_file(file)\r\n        # print ('check for edge overlap')\r\n        # print(edge_dict.keys() & total_edge_dict.keys())\r\n        # print (\"-------------------------\")\r\n        #! write to csv after each file is processed. \r\n        # total_edge_dict.update(edge_dict)\r\n        outname = \"github_02_2024.csv\"\r\n        write2csv(outname, edge_dict)\r\n        \r\n\r\n\r\n\r\n\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()"
  },
  {
    "path": "tgb/datasets/thgl_github/2024_03/github_extract.py",
    "content": "import json\r\nfrom datetime import datetime\r\nimport glob\r\nimport gzip\r\nimport csv\r\n\"\"\"\r\ngo to https://www.gharchive.org/\r\n\r\nwget https://data.gharchive.org/2024-01-{01..31}-{0..23}.json.gz\r\n\r\nCreates (src, edge_type, dst, time) edges from the GitHub archive JSON file.\r\nUsing the rules from https://arxiv.org/pdf/2007.01231 (page 11)\r\nThe parser creates 18 rules that are in the GITHUB-SE-1Y-Repo dataset. I wrote the meaning of the rules and sources and destination types here.\r\n\"\"\"\r\n\r\n\r\n\r\nrels = {\r\n    \"IC_Created_IC_I\": \"IC_AO_C_I\",\r\n    \"IC_Created_U_IC\": \"U_SO_C_IC\",\r\n    \"I_Opened_U_I\": \"U_SE_O_I\",\r\n    \"I_Opened_I_R\": \"I_AO_O_R\",\r\n    \"I_Closed_U_I\": \"U_SE_C_I\",\r\n    \"I_Closed_I_R\": \"I_AO_C_R\",\r\n    \"I_Reopened_U_I\": \"U_SE_RO_I\",\r\n    \"I_Reopened_I_R\": \"I_AO_RO_R\",\r\n    \"PR_Opened_U_PR\": \"U_SO_O_P\",\r\n    \"PR_Opened_PR_R\": \"P_AO_O_R\",\r\n    \"PR_Closed_U_PR\": \"U_SO_C_P\",\r\n    \"PR_Closed_PR_R\": \"P_AO_C_R\",\r\n    \"PR_Reopened_U_PR\": \"U_SO_R_P\",\r\n    \"PR_Reopened_PR_R\": \"P_AO_R_R\",\r\n    \"PRRC_Created_U_PRC\": \"U_SO_C_PRC\",\r\n    \"PRRC_Created_PRC_PR\": \"PRC_AO_C_P\",\r\n    \"Forked_R_R\": \"R_FO_R\",\r\n    \"AddMember_U_R\": \"U_CO_A_R\",\r\n}\r\n\r\nissue_comment_format = \"/issue_comment/{}\"\r\nissue_format = \"/issue/{}\"\r\nuser_format = \"/user/{}\"\r\nrepo_format = \"/repo/{}\"\r\npull_request_format = \"/pr/{}\"\r\npull_request_review_comment_format = \"/pr_review_comment/{}\"\r\n\r\n\r\ndef str_to_timestamp(time_str):\r\n    dt = datetime.strptime(time_str, \"%Y-%m-%dT%H:%M:%SZ\")\r\n    return int(dt.timestamp())\r\n\r\n\r\ndef parse_issue_comment_events(event):\r\n    try:\r\n        if \"action\" not in event[\"payload\"]:\r\n            return []\r\n        if event[\"payload\"][\"action\"] == \"created\":\r\n            issue_comment_id = event[\"payload\"][\"comment\"][\"id\"]\r\n            issue_id = event[\"payload\"][\"issue\"][\"id\"]\r\n            user_id = event[\"actor\"][\"id\"]\r\n            created_at = str_to_timestamp(event[\"created_at\"])\r\n\r\n            ici_event = [\r\n                issue_comment_format.format(issue_comment_id),\r\n                rels[\"IC_Created_IC_I\"],\r\n                issue_format.format(issue_id),\r\n                created_at,\r\n            ]\r\n            uic_event = [\r\n                user_format.format(user_id),\r\n                rels[\"IC_Created_U_IC\"],\r\n                issue_comment_format.format(issue_comment_id),\r\n                created_at,\r\n            ]\r\n            return [ici_event, uic_event]\r\n        return []\r\n    except:\r\n        return []\r\n\r\n\r\ndef parse_issue_event(event):\r\n    try:\r\n        issue_id = event[\"payload\"][\"issue\"][\"id\"]\r\n        user_id = event[\"actor\"][\"id\"]\r\n        repo_id = event[\"repo\"][\"id\"]\r\n        created_at = str_to_timestamp(event[\"created_at\"])\r\n        action_map = {\r\n            \"opened\": (\"I_Opened_U_I\", \"I_Opened_I_R\"),\r\n            \"closed\": (\"I_Closed_U_I\", \"I_Closed_I_R\"),\r\n            \"reopened\": (\"I_Reopened_U_I\", \"I_Reopened_I_R\"),\r\n        }\r\n        for action, event_rels in action_map.items():\r\n            if event[\"payload\"][\"action\"] == action:\r\n                ui_event = [\r\n                    user_format.format(user_id),\r\n                    rels[event_rels[0]],\r\n                    issue_format.format(issue_id),\r\n                    created_at,\r\n                ]\r\n\r\n                ir_event = [\r\n                    issue_format.format(issue_id),\r\n                    rels[event_rels[1]],\r\n                    repo_format.format(repo_id),\r\n                    created_at,\r\n                ]\r\n                return [ui_event, ir_event]\r\n        return []\r\n    except:\r\n        return []\r\n\r\n\r\ndef parse_pull_request_event(event):\r\n    try:\r\n        pull_request_id = event[\"payload\"][\"pull_request\"][\"id\"]\r\n        user_id = event[\"actor\"][\"id\"]\r\n        repo_id = event[\"repo\"][\"id\"]\r\n        created_at = str_to_timestamp(event[\"created_at\"])\r\n        action_map = {\r\n            \"opened\": (\"PR_Opened_U_PR\", \"PR_Opened_PR_R\"),\r\n            \"closed\": (\"PR_Closed_U_PR\", \"PR_Closed_PR_R\"),\r\n            \"reopened\": (\"PR_Reopened_U_PR\", \"PR_Reopened_PR_R\"),\r\n        }\r\n        for action, event_rels in action_map.items():\r\n            if event[\"payload\"][\"action\"] == action:\r\n                upr_event = [\r\n                    user_format.format(user_id),\r\n                    rels[event_rels[0]],\r\n                    pull_request_format.format(pull_request_id),\r\n                    created_at,\r\n                ]\r\n\r\n                prr_event = [\r\n                    pull_request_format.format(pull_request_id),\r\n                    rels[event_rels[1]],\r\n                    repo_format.format(repo_id),\r\n                    created_at,\r\n                ]\r\n                return [upr_event, prr_event]\r\n        return []\r\n    except:\r\n        return []\r\n\r\ndef parse_pull_request_review_comment_event(event):\r\n    try:\r\n        pull_request_review_comment_id = event[\"payload\"][\"comment\"][\"id\"]\r\n        pull_request_id = event[\"payload\"][\"pull_request\"][\"id\"]\r\n        user_id = event[\"actor\"][\"id\"]\r\n        created_at = str_to_timestamp(event[\"created_at\"])\r\n        if event[\"payload\"][\"action\"] == \"created\":\r\n            uprc_event = [\r\n                user_format.format(user_id),\r\n                rels[\"PRRC_Created_U_PRC\"],\r\n                pull_request_review_comment_format.format(pull_request_review_comment_id),\r\n                created_at,\r\n            ]\r\n\r\n            prcpr_event = [\r\n                pull_request_review_comment_format.format(pull_request_review_comment_id),\r\n                rels[\"PRRC_Created_PRC_PR\"],\r\n                pull_request_format.format(pull_request_id),\r\n                created_at,\r\n            ]\r\n            return [uprc_event, prcpr_event]\r\n        return []\r\n    except:\r\n        return []\r\n\r\n\r\ndef parse_fork_event(event):\r\n    try:\r\n        forkee_repo_id = event[\"payload\"][\"forkee\"][\"id\"]\r\n        forked_repo_id = event[\"repo\"][\"id\"]\r\n        created_at = str_to_timestamp(event[\"created_at\"])\r\n        return [\r\n            [\r\n                repo_format.format(forkee_repo_id),\r\n                rels[\"Forked_R_R\"],\r\n                repo_format.format(forked_repo_id),\r\n                created_at,\r\n            ]\r\n        ]\r\n    except:\r\n        return []\r\n\r\n\r\ndef parse_member_event(event):\r\n    try:\r\n        user_id = event[\"payload\"][\"member\"][\"id\"]\r\n        repo_id = event[\"repo\"][\"id\"]\r\n        created_at = str_to_timestamp(event[\"created_at\"])\r\n        return [\r\n            [\r\n                user_format.format(user_id),\r\n                rels[\"AddMember_U_R\"],\r\n                repo_format.format(repo_id),\r\n                created_at,\r\n            ]\r\n        ]\r\n    except:\r\n        return []\r\n\r\n\r\nevent_handler_dict = {\r\n    \"IssueCommentEvent\": parse_issue_comment_events,\r\n    \"IssuesEvent\": parse_issue_event,\r\n    \"PullRequestEvent\": parse_pull_request_event,\r\n    \"PullRequestReviewCommentEvent\": parse_pull_request_review_comment_event,\r\n    \"ForkEvent\": parse_fork_event,\r\n    \"MemberEvent\": parse_member_event,\r\n}\r\n\r\n\r\ndef parse_event(event):\r\n    event_type = event[\"type\"]\r\n    if event_type in event_handler_dict:\r\n        output_list = event_handler_dict[event_type](event)\r\n        # print(\"Got {} outputs for event type {}\".format(len(output_list), event_type))\r\n    else:\r\n        # print(\"Unknown event type: {}\".format(event_type))\r\n        output_list = []\r\n    return output_list\r\n\r\n\r\ndef parse_file(filename):\r\n    # events = []\r\n    output_dict = {}\r\n    num_edge = 1\r\n    #with open(filename) as f:\r\n    with gzip.open(filename, 'r') as f:\r\n        for i, line in enumerate(f):\r\n            djson = json.loads(line)\r\n            parsed_events = parse_event(djson)\r\n            if (len(parsed_events) > 0):\r\n                for edge in parsed_events:\r\n                    #? ['/user/41898282', 'U_SE_O_I', '/issue/2061196208', 1704085558]\r\n                    ts = int(edge[3])\r\n                    head = edge[0]\r\n                    rel = edge[1]\r\n                    tail = edge[2]\r\n                    if ts not in output_dict:\r\n                        output_dict[ts] = {}\r\n                        output_dict[ts][(head,tail,rel)] = 1\r\n                        num_edge += 1\r\n                    else:\r\n                        if (head,tail,rel) in output_dict[ts]:\r\n                            output_dict[ts][(head,tail,rel)] += 1\r\n                        else:\r\n                            output_dict[ts][(head,tail,rel)] = 1\r\n                            num_edge += 1\r\n    print(\"Parsed {} events\".format(num_edge))\r\n    return output_dict\r\n\r\ndef write2csv(outname, out_dict):\r\n    with open(outname, 'a') as f:\r\n        writer = csv.writer(f, delimiter =',')\r\n        # writer.writerow(['ts', 'head', 'tail', 'relation_type'])\r\n        ts_list = list(out_dict.keys())\r\n        ts_list.sort()\r\n\r\n        for ts in ts_list:\r\n            for edge in out_dict[ts]:\r\n                head = edge[0]\r\n                tail = edge[1]\r\n                relation_type = edge[2]\r\n                row = [ts, head, tail, relation_type]\r\n                writer.writerow(row)\r\n\r\n\r\n\r\ndef main():\r\n    total_edge_dict = {}\r\n    for file in glob.glob(\"*.json.gz\"):\r\n        print (\"processing,\", file)\r\n        edge_dict = parse_file(file)\r\n        # print ('check for edge overlap')\r\n        # print(edge_dict.keys() & total_edge_dict.keys())\r\n        # print (\"-------------------------\")\r\n        #! write to csv after each file is processed. \r\n        # total_edge_dict.update(edge_dict)\r\n        outname = \"github_03_2024.csv\"\r\n        write2csv(outname, edge_dict)\r\n        \r\n\r\n\r\n\r\n\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()"
  },
  {
    "path": "tgb/datasets/thgl_github/extract_subset.py",
    "content": "import csv\r\n\r\n\r\ndef load_edgelist(file_path, freq_threshold=5):\r\n    \"\"\"\r\n    ts, head, tail, relation_type\r\n    1704085200,/user/34452971,/pr/1660752740,U_SO_C_P\r\n    \"\"\"\r\n    first_row = True\r\n    edge_dict = {}\r\n    num_nodes = 0\r\n    num_edges = 0\r\n    num_rels = 0\r\n    node_dict = {}\r\n    edge_freq_dict = {}\r\n    num_lines = 0\r\n\r\n\r\n    #! identify node type with least amount of edges\r\n    node_type_freq = {}\r\n\r\n    with open(file_path, 'r') as f:\r\n        reader = csv.reader(f, delimiter =',')\r\n        for row in reader: \r\n            if first_row:\r\n                first_row = False\r\n                continue\r\n            ts = int(row[0])\r\n            head = row[1]\r\n            head_type = head.split(\"/\")[1]\r\n            if (head_type not in node_type_freq):\r\n                node_type_freq[head_type] = 1\r\n            else:\r\n                node_type_freq[head_type] += 1\r\n            tail = row[2]\r\n            tail_type = tail.split(\"/\")[1]\r\n            if (tail_type not in node_type_freq):\r\n                node_type_freq[tail_type] = 1\r\n            else:\r\n                node_type_freq[tail_type] += 1\r\n            relation_type = row[3]\r\n            if head not in node_dict:\r\n                node_dict[head] = 1\r\n                num_nodes += 1\r\n            else:\r\n                node_dict[head] += 1\r\n\r\n            if tail not in node_dict:\r\n                node_dict[tail] = 1\r\n                num_nodes += 1\r\n            else:\r\n                node_dict[tail] += 1\r\n\r\n            if relation_type not in edge_freq_dict:\r\n                edge_freq_dict[relation_type] = 1\r\n                num_rels += 1\r\n            else:\r\n                edge_freq_dict[relation_type] += 1\r\n            num_lines += 1\r\n    print (\"there are \", num_lines, \" edges\")\r\n    print (\"there are \", num_nodes, \" nodes\")\r\n    print (\"there are \", num_rels, \" relations\")\r\n\r\n    node_freq5 = 0\r\n    node_freq10 = 0\r\n    node_freq100 = 0\r\n    node_freq1000 = 0\r\n    low_freq_dict = {}\r\n    for k, v in node_dict.items():\r\n        if v <= freq_threshold:\r\n            low_freq_dict[k] = 1\r\n            node_freq5 += 1\r\n        if v >= 10:\r\n            node_freq10 += 1\r\n        if v >= 100:\r\n            node_freq100 += 1\r\n        if v >= 1000:\r\n            node_freq1000 += 1\r\n    print (\"there are \", node_freq5, \" nodes with frequency <= \", freq_threshold, \" (inclusive)\")\r\n    print (\"there are \", node_freq10, \" nodes with frequency >= 10\")\r\n    print (\"there are \", node_freq100, \" nodes with frequency >= 100\")\r\n    print (\"there are \", node_freq1000, \" nodes with frequency >= 1000\")\r\n    # return node_freq10_dict\r\n    return low_freq_dict, node_type_freq\r\n\r\n\r\n\r\ndef subset_by_node(file_path, low_freq_dict):\r\n    first_row = True\r\n    edge_dict = {}\r\n    with open(file_path, 'r') as f:\r\n        reader = csv.reader(f, delimiter =',')\r\n        for row in reader: \r\n            if first_row:\r\n                first_row = False\r\n                continue\r\n            ts = int(row[0])\r\n            head = row[1]\r\n            tail = row[2]\r\n            relation_type = row[3]\r\n\r\n            #! remove any edges that belongs any node with degree one\r\n            if (head in low_freq_dict) or (tail in low_freq_dict):\r\n                continue\r\n\r\n            # if (head in node_dict) or (tail in node_dict):\r\n            if ts not in edge_dict:\r\n                edge_dict[ts] = {}\r\n            if (head,tail,relation_type) not in edge_dict[ts]:\r\n                edge_dict[ts][(head,tail,relation_type)] = 1\r\n            else:\r\n                edge_dict[ts][(head,tail,relation_type)] += 1\r\n    return edge_dict\r\n\r\n\r\ndef subset_by_node_type(file_path, remove_node_type_dict, low_freq_dict=None):\r\n    first_row = True\r\n    edge_dict = {}\r\n    node_dict = {}\r\n    num_edges = 0\r\n    if (low_freq_dict is not None):\r\n        check_low_freq = True\r\n\r\n    with open(file_path, 'r') as f:\r\n        reader = csv.reader(f, delimiter =',')\r\n        for row in reader: \r\n            if first_row:\r\n                first_row = False\r\n                continue\r\n            ts = int(row[0])\r\n            head = row[1]\r\n            tail = row[2]\r\n            if (head in low_freq_dict) or (tail in low_freq_dict):\r\n                continue\r\n\r\n            head_type = head.split(\"/\")[1]\r\n            tail_type = tail.split(\"/\")[1]\r\n            relation_type = row[3]\r\n\r\n\r\n\r\n            if (head_type in remove_node_type_dict) or (tail_type in remove_node_type_dict):\r\n                continue\r\n\r\n            if (head not in node_dict):\r\n                node_dict[head] = 1\r\n            if (tail not in node_dict):\r\n                node_dict[tail] = 1\r\n            num_edges += 1\r\n            if ts not in edge_dict:\r\n                edge_dict[ts] = {}\r\n            if (head,tail,relation_type) not in edge_dict[ts]:\r\n                edge_dict[ts][(head,tail,relation_type)] = 1\r\n            else:\r\n                edge_dict[ts][(head,tail,relation_type)] += 1\r\n\r\n    print (\"there are \", num_edges, \" edges in the output file\")\r\n    print (\"there are \", len(node_dict), \" nodes in the output file\")\r\n    return edge_dict\r\n\r\n\r\n\r\n\r\ndef write2csv(outname, out_dict):\r\n    num_edges = 0\r\n    with open(outname, 'w') as f:\r\n        writer = csv.writer(f, delimiter =',')\r\n        writer.writerow(['ts', 'head', 'tail', 'relation_type'])\r\n        ts_list = list(out_dict.keys())\r\n        ts_list.sort()\r\n\r\n        for ts in ts_list:\r\n            for edge in out_dict[ts]:\r\n                head = edge[0]\r\n                tail = edge[1]\r\n                relation_type = edge[2]\r\n                row = [ts, head, tail, relation_type]\r\n                writer.writerow(row)\r\n                num_edges += 1\r\n    print (\"there are \", num_edges, \" edges in the output file\")\r\n\r\n\r\n\r\n\r\n\r\ndef combine_edgelist(file_paths, outname):\r\n    with open(outname, 'w') as f:\r\n        writer = csv.writer(f, delimiter =',')\r\n        writer.writerow(['ts', 'head', 'tail', 'relation_type'])\r\n        for file_path in file_paths:\r\n            first_row = True\r\n            with open(file_path, 'r') as f:\r\n                reader = csv.reader(f, delimiter =',')\r\n                for row in reader: \r\n                    if first_row:\r\n                        first_row = False\r\n                        continue\r\n                    ts = int(row[0])\r\n                    head = row[1]\r\n                    tail = row[2]\r\n                    relation_type = row[3]\r\n                    writer.writerow([ts, head, tail, relation_type])\r\n\r\n\r\n\r\n\r\ndef main():\r\n    file_path = \"github_03_2024.csv\"\r\n    freq_threshold = 2\r\n    low_freq_dict, node_type_dict = load_edgelist(file_path, freq_threshold=freq_threshold)\r\n\r\n    remove_node_type_dict = {'issue_comment':1, 'pr_review_comment':1} #{'issue_comment':1, 'pr_review_comment':1, 'issue':1} \r\n    edge_dict = subset_by_node_type(file_path, remove_node_type_dict, low_freq_dict=low_freq_dict)\r\n    # edge_dict = subset_by_node(file_path, low_freq_dict=low_freq_dict)\r\n    outname = \"github_03_2024_subset.csv\"\r\n    write2csv(outname, edge_dict)\r\n\r\n    # file_paths = [\"github_01_2024_subset.csv\", \"github_02_2024_subset.csv\", \"github_03_2024_subset.csv\"]\r\n    # outname = \"thgl-github_edges.csv\"\r\n    # combine_edgelist(file_paths, outname)\r\n\r\n\r\n\r\n\r\n\r\n    \r\nif __name__ == \"__main__\":\r\n    main()"
  },
  {
    "path": "tgb/datasets/thgl_github/thgl_github.py",
    "content": "import csv\r\nimport datetime\r\nimport glob, os\r\n\r\ndef load_csv_raw(fname):\r\n    \"\"\"\r\n    load the raw csv file and merge them into one\r\n    \"\"\"\r\n    out_dict = {}\r\n    num_lines = 0\r\n    with open(fname, 'r') as f:\r\n        reader = csv.reader(f, delimiter ='\\t')\r\n        #* /user/10746682\tU_SO_C_IC\t/issue_comment/455195715\t1547754198\r\n        for row in reader: \r\n            head = row[0]\r\n            relation_type = row[1]\r\n            tail = row[2]\r\n            ts = int(row[3])\r\n            if (ts in out_dict):\r\n                if (head, tail, relation_type) in out_dict[ts]:\r\n                    out_dict[ts][(head, tail, relation_type)] += 1\r\n                else:\r\n                    out_dict[ts][(head, tail, relation_type)] = 1\r\n            else:\r\n                out_dict[ts] = {}\r\n                out_dict[ts][(head, tail, relation_type)] = 1\r\n            num_lines += 1\r\n    return out_dict, num_lines\r\n\r\n\r\ndef write2csv(outname, out_dict):\r\n    with open(outname, 'w') as f:\r\n        writer = csv.writer(f, delimiter =',')\r\n        writer.writerow(['ts', 'head', 'tail', 'relation_type'])\r\n        ts_list = list(out_dict.keys())\r\n        ts_list.sort()\r\n\r\n        for ts in ts_list:\r\n            for edge in out_dict[ts]:\r\n                head = edge[0]\r\n                tail = edge[1]\r\n                relation_type = edge[2]\r\n                row = [ts, head, tail, relation_type]\r\n                writer.writerow(row)\r\n\r\n\r\ndef load_edgelist(fname):\r\n    \"\"\"\r\n    load the edgelist\r\n    \"\"\"\r\n    node_dict = {} # {node_name: node_id}\r\n    node_type_dict = {} # {node_id: node_type}\r\n    rel_type_dict = {}\r\n    edge_dict = {} # {edge: edge_type}\r\n    node_type_mapping = {}\r\n    num_edges = 0\r\n    with open(fname, 'r') as f:\r\n        reader = csv.reader(f, delimiter =',')\r\n        first_row = True\r\n        for row in reader:\r\n            if first_row:\r\n                first_row = False\r\n                continue\r\n            ts = int(row[0])\r\n            head = row[1]\r\n            tail = row[2]\r\n            relation_type = row[3]\r\n\r\n            head_strs = head.split('/')\r\n            tail_strs = tail.split('/')\r\n            head_type = head_strs[1]\r\n            tail_type = tail_strs[1]\r\n\r\n            if head_type not in node_type_mapping:\r\n                node_type_mapping[head_type] = len(node_type_mapping)\r\n            if tail_type not in node_type_mapping:\r\n                node_type_mapping[tail_type] = len(node_type_mapping)\r\n\r\n            if head not in node_dict:\r\n                node_dict[head] = len(node_dict)\r\n                node_type_dict[node_dict[head]] = node_type_mapping[head_type]\r\n            if tail not in node_dict:\r\n                node_dict[tail] = len(node_dict)\r\n                node_type_dict[node_dict[tail]] = node_type_mapping[tail_type]\r\n            if relation_type not in rel_type_dict:\r\n                rel_type_dict[relation_type] = len(rel_type_dict)\r\n            if ts not in edge_dict:\r\n                edge_dict[ts] = {}\r\n            edge_dict[ts][(node_dict[head], node_dict[tail], rel_type_dict[relation_type])] = 1\r\n            num_edges += 1\r\n    print (\"there are {} nodes\".format(len(node_dict)))\r\n    print (\"there are {} edges\".format(num_edges))\r\n\r\n    return node_dict, node_type_dict, edge_dict, rel_type_dict, node_type_mapping\r\n\r\n\r\n\r\ndef writeNodeType(node_type_dict, outname):\r\n    r\"\"\"\r\n    write the node type mapping to a file\r\n    \"\"\"\r\n    with open(outname, 'w') as f:\r\n        writer = csv.writer(f, delimiter =',')\r\n        writer.writerow(['node_id', 'type'])\r\n        for key in node_type_dict:\r\n            writer.writerow([key, node_type_dict[key]])\r\n\r\n\r\ndef writeEdgeTypeMapping(edge_type_dict, outname):\r\n    r\"\"\"\r\n    write the edge type mapping to a file\r\n    \"\"\"\r\n    with open(outname, 'w') as f:\r\n        writer = csv.writer(f, delimiter =',')\r\n        writer.writerow(['edge_id', 'type'])\r\n        for key in edge_type_dict:\r\n            writer.writerow([key, edge_type_dict[key]])\r\n\r\n\r\ndef writeNodeTypeMapping(node_type_dict, outname):\r\n    r\"\"\"\r\n    write the edge type mapping to a file\r\n    \"\"\"\r\n    with open(outname, 'w') as f:\r\n        writer = csv.writer(f, delimiter =',')\r\n        writer.writerow(['node_type_id', 'type'])\r\n        for key in node_type_dict:\r\n            writer.writerow([key, node_type_dict[key]])\r\n\r\n\r\ndef write2edgelist(out_dict, outname):\r\n    r\"\"\"\r\n    Write the dictionary to a csv file\r\n    \"\"\"\r\n    num_lines = 0\r\n    with open(outname, 'w') as f:\r\n        writer = csv.writer(f, delimiter =',')\r\n        writer.writerow(['timestamp', 'head', 'tail', 'relation_type'])\r\n        dates = list(out_dict.keys())\r\n        dates.sort()\r\n        for date in dates:\r\n            for edge in out_dict[date]:\r\n                head = edge[0]\r\n                tail = edge[1]\r\n                relation_type = int(edge[2])\r\n                row = [date, head, tail, relation_type]\r\n                writer.writerow(row)\r\n                num_lines += 1\r\n    print (\"there are {} lines in the file\".format(num_lines))\r\n\r\n\r\n\r\ndef main():\r\n    # \"\"\"\r\n    # concatenate edgelists\r\n    # \"\"\"\r\n    # total_lines = 0\r\n    # total_edge_dict = {} \r\n    # #1. find all files with .txt in the folder\r\n    # for file in glob.glob(\"*.txt\"):\r\n    #     # outname = file[7:11] + \"_edgelist.csv\"\r\n    #     print (\"processing\", file)\r\n    #     edge_dict, num_lines = load_csv_raw(file)\r\n    #     total_lines += num_lines\r\n    #     print (\"-----------------------------------\")\r\n    #     print (\"file, \", file)\r\n    #     print (\"number of lines, \", num_lines)\r\n    #     print (\"number of ts, \", len(edge_dict))\r\n    #     print (\"-----------------------------------\")\r\n    #     total_edge_dict.update(edge_dict)\r\n    # outname = \"all_edgelist.csv\"\r\n    # write2csv(outname, total_edge_dict)\r\n\r\n\r\n    fname =\"github_03_2024_subset.csv\"#\"github_01_2024_subset.csv\" #\"thgl-github_edges.csv\" #\"all_edgelist.csv\" \r\n    node_dict, node_type_dict, edge_dict, edge_type_dict, node_type_mapping = load_edgelist(fname)\r\n    write2edgelist (edge_dict, \"thgl-github_edgelist.csv\")\r\n    writeNodeType(node_type_dict, \"thgl-github_nodetype.csv\")\r\n    writeEdgeTypeMapping(edge_type_dict, \"thgl-github_edgemapping.csv\")\r\n    writeNodeTypeMapping(node_type_mapping, \"thgl-github_nodemapping.csv\")\r\n    \r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()"
  },
  {
    "path": "tgb/datasets/thgl_github/thgl_github_ns_gen.py",
    "content": "import time\r\nfrom tgb.linkproppred.thg_negative_generator import THGNegativeEdgeGenerator\r\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\r\n\r\n\r\ndef main():\r\n    r\"\"\"\r\n    Generate negative edges for the validation or test phase\r\n    \"\"\"\r\n    print(\"*** Negative Sample Generation ***\")\r\n\r\n    # setting the required parameters\r\n    num_neg_e_per_pos = 20 #1000\r\n    neg_sample_strategy = \"node-type-filtered\"\r\n    rnd_seed = 42\r\n\r\n\r\n    name = \"thgl-github\"\r\n    dataset = PyGLinkPropPredDataset(name=name, root=\"datasets\")\r\n    train_mask = dataset.train_mask\r\n    val_mask = dataset.val_mask\r\n    test_mask = dataset.test_mask\r\n    data = dataset.get_TemporalData()\r\n\r\n\r\n\r\n    data_splits = {}\r\n    data_splits['train'] = data[train_mask]\r\n    data_splits['val'] = data[val_mask]\r\n    data_splits['test'] = data[test_mask]\r\n\r\n    min_node_idx = min(int(data.src.min()), int(data.dst.min()))\r\n    max_node_idx = max(int(data.src.max()), int(data.dst.max()))\r\n\r\n    neg_sampler = THGNegativeEdgeGenerator(\r\n        dataset_name=name,\r\n        first_node_id=min_node_idx,\r\n        last_node_id=max_node_idx,\r\n        node_type=dataset.node_type,\r\n        num_neg_e=num_neg_e_per_pos,\r\n        strategy=neg_sample_strategy,\r\n        rnd_seed=rnd_seed,\r\n        edge_data=data,\r\n    )\r\n\r\n   \r\n    # generate evaluation set\r\n    partial_path = \".\"\r\n    # generate validation negative edge set\r\n    start_time = time.time()\r\n    split_mode = \"val\"\r\n    print(\r\n        f\"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}\"\r\n    )\r\n    neg_sampler.generate_negative_samples(\r\n        pos_edges=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path\r\n    )\r\n    print(\r\n        f\"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}\"\r\n    )\r\n\r\n    # generate test negative edge set\r\n    start_time = time.time()\r\n    split_mode = \"test\"\r\n    print(\r\n        f\"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}\"\r\n    )\r\n    neg_sampler.generate_negative_samples(\r\n        pos_edges=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path\r\n    )\r\n    print(\r\n        f\"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}\"\r\n    )\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n"
  },
  {
    "path": "tgb/datasets/thgl_myket/thgl_myket.py",
    "content": "import dateutil.parser as dparser\r\nimport csv\r\nimport matplotlib.pyplot as plt\r\nimport numpy as np\r\nfrom tqdm import tqdm\r\nfrom os import listdir\r\nfrom datetime import datetime\r\n\r\n\r\n\r\ndef date2ts(date_str: str) -> float:\r\n    r\"\"\"\r\n    convert date string to timestamp\r\n    \"\"\"\r\n    TIME_FORMAT = \"%Y-%m-%d %H:%M:%S.%f\"\r\n    date_cur = datetime.strptime(date_str, TIME_FORMAT)\r\n    return int(date_cur.timestamp())\r\n\r\n\r\n\"\"\"\r\napp_name\tuser_id\tdatetime\tis_update\r\ncom.cocoplay.erpetvet\t392863962\t2020-06-17 23:55:17.460\t0\r\ncom.titan.royal\t790103760\t2020-06-17 23:55:19.583\t0\r\ncom.tencent.ig\t-1651723014\t2020-06-17 23:55:20.647\t0\r\ncom.cyberlink.youperfect\t-2116095669\t2020-06-17 23:55:20.723\t1\r\ncom.whatsapp\t1591275459\t2020-06-17 23:55:20.820\t0\r\ncom.nexttechgamesstudio.house.paint.craft.coloring.book.pages\t-984956295\t2020-06-17 23:55:21.840\t0\r\ncom.lenovo.anyshare.gps\t1643649087\t2020-06-17 23:55:21.853\t1\r\ncom.kurankarim.mp3\t1316745267\t2020-06-17 23:55:22.537\t0\r\ncom.google.android.dialer\t239675079\t2020-06-17 23:55:22.950\t1\r\ncom.ma.textgraphy\t-951808761\t2020-06-17 23:55:22.977\t0\r\nir.shahbaz.SHZToolBox\t1643649087\t2020-06-17 23:55:22.987\t1\r\npicture.instagram.makers\t-1898448882\t2020-06-17 23:55:23.010\t0\r\nir.shahbaz.SHZToolBox\t780669111\t2020-06-17 23:55:23.600\t1\r\nfantasy.survival.game.rpg\t1849120437\t2020-06-17 23:55:23.980\t0\r\ncom.ags.flying.muscle.car.transform.robot.war.robot.games\t1751574033\t2020-06-17 23:55:24.680\t0\r\n\"\"\"\r\ndef read_csv2dict(fname):\r\n    r\"\"\"\r\n    load from the raw data and retrieve, timestamp, head, tail, relation \r\n    also return a mapping from node text to node id\r\n    convert all dates into unix timestamps\r\n    \"\"\"\r\n    out_dict = {}\r\n    first_row = True\r\n    num_lines = 0\r\n    with open(fname, 'r') as f:\r\n        reader = csv.reader(f, delimiter ='\\t')\r\n        for row in reader: \r\n            if first_row:\r\n                first_row = False\r\n                continue\r\n            app = row[0]\r\n            user = row[1]\r\n            date = row[2]\r\n            is_update = int(row[3])\r\n            if (len(date) == 0 or date is None):\r\n                continue\r\n            else:\r\n                ts = date2ts(date)\r\n                head = user\r\n                tail = app\r\n                if (ts not in out_dict):\r\n                    out_dict[ts] = {(head,tail,is_update): 1}\r\n                else:\r\n                    out_dict[ts][(head,tail,is_update)] = 1\r\n                num_lines += 1\r\n    print (\"there are {} lines in the file\".format(num_lines))\r\n    return out_dict\r\n\r\n\r\n# def writeIDmapping(id_dict, outname):\r\n#     r\"\"\"\r\n#     write the id mapping to a file\r\n#     \"\"\"\r\n#     with open(outname, 'w') as f:\r\n#         writer = csv.writer(f, delimiter =',')\r\n#         writer.writerow(['ID', 'name'])\r\n#         for key in id_dict:\r\n#             writer.writerow([key, id_dict[key]])\r\n\r\n\r\ndef edge2nodetype(out_dict):\r\n    r\"\"\"\r\n    1. remap node id of nodes\r\n    2. output the node_type file\r\n    \"\"\"\r\n    node_dict = {} # {node_name: node_id}\r\n    node_type_dict = {} # {node_id: node_type}\r\n    edge_dict = {} # {edge: edge_type}\r\n    dates = list(out_dict.keys())\r\n    dates.sort()\r\n    for date in dates:\r\n        for edge in out_dict[date]:\r\n                head = edge[0] # user node\r\n                tail = edge[1] # app node\r\n                relation_type = int(edge[2])\r\n                if head not in node_dict:\r\n                    node_dict[head] = len(node_dict)\r\n                    node_type_dict[node_dict[head]] = 0 #user\r\n                if tail not in node_dict:\r\n                    node_dict[tail] = len(node_dict)\r\n                    node_type_dict[node_dict[tail]] = 1 #app\r\n                if date not in edge_dict:\r\n                    edge_dict[date] = {}\r\n                edge_dict[date][(node_dict[head], node_dict[tail], relation_type)] = 1\r\n    return node_dict, node_type_dict, edge_dict\r\n                \r\n\r\ndef writeNodeType(node_type_dict, outname):\r\n    r\"\"\"\r\n    write the node type mapping to a file\r\n    \"\"\"\r\n    with open(outname, 'w') as f:\r\n        writer = csv.writer(f, delimiter =',')\r\n        writer.writerow(['node_id', 'type'])\r\n        for key in node_type_dict:\r\n            writer.writerow([key, node_type_dict[key]])\r\n\r\n\r\n\r\ndef write2edgelist(out_dict, outname):\r\n    r\"\"\"\r\n    Write the dictionary to a csv file\r\n    \"\"\"\r\n    num_lines = 0\r\n    with open(outname, 'w') as f:\r\n        writer = csv.writer(f, delimiter =',')\r\n        writer.writerow(['timestamp', 'head', 'tail', 'relation_type'])\r\n        dates = list(out_dict.keys())\r\n        dates.sort()\r\n        for date in dates:\r\n            for edge in out_dict[date]:\r\n                head = edge[0]\r\n                tail = edge[1]\r\n                relation_type = int(edge[2])\r\n                row = [date, head, tail, relation_type]\r\n                writer.writerow(row)\r\n                num_lines += 1\r\n    print (\"there are {} lines in the file\".format(num_lines))\r\n        \r\n\r\n\"\"\"\r\nneed to have edgelist with n_ids \r\nneed to have a node_type file to document which nodes are which type\r\n\"\"\"\r\n\r\n\r\ndef main():\r\n    fname = \"raw_myket_input-001.csv\"\r\n    out_dict = read_csv2dict(fname)\r\n    # write2edgelist (out_dict, \"thgl-myket_edgelist.csv\")\r\n    node_dict, node_type_dict, edge_dict = edge2nodetype(out_dict)\r\n\r\n    write2edgelist (edge_dict, \"thgl-myket_edgelist.csv\")\r\n    writeNodeType(node_type_dict, \"thgl-myket_nodetype.csv\")\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()"
  },
  {
    "path": "tgb/datasets/thgl_myket/thgl_myket_ns_gen.py",
    "content": "import time\r\nfrom tgb.linkproppred.thg_negative_generator import THGNegativeEdgeGenerator\r\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\r\n\r\n\r\ndef main():\r\n    r\"\"\"\r\n    Generate negative edges for the validation or test phase\r\n    \"\"\"\r\n    print(\"*** Negative Sample Generation ***\")\r\n\r\n    # setting the required parameters\r\n    num_neg_e_per_pos = 20 #-1 \r\n    neg_sample_strategy = \"node-type-filtered\"\r\n    rnd_seed = 42\r\n\r\n\r\n    name = \"thgl-myket\"\r\n    dataset = PyGLinkPropPredDataset(name=name, root=\"datasets\")\r\n    train_mask = dataset.train_mask\r\n    val_mask = dataset.val_mask\r\n    test_mask = dataset.test_mask\r\n    data = dataset.get_TemporalData()\r\n\r\n\r\n\r\n    data_splits = {}\r\n    data_splits['train'] = data[train_mask]\r\n    data_splits['val'] = data[val_mask]\r\n    data_splits['test'] = data[test_mask]\r\n\r\n    min_node_idx = min(int(data.src.min()), int(data.dst.min()))\r\n    max_node_idx = max(int(data.src.max()), int(data.dst.max()))\r\n\r\n    neg_sampler = THGNegativeEdgeGenerator(\r\n        dataset_name=name,\r\n        first_node_id=min_node_idx,\r\n        last_node_id=max_node_idx,\r\n        node_type=dataset.node_type,\r\n        num_neg_e=num_neg_e_per_pos,\r\n        strategy=neg_sample_strategy,\r\n        rnd_seed=rnd_seed,\r\n        edge_data=data,\r\n    )\r\n\r\n   \r\n    # generate evaluation set\r\n    partial_path = \".\"\r\n    # generate validation negative edge set\r\n    start_time = time.time()\r\n    split_mode = \"val\"\r\n    print(\r\n        f\"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}\"\r\n    )\r\n    neg_sampler.generate_negative_samples(\r\n        pos_edges=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path\r\n    )\r\n    print(\r\n        f\"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}\"\r\n    )\r\n\r\n    # generate test negative edge set\r\n    start_time = time.time()\r\n    split_mode = \"test\"\r\n    print(\r\n        f\"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}\"\r\n    )\r\n    neg_sampler.generate_negative_samples(\r\n        pos_edges=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path\r\n    )\r\n    print(\r\n        f\"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}\"\r\n    )\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n"
  },
  {
    "path": "tgb/datasets/thgl_software/thgl_software.py",
    "content": "import csv\r\nimport datetime\r\nimport glob, os\r\n\r\ndef load_csv_raw(fname):\r\n    \"\"\"\r\n    load the raw csv file and merge them into one\r\n    \"\"\"\r\n    out_dict = {}\r\n    num_lines = 0\r\n    with open(fname, 'r') as f:\r\n        reader = csv.reader(f, delimiter ='\\t')\r\n        #* /user/10746682\tU_SO_C_IC\t/issue_comment/455195715\t1547754198\r\n        for row in reader: \r\n            head = row[0]\r\n            relation_type = row[1]\r\n            tail = row[2]\r\n            ts = int(row[3])\r\n            if (ts in out_dict):\r\n                if (head, tail, relation_type) in out_dict[ts]:\r\n                    out_dict[ts][(head, tail, relation_type)] += 1\r\n                else:\r\n                    out_dict[ts][(head, tail, relation_type)] = 1\r\n            else:\r\n                out_dict[ts] = {}\r\n                out_dict[ts][(head, tail, relation_type)] = 1\r\n            num_lines += 1\r\n    return out_dict, num_lines\r\n\r\n\r\ndef write2csv(outname, out_dict):\r\n    with open(outname, 'w') as f:\r\n        writer = csv.writer(f, delimiter =',')\r\n        writer.writerow(['ts', 'head', 'tail', 'relation_type'])\r\n        ts_list = list(out_dict.keys())\r\n        ts_list.sort()\r\n\r\n        for ts in ts_list:\r\n            for edge in out_dict[ts]:\r\n                head = edge[0]\r\n                tail = edge[1]\r\n                relation_type = edge[2]\r\n                row = [ts, head, tail, relation_type]\r\n                writer.writerow(row)\r\n\r\n\r\ndef load_edgelist(fname):\r\n    \"\"\"\r\n    load the edgelist\r\n    \"\"\"\r\n    node_dict = {} # {node_name: node_id}\r\n    node_type_dict = {} # {node_id: node_type}\r\n    rel_type_dict = {}\r\n    edge_dict = {} # {edge: edge_type}\r\n    node_type_mapping = {}\r\n    num_edges = 0\r\n    with open(fname, 'r') as f:\r\n        reader = csv.reader(f, delimiter =',')\r\n        first_row = True\r\n        for row in reader:\r\n            if first_row:\r\n                first_row = False\r\n                continue\r\n            ts = int(row[0])\r\n            head = row[1]\r\n            tail = row[2]\r\n            relation_type = row[3]\r\n\r\n            head_strs = head.split('/')\r\n            tail_strs = tail.split('/')\r\n            head_type = head_strs[1]\r\n            tail_type = tail_strs[1]\r\n\r\n            if head_type not in node_type_mapping:\r\n                node_type_mapping[head_type] = len(node_type_mapping)\r\n            if tail_type not in node_type_mapping:\r\n                node_type_mapping[tail_type] = len(node_type_mapping)\r\n\r\n            if head not in node_dict:\r\n                node_dict[head] = len(node_dict)\r\n                node_type_dict[node_dict[head]] = node_type_mapping[head_type]\r\n            if tail not in node_dict:\r\n                node_dict[tail] = len(node_dict)\r\n                node_type_dict[node_dict[tail]] = node_type_mapping[tail_type]\r\n            if relation_type not in rel_type_dict:\r\n                rel_type_dict[relation_type] = len(rel_type_dict)\r\n            if ts not in edge_dict:\r\n                edge_dict[ts] = {}\r\n            edge_dict[ts][(node_dict[head], node_dict[tail], rel_type_dict[relation_type])] = 1\r\n            num_edges += 1\r\n    print (\"there are {} nodes\".format(len(node_dict)))\r\n    print (\"there are {} edges\".format(num_edges))\r\n\r\n    return node_dict, node_type_dict, edge_dict, rel_type_dict, node_type_mapping\r\n\r\n\r\n\r\ndef writeNodeType(node_type_dict, outname):\r\n    r\"\"\"\r\n    write the node type mapping to a file\r\n    \"\"\"\r\n    with open(outname, 'w') as f:\r\n        writer = csv.writer(f, delimiter =',')\r\n        writer.writerow(['node_id', 'type'])\r\n        for key in node_type_dict:\r\n            writer.writerow([key, node_type_dict[key]])\r\n\r\n\r\ndef writeEdgeTypeMapping(edge_type_dict, outname):\r\n    r\"\"\"\r\n    write the edge type mapping to a file\r\n    \"\"\"\r\n    with open(outname, 'w') as f:\r\n        writer = csv.writer(f, delimiter =',')\r\n        writer.writerow(['edge_id', 'type'])\r\n        for key in edge_type_dict:\r\n            writer.writerow([key, edge_type_dict[key]])\r\n\r\n\r\ndef writeNodeTypeMapping(node_type_dict, outname):\r\n    r\"\"\"\r\n    write the edge type mapping to a file\r\n    \"\"\"\r\n    with open(outname, 'w') as f:\r\n        writer = csv.writer(f, delimiter =',')\r\n        writer.writerow(['node_type_id', 'type'])\r\n        for key in node_type_dict:\r\n            writer.writerow([key, node_type_dict[key]])\r\n\r\n\r\ndef write2edgelist(out_dict, outname):\r\n    r\"\"\"\r\n    Write the dictionary to a csv file\r\n    \"\"\"\r\n    num_lines = 0\r\n    with open(outname, 'w') as f:\r\n        writer = csv.writer(f, delimiter =',')\r\n        writer.writerow(['timestamp', 'head', 'tail', 'relation_type'])\r\n        dates = list(out_dict.keys())\r\n        dates.sort()\r\n        for date in dates:\r\n            for edge in out_dict[date]:\r\n                head = edge[0]\r\n                tail = edge[1]\r\n                relation_type = int(edge[2])\r\n                row = [date, head, tail, relation_type]\r\n                writer.writerow(row)\r\n                num_lines += 1\r\n    print (\"there are {} lines in the file\".format(num_lines))\r\n\r\n\r\n\r\ndef main():\r\n\r\n    fname = \"software_edgelist.csv\" #\"all_edgelist.csv\"\r\n    node_dict, node_type_dict, edge_dict, edge_type_dict, node_type_mapping = load_edgelist(fname)\r\n    write2edgelist (edge_dict, \"thgl-software_edgelist.csv\")\r\n    writeNodeType(node_type_dict, \"thgl-software_nodetype.csv\")\r\n    writeEdgeTypeMapping(edge_type_dict, \"thgl-software_edgemapping.csv\")\r\n    writeNodeTypeMapping(node_type_mapping, \"thgl-software_nodemapping.csv\")\r\n    \r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()"
  },
  {
    "path": "tgb/datasets/thgl_software/thgl_software_ns_gen.py",
    "content": "import time\r\nfrom tgb.linkproppred.thg_negative_generator import THGNegativeEdgeGenerator\r\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\r\n\r\n\r\ndef main():\r\n    r\"\"\"\r\n    Generate negative edges for the validation or test phase\r\n    \"\"\"\r\n    print(\"*** Negative Sample Generation ***\")\r\n\r\n    # setting the required parameters\r\n    num_neg_e_per_pos = 1000 \r\n    neg_sample_strategy = \"node-type-filtered\" #\"random\"\r\n    rnd_seed = 42\r\n\r\n\r\n    name = \"thgl-software\"\r\n    dataset = PyGLinkPropPredDataset(name=name, root=\"datasets\")\r\n    train_mask = dataset.train_mask\r\n    val_mask = dataset.val_mask\r\n    test_mask = dataset.test_mask\r\n    data = dataset.get_TemporalData()\r\n\r\n\r\n\r\n    data_splits = {}\r\n    data_splits['train'] = data[train_mask]\r\n    data_splits['val'] = data[val_mask]\r\n    data_splits['test'] = data[test_mask]\r\n\r\n    min_node_idx = min(int(data.src.min()), int(data.dst.min()))\r\n    max_node_idx = max(int(data.src.max()), int(data.dst.max()))\r\n\r\n    neg_sampler = THGNegativeEdgeGenerator(\r\n        dataset_name=name,\r\n        first_node_id=min_node_idx,\r\n        last_node_id=max_node_idx,\r\n        node_type=dataset.node_type,\r\n        num_neg_e=num_neg_e_per_pos,\r\n        strategy=neg_sample_strategy,\r\n        rnd_seed=rnd_seed,\r\n        edge_data=data,\r\n    )\r\n\r\n   \r\n    # generate evaluation set\r\n    partial_path = \".\"\r\n    # generate validation negative edge set\r\n    start_time = time.time()\r\n    split_mode = \"val\"\r\n    print(\r\n        f\"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}\"\r\n    )\r\n    neg_sampler.generate_negative_samples(\r\n        pos_edges=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path\r\n    )\r\n    print(\r\n        f\"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}\"\r\n    )\r\n\r\n    # generate test negative edge set\r\n    start_time = time.time()\r\n    split_mode = \"test\"\r\n    print(\r\n        f\"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}\"\r\n    )\r\n    neg_sampler.generate_negative_samples(\r\n        pos_edges=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path\r\n    )\r\n    print(\r\n        f\"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}\"\r\n    )\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n"
  },
  {
    "path": "tgb/datasets/tkgl_icews/tkgl_icews.py",
    "content": "import csv\r\nimport datetime\r\nimport glob, os\r\n\r\n\r\ndef load_csv_raw(fname):\r\n    r\"\"\"\r\n    load from the raw data and retrieve, timestamp, head, tail, relation \r\n    convert all dates into unix timestamps\r\n    #! Event ID\tEvent Date\tSource Name\tSource Sectors\tSource Country\tEvent Text\tCAMEO Code\tIntensity\tTarget Name\tTarget Sectors\tTarget Country\tStory ID\tSentence Number\tPublisher\tCity\tDistrict\tProvince\tCountry\tLatitude\tLongitude\r\n    \"\"\"\r\n    out_dict = {}\r\n    first_row = True\r\n    num_lines = 0\r\n    with open(fname, 'r', encoding='ISO-8859-1') as f:\r\n        reader = csv.reader(f, delimiter ='\\t')\r\n        for row in reader: \r\n            if first_row:\r\n                first_row = False\r\n                continue\r\n            date = row[1] #1995-01-01\r\n            head = row[2]\r\n            tail = row[8]\r\n            relation_type = row[6] #CAMEO code  #! not always integer in 2017 for some reason there is 13y\r\n            if (len(date) == 0):\r\n                continue\r\n            \r\n            if (\"None\" in date or \"None\" in head or \"None\" in tail or \"None\" in relation_type):\r\n                continue\r\n            else:\r\n                 #! remove redundant edges with same timestamps\r\n                TIME_FORMAT = \"%Y-%m-%d\" #2018-01-01\r\n                date_cur = datetime.datetime.strptime(date, TIME_FORMAT)\r\n                ts = int(date_cur.timestamp())\r\n                num_lines += 1\r\n                if (ts in out_dict):\r\n                    if (head, tail, relation_type) in out_dict[ts]:\r\n                        out_dict[ts][(head, tail, relation_type)] += 1\r\n                    else:\r\n                        out_dict[ts][(head, tail, relation_type)] = 1\r\n                else:\r\n                    out_dict[ts] = {}\r\n                    out_dict[ts][(head, tail, relation_type)] = 1\r\n    return out_dict, num_lines\r\n\r\n\r\ndef write2csv(outname, out_dict):\r\n\r\n    node_dict = {}\r\n    max_node_id = 0\r\n    edge_type_dict = {}\r\n    max_edge_type_id = 0\r\n\r\n    with open(outname, 'w') as f:\r\n        writer = csv.writer(f, delimiter =',')\r\n        writer.writerow(['date', 'head', 'tail', 'relation_type'])\r\n\r\n        for date in out_dict:\r\n            for edge in out_dict[date]:\r\n                head = edge[0]\r\n                tail = edge[1]\r\n                relation_type = edge[2]\r\n                if head not in node_dict:\r\n                    node_dict[head] = max_node_id\r\n                    max_node_id += 1\r\n                if tail not in node_dict:\r\n                    node_dict[tail] = max_node_id\r\n                    max_node_id += 1\r\n                if relation_type not in edge_type_dict:\r\n                    edge_type_dict[relation_type] = max_edge_type_id\r\n                    max_edge_type_id += 1\r\n                row = [date, node_dict[head], node_dict[tail], edge_type_dict[relation_type]]\r\n                writer.writerow(row)\r\n    return node_dict, edge_type_dict\r\n\r\n\r\ndef writeEdgeTypeMapping(edge_type_dict, outname):\r\n    r\"\"\"\r\n    write the edge type mapping to a file\r\n    \"\"\"\r\n    with open(outname, 'w') as f:\r\n        writer = csv.writer(f, delimiter =',')\r\n        writer.writerow(['edge_id', 'type'])\r\n        for key in edge_type_dict:\r\n            writer.writerow([key, edge_type_dict[key]])\r\n\r\n\r\n\r\n\r\ndef main():\r\n\r\n    total_lines = 0\r\n    total_edge_dict = {} \r\n    #1. find all files with .txt in the folder\r\n    for file in glob.glob(\"*.tab\"):\r\n        # outname = file[7:11] + \"_edgelist.csv\"\r\n        print (\"processing\", file)\r\n        edge_dict, num_lines = load_csv_raw(file)\r\n        total_lines += num_lines\r\n        print (\"-----------------------------------\")\r\n        print (\"file, \", file)\r\n        print (\"number of lines, \", num_lines)\r\n        print (\"number of days, \", len(edge_dict))\r\n        print (\"-----------------------------------\")\r\n        total_edge_dict.update(edge_dict)\r\n    outname = \"tkgl-icews_edgelist_tiny.csv\"\r\n    print (\"total number of lines\", total_lines)\r\n    print (\"total number of days\", len(total_edge_dict))    \r\n    node_dict, edge_type_dict = write2csv(outname, total_edge_dict)\r\n    writeEdgeTypeMapping(edge_type_dict, \"tkgl-icews_edgemapping.csv\")\r\n\r\n\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()"
  },
  {
    "path": "tgb/datasets/tkgl_icews/tkgl_icews_ns_gen.py",
    "content": "import time\r\nfrom tgb.linkproppred.tkg_negative_generator import TKGNegativeEdgeGenerator\r\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\r\n\r\n\r\ndef main():\r\n    r\"\"\"\r\n    Generate negative edges for the validation or test phase\r\n    \"\"\"\r\n    print(\"*** Negative Sample Generation ***\")\r\n\r\n    # setting the required parameters\r\n    num_neg_e_per_pos = -1 \r\n    neg_sample_strategy = \"time-filtered\"\r\n    rnd_seed = 42\r\n\r\n\r\n    name = \"tkgl-icews\"\r\n    dataset = PyGLinkPropPredDataset(name=name, root=\"datasets\")\r\n    train_mask = dataset.train_mask\r\n    val_mask = dataset.val_mask\r\n    test_mask = dataset.test_mask\r\n    data = dataset.get_TemporalData()\r\n\r\n\r\n\r\n    data_splits = {}\r\n    data_splits['train'] = data[train_mask]\r\n    data_splits['val'] = data[val_mask]\r\n    data_splits['test'] = data[test_mask]\r\n\r\n    # Ensure to only sample actual destination nodes as negatives.\r\n    min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\r\n\r\n\r\n    neg_sampler = TKGNegativeEdgeGenerator(\r\n        dataset_name=name,\r\n        first_dst_id=min_dst_idx,\r\n        last_dst_id=max_dst_idx,\r\n        num_neg_e=num_neg_e_per_pos,\r\n        strategy=neg_sample_strategy,\r\n        rnd_seed=rnd_seed,\r\n        edge_data=data,\r\n    )\r\n\r\n    # generate evaluation set\r\n    partial_path = \".\"\r\n    # generate validation negative edge set\r\n    start_time = time.time()\r\n    split_mode = \"val\"\r\n    print(\r\n        f\"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}\"\r\n    )\r\n    neg_sampler.generate_negative_samples(\r\n        pos_edges=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path\r\n    )\r\n    print(\r\n        f\"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}\"\r\n    )\r\n\r\n    # generate test negative edge set\r\n    start_time = time.time()\r\n    split_mode = \"test\"\r\n    print(\r\n        f\"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}\"\r\n    )\r\n    neg_sampler.generate_negative_samples(\r\n        pos_edges=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path\r\n    )\r\n    print(\r\n        f\"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}\"\r\n    )\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n"
  },
  {
    "path": "tgb/datasets/tkgl_polecat/tkgl_polecat.py",
    "content": "import csv\r\nimport datetime\r\nimport glob, os\r\n\r\n\r\n\r\n\r\ndef load_csv_raw(fname):\r\n    r\"\"\"\r\n    load from the raw data and retrieve, timestamp, head, tail, relation \r\n    convert all dates into unix timestamps\r\n    #!Event ID\tEvent Date\tEvent Type\tEvent Mode\tIntensity\tQuad Code\tContexts\tActor Name\tActor Country\tActor COW\tPrimary Actor Sector\tActor Sectors\tActor Title\tActor Name Raw\tWikipedia Actor ID\tRecipient Name\tRecipient Country\tRecipient COW\tPrimary Recipient Sector\tRecipient Sectors\tRecipient Title\tRecipient Name Raw\tWikipedia Recipient ID\tPlacename\tCity\tDistrict\tProvince\tCountry\tLatitude\tLongitude\tGeoNames ID\tRaw Placename\tFeature Type\tSource\tPublication Date\tStory People\tStory Organizations\tStory Locations\tLanguage\tVersion\r\n    \"\"\"\r\n    out_dict = {}\r\n    first_row = True\r\n    num_lines = 0\r\n    with open(fname, 'r', encoding='ISO-8859-1') as f:\r\n        reader = csv.reader(f, delimiter ='\\t')\r\n        for row in reader: \r\n            if first_row:\r\n                first_row = False\r\n                continue\r\n            date = row[1]\r\n            relation_type = row[2]\r\n            head = row[7]\r\n            tail = row[15]\r\n\r\n            if (len(date) == 0):\r\n                continue\r\n            \r\n            if (\"None\" in date or \"None\" in head or \"None\" in tail or \"None\" in relation_type):\r\n                continue\r\n            else:\r\n                 #! remove redundant edges with same timestamps\r\n                TIME_FORMAT = \"%Y-%m-%d\" #2018-01-01\r\n                date_cur = datetime.datetime.strptime(date, TIME_FORMAT)\r\n                ts = int(date_cur.timestamp())\r\n                num_lines += 1\r\n                if (ts in out_dict):\r\n                    if (head, tail, relation_type) in out_dict[ts]:\r\n                        out_dict[ts][(head, tail, relation_type)] += 1\r\n                    else:\r\n                        out_dict[ts][(head, tail, relation_type)] = 1\r\n                else:\r\n                    out_dict[ts] = {}\r\n                    out_dict[ts][(head, tail, relation_type)] = 1\r\n    return out_dict, num_lines\r\n\r\n\r\n#! fill in node and edge type dictionaries\r\ndef write2csv(outname: str, \r\n              out_dict: dict,\r\n              edge_type_dict: dict = None,\r\n              node_dict: dict = None,):\r\n    r\"\"\"\r\n    Write the dictionary to a csv file\r\n    also keep track of edge_type or node_dict, update the provided one too\r\n    \"\"\"\r\n    if (edge_type_dict is None):\r\n        edge_type_dict = {}\r\n    if (node_dict is None):\r\n        node_dict = {}\r\n\r\n    max_edge_type_id = len(edge_type_dict)\r\n    max_node_id = len(node_dict)    \r\n    with open(outname, 'w') as f:\r\n        writer = csv.writer(f, delimiter =',')\r\n        writer.writerow(['date', 'head', 'tail', 'relation_type'])\r\n\r\n        dates = list(out_dict.keys())\r\n        dates.sort()\r\n        for date in dates:\r\n            for edge in out_dict[date]:\r\n                head = edge[0]\r\n                tail = edge[1]\r\n                relation_type = edge[2]\r\n                if head not in node_dict:\r\n                    node_dict[head] = max_node_id\r\n                    max_node_id += 1\r\n                if tail not in node_dict:\r\n                    node_dict[tail] = max_node_id\r\n                    max_node_id += 1\r\n                if relation_type not in edge_type_dict:\r\n                    edge_type_dict[relation_type] = max_edge_type_id\r\n                    max_edge_type_id += 1\r\n                row = [date, node_dict[head], node_dict[tail], edge_type_dict[relation_type]]\r\n                writer.writerow(row)\r\n    \r\n    return edge_type_dict, node_dict\r\n\r\n\r\n# def write2csv(outname, out_dict):\r\n\r\n#     node_dict = {}\r\n#     max_node_id = 0\r\n#     edge_type_dict = {}\r\n#     max_edge_type_id = 0\r\n\r\n#     with open(outname, 'w') as f:\r\n#         writer = csv.writer(f, delimiter =',')\r\n#         writer.writerow(['date', 'head', 'tail', 'relation_type'])\r\n\r\n#         for date in out_dict:\r\n#             for edge in out_dict[date]:\r\n#                 head = edge[0]\r\n#                 tail = edge[1]\r\n#                 relation_type = edge[2]\r\n#                 if head not in node_dict:\r\n#                     node_dict[head] = max_node_id\r\n#                     max_node_id += 1\r\n#                 if tail not in node_dict:\r\n#                     node_dict[tail] = max_node_id\r\n#                     max_node_id += 1\r\n#                 if relation_type not in edge_type_dict:\r\n#                     edge_type_dict[relation_type] = max_edge_type_id\r\n#                     max_edge_type_id += 1\r\n#                 row = [date, node_dict[head], node_dict[tail], edge_type_dict[relation_type]]\r\n#                 writer.writerow(row)\r\n\r\n\r\ndef writeEdgeTypeMapping(edge_type_dict, outname):\r\n    r\"\"\"\r\n    write the edge type mapping to a file\r\n    \"\"\"\r\n    with open(outname, 'w') as f:\r\n        writer = csv.writer(f, delimiter =',')\r\n        writer.writerow(['edge_id', 'type'])\r\n        for key in edge_type_dict:\r\n            writer.writerow([key, edge_type_dict[key]])\r\n\r\n\r\ndef main():\r\n\r\n    #example\r\n    # fname = \"2018-Jan.txt\"\r\n    # print (\"hi\")\r\n    # lines = load_csv_raw(fname)\r\n    # outname = \"tkgl-polecat_edgelist.csv\"\r\n    # write2csv(outname, lines)\r\n\r\n    total_lines = 0\r\n    num_days = 0\r\n    total_edge_dict = {}\r\n\r\n    #1. find all files with .txt in the folder\r\n    for file in glob.glob(\"*.csv\"):\r\n        outname = file[0:7] + \"_edgelist.csv\"\r\n        print (\"processing\", file, \"to\", outname)\r\n        edge_dict, num_lines = load_csv_raw(file)\r\n        total_lines += num_lines\r\n        num_days += len(edge_dict)\r\n        total_edge_dict.update(edge_dict)\r\n    edge_type_dict, node_dict = write2csv(\"tkgl-polecat_edgelist.csv\", total_edge_dict)\r\n    print (\"-----------------------------------\")\r\n    print (\"total number of lines\", total_lines)\r\n    print (\"total number of days\", num_days)\r\n    print (\"there are\", len(edge_type_dict), \"unique edge types\")\r\n    print (\"there are\", len(node_dict), \"unique nodes\")\r\n    writeEdgeTypeMapping(edge_type_dict, \"tkgl-polecat_edgemapping.csv\")\r\n\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n\r\n    #* rename functions\r\n    # renames = []\r\n    # for file in glob.glob(\"*.txt\"):\r\n    #     outname = file[-12:-4] + \"_edgelist.csv\"\r\n    #     file_rename = file[-12:-4] + \"_raw.csv\"\r\n    #     if (\"Jan\" in outname):\r\n    #         outname = outname.replace(\"Jan\", \"01\")\r\n    #         renames.append((file, file_rename.replace(\"Jan\", \"01\")))\r\n    #     elif (\"Feb\" in outname):\r\n    #         outname = outname.replace(\"Feb\", \"02\")\r\n    #         renames.append((file, file_rename.replace(\"Feb\", \"02\")))\r\n    #     elif (\"Mar\" in outname):\r\n    #         outname = outname.replace(\"Mar\", \"03\")\r\n    #         renames.append((file, file_rename.replace(\"Mar\", \"03\")))\r\n    #     elif (\"Apr\" in outname):\r\n    #         outname = outname.replace(\"Apr\", \"04\")\r\n    #         renames.append((file, file_rename.replace(\"Apr\", \"04\")))\r\n    #     elif (\"May\" in outname):\r\n    #         outname = outname.replace(\"May\", \"05\")\r\n    #         renames.append((file, file_rename.replace(\"May\", \"05\")))\r\n    #     elif (\"Jun\" in outname):\r\n    #         outname = outname.replace(\"Jun\", \"06\")\r\n    #         renames.append((file, file_rename.replace(\"Jun\", \"06\")))\r\n    #     elif (\"Jul\" in outname):\r\n    #         outname = outname.replace(\"Jul\", \"07\")\r\n    #         renames.append((file, file_rename.replace(\"Jul\", \"07\")))\r\n    #     elif (\"Aug\" in outname):\r\n    #         outname = outname.replace(\"Aug\", \"08\")\r\n    #         renames.append((file, file_rename.replace(\"Aug\", \"08\")))\r\n    #     elif (\"Sep\" in outname):\r\n    #         outname = outname.replace(\"Sep\", \"09\")\r\n    #         renames.append((file, file_rename.replace(\"Sep\", \"09\")))\r\n    #     elif (\"Oct\" in outname):\r\n    #         outname = outname.replace(\"Oct\", \"10\")\r\n    #         renames.append((file, file_rename.replace(\"Oct\", \"10\")))\r\n    #     elif (\"Nov\" in outname):\r\n    #         outname = outname.replace(\"Nov\", \"11\")\r\n    #         renames.append((file, file_rename.replace(\"Nov\", \"11\")))\r\n    #     elif (\"Dec\" in outname):\r\n    #         outname = outname.replace(\"Dec\", \"12\")\r\n    #         renames.append((file, file_rename.replace(\"Dec\", \"12\")))\r\n    # for file, file_rename in renames:\r\n    #     os.rename(file, file_rename)"
  },
  {
    "path": "tgb/datasets/tkgl_polecat/tkgl_polecat_ns_gen.py",
    "content": "import time\r\nfrom tgb.linkproppred.tkg_negative_generator import TKGNegativeEdgeGenerator\r\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\r\n\r\n\r\ndef main():\r\n    r\"\"\"\r\n    Generate negative edges for the validation or test phase\r\n    \"\"\"\r\n    print(\"*** Negative Sample Generation ***\")\r\n\r\n    # setting the required parameters\r\n    num_neg_e_per_pos = -1 \r\n    neg_sample_strategy = \"time-filtered\"\r\n    rnd_seed = 42\r\n\r\n\r\n    name = \"tkgl-polecat\"\r\n    dataset = PyGLinkPropPredDataset(name=name, root=\"datasets\")\r\n    train_mask = dataset.train_mask\r\n    val_mask = dataset.val_mask\r\n    test_mask = dataset.test_mask\r\n    data = dataset.get_TemporalData()\r\n\r\n\r\n\r\n    data_splits = {}\r\n    data_splits['train'] = data[train_mask]\r\n    data_splits['val'] = data[val_mask]\r\n    data_splits['test'] = data[test_mask]\r\n\r\n    # Ensure to only sample actual destination nodes as negatives.\r\n    min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\r\n\r\n\r\n    neg_sampler = TKGNegativeEdgeGenerator(\r\n        dataset_name=name,\r\n        first_dst_id=min_dst_idx,\r\n        last_dst_id=max_dst_idx,\r\n        num_neg_e=num_neg_e_per_pos,\r\n        strategy=neg_sample_strategy,\r\n        rnd_seed=rnd_seed,\r\n        edge_data=data,\r\n    )\r\n\r\n    # generate evaluation set\r\n    partial_path = \".\"\r\n    # generate validation negative edge set\r\n    start_time = time.time()\r\n    split_mode = \"val\"\r\n    print(\r\n        f\"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}\"\r\n    )\r\n    neg_sampler.generate_negative_samples(\r\n        pos_edges=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path\r\n    )\r\n    print(\r\n        f\"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}\"\r\n    )\r\n\r\n    # generate test negative edge set\r\n    start_time = time.time()\r\n    split_mode = \"test\"\r\n    print(\r\n        f\"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}\"\r\n    )\r\n    neg_sampler.generate_negative_samples(\r\n        pos_edges=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path\r\n    )\r\n    print(\r\n        f\"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}\"\r\n    )\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n"
  },
  {
    "path": "tgb/datasets/tkgl_smallpedia/smallpedia_remove_conflict.py",
    "content": "import csv\r\n\r\ndef load_static_edgelist(file_path):\r\n    r\"\"\"\r\n    Load the static edgelist from the file_path\r\n    Args:\r\n        file_path: str, The path to the file\r\n    \"\"\"\r\n    static_dict = {}\r\n    first_row = True\r\n    with open(file_path, 'r') as f:\r\n        reader = csv.reader(f, delimiter =',')\r\n        for row in reader: \r\n            if first_row:\r\n                first_row = False\r\n                continue\r\n            head = row[0]\r\n            tail = row[1]\r\n            relation_type = row[2]\r\n            static_dict[(head, tail, relation_type)] = 1\r\n    return static_dict\r\n\r\n\r\ndef load_temporal_edgelist(file_path):\r\n    r\"\"\"\r\n    Load the temporal edgelist from the file_path\r\n    Args:\r\n        file_path: str, The path to the file\r\n    \"\"\"\r\n    temporal_dict = {}\r\n    first_row = True\r\n    with open(file_path, 'r') as f:\r\n        \"\"\"\r\n        ts,head,tail,relation_type\r\n        0,Q331755,Q1294765,P39\r\n        \"\"\"\r\n        reader = csv.reader(f, delimiter =',')\r\n        for row in reader: \r\n            if first_row:\r\n                first_row = False\r\n                continue\r\n            ts = int(row[0])\r\n            head = row[1]\r\n            tail = row[2]\r\n            relation_type = row[3]\r\n            if ts not in temporal_dict:\r\n                temporal_dict[ts] = {}\r\n                temporal_dict[ts][(head, tail, relation_type)] = 1\r\n            else:\r\n                if (head, tail, relation_type) in temporal_dict[ts]:\r\n                    temporal_dict[ts][(head, tail, relation_type)] += 1\r\n                else:\r\n                    temporal_dict[ts][(head, tail, relation_type)] = 1\r\n    return temporal_dict\r\n\r\n\r\ndef remove_conflict(static_dict, temporal_dict):\r\n    r\"\"\"\r\n    Remove the conflict between the static and temporal edgelist\r\n    Args:\r\n        static_dict: dict, The static edgelist\r\n        temporal_dict: dict, The temporal edgelist\r\n    \"\"\"\r\n    num_conflicts = 0\r\n    for ts in temporal_dict:\r\n        for edge in temporal_dict[ts]:\r\n            head = edge[0]\r\n            tail = edge[1]\r\n            relation_type = edge[2]\r\n            if (head, tail, relation_type) in static_dict:\r\n                num_conflicts += 1\r\n                static_dict.pop((head, tail, relation_type))\r\n    print(\"Removed {} conflicts\".format(num_conflicts))\r\n    return static_dict\r\n\r\n\r\ndef write2csv(outname: str, \r\n              out_dict: dict,):\r\n    r\"\"\"\r\n    Write the dictionary to a csv file\r\n    \"\"\"\r\n    with open(outname, 'w') as f:\r\n        writer = csv.writer(f, delimiter =',')\r\n        #head,tail,relation_type\r\n        writer.writerow(['head', 'tail', 'relation_type'])\r\n        for edge in out_dict:\r\n            head = edge[0]\r\n            tail = edge[1]\r\n            relation_type = edge[2]\r\n            row = [head, tail, relation_type]\r\n            writer.writerow(row)\r\n\r\n\r\n\r\n\r\ndef main():\r\n    #! remove conflict: remove all edges with the same head, tail, relation_type from the static edgelist\r\n    static_file = \"tkgl-smallpedia_static_edgelist.csv\"\r\n    temporal_file = \"tkgl-smallpedia_edgelist.csv\"\r\n    static_dict = load_static_edgelist(static_file)\r\n    print(\"constructed static dictionary\")\r\n    temporal_dict = load_temporal_edgelist(temporal_file)\r\n    print(\"constructed temporal dictionary\")\r\n    static_dict = remove_conflict(static_dict, temporal_dict)\r\n    out_name = \"tkgl-smallpedia_static_edgelist_no_conflict.csv\"\r\n    write2csv(out_name, static_dict)\r\n\r\n\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()"
  },
  {
    "path": "tgb/datasets/tkgl_smallpedia/tkgl_smallpedia_ns_gen.py",
    "content": "import time\r\nfrom tgb.linkproppred.tkg_negative_generator import TKGNegativeEdgeGenerator\r\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\r\n\r\n\r\ndef main():\r\n    r\"\"\"\r\n    Generate negative edges for the validation or test phase\r\n    \"\"\"\r\n    print(\"*** Negative Sample Generation ***\")\r\n\r\n    # setting the required parameters\r\n    num_neg_e_per_pos = -1 #10000\r\n    neg_sample_strategy = \"time-filtered\" #\"dst-time-filtered\"\r\n    rnd_seed = 42\r\n\r\n\r\n    name = \"tkgl-smallpedia\"\r\n    dataset = PyGLinkPropPredDataset(name=name, root=\"datasets\")\r\n    train_mask = dataset.train_mask\r\n    val_mask = dataset.val_mask\r\n    test_mask = dataset.test_mask\r\n    data = dataset.get_TemporalData()\r\n\r\n\r\n\r\n    data_splits = {}\r\n    data_splits['train'] = data[train_mask]\r\n    data_splits['val'] = data[val_mask]\r\n    data_splits['test'] = data[test_mask]\r\n\r\n    # Ensure to only sample actual destination nodes as negatives.\r\n    min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\r\n\r\n\r\n    neg_sampler = TKGNegativeEdgeGenerator(\r\n        dataset_name=name,\r\n        first_dst_id=min_dst_idx,\r\n        last_dst_id=max_dst_idx,\r\n        num_neg_e=num_neg_e_per_pos,\r\n        strategy=neg_sample_strategy,\r\n        rnd_seed=rnd_seed,\r\n        partial_path=\".\",\r\n        edge_data=data,\r\n    )\r\n\r\n    # generate evaluation set\r\n    partial_path = \".\"\r\n    # generate validation negative edge set\r\n    start_time = time.time()\r\n    split_mode = \"val\"\r\n    print(\r\n        f\"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}\"\r\n    )\r\n    neg_sampler.generate_negative_samples(\r\n        pos_edges=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path\r\n    )\r\n    print(\r\n        f\"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}\"\r\n    )\r\n\r\n    # generate test negative edge set\r\n    start_time = time.time()\r\n    split_mode = \"test\"\r\n    print(\r\n        f\"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}\"\r\n    )\r\n    neg_sampler.generate_negative_samples(\r\n        pos_edges=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path\r\n    )\r\n    print(\r\n        f\"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}\"\r\n    )\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n"
  },
  {
    "path": "tgb/datasets/tkgl_wikidata/extract.sh",
    "content": "for chunk in 5; do\nnum_chunk=25\nwhile [ $chunk -le $num_chunk ]; do\n    cmd=\"tkgl_wikidata.py \\\n    --chunk ${chunk} \\\n    --num_chunks ${num_chunk} \\\n    \"\n    python $cmd\n    chunk=$(( $chunk + 1 ))\ndone\ndone"
  },
  {
    "path": "tgb/datasets/tkgl_wikidata/time_edges/tkgl-wikidata_extract.py",
    "content": "import csv\r\nimport datetime\r\nimport glob, os\r\n\r\n\r\ndef load_time_csv_raw(fname):\r\n    r\"\"\"\r\n    load from the raw data and retrieve, timestamp, head, tail, relation, time_rel\r\n    convert all dates into unix timestamps\r\n    \"\"\"\r\n    out_dict = {}\r\n    first_row = True\r\n    num_lines = 0\r\n    #? timestamp,head,tail,relation_type,time_rel_type\r\n    #* +1999-01-01T00:00:00Z,Q31,Q4916,P38,P580\r\n    error_ctr = 0\r\n    with open(fname, 'r') as f:\r\n        reader = csv.reader(f, delimiter =',')\r\n\r\n        for row in reader: \r\n            if first_row:\r\n                first_row = False\r\n                continue\r\n            date = row[0][0:11]\r\n            head = row[1]\r\n            tail = row[2]\r\n            relation_type = row[3]\r\n            time_rel = row[4]\r\n\r\n            if (len(date) == 0):\r\n                continue\r\n            \r\n            if (\"None\" in date or \"None\" in head or \"None\" in tail or \"None\" in relation_type):\r\n                continue\r\n            else:\r\n                TIME_FORMAT = \"%Y\"\r\n                #* only keep track of year in positive BC\r\n                if (date[0] == \"+\"):\r\n                    ts = int(date[1:5])\r\n                else:\r\n                    continue\r\n\r\n                #* no scifi for knowledge graphs \r\n                if (ts > 2024):\r\n                    continue\r\n\r\n                num_lines += 1\r\n                if (ts in out_dict):\r\n                    if (head, tail, relation_type, time_rel) in out_dict[ts]:\r\n                        out_dict[ts][(head, tail, relation_type, time_rel)] += 1\r\n                    else:\r\n                        out_dict[ts][(head, tail, relation_type, time_rel)] = 1\r\n                else:\r\n                    out_dict[ts] = {}\r\n                    out_dict[ts][(head, tail, relation_type, time_rel)] = 1\r\n    return out_dict, num_lines\r\n\r\n\r\ndef write2csv(outname: str, \r\n              out_dict: dict,):\r\n    r\"\"\"\r\n    Write the dictionary to a csv file\r\n    \"\"\"\r\n    with open(outname, 'w') as f:\r\n        writer = csv.writer(f, delimiter =',')\r\n        writer.writerow(['ts', 'head', 'tail', 'relation_type','time_rel_type'])\r\n        dates = list(out_dict.keys())\r\n        dates.sort()\r\n        for date in dates:\r\n            for edge in out_dict[date]:\r\n                head = edge[0]\r\n                tail = edge[1]\r\n                relation_type = edge[2]\r\n                time_rel = edge[3]\r\n                row = [date, head, tail, relation_type, time_rel]\r\n                writer.writerow(row)\r\n\r\n\r\ndef update_dict(total_dict, new_dict):\r\n    r\"\"\"\r\n    Update the total_dict with new_dict\r\n    \"\"\"\r\n    for key in new_dict:\r\n        if key in total_dict:\r\n            for edge in new_dict[key]:\r\n                if edge in total_dict[key]:\r\n                    total_dict[key][edge] += new_dict[key][edge]\r\n                else:\r\n                    total_dict[key][edge] = new_dict[key][edge]\r\n        else:\r\n            total_dict[key] = new_dict[key]\r\n    return total_dict\r\n\r\n\r\ndef retrieve_all_entities(total_dict):\r\n    r\"\"\"\r\n    retrieve the entities from all edges of the total dictionary\r\n\r\n    Parameters:\r\n        total_dict: dictionary of all edges, {ts: {edge: count}}\r\n    \"\"\"\r\n    node_dict = {}\r\n    for key in total_dict:\r\n        for edge in total_dict[key]:\r\n            head = edge[0]\r\n            tail = edge[1]\r\n            if head not in node_dict:\r\n                node_dict[head] = 1\r\n            else:\r\n                node_dict[head] += 1\r\n            if tail not in node_dict:\r\n                node_dict[tail] = 1\r\n            else:\r\n                node_dict[tail] += 1\r\n    return node_dict\r\n\r\n\r\ndef writenode2csv(outname: str, \r\n              out_dict: dict,):\r\n    r\"\"\"\r\n    Write the dictionary to a csv file\r\n    \"\"\"\r\n    with open(outname, 'w') as f:\r\n        writer = csv.writer(f, delimiter =',')\r\n        writer.writerow(['entity', 'occurrences'])\r\n        for node in out_dict:\r\n            row = [node, out_dict[node]]\r\n            writer.writerow(row) \r\n\r\n\r\ndef main():\r\n\r\n    #! when timestamps overlap can't update dictionary\r\n\r\n    total_lines = 0\r\n    total_edge_dict = {}\r\n\r\n    #1. find all files with .txt in the folder\r\n    total_edge_file = \"tkgl-wikidata_edgelist.csv\"\r\n    for file in glob.glob(\"*.csv\"):\r\n        print (file)\r\n        edge_dict, num_lines = load_time_csv_raw(file)\r\n        print (\"processed \", num_lines, \" lines\")\r\n        total_lines += num_lines\r\n        update_dict(total_edge_dict, edge_dict)\r\n    print (\"processed a total of \", total_lines, \" lines\")\r\n    node_dict = retrieve_all_entities(total_edge_dict)\r\n    writenode2csv(\"wiki_entities.csv\", node_dict)\r\n    write2csv(total_edge_file, total_edge_dict)\r\n\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()"
  },
  {
    "path": "tgb/datasets/tkgl_wikidata/tkgl-wikidata.py",
    "content": "import csv\r\nimport datetime\r\nimport glob, os\r\n\r\n\r\ndef load_time_csv(fname):\r\n    r\"\"\"\r\n    load from data and retrieve, ts,head,tail,relation_type,time_rel_type\r\n    \"\"\"\r\n    out_dict = {} #only contain edges {ts: {(head, tail, rel_type):count}}\r\n    start_end_dict = {} #{(head, tail, rel_type): {start:year, end:year}}\r\n\r\n    first_row = True\r\n    point_in_time_lines = 0\r\n    start_end_lines = 0\r\n\r\n    #? ts,head,tail,relation_type,time_rel_type\r\n    #* 0,Q331755,Q1294765,P39,P580\r\n    with open(fname, 'r') as f:\r\n        reader = csv.reader(f, delimiter =',')\r\n        for row in reader: \r\n            if first_row:\r\n                first_row = False\r\n                continue\r\n            ts = int(row[0])\r\n            head = row[1]\r\n            tail = row[2]\r\n            relation_type = row[3]\r\n            time_rel = row[4]\r\n            if (time_rel in ['P585', 'P577', 'P574']):\r\n                if (ts in out_dict):\r\n                    if (head, tail, relation_type) in out_dict[ts]:\r\n                        out_dict[ts][(head, tail, relation_type)] += 1\r\n                    else:\r\n                        out_dict[ts][(head, tail, relation_type)] = 1\r\n                else:\r\n                    out_dict[ts] = {}\r\n                    out_dict[ts][(head, tail, relation_type)] = 1\r\n                point_in_time_lines += 1\r\n            else: # time_rel in ['P580', 'P582']\r\n                if (head, tail, relation_type) in start_end_dict:\r\n                    if (time_rel in ['P580']):\r\n                        start_end_dict[(head, tail, relation_type)]['start'] = ts\r\n                    elif (time_rel in ['P582']):\r\n                        start_end_dict[(head, tail, relation_type)]['end'] = ts\r\n                    else:\r\n                        raise ValueError(f\"Unknown time_rel: {time_rel}\")\r\n                else:\r\n                    start_end_dict[(head, tail, relation_type)] = {}\r\n                    if (time_rel in ['P580']):\r\n                        start_end_dict[(head, tail, relation_type)]['start'] = ts\r\n                    else:\r\n                        start_end_dict[(head, tail, relation_type)]['end'] = ts\r\n                start_end_lines += 1\r\n\r\n    print (\"-----------------------------------\")\r\n    print (\"for this edgelist:\")\r\n    print (f\"point_in_time_lines: {point_in_time_lines}\")\r\n    print (f\"start_end_lines: {start_end_lines}\")\r\n    print (\"-----------------------------------\")\r\n    \r\n    \r\n\r\n    repeated_lines = 0\r\n    no_duration_lines = 0\r\n    \r\n    #* now, repeat edges from start_end_dict\r\n    for edge in start_end_dict.keys():\r\n        if 'start' not in start_end_dict[edge]:\r\n            #start_end_dict[edge]['start'] = 0 #start at year 0\r\n            #start_end_dict[edge]['start'] = start_end_dict[edge]['end']\r\n            no_duration_lines += 1\r\n            continue\r\n        if 'end' not in start_end_dict[edge]:\r\n            # start_end_dict[edge]['end'] = 2024 #end at year 2024\r\n            start_end_dict[edge]['end'] = start_end_dict[edge]['start'] #end at year 2024\r\n            no_duration_lines += 1\r\n            continue\r\n        for year in range(start_end_dict[edge]['start'], start_end_dict[edge]['end']+1):\r\n            if year not in out_dict:\r\n                out_dict[year] = {}\r\n            out_dict[year][edge] = 1\r\n            repeated_lines += 1\r\n\r\n    print (\"-----------------------------------\")\r\n    print (\"for this edgelist:\")\r\n    print (f\"point_in_time_lines: {point_in_time_lines}\")\r\n    print (f\"start_end_lines: {start_end_lines} resulting in\")\r\n    print (f\"repeated_lines: {repeated_lines}\")\r\n    print (f\"no_duration_lines: {no_duration_lines}\")\r\n    print (\"-----------------------------------\")\r\n    print (\"total lines: \", point_in_time_lines + repeated_lines)\r\n    num_lines = point_in_time_lines + repeated_lines\r\n    return out_dict, num_lines\r\n\r\n\r\n\r\ndef write2csv(outname: str, \r\n              out_dict: dict,):\r\n    r\"\"\"\r\n    Write the dictionary to a csv file\r\n    \"\"\"\r\n    with open(outname, 'w') as f:\r\n        writer = csv.writer(f, delimiter =',')\r\n        writer.writerow(['ts', 'head', 'tail', 'relation_type'])\r\n        dates = list(out_dict.keys())\r\n        dates.sort()\r\n        for date in dates:\r\n            for edge in out_dict[date]:\r\n                head = edge[0]\r\n                tail = edge[1]\r\n                relation_type = edge[2]\r\n                row = [date, head, tail, relation_type]\r\n                writer.writerow(row)\r\n\r\n\r\ndef extract_subset(fname, outname, start_year=2000, end_year=2024):\r\n    node_dict = {}\r\n    first_row = True\r\n    rel_type = {}\r\n    r\"\"\"\r\n    ts,head,tail,relation_type\r\n    0,Q331755,Q1294765,P39\r\n    0,Q116233388,Q2566630,P2348\r\n    \"\"\"\r\n    with open(outname, 'w') as f:\r\n        writer = csv.writer(f, delimiter =',')\r\n        writer.writerow(['ts', 'head', 'tail', 'relation_type'])\r\n        with open(fname, 'r') as f:\r\n            reader = csv.reader(f, delimiter =',')\r\n            for row in reader: \r\n                if first_row:\r\n                    first_row = False\r\n                    continue\r\n                ts = int(row[0])\r\n                head = row[1]\r\n                tail = row[2]\r\n                relation_type = row[3]\r\n                if (ts >= start_year and ts <= end_year):\r\n                    if head not in node_dict:\r\n                        node_dict[head] = 1\r\n                    if tail not in node_dict:\r\n                        node_dict[tail] = 1\r\n                    row = [ts, head, tail, relation_type]\r\n                    if (relation_type not in rel_type):\r\n                        rel_type[relation_type] = 1\r\n                    writer.writerow(row)\r\n    print (\"there are \",len(rel_type), \" relation types\")\r\n    return node_dict\r\n\r\n\r\ndef extract_subset_nodeid(fname, outname, start_year=2000, end_year=2024, max_id=1000000):\r\n    node_dict = {}\r\n    first_row = True\r\n    rel_type = {}\r\n    r\"\"\"\r\n    ts,head,tail,relation_type\r\n    0,Q331755,Q1294765,P39\r\n    0,Q116233388,Q2566630,P2348\r\n    \"\"\"\r\n    with open(outname, 'w') as f:\r\n        writer = csv.writer(f, delimiter =',')\r\n        writer.writerow(['ts', 'head', 'tail', 'relation_type'])\r\n        with open(fname, 'r') as f:\r\n            reader = csv.reader(f, delimiter =',')\r\n            for row in reader: \r\n                if first_row:\r\n                    first_row = False\r\n                    continue\r\n                ts = int(row[0])\r\n                head = row[1]\r\n                head_id = int(head[1:])\r\n                tail = row[2]\r\n                tail_id = int(tail[1:])\r\n                if (head_id > max_id or tail_id > max_id):\r\n                    continue\r\n                relation_type = row[3]\r\n                if (ts >= start_year and ts <= end_year):\r\n                    if head not in node_dict:\r\n                        node_dict[head] = 1\r\n                    if tail not in node_dict:\r\n                        node_dict[tail] = 1\r\n                    row = [ts, head, tail, relation_type]\r\n                    if (relation_type not in rel_type):\r\n                        rel_type[relation_type] = 1\r\n                    writer.writerow(row)\r\n    print (\"there are \",len(rel_type), \" relation types\")\r\n    return node_dict\r\n\r\n\r\n\r\n\r\ndef extract_static_subset(fname, outname, node_dict, max_id=1000000):\r\n    r\"\"\"\r\n    extract static edges based a given node dict\r\n    \"\"\"\r\n    first_row = True\r\n    r\"\"\"\r\n    head,tail,relation_type\r\n    Q31,Q1088364,P1344\r\n    Q31,Q3247091,P1151\r\n    \"\"\"\r\n    rel_type = {}\r\n    full_node = {}\r\n    with open(outname, 'w') as f:\r\n        writer = csv.writer(f, delimiter =',')\r\n        writer.writerow(['head', 'tail', 'relation_type'])\r\n        with open(fname, 'r') as f:\r\n            reader = csv.reader(f, delimiter =',')\r\n            for row in reader: \r\n                if first_row:\r\n                    first_row = False\r\n                    continue\r\n                head = row[0]\r\n                head_id = int(head[1:])\r\n                tail = row[1]\r\n                tail_id = int(tail[1:])\r\n                relation_type = row[2]\r\n                if (head_id > max_id or tail_id > max_id):\r\n                    continue\r\n                if (head in node_dict) or (tail in node_dict): #need to check\r\n                    row = [head, tail, relation_type]\r\n                    writer.writerow(row)\r\n                    if (relation_type not in rel_type):\r\n                        rel_type[relation_type] = 1\r\n                    else:\r\n                        rel_type[relation_type] += 1\r\n                    if (head not in full_node):\r\n                        full_node[head] = 1\r\n                    if (tail not in full_node):\r\n                        full_node[tail] = 1\r\n    print (\"there are \",len(rel_type), \" relation types\")\r\n    print (\"there are \",len(full_node), \" nodes in static edgelist\")\r\n    return rel_type\r\n\r\n\r\n\r\n#! not used, filter by top edgetypes \r\ndef subset_static_edges(fname, outname, rel_type, topk=10):\r\n    #* select edges based on frequency\r\n    import operator\r\n    sorted_x = sorted(rel_type.items(), key=operator.itemgetter(1))\r\n    sorted_x = sorted_x[-topk:]\r\n    rel_kept = {}\r\n    for (u,v) in sorted_x:\r\n        rel_kept[u] = 1\r\n        print (u,v)\r\n\r\n    kept_nodes = {}\r\n    first_row = True\r\n\r\n    # rel_kept = {\"P17\":1, \"P27\":1, \"P495\":1, \"P19\": 1}\r\n    with open(outname, 'w') as f:\r\n        writer = csv.writer(f, delimiter =',')\r\n        writer.writerow(['head', 'tail', 'relation_type'])\r\n        with open(fname, 'r') as f:\r\n            reader = csv.reader(f, delimiter =',')\r\n            for row in reader: \r\n                if first_row:\r\n                    first_row = False\r\n                    continue\r\n                head = row[0]\r\n                tail = row[1]\r\n                relation_type = row[2]\r\n                if (relation_type in rel_kept):\r\n                    row = [head, tail, relation_type]\r\n                    if (head not in kept_nodes):\r\n                        kept_nodes[head] = 1\r\n                    if (tail not in kept_nodes):\r\n                        kept_nodes[tail] = 1\r\n                    writer.writerow(row)\r\n    print (\"there are \",len(kept_nodes), \" nodes in static edgelist\")\r\n                \r\n\r\n\r\n\r\n\r\n\r\n\r\ndef main():\r\n\r\n    # #* repeat the edges of start and end dates\r\n    # \"\"\"\r\n    # P580: start time\r\n    # P582: end time\r\n    # P585: point in time\r\n    # P577: publication date\r\n    # P574: year of publication of scientific name for taxon\r\n\r\n    # we need to:\r\n    # 1. get all edges with P585, P577 and P574 \r\n    # 2. find out which edges has both start and end time\r\n    # 3. for those without start time, start at year 0, without end time, end at year 2024\r\n    # \"\"\"\r\n    # fname = \"tkgl-wikidata_edgelist_raw.csv\"\r\n    # out_dict, num_lines = load_time_csv(fname)\r\n\r\n    # outname = \"tkgl-wikidata_edgelist.csv\"\r\n    # write2csv(outname, out_dict)\r\n\r\n    inputfile = \"tkgl-wikidata_edgelist.csv\"\r\n    outname = \"tkgl-smallpedia_edgelist.csv\"\r\n    # start_year = 2015\r\n    start_year=1900#1700\r\n    end_year=2024#1800\r\n    max_id=1000000\r\n    # node_dict = extract_subset(inputfile, outname, start_year=start_year, end_year=end_year)\r\n    node_dict = extract_subset_nodeid(inputfile, outname, start_year=start_year, end_year=end_year, max_id=max_id)\r\n    print (\"there are \",len(node_dict), \" nodes\")\r\n\r\n    inputfile = \"tkgl-wikidata_static_edgelist.csv\"\r\n    outname = \"tkgl-smallpedia_static_edgelist.csv\"\r\n    rel_type = extract_static_subset(inputfile, outname, node_dict, max_id=max_id)\r\n\r\n    #! not used\r\n    # inputfile = \"tkgl-smallpedia_static_edgelist.csv\"\r\n    # outname = \"tkgl-smallpedia_static_edgelist_top10.csv\"\r\n    # topk=10\r\n    # subset_static_edges(inputfile, outname, rel_type, topk=topk)\r\n    \r\n\r\n\r\n\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()"
  },
  {
    "path": "tgb/datasets/tkgl_wikidata/tkgl_wikidata_mining.py",
    "content": "r\"\"\"\r\nHow to use\r\npython tkgl_wikidata.py --chunk 0 --num_chunks 25\r\n# python tkgl_wikidata.py --chunk 1 --num_chunks 25\r\n\"\"\"\r\n\r\nfrom qwikidata.entity import WikidataItem\r\nfrom qwikidata.json_dump import WikidataJsonDump\r\nfrom qwikidata.datavalue import get_datavalue_from_snak_dict, WikibaseEntityId\r\nfrom tqdm import tqdm\r\nfrom collections import defaultdict\r\nimport os.path as osp\r\nimport os\r\nimport pickle\r\nimport argparse\r\nimport numpy as np\r\nimport csv\r\n\r\ndef timeEdgeWrite2csv(outname, out_dict):\r\n    with open(outname, 'w') as f:\r\n        writer = csv.writer(f, delimiter =',')\r\n        writer.writerow(['timestamp', 'head', 'tail', 'relation_type', 'time_rel_type'])\r\n        for edge in out_dict.keys():\r\n            ts = edge[0]\r\n            src = edge[1]\r\n            dst = edge[2]\r\n            rel_type = edge[3]\r\n            time_rel_type = edge[4]\r\n            row = [ts, src, dst, rel_type, time_rel_type]\r\n            writer.writerow(row)\r\n\r\n\r\ndef EdgeWrite2csv(outname, out_dict):\r\n    with open(outname, 'w') as f:\r\n        writer = csv.writer(f, delimiter =',')\r\n        writer.writerow(['head', 'tail', 'relation_type'])\r\n        for edge in out_dict.keys():\r\n            src = edge[0]\r\n            dst = edge[1]\r\n            rel_type = edge[2]\r\n            row = [src, dst, rel_type]\r\n            writer.writerow(row)\r\n\r\n\r\ndef main():\r\n    parser = argparse.ArgumentParser(description='Process some integers.')\r\n    # parser.add_argument('--split', type=str, default = 'train',\r\n    #                     help='an integer for the accumulator')\r\n    parser.add_argument('--chunk', type=int, default = 0,\r\n                        help='an integer for the accumulator')\r\n    parser.add_argument('--num_chunks', type=int, default = 10,\r\n                        help='an integer for the accumulator')\r\n\r\n    args = parser.parse_args()\r\n    print(args)\r\n    assert args.chunk < args.num_chunks\r\n\r\n    # # create an instance of WikidataJsonDump\r\n    # if args.split == 'train':\r\n    #     wjd_dump_path_original = \"wikidata-20210517-all.json.gz\"\r\n    # elif args.split == 'val':\r\n    #     wjd_dump_path_original = \"wikidata-20210607-all.json.gz\"\r\n    # elif args.split == 'test':\r\n    #     wjd_dump_path_original = 'wikidata-20210628-all.json.gz'\r\n    # else:\r\n    #     raise ValueError('Unknown split')\r\n\r\n    #* download here\r\n    #? https://dumps.wikimedia.org/wikidatawiki/entities/\r\n\r\n    wjd_dump_path_original = \"wikidata-20240220-all.json.gz\" #\"latest-all_03_Apr_2024_12_49.json\" \r\n\r\n    wjd_dump_path = osp.join('dump', wjd_dump_path_original)\r\n    wjd = WikidataJsonDump(wjd_dump_path)\r\n\r\n    print(wjd_dump_path)\r\n\r\n    \r\n\r\n    \"\"\"\r\n    # head = entity_dict['id']\r\n    # type = entity_dict['type']\r\n    # labels = entity_dict['labels']\r\n    # descriptions = entity_dict['descriptions']\r\n    # aliases = entity_dict['aliases']\r\n    # if ('claims' in entity_dict):\r\n    #     claims = entity_dict['claims']\r\n    #     for key in claims.keys():\r\n    #         print (key)\r\n    #         print (claims[key])\r\n    # sitelinks = entity_dict['sitelinks']\r\n    \"\"\"\r\n\r\n    time_edge_dict = {} #{()}\r\n    time_rel_dict = {}\r\n    static_edge_dict = {}\r\n    dummy_rel_set = ['P31','P279']  #filter out instance of and subclass of\r\n    time_rel_set = ['P585','P580', 'P582', 'P577', 'P574']  #point in time, start time, end time, publication date,year of publication of scientific name for taxon\r\n\r\n    num_totals = 100000000 #4000000 #10000000 #110000000\r\n\r\n    tmp = np.linspace(0, num_totals, args.num_chunks + 1).astype(np.int64)\r\n    start_idx = tmp[args.chunk]\r\n    end_idx = tmp[args.chunk + 1]\r\n    print('Start: ', start_idx)\r\n    print('End: ', end_idx)\r\n\r\n\r\n    #? output format is (timestamp, head, tail, relation_type, time_rel_type)\r\n    for i, entity_dict in enumerate(tqdm(wjd, total=(end_idx))):\r\n        #! entity_dict keys(['type', 'id', 'labels', 'descriptions', 'aliases', 'claims', 'sitelinks', 'pageid', 'ns', 'title', 'lastrevid', 'modified'])\r\n        if i > end_idx:\r\n            break\r\n\r\n        if not (start_idx <= i and i < end_idx):\r\n            continue\r\n\r\n        head = entity_dict['id']\r\n\r\n        # head needs to start from 'Q'\r\n        if head[0] == 'Q':\r\n            head_id = head\r\n            if 'claims' in entity_dict:\r\n                claim_dict = entity_dict['claims']\r\n                rel_list = list(claim_dict.keys())\r\n                for rel in rel_list:\r\n                    if (rel in dummy_rel_set):\r\n                        continue\r\n                    tail_list = claim_dict[rel]\r\n                    for tail in tail_list:\r\n                        tail_id = None\r\n                        #* first check if there is a valid tail\r\n                        if (tail['mainsnak']['datatype'] == 'wikibase-item'):\r\n                            if ('rank' in tail) and (tail['rank'] != 'deprecated') and ('datavalue' in tail['mainsnak']):\r\n                                if 'id' in tail['mainsnak']['datavalue']['value']:\r\n                                    tail_id = tail['mainsnak']['datavalue']['value']['id']\r\n                                else: \r\n                                    tail_id = 'Q' + str(tail['mainsnak']['datavalue']['value']['numeric-id'])\r\n\r\n                        #* check if there is a qualifier and if it is a time qualifier\r\n                        if (tail_id is not None):\r\n                            if (\"qualifiers\" in tail):\r\n                                time_logged = False\r\n                                for q in tail[\"qualifiers\"]:\r\n                                    for item in tail[\"qualifiers\"][q]:\r\n                                        if (item['datatype'] == 'time') and ('datavalue' in item):\r\n                                            timestr = item['datavalue']['value']['time']\r\n                                            time_rel_type = q\r\n                                            if (time_rel_type in time_rel_set):\r\n                                                time_edge_dict[(timestr, head_id, tail_id, rel, time_rel_type)] = 1\r\n                                                time_logged = True\r\n                                            else:\r\n                                                time_logged = False\r\n                                if not time_logged:\r\n                                    static_edge_dict[(head_id, tail_id, rel)] = 1\r\n                            else:\r\n                                static_edge_dict[(head_id, tail_id, rel)] = 1\r\n\r\n    #! write edges to file\r\n    print (\"there are \", len(time_edge_dict), \" temporal edges in the dataset\")\r\n    outname = \"time_edgelist_\" + str(args.chunk) + \".csv\" #\"tkgl-wikidata_time_edgelist.csv\"\r\n    timeEdgeWrite2csv(outname, time_edge_dict)\r\n\r\n\r\n    print (\"there are \", len(static_edge_dict), \" static edges in the dataset\")\r\n    outname = \"static_edgelist_\" + str(args.chunk) + \".csv\" #\"tkgl-wikidata_static_edgelist.csv\"\r\n    EdgeWrite2csv(outname, static_edge_dict)\r\n    \r\n\r\n                \r\n                    \r\nif __name__ == '__main__':\r\n\r\n    main()"
  },
  {
    "path": "tgb/datasets/tkgl_wikidata/tkgl_wikidata_ns_gen.py",
    "content": "import time\r\nimport sys\r\nimport os\r\nimport os.path as osp\r\nfrom pathlib import Path\r\nmodules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))\r\nsys.path.append(modules_path)\r\nfrom tgb.linkproppred.tkg_negative_generator import TKGNegativeEdgeGenerator\r\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\r\n\r\n\r\ndef main():\r\n    r\"\"\"\r\n    Generate negative edges for the validation or test phase\r\n    \"\"\"\r\n    print(\"*** Negative Sample Generation ***\")\r\n\r\n    # setting the required parameters\r\n    num_neg_e_per_pos = 1000 #10000\r\n    neg_sample_strategy = \"dst-time-filtered\" #\"time-filtered\"\r\n    rnd_seed = 42\r\n\r\n\r\n    name = \"tkgl-wikidata\"\r\n    dataset = PyGLinkPropPredDataset(name=name, root=\"datasets\")\r\n    train_mask = dataset.train_mask\r\n    val_mask = dataset.val_mask\r\n    test_mask = dataset.test_mask\r\n    data = dataset.get_TemporalData()\r\n\r\n\r\n\r\n    data_splits = {}\r\n    data_splits['train'] = data[train_mask]\r\n    data_splits['val'] = data[val_mask]\r\n    data_splits['test'] = data[test_mask]\r\n\r\n    # Ensure to only sample actual destination nodes as negatives.\r\n    min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\r\n\r\n\r\n    neg_sampler = TKGNegativeEdgeGenerator(\r\n        dataset_name=name,\r\n        first_dst_id=min_dst_idx,\r\n        last_dst_id=max_dst_idx,\r\n        num_neg_e=num_neg_e_per_pos,\r\n        strategy=neg_sample_strategy,\r\n        rnd_seed=rnd_seed,\r\n        partial_path=\".\",\r\n        edge_data=data,\r\n    )\r\n\r\n    # generate evaluation set\r\n    partial_path = \".\"\r\n    # generate validation negative edge set\r\n    start_time = time.time()\r\n    split_mode = \"val\"\r\n    print(\r\n        f\"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}\"\r\n    )\r\n    neg_sampler.generate_negative_samples(\r\n        pos_edges=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path\r\n    )\r\n    print(\r\n        f\"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}\"\r\n    )\r\n\r\n    # generate test negative edge set\r\n    start_time = time.time()\r\n    split_mode = \"test\"\r\n    print(\r\n        f\"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}\"\r\n    )\r\n    neg_sampler.generate_negative_samples(\r\n        pos_edges=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path\r\n    )\r\n    print(\r\n        f\"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}\"\r\n    )\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n"
  },
  {
    "path": "tgb/datasets/tkgl_wikidata/wikidata_remove_conflict.py",
    "content": "import csv\r\n\r\ndef load_static_edgelist(file_path):\r\n    r\"\"\"\r\n    Load the static edgelist from the file_path\r\n    Args:\r\n        file_path: str, The path to the file\r\n    \"\"\"\r\n    static_dict = {}\r\n    first_row = True\r\n    with open(file_path, 'r') as f:\r\n        reader = csv.reader(f, delimiter =',')\r\n        for row in reader: \r\n            if first_row:\r\n                first_row = False\r\n                continue\r\n            head = row[0]\r\n            tail = row[1]\r\n            relation_type = row[2]\r\n            static_dict[(head, tail, relation_type)] = 1\r\n    return static_dict\r\n\r\n\r\ndef load_temporal_edgelist(file_path):\r\n    r\"\"\"\r\n    Load the temporal edgelist from the file_path\r\n    Args:\r\n        file_path: str, The path to the file\r\n    \"\"\"\r\n    temporal_dict = {}\r\n    first_row = True\r\n    with open(file_path, 'r') as f:\r\n        \"\"\"\r\n        ts,head,tail,relation_type\r\n        0,Q331755,Q1294765,P39\r\n        \"\"\"\r\n        reader = csv.reader(f, delimiter =',')\r\n        for row in reader: \r\n            if first_row:\r\n                first_row = False\r\n                continue\r\n            ts = int(row[0])\r\n            head = row[1]\r\n            tail = row[2]\r\n            relation_type = row[3]\r\n            if ts not in temporal_dict:\r\n                temporal_dict[ts] = {}\r\n                temporal_dict[ts][(head, tail, relation_type)] = 1\r\n            else:\r\n                if (head, tail, relation_type) in temporal_dict[ts]:\r\n                    temporal_dict[ts][(head, tail, relation_type)] += 1\r\n                else:\r\n                    temporal_dict[ts][(head, tail, relation_type)] = 1\r\n    return temporal_dict\r\n\r\n\r\ndef remove_conflict(static_dict, temporal_dict):\r\n    r\"\"\"\r\n    Remove the conflict between the static and temporal edgelist\r\n    Args:\r\n        static_dict: dict, The static edgelist\r\n        temporal_dict: dict, The temporal edgelist\r\n    \"\"\"\r\n    num_conflicts = 0\r\n    for ts in temporal_dict:\r\n        for edge in temporal_dict[ts]:\r\n            head = edge[0]\r\n            tail = edge[1]\r\n            relation_type = edge[2]\r\n            if (head, tail, relation_type) in static_dict:\r\n                num_conflicts += 1\r\n                static_dict.pop((head, tail, relation_type))\r\n    print(\"Removed {} conflicts\".format(num_conflicts))\r\n    return static_dict\r\n\r\n\r\ndef write2csv(outname: str, \r\n              out_dict: dict,):\r\n    r\"\"\"\r\n    Write the dictionary to a csv file\r\n    \"\"\"\r\n    with open(outname, 'w') as f:\r\n        writer = csv.writer(f, delimiter =',')\r\n        #head,tail,relation_type\r\n        writer.writerow(['head', 'tail', 'relation_type'])\r\n        for edge in out_dict:\r\n            head = edge[0]\r\n            tail = edge[1]\r\n            relation_type = edge[2]\r\n            row = [head, tail, relation_type]\r\n            writer.writerow(row)\r\n\r\n\r\n\r\n\r\ndef main():\r\n    #! remove conflict: remove all edges with the same head, tail, relation_type from the static edgelist\r\n    static_file = \"tkgl-wikidata_static_edgelist.csv\"\r\n    temporal_file = \"tkgl-wikidata_edgelist.csv\"\r\n    static_dict = load_static_edgelist(static_file)\r\n    print(\"constructed static dictionary\")\r\n    temporal_dict = load_temporal_edgelist(temporal_file)\r\n    print(\"constructed temporal dictionary\")\r\n    static_dict = remove_conflict(static_dict, temporal_dict)\r\n    out_name = \"tkgl-wikidata_static_edgelist_no_conflict.csv\"\r\n    write2csv(out_name, static_dict)\r\n\r\n\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()"
  },
  {
    "path": "tgb/datasets/tkgl_yago/tkgl_yago.py",
    "content": "import csv\r\nimport datetime\r\nimport glob, os\r\n\r\n\r\n\r\ndef main():\r\n    train_fname = \"train.txt\"\r\n    val_fname = \"valid.txt\"\r\n    test_fname = \"test.txt\"\r\n    \r\n    train_dict, num_lines = load_csv(train_fname)\r\n    print (\"there are \", num_lines, \" lines in the train file\")\r\n    print (\"there are \", len(train_dict), \" timestamps in the train file\")\r\n    val_dict, num_lines = load_csv(val_fname)\r\n    print (\"there are \", num_lines, \" lines in the val file\")\r\n    print (\"there are \", len(val_dict), \" timestamps in the val file\")\r\n    test_dict, num_lines = load_csv(test_fname)\r\n    print (\"there are \", num_lines, \" lines in the test file\")\r\n    print (\"there are \", len(test_dict), \" timestamps in the test file\")\r\n\r\n    train_dict.update(val_dict)\r\n    train_dict.update(test_dict)\r\n    print (\"there are \", len(train_dict), \" timestamps in the combined file\")\r\n\r\n    outname = \"tkgl-yago_edgelist.csv\"\r\n    write_csv(outname, train_dict)\r\n\r\n\r\ndef write_csv(outname, out_dict):\r\n    with open(outname, 'w') as f:\r\n        writer = csv.writer(f, delimiter =',')\r\n        writer.writerow(['timestamp', 'head', 'tail', 'relation_type'])\r\n        for ts in out_dict:\r\n            for edge in out_dict[ts]:\r\n                src = edge[0]\r\n                rel_type = edge[1]\r\n                dst = edge[2]\r\n                row = [ts, src, dst, rel_type]\r\n                writer.writerow(row)\r\n\r\n\r\ndef load_csv(fname):\r\n    out_dict = {}\r\n    num_lines = 0\r\n    with open(fname, 'r') as f:\r\n        reader = csv.reader(f, delimiter ='\\t')\r\n        #! src rel_type dst ts \r\n        # 10289\t9\t10290\t0\t0\r\n        for row in reader: \r\n            src = int (row[0])\r\n            rel_type = int (row[1])\r\n            dst = int (row[2])\r\n            ts = int (row[3])\r\n            if ts not in out_dict:\r\n                out_dict[ts] = {(src,rel_type,dst):1}\r\n            else:\r\n                out_dict[ts][(src,rel_type,dst)] = 1\r\n            num_lines += 1\r\n    return out_dict, num_lines\r\n\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()"
  },
  {
    "path": "tgb/datasets/tkgl_yago/tkgl_yago_ns_gen.py",
    "content": "import time\r\nimport sys\r\nsys.path.insert(0,'/../../../')\r\nfrom tgb.linkproppred.tkg_negative_generator import TKGNegativeEdgeGenerator\r\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\r\n\r\n\r\ndef main():\r\n    r\"\"\"\r\n    Generate negative edges for the validation or test phase\r\n    \"\"\"\r\n    print(\"*** Negative Sample Generation ***\")\r\n\r\n    # setting the required parameters\r\n    num_neg_e_per_pos = -1 \r\n    neg_sample_strategy = \"time-filtered\"\r\n    rnd_seed = 42\r\n\r\n\r\n    name = \"tkgl-yago\"\r\n    dataset = PyGLinkPropPredDataset(name=name, root=\"datasets\")\r\n    train_mask = dataset.train_mask\r\n    val_mask = dataset.val_mask\r\n    test_mask = dataset.test_mask\r\n    data = dataset.get_TemporalData()\r\n\r\n\r\n\r\n    data_splits = {}\r\n    data_splits['train'] = data[train_mask]\r\n    data_splits['val'] = data[val_mask]\r\n    data_splits['test'] = data[test_mask]\r\n\r\n    # Ensure to only sample actual destination nodes as negatives.\r\n    min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())\r\n\r\n\r\n    neg_sampler = TKGNegativeEdgeGenerator(\r\n        dataset_name=name,\r\n        first_dst_id=min_dst_idx,\r\n        last_dst_id=max_dst_idx,\r\n        num_neg_e=num_neg_e_per_pos,\r\n        strategy=neg_sample_strategy,\r\n        rnd_seed=rnd_seed,\r\n        edge_data=data,\r\n    )\r\n\r\n    # generate evaluation set\r\n    partial_path = \".\"\r\n    # generate validation negative edge set\r\n    start_time = time.time()\r\n    split_mode = \"val\"\r\n    print(\r\n        f\"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}\"\r\n    )\r\n    neg_sampler.generate_negative_samples(\r\n        pos_edges=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path\r\n    )\r\n    print(\r\n        f\"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}\"\r\n    )\r\n\r\n    # generate test negative edge set\r\n    start_time = time.time()\r\n    split_mode = \"test\"\r\n    print(\r\n        f\"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}\"\r\n    )\r\n    neg_sampler.generate_negative_samples(\r\n        pos_edges=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path\r\n    )\r\n    print(\r\n        f\"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}\"\r\n    )\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n"
  },
  {
    "path": "tgb/linkproppred/dataset.py",
    "content": "import sys\r\n\r\nfrom typing import Optional, Dict, Any, Tuple\r\nimport os\r\nimport os.path as osp\r\nimport numpy as np\r\nimport pandas as pd\r\nimport zipfile\r\nimport requests\r\nfrom clint.textui import progress\r\n\r\n\r\nfrom tgb.linkproppred.negative_sampler import NegativeEdgeSampler\r\nfrom tgb.linkproppred.tkg_negative_sampler import TKGNegativeEdgeSampler\r\nfrom tgb.linkproppred.thg_negative_sampler import THGNegativeEdgeSampler\r\nfrom tgb.utils.info import (\r\n    PROJ_DIR, \r\n    DATA_URL_DICT, \r\n    DATA_VERSION_DICT, \r\n    DATA_EVAL_METRIC_DICT, \r\n    DATA_NS_STRATEGY_DICT,\r\n    BColors\r\n)\r\nfrom tgb.utils.pre_process import (\r\n    csv_to_pd_data,\r\n    process_node_feat,\r\n    process_node_type,\r\n    csv_to_pd_data_sc,\r\n    csv_to_pd_data_rc,\r\n    load_edgelist_wiki,\r\n    csv_to_tkg_data,\r\n    csv_to_thg_data,\r\n    csv_to_forum_data,\r\n    csv_to_wikidata,\r\n    csv_to_staticdata,\r\n)\r\nfrom tgb.utils.utils import save_pkl, load_pkl\r\nfrom tgb.utils.utils import add_inverse_quadruples, vprint\r\n\r\n\r\n\r\nclass LinkPropPredDataset(object):\r\n    def __init__(\r\n        self,\r\n        name: str,\r\n        root: str = \"datasets\",\r\n        meta_dict: Optional[dict] = None,\r\n        preprocess: Optional[bool] = True,\r\n        download: Optional[bool] = True, \r\n    ):\r\n        r\"\"\"Dataset class for link prediction dataset. Stores meta information about each dataset such as evaluation metrics etc.\r\n        also automatically pre-processes the dataset.\r\n        Args:\r\n            name: name of the dataset\r\n            root: root directory to store the dataset folder\r\n            meta_dict: dictionary containing meta information about the dataset, should contain key 'dir_name' which is the name of the dataset folder\r\n            preprocess: whether to pre-process the dataset\r\n            download: whether to download the dataset (default: true)\r\n        \"\"\"\r\n        self.name = name  ## original name\r\n        # check if dataset url exist\r\n        if self.name in DATA_URL_DICT:\r\n            self.url = DATA_URL_DICT[self.name]\r\n        else:\r\n            self.url = None\r\n\r\n        \r\n        # check if the evaluatioin metric are specified\r\n        if self.name in DATA_EVAL_METRIC_DICT:\r\n            self.metric = DATA_EVAL_METRIC_DICT[self.name]\r\n        else:\r\n            self.metric = None\r\n            raise ValueError(f\"Dataset {self.name} default evaluation metric not found, it is not supported yet.\")\r\n        \r\n        root = PROJ_DIR + root\r\n\r\n        if meta_dict is None:\r\n            self.dir_name = \"_\".join(name.split(\"-\"))  ## replace hyphen with underline\r\n            meta_dict = {\"dir_name\": self.dir_name}\r\n        else:\r\n            self.dir_name = meta_dict[\"dir_name\"]\r\n        self.root = osp.join(root, self.dir_name)\r\n        self.meta_dict = meta_dict\r\n        if \"fname\" not in self.meta_dict:\r\n            self.meta_dict[\"fname\"] = self.root + \"/\" + self.name + \"_edgelist.csv\"\r\n            self.meta_dict[\"nodefile\"] = None\r\n\r\n        if name == \"tgbl-flight\":\r\n            self.meta_dict[\"nodefile\"] = self.root + \"/\" + \"airport_node_feat.csv\"\r\n\r\n        if name == \"tkgl-wikidata\" or name == \"tkgl-smallpedia\":\r\n            self.meta_dict[\"staticfile\"] = self.root + \"/\" + self.name + \"_static_edgelist.csv\"\r\n        \r\n        if \"thg\" in name:\r\n            self.meta_dict[\"nodeTypeFile\"] = self.root + \"/\" + self.name + \"_nodetype.csv\"\r\n        else:\r\n            self.meta_dict[\"nodeTypeFile\"] = None\r\n        \r\n        self.meta_dict[\"val_ns\"] = self.root + \"/\" + self.name + \"_val_ns.pkl\"\r\n        self.meta_dict[\"test_ns\"] = self.root + \"/\" + self.name + \"_test_ns.pkl\"\r\n\r\n        #! version check\r\n        self.version_passed = True\r\n        self._version_check()\r\n\r\n        # initialize\r\n        self._node_feat = None\r\n        self._edge_feat = None\r\n        self._full_data = None\r\n        self._train_data = None\r\n        self._val_data = None\r\n        self._test_data = None\r\n\r\n        # for tkg and thg\r\n        self._edge_type = None\r\n\r\n        #tkgl-wikidata and tkgl-smallpedia only\r\n        self._static_data = None\r\n\r\n        # for thg only\r\n        self._node_type = None\r\n        self._node_id = None\r\n\r\n        if download:\r\n            self.download()\r\n        else:\r\n            if osp.exists(self.meta_dict[\"fname\"]):\r\n                dir_name = self.meta_dict[\"fname\"]\r\n                vprint(f\"files found in {dir_name}\")\r\n            else:\r\n                dir_name = self.meta_dict[\"fname\"]\r\n                raise FileNotFoundError(f\"Directory not found at {dir_name}, please download the dataset\")\r\n            \r\n        \r\n        # check if the root directory exists, if not create it\r\n        if osp.isdir(self.root):\r\n            vprint(\"Dataset directory is \", self.root)\r\n        else:\r\n            raise FileNotFoundError(f\"Directory not found at {self.root}\")\r\n\r\n        if preprocess:\r\n            self.pre_process()\r\n\r\n        self.min_dst_idx, self.max_dst_idx = int(self._full_data[\"destinations\"].min()), int(self._full_data[\"destinations\"].max())\r\n\r\n        if ('tkg' in self.name):\r\n            if self.name in DATA_NS_STRATEGY_DICT:\r\n                self.ns_sampler = TKGNegativeEdgeSampler(\r\n                    dataset_name=self.name,\r\n                    first_dst_id=self.min_dst_idx,\r\n                    last_dst_id=self.max_dst_idx,\r\n                    strategy=DATA_NS_STRATEGY_DICT[self.name],\r\n                    partial_path=self.root + \"/\" + self.name,\r\n                )\r\n            else:\r\n                raise ValueError(f\"Dataset {self.name} negative sampling strategy not found.\")\r\n        elif ('thg' in self.name):\r\n            #* need to find the smallest node id of all nodes (regardless of types)\r\n            \r\n            min_node_idx = min(int(self._full_data[\"sources\"].min()), int(self._full_data[\"destinations\"].min()))\r\n            max_node_idx = max(int(self._full_data[\"sources\"].max()), int(self._full_data[\"destinations\"].max()))\r\n            self.ns_sampler = THGNegativeEdgeSampler(\r\n                dataset_name=self.name,\r\n                first_node_id=min_node_idx,\r\n                last_node_id=max_node_idx,\r\n                node_type=self._node_type,\r\n            )\r\n        else:\r\n            self.ns_sampler = NegativeEdgeSampler(\r\n                dataset_name=self.name,\r\n                first_dst_id=self.min_dst_idx,\r\n                last_dst_id=self.max_dst_idx,\r\n            )\r\n\r\n\r\n    def _version_check(self) -> None:\r\n        r\"\"\"Implement Version checks for dataset files\r\n        updates the file names based on the current version number\r\n        prompt the user to download the new version via self.version_passed variable\r\n        \"\"\"\r\n        if (self.name in DATA_VERSION_DICT):\r\n            version = DATA_VERSION_DICT[self.name]\r\n        else:\r\n            raise ValueError(f\"Dataset {self.name} version number not found.\")\r\n        \r\n        if (version > 1):\r\n            #* check if current version is outdated\r\n            self.meta_dict[\"fname\"] = self.root + \"/\" + self.name + \"_edgelist_v\" + str(int(version)) + \".csv\"\r\n            self.meta_dict[\"nodefile\"] = None\r\n            if self.name == \"tgbl-flight\":\r\n                self.meta_dict[\"nodefile\"] = self.root + \"/\" + \"airport_node_feat_v\" + str(int(version)) + \".csv\"\r\n            self.meta_dict[\"val_ns\"] = self.root + \"/\" + self.name + \"_val_ns_v\" + str(int(version)) + \".pkl\"\r\n            self.meta_dict[\"test_ns\"] = self.root + \"/\" + self.name + \"_test_ns_v\" + str(int(version)) + \".pkl\"\r\n            \r\n            if (not osp.exists(self.meta_dict[\"fname\"])):\r\n                vprint(f\"Dataset {self.name} version {int(version)} not found, Please download the latest version of the dataset.\")\r\n                self.version_passed = False\r\n                return None\r\n        \r\n\r\n    def download(self) -> None:\r\n        \"\"\"\r\n        downloads this dataset from url\r\n        check if files are already downloaded\r\n        \"\"\"\r\n        # check if the file already exists\r\n        if osp.exists(self.meta_dict[\"fname\"]):\r\n            dir_name = self.meta_dict[\"fname\"]\r\n            vprint(f\"files found in {dir_name}\")\r\n            return None\r\n\r\n        vprint(\r\n            f\"{BColors.WARNING}Download started, this might take a while . . . {BColors.ENDC}\"\r\n        )\r\n        vprint(f\"Dataset title: {self.name}\")\r\n\r\n        if self.url is None:\r\n            raise ValueError(f\"Dataset {self.name} url not found, download not supported yet.\")\r\n        else:\r\n            r = requests.get(self.url, stream=True)\r\n            if osp.isdir(self.root):\r\n                vprint(\"Dataset directory is \", self.root)\r\n            else:\r\n                os.makedirs(self.root)\r\n\r\n            path_download = self.root + \"/\" + self.name + \".zip\"\r\n            print(f\"downloading Dataset: {self.name} to {path_download}\")\r\n            with open(path_download, \"wb\") as f:\r\n                total_length = int(r.headers.get(\"content-length\"))\r\n                for chunk in progress.bar(\r\n                    r.iter_content(chunk_size=1024),\r\n                    expected_size=(total_length / 1024) + 1,\r\n                ):\r\n                    if chunk:\r\n                        f.write(chunk)\r\n                        f.flush()\r\n            # for unzipping the file\r\n            with zipfile.ZipFile(path_download, \"r\") as zip_ref:\r\n                zip_ref.extractall(self.root)\r\n            vprint(f\"{BColors.OKGREEN}Download completed {BColors.ENDC}\")\r\n            self.version_passed = True\r\n      \r\n\r\n    def generate_processed_files(self) -> pd.DataFrame:\r\n        r\"\"\"\r\n        turns raw data .csv file into a pandas data frame, stored on disc if not already\r\n        Returns:\r\n            df: pandas data frame\r\n        \"\"\"\r\n        node_feat = None\r\n        if not osp.exists(self.meta_dict[\"fname\"]):\r\n            raise FileNotFoundError(f\"File not found at {self.meta_dict['fname']}\")\r\n\r\n        if self.meta_dict[\"nodefile\"] is not None:\r\n            if not osp.exists(self.meta_dict[\"nodefile\"]):\r\n                raise FileNotFoundError(\r\n                    f\"File not found at {self.meta_dict['nodefile']}\"\r\n                )\r\n        #* for thg must have nodetypes \r\n        if self.meta_dict[\"nodeTypeFile\"] is not None:\r\n            if not osp.exists(self.meta_dict[\"nodeTypeFile\"]):\r\n                raise FileNotFoundError(\r\n                    f\"File not found at {self.meta_dict['nodeTypeFile']}\"\r\n                )\r\n\r\n\r\n        OUT_DF = self.root + \"/\" + \"ml_{}.pkl\".format(self.name)\r\n        OUT_EDGE_FEAT = self.root + \"/\" + \"ml_{}.pkl\".format(self.name + \"_edge\")\r\n        OUT_NODE_ID = self.root + \"/\" + \"ml_{}.pkl\".format(self.name + \"_nodeid\")\r\n        if self.meta_dict[\"nodefile\"] is not None:\r\n            OUT_NODE_FEAT = self.root + \"/\" + \"ml_{}.pkl\".format(self.name + \"_node\")\r\n        if self.meta_dict[\"nodeTypeFile\"] is not None:\r\n            OUT_NODE_TYPE = self.root + \"/\" + \"ml_{}.pkl\".format(self.name + \"_nodeType\")\r\n\r\n        if osp.exists(OUT_DF) and self.version_passed is True:\r\n            vprint(f\"loading processed file from {OUT_DF}.\")\r\n            df = pd.read_pickle(OUT_DF)\r\n            edge_feat = load_pkl(OUT_EDGE_FEAT)\r\n            if (self.name == \"tkgl-wikidata\") or (self.name == \"tkgl-smallpedia\"):\r\n                node_id = load_pkl(OUT_NODE_ID)\r\n                self._node_id = node_id\r\n            if self.meta_dict[\"nodefile\"] is not None:\r\n                node_feat = load_pkl(OUT_NODE_FEAT)\r\n            if self.meta_dict[\"nodeTypeFile\"] is not None:\r\n                node_type = load_pkl(OUT_NODE_TYPE)\r\n                self._node_type = node_type\r\n\r\n        else:\r\n            vprint(\"file not processed, generating processed file\")\r\n            if self.name == \"tgbl-flight\":\r\n                df, edge_feat, node_ids = csv_to_pd_data(self.meta_dict[\"fname\"])\r\n            elif self.name == \"tgbl-coin\":\r\n                df, edge_feat, node_ids = csv_to_pd_data_sc(self.meta_dict[\"fname\"])\r\n            elif self.name == \"tgbl-comment\":\r\n                df, edge_feat, node_ids = csv_to_pd_data_rc(self.meta_dict[\"fname\"])\r\n            elif self.name == \"tgbl-review\":\r\n                df, edge_feat, node_ids = csv_to_pd_data_sc(self.meta_dict[\"fname\"])\r\n            elif self.name == \"tgbl-wiki\":\r\n                df, edge_feat, node_ids = load_edgelist_wiki(self.meta_dict[\"fname\"])\r\n            elif self.name == \"tgbl-subreddit\":\r\n                df, edge_feat, node_ids = load_edgelist_wiki(self.meta_dict[\"fname\"])\r\n            elif self.name == \"tgbl-uci\":\r\n                df, edge_feat, node_ids = load_edgelist_wiki(self.meta_dict[\"fname\"])\r\n            elif self.name == \"tgbl-enron\":\r\n                df, edge_feat, node_ids = load_edgelist_wiki(self.meta_dict[\"fname\"])\r\n            elif self.name == \"tgbl-lastfm\":\r\n                df, edge_feat, node_ids = load_edgelist_wiki(self.meta_dict[\"fname\"])\r\n            elif self.name == \"tkgl-polecat\":\r\n                df, edge_feat, node_ids = csv_to_tkg_data(self.meta_dict[\"fname\"])\r\n            elif self.name == \"tkgl-icews\":\r\n                df, edge_feat, node_ids = csv_to_tkg_data(self.meta_dict[\"fname\"])\r\n            elif self.name == \"tkgl-yago\":\r\n                df, edge_feat, node_ids = csv_to_tkg_data(self.meta_dict[\"fname\"])\r\n            elif self.name == \"tkgl-wikidata\":\r\n                df, edge_feat, node_ids = csv_to_wikidata(self.meta_dict[\"fname\"])\r\n                save_pkl(node_ids, OUT_NODE_ID)\r\n                self._node_id = node_ids\r\n            elif self.name == \"tkgl-smallpedia\":\r\n                df, edge_feat, node_ids = csv_to_wikidata(self.meta_dict[\"fname\"])\r\n                save_pkl(node_ids, OUT_NODE_ID)\r\n                self._node_id = node_ids\r\n            elif self.name == \"thgl-myket\":\r\n                df, edge_feat, node_ids = csv_to_thg_data(self.meta_dict[\"fname\"])\r\n            elif self.name == \"thgl-github\":\r\n                df, edge_feat, node_ids = csv_to_thg_data(self.meta_dict[\"fname\"])\r\n            elif self.name == \"thgl-forum\":\r\n                df, edge_feat, node_ids = csv_to_forum_data(self.meta_dict[\"fname\"])\r\n            elif self.name == \"thgl-software\":\r\n                df, edge_feat, node_ids = csv_to_thg_data(self.meta_dict[\"fname\"])\r\n            else:\r\n                raise ValueError(f\"Dataset {self.name} not found.\")\r\n\r\n            save_pkl(edge_feat, OUT_EDGE_FEAT)\r\n            df.to_pickle(OUT_DF)\r\n            if self.meta_dict[\"nodefile\"] is not None:\r\n                node_feat = process_node_feat(self.meta_dict[\"nodefile\"], node_ids)\r\n                save_pkl(node_feat, OUT_NODE_FEAT)\r\n            if self.meta_dict[\"nodeTypeFile\"] is not None:\r\n                node_type = process_node_type(self.meta_dict[\"nodeTypeFile\"], node_ids)\r\n                save_pkl(node_type, OUT_NODE_TYPE)\r\n                #? do not return node_type, simply set it\r\n                self._node_type = node_type\r\n            \r\n\r\n        return df, edge_feat, node_feat\r\n\r\n    def pre_process(self):\r\n        \"\"\"\r\n        Pre-process the dataset and generates the splits, must be run before dataset properties can be accessed\r\n        generates the edge data and different train, val, test splits\r\n        \"\"\"\r\n\r\n        # check if path to file is valid\r\n        df, edge_feat, node_feat = self.generate_processed_files()\r\n\r\n        #* design choice, only stores the original edges not the inverse relations on disc\r\n        if (\"tkgl\" in self.name):\r\n            df = add_inverse_quadruples(df)\r\n\r\n        sources = np.array(df[\"u\"])\r\n        destinations = np.array(df[\"i\"])\r\n        timestamps = np.array(df[\"ts\"])\r\n        edge_idxs = np.array(df[\"idx\"])\r\n        weights = np.array(df[\"w\"])\r\n        edge_label = np.ones(len(df))  # should be 1 for all pos edges\r\n        if (self.name == \"tgbl-coin\") or (self.name == \"tgbl-review\"):\r\n            self._edge_feat = weights.reshape(-1,1)\r\n        elif (self.name == \"tgbl-comment\"):\r\n            self._edge_feat = np.concatenate((edge_feat, weights.reshape(-1,1)), axis=1)\r\n        else:\r\n            self._edge_feat = edge_feat\r\n        self._node_feat = node_feat\r\n\r\n        full_data = {\r\n            \"sources\": sources.astype(int),\r\n            \"destinations\": destinations.astype(int),\r\n            \"timestamps\": timestamps.astype(int),\r\n            \"edge_idxs\": edge_idxs,\r\n            \"edge_feat\": self._edge_feat,\r\n            \"w\": weights,\r\n            \"edge_label\": edge_label,\r\n        }\r\n\r\n        #* for tkg and thg\r\n        if (\"edge_type\" in df):\r\n            edge_type = np.array(df[\"edge_type\"]).astype(int)\r\n            self._edge_type = edge_type\r\n            full_data[\"edge_type\"] = edge_type\r\n\r\n        self._full_data = full_data\r\n\r\n        if (\"yago\" in self.name):\r\n            _train_mask, _val_mask, _test_mask = self.generate_splits(full_data, val_ratio=0.1, test_ratio=0.10) #99) #val_ratio=0.097, test_ratio=0.099)\r\n        else:\r\n            _train_mask, _val_mask, _test_mask = self.generate_splits(full_data, val_ratio=0.15, test_ratio=0.15)\r\n        self._train_mask = _train_mask\r\n        self._val_mask = _val_mask\r\n        self._test_mask = _test_mask\r\n\r\n    def generate_splits(\r\n        self,\r\n        full_data: Dict[str, Any],\r\n        val_ratio: float = 0.15,\r\n        test_ratio: float = 0.15,\r\n    ) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:\r\n        r\"\"\"Generates train, validation, and test splits from the full dataset\r\n        Args:\r\n            full_data: dictionary containing the full dataset\r\n            val_ratio: ratio of validation data\r\n            test_ratio: ratio of test data\r\n        Returns:\r\n            train_data: dictionary containing the training dataset\r\n            val_data: dictionary containing the validation dataset\r\n            test_data: dictionary containing the test dataset\r\n        \"\"\"\r\n        val_time, test_time = list(\r\n            np.quantile(\r\n                full_data[\"timestamps\"],\r\n                [(1 - val_ratio - test_ratio), (1 - test_ratio)],\r\n            )\r\n        )\r\n        timestamps = full_data[\"timestamps\"]\r\n\r\n        train_mask = timestamps <= val_time\r\n        val_mask = np.logical_and(timestamps <= test_time, timestamps > val_time)\r\n        test_mask = timestamps > test_time\r\n\r\n        return train_mask, val_mask, test_mask\r\n    \r\n    def preprocess_static_edges(self):\r\n        \"\"\"\r\n        Pre-process the static edges of the dataset\r\n        \"\"\"\r\n        if (\"staticfile\" in self.meta_dict):\r\n            OUT_DF = self.root + \"/\" + \"ml_{}.pkl\".format(self.name + \"_static\")\r\n            if osp.exists(OUT_DF) and self.version_passed is True:\r\n                vprint(f\"loading processed file from {OUT_DF}.\")\r\n                static_dict = load_pkl(OUT_DF)\r\n                self._static_data = static_dict\r\n            else:\r\n                vprint(\"file not processed, generating processed file\")\r\n                static_dict, node_ids =  csv_to_staticdata(self.meta_dict[\"staticfile\"], self._node_id)\r\n                save_pkl(static_dict, OUT_DF)\r\n                self._static_data = static_dict\r\n        else:\r\n            vprint (\"static edges are only for tkgl-wikidata and tkgl-smallpedia datasets\")\r\n\r\n    \r\n    @property\r\n    def eval_metric(self) -> str:\r\n        \"\"\"\r\n        the official evaluation metric for the dataset, loaded from info.py\r\n        Returns:\r\n            eval_metric: str, the evaluation metric\r\n        \"\"\"\r\n        return self.metric\r\n\r\n    @property\r\n    def negative_sampler(self) -> NegativeEdgeSampler:\r\n        r\"\"\"\r\n        Returns the negative sampler of the dataset, will load negative samples from disc\r\n        Returns:\r\n            negative_sampler: NegativeEdgeSampler\r\n        \"\"\"\r\n        return self.ns_sampler\r\n    \r\n\r\n    def load_val_ns(self) -> None:\r\n        r\"\"\"\r\n        load the negative samples for the validation set\r\n        \"\"\"\r\n        self.ns_sampler.load_eval_set(\r\n            fname=self.meta_dict[\"val_ns\"], split_mode=\"val\"\r\n        )\r\n\r\n    def load_test_ns(self) -> None:\r\n        r\"\"\"\r\n        load the negative samples for the test set\r\n        \"\"\"\r\n        self.ns_sampler.load_eval_set(\r\n            fname=self.meta_dict[\"test_ns\"], split_mode=\"test\"\r\n        )\r\n\r\n    @property\r\n    def num_nodes(self) -> int:\r\n        r\"\"\"\r\n        Returns the total number of unique nodes in the dataset \r\n        Returns:\r\n            num_nodes: int, the number of unique nodes\r\n        \"\"\"\r\n        src = self._full_data[\"sources\"]\r\n        dst = self._full_data[\"destinations\"]\r\n        all_nodes = np.concatenate((src, dst), axis=0)\r\n        uniq_nodes = np.unique(all_nodes, axis=0)\r\n        return uniq_nodes.shape[0]\r\n    \r\n\r\n    @property\r\n    def num_edges(self) -> int:\r\n        r\"\"\"\r\n        Returns the total number of edges in the dataset\r\n        Returns:\r\n            num_edges: int, the number of edges\r\n        \"\"\"\r\n        src = self._full_data[\"sources\"]\r\n        return src.shape[0]\r\n    \r\n\r\n    @property\r\n    def num_rels(self) -> int:\r\n        r\"\"\"\r\n        Returns the number of relation types in the dataset\r\n        Returns:\r\n            num_rels: int, the number of relation types\r\n        \"\"\"\r\n        #* if it is a homogenous graph\r\n        if (\"edge_type\" not in self._full_data):\r\n            return 1\r\n        else:\r\n            return np.unique(self._full_data[\"edge_type\"]).shape[0]\r\n\r\n    @property\r\n    def node_feat(self) -> Optional[np.ndarray]:\r\n        r\"\"\"\r\n        Returns the node features of the dataset with dim [N, feat_dim]\r\n        Returns:\r\n            node_feat: np.ndarray, [N, feat_dim] or None if there is no node feature\r\n        \"\"\"\r\n        return self._node_feat\r\n    \r\n    @property\r\n    def node_type(self) -> Optional[np.ndarray]:\r\n        r\"\"\"\r\n        Returns the node types of the dataset with dim [N], only for temporal heterogeneous graphs\r\n        Returns:\r\n            node_feat: np.ndarray, [N] or None if there is no node feature\r\n        \"\"\"\r\n        return self._node_type\r\n\r\n    @property\r\n    def edge_feat(self) -> Optional[np.ndarray]:\r\n        r\"\"\"\r\n        Returns the edge features of the dataset with dim [E, feat_dim]\r\n        Returns:\r\n            edge_feat: np.ndarray, [E, feat_dim] or None if there is no edge feature\r\n        \"\"\"\r\n        return self._edge_feat\r\n    \r\n    @property\r\n    def edge_type(self) -> Optional[np.ndarray]:\r\n        r\"\"\"\r\n        Returns the edge types of the dataset with dim [E, 1], only for temporal knowledge graph and temporal heterogeneous graph\r\n        Returns:\r\n            edge_type: np.ndarray, [E, 1] or None if it is not a TKG or THG\r\n        \"\"\"\r\n        return self._edge_type\r\n    \r\n    @property\r\n    def static_data(self) -> Optional[np.ndarray]:\r\n        r\"\"\"\r\n        Returns the static edges related to this dataset, applies for tkgl-wikidata and tkgl-smallpedia, edges are (src, dst, rel_type)\r\n        Returns:\r\n            df: pd.DataFrame {\"head\": np.ndarray, \"tail\": np.ndarray, \"rel_type\": np.ndarray}\r\n        \"\"\"\r\n        if (self.name == \"tkgl-wikidata\") or (self.name == \"tkgl-smallpedia\"):\r\n            self.preprocess_static_edges()\r\n        return self._static_data\r\n\r\n    @property\r\n    def full_data(self) -> Dict[str, Any]:\r\n        r\"\"\"\r\n        the full data of the dataset as a dictionary with keys: 'sources', 'destinations', 'timestamps', 'edge_idxs', 'edge_feat', 'w', 'edge_label',\r\n\r\n        Returns:\r\n            full_data: Dict[str, Any]\r\n        \"\"\"\r\n        if self._full_data is None:\r\n            raise ValueError(\r\n                \"dataset has not been processed yet, please call pre_process() first\"\r\n            )\r\n        return self._full_data\r\n\r\n    @property\r\n    def train_mask(self) -> np.ndarray:\r\n        r\"\"\"\r\n        Returns the train mask of the dataset\r\n        Returns:\r\n            train_mask: training masks\r\n        \"\"\"\r\n        if self._train_mask is None:\r\n            raise ValueError(\"training split hasn't been loaded\")\r\n        return self._train_mask\r\n\r\n    @property\r\n    def val_mask(self) -> np.ndarray:\r\n        r\"\"\"\r\n        Returns the validation mask of the dataset\r\n        Returns:\r\n            val_mask: Dict[str, Any]\r\n        \"\"\"\r\n        if self._val_mask is None:\r\n            raise ValueError(\"validation split hasn't been loaded\")\r\n        return self._val_mask\r\n\r\n    @property\r\n    def test_mask(self) -> np.ndarray:\r\n        r\"\"\"\r\n        Returns the test mask of the dataset:\r\n        Returns:\r\n            test_mask: Dict[str, Any]\r\n        \"\"\"\r\n        if self._test_mask is None:\r\n            raise ValueError(\"test split hasn't been loaded\")\r\n        return self._test_mask\r\n\r\n\r\ndef main():\r\n\r\n    name = \"tkgl-polecat\"\r\n    dataset = LinkPropPredDataset(name=name, root=\"datasets\", preprocess=True)\r\n    dataset.edge_type\r\n\r\n\r\n\r\n    # name = \"tgbl-comment\" \r\n    # dataset = LinkPropPredDataset(name=name, root=\"datasets\", preprocess=True)\r\n\r\n    # dataset.node_feat\r\n    # dataset.edge_feat  # not the edge weights\r\n    # dataset.full_data\r\n    # dataset.full_data[\"edge_idxs\"]\r\n    # dataset.full_data[\"sources\"]\r\n    # dataset.full_data[\"destinations\"]\r\n    # dataset.full_data[\"timestamps\"]\r\n    # dataset.full_data[\"edge_label\"]\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n"
  },
  {
    "path": "tgb/linkproppred/dataset_pyg.py",
    "content": "import torch\r\nfrom typing import Optional, Optional, Callable\r\n\r\nfrom torch_geometric.data import Dataset, TemporalData\r\nfrom tgb.linkproppred.dataset import LinkPropPredDataset\r\nfrom tgb.linkproppred.negative_sampler import NegativeEdgeSampler\r\n\r\n\r\nclass PyGLinkPropPredDataset(Dataset):\r\n    def __init__(\r\n        self,\r\n        name: str,\r\n        root: str = \"datasets\",\r\n        transform: Optional[Callable] = None,\r\n        pre_transform: Optional[Callable] = None,\r\n        download: Optional[bool] = True,\r\n    ):\r\n        r\"\"\"\r\n        PyG wrapper for the LinkPropPredDataset\r\n        can return pytorch tensors for src,dst,t,msg,label\r\n        can return Temporal Data object\r\n        Parameters:\r\n            name: name of the dataset, passed to `LinkPropPredDataset`\r\n            root (string): Root directory where the dataset should be saved, passed to `LinkPropPredDataset`\r\n            transform (callable, optional): A function/transform that takes in an, not used in this case\r\n            pre_transform (callable, optional): A function/transform that takes in, not used in this case\r\n            download (optional, bool): download or not (default True)\r\n        \"\"\"\r\n        self.name = name\r\n        self.root = root\r\n        self.dataset = LinkPropPredDataset(name=name, root=root, download=download)\r\n        self._train_mask = torch.from_numpy(self.dataset.train_mask)\r\n        self._val_mask = torch.from_numpy(self.dataset.val_mask)\r\n        self._test_mask = torch.from_numpy(self.dataset.test_mask)\r\n        super().__init__(root, transform, pre_transform)\r\n        self._node_feat = self.dataset.node_feat\r\n        self._edge_type = None\r\n        self._static_data = None\r\n\r\n        if self._node_feat is None:\r\n            self._node_feat = None\r\n        else:\r\n            self._node_feat = torch.from_numpy(self._node_feat).float()\r\n        \r\n        self._node_type = self.dataset.node_type\r\n        if self.node_type is not None:\r\n            self._node_type = torch.from_numpy(self.dataset.node_type).long()\r\n        \r\n        self.process_data()\r\n\r\n        self._ns_sampler = self.dataset.negative_sampler\r\n\r\n    @property\r\n    def eval_metric(self) -> str:\r\n        \"\"\"\r\n        the official evaluation metric for the dataset, loaded from info.py\r\n        Returns:\r\n            eval_metric: str, the evaluation metric\r\n        \"\"\"\r\n        return self.dataset.eval_metric\r\n\r\n    @property\r\n    def negative_sampler(self) -> NegativeEdgeSampler:\r\n        r\"\"\"\r\n        Returns the negative sampler of the dataset, will load negative samples from disc\r\n        Returns:\r\n            negative_sampler: NegativeEdgeSampler\r\n        \"\"\"\r\n        return self._ns_sampler\r\n    \r\n    @property\r\n    def num_nodes(self) -> int:\r\n        r\"\"\"\r\n        Returns the total number of unique nodes in the dataset \r\n        Returns:\r\n            num_nodes: int, the number of unique nodes\r\n        \"\"\"\r\n        return self.dataset.num_nodes\r\n    \r\n    @property\r\n    def num_rels(self) -> int:\r\n        r\"\"\"\r\n        Returns the total number of unique relations in the dataset \r\n        Returns:\r\n            num_rels: int, the number of unique relations\r\n        \"\"\"\r\n        return self.dataset.num_rels\r\n    \r\n    @property\r\n    def num_edges(self) -> int:\r\n        r\"\"\"\r\n        Returns the total number of edges in the dataset \r\n        Returns:\r\n            num_edges: int, the number of edges\r\n        \"\"\"\r\n        return self.dataset.num_edges\r\n\r\n    def load_val_ns(self) -> None:\r\n        r\"\"\"\r\n        load the negative samples for the validation set\r\n        \"\"\"\r\n        self.dataset.load_val_ns()\r\n\r\n    def load_test_ns(self) -> None:\r\n        r\"\"\"\r\n        load the negative samples for the test set\r\n        \"\"\"\r\n        self.dataset.load_test_ns()\r\n\r\n    @property\r\n    def train_mask(self) -> torch.Tensor:\r\n        r\"\"\"\r\n        Returns the train mask of the dataset\r\n        Returns:\r\n            train_mask: the mask for edges in the training set\r\n        \"\"\"\r\n        if self._train_mask is None:\r\n            raise ValueError(\"training split hasn't been loaded\")\r\n        return self._train_mask\r\n\r\n    @property\r\n    def val_mask(self) -> torch.Tensor:\r\n        r\"\"\"\r\n        Returns the validation mask of the dataset\r\n        Returns:\r\n            val_mask: the mask for edges in the validation set\r\n        \"\"\"\r\n        if self._val_mask is None:\r\n            raise ValueError(\"validation split hasn't been loaded\")\r\n        return self._val_mask\r\n\r\n    @property\r\n    def test_mask(self) -> torch.Tensor:\r\n        r\"\"\"\r\n        Returns the test mask of the dataset:\r\n        Returns:\r\n            test_mask: the mask for edges in the test set\r\n        \"\"\"\r\n        if self._test_mask is None:\r\n            raise ValueError(\"test split hasn't been loaded\")\r\n        return self._test_mask\r\n\r\n    @property\r\n    def node_feat(self) -> torch.Tensor:\r\n        r\"\"\"\r\n        Returns the node features of the dataset\r\n        Returns:\r\n            node_feat: the node features\r\n        \"\"\"\r\n        return self._node_feat\r\n    \r\n    @property\r\n    def node_type(self) -> torch.Tensor:\r\n        r\"\"\"\r\n        Returns the node types of the dataset\r\n        Returns:\r\n            node_type: the node types [N]\r\n        \"\"\"\r\n        return self._node_type\r\n\r\n    @property\r\n    def src(self) -> torch.Tensor:\r\n        r\"\"\"\r\n        Returns the source nodes of the dataset\r\n        Returns:\r\n            src: the idx of the source nodes\r\n        \"\"\"\r\n        return self._src\r\n\r\n    @property\r\n    def dst(self) -> torch.Tensor:\r\n        r\"\"\"\r\n        Returns the destination nodes of the dataset\r\n        Returns:\r\n            dst: the idx of the destination nodes\r\n        \"\"\"\r\n        return self._dst\r\n\r\n    @property\r\n    def ts(self) -> torch.Tensor:\r\n        r\"\"\"\r\n        Returns the timestamps of the dataset\r\n        Returns:\r\n            ts: the timestamps of the edges\r\n        \"\"\"\r\n        return self._ts\r\n    \r\n    @property\r\n    def static_data(self) -> torch.Tensor:\r\n        r\"\"\"\r\n        Returns the static data of the dataset for tkgl-wikidata and tkgl-smallpedia\r\n        Returns:\r\n            static_data: the static data of the dataset\r\n        \"\"\"\r\n        if (self._static_data is None):\r\n            static_dict = {}\r\n            static_dict[\"head\"] = torch.from_numpy(self.dataset.static_data[\"head\"]).long()\r\n            static_dict[\"tail\"] = torch.from_numpy(self.dataset.static_data[\"tail\"]).long()\r\n            static_dict[\"edge_type\"] = torch.from_numpy(self.dataset.static_data[\"edge_type\"]).long()\r\n            self._static_data = static_dict\r\n            return self._static_data\r\n        else:\r\n            return self._static_data \r\n    \r\n    @property\r\n    def edge_type(self) -> torch.Tensor:\r\n        r\"\"\"\r\n        Returns the edge types for each edge\r\n        Returns:\r\n            edge_type: edge type tensor (int)\r\n        \"\"\"\r\n        return self._edge_type\r\n\r\n    @property\r\n    def edge_feat(self) -> torch.Tensor:\r\n        r\"\"\"\r\n        Returns the edge features of the dataset\r\n        Returns:\r\n            edge_feat: the edge features\r\n        \"\"\"\r\n        return self._edge_feat\r\n\r\n    @property\r\n    def edge_label(self) -> torch.Tensor:\r\n        r\"\"\"\r\n        Returns the edge labels of the dataset\r\n        Returns:\r\n            edge_label: the labels of the edges\r\n        \"\"\"\r\n        return self._edge_label\r\n\r\n    def process_data(self) -> None:\r\n        r\"\"\"\r\n        convert the numpy arrays from dataset to pytorch tensors\r\n        \"\"\"\r\n        src = torch.from_numpy(self.dataset.full_data[\"sources\"])\r\n        dst = torch.from_numpy(self.dataset.full_data[\"destinations\"])\r\n        ts = torch.from_numpy(self.dataset.full_data[\"timestamps\"])\r\n        msg = torch.from_numpy(\r\n            self.dataset.full_data[\"edge_feat\"]\r\n        )  # use edge features here if available\r\n        edge_label = torch.from_numpy(\r\n            self.dataset.full_data[\"edge_label\"]\r\n        )  # this is the label indicating if an edge is a true edge, always 1 for true edges\r\n        w = torch.from_numpy(\r\n            self.dataset.full_data[\"w\"]\r\n        )\r\n\r\n\r\n        # * first check typing for all tensors\r\n        # source tensor must be of type int64\r\n        # warnings.warn(\"sources tensor is not of type int64 or int32, forcing conversion\")\r\n        if src.dtype != torch.int64:\r\n            src = src.long()\r\n\r\n        # destination tensor must be of type int64\r\n        if dst.dtype != torch.int64:\r\n            dst = dst.long()\r\n\r\n        # timestamp tensor must be of type int64\r\n        if ts.dtype != torch.int64:\r\n            ts = ts.long()\r\n\r\n        # message tensor must be of type float32\r\n        if msg.dtype != torch.float32:\r\n            msg = msg.float()\r\n\r\n        # weight tensor must be of type float32\r\n        if w.dtype != torch.float32:\r\n            w = w.float()\r\n\r\n        #* for tkg\r\n        if (\"edge_type\" in self.dataset.full_data):\r\n            edge_type = torch.from_numpy(self.dataset.full_data[\"edge_type\"])\r\n            if edge_type.dtype != torch.int64:\r\n                edge_type = edge_type.long()\r\n            self._edge_type = edge_type\r\n\r\n        self._src = src\r\n        self._dst = dst\r\n        self._ts = ts\r\n        self._edge_label = edge_label\r\n        self._edge_feat = msg\r\n        self._w = w\r\n\r\n    def get_TemporalData(self) -> TemporalData:\r\n        \"\"\"\r\n        return the TemporalData object for the entire dataset\r\n        \"\"\"\r\n        if (self._edge_type is not None):\r\n            data = TemporalData(\r\n                src=self._src,\r\n                dst=self._dst,\r\n                t=self._ts,\r\n                msg=self._edge_feat,\r\n                y=self._edge_label,\r\n                edge_type=self._edge_type,\r\n                w=self._w,\r\n            )\r\n        else:\r\n            data = TemporalData(\r\n                src=self._src,\r\n                dst=self._dst,\r\n                t=self._ts,\r\n                msg=self._edge_feat,\r\n                y=self._edge_label,\r\n                w=self._w,\r\n            )\r\n        return data\r\n\r\n    def len(self) -> int:\r\n        \"\"\"\r\n        size of the dataset\r\n        Returns:\r\n            size: int\r\n        \"\"\"\r\n        return self._src.shape[0]\r\n\r\n    def get(self, idx: int) -> TemporalData:\r\n        \"\"\"\r\n        construct temporal data object for a single edge\r\n        Parameters:\r\n            idx: index of the edge\r\n        Returns:\r\n            data: TemporalData object\r\n        \"\"\"\r\n        if (self._edge_type is not None):\r\n            data = TemporalData(\r\n                src=self._src[idx],\r\n                dst=self._dst[idx],\r\n                t=self._ts[idx],\r\n                msg=self._edge_feat[idx],\r\n                y=self._edge_label[idx],\r\n                edge_type=self._edge_type[idx]\r\n            )\r\n        else:\r\n            data = TemporalData(\r\n                src=self._src[idx],\r\n                dst=self._dst[idx],\r\n                t=self._ts[idx],\r\n                msg=self._edge_feat[idx],\r\n                y=self._edge_label[idx],\r\n            )\r\n        return data\r\n\r\n    def __repr__(self) -> str:\r\n        return f\"{self.name.capitalize()}()\"\r\n"
  },
  {
    "path": "tgb/linkproppred/evaluate.py",
    "content": "\"\"\"\r\nEvaluator Module for Dynamic Link Prediction\r\n\"\"\"\r\n\r\nimport numpy as np\r\nfrom sklearn.metrics import *\r\nfrom tgb.utils.info import DATA_EVAL_METRIC_DICT\r\nfrom tgb.utils.utils import vprint\r\n\r\ntry:\r\n    import torch\r\nexcept ImportError:\r\n    torch = None\r\n\r\n\r\nclass Evaluator(object):\r\n    r\"\"\"Evaluator for Link Property Prediction \"\"\"\r\n\r\n    def __init__(self, name: str, k_value: int = 10):\r\n        r\"\"\"\r\n        Parameters:\r\n            name: name of the dataset\r\n            k_value: the desired 'k' value for calculating metric@k\r\n        \"\"\"\r\n        self.name = name\r\n        self.k_value = k_value  # for computing `hits@k`\r\n        self.valid_metric_list = ['hits@', 'mrr']\r\n        if self.name not in DATA_EVAL_METRIC_DICT:\r\n            raise NotImplementedError(\"Dataset not supported\")\r\n    \r\n    def _parse_and_check_input(self, input_dict):\r\n        r\"\"\"\r\n        Check whether the input has the appropriate format\r\n        Parametrers:\r\n            input_dict: a dictionary containing \"y_pred_pos\", \"y_pred_neg\", and \"eval_metric\"\r\n            note: \"eval_metric\" should be a list including one or more of the followin metrics: [\"hits@\", \"mrr\"]\r\n        Returns:\r\n            y_pred_pos: positive predicted scores\r\n            y_pred_neg: negative predicted scores\r\n        \"\"\"\r\n\r\n        if \"eval_metric\" not in input_dict:\r\n            raise RuntimeError(\"Missing key of eval_metric!\")\r\n\r\n        for eval_metric in input_dict[\"eval_metric\"]:\r\n            if eval_metric in self.valid_metric_list:\r\n                if \"y_pred_pos\" not in input_dict:\r\n                    raise RuntimeError(\"Missing key of y_true\")\r\n                if \"y_pred_neg\" not in input_dict:\r\n                    raise RuntimeError(\"Missing key of y_pred\")\r\n\r\n                y_pred_pos, y_pred_neg = input_dict[\"y_pred_pos\"], input_dict[\"y_pred_neg\"]\r\n\r\n                # converting to numpy on cpu\r\n                if torch is not None and isinstance(y_pred_pos, torch.Tensor):\r\n                    y_pred_pos = y_pred_pos.detach().cpu().numpy()\r\n                if torch is not None and isinstance(y_pred_neg, torch.Tensor):\r\n                    y_pred_neg = y_pred_neg.detach().cpu().numpy()\r\n\r\n                # check type and shape\r\n                if not isinstance(y_pred_pos, np.ndarray) or not isinstance(y_pred_neg, np.ndarray):\r\n                    raise RuntimeError(\r\n                        \"Arguments to Evaluator need to be either numpy ndarray or torch tensor!\"\r\n                    )\r\n            else:\r\n                raise ValueError(f\"Unsupported eval metric: {eval_metric}, not found in {self.valid_metric_list}\")\r\n\r\n        self.eval_metric = input_dict[\"eval_metric\"]\r\n\r\n        return y_pred_pos, y_pred_neg\r\n\r\n    def _eval_hits_and_mrr(self, y_pred_pos, y_pred_neg, type_info, k_value):\r\n        r\"\"\"\r\n        compute hist@k and mrr\r\n        reference:\r\n            - https://github.com/snap-stanford/ogb/blob/d5c11d91c9e1c22ed090a2e0bbda3fe357de66e7/ogb/linkproppred/evaluate.py#L214\r\n        \r\n        Parameters:\r\n            y_pred_pos: positive predicted scores\r\n            y_pred_neg: negative predicted scores\r\n            type_info: type of the predicted scores; could be 'torch' or 'numpy'\r\n            k_value: the desired 'k' value for calculating metric@k\r\n        \r\n        Returns:\r\n            a dictionary containing the computed performance metrics\r\n        \"\"\"\r\n        if type_info == 'torch':\r\n            # calculate ranks\r\n            y_pred_pos = y_pred_pos.view(-1, 1)\r\n            # optimistic rank: \"how many negatives have a larger score than the positive?\"\r\n            # ~> the positive is ranked first among those with equal score\r\n            optimistic_rank = (y_pred_neg > y_pred_pos).sum(dim=1)\r\n            # pessimistic rank: \"how many negatives have at least the positive score?\"\r\n            # ~> the positive is ranked last among those with equal score\r\n            pessimistic_rank = (y_pred_neg >= y_pred_pos).sum(dim=1)\r\n            ranking_list = 0.5 * (optimistic_rank + pessimistic_rank) + 1\r\n            hitsK_list = (ranking_list <= k_value).to(torch.float)\r\n            mrr_list = 1./ranking_list.to(torch.float)\r\n\r\n            return {\r\n                    f'hits@{k_value}': hitsK_list.mean(),\r\n                    'mrr': mrr_list.mean()\r\n                    }\r\n\r\n        else:\r\n            y_pred_pos = y_pred_pos.reshape(-1, 1)\r\n            optimistic_rank = (y_pred_neg > y_pred_pos).sum(axis=1)\r\n            pessimistic_rank = (y_pred_neg >= y_pred_pos).sum(axis=1)\r\n            ranking_list = 0.5 * (optimistic_rank + pessimistic_rank) + 1\r\n            hitsK_list = (ranking_list <= k_value).astype(np.float32)\r\n            mrr_list = 1./ranking_list.astype(np.float32)\r\n\r\n            return {\r\n                    f'hits@{k_value}': hitsK_list.mean(),\r\n                    'mrr': mrr_list.mean()\r\n                    }\r\n\r\n    def eval(self, \r\n             input_dict: dict, \r\n             verbose: bool = False) -> dict:\r\n        r\"\"\"\r\n        evaluate the link prediction task\r\n        this method is callable through an instance of this object to compute the metric\r\n\r\n        Parameters:\r\n            input_dict: a dictionary containing \"y_pred_pos\", \"y_pred_neg\", and \"eval_metric\"\r\n                        the performance metric is calculated for the provided scores\r\n            verbose: whether to print out the computed metric\r\n        \r\n        Returns:\r\n            perf_dict: a dictionary containing the computed performance metric\r\n        \"\"\"\r\n        y_pred_pos, y_pred_neg = self._parse_and_check_input(input_dict)  # convert the predictions to numpy\r\n        perf_dict = self._eval_hits_and_mrr(y_pred_pos, y_pred_neg, type_info='numpy', k_value=self.k_value)\r\n        \r\n        return perf_dict\r\n    \r\n"
  },
  {
    "path": "tgb/linkproppred/negative_generator.py",
    "content": "\"\"\"\r\nSample and Generate negative edges that are going to be used for evaluation of a dynamic graph learning model\r\nNegative samples are generated and saved to files ONLY once; \r\n    other times, they should be loaded from file with instances of the `negative_sampler.py`.\r\n\"\"\"\r\n\r\nimport torch\r\nimport numpy as np\r\nfrom torch_geometric.data import TemporalData\r\nfrom tgb.utils.utils import save_pkl\r\nimport os\r\nfrom tqdm import tqdm\r\nfrom tgb.utils.utils import vprint\r\n\r\n\r\nclass NegativeEdgeGenerator(object):\r\n    def __init__(\r\n        self,\r\n        dataset_name: str,\r\n        first_dst_id: int,\r\n        last_dst_id: int,\r\n        num_neg_e: int = 100,  # number of negative edges sampled per positive edges --> make it constant => 1000\r\n        strategy: str = \"rnd\",\r\n        rnd_seed: int = 123,\r\n        hist_ratio: float = 0.5,\r\n        historical_data: TemporalData = None,\r\n    ) -> None:\r\n        r\"\"\"\r\n        Negative Edge Sampler class\r\n        this is a class for generating negative samples for a specific datasets\r\n        the set of the positive samples are provided, the negative samples are generated with specific strategies \r\n        and are saved for consistent evaluation across different methods\r\n        negative edges are sampled with 'oen_vs_many' strategy.\r\n        it is assumed that the destination nodes are indexed sequentially with 'first_dst_id' \r\n        and 'last_dst_id' being the first and last index, respectively.\r\n\r\n        Parameters:\r\n            dataset_name: name of the dataset\r\n            first_dst_id: identity of the first destination node\r\n            last_dst_id: indentity of the last destination node\r\n            num_neg_e: number of negative edges being generated per each positive edge\r\n            strategy: how to generate negative edges; can be 'rnd' or 'hist_rnd'\r\n            rnd_seed: random seed for consistency\r\n            hist_ratio: if the startegy is 'hist_rnd', how much of the negatives are historical\r\n            historical_data: previous records of the positive edges\r\n        \r\n        Returns:\r\n            None\r\n        \"\"\"\r\n        self.rnd_seed = rnd_seed\r\n        np.random.seed(self.rnd_seed)\r\n        self.dataset_name = dataset_name\r\n\r\n        self.first_dst_id = first_dst_id\r\n        self.last_dst_id = last_dst_id\r\n        self.num_neg_e = num_neg_e\r\n        assert strategy in [\r\n            \"rnd\",\r\n            \"hist_rnd\",\r\n        ], \"The supported strategies are `rnd` or `hist_rnd`!\"\r\n        self.strategy = strategy\r\n        if self.strategy == \"hist_rnd\":\r\n            assert (\r\n                historical_data != None\r\n            ), \"Train data should be passed when `hist_rnd` strategy is selected.\"\r\n            self.hist_ratio = hist_ratio\r\n            self.historical_data = historical_data\r\n\r\n    def generate_negative_samples(self, \r\n                                  data: TemporalData, \r\n                                  split_mode: str, \r\n                                  partial_path: str,\r\n                                  ) -> None:\r\n        r\"\"\"\r\n        Generate negative samples\r\n\r\n        Parameters:\r\n            data: an object containing positive edges information\r\n            split_mode: specifies whether to generate negative edges for 'validation' or 'test' splits\r\n            partial_path: in which directory save the generated negatives\r\n        \"\"\"\r\n        # file name for saving or loading...\r\n        filename = (\r\n            partial_path\r\n            + \"/\"\r\n            + self.dataset_name\r\n            + \"_\"\r\n            + split_mode\r\n            + \"_\"\r\n            + \"ns\"\r\n            + \".pkl\"\r\n        )\r\n\r\n        if self.strategy == \"rnd\":\r\n            self.generate_negative_samples_rnd(data, split_mode, filename)\r\n        elif self.strategy == \"hist_rnd\":\r\n            self.generate_negative_samples_hist_rnd(\r\n                self.historical_data, data, split_mode, filename\r\n            )\r\n        else:\r\n            raise ValueError(\"Unsupported negative sample generation strategy!\")\r\n\r\n    def generate_negative_samples_rnd(self, \r\n                                      data: TemporalData, \r\n                                      split_mode: str, \r\n                                      filename: str,\r\n                                      ) -> None:\r\n        r\"\"\"\r\n        Generate negative samples based on the `HIST-RND` strategy:\r\n            - for each positive edge, sample a batch of negative edges from all possible edges with the same source node\r\n            - filter actual positive edges\r\n        \r\n        Parameters:\r\n            data: an object containing positive edges information\r\n            split_mode: specifies whether to generate negative edges for 'validation' or 'test' splits\r\n            filename: name of the file containing the generated negative edges\r\n        \"\"\"\r\n        vprint(\r\n            f\"INFO: Negative Sampling Strategy: {self.strategy}, Data Split: {split_mode}\"\r\n        )\r\n        assert split_mode in [\r\n            \"val\",\r\n            \"test\",\r\n        ], \"Invalid split-mode! It should be `val` or `test`!\"\r\n\r\n        if os.path.exists(filename):\r\n            vprint(\r\n                f\"INFO: negative samples for '{split_mode}' evaluation are already generated!\"\r\n            )\r\n        else:\r\n            vprint(f\"INFO: Generating negative samples for '{split_mode}' evaluation!\")\r\n            # retrieve the information from the batch\r\n            pos_src, pos_dst, pos_timestamp = (\r\n                data.src.cpu().numpy(),\r\n                data.dst.cpu().numpy(),\r\n                data.t.cpu().numpy(),\r\n            )\r\n\r\n            # all possible destinations\r\n            all_dst = np.arange(self.first_dst_id, self.last_dst_id + 1)\r\n\r\n            evaluation_set = {}\r\n            # generate a list of negative destinations for each positive edge\r\n            pos_edge_tqdm = tqdm(\r\n                zip(pos_src, pos_dst, pos_timestamp), total=len(pos_src)\r\n            )\r\n            for (\r\n                pos_s,\r\n                pos_d,\r\n                pos_t,\r\n            ) in pos_edge_tqdm:\r\n                t_mask = pos_timestamp == pos_t\r\n                src_mask = pos_src == pos_s\r\n                fn_mask = np.logical_and(t_mask, src_mask)\r\n                pos_e_dst_same_src = pos_dst[fn_mask]\r\n                filtered_all_dst = np.setdiff1d(all_dst, pos_e_dst_same_src)\r\n\r\n                '''\r\n                when num_neg_e is larger than all possible destinations simple return all possible destinations\r\n                '''\r\n                if (self.num_neg_e > len(filtered_all_dst)):\r\n                    neg_d_arr = filtered_all_dst\r\n                else:\r\n                    neg_d_arr = np.random.choice(\r\n                    filtered_all_dst, self.num_neg_e, replace=False) #never replace negatives\r\n\r\n                evaluation_set[(pos_s, pos_d, pos_t)] = neg_d_arr\r\n\r\n            # save the generated evaluation set to disk\r\n            save_pkl(evaluation_set, filename)\r\n\r\n    def generate_historical_edge_set(self, \r\n                                     historical_data: TemporalData,\r\n                                     ) -> tuple:\r\n        r\"\"\"\r\n        Generate the set of edges seen durign training or validation\r\n\r\n        ONLY `train_data` should be passed as historical data; i.e., the HISTORICAL negative edges should be selected from training data only.\r\n        \r\n        Parameters:\r\n            historical_data: contains the positive edges observed previously\r\n        \r\n        Returns:\r\n            historical_edges: distict historical positive edges\r\n            hist_edge_set_per_node: historical edges observed for each node\r\n        \"\"\"\r\n        sources = historical_data.src.cpu().numpy()\r\n        destinations = historical_data.dst.cpu().numpy()\r\n        historical_edges = {}\r\n        hist_e_per_node = {}\r\n        for src, dst in zip(sources, destinations):\r\n            # edge-centric\r\n            if (src, dst) not in historical_edges:\r\n                historical_edges[(src, dst)] = 1\r\n\r\n            # node-centric\r\n            if src not in hist_e_per_node:\r\n                hist_e_per_node[src] = [dst]\r\n            else:\r\n                hist_e_per_node[src].append(dst)\r\n\r\n        hist_edge_set_per_node = {}\r\n        for src, dst_list in hist_e_per_node.items():\r\n            hist_edge_set_per_node[src] = np.array(list(set(dst_list)))\r\n\r\n        return historical_edges, hist_edge_set_per_node\r\n\r\n    def generate_negative_samples_hist_rnd(\r\n        self, \r\n        historical_data : TemporalData, \r\n        data: TemporalData, \r\n        split_mode: str, \r\n        filename: str,\r\n    ) -> None:\r\n        r\"\"\"\r\n        Generate negative samples based on the `HIST-RND` strategy:\r\n            - up to 50% of the negative samples are selected from the set of edges seen during the training with the same source node.\r\n            - the rest of the negative edges are randomly sampled with the fixed source node.\r\n\r\n        Parameters:\r\n            historical_data: contains the history of the observed positive edges including \r\n                            distinct positive edges and edges observed for each positive node\r\n            data: an object containing positive edges information\r\n            split_mode: specifies whether to generate negative edges for 'validation' or 'test' splits\r\n            filename: name of the file to save generated negative edges\r\n        \r\n        Returns:\r\n            None\r\n        \"\"\"\r\n        vprint(\r\n            f\"INFO: Negative Sampling Strategy: {self.strategy}, Data Split: {split_mode}\"\r\n        )\r\n        assert split_mode in [\r\n            \"val\",\r\n            \"test\",\r\n        ], \"Invalid split-mode! It should be `val` or `test`!\"\r\n\r\n        if os.path.exists(filename):\r\n            vprint(\r\n                f\"INFO: negative samples for '{split_mode}' evaluation are already generated!\"\r\n            )\r\n        else:\r\n            vprint(f\"INFO: Generating negative samples for '{split_mode}' evaluation!\")\r\n            # retrieve the information from the batch\r\n            pos_src, pos_dst, pos_timestamp = (\r\n                data.src.cpu().numpy(),\r\n                data.dst.cpu().numpy(),\r\n                data.t.cpu().numpy(),\r\n            )\r\n\r\n            pos_ts_edge_dict = {} #{ts: {src: [dsts]}}\r\n            pos_edge_tqdm = tqdm(\r\n                zip(pos_src, pos_dst, pos_timestamp), total=len(pos_src)\r\n            )\r\n            for (\r\n                pos_s,\r\n                pos_d,\r\n                pos_t,\r\n            ) in pos_edge_tqdm:\r\n                if (pos_t not in pos_ts_edge_dict):\r\n                    pos_ts_edge_dict[pos_t] = {pos_s: [pos_d]}\r\n                else:\r\n                    if (pos_s not in pos_ts_edge_dict[pos_t]):\r\n                        pos_ts_edge_dict[pos_t][pos_s] = [pos_d]\r\n                    else:\r\n                        pos_ts_edge_dict[pos_t][pos_s].append(pos_d)\r\n\r\n            # all possible destinations\r\n            all_dst = np.arange(self.first_dst_id, self.last_dst_id + 1)\r\n\r\n            # get seen edge history\r\n            (\r\n                historical_edges,\r\n                hist_edge_set_per_node,\r\n            ) = self.generate_historical_edge_set(historical_data)\r\n\r\n            # sample historical edges\r\n            max_num_hist_neg_e = int(self.num_neg_e * self.hist_ratio)\r\n\r\n            evaluation_set = {}\r\n            # generate a list of negative destinations for each positive edge\r\n            pos_edge_tqdm = tqdm(\r\n                zip(pos_src, pos_dst, pos_timestamp), total=len(pos_src)\r\n            )\r\n            for (\r\n                pos_s,\r\n                pos_d,\r\n                pos_t,\r\n            ) in pos_edge_tqdm:\r\n                pos_e_dst_same_src = np.array(pos_ts_edge_dict[pos_t][pos_s])\r\n\r\n                # sample historical edges\r\n                num_hist_neg_e = 0\r\n                neg_hist_dsts = np.array([])\r\n                seen_dst = []\r\n                if pos_s in hist_edge_set_per_node:\r\n                    seen_dst = hist_edge_set_per_node[pos_s]\r\n                    if len(seen_dst) >= 1:\r\n                        filtered_all_seen_dst = np.setdiff1d(seen_dst, pos_e_dst_same_src)\r\n                        #filtered_all_seen_dst = seen_dst #! no collision check\r\n                        num_hist_neg_e = (\r\n                            max_num_hist_neg_e\r\n                            if max_num_hist_neg_e <= len(filtered_all_seen_dst)\r\n                            else len(filtered_all_seen_dst)\r\n                        )\r\n                        neg_hist_dsts = np.random.choice(\r\n                            filtered_all_seen_dst, num_hist_neg_e, replace=False\r\n                        )\r\n\r\n                # sample random edges\r\n                if (len(seen_dst) >= 1):\r\n                    invalid_dst = np.concatenate((np.array(pos_e_dst_same_src), seen_dst))\r\n                else:\r\n                    invalid_dst = np.array(pos_e_dst_same_src)\r\n                filtered_all_rnd_dst = np.setdiff1d(all_dst, invalid_dst)\r\n\r\n                num_rnd_neg_e = self.num_neg_e - num_hist_neg_e\r\n                '''\r\n                when num_neg_e is larger than all possible destinations simple return all possible destinations\r\n                '''\r\n                if (num_rnd_neg_e > len(filtered_all_rnd_dst)):\r\n                    neg_rnd_dsts = filtered_all_rnd_dst\r\n                else:\r\n                    neg_rnd_dsts = np.random.choice(\r\n                    filtered_all_rnd_dst, num_rnd_neg_e, replace=False\r\n                )\r\n                # concatenate the two sets: historical and random\r\n                neg_dst_arr = np.concatenate((neg_hist_dsts, neg_rnd_dsts))\r\n                evaluation_set[(pos_s, pos_d, pos_t)] = neg_dst_arr\r\n\r\n            # save the generated evaluation set to disk\r\n            save_pkl(evaluation_set, filename)\r\n"
  },
  {
    "path": "tgb/linkproppred/negative_sampler.py",
    "content": "\"\"\"\r\nSample negative edges for evaluation of dynamic link prediction\r\nLoad already generated negative edges from file, batch them based on the positive edge, and return the evaluation set\r\n\"\"\"\r\n\r\nimport torch\r\nfrom torch import Tensor\r\nimport numpy as np\r\nfrom tgb.utils.utils import save_pkl, load_pkl\r\nfrom tgb.utils.info import PROJ_DIR\r\nimport os\r\nimport time\r\n\r\n\r\nclass NegativeEdgeSampler(object):\r\n    def __init__(\r\n        self,\r\n        dataset_name: str,\r\n        first_dst_id: int = 0,\r\n        last_dst_id: int = 0,\r\n        strategy: str = \"hist_rnd\",\r\n    ) -> None:\r\n        r\"\"\"\r\n        Negative Edge Sampler\r\n            Loads and query the negative batches based on the positive batches provided.\r\n        constructor for the negative edge sampler class\r\n\r\n        Parameters:\r\n            dataset_name: name of the dataset\r\n            first_dst_id: identity of the first destination node\r\n            last_dst_id: indentity of the last destination node\r\n            strategy: will always load the pre-generated negatives\r\n        \r\n        Returns:\r\n            None\r\n        \"\"\"\r\n        self.dataset_name = dataset_name\r\n        assert strategy in [\r\n            \"rnd\",\r\n            \"hist_rnd\",\r\n        ], \"The supported strategies are `rnd` or `hist_rnd`!\"\r\n        self.strategy = strategy\r\n        self.eval_set = {}\r\n\r\n    def load_eval_set(\r\n        self,\r\n        fname: str,\r\n        split_mode: str = \"val\",\r\n    ) -> None:\r\n        r\"\"\"\r\n        Load the evaluation set from disk, can be either val or test set ns samples\r\n        Parameters:\r\n            fname: the file name of the evaluation ns on disk\r\n            split_mode: the split mode of the evaluation set, can be either `val` or `test`\r\n        \r\n        Returns:\r\n            None\r\n        \"\"\"\r\n        assert split_mode in [\r\n            \"val\",\r\n            \"test\",\r\n        ], \"Invalid split-mode! It should be `val`, `test`\"\r\n        if not os.path.exists(fname):\r\n            raise FileNotFoundError(f\"File not found at {fname}\")\r\n        self.eval_set[split_mode] = load_pkl(fname)\r\n\r\n    def reset_eval_set(self, \r\n                       split_mode: str = \"test\",\r\n                       ) -> None:\r\n        r\"\"\"\r\n        Reset evaluation set\r\n\r\n        Parameters:\r\n            split_mode: specifies whether to generate negative edges for 'validation' or 'test' splits\r\n\r\n        Returns:\r\n            None\r\n        \"\"\"\r\n        assert split_mode in [\r\n            \"val\",\r\n            \"test\",\r\n        ], \"Invalid split-mode! It should be `val`, `test`!\"\r\n        self.eval_set[split_mode] = None\r\n\r\n    def query_batch(self, \r\n                    pos_src: Tensor, \r\n                    pos_dst: Tensor, \r\n                    pos_timestamp: Tensor, \r\n                    edge_type: Tensor = None,\r\n                    split_mode: str = \"test\") -> list:\r\n        r\"\"\"\r\n        For each positive edge in the `pos_batch`, return a list of negative edges\r\n        `split_mode` specifies whether the valiation or test evaluation set should be retrieved.\r\n        modify now to include edge type argument\r\n\r\n        Parameters:\r\n            pos_src: list of positive source nodes\r\n            pos_dst: list of positive destination nodes\r\n            pos_timestamp: list of timestamps of the positive edges\r\n            split_mode: specifies whether to generate negative edges for 'validation' or 'test' splits\r\n\r\n        Returns:\r\n            neg_samples: a list of list; each internal list contains the set of negative edges that\r\n                        should be evaluated against each positive edge.\r\n        \"\"\"\r\n        assert split_mode in [\r\n            \"val\",\r\n            \"test\",\r\n        ], \"Invalid split-mode! It should be `val`, `test`!\"\r\n        if self.eval_set[split_mode] == None:\r\n            raise ValueError(\r\n                f\"Evaluation set is None! You should load the {split_mode} evaluation set first!\"\r\n            )\r\n        \r\n        # check the argument types...\r\n        if torch is not None and isinstance(pos_src, torch.Tensor):\r\n            pos_src = pos_src.detach().cpu().numpy()\r\n        if torch is not None and isinstance(pos_dst, torch.Tensor):\r\n            pos_dst = pos_dst.detach().cpu().numpy()\r\n        if torch is not None and isinstance(pos_timestamp, torch.Tensor):\r\n            pos_timestamp = pos_timestamp.detach().cpu().numpy()\r\n        if torch is not None and isinstance(edge_type, torch.Tensor):\r\n            edge_type = edge_type.detach().cpu().numpy()\r\n        \r\n        if not isinstance(pos_src, np.ndarray) or not isinstance(pos_dst, np.ndarray) or not(pos_timestamp, np.ndarray):\r\n            raise RuntimeError(\r\n                \"pos_src, pos_dst, and pos_timestamp need to be either numpy ndarray or torch tensor!\"\r\n                )\r\n\r\n        neg_samples = []\r\n        if (edge_type is None):\r\n            for pos_s, pos_d, pos_t in zip(pos_src, pos_dst, pos_timestamp):\r\n                if (pos_s, pos_d, pos_t) not in self.eval_set[split_mode]:\r\n                    raise ValueError(\r\n                        f\"The edge ({pos_s}, {pos_d}, {pos_t}) is not in the '{split_mode}' evaluation set! Please check the implementation.\"\r\n                    )\r\n                else:\r\n                    neg_samples.append(\r\n                        [\r\n                            int(neg_dst)\r\n                            for neg_dst in self.eval_set[split_mode][(pos_s, pos_d, pos_t)]\r\n                        ]\r\n                    )\r\n        else:\r\n            for pos_s, pos_d, pos_t, e_type in zip(pos_src, pos_dst, pos_timestamp, edge_type):\r\n                if (pos_s, pos_d, pos_t, e_type) not in self.eval_set[split_mode]:\r\n                    raise ValueError(\r\n                        f\"The edge ({pos_s}, {pos_d}, {pos_t}, {e_type}) is not in the '{split_mode}' evaluation set! Please check the implementation.\"\r\n                    )\r\n                else:\r\n                    neg_samples.append(\r\n                        [\r\n                            int(neg_dst)\r\n                            for neg_dst in self.eval_set[split_mode][(pos_s, pos_d, pos_t, e_type)]\r\n                        ]\r\n                    )\r\n\r\n        return neg_samples\r\n"
  },
  {
    "path": "tgb/linkproppred/thg_negative_generator.py",
    "content": "\"\"\"\r\nSample and Generate negative edges that are going to be used for evaluation of a dynamic graph learning model\r\nNegative samples are generated and saved to files ONLY once; \r\n    other times, they should be loaded from file with instances of the `negative_sampler.py`.\r\n\"\"\"\r\nimport os\r\nimport torch\r\nimport numpy as np\r\nfrom tqdm import tqdm\r\nfrom torch_geometric.data import TemporalData\r\nfrom tgb.utils.utils import save_pkl\r\nfrom typing import Union\r\nfrom tgb.utils.utils import vprint\r\n\r\n\r\n\r\n\"\"\"\r\nnegative sample generator for tkg datasets \r\ntemporal filterted MRR\r\n\"\"\"\r\nclass THGNegativeEdgeGenerator(object):\r\n    def __init__(\r\n        self,\r\n        dataset_name: str,\r\n        first_node_id: int,\r\n        last_node_id: int,\r\n        node_type: Union[np.ndarray, torch.Tensor],\r\n        strategy: str = \"node-type-filtered\",\r\n        num_neg_e: int = -1,  # -1 means generate all possible negatives\r\n        rnd_seed: int = 1,\r\n        edge_data: TemporalData = None,\r\n    ) -> None:\r\n        r\"\"\"\r\n        Negative Edge Generator class for Temporal Heterogeneous Graphs\r\n        this is a class for generating negative samples for a specific datasets\r\n        the set of the positive samples are provided, the negative samples are generated with specific strategies \r\n        and are saved for consistent evaluation across different methods\r\n\r\n        Parameters:\r\n            dataset_name: name of the dataset\r\n            first_node_id: the first node id\r\n            last_node_id: the last node id\r\n            node_type: the node type of each node\r\n            strategy: the strategy to generate negative samples\r\n            num_neg_e: number of negative samples to generate\r\n            rnd_seed: random seed\r\n            edge_data: the edge data object containing the positive edges\r\n        Returns:\r\n            None\r\n        \"\"\"\r\n        self.rnd_seed = rnd_seed\r\n        np.random.seed(self.rnd_seed)\r\n        self.dataset_name = dataset_name\r\n        self.first_node_id = first_node_id\r\n        self.last_node_id = last_node_id\r\n        if isinstance(node_type, torch.Tensor):\r\n            node_type = node_type.cpu().numpy()\r\n        self.node_type = node_type\r\n        self.node_type_dict = self.get_destinations_based_on_node_type(first_node_id, last_node_id, self.node_type) # {node_type: {nid:1}}\r\n        assert isinstance(self.node_type, np.ndarray), \"node_type should be a numpy array\"\r\n        self.num_neg_e = num_neg_e  #-1 means generate all \r\n\r\n        assert strategy in [\r\n            \"node-type-filtered\",\r\n            \"random\",\r\n        ], \"The supported strategies are `node-type-filtered`\"\r\n        self.strategy = strategy\r\n        self.edge_data = edge_data\r\n\r\n    def get_destinations_based_on_node_type(self, \r\n                                            first_node_id: int,\r\n                                            last_node_id: int,\r\n                                            node_type: np.ndarray) -> dict:\r\n        r\"\"\"\r\n        get the destination node id arrays based on the node type\r\n        Parameters:\r\n            first_node_id: the first node id\r\n            last_node_id: the last node id\r\n            node_type: the node type of each node\r\n\r\n        Returns:\r\n            node_type_dict: a dictionary containing the destination node ids for each node type\r\n        \"\"\"\r\n        node_type_store = {}\r\n        assert first_node_id <= last_node_id, \"Invalid destination node ids!\"\r\n        assert len(node_type) == (last_node_id - first_node_id + 1), \"node type array must match the indices\"\r\n        for k in range(len(node_type)):\r\n            nt = int(node_type[k]) #node type must be ints\r\n            nid = k + first_node_id\r\n            if nt not in node_type_store:\r\n                node_type_store[nt] = {nid:1}\r\n            else:\r\n                node_type_store[nt][nid] = 1\r\n        node_type_dict = {}\r\n        for ntype in node_type_store:\r\n            node_type_dict[ntype] = np.array(list(node_type_store[ntype].keys()))\r\n            assert np.all(np.diff(node_type_dict[ntype]) >= 0), \"Destination node ids for a given type must be sorted\"\r\n            assert np.all(node_type_dict[ntype] <= last_node_id), \"Destination node ids must be less than or equal to the last destination id\"\r\n        return node_type_dict\r\n\r\n    def generate_negative_samples(self, \r\n                                  pos_edges: TemporalData,\r\n                                  split_mode: str, \r\n                                  partial_path: str,\r\n                                  ) -> None:\r\n        r\"\"\"\r\n        Generate negative samples\r\n\r\n        Parameters:\r\n            pos_edges: positive edges to generate the negatives for\r\n            split_mode: specifies whether to generate negative edges for 'validation' or 'test' splits\r\n            partial_path: in which directory save the generated negatives\r\n        \"\"\"\r\n        # file name for saving or loading...\r\n        filename = (\r\n            partial_path\r\n            + \"/\"\r\n            + self.dataset_name\r\n            + \"_\"\r\n            + split_mode\r\n            + \"_\"\r\n            + \"ns\"\r\n            + \".pkl\"\r\n        )\r\n\r\n        if self.strategy == \"node-type-filtered\":\r\n            self.generate_negative_samples_nt(pos_edges, split_mode, filename)\r\n        elif self.strategy == \"random\":\r\n            self.generate_negative_samples_random(pos_edges, split_mode, filename)\r\n        else:\r\n            raise ValueError(\"Unsupported negative sample generation strategy!\")\r\n\r\n    def generate_negative_samples_nt(self, \r\n                                      data: TemporalData, \r\n                                      split_mode: str, \r\n                                      filename: str,\r\n                                      ) -> None:\r\n        r\"\"\"\r\n        now we consider (s, d, t, edge_type) as a unique edge, also adding the node type info for the destination node for convenience so (s, d, t, edge_type): (conflict_set, d_node_type)\r\n        Generate negative samples based on the random strategy:\r\n            - for each positive edge, retrieve all possible destinations based on the node type of the destination node\r\n            - filter actual positive edges at the same timestamp with the same edge type\r\n        \r\n        Parameters:\r\n            data: an object containing positive edges information\r\n            split_mode: specifies whether to generate negative edges for 'validation' or 'test' splits\r\n            filename: name of the file containing the generated negative edges\r\n        \"\"\"\r\n        vprint(\r\n            f\"INFO: Negative Sampling Strategy: {self.strategy}, Data Split: {split_mode}\"\r\n        )\r\n        assert split_mode in [\r\n            \"val\",\r\n            \"test\",\r\n        ], \"Invalid split-mode! It should be `val` or `test`!\"\r\n\r\n        if os.path.exists(filename):\r\n            vprint(\r\n                f\"INFO: negative samples for '{split_mode}' evaluation are already generated!\"\r\n            )\r\n        else:\r\n            vprint(f\"INFO: Generating negative samples for '{split_mode}' evaluation!\")\r\n            # retrieve the information from the batch\r\n            pos_src, pos_dst, pos_timestamp, edge_type = (\r\n                data.src.cpu().numpy(),\r\n                data.dst.cpu().numpy(),\r\n                data.t.cpu().numpy(),\r\n                data.edge_type.cpu().numpy(),\r\n            )\r\n\r\n            # generate a list of negative destinations for each positive edge\r\n            pos_edge_tqdm = tqdm(\r\n                zip(pos_src, pos_dst, pos_timestamp, edge_type), total=len(pos_src)\r\n            )\r\n\r\n            edge_t_dict = {} # {(t, u, edge_type): {v_1, v_2, ..} }\r\n            #! iterate once to put all edges into a dictionary for reference\r\n            for (\r\n                pos_s,\r\n                pos_d,\r\n                pos_t,\r\n                edge_type,\r\n            ) in pos_edge_tqdm:\r\n                if (pos_t, pos_s, edge_type) not in edge_t_dict:\r\n                    edge_t_dict[(pos_t, pos_s, edge_type)] = {pos_d:1}\r\n                else:\r\n                    edge_t_dict[(pos_t, pos_s, edge_type)][pos_d] = 1\r\n\r\n            out_dict = {}\r\n            for key in tqdm(edge_t_dict):\r\n                conflict_set = np.array(list(edge_t_dict[key].keys()))\r\n                pos_d = conflict_set[0]\r\n                #* retieve the node type of the destination node as well \r\n                #! assumption, same edge type = same destination node type\r\n                d_node_type = int(self.node_type[pos_d - self.first_node_id])\r\n                all_dst = self.node_type_dict[d_node_type]\r\n                if (self.num_neg_e == -1):\r\n                    filtered_all_dst = np.setdiff1d(all_dst, conflict_set)\r\n                else:\r\n                    #* lazy sampling\r\n                    neg_d_arr = np.random.choice(\r\n                        all_dst, self.num_neg_e, replace=False) #never replace negatives\r\n                    if len(np.setdiff1d(neg_d_arr, conflict_set)) < self.num_neg_e:\r\n                        neg_d_arr = np.random.choice(\r\n                            np.setdiff1d(all_dst, conflict_set), self.num_neg_e, replace=False)\r\n                    filtered_all_dst = neg_d_arr\r\n                out_dict[key] = filtered_all_dst\r\n            vprint (\"ns samples for \", len(out_dict), \" positive edges are generated\")\r\n            # save the generated evaluation set to disk\r\n            save_pkl(out_dict, filename)\r\n\r\n    def generate_negative_samples_random(self, \r\n                                      data: TemporalData, \r\n                                      split_mode: str, \r\n                                      filename: str,\r\n                                      ) -> None:\r\n        r\"\"\"\r\n        generate random negative edges for ablation study\r\n        \r\n        Parameters:\r\n            data: an object containing positive edges information\r\n            split_mode: specifies whether to generate negative edges for 'validation' or 'test' splits\r\n            filename: name of the file containing the generated negative edges\r\n        \"\"\"\r\n        vprint(\r\n            f\"INFO: Negative Sampling Strategy: {self.strategy}, Data Split: {split_mode}\"\r\n        )\r\n        assert split_mode in [\r\n            \"val\",\r\n            \"test\",\r\n        ], \"Invalid split-mode! It should be `val` or `test`!\"\r\n\r\n        if os.path.exists(filename):\r\n            vprint(\r\n                f\"INFO: negative samples for '{split_mode}' evaluation are already generated!\"\r\n            )\r\n        else:\r\n            vprint(f\"INFO: Generating negative samples for '{split_mode}' evaluation!\")\r\n            # retrieve the information from the batch\r\n            pos_src, pos_dst, pos_timestamp, edge_type = (\r\n                data.src.cpu().numpy(),\r\n                data.dst.cpu().numpy(),\r\n                data.t.cpu().numpy(),\r\n                data.edge_type.cpu().numpy(),\r\n            )\r\n            first_dst_id = self.edge_data.dst.min()\r\n            last_dst_id = self.edge_data.dst.max()\r\n            all_dst = np.arange(first_dst_id, last_dst_id + 1)\r\n            evaluation_set = {}\r\n            # generate a list of negative destinations for each positive edge\r\n            pos_edge_tqdm = tqdm(\r\n                zip(pos_src, pos_dst, pos_timestamp, edge_type), total=len(pos_src)\r\n            )\r\n\r\n            for (\r\n                pos_s,\r\n                pos_d,\r\n                pos_t,\r\n                edge_type,\r\n            ) in pos_edge_tqdm:\r\n                t_mask = pos_timestamp == pos_t\r\n                src_mask = pos_src == pos_s\r\n                fn_mask = np.logical_and(t_mask, src_mask)\r\n                pos_e_dst_same_src = pos_dst[fn_mask]\r\n                filtered_all_dst = np.setdiff1d(all_dst, pos_e_dst_same_src)\r\n                if (self.num_neg_e > len(filtered_all_dst)):\r\n                    neg_d_arr = filtered_all_dst\r\n                else:\r\n                    neg_d_arr = np.random.choice(\r\n                    filtered_all_dst, self.num_neg_e, replace=False) #never replace negatives\r\n                evaluation_set[(pos_t, pos_s, edge_type)] = neg_d_arr\r\n            save_pkl(evaluation_set, filename)\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n"
  },
  {
    "path": "tgb/linkproppred/thg_negative_sampler.py",
    "content": "\"\"\"\r\nSample negative edges for evaluation of dynamic link prediction\r\nLoad already generated negative edges from file, batch them based on the positive edge, and return the evaluation set\r\n\"\"\"\r\n\r\nimport torch\r\nfrom torch import Tensor\r\nimport numpy as np\r\nfrom tgb.utils.utils import load_pkl\r\nfrom typing import Union\r\nimport os\r\n\r\n\r\nclass THGNegativeEdgeSampler(object):\r\n    def __init__(\r\n        self,\r\n        dataset_name: str,\r\n        first_node_id: int,\r\n        last_node_id: int,\r\n        node_type: np.ndarray,\r\n        strategy: str = \"node-type-filtered\",\r\n    ) -> None:\r\n        r\"\"\"\r\n        Negative Edge Sampler\r\n            Loads and query the negative batches based on the positive batches provided.\r\n            constructor for the negative edge sampler class\r\n\r\n        Parameters:\r\n            dataset_name: name of the dataset\r\n            first_node_id: identity of the first node\r\n            last_node_id: indentity of the last destination node\r\n            node_type: the node type of each node\r\n            strategy: will always load the pre-generated negatives\r\n        \r\n        Returns:\r\n            None\r\n        \"\"\"\r\n        self.dataset_name = dataset_name\r\n        self.eval_set = {}\r\n        self.first_node_id = first_node_id\r\n        self.last_node_id = last_node_id\r\n        self.node_type = node_type\r\n        assert isinstance(self.node_type, np.ndarray), \"node_type should be a numpy array\"\r\n        \r\n    def load_eval_set(\r\n        self,\r\n        fname: str,\r\n        split_mode: str = \"val\",\r\n    ) -> None:\r\n        r\"\"\"\r\n        Load the evaluation set from disk, can be either val or test set ns samples\r\n        Parameters:\r\n            fname: the file name of the evaluation ns on disk\r\n            split_mode: the split mode of the evaluation set, can be either `val` or `test`\r\n        \r\n        Returns:\r\n            None\r\n        \"\"\"\r\n        assert split_mode in [\r\n            \"val\",\r\n            \"test\",\r\n        ], \"Invalid split-mode! It should be `val`, `test`\"\r\n        if not os.path.exists(fname):\r\n            raise FileNotFoundError(f\"File not found at {fname}\")\r\n        self.eval_set[split_mode] = load_pkl(fname)\r\n\r\n    def query_batch(self, \r\n                    pos_src: Union[Tensor, np.ndarray], \r\n                    pos_dst: Union[Tensor, np.ndarray], \r\n                    pos_timestamp: Union[Tensor, np.ndarray], \r\n                    edge_type: Union[Tensor, np.ndarray],\r\n                    split_mode: str = \"test\") -> list:\r\n        r\"\"\"\r\n        For each positive edge in the `pos_batch`, return a list of negative edges\r\n        `split_mode` specifies whether the valiation or test evaluation set should be retrieved.\r\n        modify now to include edge type argument\r\n\r\n        Parameters:\r\n            pos_src: list of positive source nodes\r\n            pos_dst: list of positive destination nodes\r\n            pos_timestamp: list of timestamps of the positive edges\r\n            split_mode: specifies whether to generate negative edges for 'validation' or 'test' splits\r\n\r\n        Returns:\r\n            neg_samples: list of numpy array; each array contains the set of negative edges that\r\n                        should be evaluated against each positive edge.\r\n        \"\"\"\r\n        assert split_mode in [\r\n            \"val\",\r\n            \"test\",\r\n        ], \"Invalid split-mode! It should be `val`, `test`!\"\r\n        if self.eval_set[split_mode] == None:\r\n            raise ValueError(\r\n                f\"Evaluation set is None! You should load the {split_mode} evaluation set first!\"\r\n            )\r\n        \r\n        # check the argument types...\r\n        if torch is not None and isinstance(pos_src, torch.Tensor):\r\n            pos_src = pos_src.detach().cpu().numpy()\r\n        if torch is not None and isinstance(pos_dst, torch.Tensor):\r\n            pos_dst = pos_dst.detach().cpu().numpy()\r\n        if torch is not None and isinstance(pos_timestamp, torch.Tensor):\r\n            pos_timestamp = pos_timestamp.detach().cpu().numpy()\r\n        if torch is not None and isinstance(edge_type, torch.Tensor):\r\n            edge_type = edge_type.detach().cpu().numpy()\r\n        \r\n        if not isinstance(pos_src, np.ndarray) or not isinstance(pos_dst, np.ndarray) or not(pos_timestamp, np.ndarray) or not(edge_type, np.ndarray):\r\n            raise RuntimeError(\r\n                \"pos_src, pos_dst, and pos_timestamp need to be either numpy ndarray or torch tensor!\"\r\n                )\r\n\r\n        neg_samples = []\r\n        for pos_s, pos_d, pos_t, e_type in zip(pos_src, pos_dst, pos_timestamp, edge_type):\r\n            if (pos_t, pos_s, e_type) not in self.eval_set[split_mode]:\r\n                raise ValueError(\r\n                    f\"The edge ({pos_s}, {pos_d}, {pos_t}, {e_type}) is not in the '{split_mode}' evaluation set! Please check the implementation.\"\r\n                )\r\n            else:\r\n                filtered_dst = self.eval_set[split_mode]\r\n                neg_d_arr = filtered_dst[(pos_t, pos_s, e_type)]\r\n                neg_samples.append(\r\n                        neg_d_arr\r\n                    )\r\n        \r\n        #? can't convert to numpy array due to different lengths of negative samples\r\n        return neg_samples\r\n"
  },
  {
    "path": "tgb/linkproppred/tkg_negative_generator.py",
    "content": "\"\"\"\r\nSample and Generate negative edges that are going to be used for evaluation of a dynamic graph learning model\r\nNegative samples are generated and saved to files ONLY once; \r\n    other times, they should be loaded from file with instances of the `negative_sampler.py`.\r\n\"\"\"\r\nimport numpy as np\r\nfrom torch_geometric.data import TemporalData\r\nimport matplotlib.pyplot as plt\r\nfrom tgb.utils.utils import save_pkl\r\nimport os\r\nfrom tqdm import tqdm\r\nfrom tgb.utils.utils import vprint\r\n\r\n\r\n\r\n\"\"\"\r\nnegative sample generator for tkg datasets \r\ntemporal filterted MRR\r\n\"\"\"\r\nclass TKGNegativeEdgeGenerator(object):\r\n    def __init__(\r\n        self,\r\n        dataset_name: str,\r\n        first_dst_id: int,\r\n        last_dst_id: int,\r\n        strategy: str = \"time-filtered\",\r\n        num_neg_e: int = -1,  # -1 means generate all possible negatives\r\n        rnd_seed: int = 1,\r\n        partial_path: str = None,\r\n        edge_data: TemporalData = None,\r\n    ) -> None:\r\n        r\"\"\"\r\n        Negative Edge Generator class for Temporal Knowledge Graphs\r\n        constructor for the negative edge generator class\r\n\r\n        Parameters:\r\n            dataset_name: name of the dataset\r\n            first_dst_id: identity of the first destination node\r\n            last_dst_id: indentity of the last destination node\r\n            num_neg_e: number of negative edges being generated per each positive edge\r\n            strategy: specifies which strategy should be used for generating the negatives\r\n            rnd_seed: random seed for reproducibility\r\n            edge_data: the positive edges to generate the negatives for, assuming sorted temporally\r\n        \r\n        Returns:\r\n            None\r\n        \"\"\"\r\n        self.rnd_seed = rnd_seed\r\n        np.random.seed(self.rnd_seed)\r\n        self.dataset_name = dataset_name\r\n        self.first_dst_id = first_dst_id\r\n        self.last_dst_id = last_dst_id      \r\n        self.num_neg_e = num_neg_e  #-1 means generate all \r\n        assert strategy in [\r\n            \"time-filtered\",\r\n            \"dst-time-filtered\",\r\n            \"random\"\r\n        ], \"The supported strategies are `time-filtered`, dst-time-filtered, random\"\r\n        self.strategy = strategy\r\n        self.dst_dict = None\r\n        if self.strategy == \"dst-time-filtered\":\r\n            if partial_path is None:\r\n                raise ValueError(\r\n                    \"The partial path to the directory where the dst_dict is stored is required\")\r\n            else:\r\n                self.dst_dict_name = (\r\n                    partial_path\r\n                    + \"/\"\r\n                    + self.dataset_name\r\n                    + \"_\"\r\n                    + \"dst_dict\"\r\n                    + \".pkl\"\r\n                )\r\n                self.dst_dict = self.generate_dst_dict(edge_data=edge_data, dst_name=self.dst_dict_name)\r\n        self.edge_data = edge_data\r\n\r\n    def generate_dst_dict(self, edge_data: TemporalData, dst_name: str) -> dict:\r\n        r\"\"\"\r\n        Generate a dictionary of destination nodes for each type of edge\r\n\r\n        Parameters:\r\n            edge_data: an object containing positive edges information\r\n            dst_name: name of the file to save the generated dictionary of destination nodes\r\n        \r\n        Returns:\r\n            dst_dict: a dictionary of destination nodes for each type of edge\r\n        \"\"\"\r\n\r\n        min_dst_idx, max_dst_idx = int(edge_data.dst.min()), int(edge_data.dst.max())\r\n\r\n        pos_src, pos_dst, pos_timestamp, edge_type = (\r\n            edge_data.src.cpu().numpy(),\r\n            edge_data.dst.cpu().numpy(),\r\n            edge_data.t.cpu().numpy(),\r\n            edge_data.edge_type.cpu().numpy(),\r\n        )\r\n\r\n\r\n\r\n        dst_track_dict = {} # {edge_type: {dst_1, dst_2, ..} }\r\n\r\n        # generate a list of negative destinations for each positive edge\r\n        pos_edge_tqdm = tqdm(\r\n            zip(pos_src, pos_dst, pos_timestamp, edge_type), total=len(pos_src)\r\n        )\r\n\r\n        for (\r\n            pos_s,\r\n            pos_d,\r\n            pos_t,\r\n            edge_type,\r\n            ) in pos_edge_tqdm:\r\n            if edge_type not in dst_track_dict:\r\n                dst_track_dict[edge_type] = {pos_d:1}\r\n            else:\r\n                dst_track_dict[edge_type][pos_d] = 1\r\n        dst_dict = {}\r\n        edge_type_size = []\r\n        for key in dst_track_dict:\r\n            dst = np.array(list(dst_track_dict[key].keys()))\r\n            edge_type_size.append(len(dst))\r\n            dst_dict[key] = dst\r\n        vprint ('destination candidates generated for all edge types ', len(dst_dict))\r\n        return dst_dict\r\n\r\n    def generate_negative_samples(self, \r\n                                  pos_edges: TemporalData,\r\n                                  split_mode: str, \r\n                                  partial_path: str,\r\n                                  ) -> None:\r\n        r\"\"\"\r\n        Generate negative samples\r\n\r\n        Parameters:\r\n            pos_edges: positive edges to generate the negatives for\r\n            split_mode: specifies whether to generate negative edges for 'validation' or 'test' splits\r\n            partial_path: in which directory save the generated negatives\r\n        \"\"\"\r\n        # file name for saving or loading...\r\n        filename = (\r\n            partial_path\r\n            + \"/\"\r\n            + self.dataset_name\r\n            + \"_\"\r\n            + split_mode\r\n            + \"_\"\r\n            + \"ns\"\r\n            + \".pkl\"\r\n        )\r\n\r\n        if self.strategy == \"time-filtered\":\r\n            self.generate_negative_samples_ftr(pos_edges, split_mode, filename)\r\n        elif self.strategy == \"dst-time-filtered\":\r\n            self.generate_negative_samples_dst(pos_edges, split_mode, filename)\r\n        elif self.strategy == \"random\":\r\n            self.generate_negative_samples_random(pos_edges, split_mode, filename)\r\n        else:\r\n            raise ValueError(\"Unsupported negative sample generation strategy!\")\r\n        \r\n    def generate_negative_samples_ftr(self, \r\n                                      data: TemporalData, \r\n                                      split_mode: str, \r\n                                      filename: str,\r\n                                      ) -> None:\r\n        r\"\"\"\r\n        now we consider (s, d, t, edge_type) as a unique edge\r\n        Generate negative samples based on the random strategy:\r\n            - for each positive edge, sample a batch of negative edges from all possible edges with the same source node\r\n            - filter actual positive edges at the same timestamp with the same edge type\r\n        \r\n        Parameters:\r\n            data: an object containing positive edges information\r\n            split_mode: specifies whether to generate negative edges for 'validation' or 'test' splits\r\n            filename: name of the file containing the generated negative edges\r\n        \"\"\"\r\n        vprint(\r\n            f\"INFO: Negative Sampling Strategy: {self.strategy}, Data Split: {split_mode}\"\r\n        )\r\n        assert split_mode in [\r\n            \"val\",\r\n            \"test\",\r\n        ], \"Invalid split-mode! It should be `val` or `test`!\"\r\n\r\n        if os.path.exists(filename):\r\n            vprint(\r\n                f\"INFO: negative samples for '{split_mode}' evaluation are already generated!\"\r\n            )\r\n        else:\r\n            vprint(f\"INFO: Generating negative samples for '{split_mode}' evaluation!\")\r\n            # retrieve the information from the batch\r\n            pos_src, pos_dst, pos_timestamp, edge_type = (\r\n                data.src.cpu().numpy(),\r\n                data.dst.cpu().numpy(),\r\n                data.t.cpu().numpy(),\r\n                data.edge_type.cpu().numpy(),\r\n            )\r\n            # generate a list of negative destinations for each positive edge\r\n            pos_edge_tqdm = tqdm(\r\n                zip(pos_src, pos_dst, pos_timestamp, edge_type), total=len(pos_src)\r\n            )\r\n\r\n            edge_t_dict = {} # {(t, u, edge_type): {v_1, v_2, ..} }\r\n            #! iterate once to put all edges into a dictionary for reference\r\n            for (\r\n                pos_s,\r\n                pos_d,\r\n                pos_t,\r\n                edge_type,\r\n            ) in pos_edge_tqdm:\r\n                if (pos_t, pos_s, edge_type) not in edge_t_dict:\r\n                    edge_t_dict[(pos_t, pos_s, edge_type)] = {pos_d:1}\r\n                else:\r\n                    edge_t_dict[(pos_t, pos_s, edge_type)][pos_d] = 1\r\n\r\n            conflict_dict = {}\r\n            for key in edge_t_dict:\r\n                conflict_dict[key] = np.array(list(edge_t_dict[key].keys()))\r\n            \r\n            vprint (\"conflict sets for ns samples for \", len(conflict_dict), \" positive edges are generated\")\r\n            # save the generated evaluation set to disk\r\n            save_pkl(conflict_dict, filename)\r\n\r\n\r\n    def generate_negative_samples_dst(self, \r\n                                      data: TemporalData, \r\n                                      split_mode: str, \r\n                                      filename: str,\r\n                                      ) -> None:\r\n        r\"\"\"\r\n        now we consider (s, d, t, edge_type) as a unique edge\r\n        Generate negative samples based on the random strategy:\r\n            - for each positive edge, sample a batch of negative edges from all possible edges with the same source node\r\n            - filter actual positive edges at the same timestamp with the same edge type\r\n        \r\n        Parameters:\r\n            data: an object containing positive edges information\r\n            split_mode: specifies whether to generate negative edges for 'validation' or 'test' splits\r\n            filename: name of the file containing the generated negative edges\r\n        \"\"\"\r\n        vprint(\r\n            f\"INFO: Negative Sampling Strategy: {self.strategy}, Data Split: {split_mode}\"\r\n        )\r\n        assert split_mode in [\r\n            \"val\",\r\n            \"test\",\r\n        ], \"Invalid split-mode! It should be `val` or `test`!\"\r\n\r\n        if os.path.exists(filename):\r\n            vprint(\r\n                f\"INFO: negative samples for '{split_mode}' evaluation are already generated!\"\r\n            )\r\n        else:\r\n            if self.dst_dict is None:\r\n                raise ValueError(\"The dst_dict is not generated!\")\r\n\r\n            vprint(f\"INFO: Generating negative samples for '{split_mode}' evaluation!\")\r\n            # retrieve the information from the batch\r\n            pos_src, pos_dst, pos_timestamp, edge_type = (\r\n                data.src.cpu().numpy(),\r\n                data.dst.cpu().numpy(),\r\n                data.t.cpu().numpy(),\r\n                data.edge_type.cpu().numpy(),\r\n            )\r\n            # generate a list of negative destinations for each positive edge\r\n            pos_edge_tqdm = tqdm(\r\n                zip(pos_src, pos_dst, pos_timestamp, edge_type), total=len(pos_src)\r\n            )\r\n\r\n            edge_t_dict = {} # {(t, u, edge_type): {v_1, v_2, ..} }\r\n            out_dict = {}\r\n            #! iterate once to put all edges into a dictionary for reference\r\n            for (\r\n                pos_s,\r\n                pos_d,\r\n                pos_t,\r\n                edge_type,\r\n            ) in pos_edge_tqdm:\r\n                if (pos_t, pos_s, edge_type) not in edge_t_dict:\r\n                    edge_t_dict[(pos_t, pos_s, edge_type)] = {pos_d:1}\r\n                else:\r\n                    edge_t_dict[(pos_t, pos_s, edge_type)][pos_d] = 1\r\n\r\n\r\n            pos_src, pos_dst, pos_timestamp, edge_type = (\r\n                data.src.cpu().numpy(),\r\n                data.dst.cpu().numpy(),\r\n                data.t.cpu().numpy(),\r\n                data.edge_type.cpu().numpy(),\r\n            )\r\n\r\n            new_pos_edge_tqdm = tqdm(\r\n                zip(pos_src, pos_dst, pos_timestamp, edge_type), total=len(pos_src)\r\n            )\r\n\r\n            min_dst_idx, max_dst_idx = int(self.edge_data.dst.min()), int(self.edge_data.dst.max())\r\n\r\n\r\n            for (\r\n                pos_s,\r\n                pos_d,\r\n                pos_t,\r\n                edge_type,\r\n            ) in new_pos_edge_tqdm:\r\n                #* generate based on # of ns samples\r\n                conflict_set = np.array(list(edge_t_dict[(pos_t, pos_s, edge_type)].keys()))\r\n                dst_set = self.dst_dict[edge_type]  #dst_set contains conflict set\r\n                sample_num = self.num_neg_e\r\n                filtered_dst_set = np.setdiff1d(dst_set, conflict_set) #more efficient\r\n                dst_sampled = None\r\n                all_dst = np.arange(min_dst_idx, max_dst_idx+1)\r\n                if len(filtered_dst_set) < (sample_num):\r\n                    #* with collision check\r\n                    filtered_sample_set = np.setdiff1d(all_dst, filtered_dst_set)\r\n                    dst_sampled = np.random.choice(filtered_sample_set, sample_num, replace=False)\r\n                    # #* remove the conflict set from dst set\r\n                    dst_sampled[0:len(filtered_dst_set)] = filtered_dst_set[:]\r\n                else:\r\n                    # dst_sampled = rng.choice(max_dst_idx+1, sample_num, replace=False)\r\n                    dst_sampled = np.random.choice(filtered_dst_set, sample_num, replace=False)\r\n\r\n                out_dict[(pos_t, pos_s, edge_type)] = dst_sampled\r\n            \r\n            vprint (\"negative samples for \", len(out_dict), \" positive edges are generated\")\r\n            # save the generated evaluation set to disk\r\n            save_pkl(out_dict, filename)\r\n\r\n    \r\n    def generate_negative_samples_random(self, \r\n                                      data: TemporalData, \r\n                                      split_mode: str, \r\n                                      filename: str,\r\n                                      ) -> None:\r\n        r\"\"\"\r\n        generate random negative edges for ablation study\r\n        \r\n        Parameters:\r\n            data: an object containing positive edges information\r\n            split_mode: specifies whether to generate negative edges for 'validation' or 'test' splits\r\n            filename: name of the file containing the generated negative edges\r\n        \"\"\"\r\n        vprint(\r\n            f\"INFO: Negative Sampling Strategy: {self.strategy}, Data Split: {split_mode}\"\r\n        )\r\n        assert split_mode in [\r\n            \"val\",\r\n            \"test\",\r\n        ], \"Invalid split-mode! It should be `val` or `test`!\"\r\n\r\n        if os.path.exists(filename):\r\n            vprint(\r\n                f\"INFO: negative samples for '{split_mode}' evaluation are already generated!\"\r\n            )\r\n        else:\r\n            vprint(f\"INFO: Generating negative samples for '{split_mode}' evaluation!\")\r\n            # retrieve the information from the batch\r\n            pos_src, pos_dst, pos_timestamp, edge_type = (\r\n                data.src.cpu().numpy(),\r\n                data.dst.cpu().numpy(),\r\n                data.t.cpu().numpy(),\r\n                data.edge_type.cpu().numpy(),\r\n            )\r\n            first_dst_id = self.edge_data.dst.min()\r\n            last_dst_id = self.edge_data.dst.max()\r\n            all_dst = np.arange(first_dst_id, last_dst_id + 1)\r\n            evaluation_set = {}\r\n            # generate a list of negative destinations for each positive edge\r\n            pos_edge_tqdm = tqdm(\r\n                zip(pos_src, pos_dst, pos_timestamp, edge_type), total=len(pos_src)\r\n            )\r\n\r\n            for (\r\n                pos_s,\r\n                pos_d,\r\n                pos_t,\r\n                edge_type,\r\n            ) in pos_edge_tqdm:\r\n                t_mask = pos_timestamp == pos_t\r\n                src_mask = pos_src == pos_s\r\n                fn_mask = np.logical_and(t_mask, src_mask)\r\n                pos_e_dst_same_src = pos_dst[fn_mask]\r\n                filtered_all_dst = np.setdiff1d(all_dst, pos_e_dst_same_src)\r\n                if (self.num_neg_e > len(filtered_all_dst)):\r\n                    neg_d_arr = filtered_all_dst\r\n                else:\r\n                    neg_d_arr = np.random.choice(\r\n                    filtered_all_dst, self.num_neg_e, replace=False) #never replace negatives\r\n                evaluation_set[(pos_t, pos_s, edge_type)] = neg_d_arr\r\n            save_pkl(evaluation_set, filename)\r\n\r\n    "
  },
  {
    "path": "tgb/linkproppred/tkg_negative_sampler.py",
    "content": "\"\"\"\r\nSample negative edges for evaluation of dynamic link prediction\r\nLoad already generated negative edges from file, batch them based on the positive edge, and return the evaluation set\r\n\"\"\"\r\n\r\nimport torch\r\nfrom torch import Tensor\r\nimport numpy as np\r\nfrom torch_geometric.data import TemporalData\r\nfrom tgb.utils.utils import save_pkl, load_pkl\r\nfrom tgb.utils.info import PROJ_DIR\r\nfrom typing import Union\r\nimport os\r\nimport time\r\n\r\n\r\nclass TKGNegativeEdgeSampler(object):\r\n    def __init__(\r\n        self,\r\n        dataset_name: str,\r\n        first_dst_id: int,\r\n        last_dst_id: int,\r\n        strategy: str = \"time-filtered\",\r\n        partial_path: str = PROJ_DIR + \"/data/processed\",\r\n    ) -> None:\r\n        r\"\"\"\r\n        Negative Edge Sampler\r\n            Loads and query the negative batches based on the positive batches provided.\r\n        constructor for the negative edge sampler class\r\n\r\n        Parameters:\r\n            dataset_name: name of the dataset\r\n            first_dst_id: identity of the first destination node\r\n            last_dst_id: indentity of the last destination node\r\n            strategy: will always load the pre-generated negatives\r\n            partial_path: the path to the directory where the negative edges are stored\r\n        \r\n        Returns:\r\n            None\r\n        \"\"\"\r\n        self.dataset_name = dataset_name\r\n        self.eval_set = {}\r\n        self.first_dst_id = first_dst_id\r\n        self.last_dst_id = last_dst_id\r\n        self.strategy = strategy\r\n        self.dst_dict = None\r\n      \r\n    def load_eval_set(\r\n        self,\r\n        fname: str,\r\n        split_mode: str = \"val\",\r\n    ) -> None:\r\n        r\"\"\"\r\n        Load the evaluation set from disk, can be either val or test set ns samples\r\n        Parameters:\r\n            fname: the file name of the evaluation ns on disk\r\n            split_mode: the split mode of the evaluation set, can be either `val` or `test`\r\n        \r\n        Returns:\r\n            None\r\n        \"\"\"\r\n        assert split_mode in [\r\n            \"val\",\r\n            \"test\",\r\n        ], \"Invalid split-mode! It should be `val`, `test`\"\r\n        if not os.path.exists(fname):\r\n            raise FileNotFoundError(f\"File not found at {fname}\")\r\n        self.eval_set[split_mode] = load_pkl(fname)\r\n\r\n    def query_batch(self, \r\n                    pos_src: Union[Tensor, np.ndarray], \r\n                    pos_dst: Union[Tensor, np.ndarray], \r\n                    pos_timestamp: Union[Tensor, np.ndarray], \r\n                    edge_type: Union[Tensor, np.ndarray],\r\n                    split_mode: str = \"test\") -> list:\r\n        r\"\"\"\r\n        For each positive edge in the `pos_batch`, return a list of negative edges\r\n        `split_mode` specifies whether the valiation or test evaluation set should be retrieved.\r\n        modify now to include edge type argument\r\n\r\n        Parameters:\r\n            pos_src: list of positive source nodes\r\n            pos_dst: list of positive destination nodes\r\n            pos_timestamp: list of timestamps of the positive edges\r\n            split_mode: specifies whether to generate negative edges for 'validation' or 'test' splits\r\n\r\n        Returns:\r\n            neg_samples: list of numpy array; each array contains the set of negative edges that\r\n                        should be evaluated against each positive edge.\r\n        \"\"\"\r\n        assert split_mode in [\r\n            \"val\",\r\n            \"test\",\r\n        ], \"Invalid split-mode! It should be `val`, `test`!\"\r\n        if self.eval_set[split_mode] == None:\r\n            raise ValueError(\r\n                f\"Evaluation set is None! You should load the {split_mode} evaluation set first!\"\r\n            )\r\n        \r\n        # check the argument types...\r\n        if torch is not None and isinstance(pos_src, torch.Tensor):\r\n            pos_src = pos_src.detach().cpu().numpy()\r\n        if torch is not None and isinstance(pos_dst, torch.Tensor):\r\n            pos_dst = pos_dst.detach().cpu().numpy()\r\n        if torch is not None and isinstance(pos_timestamp, torch.Tensor):\r\n            pos_timestamp = pos_timestamp.detach().cpu().numpy()\r\n        if torch is not None and isinstance(edge_type, torch.Tensor):\r\n            edge_type = edge_type.detach().cpu().numpy()\r\n        \r\n        if not isinstance(pos_src, np.ndarray) or not isinstance(pos_dst, np.ndarray) or not(pos_timestamp, np.ndarray) or not(edge_type, np.ndarray):\r\n            raise RuntimeError(\r\n                \"pos_src, pos_dst, and pos_timestamp need to be either numpy ndarray or torch tensor!\"\r\n                )\r\n        \r\n        if self.strategy == \"time-filtered\":\r\n            neg_samples = []\r\n            for pos_s, pos_d, pos_t, e_type in zip(pos_src, pos_dst, pos_timestamp, edge_type):\r\n                if (pos_t, pos_s, e_type) not in self.eval_set[split_mode]:\r\n                    raise ValueError(\r\n                        f\"The edge ({pos_s}, {pos_d}, {pos_t}, {e_type}) is not in the '{split_mode}' evaluation set! Please check the implementation.\"\r\n                    )\r\n                else:\r\n                    conflict_dict = self.eval_set[split_mode]\r\n                    conflict_set = conflict_dict[(pos_t, pos_s, e_type)]\r\n                    all_dst = np.arange(self.first_dst_id, self.last_dst_id + 1)\r\n                    filtered_all_dst = np.delete(all_dst, conflict_set, axis=0)\r\n\r\n                    #! always using all possible destinations for evaluation\r\n                    neg_d_arr = filtered_all_dst\r\n\r\n                    #! this is very slow\r\n                    neg_samples.append(\r\n                            neg_d_arr\r\n                        )\r\n        elif self.strategy == \"dst-time-filtered\":\r\n            neg_samples = []\r\n            for pos_s, pos_d, pos_t, e_type in zip(pos_src, pos_dst, pos_timestamp, edge_type):\r\n                if (pos_t, pos_s, e_type) not in self.eval_set[split_mode]:\r\n                    raise ValueError(\r\n                        f\"The edge ({pos_s}, {pos_d}, {pos_t}, {e_type}) is not in the '{split_mode}' evaluation set! Please check the implementation.\"\r\n                    )\r\n                else:\r\n                    filtered_dst = self.eval_set[split_mode]\r\n                    neg_d_arr = filtered_dst[(pos_t, pos_s, e_type)]\r\n                    neg_samples.append(\r\n                            neg_d_arr\r\n                        )\r\n        #? can't convert to numpy array due to different lengths of negative samples\r\n        return neg_samples\r\n\r\n\r\n        "
  },
  {
    "path": "tgb/nodeproppred/dataset.py",
    "content": "from typing import Optional, Dict, Any, Tuple\r\nimport os\r\nimport os.path as osp\r\nimport numpy as np\r\nimport pandas as pd\r\nimport zipfile\r\nimport requests\r\nfrom clint.textui import progress\r\n\r\nfrom tgb.utils.info import (\r\n    PROJ_DIR,\r\n    DATA_URL_DICT,\r\n    DATA_NUM_CLASSES,\r\n    DATA_VERSION_DICT,\r\n    DATA_EVAL_METRIC_DICT,\r\n    BColors,\r\n)\r\nfrom tgb.utils.utils import save_pkl, load_pkl, vprint\r\nfrom tgb.utils.pre_process import (\r\n    load_label_dict,\r\n    load_edgelist_sr,\r\n    load_edgelist_token,\r\n    load_edgelist_datetime,\r\n    load_trade_label_dict,\r\n    load_edgelist_trade,\r\n)\r\n\r\n\r\nclass NodePropPredDataset(object):\r\n    def __init__(\r\n        self,\r\n        name: str,\r\n        root: str = \"datasets\",\r\n        meta_dict: Optional[dict] = None,\r\n        preprocess: Optional[bool] = True,\r\n        download: Optional[bool] = True, \r\n    ) -> None:\r\n        r\"\"\"Dataset class for the node property prediction task. Stores meta information about each dataset such as evaluation metrics etc.\r\n        also automatically pre-processes the dataset.\r\n        [!] node property prediction datasets requires the following:\r\n        self.meta_dict[\"fname\"]: path to the edge list file\r\n        self.meta_dict[\"nodefile\"]: path to the node label file\r\n\r\n        Parameters:\r\n            name: name of the dataset\r\n            root: root directory to store the dataset folder\r\n            meta_dict: dictionary containing meta information about the dataset, should contain key 'dir_name' which is the name of the dataset folder\r\n            preprocess: whether to pre-process the dataset\r\n            download: whether to download the dataset or not (default: True)\r\n        Returns:\r\n            None\r\n        \"\"\"\r\n        self.name = name  ## original name\r\n        # check if dataset url exist\r\n        if self.name in DATA_URL_DICT:\r\n            self.url = DATA_URL_DICT[self.name]\r\n        else:\r\n            self.url = None\r\n\r\n        # check if the evaluatioin metric are specified\r\n        if self.name in DATA_EVAL_METRIC_DICT:\r\n            self.metric = DATA_EVAL_METRIC_DICT[self.name]\r\n        else:\r\n            self.metric = None\r\n            raise ValueError(f\"Dataset {self.name} default evaluation metric not found, it is not supported yet.\")\r\n\r\n        root = PROJ_DIR + root\r\n\r\n        if meta_dict is None:\r\n            self.dir_name = \"_\".join(name.split(\"-\"))  ## replace hyphen with underline\r\n            meta_dict = {\"dir_name\": self.dir_name}\r\n        else:\r\n            self.dir_name = meta_dict[\"dir_name\"]\r\n        self.root = osp.join(root, self.dir_name)\r\n        self.meta_dict = meta_dict\r\n        if \"fname\" not in self.meta_dict:\r\n            self.meta_dict[\"fname\"] = self.root + \"/\" + self.name + \"_edgelist.csv\"\r\n            self.meta_dict[\"nodefile\"] = self.root + \"/\" + self.name + \"_node_labels.csv\"\r\n\r\n         #! version check\r\n        self.version_passed = True\r\n        self._version_check()\r\n\r\n        self._num_classes = DATA_NUM_CLASSES[self.name]\r\n\r\n        # initialize\r\n        self._node_feat = None\r\n        self._edge_feat = None\r\n        self._full_data = None\r\n\r\n        if download:\r\n            self.download()\r\n        else:\r\n            if osp.exists(self.meta_dict[\"fname\"]):\r\n                dir_name = self.meta_dict[\"fname\"]\r\n                vprint(f\"files found in {dir_name}\")\r\n            else:\r\n                dir_name = self.meta_dict[\"fname\"]\r\n                raise FileNotFoundError(f\"Directory not found at {dir_name}, please download the dataset\")\r\n            \r\n        # check if the root directory exists, if not create it\r\n        if osp.isdir(self.root):\r\n            vprint(\"Dataset directory is \", self.root)\r\n        else:\r\n            raise FileNotFoundError(f\"Directory not found at {self.root}\")\r\n\r\n        if preprocess:\r\n            self.pre_process()\r\n\r\n        self.label_ts_idx = 0  # index for which node lables to return now\r\n\r\n    def _version_check(self) -> None:\r\n        r\"\"\"Implement Version checks for dataset files\r\n        updates the file names based on the current version number\r\n        prompt the user to download the new version via self.version_passed variable\r\n        \"\"\"\r\n        if (self.name in DATA_VERSION_DICT):\r\n            version = DATA_VERSION_DICT[self.name]\r\n        else:\r\n            raise ValueError(f\"Dataset {self.name} version number not found.\")\r\n        \r\n        if (version > 1):\r\n            #* check if current version is outdated\r\n            self.meta_dict[\"fname\"] = self.root + \"/\" + self.name + \"_edgelist_v\" + str(int(version)) + \".csv\"\r\n            self.meta_dict[\"nodefile\"] = self.root + \"/\" + self.name + \"_node_labels_v\" + str(int(version)) + \".csv\"\r\n            \r\n            if (not osp.exists(self.meta_dict[\"fname\"])):\r\n                vprint(f\"Dataset {self.name} version {int(version)} not found, Please download the latest version of the dataset.\")\r\n                self.version_passed = False\r\n                return None\r\n\r\n    def download(self) -> None:\r\n        r\"\"\"\r\n        downloads this dataset from url\r\n        check if files are already downloaded\r\n        Returns:\r\n            None\r\n        \"\"\"\r\n        # check if the file already exists\r\n        if osp.exists(self.meta_dict[\"fname\"]) and osp.exists(\r\n            self.meta_dict[\"nodefile\"]\r\n        ):\r\n            dir_name = self.meta_dict[\"fname\"]\r\n            vprint(f\"files found in {dir_name}\")\r\n            return\r\n\r\n        else:\r\n            vprint(\r\n                f\"{BColors.WARNING}Download started, this might take a while . . . {BColors.ENDC}\"\r\n            )\r\n            vprint(f\"Dataset title: {self.name}\")\r\n\r\n            if self.url is None:\r\n                raise ValueError(f\"Dataset {self.name} url not found, download not supported yet.\")\r\n            else:\r\n                r = requests.get(self.url, stream=True)\r\n                if osp.isdir(self.root):\r\n                    vprint(\"Dataset directory is \", self.root)\r\n                else:\r\n                    os.makedirs(self.root)\r\n                path_download = self.root + \"/\" + self.name + \".zip\"\r\n                print(f\"downloading Dataset: {self.name} to {path_download}\")\r\n                with open(path_download, \"wb\") as f:\r\n                    total_length = int(r.headers.get(\"content-length\"))\r\n                    for chunk in progress.bar(\r\n                        r.iter_content(chunk_size=1024),\r\n                        expected_size=(total_length / 1024) + 1,\r\n                    ):\r\n                        if chunk:\r\n                            f.write(chunk)\r\n                            f.flush()\r\n                # for unzipping the file\r\n                with zipfile.ZipFile(path_download, \"r\") as zip_ref:\r\n                    zip_ref.extractall(self.root)\r\n                vprint(f\"{BColors.OKGREEN}Download completed {BColors.ENDC}\")\r\n          \r\n\r\n    def generate_processed_files(\r\n        self,\r\n    ) -> Tuple[pd.DataFrame, Dict[int, Dict[str, Any]]]:\r\n        r\"\"\"\r\n        returns an edge list of pandas data frame\r\n        Returns:\r\n            df: pandas data frame storing the temporal edge list\r\n            node_label_dict: dictionary with key as timestamp and item as dictionary of node labels\r\n        \"\"\"\r\n        OUT_DF = self.root + \"/\" + \"ml_{}.pkl\".format(self.name)\r\n        OUT_NODE_DF = self.root + \"/\" + \"ml_{}_node.pkl\".format(self.name)\r\n        OUT_LABEL_DF = self.root + \"/\" + \"ml_{}_label.pkl\".format(self.name)\r\n        OUT_EDGE_FEAT = self.root + \"/\" + \"ml_{}.pkl\".format(self.name + \"_edge\")\r\n\r\n        # * logic for large datasets, as node label file is too big to store on disc\r\n        if self.name == \"tgbn-reddit\" or self.name == \"tgbn-token\":\r\n            if osp.exists(OUT_DF) and osp.exists(OUT_NODE_DF) and osp.exists(OUT_EDGE_FEAT):\r\n                df = pd.read_pickle(OUT_DF)\r\n                edge_feat = load_pkl(OUT_EDGE_FEAT)\r\n                if (self.name == \"tgbn-token\"):\r\n                    #! taking log normalization for numerical stability\r\n                    vprint (\"applying log normalization for weights in tgbn-token\")\r\n                    edge_feat[:,0] = np.log(edge_feat[:,0])\r\n                node_ids = load_pkl(OUT_NODE_DF)\r\n                labels_dict = load_pkl(OUT_LABEL_DF)\r\n                node_label_dict = load_label_dict(\r\n                    self.meta_dict[\"nodefile\"], node_ids, labels_dict\r\n                )\r\n                return df, node_label_dict, edge_feat\r\n\r\n        # * load the preprocessed file if possible\r\n        if osp.exists(OUT_DF) and osp.exists(OUT_NODE_DF) and osp.exists(OUT_EDGE_FEAT):\r\n            vprint(f\"loading processed file from {OUT_DF}, edge features from {OUT_EDGE_FEAT}, node info from {OUT_NODE_DF}.\")\r\n            df = pd.read_pickle(OUT_DF)\r\n            node_label_dict = load_pkl(OUT_NODE_DF)\r\n            edge_feat = load_pkl(OUT_EDGE_FEAT)\r\n        else:  # * process the file\r\n            vprint(\"file not processed, generating processed file\")\r\n            if self.name == \"tgbn-reddit\":\r\n                df, edge_feat, node_ids, labels_dict = load_edgelist_sr(\r\n                    self.meta_dict[\"fname\"], label_size=self._num_classes\r\n                )\r\n            elif self.name == \"tgbn-token\":\r\n                df, edge_feat, node_ids, labels_dict = load_edgelist_token(\r\n                    self.meta_dict[\"fname\"], label_size=self._num_classes\r\n                )\r\n            elif self.name == \"tgbn-genre\":\r\n                df, edge_feat, node_ids, labels_dict = load_edgelist_datetime(\r\n                    self.meta_dict[\"fname\"], label_size=self._num_classes\r\n                )\r\n            elif self.name == \"tgbn-trade\":\r\n                df, edge_feat, node_ids = load_edgelist_trade(\r\n                    self.meta_dict[\"fname\"], label_size=self._num_classes\r\n                )\r\n\r\n            df.to_pickle(OUT_DF)\r\n            save_pkl(edge_feat, OUT_EDGE_FEAT)\r\n\r\n            if self.name == \"tgbn-trade\":\r\n                node_label_dict = load_trade_label_dict(\r\n                    self.meta_dict[\"nodefile\"], node_ids\r\n                )\r\n            else:\r\n                node_label_dict = load_label_dict(\r\n                    self.meta_dict[\"nodefile\"], node_ids, labels_dict\r\n                )\r\n\r\n            if (\r\n                self.name != \"tgbn-reddit\" and self.name != \"tgbn-token\"\r\n            ):  # don't save subreddits on disc, the node label file is too big\r\n                save_pkl(node_label_dict, OUT_NODE_DF)\r\n            else:\r\n                save_pkl(node_ids, OUT_NODE_DF)\r\n                save_pkl(labels_dict, OUT_LABEL_DF)\r\n            \r\n            vprint(\"file processed and saved\")\r\n        return df, node_label_dict, edge_feat\r\n\r\n    def pre_process(self) -> None:\r\n        \"\"\"\r\n        Pre-process the dataset and generates the splits, must be run before dataset properties can be accessed\r\n        Returns:\r\n            None\r\n        \"\"\"\r\n        # first check if all files exist\r\n        if (\"fname\" not in self.meta_dict) or (\"nodefile\" not in self.meta_dict):\r\n            raise Exception(\"meta_dict does not contain all required filenames\")\r\n\r\n        df, node_label_dict, edge_feat = self.generate_processed_files()\r\n        sources = np.array(df[\"u\"])\r\n        destinations = np.array(df[\"i\"])\r\n        timestamps = np.array(df[\"ts\"])\r\n        edge_idxs = np.array(df[\"idx\"])\r\n        edge_label = np.ones(sources.shape[0])\r\n        #self._edge_feat = np.array(df[\"w\"])\r\n        self._edge_feat = edge_feat\r\n\r\n        full_data = {\r\n            \"sources\": sources,\r\n            \"destinations\": destinations,\r\n            \"timestamps\": timestamps,\r\n            \"edge_idxs\": edge_idxs,\r\n            \"edge_feat\": self._edge_feat,\r\n            \"edge_label\": edge_label,\r\n            \"node_label_dict\": node_label_dict,\r\n        }\r\n        self._full_data = full_data\r\n\r\n        # storing the split masks\r\n        _train_mask, _val_mask, _test_mask = self.generate_splits(full_data)\r\n\r\n        self._train_mask = _train_mask\r\n        self._val_mask = _val_mask\r\n        self._test_mask = _test_mask\r\n\r\n        self.label_dict = node_label_dict\r\n        self.label_ts = np.array(list(node_label_dict.keys()))\r\n        self.label_ts = np.sort(self.label_ts)\r\n\r\n    def generate_splits(\r\n        self,\r\n        full_data: Dict[str, Any],\r\n        val_ratio: float = 0.15,\r\n        test_ratio: float = 0.15,\r\n    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:\r\n        r\"\"\"\r\n        Generates train, validation, and test splits from the full dataset\r\n        Parameters:\r\n            full_data: dictionary containing the full dataset\r\n            val_ratio: ratio of validation data\r\n            test_ratio: ratio of test data\r\n        Returns:\r\n            train_mask: boolean mask for training data\r\n            val_mask: boolean mask for validation data\r\n            test_mask: boolean mask for test data\r\n        \"\"\"\r\n        val_time, test_time = list(\r\n            np.quantile(\r\n                full_data[\"timestamps\"],\r\n                [(1 - val_ratio - test_ratio), (1 - test_ratio)],\r\n            )\r\n        )\r\n        timestamps = full_data[\"timestamps\"]\r\n        train_mask = timestamps <= val_time\r\n        val_mask = np.logical_and(timestamps <= test_time, timestamps > val_time)\r\n        test_mask = timestamps > test_time\r\n\r\n        return train_mask, val_mask, test_mask\r\n\r\n    def find_next_labels_batch(\r\n        self,\r\n        cur_t: int,\r\n    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:\r\n        r\"\"\"\r\n        this returns the node labels closest to cur_t (for that given day)\r\n        Parameters:\r\n            cur_t: current timestamp of the batch of edges\r\n        Returns:\r\n            ts: timestamp of the node labels\r\n            source_idx: node ids\r\n            labels: the stacked label vectors\r\n        \"\"\"\r\n        if self.label_ts_idx >= (self.label_ts.shape[0]):\r\n            # for query that are after the last batch of labels\r\n            return None\r\n        else:\r\n            ts = self.label_ts[self.label_ts_idx]\r\n\r\n        if cur_t >= ts:\r\n            self.label_ts_idx += 1  # move to the next ts\r\n            # {ts: {node_id: label_vec}}\r\n            node_ids = np.array(list(self.label_dict[ts].keys()))\r\n\r\n            node_labels = []\r\n            for key in self.label_dict[ts]:\r\n                node_labels.append(np.array(self.label_dict[ts][key]))\r\n            node_labels = np.stack(node_labels, axis=0)\r\n            label_ts = np.full(node_ids.shape[0], ts, dtype=\"int\")\r\n            return (label_ts, node_ids, node_labels)\r\n        else:\r\n            return None\r\n\r\n    def reset_label_time(self) -> None:\r\n        r\"\"\"\r\n        reset the pointer for node label once the entire dataset has been iterated once\r\n        Returns:\r\n            None\r\n        \"\"\"\r\n        self.label_ts_idx = 0\r\n\r\n    def return_label_ts(self) -> int:\r\n        \"\"\"\r\n        return the current label timestamp that the pointer is at\r\n        Returns:\r\n            ts: int, the timestamp of the node labels\r\n        \"\"\"\r\n        if (self.label_ts_idx >= self.label_ts.shape[0]):\r\n            return self.label_ts[-1]\r\n        else:\r\n            return self.label_ts[self.label_ts_idx]\r\n\r\n    @property\r\n    def num_classes(self) -> int:\r\n        \"\"\"\r\n        number of classes in the node label\r\n        Returns:\r\n            num_classes: int, number of classes\r\n        \"\"\"\r\n        return self._num_classes\r\n\r\n    @property\r\n    def eval_metric(self) -> str:\r\n        \"\"\"\r\n        the official evaluation metric for the dataset, loaded from info.py\r\n        Returns:\r\n            eval_metric: str, the evaluation metric\r\n        \"\"\"\r\n        return self.metric\r\n\r\n    # TODO not sure needed, to be removed\r\n    @property\r\n    def node_feat(self) -> Optional[np.ndarray]:\r\n        r\"\"\"\r\n        Returns the node features of the dataset with dim [N, feat_dim]\r\n        Returns:\r\n            node_feat: np.ndarray, [N, feat_dim] or None if there is no node feature\r\n        \"\"\"\r\n        return self._node_feat\r\n\r\n    # TODO not sure needed, to be removed\r\n    @property\r\n    def edge_feat(self) -> Optional[np.ndarray]:\r\n        r\"\"\"\r\n        Returns the edge features of the dataset with dim [E, feat_dim]\r\n        Returns:\r\n            edge_feat: np.ndarray, [E, feat_dim] or None if there is no edge feature\r\n        \"\"\"\r\n        return self._edge_feat\r\n    \r\n    @property\r\n    def node_label_dict(self) -> Dict[int, Dict[int, Any]]:\r\n        r\"\"\"\r\n        Returns the node label dictionary of the dataset with {timestamp: {node_id: label_vec}}\r\n        Returns:\r\n            label_dict: Dict[int, Dict[int, Any]], the node label dictionary\r\n        \"\"\"\r\n        return self.label_dict\r\n\r\n    @property\r\n    def full_data(self) -> Dict[str, Any]:\r\n        r\"\"\"\r\n        the full data of the dataset as a dictionary with keys: 'sources', 'destinations', 'timestamps', 'edge_idxs', 'edge_feat', 'w', 'edge_label',\r\n\r\n        Returns:\r\n            full_data: Dict[str, Any]\r\n        \"\"\"\r\n        if self._full_data is None:\r\n            raise ValueError(\r\n                \"dataset has not been processed yet, please call pre_process() first\"\r\n            )\r\n        return self._full_data\r\n\r\n    @property\r\n    def train_mask(self) -> np.ndarray:\r\n        r\"\"\"\r\n        Returns the train mask of the dataset\r\n        Returns:\r\n            train_mask\r\n        \"\"\"\r\n        if self._train_mask is None:\r\n            raise ValueError(\"training split hasn't been loaded\")\r\n        return self._train_mask\r\n\r\n    @property\r\n    def val_mask(self) -> np.ndarray:\r\n        r\"\"\"\r\n        Returns the validation mask of the dataset\r\n        Returns:\r\n            val_mask: Dict[str, Any]\r\n        \"\"\"\r\n        if self._val_mask is None:\r\n            raise ValueError(\"validation split hasn't been loaded\")\r\n\r\n        return self._val_mask\r\n\r\n    @property\r\n    def test_mask(self) -> np.ndarray:\r\n        r\"\"\"\r\n        Returns the test mask of the dataset:\r\n        Returns:\r\n            test_mask: Dict[str, Any]\r\n        \"\"\"\r\n        if self._test_mask is None:\r\n            raise ValueError(\"test split hasn't been loaded\")\r\n\r\n        return self._test_mask\r\n\r\n\r\ndef main():\r\n    # download files\r\n    name = \"tgbn-trade\" \r\n    dataset = NodePropPredDataset(name=name, root=\"datasets\", preprocess=True)\r\n\r\n    dataset.node_feat\r\n    dataset.edge_feat  # not the edge weights\r\n    dataset.full_data\r\n    dataset.full_data[\"edge_idxs\"]\r\n    dataset.full_data[\"sources\"]\r\n    dataset.full_data[\"destinations\"]\r\n    dataset.full_data[\"timestamps\"]\r\n    dataset.full_data[\"y\"]\r\n\r\n    train_data = dataset.full_data[dataset.train_mask]\r\n    val_data = dataset.full_data[dataset.val_mask]\r\n    test_data = dataset.full_data[dataset.test_mask]\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n"
  },
  {
    "path": "tgb/nodeproppred/dataset_pyg.py",
    "content": "import os.path as osp\nfrom typing import Optional, Dict, Any, Optional, Callable\n\n\nimport torch\n\nfrom torch_geometric.data import InMemoryDataset, TemporalData, download_url\nfrom tgb.nodeproppred.dataset import NodePropPredDataset\nimport warnings\n\n\n# TODO check https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/data/in_memory_dataset.html\n# avoid any overlapping properties\nclass PyGNodePropPredDataset(InMemoryDataset):\n    r\"\"\"\n    PyG wrapper for the NodePropPredDataset\n    can return pytorch tensors for src,dst,t,msg,label\n    can return Temporal Data object\n    also query the node labels corresponding to a timestamp from edge batch\n    Parameters:\n        name: name of the dataset, passed to `NodePropPredDataset`\n        root (string): Root directory where the dataset should be saved.\n        transform (callable, optional): A function/transform that takes in an\n        pre_transform (callable, optional): A function/transform that takes in\n        download (optional, bool): download dataset or not (default True)\n    \"\"\"\n\n    def __init__(\n        self,\n        name: str,\n        root: str = \"datasets\",\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        download: Optional[bool] = True,\n\n    ):\n        self.name = name\n        self.root = root\n        self.dataset = NodePropPredDataset(name=name, root=root, download=download)\n        self._train_mask = torch.from_numpy(self.dataset.train_mask)\n        self._val_mask = torch.from_numpy(self.dataset.val_mask)\n        self._test_mask = torch.from_numpy(self.dataset.test_mask)\n        self.__num_classes = self.dataset.num_classes\n        super().__init__(root, transform, pre_transform)\n        self.process_data()\n\n    @property\n    def num_classes(self) -> int:\n        \"\"\"\n        how many classes are in the node label\n        Returns:\n            num_classes: int\n        \"\"\"\n        return self.__num_classes\n\n    @property\n    def eval_metric(self) -> str:\n        \"\"\"\n        the official evaluation metric for the dataset, loaded from info.py\n        Returns:\n            eval_metric: str, the evaluation metric\n        \"\"\"\n        return self.dataset.eval_metric\n\n    @property\n    def train_mask(self) -> torch.Tensor:\n        r\"\"\"\n        Returns the train mask of the dataset\n        Returns:\n            train_mask: the mask for edges in the training set\n        \"\"\"\n        if self._train_mask is None:\n            raise ValueError(\"training split hasn't been loaded\")\n        return self._train_mask\n\n    @property\n    def val_mask(self) -> torch.Tensor:\n        r\"\"\"\n        Returns the validation mask of the dataset\n        Returns:\n            val_mask: the mask for edges in the validation set\n        \"\"\"\n        if self._val_mask is None:\n            raise ValueError(\"validation split hasn't been loaded\")\n        return self._val_mask\n\n    @property\n    def test_mask(self) -> torch.Tensor:\n        r\"\"\"\n        Returns the test mask of the dataset:\n        Returns:\n            test_mask: the mask for edges in the test set\n        \"\"\"\n        if self._test_mask is None:\n            raise ValueError(\"test split hasn't been loaded\")\n        return self._test_mask\n\n    @property\n    def src(self) -> torch.Tensor:\n        r\"\"\"\n        Returns the source nodes of the dataset\n        Returns:\n            src: the idx of the source nodes\n        \"\"\"\n        return self._src\n\n    @property\n    def dst(self) -> torch.Tensor:\n        r\"\"\"\n        Returns the destination nodes of the dataset\n        Returns:\n            dst: the idx of the destination nodes\n        \"\"\"\n        return self._dst\n\n    @property\n    def ts(self) -> torch.Tensor:\n        r\"\"\"\n        Returns the timestamps of the dataset\n        Returns:\n            ts: the timestamps of the edges\n        \"\"\"\n        return self._ts\n\n    @property\n    def edge_feat(self) -> torch.Tensor:\n        r\"\"\"\n        Returns the edge features of the dataset\n        Returns:\n            edge_feat: the edge features\n        \"\"\"\n        return self._edge_feat\n\n    @property\n    def edge_label(self) -> torch.Tensor:\n        r\"\"\"\n        Returns the edge labels of the dataset\n        Returns:\n            edge_label: the labels of the edges (all one tensor)\n        \"\"\"\n        return self._edge_label\n\n    def process_data(self):\n        \"\"\"\n        convert data to pytorch tensors\n        \"\"\"\n        src = torch.from_numpy(self.dataset.full_data[\"sources\"])\n        dst = torch.from_numpy(self.dataset.full_data[\"destinations\"])\n        t = torch.from_numpy(self.dataset.full_data[\"timestamps\"])\n        edge_label = torch.from_numpy(self.dataset.full_data[\"edge_label\"])\n        msg = torch.from_numpy(self.dataset.full_data[\"edge_feat\"])\n        # msg = torch.from_numpy(self.dataset.full_data[\"edge_feat\"]).reshape(\n        #     [-1, 1]\n        # ) \n        # * check typing\n        if src.dtype != torch.int64:\n            src = src.long()\n\n        if dst.dtype != torch.int64:\n            dst = dst.long()\n\n        if t.dtype != torch.int64:\n            t = t.long()\n\n        if msg.dtype != torch.float32:\n            msg = msg.float()\n\n        self._src = src\n        self._dst = dst\n        self._ts = t\n        self._edge_label = edge_label\n        self._edge_feat = msg\n\n    def get_TemporalData(\n        self,\n    ) -> TemporalData:\n        \"\"\"\n        return the TemporalData object for the entire dataset\n        Returns:\n            data: TemporalData object storing the edgelist\n        \"\"\"\n        data = TemporalData(\n            src=self._src,\n            dst=self._dst,\n            t=self._ts,\n            msg=self._edge_feat,\n            y=self._edge_label,\n        )\n        return data\n\n    def reset_label_time(self) -> None:\n        \"\"\"\n        reset the pointer for the node labels, should be done per epoch\n        \"\"\"\n        self.dataset.reset_label_time()\n\n    def get_node_label(self, cur_t):\n        \"\"\"\n        return the node labels for the current timestamp\n        \"\"\"\n        label_tuple = self.dataset.find_next_labels_batch(cur_t)\n        if label_tuple is None:\n            return None\n        label_ts, label_srcs, labels = label_tuple[0], label_tuple[1], label_tuple[2]\n        label_ts = torch.from_numpy(label_ts).long()\n        label_srcs = torch.from_numpy(label_srcs).long()\n        labels = torch.from_numpy(labels).to(torch.float32)\n        return label_ts, label_srcs, labels\n\n    def get_label_time(self) -> int:\n        \"\"\"\n        return the timestamps of the current node labels\n        Returns:\n            t: time of the current node labels\n        \"\"\"\n        return self.dataset.return_label_ts()\n\n    def len(self) -> int:\n        \"\"\"\n        size of the dataset\n        Returns:\n            size: int\n        \"\"\"\n        return self._src.shape[0]\n\n    def get(self, idx: int) -> TemporalData:\n        \"\"\"\n        construct temporal data object for a single edge\n        Parameters:\n            idx: index of the edge\n        Returns:\n            data: TemporalData object\n        \"\"\"\n        data = TemporalData(\n            src=self._src[idx],\n            dst=self._dst[idx],\n            t=self._ts[idx],\n            msg=self._edge_feat[idx],\n            y=self._edge_label[idx],\n        )\n        return data\n\n    def __repr__(self) -> str:\n        return f\"{self.name.capitalize()}()\"\n"
  },
  {
    "path": "tgb/nodeproppred/evaluate.py",
    "content": "import numpy as np\r\nfrom sklearn.metrics import mean_squared_error\r\nfrom sklearn.metrics import ndcg_score\r\nimport math\r\n\r\nfrom tgb.utils.info import DATA_EVAL_METRIC_DICT\r\n\r\ntry:\r\n    import torch\r\nexcept ImportError:\r\n    torch = None\r\nfrom tgb.utils.utils import vprint\r\n\r\n\r\nclass Evaluator(object):\r\n    \"\"\"Evaluator for Node Property Prediction\"\"\"\r\n\r\n    def __init__(self, name: str):\r\n        r\"\"\"\r\n        Parameters:\r\n            name: name of the dataset\r\n        \"\"\"\r\n        self.name = name\r\n        self.valid_metric_list = [\"mse\", \"rmse\", \"ndcg\"]\r\n        if self.name not in DATA_EVAL_METRIC_DICT:\r\n            raise NotImplementedError(\"Dataset not supported\")\r\n\r\n    def _parse_and_check_input(self, input_dict):\r\n        \"\"\"\r\n        check whether the input has the required format\r\n        Parametrers:\r\n            -input_dict: a dictionary containing \"y_true\", \"y_pred\", and \"eval_metric\"\r\n\r\n            note: \"eval_metric\" should be a list including one or more of the followin metrics:\r\n                    [\"mse\"]\r\n        \"\"\"\r\n        # valid_metric_list = ['ap', 'au_roc_score', 'au_pr_score', 'acc', 'prec', 'rec', 'f1']\r\n\r\n        if \"eval_metric\" not in input_dict:\r\n            raise RuntimeError(\"Missing key of eval_metric\")\r\n\r\n        for eval_metric in input_dict[\"eval_metric\"]:\r\n            if eval_metric in self.valid_metric_list:\r\n                if \"y_true\" not in input_dict:\r\n                    raise RuntimeError(\"Missing key of y_true\")\r\n                if \"y_pred\" not in input_dict:\r\n                    raise RuntimeError(\"Missing key of y_pred\")\r\n\r\n                y_true, y_pred = input_dict[\"y_true\"], input_dict[\"y_pred\"]\r\n\r\n                # converting to numpy on cpu\r\n                if torch is not None and isinstance(y_true, torch.Tensor):\r\n                    y_true = y_true.detach().cpu().numpy()\r\n                if torch is not None and isinstance(y_pred, torch.Tensor):\r\n                    y_pred = y_pred.detach().cpu().numpy()\r\n\r\n                # check type and shape\r\n                if not isinstance(y_true, np.ndarray) or not isinstance(\r\n                    y_pred, np.ndarray\r\n                ):\r\n                    raise RuntimeError(\r\n                        \"Arguments to Evaluator need to be either numpy ndarray or torch tensor!\"\r\n                    )\r\n\r\n                if not y_true.shape == y_pred.shape:\r\n                    raise RuntimeError(\"Shape of y_true and y_pred must be the same!\")\r\n\r\n            else:\r\n                raise ValueError(f\"Unsupported eval metric: {eval_metric}, not found in {self.valid_metric_list}\")\r\n\r\n        self.eval_metric = input_dict[\"eval_metric\"]\r\n\r\n        return y_true, y_pred\r\n\r\n    def _compute_metrics(self, y_true, y_pred):\r\n        \"\"\"\r\n        compute the performance metrics for the given true labels and prediction probabilities\r\n        Parameters:\r\n            -y_true: actual true labels\r\n            -y_pred: predicted probabilities\r\n        \"\"\"\r\n        perf_dict = {}\r\n        for eval_metric in self.eval_metric:\r\n            if eval_metric == \"mse\":\r\n                perf_dict = {\r\n                    \"mse\": mean_squared_error(y_true, y_pred),\r\n                    \"rmse\": math.sqrt(mean_squared_error(y_true, y_pred)),\r\n                }\r\n            elif eval_metric == \"ndcg\":\r\n                k = 10\r\n                perf_dict = {\"ndcg\": ndcg_score(y_true, y_pred, k=k)}\r\n        return perf_dict\r\n\r\n    def eval(self, input_dict, verbose=False):\r\n        \"\"\"\r\n        evaluation for edge regression task\r\n        \"\"\"\r\n        y_true, y_pred = self._parse_and_check_input(input_dict)\r\n        perf_dict = self._compute_metrics(y_true, y_pred)\r\n\r\n        if verbose:\r\n            print(\"INFO: Evaluation Results:\")\r\n            for eval_metric in input_dict[\"eval_metric\"]:\r\n                print(f\"\\t>>> {eval_metric}: {perf_dict[eval_metric]:.4f}\")\r\n        return perf_dict\r\n\r\n    @property\r\n    def expected_input_format(self):\r\n        desc = \"==== Expected input format of Evaluator for {}\\n\".format(self.name)\r\n        if \"mse\" in self.valid_metric_list:\r\n            desc += \"{'y_pred': y_pred}\\n\"\r\n            desc += \"- y_pred: numpy ndarray or torch tensor of shape (num_edges, ). Torch tensor on GPU is recommended for efficiency.\\n\"\r\n            desc += \"y_pred is the predicted weight for edges.\\n\"\r\n        else:\r\n            raise ValueError(\"Undefined eval metric %s\" % (self.eval_metric))\r\n        return desc\r\n\r\n    @property\r\n    def expected_output_format(self):\r\n        desc = \"==== Expected output format of Evaluator for {}\\n\".format(self.name)\r\n        if \"mse\" in self.valid_metric_list:\r\n            desc += \"{'mse': mse\\n\"\r\n            desc += \"- mse (float): mse score\\n\"\r\n        else:\r\n            raise ValueError(\"Undefined eval metric %s\" % (self.eval_metric))\r\n        return desc\r\n\r\n\r\ndef main():\r\n    \"\"\"\r\n    simple test for evaluator\r\n    \"\"\"\r\n    name = \"tgbn-trade\"\r\n    evaluator = Evaluator(name=name)\r\n    print(evaluator.expected_input_format)\r\n    print(evaluator.expected_output_format)\r\n    input_dict = {\"y_true\": y_true, \"y_pred\": y_pred, \"eval_metric\": [\"mse\"]}\r\n\r\n    result_dict = evaluator.eval(input_dict)\r\n    print(result_dict)\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n"
  },
  {
    "path": "tgb/utils/dataset_stats.py",
    "content": "\"\"\"\nDataset statistics\n\"\"\"\n\nimport numpy as np\nimport pandas as pd\nimport networkx as nx\nimport argparse\n\nfrom tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset\nfrom tgb.linkproppred.dataset import LinkPropPredDataset\nfrom tgb.utils.utils import vprint\n\n\n\n\ndef get_unique_edges(sources, destination):\n    r\"\"\"\n    return unique edges\n    \"\"\"\n    unique_e = {}\n    for src, dst in zip(sources, destination):\n        if (src, dst) not in unique_e:\n            unique_e[(src, dst)] = True\n    return unique_e\n\n\ndef get_avg_e_per_ts(edgelist_df):\n    r\"\"\"\n    get the average number of edges per each timestamp\n    \"\"\"\n    sum_num_e_per_ts = 0\n    unique_ts = np.unique(np.array(edgelist_df['ts'].tolist()))\n    for ts in unique_ts:\n        num_e_at_this_ts = len(edgelist_df.loc[edgelist_df['ts'] == ts])\n        sum_num_e_per_ts += num_e_at_this_ts\n    avg_num_e_per_ts = (sum_num_e_per_ts * 1.0) / len(unique_ts)    \n    return avg_num_e_per_ts\n\n\ndef get_avg_degree(edgelist_df):\n    r\"\"\"\n    get average degree over the timestamps\n    \"\"\"\n    degree_avg_at_ts_list = []\n    unique_ts = np.unique(np.array(edgelist_df['ts'].tolist()))\n    for ts in unique_ts:  \n        e_at_this_ts = edgelist_df.loc[edgelist_df['ts'] == ts]\n        G = nx.MultiGraph()\n        for idx, e_row in e_at_this_ts.iterrows():\n            G.add_edge(e_row['src'], e_row['dst'], weight=e_row['ts'])\n        nodes = G.nodes()\n        degrees = [G.degree[n] for n in nodes]\n        degree_avg_at_ts_list.append(np.mean(degrees))    \n    return np.mean(degree_avg_at_ts_list)\n\n\ndef get_index_metrics(train_val_data, test_data):\n    r\"\"\"\n    compute `surprise` and `recurrence` indices\n    \"\"\"\n    train_val_e_set = {}\n    for src, dst in zip(train_val_data['sources'], train_val_data['destinations']):\n        if (src, dst) not in train_val_e_set:\n            train_val_e_set[(src, dst)] = True\n    \n    test_e_set = {}\n    for src, dst in zip(test_data['sources'], test_data['destinations']):\n        if (src, dst) not in test_e_set:\n            test_e_set[(src, dst)] = True\n    \n    train_val_size = len(train_val_data['sources'])\n    test_size = len(test_data['sources'])\n\n    intersect = difference = 0\n    for e in test_e_set:\n        if e in train_val_e_set:\n            intersect += 1\n        else:\n            difference += 1\n\n    surprise = float(difference * 1.0 / test_size)\n    reoccurrence = float(intersect * 1.0 / train_val_size)\n    return surprise, reoccurrence\n\n\ndef get_node_ratio(history_data, eval_data):\n    r\"\"\"\n    compute the ratio of new nodes\n    \"\"\"\n    eval_uniq_nodes = set(eval_data['sources']).union(set(eval_data['destinations'])) \n    hist_uniq_nodes = set(history_data['sources']).union(set(history_data['destinations'])) \n    new_nodes = []\n    for node in eval_uniq_nodes:\n        if node not in hist_uniq_nodes:\n            new_nodes.append(node)\n    new_nodes = set(new_nodes)\n    new_node_ratio = float(len(new_nodes) * 1.0 / len(eval_uniq_nodes))\n\n    return new_node_ratio\n\n\ndef get_dataset_stats(data, temporal_stats=False):\n    r\"\"\"\n    returns simple stats based on counts\n    \"\"\"\n    # simple stats\n    sources, destinations, timestamps = data['full']['sources'], data['full']['destinations'], data['full']['timestamps']\n    edgelist_df = pd.DataFrame(zip(sources, destinations, timestamps), columns=['src', 'dst', 'ts'])\n    num_nodes = len(np.unique(np.concatenate((sources, destinations), axis=0)))\n    num_edges = len(sources)  # = len(destinations) = len(timestamps)\n    num_unique_ts = len(np.unique(timestamps))\n    unique_e = get_unique_edges(sources, destinations)\n    num_unique_e = len(unique_e)\n\n    # compute temporal stats\n    if temporal_stats:  # because it takes so long for large datasets...\n        avg_e_per_ts = get_avg_e_per_ts(edgelist_df)\n        avg_degree_per_ts = get_avg_degree(edgelist_df)\n    else:\n        avg_e_per_ts = -1\n        avg_degree_per_ts = -1\n    \n    # compute reoccurrence & surprise\n    surprise, reoccurrence = get_index_metrics(data['train_val'], data['test'])\n\n    # compute new node ratio \n    val_nn_ratio = get_node_ratio(data['train'], data['val'])\n    #test_nn_ratio = get_node_ratio(data['train_val'], data['test'])\n    test_nn_ratio = get_node_ratio(data['train'], data['test'])\n\n\n    stats_dict = {\n                  'num_nodes': num_nodes,\n                  'num_edges': num_edges,\n                  'num_unique_ts': num_unique_ts,\n                  'num_unique_e': num_unique_e,\n                  'avg_e_per_ts': avg_e_per_ts,\n                  'avg_degree_per_ts': avg_degree_per_ts,\n                  'surprise': surprise,\n                  'reocurrence': reoccurrence,\n                  'val_nn_ratio': val_nn_ratio,\n                  'test_nn_ratio': test_nn_ratio,\n                  }\n    return stats_dict\n\n\ndef main():\n    r\"\"\"\n    Generate dateset statistics\n    \"\"\"\n    parser = argparse.ArgumentParser(description='Dataset statistics')\n    parser.add_argument('-d', '--data', type=str, default='tgbl-wiki', help='random seed to use')\n    parser.add_argument('--tempstats', action='store_true', default=False, help='whether compute temporal statistics')\n    parser.parse_args()\n    args = parser.parse_args()\n\n    DATA = args.data\n    temporal_stats = args.tempstats\n\n    # data loading ...\n    if DATA in ['tgbl-wiki', 'tgbl-review', 'tgbl-flight', 'tgbl-comment', 'tgbl-coin']:\n        # load data: link prop. pred. with `numpy`\n        dataset = LinkPropPredDataset(name=DATA, root=\"datasets\", preprocess=True)\n        data = dataset.full_data  \n\n        # get masks\n        train_mask = dataset.train_mask\n        val_mask = dataset.val_mask\n        test_mask = dataset.test_mask\n        train_data = {'sources': data['sources'][train_mask],\n                      'destinations': data['destinations'][train_mask],\n                      }\n        val_data = {'sources': data['sources'][val_mask],\n                      'destinations': data['destinations'][val_mask],\n                      }\n        train_val_data = {'sources': np.concatenate([data['sources'][train_mask], data['sources'][val_mask]]),\n                      'destinations': np.concatenate([data['destinations'][train_mask], data['destinations'][val_mask]]),\n                      }\n        test_data = {'sources': data['sources'][test_mask],\n                      'destinations': data['destinations'][test_mask],\n                      }\n        full_data = {'sources': data['sources'], \n                     'destinations': data['destinations'], \n                     'timestamps': data['timestamps'],\n                     }\n\n    elif DATA in ['tgbn-trade', 'tgbn-genre', 'tgbn-reddit', 'tgbn-token']:\n        # load data: node prop. pred.\n        dataset = PyGNodePropPredDataset(name=DATA, root=\"datasets\")\n        data = dataset.get_TemporalData()\n        \n        # split data\n        train_mask = dataset.train_mask\n        val_mask = dataset.val_mask\n        test_mask = dataset.test_mask\n        train_val_mask = np.logical_or(np.array(train_mask), np.array(val_mask))\n\n        train_data = {'sources': np.array(data[train_mask].src),\n                      'destinations': np.array(data[train_mask].dst),\n                      }\n        val_data = {'sources': np.array(data[val_mask].src),\n                    'destinations': np.array(data[val_mask].dst),\n                    }\n        train_val_data = {'sources': np.concatenate([np.array(data[train_mask].src), np.array(data[val_mask].src)]),\n                          'destinations': np.concatenate([np.array(data[train_mask].dst), np.array(data[val_mask].dst)]),\n                          }\n        test_data = {'sources': np.array(data[test_mask].src),\n                     'destinations': np.array(data[test_mask].dst),\n                     } \n        full_data = {'sources': np.array(data.src), \n                     'destinations': np.array(data.dst), \n                     'timestamps': np.array(data.t),\n                     }\n\n    else:\n        raise ValueError(\"Unsupported data!\")\n\n    split_data = {'train': train_data,\n                  'val': val_data,\n                  'train_val': train_val_data,\n                  'test': test_data,\n                  'full': full_data,\n                  }\n    vprint(\"=============================\")\n    vprint(f\">>> DATA: {DATA}\")\n    dataset_stats = get_dataset_stats(split_data, temporal_stats)\n    for k, v in dataset_stats.items():\n        vprint(f\"{k}: {v}\")\n    vprint(\"=============================\")\n\n\nif __name__ == \"__main__\":\n    main()"
  },
  {
    "path": "tgb/utils/info.py",
    "content": "import os.path as osp\r\nimport os\r\n\r\nr\"\"\"\r\nGeneral space to store global information used elsewhere such as url links, evaluation metrics etc.\r\n\"\"\"\r\nPROJ_DIR = osp.dirname(osp.abspath(os.path.join(__file__, os.pardir))) + \"/\"\r\n\r\n\r\nclass BColors:\r\n    \"\"\"\r\n    A class to change the colors of the strings.\r\n    \"\"\"\r\n\r\n    HEADER = \"\\033[95m\"\r\n    OKBLUE = \"\\033[94m\"\r\n    OKCYAN = \"\\033[96m\"\r\n    OKGREEN = \"\\033[92m\"\r\n    WARNING = \"\\033[93m\"\r\n    FAIL = \"\\033[91m\"\r\n    ENDC = \"\\033[0m\"\r\n    BOLD = \"\\033[1m\"\r\n    UNDERLINE = \"\\033[4m\"\r\n\r\nDATA_URL_DICT = {\r\n    \"tgbl-enron\": \"https://object-arbutus.cloud.computecanada.ca/tgb/tgbl-enron.zip\",\r\n    \"tgbl-uci\": \"https://object-arbutus.cloud.computecanada.ca/tgb/tgbl-uci.zip\",\r\n    \"tgbl-wiki\":\"https://object-arbutus.cloud.computecanada.ca/tgb/tgbl-wiki-v2.zip\", #\"https://object-arbutus.cloud.computecanada.ca/tgb/tgbl-wiki.zip\", #v1\r\n    \"tgbl-subreddit\":\"https://object-arbutus.cloud.computecanada.ca/tgb/tgbl-subreddit.zip\",\r\n    \"tgbl-lastfm\":\"https://object-arbutus.cloud.computecanada.ca/tgb/tgbl-lastfm.zip\",\r\n    \"tgbl-review\": \"https://object-arbutus.cloud.computecanada.ca/tgb/tgbl-review-v2.zip\",  # \"https://object-arbutus.cloud.computecanada.ca/tgb/tgbl-review-v3.zip\" #\"https://object-arbutus.cloud.computecanada.ca/tgb/tgbl-review-v2.zip\"  #\"https://object-arbutus.cloud.computecanada.ca/tgb/tgbl-review.zip\", #v1\r\n    \"tgbl-coin\": \"https://object-arbutus.cloud.computecanada.ca/tgb/tgbl-coin-v2.zip\", #\"https://object-arbutus.cloud.computecanada.ca/tgb/tgbl-coin.zip\",\r\n    \"tgbl-flight\": \"https://object-arbutus.cloud.computecanada.ca/tgb/tgbl-flight-v2.zip\", #\"tgbl-flight\": \"https://object-arbutus.cloud.computecanada.ca/tgb/tgbl-flight_edgelist_v2_ts.zip\",\r\n    \"tgbl-comment\": \"https://object-arbutus.cloud.computecanada.ca/tgb/tgbl-comment.zip\",\r\n    \"tgbn-trade\": \"https://object-arbutus.cloud.computecanada.ca/tgb/tgbn-trade.zip\",\r\n    \"tgbn-genre\": \"https://object-arbutus.cloud.computecanada.ca/tgb/tgbn-genre.zip\",\r\n    \"tgbn-reddit\": \"https://object-arbutus.cloud.computecanada.ca/tgb/tgbn-reddit.zip\",\r\n    \"tgbn-token\": \"https://object-arbutus.cloud.computecanada.ca/tgb/tgbn-token.zip\",\r\n    \"tkgl-polecat\": \"https://object-arbutus.cloud.computecanada.ca/tgb/tkgl-polecat.zip\",\r\n    \"tkgl-icews\": \"https://object-arbutus.cloud.computecanada.ca/tgb/tkgl-icews.zip\",\r\n    \"tkgl-yago\":\"https://object-arbutus.cloud.computecanada.ca/tgb/tkgl-yago.zip\",\r\n    \"tkgl-wikidata\": \"https://object-arbutus.cloud.computecanada.ca/tgb/tkgl-wikidata.zip\",\r\n    \"tkgl-smallpedia\": \"https://object-arbutus.cloud.computecanada.ca/tgb/tkgl-smallpedia.zip\",\r\n    \"thgl-myket\": \"https://object-arbutus.cloud.computecanada.ca/tgb/thgl-myket.zip\",\r\n    \"thgl-github\": \"https://object-arbutus.cloud.computecanada.ca/tgb/thgl-github.zip\",\r\n    \"thgl-forum\": \"https://object-arbutus.cloud.computecanada.ca/tgb/thgl-forum.zip\",\r\n    \"thgl-software\": \"https://object-arbutus.cloud.computecanada.ca/tgb/thgl-software.zip\", #\"https://object-arbutus.cloud.computecanada.ca/tgb/thgl-software_ns_random.zip\"\r\n}\r\n\r\n\r\n\r\nDATA_VERSION_DICT = {\r\n    \"tgbl-enron\": 1,\r\n    \"tgbl-uci\": 1,\r\n    \"tgbl-wiki\": 2,  \r\n    \"tgbl-subreddit\": 1,\r\n    \"tgbl-lastfm\": 1,\r\n    \"tgbl-review\": 2, #3\r\n    \"tgbl-coin\": 2,\r\n    \"tgbl-comment\": 1,\r\n    \"tgbl-flight\": 2,\r\n    \"tgbn-trade\": 1,\r\n    \"tgbn-genre\": 1,\r\n    \"tgbn-reddit\": 1,\r\n    \"tgbn-token\": 1,\r\n    \"tkgl-polecat\": 1,\r\n    \"tkgl-icews\": 1,\r\n    \"tkgl-yago\": 1,\r\n    \"tkgl-wikidata\": 1,\r\n    \"tkgl-smallpedia\": 1,\r\n    \"thgl-myket\": 1,\r\n    \"thgl-github\": 1,\r\n    \"thgl-forum\": 1,\r\n    \"thgl-software\": 1,\r\n}\r\n\r\n\r\nDATA_EVAL_METRIC_DICT = {\r\n    \"tgbl-enron\": \"mrr\",\r\n    \"tgbl-uci\": \"mrr\",\r\n    \"tgbl-wiki\": \"mrr\",\r\n    \"tgbl-subreddit\": \"mrr\",\r\n    \"tgbl-lastfm\": \"mrr\",\r\n    \"tgbl-review\": \"mrr\",\r\n    \"tgbl-coin\": \"mrr\",\r\n    \"tgbl-comment\": \"mrr\",\r\n    \"tgbl-flight\": \"mrr\",\r\n    \"tkgl-polecat\": \"mrr\",\r\n    \"tkgl-yago\": \"mrr\",\r\n    \"tkgl-wikidata\": \"mrr\",\r\n    \"tkgl-smallpedia\": \"mrr\",\r\n    \"tkgl-icews\": \"mrr\",\r\n    \"thgl-myket\": \"mrr\",\r\n    \"thgl-github\": \"mrr\",\r\n    \"thgl-forum\": \"mrr\",\r\n    \"thgl-software\": \"mrr\",\r\n    \"tgbn-trade\": \"ndcg\",\r\n    \"tgbn-genre\": \"ndcg\",\r\n    \"tgbn-reddit\": \"ndcg\",\r\n    \"tgbn-token\": \"ndcg\",\r\n}\r\n\r\nDATA_NS_STRATEGY_DICT = {\r\n    \"tgbl-enron\": \"hist_rnd\",\r\n    \"tgbl-uci\": \"hist_rnd\",\r\n    \"tgbl-wiki\": \"hist_rnd\",\r\n    \"tgbl-subreddit\": \"hist_rnd\",\r\n    \"tgbl-lastfm\": \"hist_rnd\",\r\n    \"tgbl-review\": \"hist_rnd\",\r\n    \"tgbl-coin\": \"hist_rnd\",\r\n    \"tgbl-comment\": \"hist_rnd\",\r\n    \"tgbl-flight\": \"hist_rnd\",\r\n    \"tkgl-polecat\": \"time-filtered\",\r\n    \"tkgl-yago\": \"time-filtered\",\r\n    \"tkgl-wikidata\": \"dst-time-filtered\",\r\n    \"tkgl-smallpedia\": \"time-filtered\",\r\n    \"tkgl-icews\": \"time-filtered\",\r\n    \"thgl-myket\": \"node-type-filtered\",\r\n    \"thgl-github\": \"node-type-filtered\",\r\n    \"thgl-forum\": \"node-type-filtered\",\r\n    \"thgl-software\": \"node-type-filtered\",\r\n}\r\n\r\n\r\nDATA_NUM_CLASSES = {\r\n    \"tgbn-trade\": 255,\r\n    \"tgbn-genre\": 513,\r\n    \"tgbn-reddit\": 698,\r\n    \"tgbn-token\": 1001,\r\n}\r\n"
  },
  {
    "path": "tgb/utils/pre_process.py",
    "content": "from typing import Optional, cast, Union, List, overload, Literal\r\nfrom tqdm import tqdm\r\nimport numpy as np\r\nimport pandas as pd\r\nimport os.path as osp\r\nimport time\r\nimport csv\r\nimport datetime\r\nfrom datetime import date\r\nfrom tgb.utils.utils import vprint\r\n\r\n\"\"\"\r\nfunction to process node type for thg datasets\r\n\"\"\"\r\n\r\ndef process_node_type(\r\n    fname: str,\r\n    node_ids,\r\n):\r\n    \"\"\"\r\n    1. process the node type into integer\r\n    3. return a numpy array of node types with index corresponding to node id\r\n    \"\"\"\r\n    node_feat = np.zeros(len(node_ids))\r\n    with open(fname, \"r\") as csv_file:\r\n        csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n        idx = 0\r\n        # node_id,type\r\n        for row in tqdm(csv_reader):\r\n            if idx == 0:\r\n                idx += 1\r\n                continue\r\n            else:\r\n                nid = int(row[0])\r\n                try:\r\n                    node_type = int(row[1])\r\n                except:\r\n                    raise ValueError(row[1], \" is not an integer thus can't be a node type for thg dataset\")\r\n                try:\r\n                    node_id = node_ids[nid]\r\n                except:\r\n                    raise ValueError(nid, \" is not a valid node id\")\r\n                node_feat[node_id] = node_type\r\n    return node_feat\r\n\r\n\"\"\"\r\nfunctions for thgl-forum dataset\r\n\"\"\"\r\ndef csv_to_forum_data(\r\n    fname: str,\r\n) -> pd.DataFrame:\r\n    r\"\"\"\r\n    used by thgl-forum dataset\r\n    convert the raw .csv data to pandas dataframe and numpy array\r\n    input .csv file format should be: timestamp, head, tail, relation type\r\n    Args:\r\n        fname: the path to the raw data\r\n    \"\"\"\r\n    feat_size = 2\r\n    num_lines = sum(1 for line in open(fname)) - 1\r\n    vprint(f\"number of lines counted: {num_lines} in {fname}\")\r\n    u_list = np.zeros(num_lines)\r\n    i_list = np.zeros(num_lines)\r\n    ts_list = np.zeros(num_lines)\r\n    label_list = np.zeros(num_lines)\r\n    edge_type = np.zeros(num_lines)\r\n    feat_l = np.zeros((num_lines, feat_size))\r\n    idx_list = np.zeros(num_lines)\r\n    w_list = np.zeros(num_lines)\r\n    node_ids = {}\r\n    unique_id = 0\r\n\r\n    word_max = 10000\r\n    score_max = 10000\r\n\r\n    with open(fname, \"r\") as csv_file:\r\n        csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n        idx = 0\r\n        #timestamp, head, tail, relation type\r\n        for row in tqdm(csv_reader):\r\n            if idx == 0:\r\n                idx += 1\r\n                continue\r\n            else:\r\n                #! ts,src,dst,relation_type,num_words,score\r\n                ts = int(row[0]) #converted to UNIX timestamp already \r\n                src = int(row[1])\r\n                dst = int(row[2])\r\n                relation = int(row[3])\r\n                num_words = int(row[4])\r\n                score = int(row[5])\r\n                if src not in node_ids:\r\n                    node_ids[src] = unique_id\r\n                    unique_id += 1\r\n                if dst not in node_ids:\r\n                    node_ids[dst] = unique_id\r\n                    unique_id += 1\r\n                u = node_ids[src]\r\n                i = node_ids[dst]\r\n                u_list[idx - 1] = u\r\n                i_list[idx - 1] = i\r\n                ts_list[idx - 1] = ts\r\n                idx_list[idx - 1] = idx\r\n                w_list[idx - 1] = float(1)\r\n                edge_type[idx - 1] = relation\r\n                feat_l[idx - 1] = np.array([num_words/word_max, score/score_max])\r\n                idx += 1\r\n    return (\r\n        pd.DataFrame(\r\n            {\r\n                \"u\": u_list,\r\n                \"i\": i_list,\r\n                \"ts\": ts_list,\r\n                \"label\": label_list,\r\n                \"idx\": idx_list,\r\n                \"w\": w_list,\r\n                \"edge_type\": edge_type,\r\n            }\r\n        ),\r\n        feat_l,\r\n        node_ids,\r\n    )\r\n\r\n\r\n\r\n\r\n\"\"\"\r\nfunctions for thg dataset\r\n\"\"\"\r\ndef csv_to_thg_data(\r\n    fname: str,\r\n) -> pd.DataFrame:\r\n    r\"\"\"\r\n    used by thgl-myket dataset\r\n    convert the raw .csv data to pandas dataframe and numpy array\r\n    input .csv file format should be: timestamp, head, tail, relation type\r\n    Args:\r\n        fname: the path to the raw data\r\n    \"\"\"\r\n    feat_size = 1\r\n    num_lines = sum(1 for line in open(fname)) - 1\r\n    vprint(f\"number of lines counted: {num_lines} in {fname}\")\r\n    u_list = np.zeros(num_lines)\r\n    i_list = np.zeros(num_lines)\r\n    ts_list = np.zeros(num_lines)\r\n    label_list = np.zeros(num_lines)\r\n    edge_type = np.zeros(num_lines)\r\n    feat_l = np.zeros((num_lines, feat_size))\r\n    idx_list = np.zeros(num_lines)\r\n    w_list = np.zeros(num_lines)\r\n    node_ids = {}\r\n    unique_id = 0\r\n\r\n    with open(fname, \"r\") as csv_file:\r\n        csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n        idx = 0\r\n        #timestamp, head, tail, relation type\r\n        for row in tqdm(csv_reader):\r\n            if idx == 0:\r\n                idx += 1\r\n                continue\r\n            else:\r\n                ts = int(row[0]) #converted to UNIX timestamp already \r\n                src = int(row[1])\r\n                dst = int(row[2])\r\n                relation = int(row[3])\r\n                if src not in node_ids:\r\n                    node_ids[src] = unique_id\r\n                    unique_id += 1\r\n                if dst not in node_ids:\r\n                    node_ids[dst] = unique_id\r\n                    unique_id += 1\r\n                u = node_ids[src]\r\n                i = node_ids[dst]\r\n                u_list[idx - 1] = u\r\n                i_list[idx - 1] = i\r\n                ts_list[idx - 1] = ts\r\n                idx_list[idx - 1] = idx\r\n                w_list[idx - 1] = float(1)\r\n                edge_type[idx - 1] = relation\r\n                idx += 1\r\n    return (\r\n        pd.DataFrame(\r\n            {\r\n                \"u\": u_list,\r\n                \"i\": i_list,\r\n                \"ts\": ts_list,\r\n                \"label\": label_list,\r\n                \"idx\": idx_list,\r\n                \"w\": w_list,\r\n                \"edge_type\": edge_type,\r\n            }\r\n        ),\r\n        feat_l,\r\n        node_ids,\r\n    )\r\n\r\n\"\"\"\r\nfunctions for tkgl-wikidata dataset\r\n\"\"\"\r\ndef csv_to_wikidata(\r\n    fname: str,\r\n) -> pd.DataFrame:\r\n    r\"\"\"\r\n    used by tkgl-wikidata and tkgl-smallpedia\r\n    convert the raw .csv data to pandas dataframe and numpy array\r\n    input .csv file format should be: timestamp, head, tail, relation type\r\n    Args:\r\n        fname: the path to the raw data\r\n    \"\"\"\r\n    feat_size = 1\r\n    num_lines = sum(1 for line in open(fname)) - 1\r\n    vprint(f\"number of lines counted: {num_lines} in {fname}\")\r\n    u_list = np.zeros(num_lines)\r\n    i_list = np.zeros(num_lines)\r\n    ts_list = np.zeros(num_lines)\r\n    label_list = np.zeros(num_lines)\r\n    edge_type = np.zeros(num_lines)\r\n    feat_l = np.zeros((num_lines, feat_size))\r\n    idx_list = np.zeros(num_lines)\r\n    w_list = np.zeros(num_lines)\r\n    node_ids = {}\r\n    edge_type_ids = {}\r\n    unique_id = 0\r\n    et_id = 0\r\n\r\n    with open(fname, \"r\") as csv_file:\r\n        csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n        idx = 0\r\n        #timestamp, head, tail, relation type\r\n        for row in tqdm(csv_reader):\r\n            if idx == 0:\r\n                idx += 1\r\n                continue\r\n            else:\r\n                ts = int(row[0]) #converted to year already\r\n                src = row[1]\r\n                dst = row[2]\r\n                relation = row[3]\r\n                if src not in node_ids:\r\n                    node_ids[src] = unique_id\r\n                    unique_id += 1\r\n                if dst not in node_ids:\r\n                    node_ids[dst] = unique_id\r\n                    unique_id += 1\r\n                if relation not in edge_type_ids:\r\n                    edge_type_ids[relation] = et_id\r\n                    et_id += 1\r\n                u = node_ids[src]\r\n                i = node_ids[dst]\r\n                u_list[idx - 1] = u\r\n                i_list[idx - 1] = i\r\n                ts_list[idx - 1] = ts\r\n                idx_list[idx - 1] = idx\r\n                w_list[idx - 1] = float(1)\r\n                edge_type[idx - 1] = edge_type_ids[relation]\r\n                idx += 1\r\n    return (\r\n        pd.DataFrame(\r\n            {\r\n                \"u\": u_list,\r\n                \"i\": i_list,\r\n                \"ts\": ts_list,\r\n                \"label\": label_list,\r\n                \"idx\": idx_list,\r\n                \"w\": w_list,\r\n                \"edge_type\": edge_type,\r\n            }\r\n        ),\r\n        feat_l,\r\n        node_ids,\r\n    )\r\n\r\n\r\ndef csv_to_staticdata(\r\n    fname: str,\r\n    node_ids: dict,\r\n) -> pd.DataFrame:\r\n    r\"\"\"\r\n    used by tkgl-wikidata and tkgl-smallpedia\r\n    convert the raw .csv data to pandas dataframe and numpy array for static knowledge edges\r\n    input .csv file format should be: head, tail, relation type\r\n    Args:\r\n        fname: the path to the raw data\r\n        node_ids: dictionary of node names mapped to integer node ids\r\n    \"\"\"\r\n    num_lines = sum(1 for line in open(fname)) - 1\r\n    vprint(f\"number of lines counted: {num_lines} in {fname}\")\r\n    u_list = np.zeros(num_lines)\r\n    i_list = np.zeros(num_lines)\r\n    edge_type = np.zeros(num_lines)\r\n    edge_type_ids = {}\r\n    out_dict = {}\r\n\r\n    with open(fname, \"r\") as csv_file:\r\n        csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n        idx = 0\r\n        #timestamp, head, tail, relation type\r\n        for row in tqdm(csv_reader):\r\n            if idx == 0:\r\n                idx += 1\r\n                continue\r\n            else:\r\n                src = row[0]\r\n                dst = row[1]\r\n                relation = row[2]\r\n                if src not in node_ids:\r\n                    node_ids[src] = len(node_ids)\r\n                if dst not in node_ids:\r\n                    node_ids[dst] = len(node_ids)\r\n                if relation not in edge_type_ids:\r\n                    edge_type_ids[relation] = len(edge_type_ids)\r\n                u = node_ids[src]\r\n                i = node_ids[dst]\r\n                u_list[idx - 1] = u\r\n                i_list[idx - 1] = i                \r\n                edge_type[idx - 1] = edge_type_ids[relation]\r\n                idx += 1\r\n\r\n    out_dict[\"head\"] = u_list\r\n    out_dict[\"tail\"] = i_list\r\n    out_dict[\"edge_type\"] = edge_type\r\n    return out_dict, node_ids\r\n\r\n\r\n\r\n\r\n\r\n\r\n\"\"\"\r\nfunctions for tkgl-polecat, tkgl-icews dataset\r\n\"\"\"\r\ndef csv_to_tkg_data(\r\n    fname: str,\r\n) -> pd.DataFrame:\r\n    r\"\"\"\r\n    used by tkgl-polecat\r\n    convert the raw .csv data to pandas dataframe and numpy array\r\n    input .csv file format should be: timestamp, head, tail, relation type\r\n    Args:\r\n        fname: the path to the raw data\r\n    \"\"\"\r\n    feat_size = 1\r\n    num_lines = sum(1 for line in open(fname)) - 1\r\n    vprint(f\"number of lines counted: {num_lines} in {fname}\")\r\n    u_list = np.zeros(num_lines)\r\n    i_list = np.zeros(num_lines)\r\n    ts_list = np.zeros(num_lines)\r\n    label_list = np.zeros(num_lines)\r\n    edge_type = np.zeros(num_lines)\r\n    feat_l = np.zeros((num_lines, feat_size))\r\n    idx_list = np.zeros(num_lines)\r\n    w_list = np.zeros(num_lines)\r\n    node_ids = {}\r\n    unique_id = 0\r\n\r\n    with open(fname, \"r\") as csv_file:\r\n        csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n        idx = 0\r\n        #timestamp, head, tail, relation type\r\n        for row in tqdm(csv_reader):\r\n            if idx == 0:\r\n                idx += 1\r\n                continue\r\n            else:\r\n                ts = int(row[0]) #converted to UNIX timestamp already \r\n                src = int(row[1])\r\n                dst = int(row[2])\r\n                relation = int(row[3])\r\n                if src not in node_ids:\r\n                    node_ids[src] = unique_id\r\n                    unique_id += 1\r\n                if dst not in node_ids:\r\n                    node_ids[dst] = unique_id\r\n                    unique_id += 1\r\n                u = node_ids[src]\r\n                i = node_ids[dst]\r\n                u_list[idx - 1] = u\r\n                i_list[idx - 1] = i\r\n                ts_list[idx - 1] = ts\r\n                idx_list[idx - 1] = idx\r\n                w_list[idx - 1] = float(1)\r\n                edge_type[idx - 1] = relation\r\n                idx += 1\r\n    return (\r\n        pd.DataFrame(\r\n            {\r\n                \"u\": u_list,\r\n                \"i\": i_list,\r\n                \"ts\": ts_list,\r\n                \"label\": label_list,\r\n                \"idx\": idx_list,\r\n                \"w\": w_list,\r\n                \"edge_type\": edge_type,\r\n            }\r\n        ),\r\n        feat_l,\r\n        node_ids,\r\n    )\r\n\r\n\r\n\r\n\r\n\"\"\"\r\nfunctions for wikipedia dataset\r\n---------------------------------------\r\n\"\"\"\r\ndef load_edgelist_wiki(fname: str) -> pd.DataFrame:\r\n    \"\"\"\r\n    loading wikipedia dataset into pandas dataframe\r\n    similar processing to\r\n    https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/datasets/jodie.html\r\n\r\n    Parameters:\r\n        fname: str, name of the input file\r\n    Returns:\r\n        df: a pandas dataframe containing the edgelist data\r\n    \"\"\"\r\n    df = pd.read_csv(fname, skiprows=1, header=None)\r\n    src = df.iloc[:, 0].values\r\n    dst = df.iloc[:, 1].values\r\n    dst += int(src.max()) + 1\r\n    t = df.iloc[:, 2].values\r\n    msg = df.iloc[:, 4:].values\r\n    idx = np.arange(t.shape[0])\r\n    w = np.ones(t.shape[0])\r\n\r\n    return pd.DataFrame({\"u\": src, \"i\": dst, \"ts\": t, \"idx\": idx, \"w\": w}), msg, None\r\n\r\n\r\n\"\"\"\r\nfunctions for un_trade dataset\r\n---------------------------------------\r\n\"\"\"\r\n\r\n\r\ndef load_edgelist_trade(fname: str, label_size=255):\r\n    \"\"\"\r\n    load the edgelist into pandas dataframe\r\n    \"\"\"\r\n    feat_size = 1\r\n    num_lines = sum(1 for line in open(fname)) - 1\r\n\r\n    u_list = np.zeros(num_lines)\r\n    i_list = np.zeros(num_lines)\r\n    ts_list = np.zeros(num_lines)\r\n    feat_l = np.zeros((num_lines, feat_size))\r\n    idx_list = np.zeros(num_lines)\r\n    w_list = np.zeros(num_lines)\r\n    node_ids = {}  # dictionary for node ids\r\n    node_uid = 0\r\n\r\n    with open(fname, \"r\") as csv_file:\r\n        csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n        idx = 0\r\n        for row in tqdm(csv_reader):\r\n            if idx == 0:\r\n                idx += 1\r\n            else:\r\n                ts = int(row[0])\r\n                u = row[1]\r\n                v = row[2]\r\n                w = float(row[3])\r\n                if u not in node_ids:\r\n                    node_ids[u] = node_uid\r\n                    node_uid += 1\r\n\r\n                if v not in node_ids:\r\n                    node_ids[v] = node_uid\r\n                    node_uid += 1\r\n\r\n                u = node_ids[u]\r\n                i = node_ids[v]\r\n                u_list[idx - 1] = u\r\n                i_list[idx - 1] = i\r\n                ts_list[idx - 1] = ts\r\n                idx_list[idx - 1] = idx\r\n                w_list[idx - 1] = w\r\n                feat_l[idx - 1] = np.array([w])\r\n                idx += 1\r\n\r\n    return (\r\n        pd.DataFrame(\r\n            {\"u\": u_list, \"i\": i_list, \"ts\": ts_list, \"idx\": idx_list, \"w\": w_list}\r\n        ),\r\n        feat_l,\r\n        node_ids,\r\n    )\r\n\r\n\r\ndef load_trade_label_dict(\r\n    fname: str,\r\n    node_ids: dict,\r\n) -> dict:\r\n    \"\"\"\r\n    load node labels into a nested dictionary instead of pandas dataobject\r\n    {ts: {node_id: label_vec}}\r\n    Parameters:\r\n        fname: str, name of the input file\r\n        node_ids: dictionary of user names mapped to integer node ids\r\n    Returns:\r\n        node_label_dict: a nested dictionary of node labels\r\n    \"\"\"\r\n    if not osp.exists(fname):\r\n        raise FileNotFoundError(f\"File not found at {fname}\")\r\n\r\n    label_size = len(node_ids)\r\n    #label_vec = np.zeros(label_size)\r\n\r\n    node_label_dict = {}  # {ts: {node_id: label_vec}}\r\n\r\n    with open(fname, \"r\") as csv_file:\r\n        csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n        idx = 0\r\n        for row in tqdm(csv_reader):\r\n            if idx == 0:\r\n                idx += 1\r\n            else:\r\n                ts = int(row[0])\r\n                u = node_ids[row[1]]\r\n                v = node_ids[row[2]]\r\n                weight = float(row[3])\r\n\r\n                if (ts not in node_label_dict):\r\n                    node_label_dict[ts] = {u:np.zeros(label_size)}\r\n\r\n                if (u not in node_label_dict[ts]):\r\n                    node_label_dict[ts][u] = np.zeros(label_size)\r\n\r\n                node_label_dict[ts][u][v] = weight\r\n                idx += 1\r\n        return node_label_dict\r\n\r\n\r\n\"\"\"\r\nfunctions for tgbn-token\r\n---------------------------------------\r\n\"\"\"\r\n\r\ndef load_edgelist_token(\r\n    fname: str,\r\n    label_size: int = 1001,\r\n) -> pd.DataFrame:\r\n    \"\"\"\r\n    load the edgelist into pandas dataframe\r\n    also outputs index for the user nodes and genre nodes\r\n    Parameters:\r\n        fname: str, name of the input file\r\n        label_size: int, number of genres\r\n    Returns:\r\n        df: a pandas dataframe containing the edgelist data\r\n    \"\"\"\r\n    feat_size = 2\r\n    num_lines = sum(1 for line in open(fname)) - 1\r\n    vprint(\"there are \", num_lines, \" lines in the raw data\")\r\n    u_list = np.zeros(num_lines)\r\n    i_list = np.zeros(num_lines)\r\n    ts_list = np.zeros(num_lines)\r\n    label_list = np.zeros(num_lines)\r\n    feat_l = np.zeros((num_lines, feat_size))\r\n    idx_list = np.zeros(num_lines)\r\n    w_list = np.zeros(num_lines)\r\n\r\n    node_ids = {}\r\n    rd_dict = {}\r\n    node_uid = label_size  # node ids start after all the genres\r\n    sr_uid = 0\r\n\r\n    with open(fname, \"r\") as csv_file:\r\n        csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n        idx = 0\r\n        # [timestamp,user_address,token_address,value,IsSender]\r\n        for row in tqdm(csv_reader):\r\n            if idx == 0:\r\n                idx += 1\r\n            else:\r\n                ts = row[0]\r\n                src = row[1]\r\n                token = row[2]\r\n                w = float(row[3])\r\n                attr = float(row[4])\r\n                if src not in node_ids:\r\n                    node_ids[src] = node_uid\r\n                    node_uid += 1\r\n                if token not in rd_dict:\r\n                    rd_dict[token] = sr_uid\r\n                    sr_uid += 1\r\n                u = node_ids[src]\r\n                i = rd_dict[token]\r\n                u_list[idx - 1] = u\r\n                i_list[idx - 1] = i\r\n                ts_list[idx - 1] = ts\r\n                idx_list[idx - 1] = idx\r\n                w_list[idx - 1] = w\r\n                feat_l[idx - 1] = np.array([w,attr])\r\n                idx += 1\r\n\r\n        return (\r\n            pd.DataFrame(\r\n                {\r\n                    \"u\": u_list,\r\n                    \"i\": i_list,\r\n                    \"ts\": ts_list,\r\n                    \"label\": label_list,\r\n                    \"idx\": idx_list,\r\n                    \"w\": w_list,\r\n                }\r\n            ),\r\n            feat_l,\r\n            node_ids,\r\n            rd_dict,\r\n        )\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\"\"\"\r\nfunctions for subreddits dataset\r\n---------------------------------------\r\n\"\"\"\r\n\r\n\r\ndef load_edgelist_sr(\r\n    fname: str,\r\n    label_size: int = 2221,\r\n) -> pd.DataFrame:\r\n    \"\"\"\r\n    load the edgelist into pandas dataframe\r\n    also outputs index for the user nodes and genre nodes\r\n    Parameters:\r\n        fname: str, name of the input file\r\n        label_size: int, number of genres\r\n    Returns:\r\n        df: a pandas dataframe containing the edgelist data\r\n    \"\"\"\r\n    feat_size = 1 #2\r\n    num_lines = sum(1 for line in open(fname)) - 1\r\n    vprint(\"there are \", num_lines, \" lines in the raw data\")\r\n    u_list = np.zeros(num_lines)\r\n    i_list = np.zeros(num_lines)\r\n    ts_list = np.zeros(num_lines)\r\n    label_list = np.zeros(num_lines)\r\n    feat_l = np.zeros((num_lines, feat_size))\r\n    idx_list = np.zeros(num_lines)\r\n    w_list = np.zeros(num_lines)\r\n\r\n    node_ids = {}\r\n    rd_dict = {}\r\n    node_uid = label_size  # node ids start after all the genres\r\n    sr_uid = 0\r\n\r\n    with open(fname, \"r\") as csv_file:\r\n        csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n        idx = 0\r\n        # ['ts', 'src', 'subreddit', 'num_words', 'score']\r\n        for row in tqdm(csv_reader):\r\n            if idx == 0:\r\n                idx += 1\r\n            else:\r\n                ts = row[0]\r\n                src = row[1]\r\n                subreddit = row[2]\r\n                #num_words = int(row[3])\r\n                score = int(row[4])\r\n                if src not in node_ids:\r\n                    node_ids[src] = node_uid\r\n                    node_uid += 1\r\n                if subreddit not in rd_dict:\r\n                    rd_dict[subreddit] = sr_uid\r\n                    sr_uid += 1\r\n                w = float(score)\r\n                u = node_ids[src]\r\n                i = rd_dict[subreddit]\r\n                u_list[idx - 1] = u\r\n                i_list[idx - 1] = i\r\n                ts_list[idx - 1] = ts\r\n                idx_list[idx - 1] = idx\r\n                w_list[idx - 1] = w\r\n                feat_l[idx - 1] = np.array([w])\r\n                idx += 1\r\n\r\n        return (\r\n            pd.DataFrame(\r\n                {\r\n                    \"u\": u_list,\r\n                    \"i\": i_list,\r\n                    \"ts\": ts_list,\r\n                    \"label\": label_list,\r\n                    \"idx\": idx_list,\r\n                    \"w\": w_list,\r\n                }\r\n            ),\r\n            feat_l,\r\n            node_ids,\r\n            rd_dict,\r\n        )\r\n\r\n\r\ndef load_labels_sr(\r\n    fname,\r\n    node_ids,\r\n    rd_dict,\r\n):\r\n    \"\"\"\r\n    load the node labels for subreddit dataset\r\n    \"\"\"\r\n    if not osp.exists(fname):\r\n        raise FileNotFoundError(f\"File not found at {fname}\")\r\n\r\n    # day, user_idx, label_vec\r\n    label_size = len(rd_dict)\r\n    label_vec = np.zeros(label_size)\r\n    ts_prev = 0\r\n    prev_user = 0\r\n\r\n    ts_list = []\r\n    node_id_list = []\r\n    y_list = []\r\n\r\n    with open(fname, \"r\") as csv_file:\r\n        csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n        idx = 0\r\n        # ['ts', 'src', 'subreddit', 'num_words', 'score']\r\n        for row in tqdm(csv_reader):\r\n            if idx == 0:\r\n                idx += 1\r\n            else:\r\n                user_id = node_ids[int(row[1])]\r\n                ts = int(row[0])\r\n                sr_id = int(rd_dict[row[2]])\r\n                weight = float(row[3])\r\n                if idx == 1:\r\n                    ts_prev = ts\r\n                    prev_user = user_id\r\n                # the next day\r\n                if ts != ts_prev:\r\n                    ts_list.append(ts_prev)\r\n                    node_id_list.append(prev_user)\r\n                    y_list.append(label_vec)\r\n                    label_vec = np.zeros(label_size)\r\n                    ts_prev = ts\r\n                    prev_user = user_id\r\n                else:\r\n                    label_vec[sr_id] = weight\r\n\r\n                if user_id != prev_user:\r\n                    ts_list.append(ts_prev)\r\n                    node_id_list.append(prev_user)\r\n                    y_list.append(label_vec)\r\n                    prev_user = user_id\r\n                    label_vec = np.zeros(label_size)\r\n                idx += 1\r\n        return pd.DataFrame({\"ts\": ts_list, \"node_id\": node_id_list, \"y\": y_list})\r\n\r\n\r\ndef load_label_dict(fname: str, node_ids: dict, rd_dict: dict) -> dict:\r\n    \"\"\"\r\n    load node labels into a nested dictionary instead of pandas dataobject\r\n    {ts: {node_id: label_vec}}\r\n    Parameters:\r\n        fname: str, name of the input file\r\n        node_ids: dictionary of user names mapped to integer node ids\r\n        rd_dict: dictionary of subreddit names mapped to integer node ids\r\n    \"\"\"\r\n    if not osp.exists(fname):\r\n        raise FileNotFoundError(f\"File not found at {fname}\")\r\n\r\n    # day, user_idx, label_vec\r\n    label_size = len(rd_dict)\r\n    node_label_dict = {}  # {ts: {node_id: label_vec}}\r\n\r\n    with open(fname, \"r\") as csv_file:\r\n        csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n        idx = 0\r\n        # ['ts', 'src', 'dst', 'w']\r\n        for row in tqdm(csv_reader):\r\n            if idx == 0:\r\n                idx += 1\r\n            else:\r\n                u = node_ids[row[1]]\r\n                ts = int(row[0])\r\n                v = int(rd_dict[row[2]])\r\n                weight = float(row[3])\r\n                if (ts not in node_label_dict):\r\n                    node_label_dict[ts] = {u:np.zeros(label_size)}\r\n\r\n                if (u not in node_label_dict[ts]):\r\n                    node_label_dict[ts][u] = np.zeros(label_size)\r\n\r\n                node_label_dict[ts][u][v] = weight\r\n                idx += 1\r\n        return node_label_dict\r\n\r\n\r\n\"\"\"\r\nfunctions for redditcomments\r\n-------------------------------------------\r\n\"\"\"\r\n\r\n\r\ndef csv_to_pd_data_rc(\r\n    fname: str,\r\n) -> pd.DataFrame:\r\n    r\"\"\"\r\n    currently used by redditcomments dataset\r\n    convert the raw .csv data to pandas dataframe and numpy array\r\n    input .csv file format should be: timestamp, node u, node v, attributes\r\n    Args:\r\n        fname: the path to the raw data\r\n    \"\"\"\r\n    feat_size = 2  # 1 for subreddit, 1 for num words\r\n    num_lines = sum(1 for line in open(fname)) - 1\r\n    vprint(\"there are \", num_lines, \" lines in the raw data\")\r\n    u_list = np.zeros(num_lines)\r\n    i_list = np.zeros(num_lines)\r\n    ts_list = np.zeros(num_lines)\r\n    label_list = np.zeros(num_lines)\r\n    feat_l = np.zeros((num_lines, feat_size))\r\n    idx_list = np.zeros(num_lines)\r\n    w_list = np.zeros(num_lines)\r\n    node_ids = {}\r\n\r\n    unique_id = 0\r\n    max_words = 5000  # counted form statistics\r\n\r\n    with open(fname, \"r\") as csv_file:\r\n        csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n        idx = 0\r\n        # ['ts', 'src', 'dst', 'subreddit', 'num_words', 'score']\r\n        for row in tqdm(csv_reader):\r\n            if idx == 0:\r\n                idx += 1\r\n                continue\r\n            else:\r\n                ts = int(row[0])\r\n                src = row[1]\r\n                dst = row[2]\r\n                num_words = int(row[3]) / max_words  # int number, normalize to [0,1]\r\n                score = int(row[4])  # int number\r\n\r\n                # reindexing node and subreddits\r\n                if src not in node_ids:\r\n                    node_ids[src] = unique_id\r\n                    unique_id += 1\r\n                if dst not in node_ids:\r\n                    node_ids[dst] = unique_id\r\n                    unique_id += 1\r\n                w = float(score)\r\n                u = node_ids[src]\r\n                i = node_ids[dst]\r\n                u_list[idx - 1] = u\r\n                i_list[idx - 1] = i\r\n                ts_list[idx - 1] = ts\r\n                idx_list[idx - 1] = idx\r\n                w_list[idx - 1] = w\r\n                feat_l[idx - 1] = np.array([num_words])\r\n                idx += 1\r\n    vprint(\"there are \", len(node_ids), \" unique nodes\")\r\n\r\n    return (\r\n        pd.DataFrame(\r\n            {\r\n                \"u\": u_list,\r\n                \"i\": i_list,\r\n                \"ts\": ts_list,\r\n                \"label\": label_list,\r\n                \"idx\": idx_list,\r\n                \"w\": w_list,\r\n            }\r\n        ),\r\n        feat_l,\r\n        node_ids,\r\n    )\r\n\r\n\r\n\"\"\"\r\nfunctions for stablecoin\r\n-------------------------------------------\r\n\"\"\"\r\ndef csv_to_pd_data_sc(\r\n    fname: str,\r\n) -> pd.DataFrame:\r\n    r\"\"\"\r\n    currently used by stablecoin dataset\r\n    convert the raw .csv data to pandas dataframe and numpy array\r\n    input .csv file format should be: timestamp, node u, node v, attributes\r\n    Parameters:\r\n        fname: the path to the raw data\r\n    Returns:\r\n        df: a pandas dataframe containing the edgelist data\r\n        feat_l: a numpy array containing the node features\r\n        node_ids: a dictionary mapping node id to integer\r\n    \"\"\"\r\n    feat_size = 1\r\n    num_lines = sum(1 for line in open(fname)) - 1\r\n    vprint(f\"number of lines counted: {num_lines} in {fname}\")\r\n    u_list = np.zeros(num_lines)\r\n    i_list = np.zeros(num_lines)\r\n    ts_list = np.zeros(num_lines)\r\n    label_list = np.zeros(num_lines)\r\n    feat_l = np.zeros((num_lines, feat_size))\r\n    idx_list = np.zeros(num_lines)\r\n    w_list = np.zeros(num_lines)\r\n    node_ids = {}\r\n    unique_id = 0\r\n\r\n    with open(fname, \"r\") as csv_file:\r\n        csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n        idx = 0\r\n        # time,src,dst,weight\r\n        # 1648811421,0x27cbb0e6885ccb1db2dab7c2314131c94795fbef,0x8426a27add8dca73548f012d92c7f8f4bbd42a3e,800.0\r\n        for row in tqdm(csv_reader):\r\n            if idx == 0:\r\n                idx += 1\r\n                continue\r\n            else:\r\n                ts = int(row[0])\r\n                src = row[1]\r\n                dst = row[2]\r\n\r\n                if src not in node_ids:\r\n                    node_ids[src] = unique_id\r\n                    unique_id += 1\r\n                if dst not in node_ids:\r\n                    node_ids[dst] = unique_id\r\n                    unique_id += 1\r\n\r\n                w = float(row[3])\r\n                if w == 0:\r\n                    w = 1\r\n\r\n                u = node_ids[src]\r\n                i = node_ids[dst]\r\n                u_list[idx - 1] = u\r\n                i_list[idx - 1] = i\r\n                ts_list[idx - 1] = ts\r\n                idx_list[idx - 1] = idx\r\n                w_list[idx - 1] = w\r\n                feat_l[idx - 1] = np.zeros(feat_size)\r\n                idx += 1\r\n\r\n    #! normalize by log 2 for stablecoin\r\n    w_list = np.log2(w_list)\r\n\r\n    return (\r\n        pd.DataFrame(\r\n            {\r\n                \"u\": u_list,\r\n                \"i\": i_list,\r\n                \"ts\": ts_list,\r\n                \"label\": label_list,\r\n                \"idx\": idx_list,\r\n                \"w\": w_list,\r\n            }\r\n        ),\r\n        feat_l,\r\n        node_ids,\r\n    )\r\n\r\n\"\"\"\r\nfunctions for opensky\r\n-------------------------------------------\r\n\"\"\"\r\n\r\n\r\ndef convert_str2int(\r\n    in_str: str,\r\n) -> np.ndarray:\r\n    \"\"\"\r\n    convert strings to vectors of integers based on individual character\r\n    each letter is converted as follows, a=10, b=11\r\n    numbers are still int\r\n    Parameters:\r\n        in_str: an input string to parse\r\n    Returns:\r\n        out: a numpy integer array\r\n    \"\"\"\r\n    out = []\r\n    for element in in_str:\r\n        if element.isnumeric():\r\n            out.append(element)\r\n        elif element == \"!\":\r\n            out.append(-1)\r\n        else:\r\n            out.append(ord(element.upper()) - 44 + 9)\r\n    out = np.array(out, dtype=np.float32)\r\n    return out\r\n\r\n\r\ndef csv_to_pd_data(\r\n    fname: str,\r\n) -> pd.DataFrame:\r\n    r\"\"\"\r\n    currently used by tgbl-flight dataset\r\n    convert the raw .csv data to pandas dataframe and numpy array\r\n    input .csv file format should be: timestamp, node u, node v, attributes\r\n    Args:\r\n        fname: the path to the raw data\r\n    \"\"\"\r\n    feat_size = 16\r\n    num_lines = sum(1 for line in open(fname)) - 1\r\n    vprint(f\"number of lines counted: {num_lines} in {fname}\")\r\n    u_list = np.zeros(num_lines)\r\n    i_list = np.zeros(num_lines)\r\n    ts_list = np.zeros(num_lines)\r\n    label_list = np.zeros(num_lines)\r\n    feat_l = np.zeros((num_lines, feat_size))\r\n    idx_list = np.zeros(num_lines)\r\n    w_list = np.zeros(num_lines)\r\n    node_ids = {}\r\n    unique_id = 0\r\n    ts_format = None\r\n\r\n    with open(fname, \"r\") as csv_file:\r\n        csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n        idx = 0\r\n        #'day','src','dst','callsign','typecode'\r\n        for row in tqdm(csv_reader):\r\n            if idx == 0:\r\n                idx += 1\r\n                continue\r\n            else:\r\n                ts = row[0]\r\n                if ts_format is None:\r\n                    if (ts.isdigit()):\r\n                        ts_format = True\r\n                    else:\r\n                        ts_format = False\r\n                \r\n                if ts_format:\r\n                    ts = float(int(ts)) #unix timestamp already\r\n                else:\r\n                    #convert to unix timestamp\r\n                    TIME_FORMAT = \"%Y-%m-%d\"\r\n                    date_cur = datetime.datetime.strptime(ts, TIME_FORMAT)\r\n                    ts = float(date_cur.timestamp())\r\n                    # TIME_FORMAT = \"%Y-%m-%d\" # 2019-01-01\r\n                    # date_cur  = date.fromisoformat(ts)\r\n                    # dt = datetime.datetime.combine(date_cur, datetime.datetime.min.time())\r\n                    # dt = dt.replace(tzinfo=datetime.timezone.edt)\r\n                    # ts = float(dt.timestamp())\r\n\r\n\r\n                src = row[1]\r\n                dst = row[2]\r\n\r\n                # 'callsign' has max size 8, can be 4, 5, 6, or 7\r\n                # 'typecode' has max size 8\r\n                # use ! as padding\r\n\r\n                # pad row[3] to size 7\r\n                if len(row[3]) == 0:\r\n                    row[3] = \"!!!!!!!!\"\r\n                while len(row[3]) < 8:\r\n                    row[3] += \"!\"\r\n\r\n                # pad row[4] to size 4\r\n                if len(row[4]) == 0:\r\n                    row[4] = \"!!!!!!!!\"\r\n                while len(row[4]) < 8:\r\n                    row[4] += \"!\"\r\n                if len(row[4]) > 8:\r\n                    row[4] = \"!!!!!!!!\"\r\n\r\n                feat_str = row[3] + row[4]\r\n\r\n                if src not in node_ids:\r\n                    node_ids[src] = unique_id\r\n                    unique_id += 1\r\n                if dst not in node_ids:\r\n                    node_ids[dst] = unique_id\r\n                    unique_id += 1\r\n                u = node_ids[src]\r\n                i = node_ids[dst]\r\n                u_list[idx - 1] = u\r\n                i_list[idx - 1] = i\r\n                ts_list[idx - 1] = ts\r\n                idx_list[idx - 1] = idx\r\n                w_list[idx - 1] = float(1)\r\n                feat_l[idx - 1] = convert_str2int(feat_str)\r\n                idx += 1\r\n    return (\r\n        pd.DataFrame(\r\n            {\r\n                \"u\": u_list,\r\n                \"i\": i_list,\r\n                \"ts\": ts_list,\r\n                \"label\": label_list,\r\n                \"idx\": idx_list,\r\n                \"w\": w_list,\r\n            }\r\n        ),\r\n        feat_l,\r\n        node_ids,\r\n    )\r\n\r\n\r\ndef process_node_feat(\r\n    fname: str,\r\n    node_ids,\r\n):\r\n    \"\"\"\r\n    1. need to have the same node id as csv_to_pd_data\r\n    2. process the various node features into a vector\r\n    3. return a numpy array of node features with index corresponding to node id\r\n\r\n    airport_code,type,continent,iso_region,longitude,latitude\r\n    type: onehot encoding\r\n    continent: onehot encoding\r\n    iso_region: alphabet encoding same as edge feat\r\n    longitude: float divide by 180\r\n    latitude: float divide by 90\r\n    \"\"\"\r\n    feat_size = 20\r\n    node_feat = np.zeros((len(node_ids), feat_size))\r\n    type_dict = {}\r\n    type_idx = 0\r\n    continent_dict = {}\r\n    cont_idx = 0\r\n\r\n    with open(fname, \"r\") as csv_file:\r\n        csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n        idx = 0\r\n        # airport_code,type,continent,iso_region,longitude,latitude\r\n        for row in tqdm(csv_reader):\r\n            if idx == 0:\r\n                idx += 1\r\n                continue\r\n            else:\r\n                code = row[0]\r\n                if code not in node_ids:\r\n                    continue\r\n                else:\r\n                    node_id = node_ids[code]\r\n                    airport_type = row[1]\r\n                    if airport_type not in type_dict:\r\n                        type_dict[airport_type] = type_idx\r\n                        type_idx += 1\r\n                    continent = row[2]\r\n                    if continent not in continent_dict:\r\n                        continent_dict[continent] = cont_idx\r\n                        cont_idx += 1\r\n\r\n    with open(fname, \"r\") as csv_file:\r\n        csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n        idx = 0\r\n        # airport_code,type,continent,iso_region,longitude,latitude\r\n        for row in tqdm(csv_reader):\r\n            if idx == 0:\r\n                idx += 1\r\n                continue\r\n            else:\r\n                code = row[0]\r\n                if code not in node_ids:\r\n                    continue\r\n                else:\r\n                    node_id = node_ids[code]\r\n                    airport_type = type_dict[row[1]]\r\n                    type_vec = np.zeros(type_idx)\r\n                    type_vec[airport_type] = 1\r\n                    continent = continent_dict[row[2]]\r\n                    cont_vec = np.zeros(cont_idx)\r\n                    cont_vec[continent] = 1\r\n                    while len(row[3]) < 7:\r\n                        row[3] += \"!\"\r\n                    iso_region = convert_str2int(row[3])  # numpy float array\r\n                    lng = float(row[4])\r\n                    lat = float(row[5])\r\n                    coor_vec = np.array([lng, lat])\r\n                    final = np.concatenate(\r\n                        (type_vec, cont_vec, iso_region, coor_vec), axis=0\r\n                    )\r\n                    node_feat[node_id] = final\r\n    return node_feat\r\n\r\n\"\"\"\r\nfunctions for un trade\r\n-------------------------------------------\r\n\"\"\"\r\n\r\n\r\n#! these are helper functions\r\n# TODO cleaning the un trade csv with countries with comma in the name, to remove this function\r\ndef clean_rows(\r\n    fname: str,\r\n    outname: str,\r\n):\r\n    r\"\"\"\r\n    clean the rows with comma in the name\r\n    args:\r\n        fname: the path to the raw data\r\n        outname: the path to the cleaned data\r\n    \"\"\"\r\n\r\n    outf = open(outname, \"w\")\r\n\r\n    with open(fname) as f:\r\n        s = next(f)\r\n        outf.write(s)\r\n        for idx, line in enumerate(f):\r\n            strs = [\"China, Taiwan Province of\", \"China, mainland\"]\r\n            for str in strs:\r\n                line = line.replace(\r\n                    \"China, Taiwan Province of\", \"Taiwan Province of China\"\r\n                )\r\n                line = line.replace(\"China, mainland\", \"China mainland\")\r\n                line = line.replace(\"China, Hong Kong SAR\", \"China Hong Kong SAR\")\r\n                line = line.replace(\"China, Macao SAR\", \"China Macao SAR\")\r\n                line = line.replace(\r\n                    \"Saint Helena, Ascension and Tristan da Cunha\",\r\n                    \"Saint Helena Ascension and Tristan da Cunha\",\r\n                )\r\n\r\n            e = line.strip().split(\",\")\r\n            if len(e) > 4:\r\n                raise ValueError(f\"line has more than 4 elements: {e}\")\r\n            outf.write(line)\r\n\r\n    outf.close()\r\n\r\n\r\n\"\"\"\r\nfunctions for last fm genre\r\n-------------------------------------------\r\n\"\"\"\r\n\r\n\r\ndef load_edgelist_datetime(fname, label_size=514):\r\n    \"\"\"\r\n    load the edgelist into a pandas dataframe\r\n    use numpy array instead of list for faster processing\r\n    assume all edges are already sorted by time\r\n    convert all time unit to unix time\r\n\r\n    time, user_id, genre, weight\r\n    \"\"\"\r\n    feat_size = 1\r\n    num_lines = sum(1 for line in open(fname)) - 1\r\n    vprint(f\"number of lines counted: {num_lines} in {fname}\")\r\n    u_list = np.zeros(num_lines)\r\n    i_list = np.zeros(num_lines)\r\n    ts_list = np.zeros(num_lines)\r\n    feat_l = np.zeros((num_lines, feat_size))\r\n    idx_list = np.zeros(num_lines)\r\n    w_list = np.zeros(num_lines)\r\n    node_ids = {}  # dictionary for node ids\r\n    label_ids = {}  # dictionary for label ids\r\n    node_uid = label_size  # node ids start after the genre nodes\r\n    label_uid = 0\r\n\r\n    with open(fname, \"r\") as csv_file:\r\n        csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n        idx = 0\r\n        for row in tqdm(csv_reader):\r\n            if idx == 0:\r\n                idx += 1\r\n            else:\r\n                ts = int(row[0])\r\n                user_id = row[1]\r\n                genre = row[2]\r\n                w = float(row[3])\r\n\r\n                if user_id not in node_ids:\r\n                    node_ids[user_id] = node_uid\r\n                    node_uid += 1\r\n\r\n                if genre not in label_ids:\r\n                    label_ids[genre] = label_uid\r\n                    if label_uid >= label_size:\r\n                        vprint(\"id overlap, terminate\")\r\n                    label_uid += 1\r\n\r\n                u = node_ids[user_id]\r\n                i = label_ids[genre]\r\n                u_list[idx - 1] = u\r\n                i_list[idx - 1] = i\r\n                ts_list[idx - 1] = ts\r\n                idx_list[idx - 1] = idx\r\n                w_list[idx - 1] = w\r\n                feat_l[idx - 1] = np.asarray([w])\r\n                idx += 1\r\n\r\n    return (\r\n        pd.DataFrame(\r\n            {\"u\": u_list, \"i\": i_list, \"ts\": ts_list, \"idx\": idx_list, \"w\": w_list}\r\n        ),\r\n        feat_l,\r\n        node_ids,\r\n        label_ids,\r\n    )\r\n\r\n\r\ndef load_genre_list(fname):\r\n    \"\"\"\r\n    load the list of genres\r\n    \"\"\"\r\n    if not osp.exists(fname):\r\n        raise FileNotFoundError(f\"File not found at {fname}\")\r\n\r\n    edgelist = open(fname, \"r\")\r\n    lines = list(edgelist.readlines())\r\n    edgelist.close()\r\n\r\n    genre_index = {}\r\n    ctr = 0\r\n    for i in range(1, len(lines)):\r\n        vals = lines[i].split(\",\")\r\n        genre = vals[0]\r\n        if genre not in genre_index:\r\n            genre_index[genre] = ctr\r\n            ctr += 1\r\n        else:\r\n            raise ValueError(\"duplicate in genre_index\")\r\n    return genre_index\r\n\r\n\r\n\"\"\"\r\nfunctions for wikipedia and un_trade\r\n-------------------------------------------\r\n\"\"\"\r\n\r\ndef reindex(\r\n    df: pd.DataFrame,\r\n    bipartite: Optional[bool] = False,\r\n):\r\n    r\"\"\"\r\n    reindex the nodes especially if the node ids are not integers\r\n    Args:\r\n        df: the pandas dataframe containing the graph\r\n        bipartite: whether the graph is bipartite\r\n    \"\"\"\r\n    new_df = df.copy()\r\n    if bipartite:\r\n        assert df.u.max() - df.u.min() + 1 == len(df.u.unique())\r\n        assert df.i.max() - df.i.min() + 1 == len(df.i.unique())\r\n\r\n        upper_u = df.u.max() + 1\r\n        new_i = df.i + upper_u\r\n\r\n        new_df.i = new_i\r\n        new_df.u += 1\r\n        new_df.i += 1\r\n        new_df.idx += 1\r\n    else:\r\n        new_df.u += 1\r\n        new_df.i += 1\r\n        new_df.idx += 1\r\n\r\n    return new_df\r\n\r\n\r\n    \r\n"
  },
  {
    "path": "tgb/utils/stats.py",
    "content": "\"\"\"\r\nscript for generating statistics from the dataset\r\n\"\"\"\r\nimport csv\r\nimport numpy as np\r\nfrom tgb.utils.utils import vprint\r\n\r\n\r\n\"\"\"\r\n#! analyze statistics from the dataset\r\n#* 1). # of unique nodes, 2). # of edges. 3). # of unique edges, 4). # of timestamps 5). recurrence of nodes\r\n\"\"\"\r\n\r\n\r\ndef analyze_csv(fname):\r\n    node_dict = {}\r\n    edge_dict = {}\r\n    num_edges = 0\r\n    num_time = 0\r\n    time_dict = {}\r\n\r\n    with open(fname, \"r\") as csv_file:\r\n        csv_reader = csv.reader(csv_file, delimiter=\",\")\r\n        line_count = 0\r\n        for row in csv_reader:\r\n            if line_count == 0:\r\n                line_count += 1\r\n            else:\r\n                # t,u,v,w\r\n                t = row[0]\r\n                u = row[1]\r\n                v = row[2]\r\n\r\n                # count unique time\r\n                if t not in time_dict:\r\n                    time_dict[t] = 1\r\n                    num_time += 1\r\n\r\n                # unique nodes\r\n                if u not in node_dict:\r\n                    node_dict[u] = 1\r\n                else:\r\n                    node_dict[u] += 1\r\n\r\n                if v not in node_dict:\r\n                    node_dict[v] = 1\r\n                else:\r\n                    node_dict[v] += 1\r\n\r\n                # unique edges\r\n                num_edges += 1\r\n                if (u, v) not in edge_dict:\r\n                    edge_dict[(u, v)] = 1\r\n                else:\r\n                    edge_dict[(u, v)] += 1\r\n\r\n    vprint(\"----------------------high level statistics-------------------------\")\r\n    vprint(\"number of total edges are \", num_edges)\r\n    vprint(\"number of nodes are \", len(node_dict))\r\n    vprint(\"number of unique edges are \", len(edge_dict))\r\n    vprint(\"number of unique timestamps are \", num_time)\r\n\r\n    num_10 = 0\r\n    num_100 = 0\r\n    num_1000 = 0\r\n\r\n    for node in node_dict:\r\n        if node_dict[node] >= 10:\r\n            num_10 += 1\r\n        if node_dict[node] >= 100:\r\n            num_100 += 1\r\n        if node_dict[node] >= 1000:\r\n            num_1000 += 1\r\n    vprint(\"number of nodes with # edges >= 10 is \", num_10)\r\n    vprint(\"number of nodes with # edges >= 100 is \", num_100)\r\n    vprint(\"number of nodes with # edges >= 1000 is \", num_1000)\r\n    vprint(\"----------------------high level statistics-------------------------\")\r\n\r\n\r\ndef plot_curve(y: np.ndarray, outname: str) -> None:\r\n    \"\"\"\r\n    plot the training curve given y\r\n    Parameters:\r\n        y: np.ndarray, the training curve\r\n        outname: str, the output name\r\n    \"\"\"\r\n    plt.plot(y, color=\"#fc4e2a\")\r\n    plt.savefig(outname + \".pdf\")\r\n    plt.close()\r\n\r\n\r\ndef main():\r\n    fname = \"tgb/datasets/tgbl-wiki/tgbl-wiki_edgelist.csv\"\r\n    analyze_csv(fname)\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n"
  },
  {
    "path": "tgb/utils/utils.py",
    "content": "import random\r\nimport os\r\nimport pickle\r\nimport sys\r\nimport argparse\r\nimport json\r\nimport torch\r\nfrom typing import Any\r\nimport numpy as np\r\nfrom torch_geometric.data import TemporalData\r\nimport pandas as pd\r\nimport torch\r\n\r\n\r\n_VERBOSE = os.getenv(\"TGB_VERBOSE\", 'False').lower() in ['true', '1']\r\n\r\ndef set_verbose(flag: bool) -> None:\r\n    global _VERBOSE\r\n    _VERBOSE = flag\r\n\r\ndef vprint(*args, **kwargs):\r\n    global _VERBOSE\r\n    if _VERBOSE: print(*args, **kwargs)\r\n \r\n\r\n\r\ndef add_inverse_quadruples(df: pd.DataFrame) -> pd.DataFrame:\r\n    r\"\"\"\r\n    adds the inverse relations required for the model to the dataframe\r\n    \"\"\"\r\n    if (\"edge_type\" not in df):\r\n        raise ValueError(\"edge_type is required to invert relation in TKG\")\r\n    \r\n    sources = np.array(df[\"u\"])\r\n    destinations = np.array(df[\"i\"])\r\n    timestamps = np.array(df[\"ts\"])\r\n    edge_idxs = np.array(df[\"idx\"])\r\n    weights = np.array(df[\"w\"])\r\n    edge_type = np.array(df[\"edge_type\"])\r\n\r\n    num_rels = np.unique(edge_type).shape[0]\r\n    inv_edge_type = edge_type + num_rels\r\n\r\n    all_sources = np.concatenate([sources, destinations])\r\n    all_destinations = np.concatenate([destinations, sources])\r\n    all_timestamps = np.concatenate([timestamps, timestamps])\r\n    all_edge_idxs = np.concatenate([edge_idxs, edge_idxs+edge_idxs.max()+1])\r\n    all_weights = np.concatenate([weights, weights])\r\n    all_edge_types = np.concatenate([edge_type, inv_edge_type])\r\n\r\n    return pd.DataFrame(\r\n            {\r\n                \"u\": all_sources,\r\n                \"i\": all_destinations,\r\n                \"ts\": all_timestamps,\r\n                \"label\": np.ones(all_timestamps.shape[0]),\r\n                \"idx\": all_edge_idxs,\r\n                \"w\": all_weights,\r\n                \"edge_type\": all_edge_types,\r\n            }\r\n        )\r\n\r\n\r\n\r\ndef add_inverse_quadruples_np(quadruples: np.array, \r\n                              num_rels:int) -> np.array:\r\n    \"\"\"\r\n    creates an inverse quadruple for each quadruple in quadruples. inverse quadruple swaps subject and objsect, and increases \r\n    relation id by num_rels\r\n    :param quadruples: [np.array] dataset quadruples, [src, relation_id, dst, timestamp ]\r\n    :param num_rels: [int] number of relations that we have originally\r\n    returns all_quadruples: [np.array] quadruples including inverse quadruples\r\n    \"\"\"\r\n    inverse_quadruples = quadruples[:, [2, 1, 0, 3]]\r\n    inverse_quadruples[:, 1] = inverse_quadruples[:, 1] + num_rels  # we also need inverse quadruples\r\n    all_quadruples = np.concatenate((quadruples[:,0:4], inverse_quadruples))\r\n    return all_quadruples\r\n\r\n\r\ndef add_inverse_quadruples_pyg(data: TemporalData, num_rels:int=-1) -> list:\r\n    r\"\"\"\r\n    creates an inverse quadruple from PyG TemporalData object, returns both the original and inverse quadruples\r\n    \"\"\"\r\n    timestamp = data.t\r\n    head = data.src\r\n    tail = data.dst\r\n    msg = data.msg\r\n    edge_type = data.edge_type #relation\r\n    num_rels = torch.max(edge_type).item() + 1\r\n    inv_type = edge_type + num_rels\r\n    all_data = TemporalData(src=torch.cat([head, tail]), \r\n                            dst=torch.cat([tail, head]), \r\n                            t=torch.cat([timestamp, timestamp.clone()]), \r\n                            edge_type=torch.cat([edge_type, inv_type]), \r\n                            msg=torch.cat([msg, msg.clone()]),\r\n                            y = torch.cat([data.y, data.y.clone()]),)\r\n    return all_data\r\n\r\n\r\n\r\n# import torch\r\ndef save_pkl(obj: Any, fname: str) -> None:\r\n    r\"\"\"\r\n    save a python object as a pickle file\r\n    \"\"\"\r\n    with open(fname, \"wb\") as handle:\r\n        pickle.dump(obj, handle, protocol=pickle.HIGHEST_PROTOCOL)\r\n\r\n\r\ndef load_pkl(fname: str) -> Any:\r\n    r\"\"\"\r\n    load a python object from a pickle file\r\n    \"\"\"\r\n    with open(fname, \"rb\") as handle:\r\n        return pickle.load(handle)\r\n\r\ndef set_random_seed(random_seed: int):\r\n    r\"\"\"\r\n    set random seed for reproducibility\r\n    Args:\r\n        random_seed (int): random seed\r\n    \"\"\"\r\n    random.seed(random_seed)\r\n    np.random.seed(random_seed)\r\n    torch.manual_seed(random_seed)\r\n    torch.cuda.manual_seed(random_seed)\r\n    torch.cuda.manual_seed_all(random_seed)\r\n    torch.backends.cudnn.benchmark = False\r\n    torch.backends.cudnn.deterministic = True\r\n    vprint(f'INFO: fixed random seed: {random_seed}')\r\n\r\n\r\n\r\ndef find_nearest(array, value):\r\n    array = np.asarray(array)\r\n    idx = (np.abs(array - value)).argmin()\r\n    return array[idx]\r\n\r\n\r\ndef get_args():\r\n    parser = argparse.ArgumentParser('*** TGB ***')\r\n    parser.add_argument('-d', '--data', type=str, help='Dataset name', default='tgbl-wiki')\r\n    parser.add_argument('--lr', type=float, help='Learning rate', default=1e-4)\r\n    parser.add_argument('--bs', type=int, help='Batch size', default=200)\r\n    parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)\r\n    parser.add_argument('--num_epoch', type=int, help='Number of epochs', default=30)\r\n    parser.add_argument('--seed', type=int, help='Random seed', default=1)\r\n    parser.add_argument('--mem_dim', type=int, help='Memory dimension', default=100)\r\n    parser.add_argument('--time_dim', type=int, help='Time dimension', default=100)\r\n    parser.add_argument('--emb_dim', type=int, help='Embedding dimension', default=100)\r\n    parser.add_argument('--tolerance', type=float, help='Early stopper tolerance', default=1e-6)\r\n    parser.add_argument('--patience', type=float, help='Early stopper patience', default=5)\r\n    parser.add_argument('--num_run', type=int, help='Number of iteration runs', default=5)\r\n\r\n    try:\r\n        args = parser.parse_args()\r\n    except:\r\n        parser.print_help()\r\n        sys.exit(0)\r\n    return args, sys.argv\r\n\r\n\r\n\r\n\r\ndef save_results(new_results: dict, filename: str):\r\n    r\"\"\"\r\n    save (new) results into a json file\r\n    :param: new_results (dictionary): a dictionary of new results to be saved\r\n    :filename: the name of the file to save the (new) results\r\n    \"\"\"\r\n    if os.path.isfile(filename):\r\n        # append to the file\r\n        with open(filename, 'r+') as json_file:\r\n            file_data = json.load(json_file)\r\n            # convert file_data to list if not\r\n            if type(file_data) is dict:\r\n                file_data = [file_data]\r\n            file_data.append(new_results)\r\n            json_file.seek(0)\r\n            json.dump(file_data, json_file, indent=4)\r\n    else:\r\n        # dump the results\r\n        with open(filename, 'w') as json_file:\r\n            json.dump(new_results, json_file, indent=4)\r\n\r\n\r\ndef split_by_time(data):\r\n    \"\"\"\r\n    https://github.com/Lee-zix/CEN/blob/main/rgcn/utils.py\r\n    create list where each entry has an entry with all triples for this timestep\r\n    \"\"\"\r\n    timesteps = list(set(data[:,3]))\r\n    timesteps.sort()\r\n    snapshot_list = [None] * len(timesteps)\r\n\r\n    for index, ts in enumerate(timesteps):\r\n        mask = np.where(data[:, 3] == ts)[0]\r\n        snapshot_list[index] = data[mask,:3]\r\n\r\n    return snapshot_list\r\n\r\n"
  }
]