Showing preview only (1,808K chars total). Download the full file or copy to clipboard to get everything.
Repository: shenyangHuang/TGB
Branch: main
Commit: 00d688b038b7
Files: 213
Total size: 1.7 MB
Directory structure:
gitextract_3nm00l7d/
├── .devcontainer/
│ ├── .gitignore
│ ├── Dockerfile
│ └── devcontainer.json
├── .github/
│ └── workflows/
│ ├── mkdocs.yaml
│ └── pypi.yaml
├── .gitignore
├── LICENSE
├── README.md
├── docs/
│ ├── about.md
│ ├── api/
│ │ ├── tgb.linkproppred.md
│ │ ├── tgb.nodeproppred.md
│ │ └── tgb.utils.md
│ ├── index.md
│ └── tutorials/
│ ├── Edge_data_numpy.ipynb
│ ├── Edge_data_pyg.ipynb
│ └── Node_label_tutorial.ipynb
├── examples/
│ ├── linkproppred/
│ │ ├── tgbl-coin/
│ │ │ ├── dyrep.py
│ │ │ ├── edgebank.py
│ │ │ └── tgn.py
│ │ ├── tgbl-comment/
│ │ │ ├── dyrep.py
│ │ │ ├── edgebank.py
│ │ │ └── tgn.py
│ │ ├── tgbl-enron/
│ │ │ └── edgebank.py
│ │ ├── tgbl-flight/
│ │ │ ├── dyrep.py
│ │ │ ├── edgebank.py
│ │ │ └── tgn.py
│ │ ├── tgbl-lastfm/
│ │ │ ├── edgebank.py
│ │ │ └── tgn.py
│ │ ├── tgbl-review/
│ │ │ ├── dyrep.py
│ │ │ ├── edgebank.py
│ │ │ └── tgn.py
│ │ ├── tgbl-subreddit/
│ │ │ ├── edgebank.py
│ │ │ └── tgn.py
│ │ ├── tgbl-uci/
│ │ │ └── edgebank.py
│ │ ├── tgbl-wiki/
│ │ │ ├── dyrep.py
│ │ │ ├── edgebank.py
│ │ │ └── tgn.py
│ │ ├── thgl-forum/
│ │ │ ├── edgebank.py
│ │ │ ├── recurrencybaseline.py
│ │ │ ├── sthn.py
│ │ │ └── tgn.py
│ │ ├── thgl-github/
│ │ │ ├── edgebank.py
│ │ │ ├── recurrencybaseline.py
│ │ │ ├── run_seeds.sh
│ │ │ ├── sthn.py
│ │ │ └── tgn.py
│ │ ├── thgl-myket/
│ │ │ ├── edgebank.py
│ │ │ ├── recurrencybaseline.py
│ │ │ ├── sthn.py
│ │ │ └── tgn.py
│ │ ├── thgl-software/
│ │ │ ├── STHN_README.md
│ │ │ ├── edgebank.py
│ │ │ ├── recurrencybaseline.py
│ │ │ ├── sthn.py
│ │ │ └── tgn.py
│ │ ├── tkgl-icews/
│ │ │ ├── cen.py
│ │ │ ├── edgebank.py
│ │ │ ├── recurrencybaseline.py
│ │ │ ├── regcn.py
│ │ │ ├── timetraveler.py
│ │ │ ├── tkgl-icews_example.py
│ │ │ └── tlogic.py
│ │ ├── tkgl-polecat/
│ │ │ ├── cen.py
│ │ │ ├── edgebank.py
│ │ │ ├── example.py
│ │ │ ├── recurrencybaseline.py
│ │ │ ├── regcn.py
│ │ │ ├── timetraveler.py
│ │ │ ├── tkgl-polecat_example.py
│ │ │ └── tlogic.py
│ │ ├── tkgl-smallpedia/
│ │ │ ├── cen.py
│ │ │ ├── edgebank.py
│ │ │ ├── recurrencybaseline.py
│ │ │ ├── regcn.py
│ │ │ ├── timetraveler.py
│ │ │ ├── tkgl-smallpedia_example.py
│ │ │ └── tlogic.py
│ │ ├── tkgl-wikidata/
│ │ │ ├── edgebank.py
│ │ │ ├── recurrencybaseline.py
│ │ │ ├── regcn.py
│ │ │ ├── tkgl-wikidata_example.py
│ │ │ └── tlogic.py
│ │ └── tkgl-yago/
│ │ ├── cen.py
│ │ ├── edgebank.py
│ │ ├── recurrencybaseline.py
│ │ ├── regcn.py
│ │ ├── timetraveler.py
│ │ ├── tkgl-yago_example.py
│ │ └── tlogic.py
│ └── nodeproppred/
│ ├── tgbn-genre/
│ │ ├── dyrep.py
│ │ ├── moving_average.py
│ │ ├── persistant_forecast.py
│ │ └── tgn.py
│ ├── tgbn-reddit/
│ │ ├── dyrep.py
│ │ ├── moving_average.py
│ │ ├── persistant_forecast.py
│ │ └── tgn.py
│ ├── tgbn-token/
│ │ ├── dyrep.py
│ │ ├── moving_average.py
│ │ ├── persistant_forecast.py
│ │ └── tgn.py
│ └── tgbn-trade/
│ ├── count_new_nodes.py
│ ├── dyrep.py
│ ├── moving_average.py
│ ├── persistant_forecast.py
│ └── tgn.py
├── mkdocs.yml
├── modules/
│ ├── decoder.py
│ ├── early_stopping.py
│ ├── edgebank_predictor.py
│ ├── emb_module.py
│ ├── heuristics.py
│ ├── memory_module.py
│ ├── msg_agg.py
│ ├── msg_func.py
│ ├── neighbor_loader.py
│ ├── nodebank.py
│ ├── recurrencybaseline_predictor.py
│ ├── rgcn_layers.py
│ ├── rgcn_model.py
│ ├── rrgcn.py
│ ├── sampler_core.cpp
│ ├── sthn.py
│ ├── sthn_sampler_setup.py
│ ├── time_enc.py
│ ├── timetraveler_agent.py
│ ├── timetraveler_dirichlet.py
│ ├── timetraveler_environment.py
│ ├── timetraveler_episode.py
│ ├── timetraveler_policygradient.py
│ ├── timetraveler_trainertester.py
│ ├── tkg_utils.py
│ ├── tkg_utils_dgl.py
│ ├── tlogic_apply_modules.py
│ └── tlogic_learn_modules.py
├── pyproject.toml
├── run.sh
├── scripts/
│ ├── env.sh
│ ├── mila.sh
│ ├── mila_install.sh
│ └── run.sh
├── setup.py
└── tgb/
├── datasets/
│ ├── ICEWS14/
│ │ ├── ent2word.py
│ │ └── icews14.py
│ ├── dataset_scripts/
│ │ ├── MAG/
│ │ │ ├── mag.py
│ │ │ └── old/
│ │ │ └── plot_stats.py
│ │ ├── dgraph.py
│ │ ├── dgraph_Readme.md
│ │ ├── process_arxiv.py
│ │ ├── process_github.py
│ │ ├── tgbl-coin.py
│ │ ├── tgbl-coin_neg_generator.py
│ │ ├── tgbl-comment.py
│ │ ├── tgbl-comment_neg_generator.py
│ │ ├── tgbl-flight.py
│ │ ├── tgbl-flight_neg_generator.py
│ │ ├── tgbl-review.py
│ │ ├── tgbl-review_neg_generator.py
│ │ ├── tgbl-wiki_neg_generator.py
│ │ ├── tgbn-genre.py
│ │ ├── tgbn-reddit.py
│ │ ├── tgbn-token.py
│ │ └── tgbn-trade.py
│ ├── tgbl_enron/
│ │ ├── tgbl-enron_neg_generator.py
│ │ └── tgbl_enron.py
│ ├── tgbl_lastfm/
│ │ └── tgbl-lastfm_neg_generator.py
│ ├── tgbl_subreddit/
│ │ └── tgbl-subreddit_neg_generator.py
│ ├── tgbl_uci/
│ │ ├── tgbl-uci_neg_generator.py
│ │ └── tgbl_uci.py
│ ├── thgl_forum/
│ │ ├── merge_files.py
│ │ ├── thgl-forum.py
│ │ └── thgl_forum_ns_gen.py
│ ├── thgl_github/
│ │ ├── 2024_01/
│ │ │ └── github_extract.py
│ │ ├── 2024_02/
│ │ │ └── github_extract.py
│ │ ├── 2024_03/
│ │ │ └── github_extract.py
│ │ ├── extract_subset.py
│ │ ├── thgl_github.py
│ │ └── thgl_github_ns_gen.py
│ ├── thgl_myket/
│ │ ├── thgl_myket.py
│ │ └── thgl_myket_ns_gen.py
│ ├── thgl_software/
│ │ ├── thgl_software.py
│ │ └── thgl_software_ns_gen.py
│ ├── tkgl_icews/
│ │ ├── tkgl_icews.py
│ │ └── tkgl_icews_ns_gen.py
│ ├── tkgl_polecat/
│ │ ├── tkgl_polecat.py
│ │ └── tkgl_polecat_ns_gen.py
│ ├── tkgl_smallpedia/
│ │ ├── smallpedia_remove_conflict.py
│ │ └── tkgl_smallpedia_ns_gen.py
│ ├── tkgl_wikidata/
│ │ ├── extract.sh
│ │ ├── time_edges/
│ │ │ └── tkgl-wikidata_extract.py
│ │ ├── tkgl-wikidata.py
│ │ ├── tkgl_wikidata_mining.py
│ │ ├── tkgl_wikidata_ns_gen.py
│ │ └── wikidata_remove_conflict.py
│ └── tkgl_yago/
│ ├── tkgl_yago.py
│ └── tkgl_yago_ns_gen.py
├── linkproppred/
│ ├── dataset.py
│ ├── dataset_pyg.py
│ ├── evaluate.py
│ ├── negative_generator.py
│ ├── negative_sampler.py
│ ├── thg_negative_generator.py
│ ├── thg_negative_sampler.py
│ ├── tkg_negative_generator.py
│ └── tkg_negative_sampler.py
├── nodeproppred/
│ ├── dataset.py
│ ├── dataset_pyg.py
│ └── evaluate.py
└── utils/
├── dataset_stats.py
├── info.py
├── pre_process.py
├── stats.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .devcontainer/.gitignore
================================================
!devcontainer.json
================================================
FILE: .devcontainer/Dockerfile
================================================
FROM mcr.microsoft.com/devcontainers/python:3.10
RUN python -m pip install --no-cache-dir --upgrade pip poetry \
&& pip install --no-cache-dir --extra-index-url https://download.pytorch.org/whl/cu118 'torch>=2.0.0' \
&& pip install --no-cache-dir --find-links https://data.pyg.org/whl/torch-2.0.0+cu118.html torch_geometric pyg_lib torch_scatter torch_sparse
ENV POETRY_VIRTUALENVS_CREATE=false
COPY pyproject.toml poetry.lock* /tmp/poetry/
RUN poetry -C /tmp/poetry --no-cache install --no-root --no-directory
================================================
FILE: .devcontainer/devcontainer.json
================================================
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
// README at: https://github.com/devcontainers/templates/tree/main/src/python
{
"name": "py-tgb",
"build": {
"dockerfile": "Dockerfile",
"context": ".."
},
"customizations": {
"vscode": {
"extensions": [
"editorconfig.editorconfig",
"github.vscode-pull-request-github",
"ms-azuretools.vscode-docker",
"ms-python.python",
"ms-python.vscode-pylance",
"ms-python.pylint",
"ms-python.isort",
"ms-python.flake8",
"ms-python.black-formatter",
"ms-vsliveshare.vsliveshare",
"ryanluker.vscode-coverage-gutters",
"bungcip.better-toml",
"GitHub.copilot",
"redhat.vscode-yaml"
],
"settings": {
"python.defaultInterpreterPath": "/usr/local/bin/python",
"black-formatter.path": [
"/usr/local/py-utils/bin/black"
],
"pylint.path": [
"/usr/local/py-utils/bin/pylint"
],
"flake8.path": [
"/usr/local/py-utils/bin/flake8"
],
"isort.path": [
"/usr/local/py-utils/bin/isort"
]
}
}
},
"features": {
"ghcr.io/devcontainers-contrib/features/act:1": {},
"ghcr.io/stuartleeks/dev-container-features/shell-history:0": {},
"ghcr.io/devcontainers/features/common-utils:2": {}
},
"postCreateCommand": "poetry --no-cache install --only main"
}
================================================
FILE: .github/workflows/mkdocs.yaml
================================================
name: mkdocs
on:
push:
# branches:
# - master
# - main
tags:
- "v*.*.*"
permissions:
contents: write
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: 3.x
- uses: actions/cache@v3
with:
key: mkdocs-material-${{ github.ref }}
path: .cache
restore-keys: |
mkdocs-material-
- run: pip install mkdocs-material mkdocstrings-python mkdocs-jupyter
- run: mkdocs gh-deploy --force
================================================
FILE: .github/workflows/pypi.yaml
================================================
# https://github.com/JRubics/poetry-publish
name: Publish to PyPI
on:
push:
tags:
- "v*.*.*"
jobs:
publish:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
- name: Build and publish to pypi
uses: JRubics/poetry-publish@v1.17
with:
pypi_token: ${{ secrets.PYPI_API_TOKEN }}
================================================
FILE: .gitignore
================================================
!requirements*.txt
get_croissant.py
#dataset
stats_figures/
figs/
*.xz
*.dict
*.tab
*.npz
*.xz
*.parquet
*.gz
*.tar
*.pdf
*.csv
*.zip
*.json
*.npy
*.pt
*.out
*.pkl
*.txt
*.attr
*.edge
.DS_Store
store_files/
# Byte-compiled / optimized / DLL files
__pycache__/
raw/
books/
electronics/
software/
*.py[cod]
*$py.class
saved_models/
dump/
saved_results/
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
__pycache__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
cc_env.sh
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2023 TGB Team
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
<!-- # TGB -->

**Temporal Graph Benchmark for Machine Learning on Temporal Graphs** (NeurIPS 2023 Datasets and Benchmarks Track)
<h3>
<a href="https://proceedings.neurips.cc/paper_files/paper/2023/hash/066b98e63313162f6562b35962671288-Abstract-Datasets_and_Benchmarks.html"><img src="https://img.shields.io/badge/Paper-link-important"></a>
<a href="https://arxiv.org/abs/2307.01026"><img src="https://img.shields.io/badge/arXiv-pdf-yellowgreen"></a>
<a href="https://pypi.org/project/py-tgb/"><img src="https://img.shields.io/pypi/v/py-tgb.svg?color=brightgreen"></a>
<a href="https://tgb.complexdatalab.com/"><img src="https://img.shields.io/badge/website-blue"></a>
<a href="https://docs.tgb.complexdatalab.com/"><img src="https://img.shields.io/badge/docs-orange"></a>
</h3>
**TGB 2.0: A Benchmark for Learning on Temporal Knowledge Graphs and Heterogeneous Graphs** (NeurIPS 2024 Datasets and Benchmarks Track)
<h3>
<a href="https://openreview.net/forum?id=EADRzNJFn1#discussion"><img src="https://img.shields.io/badge/Paper-link-important"></a>
<a href="https://arxiv.org/abs/2406.09639v1"><img src="https://img.shields.io/badge/arXiv-pdf-yellowgreen"></a>
<a href="https://pypi.org/project/py-tgb/"><img src="https://img.shields.io/pypi/v/py-tgb.svg?color=brightgreen"></a>
<a href="https://tgb.complexdatalab.com/"><img src="https://img.shields.io/badge/website-blue"></a>
</h3>
Overview of the Temporal Graph Benchmark (TGB) pipeline:
- TGB includes large-scale and realistic datasets from 10 different domains with both dynamic link prediction and node property prediction tasks.
- TGB automatically downloads datasets and processes them into `numpy`, `PyTorch` and `PyG compatible TemporalData` formats.
- Novel TG models can be easily evaluated on TGB datasets via reproducible and realistic evaluation protocols.
- TGB provides public and online leaderboards to track recent developments in temporal graph learning domain.
- Now TGB supports temporal homogeneous graphs, temporal knowledge graphs and temporal heterogenenous graph datasets.

**To submit to [TGB leaderboard](https://tgb.complexdatalab.com/), please fill in this [google form](https://forms.gle/SEsXvN1QHo9tSFwx9)**
**See all version differences and update notes [here](https://tgb.complexdatalab.com/docs/update/)**
### Announcements
**Excited to announce TGB 2.0, has been presented at NeurIPS 2024 Datasets and Benchmarks Track**
See our [camera ready version](https://openreview.net/forum?id=EADRzNJFn1#discussion) and [arXiv version](https://arxiv.org/abs/2307.01026) for details. Please [install locally](https://tgb.complexdatalab.com/docs/home/) first. We welcome your feedback and suggestions.
**Excited to announce TGX, a companion package for analyzing temporal graphs in WSDM 2024 Demo Track**
TGX supports all TGB datasets and provides numerous temporal graph visualization plots and statistics out of the box. See our paper: [Temporal Graph Analysis with TGX](https://arxiv.org/abs/2402.03651) and [TGX website](https://complexdata-mila.github.io/TGX/).
<!-- **Excited to announce that TGB has been accepted to NeurIPS 2023 Datasets and Benchmarks Track**
Thanks to everyone for your help in improving TGB! we will continue to improve TGB based on your feedback and suggestions. -->
**Please update to version `2.2.0`**
#### version `2.2.0`
Adding license for TGB software (for dataset license please check TGB website).
Printing messages now will not automatically set to stdout, use `TGB_VERBOSE=True` in your shell to set the print to be verbose.
Default option is to automatically download the datasets (rather than command line input as before).
#### version `2.1.0`
Includes supplementary datasets `tgbl-lastfm` `tgbl-enron` `tgbl-uci` `tgbl-subreddit` for research purposes.
For more details, see the release notes
#### version `2.0.0`
Includes all new datasets from TGB 2.0 including temporal knowledge graphs and temporal heterogeneous graphs.
<!--
#### version `0.9.2`
Update the fix for `tgbl-flight` where now the unix timestamps are provided directly in the dataset. If you had issues with `tgbl-flight`, please remove `TGB/tgb/datasets/tgbl_flight`and redownload the dataset for a clean install --> -->
<!--
#### version `0.9.1`
Fixed an issue for `tgbl-flight` where the timestamp conversion is incorrect due to time zone differences. If you had issues with `tgbl-flight` before, please update your package.
#### version `0.9.0`
Added the large `tgbn-token` dataset with 72 million edges to the `nodeproppred` dataset.
Fixed errors in `tgbl-coin` and `tgbl-flight` where a small set of edges are not sorted chronologically. Please update your dataset version for them to version 2 (will be prompted in terminal). -->
### Pip Install
You can install TGB via [pip](https://pypi.org/project/py-tgb/). **Requires python >= 3.9**
```
pip install py-tgb
```
### Links and Datasets
The project website can be found [here](https://tgb.complexdatalab.com/).
The API documentations can be found [here](https://shenyanghuang.github.io/TGB/).
all dataset download links can be found at [info.py](https://github.com/shenyangHuang/TGB/blob/main/tgb/utils/info.py)
TGB dataloader will also automatically download the dataset as well as the negative samples for the link property prediction datasets.
if website is unaccessible, please use [this link](https://tgb-website.pages.dev/) instead.
### Running Example Methods
- For the dynamic link property prediction task, see the [`examples/linkproppred`](https://github.com/shenyangHuang/TGB/tree/main/examples/linkproppred) folder for example scripts to run TGN, DyRep and EdgeBank on TGB datasets.
- For the dynamic node property prediction task, see the [`examples/nodeproppred`](https://github.com/shenyangHuang/TGB/tree/main/examples/nodeproppred) folder for example scripts to run TGN, DyRep and EdgeBank on TGB datasets.
- For all other baselines, please see the [TGB_Baselines](https://github.com/fpour/TGB_Baselines) repo.
### Acknowledgments
We thank the [OGB](https://ogb.stanford.edu/) team for their support throughout this project and sharing their website code for the construction of [TGB website](https://tgb.complexdatalab.com/).
### Software License
The code from this repo is licensed under the MIT License (see LICENSE)
### Citation
If code or data from this repo is useful for your project, please consider citing our TGB and TGB 2.0 paper:
```
@article{huang2023temporal,
title={Temporal graph benchmark for machine learning on temporal graphs},
author={Huang, Shenyang and Poursafaei, Farimah and Danovitch, Jacob and Fey, Matthias and Hu, Weihua and Rossi, Emanuele and Leskovec, Jure and Bronstein, Michael and Rabusseau, Guillaume and Rabbany, Reihaneh},
journal={Advances in Neural Information Processing Systems},
year={2023}
}
```
```
@article{huang2024tgb2,
title={TGB 2.0: A Benchmark for Learning on Temporal Knowledge Graphs and Heterogeneous Graphs},
author={Gastinger, Julia and Huang, Shenyang and Galkin, Mikhail and Loghmani, Erfan and Parviz, Ali and Poursafaei, Farimah and Danovitch, Jacob and Rossi, Emanuele and Koutis, Ioannis and Stuckenschmidt, Heiner and Rabbany, Reihaneh and Rabusseau, Guillaume},
journal={Advances in Neural Information Processing Systems},
year={2024}
}
```
<!--
### Install dependency
Our implementation works with python >= 3.9 and can be installed as follows
1. set up virtual environment (conda should work as well)
```
python -m venv ~/tgb_env/
source ~/tgb_env/bin/activate
```
2. install external packages
```
pip install pandas==1.5.3
pip install matplotlib==3.7.1
pip install clint==0.5.1
```
install Pytorch and PyG dependencies (needed to run the examples)
```
pip install torch==2.0.0 --index-url https://download.pytorch.org/whl/cu117
pip install torch_geometric==2.3.0
pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.0.0+cu117.html
```
3. install local dependencies under root directory `/TGB`
```
pip install -e .
```
### Instruction for tracking new documentation and running mkdocs locally
1. first run the mkdocs server locally in your terminal
```
mkdocs serve
```
2. go to the local hosted web address similar to
```
[14:18:13] Browser connected: http://127.0.0.1:8000/
```
Example: to track documentation of a new hi.py file in tgb/edgeregression/hi.py
3. create docs/api/tgb.hi.md and add the following
```
# `tgb.edgeregression`
::: tgb.edgeregression.hi
```
4. edit mkdocs.yml
```
nav:
- Overview: index.md
- About: about.md
- API:
other *.md files
- tgb.edgeregression: api/tgb.hi.md
```
### Creating new branch ###
```
git fetch origin
git checkout -b test origin/test
```
### dependencies for mkdocs (documentation)
```
pip install mkdocs
pip install mkdocs-material
pip install mkdocstrings-python
pip install mkdocs-jupyter
pip install notebook
```
### full dependency list
Our implementation works with python >= 3.9 and has the following dependencies
```
pytorch == 2.0.0
torch-geometric == 2.3.0
torch-scatter==2.1.1
torch-sparse==0.6.17
torch-spline-conv==1.2.2
pandas==1.5.3
clint==0.5.1
``` -->
================================================
FILE: docs/about.md
================================================
# Temporal Graph Benchmark (TGB)

## Overview
The TGB repo provides an automated ML pipeline for learning on a diverse set of temporal graph datasets:
- automatic download of datasets from url
- processing the raw files into ML ready format
- support datasets in `numpy`, `Pytorch` and `PyG TemporalData` formats
- evaluation code for each dataset
================================================
FILE: docs/api/tgb.linkproppred.md
================================================
# `tgb.linkproppred`
::: tgb.linkproppred.dataset
::: tgb.linkproppred.dataset_pyg
::: tgb.linkproppred.evaluate
::: tgb.linkproppred.negative_sampler
::: tgb.linkproppred.negative_generator
::: tgb.linkproppred.tkg_negative_generator
::: tgb.linkproppred.tkg_negative_sampler
::: tgb.linkproppred.thg_negative_generator
::: tgb.linkproppred.thg_negative_sampler
================================================
FILE: docs/api/tgb.nodeproppred.md
================================================
# `tgb.nodeproppred`
::: tgb.nodeproppred.dataset
::: tgb.nodeproppred.dataset_pyg
::: tgb.nodeproppred.evaluate
================================================
FILE: docs/api/tgb.utils.md
================================================
# `tgb.utils`
::: tgb.utils.pre_process
::: tgb.utils.utils
::: tgb.utils.info
::: tgb.utils.stats
================================================
FILE: docs/index.md
================================================
# Welcome to Temporal Graph Benchmark

### Pip Install
You can install TGB via [pip](https://pypi.org/project/py-tgb/)
```
pip install py-tgb
```
### Links and Datasets
The project website can be found [here](https://tgb.complexdatalab.com/).
The API documentations can be found [here](https://shenyanghuang.github.io/TGB/).
all dataset download links can be found at [info.py](https://github.com/shenyangHuang/TGB/blob/main/tgb/utils/info.py)
TGB dataloader will also automatically download the dataset as well as the negative samples for the link property prediction datasets.
### Install dependency
Our implementation works with python >= 3.9 and can be installed as follows
1. set up virtual environment (conda should work as well)
```
python -m venv ~/tgb_env/
source ~/tgb_env/bin/activate
```
2. install external packages
```
pip install pandas==1.5.3
pip install matplotlib==3.7.1
pip install clint==0.5.1
```
install Pytorch and PyG dependencies (needed to run the examples)
```
pip install torch==2.0.0 --index-url https://download.pytorch.org/whl/cu117
pip install torch_geometric==2.3.0
pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.0.0+cu117.html
```
3. install local dependencies under root directory `/TGB`
```
pip install -e .
```
### Instruction for tracking new documentation and running mkdocs locally
1. first run the mkdocs server locally in your terminal
```
mkdocs serve
```
2. go to the local hosted web address similar to
```
[14:18:13] Browser connected: http://127.0.0.1:8000/
```
Example: to track documentation of a new hi.py file in tgb/edgeregression/hi.py
3. create docs/api/tgb.hi.md and add the following
```
# `tgb.edgeregression`
::: tgb.edgeregression.hi
```
4. edit mkdocs.yml
```
nav:
- Overview: index.md
- About: about.md
- API:
other *.md files
- tgb.edgeregression: api/tgb.hi.md
```
### Creating new branch ###
```
git fetch origin
git checkout -b test origin/test
```
### dependencies for mkdocs (documentation)
```
pip install mkdocs
pip install mkdocs-material
pip install mkdocstrings-python
pip install mkdocs-jupyter
pip install notebook
```
### full dependency list
Our implementation works with python >= 3.9 and has the following dependencies
```
pytorch == 2.0.0
torch-geometric == 2.3.0
torch-scatter==2.1.1
torch-sparse==0.6.17
torch-spline-conv==1.2.2
pandas==1.5.3
clint==0.5.1
```
<!-- ## Code blocks
`pip install tgb` -->
<!--
### Plain codeblock
A plain codeblock:
```
Some code here
def myfunction()
// some comment
```
#### Code for a specific language
Some more code with the `py` at the start:
``` py
import tensorflow as tf
def whatever()
```
#### With a title
``` py title="bubble_sort.py"
def bubble_sort(items):
for i in range(len(items)):
for j in range(len(items) - 1 - i):
if items[j] > items[j + 1]:
items[j], items[j + 1] = items[j + 1], items[j]
```
#### With line numbers
``` py linenums="1"
def bubble_sort(items):
for i in range(len(items)):
for j in range(len(items) - 1 - i):
if items[j] > items[j + 1]:
items[j], items[j + 1] = items[j + 1], items[j]
```
#### Highlighting lines
``` py hl_lines="2 3"
def bubble_sort(items):
for i in range(len(items)):
for j in range(len(items) - 1 - i):
if items[j] > items[j + 1]:
items[j], items[j + 1] = items[j + 1], items[j]
```
## Icons and Emojs
:smile:
:fontawesome-regular-face-laugh-wink:
:fontawesome-brands-twitter:{ .twitter }
:octicons-heart-fill-24:{ .heart } -->
================================================
FILE: docs/tutorials/Edge_data_numpy.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"id": "d5e3f5a2",
"metadata": {},
"source": [
"# Access edge data as numpy arrays\n",
"\n",
"This tutorial will show you how to access various datasets and their corresponding edgelists in `tgb`\n",
"\n",
"You can directly retrieve the edge data as `numpy` arrays, `PyG` and `Pytorch` dependencies are not necessary\n",
"\n",
"The logic is implemented in `dataset.py` under `tgb/linkproppred/` and `tgb/nodeproppred/` folders respectively\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "23f00c08",
"metadata": {},
"outputs": [],
"source": [
"from tgb.linkproppred.dataset import LinkPropPredDataset"
]
},
{
"cell_type": "markdown",
"id": "60e52b7b",
"metadata": {},
"source": [
"specifying the name of the dataset"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "48888070",
"metadata": {},
"outputs": [],
"source": [
"name = \"tgbl-wiki\" "
]
},
{
"cell_type": "markdown",
"id": "3511804a",
"metadata": {},
"source": [
"### process and loading the dataset\n",
"\n",
"if the dataset has been processed, it will be loaded from disc for fast access\n",
"\n",
"if the dataset has not been downloaded, it will be processed automatically"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "8486fa82",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Will you download the dataset(s) now? (y/N)\n",
"y\n",
"\u001b[93mDownload started, this might take a while . . . \u001b[0m\n",
"Dataset title: tgbl-wiki\n",
"\u001b[92mDownload completed \u001b[0m\n",
"Dataset directory is /mnt/f/code/TGB/tgb/datasets/tgbl_wiki\n",
"file not processed, generating processed file\n"
]
},
{
"data": {
"text/plain": [
"tgb.linkproppred.dataset.LinkPropPredDataset"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset = LinkPropPredDataset(name=name, root=\"datasets\", preprocess=True)\n",
"type(dataset)"
]
},
{
"cell_type": "markdown",
"id": "47c949b4",
"metadata": {},
"source": [
"### Accessing the edge data\n",
"\n",
"the edge data can be easily accessed via the property of the method as `numpy` arrays "
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "9e4e7421",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"dict"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = dataset.full_data #a dictionary stores all the edge data\n",
"type(data) "
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "c6ec9ac0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"numpy.ndarray"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"type(data['sources'])\n",
"type(data['destinations'])\n",
"type(data['timestamps'])\n",
"type(data['edge_feat'])\n",
"type(data['w'])\n",
"type(data['edge_label']) #just all one array as all edges in the dataset are positive edges\n",
"type(data['edge_idxs']) #just index of the edges increment by 1 for each edge"
]
},
{
"cell_type": "markdown",
"id": "bb1bbfd6",
"metadata": {},
"source": [
"### Accessing the train, test, val split\n",
"\n",
"the masks for training, validation, and test split can be accessed directly from the `dataset` as well"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "8cd3507c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"numpy.ndarray"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_mask = dataset.train_mask\n",
"val_mask = dataset.val_mask\n",
"test_mask = dataset.test_mask\n",
"\n",
"type(train_mask)\n",
"type(val_mask)\n",
"type(test_mask)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cf5eff06",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
================================================
FILE: docs/tutorials/Edge_data_pyg.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"id": "d5e3f5a2",
"metadata": {},
"source": [
"# Access edge data in Pytorch Geometric\n",
"\n",
"This tutorial will show you how to access various datasets and their corresponding edgelists in `tgb`\n",
"\n",
"The logic for PyG data is stored in `dataset_pyg.py` in `tgb/linkproppred` and `tgb/nodeproppred` folders\n",
"\n",
"This tutorial requires `Pytorch` and `PyG`, refer to `README.md` for installation instructions"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "23f00c08",
"metadata": {},
"outputs": [],
"source": [
"from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset"
]
},
{
"cell_type": "markdown",
"id": "60e52b7b",
"metadata": {},
"source": [
"specifying the name of the dataset"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "48888070",
"metadata": {},
"outputs": [],
"source": [
"name = \"tgbl-wiki\""
]
},
{
"cell_type": "markdown",
"id": "3511804a",
"metadata": {},
"source": [
"### Process and load the dataset\n",
"\n",
"if the dataset has been processed, it will be loaded from disc for fast access\n",
"\n",
"if the dataset has not been downloaded, it will be processed automatically"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "8486fa82",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"file found, skipping download\n",
"Dataset directory is /mnt/f/code/TGB/tgb/datasets/tgbl_wiki\n",
"loading processed file\n"
]
},
{
"data": {
"text/plain": [
"tgb.linkproppred.dataset_pyg.PyGLinkPropPredDataset"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset = PyGLinkPropPredDataset(name=name, root=\"datasets\")\n",
"type(dataset)"
]
},
{
"cell_type": "markdown",
"id": "47c949b4",
"metadata": {},
"source": [
"### Access edge data from TemporalData object \n",
"\n",
"You can retrieve `torch_geometric.data.temporal.TemporalData` directly from `PyGLinkPropPredDataset`"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "9e4e7421",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch_geometric.data.temporal.TemporalData"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = dataset.get_TemporalData()\n",
"type(data)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "c6ec9ac0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Tensor"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"type(data.src)\n",
"type(data.dst)\n",
"type(data.t)\n",
"type(data.msg)"
]
},
{
"cell_type": "markdown",
"id": "52fd601f",
"metadata": {},
"source": [
"### Directly access edge data as Pytorch tensors\n",
"\n",
"the edge data can be easily accessed via the property of the method, these are converted into pytorch tensors (from `PyGLinkPropPredDataset`)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "56fb3347",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Tensor"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"type(dataset.src) #same as src from above\n",
"type(dataset.dst) #same as dst\n",
"type(dataset.ts) #same as t\n",
"type(dataset.edge_feat) #same as msg\n",
"type(dataset.edge_label) #same as label used in tgn"
]
},
{
"cell_type": "markdown",
"id": "bb1bbfd6",
"metadata": {},
"source": [
"### Accessing the train, test, val split\n",
"\n",
"the masks for training, validation, and test split can be accessed directly from the `dataset` as well"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "8cd3507c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Tensor"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_mask = dataset.train_mask\n",
"val_mask = dataset.val_mask\n",
"test_mask = dataset.test_mask\n",
"\n",
"type(train_mask)\n",
"type(val_mask)\n",
"type(test_mask)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9d6ed432",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
================================================
FILE: docs/tutorials/Node_label_tutorial.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"id": "d5e3f5a2",
"metadata": {},
"source": [
"# Access node labels for Dynamic Node Property Prediction\n",
"\n",
"This tutorial will show you how to access node labels and edge data for the node property prediction datasets in `tgb`.\n",
"\n",
"The source code is stored in `dataset_pyg.py` in `tgb/nodeproppred` folder\n",
"\n",
"This tutorial requires `Pytorch` and `PyG`, refer to `README.md` for installation instructions\n",
"\n",
"This tutorial uses `PyG TemporalData` object, however it is possible to use `numpy` arrays as well.\n",
"\n",
"see examples in `examples/nodeproppred` folder for more details.\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "23f00c08",
"metadata": {},
"outputs": [],
"source": [
"from tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset\n",
"from torch_geometric.loader import TemporalDataLoader"
]
},
{
"cell_type": "markdown",
"id": "60e52b7b",
"metadata": {},
"source": [
"specifying the name of the dataset"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "48888070",
"metadata": {},
"outputs": [],
"source": [
"name = \"tgbn-genre\""
]
},
{
"cell_type": "markdown",
"id": "3511804a",
"metadata": {},
"source": [
"### Process and load the dataset\n",
"\n",
"if the dataset has been processed, it will be loaded from disc for fast access\n",
"\n",
"if the dataset has not been downloaded, it will be processed automatically"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "8486fa82",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"file found, skipping download\n",
"Dataset directory is /mnt/f/code/TGB/tgb/datasets/tgbn_genre\n",
"loading processed file\n"
]
},
{
"data": {
"text/plain": [
"tgb.nodeproppred.dataset_pyg.PyGNodePropPredDataset"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset = PyGNodePropPredDataset(name=name, root=\"datasets\")\n",
"type(dataset)"
]
},
{
"cell_type": "markdown",
"id": "31338262",
"metadata": {},
"source": [
"### Train, Validation and Test splits with dataloaders\n",
"\n",
"splitting the edges into train, val, test sets and construct dataloader for each"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "27b4f6a1",
"metadata": {},
"outputs": [],
"source": [
"train_mask = dataset.train_mask\n",
"val_mask = dataset.val_mask\n",
"test_mask = dataset.test_mask\n",
"\n",
"\n",
"data = dataset.get_TemporalData()\n",
"\n",
"train_data = data[train_mask]\n",
"val_data = data[val_mask]\n",
"test_data = data[test_mask]\n",
"\n",
"batch_size = 200\n",
"train_loader = TemporalDataLoader(train_data, batch_size=batch_size)\n",
"val_loader = TemporalDataLoader(val_data, batch_size=batch_size)\n",
"test_loader = TemporalDataLoader(test_data, batch_size=batch_size)\n"
]
},
{
"cell_type": "markdown",
"id": "47c949b4",
"metadata": {},
"source": [
"### Access node label data \n",
"\n",
"In `tgb`, the node label data are queried based on the nearest edge observed so far and retrieves the node label data for the corresponding day. \n",
"\n",
"Note that this is because the node labels often have different timestamps from the edges thus should be processed at the correct time in the edge stream.\n",
"\n",
"In the example below, we show how to iterate through the edges and retrieve the node labels of the corresponding time. "
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "9e4e7421",
"metadata": {},
"outputs": [],
"source": [
"#query the timestamps for the first node labels\n",
"label_t = dataset.get_label_time()\n",
"\n",
"for batch in train_loader:\n",
" #access the edges in this batch\n",
" src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg\n",
" query_t = batch.t[-1]\n",
" # check if this batch moves to the next day\n",
" if query_t > label_t:\n",
" # find the node labels from the past day\n",
" label_tuple = dataset.get_node_label(query_t)\n",
" # node labels are structured as a tuple with (timestamps, source node, label) format, label is a vector\n",
" label_ts, label_srcs, labels = (\n",
" label_tuple[0],\n",
" label_tuple[1],\n",
" label_tuple[2],\n",
" )\n",
" label_t = dataset.get_label_time()\n",
"\n",
" #insert your code for backproping with node labels here\n",
" "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
================================================
FILE: examples/linkproppred/tgbl-coin/dyrep.py
================================================
"""
DyRep
This has been implemented with intuitions from the following sources:
- https://github.com/twitter-research/tgn
- https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py
Spec.:
- Memory Updater: RNN
- Embedding Module: ID
- Message Function: ATTN
command for an example run:
python examples/linkproppred/tgbl-coin/dyrep.py --data "tgbl-coin" --num_run 1 --seed 1
"""
import math
import timeit
import os
import os.path as osp
from pathlib import Path
import numpy as np
import torch
from torch_geometric.loader import TemporalDataLoader
# internal imports
from tgb.utils.utils import get_args, set_random_seed, save_results
from tgb.linkproppred.evaluate import Evaluator
from modules.decoder import LinkPredictor
from modules.emb_module import GraphAttentionEmbedding
from modules.msg_func import IdentityMessage
from modules.msg_agg import LastAggregator
from modules.neighbor_loader import LastNeighborLoader
from modules.memory_module import DyRepMemory
from modules.early_stopping import EarlyStopMonitor
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
# ==========
# ========== Define helper function...
# ==========
def train():
r"""
Training procedure for DyRep model
This function uses some objects that are globally defined in the current scrips
Parameters:
None
Returns:
None
"""
model['memory'].train()
model['gnn'].train()
model['link_pred'].train()
model['memory'].reset_state() # Start with a fresh memory.
neighbor_loader.reset_state() # Start with an empty graph.
total_loss = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
# Sample negative destination nodes.
neg_dst = torch.randint(
min_dst_idx,
max_dst_idx + 1,
(src.size(0),),
dtype=torch.long,
device=device,
)
n_id = torch.cat([src, pos_dst, neg_dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])
neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])
loss = criterion(pos_out, torch.ones_like(pos_out))
loss += criterion(neg_out, torch.zeros_like(neg_out))
# update the memory with ground-truth
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
model['memory'].update_state(src, pos_dst, t, msg, z, assoc)
# update neighbor loader
neighbor_loader.insert(src, pos_dst)
loss.backward()
optimizer.step()
model['memory'].detach()
total_loss += float(loss.detach()) * batch.num_events
return total_loss / train_data.num_events
@torch.no_grad()
def test(loader, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
loader: an object containing positive attributes of the positive edges of the evaluation set
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluation
"""
model['memory'].eval()
model['gnn'].eval()
model['link_pred'].eval()
perf_list = []
for pos_batch in loader:
pos_src, pos_dst, pos_t, pos_msg = (
pos_batch.src,
pos_batch.dst,
pos_batch.t,
pos_batch.msg,
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)
dst = torch.tensor(
np.concatenate(
([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),
axis=0,
),
device=device,
)
n_id = torch.cat([src, dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),
"y_pred_neg": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# update the memory with positive edges
n_id = torch.cat([pos_src, pos_dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg, z, assoc)
# update the neighbor loader
neighbor_loader.insert(pos_src, pos_dst)
perf_metric = float(torch.tensor(perf_list).mean())
return perf_metric
# ==========
# ==========
# ==========
# Start...
start_overall = timeit.default_timer()
# ========== set parameters...
args, _ = get_args()
print("INFO: Arguments:", args)
DATA = "tgbl-coin"
LR = args.lr
BATCH_SIZE = args.bs
K_VALUE = args.k_value
NUM_EPOCH = args.num_epoch
SEED = args.seed
MEM_DIM = args.mem_dim
TIME_DIM = args.time_dim
EMB_DIM = args.emb_dim
TOLERANCE = args.tolerance
PATIENCE = args.patience
NUM_RUNS = args.num_run
NUM_NEIGHBORS = 10
MODEL_NAME = 'DyRep'
USE_SRC_EMB_IN_MSG = False
USE_DST_EMB_IN_MSG = True
# ==========
# set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# data loading
dataset = PyGLinkPropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data = data.to(device)
metric = dataset.eval_metric
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
train_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)
val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)
test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
# neighborhood sampler
neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)
# define the model end-to-end
# 1) memory
memory = DyRepMemory(
data.num_nodes,
data.msg.size(-1),
MEM_DIM,
TIME_DIM,
message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),
aggregator_module=LastAggregator(),
memory_updater_type='rnn',
use_src_emb_in_msg=USE_SRC_EMB_IN_MSG,
use_dst_emb_in_msg=USE_DST_EMB_IN_MSG
).to(device)
# 2) GNN
gnn = GraphAttentionEmbedding(
in_channels=MEM_DIM,
out_channels=EMB_DIM,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
).to(device)
# 3) link predictor
link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)
model = {'memory': memory,
'gnn': gnn,
'link_pred': link_pred}
# define an optimizer
optimizer = torch.optim.Adam(
set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),
lr=LR,
)
criterion = torch.nn.BCEWithLogitsLoss()
# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)
print("==========================================================")
print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
for run_idx in range(NUM_RUNS):
print('-------------------------------------------------------------------------------')
print(f"INFO: >>>>> Run: {run_idx} <<<<<")
start_run = timeit.default_timer()
# set the seed for deterministic results...
torch.manual_seed(run_idx + SEED)
set_random_seed(run_idx + SEED)
# define an early stopper
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'
early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id,
tolerance=TOLERANCE, patience=PATIENCE)
# ==================================================== Train & Validation
# loading the validation negative samples
dataset.load_val_ns()
val_perf_list = []
start_train_val = timeit.default_timer()
for epoch in range(1, NUM_EPOCH + 1):
# training
start_epoch_train = timeit.default_timer()
loss = train()
print(
f"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {timeit.default_timer() - start_epoch_train: .4f}"
)
# validation
start_val = timeit.default_timer()
perf_metric_val = test(val_loader, neg_sampler, split_mode="val")
print(f"\tValidation {metric}: {perf_metric_val: .4f}")
print(f"\tValidation: Elapsed time (s): {timeit.default_timer() - start_val: .4f}")
val_perf_list.append(perf_metric_val)
# check for early stopping
if early_stopper.step_check(perf_metric_val, model):
break
train_val_time = timeit.default_timer() - start_train_val
print(f"Train & Validation Total Elapsed Time (s): {train_val_time: .4f}")
# ==================================================== Test
# first, load the best model
early_stopper.load_checkpoint(model)
# loading the test negative samples
dataset.load_test_ns()
# final testing
start_test = timeit.default_timer()
perf_metric_test = test(test_loader, neg_sampler, split_mode="test")
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'data': DATA,
'run': run_idx,
'seed': SEED,
f'val {metric}': val_perf_list,
f'test {metric}': perf_metric_test,
'test_time': test_time,
'tot_train_val_time': train_val_time
},
results_filename)
print(f"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<")
print('-------------------------------------------------------------------------------')
print(f"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}")
print("==============================================================")
================================================
FILE: examples/linkproppred/tgbl-coin/edgebank.py
================================================
"""
Dynamic Link Prediction with EdgeBank
NOTE: This implementation works only based on `numpy`
Reference:
- https://github.com/fpour/DGB/tree/main
"""
import timeit
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch_geometric.loader import TemporalDataLoader
from tqdm import tqdm
import math
import os
import os.path as osp
from pathlib import Path
import sys
import argparse
# internal imports
from tgb.linkproppred.evaluate import Evaluator
from modules.edgebank_predictor import EdgeBankPredictor
from tgb.utils.utils import set_random_seed
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.utils.utils import save_results
# ==================
# ==================
# ==================
def test(data, test_mask, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
data: a dataset object
test_mask: required masks to load the test set edges
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluation
"""
num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)
perf_list = []
for batch_idx in tqdm(range(num_batches)):
start_idx = batch_idx * BATCH_SIZE
end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))
pos_src, pos_dst, pos_t = (
data['sources'][test_mask][start_idx: end_idx],
data['destinations'][test_mask][start_idx: end_idx],
data['timestamps'][test_mask][start_idx: end_idx],
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])
query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])
y_pred = edgebank.predict_link(query_src, query_dst)
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0]]),
"y_pred_neg": np.array(y_pred[1:]),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# update edgebank memory after each positive batch
edgebank.update_memory(pos_src, pos_dst, pos_t)
perf_metrics = float(np.mean(perf_list))
return perf_metrics
def get_args():
parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')
parser.add_argument('-d', '--data', type=str, help='Dataset name', default='tgbl-coin')
parser.add_argument('--bs', type=int, help='Batch size', default=200)
parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)
parser.add_argument('--seed', type=int, help='Random seed', default=1)
parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])
parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)
try:
args = parser.parse_args()
except:
parser.print_help()
sys.exit(0)
return args, sys.argv
# ==================
# ==================
# ==================
start_overall = timeit.default_timer()
# set hyperparameters
args, _ = get_args()
SEED = args.seed # set the random seed for consistency
set_random_seed(SEED)
MEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`
BATCH_SIZE = args.bs
K_VALUE = args.k_value
TIME_WINDOW_RATIO = args.time_window_ratio
DATA = "tgbl-coin"
MODEL_NAME = 'EdgeBank'
# data loading with `numpy`
dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
data = dataset.full_data
metric = dataset.eval_metric
# get masks
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
#data for memory in edgebank
hist_src = np.concatenate([data['sources'][train_mask]])
hist_dst = np.concatenate([data['destinations'][train_mask]])
hist_ts = np.concatenate([data['timestamps'][train_mask]])
# #! check if edges are sorted
# sorted = np.all(np.diff(data['timestamps']) >= 0)
# print (" INFO: Edges are sorted: ", sorted)
# Set EdgeBank with memory updater
edgebank = EdgeBankPredictor(
hist_src,
hist_dst,
hist_ts,
memory_mode=MEMORY_MODE,
time_window_ratio=TIME_WINDOW_RATIO)
print("==========================================================")
print(f"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'
# ==================================================== Test
# loading the validation negative samples
dataset.load_val_ns()
# testing ...
start_val = timeit.default_timer()
perf_metric_test = test(data, val_mask, neg_sampler, split_mode='val')
end_val = timeit.default_timer()
print(f"INFO: val: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tval: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_val
print(f"\tval: Elapsed Time (s): {test_time: .4f}")
# ==================================================== Test
# loading the test negative samples
dataset.load_test_ns()
# testing ...
start_test = timeit.default_timer()
perf_metric_test = test(data, test_mask, neg_sampler, split_mode='test')
end_test = timeit.default_timer()
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'memory_mode': MEMORY_MODE,
'data': DATA,
'run': 1,
'seed': SEED,
metric: perf_metric_test,
'test_time': test_time,
'tot_train_val_time': 'NA'
},
results_filename)
================================================
FILE: examples/linkproppred/tgbl-coin/tgn.py
================================================
"""
Dynamic Link Prediction with a TGN model with Early Stopping
Reference:
- https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py
command for an example run:
python examples/linkproppred/tgbl-coin/tgn.py --data "tgbl-coin" --num_run 1 --seed 1
"""
import math
import timeit
import os
import os.path as osp
from pathlib import Path
import numpy as np
import torch
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn import Linear
from torch_geometric.datasets import JODIEDataset
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TransformerConv
# internal imports
from tgb.utils.utils import get_args, set_random_seed, save_results
from tgb.linkproppred.evaluate import Evaluator
from modules.decoder import LinkPredictor
from modules.emb_module import GraphAttentionEmbedding
from modules.msg_func import IdentityMessage
from modules.msg_agg import LastAggregator
from modules.neighbor_loader import LastNeighborLoader
from modules.memory_module import TGNMemory
from modules.early_stopping import EarlyStopMonitor
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
# ==========
# ========== Define helper function...
# ==========
def train():
r"""
Training procedure for TGN model
This function uses some objects that are globally defined in the current scrips
Parameters:
None
Returns:
None
"""
model['memory'].train()
model['gnn'].train()
model['link_pred'].train()
model['memory'].reset_state() # Start with a fresh memory.
neighbor_loader.reset_state() # Start with an empty graph.
total_loss = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
# Sample negative destination nodes.
neg_dst = torch.randint(
min_dst_idx,
max_dst_idx + 1,
(src.size(0),),
dtype=torch.long,
device=device,
)
n_id = torch.cat([src, pos_dst, neg_dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])
neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])
loss = criterion(pos_out, torch.ones_like(pos_out))
loss += criterion(neg_out, torch.zeros_like(neg_out))
# Update memory and neighbor loader with ground-truth state.
model['memory'].update_state(src, pos_dst, t, msg)
neighbor_loader.insert(src, pos_dst)
loss.backward()
optimizer.step()
model['memory'].detach()
total_loss += float(loss.detach()) * batch.num_events
return total_loss / train_data.num_events
@torch.no_grad()
def test(loader, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
loader: an object containing positive attributes of the positive edges of the evaluation set
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluation
"""
model['memory'].eval()
model['gnn'].eval()
model['link_pred'].eval()
perf_list = []
for pos_batch in loader:
pos_src, pos_dst, pos_t, pos_msg = (
pos_batch.src,
pos_batch.dst,
pos_batch.t,
pos_batch.msg,
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)
dst = torch.tensor(
np.concatenate(
([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),
axis=0,
),
device=device,
)
n_id = torch.cat([src, dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),
"y_pred_neg": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# Update memory and neighbor loader with ground-truth state.
model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg)
neighbor_loader.insert(pos_src, pos_dst)
perf_metrics = float(torch.tensor(perf_list).mean())
return perf_metrics
# ==========
# ==========
# ==========
# Start...
start_overall = timeit.default_timer()
DATA = "tgbl-coin"
# ========== set parameters...
args, _ = get_args()
args.data = DATA
print("INFO: Arguments:", args)
LR = args.lr
BATCH_SIZE = args.bs
K_VALUE = args.k_value
NUM_EPOCH = args.num_epoch
SEED = args.seed
MEM_DIM = args.mem_dim
TIME_DIM = args.time_dim
EMB_DIM = args.emb_dim
TOLERANCE = args.tolerance
PATIENCE = args.patience
NUM_RUNS = args.num_run
NUM_NEIGHBORS = 10
MODEL_NAME = 'TGN'
# ==========
# set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# data loading
dataset = PyGLinkPropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data = data.to(device)
metric = dataset.eval_metric
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
train_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)
val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)
test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
print("==========================================================")
print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
for run_idx in range(NUM_RUNS):
print('-------------------------------------------------------------------------------')
print(f"INFO: >>>>> Run: {run_idx} <<<<<")
start_run = timeit.default_timer()
# set the seed for deterministic results...
torch.manual_seed(run_idx + SEED)
set_random_seed(run_idx + SEED)
# neighborhood sampler
neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)
# define the model end-to-end
memory = TGNMemory(
data.num_nodes,
data.msg.size(-1),
MEM_DIM,
TIME_DIM,
message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),
aggregator_module=LastAggregator(),
).to(device)
gnn = GraphAttentionEmbedding(
in_channels=MEM_DIM,
out_channels=EMB_DIM,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
).to(device)
link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)
model = {'memory': memory,
'gnn': gnn,
'link_pred': link_pred}
optimizer = torch.optim.Adam(
set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),
lr=LR,
)
criterion = torch.nn.BCEWithLogitsLoss()
# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)
# define an early stopper
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'
early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id,
tolerance=TOLERANCE, patience=PATIENCE)
# ==================================================== Train & Validation
# loading the validation negative samples
dataset.load_val_ns()
val_perf_list = []
start_train_val = timeit.default_timer()
for epoch in range(1, NUM_EPOCH + 1):
# training
start_epoch_train = timeit.default_timer()
loss = train()
print(
f"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {timeit.default_timer() - start_epoch_train: .4f}"
)
# validation
start_val = timeit.default_timer()
perf_metric_val = test(val_loader, neg_sampler, split_mode="val")
print(f"\tValidation {metric}: {perf_metric_val: .4f}")
print(f"\tValidation: Elapsed time (s): {timeit.default_timer() - start_val: .4f}")
val_perf_list.append(perf_metric_val)
# check for early stopping
if early_stopper.step_check(perf_metric_val, model):
break
train_val_time = timeit.default_timer() - start_train_val
print(f"Train & Validation: Elapsed Time (s): {train_val_time: .4f}")
# ==================================================== Test
# first, load the best model
early_stopper.load_checkpoint(model)
# loading the test negative samples
dataset.load_test_ns()
# final testing
start_test = timeit.default_timer()
perf_metric_test = test(test_loader, neg_sampler, split_mode="test")
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'data': DATA,
'run': run_idx,
'seed': SEED,
f'val {metric}': val_perf_list,
f'test {metric}': perf_metric_test,
'test_time': test_time,
'tot_train_val_time': train_val_time
},
results_filename)
print(f"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<")
print('-------------------------------------------------------------------------------')
print(f"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}")
print("==============================================================")
================================================
FILE: examples/linkproppred/tgbl-comment/dyrep.py
================================================
"""
DyRep
This has been implemented with intuitions from the following sources:
- https://github.com/twitter-research/tgn
- https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py
Spec.:
- Memory Updater: RNN
- Embedding Module: ID
- Message Function: ATTN
"""
import math
import timeit
import os
import os.path as osp
from pathlib import Path
import numpy as np
import torch
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn import Linear
from torch_geometric.datasets import JODIEDataset
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TransformerConv
# internal imports
from tgb.utils.utils import get_args, set_random_seed, save_results
from tgb.linkproppred.evaluate import Evaluator
from modules.decoder import LinkPredictor
from modules.emb_module import GraphAttentionEmbedding
from modules.msg_func import IdentityMessage
from modules.msg_agg import LastAggregator
from modules.neighbor_loader import LastNeighborLoader
from modules.memory_module import DyRepMemory
from modules.early_stopping import EarlyStopMonitor
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
# ==========
# ========== Define helper function...
# ==========
def train():
model['memory'].train()
model['gnn'].train()
model['link_pred'].train()
model['memory'].reset_state() # Start with a fresh memory.
neighbor_loader.reset_state() # Start with an empty graph.
total_loss = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
# Sample negative destination nodes.
neg_dst = torch.randint(
min_dst_idx,
max_dst_idx + 1,
(src.size(0),),
dtype=torch.long,
device=device,
)
n_id = torch.cat([src, pos_dst, neg_dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])
neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])
loss = criterion(pos_out, torch.ones_like(pos_out))
loss += criterion(neg_out, torch.zeros_like(neg_out))
# update the memory with ground-truth
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
model['memory'].update_state(src, pos_dst, t, msg, z, assoc)
# update neighbor loader
neighbor_loader.insert(src, pos_dst)
loss.backward()
optimizer.step()
model['memory'].detach()
total_loss += float(loss.detach()) * batch.num_events
return total_loss / train_data.num_events
@torch.no_grad()
def test_one_vs_many(loader, neg_sampler, split_mode):
"""
Evaluated the dynamic link prediction
"""
model['memory'].eval()
model['gnn'].eval()
model['link_pred'].eval()
perf_list = []
for pos_batch in loader:
pos_src, pos_dst, pos_t, pos_msg = (
pos_batch.src,
pos_batch.dst,
pos_batch.t,
pos_batch.msg,
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)
dst = torch.tensor(
np.concatenate(
([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),
axis=0,
),
device=device,
)
n_id = torch.cat([src, dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),
"y_pred_neg": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# update the memory with positive edges
n_id = torch.cat([pos_src, pos_dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg, z, assoc)
# update the neighbor loader
neighbor_loader.insert(pos_src, pos_dst)
perf_metric = float(torch.tensor(perf_list).mean())
return perf_metric
# ==========
# ==========
# ==========
# Start...
start_overall = timeit.default_timer()
# ========== set parameters...
args, _ = get_args()
print("INFO: Arguments:", args)
DATA = "tgbl-comment"
LR = args.lr
BATCH_SIZE = args.bs
K_VALUE = args.k_value
NUM_EPOCH = args.num_epoch
SEED = args.seed
MEM_DIM = args.mem_dim
TIME_DIM = args.time_dim
EMB_DIM = args.emb_dim
TOLERANCE = args.tolerance
PATIENCE = args.patience
NUM_RUNS = args.num_run
NUM_NEIGHBORS = 10
MODEL_NAME = 'DyRep'
USE_SRC_EMB_IN_MSG = False
USE_DST_EMB_IN_MSG = True
# ==========
# set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# data loading
dataset = PyGLinkPropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data = data.to(device)
metric = dataset.eval_metric
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
train_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)
val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)
test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
# neighborhood sampler
neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)
# define the model end-to-end
memory = DyRepMemory(
data.num_nodes,
data.msg.size(-1),
MEM_DIM,
TIME_DIM,
message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),
aggregator_module=LastAggregator(),
memory_updater_type='rnn',
use_src_emb_in_msg=USE_SRC_EMB_IN_MSG,
use_dst_emb_in_msg=USE_DST_EMB_IN_MSG
).to(device)
gnn = GraphAttentionEmbedding(
in_channels=MEM_DIM,
out_channels=EMB_DIM,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
).to(device)
link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)
model = {'memory': memory,
'gnn': gnn,
'link_pred': link_pred}
optimizer = torch.optim.Adam(
set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),
lr=LR,
)
criterion = torch.nn.BCEWithLogitsLoss()
# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)
print("==========================================================")
print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
for run_idx in range(NUM_RUNS):
print('-------------------------------------------------------------------------------')
print(f"INFO: >>>>> Run: {run_idx} <<<<<")
start_run = timeit.default_timer()
# set the seed for deterministic results...
torch.manual_seed(run_idx + SEED)
set_random_seed(run_idx + SEED)
# define an early stopper
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'
early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id,
tolerance=TOLERANCE, patience=PATIENCE)
# ==================================================== Train & Validation
# loading the validation negative samples
dataset.load_val_ns()
val_perf_list = []
start_train_val = timeit.default_timer()
for epoch in range(1, NUM_EPOCH + 1):
# training
start_epoch_train = timeit.default_timer()
loss = train()
print(
f"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {timeit.default_timer() - start_epoch_train: .4f}"
)
# validation
start_val = timeit.default_timer()
perf_metric_val = test_one_vs_many(val_loader, neg_sampler, split_mode="val")
print(f"\tValidation {metric}: {perf_metric_val: .4f}")
print(f"\tValidation: Elapsed time (s): {timeit.default_timer() - start_val: .4f}")
val_perf_list.append(perf_metric_val)
# check for early stopping
if early_stopper.step_check(perf_metric_val, model):
break
train_val_time = timeit.default_timer() - start_train_val
print(f"Train & Validation: Elapsed Time (s): {train_val_time: .4f}")
# ==================================================== Test
# first, load the best model
early_stopper.load_checkpoint(model)
# loading the test negative samples
dataset.load_test_ns()
# final testing
start_test = timeit.default_timer()
perf_metric_test = test_one_vs_many(test_loader, neg_sampler, split_mode="test")
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'data': DATA,
'run': run_idx,
'seed': SEED,
f'val {metric}': val_perf_list,
f'test {metric}': perf_metric_test,
'test_time': test_time,
'tot_train_val_time': train_val_time
},
results_filename)
print(f"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<")
print('-------------------------------------------------------------------------------')
print(f"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}")
print("==============================================================")
================================================
FILE: examples/linkproppred/tgbl-comment/edgebank.py
================================================
"""
Dynamic Link Prediction with EdgeBank
NOTE: This implementation works only based on `numpy`
Reference:
- https://github.com/fpour/DGB/tree/main
"""
import timeit
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch_geometric.loader import TemporalDataLoader
from tqdm import tqdm
import math
import os
import os.path as osp
from pathlib import Path
import sys
import argparse
# internal imports
from tgb.linkproppred.evaluate import Evaluator
from modules.edgebank_predictor import EdgeBankPredictor
from tgb.utils.utils import set_random_seed
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.utils.utils import save_results
# ==================
# ==================
# ==================
def test(data, test_mask, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
data: a dataset object
test_mask: required masks to load the test set edges
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluation
"""
num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)
perf_list = []
for batch_idx in tqdm(range(num_batches)):
start_idx = batch_idx * BATCH_SIZE
end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))
pos_src, pos_dst, pos_t = (
data['sources'][test_mask][start_idx: end_idx],
data['destinations'][test_mask][start_idx: end_idx],
data['timestamps'][test_mask][start_idx: end_idx],
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])
query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])
y_pred = edgebank.predict_link(query_src, query_dst)
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0]]),
"y_pred_neg": np.array(y_pred[1:]),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# update edgebank memory after each positive batch
edgebank.update_memory(pos_src, pos_dst, pos_t)
perf_metrics = float(np.mean(perf_list))
return perf_metrics
def get_args():
parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')
parser.add_argument('-d', '--data', type=str, help='Dataset name', default='tgbl-comment')
parser.add_argument('--bs', type=int, help='Batch size', default=200)
parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)
parser.add_argument('--seed', type=int, help='Random seed', default=1)
parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])
parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)
try:
args = parser.parse_args()
except:
parser.print_help()
sys.exit(0)
return args, sys.argv
# ==================
# ==================
# ==================
start_overall = timeit.default_timer()
# set hyperparameters
args, _ = get_args()
SEED = args.seed # set the random seed for consistency
set_random_seed(SEED)
MEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`
BATCH_SIZE = args.bs
K_VALUE = args.k_value
TIME_WINDOW_RATIO = args.time_window_ratio
DATA = "tgbl-comment"
MODEL_NAME = 'EdgeBank'
# data loading with `numpy`
dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
data = dataset.full_data
metric = dataset.eval_metric
# get masks
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
#data for memory in edgebank
hist_src = np.concatenate([data['sources'][train_mask]])
hist_dst = np.concatenate([data['destinations'][train_mask]])
hist_ts = np.concatenate([data['timestamps'][train_mask]])
# Set EdgeBank with memory updater
edgebank = EdgeBankPredictor(
hist_src,
hist_dst,
hist_ts,
memory_mode=MEMORY_MODE,
time_window_ratio=TIME_WINDOW_RATIO)
print("==========================================================")
print(f"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'
# ==================================================== Test
# loading the validation negative samples
dataset.load_val_ns()
# testing ...
start_val = timeit.default_timer()
perf_metric_test = test(data, val_mask, neg_sampler, split_mode='val')
end_val = timeit.default_timer()
print(f"INFO: val: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tval: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_val
print(f"\tval: Elapsed Time (s): {test_time: .4f}")
# ==================================================== Test
# loading the test negative samples
dataset.load_test_ns()
# testing ...
start_test = timeit.default_timer()
perf_metric_test = test(data, test_mask, neg_sampler, split_mode='test')
end_test = timeit.default_timer()
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'memory_mode': MEMORY_MODE,
'data': DATA,
'run': 1,
'seed': SEED,
metric: perf_metric_test,
'test_time': test_time,
'tot_train_val_time': 'NA'
},
results_filename)
================================================
FILE: examples/linkproppred/tgbl-comment/tgn.py
================================================
"""
Dynamic Link Prediction with a TGN model with Early Stopping
Reference:
- https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py
command for an example run:
python examples/linkproppred/tgbl-comment/tgn.py --data "tgbl-comment" --num_run 1 --seed 1
"""
import math
import timeit
import os
import os.path as osp
from pathlib import Path
import numpy as np
import torch
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn import Linear
from torch_geometric.datasets import JODIEDataset
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TransformerConv
# internal imports
from tgb.utils.utils import get_args, set_random_seed, save_results
from tgb.linkproppred.evaluate import Evaluator
from modules.decoder import LinkPredictor
from modules.emb_module import GraphAttentionEmbedding
from modules.msg_func import IdentityMessage
from modules.msg_agg import LastAggregator
from modules.neighbor_loader import LastNeighborLoader
from modules.memory_module import TGNMemory
from modules.early_stopping import EarlyStopMonitor
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
# ==========
# ========== Define helper function...
# ==========
def train():
r"""
Training procedure for TGN model
This function uses some objects that are globally defined in the current scrips
Parameters:
None
Returns:
None
"""
model['memory'].train()
model['gnn'].train()
model['link_pred'].train()
model['memory'].reset_state() # Start with a fresh memory.
neighbor_loader.reset_state() # Start with an empty graph.
total_loss = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
# Sample negative destination nodes.
neg_dst = torch.randint(
min_dst_idx,
max_dst_idx + 1,
(src.size(0),),
dtype=torch.long,
device=device,
)
n_id = torch.cat([src, pos_dst, neg_dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])
neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])
loss = criterion(pos_out, torch.ones_like(pos_out))
loss += criterion(neg_out, torch.zeros_like(neg_out))
# Update memory and neighbor loader with ground-truth state.
model['memory'].update_state(src, pos_dst, t, msg)
neighbor_loader.insert(src, pos_dst)
loss.backward()
optimizer.step()
model['memory'].detach()
total_loss += float(loss.detach()) * batch.num_events
return total_loss / train_data.num_events
@torch.no_grad()
def test(loader, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
loader: an object containing positive attributes of the positive edges of the evaluation set
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluation
"""
model['memory'].eval()
model['gnn'].eval()
model['link_pred'].eval()
perf_list = []
for pos_batch in loader:
pos_src, pos_dst, pos_t, pos_msg = (
pos_batch.src,
pos_batch.dst,
pos_batch.t,
pos_batch.msg,
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)
dst = torch.tensor(
np.concatenate(
([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),
axis=0,
),
device=device,
)
n_id = torch.cat([src, dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),
"y_pred_neg": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# Update memory and neighbor loader with ground-truth state.
model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg)
neighbor_loader.insert(pos_src, pos_dst)
perf_metrics = float(torch.tensor(perf_list).mean())
return perf_metrics
# ==========
# ==========
# ==========
# Start...
start_overall = timeit.default_timer()
DATA = "tgbl-comment"
# ========== set parameters...
args, _ = get_args()
args.data = DATA
print("INFO: Arguments:", args)
LR = args.lr
BATCH_SIZE = args.bs
K_VALUE = args.k_value
NUM_EPOCH = args.num_epoch
SEED = args.seed
MEM_DIM = args.mem_dim
TIME_DIM = args.time_dim
EMB_DIM = args.emb_dim
TOLERANCE = args.tolerance
PATIENCE = args.patience
NUM_RUNS = args.num_run
NUM_NEIGHBORS = 10
MODEL_NAME = 'TGN'
# ==========
# set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# data loading
dataset = PyGLinkPropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data = data.to(device)
metric = dataset.eval_metric
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
train_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)
val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)
test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
print("==========================================================")
print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
for run_idx in range(NUM_RUNS):
print('-------------------------------------------------------------------------------')
print(f"INFO: >>>>> Run: {run_idx} <<<<<")
start_run = timeit.default_timer()
# set the seed for deterministic results...
torch.manual_seed(run_idx + SEED)
set_random_seed(run_idx + SEED)
# neighborhood sampler
neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)
# define the model end-to-end
memory = TGNMemory(
data.num_nodes,
data.msg.size(-1),
MEM_DIM,
TIME_DIM,
message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),
aggregator_module=LastAggregator(),
).to(device)
gnn = GraphAttentionEmbedding(
in_channels=MEM_DIM,
out_channels=EMB_DIM,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
).to(device)
link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)
model = {'memory': memory,
'gnn': gnn,
'link_pred': link_pred}
optimizer = torch.optim.Adam(
set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),
lr=LR,
)
criterion = torch.nn.BCEWithLogitsLoss()
# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)
# define an early stopper
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'
early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id,
tolerance=TOLERANCE, patience=PATIENCE)
# ==================================================== Train & Validation
# loading the validation negative samples
dataset.load_val_ns()
val_perf_list = []
start_train_val = timeit.default_timer()
for epoch in range(1, NUM_EPOCH + 1):
# training
start_epoch_train = timeit.default_timer()
loss = train()
print(
f"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {timeit.default_timer() - start_epoch_train: .4f}"
)
# validation
start_val = timeit.default_timer()
perf_metric_val = test(val_loader, neg_sampler, split_mode="val")
print(f"\tValidation {metric}: {perf_metric_val: .4f}")
print(f"\tValidation: Elapsed time (s): {timeit.default_timer() - start_val: .4f}")
val_perf_list.append(perf_metric_val)
# check for early stopping
if early_stopper.step_check(perf_metric_val, model):
break
train_val_time = timeit.default_timer() - start_train_val
print(f"Train & Validation: Elapsed Time (s): {train_val_time: .4f}")
# ==================================================== Test
# first, load the best model
early_stopper.load_checkpoint(model)
# loading the test negative samples
dataset.load_test_ns()
# final testing
start_test = timeit.default_timer()
perf_metric_test = test(test_loader, neg_sampler, split_mode="test")
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'data': DATA,
'run': run_idx,
'seed': SEED,
f'val {metric}': val_perf_list,
f'test {metric}': perf_metric_test,
'test_time': test_time,
'tot_train_val_time': train_val_time
},
results_filename)
print(f"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<")
print('-------------------------------------------------------------------------------')
print(f"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}")
print("==============================================================")
================================================
FILE: examples/linkproppred/tgbl-enron/edgebank.py
================================================
"""
Dynamic Link Prediction with EdgeBank
NOTE: This implementation works only based on `numpy`
Reference:
- https://github.com/fpour/DGB/tree/main
"""
import timeit
import numpy as np
from tqdm import tqdm
import math
import os
import os.path as osp
from pathlib import Path
import sys
import argparse
# internal imports
from tgb.linkproppred.evaluate import Evaluator
from modules.edgebank_predictor import EdgeBankPredictor
from tgb.utils.utils import set_random_seed
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.utils.utils import save_results
# ==================
# ==================
# ==================
def test(data, test_mask, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
data: a dataset object
test_mask: required masks to load the test set edges
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluaiton
"""
num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)
perf_list = []
for batch_idx in tqdm(range(num_batches)):
start_idx = batch_idx * BATCH_SIZE
end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))
pos_src, pos_dst, pos_t = (
data['sources'][test_mask][start_idx: end_idx],
data['destinations'][test_mask][start_idx: end_idx],
data['timestamps'][test_mask][start_idx: end_idx],
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])
query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])
y_pred = edgebank.predict_link(query_src, query_dst)
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0]]),
"y_pred_neg": np.array(y_pred[1:]),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# update edgebank memory after each positive batch
edgebank.update_memory(pos_src, pos_dst, pos_t)
perf_metrics = float(np.mean(perf_list))
return perf_metrics
def get_args():
parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')
parser.add_argument('-d', '--data', type=str, help='Dataset name', default='tgbl-wiki')
parser.add_argument('--bs', type=int, help='Batch size', default=200)
parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)
parser.add_argument('--seed', type=int, help='Random seed', default=1)
parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])
parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)
try:
args = parser.parse_args()
except:
parser.print_help()
sys.exit(0)
return args, sys.argv
# ==================
# ==================
# ==================
start_overall = timeit.default_timer()
# set hyperparameters
args, _ = get_args()
SEED = args.seed # set the random seed for consistency
set_random_seed(SEED)
MEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`
BATCH_SIZE = args.bs
K_VALUE = args.k_value
TIME_WINDOW_RATIO = args.time_window_ratio
DATA = "tgbl-enron"
MODEL_NAME = 'EdgeBank'
# data loading with `numpy`
dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
data = dataset.full_data
metric = dataset.eval_metric
# get masks
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
#data for memory in edgebank
hist_src = np.concatenate([data['sources'][train_mask]])
hist_dst = np.concatenate([data['destinations'][train_mask]])
hist_ts = np.concatenate([data['timestamps'][train_mask]])
# Set EdgeBank with memory updater
edgebank = EdgeBankPredictor(
hist_src,
hist_dst,
hist_ts,
memory_mode=MEMORY_MODE,
time_window_ratio=TIME_WINDOW_RATIO)
print("==========================================================")
print(f"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'
# ==================================================== Test
# loading the validation negative samples
dataset.load_val_ns()
# testing ...
start_val = timeit.default_timer()
perf_metric_test = test(data, val_mask, neg_sampler, split_mode='val')
end_val = timeit.default_timer()
print(f"INFO: val: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tval: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_val
print(f"\tval: Elapsed Time (s): {test_time: .4f}")
# ==================================================== Test
# loading the test negative samples
dataset.load_test_ns()
# testing ...
start_test = timeit.default_timer()
perf_metric_test = test(data, test_mask, neg_sampler, split_mode='test')
end_test = timeit.default_timer()
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'memory_mode': MEMORY_MODE,
'data': DATA,
'run': 1,
'seed': SEED,
metric: perf_metric_test,
'test_time': test_time,
'tot_train_val_time': 'NA'
},
results_filename)
================================================
FILE: examples/linkproppred/tgbl-flight/dyrep.py
================================================
"""
DyRep
This has been implemented with intuitions from the following sources:
- https://github.com/twitter-research/tgn
- https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py
Spec.:
- Memory Updater: RNN
- Embedding Module: ID
- Message Function: ATTN
command for an example run:
python examples/linkproppred/tgbl-flight/dyrep.py --data "tgbl-flight" --num_run 1 --seed 1
"""
import math
import timeit
import os
import os.path as osp
from pathlib import Path
import numpy as np
import torch
from torch_geometric.loader import TemporalDataLoader
# internal imports
from tgb.utils.utils import get_args, set_random_seed, save_results
from tgb.linkproppred.evaluate import Evaluator
from modules.decoder import LinkPredictor
from modules.emb_module import GraphAttentionEmbedding
from modules.msg_func import IdentityMessage
from modules.msg_agg import LastAggregator
from modules.neighbor_loader import LastNeighborLoader
from modules.memory_module import DyRepMemory
from modules.early_stopping import EarlyStopMonitor
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
# ==========
# ========== Define helper function...
# ==========
def train():
r"""
Training procedure for DyRep model
This function uses some objects that are globally defined in the current scrips
Parameters:
None
Returns:
None
"""
model['memory'].train()
model['gnn'].train()
model['link_pred'].train()
model['memory'].reset_state() # Start with a fresh memory.
neighbor_loader.reset_state() # Start with an empty graph.
total_loss = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
# Sample negative destination nodes.
neg_dst = torch.randint(
min_dst_idx,
max_dst_idx + 1,
(src.size(0),),
dtype=torch.long,
device=device,
)
n_id = torch.cat([src, pos_dst, neg_dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])
neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])
loss = criterion(pos_out, torch.ones_like(pos_out))
loss += criterion(neg_out, torch.zeros_like(neg_out))
# update the memory with ground-truth
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
model['memory'].update_state(src, pos_dst, t, msg, z, assoc)
# update neighbor loader
neighbor_loader.insert(src, pos_dst)
loss.backward()
optimizer.step()
model['memory'].detach()
total_loss += float(loss.detach()) * batch.num_events
return total_loss / train_data.num_events
@torch.no_grad()
def test(loader, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
loader: an object containing positive attributes of the positive edges of the evaluation set
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluation
"""
model['memory'].eval()
model['gnn'].eval()
model['link_pred'].eval()
perf_list = []
for pos_batch in loader:
pos_src, pos_dst, pos_t, pos_msg = (
pos_batch.src,
pos_batch.dst,
pos_batch.t,
pos_batch.msg,
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)
dst = torch.tensor(
np.concatenate(
([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),
axis=0,
),
device=device,
)
n_id = torch.cat([src, dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),
"y_pred_neg": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# update the memory with positive edges
n_id = torch.cat([pos_src, pos_dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg, z, assoc)
# update the neighbor loader
neighbor_loader.insert(pos_src, pos_dst)
perf_metric = float(torch.tensor(perf_list).mean())
return perf_metric
# ==========
# ==========
# ==========
# Start...
start_overall = timeit.default_timer()
# ========== set parameters...
args, _ = get_args()
print("INFO: Arguments:", args)
DATA = "tgbl-flight"
LR = args.lr
BATCH_SIZE = args.bs
K_VALUE = args.k_value
NUM_EPOCH = args.num_epoch
SEED = args.seed
MEM_DIM = args.mem_dim
TIME_DIM = args.time_dim
EMB_DIM = args.emb_dim
TOLERANCE = args.tolerance
PATIENCE = args.patience
NUM_RUNS = args.num_run
NUM_NEIGHBORS = 10
MODEL_NAME = 'DyRep'
USE_SRC_EMB_IN_MSG = False
USE_DST_EMB_IN_MSG = True
# ==========
# set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# data loading
dataset = PyGLinkPropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data = data.to(device)
metric = dataset.eval_metric
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
train_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)
val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)
test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
# neighborhood sampler
neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)
# define the model end-to-end
# 1) memory
memory = DyRepMemory(
data.num_nodes,
data.msg.size(-1),
MEM_DIM,
TIME_DIM,
message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),
aggregator_module=LastAggregator(),
memory_updater_type='rnn',
use_src_emb_in_msg=USE_SRC_EMB_IN_MSG,
use_dst_emb_in_msg=USE_DST_EMB_IN_MSG
).to(device)
# 2) GNN
gnn = GraphAttentionEmbedding(
in_channels=MEM_DIM,
out_channels=EMB_DIM,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
).to(device)
# 3) link predictor
link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)
model = {'memory': memory,
'gnn': gnn,
'link_pred': link_pred}
# define an optimizer
optimizer = torch.optim.Adam(
set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),
lr=LR,
)
criterion = torch.nn.BCEWithLogitsLoss()
# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)
print("==========================================================")
print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
for run_idx in range(NUM_RUNS):
print('-------------------------------------------------------------------------------')
print(f"INFO: >>>>> Run: {run_idx} <<<<<")
start_run = timeit.default_timer()
# set the seed for deterministic results...
torch.manual_seed(run_idx + SEED)
set_random_seed(run_idx + SEED)
# define an early stopper
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'
early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id,
tolerance=TOLERANCE, patience=PATIENCE)
# ==================================================== Train & Validation
# loading the validation negative samples
dataset.load_val_ns()
val_perf_list = []
train_times_l, val_times_l = [], []
free_mem_l, total_mem_l, used_mem_l = [], [], []
start_train_val = timeit.default_timer()
for epoch in range(1, NUM_EPOCH + 1):
# training
start_epoch_train = timeit.default_timer()
loss = train()
end_epoch_train = timeit.default_timer()
print(
f"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {end_epoch_train - start_epoch_train: .4f}"
)
# checking GPU memory usage
free_mem, used_mem, total_mem = 0, 0, 0
if torch.cuda.is_available():
print("DEBUG: device: {}".format(torch.cuda.get_device_name(0)))
free_mem, total_mem = torch.cuda.mem_get_info()
used_mem = total_mem - free_mem
print("------------Epoch {}: GPU memory usage-----------".format(epoch))
print("Free memory: {}".format(free_mem))
print("Total available memory: {}".format(total_mem))
print("Used memory: {}".format(used_mem))
print("--------------------------------------------")
train_times_l.append(end_epoch_train - start_epoch_train)
free_mem_l.append(float((free_mem*1.0)/2**30)) # in GB
used_mem_l.append(float((used_mem*1.0)/2**30)) # in GB
total_mem_l.append(float((total_mem*1.0)/2**30)) # in GB
# validation
start_val = timeit.default_timer()
perf_metric_val = test(val_loader, neg_sampler, split_mode="val")
end_val = timeit.default_timer()
print(f"\tValidation {metric}: {perf_metric_val: .4f}")
print(f"\tValidation: Elapsed time (s): {end_val - start_val: .4f}")
val_perf_list.append(perf_metric_val)
val_times_l.append(end_val - start_val)
# check for early stopping
if early_stopper.step_check(perf_metric_val, model):
break
train_val_time = timeit.default_timer() - start_train_val
print(f"Train & Validation Total Elapsed Time (s): {train_val_time: .4f}")
# ==================================================== Test
# first, load the best model
early_stopper.load_checkpoint(model)
# loading the test negative samples
dataset.load_test_ns()
# final testing
start_test = timeit.default_timer()
perf_metric_test = test(test_loader, neg_sampler, split_mode="test")
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'data': DATA,
'model': MODEL_NAME,
'run': run_idx,
'seed': SEED,
'train_times': train_times_l,
'free_mem': free_mem_l,
'total_mem': total_mem_l,
'used_mem': used_mem_l,
'max_used_mem': max(used_mem_l),
'val_times': val_times_l,
f'val {metric}': val_perf_list,
f'test {metric}': perf_metric_test,
'test_time': test_time,
'train_val_total_time': np.sum(np.array(train_times_l)) + np.sum(np.array(val_times_l)),
},
results_filename)
print(f"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<")
print('-------------------------------------------------------------------------------')
print(f"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}")
print("==============================================================")
================================================
FILE: examples/linkproppred/tgbl-flight/edgebank.py
================================================
"""
Dynamic Link Prediction with EdgeBank
NOTE: This implementation works only based on `numpy`
Reference:
- https://github.com/fpour/DGB/tree/main
"""
import timeit
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch_geometric.loader import TemporalDataLoader
from tqdm import tqdm
import math
import os
import os.path as osp
from pathlib import Path
import sys
import argparse
# internal imports
from tgb.linkproppred.evaluate import Evaluator
from modules.edgebank_predictor import EdgeBankPredictor
from tgb.utils.utils import set_random_seed
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.utils.utils import save_results
# ==================
# ==================
# ==================
def test(data, test_mask, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
data: a dataset object
test_mask: required masks to load the test set edges
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluation
"""
num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)
perf_list = []
for batch_idx in tqdm(range(num_batches)):
start_idx = batch_idx * BATCH_SIZE
end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))
pos_src, pos_dst, pos_t = (
data['sources'][test_mask][start_idx: end_idx],
data['destinations'][test_mask][start_idx: end_idx],
data['timestamps'][test_mask][start_idx: end_idx],
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])
query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])
y_pred = edgebank.predict_link(query_src, query_dst)
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0]]),
"y_pred_neg": np.array(y_pred[1:]),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# update edgebank memory after each positive batch
edgebank.update_memory(pos_src, pos_dst, pos_t)
perf_metrics = float(np.mean(perf_list))
return perf_metrics
def get_args():
parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')
parser.add_argument('-d', '--data', type=str, help='Dataset name', default='tgbl-flight')
parser.add_argument('--bs', type=int, help='Batch size', default=200)
parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)
parser.add_argument('--seed', type=int, help='Random seed', default=1)
parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])
parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)
try:
args = parser.parse_args()
except:
parser.print_help()
sys.exit(0)
return args, sys.argv
# ==================
# ==================
# ==================
start_overall = timeit.default_timer()
# set hyperparameters
args, _ = get_args()
SEED = args.seed # set the random seed for consistency
set_random_seed(SEED)
MEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`
BATCH_SIZE = args.bs
K_VALUE = args.k_value
TIME_WINDOW_RATIO = args.time_window_ratio
DATA = "tgbl-flight"
MODEL_NAME = 'EdgeBank'
# data loading with `numpy`
dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
data = dataset.full_data
metric = dataset.eval_metric
# get masks
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
#data for memory in edgebank
hist_src = np.concatenate([data['sources'][train_mask]])
hist_dst = np.concatenate([data['destinations'][train_mask]])
hist_ts = np.concatenate([data['timestamps'][train_mask]])
# Set EdgeBank with memory updater
edgebank = EdgeBankPredictor(
hist_src,
hist_dst,
hist_ts,
memory_mode=MEMORY_MODE,
time_window_ratio=TIME_WINDOW_RATIO)
print("==========================================================")
print(f"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'
# ==================================================== Test
# loading the validation negative samples
dataset.load_val_ns()
# testing ...
start_val = timeit.default_timer()
perf_metric_test = test(data, val_mask, neg_sampler, split_mode='val')
end_val = timeit.default_timer()
print(f"INFO: val: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tval: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_val
print(f"\tval: Elapsed Time (s): {test_time: .4f}")
# ==================================================== Test
# loading the test negative samples
dataset.load_test_ns()
# testing ...
start_test = timeit.default_timer()
perf_metric_test = test(data, test_mask, neg_sampler, split_mode='test')
end_test = timeit.default_timer()
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'memory_mode': MEMORY_MODE,
'data': DATA,
'run': 1,
'seed': SEED,
metric: perf_metric_test,
'test_time': test_time,
'tot_train_val_time': 'NA'
},
results_filename)
================================================
FILE: examples/linkproppred/tgbl-flight/tgn.py
================================================
"""
Dynamic Link Prediction with a TGN model with Early Stopping
Reference:
- https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py
command for an example run:
python examples/linkproppred/tgbl-flight/tgn.py --data "tgbl-flight" --num_run 1 --seed 1
"""
import math
import timeit
import os
import os.path as osp
from pathlib import Path
import numpy as np
import torch
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn import Linear
from torch_geometric.datasets import JODIEDataset
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TransformerConv
# internal imports
from tgb.utils.utils import get_args, set_random_seed, save_results
from tgb.linkproppred.evaluate import Evaluator
from modules.decoder import LinkPredictor
from modules.emb_module import GraphAttentionEmbedding
from modules.msg_func import IdentityMessage
from modules.msg_agg import LastAggregator
from modules.neighbor_loader import LastNeighborLoader
from modules.memory_module import TGNMemory
from modules.early_stopping import EarlyStopMonitor
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
# ==========
# ========== Define helper function...
# ==========
def train():
r"""
Training procedure for TGN model
This function uses some objects that are globally defined in the current scrips
Parameters:
None
Returns:
None
"""
model['memory'].train()
model['gnn'].train()
model['link_pred'].train()
model['memory'].reset_state() # Start with a fresh memory.
neighbor_loader.reset_state() # Start with an empty graph.
total_loss = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
# Sample negative destination nodes.
neg_dst = torch.randint(
min_dst_idx,
max_dst_idx + 1,
(src.size(0),),
dtype=torch.long,
device=device,
)
n_id = torch.cat([src, pos_dst, neg_dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])
neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])
loss = criterion(pos_out, torch.ones_like(pos_out))
loss += criterion(neg_out, torch.zeros_like(neg_out))
# Update memory and neighbor loader with ground-truth state.
model['memory'].update_state(src, pos_dst, t, msg)
neighbor_loader.insert(src, pos_dst)
loss.backward()
optimizer.step()
model['memory'].detach()
total_loss += float(loss.detach()) * batch.num_events
return total_loss / train_data.num_events
@torch.no_grad()
def test(loader, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
loader: an object containing positive attributes of the positive edges of the evaluation set
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluation
"""
model['memory'].eval()
model['gnn'].eval()
model['link_pred'].eval()
perf_list = []
for pos_batch in loader:
pos_src, pos_dst, pos_t, pos_msg = (
pos_batch.src,
pos_batch.dst,
pos_batch.t,
pos_batch.msg,
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)
dst = torch.tensor(
np.concatenate(
([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),
axis=0,
),
device=device,
)
n_id = torch.cat([src, dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),
"y_pred_neg": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# Update memory and neighbor loader with ground-truth state.
model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg)
neighbor_loader.insert(pos_src, pos_dst)
perf_metrics = float(torch.tensor(perf_list).mean())
return perf_metrics
# ==========
# ==========
# ==========
# Start...
start_overall = timeit.default_timer()
DATA = "tgbl-flight"
# ========== set parameters...
args, _ = get_args()
args.data = DATA
print("INFO: Arguments:", args)
LR = args.lr
BATCH_SIZE = args.bs
K_VALUE = args.k_value
NUM_EPOCH = args.num_epoch
SEED = args.seed
MEM_DIM = args.mem_dim
TIME_DIM = args.time_dim
EMB_DIM = args.emb_dim
TOLERANCE = args.tolerance
PATIENCE = args.patience
NUM_RUNS = args.num_run
NUM_NEIGHBORS = 10
MODEL_NAME = 'TGN'
# ==========
# set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# data loading
dataset = PyGLinkPropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data = data.to(device)
metric = dataset.eval_metric
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
train_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)
val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)
test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
print("==========================================================")
print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
for run_idx in range(NUM_RUNS):
print('-------------------------------------------------------------------------------')
print(f"INFO: >>>>> Run: {run_idx} <<<<<")
start_run = timeit.default_timer()
# set the seed for deterministic results...
torch.manual_seed(run_idx + SEED)
set_random_seed(run_idx + SEED)
# neighborhood sampler
neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)
# define the model end-to-end
memory = TGNMemory(
data.num_nodes,
data.msg.size(-1),
MEM_DIM,
TIME_DIM,
message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),
aggregator_module=LastAggregator(),
).to(device)
gnn = GraphAttentionEmbedding(
in_channels=MEM_DIM,
out_channels=EMB_DIM,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
).to(device)
link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)
model = {'memory': memory,
'gnn': gnn,
'link_pred': link_pred}
optimizer = torch.optim.Adam(
set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),
lr=LR,
)
criterion = torch.nn.BCEWithLogitsLoss()
# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)
# define an early stopper
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'
early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id,
tolerance=TOLERANCE, patience=PATIENCE)
# ==================================================== Train & Validation
# loading the validation negative samples
dataset.load_val_ns()
val_perf_list = []
train_times_l, val_times_l = [], []
free_mem_l, total_mem_l, used_mem_l = [], [], []
start_train_val = timeit.default_timer()
for epoch in range(1, NUM_EPOCH + 1):
# training
start_epoch_train = timeit.default_timer()
loss = train()
end_epoch_train = timeit.default_timer()
print(
f"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {end_epoch_train - start_epoch_train: .4f}"
)
# checking GPU memory usage
free_mem, used_mem, total_mem = 0, 0, 0
if torch.cuda.is_available():
print("DEBUG: device: {}".format(torch.cuda.get_device_name(0)))
free_mem, total_mem = torch.cuda.mem_get_info()
used_mem = total_mem - free_mem
print("------------Epoch {}: GPU memory usage-----------".format(epoch))
print("Free memory: {}".format(free_mem))
print("Total available memory: {}".format(total_mem))
print("Used memory: {}".format(used_mem))
print("--------------------------------------------")
train_times_l.append(end_epoch_train - start_epoch_train)
free_mem_l.append(float((free_mem*1.0)/2**30)) # in GB
used_mem_l.append(float((used_mem*1.0)/2**30)) # in GB
total_mem_l.append(float((total_mem*1.0)/2**30)) # in GB
# validation
start_val = timeit.default_timer()
perf_metric_val = test(val_loader, neg_sampler, split_mode="val")
end_val = timeit.default_timer()
print(f"\tValidation {metric}: {perf_metric_val: .4f}")
print(f"\tValidation: Elapsed time (s): {end_val - start_val: .4f}")
val_perf_list.append(perf_metric_val)
val_times_l.append(end_val - start_val)
# check for early stopping
if early_stopper.step_check(perf_metric_val, model):
break
train_val_time = timeit.default_timer() - start_train_val
print(f"Train & Validation: Elapsed Time (s): {train_val_time: .4f}")
# ==================================================== Test
# first, load the best model
early_stopper.load_checkpoint(model)
# loading the test negative samples
dataset.load_test_ns()
# final testing
start_test = timeit.default_timer()
perf_metric_test = test(test_loader, neg_sampler, split_mode="test")
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'data': DATA,
'model': MODEL_NAME,
'run': run_idx,
'seed': SEED,
'train_times': train_times_l,
'free_mem': free_mem_l,
'total_mem': total_mem_l,
'used_mem': used_mem_l,
'max_used_mem': max(used_mem_l),
'val_times': val_times_l,
f'val {metric}': val_perf_list,
f'test {metric}': perf_metric_test,
'test_time': test_time,
'train_val_total_time': np.sum(np.array(train_times_l)) + np.sum(np.array(val_times_l)),
},
results_filename)
print(f"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<")
print('-------------------------------------------------------------------------------')
print(f"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}")
print("==============================================================")
================================================
FILE: examples/linkproppred/tgbl-lastfm/edgebank.py
================================================
"""
Dynamic Link Prediction with EdgeBank
NOTE: This implementation works only based on `numpy`
Reference:
- https://github.com/fpour/DGB/tree/main
"""
import timeit
import numpy as np
from tqdm import tqdm
import math
import os
import os.path as osp
from pathlib import Path
import sys
import argparse
# internal imports
from tgb.linkproppred.evaluate import Evaluator
from modules.edgebank_predictor import EdgeBankPredictor
from tgb.utils.utils import set_random_seed
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.utils.utils import save_results
# ==================
# ==================
# ==================
def test(data, test_mask, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
data: a dataset object
test_mask: required masks to load the test set edges
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluaiton
"""
num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)
perf_list = []
for batch_idx in tqdm(range(num_batches)):
start_idx = batch_idx * BATCH_SIZE
end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))
pos_src, pos_dst, pos_t = (
data['sources'][test_mask][start_idx: end_idx],
data['destinations'][test_mask][start_idx: end_idx],
data['timestamps'][test_mask][start_idx: end_idx],
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])
query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])
y_pred = edgebank.predict_link(query_src, query_dst)
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0]]),
"y_pred_neg": np.array(y_pred[1:]),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# update edgebank memory after each positive batch
edgebank.update_memory(pos_src, pos_dst, pos_t)
perf_metrics = float(np.mean(perf_list))
return perf_metrics
def get_args():
parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')
parser.add_argument('-d', '--data', type=str, help='Dataset name', default='tgbl-wiki')
parser.add_argument('--bs', type=int, help='Batch size', default=200)
parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)
parser.add_argument('--seed', type=int, help='Random seed', default=1)
parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])
parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)
try:
args = parser.parse_args()
except:
parser.print_help()
sys.exit(0)
return args, sys.argv
# ==================
# ==================
# ==================
start_overall = timeit.default_timer()
# set hyperparameters
args, _ = get_args()
SEED = args.seed # set the random seed for consistency
set_random_seed(SEED)
MEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`
BATCH_SIZE = args.bs
K_VALUE = args.k_value
TIME_WINDOW_RATIO = args.time_window_ratio
DATA = "tgbl-lastfm"
MODEL_NAME = 'EdgeBank'
# data loading with `numpy`
dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
data = dataset.full_data
metric = dataset.eval_metric
# get masks
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
#data for memory in edgebank
hist_src = np.concatenate([data['sources'][train_mask]])
hist_dst = np.concatenate([data['destinations'][train_mask]])
hist_ts = np.concatenate([data['timestamps'][train_mask]])
# Set EdgeBank with memory updater
edgebank = EdgeBankPredictor(
hist_src,
hist_dst,
hist_ts,
memory_mode=MEMORY_MODE,
time_window_ratio=TIME_WINDOW_RATIO)
print("==========================================================")
print(f"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'
# ==================================================== Test
# loading the validation negative samples
dataset.load_val_ns()
# testing ...
start_val = timeit.default_timer()
perf_metric_test = test(data, val_mask, neg_sampler, split_mode='val')
end_val = timeit.default_timer()
print(f"INFO: val: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tval: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_val
print(f"\tval: Elapsed Time (s): {test_time: .4f}")
# ==================================================== Test
# loading the test negative samples
dataset.load_test_ns()
# testing ...
start_test = timeit.default_timer()
perf_metric_test = test(data, test_mask, neg_sampler, split_mode='test')
end_test = timeit.default_timer()
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'memory_mode': MEMORY_MODE,
'data': DATA,
'run': 1,
'seed': SEED,
metric: perf_metric_test,
'test_time': test_time,
'tot_train_val_time': 'NA'
},
results_filename)
================================================
FILE: examples/linkproppred/tgbl-lastfm/tgn.py
================================================
"""
Dynamic Link Prediction with a TGN model with Early Stopping
Reference:
- https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py
command for an example run:
python examples/linkproppred/tgbl-lastfm/tgn.py --data "tgbl-lastfm" --num_run 1 --seed 1
"""
import math
import timeit
import os
import os.path as osp
from pathlib import Path
import numpy as np
import torch
from torch.nn import Linear
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TransformerConv
# internal imports
from tgb.utils.utils import get_args, set_random_seed, save_results
from tgb.linkproppred.evaluate import Evaluator
from modules.decoder import LinkPredictor
from modules.emb_module import GraphAttentionEmbedding
from modules.msg_func import IdentityMessage
from modules.msg_agg import LastAggregator
from modules.neighbor_loader import LastNeighborLoader
from modules.memory_module import TGNMemory
from modules.early_stopping import EarlyStopMonitor
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
# ==========
# ========== Define helper function...
# ==========
def train():
r"""
Training procedure for TGN model
This function uses some objects that are globally defined in the current scrips
Parameters:
None
Returns:
None
"""
model['memory'].train()
model['gnn'].train()
model['link_pred'].train()
model['memory'].reset_state() # Start with a fresh memory.
neighbor_loader.reset_state() # Start with an empty graph.
total_loss = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
# Sample negative destination nodes.
neg_dst = torch.randint(
min_dst_idx,
max_dst_idx + 1,
(src.size(0),),
dtype=torch.long,
device=device,
)
n_id = torch.cat([src, pos_dst, neg_dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])
neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])
loss = criterion(pos_out, torch.ones_like(pos_out))
loss += criterion(neg_out, torch.zeros_like(neg_out))
# Update memory and neighbor loader with ground-truth state.
model['memory'].update_state(src, pos_dst, t, msg)
neighbor_loader.insert(src, pos_dst)
loss.backward()
optimizer.step()
model['memory'].detach()
total_loss += float(loss.detach()) * batch.num_events
return total_loss / train_data.num_events
@torch.no_grad()
def test(loader, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
loader: an object containing positive attributes of the positive edges of the evaluation set
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluaiton
"""
model['memory'].eval()
model['gnn'].eval()
model['link_pred'].eval()
perf_list = []
for pos_batch in loader:
pos_src, pos_dst, pos_t, pos_msg = (
pos_batch.src,
pos_batch.dst,
pos_batch.t,
pos_batch.msg,
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)
dst = torch.tensor(
np.concatenate(
([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),
axis=0,
),
device=device,
)
n_id = torch.cat([src, dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),
"y_pred_neg": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# Update memory and neighbor loader with ground-truth state.
model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg)
neighbor_loader.insert(pos_src, pos_dst)
perf_metrics = float(torch.tensor(perf_list).mean())
return perf_metrics
# ==========
# ==========
# ==========
# Start...
start_overall = timeit.default_timer()
DATA = "tgbl-lastfm"
# ========== set parameters...
args, _ = get_args()
args.data = DATA
print("INFO: Arguments:", args)
LR = args.lr
BATCH_SIZE = args.bs
K_VALUE = args.k_value
NUM_EPOCH = args.num_epoch
SEED = args.seed
MEM_DIM = args.mem_dim
TIME_DIM = args.time_dim
EMB_DIM = args.emb_dim
TOLERANCE = args.tolerance
PATIENCE = args.patience
NUM_RUNS = args.num_run
NUM_NEIGHBORS = 10
MODEL_NAME = 'TGN'
# ==========
# set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# data loading
dataset = PyGLinkPropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data = data.to(device)
metric = dataset.eval_metric
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
train_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)
val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)
test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
print("==========================================================")
print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
for run_idx in range(NUM_RUNS):
print('-------------------------------------------------------------------------------')
print(f"INFO: >>>>> Run: {run_idx} <<<<<")
start_run = timeit.default_timer()
# set the seed for deterministic results...
torch.manual_seed(run_idx + SEED)
set_random_seed(run_idx + SEED)
# define an early stopper
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'
early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id,
tolerance=TOLERANCE, patience=PATIENCE)
# neighhorhood sampler
neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)
# define the model end-to-end
memory = TGNMemory(
data.num_nodes,
data.msg.size(-1),
MEM_DIM,
TIME_DIM,
message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),
aggregator_module=LastAggregator(),
).to(device)
gnn = GraphAttentionEmbedding(
in_channels=MEM_DIM,
out_channels=EMB_DIM,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
).to(device)
link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)
model = {'memory': memory,
'gnn': gnn,
'link_pred': link_pred}
optimizer = torch.optim.Adam(
set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),
lr=LR,
)
criterion = torch.nn.BCEWithLogitsLoss()
# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)
# ==================================================== Train & Validation
# loading the validation negative samples
dataset.load_val_ns()
val_perf_list = []
start_train_val = timeit.default_timer()
for epoch in range(1, NUM_EPOCH + 1):
# training
start_epoch_train = timeit.default_timer()
loss = train()
print(
f"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {timeit.default_timer() - start_epoch_train: .4f}"
)
# validation
start_val = timeit.default_timer()
perf_metric_val = test(val_loader, neg_sampler, split_mode="val")
print(f"\tValidation {metric}: {perf_metric_val: .4f}")
print(f"\tValidation: Elapsed time (s): {timeit.default_timer() - start_val: .4f}")
val_perf_list.append(perf_metric_val)
# check for early stopping
if early_stopper.step_check(perf_metric_val, model):
break
train_val_time = timeit.default_timer() - start_train_val
print(f"Train & Validation: Elapsed Time (s): {train_val_time: .4f}")
# ==================================================== Test
# first, load the best model
early_stopper.load_checkpoint(model)
# loading the test negative samples
dataset.load_test_ns()
# final testing
start_test = timeit.default_timer()
perf_metric_test = test(test_loader, neg_sampler, split_mode="test")
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'data': DATA,
'run': run_idx,
'seed': SEED,
f'val {metric}': val_perf_list,
f'test {metric}': perf_metric_test,
'test_time': test_time,
'tot_train_val_time': train_val_time
},
results_filename)
print(f"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<")
print('-------------------------------------------------------------------------------')
print(f"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}")
print("==============================================================")
================================================
FILE: examples/linkproppred/tgbl-review/dyrep.py
================================================
"""
DyRep
This has been implemented with intuitions from the following sources:
- https://github.com/twitter-research/tgn
- https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py
Spec.:
- Memory Updater: RNN
- Embedding Module: ID
- Message Function: ATTN
command for an example run:
python examples/linkproppred/tgbl-review/dyrep.py --data "tgbl-review" --num_run 1 --seed 1
"""
import math
import timeit
import os
import os.path as osp
from pathlib import Path
import numpy as np
import torch
from torch_geometric.loader import TemporalDataLoader
# internal imports
from tgb.utils.utils import get_args, set_random_seed, save_results
from tgb.linkproppred.evaluate import Evaluator
from modules.decoder import LinkPredictor
from modules.emb_module import GraphAttentionEmbedding
from modules.msg_func import IdentityMessage
from modules.msg_agg import LastAggregator
from modules.neighbor_loader import LastNeighborLoader
from modules.memory_module import DyRepMemory
from modules.early_stopping import EarlyStopMonitor
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
# ==========
# ========== Define helper function...
# ==========
def train():
r"""
Training procedure for DyRep model
This function uses some objects that are globally defined in the current scrips
Parameters:
None
Returns:
None
"""
model['memory'].train()
model['gnn'].train()
model['link_pred'].train()
model['memory'].reset_state() # Start with a fresh memory.
neighbor_loader.reset_state() # Start with an empty graph.
total_loss = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
# Sample negative destination nodes.
neg_dst = torch.randint(
min_dst_idx,
max_dst_idx + 1,
(src.size(0),),
dtype=torch.long,
device=device,
)
n_id = torch.cat([src, pos_dst, neg_dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])
neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])
loss = criterion(pos_out, torch.ones_like(pos_out))
loss += criterion(neg_out, torch.zeros_like(neg_out))
# update the memory with ground-truth
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
model['memory'].update_state(src, pos_dst, t, msg, z, assoc)
# update neighbor loader
neighbor_loader.insert(src, pos_dst)
loss.backward()
optimizer.step()
model['memory'].detach()
total_loss += float(loss.detach()) * batch.num_events
return total_loss / train_data.num_events
@torch.no_grad()
def test(loader, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
loader: an object containing positive attributes of the positive edges of the evaluation set
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluation
"""
model['memory'].eval()
model['gnn'].eval()
model['link_pred'].eval()
perf_list = []
for pos_batch in loader:
pos_src, pos_dst, pos_t, pos_msg = (
pos_batch.src,
pos_batch.dst,
pos_batch.t,
pos_batch.msg,
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)
dst = torch.tensor(
np.concatenate(
([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),
axis=0,
),
device=device,
)
n_id = torch.cat([src, dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),
"y_pred_neg": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# update the memory with positive edges
n_id = torch.cat([pos_src, pos_dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg, z, assoc)
# update the neighbor loader
neighbor_loader.insert(pos_src, pos_dst)
perf_metric = float(torch.tensor(perf_list).mean())
return perf_metric
# ==========
# ==========
# ==========
# Start...
start_overall = timeit.default_timer()
# ========== set parameters...
args, _ = get_args()
print("INFO: Arguments:", args)
DATA = "tgbl-review"
LR = args.lr
BATCH_SIZE = args.bs
K_VALUE = args.k_value
NUM_EPOCH = args.num_epoch
SEED = args.seed
MEM_DIM = args.mem_dim
TIME_DIM = args.time_dim
EMB_DIM = args.emb_dim
TOLERANCE = args.tolerance
PATIENCE = args.patience
NUM_RUNS = args.num_run
NUM_NEIGHBORS = 10
MODEL_NAME = 'DyRep'
USE_SRC_EMB_IN_MSG = False
USE_DST_EMB_IN_MSG = True
# ==========
# set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# data loading
dataset = PyGLinkPropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data = data.to(device)
metric = dataset.eval_metric
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
train_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)
val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)
test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
# neighborhood sampler
neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)
# define the model end-to-end
# 1) memory
memory = DyRepMemory(
data.num_nodes,
data.msg.size(-1),
MEM_DIM,
TIME_DIM,
message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),
aggregator_module=LastAggregator(),
memory_updater_type='rnn',
use_src_emb_in_msg=USE_SRC_EMB_IN_MSG,
use_dst_emb_in_msg=USE_DST_EMB_IN_MSG
).to(device)
# 2) GNN
gnn = GraphAttentionEmbedding(
in_channels=MEM_DIM,
out_channels=EMB_DIM,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
).to(device)
# 3) link predictor
link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)
model = {'memory': memory,
'gnn': gnn,
'link_pred': link_pred}
# define an optimizer
optimizer = torch.optim.Adam(
set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),
lr=LR,
)
criterion = torch.nn.BCEWithLogitsLoss()
# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)
print("==========================================================")
print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
for run_idx in range(NUM_RUNS):
print('-------------------------------------------------------------------------------')
print(f"INFO: >>>>> Run: {run_idx} <<<<<")
start_run = timeit.default_timer()
# set the seed for deterministic results...
torch.manual_seed(run_idx + SEED)
set_random_seed(run_idx + SEED)
# define an early stopper
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'
early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id,
tolerance=TOLERANCE, patience=PATIENCE)
# ==================================================== Train & Validation
# loading the validation negative samples
dataset.load_val_ns()
val_perf_list = []
start_train_val = timeit.default_timer()
for epoch in range(1, NUM_EPOCH + 1):
# training
start_epoch_train = timeit.default_timer()
loss = train()
print(
f"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {timeit.default_timer() - start_epoch_train: .4f}"
)
# validation
start_val = timeit.default_timer()
perf_metric_val = test(val_loader, neg_sampler, split_mode="val")
print(f"\tValidation {metric}: {perf_metric_val: .4f}")
print(f"\tValidation: Elapsed time (s): {timeit.default_timer() - start_val: .4f}")
val_perf_list.append(perf_metric_val)
# check for early stopping
if early_stopper.step_check(perf_metric_val, model):
break
train_val_time = timeit.default_timer() - start_train_val
print(f"Train & Validation Total Elapsed Time (s): {train_val_time: .4f}")
# ==================================================== Test
# first, load the best model
early_stopper.load_checkpoint(model)
# loading the test negative samples
dataset.load_test_ns()
# final testing
start_test = timeit.default_timer()
perf_metric_test = test(test_loader, neg_sampler, split_mode="test")
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'data': DATA,
'run': run_idx,
'seed': SEED,
f'val {metric}': val_perf_list,
f'test {metric}': perf_metric_test,
'test_time': test_time,
'tot_train_val_time': train_val_time
},
results_filename)
print(f"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<")
print('-------------------------------------------------------------------------------')
print(f"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}")
print("==============================================================")
================================================
FILE: examples/linkproppred/tgbl-review/edgebank.py
================================================
"""
Dynamic Link Prediction with EdgeBank
NOTE: This implementation works only based on `numpy`
Reference:
- https://github.com/fpour/DGB/tree/main
"""
import timeit
import numpy as np
from tqdm import tqdm
import math
import os
import os.path as osp
from pathlib import Path
import sys
import argparse
# internal imports
from tgb.linkproppred.evaluate import Evaluator
from modules.edgebank_predictor import EdgeBankPredictor
from tgb.utils.utils import set_random_seed
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.utils.utils import save_results
# ==================
# ==================
# ==================
def test(data, test_mask, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
data: a dataset object
test_mask: required masks to load the test set edges
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluation
"""
num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)
perf_list = []
for batch_idx in tqdm(range(num_batches)):
start_idx = batch_idx * BATCH_SIZE
end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))
pos_src, pos_dst, pos_t = (
data['sources'][test_mask][start_idx: end_idx],
data['destinations'][test_mask][start_idx: end_idx],
data['timestamps'][test_mask][start_idx: end_idx],
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])
query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])
y_pred = edgebank.predict_link(query_src, query_dst)
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0]]),
"y_pred_neg": np.array(y_pred[1:]),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# update edgebank memory after each positive batch
edgebank.update_memory(pos_src, pos_dst, pos_t)
perf_metrics = float(np.mean(perf_list))
return perf_metrics
def get_args():
parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')
parser.add_argument('-d', '--data', type=str, help='Dataset name', default='tgbl-review')
parser.add_argument('--bs', type=int, help='Batch size', default=200)
parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)
parser.add_argument('--seed', type=int, help='Random seed', default=1)
parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])
parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)
try:
args = parser.parse_args()
except:
parser.print_help()
sys.exit(0)
return args, sys.argv
# ==================
# ==================
# ==================
start_overall = timeit.default_timer()
# set hyperparameters
args, _ = get_args()
SEED = args.seed # set the random seed for consistency
set_random_seed(SEED)
MEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`
BATCH_SIZE = args.bs
K_VALUE = args.k_value
TIME_WINDOW_RATIO = args.time_window_ratio
DATA = "tgbl-review"
MODEL_NAME = 'EdgeBank'
# data loading with `numpy`
dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
data = dataset.full_data
metric = dataset.eval_metric
# get masks
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
#data for memory in edgebank
hist_src = np.concatenate([data['sources'][train_mask]])
hist_dst = np.concatenate([data['destinations'][train_mask]])
hist_ts = np.concatenate([data['timestamps'][train_mask]])
# Set EdgeBank with memory updater
edgebank = EdgeBankPredictor(
hist_src,
hist_dst,
hist_ts,
memory_mode=MEMORY_MODE,
time_window_ratio=TIME_WINDOW_RATIO)
print("==========================================================")
print(f"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'
# ==================================================== Test
# loading the validation negative samples
dataset.load_val_ns()
# testing ...
start_val = timeit.default_timer()
perf_metric_test = test(data, val_mask, neg_sampler, split_mode='val')
end_val = timeit.default_timer()
print(f"INFO: val: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tval: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_val
print(f"\tval: Elapsed Time (s): {test_time: .4f}")
# ==================================================== Test
# loading the test negative samples
dataset.load_test_ns()
# testing ...
start_test = timeit.default_timer()
perf_metric_test = test(data, test_mask, neg_sampler, split_mode='test')
end_test = timeit.default_timer()
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'memory_mode': MEMORY_MODE,
'data': DATA,
'run': 1,
'seed': SEED,
metric: perf_metric_test,
'test_time': test_time,
'tot_train_val_time': 'NA'
},
results_filename)
================================================
FILE: examples/linkproppred/tgbl-review/tgn.py
================================================
"""
Dynamic Link Prediction with a TGN model with Early Stopping
Reference:
- https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py
command for an example run:
python examples/linkproppred/tgbl-review/tgn.py --data "tgbl-review" --num_run 1 --seed 1
"""
import math
import timeit
import os
import os.path as osp
from pathlib import Path
import numpy as np
import torch
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn import Linear
from torch_geometric.datasets import JODIEDataset
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TransformerConv
# internal imports
from tgb.utils.utils import get_args, set_random_seed, save_results
from tgb.linkproppred.evaluate import Evaluator
from modules.decoder import LinkPredictor
from modules.emb_module import GraphAttentionEmbedding
from modules.msg_func import IdentityMessage
from modules.msg_agg import LastAggregator
from modules.neighbor_loader import LastNeighborLoader
from modules.memory_module import TGNMemory
from modules.early_stopping import EarlyStopMonitor
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
# ==========
# ========== Define helper function...
# ==========
def train():
r"""
Training procedure for TGN model
This function uses some objects that are globally defined in the current scrips
Parameters:
None
Returns:
None
"""
model['memory'].train()
model['gnn'].train()
model['link_pred'].train()
model['memory'].reset_state() # Start with a fresh memory.
neighbor_loader.reset_state() # Start with an empty graph.
total_loss = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
# Sample negative destination nodes.
neg_dst = torch.randint(
min_dst_idx,
max_dst_idx + 1,
(src.size(0),),
dtype=torch.long,
device=device,
)
n_id = torch.cat([src, pos_dst, neg_dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])
neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])
loss = criterion(pos_out, torch.ones_like(pos_out))
loss += criterion(neg_out, torch.zeros_like(neg_out))
# Update memory and neighbor loader with ground-truth state.
model['memory'].update_state(src, pos_dst, t, msg)
neighbor_loader.insert(src, pos_dst)
loss.backward()
optimizer.step()
model['memory'].detach()
total_loss += float(loss.detach()) * batch.num_events
return total_loss / train_data.num_events
@torch.no_grad()
def test(loader, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
loader: an object containing positive attributes of the positive edges of the evaluation set
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluation
"""
model['memory'].eval()
model['gnn'].eval()
model['link_pred'].eval()
perf_list = []
for pos_batch in loader:
pos_src, pos_dst, pos_t, pos_msg = (
pos_batch.src,
pos_batch.dst,
pos_batch.t,
pos_batch.msg,
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)
dst = torch.tensor(
np.concatenate(
([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),
axis=0,
),
device=device,
)
n_id = torch.cat([src, dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),
"y_pred_neg": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# Update memory and neighbor loader with ground-truth state.
model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg)
neighbor_loader.insert(pos_src, pos_dst)
perf_metrics = float(torch.tensor(perf_list).mean())
return perf_metrics
# ==========
# ==========
# ==========
# Start...
start_overall = timeit.default_timer()
DATA = "tgbl-review"
# ========== set parameters...
args, _ = get_args()
args.data = DATA
print("INFO: Arguments:", args)
LR = args.lr
BATCH_SIZE = args.bs
K_VALUE = args.k_value
NUM_EPOCH = args.num_epoch
SEED = args.seed
MEM_DIM = args.mem_dim
TIME_DIM = args.time_dim
EMB_DIM = args.emb_dim
TOLERANCE = args.tolerance
PATIENCE = args.patience
NUM_RUNS = args.num_run
NUM_NEIGHBORS = 10
MODEL_NAME = 'TGN'
# ==========
# set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# data loading
dataset = PyGLinkPropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data = data.to(device)
metric = dataset.eval_metric
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
train_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)
val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)
test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
print("==========================================================")
print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
for run_idx in range(NUM_RUNS):
print('-------------------------------------------------------------------------------')
print(f"INFO: >>>>> Run: {run_idx} <<<<<")
start_run = timeit.default_timer()
# set the seed for deterministic results...
torch.manual_seed(run_idx + SEED)
set_random_seed(run_idx + SEED)
# neighborhood sampler
neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)
# define the model end-to-end
memory = TGNMemory(
data.num_nodes,
data.msg.size(-1),
MEM_DIM,
TIME_DIM,
message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),
aggregator_module=LastAggregator(),
).to(device)
gnn = GraphAttentionEmbedding(
in_channels=MEM_DIM,
out_channels=EMB_DIM,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
).to(device)
link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)
model = {'memory': memory,
'gnn': gnn,
'link_pred': link_pred}
optimizer = torch.optim.Adam(
set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),
lr=LR,
)
criterion = torch.nn.BCEWithLogitsLoss()
# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)
# define an early stopper
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'
early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id,
tolerance=TOLERANCE, patience=PATIENCE)
# ==================================================== Train & Validation
# loading the validation negative samples
dataset.load_val_ns()
val_perf_list = []
start_train_val = timeit.default_timer()
for epoch in range(1, NUM_EPOCH + 1):
# training
start_epoch_train = timeit.default_timer()
loss = train()
print(
f"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {timeit.default_timer() - start_epoch_train: .4f}"
)
# validation
start_val = timeit.default_timer()
perf_metric_val = test(val_loader, neg_sampler, split_mode="val")
print(f"\tValidation {metric}: {perf_metric_val: .4f}")
print(f"\tValidation: Elapsed time (s): {timeit.default_timer() - start_val: .4f}")
val_perf_list.append(perf_metric_val)
# check for early stopping
if early_stopper.step_check(perf_metric_val, model):
break
train_val_time = timeit.default_timer() - start_train_val
print(f"Train & Validation: Elapsed Time (s): {train_val_time: .4f}")
# ==================================================== Test
# first, load the best model
early_stopper.load_checkpoint(model)
# loading the test negative samples
dataset.load_test_ns()
# final testing
start_test = timeit.default_timer()
perf_metric_test = test(test_loader, neg_sampler, split_mode="test")
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'data': DATA,
'run': run_idx,
'seed': SEED,
f'val {metric}': val_perf_list,
f'test {metric}': perf_metric_test,
'test_time': test_time,
'tot_train_val_time': train_val_time
},
results_filename)
print(f"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<")
print('-------------------------------------------------------------------------------')
print(f"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}")
print("==============================================================")
================================================
FILE: examples/linkproppred/tgbl-subreddit/edgebank.py
================================================
"""
Dynamic Link Prediction with EdgeBank
NOTE: This implementation works only based on `numpy`
Reference:
- https://github.com/fpour/DGB/tree/main
"""
import timeit
import numpy as np
from tqdm import tqdm
import math
import os
import os.path as osp
from pathlib import Path
import sys
import argparse
# internal imports
from tgb.linkproppred.evaluate import Evaluator
from modules.edgebank_predictor import EdgeBankPredictor
from tgb.utils.utils import set_random_seed
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.utils.utils import save_results
# ==================
# ==================
# ==================
def test(data, test_mask, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
data: a dataset object
test_mask: required masks to load the test set edges
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluaiton
"""
num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)
perf_list = []
for batch_idx in tqdm(range(num_batches)):
start_idx = batch_idx * BATCH_SIZE
end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))
pos_src, pos_dst, pos_t = (
data['sources'][test_mask][start_idx: end_idx],
data['destinations'][test_mask][start_idx: end_idx],
data['timestamps'][test_mask][start_idx: end_idx],
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])
query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])
y_pred = edgebank.predict_link(query_src, query_dst)
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0]]),
"y_pred_neg": np.array(y_pred[1:]),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# update edgebank memory after each positive batch
edgebank.update_memory(pos_src, pos_dst, pos_t)
perf_metrics = float(np.mean(perf_list))
return perf_metrics
def get_args():
parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')
parser.add_argument('-d', '--data', type=str, help='Dataset name', default='tgbl-wiki')
parser.add_argument('--bs', type=int, help='Batch size', default=200)
parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)
parser.add_argument('--seed', type=int, help='Random seed', default=1)
parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])
parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)
try:
args = parser.parse_args()
except:
parser.print_help()
sys.exit(0)
return args, sys.argv
# ==================
# ==================
# ==================
start_overall = timeit.default_timer()
# set hyperparameters
args, _ = get_args()
SEED = args.seed # set the random seed for consistency
set_random_seed(SEED)
MEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`
BATCH_SIZE = args.bs
K_VALUE = args.k_value
TIME_WINDOW_RATIO = args.time_window_ratio
DATA = "tgbl-subreddit"
MODEL_NAME = 'EdgeBank'
# data loading with `numpy`
dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
data = dataset.full_data
metric = dataset.eval_metric
# get masks
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
#data for memory in edgebank
hist_src = np.concatenate([data['sources'][train_mask]])
hist_dst = np.concatenate([data['destinations'][train_mask]])
hist_ts = np.concatenate([data['timestamps'][train_mask]])
# Set EdgeBank with memory updater
edgebank = EdgeBankPredictor(
hist_src,
hist_dst,
hist_ts,
memory_mode=MEMORY_MODE,
time_window_ratio=TIME_WINDOW_RATIO)
print("==========================================================")
print(f"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'
# ==================================================== Test
# loading the validation negative samples
dataset.load_val_ns()
# testing ...
start_val = timeit.default_timer()
perf_metric_test = test(data, val_mask, neg_sampler, split_mode='val')
end_val = timeit.default_timer()
print(f"INFO: val: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tval: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_val
print(f"\tval: Elapsed Time (s): {test_time: .4f}")
# ==================================================== Test
# loading the test negative samples
dataset.load_test_ns()
# testing ...
start_test = timeit.default_timer()
perf_metric_test = test(data, test_mask, neg_sampler, split_mode='test')
end_test = timeit.default_timer()
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'memory_mode': MEMORY_MODE,
'data': DATA,
'run': 1,
'seed': SEED,
metric: perf_metric_test,
'test_time': test_time,
'tot_train_val_time': 'NA'
},
results_filename)
================================================
FILE: examples/linkproppred/tgbl-subreddit/tgn.py
================================================
"""
Dynamic Link Prediction with a TGN model with Early Stopping
Reference:
- https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py
command for an example run:
python examples/linkproppred/tgbl-subreddit/tgn.py --data "tgbl-subreddit" --num_run 1 --seed 1
"""
import math
import timeit
import os
import os.path as osp
from pathlib import Path
import numpy as np
import torch
from torch.nn import Linear
from torch_geometric.datasets import JODIEDataset
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TransformerConv
# internal imports
from tgb.utils.utils import get_args, set_random_seed, save_results
from tgb.linkproppred.evaluate import Evaluator
from modules.decoder import LinkPredictor
from modules.emb_module import GraphAttentionEmbedding
from modules.msg_func import IdentityMessage
from modules.msg_agg import LastAggregator
from modules.neighbor_loader import LastNeighborLoader
from modules.memory_module import TGNMemory
from modules.early_stopping import EarlyStopMonitor
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
# ==========
# ========== Define helper function...
# ==========
def train():
r"""
Training procedure for TGN model
This function uses some objects that are globally defined in the current scrips
Parameters:
None
Returns:
None
"""
model['memory'].train()
model['gnn'].train()
model['link_pred'].train()
model['memory'].reset_state() # Start with a fresh memory.
neighbor_loader.reset_state() # Start with an empty graph.
total_loss = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
# Sample negative destination nodes.
neg_dst = torch.randint(
min_dst_idx,
max_dst_idx + 1,
(src.size(0),),
dtype=torch.long,
device=device,
)
n_id = torch.cat([src, pos_dst, neg_dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])
neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])
loss = criterion(pos_out, torch.ones_like(pos_out))
loss += criterion(neg_out, torch.zeros_like(neg_out))
# Update memory and neighbor loader with ground-truth state.
model['memory'].update_state(src, pos_dst, t, msg)
neighbor_loader.insert(src, pos_dst)
loss.backward()
optimizer.step()
model['memory'].detach()
total_loss += float(loss.detach()) * batch.num_events
return total_loss / train_data.num_events
@torch.no_grad()
def test(loader, neg_sampl
gitextract_3nm00l7d/
├── .devcontainer/
│ ├── .gitignore
│ ├── Dockerfile
│ └── devcontainer.json
├── .github/
│ └── workflows/
│ ├── mkdocs.yaml
│ └── pypi.yaml
├── .gitignore
├── LICENSE
├── README.md
├── docs/
│ ├── about.md
│ ├── api/
│ │ ├── tgb.linkproppred.md
│ │ ├── tgb.nodeproppred.md
│ │ └── tgb.utils.md
│ ├── index.md
│ └── tutorials/
│ ├── Edge_data_numpy.ipynb
│ ├── Edge_data_pyg.ipynb
│ └── Node_label_tutorial.ipynb
├── examples/
│ ├── linkproppred/
│ │ ├── tgbl-coin/
│ │ │ ├── dyrep.py
│ │ │ ├── edgebank.py
│ │ │ └── tgn.py
│ │ ├── tgbl-comment/
│ │ │ ├── dyrep.py
│ │ │ ├── edgebank.py
│ │ │ └── tgn.py
│ │ ├── tgbl-enron/
│ │ │ └── edgebank.py
│ │ ├── tgbl-flight/
│ │ │ ├── dyrep.py
│ │ │ ├── edgebank.py
│ │ │ └── tgn.py
│ │ ├── tgbl-lastfm/
│ │ │ ├── edgebank.py
│ │ │ └── tgn.py
│ │ ├── tgbl-review/
│ │ │ ├── dyrep.py
│ │ │ ├── edgebank.py
│ │ │ └── tgn.py
│ │ ├── tgbl-subreddit/
│ │ │ ├── edgebank.py
│ │ │ └── tgn.py
│ │ ├── tgbl-uci/
│ │ │ └── edgebank.py
│ │ ├── tgbl-wiki/
│ │ │ ├── dyrep.py
│ │ │ ├── edgebank.py
│ │ │ └── tgn.py
│ │ ├── thgl-forum/
│ │ │ ├── edgebank.py
│ │ │ ├── recurrencybaseline.py
│ │ │ ├── sthn.py
│ │ │ └── tgn.py
│ │ ├── thgl-github/
│ │ │ ├── edgebank.py
│ │ │ ├── recurrencybaseline.py
│ │ │ ├── run_seeds.sh
│ │ │ ├── sthn.py
│ │ │ └── tgn.py
│ │ ├── thgl-myket/
│ │ │ ├── edgebank.py
│ │ │ ├── recurrencybaseline.py
│ │ │ ├── sthn.py
│ │ │ └── tgn.py
│ │ ├── thgl-software/
│ │ │ ├── STHN_README.md
│ │ │ ├── edgebank.py
│ │ │ ├── recurrencybaseline.py
│ │ │ ├── sthn.py
│ │ │ └── tgn.py
│ │ ├── tkgl-icews/
│ │ │ ├── cen.py
│ │ │ ├── edgebank.py
│ │ │ ├── recurrencybaseline.py
│ │ │ ├── regcn.py
│ │ │ ├── timetraveler.py
│ │ │ ├── tkgl-icews_example.py
│ │ │ └── tlogic.py
│ │ ├── tkgl-polecat/
│ │ │ ├── cen.py
│ │ │ ├── edgebank.py
│ │ │ ├── example.py
│ │ │ ├── recurrencybaseline.py
│ │ │ ├── regcn.py
│ │ │ ├── timetraveler.py
│ │ │ ├── tkgl-polecat_example.py
│ │ │ └── tlogic.py
│ │ ├── tkgl-smallpedia/
│ │ │ ├── cen.py
│ │ │ ├── edgebank.py
│ │ │ ├── recurrencybaseline.py
│ │ │ ├── regcn.py
│ │ │ ├── timetraveler.py
│ │ │ ├── tkgl-smallpedia_example.py
│ │ │ └── tlogic.py
│ │ ├── tkgl-wikidata/
│ │ │ ├── edgebank.py
│ │ │ ├── recurrencybaseline.py
│ │ │ ├── regcn.py
│ │ │ ├── tkgl-wikidata_example.py
│ │ │ └── tlogic.py
│ │ └── tkgl-yago/
│ │ ├── cen.py
│ │ ├── edgebank.py
│ │ ├── recurrencybaseline.py
│ │ ├── regcn.py
│ │ ├── timetraveler.py
│ │ ├── tkgl-yago_example.py
│ │ └── tlogic.py
│ └── nodeproppred/
│ ├── tgbn-genre/
│ │ ├── dyrep.py
│ │ ├── moving_average.py
│ │ ├── persistant_forecast.py
│ │ └── tgn.py
│ ├── tgbn-reddit/
│ │ ├── dyrep.py
│ │ ├── moving_average.py
│ │ ├── persistant_forecast.py
│ │ └── tgn.py
│ ├── tgbn-token/
│ │ ├── dyrep.py
│ │ ├── moving_average.py
│ │ ├── persistant_forecast.py
│ │ └── tgn.py
│ └── tgbn-trade/
│ ├── count_new_nodes.py
│ ├── dyrep.py
│ ├── moving_average.py
│ ├── persistant_forecast.py
│ └── tgn.py
├── mkdocs.yml
├── modules/
│ ├── decoder.py
│ ├── early_stopping.py
│ ├── edgebank_predictor.py
│ ├── emb_module.py
│ ├── heuristics.py
│ ├── memory_module.py
│ ├── msg_agg.py
│ ├── msg_func.py
│ ├── neighbor_loader.py
│ ├── nodebank.py
│ ├── recurrencybaseline_predictor.py
│ ├── rgcn_layers.py
│ ├── rgcn_model.py
│ ├── rrgcn.py
│ ├── sampler_core.cpp
│ ├── sthn.py
│ ├── sthn_sampler_setup.py
│ ├── time_enc.py
│ ├── timetraveler_agent.py
│ ├── timetraveler_dirichlet.py
│ ├── timetraveler_environment.py
│ ├── timetraveler_episode.py
│ ├── timetraveler_policygradient.py
│ ├── timetraveler_trainertester.py
│ ├── tkg_utils.py
│ ├── tkg_utils_dgl.py
│ ├── tlogic_apply_modules.py
│ └── tlogic_learn_modules.py
├── pyproject.toml
├── run.sh
├── scripts/
│ ├── env.sh
│ ├── mila.sh
│ ├── mila_install.sh
│ └── run.sh
├── setup.py
└── tgb/
├── datasets/
│ ├── ICEWS14/
│ │ ├── ent2word.py
│ │ └── icews14.py
│ ├── dataset_scripts/
│ │ ├── MAG/
│ │ │ ├── mag.py
│ │ │ └── old/
│ │ │ └── plot_stats.py
│ │ ├── dgraph.py
│ │ ├── dgraph_Readme.md
│ │ ├── process_arxiv.py
│ │ ├── process_github.py
│ │ ├── tgbl-coin.py
│ │ ├── tgbl-coin_neg_generator.py
│ │ ├── tgbl-comment.py
│ │ ├── tgbl-comment_neg_generator.py
│ │ ├── tgbl-flight.py
│ │ ├── tgbl-flight_neg_generator.py
│ │ ├── tgbl-review.py
│ │ ├── tgbl-review_neg_generator.py
│ │ ├── tgbl-wiki_neg_generator.py
│ │ ├── tgbn-genre.py
│ │ ├── tgbn-reddit.py
│ │ ├── tgbn-token.py
│ │ └── tgbn-trade.py
│ ├── tgbl_enron/
│ │ ├── tgbl-enron_neg_generator.py
│ │ └── tgbl_enron.py
│ ├── tgbl_lastfm/
│ │ └── tgbl-lastfm_neg_generator.py
│ ├── tgbl_subreddit/
│ │ └── tgbl-subreddit_neg_generator.py
│ ├── tgbl_uci/
│ │ ├── tgbl-uci_neg_generator.py
│ │ └── tgbl_uci.py
│ ├── thgl_forum/
│ │ ├── merge_files.py
│ │ ├── thgl-forum.py
│ │ └── thgl_forum_ns_gen.py
│ ├── thgl_github/
│ │ ├── 2024_01/
│ │ │ └── github_extract.py
│ │ ├── 2024_02/
│ │ │ └── github_extract.py
│ │ ├── 2024_03/
│ │ │ └── github_extract.py
│ │ ├── extract_subset.py
│ │ ├── thgl_github.py
│ │ └── thgl_github_ns_gen.py
│ ├── thgl_myket/
│ │ ├── thgl_myket.py
│ │ └── thgl_myket_ns_gen.py
│ ├── thgl_software/
│ │ ├── thgl_software.py
│ │ └── thgl_software_ns_gen.py
│ ├── tkgl_icews/
│ │ ├── tkgl_icews.py
│ │ └── tkgl_icews_ns_gen.py
│ ├── tkgl_polecat/
│ │ ├── tkgl_polecat.py
│ │ └── tkgl_polecat_ns_gen.py
│ ├── tkgl_smallpedia/
│ │ ├── smallpedia_remove_conflict.py
│ │ └── tkgl_smallpedia_ns_gen.py
│ ├── tkgl_wikidata/
│ │ ├── extract.sh
│ │ ├── time_edges/
│ │ │ └── tkgl-wikidata_extract.py
│ │ ├── tkgl-wikidata.py
│ │ ├── tkgl_wikidata_mining.py
│ │ ├── tkgl_wikidata_ns_gen.py
│ │ └── wikidata_remove_conflict.py
│ └── tkgl_yago/
│ ├── tkgl_yago.py
│ └── tkgl_yago_ns_gen.py
├── linkproppred/
│ ├── dataset.py
│ ├── dataset_pyg.py
│ ├── evaluate.py
│ ├── negative_generator.py
│ ├── negative_sampler.py
│ ├── thg_negative_generator.py
│ ├── thg_negative_sampler.py
│ ├── tkg_negative_generator.py
│ └── tkg_negative_sampler.py
├── nodeproppred/
│ ├── dataset.py
│ ├── dataset_pyg.py
│ └── evaluate.py
└── utils/
├── dataset_stats.py
├── info.py
├── pre_process.py
├── stats.py
└── utils.py
SYMBOL INDEX (992 symbols across 175 files)
FILE: examples/linkproppred/tgbl-coin/dyrep.py
function train (line 42) | def train():
function test (line 111) | def test(loader, neg_sampler, split_mode):
FILE: examples/linkproppred/tgbl-coin/edgebank.py
function test (line 34) | def test(data, test_mask, neg_sampler, split_mode):
function get_args (line 79) | def get_args():
FILE: examples/linkproppred/tgbl-coin/tgn.py
function train (line 44) | def train():
function test (line 112) | def test(loader, neg_sampler, split_mode):
FILE: examples/linkproppred/tgbl-comment/dyrep.py
function train (line 43) | def train():
function test_one_vs_many (line 102) | def test_one_vs_many(loader, neg_sampler, split_mode):
FILE: examples/linkproppred/tgbl-comment/edgebank.py
function test (line 34) | def test(data, test_mask, neg_sampler, split_mode):
function get_args (line 79) | def get_args():
FILE: examples/linkproppred/tgbl-comment/tgn.py
function train (line 44) | def train():
function test (line 112) | def test(loader, neg_sampler, split_mode):
FILE: examples/linkproppred/tgbl-enron/edgebank.py
function test (line 32) | def test(data, test_mask, neg_sampler, split_mode):
function get_args (line 77) | def get_args():
FILE: examples/linkproppred/tgbl-flight/dyrep.py
function train (line 42) | def train():
function test (line 111) | def test(loader, neg_sampler, split_mode):
FILE: examples/linkproppred/tgbl-flight/edgebank.py
function test (line 34) | def test(data, test_mask, neg_sampler, split_mode):
function get_args (line 79) | def get_args():
FILE: examples/linkproppred/tgbl-flight/tgn.py
function train (line 44) | def train():
function test (line 112) | def test(loader, neg_sampler, split_mode):
FILE: examples/linkproppred/tgbl-lastfm/edgebank.py
function test (line 32) | def test(data, test_mask, neg_sampler, split_mode):
function get_args (line 77) | def get_args():
FILE: examples/linkproppred/tgbl-lastfm/tgn.py
function train (line 41) | def train():
function test (line 109) | def test(loader, neg_sampler, split_mode):
FILE: examples/linkproppred/tgbl-review/dyrep.py
function train (line 42) | def train():
function test (line 111) | def test(loader, neg_sampler, split_mode):
FILE: examples/linkproppred/tgbl-review/edgebank.py
function test (line 32) | def test(data, test_mask, neg_sampler, split_mode):
function get_args (line 77) | def get_args():
FILE: examples/linkproppred/tgbl-review/tgn.py
function train (line 44) | def train():
function test (line 112) | def test(loader, neg_sampler, split_mode):
FILE: examples/linkproppred/tgbl-subreddit/edgebank.py
function test (line 32) | def test(data, test_mask, neg_sampler, split_mode):
function get_args (line 77) | def get_args():
FILE: examples/linkproppred/tgbl-subreddit/tgn.py
function train (line 42) | def train():
function test (line 110) | def test(loader, neg_sampler, split_mode):
FILE: examples/linkproppred/tgbl-uci/edgebank.py
function test (line 32) | def test(data, test_mask, neg_sampler, split_mode):
function get_args (line 77) | def get_args():
FILE: examples/linkproppred/tgbl-wiki/dyrep.py
function train (line 42) | def train():
function test (line 111) | def test(loader, neg_sampler, split_mode):
FILE: examples/linkproppred/tgbl-wiki/edgebank.py
function test (line 32) | def test(data, test_mask, neg_sampler, split_mode):
function get_args (line 77) | def get_args():
FILE: examples/linkproppred/tgbl-wiki/tgn.py
function train (line 43) | def train():
function test (line 111) | def test(loader, neg_sampler, split_mode):
FILE: examples/linkproppred/thgl-forum/edgebank.py
function test (line 36) | def test(data, test_mask, neg_sampler, split_mode):
function get_args (line 86) | def get_args():
FILE: examples/linkproppred/thgl-forum/recurrencybaseline.py
function predict (line 34) | def predict(num_processes, data_c_rel, all_data_c_rel, alpha, lmbda_psi,
function test (line 85) | def test(best_config, all_relations,test_data_prel, all_data_prel, neg_s...
function read_dict_compute_mrr (line 128) | def read_dict_compute_mrr(split_mode='test'):
function train (line 164) | def train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampl...
function get_args (line 256) | def get_args():
FILE: examples/linkproppred/thgl-forum/sthn.py
function print_model_info (line 72) | def print_model_info(model):
function get_args (line 78) | def get_args():
function load_model (line 122) | def load_model(args):
function load_graph (line 160) | def load_graph(data):
function load_all_data (line 211) | def load_all_data(args):
function test (line 276) | def test(data, test_mask, model, neg_sampler, split_mode):
FILE: examples/linkproppred/thgl-forum/tgn.py
function train (line 42) | def train():
function test (line 110) | def test(loader, neg_sampler, split_mode):
FILE: examples/linkproppred/thgl-github/edgebank.py
function test (line 36) | def test(data, test_mask, neg_sampler, split_mode):
function get_args (line 86) | def get_args():
FILE: examples/linkproppred/thgl-github/recurrencybaseline.py
function predict (line 35) | def predict(num_processes, data_c_rel, all_data_c_rel, alpha, lmbda_psi,
function test (line 86) | def test(best_config, all_relations,test_data_prel, all_data_prel, neg_s...
function read_dict_compute_mrr (line 129) | def read_dict_compute_mrr(split_mode='test'):
function train (line 165) | def train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampl...
function get_args (line 257) | def get_args():
FILE: examples/linkproppred/thgl-github/sthn.py
function print_model_info (line 72) | def print_model_info(model):
function get_args (line 78) | def get_args():
function load_model (line 122) | def load_model(args):
function load_graph (line 160) | def load_graph(data):
function load_all_data (line 211) | def load_all_data(args):
function test (line 276) | def test(data, test_mask, model, neg_sampler, split_mode):
FILE: examples/linkproppred/thgl-github/tgn.py
function train (line 46) | def train():
function test (line 114) | def test(loader, neg_sampler, split_mode):
FILE: examples/linkproppred/thgl-myket/edgebank.py
function test (line 36) | def test(data, test_mask, neg_sampler, split_mode):
function get_args (line 86) | def get_args():
FILE: examples/linkproppred/thgl-myket/recurrencybaseline.py
function predict (line 34) | def predict(num_processes, data_c_rel, all_data_c_rel, alpha, lmbda_psi,
function test (line 85) | def test(best_config, all_relations,test_data_prel, all_data_prel, neg_s...
function read_dict_compute_mrr (line 128) | def read_dict_compute_mrr(split_mode='test'):
function train (line 164) | def train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampl...
function get_args (line 256) | def get_args():
FILE: examples/linkproppred/thgl-myket/sthn.py
function print_model_info (line 72) | def print_model_info(model):
function get_args (line 78) | def get_args():
function load_model (line 122) | def load_model(args):
function load_graph (line 160) | def load_graph(data):
function load_all_data (line 211) | def load_all_data(args):
function test (line 276) | def test(data, test_mask, model, neg_sampler, split_mode):
FILE: examples/linkproppred/thgl-myket/tgn.py
function train (line 42) | def train():
function test (line 110) | def test(loader, neg_sampler, split_mode):
FILE: examples/linkproppred/thgl-software/edgebank.py
function test (line 36) | def test(data, test_mask, neg_sampler, split_mode):
function get_args (line 86) | def get_args():
FILE: examples/linkproppred/thgl-software/recurrencybaseline.py
function predict (line 34) | def predict(num_processes, data_c_rel, all_data_c_rel, alpha, lmbda_psi,
function test (line 85) | def test(best_config, all_relations,test_data_prel, all_data_prel, neg_s...
function read_dict_compute_mrr (line 128) | def read_dict_compute_mrr(split_mode='test'):
function train (line 164) | def train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampl...
function get_args (line 256) | def get_args():
FILE: examples/linkproppred/thgl-software/sthn.py
function print_model_info (line 72) | def print_model_info(model):
function get_args (line 78) | def get_args():
function load_model (line 122) | def load_model(args):
function load_graph (line 160) | def load_graph(data):
function load_all_data (line 211) | def load_all_data(args):
function test (line 276) | def test(data, test_mask, model, neg_sampler, split_mode):
FILE: examples/linkproppred/thgl-software/tgn.py
function train (line 42) | def train():
function test (line 110) | def test(loader, neg_sampler, split_mode):
FILE: examples/linkproppred/tkgl-icews/cen.py
function test (line 28) | def test(model, history_len, history_list, test_list, num_rels, num_node...
function run_experiment (line 105) | def run_experiment(args, trainvalidtest_id=0, n_hidden=None, n_layers=No...
FILE: examples/linkproppred/tkgl-icews/edgebank.py
function test (line 36) | def test(data, test_mask, neg_sampler, split_mode):
function get_args (line 86) | def get_args():
FILE: examples/linkproppred/tkgl-icews/recurrencybaseline.py
function predict (line 34) | def predict(num_processes, data_c_rel, all_data_c_rel, alpha, lmbda_psi,
function test (line 85) | def test(best_config, all_relations,test_data_prel, all_data_prel, neg_s...
function read_dict_compute_mrr (line 128) | def read_dict_compute_mrr(split_mode='test'):
function train (line 164) | def train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampl...
function get_args (line 256) | def get_args():
FILE: examples/linkproppred/tkgl-icews/regcn.py
function test (line 29) | def test(model, history_list, test_list, num_rels, num_nodes, use_cuda, ...
function run_experiment (line 106) | def run_experiment(args, n_hidden=None, n_layers=None, dropout=None, n_b...
FILE: examples/linkproppred/tkgl-icews/timetraveler.py
class QuadruplesDataset (line 35) | class QuadruplesDataset(Dataset):
method __init__ (line 38) | def __init__(self, examples):
method __len__ (line 46) | def __len__(self):
method __getitem__ (line 49) | def __getitem__(self, item):
function set_logger (line 56) | def set_logger(save_path):
function preprocess_data (line 76) | def preprocess_data(args, config, timestamps, save_path, all_quads):
function log_metrics (line 97) | def log_metrics(mode, step, metrics):
function main (line 102) | def main(args):
FILE: examples/linkproppred/tkgl-icews/tlogic.py
function learn_rules (line 28) | def learn_rules(i, num_relations):
function apply_rules (line 75) | def apply_rules(i, num_queries, rules_dict, neg_sampler, data, window, l...
function get_args (line 230) | def get_args():
FILE: examples/linkproppred/tkgl-polecat/cen.py
function test (line 28) | def test(model, history_len, history_list, test_list, num_rels, num_node...
function run_experiment (line 105) | def run_experiment(args, trainvalidtest_id=0, n_hidden=None, n_layers=No...
FILE: examples/linkproppred/tkgl-polecat/edgebank.py
function test (line 36) | def test(data, test_mask, neg_sampler, split_mode):
function get_args (line 86) | def get_args():
FILE: examples/linkproppred/tkgl-polecat/recurrencybaseline.py
function predict (line 34) | def predict(num_processes, data_c_rel, all_data_c_rel, alpha, lmbda_psi,
function test (line 85) | def test(best_config, all_relations,test_data_prel, all_data_prel, neg_s...
function read_dict_compute_mrr (line 128) | def read_dict_compute_mrr(split_mode='test'):
function train (line 164) | def train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampl...
function get_args (line 256) | def get_args():
FILE: examples/linkproppred/tkgl-polecat/regcn.py
function test (line 29) | def test(model, history_list, test_list, num_rels, num_nodes, use_cuda, ...
function run_experiment (line 106) | def run_experiment(args, n_hidden=None, n_layers=None, dropout=None, n_b...
FILE: examples/linkproppred/tkgl-polecat/timetraveler.py
class QuadruplesDataset (line 35) | class QuadruplesDataset(Dataset):
method __init__ (line 38) | def __init__(self, examples):
method __len__ (line 46) | def __len__(self):
method __getitem__ (line 49) | def __getitem__(self, item):
function set_logger (line 56) | def set_logger(save_path):
function preprocess_data (line 76) | def preprocess_data(args, config, timestamps, save_path, all_quads):
function log_metrics (line 97) | def log_metrics(mode, step, metrics):
function main (line 102) | def main(args):
FILE: examples/linkproppred/tkgl-polecat/tlogic.py
function learn_rules (line 28) | def learn_rules(i, num_relations):
function apply_rules (line 75) | def apply_rules(i, num_queries, rules_dict, neg_sampler, data, window, l...
function get_args (line 230) | def get_args():
FILE: examples/linkproppred/tkgl-smallpedia/cen.py
function test (line 28) | def test(model, history_len, history_list, test_list, num_rels, num_node...
function run_experiment (line 105) | def run_experiment(args, trainvalidtest_id=0, n_hidden=None, n_layers=No...
FILE: examples/linkproppred/tkgl-smallpedia/edgebank.py
function test (line 36) | def test(data, test_mask, neg_sampler, split_mode):
function get_args (line 86) | def get_args():
FILE: examples/linkproppred/tkgl-smallpedia/recurrencybaseline.py
function predict (line 34) | def predict(num_processes, data_c_rel, all_data_c_rel, alpha, lmbda_psi,
function test (line 85) | def test(best_config, all_relations,test_data_prel, all_data_prel, neg_s...
function read_dict_compute_mrr (line 128) | def read_dict_compute_mrr(split_mode='test'):
function train (line 164) | def train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampl...
function get_args (line 256) | def get_args():
FILE: examples/linkproppred/tkgl-smallpedia/regcn.py
function test (line 29) | def test(model, history_list, test_list, num_rels, num_nodes, use_cuda, ...
function run_experiment (line 106) | def run_experiment(args, n_hidden=None, n_layers=None, dropout=None, n_b...
FILE: examples/linkproppred/tkgl-smallpedia/timetraveler.py
class QuadruplesDataset (line 35) | class QuadruplesDataset(Dataset):
method __init__ (line 38) | def __init__(self, examples):
method __len__ (line 46) | def __len__(self):
method __getitem__ (line 49) | def __getitem__(self, item):
function set_logger (line 56) | def set_logger(save_path):
function preprocess_data (line 76) | def preprocess_data(args, config, timestamps, save_path, all_quads):
function log_metrics (line 97) | def log_metrics(mode, step, metrics):
function main (line 102) | def main(args):
FILE: examples/linkproppred/tkgl-smallpedia/tlogic.py
function learn_rules (line 28) | def learn_rules(i, num_relations):
function apply_rules (line 75) | def apply_rules(i, num_queries, rules_dict, neg_sampler, data, window, l...
function get_args (line 230) | def get_args():
FILE: examples/linkproppred/tkgl-wikidata/edgebank.py
function test (line 36) | def test(data, test_mask, neg_sampler, split_mode):
function get_args (line 86) | def get_args():
FILE: examples/linkproppred/tkgl-wikidata/recurrencybaseline.py
function predict (line 34) | def predict(num_processes, data_c_rel, all_data_c_rel, alpha, lmbda_psi,
function test (line 85) | def test(best_config, all_relations,test_data_prel, all_data_prel, neg_s...
function read_dict_compute_mrr (line 128) | def read_dict_compute_mrr(split_mode='test'):
function train (line 164) | def train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampl...
function get_args (line 256) | def get_args():
FILE: examples/linkproppred/tkgl-wikidata/regcn.py
function test (line 29) | def test(model, history_list, test_list, num_rels, num_nodes, use_cuda, ...
function run_experiment (line 106) | def run_experiment(args, n_hidden=None, n_layers=None, dropout=None, n_b...
FILE: examples/linkproppred/tkgl-wikidata/tlogic.py
function learn_rules (line 28) | def learn_rules(i, num_relations):
function apply_rules (line 75) | def apply_rules(i, num_queries, rules_dict, neg_sampler, data, window, l...
function get_args (line 230) | def get_args():
FILE: examples/linkproppred/tkgl-yago/cen.py
function test (line 28) | def test(model, history_len, history_list, test_list, num_rels, num_node...
function run_experiment (line 105) | def run_experiment(args, trainvalidtest_id=0, n_hidden=None, n_layers=No...
FILE: examples/linkproppred/tkgl-yago/edgebank.py
function test (line 36) | def test(data, test_mask, neg_sampler, split_mode):
function get_args (line 86) | def get_args():
FILE: examples/linkproppred/tkgl-yago/recurrencybaseline.py
function predict (line 34) | def predict(num_processes, data_c_rel, all_data_c_rel, alpha, lmbda_psi,
function test (line 85) | def test(best_config, all_relations,test_data_prel, all_data_prel, neg_s...
function read_dict_compute_mrr (line 128) | def read_dict_compute_mrr(split_mode='test'):
function train (line 164) | def train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampl...
function get_args (line 256) | def get_args():
FILE: examples/linkproppred/tkgl-yago/regcn.py
function test (line 29) | def test(model, history_list, test_list, num_rels, num_nodes, use_cuda, ...
function run_experiment (line 106) | def run_experiment(args, n_hidden=None, n_layers=None, dropout=None, n_b...
FILE: examples/linkproppred/tkgl-yago/timetraveler.py
class QuadruplesDataset (line 35) | class QuadruplesDataset(Dataset):
method __init__ (line 38) | def __init__(self, examples):
method __len__ (line 46) | def __len__(self):
method __getitem__ (line 49) | def __getitem__(self, item):
function set_logger (line 56) | def set_logger(save_path):
function preprocess_data (line 76) | def preprocess_data(args, config, timestamps, save_path, all_quads):
function log_metrics (line 97) | def log_metrics(mode, step, metrics):
function main (line 102) | def main(args):
FILE: examples/linkproppred/tkgl-yago/tlogic.py
function learn_rules (line 28) | def learn_rules(i, num_relations):
function apply_rules (line 75) | def apply_rules(i, num_queries, rules_dict, neg_sampler, data, window, l...
function get_args (line 230) | def get_args():
FILE: examples/nodeproppred/tgbn-genre/dyrep.py
function process_edges (line 29) | def process_edges(src, dst, t, msg):
function train (line 39) | def train():
function test (line 139) | def test(loader):
FILE: examples/nodeproppred/tgbn-genre/moving_average.py
function test_n_upate (line 41) | def test_n_upate(loader):
FILE: examples/nodeproppred/tgbn-genre/persistant_forecast.py
function test_n_upate (line 52) | def test_n_upate(loader):
FILE: examples/nodeproppred/tgbn-genre/tgn.py
function plot_curve (line 100) | def plot_curve(scores, out_name):
function process_edges (line 107) | def process_edges(src, dst, t, msg):
function train (line 114) | def train():
function test (line 215) | def test(loader):
FILE: examples/nodeproppred/tgbn-reddit/dyrep.py
function process_edges (line 31) | def process_edges(src, dst, t, msg):
function train (line 41) | def train():
function test (line 141) | def test(loader):
FILE: examples/nodeproppred/tgbn-reddit/moving_average.py
function test_n_upate (line 41) | def test_n_upate(loader):
FILE: examples/nodeproppred/tgbn-reddit/persistant_forecast.py
function test_n_upate (line 52) | def test_n_upate(loader):
FILE: examples/nodeproppred/tgbn-reddit/tgn.py
function plot_curve (line 101) | def plot_curve(scores, out_name):
function process_edges (line 108) | def process_edges(src, dst, t, msg):
function train (line 115) | def train():
function test (line 215) | def test(loader):
FILE: examples/nodeproppred/tgbn-token/dyrep.py
function process_edges (line 31) | def process_edges(src, dst, t, msg):
function train (line 41) | def train():
function test (line 141) | def test(loader):
FILE: examples/nodeproppred/tgbn-token/moving_average.py
function test_n_upate (line 42) | def test_n_upate(loader):
FILE: examples/nodeproppred/tgbn-token/persistant_forecast.py
function test_n_upate (line 52) | def test_n_upate(loader):
FILE: examples/nodeproppred/tgbn-token/tgn.py
function plot_curve (line 102) | def plot_curve(scores, out_name):
function process_edges (line 109) | def process_edges(src, dst, t, msg):
function train (line 116) | def train():
function test (line 215) | def test(loader):
FILE: examples/nodeproppred/tgbn-trade/count_new_nodes.py
function count_nodes (line 23) | def count_nodes(data, test_mask, nodebank):
function get_args (line 71) | def get_args():
FILE: examples/nodeproppred/tgbn-trade/dyrep.py
function process_edges (line 29) | def process_edges(src, dst, t, msg):
function train (line 39) | def train():
function test (line 139) | def test(loader):
FILE: examples/nodeproppred/tgbn-trade/moving_average.py
function test_n_upate (line 43) | def test_n_upate(loader):
FILE: examples/nodeproppred/tgbn-trade/persistant_forecast.py
function test_n_upate (line 42) | def test_n_upate(loader):
FILE: examples/nodeproppred/tgbn-trade/tgn.py
function plot_curve (line 100) | def plot_curve(scores, out_name):
function process_edges (line 107) | def process_edges(src, dst, t, msg):
function train (line 114) | def train():
function test (line 215) | def test(loader):
FILE: modules/decoder.py
class LinkPredictor (line 12) | class LinkPredictor(torch.nn.Module):
method __init__ (line 18) | def __init__(self, in_channels):
method forward (line 24) | def forward(self, z_src, z_dst):
class NodePredictor (line 30) | class NodePredictor(torch.nn.Module):
method __init__ (line 31) | def __init__(self, in_dim, out_dim):
method forward (line 36) | def forward(self, node_embed):
class ConvTransE (line 45) | class ConvTransE(torch.nn.Module):
method __init__ (line 49) | def __init__(self, num_entities, embedding_dim, input_dropout=0, hidde...
method forward (line 75) | def forward(self, embedding, emb_rel, triplets, partial_embeding=None,...
method forward_inner (line 95) | def forward_inner(self, embedding, emb_rel, triplets, idx=0, partial_e...
FILE: modules/early_stopping.py
class EarlyStopMonitor (line 10) | class EarlyStopMonitor(object):
method __init__ (line 12) | def __init__(self, save_model_dir: str, save_model_id: str,
method get_best_model_path (line 38) | def get_best_model_path(self):
method step_check (line 44) | def step_check(self, curr_metric: float, models_dict: dict):
method save_checkpoint (line 68) | def save_checkpoint(self, models_dict: dict):
method load_checkpoint (line 80) | def load_checkpoint(self, models_dict: dict):
FILE: modules/edgebank_predictor.py
class EdgeBankPredictor (line 13) | class EdgeBankPredictor(object):
method __init__ (line 14) | def __init__(
method update_memory (line 52) | def update_memory(self,
method start_time (line 72) | def start_time(self) -> int:
method end_time (line 83) | def end_time(self) -> int:
method _update_unlimited_memory (line 93) | def _update_unlimited_memory(self,
method _update_time_window_memory (line 106) | def _update_time_window_memory(self,
method predict_link (line 135) | def predict_link(self,
FILE: modules/emb_module.py
class GraphAttentionEmbedding (line 11) | class GraphAttentionEmbedding(torch.nn.Module):
method __init__ (line 17) | def __init__(self, in_channels, out_channels, msg_dim, time_enc):
method forward (line 25) | def forward(self, x, last_update, edge_index, t, msg):
class TimeEmbedding (line 32) | class TimeEmbedding(torch.nn.Module):
method __init__ (line 33) | def __init__(self, in_channels, out_channels):
method forward (line 48) | def forward(self, x, last_update, t):
FILE: modules/heuristics.py
class PersistantForecaster (line 4) | class PersistantForecaster:
method __init__ (line 5) | def __init__(self, num_class):
method update_dict (line 9) | def update_dict(self, node_id, label):
method query_dict (line 12) | def query_dict(self, node_id):
class MovingAverage (line 25) | class MovingAverage:
method __init__ (line 26) | def __init__(self, num_class, window=7):
method update_dict (line 31) | def update_dict(self, node_id, label):
method query_dict (line 38) | def query_dict(self, node_id):
FILE: modules/memory_module.py
class TGNMemory (line 25) | class TGNMemory(torch.nn.Module):
method __init__ (line 49) | def __init__(
method device (line 91) | def device(self) -> torch.device:
method reset_parameters (line 94) | def reset_parameters(self):
method reset_state (line 106) | def reset_state(self):
method detach (line 112) | def detach(self):
method forward (line 116) | def forward(self, n_id: Tensor) -> Tuple[Tensor, Tensor]:
method update_state (line 126) | def update_state(self, src: Tensor, dst: Tensor, t: Tensor, raw_msg: T...
method _reset_message_store (line 140) | def _reset_message_store(self):
method _update_memory (line 147) | def _update_memory(self, n_id: Tensor):
method _get_updated_memory (line 152) | def _get_updated_memory(self, n_id: Tensor) -> Tuple[Tensor, Tensor]:
method _update_msg_store (line 180) | def _update_msg_store(
method _compute_msg (line 193) | def _compute_msg(
method train (line 209) | def train(self, mode: bool = True):
class DyRepMemory (line 218) | class DyRepMemory(torch.nn.Module):
method __init__ (line 242) | def __init__(self, num_nodes: int, raw_msg_dim: int, memory_dim: int,
method device (line 281) | def device(self) -> torch.device:
method reset_parameters (line 284) | def reset_parameters(self):
method reset_state (line 296) | def reset_state(self):
method detach (line 302) | def detach(self):
method forward (line 306) | def forward(self, n_id: Tensor) -> Tuple[Tensor, Tensor]:
method update_state (line 316) | def update_state(self, src: Tensor, dst: Tensor, t: Tensor, raw_msg: T...
method _reset_message_store (line 331) | def _reset_message_store(self):
method _update_memory (line 338) | def _update_memory(self, n_id: Tensor, embeddings: Tensor = None, asso...
method _get_updated_memory (line 343) | def _get_updated_memory(self, n_id: Tensor, embeddings: Tensor = None,...
method _update_msg_store (line 369) | def _update_msg_store(self, src: Tensor, dst: Tensor, t: Tensor,
method _compute_msg (line 376) | def _compute_msg(self, n_id: Tensor, msg_store: TGNMessageStoreType, m...
method train (line 414) | def train(self, mode: bool = True):
FILE: modules/msg_agg.py
class LastAggregator (line 15) | class LastAggregator(torch.nn.Module):
method forward (line 16) | def forward(self, msg: Tensor, index: Tensor, t: Tensor, dim_size: int):
class MeanAggregator (line 24) | class MeanAggregator(torch.nn.Module):
method forward (line 25) | def forward(self, msg: Tensor, index: Tensor, t: Tensor, dim_size: int):
FILE: modules/msg_func.py
class IdentityMessage (line 12) | class IdentityMessage(torch.nn.Module):
method __init__ (line 13) | def __init__(self, raw_msg_dim: int, memory_dim: int, time_dim: int):
method forward (line 17) | def forward(self, z_src: Tensor, z_dst: Tensor, raw_msg: Tensor, t_enc...
FILE: modules/neighbor_loader.py
class LastNeighborLoader (line 15) | class LastNeighborLoader:
method __init__ (line 16) | def __init__(self, num_nodes: int, size: int, device=None):
method __call__ (line 25) | def __call__(self, n_id: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
method insert (line 41) | def insert(self, src: Tensor, dst: Tensor):
method reset_state (line 83) | def reset_state(self):
FILE: modules/nodebank.py
class NodeBank (line 4) | class NodeBank(object):
method __init__ (line 5) | def __init__(
method update_memory (line 21) | def update_memory(self,
method query_node (line 37) | def query_node(self, node: int) -> bool:
FILE: modules/recurrencybaseline_predictor.py
function baseline_predict_remote (line 20) | def baseline_predict_remote(num_queries, test_data, all_data, window, ba...
function baseline_predict (line 29) | def baseline_predict(num_queries, test_data, all_data, window, basis_dic...
function match_body_relations (line 134) | def match_body_relations(rule, edges, test_query_sub):
function score_delta (line 165) | def score_delta(cands_ts, test_query_ts, lmbda):
function get_window_edges (line 177) | def get_window_edges(all_data, test_query_ts, window=-2, first_test_quer...
function quads_per_rel (line 219) | def quads_per_rel(quads):
function get_candidates_psi (line 237) | def get_candidates_psi(rule_walks, test_query_ts, cands_dict,lmbda, sum_...
function update_delta_t (line 262) | def update_delta_t(min_ts, max_ts, cur_ts, lmbda):
function score_psi (line 278) | def score_psi(cands_walks, test_query_ts, lmbda, sum_delta_t):
function update_distributions (line 304) | def update_distributions(ts_edges,num_rels):
function calculate_obj_distribution (line 311) | def calculate_obj_distribution(edges, num_rels):
function update_delta_t (line 334) | def update_delta_t(min_ts, max_ts, cur_ts, lmbda):
FILE: modules/rgcn_layers.py
class RGCNLayer (line 11) | class RGCNLayer(nn.Module):
method __init__ (line 12) | def __init__(self, in_feat, out_feat, bias=None, activation=None,
method propagate (line 52) | def propagate(self, g):
method forward (line 55) | def forward(self, g, prev_h=[]):
class RGCNBasisLayer (line 101) | class RGCNBasisLayer(RGCNLayer):
method __init__ (line 102) | def __init__(self, in_feat, out_feat, num_rels, num_bases=-1, bias=None,
method propagate (line 125) | def propagate(self, g):
class RGCNBlockLayer (line 154) | class RGCNBlockLayer(RGCNLayer):
method __init__ (line 155) | def __init__(self, in_feat, out_feat, num_rels, num_bases, bias=None,
method msg_func (line 174) | def msg_func(self, edges):
method propagate (line 181) | def propagate(self, g):
method apply_func (line 185) | def apply_func(self, nodes):
class UnionRGCNLayer (line 189) | class UnionRGCNLayer(nn.Module):
method __init__ (line 190) | def __init__(self, in_feat, out_feat, num_rels, num_bases=-1, bias=None,
method propagate (line 226) | def propagate(self, g):
method forward (line 229) | def forward(self, g, prev_h):
method msg_func (line 263) | def msg_func(self, edges):
method apply_func (line 284) | def apply_func(self, nodes):
FILE: modules/rgcn_model.py
class BaseRGCN (line 8) | class BaseRGCN(nn.Module):
method __init__ (line 9) | def __init__(self, num_nodes, h_dim, out_dim, num_rels, num_bases=-1, ...
method build_model (line 35) | def build_model(self):
method create_features (line 52) | def create_features(self):
method build_input_layer (line 55) | def build_input_layer(self):
method build_hidden_layer (line 58) | def build_hidden_layer(self, idx):
method build_output_layer (line 61) | def build_output_layer(self):
method forward (line 64) | def forward(self, g):
FILE: modules/rrgcn.py
class RGCNCell (line 13) | class RGCNCell(BaseRGCN):
method build_hidden_layer (line 14) | def build_hidden_layer(self, idx):
method forward (line 30) | def forward(self, g, init_ent_emb):
class RecurrentRGCNCEN (line 52) | class RecurrentRGCNCEN(nn.Module):
method __init__ (line 53) | def __init__(self, decoder_name, encoder_name, num_ents, num_rels, h_d...
method forward (line 109) | def forward(self, g_list, use_cuda):
method predict (line 122) | def predict(self, test_graph, test_triplets, use_cuda, neg_samples_bat...
method get_ft_loss (line 165) | def get_ft_loss(self, glist, triple_list, use_cuda):
method get_loss (line 186) | def get_loss(self, glist, triples, prev_model, use_cuda):
class RecurrentRGCNREGCN (line 208) | class RecurrentRGCNREGCN(nn.Module):
method __init__ (line 209) | def __init__(self, decoder_name, encoder_name, num_ents, num_rels, num...
method forward (line 293) | def forward(self, g_list, static_graph, use_cuda):
method predict (line 336) | def predict(self, test_graph, num_rels, static_graph, test_triplets, u...
method get_mask_nonzero (line 370) | def get_mask_nonzero(self, static_embedding):
method get_loss (line 377) | def get_loss(self, glist, triples, static_graph, use_cuda):
FILE: modules/sampler_core.cpp
class TemporalGraphBlock (line 17) | class TemporalGraphBlock
method TemporalGraphBlock (line 33) | TemporalGraphBlock() {}
method TemporalGraphBlock (line 34) | TemporalGraphBlock(std::vector<NodeIDType> &_row, std::vector<NodeIDTy...
class ParallelSampler (line 41) | class ParallelSampler
method ParallelSampler (line 63) | ParallelSampler(std::vector<EdgeIDType> &_indptr, std::vector<EdgeIDTy...
method reset (line 89) | void reset()
method update_ts_ptr (line 100) | void update_ts_ptr(int slc, std::vector<NodeIDType> &root_nodes,
method add_neighbor (line 126) | inline void add_neighbor(std::vector<NodeIDType> *_row, std::vector<No...
method combine_coo (line 150) | inline void combine_coo(TemporalGraphBlock &_ret, std::vector<NodeIDTy...
method sample_layer (line 200) | void sample_layer(std::vector<NodeIDType> &_root_nodes, std::vector<Ti...
method sample (line 336) | void sample(std::vector<NodeIDType> &root_nodes, std::vector<TimeStamp...
function vec2npy (line 362) | inline py::array vec2npy(const std::vector<T> &vec)
function PYBIND11_MODULE (line 373) | PYBIND11_MODULE(sampler_core, m)
FILE: modules/sthn.py
function set_seed (line 31) | def set_seed(seed):
function row_norm (line 37) | def row_norm(adj_t):
class NegLinkSampler (line 61) | class NegLinkSampler:
method __init__ (line 65) | def __init__(self, num_nodes):
method sample (line 68) | def sample(self, n):
function get_parallel_sampler (line 71) | def get_parallel_sampler(g, num_neighbors=10):
function get_mini_batch (line 101) | def get_mini_batch(sampler, root_nodes, ts, num_hops): # neg_samples is ...
function fetch_subgraph (line 113) | def fetch_subgraph(sampler, root_node, root_time, num_hops):
function construct_mini_batch_giant_graph (line 181) | def construct_mini_batch_giant_graph(all_graphs, max_num_edges):
function print_subgraph_data (line 255) | def print_subgraph_data(subgraph_data):
class SubgraphSampler (line 278) | class SubgraphSampler:
method __init__ (line 279) | def __init__(self, all_root_nodes, all_ts, sampler, args):
method mini_batch (line 285) | def mini_batch(self, ind, mini_batch_inds):
function get_subgraph_sampler (line 290) | def get_subgraph_sampler(args, g, df, mode):
function pre_compute_subgraphs (line 342) | def pre_compute_subgraphs(args, g, df, mode, negative_sampler=None, spli...
function get_random_inds (line 429) | def get_random_inds(num_subgraph, cached_neg_samples, neg_samples):
function get_all_inds (line 442) | def get_all_inds(num_subgraph, neg_samples):
function check_data_leakage (line 454) | def check_data_leakage(args, g, df):
function get_inputs_for_ind (line 492) | def get_inputs_for_ind(subgraphs, mode, cached_neg_samples, neg_samples,...
function run (line 555) | def run(model, optimizer, args, subgraphs, df, node_feats, edge_feats, M...
function link_pred_train (line 616) | def link_pred_train(model, args, g, df, node_feats, edge_feats):
function compute_sign_feats (line 696) | def compute_sign_feats(node_feats, df, start_i, num_links, root_nodes, a...
function get_emb (line 750) | def get_emb(sin_inp):
class PositionalEncoding1D (line 758) | class PositionalEncoding1D(nn.Module):
method __init__ (line 759) | def __init__(self, channels):
method forward (line 771) | def forward(self, tensor):
class PositionalEncodingPermute1D (line 794) | class PositionalEncodingPermute1D(nn.Module):
method __init__ (line 795) | def __init__(self, channels):
method forward (line 802) | def forward(self, tensor):
method org_channels (line 808) | def org_channels(self):
class PositionalEncoding2D (line 812) | class PositionalEncoding2D(nn.Module):
method __init__ (line 813) | def __init__(self, channels):
method forward (line 825) | def forward(self, tensor):
class PositionalEncodingPermute2D (line 854) | class PositionalEncodingPermute2D(nn.Module):
method __init__ (line 855) | def __init__(self, channels):
method forward (line 862) | def forward(self, tensor):
method org_channels (line 868) | def org_channels(self):
class PositionalEncoding3D (line 872) | class PositionalEncoding3D(nn.Module):
method __init__ (line 873) | def __init__(self, channels):
method forward (line 887) | def forward(self, tensor):
class PositionalEncodingPermute3D (line 920) | class PositionalEncodingPermute3D(nn.Module):
method __init__ (line 921) | def __init__(self, channels):
method forward (line 928) | def forward(self, tensor):
method org_channels (line 934) | def org_channels(self):
class Summer (line 938) | class Summer(nn.Module):
method __init__ (line 939) | def __init__(self, penc):
method forward (line 946) | def forward(self, tensor):
class TimeEncode (line 970) | class TimeEncode(nn.Module):
method __init__ (line 975) | def __init__(self, dim):
method reset_parameters (line 981) | def reset_parameters(self, ):
method forward (line 989) | def forward(self, t):
class FeedForward (line 1002) | class FeedForward(nn.Module):
method __init__ (line 1006) | def __init__(self, dims, expansion_factor, dropout=0, use_single_layer...
method reset_parameters (line 1023) | def reset_parameters(self):
method forward (line 1028) | def forward(self, x):
class TransformerBlock (line 1039) | class TransformerBlock(nn.Module):
method __init__ (line 1044) | def __init__(self, dims,
method reset_parameters (line 1066) | def reset_parameters(self):
method token_mixer (line 1073) | def token_mixer(self, x):
method channel_mixer (line 1077) | def channel_mixer(self, x):
method forward (line 1082) | def forward(self, x):
class _MultiheadAttention (line 1090) | class _MultiheadAttention(nn.Module):
method __init__ (line 1091) | def __init__(self, d_model, n_heads, d_k=None, d_v=None, res_attention...
method reset_parameters (line 1115) | def reset_parameters(self):
method forward (line 1121) | def forward(self, Q:Tensor, K:Optional[Tensor]=None, V:Optional[Tensor...
class _ScaledDotProductAttention (line 1143) | class _ScaledDotProductAttention(nn.Module):
method __init__ (line 1148) | def __init__(self, d_model, n_heads, attn_dropout=0., res_attention=Fa...
method forward (line 1156) | def forward(self, q:Tensor, k:Tensor, v:Tensor, prev:Optional[Tensor]=...
class FeatEncode (line 1200) | class FeatEncode(nn.Module):
method __init__ (line 1204) | def __init__(self, time_dims, feat_dims, out_dims):
method reset_parameters (line 1211) | def reset_parameters(self):
method forward (line 1215) | def forward(self, edge_feats, edge_ts):
class Patch_Encoding (line 1220) | class Patch_Encoding(nn.Module):
method __init__ (line 1225) | def __init__(self, per_graph_size, time_channels,
method reset_parameters (line 1261) | def reset_parameters(self):
method forward (line 1268) | def forward(self, edge_feats, edge_ts, batch_size, inds):
class EdgePredictor_per_node (line 1293) | class EdgePredictor_per_node(torch.nn.Module):
method __init__ (line 1298) | def __init__(self, dim_in_time, dim_in_node, predict_class):
method reset_parameters (line 1311) | def reset_parameters(self, ):
method forward (line 1316) | def forward(self, h, neg_samples=1):
class STHN_Interface (line 1328) | class STHN_Interface(nn.Module):
method __init__ (line 1329) | def __init__(self, mlp_mixer_configs, edge_predictor_configs):
method reset_parameters (line 1342) | def reset_parameters(self):
method forward (line 1347) | def forward(self, model_inputs, neg_samples, node_feats):
method predict (line 1355) | def predict(self, model_inputs, neg_samples, node_feats):
class Multiclass_Interface (line 1369) | class Multiclass_Interface(nn.Module):
method __init__ (line 1370) | def __init__(self, mlp_mixer_configs, edge_predictor_configs):
method reset_parameters (line 1383) | def reset_parameters(self):
method forward (line 1388) | def forward(self, model_inputs, neg_samples, node_feats):
method predict (line 1399) | def predict(self, model_inputs, neg_samples, node_feats):
FILE: modules/time_enc.py
class TimeEncoder (line 14) | class TimeEncoder(torch.nn.Module):
method __init__ (line 15) | def __init__(self, out_channels: int):
method reset_parameters (line 20) | def reset_parameters(self):
method forward (line 23) | def forward(self, t: Tensor) -> Tensor:
FILE: modules/timetraveler_agent.py
class HistoryEncoder (line 14) | class HistoryEncoder(nn.Module):
method __init__ (line 15) | def __init__(self, config):
method set_hiddenx (line 21) | def set_hiddenx(self, batch_size):
method forward (line 30) | def forward(self, prev_action, mask):
class PolicyMLP (line 37) | class PolicyMLP(nn.Module):
method __init__ (line 38) | def __init__(self, config):
method forward (line 43) | def forward(self, state_query):
class DynamicEmbedding (line 48) | class DynamicEmbedding(nn.Module):
method __init__ (line 49) | def __init__(self, n_ent, dim_ent, dim_t):
method forward (line 55) | def forward(self, entities, dt):
class StaticEmbedding (line 67) | class StaticEmbedding(nn.Module):
method __init__ (line 68) | def __init__(self, n_ent, dim_ent):
method forward (line 72) | def forward(self, entities, timestamps=None):
class Agent (line 75) | class Agent(nn.Module):
method __init__ (line 76) | def __init__(self, config):
method forward (line 101) | def forward(self, prev_relation, current_entities, current_timestamps,
method get_im_embedding (line 164) | def get_im_embedding(self, cooccurrence_entities):
method update_entity_embedding (line 173) | def update_entity_embedding(self, entity, ims, mu):
method entities_embedding_shift (line 182) | def entities_embedding_shift(self, entity, im, mu):
method back_entities_embedding (line 187) | def back_entities_embedding(self, entity):
FILE: modules/timetraveler_dirichlet.py
class NotConvergingError (line 54) | class NotConvergingError(Exception):
function test (line 60) | def test(D1, D2, method="meanprecision", maxiter=None):
function pdf (line 104) | def pdf(alphas):
function meanprecision (line 136) | def meanprecision(a):
function loglikelihood (line 154) | def loglikelihood(D, a):
function mle (line 172) | def mle(D, tol=1e-7, method="meanprecision", maxiter=None):
function _fixedpoint (line 201) | def _fixedpoint(D, tol=1e-7, maxiter=None):
function _meanprecision (line 236) | def _meanprecision(D, tol=1e-7, maxiter=None):
function _fit_s (line 285) | def _fit_s(D, a0, logp, tol=1e-7, maxiter=1000):
function _fit_m (line 332) | def _fit_m(D, a0, logp, tol=1e-7, maxiter=1000):
function _init_a (line 365) | def _init_a(D):
function _ipsi (line 381) | def _ipsi(y, tol=1.48e-9, maxiter=10):
function _trigamma (line 414) | def _trigamma(x):
class MLE_Dirchlet (line 418) | class MLE_Dirchlet(object):
method __init__ (line 419) | def __init__(self, trainQuads, num_r, k, timespan,
method get_entity_occ_times (line 441) | def get_entity_occ_times(self, trainQuads):
method get_relations_observed_data (line 454) | def get_relations_observed_data(self, trainQuads):
method mle_dirchlet (line 479) | def mle_dirchlet(self):
class Dirichlet (line 488) | class Dirichlet(object):
method __init__ (line 489) | def __init__(self, alphas, k):
method __call__ (line 498) | def __call__(self, rel, dt):
FILE: modules/timetraveler_environment.py
class Env (line 14) | class Env(object):
method __init__ (line 15) | def __init__(self, examples, config, state_action_space=None):
method build_graph (line 33) | def build_graph(self, examples):
method get_state_actions_space_complete (line 63) | def get_state_actions_space_complete(self, entity, time, current_=Fals...
method next_actions (line 96) | def next_actions(self, entites, times, query_times, max_action_num=200...
method get_padd_actions (line 123) | def get_padd_actions(self, entites, times, query_times, max_action_num...
FILE: modules/timetraveler_episode.py
class Episode (line 12) | class Episode(nn.Module):
method __init__ (line 13) | def __init__(self, env, agent, config):
method forward (line 22) | def forward(self, query_entities, query_timestamps, query_relations):
method beam_search (line 85) | def beam_search(self, query_entities, query_timestamps, query_relations):
FILE: modules/timetraveler_policygradient.py
class ReactiveBaseline (line 14) | class ReactiveBaseline(object):
method __init__ (line 15) | def __init__(self, config, update_rate):
method get_baseline_value (line 21) | def get_baseline_value(self):
method update (line 24) | def update(self, target):
class PG (line 27) | class PG(object):
method __init__ (line 28) | def __init__(self, config):
method get_reward (line 35) | def get_reward(self, current_entites, answers):
method calc_cum_discounted_reward (line 41) | def calc_cum_discounted_reward(self, rewards):
method entropy_reg_loss (line 54) | def entropy_reg_loss(self, all_logits):
method calc_reinforce_loss (line 59) | def calc_reinforce_loss(self, all_loss, all_logits, cum_discounted_rew...
FILE: modules/timetraveler_trainertester.py
class Trainer (line 7) | class Trainer(object):
method __init__ (line 8) | def __init__(self, model, pg, optimizer, args, distribution=None):
method train_epoch (line 15) | def train_epoch(self, dataloader, ntriple):
method save_model (line 66) | def save_model(self, save_path, checkpoint_path='checkpoint.pth'):
class Tester (line 79) | class Tester(object):
method __init__ (line 80) | def __init__(self, model, args, train_entities, RelEntCooccurrence, me...
method get_rank (line 88) | def get_rank(self, score, answer, entities_space, num_ent):
method test (line 106) | def test(self, dataloader, ntriple, num_nodes, neg_sampler, evaluator,...
function getRelEntCooccurrence (line 247) | def getRelEntCooccurrence(quadruples, num_rels):
FILE: modules/tkg_utils.py
function get_args_timetraveler (line 9) | def get_args_timetraveler(args=None):
function get_model_config_timetraveler (line 78) | def get_model_config_timetraveler(args, num_ent, num_rel):
function get_args_cen (line 103) | def get_args_cen():
function get_args_regcn (line 222) | def get_args_regcn():
function compute_min_distance (line 339) | def compute_min_distance(unique_sorted_timestamps):
function compute_maxminmean_distances (line 347) | def compute_maxminmean_distances(unique_sorted_timestamps):
function group_by (line 361) | def group_by(data: np.array, key_idx: int) -> dict:
function tkg_granularity_lookup (line 374) | def tkg_granularity_lookup(dataset_name, ts_distmean):
function reformat_ts (line 386) | def reformat_ts(timestamps, dataset_name='tkgl'):
function get_original_ts (line 412) | def get_original_ts(reformatted_ts, ts_dist, min_ts):
function create_basis_dict (line 425) | def create_basis_dict(data):
function get_inv_relation_id (line 444) | def get_inv_relation_id(num_rels):
function create_scores_array (line 460) | def create_scores_array(predictions_dict, num_nodes):
FILE: modules/tkg_utils_dgl.py
function build_sub_graph (line 7) | def build_sub_graph(num_nodes, num_rels, triples, use_cuda, gpu, mode='d...
function r2e (line 51) | def r2e(triplets, num_rels):
FILE: modules/tlogic_apply_modules.py
function filter_rules (line 14) | def filter_rules(rules_dict, min_conf, min_body_supp, rule_lengths):
function get_window_edges (line 44) | def get_window_edges(all_data, test_query_ts, learn_edges, window=-1, fi...
function match_body_relations (line 89) | def match_body_relations(rule, edges, test_query_sub):
function match_body_relations_complete (line 137) | def match_body_relations_complete(rule, edges, test_query_sub):
function get_walks (line 180) | def get_walks(rule, walk_edges):
function get_walks_complete (line 231) | def get_walks_complete(rule, walk_edges):
function check_var_constraints (line 279) | def check_var_constraints(var_constraints, rule_walks):
function get_candidates (line 301) | def get_candidates(
function save_candidates (line 338) | def save_candidates(
function verbalize_walk (line 367) | def verbalize_walk(walk, data):
function score1 (line 391) | def score1(rule, c=0):
function score2 (line 408) | def score2(cands_walks, test_query_ts, lmbda):
function score_12 (line 429) | def score_12(rule, cands_walks, test_query_ts, lmbda, a):
FILE: modules/tlogic_learn_modules.py
class Temporal_Walk (line 15) | class Temporal_Walk(object):
method __init__ (line 16) | def __init__(self, learn_data, inv_relation_id, transition_distr):
method sample_start_edge (line 37) | def sample_start_edge(self, rel_idx):
method sample_next_edge (line 53) | def sample_next_edge(self, filtered_edges, cur_ts):
method transition_step (line 80) | def transition_step(self, cur_node, cur_ts, prev_edge, start_node, ste...
method sample_walk (line 126) | def sample_walk(self, L, rel_idx):
function store_neighbors (line 167) | def store_neighbors(quads):
function store_edges (line 186) | def store_edges(quads):
class Rule_Learner (line 205) | class Rule_Learner(object):
method __init__ (line 206) | def __init__(self, edges, id2relation, inv_relation_id, output_dir):
method create_rule (line 230) | def create_rule(self, walk):
method define_var_constraints (line 267) | def define_var_constraints(self, entities):
method estimate_confidence (line 286) | def estimate_confidence(self, rule, num_samples=500):
method sample_body (line 320) | def sample_body(self, body_rels, var_constraints):
method calculate_rule_support (line 369) | def calculate_rule_support(self, unique_bodies, head_rel):
method update_rules_dict (line 396) | def update_rules_dict(self, rule):
method sort_rules_dict (line 412) | def sort_rules_dict(self):
method save_rules (line 428) | def save_rules(self, dt, rule_lengths, num_walks, transition_distr, se...
method save_rules_verbalized (line 453) | def save_rules_verbalized(
function verbalize_rule (line 483) | def verbalize_rule(rule, id2relation):
FILE: tgb/datasets/ICEWS14/ent2word.py
function load_index (line 10) | def load_index(input_path):
FILE: tgb/datasets/ICEWS14/icews14.py
function load_index (line 3) | def load_index(input_path):
function load_tab_list (line 13) | def load_tab_list(input_path):
function write2csv (line 22) | def write2csv(rows, output_path):
function main (line 29) | def main():
FILE: tgb/datasets/dataset_scripts/MAG/old/plot_stats.py
function load_csv (line 5) | def load_csv(fname: str):
FILE: tgb/datasets/dataset_scripts/dgraph.py
function main (line 39) | def main():
FILE: tgb/datasets/dataset_scripts/process_arxiv.py
function load_full_json (line 8) | def load_full_json(fname):
function main (line 20) | def main():
FILE: tgb/datasets/dataset_scripts/process_github.py
function str_to_timestamp (line 33) | def str_to_timestamp(time_str):
function parse_issue_comment_events (line 38) | def parse_issue_comment_events(event):
function parse_issue_event (line 61) | def parse_issue_event(event):
function parse_pull_request_event (line 90) | def parse_pull_request_event(event):
function parse_pull_request_review_comment_event (line 119) | def parse_pull_request_review_comment_event(event):
function parse_fork_event (line 142) | def parse_fork_event(event):
function parse_member_event (line 156) | def parse_member_event(event):
function parse_event (line 180) | def parse_event(event):
function parse_file (line 191) | def parse_file(filename):
FILE: tgb/datasets/dataset_scripts/tgbl-coin.py
function analyze_csv (line 9) | def analyze_csv(fname):
function extract_node_dict (line 91) | def extract_node_dict(fname, freq=10):
function clean_edgelist (line 127) | def clean_edgelist(fname, outname, node_dict):
function sort_edgelist (line 149) | def sort_edgelist(in_file, outname):
function main (line 181) | def main():
FILE: tgb/datasets/dataset_scripts/tgbl-coin_neg_generator.py
function main (line 9) | def main():
FILE: tgb/datasets/dataset_scripts/tgbl-comment.py
function find_filenames (line 7) | def find_filenames(path_to_dir):
function read_edgelist (line 17) | def read_edgelist(fname, outfname, write_header=False):
function read_nodeattr (line 47) | def read_nodeattr(fname, outfname, write_header=False):
function combine_edgelist_edgefeat (line 91) | def combine_edgelist_edgefeat(edgefname, featfname, outname):
function main (line 165) | def main():
FILE: tgb/datasets/dataset_scripts/tgbl-comment_neg_generator.py
function main (line 9) | def main():
FILE: tgb/datasets/dataset_scripts/tgbl-flight.py
function find_csv_filenames (line 9) | def find_csv_filenames(path_to_dir, suffix=".csv"):
function flight2edgelist (line 20) | def flight2edgelist(
function load_icao_airports (line 97) | def load_icao_airports(fname="airport_codes.csv"):
function merge_edgelist (line 118) | def merge_edgelist(input_names: str, in_dir: str, outname: str):
function clean_node_feat (line 148) | def clean_node_feat(in_file, outname):
function sort_edgelist (line 188) | def sort_edgelist(in_file, outname):
function date2ts (line 221) | def date2ts(date_str: str) -> float:
function main (line 230) | def main():
FILE: tgb/datasets/dataset_scripts/tgbl-flight_neg_generator.py
function main (line 9) | def main():
FILE: tgb/datasets/dataset_scripts/tgbl-review.py
function collect_csv (line 9) | def collect_csv(dir_name="software"):
function reorder_column (line 15) | def reorder_column(fname: str, outname: str):
function sort_edgelist (line 36) | def sort_edgelist(fname: str, outname: str):
function count_degree (line 71) | def count_degree(fname: str):
function reduce_edgelist (line 97) | def reduce_edgelist(fname: str, outname: str, node10_id: dict):
function csv_process_review (line 121) | def csv_process_review(
function main (line 195) | def main():
FILE: tgb/datasets/dataset_scripts/tgbl-review_neg_generator.py
function main (line 9) | def main():
FILE: tgb/datasets/dataset_scripts/tgbl-wiki_neg_generator.py
function main (line 9) | def main():
FILE: tgb/datasets/dataset_scripts/tgbn-genre.py
function filter_genre_edgelist (line 55) | def filter_genre_edgelist(fname, genres_dict):
function get_genre_list (line 80) | def get_genre_list(fname):
function find_unique_genres (line 133) | def find_unique_genres(fname: str, threshold: float = 0.8):
function load_genre_dict (line 162) | def load_genre_dict(
function generate_daily_node_labels (line 180) | def generate_daily_node_labels(fname: str):
function generate_aggregate_labels (line 246) | def generate_aggregate_labels(fname: str, days: int = 7):
function most_frequent (line 309) | def most_frequent(List):
function convert_ts_unix (line 325) | def convert_ts_unix(fname: str, outname: str):
function convert_ts_edgelist (line 352) | def convert_ts_edgelist(fname: str, outname: str):
function sort_node_labels (line 378) | def sort_node_labels(fname, outname):
function sort_edgelist (line 416) | def sort_edgelist(fname, outname="sorted_lastfm_edgelist.csv"):
FILE: tgb/datasets/dataset_scripts/tgbn-reddit.py
function find_filenames (line 7) | def find_filenames(path_to_dir):
function combine_edgelist_edgefeat2subreddits (line 17) | def combine_edgelist_edgefeat2subreddits(edgefname, featfname, outname):
function filter_subreddits (line 63) | def filter_subreddits(fname):
function clean_edgelist (line 91) | def clean_edgelist(fname, node_counts, outname, threshold=1000):
function clean_edgelist_reddits (line 123) | def clean_edgelist_reddits(fname, reddit_counts, outname, threshold=50):
function remove_missing_user (line 155) | def remove_missing_user(fname, outname):
function generate_daily_node_labels (line 181) | def generate_daily_node_labels(
function generate_aggregate_labels (line 229) | def generate_aggregate_labels(fname: str, outname: str, days: int = 7):
function main (line 288) | def main():
FILE: tgb/datasets/dataset_scripts/tgbn-token.py
function count_node_freq (line 4) | def count_node_freq(fname, filter_size=100):
function filter_edgelist (line 66) | def filter_edgelist(token_fname, edgefile, outname):
function filter_by_node (line 104) | def filter_by_node(node_dict, edgefile, outname):
function store_node_list (line 125) | def store_node_list(node_dict, outname):
function load_node_dict (line 139) | def load_node_dict(fname):
function store_token_address (line 161) | def store_token_address(token_dict, outname, topk=1000):
function analyze_token_frequency (line 180) | def analyze_token_frequency(fname):
function to_bipartite (line 234) | def to_bipartite(in_name, out_name, node_dict):
function analyze_csv (line 260) | def analyze_csv(fname):
function convert_2_sec (line 326) | def convert_2_sec(fname, outname):
function print_csv (line 357) | def print_csv(fname):
function sort_edgelist_by_time (line 368) | def sort_edgelist_by_time(fname, outname):
function generate_aggregate_labels (line 396) | def generate_aggregate_labels(fname: str, outname: str, days: int = 7):
function main (line 461) | def main():
FILE: tgb/datasets/dataset_scripts/tgbn-trade.py
function count_unique_countries (line 9) | def count_unique_countries(fname):
function normalize_edgelist (line 34) | def normalize_edgelist(fname: str, outname: str):
function generate_aggregate_labels (line 94) | def generate_aggregate_labels(fname: str, outname: str):
function check_sum_to_one (line 124) | def check_sum_to_one(fname: str):
function main (line 159) | def main():
FILE: tgb/datasets/tgbl_enron/tgbl-enron_neg_generator.py
function main (line 7) | def main():
FILE: tgb/datasets/tgbl_lastfm/tgbl-lastfm_neg_generator.py
function main (line 7) | def main():
FILE: tgb/datasets/tgbl_subreddit/tgbl-subreddit_neg_generator.py
function main (line 7) | def main():
FILE: tgb/datasets/tgbl_uci/tgbl-uci_neg_generator.py
function main (line 7) | def main():
FILE: tgb/datasets/thgl_forum/merge_files.py
function find_filenames (line 7) | def find_filenames(path_to_dir):
function read_edgelist (line 17) | def read_edgelist(fname, outfname, write_header=False):
function read_nodeattr (line 47) | def read_nodeattr(fname, outfname, write_header=False):
function combine_edgelist_edgefeat (line 117) | def combine_edgelist_edgefeat(edgefname, featfname, outname):
function main (line 192) | def main():
FILE: tgb/datasets/thgl_forum/thgl-forum.py
function load_csv_raw (line 106) | def load_csv_raw(fname):
function load_csv_filtered_node (line 199) | def load_csv_filtered_node(fname, low_deg_dict):
function writeNodeType (line 264) | def writeNodeType(node_type_dict, outname):
function write2csv (line 275) | def write2csv(outname, out_dict):
function writeNodeIDMapping (line 292) | def writeNodeIDMapping(node_dict, outname):
function node_deg_filter (line 302) | def node_deg_filter(node_deg_dict):
function find_low_degree_nodes (line 321) | def find_low_degree_nodes(node_deg_dict, threshold=10):
function main (line 332) | def main():
FILE: tgb/datasets/thgl_forum/thgl_forum_ns_gen.py
function main (line 6) | def main():
FILE: tgb/datasets/thgl_github/2024_01/github_extract.py
function str_to_timestamp (line 47) | def str_to_timestamp(time_str):
function parse_issue_comment_events (line 52) | def parse_issue_comment_events(event):
function parse_issue_event (line 80) | def parse_issue_event(event):
function parse_pull_request_event (line 112) | def parse_pull_request_event(event):
function parse_pull_request_review_comment_event (line 143) | def parse_pull_request_review_comment_event(event):
function parse_fork_event (line 169) | def parse_fork_event(event):
function parse_member_event (line 186) | def parse_member_event(event):
function parse_event (line 213) | def parse_event(event):
function parse_file (line 224) | def parse_file(filename):
function write2csv (line 253) | def write2csv(outname, out_dict):
function main (line 270) | def main():
FILE: tgb/datasets/thgl_github/2024_02/github_extract.py
function str_to_timestamp (line 47) | def str_to_timestamp(time_str):
function parse_issue_comment_events (line 52) | def parse_issue_comment_events(event):
function parse_issue_event (line 80) | def parse_issue_event(event):
function parse_pull_request_event (line 112) | def parse_pull_request_event(event):
function parse_pull_request_review_comment_event (line 143) | def parse_pull_request_review_comment_event(event):
function parse_fork_event (line 169) | def parse_fork_event(event):
function parse_member_event (line 186) | def parse_member_event(event):
function parse_event (line 213) | def parse_event(event):
function parse_file (line 224) | def parse_file(filename):
function write2csv (line 253) | def write2csv(outname, out_dict):
function main (line 270) | def main():
FILE: tgb/datasets/thgl_github/2024_03/github_extract.py
function str_to_timestamp (line 47) | def str_to_timestamp(time_str):
function parse_issue_comment_events (line 52) | def parse_issue_comment_events(event):
function parse_issue_event (line 80) | def parse_issue_event(event):
function parse_pull_request_event (line 112) | def parse_pull_request_event(event):
function parse_pull_request_review_comment_event (line 143) | def parse_pull_request_review_comment_event(event):
function parse_fork_event (line 169) | def parse_fork_event(event):
function parse_member_event (line 186) | def parse_member_event(event):
function parse_event (line 213) | def parse_event(event):
function parse_file (line 224) | def parse_file(filename):
function write2csv (line 253) | def write2csv(outname, out_dict):
function main (line 270) | def main():
FILE: tgb/datasets/thgl_github/extract_subset.py
function load_edgelist (line 4) | def load_edgelist(file_path, freq_threshold=5):
function subset_by_node (line 88) | def subset_by_node(file_path, low_freq_dict):
function subset_by_node_type (line 116) | def subset_by_node_type(file_path, remove_node_type_dict, low_freq_dict=...
function write2csv (line 164) | def write2csv(outname, out_dict):
function combine_edgelist (line 186) | def combine_edgelist(file_paths, outname):
function main (line 207) | def main():
FILE: tgb/datasets/thgl_github/thgl_github.py
function load_csv_raw (line 5) | def load_csv_raw(fname):
function write2csv (line 31) | def write2csv(outname, out_dict):
function load_edgelist (line 47) | def load_edgelist(fname):
function writeNodeType (line 98) | def writeNodeType(node_type_dict, outname):
function writeEdgeTypeMapping (line 109) | def writeEdgeTypeMapping(edge_type_dict, outname):
function writeNodeTypeMapping (line 120) | def writeNodeTypeMapping(node_type_dict, outname):
function write2edgelist (line 131) | def write2edgelist(out_dict, outname):
function main (line 153) | def main():
FILE: tgb/datasets/thgl_github/thgl_github_ns_gen.py
function main (line 6) | def main():
FILE: tgb/datasets/thgl_myket/thgl_myket.py
function date2ts (line 11) | def date2ts(date_str: str) -> float:
function read_csv2dict (line 38) | def read_csv2dict(fname):
function edge2nodetype (line 83) | def edge2nodetype(out_dict):
function writeNodeType (line 110) | def writeNodeType(node_type_dict, outname):
function write2edgelist (line 122) | def write2edgelist(out_dict, outname):
function main (line 149) | def main():
FILE: tgb/datasets/thgl_myket/thgl_myket_ns_gen.py
function main (line 6) | def main():
FILE: tgb/datasets/thgl_software/thgl_software.py
function load_csv_raw (line 5) | def load_csv_raw(fname):
function write2csv (line 31) | def write2csv(outname, out_dict):
function load_edgelist (line 47) | def load_edgelist(fname):
function writeNodeType (line 98) | def writeNodeType(node_type_dict, outname):
function writeEdgeTypeMapping (line 109) | def writeEdgeTypeMapping(edge_type_dict, outname):
function writeNodeTypeMapping (line 120) | def writeNodeTypeMapping(node_type_dict, outname):
function write2edgelist (line 131) | def write2edgelist(out_dict, outname):
function main (line 153) | def main():
FILE: tgb/datasets/thgl_software/thgl_software_ns_gen.py
function main (line 6) | def main():
FILE: tgb/datasets/tkgl_icews/tkgl_icews.py
function load_csv_raw (line 6) | def load_csv_raw(fname):
function write2csv (line 47) | def write2csv(outname, out_dict):
function writeEdgeTypeMapping (line 77) | def writeEdgeTypeMapping(edge_type_dict, outname):
function main (line 90) | def main():
FILE: tgb/datasets/tkgl_icews/tkgl_icews_ns_gen.py
function main (line 6) | def main():
FILE: tgb/datasets/tkgl_polecat/tkgl_polecat.py
function load_csv_raw (line 8) | def load_csv_raw(fname):
function write2csv (line 51) | def write2csv(outname: str,
function writeEdgeTypeMapping (line 121) | def writeEdgeTypeMapping(edge_type_dict, outname):
function main (line 132) | def main():
FILE: tgb/datasets/tkgl_polecat/tkgl_polecat_ns_gen.py
function main (line 6) | def main():
FILE: tgb/datasets/tkgl_smallpedia/smallpedia_remove_conflict.py
function load_static_edgelist (line 3) | def load_static_edgelist(file_path):
function load_temporal_edgelist (line 24) | def load_temporal_edgelist(file_path):
function remove_conflict (line 57) | def remove_conflict(static_dict, temporal_dict):
function write2csv (line 77) | def write2csv(outname: str,
function main (line 96) | def main():
FILE: tgb/datasets/tkgl_smallpedia/tkgl_smallpedia_ns_gen.py
function main (line 6) | def main():
FILE: tgb/datasets/tkgl_wikidata/time_edges/tkgl-wikidata_extract.py
function load_time_csv_raw (line 6) | def load_time_csv_raw(fname):
function write2csv (line 59) | def write2csv(outname: str,
function update_dict (line 79) | def update_dict(total_dict, new_dict):
function retrieve_all_entities (line 95) | def retrieve_all_entities(total_dict):
function writenode2csv (line 118) | def writenode2csv(outname: str,
function main (line 131) | def main():
FILE: tgb/datasets/tkgl_wikidata/tkgl-wikidata.py
function load_time_csv (line 6) | def load_time_csv(fname):
function write2csv (line 98) | def write2csv(outname: str,
function extract_subset (line 117) | def extract_subset(fname, outname, start_year=2000, end_year=2024):
function extract_subset_nodeid (line 152) | def extract_subset_nodeid(fname, outname, start_year=2000, end_year=2024...
function extract_static_subset (line 193) | def extract_static_subset(fname, outname, node_dict, max_id=1000000):
function subset_static_edges (line 239) | def subset_static_edges(fname, outname, rel_type, topk=10):
function main (line 280) | def main():
FILE: tgb/datasets/tkgl_wikidata/tkgl_wikidata_mining.py
function timeEdgeWrite2csv (line 19) | def timeEdgeWrite2csv(outname, out_dict):
function EdgeWrite2csv (line 33) | def EdgeWrite2csv(outname, out_dict):
function main (line 45) | def main():
FILE: tgb/datasets/tkgl_wikidata/tkgl_wikidata_ns_gen.py
function main (line 12) | def main():
FILE: tgb/datasets/tkgl_wikidata/wikidata_remove_conflict.py
function load_static_edgelist (line 3) | def load_static_edgelist(file_path):
function load_temporal_edgelist (line 24) | def load_temporal_edgelist(file_path):
function remove_conflict (line 57) | def remove_conflict(static_dict, temporal_dict):
function write2csv (line 77) | def write2csv(outname: str,
function main (line 96) | def main():
FILE: tgb/datasets/tkgl_yago/tkgl_yago.py
function main (line 7) | def main():
function write_csv (line 30) | def write_csv(outname, out_dict):
function load_csv (line 43) | def load_csv(fname):
FILE: tgb/datasets/tkgl_yago/tkgl_yago_ns_gen.py
function main (line 8) | def main():
FILE: tgb/linkproppred/dataset.py
class LinkPropPredDataset (line 42) | class LinkPropPredDataset(object):
method __init__ (line 43) | def __init__(
method _version_check (line 176) | def _version_check(self) -> None:
method download (line 201) | def download(self) -> None:
method generate_processed_files (line 244) | def generate_processed_files(self) -> pd.DataFrame:
method pre_process (line 347) | def pre_process(self):
method generate_splits (line 400) | def generate_splits(
method preprocess_static_edges (line 430) | def preprocess_static_edges(self):
method eval_metric (line 450) | def eval_metric(self) -> str:
method negative_sampler (line 459) | def negative_sampler(self) -> NegativeEdgeSampler:
method load_val_ns (line 468) | def load_val_ns(self) -> None:
method load_test_ns (line 476) | def load_test_ns(self) -> None:
method num_nodes (line 485) | def num_nodes(self) -> int:
method num_edges (line 499) | def num_edges(self) -> int:
method num_rels (line 510) | def num_rels(self) -> int:
method node_feat (line 523) | def node_feat(self) -> Optional[np.ndarray]:
method node_type (line 532) | def node_type(self) -> Optional[np.ndarray]:
method edge_feat (line 541) | def edge_feat(self) -> Optional[np.ndarray]:
method edge_type (line 550) | def edge_type(self) -> Optional[np.ndarray]:
method static_data (line 559) | def static_data(self) -> Optional[np.ndarray]:
method full_data (line 570) | def full_data(self) -> Dict[str, Any]:
method train_mask (line 584) | def train_mask(self) -> np.ndarray:
method val_mask (line 595) | def val_mask(self) -> np.ndarray:
method test_mask (line 606) | def test_mask(self) -> np.ndarray:
function main (line 617) | def main():
FILE: tgb/linkproppred/dataset_pyg.py
class PyGLinkPropPredDataset (line 9) | class PyGLinkPropPredDataset(Dataset):
method __init__ (line 10) | def __init__(
method eval_metric (line 54) | def eval_metric(self) -> str:
method negative_sampler (line 63) | def negative_sampler(self) -> NegativeEdgeSampler:
method num_nodes (line 72) | def num_nodes(self) -> int:
method num_rels (line 81) | def num_rels(self) -> int:
method num_edges (line 90) | def num_edges(self) -> int:
method load_val_ns (line 98) | def load_val_ns(self) -> None:
method load_test_ns (line 104) | def load_test_ns(self) -> None:
method train_mask (line 111) | def train_mask(self) -> torch.Tensor:
method val_mask (line 122) | def val_mask(self) -> torch.Tensor:
method test_mask (line 133) | def test_mask(self) -> torch.Tensor:
method node_feat (line 144) | def node_feat(self) -> torch.Tensor:
method node_type (line 153) | def node_type(self) -> torch.Tensor:
method src (line 162) | def src(self) -> torch.Tensor:
method dst (line 171) | def dst(self) -> torch.Tensor:
method ts (line 180) | def ts(self) -> torch.Tensor:
method static_data (line 189) | def static_data(self) -> torch.Tensor:
method edge_type (line 206) | def edge_type(self) -> torch.Tensor:
method edge_feat (line 215) | def edge_feat(self) -> torch.Tensor:
method edge_label (line 224) | def edge_label(self) -> torch.Tensor:
method process_data (line 232) | def process_data(self) -> None:
method get_TemporalData (line 286) | def get_TemporalData(self) -> TemporalData:
method len (line 311) | def len(self) -> int:
method get (line 319) | def get(self, idx: int) -> TemporalData:
method __repr__ (line 346) | def __repr__(self) -> str:
FILE: tgb/linkproppred/evaluate.py
class Evaluator (line 16) | class Evaluator(object):
method __init__ (line 19) | def __init__(self, name: str, k_value: int = 10):
method _parse_and_check_input (line 31) | def _parse_and_check_input(self, input_dict):
method _eval_hits_and_mrr (line 72) | def _eval_hits_and_mrr(self, y_pred_pos, y_pred_neg, type_info, k_value):
method eval (line 118) | def eval(self,
FILE: tgb/linkproppred/negative_generator.py
class NegativeEdgeGenerator (line 16) | class NegativeEdgeGenerator(object):
method __init__ (line 17) | def __init__(
method generate_negative_samples (line 69) | def generate_negative_samples(self,
method generate_negative_samples_rnd (line 103) | def generate_negative_samples_rnd(self,
method generate_historical_edge_set (line 172) | def generate_historical_edge_set(self,
method generate_negative_samples_hist_rnd (line 208) | def generate_negative_samples_hist_rnd(
FILE: tgb/linkproppred/negative_sampler.py
class NegativeEdgeSampler (line 15) | class NegativeEdgeSampler(object):
method __init__ (line 16) | def __init__(
method load_eval_set (line 45) | def load_eval_set(
method reset_eval_set (line 67) | def reset_eval_set(self,
method query_batch (line 85) | def query_batch(self,
FILE: tgb/linkproppred/thg_negative_generator.py
class THGNegativeEdgeGenerator (line 21) | class THGNegativeEdgeGenerator(object):
method __init__ (line 22) | def __init__(
method get_destinations_based_on_node_type (line 70) | def get_destinations_based_on_node_type(self,
method generate_negative_samples (line 101) | def generate_negative_samples(self,
method generate_negative_samples_nt (line 133) | def generate_negative_samples_nt(self,
method generate_negative_samples_random (line 212) | def generate_negative_samples_random(self,
FILE: tgb/linkproppred/thg_negative_sampler.py
class THGNegativeEdgeSampler (line 14) | class THGNegativeEdgeSampler(object):
method __init__ (line 15) | def __init__(
method load_eval_set (line 45) | def load_eval_set(
method query_batch (line 67) | def query_batch(self,
FILE: tgb/linkproppred/tkg_negative_generator.py
class TKGNegativeEdgeGenerator (line 20) | class TKGNegativeEdgeGenerator(object):
method __init__ (line 21) | def __init__(
method generate_dst_dict (line 77) | def generate_dst_dict(self, edge_data: TemporalData, dst_name: str) ->...
method generate_negative_samples (line 126) | def generate_negative_samples(self,
method generate_negative_samples_ftr (line 160) | def generate_negative_samples_ftr(self,
method generate_negative_samples_dst (line 224) | def generate_negative_samples_dst(self,
method generate_negative_samples_random (line 328) | def generate_negative_samples_random(self,
FILE: tgb/linkproppred/tkg_negative_sampler.py
class TKGNegativeEdgeSampler (line 17) | class TKGNegativeEdgeSampler(object):
method __init__ (line 18) | def __init__(
method load_eval_set (line 48) | def load_eval_set(
method query_batch (line 70) | def query_batch(self,
FILE: tgb/nodeproppred/dataset.py
class NodePropPredDataset (line 29) | class NodePropPredDataset(object):
method __init__ (line 30) | def __init__(
method _version_check (line 112) | def _version_check(self) -> None:
method download (line 132) | def download(self) -> None:
method generate_processed_files (line 178) | def generate_processed_files(
method pre_process (line 256) | def pre_process(self) -> None:
method generate_splits (line 297) | def generate_splits(
method find_next_labels_batch (line 327) | def find_next_labels_batch(
method reset_label_time (line 360) | def reset_label_time(self) -> None:
method return_label_ts (line 368) | def return_label_ts(self) -> int:
method num_classes (line 380) | def num_classes(self) -> int:
method eval_metric (line 389) | def eval_metric(self) -> str:
method node_feat (line 399) | def node_feat(self) -> Optional[np.ndarray]:
method edge_feat (line 409) | def edge_feat(self) -> Optional[np.ndarray]:
method node_label_dict (line 418) | def node_label_dict(self) -> Dict[int, Dict[int, Any]]:
method full_data (line 427) | def full_data(self) -> Dict[str, Any]:
method train_mask (line 441) | def train_mask(self) -> np.ndarray:
method val_mask (line 452) | def val_mask(self) -> np.ndarray:
method test_mask (line 464) | def test_mask(self) -> np.ndarray:
function main (line 476) | def main():
FILE: tgb/nodeproppred/dataset_pyg.py
class PyGNodePropPredDataset (line 14) | class PyGNodePropPredDataset(InMemoryDataset):
method __init__ (line 28) | def __init__(
method num_classes (line 48) | def num_classes(self) -> int:
method eval_metric (line 57) | def eval_metric(self) -> str:
method train_mask (line 66) | def train_mask(self) -> torch.Tensor:
method val_mask (line 77) | def val_mask(self) -> torch.Tensor:
method test_mask (line 88) | def test_mask(self) -> torch.Tensor:
method src (line 99) | def src(self) -> torch.Tensor:
method dst (line 108) | def dst(self) -> torch.Tensor:
method ts (line 117) | def ts(self) -> torch.Tensor:
method edge_feat (line 126) | def edge_feat(self) -> torch.Tensor:
method edge_label (line 135) | def edge_label(self) -> torch.Tensor:
method process_data (line 143) | def process_data(self):
method get_TemporalData (line 174) | def get_TemporalData(
method reset_label_time (line 191) | def reset_label_time(self) -> None:
method get_node_label (line 197) | def get_node_label(self, cur_t):
method get_label_time (line 210) | def get_label_time(self) -> int:
method len (line 218) | def len(self) -> int:
method get (line 226) | def get(self, idx: int) -> TemporalData:
method __repr__ (line 243) | def __repr__(self) -> str:
FILE: tgb/nodeproppred/evaluate.py
class Evaluator (line 15) | class Evaluator(object):
method __init__ (line 18) | def __init__(self, name: str):
method _parse_and_check_input (line 28) | def _parse_and_check_input(self, input_dict):
method _compute_metrics (line 75) | def _compute_metrics(self, y_true, y_pred):
method eval (line 94) | def eval(self, input_dict, verbose=False):
method expected_input_format (line 108) | def expected_input_format(self):
method expected_output_format (line 119) | def expected_output_format(self):
function main (line 129) | def main():
FILE: tgb/utils/dataset_stats.py
function get_unique_edges (line 17) | def get_unique_edges(sources, destination):
function get_avg_e_per_ts (line 28) | def get_avg_e_per_ts(edgelist_df):
function get_avg_degree (line 41) | def get_avg_degree(edgelist_df):
function get_index_metrics (line 58) | def get_index_metrics(train_val_data, test_data):
function get_node_ratio (line 87) | def get_node_ratio(history_data, eval_data):
function get_dataset_stats (line 103) | def get_dataset_stats(data, temporal_stats=False):
function main (line 148) | def main():
FILE: tgb/utils/info.py
class BColors (line 10) | class BColors:
FILE: tgb/utils/pre_process.py
function process_node_type (line 16) | def process_node_type(
function csv_to_forum_data (line 49) | def csv_to_forum_data(
function csv_to_thg_data (line 130) | def csv_to_thg_data(
function csv_to_wikidata (line 201) | def csv_to_wikidata(
function csv_to_staticdata (line 275) | def csv_to_staticdata(
function csv_to_tkg_data (line 333) | def csv_to_tkg_data(
function load_edgelist_wiki (line 408) | def load_edgelist_wiki(fname: str) -> pd.DataFrame:
function load_edgelist_trade (line 437) | def load_edgelist_trade(fname: str, label_size=255):
function load_trade_label_dict (line 491) | def load_trade_label_dict(
function load_edgelist_token (line 540) | def load_edgelist_token(
function load_edgelist_sr (line 639) | def load_edgelist_sr(
function load_labels_sr (line 715) | def load_labels_sr(
function load_label_dict (line 772) | def load_label_dict(fname: str, node_ids: dict, rd_dict: dict) -> dict:
function csv_to_pd_data_rc (line 817) | def csv_to_pd_data_rc(
function csv_to_pd_data_sc (line 896) | def csv_to_pd_data_sc(
function convert_str2int (line 982) | def convert_str2int(
function csv_to_pd_data (line 1006) | def csv_to_pd_data(
function process_node_feat (line 1114) | def process_node_feat(
function clean_rows (line 1200) | def clean_rows(
function load_edgelist_datetime (line 1244) | def load_edgelist_datetime(fname, label_size=514):
function load_genre_list (line 1309) | def load_genre_list(fname):
function reindex (line 1338) | def reindex(
FILE: tgb/utils/stats.py
function analyze_csv (line 15) | def analyze_csv(fname):
function plot_curve (line 80) | def plot_curve(y: np.ndarray, outname: str) -> None:
function main (line 92) | def main():
FILE: tgb/utils/utils.py
function set_verbose (line 17) | def set_verbose(flag: bool) -> None:
function vprint (line 21) | def vprint(*args, **kwargs):
function add_inverse_quadruples (line 27) | def add_inverse_quadruples(df: pd.DataFrame) -> pd.DataFrame:
function add_inverse_quadruples_np (line 65) | def add_inverse_quadruples_np(quadruples: np.array,
function add_inverse_quadruples_pyg (line 80) | def add_inverse_quadruples_pyg(data: TemporalData, num_rels:int=-1) -> l...
function save_pkl (line 102) | def save_pkl(obj: Any, fname: str) -> None:
function load_pkl (line 110) | def load_pkl(fname: str) -> Any:
function set_random_seed (line 117) | def set_random_seed(random_seed: int):
function find_nearest (line 134) | def find_nearest(array, value):
function get_args (line 140) | def get_args():
function save_results (line 165) | def save_results(new_results: dict, filename: str):
function split_by_time (line 187) | def split_by_time(data):
Condensed preview — 213 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (1,864K chars).
[
{
"path": ".devcontainer/.gitignore",
"chars": 18,
"preview": "!devcontainer.json"
},
{
"path": ".devcontainer/Dockerfile",
"chars": 520,
"preview": "FROM mcr.microsoft.com/devcontainers/python:3.10\n\nRUN python -m pip install --no-cache-dir --upgrade pip poetry \\\n &&"
},
{
"path": ".devcontainer/devcontainer.json",
"chars": 1350,
"preview": "// For format details, see https://aka.ms/devcontainer.json. For config options, see the\n// README at: https://github.co"
},
{
"path": ".github/workflows/mkdocs.yaml",
"chars": 589,
"preview": "name: mkdocs\non:\n push:\n # branches:\n # - master \n # - main\n tags:\n - \"v*.*.*\"\npermissions:\n cont"
},
{
"path": ".github/workflows/pypi.yaml",
"chars": 382,
"preview": "# https://github.com/JRubics/poetry-publish\n\nname: Publish to PyPI\non:\n push:\n tags:\n - \"v*.*.*\"\n\njobs:\n publi"
},
{
"path": ".gitignore",
"chars": 2099,
"preview": "!requirements*.txt\nget_croissant.py\n#dataset\nstats_figures/\nfigs/\n*.xz\n*.dict\n*.tab\n*.npz\n*.xz\n*.parquet\n*.gz\n*.tar\n*.pd"
},
{
"path": "LICENSE",
"chars": 1064,
"preview": "MIT License\n\nCopyright (c) 2023 TGB Team\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\no"
},
{
"path": "README.md",
"chars": 9323,
"preview": "<!-- # TGB -->\n\n\n**Temporal Graph Benchmark for Machine Learning on Temporal Graphs** (NeurIPS"
},
{
"path": "docs/about.md",
"chars": 399,
"preview": "# Temporal Graph Benchmark (TGB)\r\n\r\n\r\n## Overview\r\n\r\nThe TGB repo provides an automated ML p"
},
{
"path": "docs/api/tgb.linkproppred.md",
"chars": 375,
"preview": "# `tgb.linkproppred`\r\n\r\n::: tgb.linkproppred.dataset\r\n::: tgb.linkproppred.dataset_pyg\r\n::: tgb.linkproppred.evaluate\r\n:"
},
{
"path": "docs/api/tgb.nodeproppred.md",
"chars": 121,
"preview": "# `tgb.nodeproppred`\r\n\r\n::: tgb.nodeproppred.dataset\r\n::: tgb.nodeproppred.dataset_pyg\r\n::: tgb.nodeproppred.evaluate\r\n\r"
},
{
"path": "docs/api/tgb.utils.md",
"chars": 104,
"preview": "# `tgb.utils`\r\n\r\n::: tgb.utils.pre_process\r\n::: tgb.utils.utils\r\n::: tgb.utils.info\r\n::: tgb.utils.stats"
},
{
"path": "docs/index.md",
"chars": 3677,
"preview": "# Welcome to Temporal Graph Benchmark\n\n\n\n\n### Pip Install\n\nYou can install TGB via [pip](htt"
},
{
"path": "docs/tutorials/Edge_data_numpy.ipynb",
"chars": 4909,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"id\": \"d5e3f5a2\",\n \"metadata\": {},\n \"source\": [\n \"# Access edge "
},
{
"path": "docs/tutorials/Edge_data_pyg.ipynb",
"chars": 5316,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"id\": \"d5e3f5a2\",\n \"metadata\": {},\n \"source\": [\n \"# Access edge "
},
{
"path": "docs/tutorials/Node_label_tutorial.ipynb",
"chars": 5414,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"id\": \"d5e3f5a2\",\n \"metadata\": {},\n \"source\": [\n \"# Access node "
},
{
"path": "examples/linkproppred/tgbl-coin/dyrep.py",
"chars": 12210,
"preview": "\"\"\"\nDyRep\n This has been implemented with intuitions from the following sources:\n - https://github.com/twitter-res"
},
{
"path": "examples/linkproppred/tgbl-coin/edgebank.py",
"chars": 6717,
"preview": "\"\"\"\nDynamic Link Prediction with EdgeBank\nNOTE: This implementation works only based on `numpy`\n\nReference: \n - https"
},
{
"path": "examples/linkproppred/tgbl-coin/tgn.py",
"chars": 12250,
"preview": "\"\"\"\r\nDynamic Link Prediction with a TGN model with Early Stopping\r\nReference: \r\n - https://github.com/pyg-team/pytorc"
},
{
"path": "examples/linkproppred/tgbl-comment/dyrep.py",
"chars": 11529,
"preview": "\"\"\"\nDyRep\n This has been implemented with intuitions from the following sources:\n - https://github.com/twitter-res"
},
{
"path": "examples/linkproppred/tgbl-comment/edgebank.py",
"chars": 6592,
"preview": "\"\"\"\nDynamic Link Prediction with EdgeBank\nNOTE: This implementation works only based on `numpy`\n\nReference: \n - https"
},
{
"path": "examples/linkproppred/tgbl-comment/tgn.py",
"chars": 11894,
"preview": "\"\"\"\nDynamic Link Prediction with a TGN model with Early Stopping\nReference: \n - https://github.com/pyg-team/pytorch_g"
},
{
"path": "examples/linkproppred/tgbl-enron/edgebank.py",
"chars": 6467,
"preview": "\"\"\"\nDynamic Link Prediction with EdgeBank\nNOTE: This implementation works only based on `numpy`\n\nReference: \n - https"
},
{
"path": "examples/linkproppred/tgbl-flight/dyrep.py",
"chars": 13649,
"preview": "\"\"\"\nDyRep\n This has been implemented with intuitions from the following sources:\n - https://github.com/twitter-res"
},
{
"path": "examples/linkproppred/tgbl-flight/edgebank.py",
"chars": 6592,
"preview": "\"\"\"\nDynamic Link Prediction with EdgeBank\nNOTE: This implementation works only based on `numpy`\n\nReference: \n - https"
},
{
"path": "examples/linkproppred/tgbl-flight/tgn.py",
"chars": 13320,
"preview": "\"\"\"\nDynamic Link Prediction with a TGN model with Early Stopping\nReference: \n - https://github.com/pyg-team/pytorch_g"
},
{
"path": "examples/linkproppred/tgbl-lastfm/edgebank.py",
"chars": 6468,
"preview": "\"\"\"\nDynamic Link Prediction with EdgeBank\nNOTE: This implementation works only based on `numpy`\n\nReference: \n - https"
},
{
"path": "examples/linkproppred/tgbl-lastfm/tgn.py",
"chars": 11778,
"preview": "\"\"\"\nDynamic Link Prediction with a TGN model with Early Stopping\nReference: \n - https://github.com/pyg-team/pytorch_g"
},
{
"path": "examples/linkproppred/tgbl-review/dyrep.py",
"chars": 12216,
"preview": "\"\"\"\nDyRep\n This has been implemented with intuitions from the following sources:\n - https://github.com/twitter-res"
},
{
"path": "examples/linkproppred/tgbl-review/edgebank.py",
"chars": 6472,
"preview": "\"\"\"\nDynamic Link Prediction with EdgeBank\nNOTE: This implementation works only based on `numpy`\n\nReference: \n - https"
},
{
"path": "examples/linkproppred/tgbl-review/tgn.py",
"chars": 11893,
"preview": "\"\"\"\nDynamic Link Prediction with a TGN model with Early Stopping\nReference: \n - https://github.com/pyg-team/pytorch_g"
},
{
"path": "examples/linkproppred/tgbl-subreddit/edgebank.py",
"chars": 6471,
"preview": "\"\"\"\nDynamic Link Prediction with EdgeBank\nNOTE: This implementation works only based on `numpy`\n\nReference: \n - https"
},
{
"path": "examples/linkproppred/tgbl-subreddit/tgn.py",
"chars": 11833,
"preview": "\"\"\"\nDynamic Link Prediction with a TGN model with Early Stopping\nReference: \n - https://github.com/pyg-team/pytorch_g"
},
{
"path": "examples/linkproppred/tgbl-uci/edgebank.py",
"chars": 6465,
"preview": "\"\"\"\nDynamic Link Prediction with EdgeBank\nNOTE: This implementation works only based on `numpy`\n\nReference: \n - https"
},
{
"path": "examples/linkproppred/tgbl-wiki/dyrep.py",
"chars": 12210,
"preview": "\"\"\"\nDyRep\n This has been implemented with intuitions from the following sources:\n - https://github.com/twitter-res"
},
{
"path": "examples/linkproppred/tgbl-wiki/edgebank.py",
"chars": 6466,
"preview": "\"\"\"\nDynamic Link Prediction with EdgeBank\nNOTE: This implementation works only based on `numpy`\n\nReference: \n - https"
},
{
"path": "examples/linkproppred/tgbl-wiki/tgn.py",
"chars": 11884,
"preview": "\"\"\"\nDynamic Link Prediction with a TGN model with Early Stopping\nReference: \n - https://github.com/pyg-team/pytorch_g"
},
{
"path": "examples/linkproppred/thgl-forum/edgebank.py",
"chars": 7042,
"preview": "\"\"\"\nDynamic Link Prediction with EdgeBank\nNOTE: This implementation works only based on `numpy`\n\nReference: \n - https"
},
{
"path": "examples/linkproppred/thgl-forum/recurrencybaseline.py",
"chars": 18002,
"preview": "\"\"\" from paper: \"History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting\" \nJulia Gastinger, Christia"
},
{
"path": "examples/linkproppred/thgl-forum/sthn.py",
"chars": 16330,
"preview": "import timeit\nimport numpy as np\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\nfrom tgb.linkproppred.e"
},
{
"path": "examples/linkproppred/thgl-forum/tgn.py",
"chars": 13724,
"preview": "import numpy as np\r\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\r\nfrom tgb.linkproppred.evaluate impo"
},
{
"path": "examples/linkproppred/thgl-github/edgebank.py",
"chars": 7043,
"preview": "\"\"\"\nDynamic Link Prediction with EdgeBank\nNOTE: This implementation works only based on `numpy`\n\nReference: \n - https"
},
{
"path": "examples/linkproppred/thgl-github/recurrencybaseline.py",
"chars": 18069,
"preview": "\"\"\" from paper: \"History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting\" \nJulia Gastinger, Christia"
},
{
"path": "examples/linkproppred/thgl-github/run_seeds.sh",
"chars": 603,
"preview": "python -u tgn.py --seed 1 --mem_dim 16 --time_dim 16 --emb_dim 16 --num_epoch 5 --num_run 1 | tee tgn_s1_new_output.txt\n"
},
{
"path": "examples/linkproppred/thgl-github/sthn.py",
"chars": 16331,
"preview": "import timeit\nimport numpy as np\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\nfrom tgb.linkproppred.e"
},
{
"path": "examples/linkproppred/thgl-github/tgn.py",
"chars": 14104,
"preview": "\"\"\"\r\npython -u tgn.py --seed 1 --mem_dim 16 --time_dim 16 --emb_dim 16 --num_epoch 5 | tee tgn_s1_github_output.txt\r\n\"\"\""
},
{
"path": "examples/linkproppred/thgl-myket/edgebank.py",
"chars": 7042,
"preview": "\"\"\"\nDynamic Link Prediction with EdgeBank\nNOTE: This implementation works only based on `numpy`\n\nReference: \n - https"
},
{
"path": "examples/linkproppred/thgl-myket/recurrencybaseline.py",
"chars": 18002,
"preview": "\"\"\" from paper: \"History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting\" \nJulia Gastinger, Christia"
},
{
"path": "examples/linkproppred/thgl-myket/sthn.py",
"chars": 16330,
"preview": "import timeit\nimport numpy as np\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\nfrom tgb.linkproppred.e"
},
{
"path": "examples/linkproppred/thgl-myket/tgn.py",
"chars": 13742,
"preview": "import numpy as np\r\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\r\nfrom tgb.linkproppred.evaluate impo"
},
{
"path": "examples/linkproppred/thgl-software/STHN_README.md",
"chars": 497,
"preview": "STHN method adopted from: https://github.com/celi52/STHN/tree/main\n\nTo run:\n\n1. Install requirements. The two new additi"
},
{
"path": "examples/linkproppred/thgl-software/edgebank.py",
"chars": 7045,
"preview": "\"\"\"\nDynamic Link Prediction with EdgeBank\nNOTE: This implementation works only based on `numpy`\n\nReference: \n - https"
},
{
"path": "examples/linkproppred/thgl-software/recurrencybaseline.py",
"chars": 18005,
"preview": "\"\"\" from paper: \"History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting\" \nJulia Gastinger, Christia"
},
{
"path": "examples/linkproppred/thgl-software/sthn.py",
"chars": 16333,
"preview": "import timeit\nimport numpy as np\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\nfrom tgb.linkproppred.e"
},
{
"path": "examples/linkproppred/thgl-software/tgn.py",
"chars": 13728,
"preview": "import numpy as np\r\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\r\nfrom tgb.linkproppred.evaluate impo"
},
{
"path": "examples/linkproppred/tkgl-icews/cen.py",
"chars": 20625,
"preview": "\"\"\"\nComplex Evolutional Pattern Learning for Temporal Knowledge Graph Reasoning\nReference:\n- https://github.com/Lee-zix/"
},
{
"path": "examples/linkproppred/tkgl-icews/edgebank.py",
"chars": 7042,
"preview": "\"\"\"\nDynamic Link Prediction with EdgeBank\nNOTE: This implementation works only based on `numpy`\n\nReference: \n - https"
},
{
"path": "examples/linkproppred/tkgl-icews/recurrencybaseline.py",
"chars": 18002,
"preview": "\"\"\" from paper: \"History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting\" \nJulia Gastinger, Christia"
},
{
"path": "examples/linkproppred/tkgl-icews/regcn.py",
"chars": 14107,
"preview": "\"\"\"\nTemporal Knowledge Graph Reasoning Based on Evolutional Representation Learning\nReference:\n- https://github.com/Lee-"
},
{
"path": "examples/linkproppred/tkgl-icews/timetraveler.py",
"chars": 11921,
"preview": "\"\"\"\nTimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting\nReference:\n- https://github.com/JHL-HU"
},
{
"path": "examples/linkproppred/tkgl-icews/tkgl-icews_example.py",
"chars": 3170,
"preview": "import numpy as np\r\nimport timeit\r\nfrom tqdm import tqdm\r\nimport sys\r\nimport os\r\nimport os.path as osp\r\n\r\n#internal impo"
},
{
"path": "examples/linkproppred/tkgl-icews/tlogic.py",
"chars": 17604,
"preview": "\"\"\"\nhttps://github.com/liu-yushan/TLogic/tree/main/mycode\nTLogic: Temporal Logical Rules for Explainable Link Forecastin"
},
{
"path": "examples/linkproppred/tkgl-polecat/cen.py",
"chars": 20627,
"preview": "\"\"\"\nComplex Evolutional Pattern Learning for Temporal Knowledge Graph Reasoning\nReference:\n- https://github.com/Lee-zix/"
},
{
"path": "examples/linkproppred/tkgl-polecat/edgebank.py",
"chars": 7044,
"preview": "\"\"\"\nDynamic Link Prediction with EdgeBank\nNOTE: This implementation works only based on `numpy`\n\nReference: \n - https"
},
{
"path": "examples/linkproppred/tkgl-polecat/example.py",
"chars": 534,
"preview": "from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\r\n\r\n\r\nDATA = \"tkgl-polecat\"\r\n\r\n# data loading\r\ndataset = "
},
{
"path": "examples/linkproppred/tkgl-polecat/recurrencybaseline.py",
"chars": 18004,
"preview": "\"\"\" from paper: \"History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting\" \nJulia Gastinger, Christia"
},
{
"path": "examples/linkproppred/tkgl-polecat/regcn.py",
"chars": 14109,
"preview": "\"\"\"\nTemporal Knowledge Graph Reasoning Based on Evolutional Representation Learning\nReference:\n- https://github.com/Lee-"
},
{
"path": "examples/linkproppred/tkgl-polecat/timetraveler.py",
"chars": 11920,
"preview": "\"\"\"\nTimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting\nReference:\n- https://github.com/JHL-HU"
},
{
"path": "examples/linkproppred/tkgl-polecat/tkgl-polecat_example.py",
"chars": 3029,
"preview": "import sys\r\nsys.path.insert(0,'/../../../')\r\nimport numpy as np\r\nimport timeit\r\nfrom tqdm import tqdm\r\nfrom tgb.linkprop"
},
{
"path": "examples/linkproppred/tkgl-polecat/tlogic.py",
"chars": 17606,
"preview": "\"\"\"\nhttps://github.com/liu-yushan/TLogic/tree/main/mycode\nTLogic: Temporal Logical Rules for Explainable Link Forecastin"
},
{
"path": "examples/linkproppred/tkgl-smallpedia/cen.py",
"chars": 20630,
"preview": "\"\"\"\nComplex Evolutional Pattern Learning for Temporal Knowledge Graph Reasoning\nReference:\n- https://github.com/Lee-zix/"
},
{
"path": "examples/linkproppred/tkgl-smallpedia/edgebank.py",
"chars": 7047,
"preview": "\"\"\"\nDynamic Link Prediction with EdgeBank\nNOTE: This implementation works only based on `numpy`\n\nReference: \n - https"
},
{
"path": "examples/linkproppred/tkgl-smallpedia/recurrencybaseline.py",
"chars": 18007,
"preview": "\"\"\" from paper: \"History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting\" \nJulia Gastinger, Christia"
},
{
"path": "examples/linkproppred/tkgl-smallpedia/regcn.py",
"chars": 14112,
"preview": "\"\"\"\nTemporal Knowledge Graph Reasoning Based on Evolutional Representation Learning\nReference:\n- https://github.com/Lee-"
},
{
"path": "examples/linkproppred/tkgl-smallpedia/timetraveler.py",
"chars": 11926,
"preview": "\"\"\"\nTimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting\nReference:\n- https://github.com/JHL-HU"
},
{
"path": "examples/linkproppred/tkgl-smallpedia/tkgl-smallpedia_example.py",
"chars": 2901,
"preview": "import numpy as np\r\nimport timeit\r\nfrom tqdm import tqdm\r\nimport sys\r\nimport os.path as osp\r\nimport os\r\nfrom pathlib imp"
},
{
"path": "examples/linkproppred/tkgl-smallpedia/tlogic.py",
"chars": 17645,
"preview": "\"\"\"\nhttps://github.com/liu-yushan/TLogic/tree/main/mycode\nTLogic: Temporal Logical Rules for Explainable Link Forecastin"
},
{
"path": "examples/linkproppred/tkgl-wikidata/edgebank.py",
"chars": 7045,
"preview": "\"\"\"\nDynamic Link Prediction with EdgeBank\nNOTE: This implementation works only based on `numpy`\n\nReference: \n - https"
},
{
"path": "examples/linkproppred/tkgl-wikidata/recurrencybaseline.py",
"chars": 18005,
"preview": "\"\"\" from paper: \"History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting\" \nJulia Gastinger, Christia"
},
{
"path": "examples/linkproppred/tkgl-wikidata/regcn.py",
"chars": 14110,
"preview": "\"\"\"\nTemporal Knowledge Graph Reasoning Based on Evolutional Representation Learning\nReference:\n- https://github.com/Lee-"
},
{
"path": "examples/linkproppred/tkgl-wikidata/tkgl-wikidata_example.py",
"chars": 2668,
"preview": "import numpy as np\r\nimport timeit\r\nfrom tqdm import tqdm\r\nimport os.path as osp\r\nimport sys\r\nimport os\r\nmodules_path = o"
},
{
"path": "examples/linkproppred/tkgl-wikidata/tlogic.py",
"chars": 17607,
"preview": "\"\"\"\nhttps://github.com/liu-yushan/TLogic/tree/main/mycode\nTLogic: Temporal Logical Rules for Explainable Link Forecastin"
},
{
"path": "examples/linkproppred/tkgl-yago/cen.py",
"chars": 20624,
"preview": "\"\"\"\nComplex Evolutional Pattern Learning for Temporal Knowledge Graph Reasoning\nReference:\n- https://github.com/Lee-zix/"
},
{
"path": "examples/linkproppred/tkgl-yago/edgebank.py",
"chars": 7041,
"preview": "\"\"\"\nDynamic Link Prediction with EdgeBank\nNOTE: This implementation works only based on `numpy`\n\nReference: \n - https"
},
{
"path": "examples/linkproppred/tkgl-yago/recurrencybaseline.py",
"chars": 18001,
"preview": "\"\"\" from paper: \"History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting\" \nJulia Gastinger, Christia"
},
{
"path": "examples/linkproppred/tkgl-yago/regcn.py",
"chars": 14106,
"preview": "\"\"\"\nTemporal Knowledge Graph Reasoning Based on Evolutional Representation Learning\nReference:\n- https://github.com/Lee-"
},
{
"path": "examples/linkproppred/tkgl-yago/timetraveler.py",
"chars": 11920,
"preview": "\"\"\"\nTimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting\nReference:\n- https://github.com/JHL-HU"
},
{
"path": "examples/linkproppred/tkgl-yago/tkgl-yago_example.py",
"chars": 1584,
"preview": "import numpy as np\r\nfrom tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset\r\nfrom tgb.linkproppred.evaluate impo"
},
{
"path": "examples/linkproppred/tkgl-yago/tlogic.py",
"chars": 17603,
"preview": "\"\"\"\nhttps://github.com/liu-yushan/TLogic/tree/main/mycode\nTLogic: Temporal Logical Rules for Explainable Link Forecastin"
},
{
"path": "examples/nodeproppred/tgbn-genre/dyrep.py",
"chars": 11347,
"preview": "\"\"\"\nDyRep\n This has been implemented with intuitions from the following sources:\n - https://github.com/twitter-res"
},
{
"path": "examples/nodeproppred/tgbn-genre/moving_average.py",
"chars": 3403,
"preview": "\"\"\"\nimplement persistant forecast as baseline for the node prop pred task\nsimply predict last seen label for the node\n\"\""
},
{
"path": "examples/nodeproppred/tgbn-genre/persistant_forecast.py",
"chars": 3568,
"preview": "\"\"\"\nimplement persistant forecast as baseline for the node prop pred task\nsimply predict last seen label for the node\n\"\""
},
{
"path": "examples/nodeproppred/tgbn-genre/tgn.py",
"chars": 11182,
"preview": "from tqdm import tqdm\r\nimport torch\r\nimport timeit\r\nimport argparse\r\nimport matplotlib.pyplot as plt\r\n\r\nfrom torch_geome"
},
{
"path": "examples/nodeproppred/tgbn-reddit/dyrep.py",
"chars": 11420,
"preview": "\"\"\"\nDyRep\n This has been implemented with intuitions from the following sources:\n - https://github.com/twitter-res"
},
{
"path": "examples/nodeproppred/tgbn-reddit/moving_average.py",
"chars": 3403,
"preview": "\"\"\"\nimplement persistant forecast as baseline for the node prop pred task\nsimply predict last seen label for the node\n\"\""
},
{
"path": "examples/nodeproppred/tgbn-reddit/persistant_forecast.py",
"chars": 3568,
"preview": "\"\"\"\nimplement persistant forecast as baseline for the node prop pred task\nsimply predict last seen label for the node\n\"\""
},
{
"path": "examples/nodeproppred/tgbn-reddit/tgn.py",
"chars": 11213,
"preview": "import timeit\r\nimport argparse\r\nfrom tqdm import tqdm\r\nimport torch\r\nimport matplotlib.pyplot as plt\r\n\r\nfrom torch_geome"
},
{
"path": "examples/nodeproppred/tgbn-token/dyrep.py",
"chars": 11419,
"preview": "\"\"\"\nDyRep\n This has been implemented with intuitions from the following sources:\n - https://github.com/twitter-res"
},
{
"path": "examples/nodeproppred/tgbn-token/moving_average.py",
"chars": 3430,
"preview": "\"\"\"\nimplement persistant forecast as baseline for the node prop pred task\nsimply predict last seen label for the node\n\"\""
},
{
"path": "examples/nodeproppred/tgbn-token/persistant_forecast.py",
"chars": 3567,
"preview": "\"\"\"\nimplement persistant forecast as baseline for the node prop pred task\nsimply predict last seen label for the node\n\"\""
},
{
"path": "examples/nodeproppred/tgbn-token/tgn.py",
"chars": 11230,
"preview": "import timeit\r\nimport argparse\r\nfrom tqdm import tqdm\r\nimport torch\r\nimport matplotlib.pyplot as plt\r\nimport numpy as np"
},
{
"path": "examples/nodeproppred/tgbn-trade/count_new_nodes.py",
"chars": 5162,
"preview": "\nimport timeit\nimport numpy as np\nfrom tqdm import tqdm\nimport math\nimport os\nimport os.path as osp\nfrom pathlib import "
},
{
"path": "examples/nodeproppred/tgbn-trade/dyrep.py",
"chars": 11341,
"preview": "\"\"\"\nDyRep\n This has been implemented with intuitions from the following sources:\n - https://github.com/twitter-res"
},
{
"path": "examples/nodeproppred/tgbn-trade/moving_average.py",
"chars": 3403,
"preview": "\"\"\"\nimplement persistant forecast as baseline for the node prop pred task\nsimply predict last seen label for the node\n\"\""
},
{
"path": "examples/nodeproppred/tgbn-trade/persistant_forecast.py",
"chars": 3510,
"preview": "\"\"\"\r\nimplement persistant forecast as baseline for the node prop pred task\r\nsimply predict last seen label for the node\r"
},
{
"path": "examples/nodeproppred/tgbn-trade/tgn.py",
"chars": 11157,
"preview": "import timeit\r\nimport argparse\r\nfrom tqdm import tqdm\r\nimport torch\r\nimport matplotlib.pyplot as plt\r\n\r\nfrom torch_geome"
},
{
"path": "mkdocs.yml",
"chars": 2212,
"preview": "site_name: Temporal Graph Benchmark\n\nnav:\n - Overview: index.md\n - About: about.md\n - API:\n - tgb.linkproppred: ap"
},
{
"path": "modules/decoder.py",
"chars": 4909,
"preview": "\"\"\"\nDecoder modules for dynamic link prediction\n\n\"\"\"\n\nimport torch\nfrom torch.nn import Linear\nimport torch.nn.functiona"
},
{
"path": "modules/early_stopping.py",
"chars": 3391,
"preview": "\"\"\"\nAn Early Stopping Module\n\"\"\"\nimport os\nfrom pathlib import Path\nimport torch\nimport torch.nn as nn\nimport numpy as n"
},
{
"path": "modules/edgebank_predictor.py",
"chars": 6342,
"preview": "\"\"\"\r\nEdgeBank is a simple strong baseline for dynamic link prediction\r\nit predicts the existence of edges based on their"
},
{
"path": "modules/emb_module.py",
"chars": 1668,
"preview": "\"\"\"\nGNN-based modules used in the architecture of MP-TG models\n\n\"\"\"\n\nimport math\nfrom torch_geometric.nn import Transfor"
},
{
"path": "modules/heuristics.py",
"chars": 1390,
"preview": "import numpy as np\r\n\r\n\r\nclass PersistantForecaster:\r\n def __init__(self, num_class):\r\n self.dict = {}\r\n "
},
{
"path": "modules/memory_module.py",
"chars": 17150,
"preview": "\"\"\"\nMemory Module\n\nReference:\n - https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/model"
},
{
"path": "modules/msg_agg.py",
"chars": 835,
"preview": "\"\"\"\nMessage Aggregator Module\n\nReference:\n - https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geomet"
},
{
"path": "modules/msg_func.py",
"chars": 546,
"preview": "\"\"\"\nMessage Function Module\n\nReference:\n - https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometri"
},
{
"path": "modules/neighbor_loader.py",
"chars": 3082,
"preview": "\"\"\"\nNeighbor Loader\n\nReference:\n - https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/mod"
},
{
"path": "modules/nodebank.py",
"chars": 1338,
"preview": "import numpy as np\r\n\r\n\r\nclass NodeBank(object):\r\n def __init__(\r\n self,\r\n src: np.ndarray,\r\n dst"
},
{
"path": "modules/recurrencybaseline_predictor.py",
"chars": 16328,
"preview": "\"\"\"\n from paper: \"History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting\" \nJulia Gastinger, Christi"
},
{
"path": "modules/rgcn_layers.py",
"chars": 12481,
"preview": "\"\"\"\nhttps://github.com/Lee-zix/CEN/blob/main/rgcn/layers.py\n\"\"\"\n\nimport dgl.function as fn\nimport torch\nimport torch.nn "
},
{
"path": "modules/rgcn_model.py",
"chars": 2248,
"preview": "\"\"\"\nhttps://github.com/nec-research/CEN/blob/main/src/model.py\n\"\"\"\n\nimport torch.nn as nn\n\n\nclass BaseRGCN(nn.Module):\n "
},
{
"path": "modules/rrgcn.py",
"chars": 20926,
"preview": "\"\"\"\nhttps://github.com/Lee-zix/CEN/blob/main/src/rrgcn.py\n\"\"\"\nimport torch\nimport torch.nn as nn\nimport torch.nn.functio"
},
{
"path": "modules/sampler_core.cpp",
"chars": 18922,
"preview": "#include <iostream>\n#include <string>\n#include <cstdlib>\n#include <random>\n#include <omp.h>\n#include <math.h>\n#include <"
},
{
"path": "modules/sthn.py",
"chars": 55022,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom typing import Optional\nimport numpy as np\nfrom t"
},
{
"path": "modules/sthn_sampler_setup.py",
"chars": 540,
"preview": "from glob import glob\nfrom setuptools import setup\nfrom pybind11.setup_helpers import Pybind11Extension\n\next_modules = ["
},
{
"path": "modules/time_enc.py",
"chars": 561,
"preview": "\"\"\"\nTime Encoding Module\n\nReference:\n - https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/n"
},
{
"path": "modules/timetraveler_agent.py",
"chars": 9291,
"preview": "\"\"\"\nTimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting\nReference:\n- https://github.com/JHL-HU"
},
{
"path": "modules/timetraveler_dirichlet.py",
"chars": 16944,
"preview": "\"\"\"\nTimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting\nReference:\n- https://github.com/JHL-HU"
},
{
"path": "modules/timetraveler_environment.py",
"chars": 7005,
"preview": "\"\"\"\nTimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting\nReference:\n- https://github.com/JHL-HU"
},
{
"path": "modules/timetraveler_episode.py",
"chars": 8762,
"preview": "\"\"\"\nTimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting\nReference:\n- https://github.com/JHL-HU"
},
{
"path": "modules/timetraveler_policygradient.py",
"chars": 2933,
"preview": "\"\"\"\nTimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting\nReference:\n- https://github.com/JHL-HU"
},
{
"path": "modules/timetraveler_trainertester.py",
"chars": 12935,
"preview": "import torch\nimport json\nimport os\nimport tqdm\nimport numpy as np\n\nclass Trainer(object):\n def __init__(self, model, "
},
{
"path": "modules/tkg_utils.py",
"chars": 24126,
"preview": "\nfrom itertools import groupby\nfrom operator import itemgetter\nfrom collections import defaultdict\nimport sys\nimport arg"
},
{
"path": "modules/tkg_utils_dgl.py",
"chars": 2633,
"preview": "\nimport dgl\nimport torch\nimport numpy as np\nfrom collections import defaultdict\n\ndef build_sub_graph(num_nodes, num_rels"
},
{
"path": "modules/tlogic_apply_modules.py",
"chars": 14908,
"preview": "\"\"\"\nhttps://github.com/liu-yushan/TLogic/blob/main/mycode/rule_application.py\nTLogic: Temporal Logical Rules for Explain"
},
{
"path": "modules/tlogic_learn_modules.py",
"chars": 17365,
"preview": "\"\"\"\nhttps://github.com/liu-yushan/TLogic/blob/main/mycode/temporal_walk.py\nAND\nhttps://github.com/liu-yushan/TLogic/blob"
},
{
"path": "pyproject.toml",
"chars": 785,
"preview": "[tool.poetry]\nname = \"py-tgb\"\nversion = \"2.2.0\"\ndescription = \"Temporal Graph Benchmark project repo\"\nauthors = [\"shenya"
},
{
"path": "run.sh",
"chars": 887,
"preview": "#!/bin/bash\n#SBATCH --partition=long #unkillable #main #long\n#SBATCH --output=tgnlog_genre_s5.txt #tgn_lastfmgenre_s5.t"
},
{
"path": "scripts/env.sh",
"chars": 57,
"preview": "module load python/3.9\nsource $HOME/tgbenv/bin/activate\n\n"
},
{
"path": "scripts/mila.sh",
"chars": 71,
"preview": "salloc --partition=unkillable --cpus-per-task=4 --gres=gpu:1 --mem=32G\n"
},
{
"path": "scripts/mila_install.sh",
"chars": 137,
"preview": "module load python/3.9\npython -m venv $HOME/tgbenv\nsource $HOME/tgbenv/bin/activate\npip3 install -r requirements.txt\npip"
},
{
"path": "scripts/run.sh",
"chars": 963,
"preview": "#!/bin/bash\n#SBATCH --partition=long #unkillable #main #long\n#SBATCH --output=dyrep_trade_s5.txt #tgn_lastfmgenre_s5.tx"
},
{
"path": "setup.py",
"chars": 112,
"preview": "from setuptools import setup, find_packages\r\n\r\nsetup(name=\"py-tgb\", version=\"2.2.0\", packages=find_packages())\r\n"
},
{
"path": "tgb/datasets/ICEWS14/ent2word.py",
"chars": 2237,
"preview": "# -*- coding: utf-8 -*-\n# @Time : 2019/12/5 4:20 下午\n# @Author : Lee_zix\n# @Email : Lee_zix@163.com\n# @File : en"
},
{
"path": "tgb/datasets/ICEWS14/icews14.py",
"chars": 1333,
"preview": "import csv\r\n\r\ndef load_index(input_path):\r\n index, rev_index = {}, {}\r\n with open(input_path) as f:\r\n for i"
},
{
"path": "tgb/datasets/dataset_scripts/MAG/mag.py",
"chars": 177,
"preview": "import pandas as pd\r\n\r\n\r\nif __name__ == \"__main__\":\r\n df = pd.read_parquet(\"nodes.parquet/nodes.parquet\", engine=\"pya"
},
{
"path": "tgb/datasets/dataset_scripts/MAG/old/plot_stats.py",
"chars": 820,
"preview": "import networkx as nx\r\nimport matplotlib.pyplot as plt\r\n\r\n\r\ndef load_csv(fname: str):\r\n \"\"\"\r\n plot the number of c"
},
{
"path": "tgb/datasets/dataset_scripts/dgraph.py",
"chars": 1987,
"preview": "import dateutil.parser as dparser\r\nimport csv\r\nimport matplotlib.pyplot as plt\r\nimport numpy as np\r\nfrom tqdm import tqd"
},
{
"path": "tgb/datasets/dataset_scripts/dgraph_Readme.md",
"chars": 832,
"preview": "# Description of DGraphFin datafile.\n\nFile **dgraphfin.npz** including below keys: \n\n- **x**: 17-dimensional node featu"
},
{
"path": "tgb/datasets/dataset_scripts/process_arxiv.py",
"chars": 539,
"preview": "import json\r\nimport networkx as nx \r\nimport numpy as np\r\nimport csv\r\nfrom datetime import date\r\n\r\n\r\ndef load_full_json(f"
},
{
"path": "tgb/datasets/dataset_scripts/process_github.py",
"chars": 6366,
"preview": "import json\nfrom datetime import datetime\n\nrels = {\n \"IC_Created_IC_I\": \"IC_AO_C_I\",\n \"IC_Created_U_IC\": \"U_SO_C_I"
},
{
"path": "tgb/datasets/dataset_scripts/tgbl-coin.py",
"chars": 5936,
"preview": "import csv\r\n\r\n\"\"\"\r\n#! analyze statistics from the dataset\r\n#* 1). # of unique nodes, 2). # of edges. 3). # of unique edg"
},
{
"path": "tgb/datasets/dataset_scripts/tgbl-coin_neg_generator.py",
"chars": 2496,
"preview": "import time\r\n\r\n\r\nfrom tgb.linkproppred.negative_generator import NegativeEdgeGenerator\r\nfrom tgb.linkproppred.dataset_py"
},
{
"path": "tgb/datasets/dataset_scripts/tgbl-comment.py",
"chars": 6894,
"preview": "import csv\r\nfrom tqdm import tqdm\r\nfrom os import listdir\r\nfrom tgb.utils.stats import analyze_csv\r\n\r\n\r\ndef find_filenam"
},
{
"path": "tgb/datasets/dataset_scripts/tgbl-comment_neg_generator.py",
"chars": 2499,
"preview": "import time\r\n\r\n\r\nfrom tgb.linkproppred.negative_generator import NegativeEdgeGenerator\r\nfrom tgb.linkproppred.dataset_py"
},
{
"path": "tgb/datasets/dataset_scripts/tgbl-flight.py",
"chars": 9628,
"preview": "import dateutil.parser as dparser\r\nimport csv\r\nimport matplotlib.pyplot as plt\r\nimport numpy as np\r\nfrom tqdm import tqd"
},
{
"path": "tgb/datasets/dataset_scripts/tgbl-flight_neg_generator.py",
"chars": 2493,
"preview": "import time\r\n\r\n\r\nfrom tgb.linkproppred.negative_generator import NegativeEdgeGenerator\r\nfrom tgb.linkproppred.dataset_py"
},
{
"path": "tgb/datasets/dataset_scripts/tgbl-review.py",
"chars": 7378,
"preview": "import pyarrow.dataset as ds\r\nimport csv\r\nimport numpy as np\r\nfrom tgb.utils.stats import analyze_csv\r\nimport pandas as "
},
{
"path": "tgb/datasets/dataset_scripts/tgbl-review_neg_generator.py",
"chars": 2587,
"preview": "import time\r\n\r\n\r\nfrom tgb.linkproppred.negative_generator import NegativeEdgeGenerator\r\nfrom tgb.linkproppred.dataset_py"
},
{
"path": "tgb/datasets/dataset_scripts/tgbl-wiki_neg_generator.py",
"chars": 2543,
"preview": "import timeit\r\n\r\n\r\nfrom tgb.linkproppred.negative_generator import NegativeEdgeGenerator\r\nfrom tgb.linkproppred.dataset_"
},
{
"path": "tgb/datasets/dataset_scripts/tgbn-genre.py",
"chars": 19778,
"preview": "import networkx as nx\r\nimport numpy as np\r\nimport matplotlib.pyplot as plt\r\nimport csv\r\nfrom typing import Optional, Dic"
},
{
"path": "tgb/datasets/dataset_scripts/tgbn-reddit.py",
"chars": 11859,
"preview": "import csv\r\nfrom tqdm import tqdm\r\nfrom os import listdir\r\nfrom tgb.utils.stats import analyze_csv\r\n\r\n\r\ndef find_filenam"
},
{
"path": "tgb/datasets/dataset_scripts/tgbn-token.py",
"chars": 17005,
"preview": "import csv\r\nimport datetime\r\n\r\ndef count_node_freq(fname, filter_size=100):\r\n\r\n node_dict = {}\r\n with open(fname, "
},
{
"path": "tgb/datasets/dataset_scripts/tgbn-trade.py",
"chars": 5927,
"preview": "import dateutil.parser as dparser\r\nimport csv\r\nimport matplotlib.pyplot as plt\r\nimport numpy as np\r\nfrom tqdm import tqd"
},
{
"path": "tgb/datasets/tgbl_enron/tgbl-enron_neg_generator.py",
"chars": 2578,
"preview": "import timeit\r\nfrom tgb.linkproppred.negative_generator import NegativeEdgeGenerator\r\nfrom tgb.linkproppred.dataset_pyg "
},
{
"path": "tgb/datasets/tgbl_enron/tgbl_enron.py",
"chars": 252,
"preview": "import csv\n\n\nwith open('ml_enron.csv', 'r', newline='\\n') as infile, open('tgbl-enron_edgelist.csv', 'w', newline='\\n') "
},
{
"path": "tgb/datasets/tgbl_lastfm/tgbl-lastfm_neg_generator.py",
"chars": 2573,
"preview": "import timeit\r\nfrom tgb.linkproppred.negative_generator import NegativeEdgeGenerator\r\nfrom tgb.linkproppred.dataset_pyg "
},
{
"path": "tgb/datasets/tgbl_subreddit/tgbl-subreddit_neg_generator.py",
"chars": 2543,
"preview": "import timeit\r\nfrom tgb.linkproppred.negative_generator import NegativeEdgeGenerator\r\nfrom tgb.linkproppred.dataset_pyg "
},
{
"path": "tgb/datasets/tgbl_uci/tgbl-uci_neg_generator.py",
"chars": 2576,
"preview": "import timeit\r\nfrom tgb.linkproppred.negative_generator import NegativeEdgeGenerator\r\nfrom tgb.linkproppred.dataset_pyg "
},
{
"path": "tgb/datasets/tgbl_uci/tgbl_uci.py",
"chars": 248,
"preview": "import csv\n\n\nwith open('ml_uci.csv', 'r', newline='\\n') as infile, open('tgbl-uci_edgelist.csv', 'w', newline='\\n') as o"
},
{
"path": "tgb/datasets/thgl_forum/merge_files.py",
"chars": 8035,
"preview": "\r\nimport csv\r\nfrom tqdm import tqdm\r\nfrom os import listdir\r\nimport glob\r\n\r\ndef find_filenames(path_to_dir):\r\n r\"\"\"\r\n"
},
{
"path": "tgb/datasets/thgl_forum/thgl-forum.py",
"chars": 11967,
"preview": "\r\n\"\"\"\r\nData source: \r\n# https://surfdrive.surf.nl/files/index.php/s/M09RDerAMZrQy8q#editor\r\n\r\n# https://dl.acm.org/doi/a"
},
{
"path": "tgb/datasets/thgl_forum/thgl_forum_ns_gen.py",
"chars": 2355,
"preview": "import time\r\nfrom tgb.linkproppred.thg_negative_generator import THGNegativeEdgeGenerator\r\nfrom tgb.linkproppred.dataset"
},
{
"path": "tgb/datasets/thgl_github/2024_01/github_extract.py",
"chars": 9577,
"preview": "import json\r\nfrom datetime import datetime\r\nimport glob\r\nimport gzip\r\nimport csv\r\n\"\"\"\r\ngo to https://www.gharchive.org/\r"
},
{
"path": "tgb/datasets/thgl_github/2024_02/github_extract.py",
"chars": 9577,
"preview": "import json\r\nfrom datetime import datetime\r\nimport glob\r\nimport gzip\r\nimport csv\r\n\"\"\"\r\ngo to https://www.gharchive.org/\r"
},
{
"path": "tgb/datasets/thgl_github/2024_03/github_extract.py",
"chars": 9577,
"preview": "import json\r\nfrom datetime import datetime\r\nimport glob\r\nimport gzip\r\nimport csv\r\n\"\"\"\r\ngo to https://www.gharchive.org/\r"
},
{
"path": "tgb/datasets/thgl_github/extract_subset.py",
"chars": 7301,
"preview": "import csv\r\n\r\n\r\ndef load_edgelist(file_path, freq_threshold=5):\r\n \"\"\"\r\n ts, head, tail, relation_type\r\n 1704085"
},
{
"path": "tgb/datasets/thgl_github/thgl_github.py",
"chars": 6346,
"preview": "import csv\r\nimport datetime\r\nimport glob, os\r\n\r\ndef load_csv_raw(fname):\r\n \"\"\"\r\n load the raw csv file and merge t"
},
{
"path": "tgb/datasets/thgl_github/thgl_github_ns_gen.py",
"chars": 2356,
"preview": "import time\r\nfrom tgb.linkproppred.thg_negative_generator import THGNegativeEdgeGenerator\r\nfrom tgb.linkproppred.dataset"
},
{
"path": "tgb/datasets/thgl_myket/thgl_myket.py",
"chars": 5456,
"preview": "import dateutil.parser as dparser\r\nimport csv\r\nimport matplotlib.pyplot as plt\r\nimport numpy as np\r\nfrom tqdm import tqd"
},
{
"path": "tgb/datasets/thgl_myket/thgl_myket_ns_gen.py",
"chars": 2354,
"preview": "import time\r\nfrom tgb.linkproppred.thg_negative_generator import THGNegativeEdgeGenerator\r\nfrom tgb.linkproppred.dataset"
},
{
"path": "tgb/datasets/thgl_software/thgl_software.py",
"chars": 5550,
"preview": "import csv\r\nimport datetime\r\nimport glob, os\r\n\r\ndef load_csv_raw(fname):\r\n \"\"\"\r\n load the raw csv file and merge t"
},
{
"path": "tgb/datasets/thgl_software/thgl_software_ns_gen.py",
"chars": 2365,
"preview": "import time\r\nfrom tgb.linkproppred.thg_negative_generator import THGNegativeEdgeGenerator\r\nfrom tgb.linkproppred.dataset"
},
{
"path": "tgb/datasets/tkgl_icews/tkgl_icews.py",
"chars": 4220,
"preview": "import csv\r\nimport datetime\r\nimport glob, os\r\n\r\n\r\ndef load_csv_raw(fname):\r\n r\"\"\"\r\n load from the raw data and ret"
},
{
"path": "tgb/datasets/tkgl_icews/tkgl_icews_ns_gen.py",
"chars": 2309,
"preview": "import time\r\nfrom tgb.linkproppred.tkg_negative_generator import TKGNegativeEdgeGenerator\r\nfrom tgb.linkproppred.dataset"
},
{
"path": "tgb/datasets/tkgl_polecat/tkgl_polecat.py",
"chars": 8328,
"preview": "import csv\r\nimport datetime\r\nimport glob, os\r\n\r\n\r\n\r\n\r\ndef load_csv_raw(fname):\r\n r\"\"\"\r\n load from the raw data and"
},
{
"path": "tgb/datasets/tkgl_polecat/tkgl_polecat_ns_gen.py",
"chars": 2311,
"preview": "import time\r\nfrom tgb.linkproppred.tkg_negative_generator import TKGNegativeEdgeGenerator\r\nfrom tgb.linkproppred.dataset"
},
{
"path": "tgb/datasets/tkgl_smallpedia/smallpedia_remove_conflict.py",
"chars": 3495,
"preview": "import csv\r\n\r\ndef load_static_edgelist(file_path):\r\n r\"\"\"\r\n Load the static edgelist from the file_path\r\n Args:"
},
{
"path": "tgb/datasets/tkgl_smallpedia/tkgl_smallpedia_ns_gen.py",
"chars": 2368,
"preview": "import time\r\nfrom tgb.linkproppred.tkg_negative_generator import TKGNegativeEdgeGenerator\r\nfrom tgb.linkproppred.dataset"
},
{
"path": "tgb/datasets/tkgl_wikidata/extract.sh",
"chars": 210,
"preview": "for chunk in 5; do\nnum_chunk=25\nwhile [ $chunk -le $num_chunk ]; do\n cmd=\"tkgl_wikidata.py \\\n --chunk ${chunk} \\\n "
},
{
"path": "tgb/datasets/tkgl_wikidata/time_edges/tkgl-wikidata_extract.py",
"chars": 4741,
"preview": "import csv\r\nimport datetime\r\nimport glob, os\r\n\r\n\r\ndef load_time_csv_raw(fname):\r\n r\"\"\"\r\n load from the raw data an"
},
{
"path": "tgb/datasets/tkgl_wikidata/tkgl-wikidata.py",
"chars": 11833,
"preview": "import csv\r\nimport datetime\r\nimport glob, os\r\n\r\n\r\ndef load_time_csv(fname):\r\n r\"\"\"\r\n load from data and retrieve, "
},
{
"path": "tgb/datasets/tkgl_wikidata/tkgl_wikidata_mining.py",
"chars": 7073,
"preview": "r\"\"\"\r\nHow to use\r\npython tkgl_wikidata.py --chunk 0 --num_chunks 25\r\n# python tkgl_wikidata.py --chunk 1 --num_chunks 25"
},
{
"path": "tgb/datasets/tkgl_wikidata/tkgl_wikidata_ns_gen.py",
"chars": 2558,
"preview": "import time\r\nimport sys\r\nimport os\r\nimport os.path as osp\r\nfrom pathlib import Path\r\nmodules_path = osp.abspath(os.path."
},
{
"path": "tgb/datasets/tkgl_wikidata/wikidata_remove_conflict.py",
"chars": 3489,
"preview": "import csv\r\n\r\ndef load_static_edgelist(file_path):\r\n r\"\"\"\r\n Load the static edgelist from the file_path\r\n Args:"
},
{
"path": "tgb/datasets/tkgl_yago/tkgl_yago.py",
"chars": 2031,
"preview": "import csv\r\nimport datetime\r\nimport glob, os\r\n\r\n\r\n\r\ndef main():\r\n train_fname = \"train.txt\"\r\n val_fname = \"valid.t"
},
{
"path": "tgb/datasets/tkgl_yago/tkgl_yago_ns_gen.py",
"chars": 2353,
"preview": "import time\r\nimport sys\r\nsys.path.insert(0,'/../../../')\r\nfrom tgb.linkproppred.tkg_negative_generator import TKGNegativ"
},
{
"path": "tgb/linkproppred/dataset.py",
"chars": 25237,
"preview": "import sys\r\n\r\nfrom typing import Optional, Dict, Any, Tuple\r\nimport os\r\nimport os.path as osp\r\nimport numpy as np\r\nimpor"
},
{
"path": "tgb/linkproppred/dataset_pyg.py",
"chars": 11189,
"preview": "import torch\r\nfrom typing import Optional, Optional, Callable\r\n\r\nfrom torch_geometric.data import Dataset, TemporalData\r"
},
{
"path": "tgb/linkproppred/evaluate.py",
"chars": 5804,
"preview": "\"\"\"\r\nEvaluator Module for Dynamic Link Prediction\r\n\"\"\"\r\n\r\nimport numpy as np\r\nfrom sklearn.metrics import *\r\nfrom tgb.ut"
},
{
"path": "tgb/linkproppred/negative_generator.py",
"chars": 13889,
"preview": "\"\"\"\r\nSample and Generate negative edges that are going to be used for evaluation of a dynamic graph learning model\r\nNega"
}
]
// ... and 13 more files (download for full content)
About this extraction
This page contains the full source code of the shenyangHuang/TGB GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 213 files (1.7 MB), approximately 418.8k tokens, and a symbol index with 992 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.