Repository: shenyangHuang/TGB
Branch: main
Commit: 00d688b038b7
Files: 213
Total size: 1.7 MB
Directory structure:
gitextract_3nm00l7d/
├── .devcontainer/
│ ├── .gitignore
│ ├── Dockerfile
│ └── devcontainer.json
├── .github/
│ └── workflows/
│ ├── mkdocs.yaml
│ └── pypi.yaml
├── .gitignore
├── LICENSE
├── README.md
├── docs/
│ ├── about.md
│ ├── api/
│ │ ├── tgb.linkproppred.md
│ │ ├── tgb.nodeproppred.md
│ │ └── tgb.utils.md
│ ├── index.md
│ └── tutorials/
│ ├── Edge_data_numpy.ipynb
│ ├── Edge_data_pyg.ipynb
│ └── Node_label_tutorial.ipynb
├── examples/
│ ├── linkproppred/
│ │ ├── tgbl-coin/
│ │ │ ├── dyrep.py
│ │ │ ├── edgebank.py
│ │ │ └── tgn.py
│ │ ├── tgbl-comment/
│ │ │ ├── dyrep.py
│ │ │ ├── edgebank.py
│ │ │ └── tgn.py
│ │ ├── tgbl-enron/
│ │ │ └── edgebank.py
│ │ ├── tgbl-flight/
│ │ │ ├── dyrep.py
│ │ │ ├── edgebank.py
│ │ │ └── tgn.py
│ │ ├── tgbl-lastfm/
│ │ │ ├── edgebank.py
│ │ │ └── tgn.py
│ │ ├── tgbl-review/
│ │ │ ├── dyrep.py
│ │ │ ├── edgebank.py
│ │ │ └── tgn.py
│ │ ├── tgbl-subreddit/
│ │ │ ├── edgebank.py
│ │ │ └── tgn.py
│ │ ├── tgbl-uci/
│ │ │ └── edgebank.py
│ │ ├── tgbl-wiki/
│ │ │ ├── dyrep.py
│ │ │ ├── edgebank.py
│ │ │ └── tgn.py
│ │ ├── thgl-forum/
│ │ │ ├── edgebank.py
│ │ │ ├── recurrencybaseline.py
│ │ │ ├── sthn.py
│ │ │ └── tgn.py
│ │ ├── thgl-github/
│ │ │ ├── edgebank.py
│ │ │ ├── recurrencybaseline.py
│ │ │ ├── run_seeds.sh
│ │ │ ├── sthn.py
│ │ │ └── tgn.py
│ │ ├── thgl-myket/
│ │ │ ├── edgebank.py
│ │ │ ├── recurrencybaseline.py
│ │ │ ├── sthn.py
│ │ │ └── tgn.py
│ │ ├── thgl-software/
│ │ │ ├── STHN_README.md
│ │ │ ├── edgebank.py
│ │ │ ├── recurrencybaseline.py
│ │ │ ├── sthn.py
│ │ │ └── tgn.py
│ │ ├── tkgl-icews/
│ │ │ ├── cen.py
│ │ │ ├── edgebank.py
│ │ │ ├── recurrencybaseline.py
│ │ │ ├── regcn.py
│ │ │ ├── timetraveler.py
│ │ │ ├── tkgl-icews_example.py
│ │ │ └── tlogic.py
│ │ ├── tkgl-polecat/
│ │ │ ├── cen.py
│ │ │ ├── edgebank.py
│ │ │ ├── example.py
│ │ │ ├── recurrencybaseline.py
│ │ │ ├── regcn.py
│ │ │ ├── timetraveler.py
│ │ │ ├── tkgl-polecat_example.py
│ │ │ └── tlogic.py
│ │ ├── tkgl-smallpedia/
│ │ │ ├── cen.py
│ │ │ ├── edgebank.py
│ │ │ ├── recurrencybaseline.py
│ │ │ ├── regcn.py
│ │ │ ├── timetraveler.py
│ │ │ ├── tkgl-smallpedia_example.py
│ │ │ └── tlogic.py
│ │ ├── tkgl-wikidata/
│ │ │ ├── edgebank.py
│ │ │ ├── recurrencybaseline.py
│ │ │ ├── regcn.py
│ │ │ ├── tkgl-wikidata_example.py
│ │ │ └── tlogic.py
│ │ └── tkgl-yago/
│ │ ├── cen.py
│ │ ├── edgebank.py
│ │ ├── recurrencybaseline.py
│ │ ├── regcn.py
│ │ ├── timetraveler.py
│ │ ├── tkgl-yago_example.py
│ │ └── tlogic.py
│ └── nodeproppred/
│ ├── tgbn-genre/
│ │ ├── dyrep.py
│ │ ├── moving_average.py
│ │ ├── persistant_forecast.py
│ │ └── tgn.py
│ ├── tgbn-reddit/
│ │ ├── dyrep.py
│ │ ├── moving_average.py
│ │ ├── persistant_forecast.py
│ │ └── tgn.py
│ ├── tgbn-token/
│ │ ├── dyrep.py
│ │ ├── moving_average.py
│ │ ├── persistant_forecast.py
│ │ └── tgn.py
│ └── tgbn-trade/
│ ├── count_new_nodes.py
│ ├── dyrep.py
│ ├── moving_average.py
│ ├── persistant_forecast.py
│ └── tgn.py
├── mkdocs.yml
├── modules/
│ ├── decoder.py
│ ├── early_stopping.py
│ ├── edgebank_predictor.py
│ ├── emb_module.py
│ ├── heuristics.py
│ ├── memory_module.py
│ ├── msg_agg.py
│ ├── msg_func.py
│ ├── neighbor_loader.py
│ ├── nodebank.py
│ ├── recurrencybaseline_predictor.py
│ ├── rgcn_layers.py
│ ├── rgcn_model.py
│ ├── rrgcn.py
│ ├── sampler_core.cpp
│ ├── sthn.py
│ ├── sthn_sampler_setup.py
│ ├── time_enc.py
│ ├── timetraveler_agent.py
│ ├── timetraveler_dirichlet.py
│ ├── timetraveler_environment.py
│ ├── timetraveler_episode.py
│ ├── timetraveler_policygradient.py
│ ├── timetraveler_trainertester.py
│ ├── tkg_utils.py
│ ├── tkg_utils_dgl.py
│ ├── tlogic_apply_modules.py
│ └── tlogic_learn_modules.py
├── pyproject.toml
├── run.sh
├── scripts/
│ ├── env.sh
│ ├── mila.sh
│ ├── mila_install.sh
│ └── run.sh
├── setup.py
└── tgb/
├── datasets/
│ ├── ICEWS14/
│ │ ├── ent2word.py
│ │ └── icews14.py
│ ├── dataset_scripts/
│ │ ├── MAG/
│ │ │ ├── mag.py
│ │ │ └── old/
│ │ │ └── plot_stats.py
│ │ ├── dgraph.py
│ │ ├── dgraph_Readme.md
│ │ ├── process_arxiv.py
│ │ ├── process_github.py
│ │ ├── tgbl-coin.py
│ │ ├── tgbl-coin_neg_generator.py
│ │ ├── tgbl-comment.py
│ │ ├── tgbl-comment_neg_generator.py
│ │ ├── tgbl-flight.py
│ │ ├── tgbl-flight_neg_generator.py
│ │ ├── tgbl-review.py
│ │ ├── tgbl-review_neg_generator.py
│ │ ├── tgbl-wiki_neg_generator.py
│ │ ├── tgbn-genre.py
│ │ ├── tgbn-reddit.py
│ │ ├── tgbn-token.py
│ │ └── tgbn-trade.py
│ ├── tgbl_enron/
│ │ ├── tgbl-enron_neg_generator.py
│ │ └── tgbl_enron.py
│ ├── tgbl_lastfm/
│ │ └── tgbl-lastfm_neg_generator.py
│ ├── tgbl_subreddit/
│ │ └── tgbl-subreddit_neg_generator.py
│ ├── tgbl_uci/
│ │ ├── tgbl-uci_neg_generator.py
│ │ └── tgbl_uci.py
│ ├── thgl_forum/
│ │ ├── merge_files.py
│ │ ├── thgl-forum.py
│ │ └── thgl_forum_ns_gen.py
│ ├── thgl_github/
│ │ ├── 2024_01/
│ │ │ └── github_extract.py
│ │ ├── 2024_02/
│ │ │ └── github_extract.py
│ │ ├── 2024_03/
│ │ │ └── github_extract.py
│ │ ├── extract_subset.py
│ │ ├── thgl_github.py
│ │ └── thgl_github_ns_gen.py
│ ├── thgl_myket/
│ │ ├── thgl_myket.py
│ │ └── thgl_myket_ns_gen.py
│ ├── thgl_software/
│ │ ├── thgl_software.py
│ │ └── thgl_software_ns_gen.py
│ ├── tkgl_icews/
│ │ ├── tkgl_icews.py
│ │ └── tkgl_icews_ns_gen.py
│ ├── tkgl_polecat/
│ │ ├── tkgl_polecat.py
│ │ └── tkgl_polecat_ns_gen.py
│ ├── tkgl_smallpedia/
│ │ ├── smallpedia_remove_conflict.py
│ │ └── tkgl_smallpedia_ns_gen.py
│ ├── tkgl_wikidata/
│ │ ├── extract.sh
│ │ ├── time_edges/
│ │ │ └── tkgl-wikidata_extract.py
│ │ ├── tkgl-wikidata.py
│ │ ├── tkgl_wikidata_mining.py
│ │ ├── tkgl_wikidata_ns_gen.py
│ │ └── wikidata_remove_conflict.py
│ └── tkgl_yago/
│ ├── tkgl_yago.py
│ └── tkgl_yago_ns_gen.py
├── linkproppred/
│ ├── dataset.py
│ ├── dataset_pyg.py
│ ├── evaluate.py
│ ├── negative_generator.py
│ ├── negative_sampler.py
│ ├── thg_negative_generator.py
│ ├── thg_negative_sampler.py
│ ├── tkg_negative_generator.py
│ └── tkg_negative_sampler.py
├── nodeproppred/
│ ├── dataset.py
│ ├── dataset_pyg.py
│ └── evaluate.py
└── utils/
├── dataset_stats.py
├── info.py
├── pre_process.py
├── stats.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .devcontainer/.gitignore
================================================
!devcontainer.json
================================================
FILE: .devcontainer/Dockerfile
================================================
FROM mcr.microsoft.com/devcontainers/python:3.10
RUN python -m pip install --no-cache-dir --upgrade pip poetry \
&& pip install --no-cache-dir --extra-index-url https://download.pytorch.org/whl/cu118 'torch>=2.0.0' \
&& pip install --no-cache-dir --find-links https://data.pyg.org/whl/torch-2.0.0+cu118.html torch_geometric pyg_lib torch_scatter torch_sparse
ENV POETRY_VIRTUALENVS_CREATE=false
COPY pyproject.toml poetry.lock* /tmp/poetry/
RUN poetry -C /tmp/poetry --no-cache install --no-root --no-directory
================================================
FILE: .devcontainer/devcontainer.json
================================================
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
// README at: https://github.com/devcontainers/templates/tree/main/src/python
{
"name": "py-tgb",
"build": {
"dockerfile": "Dockerfile",
"context": ".."
},
"customizations": {
"vscode": {
"extensions": [
"editorconfig.editorconfig",
"github.vscode-pull-request-github",
"ms-azuretools.vscode-docker",
"ms-python.python",
"ms-python.vscode-pylance",
"ms-python.pylint",
"ms-python.isort",
"ms-python.flake8",
"ms-python.black-formatter",
"ms-vsliveshare.vsliveshare",
"ryanluker.vscode-coverage-gutters",
"bungcip.better-toml",
"GitHub.copilot",
"redhat.vscode-yaml"
],
"settings": {
"python.defaultInterpreterPath": "/usr/local/bin/python",
"black-formatter.path": [
"/usr/local/py-utils/bin/black"
],
"pylint.path": [
"/usr/local/py-utils/bin/pylint"
],
"flake8.path": [
"/usr/local/py-utils/bin/flake8"
],
"isort.path": [
"/usr/local/py-utils/bin/isort"
]
}
}
},
"features": {
"ghcr.io/devcontainers-contrib/features/act:1": {},
"ghcr.io/stuartleeks/dev-container-features/shell-history:0": {},
"ghcr.io/devcontainers/features/common-utils:2": {}
},
"postCreateCommand": "poetry --no-cache install --only main"
}
================================================
FILE: .github/workflows/mkdocs.yaml
================================================
name: mkdocs
on:
push:
# branches:
# - master
# - main
tags:
- "v*.*.*"
permissions:
contents: write
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: 3.x
- uses: actions/cache@v3
with:
key: mkdocs-material-${{ github.ref }}
path: .cache
restore-keys: |
mkdocs-material-
- run: pip install mkdocs-material mkdocstrings-python mkdocs-jupyter
- run: mkdocs gh-deploy --force
================================================
FILE: .github/workflows/pypi.yaml
================================================
# https://github.com/JRubics/poetry-publish
name: Publish to PyPI
on:
push:
tags:
- "v*.*.*"
jobs:
publish:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
- name: Build and publish to pypi
uses: JRubics/poetry-publish@v1.17
with:
pypi_token: ${{ secrets.PYPI_API_TOKEN }}
================================================
FILE: .gitignore
================================================
!requirements*.txt
get_croissant.py
#dataset
stats_figures/
figs/
*.xz
*.dict
*.tab
*.npz
*.xz
*.parquet
*.gz
*.tar
*.pdf
*.csv
*.zip
*.json
*.npy
*.pt
*.out
*.pkl
*.txt
*.attr
*.edge
.DS_Store
store_files/
# Byte-compiled / optimized / DLL files
__pycache__/
raw/
books/
electronics/
software/
*.py[cod]
*$py.class
saved_models/
dump/
saved_results/
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
__pycache__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
cc_env.sh
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2023 TGB Team
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================

**Temporal Graph Benchmark for Machine Learning on Temporal Graphs** (NeurIPS 2023 Datasets and Benchmarks Track)
**TGB 2.0: A Benchmark for Learning on Temporal Knowledge Graphs and Heterogeneous Graphs** (NeurIPS 2024 Datasets and Benchmarks Track)
Overview of the Temporal Graph Benchmark (TGB) pipeline:
- TGB includes large-scale and realistic datasets from 10 different domains with both dynamic link prediction and node property prediction tasks.
- TGB automatically downloads datasets and processes them into `numpy`, `PyTorch` and `PyG compatible TemporalData` formats.
- Novel TG models can be easily evaluated on TGB datasets via reproducible and realistic evaluation protocols.
- TGB provides public and online leaderboards to track recent developments in temporal graph learning domain.
- Now TGB supports temporal homogeneous graphs, temporal knowledge graphs and temporal heterogenenous graph datasets.

**To submit to [TGB leaderboard](https://tgb.complexdatalab.com/), please fill in this [google form](https://forms.gle/SEsXvN1QHo9tSFwx9)**
**See all version differences and update notes [here](https://tgb.complexdatalab.com/docs/update/)**
### Announcements
**Excited to announce TGB 2.0, has been presented at NeurIPS 2024 Datasets and Benchmarks Track**
See our [camera ready version](https://openreview.net/forum?id=EADRzNJFn1#discussion) and [arXiv version](https://arxiv.org/abs/2307.01026) for details. Please [install locally](https://tgb.complexdatalab.com/docs/home/) first. We welcome your feedback and suggestions.
**Excited to announce TGX, a companion package for analyzing temporal graphs in WSDM 2024 Demo Track**
TGX supports all TGB datasets and provides numerous temporal graph visualization plots and statistics out of the box. See our paper: [Temporal Graph Analysis with TGX](https://arxiv.org/abs/2402.03651) and [TGX website](https://complexdata-mila.github.io/TGX/).
**Please update to version `2.2.0`**
#### version `2.2.0`
Adding license for TGB software (for dataset license please check TGB website).
Printing messages now will not automatically set to stdout, use `TGB_VERBOSE=True` in your shell to set the print to be verbose.
Default option is to automatically download the datasets (rather than command line input as before).
#### version `2.1.0`
Includes supplementary datasets `tgbl-lastfm` `tgbl-enron` `tgbl-uci` `tgbl-subreddit` for research purposes.
For more details, see the release notes
#### version `2.0.0`
Includes all new datasets from TGB 2.0 including temporal knowledge graphs and temporal heterogeneous graphs.
-->
### Pip Install
You can install TGB via [pip](https://pypi.org/project/py-tgb/). **Requires python >= 3.9**
```
pip install py-tgb
```
### Links and Datasets
The project website can be found [here](https://tgb.complexdatalab.com/).
The API documentations can be found [here](https://shenyanghuang.github.io/TGB/).
all dataset download links can be found at [info.py](https://github.com/shenyangHuang/TGB/blob/main/tgb/utils/info.py)
TGB dataloader will also automatically download the dataset as well as the negative samples for the link property prediction datasets.
if website is unaccessible, please use [this link](https://tgb-website.pages.dev/) instead.
### Running Example Methods
- For the dynamic link property prediction task, see the [`examples/linkproppred`](https://github.com/shenyangHuang/TGB/tree/main/examples/linkproppred) folder for example scripts to run TGN, DyRep and EdgeBank on TGB datasets.
- For the dynamic node property prediction task, see the [`examples/nodeproppred`](https://github.com/shenyangHuang/TGB/tree/main/examples/nodeproppred) folder for example scripts to run TGN, DyRep and EdgeBank on TGB datasets.
- For all other baselines, please see the [TGB_Baselines](https://github.com/fpour/TGB_Baselines) repo.
### Acknowledgments
We thank the [OGB](https://ogb.stanford.edu/) team for their support throughout this project and sharing their website code for the construction of [TGB website](https://tgb.complexdatalab.com/).
### Software License
The code from this repo is licensed under the MIT License (see LICENSE)
### Citation
If code or data from this repo is useful for your project, please consider citing our TGB and TGB 2.0 paper:
```
@article{huang2023temporal,
title={Temporal graph benchmark for machine learning on temporal graphs},
author={Huang, Shenyang and Poursafaei, Farimah and Danovitch, Jacob and Fey, Matthias and Hu, Weihua and Rossi, Emanuele and Leskovec, Jure and Bronstein, Michael and Rabusseau, Guillaume and Rabbany, Reihaneh},
journal={Advances in Neural Information Processing Systems},
year={2023}
}
```
```
@article{huang2024tgb2,
title={TGB 2.0: A Benchmark for Learning on Temporal Knowledge Graphs and Heterogeneous Graphs},
author={Gastinger, Julia and Huang, Shenyang and Galkin, Mikhail and Loghmani, Erfan and Parviz, Ali and Poursafaei, Farimah and Danovitch, Jacob and Rossi, Emanuele and Koutis, Ioannis and Stuckenschmidt, Heiner and Rabbany, Reihaneh and Rabusseau, Guillaume},
journal={Advances in Neural Information Processing Systems},
year={2024}
}
```
================================================
FILE: docs/about.md
================================================
# Temporal Graph Benchmark (TGB)

## Overview
The TGB repo provides an automated ML pipeline for learning on a diverse set of temporal graph datasets:
- automatic download of datasets from url
- processing the raw files into ML ready format
- support datasets in `numpy`, `Pytorch` and `PyG TemporalData` formats
- evaluation code for each dataset
================================================
FILE: docs/api/tgb.linkproppred.md
================================================
# `tgb.linkproppred`
::: tgb.linkproppred.dataset
::: tgb.linkproppred.dataset_pyg
::: tgb.linkproppred.evaluate
::: tgb.linkproppred.negative_sampler
::: tgb.linkproppred.negative_generator
::: tgb.linkproppred.tkg_negative_generator
::: tgb.linkproppred.tkg_negative_sampler
::: tgb.linkproppred.thg_negative_generator
::: tgb.linkproppred.thg_negative_sampler
================================================
FILE: docs/api/tgb.nodeproppred.md
================================================
# `tgb.nodeproppred`
::: tgb.nodeproppred.dataset
::: tgb.nodeproppred.dataset_pyg
::: tgb.nodeproppred.evaluate
================================================
FILE: docs/api/tgb.utils.md
================================================
# `tgb.utils`
::: tgb.utils.pre_process
::: tgb.utils.utils
::: tgb.utils.info
::: tgb.utils.stats
================================================
FILE: docs/index.md
================================================
# Welcome to Temporal Graph Benchmark

### Pip Install
You can install TGB via [pip](https://pypi.org/project/py-tgb/)
```
pip install py-tgb
```
### Links and Datasets
The project website can be found [here](https://tgb.complexdatalab.com/).
The API documentations can be found [here](https://shenyanghuang.github.io/TGB/).
all dataset download links can be found at [info.py](https://github.com/shenyangHuang/TGB/blob/main/tgb/utils/info.py)
TGB dataloader will also automatically download the dataset as well as the negative samples for the link property prediction datasets.
### Install dependency
Our implementation works with python >= 3.9 and can be installed as follows
1. set up virtual environment (conda should work as well)
```
python -m venv ~/tgb_env/
source ~/tgb_env/bin/activate
```
2. install external packages
```
pip install pandas==1.5.3
pip install matplotlib==3.7.1
pip install clint==0.5.1
```
install Pytorch and PyG dependencies (needed to run the examples)
```
pip install torch==2.0.0 --index-url https://download.pytorch.org/whl/cu117
pip install torch_geometric==2.3.0
pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.0.0+cu117.html
```
3. install local dependencies under root directory `/TGB`
```
pip install -e .
```
### Instruction for tracking new documentation and running mkdocs locally
1. first run the mkdocs server locally in your terminal
```
mkdocs serve
```
2. go to the local hosted web address similar to
```
[14:18:13] Browser connected: http://127.0.0.1:8000/
```
Example: to track documentation of a new hi.py file in tgb/edgeregression/hi.py
3. create docs/api/tgb.hi.md and add the following
```
# `tgb.edgeregression`
::: tgb.edgeregression.hi
```
4. edit mkdocs.yml
```
nav:
- Overview: index.md
- About: about.md
- API:
other *.md files
- tgb.edgeregression: api/tgb.hi.md
```
### Creating new branch ###
```
git fetch origin
git checkout -b test origin/test
```
### dependencies for mkdocs (documentation)
```
pip install mkdocs
pip install mkdocs-material
pip install mkdocstrings-python
pip install mkdocs-jupyter
pip install notebook
```
### full dependency list
Our implementation works with python >= 3.9 and has the following dependencies
```
pytorch == 2.0.0
torch-geometric == 2.3.0
torch-scatter==2.1.1
torch-sparse==0.6.17
torch-spline-conv==1.2.2
pandas==1.5.3
clint==0.5.1
```
================================================
FILE: docs/tutorials/Edge_data_numpy.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"id": "d5e3f5a2",
"metadata": {},
"source": [
"# Access edge data as numpy arrays\n",
"\n",
"This tutorial will show you how to access various datasets and their corresponding edgelists in `tgb`\n",
"\n",
"You can directly retrieve the edge data as `numpy` arrays, `PyG` and `Pytorch` dependencies are not necessary\n",
"\n",
"The logic is implemented in `dataset.py` under `tgb/linkproppred/` and `tgb/nodeproppred/` folders respectively\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "23f00c08",
"metadata": {},
"outputs": [],
"source": [
"from tgb.linkproppred.dataset import LinkPropPredDataset"
]
},
{
"cell_type": "markdown",
"id": "60e52b7b",
"metadata": {},
"source": [
"specifying the name of the dataset"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "48888070",
"metadata": {},
"outputs": [],
"source": [
"name = \"tgbl-wiki\" "
]
},
{
"cell_type": "markdown",
"id": "3511804a",
"metadata": {},
"source": [
"### process and loading the dataset\n",
"\n",
"if the dataset has been processed, it will be loaded from disc for fast access\n",
"\n",
"if the dataset has not been downloaded, it will be processed automatically"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "8486fa82",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Will you download the dataset(s) now? (y/N)\n",
"y\n",
"\u001b[93mDownload started, this might take a while . . . \u001b[0m\n",
"Dataset title: tgbl-wiki\n",
"\u001b[92mDownload completed \u001b[0m\n",
"Dataset directory is /mnt/f/code/TGB/tgb/datasets/tgbl_wiki\n",
"file not processed, generating processed file\n"
]
},
{
"data": {
"text/plain": [
"tgb.linkproppred.dataset.LinkPropPredDataset"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset = LinkPropPredDataset(name=name, root=\"datasets\", preprocess=True)\n",
"type(dataset)"
]
},
{
"cell_type": "markdown",
"id": "47c949b4",
"metadata": {},
"source": [
"### Accessing the edge data\n",
"\n",
"the edge data can be easily accessed via the property of the method as `numpy` arrays "
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "9e4e7421",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"dict"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = dataset.full_data #a dictionary stores all the edge data\n",
"type(data) "
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "c6ec9ac0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"numpy.ndarray"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"type(data['sources'])\n",
"type(data['destinations'])\n",
"type(data['timestamps'])\n",
"type(data['edge_feat'])\n",
"type(data['w'])\n",
"type(data['edge_label']) #just all one array as all edges in the dataset are positive edges\n",
"type(data['edge_idxs']) #just index of the edges increment by 1 for each edge"
]
},
{
"cell_type": "markdown",
"id": "bb1bbfd6",
"metadata": {},
"source": [
"### Accessing the train, test, val split\n",
"\n",
"the masks for training, validation, and test split can be accessed directly from the `dataset` as well"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "8cd3507c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"numpy.ndarray"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_mask = dataset.train_mask\n",
"val_mask = dataset.val_mask\n",
"test_mask = dataset.test_mask\n",
"\n",
"type(train_mask)\n",
"type(val_mask)\n",
"type(test_mask)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cf5eff06",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
================================================
FILE: docs/tutorials/Edge_data_pyg.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"id": "d5e3f5a2",
"metadata": {},
"source": [
"# Access edge data in Pytorch Geometric\n",
"\n",
"This tutorial will show you how to access various datasets and their corresponding edgelists in `tgb`\n",
"\n",
"The logic for PyG data is stored in `dataset_pyg.py` in `tgb/linkproppred` and `tgb/nodeproppred` folders\n",
"\n",
"This tutorial requires `Pytorch` and `PyG`, refer to `README.md` for installation instructions"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "23f00c08",
"metadata": {},
"outputs": [],
"source": [
"from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset"
]
},
{
"cell_type": "markdown",
"id": "60e52b7b",
"metadata": {},
"source": [
"specifying the name of the dataset"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "48888070",
"metadata": {},
"outputs": [],
"source": [
"name = \"tgbl-wiki\""
]
},
{
"cell_type": "markdown",
"id": "3511804a",
"metadata": {},
"source": [
"### Process and load the dataset\n",
"\n",
"if the dataset has been processed, it will be loaded from disc for fast access\n",
"\n",
"if the dataset has not been downloaded, it will be processed automatically"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "8486fa82",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"file found, skipping download\n",
"Dataset directory is /mnt/f/code/TGB/tgb/datasets/tgbl_wiki\n",
"loading processed file\n"
]
},
{
"data": {
"text/plain": [
"tgb.linkproppred.dataset_pyg.PyGLinkPropPredDataset"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset = PyGLinkPropPredDataset(name=name, root=\"datasets\")\n",
"type(dataset)"
]
},
{
"cell_type": "markdown",
"id": "47c949b4",
"metadata": {},
"source": [
"### Access edge data from TemporalData object \n",
"\n",
"You can retrieve `torch_geometric.data.temporal.TemporalData` directly from `PyGLinkPropPredDataset`"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "9e4e7421",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch_geometric.data.temporal.TemporalData"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = dataset.get_TemporalData()\n",
"type(data)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "c6ec9ac0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Tensor"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"type(data.src)\n",
"type(data.dst)\n",
"type(data.t)\n",
"type(data.msg)"
]
},
{
"cell_type": "markdown",
"id": "52fd601f",
"metadata": {},
"source": [
"### Directly access edge data as Pytorch tensors\n",
"\n",
"the edge data can be easily accessed via the property of the method, these are converted into pytorch tensors (from `PyGLinkPropPredDataset`)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "56fb3347",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Tensor"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"type(dataset.src) #same as src from above\n",
"type(dataset.dst) #same as dst\n",
"type(dataset.ts) #same as t\n",
"type(dataset.edge_feat) #same as msg\n",
"type(dataset.edge_label) #same as label used in tgn"
]
},
{
"cell_type": "markdown",
"id": "bb1bbfd6",
"metadata": {},
"source": [
"### Accessing the train, test, val split\n",
"\n",
"the masks for training, validation, and test split can be accessed directly from the `dataset` as well"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "8cd3507c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Tensor"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_mask = dataset.train_mask\n",
"val_mask = dataset.val_mask\n",
"test_mask = dataset.test_mask\n",
"\n",
"type(train_mask)\n",
"type(val_mask)\n",
"type(test_mask)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9d6ed432",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
================================================
FILE: docs/tutorials/Node_label_tutorial.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"id": "d5e3f5a2",
"metadata": {},
"source": [
"# Access node labels for Dynamic Node Property Prediction\n",
"\n",
"This tutorial will show you how to access node labels and edge data for the node property prediction datasets in `tgb`.\n",
"\n",
"The source code is stored in `dataset_pyg.py` in `tgb/nodeproppred` folder\n",
"\n",
"This tutorial requires `Pytorch` and `PyG`, refer to `README.md` for installation instructions\n",
"\n",
"This tutorial uses `PyG TemporalData` object, however it is possible to use `numpy` arrays as well.\n",
"\n",
"see examples in `examples/nodeproppred` folder for more details.\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "23f00c08",
"metadata": {},
"outputs": [],
"source": [
"from tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset\n",
"from torch_geometric.loader import TemporalDataLoader"
]
},
{
"cell_type": "markdown",
"id": "60e52b7b",
"metadata": {},
"source": [
"specifying the name of the dataset"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "48888070",
"metadata": {},
"outputs": [],
"source": [
"name = \"tgbn-genre\""
]
},
{
"cell_type": "markdown",
"id": "3511804a",
"metadata": {},
"source": [
"### Process and load the dataset\n",
"\n",
"if the dataset has been processed, it will be loaded from disc for fast access\n",
"\n",
"if the dataset has not been downloaded, it will be processed automatically"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "8486fa82",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"file found, skipping download\n",
"Dataset directory is /mnt/f/code/TGB/tgb/datasets/tgbn_genre\n",
"loading processed file\n"
]
},
{
"data": {
"text/plain": [
"tgb.nodeproppred.dataset_pyg.PyGNodePropPredDataset"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset = PyGNodePropPredDataset(name=name, root=\"datasets\")\n",
"type(dataset)"
]
},
{
"cell_type": "markdown",
"id": "31338262",
"metadata": {},
"source": [
"### Train, Validation and Test splits with dataloaders\n",
"\n",
"splitting the edges into train, val, test sets and construct dataloader for each"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "27b4f6a1",
"metadata": {},
"outputs": [],
"source": [
"train_mask = dataset.train_mask\n",
"val_mask = dataset.val_mask\n",
"test_mask = dataset.test_mask\n",
"\n",
"\n",
"data = dataset.get_TemporalData()\n",
"\n",
"train_data = data[train_mask]\n",
"val_data = data[val_mask]\n",
"test_data = data[test_mask]\n",
"\n",
"batch_size = 200\n",
"train_loader = TemporalDataLoader(train_data, batch_size=batch_size)\n",
"val_loader = TemporalDataLoader(val_data, batch_size=batch_size)\n",
"test_loader = TemporalDataLoader(test_data, batch_size=batch_size)\n"
]
},
{
"cell_type": "markdown",
"id": "47c949b4",
"metadata": {},
"source": [
"### Access node label data \n",
"\n",
"In `tgb`, the node label data are queried based on the nearest edge observed so far and retrieves the node label data for the corresponding day. \n",
"\n",
"Note that this is because the node labels often have different timestamps from the edges thus should be processed at the correct time in the edge stream.\n",
"\n",
"In the example below, we show how to iterate through the edges and retrieve the node labels of the corresponding time. "
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "9e4e7421",
"metadata": {},
"outputs": [],
"source": [
"#query the timestamps for the first node labels\n",
"label_t = dataset.get_label_time()\n",
"\n",
"for batch in train_loader:\n",
" #access the edges in this batch\n",
" src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg\n",
" query_t = batch.t[-1]\n",
" # check if this batch moves to the next day\n",
" if query_t > label_t:\n",
" # find the node labels from the past day\n",
" label_tuple = dataset.get_node_label(query_t)\n",
" # node labels are structured as a tuple with (timestamps, source node, label) format, label is a vector\n",
" label_ts, label_srcs, labels = (\n",
" label_tuple[0],\n",
" label_tuple[1],\n",
" label_tuple[2],\n",
" )\n",
" label_t = dataset.get_label_time()\n",
"\n",
" #insert your code for backproping with node labels here\n",
" "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
================================================
FILE: examples/linkproppred/tgbl-coin/dyrep.py
================================================
"""
DyRep
This has been implemented with intuitions from the following sources:
- https://github.com/twitter-research/tgn
- https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py
Spec.:
- Memory Updater: RNN
- Embedding Module: ID
- Message Function: ATTN
command for an example run:
python examples/linkproppred/tgbl-coin/dyrep.py --data "tgbl-coin" --num_run 1 --seed 1
"""
import math
import timeit
import os
import os.path as osp
from pathlib import Path
import numpy as np
import torch
from torch_geometric.loader import TemporalDataLoader
# internal imports
from tgb.utils.utils import get_args, set_random_seed, save_results
from tgb.linkproppred.evaluate import Evaluator
from modules.decoder import LinkPredictor
from modules.emb_module import GraphAttentionEmbedding
from modules.msg_func import IdentityMessage
from modules.msg_agg import LastAggregator
from modules.neighbor_loader import LastNeighborLoader
from modules.memory_module import DyRepMemory
from modules.early_stopping import EarlyStopMonitor
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
# ==========
# ========== Define helper function...
# ==========
def train():
r"""
Training procedure for DyRep model
This function uses some objects that are globally defined in the current scrips
Parameters:
None
Returns:
None
"""
model['memory'].train()
model['gnn'].train()
model['link_pred'].train()
model['memory'].reset_state() # Start with a fresh memory.
neighbor_loader.reset_state() # Start with an empty graph.
total_loss = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
# Sample negative destination nodes.
neg_dst = torch.randint(
min_dst_idx,
max_dst_idx + 1,
(src.size(0),),
dtype=torch.long,
device=device,
)
n_id = torch.cat([src, pos_dst, neg_dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])
neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])
loss = criterion(pos_out, torch.ones_like(pos_out))
loss += criterion(neg_out, torch.zeros_like(neg_out))
# update the memory with ground-truth
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
model['memory'].update_state(src, pos_dst, t, msg, z, assoc)
# update neighbor loader
neighbor_loader.insert(src, pos_dst)
loss.backward()
optimizer.step()
model['memory'].detach()
total_loss += float(loss.detach()) * batch.num_events
return total_loss / train_data.num_events
@torch.no_grad()
def test(loader, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
loader: an object containing positive attributes of the positive edges of the evaluation set
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluation
"""
model['memory'].eval()
model['gnn'].eval()
model['link_pred'].eval()
perf_list = []
for pos_batch in loader:
pos_src, pos_dst, pos_t, pos_msg = (
pos_batch.src,
pos_batch.dst,
pos_batch.t,
pos_batch.msg,
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)
dst = torch.tensor(
np.concatenate(
([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),
axis=0,
),
device=device,
)
n_id = torch.cat([src, dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),
"y_pred_neg": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# update the memory with positive edges
n_id = torch.cat([pos_src, pos_dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg, z, assoc)
# update the neighbor loader
neighbor_loader.insert(pos_src, pos_dst)
perf_metric = float(torch.tensor(perf_list).mean())
return perf_metric
# ==========
# ==========
# ==========
# Start...
start_overall = timeit.default_timer()
# ========== set parameters...
args, _ = get_args()
print("INFO: Arguments:", args)
DATA = "tgbl-coin"
LR = args.lr
BATCH_SIZE = args.bs
K_VALUE = args.k_value
NUM_EPOCH = args.num_epoch
SEED = args.seed
MEM_DIM = args.mem_dim
TIME_DIM = args.time_dim
EMB_DIM = args.emb_dim
TOLERANCE = args.tolerance
PATIENCE = args.patience
NUM_RUNS = args.num_run
NUM_NEIGHBORS = 10
MODEL_NAME = 'DyRep'
USE_SRC_EMB_IN_MSG = False
USE_DST_EMB_IN_MSG = True
# ==========
# set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# data loading
dataset = PyGLinkPropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data = data.to(device)
metric = dataset.eval_metric
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
train_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)
val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)
test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
# neighborhood sampler
neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)
# define the model end-to-end
# 1) memory
memory = DyRepMemory(
data.num_nodes,
data.msg.size(-1),
MEM_DIM,
TIME_DIM,
message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),
aggregator_module=LastAggregator(),
memory_updater_type='rnn',
use_src_emb_in_msg=USE_SRC_EMB_IN_MSG,
use_dst_emb_in_msg=USE_DST_EMB_IN_MSG
).to(device)
# 2) GNN
gnn = GraphAttentionEmbedding(
in_channels=MEM_DIM,
out_channels=EMB_DIM,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
).to(device)
# 3) link predictor
link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)
model = {'memory': memory,
'gnn': gnn,
'link_pred': link_pred}
# define an optimizer
optimizer = torch.optim.Adam(
set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),
lr=LR,
)
criterion = torch.nn.BCEWithLogitsLoss()
# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)
print("==========================================================")
print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
for run_idx in range(NUM_RUNS):
print('-------------------------------------------------------------------------------')
print(f"INFO: >>>>> Run: {run_idx} <<<<<")
start_run = timeit.default_timer()
# set the seed for deterministic results...
torch.manual_seed(run_idx + SEED)
set_random_seed(run_idx + SEED)
# define an early stopper
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'
early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id,
tolerance=TOLERANCE, patience=PATIENCE)
# ==================================================== Train & Validation
# loading the validation negative samples
dataset.load_val_ns()
val_perf_list = []
start_train_val = timeit.default_timer()
for epoch in range(1, NUM_EPOCH + 1):
# training
start_epoch_train = timeit.default_timer()
loss = train()
print(
f"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {timeit.default_timer() - start_epoch_train: .4f}"
)
# validation
start_val = timeit.default_timer()
perf_metric_val = test(val_loader, neg_sampler, split_mode="val")
print(f"\tValidation {metric}: {perf_metric_val: .4f}")
print(f"\tValidation: Elapsed time (s): {timeit.default_timer() - start_val: .4f}")
val_perf_list.append(perf_metric_val)
# check for early stopping
if early_stopper.step_check(perf_metric_val, model):
break
train_val_time = timeit.default_timer() - start_train_val
print(f"Train & Validation Total Elapsed Time (s): {train_val_time: .4f}")
# ==================================================== Test
# first, load the best model
early_stopper.load_checkpoint(model)
# loading the test negative samples
dataset.load_test_ns()
# final testing
start_test = timeit.default_timer()
perf_metric_test = test(test_loader, neg_sampler, split_mode="test")
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'data': DATA,
'run': run_idx,
'seed': SEED,
f'val {metric}': val_perf_list,
f'test {metric}': perf_metric_test,
'test_time': test_time,
'tot_train_val_time': train_val_time
},
results_filename)
print(f"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<")
print('-------------------------------------------------------------------------------')
print(f"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}")
print("==============================================================")
================================================
FILE: examples/linkproppred/tgbl-coin/edgebank.py
================================================
"""
Dynamic Link Prediction with EdgeBank
NOTE: This implementation works only based on `numpy`
Reference:
- https://github.com/fpour/DGB/tree/main
"""
import timeit
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch_geometric.loader import TemporalDataLoader
from tqdm import tqdm
import math
import os
import os.path as osp
from pathlib import Path
import sys
import argparse
# internal imports
from tgb.linkproppred.evaluate import Evaluator
from modules.edgebank_predictor import EdgeBankPredictor
from tgb.utils.utils import set_random_seed
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.utils.utils import save_results
# ==================
# ==================
# ==================
def test(data, test_mask, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
data: a dataset object
test_mask: required masks to load the test set edges
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluation
"""
num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)
perf_list = []
for batch_idx in tqdm(range(num_batches)):
start_idx = batch_idx * BATCH_SIZE
end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))
pos_src, pos_dst, pos_t = (
data['sources'][test_mask][start_idx: end_idx],
data['destinations'][test_mask][start_idx: end_idx],
data['timestamps'][test_mask][start_idx: end_idx],
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])
query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])
y_pred = edgebank.predict_link(query_src, query_dst)
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0]]),
"y_pred_neg": np.array(y_pred[1:]),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# update edgebank memory after each positive batch
edgebank.update_memory(pos_src, pos_dst, pos_t)
perf_metrics = float(np.mean(perf_list))
return perf_metrics
def get_args():
parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')
parser.add_argument('-d', '--data', type=str, help='Dataset name', default='tgbl-coin')
parser.add_argument('--bs', type=int, help='Batch size', default=200)
parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)
parser.add_argument('--seed', type=int, help='Random seed', default=1)
parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])
parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)
try:
args = parser.parse_args()
except:
parser.print_help()
sys.exit(0)
return args, sys.argv
# ==================
# ==================
# ==================
start_overall = timeit.default_timer()
# set hyperparameters
args, _ = get_args()
SEED = args.seed # set the random seed for consistency
set_random_seed(SEED)
MEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`
BATCH_SIZE = args.bs
K_VALUE = args.k_value
TIME_WINDOW_RATIO = args.time_window_ratio
DATA = "tgbl-coin"
MODEL_NAME = 'EdgeBank'
# data loading with `numpy`
dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
data = dataset.full_data
metric = dataset.eval_metric
# get masks
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
#data for memory in edgebank
hist_src = np.concatenate([data['sources'][train_mask]])
hist_dst = np.concatenate([data['destinations'][train_mask]])
hist_ts = np.concatenate([data['timestamps'][train_mask]])
# #! check if edges are sorted
# sorted = np.all(np.diff(data['timestamps']) >= 0)
# print (" INFO: Edges are sorted: ", sorted)
# Set EdgeBank with memory updater
edgebank = EdgeBankPredictor(
hist_src,
hist_dst,
hist_ts,
memory_mode=MEMORY_MODE,
time_window_ratio=TIME_WINDOW_RATIO)
print("==========================================================")
print(f"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'
# ==================================================== Test
# loading the validation negative samples
dataset.load_val_ns()
# testing ...
start_val = timeit.default_timer()
perf_metric_test = test(data, val_mask, neg_sampler, split_mode='val')
end_val = timeit.default_timer()
print(f"INFO: val: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tval: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_val
print(f"\tval: Elapsed Time (s): {test_time: .4f}")
# ==================================================== Test
# loading the test negative samples
dataset.load_test_ns()
# testing ...
start_test = timeit.default_timer()
perf_metric_test = test(data, test_mask, neg_sampler, split_mode='test')
end_test = timeit.default_timer()
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'memory_mode': MEMORY_MODE,
'data': DATA,
'run': 1,
'seed': SEED,
metric: perf_metric_test,
'test_time': test_time,
'tot_train_val_time': 'NA'
},
results_filename)
================================================
FILE: examples/linkproppred/tgbl-coin/tgn.py
================================================
"""
Dynamic Link Prediction with a TGN model with Early Stopping
Reference:
- https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py
command for an example run:
python examples/linkproppred/tgbl-coin/tgn.py --data "tgbl-coin" --num_run 1 --seed 1
"""
import math
import timeit
import os
import os.path as osp
from pathlib import Path
import numpy as np
import torch
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn import Linear
from torch_geometric.datasets import JODIEDataset
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TransformerConv
# internal imports
from tgb.utils.utils import get_args, set_random_seed, save_results
from tgb.linkproppred.evaluate import Evaluator
from modules.decoder import LinkPredictor
from modules.emb_module import GraphAttentionEmbedding
from modules.msg_func import IdentityMessage
from modules.msg_agg import LastAggregator
from modules.neighbor_loader import LastNeighborLoader
from modules.memory_module import TGNMemory
from modules.early_stopping import EarlyStopMonitor
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
# ==========
# ========== Define helper function...
# ==========
def train():
r"""
Training procedure for TGN model
This function uses some objects that are globally defined in the current scrips
Parameters:
None
Returns:
None
"""
model['memory'].train()
model['gnn'].train()
model['link_pred'].train()
model['memory'].reset_state() # Start with a fresh memory.
neighbor_loader.reset_state() # Start with an empty graph.
total_loss = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
# Sample negative destination nodes.
neg_dst = torch.randint(
min_dst_idx,
max_dst_idx + 1,
(src.size(0),),
dtype=torch.long,
device=device,
)
n_id = torch.cat([src, pos_dst, neg_dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])
neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])
loss = criterion(pos_out, torch.ones_like(pos_out))
loss += criterion(neg_out, torch.zeros_like(neg_out))
# Update memory and neighbor loader with ground-truth state.
model['memory'].update_state(src, pos_dst, t, msg)
neighbor_loader.insert(src, pos_dst)
loss.backward()
optimizer.step()
model['memory'].detach()
total_loss += float(loss.detach()) * batch.num_events
return total_loss / train_data.num_events
@torch.no_grad()
def test(loader, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
loader: an object containing positive attributes of the positive edges of the evaluation set
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluation
"""
model['memory'].eval()
model['gnn'].eval()
model['link_pred'].eval()
perf_list = []
for pos_batch in loader:
pos_src, pos_dst, pos_t, pos_msg = (
pos_batch.src,
pos_batch.dst,
pos_batch.t,
pos_batch.msg,
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)
dst = torch.tensor(
np.concatenate(
([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),
axis=0,
),
device=device,
)
n_id = torch.cat([src, dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),
"y_pred_neg": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# Update memory and neighbor loader with ground-truth state.
model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg)
neighbor_loader.insert(pos_src, pos_dst)
perf_metrics = float(torch.tensor(perf_list).mean())
return perf_metrics
# ==========
# ==========
# ==========
# Start...
start_overall = timeit.default_timer()
DATA = "tgbl-coin"
# ========== set parameters...
args, _ = get_args()
args.data = DATA
print("INFO: Arguments:", args)
LR = args.lr
BATCH_SIZE = args.bs
K_VALUE = args.k_value
NUM_EPOCH = args.num_epoch
SEED = args.seed
MEM_DIM = args.mem_dim
TIME_DIM = args.time_dim
EMB_DIM = args.emb_dim
TOLERANCE = args.tolerance
PATIENCE = args.patience
NUM_RUNS = args.num_run
NUM_NEIGHBORS = 10
MODEL_NAME = 'TGN'
# ==========
# set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# data loading
dataset = PyGLinkPropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data = data.to(device)
metric = dataset.eval_metric
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
train_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)
val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)
test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
print("==========================================================")
print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
for run_idx in range(NUM_RUNS):
print('-------------------------------------------------------------------------------')
print(f"INFO: >>>>> Run: {run_idx} <<<<<")
start_run = timeit.default_timer()
# set the seed for deterministic results...
torch.manual_seed(run_idx + SEED)
set_random_seed(run_idx + SEED)
# neighborhood sampler
neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)
# define the model end-to-end
memory = TGNMemory(
data.num_nodes,
data.msg.size(-1),
MEM_DIM,
TIME_DIM,
message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),
aggregator_module=LastAggregator(),
).to(device)
gnn = GraphAttentionEmbedding(
in_channels=MEM_DIM,
out_channels=EMB_DIM,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
).to(device)
link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)
model = {'memory': memory,
'gnn': gnn,
'link_pred': link_pred}
optimizer = torch.optim.Adam(
set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),
lr=LR,
)
criterion = torch.nn.BCEWithLogitsLoss()
# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)
# define an early stopper
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'
early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id,
tolerance=TOLERANCE, patience=PATIENCE)
# ==================================================== Train & Validation
# loading the validation negative samples
dataset.load_val_ns()
val_perf_list = []
start_train_val = timeit.default_timer()
for epoch in range(1, NUM_EPOCH + 1):
# training
start_epoch_train = timeit.default_timer()
loss = train()
print(
f"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {timeit.default_timer() - start_epoch_train: .4f}"
)
# validation
start_val = timeit.default_timer()
perf_metric_val = test(val_loader, neg_sampler, split_mode="val")
print(f"\tValidation {metric}: {perf_metric_val: .4f}")
print(f"\tValidation: Elapsed time (s): {timeit.default_timer() - start_val: .4f}")
val_perf_list.append(perf_metric_val)
# check for early stopping
if early_stopper.step_check(perf_metric_val, model):
break
train_val_time = timeit.default_timer() - start_train_val
print(f"Train & Validation: Elapsed Time (s): {train_val_time: .4f}")
# ==================================================== Test
# first, load the best model
early_stopper.load_checkpoint(model)
# loading the test negative samples
dataset.load_test_ns()
# final testing
start_test = timeit.default_timer()
perf_metric_test = test(test_loader, neg_sampler, split_mode="test")
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'data': DATA,
'run': run_idx,
'seed': SEED,
f'val {metric}': val_perf_list,
f'test {metric}': perf_metric_test,
'test_time': test_time,
'tot_train_val_time': train_val_time
},
results_filename)
print(f"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<")
print('-------------------------------------------------------------------------------')
print(f"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}")
print("==============================================================")
================================================
FILE: examples/linkproppred/tgbl-comment/dyrep.py
================================================
"""
DyRep
This has been implemented with intuitions from the following sources:
- https://github.com/twitter-research/tgn
- https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py
Spec.:
- Memory Updater: RNN
- Embedding Module: ID
- Message Function: ATTN
"""
import math
import timeit
import os
import os.path as osp
from pathlib import Path
import numpy as np
import torch
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn import Linear
from torch_geometric.datasets import JODIEDataset
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TransformerConv
# internal imports
from tgb.utils.utils import get_args, set_random_seed, save_results
from tgb.linkproppred.evaluate import Evaluator
from modules.decoder import LinkPredictor
from modules.emb_module import GraphAttentionEmbedding
from modules.msg_func import IdentityMessage
from modules.msg_agg import LastAggregator
from modules.neighbor_loader import LastNeighborLoader
from modules.memory_module import DyRepMemory
from modules.early_stopping import EarlyStopMonitor
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
# ==========
# ========== Define helper function...
# ==========
def train():
model['memory'].train()
model['gnn'].train()
model['link_pred'].train()
model['memory'].reset_state() # Start with a fresh memory.
neighbor_loader.reset_state() # Start with an empty graph.
total_loss = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
# Sample negative destination nodes.
neg_dst = torch.randint(
min_dst_idx,
max_dst_idx + 1,
(src.size(0),),
dtype=torch.long,
device=device,
)
n_id = torch.cat([src, pos_dst, neg_dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])
neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])
loss = criterion(pos_out, torch.ones_like(pos_out))
loss += criterion(neg_out, torch.zeros_like(neg_out))
# update the memory with ground-truth
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
model['memory'].update_state(src, pos_dst, t, msg, z, assoc)
# update neighbor loader
neighbor_loader.insert(src, pos_dst)
loss.backward()
optimizer.step()
model['memory'].detach()
total_loss += float(loss.detach()) * batch.num_events
return total_loss / train_data.num_events
@torch.no_grad()
def test_one_vs_many(loader, neg_sampler, split_mode):
"""
Evaluated the dynamic link prediction
"""
model['memory'].eval()
model['gnn'].eval()
model['link_pred'].eval()
perf_list = []
for pos_batch in loader:
pos_src, pos_dst, pos_t, pos_msg = (
pos_batch.src,
pos_batch.dst,
pos_batch.t,
pos_batch.msg,
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)
dst = torch.tensor(
np.concatenate(
([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),
axis=0,
),
device=device,
)
n_id = torch.cat([src, dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),
"y_pred_neg": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# update the memory with positive edges
n_id = torch.cat([pos_src, pos_dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg, z, assoc)
# update the neighbor loader
neighbor_loader.insert(pos_src, pos_dst)
perf_metric = float(torch.tensor(perf_list).mean())
return perf_metric
# ==========
# ==========
# ==========
# Start...
start_overall = timeit.default_timer()
# ========== set parameters...
args, _ = get_args()
print("INFO: Arguments:", args)
DATA = "tgbl-comment"
LR = args.lr
BATCH_SIZE = args.bs
K_VALUE = args.k_value
NUM_EPOCH = args.num_epoch
SEED = args.seed
MEM_DIM = args.mem_dim
TIME_DIM = args.time_dim
EMB_DIM = args.emb_dim
TOLERANCE = args.tolerance
PATIENCE = args.patience
NUM_RUNS = args.num_run
NUM_NEIGHBORS = 10
MODEL_NAME = 'DyRep'
USE_SRC_EMB_IN_MSG = False
USE_DST_EMB_IN_MSG = True
# ==========
# set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# data loading
dataset = PyGLinkPropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data = data.to(device)
metric = dataset.eval_metric
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
train_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)
val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)
test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
# neighborhood sampler
neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)
# define the model end-to-end
memory = DyRepMemory(
data.num_nodes,
data.msg.size(-1),
MEM_DIM,
TIME_DIM,
message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),
aggregator_module=LastAggregator(),
memory_updater_type='rnn',
use_src_emb_in_msg=USE_SRC_EMB_IN_MSG,
use_dst_emb_in_msg=USE_DST_EMB_IN_MSG
).to(device)
gnn = GraphAttentionEmbedding(
in_channels=MEM_DIM,
out_channels=EMB_DIM,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
).to(device)
link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)
model = {'memory': memory,
'gnn': gnn,
'link_pred': link_pred}
optimizer = torch.optim.Adam(
set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),
lr=LR,
)
criterion = torch.nn.BCEWithLogitsLoss()
# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)
print("==========================================================")
print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
for run_idx in range(NUM_RUNS):
print('-------------------------------------------------------------------------------')
print(f"INFO: >>>>> Run: {run_idx} <<<<<")
start_run = timeit.default_timer()
# set the seed for deterministic results...
torch.manual_seed(run_idx + SEED)
set_random_seed(run_idx + SEED)
# define an early stopper
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'
early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id,
tolerance=TOLERANCE, patience=PATIENCE)
# ==================================================== Train & Validation
# loading the validation negative samples
dataset.load_val_ns()
val_perf_list = []
start_train_val = timeit.default_timer()
for epoch in range(1, NUM_EPOCH + 1):
# training
start_epoch_train = timeit.default_timer()
loss = train()
print(
f"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {timeit.default_timer() - start_epoch_train: .4f}"
)
# validation
start_val = timeit.default_timer()
perf_metric_val = test_one_vs_many(val_loader, neg_sampler, split_mode="val")
print(f"\tValidation {metric}: {perf_metric_val: .4f}")
print(f"\tValidation: Elapsed time (s): {timeit.default_timer() - start_val: .4f}")
val_perf_list.append(perf_metric_val)
# check for early stopping
if early_stopper.step_check(perf_metric_val, model):
break
train_val_time = timeit.default_timer() - start_train_val
print(f"Train & Validation: Elapsed Time (s): {train_val_time: .4f}")
# ==================================================== Test
# first, load the best model
early_stopper.load_checkpoint(model)
# loading the test negative samples
dataset.load_test_ns()
# final testing
start_test = timeit.default_timer()
perf_metric_test = test_one_vs_many(test_loader, neg_sampler, split_mode="test")
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'data': DATA,
'run': run_idx,
'seed': SEED,
f'val {metric}': val_perf_list,
f'test {metric}': perf_metric_test,
'test_time': test_time,
'tot_train_val_time': train_val_time
},
results_filename)
print(f"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<")
print('-------------------------------------------------------------------------------')
print(f"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}")
print("==============================================================")
================================================
FILE: examples/linkproppred/tgbl-comment/edgebank.py
================================================
"""
Dynamic Link Prediction with EdgeBank
NOTE: This implementation works only based on `numpy`
Reference:
- https://github.com/fpour/DGB/tree/main
"""
import timeit
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch_geometric.loader import TemporalDataLoader
from tqdm import tqdm
import math
import os
import os.path as osp
from pathlib import Path
import sys
import argparse
# internal imports
from tgb.linkproppred.evaluate import Evaluator
from modules.edgebank_predictor import EdgeBankPredictor
from tgb.utils.utils import set_random_seed
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.utils.utils import save_results
# ==================
# ==================
# ==================
def test(data, test_mask, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
data: a dataset object
test_mask: required masks to load the test set edges
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluation
"""
num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)
perf_list = []
for batch_idx in tqdm(range(num_batches)):
start_idx = batch_idx * BATCH_SIZE
end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))
pos_src, pos_dst, pos_t = (
data['sources'][test_mask][start_idx: end_idx],
data['destinations'][test_mask][start_idx: end_idx],
data['timestamps'][test_mask][start_idx: end_idx],
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])
query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])
y_pred = edgebank.predict_link(query_src, query_dst)
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0]]),
"y_pred_neg": np.array(y_pred[1:]),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# update edgebank memory after each positive batch
edgebank.update_memory(pos_src, pos_dst, pos_t)
perf_metrics = float(np.mean(perf_list))
return perf_metrics
def get_args():
parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')
parser.add_argument('-d', '--data', type=str, help='Dataset name', default='tgbl-comment')
parser.add_argument('--bs', type=int, help='Batch size', default=200)
parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)
parser.add_argument('--seed', type=int, help='Random seed', default=1)
parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])
parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)
try:
args = parser.parse_args()
except:
parser.print_help()
sys.exit(0)
return args, sys.argv
# ==================
# ==================
# ==================
start_overall = timeit.default_timer()
# set hyperparameters
args, _ = get_args()
SEED = args.seed # set the random seed for consistency
set_random_seed(SEED)
MEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`
BATCH_SIZE = args.bs
K_VALUE = args.k_value
TIME_WINDOW_RATIO = args.time_window_ratio
DATA = "tgbl-comment"
MODEL_NAME = 'EdgeBank'
# data loading with `numpy`
dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
data = dataset.full_data
metric = dataset.eval_metric
# get masks
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
#data for memory in edgebank
hist_src = np.concatenate([data['sources'][train_mask]])
hist_dst = np.concatenate([data['destinations'][train_mask]])
hist_ts = np.concatenate([data['timestamps'][train_mask]])
# Set EdgeBank with memory updater
edgebank = EdgeBankPredictor(
hist_src,
hist_dst,
hist_ts,
memory_mode=MEMORY_MODE,
time_window_ratio=TIME_WINDOW_RATIO)
print("==========================================================")
print(f"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'
# ==================================================== Test
# loading the validation negative samples
dataset.load_val_ns()
# testing ...
start_val = timeit.default_timer()
perf_metric_test = test(data, val_mask, neg_sampler, split_mode='val')
end_val = timeit.default_timer()
print(f"INFO: val: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tval: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_val
print(f"\tval: Elapsed Time (s): {test_time: .4f}")
# ==================================================== Test
# loading the test negative samples
dataset.load_test_ns()
# testing ...
start_test = timeit.default_timer()
perf_metric_test = test(data, test_mask, neg_sampler, split_mode='test')
end_test = timeit.default_timer()
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'memory_mode': MEMORY_MODE,
'data': DATA,
'run': 1,
'seed': SEED,
metric: perf_metric_test,
'test_time': test_time,
'tot_train_val_time': 'NA'
},
results_filename)
================================================
FILE: examples/linkproppred/tgbl-comment/tgn.py
================================================
"""
Dynamic Link Prediction with a TGN model with Early Stopping
Reference:
- https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py
command for an example run:
python examples/linkproppred/tgbl-comment/tgn.py --data "tgbl-comment" --num_run 1 --seed 1
"""
import math
import timeit
import os
import os.path as osp
from pathlib import Path
import numpy as np
import torch
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn import Linear
from torch_geometric.datasets import JODIEDataset
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TransformerConv
# internal imports
from tgb.utils.utils import get_args, set_random_seed, save_results
from tgb.linkproppred.evaluate import Evaluator
from modules.decoder import LinkPredictor
from modules.emb_module import GraphAttentionEmbedding
from modules.msg_func import IdentityMessage
from modules.msg_agg import LastAggregator
from modules.neighbor_loader import LastNeighborLoader
from modules.memory_module import TGNMemory
from modules.early_stopping import EarlyStopMonitor
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
# ==========
# ========== Define helper function...
# ==========
def train():
r"""
Training procedure for TGN model
This function uses some objects that are globally defined in the current scrips
Parameters:
None
Returns:
None
"""
model['memory'].train()
model['gnn'].train()
model['link_pred'].train()
model['memory'].reset_state() # Start with a fresh memory.
neighbor_loader.reset_state() # Start with an empty graph.
total_loss = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
# Sample negative destination nodes.
neg_dst = torch.randint(
min_dst_idx,
max_dst_idx + 1,
(src.size(0),),
dtype=torch.long,
device=device,
)
n_id = torch.cat([src, pos_dst, neg_dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])
neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])
loss = criterion(pos_out, torch.ones_like(pos_out))
loss += criterion(neg_out, torch.zeros_like(neg_out))
# Update memory and neighbor loader with ground-truth state.
model['memory'].update_state(src, pos_dst, t, msg)
neighbor_loader.insert(src, pos_dst)
loss.backward()
optimizer.step()
model['memory'].detach()
total_loss += float(loss.detach()) * batch.num_events
return total_loss / train_data.num_events
@torch.no_grad()
def test(loader, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
loader: an object containing positive attributes of the positive edges of the evaluation set
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluation
"""
model['memory'].eval()
model['gnn'].eval()
model['link_pred'].eval()
perf_list = []
for pos_batch in loader:
pos_src, pos_dst, pos_t, pos_msg = (
pos_batch.src,
pos_batch.dst,
pos_batch.t,
pos_batch.msg,
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)
dst = torch.tensor(
np.concatenate(
([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),
axis=0,
),
device=device,
)
n_id = torch.cat([src, dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),
"y_pred_neg": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# Update memory and neighbor loader with ground-truth state.
model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg)
neighbor_loader.insert(pos_src, pos_dst)
perf_metrics = float(torch.tensor(perf_list).mean())
return perf_metrics
# ==========
# ==========
# ==========
# Start...
start_overall = timeit.default_timer()
DATA = "tgbl-comment"
# ========== set parameters...
args, _ = get_args()
args.data = DATA
print("INFO: Arguments:", args)
LR = args.lr
BATCH_SIZE = args.bs
K_VALUE = args.k_value
NUM_EPOCH = args.num_epoch
SEED = args.seed
MEM_DIM = args.mem_dim
TIME_DIM = args.time_dim
EMB_DIM = args.emb_dim
TOLERANCE = args.tolerance
PATIENCE = args.patience
NUM_RUNS = args.num_run
NUM_NEIGHBORS = 10
MODEL_NAME = 'TGN'
# ==========
# set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# data loading
dataset = PyGLinkPropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data = data.to(device)
metric = dataset.eval_metric
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
train_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)
val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)
test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
print("==========================================================")
print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
for run_idx in range(NUM_RUNS):
print('-------------------------------------------------------------------------------')
print(f"INFO: >>>>> Run: {run_idx} <<<<<")
start_run = timeit.default_timer()
# set the seed for deterministic results...
torch.manual_seed(run_idx + SEED)
set_random_seed(run_idx + SEED)
# neighborhood sampler
neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)
# define the model end-to-end
memory = TGNMemory(
data.num_nodes,
data.msg.size(-1),
MEM_DIM,
TIME_DIM,
message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),
aggregator_module=LastAggregator(),
).to(device)
gnn = GraphAttentionEmbedding(
in_channels=MEM_DIM,
out_channels=EMB_DIM,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
).to(device)
link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)
model = {'memory': memory,
'gnn': gnn,
'link_pred': link_pred}
optimizer = torch.optim.Adam(
set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),
lr=LR,
)
criterion = torch.nn.BCEWithLogitsLoss()
# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)
# define an early stopper
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'
early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id,
tolerance=TOLERANCE, patience=PATIENCE)
# ==================================================== Train & Validation
# loading the validation negative samples
dataset.load_val_ns()
val_perf_list = []
start_train_val = timeit.default_timer()
for epoch in range(1, NUM_EPOCH + 1):
# training
start_epoch_train = timeit.default_timer()
loss = train()
print(
f"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {timeit.default_timer() - start_epoch_train: .4f}"
)
# validation
start_val = timeit.default_timer()
perf_metric_val = test(val_loader, neg_sampler, split_mode="val")
print(f"\tValidation {metric}: {perf_metric_val: .4f}")
print(f"\tValidation: Elapsed time (s): {timeit.default_timer() - start_val: .4f}")
val_perf_list.append(perf_metric_val)
# check for early stopping
if early_stopper.step_check(perf_metric_val, model):
break
train_val_time = timeit.default_timer() - start_train_val
print(f"Train & Validation: Elapsed Time (s): {train_val_time: .4f}")
# ==================================================== Test
# first, load the best model
early_stopper.load_checkpoint(model)
# loading the test negative samples
dataset.load_test_ns()
# final testing
start_test = timeit.default_timer()
perf_metric_test = test(test_loader, neg_sampler, split_mode="test")
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'data': DATA,
'run': run_idx,
'seed': SEED,
f'val {metric}': val_perf_list,
f'test {metric}': perf_metric_test,
'test_time': test_time,
'tot_train_val_time': train_val_time
},
results_filename)
print(f"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<")
print('-------------------------------------------------------------------------------')
print(f"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}")
print("==============================================================")
================================================
FILE: examples/linkproppred/tgbl-enron/edgebank.py
================================================
"""
Dynamic Link Prediction with EdgeBank
NOTE: This implementation works only based on `numpy`
Reference:
- https://github.com/fpour/DGB/tree/main
"""
import timeit
import numpy as np
from tqdm import tqdm
import math
import os
import os.path as osp
from pathlib import Path
import sys
import argparse
# internal imports
from tgb.linkproppred.evaluate import Evaluator
from modules.edgebank_predictor import EdgeBankPredictor
from tgb.utils.utils import set_random_seed
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.utils.utils import save_results
# ==================
# ==================
# ==================
def test(data, test_mask, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
data: a dataset object
test_mask: required masks to load the test set edges
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluaiton
"""
num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)
perf_list = []
for batch_idx in tqdm(range(num_batches)):
start_idx = batch_idx * BATCH_SIZE
end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))
pos_src, pos_dst, pos_t = (
data['sources'][test_mask][start_idx: end_idx],
data['destinations'][test_mask][start_idx: end_idx],
data['timestamps'][test_mask][start_idx: end_idx],
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])
query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])
y_pred = edgebank.predict_link(query_src, query_dst)
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0]]),
"y_pred_neg": np.array(y_pred[1:]),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# update edgebank memory after each positive batch
edgebank.update_memory(pos_src, pos_dst, pos_t)
perf_metrics = float(np.mean(perf_list))
return perf_metrics
def get_args():
parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')
parser.add_argument('-d', '--data', type=str, help='Dataset name', default='tgbl-wiki')
parser.add_argument('--bs', type=int, help='Batch size', default=200)
parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)
parser.add_argument('--seed', type=int, help='Random seed', default=1)
parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])
parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)
try:
args = parser.parse_args()
except:
parser.print_help()
sys.exit(0)
return args, sys.argv
# ==================
# ==================
# ==================
start_overall = timeit.default_timer()
# set hyperparameters
args, _ = get_args()
SEED = args.seed # set the random seed for consistency
set_random_seed(SEED)
MEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`
BATCH_SIZE = args.bs
K_VALUE = args.k_value
TIME_WINDOW_RATIO = args.time_window_ratio
DATA = "tgbl-enron"
MODEL_NAME = 'EdgeBank'
# data loading with `numpy`
dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
data = dataset.full_data
metric = dataset.eval_metric
# get masks
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
#data for memory in edgebank
hist_src = np.concatenate([data['sources'][train_mask]])
hist_dst = np.concatenate([data['destinations'][train_mask]])
hist_ts = np.concatenate([data['timestamps'][train_mask]])
# Set EdgeBank with memory updater
edgebank = EdgeBankPredictor(
hist_src,
hist_dst,
hist_ts,
memory_mode=MEMORY_MODE,
time_window_ratio=TIME_WINDOW_RATIO)
print("==========================================================")
print(f"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'
# ==================================================== Test
# loading the validation negative samples
dataset.load_val_ns()
# testing ...
start_val = timeit.default_timer()
perf_metric_test = test(data, val_mask, neg_sampler, split_mode='val')
end_val = timeit.default_timer()
print(f"INFO: val: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tval: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_val
print(f"\tval: Elapsed Time (s): {test_time: .4f}")
# ==================================================== Test
# loading the test negative samples
dataset.load_test_ns()
# testing ...
start_test = timeit.default_timer()
perf_metric_test = test(data, test_mask, neg_sampler, split_mode='test')
end_test = timeit.default_timer()
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'memory_mode': MEMORY_MODE,
'data': DATA,
'run': 1,
'seed': SEED,
metric: perf_metric_test,
'test_time': test_time,
'tot_train_val_time': 'NA'
},
results_filename)
================================================
FILE: examples/linkproppred/tgbl-flight/dyrep.py
================================================
"""
DyRep
This has been implemented with intuitions from the following sources:
- https://github.com/twitter-research/tgn
- https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py
Spec.:
- Memory Updater: RNN
- Embedding Module: ID
- Message Function: ATTN
command for an example run:
python examples/linkproppred/tgbl-flight/dyrep.py --data "tgbl-flight" --num_run 1 --seed 1
"""
import math
import timeit
import os
import os.path as osp
from pathlib import Path
import numpy as np
import torch
from torch_geometric.loader import TemporalDataLoader
# internal imports
from tgb.utils.utils import get_args, set_random_seed, save_results
from tgb.linkproppred.evaluate import Evaluator
from modules.decoder import LinkPredictor
from modules.emb_module import GraphAttentionEmbedding
from modules.msg_func import IdentityMessage
from modules.msg_agg import LastAggregator
from modules.neighbor_loader import LastNeighborLoader
from modules.memory_module import DyRepMemory
from modules.early_stopping import EarlyStopMonitor
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
# ==========
# ========== Define helper function...
# ==========
def train():
r"""
Training procedure for DyRep model
This function uses some objects that are globally defined in the current scrips
Parameters:
None
Returns:
None
"""
model['memory'].train()
model['gnn'].train()
model['link_pred'].train()
model['memory'].reset_state() # Start with a fresh memory.
neighbor_loader.reset_state() # Start with an empty graph.
total_loss = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
# Sample negative destination nodes.
neg_dst = torch.randint(
min_dst_idx,
max_dst_idx + 1,
(src.size(0),),
dtype=torch.long,
device=device,
)
n_id = torch.cat([src, pos_dst, neg_dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])
neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])
loss = criterion(pos_out, torch.ones_like(pos_out))
loss += criterion(neg_out, torch.zeros_like(neg_out))
# update the memory with ground-truth
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
model['memory'].update_state(src, pos_dst, t, msg, z, assoc)
# update neighbor loader
neighbor_loader.insert(src, pos_dst)
loss.backward()
optimizer.step()
model['memory'].detach()
total_loss += float(loss.detach()) * batch.num_events
return total_loss / train_data.num_events
@torch.no_grad()
def test(loader, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
loader: an object containing positive attributes of the positive edges of the evaluation set
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluation
"""
model['memory'].eval()
model['gnn'].eval()
model['link_pred'].eval()
perf_list = []
for pos_batch in loader:
pos_src, pos_dst, pos_t, pos_msg = (
pos_batch.src,
pos_batch.dst,
pos_batch.t,
pos_batch.msg,
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)
dst = torch.tensor(
np.concatenate(
([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),
axis=0,
),
device=device,
)
n_id = torch.cat([src, dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),
"y_pred_neg": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# update the memory with positive edges
n_id = torch.cat([pos_src, pos_dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg, z, assoc)
# update the neighbor loader
neighbor_loader.insert(pos_src, pos_dst)
perf_metric = float(torch.tensor(perf_list).mean())
return perf_metric
# ==========
# ==========
# ==========
# Start...
start_overall = timeit.default_timer()
# ========== set parameters...
args, _ = get_args()
print("INFO: Arguments:", args)
DATA = "tgbl-flight"
LR = args.lr
BATCH_SIZE = args.bs
K_VALUE = args.k_value
NUM_EPOCH = args.num_epoch
SEED = args.seed
MEM_DIM = args.mem_dim
TIME_DIM = args.time_dim
EMB_DIM = args.emb_dim
TOLERANCE = args.tolerance
PATIENCE = args.patience
NUM_RUNS = args.num_run
NUM_NEIGHBORS = 10
MODEL_NAME = 'DyRep'
USE_SRC_EMB_IN_MSG = False
USE_DST_EMB_IN_MSG = True
# ==========
# set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# data loading
dataset = PyGLinkPropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data = data.to(device)
metric = dataset.eval_metric
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
train_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)
val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)
test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
# neighborhood sampler
neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)
# define the model end-to-end
# 1) memory
memory = DyRepMemory(
data.num_nodes,
data.msg.size(-1),
MEM_DIM,
TIME_DIM,
message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),
aggregator_module=LastAggregator(),
memory_updater_type='rnn',
use_src_emb_in_msg=USE_SRC_EMB_IN_MSG,
use_dst_emb_in_msg=USE_DST_EMB_IN_MSG
).to(device)
# 2) GNN
gnn = GraphAttentionEmbedding(
in_channels=MEM_DIM,
out_channels=EMB_DIM,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
).to(device)
# 3) link predictor
link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)
model = {'memory': memory,
'gnn': gnn,
'link_pred': link_pred}
# define an optimizer
optimizer = torch.optim.Adam(
set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),
lr=LR,
)
criterion = torch.nn.BCEWithLogitsLoss()
# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)
print("==========================================================")
print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
for run_idx in range(NUM_RUNS):
print('-------------------------------------------------------------------------------')
print(f"INFO: >>>>> Run: {run_idx} <<<<<")
start_run = timeit.default_timer()
# set the seed for deterministic results...
torch.manual_seed(run_idx + SEED)
set_random_seed(run_idx + SEED)
# define an early stopper
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'
early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id,
tolerance=TOLERANCE, patience=PATIENCE)
# ==================================================== Train & Validation
# loading the validation negative samples
dataset.load_val_ns()
val_perf_list = []
train_times_l, val_times_l = [], []
free_mem_l, total_mem_l, used_mem_l = [], [], []
start_train_val = timeit.default_timer()
for epoch in range(1, NUM_EPOCH + 1):
# training
start_epoch_train = timeit.default_timer()
loss = train()
end_epoch_train = timeit.default_timer()
print(
f"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {end_epoch_train - start_epoch_train: .4f}"
)
# checking GPU memory usage
free_mem, used_mem, total_mem = 0, 0, 0
if torch.cuda.is_available():
print("DEBUG: device: {}".format(torch.cuda.get_device_name(0)))
free_mem, total_mem = torch.cuda.mem_get_info()
used_mem = total_mem - free_mem
print("------------Epoch {}: GPU memory usage-----------".format(epoch))
print("Free memory: {}".format(free_mem))
print("Total available memory: {}".format(total_mem))
print("Used memory: {}".format(used_mem))
print("--------------------------------------------")
train_times_l.append(end_epoch_train - start_epoch_train)
free_mem_l.append(float((free_mem*1.0)/2**30)) # in GB
used_mem_l.append(float((used_mem*1.0)/2**30)) # in GB
total_mem_l.append(float((total_mem*1.0)/2**30)) # in GB
# validation
start_val = timeit.default_timer()
perf_metric_val = test(val_loader, neg_sampler, split_mode="val")
end_val = timeit.default_timer()
print(f"\tValidation {metric}: {perf_metric_val: .4f}")
print(f"\tValidation: Elapsed time (s): {end_val - start_val: .4f}")
val_perf_list.append(perf_metric_val)
val_times_l.append(end_val - start_val)
# check for early stopping
if early_stopper.step_check(perf_metric_val, model):
break
train_val_time = timeit.default_timer() - start_train_val
print(f"Train & Validation Total Elapsed Time (s): {train_val_time: .4f}")
# ==================================================== Test
# first, load the best model
early_stopper.load_checkpoint(model)
# loading the test negative samples
dataset.load_test_ns()
# final testing
start_test = timeit.default_timer()
perf_metric_test = test(test_loader, neg_sampler, split_mode="test")
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'data': DATA,
'model': MODEL_NAME,
'run': run_idx,
'seed': SEED,
'train_times': train_times_l,
'free_mem': free_mem_l,
'total_mem': total_mem_l,
'used_mem': used_mem_l,
'max_used_mem': max(used_mem_l),
'val_times': val_times_l,
f'val {metric}': val_perf_list,
f'test {metric}': perf_metric_test,
'test_time': test_time,
'train_val_total_time': np.sum(np.array(train_times_l)) + np.sum(np.array(val_times_l)),
},
results_filename)
print(f"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<")
print('-------------------------------------------------------------------------------')
print(f"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}")
print("==============================================================")
================================================
FILE: examples/linkproppred/tgbl-flight/edgebank.py
================================================
"""
Dynamic Link Prediction with EdgeBank
NOTE: This implementation works only based on `numpy`
Reference:
- https://github.com/fpour/DGB/tree/main
"""
import timeit
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch_geometric.loader import TemporalDataLoader
from tqdm import tqdm
import math
import os
import os.path as osp
from pathlib import Path
import sys
import argparse
# internal imports
from tgb.linkproppred.evaluate import Evaluator
from modules.edgebank_predictor import EdgeBankPredictor
from tgb.utils.utils import set_random_seed
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.utils.utils import save_results
# ==================
# ==================
# ==================
def test(data, test_mask, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
data: a dataset object
test_mask: required masks to load the test set edges
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluation
"""
num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)
perf_list = []
for batch_idx in tqdm(range(num_batches)):
start_idx = batch_idx * BATCH_SIZE
end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))
pos_src, pos_dst, pos_t = (
data['sources'][test_mask][start_idx: end_idx],
data['destinations'][test_mask][start_idx: end_idx],
data['timestamps'][test_mask][start_idx: end_idx],
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])
query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])
y_pred = edgebank.predict_link(query_src, query_dst)
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0]]),
"y_pred_neg": np.array(y_pred[1:]),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# update edgebank memory after each positive batch
edgebank.update_memory(pos_src, pos_dst, pos_t)
perf_metrics = float(np.mean(perf_list))
return perf_metrics
def get_args():
parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')
parser.add_argument('-d', '--data', type=str, help='Dataset name', default='tgbl-flight')
parser.add_argument('--bs', type=int, help='Batch size', default=200)
parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)
parser.add_argument('--seed', type=int, help='Random seed', default=1)
parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])
parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)
try:
args = parser.parse_args()
except:
parser.print_help()
sys.exit(0)
return args, sys.argv
# ==================
# ==================
# ==================
start_overall = timeit.default_timer()
# set hyperparameters
args, _ = get_args()
SEED = args.seed # set the random seed for consistency
set_random_seed(SEED)
MEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`
BATCH_SIZE = args.bs
K_VALUE = args.k_value
TIME_WINDOW_RATIO = args.time_window_ratio
DATA = "tgbl-flight"
MODEL_NAME = 'EdgeBank'
# data loading with `numpy`
dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
data = dataset.full_data
metric = dataset.eval_metric
# get masks
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
#data for memory in edgebank
hist_src = np.concatenate([data['sources'][train_mask]])
hist_dst = np.concatenate([data['destinations'][train_mask]])
hist_ts = np.concatenate([data['timestamps'][train_mask]])
# Set EdgeBank with memory updater
edgebank = EdgeBankPredictor(
hist_src,
hist_dst,
hist_ts,
memory_mode=MEMORY_MODE,
time_window_ratio=TIME_WINDOW_RATIO)
print("==========================================================")
print(f"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'
# ==================================================== Test
# loading the validation negative samples
dataset.load_val_ns()
# testing ...
start_val = timeit.default_timer()
perf_metric_test = test(data, val_mask, neg_sampler, split_mode='val')
end_val = timeit.default_timer()
print(f"INFO: val: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tval: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_val
print(f"\tval: Elapsed Time (s): {test_time: .4f}")
# ==================================================== Test
# loading the test negative samples
dataset.load_test_ns()
# testing ...
start_test = timeit.default_timer()
perf_metric_test = test(data, test_mask, neg_sampler, split_mode='test')
end_test = timeit.default_timer()
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'memory_mode': MEMORY_MODE,
'data': DATA,
'run': 1,
'seed': SEED,
metric: perf_metric_test,
'test_time': test_time,
'tot_train_val_time': 'NA'
},
results_filename)
================================================
FILE: examples/linkproppred/tgbl-flight/tgn.py
================================================
"""
Dynamic Link Prediction with a TGN model with Early Stopping
Reference:
- https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py
command for an example run:
python examples/linkproppred/tgbl-flight/tgn.py --data "tgbl-flight" --num_run 1 --seed 1
"""
import math
import timeit
import os
import os.path as osp
from pathlib import Path
import numpy as np
import torch
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn import Linear
from torch_geometric.datasets import JODIEDataset
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TransformerConv
# internal imports
from tgb.utils.utils import get_args, set_random_seed, save_results
from tgb.linkproppred.evaluate import Evaluator
from modules.decoder import LinkPredictor
from modules.emb_module import GraphAttentionEmbedding
from modules.msg_func import IdentityMessage
from modules.msg_agg import LastAggregator
from modules.neighbor_loader import LastNeighborLoader
from modules.memory_module import TGNMemory
from modules.early_stopping import EarlyStopMonitor
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
# ==========
# ========== Define helper function...
# ==========
def train():
r"""
Training procedure for TGN model
This function uses some objects that are globally defined in the current scrips
Parameters:
None
Returns:
None
"""
model['memory'].train()
model['gnn'].train()
model['link_pred'].train()
model['memory'].reset_state() # Start with a fresh memory.
neighbor_loader.reset_state() # Start with an empty graph.
total_loss = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
# Sample negative destination nodes.
neg_dst = torch.randint(
min_dst_idx,
max_dst_idx + 1,
(src.size(0),),
dtype=torch.long,
device=device,
)
n_id = torch.cat([src, pos_dst, neg_dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])
neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])
loss = criterion(pos_out, torch.ones_like(pos_out))
loss += criterion(neg_out, torch.zeros_like(neg_out))
# Update memory and neighbor loader with ground-truth state.
model['memory'].update_state(src, pos_dst, t, msg)
neighbor_loader.insert(src, pos_dst)
loss.backward()
optimizer.step()
model['memory'].detach()
total_loss += float(loss.detach()) * batch.num_events
return total_loss / train_data.num_events
@torch.no_grad()
def test(loader, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
loader: an object containing positive attributes of the positive edges of the evaluation set
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluation
"""
model['memory'].eval()
model['gnn'].eval()
model['link_pred'].eval()
perf_list = []
for pos_batch in loader:
pos_src, pos_dst, pos_t, pos_msg = (
pos_batch.src,
pos_batch.dst,
pos_batch.t,
pos_batch.msg,
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)
dst = torch.tensor(
np.concatenate(
([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),
axis=0,
),
device=device,
)
n_id = torch.cat([src, dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),
"y_pred_neg": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# Update memory and neighbor loader with ground-truth state.
model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg)
neighbor_loader.insert(pos_src, pos_dst)
perf_metrics = float(torch.tensor(perf_list).mean())
return perf_metrics
# ==========
# ==========
# ==========
# Start...
start_overall = timeit.default_timer()
DATA = "tgbl-flight"
# ========== set parameters...
args, _ = get_args()
args.data = DATA
print("INFO: Arguments:", args)
LR = args.lr
BATCH_SIZE = args.bs
K_VALUE = args.k_value
NUM_EPOCH = args.num_epoch
SEED = args.seed
MEM_DIM = args.mem_dim
TIME_DIM = args.time_dim
EMB_DIM = args.emb_dim
TOLERANCE = args.tolerance
PATIENCE = args.patience
NUM_RUNS = args.num_run
NUM_NEIGHBORS = 10
MODEL_NAME = 'TGN'
# ==========
# set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# data loading
dataset = PyGLinkPropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data = data.to(device)
metric = dataset.eval_metric
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
train_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)
val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)
test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
print("==========================================================")
print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
for run_idx in range(NUM_RUNS):
print('-------------------------------------------------------------------------------')
print(f"INFO: >>>>> Run: {run_idx} <<<<<")
start_run = timeit.default_timer()
# set the seed for deterministic results...
torch.manual_seed(run_idx + SEED)
set_random_seed(run_idx + SEED)
# neighborhood sampler
neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)
# define the model end-to-end
memory = TGNMemory(
data.num_nodes,
data.msg.size(-1),
MEM_DIM,
TIME_DIM,
message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),
aggregator_module=LastAggregator(),
).to(device)
gnn = GraphAttentionEmbedding(
in_channels=MEM_DIM,
out_channels=EMB_DIM,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
).to(device)
link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)
model = {'memory': memory,
'gnn': gnn,
'link_pred': link_pred}
optimizer = torch.optim.Adam(
set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),
lr=LR,
)
criterion = torch.nn.BCEWithLogitsLoss()
# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)
# define an early stopper
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'
early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id,
tolerance=TOLERANCE, patience=PATIENCE)
# ==================================================== Train & Validation
# loading the validation negative samples
dataset.load_val_ns()
val_perf_list = []
train_times_l, val_times_l = [], []
free_mem_l, total_mem_l, used_mem_l = [], [], []
start_train_val = timeit.default_timer()
for epoch in range(1, NUM_EPOCH + 1):
# training
start_epoch_train = timeit.default_timer()
loss = train()
end_epoch_train = timeit.default_timer()
print(
f"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {end_epoch_train - start_epoch_train: .4f}"
)
# checking GPU memory usage
free_mem, used_mem, total_mem = 0, 0, 0
if torch.cuda.is_available():
print("DEBUG: device: {}".format(torch.cuda.get_device_name(0)))
free_mem, total_mem = torch.cuda.mem_get_info()
used_mem = total_mem - free_mem
print("------------Epoch {}: GPU memory usage-----------".format(epoch))
print("Free memory: {}".format(free_mem))
print("Total available memory: {}".format(total_mem))
print("Used memory: {}".format(used_mem))
print("--------------------------------------------")
train_times_l.append(end_epoch_train - start_epoch_train)
free_mem_l.append(float((free_mem*1.0)/2**30)) # in GB
used_mem_l.append(float((used_mem*1.0)/2**30)) # in GB
total_mem_l.append(float((total_mem*1.0)/2**30)) # in GB
# validation
start_val = timeit.default_timer()
perf_metric_val = test(val_loader, neg_sampler, split_mode="val")
end_val = timeit.default_timer()
print(f"\tValidation {metric}: {perf_metric_val: .4f}")
print(f"\tValidation: Elapsed time (s): {end_val - start_val: .4f}")
val_perf_list.append(perf_metric_val)
val_times_l.append(end_val - start_val)
# check for early stopping
if early_stopper.step_check(perf_metric_val, model):
break
train_val_time = timeit.default_timer() - start_train_val
print(f"Train & Validation: Elapsed Time (s): {train_val_time: .4f}")
# ==================================================== Test
# first, load the best model
early_stopper.load_checkpoint(model)
# loading the test negative samples
dataset.load_test_ns()
# final testing
start_test = timeit.default_timer()
perf_metric_test = test(test_loader, neg_sampler, split_mode="test")
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'data': DATA,
'model': MODEL_NAME,
'run': run_idx,
'seed': SEED,
'train_times': train_times_l,
'free_mem': free_mem_l,
'total_mem': total_mem_l,
'used_mem': used_mem_l,
'max_used_mem': max(used_mem_l),
'val_times': val_times_l,
f'val {metric}': val_perf_list,
f'test {metric}': perf_metric_test,
'test_time': test_time,
'train_val_total_time': np.sum(np.array(train_times_l)) + np.sum(np.array(val_times_l)),
},
results_filename)
print(f"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<")
print('-------------------------------------------------------------------------------')
print(f"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}")
print("==============================================================")
================================================
FILE: examples/linkproppred/tgbl-lastfm/edgebank.py
================================================
"""
Dynamic Link Prediction with EdgeBank
NOTE: This implementation works only based on `numpy`
Reference:
- https://github.com/fpour/DGB/tree/main
"""
import timeit
import numpy as np
from tqdm import tqdm
import math
import os
import os.path as osp
from pathlib import Path
import sys
import argparse
# internal imports
from tgb.linkproppred.evaluate import Evaluator
from modules.edgebank_predictor import EdgeBankPredictor
from tgb.utils.utils import set_random_seed
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.utils.utils import save_results
# ==================
# ==================
# ==================
def test(data, test_mask, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
data: a dataset object
test_mask: required masks to load the test set edges
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluaiton
"""
num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)
perf_list = []
for batch_idx in tqdm(range(num_batches)):
start_idx = batch_idx * BATCH_SIZE
end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))
pos_src, pos_dst, pos_t = (
data['sources'][test_mask][start_idx: end_idx],
data['destinations'][test_mask][start_idx: end_idx],
data['timestamps'][test_mask][start_idx: end_idx],
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])
query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])
y_pred = edgebank.predict_link(query_src, query_dst)
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0]]),
"y_pred_neg": np.array(y_pred[1:]),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# update edgebank memory after each positive batch
edgebank.update_memory(pos_src, pos_dst, pos_t)
perf_metrics = float(np.mean(perf_list))
return perf_metrics
def get_args():
parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')
parser.add_argument('-d', '--data', type=str, help='Dataset name', default='tgbl-wiki')
parser.add_argument('--bs', type=int, help='Batch size', default=200)
parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)
parser.add_argument('--seed', type=int, help='Random seed', default=1)
parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])
parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)
try:
args = parser.parse_args()
except:
parser.print_help()
sys.exit(0)
return args, sys.argv
# ==================
# ==================
# ==================
start_overall = timeit.default_timer()
# set hyperparameters
args, _ = get_args()
SEED = args.seed # set the random seed for consistency
set_random_seed(SEED)
MEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`
BATCH_SIZE = args.bs
K_VALUE = args.k_value
TIME_WINDOW_RATIO = args.time_window_ratio
DATA = "tgbl-lastfm"
MODEL_NAME = 'EdgeBank'
# data loading with `numpy`
dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
data = dataset.full_data
metric = dataset.eval_metric
# get masks
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
#data for memory in edgebank
hist_src = np.concatenate([data['sources'][train_mask]])
hist_dst = np.concatenate([data['destinations'][train_mask]])
hist_ts = np.concatenate([data['timestamps'][train_mask]])
# Set EdgeBank with memory updater
edgebank = EdgeBankPredictor(
hist_src,
hist_dst,
hist_ts,
memory_mode=MEMORY_MODE,
time_window_ratio=TIME_WINDOW_RATIO)
print("==========================================================")
print(f"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'
# ==================================================== Test
# loading the validation negative samples
dataset.load_val_ns()
# testing ...
start_val = timeit.default_timer()
perf_metric_test = test(data, val_mask, neg_sampler, split_mode='val')
end_val = timeit.default_timer()
print(f"INFO: val: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tval: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_val
print(f"\tval: Elapsed Time (s): {test_time: .4f}")
# ==================================================== Test
# loading the test negative samples
dataset.load_test_ns()
# testing ...
start_test = timeit.default_timer()
perf_metric_test = test(data, test_mask, neg_sampler, split_mode='test')
end_test = timeit.default_timer()
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'memory_mode': MEMORY_MODE,
'data': DATA,
'run': 1,
'seed': SEED,
metric: perf_metric_test,
'test_time': test_time,
'tot_train_val_time': 'NA'
},
results_filename)
================================================
FILE: examples/linkproppred/tgbl-lastfm/tgn.py
================================================
"""
Dynamic Link Prediction with a TGN model with Early Stopping
Reference:
- https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py
command for an example run:
python examples/linkproppred/tgbl-lastfm/tgn.py --data "tgbl-lastfm" --num_run 1 --seed 1
"""
import math
import timeit
import os
import os.path as osp
from pathlib import Path
import numpy as np
import torch
from torch.nn import Linear
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TransformerConv
# internal imports
from tgb.utils.utils import get_args, set_random_seed, save_results
from tgb.linkproppred.evaluate import Evaluator
from modules.decoder import LinkPredictor
from modules.emb_module import GraphAttentionEmbedding
from modules.msg_func import IdentityMessage
from modules.msg_agg import LastAggregator
from modules.neighbor_loader import LastNeighborLoader
from modules.memory_module import TGNMemory
from modules.early_stopping import EarlyStopMonitor
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
# ==========
# ========== Define helper function...
# ==========
def train():
r"""
Training procedure for TGN model
This function uses some objects that are globally defined in the current scrips
Parameters:
None
Returns:
None
"""
model['memory'].train()
model['gnn'].train()
model['link_pred'].train()
model['memory'].reset_state() # Start with a fresh memory.
neighbor_loader.reset_state() # Start with an empty graph.
total_loss = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
# Sample negative destination nodes.
neg_dst = torch.randint(
min_dst_idx,
max_dst_idx + 1,
(src.size(0),),
dtype=torch.long,
device=device,
)
n_id = torch.cat([src, pos_dst, neg_dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])
neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])
loss = criterion(pos_out, torch.ones_like(pos_out))
loss += criterion(neg_out, torch.zeros_like(neg_out))
# Update memory and neighbor loader with ground-truth state.
model['memory'].update_state(src, pos_dst, t, msg)
neighbor_loader.insert(src, pos_dst)
loss.backward()
optimizer.step()
model['memory'].detach()
total_loss += float(loss.detach()) * batch.num_events
return total_loss / train_data.num_events
@torch.no_grad()
def test(loader, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
loader: an object containing positive attributes of the positive edges of the evaluation set
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluaiton
"""
model['memory'].eval()
model['gnn'].eval()
model['link_pred'].eval()
perf_list = []
for pos_batch in loader:
pos_src, pos_dst, pos_t, pos_msg = (
pos_batch.src,
pos_batch.dst,
pos_batch.t,
pos_batch.msg,
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)
dst = torch.tensor(
np.concatenate(
([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),
axis=0,
),
device=device,
)
n_id = torch.cat([src, dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),
"y_pred_neg": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# Update memory and neighbor loader with ground-truth state.
model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg)
neighbor_loader.insert(pos_src, pos_dst)
perf_metrics = float(torch.tensor(perf_list).mean())
return perf_metrics
# ==========
# ==========
# ==========
# Start...
start_overall = timeit.default_timer()
DATA = "tgbl-lastfm"
# ========== set parameters...
args, _ = get_args()
args.data = DATA
print("INFO: Arguments:", args)
LR = args.lr
BATCH_SIZE = args.bs
K_VALUE = args.k_value
NUM_EPOCH = args.num_epoch
SEED = args.seed
MEM_DIM = args.mem_dim
TIME_DIM = args.time_dim
EMB_DIM = args.emb_dim
TOLERANCE = args.tolerance
PATIENCE = args.patience
NUM_RUNS = args.num_run
NUM_NEIGHBORS = 10
MODEL_NAME = 'TGN'
# ==========
# set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# data loading
dataset = PyGLinkPropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data = data.to(device)
metric = dataset.eval_metric
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
train_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)
val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)
test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
print("==========================================================")
print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
for run_idx in range(NUM_RUNS):
print('-------------------------------------------------------------------------------')
print(f"INFO: >>>>> Run: {run_idx} <<<<<")
start_run = timeit.default_timer()
# set the seed for deterministic results...
torch.manual_seed(run_idx + SEED)
set_random_seed(run_idx + SEED)
# define an early stopper
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'
early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id,
tolerance=TOLERANCE, patience=PATIENCE)
# neighhorhood sampler
neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)
# define the model end-to-end
memory = TGNMemory(
data.num_nodes,
data.msg.size(-1),
MEM_DIM,
TIME_DIM,
message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),
aggregator_module=LastAggregator(),
).to(device)
gnn = GraphAttentionEmbedding(
in_channels=MEM_DIM,
out_channels=EMB_DIM,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
).to(device)
link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)
model = {'memory': memory,
'gnn': gnn,
'link_pred': link_pred}
optimizer = torch.optim.Adam(
set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),
lr=LR,
)
criterion = torch.nn.BCEWithLogitsLoss()
# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)
# ==================================================== Train & Validation
# loading the validation negative samples
dataset.load_val_ns()
val_perf_list = []
start_train_val = timeit.default_timer()
for epoch in range(1, NUM_EPOCH + 1):
# training
start_epoch_train = timeit.default_timer()
loss = train()
print(
f"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {timeit.default_timer() - start_epoch_train: .4f}"
)
# validation
start_val = timeit.default_timer()
perf_metric_val = test(val_loader, neg_sampler, split_mode="val")
print(f"\tValidation {metric}: {perf_metric_val: .4f}")
print(f"\tValidation: Elapsed time (s): {timeit.default_timer() - start_val: .4f}")
val_perf_list.append(perf_metric_val)
# check for early stopping
if early_stopper.step_check(perf_metric_val, model):
break
train_val_time = timeit.default_timer() - start_train_val
print(f"Train & Validation: Elapsed Time (s): {train_val_time: .4f}")
# ==================================================== Test
# first, load the best model
early_stopper.load_checkpoint(model)
# loading the test negative samples
dataset.load_test_ns()
# final testing
start_test = timeit.default_timer()
perf_metric_test = test(test_loader, neg_sampler, split_mode="test")
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'data': DATA,
'run': run_idx,
'seed': SEED,
f'val {metric}': val_perf_list,
f'test {metric}': perf_metric_test,
'test_time': test_time,
'tot_train_val_time': train_val_time
},
results_filename)
print(f"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<")
print('-------------------------------------------------------------------------------')
print(f"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}")
print("==============================================================")
================================================
FILE: examples/linkproppred/tgbl-review/dyrep.py
================================================
"""
DyRep
This has been implemented with intuitions from the following sources:
- https://github.com/twitter-research/tgn
- https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py
Spec.:
- Memory Updater: RNN
- Embedding Module: ID
- Message Function: ATTN
command for an example run:
python examples/linkproppred/tgbl-review/dyrep.py --data "tgbl-review" --num_run 1 --seed 1
"""
import math
import timeit
import os
import os.path as osp
from pathlib import Path
import numpy as np
import torch
from torch_geometric.loader import TemporalDataLoader
# internal imports
from tgb.utils.utils import get_args, set_random_seed, save_results
from tgb.linkproppred.evaluate import Evaluator
from modules.decoder import LinkPredictor
from modules.emb_module import GraphAttentionEmbedding
from modules.msg_func import IdentityMessage
from modules.msg_agg import LastAggregator
from modules.neighbor_loader import LastNeighborLoader
from modules.memory_module import DyRepMemory
from modules.early_stopping import EarlyStopMonitor
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
# ==========
# ========== Define helper function...
# ==========
def train():
r"""
Training procedure for DyRep model
This function uses some objects that are globally defined in the current scrips
Parameters:
None
Returns:
None
"""
model['memory'].train()
model['gnn'].train()
model['link_pred'].train()
model['memory'].reset_state() # Start with a fresh memory.
neighbor_loader.reset_state() # Start with an empty graph.
total_loss = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
# Sample negative destination nodes.
neg_dst = torch.randint(
min_dst_idx,
max_dst_idx + 1,
(src.size(0),),
dtype=torch.long,
device=device,
)
n_id = torch.cat([src, pos_dst, neg_dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])
neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])
loss = criterion(pos_out, torch.ones_like(pos_out))
loss += criterion(neg_out, torch.zeros_like(neg_out))
# update the memory with ground-truth
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
model['memory'].update_state(src, pos_dst, t, msg, z, assoc)
# update neighbor loader
neighbor_loader.insert(src, pos_dst)
loss.backward()
optimizer.step()
model['memory'].detach()
total_loss += float(loss.detach()) * batch.num_events
return total_loss / train_data.num_events
@torch.no_grad()
def test(loader, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
loader: an object containing positive attributes of the positive edges of the evaluation set
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluation
"""
model['memory'].eval()
model['gnn'].eval()
model['link_pred'].eval()
perf_list = []
for pos_batch in loader:
pos_src, pos_dst, pos_t, pos_msg = (
pos_batch.src,
pos_batch.dst,
pos_batch.t,
pos_batch.msg,
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)
dst = torch.tensor(
np.concatenate(
([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),
axis=0,
),
device=device,
)
n_id = torch.cat([src, dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),
"y_pred_neg": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# update the memory with positive edges
n_id = torch.cat([pos_src, pos_dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg, z, assoc)
# update the neighbor loader
neighbor_loader.insert(pos_src, pos_dst)
perf_metric = float(torch.tensor(perf_list).mean())
return perf_metric
# ==========
# ==========
# ==========
# Start...
start_overall = timeit.default_timer()
# ========== set parameters...
args, _ = get_args()
print("INFO: Arguments:", args)
DATA = "tgbl-review"
LR = args.lr
BATCH_SIZE = args.bs
K_VALUE = args.k_value
NUM_EPOCH = args.num_epoch
SEED = args.seed
MEM_DIM = args.mem_dim
TIME_DIM = args.time_dim
EMB_DIM = args.emb_dim
TOLERANCE = args.tolerance
PATIENCE = args.patience
NUM_RUNS = args.num_run
NUM_NEIGHBORS = 10
MODEL_NAME = 'DyRep'
USE_SRC_EMB_IN_MSG = False
USE_DST_EMB_IN_MSG = True
# ==========
# set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# data loading
dataset = PyGLinkPropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data = data.to(device)
metric = dataset.eval_metric
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
train_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)
val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)
test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
# neighborhood sampler
neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)
# define the model end-to-end
# 1) memory
memory = DyRepMemory(
data.num_nodes,
data.msg.size(-1),
MEM_DIM,
TIME_DIM,
message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),
aggregator_module=LastAggregator(),
memory_updater_type='rnn',
use_src_emb_in_msg=USE_SRC_EMB_IN_MSG,
use_dst_emb_in_msg=USE_DST_EMB_IN_MSG
).to(device)
# 2) GNN
gnn = GraphAttentionEmbedding(
in_channels=MEM_DIM,
out_channels=EMB_DIM,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
).to(device)
# 3) link predictor
link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)
model = {'memory': memory,
'gnn': gnn,
'link_pred': link_pred}
# define an optimizer
optimizer = torch.optim.Adam(
set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),
lr=LR,
)
criterion = torch.nn.BCEWithLogitsLoss()
# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)
print("==========================================================")
print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
for run_idx in range(NUM_RUNS):
print('-------------------------------------------------------------------------------')
print(f"INFO: >>>>> Run: {run_idx} <<<<<")
start_run = timeit.default_timer()
# set the seed for deterministic results...
torch.manual_seed(run_idx + SEED)
set_random_seed(run_idx + SEED)
# define an early stopper
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'
early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id,
tolerance=TOLERANCE, patience=PATIENCE)
# ==================================================== Train & Validation
# loading the validation negative samples
dataset.load_val_ns()
val_perf_list = []
start_train_val = timeit.default_timer()
for epoch in range(1, NUM_EPOCH + 1):
# training
start_epoch_train = timeit.default_timer()
loss = train()
print(
f"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {timeit.default_timer() - start_epoch_train: .4f}"
)
# validation
start_val = timeit.default_timer()
perf_metric_val = test(val_loader, neg_sampler, split_mode="val")
print(f"\tValidation {metric}: {perf_metric_val: .4f}")
print(f"\tValidation: Elapsed time (s): {timeit.default_timer() - start_val: .4f}")
val_perf_list.append(perf_metric_val)
# check for early stopping
if early_stopper.step_check(perf_metric_val, model):
break
train_val_time = timeit.default_timer() - start_train_val
print(f"Train & Validation Total Elapsed Time (s): {train_val_time: .4f}")
# ==================================================== Test
# first, load the best model
early_stopper.load_checkpoint(model)
# loading the test negative samples
dataset.load_test_ns()
# final testing
start_test = timeit.default_timer()
perf_metric_test = test(test_loader, neg_sampler, split_mode="test")
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'data': DATA,
'run': run_idx,
'seed': SEED,
f'val {metric}': val_perf_list,
f'test {metric}': perf_metric_test,
'test_time': test_time,
'tot_train_val_time': train_val_time
},
results_filename)
print(f"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<")
print('-------------------------------------------------------------------------------')
print(f"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}")
print("==============================================================")
================================================
FILE: examples/linkproppred/tgbl-review/edgebank.py
================================================
"""
Dynamic Link Prediction with EdgeBank
NOTE: This implementation works only based on `numpy`
Reference:
- https://github.com/fpour/DGB/tree/main
"""
import timeit
import numpy as np
from tqdm import tqdm
import math
import os
import os.path as osp
from pathlib import Path
import sys
import argparse
# internal imports
from tgb.linkproppred.evaluate import Evaluator
from modules.edgebank_predictor import EdgeBankPredictor
from tgb.utils.utils import set_random_seed
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.utils.utils import save_results
# ==================
# ==================
# ==================
def test(data, test_mask, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
data: a dataset object
test_mask: required masks to load the test set edges
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluation
"""
num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)
perf_list = []
for batch_idx in tqdm(range(num_batches)):
start_idx = batch_idx * BATCH_SIZE
end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))
pos_src, pos_dst, pos_t = (
data['sources'][test_mask][start_idx: end_idx],
data['destinations'][test_mask][start_idx: end_idx],
data['timestamps'][test_mask][start_idx: end_idx],
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])
query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])
y_pred = edgebank.predict_link(query_src, query_dst)
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0]]),
"y_pred_neg": np.array(y_pred[1:]),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# update edgebank memory after each positive batch
edgebank.update_memory(pos_src, pos_dst, pos_t)
perf_metrics = float(np.mean(perf_list))
return perf_metrics
def get_args():
parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')
parser.add_argument('-d', '--data', type=str, help='Dataset name', default='tgbl-review')
parser.add_argument('--bs', type=int, help='Batch size', default=200)
parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)
parser.add_argument('--seed', type=int, help='Random seed', default=1)
parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])
parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)
try:
args = parser.parse_args()
except:
parser.print_help()
sys.exit(0)
return args, sys.argv
# ==================
# ==================
# ==================
start_overall = timeit.default_timer()
# set hyperparameters
args, _ = get_args()
SEED = args.seed # set the random seed for consistency
set_random_seed(SEED)
MEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`
BATCH_SIZE = args.bs
K_VALUE = args.k_value
TIME_WINDOW_RATIO = args.time_window_ratio
DATA = "tgbl-review"
MODEL_NAME = 'EdgeBank'
# data loading with `numpy`
dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
data = dataset.full_data
metric = dataset.eval_metric
# get masks
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
#data for memory in edgebank
hist_src = np.concatenate([data['sources'][train_mask]])
hist_dst = np.concatenate([data['destinations'][train_mask]])
hist_ts = np.concatenate([data['timestamps'][train_mask]])
# Set EdgeBank with memory updater
edgebank = EdgeBankPredictor(
hist_src,
hist_dst,
hist_ts,
memory_mode=MEMORY_MODE,
time_window_ratio=TIME_WINDOW_RATIO)
print("==========================================================")
print(f"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'
# ==================================================== Test
# loading the validation negative samples
dataset.load_val_ns()
# testing ...
start_val = timeit.default_timer()
perf_metric_test = test(data, val_mask, neg_sampler, split_mode='val')
end_val = timeit.default_timer()
print(f"INFO: val: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tval: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_val
print(f"\tval: Elapsed Time (s): {test_time: .4f}")
# ==================================================== Test
# loading the test negative samples
dataset.load_test_ns()
# testing ...
start_test = timeit.default_timer()
perf_metric_test = test(data, test_mask, neg_sampler, split_mode='test')
end_test = timeit.default_timer()
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'memory_mode': MEMORY_MODE,
'data': DATA,
'run': 1,
'seed': SEED,
metric: perf_metric_test,
'test_time': test_time,
'tot_train_val_time': 'NA'
},
results_filename)
================================================
FILE: examples/linkproppred/tgbl-review/tgn.py
================================================
"""
Dynamic Link Prediction with a TGN model with Early Stopping
Reference:
- https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py
command for an example run:
python examples/linkproppred/tgbl-review/tgn.py --data "tgbl-review" --num_run 1 --seed 1
"""
import math
import timeit
import os
import os.path as osp
from pathlib import Path
import numpy as np
import torch
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn import Linear
from torch_geometric.datasets import JODIEDataset
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TransformerConv
# internal imports
from tgb.utils.utils import get_args, set_random_seed, save_results
from tgb.linkproppred.evaluate import Evaluator
from modules.decoder import LinkPredictor
from modules.emb_module import GraphAttentionEmbedding
from modules.msg_func import IdentityMessage
from modules.msg_agg import LastAggregator
from modules.neighbor_loader import LastNeighborLoader
from modules.memory_module import TGNMemory
from modules.early_stopping import EarlyStopMonitor
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
# ==========
# ========== Define helper function...
# ==========
def train():
r"""
Training procedure for TGN model
This function uses some objects that are globally defined in the current scrips
Parameters:
None
Returns:
None
"""
model['memory'].train()
model['gnn'].train()
model['link_pred'].train()
model['memory'].reset_state() # Start with a fresh memory.
neighbor_loader.reset_state() # Start with an empty graph.
total_loss = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
# Sample negative destination nodes.
neg_dst = torch.randint(
min_dst_idx,
max_dst_idx + 1,
(src.size(0),),
dtype=torch.long,
device=device,
)
n_id = torch.cat([src, pos_dst, neg_dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])
neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])
loss = criterion(pos_out, torch.ones_like(pos_out))
loss += criterion(neg_out, torch.zeros_like(neg_out))
# Update memory and neighbor loader with ground-truth state.
model['memory'].update_state(src, pos_dst, t, msg)
neighbor_loader.insert(src, pos_dst)
loss.backward()
optimizer.step()
model['memory'].detach()
total_loss += float(loss.detach()) * batch.num_events
return total_loss / train_data.num_events
@torch.no_grad()
def test(loader, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
loader: an object containing positive attributes of the positive edges of the evaluation set
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluation
"""
model['memory'].eval()
model['gnn'].eval()
model['link_pred'].eval()
perf_list = []
for pos_batch in loader:
pos_src, pos_dst, pos_t, pos_msg = (
pos_batch.src,
pos_batch.dst,
pos_batch.t,
pos_batch.msg,
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)
dst = torch.tensor(
np.concatenate(
([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),
axis=0,
),
device=device,
)
n_id = torch.cat([src, dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),
"y_pred_neg": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# Update memory and neighbor loader with ground-truth state.
model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg)
neighbor_loader.insert(pos_src, pos_dst)
perf_metrics = float(torch.tensor(perf_list).mean())
return perf_metrics
# ==========
# ==========
# ==========
# Start...
start_overall = timeit.default_timer()
DATA = "tgbl-review"
# ========== set parameters...
args, _ = get_args()
args.data = DATA
print("INFO: Arguments:", args)
LR = args.lr
BATCH_SIZE = args.bs
K_VALUE = args.k_value
NUM_EPOCH = args.num_epoch
SEED = args.seed
MEM_DIM = args.mem_dim
TIME_DIM = args.time_dim
EMB_DIM = args.emb_dim
TOLERANCE = args.tolerance
PATIENCE = args.patience
NUM_RUNS = args.num_run
NUM_NEIGHBORS = 10
MODEL_NAME = 'TGN'
# ==========
# set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# data loading
dataset = PyGLinkPropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data = data.to(device)
metric = dataset.eval_metric
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
train_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)
val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)
test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
print("==========================================================")
print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
for run_idx in range(NUM_RUNS):
print('-------------------------------------------------------------------------------')
print(f"INFO: >>>>> Run: {run_idx} <<<<<")
start_run = timeit.default_timer()
# set the seed for deterministic results...
torch.manual_seed(run_idx + SEED)
set_random_seed(run_idx + SEED)
# neighborhood sampler
neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)
# define the model end-to-end
memory = TGNMemory(
data.num_nodes,
data.msg.size(-1),
MEM_DIM,
TIME_DIM,
message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),
aggregator_module=LastAggregator(),
).to(device)
gnn = GraphAttentionEmbedding(
in_channels=MEM_DIM,
out_channels=EMB_DIM,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
).to(device)
link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)
model = {'memory': memory,
'gnn': gnn,
'link_pred': link_pred}
optimizer = torch.optim.Adam(
set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),
lr=LR,
)
criterion = torch.nn.BCEWithLogitsLoss()
# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)
# define an early stopper
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'
early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id,
tolerance=TOLERANCE, patience=PATIENCE)
# ==================================================== Train & Validation
# loading the validation negative samples
dataset.load_val_ns()
val_perf_list = []
start_train_val = timeit.default_timer()
for epoch in range(1, NUM_EPOCH + 1):
# training
start_epoch_train = timeit.default_timer()
loss = train()
print(
f"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {timeit.default_timer() - start_epoch_train: .4f}"
)
# validation
start_val = timeit.default_timer()
perf_metric_val = test(val_loader, neg_sampler, split_mode="val")
print(f"\tValidation {metric}: {perf_metric_val: .4f}")
print(f"\tValidation: Elapsed time (s): {timeit.default_timer() - start_val: .4f}")
val_perf_list.append(perf_metric_val)
# check for early stopping
if early_stopper.step_check(perf_metric_val, model):
break
train_val_time = timeit.default_timer() - start_train_val
print(f"Train & Validation: Elapsed Time (s): {train_val_time: .4f}")
# ==================================================== Test
# first, load the best model
early_stopper.load_checkpoint(model)
# loading the test negative samples
dataset.load_test_ns()
# final testing
start_test = timeit.default_timer()
perf_metric_test = test(test_loader, neg_sampler, split_mode="test")
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'data': DATA,
'run': run_idx,
'seed': SEED,
f'val {metric}': val_perf_list,
f'test {metric}': perf_metric_test,
'test_time': test_time,
'tot_train_val_time': train_val_time
},
results_filename)
print(f"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<")
print('-------------------------------------------------------------------------------')
print(f"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}")
print("==============================================================")
================================================
FILE: examples/linkproppred/tgbl-subreddit/edgebank.py
================================================
"""
Dynamic Link Prediction with EdgeBank
NOTE: This implementation works only based on `numpy`
Reference:
- https://github.com/fpour/DGB/tree/main
"""
import timeit
import numpy as np
from tqdm import tqdm
import math
import os
import os.path as osp
from pathlib import Path
import sys
import argparse
# internal imports
from tgb.linkproppred.evaluate import Evaluator
from modules.edgebank_predictor import EdgeBankPredictor
from tgb.utils.utils import set_random_seed
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.utils.utils import save_results
# ==================
# ==================
# ==================
def test(data, test_mask, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
data: a dataset object
test_mask: required masks to load the test set edges
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluaiton
"""
num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)
perf_list = []
for batch_idx in tqdm(range(num_batches)):
start_idx = batch_idx * BATCH_SIZE
end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))
pos_src, pos_dst, pos_t = (
data['sources'][test_mask][start_idx: end_idx],
data['destinations'][test_mask][start_idx: end_idx],
data['timestamps'][test_mask][start_idx: end_idx],
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])
query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])
y_pred = edgebank.predict_link(query_src, query_dst)
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0]]),
"y_pred_neg": np.array(y_pred[1:]),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# update edgebank memory after each positive batch
edgebank.update_memory(pos_src, pos_dst, pos_t)
perf_metrics = float(np.mean(perf_list))
return perf_metrics
def get_args():
parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')
parser.add_argument('-d', '--data', type=str, help='Dataset name', default='tgbl-wiki')
parser.add_argument('--bs', type=int, help='Batch size', default=200)
parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)
parser.add_argument('--seed', type=int, help='Random seed', default=1)
parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])
parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)
try:
args = parser.parse_args()
except:
parser.print_help()
sys.exit(0)
return args, sys.argv
# ==================
# ==================
# ==================
start_overall = timeit.default_timer()
# set hyperparameters
args, _ = get_args()
SEED = args.seed # set the random seed for consistency
set_random_seed(SEED)
MEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`
BATCH_SIZE = args.bs
K_VALUE = args.k_value
TIME_WINDOW_RATIO = args.time_window_ratio
DATA = "tgbl-subreddit"
MODEL_NAME = 'EdgeBank'
# data loading with `numpy`
dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
data = dataset.full_data
metric = dataset.eval_metric
# get masks
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
#data for memory in edgebank
hist_src = np.concatenate([data['sources'][train_mask]])
hist_dst = np.concatenate([data['destinations'][train_mask]])
hist_ts = np.concatenate([data['timestamps'][train_mask]])
# Set EdgeBank with memory updater
edgebank = EdgeBankPredictor(
hist_src,
hist_dst,
hist_ts,
memory_mode=MEMORY_MODE,
time_window_ratio=TIME_WINDOW_RATIO)
print("==========================================================")
print(f"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'
# ==================================================== Test
# loading the validation negative samples
dataset.load_val_ns()
# testing ...
start_val = timeit.default_timer()
perf_metric_test = test(data, val_mask, neg_sampler, split_mode='val')
end_val = timeit.default_timer()
print(f"INFO: val: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tval: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_val
print(f"\tval: Elapsed Time (s): {test_time: .4f}")
# ==================================================== Test
# loading the test negative samples
dataset.load_test_ns()
# testing ...
start_test = timeit.default_timer()
perf_metric_test = test(data, test_mask, neg_sampler, split_mode='test')
end_test = timeit.default_timer()
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'memory_mode': MEMORY_MODE,
'data': DATA,
'run': 1,
'seed': SEED,
metric: perf_metric_test,
'test_time': test_time,
'tot_train_val_time': 'NA'
},
results_filename)
================================================
FILE: examples/linkproppred/tgbl-subreddit/tgn.py
================================================
"""
Dynamic Link Prediction with a TGN model with Early Stopping
Reference:
- https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py
command for an example run:
python examples/linkproppred/tgbl-subreddit/tgn.py --data "tgbl-subreddit" --num_run 1 --seed 1
"""
import math
import timeit
import os
import os.path as osp
from pathlib import Path
import numpy as np
import torch
from torch.nn import Linear
from torch_geometric.datasets import JODIEDataset
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TransformerConv
# internal imports
from tgb.utils.utils import get_args, set_random_seed, save_results
from tgb.linkproppred.evaluate import Evaluator
from modules.decoder import LinkPredictor
from modules.emb_module import GraphAttentionEmbedding
from modules.msg_func import IdentityMessage
from modules.msg_agg import LastAggregator
from modules.neighbor_loader import LastNeighborLoader
from modules.memory_module import TGNMemory
from modules.early_stopping import EarlyStopMonitor
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
# ==========
# ========== Define helper function...
# ==========
def train():
r"""
Training procedure for TGN model
This function uses some objects that are globally defined in the current scrips
Parameters:
None
Returns:
None
"""
model['memory'].train()
model['gnn'].train()
model['link_pred'].train()
model['memory'].reset_state() # Start with a fresh memory.
neighbor_loader.reset_state() # Start with an empty graph.
total_loss = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
# Sample negative destination nodes.
neg_dst = torch.randint(
min_dst_idx,
max_dst_idx + 1,
(src.size(0),),
dtype=torch.long,
device=device,
)
n_id = torch.cat([src, pos_dst, neg_dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])
neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])
loss = criterion(pos_out, torch.ones_like(pos_out))
loss += criterion(neg_out, torch.zeros_like(neg_out))
# Update memory and neighbor loader with ground-truth state.
model['memory'].update_state(src, pos_dst, t, msg)
neighbor_loader.insert(src, pos_dst)
loss.backward()
optimizer.step()
model['memory'].detach()
total_loss += float(loss.detach()) * batch.num_events
return total_loss / train_data.num_events
@torch.no_grad()
def test(loader, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
loader: an object containing positive attributes of the positive edges of the evaluation set
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluaiton
"""
model['memory'].eval()
model['gnn'].eval()
model['link_pred'].eval()
perf_list = []
for pos_batch in loader:
pos_src, pos_dst, pos_t, pos_msg = (
pos_batch.src,
pos_batch.dst,
pos_batch.t,
pos_batch.msg,
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)
dst = torch.tensor(
np.concatenate(
([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),
axis=0,
),
device=device,
)
n_id = torch.cat([src, dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),
"y_pred_neg": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# Update memory and neighbor loader with ground-truth state.
model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg)
neighbor_loader.insert(pos_src, pos_dst)
perf_metrics = float(torch.tensor(perf_list).mean())
return perf_metrics
# ==========
# ==========
# ==========
# Start...
start_overall = timeit.default_timer()
DATA = "tgbl-subreddit"
# ========== set parameters...
args, _ = get_args()
args.data = DATA
print("INFO: Arguments:", args)
LR = args.lr
BATCH_SIZE = args.bs
K_VALUE = args.k_value
NUM_EPOCH = args.num_epoch
SEED = args.seed
MEM_DIM = args.mem_dim
TIME_DIM = args.time_dim
EMB_DIM = args.emb_dim
TOLERANCE = args.tolerance
PATIENCE = args.patience
NUM_RUNS = args.num_run
NUM_NEIGHBORS = 10
MODEL_NAME = 'TGN'
# ==========
# set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# data loading
dataset = PyGLinkPropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data = data.to(device)
metric = dataset.eval_metric
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
train_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)
val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)
test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
print("==========================================================")
print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
for run_idx in range(NUM_RUNS):
print('-------------------------------------------------------------------------------')
print(f"INFO: >>>>> Run: {run_idx} <<<<<")
start_run = timeit.default_timer()
# set the seed for deterministic results...
torch.manual_seed(run_idx + SEED)
set_random_seed(run_idx + SEED)
# neighhorhood sampler
neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)
# define the model end-to-end
memory = TGNMemory(
data.num_nodes,
data.msg.size(-1),
MEM_DIM,
TIME_DIM,
message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),
aggregator_module=LastAggregator(),
).to(device)
gnn = GraphAttentionEmbedding(
in_channels=MEM_DIM,
out_channels=EMB_DIM,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
).to(device)
link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)
model = {'memory': memory,
'gnn': gnn,
'link_pred': link_pred}
optimizer = torch.optim.Adam(
set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),
lr=LR,
)
criterion = torch.nn.BCEWithLogitsLoss()
# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)
# define an early stopper
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'
early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id,
tolerance=TOLERANCE, patience=PATIENCE)
# ==================================================== Train & Validation
# loading the validation negative samples
dataset.load_val_ns()
val_perf_list = []
start_train_val = timeit.default_timer()
for epoch in range(1, NUM_EPOCH + 1):
# training
start_epoch_train = timeit.default_timer()
loss = train()
print(
f"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {timeit.default_timer() - start_epoch_train: .4f}"
)
# validation
start_val = timeit.default_timer()
perf_metric_val = test(val_loader, neg_sampler, split_mode="val")
print(f"\tValidation {metric}: {perf_metric_val: .4f}")
print(f"\tValidation: Elapsed time (s): {timeit.default_timer() - start_val: .4f}")
val_perf_list.append(perf_metric_val)
# check for early stopping
if early_stopper.step_check(perf_metric_val, model):
break
train_val_time = timeit.default_timer() - start_train_val
print(f"Train & Validation: Elapsed Time (s): {train_val_time: .4f}")
# ==================================================== Test
# first, load the best model
early_stopper.load_checkpoint(model)
# loading the test negative samples
dataset.load_test_ns()
# final testing
start_test = timeit.default_timer()
perf_metric_test = test(test_loader, neg_sampler, split_mode="test")
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'data': DATA,
'run': run_idx,
'seed': SEED,
f'val {metric}': val_perf_list,
f'test {metric}': perf_metric_test,
'test_time': test_time,
'tot_train_val_time': train_val_time
},
results_filename)
print(f"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<")
print('-------------------------------------------------------------------------------')
print(f"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}")
print("==============================================================")
================================================
FILE: examples/linkproppred/tgbl-uci/edgebank.py
================================================
"""
Dynamic Link Prediction with EdgeBank
NOTE: This implementation works only based on `numpy`
Reference:
- https://github.com/fpour/DGB/tree/main
"""
import timeit
import numpy as np
from tqdm import tqdm
import math
import os
import os.path as osp
from pathlib import Path
import sys
import argparse
# internal imports
from tgb.linkproppred.evaluate import Evaluator
from modules.edgebank_predictor import EdgeBankPredictor
from tgb.utils.utils import set_random_seed
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.utils.utils import save_results
# ==================
# ==================
# ==================
def test(data, test_mask, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
data: a dataset object
test_mask: required masks to load the test set edges
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluaiton
"""
num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)
perf_list = []
for batch_idx in tqdm(range(num_batches)):
start_idx = batch_idx * BATCH_SIZE
end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))
pos_src, pos_dst, pos_t = (
data['sources'][test_mask][start_idx: end_idx],
data['destinations'][test_mask][start_idx: end_idx],
data['timestamps'][test_mask][start_idx: end_idx],
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])
query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])
y_pred = edgebank.predict_link(query_src, query_dst)
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0]]),
"y_pred_neg": np.array(y_pred[1:]),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# update edgebank memory after each positive batch
edgebank.update_memory(pos_src, pos_dst, pos_t)
perf_metrics = float(np.mean(perf_list))
return perf_metrics
def get_args():
parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')
parser.add_argument('-d', '--data', type=str, help='Dataset name', default='tgbl-wiki')
parser.add_argument('--bs', type=int, help='Batch size', default=200)
parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)
parser.add_argument('--seed', type=int, help='Random seed', default=1)
parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])
parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)
try:
args = parser.parse_args()
except:
parser.print_help()
sys.exit(0)
return args, sys.argv
# ==================
# ==================
# ==================
start_overall = timeit.default_timer()
# set hyperparameters
args, _ = get_args()
SEED = args.seed # set the random seed for consistency
set_random_seed(SEED)
MEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`
BATCH_SIZE = args.bs
K_VALUE = args.k_value
TIME_WINDOW_RATIO = args.time_window_ratio
DATA = "tgbl-uci"
MODEL_NAME = 'EdgeBank'
# data loading with `numpy`
dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
data = dataset.full_data
metric = dataset.eval_metric
# get masks
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
#data for memory in edgebank
hist_src = np.concatenate([data['sources'][train_mask]])
hist_dst = np.concatenate([data['destinations'][train_mask]])
hist_ts = np.concatenate([data['timestamps'][train_mask]])
# Set EdgeBank with memory updater
edgebank = EdgeBankPredictor(
hist_src,
hist_dst,
hist_ts,
memory_mode=MEMORY_MODE,
time_window_ratio=TIME_WINDOW_RATIO)
print("==========================================================")
print(f"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'
# ==================================================== Test
# loading the validation negative samples
dataset.load_val_ns()
# testing ...
start_val = timeit.default_timer()
perf_metric_test = test(data, val_mask, neg_sampler, split_mode='val')
end_val = timeit.default_timer()
print(f"INFO: val: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tval: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_val
print(f"\tval: Elapsed Time (s): {test_time: .4f}")
# ==================================================== Test
# loading the test negative samples
dataset.load_test_ns()
# testing ...
start_test = timeit.default_timer()
perf_metric_test = test(data, test_mask, neg_sampler, split_mode='test')
end_test = timeit.default_timer()
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'memory_mode': MEMORY_MODE,
'data': DATA,
'run': 1,
'seed': SEED,
metric: perf_metric_test,
'test_time': test_time,
'tot_train_val_time': 'NA'
},
results_filename)
================================================
FILE: examples/linkproppred/tgbl-wiki/dyrep.py
================================================
"""
DyRep
This has been implemented with intuitions from the following sources:
- https://github.com/twitter-research/tgn
- https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py
Spec.:
- Memory Updater: RNN
- Embedding Module: ID
- Message Function: ATTN
command for an example run:
python examples/linkproppred/tgbl-wiki/dyrep.py --data "tgbl-wiki" --num_run 1 --seed 1
"""
import math
import timeit
import os
import os.path as osp
from pathlib import Path
import numpy as np
import torch
from torch_geometric.loader import TemporalDataLoader
# internal imports
from tgb.utils.utils import get_args, set_random_seed, save_results
from tgb.linkproppred.evaluate import Evaluator
from modules.decoder import LinkPredictor
from modules.emb_module import GraphAttentionEmbedding
from modules.msg_func import IdentityMessage
from modules.msg_agg import LastAggregator
from modules.neighbor_loader import LastNeighborLoader
from modules.memory_module import DyRepMemory
from modules.early_stopping import EarlyStopMonitor
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
# ==========
# ========== Define helper function...
# ==========
def train():
r"""
Training procedure for DyRep model
This function uses some objects that are globally defined in the current scrips
Parameters:
None
Returns:
None
"""
model['memory'].train()
model['gnn'].train()
model['link_pred'].train()
model['memory'].reset_state() # Start with a fresh memory.
neighbor_loader.reset_state() # Start with an empty graph.
total_loss = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
# Sample negative destination nodes.
neg_dst = torch.randint(
min_dst_idx,
max_dst_idx + 1,
(src.size(0),),
dtype=torch.long,
device=device,
)
n_id = torch.cat([src, pos_dst, neg_dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])
neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])
loss = criterion(pos_out, torch.ones_like(pos_out))
loss += criterion(neg_out, torch.zeros_like(neg_out))
# update the memory with ground-truth
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
model['memory'].update_state(src, pos_dst, t, msg, z, assoc)
# update neighbor loader
neighbor_loader.insert(src, pos_dst)
loss.backward()
optimizer.step()
model['memory'].detach()
total_loss += float(loss.detach()) * batch.num_events
return total_loss / train_data.num_events
@torch.no_grad()
def test(loader, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
loader: an object containing positive attributes of the positive edges of the evaluation set
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluation
"""
model['memory'].eval()
model['gnn'].eval()
model['link_pred'].eval()
perf_list = []
for pos_batch in loader:
pos_src, pos_dst, pos_t, pos_msg = (
pos_batch.src,
pos_batch.dst,
pos_batch.t,
pos_batch.msg,
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)
dst = torch.tensor(
np.concatenate(
([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),
axis=0,
),
device=device,
)
n_id = torch.cat([src, dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),
"y_pred_neg": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# update the memory with positive edges
n_id = torch.cat([pos_src, pos_dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg, z, assoc)
# update the neighbor loader
neighbor_loader.insert(pos_src, pos_dst)
perf_metric = float(torch.tensor(perf_list).mean())
return perf_metric
# ==========
# ==========
# ==========
# Start...
start_overall = timeit.default_timer()
# ========== set parameters...
args, _ = get_args()
print("INFO: Arguments:", args)
DATA = "tgbl-wiki"
LR = args.lr
BATCH_SIZE = args.bs
K_VALUE = args.k_value
NUM_EPOCH = args.num_epoch
SEED = args.seed
MEM_DIM = args.mem_dim
TIME_DIM = args.time_dim
EMB_DIM = args.emb_dim
TOLERANCE = args.tolerance
PATIENCE = args.patience
NUM_RUNS = args.num_run
NUM_NEIGHBORS = 10
MODEL_NAME = 'DyRep'
USE_SRC_EMB_IN_MSG = False
USE_DST_EMB_IN_MSG = True
# ==========
# set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# data loading
dataset = PyGLinkPropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data = data.to(device)
metric = dataset.eval_metric
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
train_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)
val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)
test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
# neighborhood sampler
neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)
# define the model end-to-end
# 1) memory
memory = DyRepMemory(
data.num_nodes,
data.msg.size(-1),
MEM_DIM,
TIME_DIM,
message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),
aggregator_module=LastAggregator(),
memory_updater_type='rnn',
use_src_emb_in_msg=USE_SRC_EMB_IN_MSG,
use_dst_emb_in_msg=USE_DST_EMB_IN_MSG
).to(device)
# 2) GNN
gnn = GraphAttentionEmbedding(
in_channels=MEM_DIM,
out_channels=EMB_DIM,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
).to(device)
# 3) link predictor
link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)
model = {'memory': memory,
'gnn': gnn,
'link_pred': link_pred}
# define an optimizer
optimizer = torch.optim.Adam(
set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),
lr=LR,
)
criterion = torch.nn.BCEWithLogitsLoss()
# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)
print("==========================================================")
print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
for run_idx in range(NUM_RUNS):
print('-------------------------------------------------------------------------------')
print(f"INFO: >>>>> Run: {run_idx} <<<<<")
start_run = timeit.default_timer()
# set the seed for deterministic results...
torch.manual_seed(run_idx + SEED)
set_random_seed(run_idx + SEED)
# define an early stopper
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'
early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id,
tolerance=TOLERANCE, patience=PATIENCE)
# ==================================================== Train & Validation
# loading the validation negative samples
dataset.load_val_ns()
val_perf_list = []
start_train_val = timeit.default_timer()
for epoch in range(1, NUM_EPOCH + 1):
# training
start_epoch_train = timeit.default_timer()
loss = train()
print(
f"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {timeit.default_timer() - start_epoch_train: .4f}"
)
# validation
start_val = timeit.default_timer()
perf_metric_val = test(val_loader, neg_sampler, split_mode="val")
print(f"\tValidation {metric}: {perf_metric_val: .4f}")
print(f"\tValidation: Elapsed time (s): {timeit.default_timer() - start_val: .4f}")
val_perf_list.append(perf_metric_val)
# check for early stopping
if early_stopper.step_check(perf_metric_val, model):
break
train_val_time = timeit.default_timer() - start_train_val
print(f"Train & Validation Total Elapsed Time (s): {train_val_time: .4f}")
# ==================================================== Test
# first, load the best model
early_stopper.load_checkpoint(model)
# loading the test negative samples
dataset.load_test_ns()
# final testing
start_test = timeit.default_timer()
perf_metric_test = test(test_loader, neg_sampler, split_mode="test")
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'data': DATA,
'run': run_idx,
'seed': SEED,
f'val {metric}': val_perf_list,
f'test {metric}': perf_metric_test,
'test_time': test_time,
'tot_train_val_time': train_val_time
},
results_filename)
print(f"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<")
print('-------------------------------------------------------------------------------')
print(f"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}")
print("==============================================================")
================================================
FILE: examples/linkproppred/tgbl-wiki/edgebank.py
================================================
"""
Dynamic Link Prediction with EdgeBank
NOTE: This implementation works only based on `numpy`
Reference:
- https://github.com/fpour/DGB/tree/main
"""
import timeit
import numpy as np
from tqdm import tqdm
import math
import os
import os.path as osp
from pathlib import Path
import sys
import argparse
# internal imports
from tgb.linkproppred.evaluate import Evaluator
from modules.edgebank_predictor import EdgeBankPredictor
from tgb.utils.utils import set_random_seed
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.utils.utils import save_results
# ==================
# ==================
# ==================
def test(data, test_mask, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
data: a dataset object
test_mask: required masks to load the test set edges
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluation
"""
num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)
perf_list = []
for batch_idx in tqdm(range(num_batches)):
start_idx = batch_idx * BATCH_SIZE
end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))
pos_src, pos_dst, pos_t = (
data['sources'][test_mask][start_idx: end_idx],
data['destinations'][test_mask][start_idx: end_idx],
data['timestamps'][test_mask][start_idx: end_idx],
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])
query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])
y_pred = edgebank.predict_link(query_src, query_dst)
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0]]),
"y_pred_neg": np.array(y_pred[1:]),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# update edgebank memory after each positive batch
edgebank.update_memory(pos_src, pos_dst, pos_t)
perf_metrics = float(np.mean(perf_list))
return perf_metrics
def get_args():
parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')
parser.add_argument('-d', '--data', type=str, help='Dataset name', default='tgbl-wiki')
parser.add_argument('--bs', type=int, help='Batch size', default=200)
parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)
parser.add_argument('--seed', type=int, help='Random seed', default=1)
parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])
parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)
try:
args = parser.parse_args()
except:
parser.print_help()
sys.exit(0)
return args, sys.argv
# ==================
# ==================
# ==================
start_overall = timeit.default_timer()
# set hyperparameters
args, _ = get_args()
SEED = args.seed # set the random seed for consistency
set_random_seed(SEED)
MEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`
BATCH_SIZE = args.bs
K_VALUE = args.k_value
TIME_WINDOW_RATIO = args.time_window_ratio
DATA = "tgbl-wiki"
MODEL_NAME = 'EdgeBank'
# data loading with `numpy`
dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
data = dataset.full_data
metric = dataset.eval_metric
# get masks
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
#data for memory in edgebank
hist_src = np.concatenate([data['sources'][train_mask]])
hist_dst = np.concatenate([data['destinations'][train_mask]])
hist_ts = np.concatenate([data['timestamps'][train_mask]])
# Set EdgeBank with memory updater
edgebank = EdgeBankPredictor(
hist_src,
hist_dst,
hist_ts,
memory_mode=MEMORY_MODE,
time_window_ratio=TIME_WINDOW_RATIO)
print("==========================================================")
print(f"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'
# ==================================================== Test
# loading the validation negative samples
dataset.load_val_ns()
# testing ...
start_val = timeit.default_timer()
perf_metric_test = test(data, val_mask, neg_sampler, split_mode='val')
end_val = timeit.default_timer()
print(f"INFO: val: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tval: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_val
print(f"\tval: Elapsed Time (s): {test_time: .4f}")
# ==================================================== Test
# loading the test negative samples
dataset.load_test_ns()
# testing ...
start_test = timeit.default_timer()
perf_metric_test = test(data, test_mask, neg_sampler, split_mode='test')
end_test = timeit.default_timer()
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'memory_mode': MEMORY_MODE,
'data': DATA,
'run': 1,
'seed': SEED,
metric: perf_metric_test,
'test_time': test_time,
'tot_train_val_time': 'NA'
},
results_filename)
================================================
FILE: examples/linkproppred/tgbl-wiki/tgn.py
================================================
"""
Dynamic Link Prediction with a TGN model with Early Stopping
Reference:
- https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py
command for an example run:
python examples/linkproppred/tgbl-wiki/tgn.py --data "tgbl-wiki" --num_run 1 --seed 1
"""
import math
import timeit
import os
import os.path as osp
from pathlib import Path
import numpy as np
import torch
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn import Linear
from torch_geometric.datasets import JODIEDataset
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TransformerConv
# internal imports
from tgb.utils.utils import get_args, set_random_seed, save_results
from tgb.linkproppred.evaluate import Evaluator
from modules.decoder import LinkPredictor
from modules.emb_module import GraphAttentionEmbedding
from modules.msg_func import IdentityMessage
from modules.msg_agg import LastAggregator
from modules.neighbor_loader import LastNeighborLoader
from modules.memory_module import TGNMemory
from modules.early_stopping import EarlyStopMonitor
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
# ==========
# ========== Define helper function...
# ==========
def train():
r"""
Training procedure for TGN model
This function uses some objects that are globally defined in the current scrips
Parameters:
None
Returns:
None
"""
model['memory'].train()
model['gnn'].train()
model['link_pred'].train()
model['memory'].reset_state() # Start with a fresh memory.
neighbor_loader.reset_state() # Start with an empty graph.
total_loss = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
# Sample negative destination nodes.
neg_dst = torch.randint(
min_dst_idx,
max_dst_idx + 1,
(src.size(0),),
dtype=torch.long,
device=device,
)
n_id = torch.cat([src, pos_dst, neg_dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])
neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])
loss = criterion(pos_out, torch.ones_like(pos_out))
loss += criterion(neg_out, torch.zeros_like(neg_out))
# Update memory and neighbor loader with ground-truth state.
model['memory'].update_state(src, pos_dst, t, msg)
neighbor_loader.insert(src, pos_dst)
loss.backward()
optimizer.step()
model['memory'].detach()
total_loss += float(loss.detach()) * batch.num_events
return total_loss / train_data.num_events
@torch.no_grad()
def test(loader, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
loader: an object containing positive attributes of the positive edges of the evaluation set
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluation
"""
model['memory'].eval()
model['gnn'].eval()
model['link_pred'].eval()
perf_list = []
for pos_batch in loader:
pos_src, pos_dst, pos_t, pos_msg = (
pos_batch.src,
pos_batch.dst,
pos_batch.t,
pos_batch.msg,
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)
dst = torch.tensor(
np.concatenate(
([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),
axis=0,
),
device=device,
)
n_id = torch.cat([src, dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),
"y_pred_neg": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# Update memory and neighbor loader with ground-truth state.
model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg)
neighbor_loader.insert(pos_src, pos_dst)
perf_metrics = float(torch.tensor(perf_list).mean())
return perf_metrics
# ==========
# ==========
# ==========
# Start...
start_overall = timeit.default_timer()
DATA = "tgbl-wiki"
# ========== set parameters...
args, _ = get_args()
args.data = DATA
print("INFO: Arguments:", args)
LR = args.lr
BATCH_SIZE = args.bs
K_VALUE = args.k_value
NUM_EPOCH = args.num_epoch
SEED = args.seed
MEM_DIM = args.mem_dim
TIME_DIM = args.time_dim
EMB_DIM = args.emb_dim
TOLERANCE = args.tolerance
PATIENCE = args.patience
NUM_RUNS = args.num_run
NUM_NEIGHBORS = 10
MODEL_NAME = 'TGN'
# ==========
# set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# data loading
dataset = PyGLinkPropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data = data.to(device)
metric = dataset.eval_metric
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
train_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)
val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)
test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
print("==========================================================")
print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
for run_idx in range(NUM_RUNS):
print('-------------------------------------------------------------------------------')
print(f"INFO: >>>>> Run: {run_idx} <<<<<")
start_run = timeit.default_timer()
# set the seed for deterministic results...
torch.manual_seed(run_idx + SEED)
set_random_seed(run_idx + SEED)
# neighborhood sampler
neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)
# define the model end-to-end
memory = TGNMemory(
data.num_nodes,
data.msg.size(-1),
MEM_DIM,
TIME_DIM,
message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),
aggregator_module=LastAggregator(),
).to(device)
gnn = GraphAttentionEmbedding(
in_channels=MEM_DIM,
out_channels=EMB_DIM,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
).to(device)
link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)
model = {'memory': memory,
'gnn': gnn,
'link_pred': link_pred}
optimizer = torch.optim.Adam(
set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),
lr=LR,
)
criterion = torch.nn.BCEWithLogitsLoss()
# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)
# define an early stopper
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'
early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id,
tolerance=TOLERANCE, patience=PATIENCE)
# ==================================================== Train & Validation
# loading the validation negative samples
dataset.load_val_ns()
val_perf_list = []
start_train_val = timeit.default_timer()
for epoch in range(1, NUM_EPOCH + 1):
# training
start_epoch_train = timeit.default_timer()
loss = train()
print(
f"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {timeit.default_timer() - start_epoch_train: .4f}"
)
# validation
start_val = timeit.default_timer()
perf_metric_val = test(val_loader, neg_sampler, split_mode="val")
print(f"\tValidation {metric}: {perf_metric_val: .4f}")
print(f"\tValidation: Elapsed time (s): {timeit.default_timer() - start_val: .4f}")
val_perf_list.append(perf_metric_val)
# check for early stopping
if early_stopper.step_check(perf_metric_val, model):
break
train_val_time = timeit.default_timer() - start_train_val
print(f"Train & Validation: Elapsed Time (s): {train_val_time: .4f}")
# ==================================================== Test
# first, load the best model
early_stopper.load_checkpoint(model)
# loading the test negative samples
dataset.load_test_ns()
# final testing
start_test = timeit.default_timer()
perf_metric_test = test(test_loader, neg_sampler, split_mode="test")
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'data': DATA,
'run': run_idx,
'seed': SEED,
f'val {metric}': val_perf_list,
f'test {metric}': perf_metric_test,
'test_time': test_time,
'tot_train_val_time': train_val_time
},
results_filename)
print(f"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<")
print('-------------------------------------------------------------------------------')
print(f"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}")
print("==============================================================")
================================================
FILE: examples/linkproppred/thgl-forum/edgebank.py
================================================
"""
Dynamic Link Prediction with EdgeBank
NOTE: This implementation works only based on `numpy`
Reference:
- https://github.com/fpour/DGB/tree/main
"""
import timeit
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch_geometric.loader import TemporalDataLoader
from tqdm import tqdm
import math
import os
import os.path as osp
from pathlib import Path
import sys
import argparse
# internal imports
tgb_modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(tgb_modules_path)
from tgb.linkproppred.evaluate import Evaluator
from modules.edgebank_predictor import EdgeBankPredictor
from tgb.utils.utils import set_random_seed
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.utils.utils import save_results
# ==================
# ==================
# ==================
def test(data, test_mask, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
data: a dataset object
test_mask: required masks to load the test set edges
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluation
"""
num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)
perf_list = []
hits_list = []
for batch_idx in tqdm(range(num_batches)):
start_idx = batch_idx * BATCH_SIZE
end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))
pos_src, pos_dst, pos_t, pos_edge = (
data['sources'][test_mask][start_idx: end_idx],
data['destinations'][test_mask][start_idx: end_idx],
data['timestamps'][test_mask][start_idx: end_idx],
data['edge_type'][test_mask][start_idx: end_idx],
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, pos_edge, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])
query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])
y_pred = edgebank.predict_link(query_src, query_dst)
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0]]),
"y_pred_neg": np.array(y_pred[1:]),
"eval_metric": [metric],
}
results = evaluator.eval(input_dict)
perf_list.append(results[metric])
hits_list.append(results['hits@10'])
# update edgebank memory after each positive batch
edgebank.update_memory(pos_src, pos_dst, pos_t)
perf_metrics = float(np.mean(perf_list))
perf_hits = float(np.mean(hits_list))
return perf_metrics, perf_hits
def get_args():
parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')
parser.add_argument('-d', '--data', type=str, help='Dataset name', default='thgl-forum')
parser.add_argument('--bs', type=int, help='Batch size', default=200)
parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)
parser.add_argument('--seed', type=int, help='Random seed', default=1)
parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])
parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)
try:
args = parser.parse_args()
except:
parser.print_help()
sys.exit(0)
return args, sys.argv
# ==================
# ==================
# ==================
start_overall = timeit.default_timer()
# set hyperparameters
args, _ = get_args()
SEED = args.seed # set the random seed for consistency
set_random_seed(SEED)
MEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`
BATCH_SIZE = args.bs
K_VALUE = args.k_value
TIME_WINDOW_RATIO = args.time_window_ratio
DATA = args.data
MODEL_NAME = 'EdgeBank'
# data loading with `numpy`
dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
data = dataset.full_data
metric = dataset.eval_metric
# get masks
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
#data for memory in edgebank
hist_src = np.concatenate([data['sources'][train_mask]])
hist_dst = np.concatenate([data['destinations'][train_mask]])
hist_ts = np.concatenate([data['timestamps'][train_mask]])
# Set EdgeBank with memory updater
edgebank = EdgeBankPredictor(
hist_src,
hist_dst,
hist_ts,
memory_mode=MEMORY_MODE,
time_window_ratio=TIME_WINDOW_RATIO)
print("==========================================================")
print(f"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'
# ==================================================== Test
# loading the validation negative samples
dataset.load_val_ns()
# testing ...
start_val = timeit.default_timer()
perf_metric_val, perf_hits_val = test(data, val_mask, neg_sampler, split_mode='val')
end_val = timeit.default_timer()
print(f"INFO: val: Evaluation Setting: >>> ONE-VS--ALL <<< ")
print(f"\tval: {metric}: {perf_metric_val: .4f}")
val_time = timeit.default_timer() - start_val
print(f"\tval: Elapsed Time (s): {val_time: .4f}")
# ==================================================== Test
# loading the test negative samples
dataset.load_test_ns()
# testing ...
start_test = timeit.default_timer()
perf_metric_test, perf_hits_test = test(data, test_mask, neg_sampler, split_mode='test')
end_test = timeit.default_timer()
print(f"INFO: Test: Evaluation Setting: >>> <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'memory_mode': MEMORY_MODE,
'data': DATA,
'run': 1,
'seed': SEED,
metric: perf_metric_test,
'val_mrr': perf_metric_val,
'test_time': test_time,
'tot_train_val_time': test_time+val_time,
'hits10': perf_hits_test},
results_filename)
================================================
FILE: examples/linkproppred/thgl-forum/recurrencybaseline.py
================================================
""" from paper: "History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting"
Julia Gastinger, Christian Meilicke, Federico Errica, Timo Sztyler, Anett Schuelke, Heiner Stuckenschmidt (IJCAI 2024)
@inproceedings{gastinger2024baselines,
title={History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting},
author={Gastinger, Julia and Meilicke, Christian and Errica, Federico and Sztyler, Timo and Schuelke, Anett and Stuckenschmidt, Heiner},
booktitle={33nd International Joint Conference on Artificial Intelligence (IJCAI 2024)},
year={2024},
organization={International Joint Conferences on Artificial Intelligence Organization}
}
"""
## imports
import timeit
import argparse
import numpy as np
from copy import copy
from pathlib import Path
import ray
import sys
import os
import os.path as osp
import json
#internal imports
modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(modules_path)
from modules.recurrencybaseline_predictor import baseline_predict, baseline_predict_remote
from tgb.linkproppred.evaluate import Evaluator
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.utils.utils import set_random_seed, save_results
from modules.tkg_utils import create_basis_dict, group_by, reformat_ts
def predict(num_processes, data_c_rel, all_data_c_rel, alpha, lmbda_psi,
perf_list_all, hits_list_all, window, neg_sampler, split_mode):
""" create predictions for each relation on test or valid set and compute mrr
:return perf_list_all: list of mrrs for each test query
:return hits_list_all: list of hits for each test query
"""
first_ts = data_c_rel[0][3]
## use this if you wanna use ray:
num_queries = len(data_c_rel) // num_processes
if num_queries < num_processes: # if we do not have enough queries for all the processes
num_processes_tmp = 1
num_queries = len(data_c_rel)
else:
num_processes_tmp = num_processes
if num_processes > 1:
object_references =[]
for i in range(num_processes_tmp):
num_test_queries = len(data_c_rel) - (i + 1) * num_queries
if num_test_queries >= num_queries:
test_queries_idx =[i * num_queries, (i + 1) * num_queries]
else:
test_queries_idx = [i * num_queries, len(test_data)]
valid_data_b = data_c_rel[test_queries_idx[0]:test_queries_idx[1]]
ob = baseline_predict_remote.remote(num_queries, valid_data_b, all_data_c_rel, window,
basis_dict,
num_nodes, num_rels, lmbda_psi,
alpha, evaluator,first_ts, neg_sampler, split_mode)
object_references.append(ob)
output = ray.get(object_references)
# updates the scores and logging dict for each process
for proc_loop in range(num_processes_tmp):
perf_list_all.extend(output[proc_loop][0])
hits_list_all.extend(output[proc_loop][1])
else:
perf_list, hits_list = baseline_predict(len(data_c_rel), data_c_rel, all_data_c_rel,
window, basis_dict,
num_nodes, num_rels, lmbda_psi,
alpha, evaluator, first_ts, neg_sampler, split_mode)
perf_list_all.extend(perf_list)
hits_list_all.extend(hits_list)
return perf_list_all, hits_list_all
## test
def test(best_config, all_relations,test_data_prel, all_data_prel, neg_sampler, num_processes, window, split_mode='test'):
""" create predictions by loopoing through all relations on test or valid set and compute mrr
:return perf_list_all: list of mrrs for each test query
:return hits_list_all: list of hits for each test query
"""
perf_list_all = []
hits_list_all =[]
csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'
## loop through relations and apply baselines
for rel in all_relations:
start = timeit.default_timer()
if rel in test_data_prel.keys():
lmbda_psi = best_config[str(rel)]['lmbda_psi'][0]
alpha = best_config[str(rel)]['alpha'][0]
# test data for this relation
test_data_c_rel = test_data_prel[rel]
timesteps_test = list(set(test_data_c_rel[:,3]))
timesteps_test.sort()
all_data_c_rel = all_data_prel[rel]
perf_list_rel = []
hits_list_rel = []
perf_list_rel, hits_list_rel = predict(num_processes, test_data_c_rel,
all_data_c_rel, alpha, lmbda_psi,perf_list_rel, hits_list_rel,
window, neg_sampler, split_mode)
perf_list_all.extend(perf_list_rel)
hits_list_all.extend(hits_list_rel)
else:
perf_list_rel =[]
end = timeit.default_timer()
total_time = round(end - start, 6)
print("Relation {} finished in {} seconds.".format(rel, total_time))
with open(csv_file, 'a') as f:
f.write("{},{}\n".format(rel, perf_list_rel))
return perf_list_all, hits_list_all
def read_dict_compute_mrr(split_mode='test'):
""" read the results per relation from a precreated file and compute mrr
:return mrr_per_rel: dictionary of mrrs for each relation
:return all_mrrs: list of mrrs for all relations
"""
csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'
# Initialize an empty dictionary to store the data
results_per_rel_dict = {}
mrr_per_rel = {}
all_mrrs = []
# Open the file for reading
with open(csv_file, 'r') as f:
# Read each line in the file
for line in f:
# Split the line at the comma
parts = line.strip().split(',')
# Extract the key (the first part)
key = int(parts[0])
# Extract the values (the rest of the parts), remove square brackets
values = [float(value.strip('[]')) for value in parts[1:]]
# Add the key-value pair to the dictionary
if key in results_per_rel_dict.keys():
print(f"Key {key} already exists in the dictionary!!! might have duplicate entries in results csv")
results_per_rel_dict[key] = values
all_mrrs.extend(values)
mrr_per_rel[key] = np.mean(values)
if len(list(results_per_rel_dict.keys())) != num_rels:
print("we do not have entries for each rel in the results csv file. only num enties: ", len(list(results_per_rel_dict.keys())))
print("Split mode: "+split_mode +" Mean MRR: ", np.mean(all_mrrs))
print("mrr per relation: ", mrr_per_rel)
## train
def train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampler, num_processes, window):
""" optional, find best values for lambda and alpha by looping through all relations and testing an a fixed set of params
based on validation mrr
:return best_config: dictionary of best params for each relation
"""
best_config= {}
best_mrr = 0
for rel in rels: # loop through relations. for each relation, apply rules with selected params, compute valid mrr
start = timeit.default_timer()
rel_key = int(rel)
best_config[str(rel_key)] = {}
best_config[str(rel_key)]['not_trained'] = 'True'
best_config[str(rel_key)]['lmbda_psi'] = [default_lmbda_psi,0] #default
best_config[str(rel_key)]['other_lmbda_mrrs'] = list(np.zeros(len(params_dict['lmbda_psi'])))
best_config[str(rel_key)]['alpha'] = [default_alpha,0] #default
best_config[str(rel_key)]['other_alpha_mrrs'] = list(np.zeros(len(params_dict['alpha'])))
if rel in val_data_prel.keys():
# valid data for this relation
val_data_c_rel = val_data_prel[rel]
timesteps_valid = list(set(val_data_c_rel[:,3]))
timesteps_valid.sort()
trainval_data_c_rel = trainval_data_prel[rel]
###### 1) select lambda ###############
lmbdas_psi = params_dict['lmbda_psi']
alpha = 1
best_lmbda_psi = 0.1
best_mrr_psi = 0
lmbda_mrrs = []
best_config[str(rel_key)]['num_app_valid'] = copy(len(val_data_c_rel))
best_config[str(rel_key)]['num_app_train_valid'] = copy(len(trainval_data_c_rel))
best_config[str(rel_key)]['not_trained'] = 'False'
for lmbda_psi in lmbdas_psi:
perf_list_r = []
hits_list_r = []
perf_list_r, hits_list_r = predict(num_processes, val_data_c_rel,
trainval_data_c_rel, alpha, lmbda_psi,perf_list_r, hits_list_r,
window, neg_sampler, split_mode='val')
# compute mrr
mrr = np.mean(perf_list_r)
# # is new mrr better than previous best? if yes: store lmbda
if mrr > best_mrr_psi:
best_mrr_psi = float(mrr)
best_lmbda_psi = lmbda_psi
lmbda_mrrs.append(float(mrr))
best_config[str(rel_key)]['lmbda_psi'] = [best_lmbda_psi, best_mrr_psi]
best_config[str(rel_key)]['other_lmbda_mrrs'] = lmbda_mrrs
best_mrr = best_mrr_psi
##### 2) select alpha ###############
best_config[str(rel_key)]['not_trained'] = 'False'
alphas = params_dict['alpha']
lmbda_psi = best_config[str(rel_key)]['lmbda_psi'][0] # use the best lmbda psi
alpha_mrrs = []
# perf_list_all = []
best_mrr_alpha = 0
best_alpha=0.99
for alpha in alphas:
perf_list_r = []
hits_list_r = []
perf_list_r, hits_list_r = predict(num_processes, val_data_c_rel,
trainval_data_c_rel, alpha, lmbda_psi,perf_list_r, hits_list_r,
window, neg_sampler, split_mode='val')
# compute mrr
mrr_alpha = np.mean(perf_list_r)
# is new mrr better than previous best? if yes: store alpha
if mrr_alpha > best_mrr_alpha:
best_mrr_alpha = float(mrr_alpha)
best_alpha = alpha
best_mrr = best_mrr_alpha
alpha_mrrs.append(float(mrr_alpha))
best_config[str(rel_key)]['alpha'] = [best_alpha, best_mrr_alpha]
best_config[str(rel_key)]['other_alpha_mrrs'] = alpha_mrrs
end = timeit.default_timer()
total_time = round(end - start, 6)
print("Relation {} finished in {} seconds.".format(rel, total_time))
return best_config
## args
def get_args():
"""parse all arguments for the script"""
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", "-d", default="thgl-forum", type=str)
parser.add_argument("--window", "-w", default=0, type=int) # set to e.g. 200 if only the most recent 200 timesteps should be considered. set to -2 if multistep
parser.add_argument("--num_processes", "-p", default=1, type=int)
parser.add_argument("--lmbda", "-l", default=0.1, type=float) # fix lambda. used if trainflag == false
parser.add_argument("--alpha", "-alpha", default=0.99, type=float) # fix alpha. used if trainflag == false
parser.add_argument("--train_flag", "-tr", default='False') # do we need training, ie selection of lambda and alpha
parser.add_argument("--load_flag", "-lo", default='False') # if train_flag set to True: do you want to load best_config?
parser.add_argument("--save_config", "-c", default='True') # do we need to save the selection of lambda and alpha in config file?
parser.add_argument('--seed', type=int, help='Random seed', default=1) # not needed
parsed = vars(parser.parse_args())
return parsed
start_o = timeit.default_timer()
parsed = get_args()
if parsed['num_processes']>1:
ray.init(num_cpus=parsed["num_processes"], num_gpus=0)
MODEL_NAME = 'RecurrencyBaseline'
SEED = parsed['seed'] # set the random seed for consistency
set_random_seed(SEED)
perrel_results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_models'
if not osp.exists(perrel_results_path):
os.mkdir(perrel_results_path)
print('INFO: Create directory {}'.format(perrel_results_path))
Path(perrel_results_path).mkdir(parents=True, exist_ok=True)
## load dataset and prepare it accordingly
name = parsed["dataset"]
dataset = LinkPropPredDataset(name=name, root="datasets", preprocess=True)
DATA = name
relations = dataset.edge_type
num_rels = dataset.num_rels
rels = np.arange(0,num_rels)
subjects = dataset.full_data["sources"]
objects= dataset.full_data["destinations"]
num_nodes = dataset.num_nodes
timestamps_orig = dataset.full_data["timestamps"]
timestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1
print("split train valid test data")
all_quads = np.stack((subjects, relations, objects, timestamps, timestamps_orig), axis=1)
train_data = all_quads[dataset.train_mask]
val_data = all_quads[dataset.val_mask]
test_data = all_quads[dataset.test_mask]
metric = dataset.eval_metric
evaluator = Evaluator(name=name)
neg_sampler = dataset.negative_sampler
train_val_data = np.concatenate([train_data, val_data])
all_data = np.concatenate([train_data, val_data, test_data])
# create dicts with key: relation id, values: triples for that relation id
print("grouping data by relation")
test_data_prel = group_by(test_data, 1)
all_data_prel = group_by(all_data, 1)
val_data_prel = group_by(val_data, 1)
trainval_data_prel = group_by(train_val_data, 1)
#load the ns samples
# if parsed['train_flag']:
print("loading negative samples")
dataset.load_val_ns()
dataset.load_test_ns()
# parameter options
if parsed['train_flag'] == 'True':
params_dict = {}
params_dict['lmbda_psi'] = [0, 0.0001, 0.0005, 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.5, 0.9, 1.0001]
params_dict['alpha'] = [0, 0.00001, 0.0001, 0.001, 0.01, 0.1, 0.5, 0.9, 0.99, 0.999, 0.9999, 0.99999, 1]
default_lmbda_psi = params_dict['lmbda_psi'][-1]
default_alpha = params_dict['alpha'][-2]
## load rules
print("creating rules")
basis_dict = create_basis_dict(train_val_data)
print("done with creating rules")
## init
# rb_predictor = RecurrencyBaselinePredictor(rels)
## train to find best lambda and alpha
start_train = timeit.default_timer()
if parsed['train_flag'] == 'True':
if parsed['load_flag'] == 'True':
with open('best_config.json', 'r') as infile:
best_config = json.load(infile)
else:
print('start training')
best_config = train(params_dict, rels, val_data_prel, trainval_data_prel, neg_sampler, parsed['num_processes'],
parsed['window'])
if parsed['save_config'] == 'True':
import json
with open('best_config.json', 'w') as outfile:
json.dump(best_config, outfile)
else: # use preset lmbda and alpha; same for all relations
best_config = {}
for rel in rels:
best_config[str(rel)] = {}
best_config[str(rel)]['lmbda_psi'] = [parsed['lmbda']]
best_config[str(rel)]['alpha'] = [parsed['alpha']]
end_train = timeit.default_timer()
# compute validation mrr
print("Computing validation MRR")
perf_list_all_val, hits_list_all_val = test(best_config,rels, val_data_prel,
trainval_data_prel, neg_sampler, parsed['num_processes'],
parsed['window'], split_mode='val')
val_mrr = float(np.mean(perf_list_all_val))
# compute test mrr
print("Computing test MRR")
start_test = timeit.default_timer()
perf_list_all, hits_list_all = test(best_config,rels, test_data_prel,
all_data_prel, neg_sampler, parsed['num_processes'],
parsed['window'])
end_o = timeit.default_timer()
train_time_o = round(end_train- start_train, 6)
test_time_o = round(end_o- start_test, 6)
total_time_o = round(end_o- start_o, 6)
print("Running Training to find best configs finished in {} seconds.".format(train_time_o))
print("Running testing with best configs finished in {} seconds.".format(test_time_o))
print("Running all steps finished in {} seconds.".format(total_time_o))
print(f"The test MRR is {np.mean(perf_list_all)}")
print(f"The valid MRR is {val_mrr}")
print(f"The Hits@10 is {np.mean(hits_list_all)}")
print(f"We have {len(perf_list_all)} predictions")
print(f"The test set has len {len(test_data)} ")
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_NONE_{DATA}_results.json'
metric = dataset.eval_metric
save_results({'model': MODEL_NAME,
'train_flag': parsed['train_flag'],
'data': DATA,
'run': 1,
'seed': SEED,
metric: float(np.mean(perf_list_all)),
'val_mrr': val_mrr,
'hits10': float(np.mean(hits_list_all)),
'test_time': test_time_o,
'tot_train_val_time': total_time_o
},
results_filename)
if parsed['num_processes']>1:
ray.shutdown()
================================================
FILE: examples/linkproppred/thgl-forum/sthn.py
================================================
import timeit
import numpy as np
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
from tgb.linkproppred.evaluate import Evaluator
import argparse
from modules.sthn import set_seed, pre_compute_subgraphs, get_inputs_for_ind, check_data_leakage
import torch
import pandas as pd
import itertools
from tqdm import tqdm
import math
import os
import os.path as osp
from pathlib import Path
from tgb.utils.utils import set_random_seed, save_results
# Start...
start_overall = timeit.default_timer()
DATA = "thgl-forum"
MODEL_NAME = 'STHN'
# data loading
dataset = PyGLinkPropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
metric = dataset.eval_metric
print ("there are {} nodes and {} edges".format(dataset.num_nodes, dataset.num_edges))
print ("there are {} relation types".format(dataset.num_rels))
timestamp = data.t
head = data.src
tail = data.dst
edge_type = data.edge_type
neg_sampler = dataset.negative_sampler
print(data)
print(timestamp)
print(head)
print(tail)
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
metric = dataset.eval_metric
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
####################################################################
####################################################################
####################################################################
def print_model_info(model):
print(model)
parameters = filter(lambda p: p.requires_grad, model.parameters())
parameters = sum([np.prod(p.size()) for p in parameters])
print('Trainable Parameters: %d' % parameters)
def get_args():
parser=argparse.ArgumentParser()
parser.add_argument('--data', type=str, default='movie')
parser.add_argument('--device', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=600)
parser.add_argument('--epochs', type=int, default=2)
parser.add_argument('--max_edges', type=int, default=50)
parser.add_argument('--num_edgeType', type=int, default=0, help='num of edgeType')
parser.add_argument('--lr', type=float, default=0.0005)
parser.add_argument('--weight_decay', type=float, default=1e-4)
parser.add_argument('--predict_class', action='store_true')
# model
parser.add_argument('--window_size', type=int, default=5)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--model', type=str, default='sthn')
parser.add_argument('--neg_samples', type=int, default=1)
parser.add_argument('--extra_neg_samples', type=int, default=5)
parser.add_argument('--num_neighbors', type=int, default=50)
parser.add_argument('--channel_expansion_factor', type=int, default=2)
parser.add_argument('--sampled_num_hops', type=int, default=1)
parser.add_argument('--time_dims', type=int, default=100)
parser.add_argument('--hidden_dims', type=int, default=100)
parser.add_argument('--num_layers', type=int, default=1)
parser.add_argument('--check_data_leakage', action='store_true')
parser.add_argument('--ignore_node_feats', action='store_true')
parser.add_argument('--node_feats_as_edge_feats', action='store_true')
parser.add_argument('--ignore_edge_feats', action='store_true')
parser.add_argument('--use_onehot_node_feats', action='store_true')
parser.add_argument('--use_type_feats', action='store_true')
parser.add_argument('--use_graph_structure', action='store_true')
parser.add_argument('--structure_time_gap', type=int, default=2000)
parser.add_argument('--structure_hops', type=int, default=1)
parser.add_argument('--use_node_cls', action='store_true')
parser.add_argument('--use_cached_subgraph', action='store_true')
parser.add_argument('--seed', type=int, help='Random seed', default=1)
parser.add_argument('--num_run', type=int, help='Number of iteration runs', default=5)
return parser.parse_args()
def load_model(args):
# get model
edge_predictor_configs = {
'dim_in_time': args.time_dims,
'dim_in_node': args.node_feat_dims,
'predict_class': 1 if not args.predict_class else args.num_edgeType+1,
}
if args.model == 'sthn':
if args.predict_class:
from modules.sthn import Multiclass_Interface as STHN_Interface
else:
from modules.sthn import STHN_Interface
from modules.sthn import link_pred_train
mixer_configs = {
'per_graph_size' : args.max_edges,
'time_channels' : args.time_dims,
'input_channels' : args.edge_feat_dims,
'hidden_channels' : args.hidden_dims,
'out_channels' : args.hidden_dims,
'num_layers' : args.num_layers,
'dropout' : args.dropout,
'channel_expansion_factor': args.channel_expansion_factor,
'window_size' : args.window_size,
'use_single_layer' : False
}
else:
NotImplementedError()
model = STHN_Interface(mixer_configs, edge_predictor_configs)
for k, v in model.named_parameters():
print(k, v.requires_grad)
print_model_info(model)
return model, args, link_pred_train
def load_graph(data):
df = pd.DataFrame({
'idx': np.arange(len(data.t)),
'src': data.src,
'dst': data.dst,
'time': data.t,
'label': data.edge_type,
})
num_nodes = max(int(df['src'].max()), int(df['dst'].max())) + 1
ext_full_indptr = np.zeros(num_nodes + 1, dtype=np.int32)
ext_full_indices = [[] for _ in range(num_nodes)]
ext_full_ts = [[] for _ in range(num_nodes)]
ext_full_eid = [[] for _ in range(num_nodes)]
for idx, row in tqdm(df.iterrows(), total=len(df)):
src = int(row['src'])
dst = int(row['dst'])
ext_full_indices[src].append(dst)
ext_full_ts[src].append(row['time'])
ext_full_eid[src].append(idx)
for i in tqdm(range(num_nodes)):
ext_full_indptr[i + 1] = ext_full_indptr[i] + len(ext_full_indices[i])
ext_full_indices = np.array(list(itertools.chain(*ext_full_indices)))
ext_full_ts = np.array(list(itertools.chain(*ext_full_ts)))
ext_full_eid = np.array(list(itertools.chain(*ext_full_eid)))
print('Sorting...')
def tsort(i, indptr, indices, t, eid):
beg = indptr[i]
end = indptr[i + 1]
sidx = np.argsort(t[beg:end])
indices[beg:end] = indices[beg:end][sidx]
t[beg:end] = t[beg:end][sidx]
eid[beg:end] = eid[beg:end][sidx]
for i in tqdm(range(ext_full_indptr.shape[0] - 1)):
tsort(i, ext_full_indptr, ext_full_indices, ext_full_ts, ext_full_eid)
print('saving...')
np.savez('/tmp/ext_full.npz', indptr=ext_full_indptr,
indices=ext_full_indices, ts=ext_full_ts, eid=ext_full_eid)
g = np.load('/tmp/ext_full.npz')
return g, df
def load_all_data(args):
# load graph
g, df = load_graph(data)
args.train_mask = train_mask.numpy()
args.val_mask = val_mask.numpy()
args.test_mask = test_mask.numpy()
args.num_edges = len(df)
print('Train %d, Valid %d, Test %d'%(sum(args.train_mask),
sum(args.val_mask),
sum(test_mask)))
args.num_nodes = max(int(df['src'].max()), int(df['dst'].max())) + 1
args.num_edges = len(df)
print('Num nodes %d, num edges %d'%(args.num_nodes, args.num_edges))
# load feats
node_feats, edge_feats = dataset.node_feat, dataset.edge_feat
node_feat_dims = 0 if node_feats is None else node_feats.shape[1]
edge_feat_dims = 0 if edge_feats is None else edge_feats.shape[1]
# feature pre-processing
if args.use_onehot_node_feats:
print('>>> Use one-hot node features')
node_feats = torch.eye(args.num_nodes)
node_feat_dims = node_feats.size(1)
if args.ignore_node_feats:
print('>>> Ignore node features')
node_feats = None
node_feat_dims = 0
if args.use_type_feats:
edge_type = df.label.values
print(edge_type)
print(edge_type.sum())
args.num_edgeType = len(set(edge_type.tolist()))
edge_feats = torch.nn.functional.one_hot(torch.from_numpy(edge_type),
num_classes=args.num_edgeType)
edge_feat_dims = edge_feats.size(1)
print('Node feature dim %d, edge feature dim %d'%(node_feat_dims, edge_feat_dims))
# double check (if data leakage then cannot continue the code)
if args.check_data_leakage:
check_data_leakage(args, g, df)
args.node_feat_dims = node_feat_dims
args.edge_feat_dims = edge_feat_dims
if node_feats != None:
node_feats = node_feats.to(args.device)
if edge_feats != None:
edge_feats = edge_feats.to(args.device)
return node_feats, edge_feats, g, df, args
####################################################################
####################################################################
####################################################################
@torch.no_grad()
def test(data, test_mask, model, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
data: a dataset object
test_mask: required masks to load the test set edges
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'val' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluation
"""
test_subgraphs = pre_compute_subgraphs(args, g, df, mode='test' if split_mode == 'test' else 'valid', negative_sampler=neg_sampler, split_mode=split_mode)
perf_list = []
if split_mode == 'test':
cur_df = df[args.test_mask]
elif split_mode == 'val':
cur_df = df[args.val_mask]
neg_samples = 20
cached_neg_samples = 20
test_loader = cur_df.groupby(cur_df.index // args.batch_size)
pbar = tqdm(total=len(test_loader))
pbar.set_description('%s mode with negative samples %d ...'%(split_mode, neg_samples))
###################################################
# compute + training + fetch all scores
cur_inds = 0
for ind in range(len(test_loader)):
###################################################
inputs, subgraph_node_feats, cur_inds = get_inputs_for_ind(test_subgraphs, 'test' if split_mode == 'test' else 'tgb-val', cached_neg_samples, neg_samples, node_feats, edge_feats, cur_df, cur_inds, ind, args)
loss, pred, edge_label = model(inputs, neg_samples, subgraph_node_feats)
# print(ind, [l for l in inputs], pred.shape)
input_dict = {
"y_pred_pos": np.array([pred.cpu()[0]]),
"y_pred_neg": np.array(pred.cpu()[1:]),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
perf_metrics_mean = float(np.mean(perf_list))
perf_metrics_std = float(np.std(perf_list))
return perf_metrics_mean, perf_metrics_std, perf_list
args = get_args()
args.use_graph_structure = True
args.use_onehot_node_feats = False
args.ignore_node_feats = False # we only use graph structure
args.use_type_feats = True # type encoding
args.use_cached_subgraph = True
print(args)
args.device = f"cuda:{args.device}" if torch.cuda.is_available() else "cpu"
args.device = torch.device(args.device)
SEED = args.seed
BATCH_SIZE = args.batch_size
NUM_RUNS = args.num_run
set_seed(SEED)
###################################################
# load feats + graph
node_feats, edge_feats, g, df, args = load_all_data(args)
###################################################
# get model
model, args, link_pred_train = load_model(args)
###################################################
print("==========================================================")
print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============")
print("==========================================================")
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
for run_idx in range(NUM_RUNS):
print('-------------------------------------------------------------------------------')
print(f"INFO: >>>>> Run: {run_idx} <<<<<")
start_run = timeit.default_timer()
# set the seed for deterministic results...
torch.manual_seed(run_idx + SEED)
set_random_seed(run_idx + SEED)
# define an early stopper
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'
# early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id,
# tolerance=TOLERANCE, patience=PATIENCE)
# ==================================================== Train & Validation
# loading the validation negative samples
# Link prediction
start_val = timeit.default_timer()
print('Train link prediction task from scratch ...')
model = link_pred_train(model.to(args.device), args, g, df, node_feats, edge_feats)
dataset.load_val_ns()
# Validation ...
perf_metrics_val_mean, perf_metrics_val_std, perf_list_val = test(data.to(args.device), test_mask, model.to(args.device), neg_sampler, split_mode='val')
end_val = timeit.default_timer()
print(f"INFO: val: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tval: {metric}: {perf_metrics_val_mean: .4f} ± {perf_metrics_val_std: .4f}")
val_time = timeit.default_timer() - start_val
print(f"\tval: Elapsed Time (s): {val_time: .4f}")
dataset.load_test_ns()
# testing ...
start_test = timeit.default_timer()
perf_metrics_test_mean, perf_metrics_test_std, perf_list_test = test(data.to(args.device), test_mask, model.to(args.device), neg_sampler, split_mode='test')
end_test = timeit.default_timer()
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metrics_test_mean: .4f} ± {perf_metrics_test_std: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'data': DATA,
'run': run_idx,
'seed': SEED,
f'val {metric}': f'{perf_metrics_val_mean: .4f} ± {perf_metrics_val_std: .4f}' ,
f'test {metric}': f'{perf_metrics_test_mean: .4f} ± {perf_metrics_test_std: .4f}' ,
'test_time': test_time,
'tot_train_val_time': val_time
},
results_filename)
print(f"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<")
print('-------------------------------------------------------------------------------')
print(f"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}")
print("==============================================================")
# save_results({'model': MODEL_NAME,
# 'data': DATA,
# 'run': 1,
# 'seed': SEED,
# metric: perf_metric_test,
# 'test_time': test_time,
# 'tot_train_val_time': 'NA'
# },
# results_filename)
================================================
FILE: examples/linkproppred/thgl-forum/tgn.py
================================================
import numpy as np
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
from tgb.linkproppred.evaluate import Evaluator
from torch_geometric.loader import TemporalDataLoader
from tqdm import tqdm
import timeit
import math
import timeit
import os
import os.path as osp
from pathlib import Path
import numpy as np
import torch
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn import Linear
from torch_geometric.datasets import JODIEDataset
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TransformerConv
# internal imports
from tgb.utils.utils import get_args, set_random_seed, save_results
from tgb.linkproppred.evaluate import Evaluator
from modules.decoder import LinkPredictor
from modules.emb_module import GraphAttentionEmbedding
from modules.msg_func import IdentityMessage
from modules.msg_agg import LastAggregator
from modules.neighbor_loader import LastNeighborLoader
from modules.memory_module import TGNMemory
from modules.early_stopping import EarlyStopMonitor
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
# ==========
# ========== Define helper function...
# ==========
def train():
r"""
Training procedure for TGN model
This function uses some objects that are globally defined in the current scrips
Parameters:
None
Returns:
None
"""
model['memory'].train()
model['gnn'].train()
model['link_pred'].train()
model['memory'].reset_state() # Start with a fresh memory.
neighbor_loader.reset_state() # Start with an empty graph.
total_loss = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
src, pos_dst, t, msg, rel = batch.src, batch.dst, batch.t, batch.msg, batch.edge_type
# Sample negative destination nodes.
neg_dst = torch.randint(
min_dst_idx,
max_dst_idx + 1,
(src.size(0),),
dtype=torch.long,
device=device,
)
n_id = torch.cat([src, pos_dst, neg_dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])
neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])
loss = criterion(pos_out, torch.ones_like(pos_out))
loss += criterion(neg_out, torch.zeros_like(neg_out))
# Update memory and neighbor loader with ground-truth state.
model['memory'].update_state(src, pos_dst, t, msg)
neighbor_loader.insert(src, pos_dst)
loss.backward()
optimizer.step()
model['memory'].detach()
total_loss += float(loss.detach()) * batch.num_events
return total_loss / train_data.num_events
@torch.no_grad()
def test(loader, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
loader: an object containing positive attributes of the positive edges of the evaluation set
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluaiton
"""
model['memory'].eval()
model['gnn'].eval()
model['link_pred'].eval()
perf_list = []
for pos_batch in loader:
pos_src, pos_dst, pos_t, pos_msg, pos_rel = (
pos_batch.src,
pos_batch.dst,
pos_batch.t,
pos_batch.msg,
pos_batch.edge_type
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, pos_rel, split_mode=split_mode)
# pos_msg_new = torch.cat([pos_msg,pos_rel.unsqueeze(dim=1)], dim=1)
for idx, neg_batch in enumerate(neg_batch_list):
src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)
dst = torch.tensor(
np.concatenate(
([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),
axis=0,
),
device=device,
)
n_id = torch.cat([src, dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),
"y_pred_neg": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# Update memory and neighbor loader with ground-truth state.
model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg)
neighbor_loader.insert(pos_src, pos_dst)
perf_metrics = float(torch.tensor(perf_list).mean())
return perf_metrics
# ==========
# ==========
# ==========
# ==========
# ==========
# ==========
# Start...
start_overall = timeit.default_timer()
DATA = "thgl-forum"
# ========== set parameters...
args, _ = get_args()
args.data = DATA
print("INFO: Arguments:", args)
LR = args.lr
BATCH_SIZE = args.bs
K_VALUE = args.k_value
NUM_EPOCH = args.num_epoch
SEED = args.seed
MEM_DIM = args.mem_dim
TIME_DIM = args.time_dim
EMB_DIM = args.emb_dim
TOLERANCE = args.tolerance
PATIENCE = args.patience
NUM_RUNS = args.num_run
NUM_NEIGHBORS = 10
USE_EDGE_TYPE = True
USE_NODE_TYPE = True
MODEL_NAME = 'TGN'
# ==========
# set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# data loading
dataset = PyGLinkPropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data = data.to(device)
metric = dataset.eval_metric
print ("there are {} nodes and {} edges".format(dataset.num_nodes, dataset.num_edges))
print ("there are {} relation types".format(dataset.num_rels))
timestamp = data.t
head = data.src
tail = data.dst
edge_type = data.edge_type #relation
edge_type_dim = len(torch.unique(edge_type))
embed_edge_type = torch.nn.Embedding(edge_type_dim, 64).to(device)
with torch.no_grad():
edge_type_embeddings = embed_edge_type(edge_type)
if USE_EDGE_TYPE:
data.msg = torch.cat([data.msg, edge_type_embeddings], dim=1)
#! node type is a property of the dataset not the temporal data as temporal data has one entry per edge
node_type = dataset.node_type #node type
neg_sampler = dataset.negative_sampler
data.__setattr__("node_type", node_type)
print ("shape of edge type is", edge_type.shape)
print ("shape of node type is", node_type.shape)
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
print ("finished loading PyG data")
train_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)
val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)
test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)
start_time = timeit.default_timer()
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
print("==========================================================")
print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
for run_idx in range(NUM_RUNS):
print('-------------------------------------------------------------------------------')
print(f"INFO: >>>>> Run: {run_idx} <<<<<")
start_run = timeit.default_timer()
# set the seed for deterministic results...
torch.manual_seed(run_idx + SEED)
set_random_seed(run_idx + SEED)
# neighhorhood sampler
neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)
# define the model end-to-end
memory = TGNMemory(
data.num_nodes,
data.msg.size(-1),
MEM_DIM,
TIME_DIM,
message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),
aggregator_module=LastAggregator(),
).to(device)
gnn = GraphAttentionEmbedding(
in_channels=MEM_DIM,
out_channels=EMB_DIM,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
).to(device)
link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)
model = {'memory': memory,
'gnn': gnn,
'link_pred': link_pred}
optimizer = torch.optim.Adam(
set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),
lr=LR,
)
criterion = torch.nn.BCEWithLogitsLoss()
# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)
# define an early stopper
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'
early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id,
tolerance=TOLERANCE, patience=PATIENCE)
# ==================================================== Train & Validation
# loading the validation negative samples
dataset.load_val_ns()
val_perf_list = []
start_train_val = timeit.default_timer()
for epoch in range(1, NUM_EPOCH + 1):
# training
start_epoch_train = timeit.default_timer()
loss = train()
print(
f"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {timeit.default_timer() - start_epoch_train: .4f}"
)
# validation
start_val = timeit.default_timer()
perf_metric_val = test(val_loader, neg_sampler, split_mode="val")
print(f"\tValidation {metric}: {perf_metric_val: .4f}")
print(f"\tValidation: Elapsed time (s): {timeit.default_timer() - start_val: .4f}")
val_perf_list.append(perf_metric_val)
# check for early stopping
if early_stopper.step_check(perf_metric_val, model):
break
train_val_time = timeit.default_timer() - start_train_val
print(f"Train & Validation: Elapsed Time (s): {train_val_time: .4f}")
# ==================================================== Test
# first, load the best model
early_stopper.load_checkpoint(model)
# loading the test negative samples
dataset.load_test_ns()
# final testing
start_test = timeit.default_timer()
perf_metric_test = test(test_loader, neg_sampler, split_mode="test")
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'data': DATA,
'run': run_idx,
'seed': SEED,
f'val {metric}': val_perf_list,
f'test {metric}': perf_metric_test,
'test_time': test_time,
'tot_train_val_time': train_val_time
},
results_filename)
print(f"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<")
print('-------------------------------------------------------------------------------')
print(f"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}")
print("==============================================================")
# #* load numpy arrays instead
# from tgb.linkproppred.dataset import LinkPropPredDataset
# # data loading
# dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
# data = dataset.full_data
# metric = dataset.eval_metric
# sources = dataset.full_data['sources']
# print ("finished loading numpy arrays")
================================================
FILE: examples/linkproppred/thgl-github/edgebank.py
================================================
"""
Dynamic Link Prediction with EdgeBank
NOTE: This implementation works only based on `numpy`
Reference:
- https://github.com/fpour/DGB/tree/main
"""
import timeit
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch_geometric.loader import TemporalDataLoader
from tqdm import tqdm
import math
import os
import os.path as osp
from pathlib import Path
import sys
import argparse
# internal imports
tgb_modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(tgb_modules_path)
from tgb.linkproppred.evaluate import Evaluator
from modules.edgebank_predictor import EdgeBankPredictor
from tgb.utils.utils import set_random_seed
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.utils.utils import save_results
# ==================
# ==================
# ==================
def test(data, test_mask, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
data: a dataset object
test_mask: required masks to load the test set edges
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluation
"""
num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)
perf_list = []
hits_list = []
for batch_idx in tqdm(range(num_batches)):
start_idx = batch_idx * BATCH_SIZE
end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))
pos_src, pos_dst, pos_t, pos_edge = (
data['sources'][test_mask][start_idx: end_idx],
data['destinations'][test_mask][start_idx: end_idx],
data['timestamps'][test_mask][start_idx: end_idx],
data['edge_type'][test_mask][start_idx: end_idx],
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, pos_edge, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])
query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])
y_pred = edgebank.predict_link(query_src, query_dst)
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0]]),
"y_pred_neg": np.array(y_pred[1:]),
"eval_metric": [metric],
}
results = evaluator.eval(input_dict)
perf_list.append(results[metric])
hits_list.append(results['hits@10'])
# update edgebank memory after each positive batch
edgebank.update_memory(pos_src, pos_dst, pos_t)
perf_metrics = float(np.mean(perf_list))
perf_hits = float(np.mean(hits_list))
return perf_metrics, perf_hits
def get_args():
parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')
parser.add_argument('-d', '--data', type=str, help='Dataset name', default='thgl-github')
parser.add_argument('--bs', type=int, help='Batch size', default=200)
parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)
parser.add_argument('--seed', type=int, help='Random seed', default=1)
parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])
parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)
try:
args = parser.parse_args()
except:
parser.print_help()
sys.exit(0)
return args, sys.argv
# ==================
# ==================
# ==================
start_overall = timeit.default_timer()
# set hyperparameters
args, _ = get_args()
SEED = args.seed # set the random seed for consistency
set_random_seed(SEED)
MEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`
BATCH_SIZE = args.bs
K_VALUE = args.k_value
TIME_WINDOW_RATIO = args.time_window_ratio
DATA = args.data
MODEL_NAME = 'EdgeBank'
# data loading with `numpy`
dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
data = dataset.full_data
metric = dataset.eval_metric
# get masks
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
#data for memory in edgebank
hist_src = np.concatenate([data['sources'][train_mask]])
hist_dst = np.concatenate([data['destinations'][train_mask]])
hist_ts = np.concatenate([data['timestamps'][train_mask]])
# Set EdgeBank with memory updater
edgebank = EdgeBankPredictor(
hist_src,
hist_dst,
hist_ts,
memory_mode=MEMORY_MODE,
time_window_ratio=TIME_WINDOW_RATIO)
print("==========================================================")
print(f"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'
# ==================================================== Test
# loading the validation negative samples
dataset.load_val_ns()
# testing ...
start_val = timeit.default_timer()
perf_metric_val, perf_hits_val = test(data, val_mask, neg_sampler, split_mode='val')
end_val = timeit.default_timer()
print(f"INFO: val: Evaluation Setting: >>> ONE-VS--ALL <<< ")
print(f"\tval: {metric}: {perf_metric_val: .4f}")
val_time = timeit.default_timer() - start_val
print(f"\tval: Elapsed Time (s): {val_time: .4f}")
# ==================================================== Test
# loading the test negative samples
dataset.load_test_ns()
# testing ...
start_test = timeit.default_timer()
perf_metric_test, perf_hits_test = test(data, test_mask, neg_sampler, split_mode='test')
end_test = timeit.default_timer()
print(f"INFO: Test: Evaluation Setting: >>> <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'memory_mode': MEMORY_MODE,
'data': DATA,
'run': 1,
'seed': SEED,
metric: perf_metric_test,
'val_mrr': perf_metric_val,
'test_time': test_time,
'tot_train_val_time': test_time+val_time,
'hits10': perf_hits_test},
results_filename)
================================================
FILE: examples/linkproppred/thgl-github/recurrencybaseline.py
================================================
""" from paper: "History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting"
Julia Gastinger, Christian Meilicke, Federico Errica, Timo Sztyler, Anett Schuelke, Heiner Stuckenschmidt (IJCAI 2024)
@inproceedings{gastinger2024baselines,
title={History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting},
author={Gastinger, Julia and Meilicke, Christian and Errica, Federico and Sztyler, Timo and Schuelke, Anett and Stuckenschmidt, Heiner},
booktitle={33nd International Joint Conference on Artificial Intelligence (IJCAI 2024)},
year={2024},
organization={International Joint Conferences on Artificial Intelligence Organization}
}
python recurrencybaseline.py --seed 1 --num_processes 1 -tr False
"""
## imports
import timeit
import argparse
import numpy as np
from copy import copy
from pathlib import Path
import ray
import sys
import os
import os.path as osp
import json
#internal imports
modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(modules_path)
from modules.recurrencybaseline_predictor import baseline_predict, baseline_predict_remote
from tgb.linkproppred.evaluate import Evaluator
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.utils.utils import set_random_seed, save_results
from modules.tkg_utils import create_basis_dict, group_by, reformat_ts
def predict(num_processes, data_c_rel, all_data_c_rel, alpha, lmbda_psi,
perf_list_all, hits_list_all, window, neg_sampler, split_mode):
""" create predictions for each relation on test or valid set and compute mrr
:return perf_list_all: list of mrrs for each test query
:return hits_list_all: list of hits for each test query
"""
first_ts = data_c_rel[0][3]
## use this if you wanna use ray:
num_queries = len(data_c_rel) // num_processes
if num_queries < num_processes: # if we do not have enough queries for all the processes
num_processes_tmp = 1
num_queries = len(data_c_rel)
else:
num_processes_tmp = num_processes
if num_processes > 1:
object_references =[]
for i in range(num_processes_tmp):
num_test_queries = len(data_c_rel) - (i + 1) * num_queries
if num_test_queries >= num_queries:
test_queries_idx =[i * num_queries, (i + 1) * num_queries]
else:
test_queries_idx = [i * num_queries, len(test_data)]
valid_data_b = data_c_rel[test_queries_idx[0]:test_queries_idx[1]]
ob = baseline_predict_remote.remote(num_queries, valid_data_b, all_data_c_rel, window,
basis_dict,
num_nodes, num_rels, lmbda_psi,
alpha, evaluator,first_ts, neg_sampler, split_mode)
object_references.append(ob)
output = ray.get(object_references)
# updates the scores and logging dict for each process
for proc_loop in range(num_processes_tmp):
perf_list_all.extend(output[proc_loop][0])
hits_list_all.extend(output[proc_loop][1])
else:
perf_list, hits_list = baseline_predict(len(data_c_rel), data_c_rel, all_data_c_rel,
window, basis_dict,
num_nodes, num_rels, lmbda_psi,
alpha, evaluator, first_ts, neg_sampler, split_mode)
perf_list_all.extend(perf_list)
hits_list_all.extend(hits_list)
return perf_list_all, hits_list_all
## test
def test(best_config, all_relations,test_data_prel, all_data_prel, neg_sampler, num_processes, window, split_mode='test'):
""" create predictions by loopoing through all relations on test or valid set and compute mrr
:return perf_list_all: list of mrrs for each test query
:return hits_list_all: list of hits for each test query
"""
perf_list_all = []
hits_list_all =[]
csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'
## loop through relations and apply baselines
for rel in all_relations:
start = timeit.default_timer()
if rel in test_data_prel.keys():
lmbda_psi = best_config[str(rel)]['lmbda_psi'][0]
alpha = best_config[str(rel)]['alpha'][0]
# test data for this relation
test_data_c_rel = test_data_prel[rel]
timesteps_test = list(set(test_data_c_rel[:,3]))
timesteps_test.sort()
all_data_c_rel = all_data_prel[rel]
perf_list_rel = []
hits_list_rel = []
perf_list_rel, hits_list_rel = predict(num_processes, test_data_c_rel,
all_data_c_rel, alpha, lmbda_psi,perf_list_rel, hits_list_rel,
window, neg_sampler, split_mode)
perf_list_all.extend(perf_list_rel)
hits_list_all.extend(hits_list_rel)
else:
perf_list_rel =[]
end = timeit.default_timer()
total_time = round(end - start, 6)
print("Relation {} finished in {} seconds.".format(rel, total_time))
with open(csv_file, 'a') as f:
f.write("{},{}\n".format(rel, perf_list_rel))
return perf_list_all, hits_list_all
def read_dict_compute_mrr(split_mode='test'):
""" read the results per relation from a precreated file and compute mrr
:return mrr_per_rel: dictionary of mrrs for each relation
:return all_mrrs: list of mrrs for all relations
"""
csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'
# Initialize an empty dictionary to store the data
results_per_rel_dict = {}
mrr_per_rel = {}
all_mrrs = []
# Open the file for reading
with open(csv_file, 'r') as f:
# Read each line in the file
for line in f:
# Split the line at the comma
parts = line.strip().split(',')
# Extract the key (the first part)
key = int(parts[0])
# Extract the values (the rest of the parts), remove square brackets
values = [float(value.strip('[]')) for value in parts[1:]]
# Add the key-value pair to the dictionary
if key in results_per_rel_dict.keys():
print(f"Key {key} already exists in the dictionary!!! might have duplicate entries in results csv")
results_per_rel_dict[key] = values
all_mrrs.extend(values)
mrr_per_rel[key] = np.mean(values)
if len(list(results_per_rel_dict.keys())) != num_rels:
print("we do not have entries for each rel in the results csv file. only num enties: ", len(list(results_per_rel_dict.keys())))
print("Split mode: "+split_mode +" Mean MRR: ", np.mean(all_mrrs))
print("mrr per relation: ", mrr_per_rel)
## train
def train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampler, num_processes, window):
""" optional, find best values for lambda and alpha by looping through all relations and testing an a fixed set of params
based on validation mrr
:return best_config: dictionary of best params for each relation
"""
best_config= {}
best_mrr = 0
for rel in rels: # loop through relations. for each relation, apply rules with selected params, compute valid mrr
start = timeit.default_timer()
rel_key = int(rel)
best_config[str(rel_key)] = {}
best_config[str(rel_key)]['not_trained'] = 'True'
best_config[str(rel_key)]['lmbda_psi'] = [default_lmbda_psi,0] #default
best_config[str(rel_key)]['other_lmbda_mrrs'] = list(np.zeros(len(params_dict['lmbda_psi'])))
best_config[str(rel_key)]['alpha'] = [default_alpha,0] #default
best_config[str(rel_key)]['other_alpha_mrrs'] = list(np.zeros(len(params_dict['alpha'])))
if rel in val_data_prel.keys():
# valid data for this relation
val_data_c_rel = val_data_prel[rel]
timesteps_valid = list(set(val_data_c_rel[:,3]))
timesteps_valid.sort()
trainval_data_c_rel = trainval_data_prel[rel]
###### 1) select lambda ###############
lmbdas_psi = params_dict['lmbda_psi']
alpha = 1
best_lmbda_psi = 0.1
best_mrr_psi = 0
lmbda_mrrs = []
best_config[str(rel_key)]['num_app_valid'] = copy(len(val_data_c_rel))
best_config[str(rel_key)]['num_app_train_valid'] = copy(len(trainval_data_c_rel))
best_config[str(rel_key)]['not_trained'] = 'False'
for lmbda_psi in lmbdas_psi:
perf_list_r = []
hits_list_r = []
perf_list_r, hits_list_r = predict(num_processes, val_data_c_rel,
trainval_data_c_rel, alpha, lmbda_psi,perf_list_r, hits_list_r,
window, neg_sampler, split_mode='val')
# compute mrr
mrr = np.mean(perf_list_r)
# # is new mrr better than previous best? if yes: store lmbda
if mrr > best_mrr_psi:
best_mrr_psi = float(mrr)
best_lmbda_psi = lmbda_psi
lmbda_mrrs.append(float(mrr))
best_config[str(rel_key)]['lmbda_psi'] = [best_lmbda_psi, best_mrr_psi]
best_config[str(rel_key)]['other_lmbda_mrrs'] = lmbda_mrrs
best_mrr = best_mrr_psi
##### 2) select alpha ###############
best_config[str(rel_key)]['not_trained'] = 'False'
alphas = params_dict['alpha']
lmbda_psi = best_config[str(rel_key)]['lmbda_psi'][0] # use the best lmbda psi
alpha_mrrs = []
# perf_list_all = []
best_mrr_alpha = 0
best_alpha=0.99
for alpha in alphas:
perf_list_r = []
hits_list_r = []
perf_list_r, hits_list_r = predict(num_processes, val_data_c_rel,
trainval_data_c_rel, alpha, lmbda_psi,perf_list_r, hits_list_r,
window, neg_sampler, split_mode='val')
# compute mrr
mrr_alpha = np.mean(perf_list_r)
# is new mrr better than previous best? if yes: store alpha
if mrr_alpha > best_mrr_alpha:
best_mrr_alpha = float(mrr_alpha)
best_alpha = alpha
best_mrr = best_mrr_alpha
alpha_mrrs.append(float(mrr_alpha))
best_config[str(rel_key)]['alpha'] = [best_alpha, best_mrr_alpha]
best_config[str(rel_key)]['other_alpha_mrrs'] = alpha_mrrs
end = timeit.default_timer()
total_time = round(end - start, 6)
print("Relation {} finished in {} seconds.".format(rel, total_time))
return best_config
## args
def get_args():
"""parse all arguments for the script"""
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", "-d", default="thgl-github", type=str)
parser.add_argument("--window", "-w", default=0, type=int) # set to e.g. 200 if only the most recent 200 timesteps should be considered. set to -2 if multistep
parser.add_argument("--num_processes", "-p", default=1, type=int)
parser.add_argument("--lmbda", "-l", default=0.1, type=float) # fix lambda. used if trainflag == false
parser.add_argument("--alpha", "-alpha", default=0.99, type=float) # fix alpha. used if trainflag == false
parser.add_argument("--train_flag", "-tr", default='False') # do we need training, ie selection of lambda and alpha
parser.add_argument("--load_flag", "-lo", default='False') # if train_flag set to True: do you want to load best_config?
parser.add_argument("--save_config", "-c", default='True') # do we need to save the selection of lambda and alpha in config file?
parser.add_argument('--seed', type=int, help='Random seed', default=1) # not needed
parsed = vars(parser.parse_args())
return parsed
start_o = timeit.default_timer()
parsed = get_args()
if parsed['num_processes']>1:
ray.init(num_cpus=parsed["num_processes"], num_gpus=0)
MODEL_NAME = 'RecurrencyBaseline'
SEED = parsed['seed'] # set the random seed for consistency
set_random_seed(SEED)
perrel_results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_models'
if not osp.exists(perrel_results_path):
os.mkdir(perrel_results_path)
print('INFO: Create directory {}'.format(perrel_results_path))
Path(perrel_results_path).mkdir(parents=True, exist_ok=True)
## load dataset and prepare it accordingly
name = parsed["dataset"]
dataset = LinkPropPredDataset(name=name, root="datasets", preprocess=True)
DATA = name
relations = dataset.edge_type
num_rels = dataset.num_rels
rels = np.arange(0,num_rels)
subjects = dataset.full_data["sources"]
objects= dataset.full_data["destinations"]
num_nodes = dataset.num_nodes
timestamps_orig = dataset.full_data["timestamps"]
timestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1
print("split train valid test data")
all_quads = np.stack((subjects, relations, objects, timestamps, timestamps_orig), axis=1)
train_data = all_quads[dataset.train_mask]
val_data = all_quads[dataset.val_mask]
test_data = all_quads[dataset.test_mask]
metric = dataset.eval_metric
evaluator = Evaluator(name=name)
neg_sampler = dataset.negative_sampler
train_val_data = np.concatenate([train_data, val_data])
all_data = np.concatenate([train_data, val_data, test_data])
# create dicts with key: relation id, values: triples for that relation id
print("grouping data by relation")
test_data_prel = group_by(test_data, 1)
all_data_prel = group_by(all_data, 1)
val_data_prel = group_by(val_data, 1)
trainval_data_prel = group_by(train_val_data, 1)
#load the ns samples
# if parsed['train_flag']:
print("loading negative samples")
dataset.load_val_ns()
dataset.load_test_ns()
# parameter options
if parsed['train_flag'] == 'True':
params_dict = {}
params_dict['lmbda_psi'] = [0, 0.0001, 0.0005, 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.5, 0.9, 1.0001]
params_dict['alpha'] = [0, 0.00001, 0.0001, 0.001, 0.01, 0.1, 0.5, 0.9, 0.99, 0.999, 0.9999, 0.99999, 1]
default_lmbda_psi = params_dict['lmbda_psi'][-1]
default_alpha = params_dict['alpha'][-2]
## load rules
print("creating rules")
basis_dict = create_basis_dict(train_val_data)
print("done with creating rules")
## init
# rb_predictor = RecurrencyBaselinePredictor(rels)
## train to find best lambda and alpha
start_train = timeit.default_timer()
if parsed['train_flag'] == 'True':
if parsed['load_flag'] == 'True':
with open('best_config.json', 'r') as infile:
best_config = json.load(infile)
else:
print('start training')
best_config = train(params_dict, rels, val_data_prel, trainval_data_prel, neg_sampler, parsed['num_processes'],
parsed['window'])
if parsed['save_config'] == 'True':
import json
with open('best_config.json', 'w') as outfile:
json.dump(best_config, outfile)
else: # use preset lmbda and alpha; same for all relations
best_config = {}
for rel in rels:
best_config[str(rel)] = {}
best_config[str(rel)]['lmbda_psi'] = [parsed['lmbda']]
best_config[str(rel)]['alpha'] = [parsed['alpha']]
end_train = timeit.default_timer()
# compute validation mrr
print("Computing validation MRR")
perf_list_all_val, hits_list_all_val = test(best_config,rels, val_data_prel,
trainval_data_prel, neg_sampler, parsed['num_processes'],
parsed['window'], split_mode='val')
val_mrr = float(np.mean(perf_list_all_val))
# compute test mrr
print("Computing test MRR")
start_test = timeit.default_timer()
perf_list_all, hits_list_all = test(best_config,rels, test_data_prel,
all_data_prel, neg_sampler, parsed['num_processes'],
parsed['window'])
end_o = timeit.default_timer()
train_time_o = round(end_train- start_train, 6)
test_time_o = round(end_o- start_test, 6)
total_time_o = round(end_o- start_o, 6)
print("Running Training to find best configs finished in {} seconds.".format(train_time_o))
print("Running testing with best configs finished in {} seconds.".format(test_time_o))
print("Running all steps finished in {} seconds.".format(total_time_o))
print(f"The test MRR is {np.mean(perf_list_all)}")
print(f"The valid MRR is {val_mrr}")
print(f"The Hits@10 is {np.mean(hits_list_all)}")
print(f"We have {len(perf_list_all)} predictions")
print(f"The test set has len {len(test_data)} ")
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_NONE_{DATA}_results.json'
metric = dataset.eval_metric
save_results({'model': MODEL_NAME,
'train_flag': parsed['train_flag'],
'data': DATA,
'run': 1,
'seed': SEED,
metric: float(np.mean(perf_list_all)),
'val_mrr': val_mrr,
'hits10': float(np.mean(hits_list_all)),
'test_time': test_time_o,
'tot_train_val_time': total_time_o
},
results_filename)
if parsed['num_processes']>1:
ray.shutdown()
================================================
FILE: examples/linkproppred/thgl-github/run_seeds.sh
================================================
python -u tgn.py --seed 1 --mem_dim 16 --time_dim 16 --emb_dim 16 --num_epoch 5 --num_run 1 | tee tgn_s1_new_output.txt
python -u tgn.py --seed 2 --mem_dim 16 --time_dim 16 --emb_dim 16 --num_epoch 5 --num_run 1 | tee tgn_s2_new_output.txt
python -u tgn.py --seed 3 --mem_dim 16 --time_dim 16 --emb_dim 16 --num_epoch 5 --num_run 1 | tee tgn_s3_new_output.txt
python -u tgn.py --seed 4 --mem_dim 16 --time_dim 16 --emb_dim 16 --num_epoch 5 --num_run 1 | tee tgn_s4_new_output.txt
python -u tgn.py --seed 5 --mem_dim 16 --time_dim 16 --emb_dim 16 --num_epoch 5 --num_run 1 | tee tgn_s5_new_output.txt
================================================
FILE: examples/linkproppred/thgl-github/sthn.py
================================================
import timeit
import numpy as np
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
from tgb.linkproppred.evaluate import Evaluator
import argparse
from modules.sthn import set_seed, pre_compute_subgraphs, get_inputs_for_ind, check_data_leakage
import torch
import pandas as pd
import itertools
from tqdm import tqdm
import math
import os
import os.path as osp
from pathlib import Path
from tgb.utils.utils import set_random_seed, save_results
# Start...
start_overall = timeit.default_timer()
DATA = "thgl-github"
MODEL_NAME = 'STHN'
# data loading
dataset = PyGLinkPropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
metric = dataset.eval_metric
print ("there are {} nodes and {} edges".format(dataset.num_nodes, dataset.num_edges))
print ("there are {} relation types".format(dataset.num_rels))
timestamp = data.t
head = data.src
tail = data.dst
edge_type = data.edge_type
neg_sampler = dataset.negative_sampler
print(data)
print(timestamp)
print(head)
print(tail)
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
metric = dataset.eval_metric
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
####################################################################
####################################################################
####################################################################
def print_model_info(model):
print(model)
parameters = filter(lambda p: p.requires_grad, model.parameters())
parameters = sum([np.prod(p.size()) for p in parameters])
print('Trainable Parameters: %d' % parameters)
def get_args():
parser=argparse.ArgumentParser()
parser.add_argument('--data', type=str, default='movie')
parser.add_argument('--device', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=600)
parser.add_argument('--epochs', type=int, default=2)
parser.add_argument('--max_edges', type=int, default=50)
parser.add_argument('--num_edgeType', type=int, default=0, help='num of edgeType')
parser.add_argument('--lr', type=float, default=0.0005)
parser.add_argument('--weight_decay', type=float, default=1e-4)
parser.add_argument('--predict_class', action='store_true')
# model
parser.add_argument('--window_size', type=int, default=5)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--model', type=str, default='sthn')
parser.add_argument('--neg_samples', type=int, default=1)
parser.add_argument('--extra_neg_samples', type=int, default=5)
parser.add_argument('--num_neighbors', type=int, default=50)
parser.add_argument('--channel_expansion_factor', type=int, default=2)
parser.add_argument('--sampled_num_hops', type=int, default=1)
parser.add_argument('--time_dims', type=int, default=100)
parser.add_argument('--hidden_dims', type=int, default=100)
parser.add_argument('--num_layers', type=int, default=1)
parser.add_argument('--check_data_leakage', action='store_true')
parser.add_argument('--ignore_node_feats', action='store_true')
parser.add_argument('--node_feats_as_edge_feats', action='store_true')
parser.add_argument('--ignore_edge_feats', action='store_true')
parser.add_argument('--use_onehot_node_feats', action='store_true')
parser.add_argument('--use_type_feats', action='store_true')
parser.add_argument('--use_graph_structure', action='store_true')
parser.add_argument('--structure_time_gap', type=int, default=2000)
parser.add_argument('--structure_hops', type=int, default=1)
parser.add_argument('--use_node_cls', action='store_true')
parser.add_argument('--use_cached_subgraph', action='store_true')
parser.add_argument('--seed', type=int, help='Random seed', default=1)
parser.add_argument('--num_run', type=int, help='Number of iteration runs', default=5)
return parser.parse_args()
def load_model(args):
# get model
edge_predictor_configs = {
'dim_in_time': args.time_dims,
'dim_in_node': args.node_feat_dims,
'predict_class': 1 if not args.predict_class else args.num_edgeType+1,
}
if args.model == 'sthn':
if args.predict_class:
from modules.sthn import Multiclass_Interface as STHN_Interface
else:
from modules.sthn import STHN_Interface
from modules.sthn import link_pred_train
mixer_configs = {
'per_graph_size' : args.max_edges,
'time_channels' : args.time_dims,
'input_channels' : args.edge_feat_dims,
'hidden_channels' : args.hidden_dims,
'out_channels' : args.hidden_dims,
'num_layers' : args.num_layers,
'dropout' : args.dropout,
'channel_expansion_factor': args.channel_expansion_factor,
'window_size' : args.window_size,
'use_single_layer' : False
}
else:
NotImplementedError()
model = STHN_Interface(mixer_configs, edge_predictor_configs)
for k, v in model.named_parameters():
print(k, v.requires_grad)
print_model_info(model)
return model, args, link_pred_train
def load_graph(data):
df = pd.DataFrame({
'idx': np.arange(len(data.t)),
'src': data.src,
'dst': data.dst,
'time': data.t,
'label': data.edge_type,
})
num_nodes = max(int(df['src'].max()), int(df['dst'].max())) + 1
ext_full_indptr = np.zeros(num_nodes + 1, dtype=np.int32)
ext_full_indices = [[] for _ in range(num_nodes)]
ext_full_ts = [[] for _ in range(num_nodes)]
ext_full_eid = [[] for _ in range(num_nodes)]
for idx, row in tqdm(df.iterrows(), total=len(df)):
src = int(row['src'])
dst = int(row['dst'])
ext_full_indices[src].append(dst)
ext_full_ts[src].append(row['time'])
ext_full_eid[src].append(idx)
for i in tqdm(range(num_nodes)):
ext_full_indptr[i + 1] = ext_full_indptr[i] + len(ext_full_indices[i])
ext_full_indices = np.array(list(itertools.chain(*ext_full_indices)))
ext_full_ts = np.array(list(itertools.chain(*ext_full_ts)))
ext_full_eid = np.array(list(itertools.chain(*ext_full_eid)))
print('Sorting...')
def tsort(i, indptr, indices, t, eid):
beg = indptr[i]
end = indptr[i + 1]
sidx = np.argsort(t[beg:end])
indices[beg:end] = indices[beg:end][sidx]
t[beg:end] = t[beg:end][sidx]
eid[beg:end] = eid[beg:end][sidx]
for i in tqdm(range(ext_full_indptr.shape[0] - 1)):
tsort(i, ext_full_indptr, ext_full_indices, ext_full_ts, ext_full_eid)
print('saving...')
np.savez('/tmp/ext_full.npz', indptr=ext_full_indptr,
indices=ext_full_indices, ts=ext_full_ts, eid=ext_full_eid)
g = np.load('/tmp/ext_full.npz')
return g, df
def load_all_data(args):
# load graph
g, df = load_graph(data)
args.train_mask = train_mask.numpy()
args.val_mask = val_mask.numpy()
args.test_mask = test_mask.numpy()
args.num_edges = len(df)
print('Train %d, Valid %d, Test %d'%(sum(args.train_mask),
sum(args.val_mask),
sum(test_mask)))
args.num_nodes = max(int(df['src'].max()), int(df['dst'].max())) + 1
args.num_edges = len(df)
print('Num nodes %d, num edges %d'%(args.num_nodes, args.num_edges))
# load feats
node_feats, edge_feats = dataset.node_feat, dataset.edge_feat
node_feat_dims = 0 if node_feats is None else node_feats.shape[1]
edge_feat_dims = 0 if edge_feats is None else edge_feats.shape[1]
# feature pre-processing
if args.use_onehot_node_feats:
print('>>> Use one-hot node features')
node_feats = torch.eye(args.num_nodes)
node_feat_dims = node_feats.size(1)
if args.ignore_node_feats:
print('>>> Ignore node features')
node_feats = None
node_feat_dims = 0
if args.use_type_feats:
edge_type = df.label.values
print(edge_type)
print(edge_type.sum())
args.num_edgeType = len(set(edge_type.tolist()))
edge_feats = torch.nn.functional.one_hot(torch.from_numpy(edge_type),
num_classes=args.num_edgeType)
edge_feat_dims = edge_feats.size(1)
print('Node feature dim %d, edge feature dim %d'%(node_feat_dims, edge_feat_dims))
# double check (if data leakage then cannot continue the code)
if args.check_data_leakage:
check_data_leakage(args, g, df)
args.node_feat_dims = node_feat_dims
args.edge_feat_dims = edge_feat_dims
if node_feats != None:
node_feats = node_feats.to(args.device)
if edge_feats != None:
edge_feats = edge_feats.to(args.device)
return node_feats, edge_feats, g, df, args
####################################################################
####################################################################
####################################################################
@torch.no_grad()
def test(data, test_mask, model, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
data: a dataset object
test_mask: required masks to load the test set edges
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'val' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluation
"""
test_subgraphs = pre_compute_subgraphs(args, g, df, mode='test' if split_mode == 'test' else 'valid', negative_sampler=neg_sampler, split_mode=split_mode)
perf_list = []
if split_mode == 'test':
cur_df = df[args.test_mask]
elif split_mode == 'val':
cur_df = df[args.val_mask]
neg_samples = 20
cached_neg_samples = 20
test_loader = cur_df.groupby(cur_df.index // args.batch_size)
pbar = tqdm(total=len(test_loader))
pbar.set_description('%s mode with negative samples %d ...'%(split_mode, neg_samples))
###################################################
# compute + training + fetch all scores
cur_inds = 0
for ind in range(len(test_loader)):
###################################################
inputs, subgraph_node_feats, cur_inds = get_inputs_for_ind(test_subgraphs, 'test' if split_mode == 'test' else 'tgb-val', cached_neg_samples, neg_samples, node_feats, edge_feats, cur_df, cur_inds, ind, args)
loss, pred, edge_label = model(inputs, neg_samples, subgraph_node_feats)
# print(ind, [l for l in inputs], pred.shape)
input_dict = {
"y_pred_pos": np.array([pred.cpu()[0]]),
"y_pred_neg": np.array(pred.cpu()[1:]),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
perf_metrics_mean = float(np.mean(perf_list))
perf_metrics_std = float(np.std(perf_list))
return perf_metrics_mean, perf_metrics_std, perf_list
args = get_args()
args.use_graph_structure = True
args.use_onehot_node_feats = False
args.ignore_node_feats = False # we only use graph structure
args.use_type_feats = True # type encoding
args.use_cached_subgraph = True
print(args)
args.device = f"cuda:{args.device}" if torch.cuda.is_available() else "cpu"
args.device = torch.device(args.device)
SEED = args.seed
BATCH_SIZE = args.batch_size
NUM_RUNS = args.num_run
set_seed(SEED)
###################################################
# load feats + graph
node_feats, edge_feats, g, df, args = load_all_data(args)
###################################################
# get model
model, args, link_pred_train = load_model(args)
###################################################
print("==========================================================")
print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============")
print("==========================================================")
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
for run_idx in range(NUM_RUNS):
print('-------------------------------------------------------------------------------')
print(f"INFO: >>>>> Run: {run_idx} <<<<<")
start_run = timeit.default_timer()
# set the seed for deterministic results...
torch.manual_seed(run_idx + SEED)
set_random_seed(run_idx + SEED)
# define an early stopper
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'
# early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id,
# tolerance=TOLERANCE, patience=PATIENCE)
# ==================================================== Train & Validation
# loading the validation negative samples
# Link prediction
start_val = timeit.default_timer()
print('Train link prediction task from scratch ...')
model = link_pred_train(model.to(args.device), args, g, df, node_feats, edge_feats)
dataset.load_val_ns()
# Validation ...
perf_metrics_val_mean, perf_metrics_val_std, perf_list_val = test(data.to(args.device), test_mask, model.to(args.device), neg_sampler, split_mode='val')
end_val = timeit.default_timer()
print(f"INFO: val: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tval: {metric}: {perf_metrics_val_mean: .4f} ± {perf_metrics_val_std: .4f}")
val_time = timeit.default_timer() - start_val
print(f"\tval: Elapsed Time (s): {val_time: .4f}")
dataset.load_test_ns()
# testing ...
start_test = timeit.default_timer()
perf_metrics_test_mean, perf_metrics_test_std, perf_list_test = test(data.to(args.device), test_mask, model.to(args.device), neg_sampler, split_mode='test')
end_test = timeit.default_timer()
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metrics_test_mean: .4f} ± {perf_metrics_test_std: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'data': DATA,
'run': run_idx,
'seed': SEED,
f'val {metric}': f'{perf_metrics_val_mean: .4f} ± {perf_metrics_val_std: .4f}' ,
f'test {metric}': f'{perf_metrics_test_mean: .4f} ± {perf_metrics_test_std: .4f}' ,
'test_time': test_time,
'tot_train_val_time': val_time
},
results_filename)
print(f"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<")
print('-------------------------------------------------------------------------------')
print(f"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}")
print("==============================================================")
# save_results({'model': MODEL_NAME,
# 'data': DATA,
# 'run': 1,
# 'seed': SEED,
# metric: perf_metric_test,
# 'test_time': test_time,
# 'tot_train_val_time': 'NA'
# },
# results_filename)
================================================
FILE: examples/linkproppred/thgl-github/tgn.py
================================================
"""
python -u tgn.py --seed 1 --mem_dim 16 --time_dim 16 --emb_dim 16 --num_epoch 5 | tee tgn_s1_github_output.txt
"""
import numpy as np
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
from tgb.linkproppred.evaluate import Evaluator
from torch_geometric.loader import TemporalDataLoader
from tqdm import tqdm
import timeit
import math
import timeit
import os
import os.path as osp
from pathlib import Path
import numpy as np
import torch
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn import Linear
from torch_geometric.datasets import JODIEDataset
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TransformerConv
# internal imports
from tgb.utils.utils import get_args, set_random_seed, save_results
from tgb.linkproppred.evaluate import Evaluator
from modules.decoder import LinkPredictor
from modules.emb_module import GraphAttentionEmbedding
from modules.msg_func import IdentityMessage
from modules.msg_agg import LastAggregator
from modules.neighbor_loader import LastNeighborLoader
from modules.memory_module import TGNMemory
from modules.early_stopping import EarlyStopMonitor
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
# ==========
# ========== Define helper function...
# ==========
def train():
r"""
Training procedure for TGN model
This function uses some objects that are globally defined in the current scrips
Parameters:
None
Returns:
None
"""
model['memory'].train()
model['gnn'].train()
model['link_pred'].train()
model['memory'].reset_state() # Start with a fresh memory.
neighbor_loader.reset_state() # Start with an empty graph.
total_loss = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
src, pos_dst, t, msg, rel = batch.src, batch.dst, batch.t, batch.msg, batch.edge_type
# Sample negative destination nodes.
neg_dst = torch.randint(
min_dst_idx,
max_dst_idx + 1,
(src.size(0),),
dtype=torch.long,
device=device,
)
n_id = torch.cat([src, pos_dst, neg_dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])
neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])
loss = criterion(pos_out, torch.ones_like(pos_out))
loss += criterion(neg_out, torch.zeros_like(neg_out))
# Update memory and neighbor loader with ground-truth state.
model['memory'].update_state(src, pos_dst, t, msg)
neighbor_loader.insert(src, pos_dst)
loss.backward()
optimizer.step()
model['memory'].detach()
total_loss += float(loss.detach()) * batch.num_events
return total_loss / train_data.num_events
@torch.no_grad()
def test(loader, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
loader: an object containing positive attributes of the positive edges of the evaluation set
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluaiton
"""
model['memory'].eval()
model['gnn'].eval()
model['link_pred'].eval()
perf_list = []
for pos_batch in loader:
pos_src, pos_dst, pos_t, pos_msg, pos_rel = (
pos_batch.src,
pos_batch.dst,
pos_batch.t,
pos_batch.msg,
pos_batch.edge_type
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, pos_rel, split_mode=split_mode)
# pos_msg_new = torch.cat([pos_msg,pos_rel.unsqueeze(dim=1)], dim=1)
for idx, neg_batch in enumerate(neg_batch_list):
src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)
dst = torch.tensor(
np.concatenate(
([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),
axis=0,
),
device=device,
)
n_id = torch.cat([src, dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),
"y_pred_neg": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# Update memory and neighbor loader with ground-truth state.
model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg)
neighbor_loader.insert(pos_src, pos_dst)
perf_metrics = float(torch.tensor(perf_list).mean())
return perf_metrics
# ==========
# ==========
# ==========
# ==========
# ==========
# ==========
# Start...
start_overall = timeit.default_timer()
DATA = "thgl-github"
# ========== set parameters...
args, _ = get_args()
args.data = DATA
print("INFO: Arguments:", args)
LR = args.lr
BATCH_SIZE = args.bs
K_VALUE = args.k_value
NUM_EPOCH = args.num_epoch
SEED = args.seed
MEM_DIM = 16 #args.mem_dim
TIME_DIM = 16 #args.time_dim
EMB_DIM = 16 #args.emb_dim
TOLERANCE = args.tolerance
PATIENCE = args.patience
NUM_RUNS = 1 #args.num_run
NUM_NEIGHBORS = 10
USE_EDGE_TYPE = True
USE_NODE_TYPE = True
MODEL_NAME = 'TGN'
# ==========
# set the device
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu"
torch.manual_seed(SEED)
set_random_seed(SEED)
# data loading
dataset = PyGLinkPropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data = data.to(device)
metric = dataset.eval_metric
print ("there are {} nodes and {} edges".format(dataset.num_nodes, dataset.num_edges))
print ("there are {} relation types".format(dataset.num_rels))
timestamp = data.t
head = data.src
tail = data.dst
edge_type = data.edge_type #relation
edge_type_dim = len(torch.unique(edge_type))
embed_edge_type = torch.nn.Embedding(edge_type_dim, EMB_DIM).to(device)
with torch.no_grad():
edge_type_embeddings = embed_edge_type(edge_type)
if USE_EDGE_TYPE:
data.msg = torch.cat([data.msg, edge_type_embeddings], dim=1)
#! node type is a property of the dataset not the temporal data as temporal data has one entry per edge
node_type = dataset.node_type #node type
neg_sampler = dataset.negative_sampler
data.__setattr__("node_type", node_type)
print ("shape of edge type is", edge_type.shape)
print ("shape of node type is", node_type.shape)
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
print ("finished loading PyG data")
train_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)
val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)
test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)
start_time = timeit.default_timer()
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
print("==========================================================")
print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
best_test = 0
best_val = 0
best_epoch = 0
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
for run_idx in range(NUM_RUNS):
print('-------------------------------------------------------------------------------')
print(f"INFO: >>>>> Run: {run_idx} <<<<<")
start_run = timeit.default_timer()
# set the seed for deterministic results...
torch.manual_seed(run_idx + SEED)
set_random_seed(run_idx + SEED)
# neighhorhood sampler
neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)
# define the model end-to-end
memory = TGNMemory(
data.num_nodes,
data.msg.size(-1),
MEM_DIM,
TIME_DIM,
message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),
aggregator_module=LastAggregator(),
).to(device)
gnn = GraphAttentionEmbedding(
in_channels=MEM_DIM,
out_channels=EMB_DIM,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
).to(device)
link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)
model = {'memory': memory,
'gnn': gnn,
'link_pred': link_pred}
optimizer = torch.optim.Adam(
set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),
lr=LR,
)
criterion = torch.nn.BCEWithLogitsLoss()
# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)
# # define an early stopper
# save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
# save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'
# early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id,
# tolerance=TOLERANCE, patience=PATIENCE)
# ==================================================== Train & Validation
# loading the validation negative samples
dataset.load_val_ns()
val_perf_list = []
start_train_val = timeit.default_timer()
for epoch in range(1, NUM_EPOCH + 1):
# training
start_epoch_train = timeit.default_timer()
loss = train()
print(
f"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {timeit.default_timer() - start_epoch_train: .4f}"
)
# validation
start_val = timeit.default_timer()
perf_metric_val = test(val_loader, neg_sampler, split_mode="val")
print(f"\tValidation {metric}: {perf_metric_val: .4f}")
print(f"\tValidation: Elapsed time (s): {timeit.default_timer() - start_val: .4f}")
val_perf_list.append(perf_metric_val)
# # check for early stopping
# if early_stopper.step_check(perf_metric_val, model):
# break
train_val_time = timeit.default_timer() - start_train_val
print(f"Train & Validation: Elapsed Time (s): {train_val_time: .4f}")
# # ==================================================== Test
# # first, load the best model
# early_stopper.load_checkpoint(model)
# loading the test negative samples
dataset.load_test_ns()
# final testing
start_test = timeit.default_timer()
perf_metric_test = test(test_loader, neg_sampler, split_mode="test")
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
# save_results({'model': MODEL_NAME,
# 'data': DATA,
# 'run': run_idx,
# 'seed': SEED,
# f'val {metric}': val_perf_list,
# f'test {metric}': perf_metric_test,
# 'test_time': test_time,
# 'tot_train_val_time': train_val_time
# },
# results_filename)
if (perf_metric_val > best_val):
best_val = perf_metric_val
best_epoch = epoch
best_test = perf_metric_test
print(f"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<")
print('-------------------------------------------------------------------------------')
print ("INFO: Best Epoch: ", best_epoch)
print ("INFO: Best Validation Performance: ", best_val)
print ("INFO: Best Test Performance: ", best_test)
print(f"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}")
print("==============================================================")
================================================
FILE: examples/linkproppred/thgl-myket/edgebank.py
================================================
"""
Dynamic Link Prediction with EdgeBank
NOTE: This implementation works only based on `numpy`
Reference:
- https://github.com/fpour/DGB/tree/main
"""
import timeit
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch_geometric.loader import TemporalDataLoader
from tqdm import tqdm
import math
import os
import os.path as osp
from pathlib import Path
import sys
import argparse
# internal imports
tgb_modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(tgb_modules_path)
from tgb.linkproppred.evaluate import Evaluator
from modules.edgebank_predictor import EdgeBankPredictor
from tgb.utils.utils import set_random_seed
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.utils.utils import save_results
# ==================
# ==================
# ==================
def test(data, test_mask, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
data: a dataset object
test_mask: required masks to load the test set edges
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluation
"""
num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)
perf_list = []
hits_list = []
for batch_idx in tqdm(range(num_batches)):
start_idx = batch_idx * BATCH_SIZE
end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))
pos_src, pos_dst, pos_t, pos_edge = (
data['sources'][test_mask][start_idx: end_idx],
data['destinations'][test_mask][start_idx: end_idx],
data['timestamps'][test_mask][start_idx: end_idx],
data['edge_type'][test_mask][start_idx: end_idx],
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, pos_edge, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])
query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])
y_pred = edgebank.predict_link(query_src, query_dst)
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0]]),
"y_pred_neg": np.array(y_pred[1:]),
"eval_metric": [metric],
}
results = evaluator.eval(input_dict)
perf_list.append(results[metric])
hits_list.append(results['hits@10'])
# update edgebank memory after each positive batch
edgebank.update_memory(pos_src, pos_dst, pos_t)
perf_metrics = float(np.mean(perf_list))
perf_hits = float(np.mean(hits_list))
return perf_metrics, perf_hits
def get_args():
parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')
parser.add_argument('-d', '--data', type=str, help='Dataset name', default='thgl-myket')
parser.add_argument('--bs', type=int, help='Batch size', default=200)
parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)
parser.add_argument('--seed', type=int, help='Random seed', default=1)
parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])
parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)
try:
args = parser.parse_args()
except:
parser.print_help()
sys.exit(0)
return args, sys.argv
# ==================
# ==================
# ==================
start_overall = timeit.default_timer()
# set hyperparameters
args, _ = get_args()
SEED = args.seed # set the random seed for consistency
set_random_seed(SEED)
MEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`
BATCH_SIZE = args.bs
K_VALUE = args.k_value
TIME_WINDOW_RATIO = args.time_window_ratio
DATA = args.data
MODEL_NAME = 'EdgeBank'
# data loading with `numpy`
dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
data = dataset.full_data
metric = dataset.eval_metric
# get masks
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
#data for memory in edgebank
hist_src = np.concatenate([data['sources'][train_mask]])
hist_dst = np.concatenate([data['destinations'][train_mask]])
hist_ts = np.concatenate([data['timestamps'][train_mask]])
# Set EdgeBank with memory updater
edgebank = EdgeBankPredictor(
hist_src,
hist_dst,
hist_ts,
memory_mode=MEMORY_MODE,
time_window_ratio=TIME_WINDOW_RATIO)
print("==========================================================")
print(f"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'
# ==================================================== Test
# loading the validation negative samples
dataset.load_val_ns()
# testing ...
start_val = timeit.default_timer()
perf_metric_val, perf_hits_val = test(data, val_mask, neg_sampler, split_mode='val')
end_val = timeit.default_timer()
print(f"INFO: val: Evaluation Setting: >>> ONE-VS--ALL <<< ")
print(f"\tval: {metric}: {perf_metric_val: .4f}")
val_time = timeit.default_timer() - start_val
print(f"\tval: Elapsed Time (s): {val_time: .4f}")
# ==================================================== Test
# loading the test negative samples
dataset.load_test_ns()
# testing ...
start_test = timeit.default_timer()
perf_metric_test, perf_hits_test = test(data, test_mask, neg_sampler, split_mode='test')
end_test = timeit.default_timer()
print(f"INFO: Test: Evaluation Setting: >>> <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'memory_mode': MEMORY_MODE,
'data': DATA,
'run': 1,
'seed': SEED,
metric: perf_metric_test,
'val_mrr': perf_metric_val,
'test_time': test_time,
'tot_train_val_time': test_time+val_time,
'hits10': perf_hits_test},
results_filename)
================================================
FILE: examples/linkproppred/thgl-myket/recurrencybaseline.py
================================================
""" from paper: "History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting"
Julia Gastinger, Christian Meilicke, Federico Errica, Timo Sztyler, Anett Schuelke, Heiner Stuckenschmidt (IJCAI 2024)
@inproceedings{gastinger2024baselines,
title={History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting},
author={Gastinger, Julia and Meilicke, Christian and Errica, Federico and Sztyler, Timo and Schuelke, Anett and Stuckenschmidt, Heiner},
booktitle={33nd International Joint Conference on Artificial Intelligence (IJCAI 2024)},
year={2024},
organization={International Joint Conferences on Artificial Intelligence Organization}
}
"""
## imports
import timeit
import argparse
import numpy as np
from copy import copy
from pathlib import Path
import ray
import sys
import os
import os.path as osp
import json
#internal imports
modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(modules_path)
from modules.recurrencybaseline_predictor import baseline_predict, baseline_predict_remote
from tgb.linkproppred.evaluate import Evaluator
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.utils.utils import set_random_seed, save_results
from modules.tkg_utils import create_basis_dict, group_by, reformat_ts
def predict(num_processes, data_c_rel, all_data_c_rel, alpha, lmbda_psi,
perf_list_all, hits_list_all, window, neg_sampler, split_mode):
""" create predictions for each relation on test or valid set and compute mrr
:return perf_list_all: list of mrrs for each test query
:return hits_list_all: list of hits for each test query
"""
first_ts = data_c_rel[0][3]
## use this if you wanna use ray:
num_queries = len(data_c_rel) // num_processes
if num_queries < num_processes: # if we do not have enough queries for all the processes
num_processes_tmp = 1
num_queries = len(data_c_rel)
else:
num_processes_tmp = num_processes
if num_processes > 1:
object_references =[]
for i in range(num_processes_tmp):
num_test_queries = len(data_c_rel) - (i + 1) * num_queries
if num_test_queries >= num_queries:
test_queries_idx =[i * num_queries, (i + 1) * num_queries]
else:
test_queries_idx = [i * num_queries, len(test_data)]
valid_data_b = data_c_rel[test_queries_idx[0]:test_queries_idx[1]]
ob = baseline_predict_remote.remote(num_queries, valid_data_b, all_data_c_rel, window,
basis_dict,
num_nodes, num_rels, lmbda_psi,
alpha, evaluator,first_ts, neg_sampler, split_mode)
object_references.append(ob)
output = ray.get(object_references)
# updates the scores and logging dict for each process
for proc_loop in range(num_processes_tmp):
perf_list_all.extend(output[proc_loop][0])
hits_list_all.extend(output[proc_loop][1])
else:
perf_list, hits_list = baseline_predict(len(data_c_rel), data_c_rel, all_data_c_rel,
window, basis_dict,
num_nodes, num_rels, lmbda_psi,
alpha, evaluator, first_ts, neg_sampler, split_mode)
perf_list_all.extend(perf_list)
hits_list_all.extend(hits_list)
return perf_list_all, hits_list_all
## test
def test(best_config, all_relations,test_data_prel, all_data_prel, neg_sampler, num_processes, window, split_mode='test'):
""" create predictions by loopoing through all relations on test or valid set and compute mrr
:return perf_list_all: list of mrrs for each test query
:return hits_list_all: list of hits for each test query
"""
perf_list_all = []
hits_list_all =[]
csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'
## loop through relations and apply baselines
for rel in all_relations:
start = timeit.default_timer()
if rel in test_data_prel.keys():
lmbda_psi = best_config[str(rel)]['lmbda_psi'][0]
alpha = best_config[str(rel)]['alpha'][0]
# test data for this relation
test_data_c_rel = test_data_prel[rel]
timesteps_test = list(set(test_data_c_rel[:,3]))
timesteps_test.sort()
all_data_c_rel = all_data_prel[rel]
perf_list_rel = []
hits_list_rel = []
perf_list_rel, hits_list_rel = predict(num_processes, test_data_c_rel,
all_data_c_rel, alpha, lmbda_psi,perf_list_rel, hits_list_rel,
window, neg_sampler, split_mode)
perf_list_all.extend(perf_list_rel)
hits_list_all.extend(hits_list_rel)
else:
perf_list_rel =[]
end = timeit.default_timer()
total_time = round(end - start, 6)
print("Relation {} finished in {} seconds.".format(rel, total_time))
with open(csv_file, 'a') as f:
f.write("{},{}\n".format(rel, perf_list_rel))
return perf_list_all, hits_list_all
def read_dict_compute_mrr(split_mode='test'):
""" read the results per relation from a precreated file and compute mrr
:return mrr_per_rel: dictionary of mrrs for each relation
:return all_mrrs: list of mrrs for all relations
"""
csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'
# Initialize an empty dictionary to store the data
results_per_rel_dict = {}
mrr_per_rel = {}
all_mrrs = []
# Open the file for reading
with open(csv_file, 'r') as f:
# Read each line in the file
for line in f:
# Split the line at the comma
parts = line.strip().split(',')
# Extract the key (the first part)
key = int(parts[0])
# Extract the values (the rest of the parts), remove square brackets
values = [float(value.strip('[]')) for value in parts[1:]]
# Add the key-value pair to the dictionary
if key in results_per_rel_dict.keys():
print(f"Key {key} already exists in the dictionary!!! might have duplicate entries in results csv")
results_per_rel_dict[key] = values
all_mrrs.extend(values)
mrr_per_rel[key] = np.mean(values)
if len(list(results_per_rel_dict.keys())) != num_rels:
print("we do not have entries for each rel in the results csv file. only num enties: ", len(list(results_per_rel_dict.keys())))
print("Split mode: "+split_mode +" Mean MRR: ", np.mean(all_mrrs))
print("mrr per relation: ", mrr_per_rel)
## train
def train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampler, num_processes, window):
""" optional, find best values for lambda and alpha by looping through all relations and testing an a fixed set of params
based on validation mrr
:return best_config: dictionary of best params for each relation
"""
best_config= {}
best_mrr = 0
for rel in rels: # loop through relations. for each relation, apply rules with selected params, compute valid mrr
start = timeit.default_timer()
rel_key = int(rel)
best_config[str(rel_key)] = {}
best_config[str(rel_key)]['not_trained'] = 'True'
best_config[str(rel_key)]['lmbda_psi'] = [default_lmbda_psi,0] #default
best_config[str(rel_key)]['other_lmbda_mrrs'] = list(np.zeros(len(params_dict['lmbda_psi'])))
best_config[str(rel_key)]['alpha'] = [default_alpha,0] #default
best_config[str(rel_key)]['other_alpha_mrrs'] = list(np.zeros(len(params_dict['alpha'])))
if rel in val_data_prel.keys():
# valid data for this relation
val_data_c_rel = val_data_prel[rel]
timesteps_valid = list(set(val_data_c_rel[:,3]))
timesteps_valid.sort()
trainval_data_c_rel = trainval_data_prel[rel]
###### 1) select lambda ###############
lmbdas_psi = params_dict['lmbda_psi']
alpha = 1
best_lmbda_psi = 0.1
best_mrr_psi = 0
lmbda_mrrs = []
best_config[str(rel_key)]['num_app_valid'] = copy(len(val_data_c_rel))
best_config[str(rel_key)]['num_app_train_valid'] = copy(len(trainval_data_c_rel))
best_config[str(rel_key)]['not_trained'] = 'False'
for lmbda_psi in lmbdas_psi:
perf_list_r = []
hits_list_r = []
perf_list_r, hits_list_r = predict(num_processes, val_data_c_rel,
trainval_data_c_rel, alpha, lmbda_psi,perf_list_r, hits_list_r,
window, neg_sampler, split_mode='val')
# compute mrr
mrr = np.mean(perf_list_r)
# # is new mrr better than previous best? if yes: store lmbda
if mrr > best_mrr_psi:
best_mrr_psi = float(mrr)
best_lmbda_psi = lmbda_psi
lmbda_mrrs.append(float(mrr))
best_config[str(rel_key)]['lmbda_psi'] = [best_lmbda_psi, best_mrr_psi]
best_config[str(rel_key)]['other_lmbda_mrrs'] = lmbda_mrrs
best_mrr = best_mrr_psi
##### 2) select alpha ###############
best_config[str(rel_key)]['not_trained'] = 'False'
alphas = params_dict['alpha']
lmbda_psi = best_config[str(rel_key)]['lmbda_psi'][0] # use the best lmbda psi
alpha_mrrs = []
# perf_list_all = []
best_mrr_alpha = 0
best_alpha=0.99
for alpha in alphas:
perf_list_r = []
hits_list_r = []
perf_list_r, hits_list_r = predict(num_processes, val_data_c_rel,
trainval_data_c_rel, alpha, lmbda_psi,perf_list_r, hits_list_r,
window, neg_sampler, split_mode='val')
# compute mrr
mrr_alpha = np.mean(perf_list_r)
# is new mrr better than previous best? if yes: store alpha
if mrr_alpha > best_mrr_alpha:
best_mrr_alpha = float(mrr_alpha)
best_alpha = alpha
best_mrr = best_mrr_alpha
alpha_mrrs.append(float(mrr_alpha))
best_config[str(rel_key)]['alpha'] = [best_alpha, best_mrr_alpha]
best_config[str(rel_key)]['other_alpha_mrrs'] = alpha_mrrs
end = timeit.default_timer()
total_time = round(end - start, 6)
print("Relation {} finished in {} seconds.".format(rel, total_time))
return best_config
## args
def get_args():
"""parse all arguments for the script"""
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", "-d", default="thgl-myket", type=str)
parser.add_argument("--window", "-w", default=0, type=int) # set to e.g. 200 if only the most recent 200 timesteps should be considered. set to -2 if multistep
parser.add_argument("--num_processes", "-p", default=1, type=int)
parser.add_argument("--lmbda", "-l", default=0.1, type=float) # fix lambda. used if trainflag == false
parser.add_argument("--alpha", "-alpha", default=0.99, type=float) # fix alpha. used if trainflag == false
parser.add_argument("--train_flag", "-tr", default='False') # do we need training, ie selection of lambda and alpha
parser.add_argument("--load_flag", "-lo", default='False') # if train_flag set to True: do you want to load best_config?
parser.add_argument("--save_config", "-c", default='True') # do we need to save the selection of lambda and alpha in config file?
parser.add_argument('--seed', type=int, help='Random seed', default=1) # not needed
parsed = vars(parser.parse_args())
return parsed
start_o = timeit.default_timer()
parsed = get_args()
if parsed['num_processes']>1:
ray.init(num_cpus=parsed["num_processes"], num_gpus=0)
MODEL_NAME = 'RecurrencyBaseline'
SEED = parsed['seed'] # set the random seed for consistency
set_random_seed(SEED)
perrel_results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_models'
if not osp.exists(perrel_results_path):
os.mkdir(perrel_results_path)
print('INFO: Create directory {}'.format(perrel_results_path))
Path(perrel_results_path).mkdir(parents=True, exist_ok=True)
## load dataset and prepare it accordingly
name = parsed["dataset"]
dataset = LinkPropPredDataset(name=name, root="datasets", preprocess=True)
DATA = name
relations = dataset.edge_type
num_rels = dataset.num_rels
rels = np.arange(0,num_rels)
subjects = dataset.full_data["sources"]
objects= dataset.full_data["destinations"]
num_nodes = dataset.num_nodes
timestamps_orig = dataset.full_data["timestamps"]
timestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1
print("split train valid test data")
all_quads = np.stack((subjects, relations, objects, timestamps, timestamps_orig), axis=1)
train_data = all_quads[dataset.train_mask]
val_data = all_quads[dataset.val_mask]
test_data = all_quads[dataset.test_mask]
metric = dataset.eval_metric
evaluator = Evaluator(name=name)
neg_sampler = dataset.negative_sampler
train_val_data = np.concatenate([train_data, val_data])
all_data = np.concatenate([train_data, val_data, test_data])
# create dicts with key: relation id, values: triples for that relation id
print("grouping data by relation")
test_data_prel = group_by(test_data, 1)
all_data_prel = group_by(all_data, 1)
val_data_prel = group_by(val_data, 1)
trainval_data_prel = group_by(train_val_data, 1)
#load the ns samples
# if parsed['train_flag']:
print("loading negative samples")
dataset.load_val_ns()
dataset.load_test_ns()
# parameter options
if parsed['train_flag'] == 'True':
params_dict = {}
params_dict['lmbda_psi'] = [0, 0.0001, 0.0005, 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.5, 0.9, 1.0001]
params_dict['alpha'] = [0, 0.00001, 0.0001, 0.001, 0.01, 0.1, 0.5, 0.9, 0.99, 0.999, 0.9999, 0.99999, 1]
default_lmbda_psi = params_dict['lmbda_psi'][-1]
default_alpha = params_dict['alpha'][-2]
## load rules
print("creating rules")
basis_dict = create_basis_dict(train_val_data)
print("done with creating rules")
## init
# rb_predictor = RecurrencyBaselinePredictor(rels)
## train to find best lambda and alpha
start_train = timeit.default_timer()
if parsed['train_flag'] == 'True':
if parsed['load_flag'] == 'True':
with open('best_config.json', 'r') as infile:
best_config = json.load(infile)
else:
print('start training')
best_config = train(params_dict, rels, val_data_prel, trainval_data_prel, neg_sampler, parsed['num_processes'],
parsed['window'])
if parsed['save_config'] == 'True':
import json
with open('best_config.json', 'w') as outfile:
json.dump(best_config, outfile)
else: # use preset lmbda and alpha; same for all relations
best_config = {}
for rel in rels:
best_config[str(rel)] = {}
best_config[str(rel)]['lmbda_psi'] = [parsed['lmbda']]
best_config[str(rel)]['alpha'] = [parsed['alpha']]
end_train = timeit.default_timer()
# compute validation mrr
print("Computing validation MRR")
perf_list_all_val, hits_list_all_val = test(best_config,rels, val_data_prel,
trainval_data_prel, neg_sampler, parsed['num_processes'],
parsed['window'], split_mode='val')
val_mrr = float(np.mean(perf_list_all_val))
# compute test mrr
print("Computing test MRR")
start_test = timeit.default_timer()
perf_list_all, hits_list_all = test(best_config,rels, test_data_prel,
all_data_prel, neg_sampler, parsed['num_processes'],
parsed['window'])
end_o = timeit.default_timer()
train_time_o = round(end_train- start_train, 6)
test_time_o = round(end_o- start_test, 6)
total_time_o = round(end_o- start_o, 6)
print("Running Training to find best configs finished in {} seconds.".format(train_time_o))
print("Running testing with best configs finished in {} seconds.".format(test_time_o))
print("Running all steps finished in {} seconds.".format(total_time_o))
print(f"The test MRR is {np.mean(perf_list_all)}")
print(f"The valid MRR is {val_mrr}")
print(f"The Hits@10 is {np.mean(hits_list_all)}")
print(f"We have {len(perf_list_all)} predictions")
print(f"The test set has len {len(test_data)} ")
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_NONE_{DATA}_results.json'
metric = dataset.eval_metric
save_results({'model': MODEL_NAME,
'train_flag': parsed['train_flag'],
'data': DATA,
'run': 1,
'seed': SEED,
metric: float(np.mean(perf_list_all)),
'val_mrr': val_mrr,
'hits10': float(np.mean(hits_list_all)),
'test_time': test_time_o,
'tot_train_val_time': total_time_o
},
results_filename)
if parsed['num_processes']>1:
ray.shutdown()
================================================
FILE: examples/linkproppred/thgl-myket/sthn.py
================================================
import timeit
import numpy as np
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
from tgb.linkproppred.evaluate import Evaluator
import argparse
from modules.sthn import set_seed, pre_compute_subgraphs, get_inputs_for_ind, check_data_leakage
import torch
import pandas as pd
import itertools
from tqdm import tqdm
import math
import os
import os.path as osp
from pathlib import Path
from tgb.utils.utils import set_random_seed, save_results
# Start...
start_overall = timeit.default_timer()
DATA = "thgl-myket"
MODEL_NAME = 'STHN'
# data loading
dataset = PyGLinkPropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
metric = dataset.eval_metric
print ("there are {} nodes and {} edges".format(dataset.num_nodes, dataset.num_edges))
print ("there are {} relation types".format(dataset.num_rels))
timestamp = data.t
head = data.src
tail = data.dst
edge_type = data.edge_type
neg_sampler = dataset.negative_sampler
print(data)
print(timestamp)
print(head)
print(tail)
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
metric = dataset.eval_metric
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
####################################################################
####################################################################
####################################################################
def print_model_info(model):
print(model)
parameters = filter(lambda p: p.requires_grad, model.parameters())
parameters = sum([np.prod(p.size()) for p in parameters])
print('Trainable Parameters: %d' % parameters)
def get_args():
parser=argparse.ArgumentParser()
parser.add_argument('--data', type=str, default='movie')
parser.add_argument('--device', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=600)
parser.add_argument('--epochs', type=int, default=2)
parser.add_argument('--max_edges', type=int, default=50)
parser.add_argument('--num_edgeType', type=int, default=0, help='num of edgeType')
parser.add_argument('--lr', type=float, default=0.0005)
parser.add_argument('--weight_decay', type=float, default=1e-4)
parser.add_argument('--predict_class', action='store_true')
# model
parser.add_argument('--window_size', type=int, default=5)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--model', type=str, default='sthn')
parser.add_argument('--neg_samples', type=int, default=1)
parser.add_argument('--extra_neg_samples', type=int, default=5)
parser.add_argument('--num_neighbors', type=int, default=50)
parser.add_argument('--channel_expansion_factor', type=int, default=2)
parser.add_argument('--sampled_num_hops', type=int, default=1)
parser.add_argument('--time_dims', type=int, default=100)
parser.add_argument('--hidden_dims', type=int, default=100)
parser.add_argument('--num_layers', type=int, default=1)
parser.add_argument('--check_data_leakage', action='store_true')
parser.add_argument('--ignore_node_feats', action='store_true')
parser.add_argument('--node_feats_as_edge_feats', action='store_true')
parser.add_argument('--ignore_edge_feats', action='store_true')
parser.add_argument('--use_onehot_node_feats', action='store_true')
parser.add_argument('--use_type_feats', action='store_true')
parser.add_argument('--use_graph_structure', action='store_true')
parser.add_argument('--structure_time_gap', type=int, default=2000)
parser.add_argument('--structure_hops', type=int, default=1)
parser.add_argument('--use_node_cls', action='store_true')
parser.add_argument('--use_cached_subgraph', action='store_true')
parser.add_argument('--seed', type=int, help='Random seed', default=1)
parser.add_argument('--num_run', type=int, help='Number of iteration runs', default=5)
return parser.parse_args()
def load_model(args):
# get model
edge_predictor_configs = {
'dim_in_time': args.time_dims,
'dim_in_node': args.node_feat_dims,
'predict_class': 1 if not args.predict_class else args.num_edgeType+1,
}
if args.model == 'sthn':
if args.predict_class:
from modules.sthn import Multiclass_Interface as STHN_Interface
else:
from modules.sthn import STHN_Interface
from modules.sthn import link_pred_train
mixer_configs = {
'per_graph_size' : args.max_edges,
'time_channels' : args.time_dims,
'input_channels' : args.edge_feat_dims,
'hidden_channels' : args.hidden_dims,
'out_channels' : args.hidden_dims,
'num_layers' : args.num_layers,
'dropout' : args.dropout,
'channel_expansion_factor': args.channel_expansion_factor,
'window_size' : args.window_size,
'use_single_layer' : False
}
else:
NotImplementedError()
model = STHN_Interface(mixer_configs, edge_predictor_configs)
for k, v in model.named_parameters():
print(k, v.requires_grad)
print_model_info(model)
return model, args, link_pred_train
def load_graph(data):
df = pd.DataFrame({
'idx': np.arange(len(data.t)),
'src': data.src,
'dst': data.dst,
'time': data.t,
'label': data.edge_type,
})
num_nodes = max(int(df['src'].max()), int(df['dst'].max())) + 1
ext_full_indptr = np.zeros(num_nodes + 1, dtype=np.int32)
ext_full_indices = [[] for _ in range(num_nodes)]
ext_full_ts = [[] for _ in range(num_nodes)]
ext_full_eid = [[] for _ in range(num_nodes)]
for idx, row in tqdm(df.iterrows(), total=len(df)):
src = int(row['src'])
dst = int(row['dst'])
ext_full_indices[src].append(dst)
ext_full_ts[src].append(row['time'])
ext_full_eid[src].append(idx)
for i in tqdm(range(num_nodes)):
ext_full_indptr[i + 1] = ext_full_indptr[i] + len(ext_full_indices[i])
ext_full_indices = np.array(list(itertools.chain(*ext_full_indices)))
ext_full_ts = np.array(list(itertools.chain(*ext_full_ts)))
ext_full_eid = np.array(list(itertools.chain(*ext_full_eid)))
print('Sorting...')
def tsort(i, indptr, indices, t, eid):
beg = indptr[i]
end = indptr[i + 1]
sidx = np.argsort(t[beg:end])
indices[beg:end] = indices[beg:end][sidx]
t[beg:end] = t[beg:end][sidx]
eid[beg:end] = eid[beg:end][sidx]
for i in tqdm(range(ext_full_indptr.shape[0] - 1)):
tsort(i, ext_full_indptr, ext_full_indices, ext_full_ts, ext_full_eid)
print('saving...')
np.savez('/tmp/ext_full.npz', indptr=ext_full_indptr,
indices=ext_full_indices, ts=ext_full_ts, eid=ext_full_eid)
g = np.load('/tmp/ext_full.npz')
return g, df
def load_all_data(args):
# load graph
g, df = load_graph(data)
args.train_mask = train_mask.numpy()
args.val_mask = val_mask.numpy()
args.test_mask = test_mask.numpy()
args.num_edges = len(df)
print('Train %d, Valid %d, Test %d'%(sum(args.train_mask),
sum(args.val_mask),
sum(test_mask)))
args.num_nodes = max(int(df['src'].max()), int(df['dst'].max())) + 1
args.num_edges = len(df)
print('Num nodes %d, num edges %d'%(args.num_nodes, args.num_edges))
# load feats
node_feats, edge_feats = dataset.node_feat, dataset.edge_feat
node_feat_dims = 0 if node_feats is None else node_feats.shape[1]
edge_feat_dims = 0 if edge_feats is None else edge_feats.shape[1]
# feature pre-processing
if args.use_onehot_node_feats:
print('>>> Use one-hot node features')
node_feats = torch.eye(args.num_nodes)
node_feat_dims = node_feats.size(1)
if args.ignore_node_feats:
print('>>> Ignore node features')
node_feats = None
node_feat_dims = 0
if args.use_type_feats:
edge_type = df.label.values
print(edge_type)
print(edge_type.sum())
args.num_edgeType = len(set(edge_type.tolist()))
edge_feats = torch.nn.functional.one_hot(torch.from_numpy(edge_type),
num_classes=args.num_edgeType)
edge_feat_dims = edge_feats.size(1)
print('Node feature dim %d, edge feature dim %d'%(node_feat_dims, edge_feat_dims))
# double check (if data leakage then cannot continue the code)
if args.check_data_leakage:
check_data_leakage(args, g, df)
args.node_feat_dims = node_feat_dims
args.edge_feat_dims = edge_feat_dims
if node_feats != None:
node_feats = node_feats.to(args.device)
if edge_feats != None:
edge_feats = edge_feats.to(args.device)
return node_feats, edge_feats, g, df, args
####################################################################
####################################################################
####################################################################
@torch.no_grad()
def test(data, test_mask, model, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
data: a dataset object
test_mask: required masks to load the test set edges
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'val' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluation
"""
test_subgraphs = pre_compute_subgraphs(args, g, df, mode='test' if split_mode == 'test' else 'valid', negative_sampler=neg_sampler, split_mode=split_mode)
perf_list = []
if split_mode == 'test':
cur_df = df[args.test_mask]
elif split_mode == 'val':
cur_df = df[args.val_mask]
neg_samples = 20
cached_neg_samples = 20
test_loader = cur_df.groupby(cur_df.index // args.batch_size)
pbar = tqdm(total=len(test_loader))
pbar.set_description('%s mode with negative samples %d ...'%(split_mode, neg_samples))
###################################################
# compute + training + fetch all scores
cur_inds = 0
for ind in range(len(test_loader)):
###################################################
inputs, subgraph_node_feats, cur_inds = get_inputs_for_ind(test_subgraphs, 'test' if split_mode == 'test' else 'tgb-val', cached_neg_samples, neg_samples, node_feats, edge_feats, cur_df, cur_inds, ind, args)
loss, pred, edge_label = model(inputs, neg_samples, subgraph_node_feats)
# print(ind, [l for l in inputs], pred.shape)
input_dict = {
"y_pred_pos": np.array([pred.cpu()[0]]),
"y_pred_neg": np.array(pred.cpu()[1:]),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
perf_metrics_mean = float(np.mean(perf_list))
perf_metrics_std = float(np.std(perf_list))
return perf_metrics_mean, perf_metrics_std, perf_list
args = get_args()
args.use_graph_structure = True
args.use_onehot_node_feats = False
args.ignore_node_feats = False # we only use graph structure
args.use_type_feats = True # type encoding
args.use_cached_subgraph = True
print(args)
args.device = f"cuda:{args.device}" if torch.cuda.is_available() else "cpu"
args.device = torch.device(args.device)
SEED = args.seed
BATCH_SIZE = args.batch_size
NUM_RUNS = args.num_run
set_seed(SEED)
###################################################
# load feats + graph
node_feats, edge_feats, g, df, args = load_all_data(args)
###################################################
# get model
model, args, link_pred_train = load_model(args)
###################################################
print("==========================================================")
print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============")
print("==========================================================")
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
for run_idx in range(NUM_RUNS):
print('-------------------------------------------------------------------------------')
print(f"INFO: >>>>> Run: {run_idx} <<<<<")
start_run = timeit.default_timer()
# set the seed for deterministic results...
torch.manual_seed(run_idx + SEED)
set_random_seed(run_idx + SEED)
# define an early stopper
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'
# early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id,
# tolerance=TOLERANCE, patience=PATIENCE)
# ==================================================== Train & Validation
# loading the validation negative samples
# Link prediction
start_val = timeit.default_timer()
print('Train link prediction task from scratch ...')
model = link_pred_train(model.to(args.device), args, g, df, node_feats, edge_feats)
dataset.load_val_ns()
# Validation ...
perf_metrics_val_mean, perf_metrics_val_std, perf_list_val = test(data.to(args.device), test_mask, model.to(args.device), neg_sampler, split_mode='val')
end_val = timeit.default_timer()
print(f"INFO: val: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tval: {metric}: {perf_metrics_val_mean: .4f} ± {perf_metrics_val_std: .4f}")
val_time = timeit.default_timer() - start_val
print(f"\tval: Elapsed Time (s): {val_time: .4f}")
dataset.load_test_ns()
# testing ...
start_test = timeit.default_timer()
perf_metrics_test_mean, perf_metrics_test_std, perf_list_test = test(data.to(args.device), test_mask, model.to(args.device), neg_sampler, split_mode='test')
end_test = timeit.default_timer()
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metrics_test_mean: .4f} ± {perf_metrics_test_std: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'data': DATA,
'run': run_idx,
'seed': SEED,
f'val {metric}': f'{perf_metrics_val_mean: .4f} ± {perf_metrics_val_std: .4f}' ,
f'test {metric}': f'{perf_metrics_test_mean: .4f} ± {perf_metrics_test_std: .4f}' ,
'test_time': test_time,
'tot_train_val_time': val_time
},
results_filename)
print(f"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<")
print('-------------------------------------------------------------------------------')
print(f"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}")
print("==============================================================")
# save_results({'model': MODEL_NAME,
# 'data': DATA,
# 'run': 1,
# 'seed': SEED,
# metric: perf_metric_test,
# 'test_time': test_time,
# 'tot_train_val_time': 'NA'
# },
# results_filename)
================================================
FILE: examples/linkproppred/thgl-myket/tgn.py
================================================
import numpy as np
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
from tgb.linkproppred.evaluate import Evaluator
from torch_geometric.loader import TemporalDataLoader
from tqdm import tqdm
import timeit
import math
import timeit
import os
import os.path as osp
from pathlib import Path
import numpy as np
import torch
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn import Linear
from torch_geometric.datasets import JODIEDataset
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TransformerConv
# internal imports
from tgb.utils.utils import get_args, set_random_seed, save_results
from tgb.linkproppred.evaluate import Evaluator
from modules.decoder import LinkPredictor
from modules.emb_module import GraphAttentionEmbedding
from modules.msg_func import IdentityMessage
from modules.msg_agg import LastAggregator
from modules.neighbor_loader import LastNeighborLoader
from modules.memory_module import TGNMemory
from modules.early_stopping import EarlyStopMonitor
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
# ==========
# ========== Define helper function...
# ==========
def train():
r"""
Training procedure for TGN model
This function uses some objects that are globally defined in the current scrips
Parameters:
None
Returns:
None
"""
model['memory'].train()
model['gnn'].train()
model['link_pred'].train()
model['memory'].reset_state() # Start with a fresh memory.
neighbor_loader.reset_state() # Start with an empty graph.
total_loss = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
src, pos_dst, t, msg, rel = batch.src, batch.dst, batch.t, batch.msg, batch.edge_type
# Sample negative destination nodes.
neg_dst = torch.randint(
min_dst_idx,
max_dst_idx + 1,
(src.size(0),),
dtype=torch.long,
device=device,
)
n_id = torch.cat([src, pos_dst, neg_dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])
neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])
loss = criterion(pos_out, torch.ones_like(pos_out))
loss += criterion(neg_out, torch.zeros_like(neg_out))
# Update memory and neighbor loader with ground-truth state.
model['memory'].update_state(src, pos_dst, t, msg)
neighbor_loader.insert(src, pos_dst)
loss.backward()
optimizer.step()
model['memory'].detach()
total_loss += float(loss.detach()) * batch.num_events
return total_loss / train_data.num_events
@torch.no_grad()
def test(loader, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
loader: an object containing positive attributes of the positive edges of the evaluation set
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluaiton
"""
model['memory'].eval()
model['gnn'].eval()
model['link_pred'].eval()
perf_list = []
for pos_batch in loader:
pos_src, pos_dst, pos_t, pos_msg, pos_rel = (
pos_batch.src,
pos_batch.dst,
pos_batch.t,
pos_batch.msg,
pos_batch.edge_type
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, pos_rel, split_mode=split_mode)
# pos_msg_new = torch.cat([pos_msg,pos_rel.unsqueeze(dim=1)], dim=1)
for idx, neg_batch in enumerate(neg_batch_list):
src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)
dst = torch.tensor(
np.concatenate(
([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),
axis=0,
),
device=device,
)
n_id = torch.cat([src, dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),
"y_pred_neg": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# Update memory and neighbor loader with ground-truth state.
model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg)
neighbor_loader.insert(pos_src, pos_dst)
perf_metrics = float(torch.tensor(perf_list).mean())
return perf_metrics
# ==========
# ==========
# ==========
# ==========
# ==========
# ==========
# Start...
start_overall = timeit.default_timer()
DATA = "thgl-myket"
# ========== set parameters...
args, _ = get_args()
args.data = DATA
print("INFO: Arguments:", args)
LR = args.lr
BATCH_SIZE = args.bs
K_VALUE = args.k_value
NUM_EPOCH = args.num_epoch
SEED = args.seed
MEM_DIM = args.mem_dim
TIME_DIM = args.time_dim
EMB_DIM = args.emb_dim
TOLERANCE = args.tolerance
PATIENCE = args.patience
NUM_RUNS = args.num_run
NUM_NEIGHBORS = 10
USE_EDGE_TYPE = True
USE_NODE_TYPE = True
MODEL_NAME = 'TGN'
# ==========
# set the device
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
print(device)
# data loading
dataset = PyGLinkPropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data = data.to(device)
metric = dataset.eval_metric
print ("there are {} nodes and {} edges".format(dataset.num_nodes, dataset.num_edges))
print ("there are {} relation types".format(dataset.num_rels))
timestamp = data.t
head = data.src
tail = data.dst
edge_type = data.edge_type #relation
edge_type_dim = len(torch.unique(edge_type))
embed_edge_type = torch.nn.Embedding(edge_type_dim, 128).to(device)
with torch.no_grad():
edge_type_embeddings = embed_edge_type(edge_type)
if USE_EDGE_TYPE:
data.msg = torch.cat([data.msg, edge_type_embeddings], dim=1)
#! node type is a property of the dataset not the temporal data as temporal data has one entry per edge
node_type = dataset.node_type #node type
neg_sampler = dataset.negative_sampler
data.__setattr__("node_type", node_type)
print ("shape of edge type is", edge_type.shape)
print ("shape of node type is", node_type.shape)
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
print ("finished loading PyG data")
train_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)
val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)
test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)
start_time = timeit.default_timer()
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
print("==========================================================")
print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
for run_idx in range(NUM_RUNS):
print('-------------------------------------------------------------------------------')
print(f"INFO: >>>>> Run: {run_idx} <<<<<")
start_run = timeit.default_timer()
# set the seed for deterministic results...
torch.manual_seed(run_idx + SEED)
set_random_seed(run_idx + SEED)
# neighhorhood sampler
neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)
# define the model end-to-end
memory = TGNMemory(
data.num_nodes,
data.msg.size(-1),
MEM_DIM,
TIME_DIM,
message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),
aggregator_module=LastAggregator(),
).to(device)
gnn = GraphAttentionEmbedding(
in_channels=MEM_DIM,
out_channels=EMB_DIM,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
).to(device)
link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)
model = {'memory': memory,
'gnn': gnn,
'link_pred': link_pred}
optimizer = torch.optim.Adam(
set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),
lr=LR,
)
criterion = torch.nn.BCEWithLogitsLoss()
# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)
# define an early stopper
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'
early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id,
tolerance=TOLERANCE, patience=PATIENCE)
# ==================================================== Train & Validation
# loading the validation negative samples
dataset.load_val_ns()
val_perf_list = []
start_train_val = timeit.default_timer()
for epoch in range(1, NUM_EPOCH + 1):
# training
start_epoch_train = timeit.default_timer()
loss = train()
print(
f"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {timeit.default_timer() - start_epoch_train: .4f}"
)
# validation
start_val = timeit.default_timer()
perf_metric_val = test(val_loader, neg_sampler, split_mode="val")
print(f"\tValidation {metric}: {perf_metric_val: .4f}")
print(f"\tValidation: Elapsed time (s): {timeit.default_timer() - start_val: .4f}")
val_perf_list.append(perf_metric_val)
# check for early stopping
if early_stopper.step_check(perf_metric_val, model):
break
train_val_time = timeit.default_timer() - start_train_val
print(f"Train & Validation: Elapsed Time (s): {train_val_time: .4f}")
# ==================================================== Test
# first, load the best model
early_stopper.load_checkpoint(model)
# loading the test negative samples
dataset.load_test_ns()
# final testing
start_test = timeit.default_timer()
perf_metric_test = test(test_loader, neg_sampler, split_mode="test")
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'data': DATA,
'run': run_idx,
'seed': SEED,
f'val {metric}': val_perf_list,
f'test {metric}': perf_metric_test,
'test_time': test_time,
'tot_train_val_time': train_val_time
},
results_filename)
print(f"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<")
print('-------------------------------------------------------------------------------')
print(f"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}")
print("==============================================================")
# #* load numpy arrays instead
# from tgb.linkproppred.dataset import LinkPropPredDataset
# # data loading
# dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
# data = dataset.full_data
# metric = dataset.eval_metric
# sources = dataset.full_data['sources']
# print ("finished loading numpy arrays")
================================================
FILE: examples/linkproppred/thgl-software/STHN_README.md
================================================
STHN method adopted from: https://github.com/celi52/STHN/tree/main
To run:
1. Install requirements. The two new additional requirements for STHN are `pybind11` and `torchmetrics==0.11.0`
2. Compile the sampler
```bash
python sthn_sampler_setup.py build_ext --inplace
```
3. Run the example code
```bash
python sthn.py
```
If the code runs correctly the output would end with
```
INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<<
Test: mrr: X
Test: Elapsed Time (s): Y
```
================================================
FILE: examples/linkproppred/thgl-software/edgebank.py
================================================
"""
Dynamic Link Prediction with EdgeBank
NOTE: This implementation works only based on `numpy`
Reference:
- https://github.com/fpour/DGB/tree/main
"""
import timeit
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch_geometric.loader import TemporalDataLoader
from tqdm import tqdm
import math
import os
import os.path as osp
from pathlib import Path
import sys
import argparse
# internal imports
tgb_modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(tgb_modules_path)
from tgb.linkproppred.evaluate import Evaluator
from modules.edgebank_predictor import EdgeBankPredictor
from tgb.utils.utils import set_random_seed
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.utils.utils import save_results
# ==================
# ==================
# ==================
def test(data, test_mask, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
data: a dataset object
test_mask: required masks to load the test set edges
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluation
"""
num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)
perf_list = []
hits_list = []
for batch_idx in tqdm(range(num_batches)):
start_idx = batch_idx * BATCH_SIZE
end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))
pos_src, pos_dst, pos_t, pos_edge = (
data['sources'][test_mask][start_idx: end_idx],
data['destinations'][test_mask][start_idx: end_idx],
data['timestamps'][test_mask][start_idx: end_idx],
data['edge_type'][test_mask][start_idx: end_idx],
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, pos_edge, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])
query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])
y_pred = edgebank.predict_link(query_src, query_dst)
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0]]),
"y_pred_neg": np.array(y_pred[1:]),
"eval_metric": [metric],
}
results = evaluator.eval(input_dict)
perf_list.append(results[metric])
hits_list.append(results['hits@10'])
# update edgebank memory after each positive batch
edgebank.update_memory(pos_src, pos_dst, pos_t)
perf_metrics = float(np.mean(perf_list))
perf_hits = float(np.mean(hits_list))
return perf_metrics, perf_hits
def get_args():
parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')
parser.add_argument('-d', '--data', type=str, help='Dataset name', default='thgl-software')
parser.add_argument('--bs', type=int, help='Batch size', default=200)
parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)
parser.add_argument('--seed', type=int, help='Random seed', default=1)
parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])
parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)
try:
args = parser.parse_args()
except:
parser.print_help()
sys.exit(0)
return args, sys.argv
# ==================
# ==================
# ==================
start_overall = timeit.default_timer()
# set hyperparameters
args, _ = get_args()
SEED = args.seed # set the random seed for consistency
set_random_seed(SEED)
MEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`
BATCH_SIZE = args.bs
K_VALUE = args.k_value
TIME_WINDOW_RATIO = args.time_window_ratio
DATA = args.data
MODEL_NAME = 'EdgeBank'
# data loading with `numpy`
dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
data = dataset.full_data
metric = dataset.eval_metric
# get masks
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
#data for memory in edgebank
hist_src = np.concatenate([data['sources'][train_mask]])
hist_dst = np.concatenate([data['destinations'][train_mask]])
hist_ts = np.concatenate([data['timestamps'][train_mask]])
# Set EdgeBank with memory updater
edgebank = EdgeBankPredictor(
hist_src,
hist_dst,
hist_ts,
memory_mode=MEMORY_MODE,
time_window_ratio=TIME_WINDOW_RATIO)
print("==========================================================")
print(f"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'
# ==================================================== Test
# loading the validation negative samples
dataset.load_val_ns()
# testing ...
start_val = timeit.default_timer()
perf_metric_val, perf_hits_val = test(data, val_mask, neg_sampler, split_mode='val')
end_val = timeit.default_timer()
print(f"INFO: val: Evaluation Setting: >>> ONE-VS--ALL <<< ")
print(f"\tval: {metric}: {perf_metric_val: .4f}")
val_time = timeit.default_timer() - start_val
print(f"\tval: Elapsed Time (s): {val_time: .4f}")
# ==================================================== Test
# loading the test negative samples
dataset.load_test_ns()
# testing ...
start_test = timeit.default_timer()
perf_metric_test, perf_hits_test = test(data, test_mask, neg_sampler, split_mode='test')
end_test = timeit.default_timer()
print(f"INFO: Test: Evaluation Setting: >>> <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'memory_mode': MEMORY_MODE,
'data': DATA,
'run': 1,
'seed': SEED,
metric: perf_metric_test,
'val_mrr': perf_metric_val,
'test_time': test_time,
'tot_train_val_time': test_time+val_time,
'hits10': perf_hits_test},
results_filename)
================================================
FILE: examples/linkproppred/thgl-software/recurrencybaseline.py
================================================
""" from paper: "History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting"
Julia Gastinger, Christian Meilicke, Federico Errica, Timo Sztyler, Anett Schuelke, Heiner Stuckenschmidt (IJCAI 2024)
@inproceedings{gastinger2024baselines,
title={History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting},
author={Gastinger, Julia and Meilicke, Christian and Errica, Federico and Sztyler, Timo and Schuelke, Anett and Stuckenschmidt, Heiner},
booktitle={33nd International Joint Conference on Artificial Intelligence (IJCAI 2024)},
year={2024},
organization={International Joint Conferences on Artificial Intelligence Organization}
}
"""
## imports
import timeit
import argparse
import numpy as np
from copy import copy
from pathlib import Path
import ray
import sys
import os
import os.path as osp
import json
#internal imports
modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(modules_path)
from modules.recurrencybaseline_predictor import baseline_predict, baseline_predict_remote
from tgb.linkproppred.evaluate import Evaluator
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.utils.utils import set_random_seed, save_results
from modules.tkg_utils import create_basis_dict, group_by, reformat_ts
def predict(num_processes, data_c_rel, all_data_c_rel, alpha, lmbda_psi,
perf_list_all, hits_list_all, window, neg_sampler, split_mode):
""" create predictions for each relation on test or valid set and compute mrr
:return perf_list_all: list of mrrs for each test query
:return hits_list_all: list of hits for each test query
"""
first_ts = data_c_rel[0][3]
## use this if you wanna use ray:
num_queries = len(data_c_rel) // num_processes
if num_queries < num_processes: # if we do not have enough queries for all the processes
num_processes_tmp = 1
num_queries = len(data_c_rel)
else:
num_processes_tmp = num_processes
if num_processes > 1:
object_references =[]
for i in range(num_processes_tmp):
num_test_queries = len(data_c_rel) - (i + 1) * num_queries
if num_test_queries >= num_queries:
test_queries_idx =[i * num_queries, (i + 1) * num_queries]
else:
test_queries_idx = [i * num_queries, len(test_data)]
valid_data_b = data_c_rel[test_queries_idx[0]:test_queries_idx[1]]
ob = baseline_predict_remote.remote(num_queries, valid_data_b, all_data_c_rel, window,
basis_dict,
num_nodes, num_rels, lmbda_psi,
alpha, evaluator,first_ts, neg_sampler, split_mode)
object_references.append(ob)
output = ray.get(object_references)
# updates the scores and logging dict for each process
for proc_loop in range(num_processes_tmp):
perf_list_all.extend(output[proc_loop][0])
hits_list_all.extend(output[proc_loop][1])
else:
perf_list, hits_list = baseline_predict(len(data_c_rel), data_c_rel, all_data_c_rel,
window, basis_dict,
num_nodes, num_rels, lmbda_psi,
alpha, evaluator, first_ts, neg_sampler, split_mode)
perf_list_all.extend(perf_list)
hits_list_all.extend(hits_list)
return perf_list_all, hits_list_all
## test
def test(best_config, all_relations,test_data_prel, all_data_prel, neg_sampler, num_processes, window, split_mode='test'):
""" create predictions by loopoing through all relations on test or valid set and compute mrr
:return perf_list_all: list of mrrs for each test query
:return hits_list_all: list of hits for each test query
"""
perf_list_all = []
hits_list_all =[]
csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'
## loop through relations and apply baselines
for rel in all_relations:
start = timeit.default_timer()
if rel in test_data_prel.keys():
lmbda_psi = best_config[str(rel)]['lmbda_psi'][0]
alpha = best_config[str(rel)]['alpha'][0]
# test data for this relation
test_data_c_rel = test_data_prel[rel]
timesteps_test = list(set(test_data_c_rel[:,3]))
timesteps_test.sort()
all_data_c_rel = all_data_prel[rel]
perf_list_rel = []
hits_list_rel = []
perf_list_rel, hits_list_rel = predict(num_processes, test_data_c_rel,
all_data_c_rel, alpha, lmbda_psi,perf_list_rel, hits_list_rel,
window, neg_sampler, split_mode)
perf_list_all.extend(perf_list_rel)
hits_list_all.extend(hits_list_rel)
else:
perf_list_rel =[]
end = timeit.default_timer()
total_time = round(end - start, 6)
print("Relation {} finished in {} seconds.".format(rel, total_time))
with open(csv_file, 'a') as f:
f.write("{},{}\n".format(rel, perf_list_rel))
return perf_list_all, hits_list_all
def read_dict_compute_mrr(split_mode='test'):
""" read the results per relation from a precreated file and compute mrr
:return mrr_per_rel: dictionary of mrrs for each relation
:return all_mrrs: list of mrrs for all relations
"""
csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'
# Initialize an empty dictionary to store the data
results_per_rel_dict = {}
mrr_per_rel = {}
all_mrrs = []
# Open the file for reading
with open(csv_file, 'r') as f:
# Read each line in the file
for line in f:
# Split the line at the comma
parts = line.strip().split(',')
# Extract the key (the first part)
key = int(parts[0])
# Extract the values (the rest of the parts), remove square brackets
values = [float(value.strip('[]')) for value in parts[1:]]
# Add the key-value pair to the dictionary
if key in results_per_rel_dict.keys():
print(f"Key {key} already exists in the dictionary!!! might have duplicate entries in results csv")
results_per_rel_dict[key] = values
all_mrrs.extend(values)
mrr_per_rel[key] = np.mean(values)
if len(list(results_per_rel_dict.keys())) != num_rels:
print("we do not have entries for each rel in the results csv file. only num enties: ", len(list(results_per_rel_dict.keys())))
print("Split mode: "+split_mode +" Mean MRR: ", np.mean(all_mrrs))
print("mrr per relation: ", mrr_per_rel)
## train
def train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampler, num_processes, window):
""" optional, find best values for lambda and alpha by looping through all relations and testing an a fixed set of params
based on validation mrr
:return best_config: dictionary of best params for each relation
"""
best_config= {}
best_mrr = 0
for rel in rels: # loop through relations. for each relation, apply rules with selected params, compute valid mrr
start = timeit.default_timer()
rel_key = int(rel)
best_config[str(rel_key)] = {}
best_config[str(rel_key)]['not_trained'] = 'True'
best_config[str(rel_key)]['lmbda_psi'] = [default_lmbda_psi,0] #default
best_config[str(rel_key)]['other_lmbda_mrrs'] = list(np.zeros(len(params_dict['lmbda_psi'])))
best_config[str(rel_key)]['alpha'] = [default_alpha,0] #default
best_config[str(rel_key)]['other_alpha_mrrs'] = list(np.zeros(len(params_dict['alpha'])))
if rel in val_data_prel.keys():
# valid data for this relation
val_data_c_rel = val_data_prel[rel]
timesteps_valid = list(set(val_data_c_rel[:,3]))
timesteps_valid.sort()
trainval_data_c_rel = trainval_data_prel[rel]
###### 1) select lambda ###############
lmbdas_psi = params_dict['lmbda_psi']
alpha = 1
best_lmbda_psi = 0.1
best_mrr_psi = 0
lmbda_mrrs = []
best_config[str(rel_key)]['num_app_valid'] = copy(len(val_data_c_rel))
best_config[str(rel_key)]['num_app_train_valid'] = copy(len(trainval_data_c_rel))
best_config[str(rel_key)]['not_trained'] = 'False'
for lmbda_psi in lmbdas_psi:
perf_list_r = []
hits_list_r = []
perf_list_r, hits_list_r = predict(num_processes, val_data_c_rel,
trainval_data_c_rel, alpha, lmbda_psi,perf_list_r, hits_list_r,
window, neg_sampler, split_mode='val')
# compute mrr
mrr = np.mean(perf_list_r)
# # is new mrr better than previous best? if yes: store lmbda
if mrr > best_mrr_psi:
best_mrr_psi = float(mrr)
best_lmbda_psi = lmbda_psi
lmbda_mrrs.append(float(mrr))
best_config[str(rel_key)]['lmbda_psi'] = [best_lmbda_psi, best_mrr_psi]
best_config[str(rel_key)]['other_lmbda_mrrs'] = lmbda_mrrs
best_mrr = best_mrr_psi
##### 2) select alpha ###############
best_config[str(rel_key)]['not_trained'] = 'False'
alphas = params_dict['alpha']
lmbda_psi = best_config[str(rel_key)]['lmbda_psi'][0] # use the best lmbda psi
alpha_mrrs = []
# perf_list_all = []
best_mrr_alpha = 0
best_alpha=0.99
for alpha in alphas:
perf_list_r = []
hits_list_r = []
perf_list_r, hits_list_r = predict(num_processes, val_data_c_rel,
trainval_data_c_rel, alpha, lmbda_psi,perf_list_r, hits_list_r,
window, neg_sampler, split_mode='val')
# compute mrr
mrr_alpha = np.mean(perf_list_r)
# is new mrr better than previous best? if yes: store alpha
if mrr_alpha > best_mrr_alpha:
best_mrr_alpha = float(mrr_alpha)
best_alpha = alpha
best_mrr = best_mrr_alpha
alpha_mrrs.append(float(mrr_alpha))
best_config[str(rel_key)]['alpha'] = [best_alpha, best_mrr_alpha]
best_config[str(rel_key)]['other_alpha_mrrs'] = alpha_mrrs
end = timeit.default_timer()
total_time = round(end - start, 6)
print("Relation {} finished in {} seconds.".format(rel, total_time))
return best_config
## args
def get_args():
"""parse all arguments for the script"""
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", "-d", default="thgl-software", type=str)
parser.add_argument("--window", "-w", default=0, type=int) # set to e.g. 200 if only the most recent 200 timesteps should be considered. set to -2 if multistep
parser.add_argument("--num_processes", "-p", default=1, type=int)
parser.add_argument("--lmbda", "-l", default=0.1, type=float) # fix lambda. used if trainflag == false
parser.add_argument("--alpha", "-alpha", default=0.99, type=float) # fix alpha. used if trainflag == false
parser.add_argument("--train_flag", "-tr", default='False') # do we need training, ie selection of lambda and alpha
parser.add_argument("--load_flag", "-lo", default='False') # if train_flag set to True: do you want to load best_config?
parser.add_argument("--save_config", "-c", default='True') # do we need to save the selection of lambda and alpha in config file?
parser.add_argument('--seed', type=int, help='Random seed', default=1) # not needed
parsed = vars(parser.parse_args())
return parsed
start_o = timeit.default_timer()
parsed = get_args()
if parsed['num_processes']>1:
ray.init(num_cpus=parsed["num_processes"], num_gpus=0)
MODEL_NAME = 'RecurrencyBaseline'
SEED = parsed['seed'] # set the random seed for consistency
set_random_seed(SEED)
perrel_results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_models'
if not osp.exists(perrel_results_path):
os.mkdir(perrel_results_path)
print('INFO: Create directory {}'.format(perrel_results_path))
Path(perrel_results_path).mkdir(parents=True, exist_ok=True)
## load dataset and prepare it accordingly
name = parsed["dataset"]
dataset = LinkPropPredDataset(name=name, root="datasets", preprocess=True)
DATA = name
relations = dataset.edge_type
num_rels = dataset.num_rels
rels = np.arange(0,num_rels)
subjects = dataset.full_data["sources"]
objects= dataset.full_data["destinations"]
num_nodes = dataset.num_nodes
timestamps_orig = dataset.full_data["timestamps"]
timestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1
print("split train valid test data")
all_quads = np.stack((subjects, relations, objects, timestamps, timestamps_orig), axis=1)
train_data = all_quads[dataset.train_mask]
val_data = all_quads[dataset.val_mask]
test_data = all_quads[dataset.test_mask]
metric = dataset.eval_metric
evaluator = Evaluator(name=name)
neg_sampler = dataset.negative_sampler
train_val_data = np.concatenate([train_data, val_data])
all_data = np.concatenate([train_data, val_data, test_data])
# create dicts with key: relation id, values: triples for that relation id
print("grouping data by relation")
test_data_prel = group_by(test_data, 1)
all_data_prel = group_by(all_data, 1)
val_data_prel = group_by(val_data, 1)
trainval_data_prel = group_by(train_val_data, 1)
#load the ns samples
# if parsed['train_flag']:
print("loading negative samples")
dataset.load_val_ns()
dataset.load_test_ns()
# parameter options
if parsed['train_flag'] == 'True':
params_dict = {}
params_dict['lmbda_psi'] = [0, 0.0001, 0.0005, 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.5, 0.9, 1.0001]
params_dict['alpha'] = [0, 0.00001, 0.0001, 0.001, 0.01, 0.1, 0.5, 0.9, 0.99, 0.999, 0.9999, 0.99999, 1]
default_lmbda_psi = params_dict['lmbda_psi'][-1]
default_alpha = params_dict['alpha'][-2]
## load rules
print("creating rules")
basis_dict = create_basis_dict(train_val_data)
print("done with creating rules")
## init
# rb_predictor = RecurrencyBaselinePredictor(rels)
## train to find best lambda and alpha
start_train = timeit.default_timer()
if parsed['train_flag'] == 'True':
if parsed['load_flag'] == 'True':
with open('best_config.json', 'r') as infile:
best_config = json.load(infile)
else:
print('start training')
best_config = train(params_dict, rels, val_data_prel, trainval_data_prel, neg_sampler, parsed['num_processes'],
parsed['window'])
if parsed['save_config'] == 'True':
import json
with open('best_config.json', 'w') as outfile:
json.dump(best_config, outfile)
else: # use preset lmbda and alpha; same for all relations
best_config = {}
for rel in rels:
best_config[str(rel)] = {}
best_config[str(rel)]['lmbda_psi'] = [parsed['lmbda']]
best_config[str(rel)]['alpha'] = [parsed['alpha']]
end_train = timeit.default_timer()
# compute validation mrr
print("Computing validation MRR")
perf_list_all_val, hits_list_all_val = test(best_config,rels, val_data_prel,
trainval_data_prel, neg_sampler, parsed['num_processes'],
parsed['window'], split_mode='val')
val_mrr = float(np.mean(perf_list_all_val))
# compute test mrr
print("Computing test MRR")
start_test = timeit.default_timer()
perf_list_all, hits_list_all = test(best_config,rels, test_data_prel,
all_data_prel, neg_sampler, parsed['num_processes'],
parsed['window'])
end_o = timeit.default_timer()
train_time_o = round(end_train- start_train, 6)
test_time_o = round(end_o- start_test, 6)
total_time_o = round(end_o- start_o, 6)
print("Running Training to find best configs finished in {} seconds.".format(train_time_o))
print("Running testing with best configs finished in {} seconds.".format(test_time_o))
print("Running all steps finished in {} seconds.".format(total_time_o))
print(f"The test MRR is {np.mean(perf_list_all)}")
print(f"The valid MRR is {val_mrr}")
print(f"The Hits@10 is {np.mean(hits_list_all)}")
print(f"We have {len(perf_list_all)} predictions")
print(f"The test set has len {len(test_data)} ")
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_NONE_{DATA}_results.json'
metric = dataset.eval_metric
save_results({'model': MODEL_NAME,
'train_flag': parsed['train_flag'],
'data': DATA,
'run': 1,
'seed': SEED,
metric: float(np.mean(perf_list_all)),
'val_mrr': val_mrr,
'hits10': float(np.mean(hits_list_all)),
'test_time': test_time_o,
'tot_train_val_time': total_time_o
},
results_filename)
if parsed['num_processes']>1:
ray.shutdown()
================================================
FILE: examples/linkproppred/thgl-software/sthn.py
================================================
import timeit
import numpy as np
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
from tgb.linkproppred.evaluate import Evaluator
import argparse
from modules.sthn import set_seed, pre_compute_subgraphs, get_inputs_for_ind, check_data_leakage
import torch
import pandas as pd
import itertools
from tqdm import tqdm
import math
import os
import os.path as osp
from pathlib import Path
from tgb.utils.utils import set_random_seed, save_results
# Start...
start_overall = timeit.default_timer()
DATA = "thgl-software"
MODEL_NAME = 'STHN'
# data loading
dataset = PyGLinkPropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
metric = dataset.eval_metric
print ("there are {} nodes and {} edges".format(dataset.num_nodes, dataset.num_edges))
print ("there are {} relation types".format(dataset.num_rels))
timestamp = data.t
head = data.src
tail = data.dst
edge_type = data.edge_type
neg_sampler = dataset.negative_sampler
print(data)
print(timestamp)
print(head)
print(tail)
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
metric = dataset.eval_metric
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
####################################################################
####################################################################
####################################################################
def print_model_info(model):
print(model)
parameters = filter(lambda p: p.requires_grad, model.parameters())
parameters = sum([np.prod(p.size()) for p in parameters])
print('Trainable Parameters: %d' % parameters)
def get_args():
parser=argparse.ArgumentParser()
parser.add_argument('--data', type=str, default='movie')
parser.add_argument('--device', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=600)
parser.add_argument('--epochs', type=int, default=2)
parser.add_argument('--max_edges', type=int, default=50)
parser.add_argument('--num_edgeType', type=int, default=0, help='num of edgeType')
parser.add_argument('--lr', type=float, default=0.0005)
parser.add_argument('--weight_decay', type=float, default=1e-4)
parser.add_argument('--predict_class', action='store_true')
# model
parser.add_argument('--window_size', type=int, default=5)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--model', type=str, default='sthn')
parser.add_argument('--neg_samples', type=int, default=1)
parser.add_argument('--extra_neg_samples', type=int, default=5)
parser.add_argument('--num_neighbors', type=int, default=50)
parser.add_argument('--channel_expansion_factor', type=int, default=2)
parser.add_argument('--sampled_num_hops', type=int, default=1)
parser.add_argument('--time_dims', type=int, default=100)
parser.add_argument('--hidden_dims', type=int, default=100)
parser.add_argument('--num_layers', type=int, default=1)
parser.add_argument('--check_data_leakage', action='store_true')
parser.add_argument('--ignore_node_feats', action='store_true')
parser.add_argument('--node_feats_as_edge_feats', action='store_true')
parser.add_argument('--ignore_edge_feats', action='store_true')
parser.add_argument('--use_onehot_node_feats', action='store_true')
parser.add_argument('--use_type_feats', action='store_true')
parser.add_argument('--use_graph_structure', action='store_true')
parser.add_argument('--structure_time_gap', type=int, default=2000)
parser.add_argument('--structure_hops', type=int, default=1)
parser.add_argument('--use_node_cls', action='store_true')
parser.add_argument('--use_cached_subgraph', action='store_true')
parser.add_argument('--seed', type=int, help='Random seed', default=1)
parser.add_argument('--num_run', type=int, help='Number of iteration runs', default=5)
return parser.parse_args()
def load_model(args):
# get model
edge_predictor_configs = {
'dim_in_time': args.time_dims,
'dim_in_node': args.node_feat_dims,
'predict_class': 1 if not args.predict_class else args.num_edgeType+1,
}
if args.model == 'sthn':
if args.predict_class:
from modules.sthn import Multiclass_Interface as STHN_Interface
else:
from modules.sthn import STHN_Interface
from modules.sthn import link_pred_train
mixer_configs = {
'per_graph_size' : args.max_edges,
'time_channels' : args.time_dims,
'input_channels' : args.edge_feat_dims,
'hidden_channels' : args.hidden_dims,
'out_channels' : args.hidden_dims,
'num_layers' : args.num_layers,
'dropout' : args.dropout,
'channel_expansion_factor': args.channel_expansion_factor,
'window_size' : args.window_size,
'use_single_layer' : False
}
else:
NotImplementedError()
model = STHN_Interface(mixer_configs, edge_predictor_configs)
for k, v in model.named_parameters():
print(k, v.requires_grad)
print_model_info(model)
return model, args, link_pred_train
def load_graph(data):
df = pd.DataFrame({
'idx': np.arange(len(data.t)),
'src': data.src,
'dst': data.dst,
'time': data.t,
'label': data.edge_type,
})
num_nodes = max(int(df['src'].max()), int(df['dst'].max())) + 1
ext_full_indptr = np.zeros(num_nodes + 1, dtype=np.int32)
ext_full_indices = [[] for _ in range(num_nodes)]
ext_full_ts = [[] for _ in range(num_nodes)]
ext_full_eid = [[] for _ in range(num_nodes)]
for idx, row in tqdm(df.iterrows(), total=len(df)):
src = int(row['src'])
dst = int(row['dst'])
ext_full_indices[src].append(dst)
ext_full_ts[src].append(row['time'])
ext_full_eid[src].append(idx)
for i in tqdm(range(num_nodes)):
ext_full_indptr[i + 1] = ext_full_indptr[i] + len(ext_full_indices[i])
ext_full_indices = np.array(list(itertools.chain(*ext_full_indices)))
ext_full_ts = np.array(list(itertools.chain(*ext_full_ts)))
ext_full_eid = np.array(list(itertools.chain(*ext_full_eid)))
print('Sorting...')
def tsort(i, indptr, indices, t, eid):
beg = indptr[i]
end = indptr[i + 1]
sidx = np.argsort(t[beg:end])
indices[beg:end] = indices[beg:end][sidx]
t[beg:end] = t[beg:end][sidx]
eid[beg:end] = eid[beg:end][sidx]
for i in tqdm(range(ext_full_indptr.shape[0] - 1)):
tsort(i, ext_full_indptr, ext_full_indices, ext_full_ts, ext_full_eid)
print('saving...')
np.savez('/tmp/ext_full.npz', indptr=ext_full_indptr,
indices=ext_full_indices, ts=ext_full_ts, eid=ext_full_eid)
g = np.load('/tmp/ext_full.npz')
return g, df
def load_all_data(args):
# load graph
g, df = load_graph(data)
args.train_mask = train_mask.numpy()
args.val_mask = val_mask.numpy()
args.test_mask = test_mask.numpy()
args.num_edges = len(df)
print('Train %d, Valid %d, Test %d'%(sum(args.train_mask),
sum(args.val_mask),
sum(test_mask)))
args.num_nodes = max(int(df['src'].max()), int(df['dst'].max())) + 1
args.num_edges = len(df)
print('Num nodes %d, num edges %d'%(args.num_nodes, args.num_edges))
# load feats
node_feats, edge_feats = dataset.node_feat, dataset.edge_feat
node_feat_dims = 0 if node_feats is None else node_feats.shape[1]
edge_feat_dims = 0 if edge_feats is None else edge_feats.shape[1]
# feature pre-processing
if args.use_onehot_node_feats:
print('>>> Use one-hot node features')
node_feats = torch.eye(args.num_nodes)
node_feat_dims = node_feats.size(1)
if args.ignore_node_feats:
print('>>> Ignore node features')
node_feats = None
node_feat_dims = 0
if args.use_type_feats:
edge_type = df.label.values
print(edge_type)
print(edge_type.sum())
args.num_edgeType = len(set(edge_type.tolist()))
edge_feats = torch.nn.functional.one_hot(torch.from_numpy(edge_type),
num_classes=args.num_edgeType)
edge_feat_dims = edge_feats.size(1)
print('Node feature dim %d, edge feature dim %d'%(node_feat_dims, edge_feat_dims))
# double check (if data leakage then cannot continue the code)
if args.check_data_leakage:
check_data_leakage(args, g, df)
args.node_feat_dims = node_feat_dims
args.edge_feat_dims = edge_feat_dims
if node_feats != None:
node_feats = node_feats.to(args.device)
if edge_feats != None:
edge_feats = edge_feats.to(args.device)
return node_feats, edge_feats, g, df, args
####################################################################
####################################################################
####################################################################
@torch.no_grad()
def test(data, test_mask, model, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
data: a dataset object
test_mask: required masks to load the test set edges
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'val' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluation
"""
test_subgraphs = pre_compute_subgraphs(args, g, df, mode='test' if split_mode == 'test' else 'valid', negative_sampler=neg_sampler, split_mode=split_mode)
perf_list = []
if split_mode == 'test':
cur_df = df[args.test_mask]
elif split_mode == 'val':
cur_df = df[args.val_mask]
neg_samples = 20
cached_neg_samples = 20
test_loader = cur_df.groupby(cur_df.index // args.batch_size)
pbar = tqdm(total=len(test_loader))
pbar.set_description('%s mode with negative samples %d ...'%(split_mode, neg_samples))
###################################################
# compute + training + fetch all scores
cur_inds = 0
for ind in range(len(test_loader)):
###################################################
inputs, subgraph_node_feats, cur_inds = get_inputs_for_ind(test_subgraphs, 'test' if split_mode == 'test' else 'tgb-val', cached_neg_samples, neg_samples, node_feats, edge_feats, cur_df, cur_inds, ind, args)
loss, pred, edge_label = model(inputs, neg_samples, subgraph_node_feats)
# print(ind, [l for l in inputs], pred.shape)
input_dict = {
"y_pred_pos": np.array([pred.cpu()[0]]),
"y_pred_neg": np.array(pred.cpu()[1:]),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
perf_metrics_mean = float(np.mean(perf_list))
perf_metrics_std = float(np.std(perf_list))
return perf_metrics_mean, perf_metrics_std, perf_list
args = get_args()
args.use_graph_structure = True
args.use_onehot_node_feats = False
args.ignore_node_feats = False # we only use graph structure
args.use_type_feats = True # type encoding
args.use_cached_subgraph = True
print(args)
args.device = f"cuda:{args.device}" if torch.cuda.is_available() else "cpu"
args.device = torch.device(args.device)
SEED = args.seed
BATCH_SIZE = args.batch_size
NUM_RUNS = args.num_run
set_seed(SEED)
###################################################
# load feats + graph
node_feats, edge_feats, g, df, args = load_all_data(args)
###################################################
# get model
model, args, link_pred_train = load_model(args)
###################################################
print("==========================================================")
print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============")
print("==========================================================")
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
for run_idx in range(NUM_RUNS):
print('-------------------------------------------------------------------------------')
print(f"INFO: >>>>> Run: {run_idx} <<<<<")
start_run = timeit.default_timer()
# set the seed for deterministic results...
torch.manual_seed(run_idx + SEED)
set_random_seed(run_idx + SEED)
# define an early stopper
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'
# early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id,
# tolerance=TOLERANCE, patience=PATIENCE)
# ==================================================== Train & Validation
# loading the validation negative samples
# Link prediction
start_val = timeit.default_timer()
print('Train link prediction task from scratch ...')
model = link_pred_train(model.to(args.device), args, g, df, node_feats, edge_feats)
dataset.load_val_ns()
# Validation ...
perf_metrics_val_mean, perf_metrics_val_std, perf_list_val = test(data.to(args.device), test_mask, model.to(args.device), neg_sampler, split_mode='val')
end_val = timeit.default_timer()
print(f"INFO: val: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tval: {metric}: {perf_metrics_val_mean: .4f} ± {perf_metrics_val_std: .4f}")
val_time = timeit.default_timer() - start_val
print(f"\tval: Elapsed Time (s): {val_time: .4f}")
dataset.load_test_ns()
# testing ...
start_test = timeit.default_timer()
perf_metrics_test_mean, perf_metrics_test_std, perf_list_test = test(data.to(args.device), test_mask, model.to(args.device), neg_sampler, split_mode='test')
end_test = timeit.default_timer()
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metrics_test_mean: .4f} ± {perf_metrics_test_std: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'data': DATA,
'run': run_idx,
'seed': SEED,
f'val {metric}': f'{perf_metrics_val_mean: .4f} ± {perf_metrics_val_std: .4f}' ,
f'test {metric}': f'{perf_metrics_test_mean: .4f} ± {perf_metrics_test_std: .4f}' ,
'test_time': test_time,
'tot_train_val_time': val_time
},
results_filename)
print(f"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<")
print('-------------------------------------------------------------------------------')
print(f"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}")
print("==============================================================")
# save_results({'model': MODEL_NAME,
# 'data': DATA,
# 'run': 1,
# 'seed': SEED,
# metric: perf_metric_test,
# 'test_time': test_time,
# 'tot_train_val_time': 'NA'
# },
# results_filename)
================================================
FILE: examples/linkproppred/thgl-software/tgn.py
================================================
import numpy as np
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
from tgb.linkproppred.evaluate import Evaluator
from torch_geometric.loader import TemporalDataLoader
from tqdm import tqdm
import timeit
import math
import timeit
import os
import os.path as osp
from pathlib import Path
import numpy as np
import torch
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn import Linear
from torch_geometric.datasets import JODIEDataset
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TransformerConv
# internal imports
from tgb.utils.utils import get_args, set_random_seed, save_results
from tgb.linkproppred.evaluate import Evaluator
from modules.decoder import LinkPredictor
from modules.emb_module import GraphAttentionEmbedding
from modules.msg_func import IdentityMessage
from modules.msg_agg import LastAggregator
from modules.neighbor_loader import LastNeighborLoader
from modules.memory_module import TGNMemory
from modules.early_stopping import EarlyStopMonitor
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
# ==========
# ========== Define helper function...
# ==========
def train():
r"""
Training procedure for TGN model
This function uses some objects that are globally defined in the current scrips
Parameters:
None
Returns:
None
"""
model['memory'].train()
model['gnn'].train()
model['link_pred'].train()
model['memory'].reset_state() # Start with a fresh memory.
neighbor_loader.reset_state() # Start with an empty graph.
total_loss = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
src, pos_dst, t, msg, rel = batch.src, batch.dst, batch.t, batch.msg, batch.edge_type
# Sample negative destination nodes.
neg_dst = torch.randint(
min_dst_idx,
max_dst_idx + 1,
(src.size(0),),
dtype=torch.long,
device=device,
)
n_id = torch.cat([src, pos_dst, neg_dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
pos_out = model['link_pred'](z[assoc[src]], z[assoc[pos_dst]])
neg_out = model['link_pred'](z[assoc[src]], z[assoc[neg_dst]])
loss = criterion(pos_out, torch.ones_like(pos_out))
loss += criterion(neg_out, torch.zeros_like(neg_out))
# Update memory and neighbor loader with ground-truth state.
model['memory'].update_state(src, pos_dst, t, msg)
neighbor_loader.insert(src, pos_dst)
loss.backward()
optimizer.step()
model['memory'].detach()
total_loss += float(loss.detach()) * batch.num_events
return total_loss / train_data.num_events
@torch.no_grad()
def test(loader, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
loader: an object containing positive attributes of the positive edges of the evaluation set
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluaiton
"""
model['memory'].eval()
model['gnn'].eval()
model['link_pred'].eval()
perf_list = []
for pos_batch in loader:
pos_src, pos_dst, pos_t, pos_msg, pos_rel = (
pos_batch.src,
pos_batch.dst,
pos_batch.t,
pos_batch.msg,
pos_batch.edge_type
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, pos_rel, split_mode=split_mode)
# pos_msg_new = torch.cat([pos_msg,pos_rel.unsqueeze(dim=1)], dim=1)
for idx, neg_batch in enumerate(neg_batch_list):
src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)
dst = torch.tensor(
np.concatenate(
([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),
axis=0,
),
device=device,
)
n_id = torch.cat([src, dst]).unique()
n_id, edge_index, e_id = neighbor_loader(n_id)
assoc[n_id] = torch.arange(n_id.size(0), device=device)
# Get updated memory of all nodes involved in the computation.
z, last_update = model['memory'](n_id)
z = model['gnn'](
z,
last_update,
edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
y_pred = model['link_pred'](z[assoc[src]], z[assoc[dst]])
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),
"y_pred_neg": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),
"eval_metric": [metric],
}
perf_list.append(evaluator.eval(input_dict)[metric])
# Update memory and neighbor loader with ground-truth state.
model['memory'].update_state(pos_src, pos_dst, pos_t, pos_msg)
neighbor_loader.insert(pos_src, pos_dst)
perf_metrics = float(torch.tensor(perf_list).mean())
return perf_metrics
# ==========
# ==========
# ==========
# ==========
# ==========
# ==========
# Start...
start_overall = timeit.default_timer()
DATA = "thgl-software"
# ========== set parameters...
args, _ = get_args()
args.data = DATA
print("INFO: Arguments:", args)
LR = args.lr
BATCH_SIZE = args.bs
K_VALUE = args.k_value
NUM_EPOCH = args.num_epoch
SEED = args.seed
MEM_DIM = args.mem_dim
TIME_DIM = args.time_dim
EMB_DIM = args.emb_dim
TOLERANCE = args.tolerance
PATIENCE = args.patience
NUM_RUNS = args.num_run
NUM_NEIGHBORS = 10
USE_EDGE_TYPE = True
USE_NODE_TYPE = True
MODEL_NAME = 'TGN'
# ==========
# set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# data loading
dataset = PyGLinkPropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data = data.to(device)
metric = dataset.eval_metric
print ("there are {} nodes and {} edges".format(dataset.num_nodes, dataset.num_edges))
print ("there are {} relation types".format(dataset.num_rels))
timestamp = data.t
head = data.src
tail = data.dst
edge_type = data.edge_type #relation
edge_type_dim = len(torch.unique(edge_type))
embed_edge_type = torch.nn.Embedding(edge_type_dim, 128).to(device)
with torch.no_grad():
edge_type_embeddings = embed_edge_type(edge_type)
if USE_EDGE_TYPE:
data.msg = torch.cat([data.msg, edge_type_embeddings], dim=1)
#! node type is a property of the dataset not the temporal data as temporal data has one entry per edge
node_type = dataset.node_type #node type
neg_sampler = dataset.negative_sampler
data.__setattr__("node_type", node_type)
print ("shape of edge type is", edge_type.shape)
print ("shape of node type is", node_type.shape)
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
print ("finished loading PyG data")
train_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)
val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)
test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)
start_time = timeit.default_timer()
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
print("==========================================================")
print(f"=================*** {MODEL_NAME}: LinkPropPred: {DATA} ***=============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
for run_idx in range(NUM_RUNS):
print('-------------------------------------------------------------------------------')
print(f"INFO: >>>>> Run: {run_idx} <<<<<")
start_run = timeit.default_timer()
# set the seed for deterministic results...
torch.manual_seed(run_idx + SEED)
set_random_seed(run_idx + SEED)
# neighhorhood sampler
neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)
# define the model end-to-end
memory = TGNMemory(
data.num_nodes,
data.msg.size(-1),
MEM_DIM,
TIME_DIM,
message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),
aggregator_module=LastAggregator(),
).to(device)
gnn = GraphAttentionEmbedding(
in_channels=MEM_DIM,
out_channels=EMB_DIM,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
).to(device)
link_pred = LinkPredictor(in_channels=EMB_DIM).to(device)
model = {'memory': memory,
'gnn': gnn,
'link_pred': link_pred}
optimizer = torch.optim.Adam(
set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['link_pred'].parameters()),
lr=LR,
)
criterion = torch.nn.BCEWithLogitsLoss()
# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)
# define an early stopper
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
save_model_id = f'{MODEL_NAME}_{DATA}_{SEED}_{run_idx}'
early_stopper = EarlyStopMonitor(save_model_dir=save_model_dir, save_model_id=save_model_id,
tolerance=TOLERANCE, patience=PATIENCE)
# ==================================================== Train & Validation
# loading the validation negative samples
dataset.load_val_ns()
val_perf_list = []
start_train_val = timeit.default_timer()
for epoch in range(1, NUM_EPOCH + 1):
# training
start_epoch_train = timeit.default_timer()
loss = train()
print(
f"Epoch: {epoch:02d}, Loss: {loss:.4f}, Training elapsed Time (s): {timeit.default_timer() - start_epoch_train: .4f}"
)
# validation
start_val = timeit.default_timer()
perf_metric_val = test(val_loader, neg_sampler, split_mode="val")
print(f"\tValidation {metric}: {perf_metric_val: .4f}")
print(f"\tValidation: Elapsed time (s): {timeit.default_timer() - start_val: .4f}")
val_perf_list.append(perf_metric_val)
# check for early stopping
if early_stopper.step_check(perf_metric_val, model):
break
train_val_time = timeit.default_timer() - start_train_val
print(f"Train & Validation: Elapsed Time (s): {train_val_time: .4f}")
# ==================================================== Test
# first, load the best model
early_stopper.load_checkpoint(model)
# loading the test negative samples
dataset.load_test_ns()
# final testing
start_test = timeit.default_timer()
perf_metric_test = test(test_loader, neg_sampler, split_mode="test")
print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'data': DATA,
'run': run_idx,
'seed': SEED,
f'val {metric}': val_perf_list,
f'test {metric}': perf_metric_test,
'test_time': test_time,
'tot_train_val_time': train_val_time
},
results_filename)
print(f"INFO: >>>>> Run: {run_idx}, elapsed time: {timeit.default_timer() - start_run: .4f} <<<<<")
print('-------------------------------------------------------------------------------')
print(f"Overall Elapsed Time (s): {timeit.default_timer() - start_overall: .4f}")
print("==============================================================")
# #* load numpy arrays instead
# from tgb.linkproppred.dataset import LinkPropPredDataset
# # data loading
# dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
# data = dataset.full_data
# metric = dataset.eval_metric
# sources = dataset.full_data['sources']
# print ("finished loading numpy arrays")
================================================
FILE: examples/linkproppred/tkgl-icews/cen.py
================================================
"""
Complex Evolutional Pattern Learning for Temporal Knowledge Graph Reasoning
Reference:
- https://github.com/Lee-zix/CEN
Zixuan Li, Saiping Guan, Xiaolong Jin, Weihua Peng, Yajuan Lyu , Yong Zhu, Long Bai, Wei Li, Jiafeng Guo, Xueqi Cheng.
Complex Evolutional Pattern Learning for Temporal Knowledge Graph Reasoning. ACL 2022.
"""
import timeit
import os
import sys
import os.path as osp
from pathlib import Path
import numpy as np
import torch
import random
from tqdm import tqdm
import json
# internal imports
modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(modules_path)
from modules.rrgcn import RecurrentRGCNCEN
from tgb.utils.utils import set_random_seed, split_by_time, save_results
from modules.tkg_utils import get_args_cen, reformat_ts
from modules.tkg_utils_dgl import build_sub_graph
from tgb.linkproppred.evaluate import Evaluator
from tgb.linkproppred.dataset import LinkPropPredDataset
def test(model, history_len, history_list, test_list, num_rels, num_nodes, use_cuda, model_name, mode, split_mode):
"""
Test the model
:param model: model used to test
:param history_list: all input history snap shot list, not include output label train list or valid list
:param test_list: test triple snap shot list
:param num_rels: number of relations
:param num_nodes: number of nodes
:param use_cuda:
:param model_name:
:param mode:
:param split_mode: 'test' or 'val' to state which negative samples to load
:return mrr
"""
print("Testing for mode: ", split_mode)
if split_mode == 'test':
timesteps_to_eval = test_timestamps_orig
else:
timesteps_to_eval = val_timestamps_orig
idx = 0
if mode == "test":
# test mode: load parameter form file
if use_cuda:
checkpoint = torch.load(model_name, map_location=torch.device(args.gpu))
else:
checkpoint = torch.load(model_name, map_location=torch.device('cpu'))
# use best stat checkpoint:
print("Load Model name: {}. Using best epoch : {}".format(model_name, checkpoint['epoch']))
print("\n"+"-"*10+"start testing"+"-"*10+"\n")
model.load_state_dict(checkpoint['state_dict'])
model.eval()
input_list = [snap for snap in history_list[-history_len:]]
perf_list_all = []
hits_list_all = []
perf_per_rel = {}
for rel in range(num_rels):
perf_per_rel[rel] = []
for time_idx, test_snap in enumerate(tqdm(test_list)):
history_glist = [build_sub_graph(num_nodes, num_rels, g, use_cuda, args.gpu) for g in input_list]
test_triples_input = torch.LongTensor(test_snap).cuda() if use_cuda else torch.LongTensor(test_snap)
timesteps_batch =timesteps_to_eval[time_idx]*np.ones(len(test_triples_input[:,0]))
neg_samples_batch = neg_sampler.query_batch(test_triples_input[:,0], test_triples_input[:,2],
timesteps_batch, edge_type=test_triples_input[:,1], split_mode=split_mode)
pos_samples_batch = test_triples_input[:,2]
_, perf_list, hits_list = model.predict(history_glist, test_triples_input, use_cuda, neg_samples_batch, pos_samples_batch,
evaluator, METRIC)
perf_list_all.extend(perf_list)
hits_list_all.extend(hits_list)
if split_mode == "test":
if args.log_per_rel:
for score, rel in zip(perf_list, test_triples_input[:,1].tolist()):
perf_per_rel[rel].append(score)
# reconstruct history graph list
input_list.pop(0)
input_list.append(test_snap)
idx += 1
if split_mode == "test":
if args.log_per_rel:
for rel in range(num_rels):
if len(perf_per_rel[rel]) > 0:
perf_per_rel[rel] = float(np.mean(perf_per_rel[rel]))
else:
perf_per_rel.pop(rel)
mrr = np.mean(perf_list_all)
hits10 = np.mean(hits_list_all)
return mrr, perf_per_rel, hits10
def run_experiment(args, trainvalidtest_id=0, n_hidden=None, n_layers=None, dropout=None, n_bases=None):
'''
Run experiment for CEN model
:param args: arguments for the model
:param trainvalidtest_id: -1: pretrainig, 0: curriculum training (to find best test history len), 1: test on valid set, 2: test on test set
:param n_hidden: number of hidden units
:param n_layers: number of layers
:param dropout: dropout rate
:param n_bases: number of bases
return: mrr, perf_per_rel: mean reciprocal rank and performance per relation
'''
# 1) load configuration for grid search the best configuration
if n_hidden:
args.n_hidden = n_hidden
if n_layers:
args.n_layers = n_layers
if dropout:
args.dropout = dropout
if n_bases:
args.n_bases = n_bases
test_history_len = args.test_history_len
# 2) set save model path
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
if not osp.exists(save_model_dir):
os.mkdir(save_model_dir)
print('INFO: Create directory {}'.format(save_model_dir))
test_model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{args.test_history_len}'
test_state_file = save_model_dir+test_model_name
perf_per_rel ={}
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
# create stat
model = RecurrentRGCNCEN(args.decoder,
args.encoder,
num_nodes,
num_rels,
args.n_hidden,
args.opn,
sequence_len=args.train_history_len,
num_bases=args.n_bases,
num_basis=args.n_basis,
num_hidden_layers=args.n_layers,
dropout=args.dropout,
self_loop=args.self_loop,
skip_connect=args.skip_connect,
layer_norm=args.layer_norm,
input_dropout=args.input_dropout,
hidden_dropout=args.hidden_dropout,
feat_dropout=args.feat_dropout,
entity_prediction=args.entity_prediction,
relation_prediction=args.relation_prediction,
use_cuda=use_cuda,
gpu = args.gpu)
if use_cuda:
torch.cuda.set_device(args.gpu)
model.cuda()
# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
if trainvalidtest_id == 1: # normal test on validation set Note that mode=test
if os.path.exists(test_state_file):
mrr, _, hits10 = test(model, args.test_history_len, train_list, valid_list, num_rels, num_nodes, use_cuda,
test_state_file, "test", split_mode="val")
else:
print('Cannot do testing because model does not exist: ', test_state_file)
mrr = 0
hits10 = 0
elif trainvalidtest_id == 2: # normal test on test set
if os.path.exists(test_state_file):
mrr, perf_per_rel, hits10 = test(model, args.test_history_len, train_list+valid_list, test_list, num_rels, num_nodes, use_cuda,
test_state_file, "test", split_mode="test")
else:
print('Cannot do testing because model does not exist: ', test_state_file)
mrr = 0
hits10 = 0
elif trainvalidtest_id == -1:
print("-------------start pre training model with history length {}----------\n".format(args.start_history_len))
model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{args.start_history_len}'
model_state_file = save_model_dir + model_name
print("Sanity Check: stat name : {}".format(model_state_file))
print("Sanity Check: Is cuda available ? {}".format(torch.cuda.is_available()))
best_mrr = 0
best_epoch = 0
best_hits10= 0
## training loop
for epoch in range(args.n_epochs):
model.train()
losses = []
idx = [_ for _ in range(len(train_list))]
random.shuffle(idx)
for train_sample_num in idx:
if train_sample_num == 0 or train_sample_num == 1: continue
if train_sample_num - args.start_history_len<0:
input_list = train_list[0: train_sample_num]
output = train_list[1:train_sample_num+1]
else:
input_list = train_list[train_sample_num-args.start_history_len: train_sample_num]
output = train_list[train_sample_num-args.start_history_len+1:train_sample_num+1]
# generate history graph
history_glist = [build_sub_graph(num_nodes, num_rels, snap, use_cuda, args.gpu) for snap in input_list]
output = [torch.from_numpy(_).long().cuda() for _ in output] if use_cuda else [torch.from_numpy(_).long() for _ in output]
loss= model.get_loss(history_glist, output[-1], None, use_cuda)
losses.append(loss.item())
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm) # clip gradients
optimizer.step()
optimizer.zero_grad()
print("His {:04d}, Epoch {:04d} | Ave Loss: {:.4f} | Best MRR {:.4f} | Model {} "
.format(args.start_history_len, epoch, np.mean(losses), best_mrr, model_name))
#! checking GPU usage
free_mem, total_mem = torch.cuda.mem_get_info()
print ("--------------GPU memory usage-----------")
print ("there are ", free_mem, " free memory")
print ("there are ", total_mem, " total available memory")
print ("there are ", total_mem - free_mem, " used memory")
print ("--------------GPU memory usage-----------")
# validation
if epoch % args.evaluate_every == 0:
mrr, _, hits10 = test(model, args.start_history_len, train_list, valid_list, num_rels, num_nodes, use_cuda,
model_state_file, mode="train", split_mode= "val")
if mrr< best_mrr:
if epoch >= args.n_epochs or epoch - best_epoch > 5:
break
else:
best_mrr = mrr
best_epoch = epoch
best_hits10 = hits10
torch.save({'state_dict': model.state_dict(), 'epoch': epoch}, model_state_file)
elif trainvalidtest_id == 0: #curriculum training
model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{args.start_history_len}'
init_state_file = save_model_dir + model_name
init_checkpoint = torch.load(init_state_file, map_location=torch.device(args.gpu))
# use best stat checkpoint:
print("Load Previous Model name: {}. Using best epoch : {}".format(init_state_file, init_checkpoint['epoch']))
print("\n"+"-"*10+"Load model with history length {}".format(args.start_history_len)+"-"*10+"\n")
model.load_state_dict(init_checkpoint['state_dict'])
test_history_len = args.start_history_len
mrr, _, hits10 = test(model,
args.start_history_len,
train_list,
valid_list,
num_rels,
num_nodes,
use_cuda,
init_state_file,
mode="test", split_mode= "val")
best_mrr_list = [mrr.item()]
best_hits_list = [hits10.item()]
# start knowledge distillation
ks_idx = 0
for history_len in range(args.start_history_len+1, args.train_history_len+1, 1):
# current model
print("best mrr list :", best_mrr_list)
# lr = 0.1*args.lr - 0.002*args.lr*ks_idx
optimizer = torch.optim.Adam(model.parameters(), lr=0.1*args.lr, weight_decay=0.00001)
model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{history_len}'
model_state_file = save_model_dir + model_name
print("Sanity Check: stat name : {}".format(model_state_file))
# load model with the least history length
prev_model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{history_len-1}'
prev_state_file = save_model_dir + prev_model_name
checkpoint = torch.load(prev_state_file, map_location=torch.device(args.gpu))
model.load_state_dict(checkpoint['state_dict'])
print("\n"+"-"*10+"start knowledge distillation for history length at "+ str(history_len)+"-"*10+"\n")
best_mrr = 0
best_hits10 = 0
best_epoch = 0
for epoch in range(args.n_epochs):
model.train()
losses = []
idx = [_ for _ in range(len(train_list))]
random.shuffle(idx)
for train_sample_num in idx:
if train_sample_num == 0 or train_sample_num == 1: continue
if train_sample_num - history_len<0:
input_list = train_list[0: train_sample_num]
output = train_list[1:train_sample_num+1]
else:
input_list = train_list[train_sample_num - history_len: train_sample_num]
output = train_list[train_sample_num-history_len+1:train_sample_num+1]
# generate history graph
history_glist = [build_sub_graph(num_nodes, num_rels, snap, use_cuda, args.gpu) for snap in input_list]
output = [torch.from_numpy(_).long().cuda() for _ in output] if use_cuda else [torch.from_numpy(_).long() for _ in output]
loss= model.get_loss(history_glist, output[-1], None, use_cuda)
# print(loss)
losses.append(loss.item())
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm) # clip gradients
optimizer.step()
optimizer.zero_grad()
print("His {:04d}, Epoch {:04d} | Ave Loss: {:.4f} |Best MRR {:.4f} | Model {} "
.format(history_len, epoch, np.mean(losses), best_mrr, model_name))
#! checking GPU usage
free_mem, total_mem = torch.cuda.mem_get_info()
print ("--------------GPU memory usage-----------")
print ("there are ", free_mem, " free memory")
print ("there are ", total_mem, " total available memory")
print ("there are ", total_mem - free_mem, " used memory")
print ("--------------GPU memory usage-----------")
# validation
if epoch % args.evaluate_every == 0:
mrr, _, hits10 = test(model, history_len, train_list, valid_list, num_rels, num_nodes, use_cuda,
model_state_file, mode="train", split_mode= "val")
if mrr< best_mrr:
if epoch >= args.n_epochs or epoch-best_epoch>2:
break
else:
best_mrr = mrr
best_epoch = epoch
best_hits10 = hits10
torch.save({'state_dict': model.state_dict(), 'epoch': epoch}, model_state_file)
mrr, _, hits10 = test(model, history_len, train_list, valid_list, num_rels, num_nodes, use_cuda,
model_state_file, mode="test", split_mode= "val")
ks_idx += 1
if mrr.item() < max(best_mrr_list):
test_history_len = history_len-1
print("early stopping, best history length: ", test_history_len)
break
else:
best_mrr_list.append(mrr.item())
best_hits_list.append(hits10.item())
return mrr, test_history_len, perf_per_rel, hits10
# ==================
# ==================
# ==================
start_overall = timeit.default_timer()
# set hyperparameters
args, _ = get_args_cen()
args.dataset = 'tkgl-icews'
SEED = args.seed # set the random seed for consistency
set_random_seed(SEED)
DATA=args.dataset
MODEL_NAME = 'CEN'
print("logging mrrs per relation: ", args.log_per_rel)
print("do test and valid? do only test no validation?: ", args.validtest, args.test_only)
# load data
dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
num_rels = dataset.num_rels
num_nodes = dataset.num_nodes
subjects = dataset.full_data["sources"]
objects= dataset.full_data["destinations"]
relations = dataset.edge_type
timestamps_orig = dataset.full_data["timestamps"]
timestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1
all_quads = np.stack((subjects, relations, objects, timestamps), axis=1)
train_data = all_quads[dataset.train_mask]
val_data = all_quads[dataset.val_mask]
test_data = all_quads[dataset.test_mask]
val_timestamps_orig = list(set(timestamps_orig[dataset.val_mask])) # needed for getting the negative samples
val_timestamps_orig.sort()
test_timestamps_orig = list(set(timestamps_orig[dataset.test_mask])) # needed for getting the negative samples
test_timestamps_orig.sort()
train_list = split_by_time(train_data)
valid_list = split_by_time(val_data)
test_list = split_by_time(test_data)
# evaluation metric
METRIC = dataset.eval_metric
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
#load the ns samples
dataset.load_val_ns()
dataset.load_test_ns()
if args.grid_search:
print("TODO: implement hyperparameter grid search")
# single run
else:
start_train = timeit.default_timer()
if args.validtest:
print('directly start testing')
if args.test_history_len_2 != args.test_history_len:
args.test_history_len = args.test_history_len_2 # hyperparameter value as given in original paper
else:
print('running pretrain and train')
# pretrain
mrr, _, _, hits10 = run_experiment(args, trainvalidtest_id=-1)
# train
mrr, args.test_history_len, _, hits10 = run_experiment(args, trainvalidtest_id=0) # overwrite test_history_len with
# the best history len (for valid mrr)
if args.test_only == False:
print("running test (on val and test dataset) with test_history_len of: ", args.test_history_len)
# test on val set
val_mrr, _, _, val_hits10 = run_experiment(args, trainvalidtest_id=1)
else:
val_mrr = 0
val_hits10 = 0
# test on test set
start_test = timeit.default_timer()
test_mrr, _, perf_per_rel, test_hits10 = run_experiment(args, trainvalidtest_id=2)
test_time = timeit.default_timer() - start_test
all_time = timeit.default_timer() - start_train
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
print(f"\Train and Test: Elapsed Time (s): {all_time: .4f}")
print(f"\tTest: {METRIC}: {test_mrr: .4f}")
print(f"\tValid: {METRIC}: {val_mrr: .4f}")
# saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
save_results({'model': MODEL_NAME,
'data': DATA,
'run': args.run_nr,
'seed': SEED,
'test_history_len': args.test_history_len,
f'val {METRIC}': float(val_mrr),
f'test {METRIC}': float(test_mrr),
'test_time': test_time,
'tot_train_val_time': all_time,
'test_hits10': float(test_hits10)
},
results_filename)
if args.log_per_rel:
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results_per_rel.json'
with open(results_filename, 'w') as json_file:
json.dump(perf_per_rel, json_file)
sys.exit()
================================================
FILE: examples/linkproppred/tkgl-icews/edgebank.py
================================================
"""
Dynamic Link Prediction with EdgeBank
NOTE: This implementation works only based on `numpy`
Reference:
- https://github.com/fpour/DGB/tree/main
"""
import timeit
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch_geometric.loader import TemporalDataLoader
from tqdm import tqdm
import math
import os
import os.path as osp
from pathlib import Path
import sys
import argparse
# internal imports
tgb_modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(tgb_modules_path)
from tgb.linkproppred.evaluate import Evaluator
from modules.edgebank_predictor import EdgeBankPredictor
from tgb.utils.utils import set_random_seed
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.utils.utils import save_results
# ==================
# ==================
# ==================
def test(data, test_mask, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
data: a dataset object
test_mask: required masks to load the test set edges
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluation
"""
num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)
perf_list = []
hits_list = []
for batch_idx in tqdm(range(num_batches)):
start_idx = batch_idx * BATCH_SIZE
end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))
pos_src, pos_dst, pos_t, pos_edge = (
data['sources'][test_mask][start_idx: end_idx],
data['destinations'][test_mask][start_idx: end_idx],
data['timestamps'][test_mask][start_idx: end_idx],
data['edge_type'][test_mask][start_idx: end_idx],
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, pos_edge, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])
query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])
y_pred = edgebank.predict_link(query_src, query_dst)
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0]]),
"y_pred_neg": np.array(y_pred[1:]),
"eval_metric": [metric],
}
results = evaluator.eval(input_dict)
perf_list.append(results[metric])
hits_list.append(results['hits@10'])
# update edgebank memory after each positive batch
edgebank.update_memory(pos_src, pos_dst, pos_t)
perf_metrics = float(np.mean(perf_list))
perf_hits = float(np.mean(hits_list))
return perf_metrics, perf_hits
def get_args():
parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')
parser.add_argument('-d', '--data', type=str, help='Dataset name', default='tkgl-icews')
parser.add_argument('--bs', type=int, help='Batch size', default=200)
parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)
parser.add_argument('--seed', type=int, help='Random seed', default=1)
parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])
parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)
try:
args = parser.parse_args()
except:
parser.print_help()
sys.exit(0)
return args, sys.argv
# ==================
# ==================
# ==================
start_overall = timeit.default_timer()
# set hyperparameters
args, _ = get_args()
SEED = args.seed # set the random seed for consistency
set_random_seed(SEED)
MEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`
BATCH_SIZE = args.bs
K_VALUE = args.k_value
TIME_WINDOW_RATIO = args.time_window_ratio
DATA = args.data
MODEL_NAME = 'EdgeBank'
# data loading with `numpy`
dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
data = dataset.full_data
metric = dataset.eval_metric
# get masks
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
#data for memory in edgebank
hist_src = np.concatenate([data['sources'][train_mask]])
hist_dst = np.concatenate([data['destinations'][train_mask]])
hist_ts = np.concatenate([data['timestamps'][train_mask]])
# Set EdgeBank with memory updater
edgebank = EdgeBankPredictor(
hist_src,
hist_dst,
hist_ts,
memory_mode=MEMORY_MODE,
time_window_ratio=TIME_WINDOW_RATIO)
print("==========================================================")
print(f"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'
# ==================================================== Test
# loading the validation negative samples
dataset.load_val_ns()
# testing ...
start_val = timeit.default_timer()
perf_metric_val, perf_hits_val = test(data, val_mask, neg_sampler, split_mode='val')
end_val = timeit.default_timer()
print(f"INFO: val: Evaluation Setting: >>> ONE-VS--ALL <<< ")
print(f"\tval: {metric}: {perf_metric_val: .4f}")
val_time = timeit.default_timer() - start_val
print(f"\tval: Elapsed Time (s): {val_time: .4f}")
# ==================================================== Test
# loading the test negative samples
dataset.load_test_ns()
# testing ...
start_test = timeit.default_timer()
perf_metric_test, perf_hits_test = test(data, test_mask, neg_sampler, split_mode='test')
end_test = timeit.default_timer()
print(f"INFO: Test: Evaluation Setting: >>> <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'memory_mode': MEMORY_MODE,
'data': DATA,
'run': 1,
'seed': SEED,
metric: perf_metric_test,
'val_mrr': perf_metric_val,
'test_time': test_time,
'tot_train_val_time': test_time+val_time,
'hits10': perf_hits_test},
results_filename)
================================================
FILE: examples/linkproppred/tkgl-icews/recurrencybaseline.py
================================================
""" from paper: "History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting"
Julia Gastinger, Christian Meilicke, Federico Errica, Timo Sztyler, Anett Schuelke, Heiner Stuckenschmidt (IJCAI 2024)
@inproceedings{gastinger2024baselines,
title={History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting},
author={Gastinger, Julia and Meilicke, Christian and Errica, Federico and Sztyler, Timo and Schuelke, Anett and Stuckenschmidt, Heiner},
booktitle={33nd International Joint Conference on Artificial Intelligence (IJCAI 2024)},
year={2024},
organization={International Joint Conferences on Artificial Intelligence Organization}
}
"""
## imports
import timeit
import argparse
import numpy as np
from copy import copy
from pathlib import Path
import ray
import sys
import os
import os.path as osp
import json
#internal imports
modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(modules_path)
from modules.recurrencybaseline_predictor import baseline_predict, baseline_predict_remote
from tgb.linkproppred.evaluate import Evaluator
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.utils.utils import set_random_seed, save_results
from modules.tkg_utils import create_basis_dict, group_by, reformat_ts
def predict(num_processes, data_c_rel, all_data_c_rel, alpha, lmbda_psi,
perf_list_all, hits_list_all, window, neg_sampler, split_mode):
""" create predictions for each relation on test or valid set and compute mrr
:return perf_list_all: list of mrrs for each test query
:return hits_list_all: list of hits for each test query
"""
first_ts = data_c_rel[0][3]
## use this if you wanna use ray:
num_queries = len(data_c_rel) // num_processes
if num_queries < num_processes: # if we do not have enough queries for all the processes
num_processes_tmp = 1
num_queries = len(data_c_rel)
else:
num_processes_tmp = num_processes
if num_processes > 1:
object_references =[]
for i in range(num_processes_tmp):
num_test_queries = len(data_c_rel) - (i + 1) * num_queries
if num_test_queries >= num_queries:
test_queries_idx =[i * num_queries, (i + 1) * num_queries]
else:
test_queries_idx = [i * num_queries, len(test_data)]
valid_data_b = data_c_rel[test_queries_idx[0]:test_queries_idx[1]]
ob = baseline_predict_remote.remote(num_queries, valid_data_b, all_data_c_rel, window,
basis_dict,
num_nodes, num_rels, lmbda_psi,
alpha, evaluator,first_ts, neg_sampler, split_mode)
object_references.append(ob)
output = ray.get(object_references)
# updates the scores and logging dict for each process
for proc_loop in range(num_processes_tmp):
perf_list_all.extend(output[proc_loop][0])
hits_list_all.extend(output[proc_loop][1])
else:
perf_list, hits_list = baseline_predict(len(data_c_rel), data_c_rel, all_data_c_rel,
window, basis_dict,
num_nodes, num_rels, lmbda_psi,
alpha, evaluator, first_ts, neg_sampler, split_mode)
perf_list_all.extend(perf_list)
hits_list_all.extend(hits_list)
return perf_list_all, hits_list_all
## test
def test(best_config, all_relations,test_data_prel, all_data_prel, neg_sampler, num_processes, window, split_mode='test'):
""" create predictions by loopoing through all relations on test or valid set and compute mrr
:return perf_list_all: list of mrrs for each test query
:return hits_list_all: list of hits for each test query
"""
perf_list_all = []
hits_list_all =[]
csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'
## loop through relations and apply baselines
for rel in all_relations:
start = timeit.default_timer()
if rel in test_data_prel.keys():
lmbda_psi = best_config[str(rel)]['lmbda_psi'][0]
alpha = best_config[str(rel)]['alpha'][0]
# test data for this relation
test_data_c_rel = test_data_prel[rel]
timesteps_test = list(set(test_data_c_rel[:,3]))
timesteps_test.sort()
all_data_c_rel = all_data_prel[rel]
perf_list_rel = []
hits_list_rel = []
perf_list_rel, hits_list_rel = predict(num_processes, test_data_c_rel,
all_data_c_rel, alpha, lmbda_psi,perf_list_rel, hits_list_rel,
window, neg_sampler, split_mode)
perf_list_all.extend(perf_list_rel)
hits_list_all.extend(hits_list_rel)
else:
perf_list_rel =[]
end = timeit.default_timer()
total_time = round(end - start, 6)
print("Relation {} finished in {} seconds.".format(rel, total_time))
with open(csv_file, 'a') as f:
f.write("{},{}\n".format(rel, perf_list_rel))
return perf_list_all, hits_list_all
def read_dict_compute_mrr(split_mode='test'):
""" read the results per relation from a precreated file and compute mrr
:return mrr_per_rel: dictionary of mrrs for each relation
:return all_mrrs: list of mrrs for all relations
"""
csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'
# Initialize an empty dictionary to store the data
results_per_rel_dict = {}
mrr_per_rel = {}
all_mrrs = []
# Open the file for reading
with open(csv_file, 'r') as f:
# Read each line in the file
for line in f:
# Split the line at the comma
parts = line.strip().split(',')
# Extract the key (the first part)
key = int(parts[0])
# Extract the values (the rest of the parts), remove square brackets
values = [float(value.strip('[]')) for value in parts[1:]]
# Add the key-value pair to the dictionary
if key in results_per_rel_dict.keys():
print(f"Key {key} already exists in the dictionary!!! might have duplicate entries in results csv")
results_per_rel_dict[key] = values
all_mrrs.extend(values)
mrr_per_rel[key] = np.mean(values)
if len(list(results_per_rel_dict.keys())) != num_rels:
print("we do not have entries for each rel in the results csv file. only num enties: ", len(list(results_per_rel_dict.keys())))
print("Split mode: "+split_mode +" Mean MRR: ", np.mean(all_mrrs))
print("mrr per relation: ", mrr_per_rel)
## train
def train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampler, num_processes, window):
""" optional, find best values for lambda and alpha by looping through all relations and testing an a fixed set of params
based on validation mrr
:return best_config: dictionary of best params for each relation
"""
best_config= {}
best_mrr = 0
for rel in rels: # loop through relations. for each relation, apply rules with selected params, compute valid mrr
start = timeit.default_timer()
rel_key = int(rel)
best_config[str(rel_key)] = {}
best_config[str(rel_key)]['not_trained'] = 'True'
best_config[str(rel_key)]['lmbda_psi'] = [default_lmbda_psi,0] #default
best_config[str(rel_key)]['other_lmbda_mrrs'] = list(np.zeros(len(params_dict['lmbda_psi'])))
best_config[str(rel_key)]['alpha'] = [default_alpha,0] #default
best_config[str(rel_key)]['other_alpha_mrrs'] = list(np.zeros(len(params_dict['alpha'])))
if rel in val_data_prel.keys():
# valid data for this relation
val_data_c_rel = val_data_prel[rel]
timesteps_valid = list(set(val_data_c_rel[:,3]))
timesteps_valid.sort()
trainval_data_c_rel = trainval_data_prel[rel]
###### 1) select lambda ###############
lmbdas_psi = params_dict['lmbda_psi']
alpha = 1
best_lmbda_psi = 0.1
best_mrr_psi = 0
lmbda_mrrs = []
best_config[str(rel_key)]['num_app_valid'] = copy(len(val_data_c_rel))
best_config[str(rel_key)]['num_app_train_valid'] = copy(len(trainval_data_c_rel))
best_config[str(rel_key)]['not_trained'] = 'False'
for lmbda_psi in lmbdas_psi:
perf_list_r = []
hits_list_r = []
perf_list_r, hits_list_r = predict(num_processes, val_data_c_rel,
trainval_data_c_rel, alpha, lmbda_psi,perf_list_r, hits_list_r,
window, neg_sampler, split_mode='val')
# compute mrr
mrr = np.mean(perf_list_r)
# # is new mrr better than previous best? if yes: store lmbda
if mrr > best_mrr_psi:
best_mrr_psi = float(mrr)
best_lmbda_psi = lmbda_psi
lmbda_mrrs.append(float(mrr))
best_config[str(rel_key)]['lmbda_psi'] = [best_lmbda_psi, best_mrr_psi]
best_config[str(rel_key)]['other_lmbda_mrrs'] = lmbda_mrrs
best_mrr = best_mrr_psi
##### 2) select alpha ###############
best_config[str(rel_key)]['not_trained'] = 'False'
alphas = params_dict['alpha']
lmbda_psi = best_config[str(rel_key)]['lmbda_psi'][0] # use the best lmbda psi
alpha_mrrs = []
# perf_list_all = []
best_mrr_alpha = 0
best_alpha=0.99
for alpha in alphas:
perf_list_r = []
hits_list_r = []
perf_list_r, hits_list_r = predict(num_processes, val_data_c_rel,
trainval_data_c_rel, alpha, lmbda_psi,perf_list_r, hits_list_r,
window, neg_sampler, split_mode='val')
# compute mrr
mrr_alpha = np.mean(perf_list_r)
# is new mrr better than previous best? if yes: store alpha
if mrr_alpha > best_mrr_alpha:
best_mrr_alpha = float(mrr_alpha)
best_alpha = alpha
best_mrr = best_mrr_alpha
alpha_mrrs.append(float(mrr_alpha))
best_config[str(rel_key)]['alpha'] = [best_alpha, best_mrr_alpha]
best_config[str(rel_key)]['other_alpha_mrrs'] = alpha_mrrs
end = timeit.default_timer()
total_time = round(end - start, 6)
print("Relation {} finished in {} seconds.".format(rel, total_time))
return best_config
## args
def get_args():
"""parse all arguments for the script"""
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", "-d", default="tkgl-icews", type=str)
parser.add_argument("--window", "-w", default=0, type=int) # set to e.g. 200 if only the most recent 200 timesteps should be considered. set to -2 if multistep
parser.add_argument("--num_processes", "-p", default=1, type=int)
parser.add_argument("--lmbda", "-l", default=0.1, type=float) # fix lambda. used if trainflag == false
parser.add_argument("--alpha", "-alpha", default=0.99, type=float) # fix alpha. used if trainflag == false
parser.add_argument("--train_flag", "-tr", default='False') # do we need training, ie selection of lambda and alpha
parser.add_argument("--load_flag", "-lo", default='False') # if train_flag set to True: do you want to load best_config?
parser.add_argument("--save_config", "-c", default='True') # do we need to save the selection of lambda and alpha in config file?
parser.add_argument('--seed', type=int, help='Random seed', default=1) # not needed
parsed = vars(parser.parse_args())
return parsed
start_o = timeit.default_timer()
parsed = get_args()
if parsed['num_processes']>1:
ray.init(num_cpus=parsed["num_processes"], num_gpus=0)
MODEL_NAME = 'RecurrencyBaseline'
SEED = parsed['seed'] # set the random seed for consistency
set_random_seed(SEED)
perrel_results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_models'
if not osp.exists(perrel_results_path):
os.mkdir(perrel_results_path)
print('INFO: Create directory {}'.format(perrel_results_path))
Path(perrel_results_path).mkdir(parents=True, exist_ok=True)
## load dataset and prepare it accordingly
name = parsed["dataset"]
dataset = LinkPropPredDataset(name=name, root="datasets", preprocess=True)
DATA = name
relations = dataset.edge_type
num_rels = dataset.num_rels
rels = np.arange(0,num_rels)
subjects = dataset.full_data["sources"]
objects= dataset.full_data["destinations"]
num_nodes = dataset.num_nodes
timestamps_orig = dataset.full_data["timestamps"]
timestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1
print("split train valid test data")
all_quads = np.stack((subjects, relations, objects, timestamps, timestamps_orig), axis=1)
train_data = all_quads[dataset.train_mask]
val_data = all_quads[dataset.val_mask]
test_data = all_quads[dataset.test_mask]
metric = dataset.eval_metric
evaluator = Evaluator(name=name)
neg_sampler = dataset.negative_sampler
train_val_data = np.concatenate([train_data, val_data])
all_data = np.concatenate([train_data, val_data, test_data])
# create dicts with key: relation id, values: triples for that relation id
print("grouping data by relation")
test_data_prel = group_by(test_data, 1)
all_data_prel = group_by(all_data, 1)
val_data_prel = group_by(val_data, 1)
trainval_data_prel = group_by(train_val_data, 1)
#load the ns samples
# if parsed['train_flag']:
print("loading negative samples")
dataset.load_val_ns()
dataset.load_test_ns()
# parameter options
if parsed['train_flag'] == 'True':
params_dict = {}
params_dict['lmbda_psi'] = [0, 0.0001, 0.0005, 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.5, 0.9, 1.0001]
params_dict['alpha'] = [0, 0.00001, 0.0001, 0.001, 0.01, 0.1, 0.5, 0.9, 0.99, 0.999, 0.9999, 0.99999, 1]
default_lmbda_psi = params_dict['lmbda_psi'][-1]
default_alpha = params_dict['alpha'][-2]
## load rules
print("creating rules")
basis_dict = create_basis_dict(train_val_data)
print("done with creating rules")
## init
# rb_predictor = RecurrencyBaselinePredictor(rels)
## train to find best lambda and alpha
start_train = timeit.default_timer()
if parsed['train_flag'] == 'True':
if parsed['load_flag'] == 'True':
with open('best_config.json', 'r') as infile:
best_config = json.load(infile)
else:
print('start training')
best_config = train(params_dict, rels, val_data_prel, trainval_data_prel, neg_sampler, parsed['num_processes'],
parsed['window'])
if parsed['save_config'] == 'True':
import json
with open('best_config.json', 'w') as outfile:
json.dump(best_config, outfile)
else: # use preset lmbda and alpha; same for all relations
best_config = {}
for rel in rels:
best_config[str(rel)] = {}
best_config[str(rel)]['lmbda_psi'] = [parsed['lmbda']]
best_config[str(rel)]['alpha'] = [parsed['alpha']]
end_train = timeit.default_timer()
# compute validation mrr
print("Computing validation MRR")
perf_list_all_val, hits_list_all_val = test(best_config,rels, val_data_prel,
trainval_data_prel, neg_sampler, parsed['num_processes'],
parsed['window'], split_mode='val')
val_mrr = float(np.mean(perf_list_all_val))
# compute test mrr
print("Computing test MRR")
start_test = timeit.default_timer()
perf_list_all, hits_list_all = test(best_config,rels, test_data_prel,
all_data_prel, neg_sampler, parsed['num_processes'],
parsed['window'])
end_o = timeit.default_timer()
train_time_o = round(end_train- start_train, 6)
test_time_o = round(end_o- start_test, 6)
total_time_o = round(end_o- start_o, 6)
print("Running Training to find best configs finished in {} seconds.".format(train_time_o))
print("Running testing with best configs finished in {} seconds.".format(test_time_o))
print("Running all steps finished in {} seconds.".format(total_time_o))
print(f"The test MRR is {np.mean(perf_list_all)}")
print(f"The valid MRR is {val_mrr}")
print(f"The Hits@10 is {np.mean(hits_list_all)}")
print(f"We have {len(perf_list_all)} predictions")
print(f"The test set has len {len(test_data)} ")
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_NONE_{DATA}_results.json'
metric = dataset.eval_metric
save_results({'model': MODEL_NAME,
'train_flag': parsed['train_flag'],
'data': DATA,
'run': 1,
'seed': SEED,
metric: float(np.mean(perf_list_all)),
'val_mrr': val_mrr,
'hits10': float(np.mean(hits_list_all)),
'test_time': test_time_o,
'tot_train_val_time': total_time_o
},
results_filename)
if parsed['num_processes']>1:
ray.shutdown()
================================================
FILE: examples/linkproppred/tkgl-icews/regcn.py
================================================
"""
Temporal Knowledge Graph Reasoning Based on Evolutional Representation Learning
Reference:
- https://github.com/Lee-zix/RE-GCN
Zixuan Li, Xiaolong Jin, Wei Li, Saiping Guan, Jiafeng Guo, Huawei Shen, Yuanzhuo Wang and Xueqi Cheng. Temporal
Knowledge Graph Reasoning Based on Evolutional Representation Learning. SIGIR 2021.
"""
import sys
import timeit
import os
import sys
import os.path as osp
from pathlib import Path
import numpy as np
import torch
import random
from tqdm import tqdm
# internal imports
modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(modules_path)
from modules.rrgcn import RecurrentRGCNREGCN
from tgb.utils.utils import set_random_seed, split_by_time, save_results
from modules.tkg_utils import get_args_regcn, reformat_ts
from modules.tkg_utils_dgl import build_sub_graph
from tgb.linkproppred.evaluate import Evaluator
from tgb.linkproppred.dataset import LinkPropPredDataset
import json
def test(model, history_list, test_list, num_rels, num_nodes, use_cuda, model_name, static_graph, mode, split_mode):
"""
Test the model on either test or validation set
:param model: model used to test
:param history_list: all input history snap shot list, not include output label train list or valid list
:param test_list: test triple snap shot list
:param num_rels: number of relations
:param num_nodes: number of nodes
:param use_cuda:
:param model_name:
:param mode:
:param split_mode: 'test' or 'val' to state which negative samples to load
:return mrr
"""
print("Testing for mode: ", split_mode)
if split_mode == 'test':
timesteps_to_eval = test_timestamps_orig
else:
timesteps_to_eval = val_timestamps_orig
idx = 0
if mode == "test":
# test mode: load parameter form file
if use_cuda:
checkpoint = torch.load(model_name, map_location=torch.device(args.gpu))
else:
checkpoint = torch.load(model_name, map_location=torch.device('cpu'))
# use best stat checkpoint:
print("Load Model name: {}. Using best epoch : {}".format(model_name, checkpoint['epoch']))
print("\n"+"-"*10+"start testing"+"-"*10+"\n")
model.load_state_dict(checkpoint['state_dict'])
model.eval()
input_list = [snap for snap in history_list[-args.test_history_len:]]
perf_list_all = []
hits_list_all = []
perf_per_rel = {}
for rel in range(num_rels):
perf_per_rel[rel] = []
for time_idx, test_snap in enumerate(tqdm(test_list)):
history_glist = [build_sub_graph(num_nodes, num_rels, g, use_cuda, args.gpu) for g in input_list]
test_triples_input = torch.LongTensor(test_snap).cuda() if use_cuda else torch.LongTensor(test_snap)
timesteps_batch =timesteps_to_eval[time_idx]*np.ones(len(test_triples_input[:,0]))
neg_samples_batch = neg_sampler.query_batch(test_triples_input[:,0], test_triples_input[:,2],
timesteps_batch, edge_type=test_triples_input[:,1], split_mode=split_mode)
pos_samples_batch = test_triples_input[:,2]
_, perf_list, hits_list = model.predict(history_glist, num_rels, static_graph, test_triples_input, use_cuda, neg_samples_batch, pos_samples_batch,
evaluator, METRIC)
perf_list_all.extend(perf_list)
hits_list_all.extend(hits_list)
if split_mode == "test":
if args.log_per_rel:
for score, rel in zip(perf_list, test_triples_input[:,1].tolist()):
perf_per_rel[rel].append(score)
# reconstruct history graph list
input_list.pop(0)
input_list.append(test_snap)
idx += 1
if split_mode == "test":
if args.log_per_rel:
for rel in range(num_rels):
if len(perf_per_rel[rel]) > 0:
perf_per_rel[rel] = float(np.mean(perf_per_rel[rel]))
else:
perf_per_rel.pop(rel)
mrr = np.mean(perf_list_all)
hits10 = np.mean(hits_list_all)
return mrr, perf_per_rel, hits10
def run_experiment(args, n_hidden=None, n_layers=None, dropout=None, n_bases=None):
"""
Run the experiment with the given configuration
:param args: arguments
:param n_hidden: hidden dimension
:param n_layers: number of layers
:param dropout: dropout rate
:param n_bases: number of bases
:return: mrr, perf_per_rel (mean reciprocal rank, performance per relation)
"""
# load configuration for grid search the best configuration
if n_hidden:
args.n_hidden = n_hidden
if n_layers:
args.n_layers = n_layers
if dropout:
args.dropout = dropout
if n_bases:
args.n_bases = n_bases
mrr = 0
hits10=0
# 2) set save model path
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
if not osp.exists(save_model_dir):
os.mkdir(save_model_dir)
print('INFO: Create directory {}'.format(save_model_dir))
model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}'
model_state_file = save_model_dir+model_name
perf_per_rel = {}
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
num_static_rels, num_words, static_triples, static_graph = 0, 0, [], None
# create stat
model = RecurrentRGCNREGCN(args.decoder,
args.encoder,
num_nodes,
int(num_rels/2),
num_static_rels, # DIFFERENT
num_words, # DIFFERENT
args.n_hidden,
args.opn,
sequence_len=args.train_history_len,
num_bases=args.n_bases,
num_basis=args.n_basis,
num_hidden_layers=args.n_layers,
dropout=args.dropout,
self_loop=args.self_loop,
skip_connect=args.skip_connect,
layer_norm=args.layer_norm,
input_dropout=args.input_dropout,
hidden_dropout=args.hidden_dropout,
feat_dropout=args.feat_dropout,
aggregation=args.aggregation, # DIFFERENT
weight=args.weight, # DIFFERENT
discount=args.discount, # DIFFERENT
angle=args.angle, # DIFFERENT
use_static=args.add_static_graph, # DIFFERENT
entity_prediction=args.entity_prediction,
relation_prediction=args.relation_prediction,
use_cuda=use_cuda,
gpu = args.gpu,
analysis=args.run_analysis) # DIFFERENT
if use_cuda:
torch.cuda.set_device(args.gpu)
model.cuda()
# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
if args.test and os.path.exists(model_state_file):
mrr, perf_per_rel, hits10 = test(model,
train_list+valid_list,
test_list,
num_rels,
num_nodes,
use_cuda,
model_state_file,
static_graph,
"test",
"test")
return mrr, perf_per_rel, hits10
elif args.test and not os.path.exists(model_state_file):
print("--------------{} not exist, Change mode to train and generate stat for testing----------------\n".format(model_state_file))
return 0, 0
else:
print("----------------------------------------start training----------------------------------------\n")
best_mrr = 0
best_hits = 0
for epoch in range(args.n_epochs):
model.train()
losses = []
idx = [_ for _ in range(len(train_list))]
random.shuffle(idx)
for train_sample_num in tqdm(idx):
if train_sample_num == 0: continue
output = train_list[train_sample_num:train_sample_num+1]
if train_sample_num - args.train_history_len<0:
input_list = train_list[0: train_sample_num]
else:
input_list = train_list[train_sample_num - args.train_history_len:
train_sample_num]
# generate history graph
history_glist = [build_sub_graph(num_nodes, num_rels, snap, use_cuda, args.gpu) for snap in input_list]
output = [torch.from_numpy(_).long().cuda() for _ in output] if use_cuda else [torch.from_numpy(_).long() for _ in output]
loss_e, loss_r, loss_static = model.get_loss(history_glist, output[0], static_graph, use_cuda)
loss = args.task_weight*loss_e + (1-args.task_weight)*loss_r + loss_static
losses.append(loss.item())
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm) # clip gradients
optimizer.step()
optimizer.zero_grad()
print("Epoch {:04d} | Ave Loss: {:.4f} | Best MRR {:.4f} | Model {} "
.format(epoch, np.mean(losses), best_mrr, model_name))
#! checking GPU usage
free_mem, total_mem = torch.cuda.mem_get_info()
print ("--------------GPU memory usage-----------")
print ("there are ", free_mem, " free memory")
print ("there are ", total_mem, " total available memory")
print ("there are ", total_mem - free_mem, " used memory")
print ("--------------GPU memory usage-----------")
# validation
if epoch and epoch % args.evaluate_every == 0:
mrr,perf_per_rel, hits10 = test(model, train_list,
valid_list,
num_rels,
num_nodes,
use_cuda,
model_state_file,
static_graph,
mode="train", split_mode='val')
if mrr < best_mrr:
if epoch >= args.n_epochs:
break
else:
best_mrr = mrr
best_hits = hits10
torch.save({'state_dict': model.state_dict(), 'epoch': epoch}, model_state_file)
return best_mrr, perf_per_rel, hits10
# ==================
# ==================
# ==================
start_overall = timeit.default_timer()
# set hyperparameters
args, _ = get_args_regcn()
args.dataset = 'tkgl-icews'
SEED = args.seed # set the random seed for consistency
set_random_seed(SEED)
DATA=args.dataset
MODEL_NAME = 'REGCN'
print("logging mrrs per relation: ", args.log_per_rel)
# load data
dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
num_rels = dataset.num_rels
num_nodes = dataset.num_nodes
subjects = dataset.full_data["sources"]
objects= dataset.full_data["destinations"]
relations = dataset.edge_type
timestamps_orig = dataset.full_data["timestamps"]
timestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1
all_quads = np.stack((subjects, relations, objects, timestamps), axis=1)
train_data = all_quads[dataset.train_mask]
val_data = all_quads[dataset.val_mask]
test_data = all_quads[dataset.test_mask]
val_timestamps_orig = list(set(timestamps_orig[dataset.val_mask])) # needed for getting the negative samples
val_timestamps_orig.sort()
test_timestamps_orig = list(set(timestamps_orig[dataset.test_mask])) # needed for getting the negative samples
test_timestamps_orig.sort()
train_list = split_by_time(train_data)
valid_list = split_by_time(val_data)
test_list = split_by_time(test_data)
# evaluation metric
METRIC = dataset.eval_metric
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
#load the ns samples
dataset.load_val_ns()
dataset.load_test_ns()
## run training and testing
val_mrr, test_mrr = 0, 0
test_hits10 = 0
if args.grid_search:
print("hyperparameter grid search not implemented. Exiting.")
# single run
else:
start_train = timeit.default_timer()
if args.test == False: #if they are true: directly test on a previously trained and stored model
print('start training')
val_mrr, perf_per_rel, val_hits10 = run_experiment(args) # do training
start_test = timeit.default_timer()
args.test = True
print('start testing')
test_mrr, perf_per_rel, test_hits10 = run_experiment(args) # do testing
test_time = timeit.default_timer() - start_test
all_time = timeit.default_timer() - start_train
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
print(f"\Train and Test: Elapsed Time (s): {all_time: .4f}")
print(f"\tTest: {METRIC}: {test_mrr: .4f}")
print(f"\tValid: {METRIC}: {val_mrr: .4f}")
# saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
save_results({'model': MODEL_NAME,
'data': DATA,
'run': args.run_nr,
'seed': SEED,
f'val {METRIC}': float(val_mrr),
f'test {METRIC}': float(test_mrr),
'test_time': test_time,
'tot_train_val_time': all_time,
'test_hits10': float(test_hits10)
},
results_filename)
if args.log_per_rel:
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results_per_rel.json'
with open(results_filename, 'w') as json_file:
json.dump(perf_per_rel, json_file)
sys.exit()
================================================
FILE: examples/linkproppred/tkgl-icews/timetraveler.py
================================================
"""
TimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting
Reference:
- https://github.com/JHL-HUST/TITer
Haohai Sun, Jialun Zhong, Yunpu Ma, Zhen Han, Kun He.
TimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting EMNLP 2021
"""
import sys
import timeit
import torch
from torch.utils.data import Dataset,DataLoader
import logging
import numpy as np
import pickle
from tqdm import tqdm
import os.path as osp
from pathlib import Path
import os
modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(modules_path)
from modules.timetraveler_agent import Agent
from modules.timetraveler_environment import Env
from modules.timetraveler_dirichlet import Dirichlet, MLE_Dirchlet
from modules.timetraveler_episode import Episode
from modules.timetraveler_policygradient import PG
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.linkproppred.evaluate import Evaluator
from modules.timetraveler_trainertester import Trainer, Tester, getRelEntCooccurrence
from tgb.utils.utils import set_random_seed,save_results
from modules.tkg_utils import get_args_timetraveler, reformat_ts, get_model_config_timetraveler
class QuadruplesDataset(Dataset):
""" this is an internal way how Timetraveler represents the data
"""
def __init__(self, examples):
"""
examples: a list of quadruples.
num_r: number of relations
"""
self.quadruples = examples.copy()
def __len__(self):
return len(self.quadruples)
def __getitem__(self, item):
return self.quadruples[item][0], \
self.quadruples[item][1], \
self.quadruples[item][2], \
self.quadruples[item][3], \
self.quadruples[item][4]
def set_logger(save_path):
"""Write logs to checkpoint and console"""
if args.do_train:
log_file = os.path.join(save_path, 'train.log')
else:
log_file = os.path.join(save_path, 'test.log')
logging.basicConfig(
format='%(asctime)s %(levelname)-8s %(message)s',
level=logging.INFO,
datefmt='%Y-%m-%d %H:%M:%S',
filename=log_file,
filemode='w'
)
console = logging.StreamHandler()
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s')
console.setFormatter(formatter)
logging.getLogger('').addHandler(console)
def preprocess_data(args, config, timestamps, save_path, all_quads):
"""
Preprocess the data and save the state-action space (pickle dump)
"""
# parser = argparse.ArgumentParser(description='Data preprocess', usage='preprocess_data.py [] [-h | --help]')
# parser.add_argument('--data_dir', default='data/ICEWS14', type=str, help='Path to data.')
env = Env(all_quads, config)
state_actions_space = {}
with tqdm(total=len(all_quads)) as bar:
for (head, rel, tail, t, _) in all_quads:
if (head, t, True) not in state_actions_space.keys():
state_actions_space[(head, t, True)] = env.get_state_actions_space_complete(head, t, True, args.store_actions_num)
state_actions_space[(head, t, False)] = env.get_state_actions_space_complete(head, t, False, args.store_actions_num)
if (tail, t, True) not in state_actions_space.keys():
state_actions_space[(tail, t, True)] = env.get_state_actions_space_complete(tail, t, True, args.store_actions_num)
state_actions_space[(tail, t, False)] = env.get_state_actions_space_complete(tail, t, False, args.store_actions_num)
bar.update(1)
pickle.dump(state_actions_space, open(os.path.join(save_path, args.state_actions_path), 'wb'))
def log_metrics(mode, step, metrics):
"""Print the evaluation logs"""
for metric in metrics:
logging.info('%s %s at epoch %d: %f' % (mode, metric, step, metrics[metric]))
def main(args):
"""
Main function to train and test the TimeTraveler model"""
start_overall = timeit.default_timer()
#######################Set Logger#################################
save_path = f'{os.path.dirname(os.path.abspath(__file__))}/saved_models/'
if not os.path.exists(save_path):
os.makedirs(save_path)
if args.cuda and torch.cuda.is_available():
args.cuda = True
else:
args.cuda = False
set_logger(save_path)
#######################Create DataLoader#################################
# set hyperparameters
args.dataset = 'tkgl-icews'
SEED = args.seed # set the random seed for consistency
set_random_seed(SEED)
DATA=args.dataset
MODEL_NAME = 'TIMETRAVELER'
# load data
dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
num_rels = dataset.num_rels
num_nodes = dataset.num_nodes
subjects = dataset.full_data["sources"]
objects= dataset.full_data["destinations"]
relations = dataset.edge_type
timestamps_orig = dataset.full_data["timestamps"]
timestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1
all_quads = np.stack((subjects, relations, objects, timestamps,timestamps_orig), axis=1)
train_data = all_quads[dataset.train_mask]
train_entities = np.unique(np.concatenate((train_data[:, 0], train_data[:, 2])))
RelEntCooccurrence = getRelEntCooccurrence(train_data, num_rels)
train_data =QuadruplesDataset(train_data)
val_data = QuadruplesDataset(all_quads[dataset.val_mask])
test_data = QuadruplesDataset(all_quads[dataset.test_mask])
METRIC = dataset.eval_metric
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
#load the ns samples
dataset.load_val_ns()
dataset.load_test_ns()
train_dataloader = DataLoader(
train_data,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
)
valid_dataloader = DataLoader(
val_data,
batch_size=args.test_batch_size,
shuffle=False,
num_workers=args.num_workers,
)
test_dataloader = DataLoader(
test_data,
batch_size=args.test_batch_size,
shuffle=False,
num_workers=args.num_workers,
)
######################Creat the agent and the environment###########################
config = get_model_config_timetraveler(args, num_nodes, num_rels)
logging.info(config)
logging.info(args)
# creat the agent
agent = Agent(config)
# creat the environment
state_actions_path = os.path.join(save_path, args.state_actions_path)
######################preprocessing###########################
if not os.path.exists(state_actions_path):
if args.preprocess:
print("preprocessing data...")
preprocess_data(args, config, timestamps, save_path, list(all_quads))
state_action_space = pickle.load(open(os.path.join(save_path, args.state_actions_path), 'rb'))
else:
state_action_space = None
else:
print("load preprocessed data...")
state_action_space = pickle.load(open(os.path.join(save_path, args.state_actions_path), 'rb'))
env = Env(list(all_quads), config, state_action_space)
# Create episode controller
episode = Episode(env, agent, config)
if args.cuda:
episode = episode.cuda()
pg = PG(config) # Policy Gradient
optimizer = torch.optim.Adam(episode.parameters(), lr=args.lr, weight_decay=0.00001)
######################Reward Shaping: MLE DIRICHLET alphas###########################
if args.reward_shaping:
try:
print("load alphas from pickle file")
alphas = pickle.load(open(os.path.join(save_path, args.alphas_pkl), 'rb'))
except:
print('running MLE dirichlet now')
mle_d = MLE_Dirchlet(all_quads, num_rels, args.k, args.time_span,
args.tol, args.method, args.maxiter)
pickle.dump(mle_d.alphas, open(os.path.join(save_path, args.alphas_pkl), 'wb'))
print('dumped alphas')
alphas = mle_d.alphas
distributions = Dirichlet(alphas, args.k)
else:
distributions = None
######################Training and Testing###########################
trainer = Trainer(episode, pg, optimizer, args, distributions)
tester = Tester(episode, args, train_entities, RelEntCooccurrence, dataset.metric)
test_metrics ={}
val_metrics = {}
test_metrics[METRIC] = None
val_metrics[METRIC] = None
if args.do_train:
start_train =timeit.default_timer()
logging.info('Start Training......')
for i in range(args.max_epochs):
loss, reward = trainer.train_epoch(train_dataloader, len(train_data))
logging.info('Epoch {}/{} Loss: {}, reward: {}'.format(i, args.max_epochs, loss, reward))
#! checking GPU usage
free_mem, total_mem = torch.cuda.mem_get_info()
print ("--------------GPU memory usage-----------")
print ("there are ", free_mem, " free memory")
print ("there are ", total_mem, " total available memory")
print ("there are ", total_mem - free_mem, " used memory")
print ("--------------GPU memory usage-----------")
if i % args.save_epoch == 0 and i != 0:
trainer.save_model(save_path, 'checkpoint_{}.pth'.format(i))
logging.info('Save Model in {}'.format(save_path))
if i % args.valid_epoch == 0 and i != 0:
logging.info('Start Val......')
val_metrics = tester.test(valid_dataloader,
len(val_data), num_nodes, neg_sampler, evaluator, split_mode='val')
for mode in val_metrics.keys():
logging.info('{} at epoch {}: {}'.format(mode, i, val_metrics[mode]))
trainer.save_model(save_path)
logging.info('Save Model in {}'.format(save_path))
else:
# # Load the model parameters
if os.path.isfile(save_path):
params = torch.load(save_path)
episode.load_state_dict(params['model_state_dict'])
optimizer.load_state_dict(params['optimizer_state_dict'])
logging.info('Load pretrain model: {}'.format(save_path))
if args.do_test:
logging.info('Start Testing......')
start_test = timeit.default_timer()
test_metrics = tester.test(test_dataloader,
len(test_data), num_nodes, neg_sampler, evaluator, split_mode='test')
for mode in test_metrics.keys():
logging.info('Test {} : {}'.format(mode, test_metrics[mode]))
# saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
test_time = timeit.default_timer() - start_test
all_time = timeit.default_timer() - start_train
all_time_preprocess = timeit.default_timer() - start_overall
save_results({'model': MODEL_NAME,
'data': DATA,
'seed': SEED,
f'val {METRIC}': float(val_metrics[METRIC]),
f'test {METRIC}': float(test_metrics[METRIC]),
'test_time': test_time,
'tot_train_val_time': all_time,
'tot_preprocess_train_val_time': all_time_preprocess
},
results_filename)
if __name__ == '__main__':
args = get_args_timetraveler()
main(args)
================================================
FILE: examples/linkproppred/tkgl-icews/tkgl-icews_example.py
================================================
import numpy as np
import timeit
from tqdm import tqdm
import sys
import os
import os.path as osp
#internal imports
modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(modules_path)
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
from torch_geometric.loader import TemporalDataLoader
from tgb.linkproppred.evaluate import Evaluator
DATA = "tkgl-icews"
# data loading
dataset = PyGLinkPropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
metric = dataset.eval_metric
print ("there are {} nodes and {} edges".format(dataset.num_nodes, dataset.num_edges))
print ("there are {} relation types".format(dataset.num_rels))
timestamp = data.t
head = data.src
tail = data.dst
edge_type = data.edge_type #relation
neg_sampler = dataset.negative_sampler
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
metric = dataset.eval_metric
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
BATCH_SIZE = 200
val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)
test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)
start_time = timeit.default_timer()
#load the ns samples first
dataset.load_val_ns()
for batch in tqdm(val_loader):
src, pos_dst, t, msg, rel = batch.src, batch.dst, batch.t, batch.msg, batch.edge_type
neg_batch_list = neg_sampler.query_batch(src.detach().cpu().numpy(), pos_dst.detach().cpu().numpy(), t.detach().cpu().numpy(), rel.detach().cpu().numpy(), split_mode='val')
print ("loading ns samples from validation", timeit.default_timer() - start_time)
# for i, (src, dst, t, rel) in enumerate(zip(val_data.src, val_data.dst, val_data.t, val_data.edge_type)):
# #must use np array to query
# neg_batch_list = neg_sampler.query_batch(np.array([src]), np.array([dst]), np.array([t]), edge_type=np.array([rel]), split_mode='val')
start_time = timeit.default_timer()
dataset.load_test_ns()
for batch in test_loader:
src, pos_dst, t, msg, rel = batch.src, batch.dst, batch.t, batch.msg, batch.edge_type
neg_batch_list = neg_sampler.query_batch(src.detach().cpu().numpy(), pos_dst.detach().cpu().numpy(), t.detach().cpu().numpy(), rel.detach().cpu().numpy(), split_mode='test')
print ("loading ns samples from test", timeit.default_timer() - start_time)
# for i, (src, dst, t, rel) in enumerate(zip(test_data.src, test_data.dst, test_data.t, test_data.edge_type)):
# #must use np array to query
# neg_batch_list = neg_sampler.query_batch(np.array([src]), np.array([dst]), np.array([t]), edge_type=np.array([rel]), split_mode='test')
print ("retrieved all negative samples")
# #* load numpy arrays instead
# from tgb.linkproppred.dataset import LinkPropPredDataset
# # data loading
# dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
# data = dataset.full_data
# metric = dataset.eval_metric
# sources = dataset.full_data['sources']
# print (sources.dtype)
================================================
FILE: examples/linkproppred/tkgl-icews/tlogic.py
================================================
"""
https://github.com/liu-yushan/TLogic/tree/main/mycode
TLogic: Temporal Logical Rules for Explainable Link Forecasting on Temporal Knowledge Graphs.
Yushan Liu, Yunpu Ma, Marcel Hildebrandt, Mitchell Joblin, Volker Tresp
"""
# imports
import sys
import os
import os.path as osp
from pathlib import Path
modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(modules_path)
import timeit
import argparse
import numpy as np
import json
from joblib import Parallel, delayed
import itertools
from tgb.linkproppred.evaluate import Evaluator
from tgb.linkproppred.dataset import LinkPropPredDataset
from modules.tlogic_learn_modules import Temporal_Walk, Rule_Learner, store_edges
import modules.tlogic_apply_modules as ra
from tgb.utils.utils import set_random_seed, save_results
from modules.tkg_utils import reformat_ts, get_inv_relation_id, create_scores_array
def learn_rules(i, num_relations):
"""
Learn rules (multiprocessing possible).
Parameters:
i (int): process number
num_relations (int): minimum number of relations for each process
Returns:
rl.rules_dict (dict): rules dictionary
"""
# if seed:
# np.random.seed(seed)
num_rest_relations = len(all_relations) - (i + 1) * num_relations
if num_rest_relations >= num_relations:
relations_idx = range(i * num_relations, (i + 1) * num_relations)
else:
relations_idx = range(i * num_relations, len(all_relations))
num_rules = [0]
for k in relations_idx:
rel = all_relations[k]
for length in rule_lengths:
it_start = timeit.default_timer()
for _ in range(num_walks):
walk_successful, walk = temporal_walk.sample_walk(length + 1, rel)
if walk_successful:
rl.create_rule(walk)
it_end = timeit.default_timer()
it_time = round(it_end - it_start, 6)
num_rules.append(sum([len(v) for k, v in rl.rules_dict.items()]) // 2)
num_new_rules = num_rules[-1] - num_rules[-2]
print(
"Process {0}: relation {1}/{2}, length {3}: {4} sec, {5} rules".format(
i,
k - relations_idx[0] + 1,
len(relations_idx),
length,
it_time,
num_new_rules,
)
)
return rl.rules_dict
def apply_rules(i, num_queries, rules_dict, neg_sampler, data, window, learn_edges, all_quads, args, split_mode,
log_per_rel=False, num_rels=0):
"""
Apply rules (multiprocessing possible).
Parameters:
i (int): process number
num_queries (int): minimum number of queries for each process
Returns:
hits_list (list): hits list (hits@10 per sample)
perf_list (list): performance list (mrr per sample)
"""
perf_per_rel = {}
for rel in range(num_rels):
perf_per_rel[rel] = []
print("Start process", i, "...")
all_candidates = [dict() for _ in range(len(args))]
no_cands_counter = 0
num_rest_queries = len(data) - (i + 1) * num_queries
if num_rest_queries >= num_queries:
test_queries_idx = range(i * num_queries, (i + 1) * num_queries)
else:
test_queries_idx = range(i * num_queries, len(data))
cur_ts = data[test_queries_idx[0]][3]
edges = ra.get_window_edges(all_quads[:,0:4], cur_ts, learn_edges, window)
it_start = timeit.default_timer()
hits_list = [0] * len(test_queries_idx)
perf_list = [0] * len(test_queries_idx)
for index, j in enumerate(test_queries_idx):
neg_sample_el = neg_sampler.query_batch(np.expand_dims(np.array(data[j,0]), axis=0),
np.expand_dims(np.array(data[j,2]), axis=0),
np.expand_dims(np.array(data[j,4]), axis=0),
np.expand_dims(np.array(data[j,1]), axis=0),
split_mode=split_mode)[0]
# neg_samples_batch[j]
pos_sample_el = data[j,2]
test_query = data[j]
assert pos_sample_el == test_query[2]
cands_dict = [dict() for _ in range(len(args))]
if test_query[3] != cur_ts:
cur_ts = test_query[3]
edges = ra.get_window_edges(all_quads[:,0:4], cur_ts, learn_edges, window)
if test_query[1] in rules_dict:
dicts_idx = list(range(len(args)))
for rule in rules_dict[test_query[1]]:
walk_edges = ra.match_body_relations(rule, edges, test_query[0])
if 0 not in [len(x) for x in walk_edges]:
rule_walks = ra.get_walks(rule, walk_edges)
if rule["var_constraints"]:
rule_walks = ra.check_var_constraints(
rule["var_constraints"], rule_walks
)
if not rule_walks.empty:
cands_dict = ra.get_candidates(
rule,
rule_walks,
cur_ts,
cands_dict,
score_func,
args,
dicts_idx,
)
for s in dicts_idx:
cands_dict[s] = {
x: sorted(cands_dict[s][x], reverse=True)
for x in cands_dict[s].keys()
}
cands_dict[s] = dict(
sorted(
cands_dict[s].items(),
key=lambda item: item[1],
reverse=True,
)
)
top_k_scores = [v for _, v in cands_dict[s].items()][:top_k]
unique_scores = list(
scores for scores, _ in itertools.groupby(top_k_scores)
)
if len(unique_scores) >= top_k:
dicts_idx.remove(s)
if not dicts_idx:
break
if cands_dict[0]:
for s in range(len(args)):
# Calculate noisy-or scores
scores = list(
map(
lambda x: 1 - np.product(1 - np.array(x)),
cands_dict[s].values(),
)
)
cands_scores = dict(zip(cands_dict[s].keys(), scores))
noisy_or_cands = dict(
sorted(cands_scores.items(), key=lambda x: x[1], reverse=True)
)
all_candidates[s][j] = noisy_or_cands
else: # No candidates found by applying rules
no_cands_counter += 1
for s in range(len(args)):
all_candidates[s][j] = dict()
else: # No rules exist for this relation
no_cands_counter += 1
for s in range(len(args)):
all_candidates[s][j] = dict()
if not (j - test_queries_idx[0] + 1) % 100:
it_end = timeit.default_timer()
it_time = round(it_end - it_start, 6)
print(
"Process {0}: test samples finished: {1}/{2}, {3} sec".format(
i, j - test_queries_idx[0] + 1, len(test_queries_idx), it_time
)
)
it_start = timeit.default_timer()
predictions = create_scores_array(all_candidates[s][j], num_nodes)
predictions_of_interest_pos = np.array(predictions[pos_sample_el])
predictions_of_interest_neg = predictions[neg_sample_el]
input_dict = {
"y_pred_pos": predictions_of_interest_pos,
"y_pred_neg": predictions_of_interest_neg,
"eval_metric": ['mrr'],
}
predictions = evaluator.eval(input_dict)
perf_list[index] = predictions['mrr']
hits_list[index] = predictions['hits@10']
if split_mode == "test":
if log_per_rel:
perf_per_rel[test_query[1]].append(perf_list[index]) #test_query[1] is the relation index
if split_mode == "test":
if log_per_rel:
for rel in range(num_rels):
if len(perf_per_rel[rel]) > 0:
perf_per_rel[rel] = float(np.mean(perf_per_rel[rel]))
else:
perf_per_rel.pop(rel)
return perf_list, hits_list, perf_per_rel
## args
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", "-d", default="tkgl-icews", type=str)
parser.add_argument("--rule_lengths", "-l", default="1", type=int, nargs="+")
parser.add_argument("--num_walks", "-n", default="100", type=int)
parser.add_argument("--transition_distr", default="exp", type=str)
parser.add_argument("--window", "-w", default=0, type=int) # set to e.g. 200 if only the most recent 200 timesteps should be considered. set to -2 if multistep
parser.add_argument("--top_k", default=20, type=int)
parser.add_argument("--num_processes", "-p", default=1, type=int)
parser.add_argument("--alpha", "-alpha", default=0.99, type=float) # fix alpha. used if trainflag == false
# parser.add_argument("--train_flag", "-tr", default=True) # do we need training, ie selection of lambda and alpha
parser.add_argument("--save_config", "-c", default=True) # do we need to save the selection of lambda and alpha in config file?
parser.add_argument('--seed', type=int, help='Random seed', default=1)
parser.add_argument('--run_nr', type=int, help='Run Number', default=1)
parser.add_argument('--learn_rules_flag', type=bool, help='Do we want to learn the rules', default=True)
parser.add_argument('--rule_filename', type=str, help='if rules not learned: where are they stored', default='0_r[3]_n100_exp_s1_rules.json')
parser.add_argument('--log_per_rel', type=bool, help='Do we want to log mrr per relation', default=False)
parser.add_argument('--compute_valid_mrr', type=bool, help='Do we want to compute mrr for valid set', default=True)
parsed = vars(parser.parse_args())
return parsed
start_o = timeit.default_timer()
## get args
parsed = get_args()
dataset = parsed["dataset"]
rule_lengths = parsed["rule_lengths"]
rule_lengths = [rule_lengths] if (type(rule_lengths) == int) else rule_lengths
num_walks = parsed["num_walks"]
transition_distr = parsed["transition_distr"]
num_processes = parsed["num_processes"]
window = parsed["window"]
top_k = parsed["top_k"]
log_per_rel = parsed['log_per_rel']
MODEL_NAME = 'TLogic'
SEED = parsed['seed'] # set the random seed for consistency
set_random_seed(SEED)
## load dataset and prepare it accordingly
name = parsed["dataset"]
compute_valid_mrr = parsed["compute_valid_mrr"]
dataset = LinkPropPredDataset(name=name, root="datasets", preprocess=True)
DATA = name
relations = dataset.edge_type
num_rels = dataset.num_rels
subjects = dataset.full_data["sources"]
objects= dataset.full_data["destinations"]
num_nodes = dataset.num_nodes
timestamps_orig = dataset.full_data["timestamps"]
timestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1
all_quads = np.stack((subjects, relations, objects, timestamps, timestamps_orig), axis=1)
train_data = all_quads[dataset.train_mask,0:4] # we do not need the original timestamps
val_data = all_quads[dataset.val_mask,0:5]
test_data = all_quads[dataset.test_mask,0:5]
all_data = all_quads[:,0:4]
metric = dataset.eval_metric
evaluator = Evaluator(name=name)
neg_sampler = dataset.negative_sampler
inv_relation_id = get_inv_relation_id(num_rels)
#load the ns samples
dataset.load_val_ns()
dataset.load_test_ns()
output_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
learn_rules_flag = parsed['learn_rules_flag']
## 1. learn rules
start_train = timeit.default_timer()
if learn_rules_flag:
print("start learning rules")
# edges (dict): edges for each relation
# inv_relation_id (dict): mapping of relation to inverse relation
temporal_walk = Temporal_Walk(train_data, inv_relation_id, transition_distr)
rl = Rule_Learner(edges=temporal_walk.edges, id2relation=None, inv_relation_id=inv_relation_id,
output_dir=output_dir)
all_relations = sorted(temporal_walk.edges) # Learn for all relations
start = timeit.default_timer()
num_relations = len(all_relations) // num_processes
output = Parallel(n_jobs=num_processes)(
delayed(learn_rules)(i, num_relations) for i in range(num_processes)
)
end = timeit.default_timer()
all_rules = output[0]
for i in range(1, num_processes):
all_rules.update(output[i])
total_time = round(end - start, 6)
print("Learning finished in {} seconds.".format(total_time))
rl.rules_dict = all_rules
rl.sort_rules_dict()
rule_filename = rl.save_rules(0, rule_lengths, num_walks, transition_distr, SEED)
# rl.save_rules_verbalized(0, rule_lengths, num_walks, transition_distr, seed)
# rules_statistics(rl.rules_dict)
else:
rule_filename = parsed['rule_filename']
print("Loading rules from file {}".format(parsed['rule_filename']))
end_train = timeit.default_timer()
## 2. Apply rules
rules_dict = json.load(open(output_dir + rule_filename))
rules_dict = {int(k): v for k, v in rules_dict.items()}
rules_dict = ra.filter_rules(
rules_dict, min_conf=0.01, min_body_supp=2, rule_lengths=rule_lengths
) # filter rules for minimum confidence, body support and rule length
learn_edges = store_edges(train_data)
score_func = ra.score_12
# It is possible to specify a list of list of arguments for tuning
args = [[0.1, 0.5]]
# compute valid mrr
start_valid = timeit.default_timer()
if compute_valid_mrr:
print('Computing valid MRR')
num_queries = len(val_data) // num_processes
output = Parallel(n_jobs=num_processes)(
delayed(apply_rules)(i, num_queries,rules_dict, neg_sampler, val_data, window, learn_edges,
all_quads, args, split_mode='val') for i in range(num_processes))
end = timeit.default_timer()
perf_list_val = []
hits_list_val = []
for i in range(num_processes):
perf_list_val.extend(output[i][0])
hits_list_val.extend(output[i][1])
else:
perf_list_val = [0]
hits_list_val = [0]
end_valid = timeit.default_timer()
# compute test mrr
if log_per_rel ==True:
num_processes = 1 #otherwise logging per rel does not work for our implementation
start_test = timeit.default_timer()
print('Computing test MRR')
start = timeit.default_timer()
num_queries = len(test_data) // num_processes
output = Parallel(n_jobs=num_processes)(
delayed(apply_rules)(i, num_queries,rules_dict, neg_sampler, test_data, window, learn_edges,
all_quads, args, split_mode='test', log_per_rel=log_per_rel, num_rels=num_rels) for i in range(num_processes))
end = timeit.default_timer()
perf_list_all = []
hits_list_all = []
for i in range(num_processes):
perf_list_all.extend(output[i][0])
hits_list_all.extend(output[i][1])
if log_per_rel == True:
perf_per_rel = output[0][2]
total_time = round(end - start, 6)
total_valid_time = round(end_valid - start_valid, 6)
print("Application finished in {} seconds.".format(total_time))
print(f"The valid MRR is {np.mean(perf_list_val)}")
print(f"The MRR is {np.mean(perf_list_all)}")
print(f"The Hits@10 is {np.mean(hits_list_all)}")
print(f"We have {len(perf_list_all)} predictions")
print(f"The test set has len {len(test_data)} ")
end_o = timeit.default_timer()
train_time_o = round(end_train- start_train, 6)
test_time_o = round(end_o- start_test, 6)
total_time_o = round(end_o- start_o, 6)
print("Running Training to find best configs finished in {} seconds.".format(train_time_o))
print("Running testing with best configs finished in {} seconds.".format(test_time_o))
print("Running all steps finished in {} seconds.".format(total_time_o))
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
if log_per_rel == True:
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results_per_rel.json'
with open(results_filename, 'w') as json_file:
json.dump(perf_per_rel, json_file)
results_filename = f'{results_path}/{MODEL_NAME}_NONE_{DATA}_results.json'
metric = dataset.eval_metric
save_results({'model': MODEL_NAME,
'train_flag': None,
'rule_len': rule_lengths,
'window': window,
'data': DATA,
'run': 1,
'seed': SEED,
metric: float(np.mean(perf_list_all)),
'hits10': float(np.mean(hits_list_all)),
'val_mrr': float(np.mean(perf_list_val)),
'test_time': test_time_o,
'tot_train_val_time': total_time_o,
'valid_time': total_valid_time
},
results_filename)
================================================
FILE: examples/linkproppred/tkgl-polecat/cen.py
================================================
"""
Complex Evolutional Pattern Learning for Temporal Knowledge Graph Reasoning
Reference:
- https://github.com/Lee-zix/CEN
Zixuan Li, Saiping Guan, Xiaolong Jin, Weihua Peng, Yajuan Lyu , Yong Zhu, Long Bai, Wei Li, Jiafeng Guo, Xueqi Cheng.
Complex Evolutional Pattern Learning for Temporal Knowledge Graph Reasoning. ACL 2022.
"""
import timeit
import os
import sys
import os.path as osp
from pathlib import Path
import numpy as np
import torch
import random
from tqdm import tqdm
import json
# internal imports
modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(modules_path)
from modules.rrgcn import RecurrentRGCNCEN
from tgb.utils.utils import set_random_seed, split_by_time, save_results
from modules.tkg_utils import get_args_cen, reformat_ts
from modules.tkg_utils_dgl import build_sub_graph
from tgb.linkproppred.evaluate import Evaluator
from tgb.linkproppred.dataset import LinkPropPredDataset
def test(model, history_len, history_list, test_list, num_rels, num_nodes, use_cuda, model_name, mode, split_mode):
"""
Test the model
:param model: model used to test
:param history_list: all input history snap shot list, not include output label train list or valid list
:param test_list: test triple snap shot list
:param num_rels: number of relations
:param num_nodes: number of nodes
:param use_cuda:
:param model_name:
:param mode:
:param split_mode: 'test' or 'val' to state which negative samples to load
:return mrr
"""
print("Testing for mode: ", split_mode)
if split_mode == 'test':
timesteps_to_eval = test_timestamps_orig
else:
timesteps_to_eval = val_timestamps_orig
idx = 0
if mode == "test":
# test mode: load parameter form file
if use_cuda:
checkpoint = torch.load(model_name, map_location=torch.device(args.gpu))
else:
checkpoint = torch.load(model_name, map_location=torch.device('cpu'))
# use best stat checkpoint:
print("Load Model name: {}. Using best epoch : {}".format(model_name, checkpoint['epoch']))
print("\n"+"-"*10+"start testing"+"-"*10+"\n")
model.load_state_dict(checkpoint['state_dict'])
model.eval()
input_list = [snap for snap in history_list[-history_len:]]
perf_list_all = []
hits_list_all = []
perf_per_rel = {}
for rel in range(num_rels):
perf_per_rel[rel] = []
for time_idx, test_snap in enumerate(tqdm(test_list)):
history_glist = [build_sub_graph(num_nodes, num_rels, g, use_cuda, args.gpu) for g in input_list]
test_triples_input = torch.LongTensor(test_snap).cuda() if use_cuda else torch.LongTensor(test_snap)
timesteps_batch =timesteps_to_eval[time_idx]*np.ones(len(test_triples_input[:,0]))
neg_samples_batch = neg_sampler.query_batch(test_triples_input[:,0], test_triples_input[:,2],
timesteps_batch, edge_type=test_triples_input[:,1], split_mode=split_mode)
pos_samples_batch = test_triples_input[:,2]
_, perf_list, hits_list = model.predict(history_glist, test_triples_input, use_cuda, neg_samples_batch, pos_samples_batch,
evaluator, METRIC)
perf_list_all.extend(perf_list)
hits_list_all.extend(hits_list)
if split_mode == "test":
if args.log_per_rel:
for score, rel in zip(perf_list, test_triples_input[:,1].tolist()):
perf_per_rel[rel].append(score)
# reconstruct history graph list
input_list.pop(0)
input_list.append(test_snap)
idx += 1
if split_mode == "test":
if args.log_per_rel:
for rel in range(num_rels):
if len(perf_per_rel[rel]) > 0:
perf_per_rel[rel] = float(np.mean(perf_per_rel[rel]))
else:
perf_per_rel.pop(rel)
mrr = np.mean(perf_list_all)
hits10 = np.mean(hits_list_all)
return mrr, perf_per_rel, hits10
def run_experiment(args, trainvalidtest_id=0, n_hidden=None, n_layers=None, dropout=None, n_bases=None):
'''
Run experiment for CEN model
:param args: arguments for the model
:param trainvalidtest_id: -1: pretrainig, 0: curriculum training (to find best test history len), 1: test on valid set, 2: test on test set
:param n_hidden: number of hidden units
:param n_layers: number of layers
:param dropout: dropout rate
:param n_bases: number of bases
return: mrr, perf_per_rel: mean reciprocal rank and performance per relation
'''
# 1) load configuration for grid search the best configuration
if n_hidden:
args.n_hidden = n_hidden
if n_layers:
args.n_layers = n_layers
if dropout:
args.dropout = dropout
if n_bases:
args.n_bases = n_bases
test_history_len = args.test_history_len
# 2) set save model path
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
if not osp.exists(save_model_dir):
os.mkdir(save_model_dir)
print('INFO: Create directory {}'.format(save_model_dir))
test_model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{args.test_history_len}'
test_state_file = save_model_dir+test_model_name
perf_per_rel ={}
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
# create stat
model = RecurrentRGCNCEN(args.decoder,
args.encoder,
num_nodes,
num_rels,
args.n_hidden,
args.opn,
sequence_len=args.train_history_len,
num_bases=args.n_bases,
num_basis=args.n_basis,
num_hidden_layers=args.n_layers,
dropout=args.dropout,
self_loop=args.self_loop,
skip_connect=args.skip_connect,
layer_norm=args.layer_norm,
input_dropout=args.input_dropout,
hidden_dropout=args.hidden_dropout,
feat_dropout=args.feat_dropout,
entity_prediction=args.entity_prediction,
relation_prediction=args.relation_prediction,
use_cuda=use_cuda,
gpu = args.gpu)
if use_cuda:
torch.cuda.set_device(args.gpu)
model.cuda()
# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
if trainvalidtest_id == 1: # normal test on validation set Note that mode=test
if os.path.exists(test_state_file):
mrr, _, hits10 = test(model, args.test_history_len, train_list, valid_list, num_rels, num_nodes, use_cuda,
test_state_file, "test", split_mode="val")
else:
print('Cannot do testing because model does not exist: ', test_state_file)
mrr = 0
hits10 = 0
elif trainvalidtest_id == 2: # normal test on test set
if os.path.exists(test_state_file):
mrr, perf_per_rel, hits10 = test(model, args.test_history_len, train_list+valid_list, test_list, num_rels, num_nodes, use_cuda,
test_state_file, "test", split_mode="test")
else:
print('Cannot do testing because model does not exist: ', test_state_file)
mrr = 0
hits10 = 0
elif trainvalidtest_id == -1:
print("-------------start pre training model with history length {}----------\n".format(args.start_history_len))
model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{args.start_history_len}'
model_state_file = save_model_dir + model_name
print("Sanity Check: stat name : {}".format(model_state_file))
print("Sanity Check: Is cuda available ? {}".format(torch.cuda.is_available()))
best_mrr = 0
best_epoch = 0
best_hits10= 0
## training loop
for epoch in range(args.n_epochs):
model.train()
losses = []
idx = [_ for _ in range(len(train_list))]
random.shuffle(idx)
for train_sample_num in idx:
if train_sample_num == 0 or train_sample_num == 1: continue
if train_sample_num - args.start_history_len<0:
input_list = train_list[0: train_sample_num]
output = train_list[1:train_sample_num+1]
else:
input_list = train_list[train_sample_num-args.start_history_len: train_sample_num]
output = train_list[train_sample_num-args.start_history_len+1:train_sample_num+1]
# generate history graph
history_glist = [build_sub_graph(num_nodes, num_rels, snap, use_cuda, args.gpu) for snap in input_list]
output = [torch.from_numpy(_).long().cuda() for _ in output] if use_cuda else [torch.from_numpy(_).long() for _ in output]
loss= model.get_loss(history_glist, output[-1], None, use_cuda)
losses.append(loss.item())
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm) # clip gradients
optimizer.step()
optimizer.zero_grad()
print("His {:04d}, Epoch {:04d} | Ave Loss: {:.4f} | Best MRR {:.4f} | Model {} "
.format(args.start_history_len, epoch, np.mean(losses), best_mrr, model_name))
#! checking GPU usage
free_mem, total_mem = torch.cuda.mem_get_info()
print ("--------------GPU memory usage-----------")
print ("there are ", free_mem, " free memory")
print ("there are ", total_mem, " total available memory")
print ("there are ", total_mem - free_mem, " used memory")
print ("--------------GPU memory usage-----------")
# validation
if epoch % args.evaluate_every == 0:
mrr, _, hits10 = test(model, args.start_history_len, train_list, valid_list, num_rels, num_nodes, use_cuda,
model_state_file, mode="train", split_mode= "val")
if mrr< best_mrr:
if epoch >= args.n_epochs or epoch - best_epoch > 5:
break
else:
best_mrr = mrr
best_epoch = epoch
best_hits10 = hits10
torch.save({'state_dict': model.state_dict(), 'epoch': epoch}, model_state_file)
elif trainvalidtest_id == 0: #curriculum training
model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{args.start_history_len}'
init_state_file = save_model_dir + model_name
init_checkpoint = torch.load(init_state_file, map_location=torch.device(args.gpu))
# use best stat checkpoint:
print("Load Previous Model name: {}. Using best epoch : {}".format(init_state_file, init_checkpoint['epoch']))
print("\n"+"-"*10+"Load model with history length {}".format(args.start_history_len)+"-"*10+"\n")
model.load_state_dict(init_checkpoint['state_dict'])
test_history_len = args.start_history_len
mrr, _, hits10 = test(model,
args.start_history_len,
train_list,
valid_list,
num_rels,
num_nodes,
use_cuda,
init_state_file,
mode="test", split_mode= "val")
best_mrr_list = [mrr.item()]
best_hits_list = [hits10.item()]
# start knowledge distillation
ks_idx = 0
for history_len in range(args.start_history_len+1, args.train_history_len+1, 1):
# current model
print("best mrr list :", best_mrr_list)
# lr = 0.1*args.lr - 0.002*args.lr*ks_idx
optimizer = torch.optim.Adam(model.parameters(), lr=0.1*args.lr, weight_decay=0.00001)
model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{history_len}'
model_state_file = save_model_dir + model_name
print("Sanity Check: stat name : {}".format(model_state_file))
# load model with the least history length
prev_model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{history_len-1}'
prev_state_file = save_model_dir + prev_model_name
checkpoint = torch.load(prev_state_file, map_location=torch.device(args.gpu))
model.load_state_dict(checkpoint['state_dict'])
print("\n"+"-"*10+"start knowledge distillation for history length at "+ str(history_len)+"-"*10+"\n")
best_mrr = 0
best_hits10 = 0
best_epoch = 0
for epoch in range(args.n_epochs):
model.train()
losses = []
idx = [_ for _ in range(len(train_list))]
random.shuffle(idx)
for train_sample_num in idx:
if train_sample_num == 0 or train_sample_num == 1: continue
if train_sample_num - history_len<0:
input_list = train_list[0: train_sample_num]
output = train_list[1:train_sample_num+1]
else:
input_list = train_list[train_sample_num - history_len: train_sample_num]
output = train_list[train_sample_num-history_len+1:train_sample_num+1]
# generate history graph
history_glist = [build_sub_graph(num_nodes, num_rels, snap, use_cuda, args.gpu) for snap in input_list]
output = [torch.from_numpy(_).long().cuda() for _ in output] if use_cuda else [torch.from_numpy(_).long() for _ in output]
loss= model.get_loss(history_glist, output[-1], None, use_cuda)
# print(loss)
losses.append(loss.item())
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm) # clip gradients
optimizer.step()
optimizer.zero_grad()
print("His {:04d}, Epoch {:04d} | Ave Loss: {:.4f} |Best MRR {:.4f} | Model {} "
.format(history_len, epoch, np.mean(losses), best_mrr, model_name))
#! checking GPU usage
free_mem, total_mem = torch.cuda.mem_get_info()
print ("--------------GPU memory usage-----------")
print ("there are ", free_mem, " free memory")
print ("there are ", total_mem, " total available memory")
print ("there are ", total_mem - free_mem, " used memory")
print ("--------------GPU memory usage-----------")
# validation
if epoch % args.evaluate_every == 0:
mrr, _, hits10 = test(model, history_len, train_list, valid_list, num_rels, num_nodes, use_cuda,
model_state_file, mode="train", split_mode= "val")
if mrr< best_mrr:
if epoch >= args.n_epochs or epoch-best_epoch>2:
break
else:
best_mrr = mrr
best_epoch = epoch
best_hits10 = hits10
torch.save({'state_dict': model.state_dict(), 'epoch': epoch}, model_state_file)
mrr, _, hits10 = test(model, history_len, train_list, valid_list, num_rels, num_nodes, use_cuda,
model_state_file, mode="test", split_mode= "val")
ks_idx += 1
if mrr.item() < max(best_mrr_list):
test_history_len = history_len-1
print("early stopping, best history length: ", test_history_len)
break
else:
best_mrr_list.append(mrr.item())
best_hits_list.append(hits10.item())
return mrr, test_history_len, perf_per_rel, hits10
# ==================
# ==================
# ==================
start_overall = timeit.default_timer()
# set hyperparameters
args, _ = get_args_cen()
args.dataset = 'tkgl-polecat'
SEED = args.seed # set the random seed for consistency
set_random_seed(SEED)
DATA=args.dataset
MODEL_NAME = 'CEN'
print("logging mrrs per relation: ", args.log_per_rel)
print("do test and valid? do only test no validation?: ", args.validtest, args.test_only)
# load data
dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
num_rels = dataset.num_rels
num_nodes = dataset.num_nodes
subjects = dataset.full_data["sources"]
objects= dataset.full_data["destinations"]
relations = dataset.edge_type
timestamps_orig = dataset.full_data["timestamps"]
timestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1
all_quads = np.stack((subjects, relations, objects, timestamps), axis=1)
train_data = all_quads[dataset.train_mask]
val_data = all_quads[dataset.val_mask]
test_data = all_quads[dataset.test_mask]
val_timestamps_orig = list(set(timestamps_orig[dataset.val_mask])) # needed for getting the negative samples
val_timestamps_orig.sort()
test_timestamps_orig = list(set(timestamps_orig[dataset.test_mask])) # needed for getting the negative samples
test_timestamps_orig.sort()
train_list = split_by_time(train_data)
valid_list = split_by_time(val_data)
test_list = split_by_time(test_data)
# evaluation metric
METRIC = dataset.eval_metric
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
#load the ns samples
dataset.load_val_ns()
dataset.load_test_ns()
if args.grid_search:
print("TODO: implement hyperparameter grid search")
# single run
else:
start_train = timeit.default_timer()
if args.validtest:
print('directly start testing')
if args.test_history_len_2 != args.test_history_len:
args.test_history_len = args.test_history_len_2 # hyperparameter value as given in original paper
else:
print('running pretrain and train')
# pretrain
mrr, _, _, hits10 = run_experiment(args, trainvalidtest_id=-1)
# train
mrr, args.test_history_len, _, hits10 = run_experiment(args, trainvalidtest_id=0) # overwrite test_history_len with
# the best history len (for valid mrr)
if args.test_only == False:
print("running test (on val and test dataset) with test_history_len of: ", args.test_history_len)
# test on val set
val_mrr, _, _, val_hits10 = run_experiment(args, trainvalidtest_id=1)
else:
val_mrr = 0
val_hits10 = 0
# test on test set
start_test = timeit.default_timer()
test_mrr, _, perf_per_rel, test_hits10 = run_experiment(args, trainvalidtest_id=2)
test_time = timeit.default_timer() - start_test
all_time = timeit.default_timer() - start_train
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
print(f"\Train and Test: Elapsed Time (s): {all_time: .4f}")
print(f"\tTest: {METRIC}: {test_mrr: .4f}")
print(f"\tValid: {METRIC}: {val_mrr: .4f}")
# saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
save_results({'model': MODEL_NAME,
'data': DATA,
'run': args.run_nr,
'seed': SEED,
'test_history_len': args.test_history_len,
f'val {METRIC}': float(val_mrr),
f'test {METRIC}': float(test_mrr),
'test_time': test_time,
'tot_train_val_time': all_time,
'test_hits10': float(test_hits10)
},
results_filename)
if args.log_per_rel:
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results_per_rel.json'
with open(results_filename, 'w') as json_file:
json.dump(perf_per_rel, json_file)
sys.exit()
================================================
FILE: examples/linkproppred/tkgl-polecat/edgebank.py
================================================
"""
Dynamic Link Prediction with EdgeBank
NOTE: This implementation works only based on `numpy`
Reference:
- https://github.com/fpour/DGB/tree/main
"""
import timeit
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch_geometric.loader import TemporalDataLoader
from tqdm import tqdm
import math
import os
import os.path as osp
from pathlib import Path
import sys
import argparse
# internal imports
tgb_modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(tgb_modules_path)
from tgb.linkproppred.evaluate import Evaluator
from modules.edgebank_predictor import EdgeBankPredictor
from tgb.utils.utils import set_random_seed
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.utils.utils import save_results
# ==================
# ==================
# ==================
def test(data, test_mask, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
data: a dataset object
test_mask: required masks to load the test set edges
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluation
"""
num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)
perf_list = []
hits_list = []
for batch_idx in tqdm(range(num_batches)):
start_idx = batch_idx * BATCH_SIZE
end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))
pos_src, pos_dst, pos_t, pos_edge = (
data['sources'][test_mask][start_idx: end_idx],
data['destinations'][test_mask][start_idx: end_idx],
data['timestamps'][test_mask][start_idx: end_idx],
data['edge_type'][test_mask][start_idx: end_idx],
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, pos_edge, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])
query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])
y_pred = edgebank.predict_link(query_src, query_dst)
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0]]),
"y_pred_neg": np.array(y_pred[1:]),
"eval_metric": [metric],
}
results = evaluator.eval(input_dict)
perf_list.append(results[metric])
hits_list.append(results['hits@10'])
# update edgebank memory after each positive batch
edgebank.update_memory(pos_src, pos_dst, pos_t)
perf_metrics = float(np.mean(perf_list))
perf_hits = float(np.mean(hits_list))
return perf_metrics, perf_hits
def get_args():
parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')
parser.add_argument('-d', '--data', type=str, help='Dataset name', default='tkgl-polecat')
parser.add_argument('--bs', type=int, help='Batch size', default=200)
parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)
parser.add_argument('--seed', type=int, help='Random seed', default=1)
parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])
parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)
try:
args = parser.parse_args()
except:
parser.print_help()
sys.exit(0)
return args, sys.argv
# ==================
# ==================
# ==================
start_overall = timeit.default_timer()
# set hyperparameters
args, _ = get_args()
SEED = args.seed # set the random seed for consistency
set_random_seed(SEED)
MEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`
BATCH_SIZE = args.bs
K_VALUE = args.k_value
TIME_WINDOW_RATIO = args.time_window_ratio
DATA = args.data
MODEL_NAME = 'EdgeBank'
# data loading with `numpy`
dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
data = dataset.full_data
metric = dataset.eval_metric
# get masks
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
#data for memory in edgebank
hist_src = np.concatenate([data['sources'][train_mask]])
hist_dst = np.concatenate([data['destinations'][train_mask]])
hist_ts = np.concatenate([data['timestamps'][train_mask]])
# Set EdgeBank with memory updater
edgebank = EdgeBankPredictor(
hist_src,
hist_dst,
hist_ts,
memory_mode=MEMORY_MODE,
time_window_ratio=TIME_WINDOW_RATIO)
print("==========================================================")
print(f"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'
# ==================================================== Test
# loading the validation negative samples
dataset.load_val_ns()
# testing ...
start_val = timeit.default_timer()
perf_metric_val, perf_hits_val = test(data, val_mask, neg_sampler, split_mode='val')
end_val = timeit.default_timer()
print(f"INFO: val: Evaluation Setting: >>> ONE-VS--ALL <<< ")
print(f"\tval: {metric}: {perf_metric_val: .4f}")
val_time = timeit.default_timer() - start_val
print(f"\tval: Elapsed Time (s): {val_time: .4f}")
# ==================================================== Test
# loading the test negative samples
dataset.load_test_ns()
# testing ...
start_test = timeit.default_timer()
perf_metric_test, perf_hits_test = test(data, test_mask, neg_sampler, split_mode='test')
end_test = timeit.default_timer()
print(f"INFO: Test: Evaluation Setting: >>> <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'memory_mode': MEMORY_MODE,
'data': DATA,
'run': 1,
'seed': SEED,
metric: perf_metric_test,
'val_mrr': perf_metric_val,
'test_time': test_time,
'tot_train_val_time': test_time+val_time,
'hits10': perf_hits_test},
results_filename)
================================================
FILE: examples/linkproppred/tkgl-polecat/example.py
================================================
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
DATA = "tkgl-polecat"
# data loading
dataset = PyGLinkPropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
metric = dataset.eval_metric
timestamp = data.t
head = data.src
tail = data.dst
edge_type = data.edge_type #relation
print (edge_type)
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
================================================
FILE: examples/linkproppred/tkgl-polecat/recurrencybaseline.py
================================================
""" from paper: "History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting"
Julia Gastinger, Christian Meilicke, Federico Errica, Timo Sztyler, Anett Schuelke, Heiner Stuckenschmidt (IJCAI 2024)
@inproceedings{gastinger2024baselines,
title={History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting},
author={Gastinger, Julia and Meilicke, Christian and Errica, Federico and Sztyler, Timo and Schuelke, Anett and Stuckenschmidt, Heiner},
booktitle={33nd International Joint Conference on Artificial Intelligence (IJCAI 2024)},
year={2024},
organization={International Joint Conferences on Artificial Intelligence Organization}
}
"""
## imports
import timeit
import argparse
import numpy as np
from copy import copy
from pathlib import Path
import ray
import sys
import os
import os.path as osp
import json
#internal imports
modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(modules_path)
from modules.recurrencybaseline_predictor import baseline_predict, baseline_predict_remote
from tgb.linkproppred.evaluate import Evaluator
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.utils.utils import set_random_seed, save_results
from modules.tkg_utils import create_basis_dict, group_by, reformat_ts
def predict(num_processes, data_c_rel, all_data_c_rel, alpha, lmbda_psi,
perf_list_all, hits_list_all, window, neg_sampler, split_mode):
""" create predictions for each relation on test or valid set and compute mrr
:return perf_list_all: list of mrrs for each test query
:return hits_list_all: list of hits for each test query
"""
first_ts = data_c_rel[0][3]
## use this if you wanna use ray:
num_queries = len(data_c_rel) // num_processes
if num_queries < num_processes: # if we do not have enough queries for all the processes
num_processes_tmp = 1
num_queries = len(data_c_rel)
else:
num_processes_tmp = num_processes
if num_processes > 1:
object_references =[]
for i in range(num_processes_tmp):
num_test_queries = len(data_c_rel) - (i + 1) * num_queries
if num_test_queries >= num_queries:
test_queries_idx =[i * num_queries, (i + 1) * num_queries]
else:
test_queries_idx = [i * num_queries, len(test_data)]
valid_data_b = data_c_rel[test_queries_idx[0]:test_queries_idx[1]]
ob = baseline_predict_remote.remote(num_queries, valid_data_b, all_data_c_rel, window,
basis_dict,
num_nodes, num_rels, lmbda_psi,
alpha, evaluator,first_ts, neg_sampler, split_mode)
object_references.append(ob)
output = ray.get(object_references)
# updates the scores and logging dict for each process
for proc_loop in range(num_processes_tmp):
perf_list_all.extend(output[proc_loop][0])
hits_list_all.extend(output[proc_loop][1])
else:
perf_list, hits_list = baseline_predict(len(data_c_rel), data_c_rel, all_data_c_rel,
window, basis_dict,
num_nodes, num_rels, lmbda_psi,
alpha, evaluator, first_ts, neg_sampler, split_mode)
perf_list_all.extend(perf_list)
hits_list_all.extend(hits_list)
return perf_list_all, hits_list_all
## test
def test(best_config, all_relations,test_data_prel, all_data_prel, neg_sampler, num_processes, window, split_mode='test'):
""" create predictions by loopoing through all relations on test or valid set and compute mrr
:return perf_list_all: list of mrrs for each test query
:return hits_list_all: list of hits for each test query
"""
perf_list_all = []
hits_list_all =[]
csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'
## loop through relations and apply baselines
for rel in all_relations:
start = timeit.default_timer()
if rel in test_data_prel.keys():
lmbda_psi = best_config[str(rel)]['lmbda_psi'][0]
alpha = best_config[str(rel)]['alpha'][0]
# test data for this relation
test_data_c_rel = test_data_prel[rel]
timesteps_test = list(set(test_data_c_rel[:,3]))
timesteps_test.sort()
all_data_c_rel = all_data_prel[rel]
perf_list_rel = []
hits_list_rel = []
perf_list_rel, hits_list_rel = predict(num_processes, test_data_c_rel,
all_data_c_rel, alpha, lmbda_psi,perf_list_rel, hits_list_rel,
window, neg_sampler, split_mode)
perf_list_all.extend(perf_list_rel)
hits_list_all.extend(hits_list_rel)
else:
perf_list_rel =[]
end = timeit.default_timer()
total_time = round(end - start, 6)
print("Relation {} finished in {} seconds.".format(rel, total_time))
with open(csv_file, 'a') as f:
f.write("{},{}\n".format(rel, perf_list_rel))
return perf_list_all, hits_list_all
def read_dict_compute_mrr(split_mode='test'):
""" read the results per relation from a precreated file and compute mrr
:return mrr_per_rel: dictionary of mrrs for each relation
:return all_mrrs: list of mrrs for all relations
"""
csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'
# Initialize an empty dictionary to store the data
results_per_rel_dict = {}
mrr_per_rel = {}
all_mrrs = []
# Open the file for reading
with open(csv_file, 'r') as f:
# Read each line in the file
for line in f:
# Split the line at the comma
parts = line.strip().split(',')
# Extract the key (the first part)
key = int(parts[0])
# Extract the values (the rest of the parts), remove square brackets
values = [float(value.strip('[]')) for value in parts[1:]]
# Add the key-value pair to the dictionary
if key in results_per_rel_dict.keys():
print(f"Key {key} already exists in the dictionary!!! might have duplicate entries in results csv")
results_per_rel_dict[key] = values
all_mrrs.extend(values)
mrr_per_rel[key] = np.mean(values)
if len(list(results_per_rel_dict.keys())) != num_rels:
print("we do not have entries for each rel in the results csv file. only num enties: ", len(list(results_per_rel_dict.keys())))
print("Split mode: "+split_mode +" Mean MRR: ", np.mean(all_mrrs))
print("mrr per relation: ", mrr_per_rel)
## train
def train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampler, num_processes, window):
""" optional, find best values for lambda and alpha by looping through all relations and testing an a fixed set of params
based on validation mrr
:return best_config: dictionary of best params for each relation
"""
best_config= {}
best_mrr = 0
for rel in rels: # loop through relations. for each relation, apply rules with selected params, compute valid mrr
start = timeit.default_timer()
rel_key = int(rel)
best_config[str(rel_key)] = {}
best_config[str(rel_key)]['not_trained'] = 'True'
best_config[str(rel_key)]['lmbda_psi'] = [default_lmbda_psi,0] #default
best_config[str(rel_key)]['other_lmbda_mrrs'] = list(np.zeros(len(params_dict['lmbda_psi'])))
best_config[str(rel_key)]['alpha'] = [default_alpha,0] #default
best_config[str(rel_key)]['other_alpha_mrrs'] = list(np.zeros(len(params_dict['alpha'])))
if rel in val_data_prel.keys():
# valid data for this relation
val_data_c_rel = val_data_prel[rel]
timesteps_valid = list(set(val_data_c_rel[:,3]))
timesteps_valid.sort()
trainval_data_c_rel = trainval_data_prel[rel]
###### 1) select lambda ###############
lmbdas_psi = params_dict['lmbda_psi']
alpha = 1
best_lmbda_psi = 0.1
best_mrr_psi = 0
lmbda_mrrs = []
best_config[str(rel_key)]['num_app_valid'] = copy(len(val_data_c_rel))
best_config[str(rel_key)]['num_app_train_valid'] = copy(len(trainval_data_c_rel))
best_config[str(rel_key)]['not_trained'] = 'False'
for lmbda_psi in lmbdas_psi:
perf_list_r = []
hits_list_r = []
perf_list_r, hits_list_r = predict(num_processes, val_data_c_rel,
trainval_data_c_rel, alpha, lmbda_psi,perf_list_r, hits_list_r,
window, neg_sampler, split_mode='val')
# compute mrr
mrr = np.mean(perf_list_r)
# # is new mrr better than previous best? if yes: store lmbda
if mrr > best_mrr_psi:
best_mrr_psi = float(mrr)
best_lmbda_psi = lmbda_psi
lmbda_mrrs.append(float(mrr))
best_config[str(rel_key)]['lmbda_psi'] = [best_lmbda_psi, best_mrr_psi]
best_config[str(rel_key)]['other_lmbda_mrrs'] = lmbda_mrrs
best_mrr = best_mrr_psi
##### 2) select alpha ###############
best_config[str(rel_key)]['not_trained'] = 'False'
alphas = params_dict['alpha']
lmbda_psi = best_config[str(rel_key)]['lmbda_psi'][0] # use the best lmbda psi
alpha_mrrs = []
# perf_list_all = []
best_mrr_alpha = 0
best_alpha=0.99
for alpha in alphas:
perf_list_r = []
hits_list_r = []
perf_list_r, hits_list_r = predict(num_processes, val_data_c_rel,
trainval_data_c_rel, alpha, lmbda_psi,perf_list_r, hits_list_r,
window, neg_sampler, split_mode='val')
# compute mrr
mrr_alpha = np.mean(perf_list_r)
# is new mrr better than previous best? if yes: store alpha
if mrr_alpha > best_mrr_alpha:
best_mrr_alpha = float(mrr_alpha)
best_alpha = alpha
best_mrr = best_mrr_alpha
alpha_mrrs.append(float(mrr_alpha))
best_config[str(rel_key)]['alpha'] = [best_alpha, best_mrr_alpha]
best_config[str(rel_key)]['other_alpha_mrrs'] = alpha_mrrs
end = timeit.default_timer()
total_time = round(end - start, 6)
print("Relation {} finished in {} seconds.".format(rel, total_time))
return best_config
## args
def get_args():
"""parse all arguments for the script"""
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", "-d", default="tkgl-polecat", type=str)
parser.add_argument("--window", "-w", default=0, type=int) # set to e.g. 200 if only the most recent 200 timesteps should be considered. set to -2 if multistep
parser.add_argument("--num_processes", "-p", default=1, type=int)
parser.add_argument("--lmbda", "-l", default=0.1, type=float) # fix lambda. used if trainflag == false
parser.add_argument("--alpha", "-alpha", default=0.99, type=float) # fix alpha. used if trainflag == false
parser.add_argument("--train_flag", "-tr", default='False') # do we need training, ie selection of lambda and alpha
parser.add_argument("--load_flag", "-lo", default='False') # if train_flag set to True: do you want to load best_config?
parser.add_argument("--save_config", "-c", default='True') # do we need to save the selection of lambda and alpha in config file?
parser.add_argument('--seed', type=int, help='Random seed', default=1) # not needed
parsed = vars(parser.parse_args())
return parsed
start_o = timeit.default_timer()
parsed = get_args()
if parsed['num_processes']>1:
ray.init(num_cpus=parsed["num_processes"], num_gpus=0)
MODEL_NAME = 'RecurrencyBaseline'
SEED = parsed['seed'] # set the random seed for consistency
set_random_seed(SEED)
perrel_results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_models'
if not osp.exists(perrel_results_path):
os.mkdir(perrel_results_path)
print('INFO: Create directory {}'.format(perrel_results_path))
Path(perrel_results_path).mkdir(parents=True, exist_ok=True)
## load dataset and prepare it accordingly
name = parsed["dataset"]
dataset = LinkPropPredDataset(name=name, root="datasets", preprocess=True)
DATA = name
relations = dataset.edge_type
num_rels = dataset.num_rels
rels = np.arange(0,num_rels)
subjects = dataset.full_data["sources"]
objects= dataset.full_data["destinations"]
num_nodes = dataset.num_nodes
timestamps_orig = dataset.full_data["timestamps"]
timestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1
print("split train valid test data")
all_quads = np.stack((subjects, relations, objects, timestamps, timestamps_orig), axis=1)
train_data = all_quads[dataset.train_mask]
val_data = all_quads[dataset.val_mask]
test_data = all_quads[dataset.test_mask]
metric = dataset.eval_metric
evaluator = Evaluator(name=name)
neg_sampler = dataset.negative_sampler
train_val_data = np.concatenate([train_data, val_data])
all_data = np.concatenate([train_data, val_data, test_data])
# create dicts with key: relation id, values: triples for that relation id
print("grouping data by relation")
test_data_prel = group_by(test_data, 1)
all_data_prel = group_by(all_data, 1)
val_data_prel = group_by(val_data, 1)
trainval_data_prel = group_by(train_val_data, 1)
#load the ns samples
# if parsed['train_flag']:
print("loading negative samples")
dataset.load_val_ns()
dataset.load_test_ns()
# parameter options
if parsed['train_flag'] == 'True':
params_dict = {}
params_dict['lmbda_psi'] = [0, 0.0001, 0.0005, 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.5, 0.9, 1.0001]
params_dict['alpha'] = [0, 0.00001, 0.0001, 0.001, 0.01, 0.1, 0.5, 0.9, 0.99, 0.999, 0.9999, 0.99999, 1]
default_lmbda_psi = params_dict['lmbda_psi'][-1]
default_alpha = params_dict['alpha'][-2]
## load rules
print("creating rules")
basis_dict = create_basis_dict(train_val_data)
print("done with creating rules")
## init
# rb_predictor = RecurrencyBaselinePredictor(rels)
## train to find best lambda and alpha
start_train = timeit.default_timer()
if parsed['train_flag'] == 'True':
if parsed['load_flag'] == 'True':
with open('best_config.json', 'r') as infile:
best_config = json.load(infile)
else:
print('start training')
best_config = train(params_dict, rels, val_data_prel, trainval_data_prel, neg_sampler, parsed['num_processes'],
parsed['window'])
if parsed['save_config'] == 'True':
import json
with open('best_config.json', 'w') as outfile:
json.dump(best_config, outfile)
else: # use preset lmbda and alpha; same for all relations
best_config = {}
for rel in rels:
best_config[str(rel)] = {}
best_config[str(rel)]['lmbda_psi'] = [parsed['lmbda']]
best_config[str(rel)]['alpha'] = [parsed['alpha']]
end_train = timeit.default_timer()
# compute validation mrr
print("Computing validation MRR")
perf_list_all_val, hits_list_all_val = test(best_config,rels, val_data_prel,
trainval_data_prel, neg_sampler, parsed['num_processes'],
parsed['window'], split_mode='val')
val_mrr = float(np.mean(perf_list_all_val))
# compute test mrr
print("Computing test MRR")
start_test = timeit.default_timer()
perf_list_all, hits_list_all = test(best_config,rels, test_data_prel,
all_data_prel, neg_sampler, parsed['num_processes'],
parsed['window'])
end_o = timeit.default_timer()
train_time_o = round(end_train- start_train, 6)
test_time_o = round(end_o- start_test, 6)
total_time_o = round(end_o- start_o, 6)
print("Running Training to find best configs finished in {} seconds.".format(train_time_o))
print("Running testing with best configs finished in {} seconds.".format(test_time_o))
print("Running all steps finished in {} seconds.".format(total_time_o))
print(f"The test MRR is {np.mean(perf_list_all)}")
print(f"The valid MRR is {val_mrr}")
print(f"The Hits@10 is {np.mean(hits_list_all)}")
print(f"We have {len(perf_list_all)} predictions")
print(f"The test set has len {len(test_data)} ")
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_NONE_{DATA}_results.json'
metric = dataset.eval_metric
save_results({'model': MODEL_NAME,
'train_flag': parsed['train_flag'],
'data': DATA,
'run': 1,
'seed': SEED,
metric: float(np.mean(perf_list_all)),
'val_mrr': val_mrr,
'hits10': float(np.mean(hits_list_all)),
'test_time': test_time_o,
'tot_train_val_time': total_time_o
},
results_filename)
if parsed['num_processes']>1:
ray.shutdown()
================================================
FILE: examples/linkproppred/tkgl-polecat/regcn.py
================================================
"""
Temporal Knowledge Graph Reasoning Based on Evolutional Representation Learning
Reference:
- https://github.com/Lee-zix/RE-GCN
Zixuan Li, Xiaolong Jin, Wei Li, Saiping Guan, Jiafeng Guo, Huawei Shen, Yuanzhuo Wang and Xueqi Cheng. Temporal
Knowledge Graph Reasoning Based on Evolutional Representation Learning. SIGIR 2021.
"""
import sys
import timeit
import os
import sys
import os.path as osp
from pathlib import Path
import numpy as np
import torch
import random
from tqdm import tqdm
# internal imports
modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(modules_path)
from modules.rrgcn import RecurrentRGCNREGCN
from tgb.utils.utils import set_random_seed, split_by_time, save_results
from modules.tkg_utils import get_args_regcn, reformat_ts
from modules.tkg_utils_dgl import build_sub_graph
from tgb.linkproppred.evaluate import Evaluator
from tgb.linkproppred.dataset import LinkPropPredDataset
import json
def test(model, history_list, test_list, num_rels, num_nodes, use_cuda, model_name, static_graph, mode, split_mode):
"""
Test the model on either test or validation set
:param model: model used to test
:param history_list: all input history snap shot list, not include output label train list or valid list
:param test_list: test triple snap shot list
:param num_rels: number of relations
:param num_nodes: number of nodes
:param use_cuda:
:param model_name:
:param mode:
:param split_mode: 'test' or 'val' to state which negative samples to load
:return mrr
"""
print("Testing for mode: ", split_mode)
if split_mode == 'test':
timesteps_to_eval = test_timestamps_orig
else:
timesteps_to_eval = val_timestamps_orig
idx = 0
if mode == "test":
# test mode: load parameter form file
if use_cuda:
checkpoint = torch.load(model_name, map_location=torch.device(args.gpu))
else:
checkpoint = torch.load(model_name, map_location=torch.device('cpu'))
# use best stat checkpoint:
print("Load Model name: {}. Using best epoch : {}".format(model_name, checkpoint['epoch']))
print("\n"+"-"*10+"start testing"+"-"*10+"\n")
model.load_state_dict(checkpoint['state_dict'])
model.eval()
input_list = [snap for snap in history_list[-args.test_history_len:]]
perf_list_all = []
hits_list_all = []
perf_per_rel = {}
for rel in range(num_rels):
perf_per_rel[rel] = []
for time_idx, test_snap in enumerate(tqdm(test_list)):
history_glist = [build_sub_graph(num_nodes, num_rels, g, use_cuda, args.gpu) for g in input_list]
test_triples_input = torch.LongTensor(test_snap).cuda() if use_cuda else torch.LongTensor(test_snap)
timesteps_batch =timesteps_to_eval[time_idx]*np.ones(len(test_triples_input[:,0]))
neg_samples_batch = neg_sampler.query_batch(test_triples_input[:,0], test_triples_input[:,2],
timesteps_batch, edge_type=test_triples_input[:,1], split_mode=split_mode)
pos_samples_batch = test_triples_input[:,2]
_, perf_list, hits_list = model.predict(history_glist, num_rels, static_graph, test_triples_input, use_cuda, neg_samples_batch, pos_samples_batch,
evaluator, METRIC)
perf_list_all.extend(perf_list)
hits_list_all.extend(hits_list)
if split_mode == "test":
if args.log_per_rel:
for score, rel in zip(perf_list, test_triples_input[:,1].tolist()):
perf_per_rel[rel].append(score)
# reconstruct history graph list
input_list.pop(0)
input_list.append(test_snap)
idx += 1
if split_mode == "test":
if args.log_per_rel:
for rel in range(num_rels):
if len(perf_per_rel[rel]) > 0:
perf_per_rel[rel] = float(np.mean(perf_per_rel[rel]))
else:
perf_per_rel.pop(rel)
mrr = np.mean(perf_list_all)
hits10 = np.mean(hits_list_all)
return mrr, perf_per_rel, hits10
def run_experiment(args, n_hidden=None, n_layers=None, dropout=None, n_bases=None):
"""
Run the experiment with the given configuration
:param args: arguments
:param n_hidden: hidden dimension
:param n_layers: number of layers
:param dropout: dropout rate
:param n_bases: number of bases
:return: mrr, perf_per_rel (mean reciprocal rank, performance per relation)
"""
# load configuration for grid search the best configuration
if n_hidden:
args.n_hidden = n_hidden
if n_layers:
args.n_layers = n_layers
if dropout:
args.dropout = dropout
if n_bases:
args.n_bases = n_bases
mrr = 0
hits10=0
# 2) set save model path
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
if not osp.exists(save_model_dir):
os.mkdir(save_model_dir)
print('INFO: Create directory {}'.format(save_model_dir))
model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}'
model_state_file = save_model_dir+model_name
perf_per_rel = {}
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
num_static_rels, num_words, static_triples, static_graph = 0, 0, [], None
# create stat
model = RecurrentRGCNREGCN(args.decoder,
args.encoder,
num_nodes,
int(num_rels/2),
num_static_rels, # DIFFERENT
num_words, # DIFFERENT
args.n_hidden,
args.opn,
sequence_len=args.train_history_len,
num_bases=args.n_bases,
num_basis=args.n_basis,
num_hidden_layers=args.n_layers,
dropout=args.dropout,
self_loop=args.self_loop,
skip_connect=args.skip_connect,
layer_norm=args.layer_norm,
input_dropout=args.input_dropout,
hidden_dropout=args.hidden_dropout,
feat_dropout=args.feat_dropout,
aggregation=args.aggregation, # DIFFERENT
weight=args.weight, # DIFFERENT
discount=args.discount, # DIFFERENT
angle=args.angle, # DIFFERENT
use_static=args.add_static_graph, # DIFFERENT
entity_prediction=args.entity_prediction,
relation_prediction=args.relation_prediction,
use_cuda=use_cuda,
gpu = args.gpu,
analysis=args.run_analysis) # DIFFERENT
if use_cuda:
torch.cuda.set_device(args.gpu)
model.cuda()
# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
if args.test and os.path.exists(model_state_file):
mrr, perf_per_rel, hits10 = test(model,
train_list+valid_list,
test_list,
num_rels,
num_nodes,
use_cuda,
model_state_file,
static_graph,
"test",
"test")
return mrr, perf_per_rel, hits10
elif args.test and not os.path.exists(model_state_file):
print("--------------{} not exist, Change mode to train and generate stat for testing----------------\n".format(model_state_file))
return 0, 0
else:
print("----------------------------------------start training----------------------------------------\n")
best_mrr = 0
best_hits = 0
for epoch in range(args.n_epochs):
model.train()
losses = []
idx = [_ for _ in range(len(train_list))]
random.shuffle(idx)
for train_sample_num in tqdm(idx):
if train_sample_num == 0: continue
output = train_list[train_sample_num:train_sample_num+1]
if train_sample_num - args.train_history_len<0:
input_list = train_list[0: train_sample_num]
else:
input_list = train_list[train_sample_num - args.train_history_len:
train_sample_num]
# generate history graph
history_glist = [build_sub_graph(num_nodes, num_rels, snap, use_cuda, args.gpu) for snap in input_list]
output = [torch.from_numpy(_).long().cuda() for _ in output] if use_cuda else [torch.from_numpy(_).long() for _ in output]
loss_e, loss_r, loss_static = model.get_loss(history_glist, output[0], static_graph, use_cuda)
loss = args.task_weight*loss_e + (1-args.task_weight)*loss_r + loss_static
losses.append(loss.item())
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm) # clip gradients
optimizer.step()
optimizer.zero_grad()
print("Epoch {:04d} | Ave Loss: {:.4f} | Best MRR {:.4f} | Model {} "
.format(epoch, np.mean(losses), best_mrr, model_name))
#! checking GPU usage
free_mem, total_mem = torch.cuda.mem_get_info()
print ("--------------GPU memory usage-----------")
print ("there are ", free_mem, " free memory")
print ("there are ", total_mem, " total available memory")
print ("there are ", total_mem - free_mem, " used memory")
print ("--------------GPU memory usage-----------")
# validation
if epoch and epoch % args.evaluate_every == 0:
mrr,perf_per_rel, hits10 = test(model, train_list,
valid_list,
num_rels,
num_nodes,
use_cuda,
model_state_file,
static_graph,
mode="train", split_mode='val')
if mrr < best_mrr:
if epoch >= args.n_epochs:
break
else:
best_mrr = mrr
best_hits = hits10
torch.save({'state_dict': model.state_dict(), 'epoch': epoch}, model_state_file)
return best_mrr, perf_per_rel, hits10
# ==================
# ==================
# ==================
start_overall = timeit.default_timer()
# set hyperparameters
args, _ = get_args_regcn()
args.dataset = 'tkgl-polecat'
SEED = args.seed # set the random seed for consistency
set_random_seed(SEED)
DATA=args.dataset
MODEL_NAME = 'REGCN'
print("logging mrrs per relation: ", args.log_per_rel)
# load data
dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
num_rels = dataset.num_rels
num_nodes = dataset.num_nodes
subjects = dataset.full_data["sources"]
objects= dataset.full_data["destinations"]
relations = dataset.edge_type
timestamps_orig = dataset.full_data["timestamps"]
timestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1
all_quads = np.stack((subjects, relations, objects, timestamps), axis=1)
train_data = all_quads[dataset.train_mask]
val_data = all_quads[dataset.val_mask]
test_data = all_quads[dataset.test_mask]
val_timestamps_orig = list(set(timestamps_orig[dataset.val_mask])) # needed for getting the negative samples
val_timestamps_orig.sort()
test_timestamps_orig = list(set(timestamps_orig[dataset.test_mask])) # needed for getting the negative samples
test_timestamps_orig.sort()
train_list = split_by_time(train_data)
valid_list = split_by_time(val_data)
test_list = split_by_time(test_data)
# evaluation metric
METRIC = dataset.eval_metric
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
#load the ns samples
dataset.load_val_ns()
dataset.load_test_ns()
## run training and testing
val_mrr, test_mrr = 0, 0
test_hits10 = 0
if args.grid_search:
print("hyperparameter grid search not implemented. Exiting.")
# single run
else:
start_train = timeit.default_timer()
if args.test == False: #if they are true: directly test on a previously trained and stored model
print('start training')
val_mrr, perf_per_rel, val_hits10 = run_experiment(args) # do training
start_test = timeit.default_timer()
args.test = True
print('start testing')
test_mrr, perf_per_rel, test_hits10 = run_experiment(args) # do testing
test_time = timeit.default_timer() - start_test
all_time = timeit.default_timer() - start_train
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
print(f"\Train and Test: Elapsed Time (s): {all_time: .4f}")
print(f"\tTest: {METRIC}: {test_mrr: .4f}")
print(f"\tValid: {METRIC}: {val_mrr: .4f}")
# saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
save_results({'model': MODEL_NAME,
'data': DATA,
'run': args.run_nr,
'seed': SEED,
f'val {METRIC}': float(val_mrr),
f'test {METRIC}': float(test_mrr),
'test_time': test_time,
'tot_train_val_time': all_time,
'test_hits10': float(test_hits10)
},
results_filename)
if args.log_per_rel:
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results_per_rel.json'
with open(results_filename, 'w') as json_file:
json.dump(perf_per_rel, json_file)
sys.exit()
================================================
FILE: examples/linkproppred/tkgl-polecat/timetraveler.py
================================================
"""
TimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting
Reference:
- https://github.com/JHL-HUST/TITer
Haohai Sun, Jialun Zhong, Yunpu Ma, Zhen Han, Kun He.
TimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting EMNLP 2021
"""
import sys
import timeit
import torch
from torch.utils.data import Dataset,DataLoader
import logging
import numpy as np
import pickle
from tqdm import tqdm
import os.path as osp
from pathlib import Path
import os
modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(modules_path)
from modules.timetraveler_agent import Agent
from modules.timetraveler_environment import Env
from modules.timetraveler_dirichlet import Dirichlet, MLE_Dirchlet
from modules.timetraveler_episode import Episode
from modules.timetraveler_policygradient import PG
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.linkproppred.evaluate import Evaluator
from modules.timetraveler_trainertester import Trainer, Tester, getRelEntCooccurrence
from tgb.utils.utils import set_random_seed,save_results
from modules.tkg_utils import get_args_timetraveler, reformat_ts, get_model_config_timetraveler
class QuadruplesDataset(Dataset):
""" this is an internal way how Timetraveler represents the data
"""
def __init__(self, examples):
"""
examples: a list of quadruples.
num_r: number of relations
"""
self.quadruples = examples.copy()
def __len__(self):
return len(self.quadruples)
def __getitem__(self, item):
return self.quadruples[item][0], \
self.quadruples[item][1], \
self.quadruples[item][2], \
self.quadruples[item][3], \
self.quadruples[item][4]
def set_logger(save_path):
"""Write logs to checkpoint and console"""
if args.do_train:
log_file = os.path.join(save_path, 'train.log')
else:
log_file = os.path.join(save_path, 'test.log')
logging.basicConfig(
format='%(asctime)s %(levelname)-8s %(message)s',
level=logging.INFO,
datefmt='%Y-%m-%d %H:%M:%S',
filename=log_file,
filemode='w'
)
console = logging.StreamHandler()
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s')
console.setFormatter(formatter)
logging.getLogger('').addHandler(console)
def preprocess_data(args, config, timestamps, save_path, all_quads):
"""
Preprocess the data and save the state-action space (pickle dump)
"""
# parser = argparse.ArgumentParser(description='Data preprocess', usage='preprocess_data.py [] [-h | --help]')
# parser.add_argument('--data_dir', default='data/ICEWS14', type=str, help='Path to data.')
env = Env(all_quads, config)
state_actions_space = {}
with tqdm(total=len(all_quads)) as bar:
for (head, rel, tail, t, _) in all_quads:
if (head, t, True) not in state_actions_space.keys():
state_actions_space[(head, t, True)] = env.get_state_actions_space_complete(head, t, True, args.store_actions_num)
state_actions_space[(head, t, False)] = env.get_state_actions_space_complete(head, t, False, args.store_actions_num)
if (tail, t, True) not in state_actions_space.keys():
state_actions_space[(tail, t, True)] = env.get_state_actions_space_complete(tail, t, True, args.store_actions_num)
state_actions_space[(tail, t, False)] = env.get_state_actions_space_complete(tail, t, False, args.store_actions_num)
bar.update(1)
pickle.dump(state_actions_space, open(os.path.join(save_path, args.state_actions_path), 'wb'))
def log_metrics(mode, step, metrics):
"""Print the evaluation logs"""
for metric in metrics:
logging.info('%s %s at epoch %d: %f' % (mode, metric, step, metrics[metric]))
def main(args):
"""
Main function to train and test the TimeTraveler model"""
start_overall = timeit.default_timer()
#######################Set Logger#################################
save_path = f'{os.path.dirname(os.path.abspath(__file__))}/saved_models/'
if not os.path.exists(save_path):
os.makedirs(save_path)
if args.cuda and torch.cuda.is_available():
args.cuda = True
else:
args.cuda = False
set_logger(save_path)
#######################Create DataLoader#################################
# set hyperparameters
args.dataset = 'tkgl-yago'
SEED = args.seed # set the random seed for consistency
set_random_seed(SEED)
DATA=args.dataset
MODEL_NAME = 'TIMETRAVELER'
# load data
dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
num_rels = dataset.num_rels
num_nodes = dataset.num_nodes
subjects = dataset.full_data["sources"]
objects= dataset.full_data["destinations"]
relations = dataset.edge_type
timestamps_orig = dataset.full_data["timestamps"]
timestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1
all_quads = np.stack((subjects, relations, objects, timestamps,timestamps_orig), axis=1)
train_data = all_quads[dataset.train_mask]
train_entities = np.unique(np.concatenate((train_data[:, 0], train_data[:, 2])))
RelEntCooccurrence = getRelEntCooccurrence(train_data, num_rels)
train_data =QuadruplesDataset(train_data)
val_data = QuadruplesDataset(all_quads[dataset.val_mask])
test_data = QuadruplesDataset(all_quads[dataset.test_mask])
METRIC = dataset.eval_metric
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
#load the ns samples
dataset.load_val_ns()
dataset.load_test_ns()
train_dataloader = DataLoader(
train_data,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
)
valid_dataloader = DataLoader(
val_data,
batch_size=args.test_batch_size,
shuffle=False,
num_workers=args.num_workers,
)
test_dataloader = DataLoader(
test_data,
batch_size=args.test_batch_size,
shuffle=False,
num_workers=args.num_workers,
)
######################Creat the agent and the environment###########################
config = get_model_config_timetraveler(args, num_nodes, num_rels)
logging.info(config)
logging.info(args)
# creat the agent
agent = Agent(config)
# creat the environment
state_actions_path = os.path.join(save_path, args.state_actions_path)
######################preprocessing###########################
if not os.path.exists(state_actions_path):
if args.preprocess:
print("preprocessing data...")
preprocess_data(args, config, timestamps, save_path, list(all_quads))
state_action_space = pickle.load(open(os.path.join(save_path, args.state_actions_path), 'rb'))
else:
state_action_space = None
else:
print("load preprocessed data...")
state_action_space = pickle.load(open(os.path.join(save_path, args.state_actions_path), 'rb'))
env = Env(list(all_quads), config, state_action_space)
# Create episode controller
episode = Episode(env, agent, config)
if args.cuda:
episode = episode.cuda()
pg = PG(config) # Policy Gradient
optimizer = torch.optim.Adam(episode.parameters(), lr=args.lr, weight_decay=0.00001)
######################Reward Shaping: MLE DIRICHLET alphas###########################
if args.reward_shaping:
try:
print("load alphas from pickle file")
alphas = pickle.load(open(os.path.join(save_path, args.alphas_pkl), 'rb'))
except:
print('running MLE dirichlet now')
mle_d = MLE_Dirchlet(all_quads, num_rels, args.k, args.time_span,
args.tol, args.method, args.maxiter)
pickle.dump(mle_d.alphas, open(os.path.join(save_path, args.alphas_pkl), 'wb'))
print('dumped alphas')
alphas = mle_d.alphas
distributions = Dirichlet(alphas, args.k)
else:
distributions = None
######################Training and Testing###########################
trainer = Trainer(episode, pg, optimizer, args, distributions)
tester = Tester(episode, args, train_entities, RelEntCooccurrence, dataset.metric)
test_metrics ={}
val_metrics = {}
test_metrics[METRIC] = None
val_metrics[METRIC] = None
if args.do_train:
start_train =timeit.default_timer()
logging.info('Start Training......')
for i in range(args.max_epochs):
loss, reward = trainer.train_epoch(train_dataloader, len(train_data))
logging.info('Epoch {}/{} Loss: {}, reward: {}'.format(i, args.max_epochs, loss, reward))
#! checking GPU usage
free_mem, total_mem = torch.cuda.mem_get_info()
print ("--------------GPU memory usage-----------")
print ("there are ", free_mem, " free memory")
print ("there are ", total_mem, " total available memory")
print ("there are ", total_mem - free_mem, " used memory")
print ("--------------GPU memory usage-----------")
if i % args.save_epoch == 0 and i != 0:
trainer.save_model(save_path, 'checkpoint_{}.pth'.format(i))
logging.info('Save Model in {}'.format(save_path))
if i % args.valid_epoch == 0 and i != 0:
logging.info('Start Val......')
val_metrics = tester.test(valid_dataloader,
len(val_data), num_nodes, neg_sampler, evaluator, split_mode='val')
for mode in val_metrics.keys():
logging.info('{} at epoch {}: {}'.format(mode, i, val_metrics[mode]))
trainer.save_model(save_path)
logging.info('Save Model in {}'.format(save_path))
else:
# # Load the model parameters
if os.path.isfile(save_path):
params = torch.load(save_path)
episode.load_state_dict(params['model_state_dict'])
optimizer.load_state_dict(params['optimizer_state_dict'])
logging.info('Load pretrain model: {}'.format(save_path))
if args.do_test:
logging.info('Start Testing......')
start_test = timeit.default_timer()
test_metrics = tester.test(test_dataloader,
len(test_data), num_nodes, neg_sampler, evaluator, split_mode='test')
for mode in test_metrics.keys():
logging.info('Test {} : {}'.format(mode, test_metrics[mode]))
# saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
test_time = timeit.default_timer() - start_test
all_time = timeit.default_timer() - start_train
all_time_preprocess = timeit.default_timer() - start_overall
save_results({'model': MODEL_NAME,
'data': DATA,
'seed': SEED,
f'val {METRIC}': float(val_metrics[METRIC]),
f'test {METRIC}': float(test_metrics[METRIC]),
'test_time': test_time,
'tot_train_val_time': all_time,
'tot_preprocess_train_val_time': all_time_preprocess
},
results_filename)
if __name__ == '__main__':
args = get_args_timetraveler()
main(args)
================================================
FILE: examples/linkproppred/tkgl-polecat/tkgl-polecat_example.py
================================================
import sys
sys.path.insert(0,'/../../../')
import numpy as np
import timeit
from tqdm import tqdm
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
from torch_geometric.loader import TemporalDataLoader
from tgb.linkproppred.evaluate import Evaluator
DATA = "tkgl-polecat"
# data loading
dataset = PyGLinkPropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
metric = dataset.eval_metric
print ("there are {} nodes and {} edges".format(dataset.num_nodes, dataset.num_edges))
print ("there are {} relation types".format(dataset.num_rels))
timestamp = data.t
head = data.src
tail = data.dst
edge_type = data.edge_type #relation
neg_sampler = dataset.negative_sampler
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
metric = dataset.eval_metric
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
BATCH_SIZE = 200
val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)
test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)
start_time = timeit.default_timer()
#load the ns samples first
dataset.load_val_ns()
for batch in tqdm(val_loader):
src, pos_dst, t, msg, rel = batch.src, batch.dst, batch.t, batch.msg, batch.edge_type
neg_batch_list = neg_sampler.query_batch(src.detach().cpu().numpy(), pos_dst.detach().cpu().numpy(), t.detach().cpu().numpy(), rel.detach().cpu().numpy(), split_mode='val')
print ("loading ns samples from validation", timeit.default_timer() - start_time)
# for i, (src, dst, t, rel) in enumerate(zip(val_data.src, val_data.dst, val_data.t, val_data.edge_type)):
# #must use np array to query
# neg_batch_list = neg_sampler.query_batch(np.array([src]), np.array([dst]), np.array([t]), edge_type=np.array([rel]), split_mode='val')
start_time = timeit.default_timer()
dataset.load_test_ns()
for batch in test_loader:
src, pos_dst, t, msg, rel = batch.src, batch.dst, batch.t, batch.msg, batch.edge_type
neg_batch_list = neg_sampler.query_batch(src.detach().cpu().numpy(), pos_dst.detach().cpu().numpy(), t.detach().cpu().numpy(), rel.detach().cpu().numpy(), split_mode='test')
print ("loading ns samples from test", timeit.default_timer() - start_time)
# for i, (src, dst, t, rel) in enumerate(zip(test_data.src, test_data.dst, test_data.t, test_data.edge_type)):
# #must use np array to query
# neg_batch_list = neg_sampler.query_batch(np.array([src]), np.array([dst]), np.array([t]), edge_type=np.array([rel]), split_mode='test')
print ("retrieved all negative samples")
# #* load numpy arrays instead
# from tgb.linkproppred.dataset import LinkPropPredDataset
# # data loading
# dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
# data = dataset.full_data
# metric = dataset.eval_metric
# sources = dataset.full_data['sources']
# print (sources.dtype)
================================================
FILE: examples/linkproppred/tkgl-polecat/tlogic.py
================================================
"""
https://github.com/liu-yushan/TLogic/tree/main/mycode
TLogic: Temporal Logical Rules for Explainable Link Forecasting on Temporal Knowledge Graphs.
Yushan Liu, Yunpu Ma, Marcel Hildebrandt, Mitchell Joblin, Volker Tresp
"""
# imports
import sys
import os
import os.path as osp
from pathlib import Path
modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(modules_path)
import timeit
import argparse
import numpy as np
import json
from joblib import Parallel, delayed
import itertools
from tgb.linkproppred.evaluate import Evaluator
from tgb.linkproppred.dataset import LinkPropPredDataset
from modules.tlogic_learn_modules import Temporal_Walk, Rule_Learner, store_edges
import modules.tlogic_apply_modules as ra
from tgb.utils.utils import set_random_seed, save_results
from modules.tkg_utils import reformat_ts, get_inv_relation_id, create_scores_array
def learn_rules(i, num_relations):
"""
Learn rules (multiprocessing possible).
Parameters:
i (int): process number
num_relations (int): minimum number of relations for each process
Returns:
rl.rules_dict (dict): rules dictionary
"""
# if seed:
# np.random.seed(seed)
num_rest_relations = len(all_relations) - (i + 1) * num_relations
if num_rest_relations >= num_relations:
relations_idx = range(i * num_relations, (i + 1) * num_relations)
else:
relations_idx = range(i * num_relations, len(all_relations))
num_rules = [0]
for k in relations_idx:
rel = all_relations[k]
for length in rule_lengths:
it_start = timeit.default_timer()
for _ in range(num_walks):
walk_successful, walk = temporal_walk.sample_walk(length + 1, rel)
if walk_successful:
rl.create_rule(walk)
it_end = timeit.default_timer()
it_time = round(it_end - it_start, 6)
num_rules.append(sum([len(v) for k, v in rl.rules_dict.items()]) // 2)
num_new_rules = num_rules[-1] - num_rules[-2]
print(
"Process {0}: relation {1}/{2}, length {3}: {4} sec, {5} rules".format(
i,
k - relations_idx[0] + 1,
len(relations_idx),
length,
it_time,
num_new_rules,
)
)
return rl.rules_dict
def apply_rules(i, num_queries, rules_dict, neg_sampler, data, window, learn_edges, all_quads, args, split_mode,
log_per_rel=False, num_rels=0):
"""
Apply rules (multiprocessing possible).
Parameters:
i (int): process number
num_queries (int): minimum number of queries for each process
Returns:
hits_list (list): hits list (hits@10 per sample)
perf_list (list): performance list (mrr per sample)
"""
perf_per_rel = {}
for rel in range(num_rels):
perf_per_rel[rel] = []
print("Start process", i, "...")
all_candidates = [dict() for _ in range(len(args))]
no_cands_counter = 0
num_rest_queries = len(data) - (i + 1) * num_queries
if num_rest_queries >= num_queries:
test_queries_idx = range(i * num_queries, (i + 1) * num_queries)
else:
test_queries_idx = range(i * num_queries, len(data))
cur_ts = data[test_queries_idx[0]][3]
edges = ra.get_window_edges(all_quads[:,0:4], cur_ts, learn_edges, window)
it_start = timeit.default_timer()
hits_list = [0] * len(test_queries_idx)
perf_list = [0] * len(test_queries_idx)
for index, j in enumerate(test_queries_idx):
neg_sample_el = neg_sampler.query_batch(np.expand_dims(np.array(data[j,0]), axis=0),
np.expand_dims(np.array(data[j,2]), axis=0),
np.expand_dims(np.array(data[j,4]), axis=0),
np.expand_dims(np.array(data[j,1]), axis=0),
split_mode=split_mode)[0]
# neg_samples_batch[j]
pos_sample_el = data[j,2]
test_query = data[j]
assert pos_sample_el == test_query[2]
cands_dict = [dict() for _ in range(len(args))]
if test_query[3] != cur_ts:
cur_ts = test_query[3]
edges = ra.get_window_edges(all_quads[:,0:4], cur_ts, learn_edges, window)
if test_query[1] in rules_dict:
dicts_idx = list(range(len(args)))
for rule in rules_dict[test_query[1]]:
walk_edges = ra.match_body_relations(rule, edges, test_query[0])
if 0 not in [len(x) for x in walk_edges]:
rule_walks = ra.get_walks(rule, walk_edges)
if rule["var_constraints"]:
rule_walks = ra.check_var_constraints(
rule["var_constraints"], rule_walks
)
if not rule_walks.empty:
cands_dict = ra.get_candidates(
rule,
rule_walks,
cur_ts,
cands_dict,
score_func,
args,
dicts_idx,
)
for s in dicts_idx:
cands_dict[s] = {
x: sorted(cands_dict[s][x], reverse=True)
for x in cands_dict[s].keys()
}
cands_dict[s] = dict(
sorted(
cands_dict[s].items(),
key=lambda item: item[1],
reverse=True,
)
)
top_k_scores = [v for _, v in cands_dict[s].items()][:top_k]
unique_scores = list(
scores for scores, _ in itertools.groupby(top_k_scores)
)
if len(unique_scores) >= top_k:
dicts_idx.remove(s)
if not dicts_idx:
break
if cands_dict[0]:
for s in range(len(args)):
# Calculate noisy-or scores
scores = list(
map(
lambda x: 1 - np.product(1 - np.array(x)),
cands_dict[s].values(),
)
)
cands_scores = dict(zip(cands_dict[s].keys(), scores))
noisy_or_cands = dict(
sorted(cands_scores.items(), key=lambda x: x[1], reverse=True)
)
all_candidates[s][j] = noisy_or_cands
else: # No candidates found by applying rules
no_cands_counter += 1
for s in range(len(args)):
all_candidates[s][j] = dict()
else: # No rules exist for this relation
no_cands_counter += 1
for s in range(len(args)):
all_candidates[s][j] = dict()
if not (j - test_queries_idx[0] + 1) % 100:
it_end = timeit.default_timer()
it_time = round(it_end - it_start, 6)
print(
"Process {0}: test samples finished: {1}/{2}, {3} sec".format(
i, j - test_queries_idx[0] + 1, len(test_queries_idx), it_time
)
)
it_start = timeit.default_timer()
predictions = create_scores_array(all_candidates[s][j], num_nodes)
predictions_of_interest_pos = np.array(predictions[pos_sample_el])
predictions_of_interest_neg = predictions[neg_sample_el]
input_dict = {
"y_pred_pos": predictions_of_interest_pos,
"y_pred_neg": predictions_of_interest_neg,
"eval_metric": ['mrr'],
}
predictions = evaluator.eval(input_dict)
perf_list[index] = predictions['mrr']
hits_list[index] = predictions['hits@10']
if split_mode == "test":
if log_per_rel:
perf_per_rel[test_query[1]].append(perf_list[index]) #test_query[1] is the relation index
if split_mode == "test":
if log_per_rel:
for rel in range(num_rels):
if len(perf_per_rel[rel]) > 0:
perf_per_rel[rel] = float(np.mean(perf_per_rel[rel]))
else:
perf_per_rel.pop(rel)
return perf_list, hits_list, perf_per_rel
## args
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", "-d", default="tkgl-polecat", type=str)
parser.add_argument("--rule_lengths", "-l", default="1", type=int, nargs="+")
parser.add_argument("--num_walks", "-n", default="100", type=int)
parser.add_argument("--transition_distr", default="exp", type=str)
parser.add_argument("--window", "-w", default=0, type=int) # set to e.g. 200 if only the most recent 200 timesteps should be considered. set to -2 if multistep
parser.add_argument("--top_k", default=20, type=int)
parser.add_argument("--num_processes", "-p", default=1, type=int)
parser.add_argument("--alpha", "-alpha", default=0.99, type=float) # fix alpha. used if trainflag == false
# parser.add_argument("--train_flag", "-tr", default=True) # do we need training, ie selection of lambda and alpha
parser.add_argument("--save_config", "-c", default=True) # do we need to save the selection of lambda and alpha in config file?
parser.add_argument('--seed', type=int, help='Random seed', default=1)
parser.add_argument('--run_nr', type=int, help='Run Number', default=1)
parser.add_argument('--learn_rules_flag', type=bool, help='Do we want to learn the rules', default=True)
parser.add_argument('--rule_filename', type=str, help='if rules not learned: where are they stored', default='0_r[3]_n100_exp_s1_rules.json')
parser.add_argument('--log_per_rel', type=bool, help='Do we want to log mrr per relation', default=False)
parser.add_argument('--compute_valid_mrr', type=bool, help='Do we want to compute mrr for valid set', default=True)
parsed = vars(parser.parse_args())
return parsed
start_o = timeit.default_timer()
## get args
parsed = get_args()
dataset = parsed["dataset"]
rule_lengths = parsed["rule_lengths"]
rule_lengths = [rule_lengths] if (type(rule_lengths) == int) else rule_lengths
num_walks = parsed["num_walks"]
transition_distr = parsed["transition_distr"]
num_processes = parsed["num_processes"]
window = parsed["window"]
top_k = parsed["top_k"]
log_per_rel = parsed['log_per_rel']
MODEL_NAME = 'TLogic'
SEED = parsed['seed'] # set the random seed for consistency
set_random_seed(SEED)
## load dataset and prepare it accordingly
name = parsed["dataset"]
compute_valid_mrr = parsed["compute_valid_mrr"]
dataset = LinkPropPredDataset(name=name, root="datasets", preprocess=True)
DATA = name
relations = dataset.edge_type
num_rels = dataset.num_rels
subjects = dataset.full_data["sources"]
objects= dataset.full_data["destinations"]
num_nodes = dataset.num_nodes
timestamps_orig = dataset.full_data["timestamps"]
timestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1
all_quads = np.stack((subjects, relations, objects, timestamps, timestamps_orig), axis=1)
train_data = all_quads[dataset.train_mask,0:4] # we do not need the original timestamps
val_data = all_quads[dataset.val_mask,0:5]
test_data = all_quads[dataset.test_mask,0:5]
all_data = all_quads[:,0:4]
metric = dataset.eval_metric
evaluator = Evaluator(name=name)
neg_sampler = dataset.negative_sampler
inv_relation_id = get_inv_relation_id(num_rels)
#load the ns samples
dataset.load_val_ns()
dataset.load_test_ns()
output_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
learn_rules_flag = parsed['learn_rules_flag']
## 1. learn rules
start_train = timeit.default_timer()
if learn_rules_flag:
print("start learning rules")
# edges (dict): edges for each relation
# inv_relation_id (dict): mapping of relation to inverse relation
temporal_walk = Temporal_Walk(train_data, inv_relation_id, transition_distr)
rl = Rule_Learner(edges=temporal_walk.edges, id2relation=None, inv_relation_id=inv_relation_id,
output_dir=output_dir)
all_relations = sorted(temporal_walk.edges) # Learn for all relations
start = timeit.default_timer()
num_relations = len(all_relations) // num_processes
output = Parallel(n_jobs=num_processes)(
delayed(learn_rules)(i, num_relations) for i in range(num_processes)
)
end = timeit.default_timer()
all_rules = output[0]
for i in range(1, num_processes):
all_rules.update(output[i])
total_time = round(end - start, 6)
print("Learning finished in {} seconds.".format(total_time))
rl.rules_dict = all_rules
rl.sort_rules_dict()
rule_filename = rl.save_rules(0, rule_lengths, num_walks, transition_distr, SEED)
# rl.save_rules_verbalized(0, rule_lengths, num_walks, transition_distr, seed)
# rules_statistics(rl.rules_dict)
else:
rule_filename = parsed['rule_filename']
print("Loading rules from file {}".format(parsed['rule_filename']))
end_train = timeit.default_timer()
## 2. Apply rules
rules_dict = json.load(open(output_dir + rule_filename))
rules_dict = {int(k): v for k, v in rules_dict.items()}
rules_dict = ra.filter_rules(
rules_dict, min_conf=0.01, min_body_supp=2, rule_lengths=rule_lengths
) # filter rules for minimum confidence, body support and rule length
learn_edges = store_edges(train_data)
score_func = ra.score_12
# It is possible to specify a list of list of arguments for tuning
args = [[0.1, 0.5]]
# compute valid mrr
start_valid = timeit.default_timer()
if compute_valid_mrr:
print('Computing valid MRR')
num_queries = len(val_data) // num_processes
output = Parallel(n_jobs=num_processes)(
delayed(apply_rules)(i, num_queries,rules_dict, neg_sampler, val_data, window, learn_edges,
all_quads, args, split_mode='val') for i in range(num_processes))
end = timeit.default_timer()
perf_list_val = []
hits_list_val = []
for i in range(num_processes):
perf_list_val.extend(output[i][0])
hits_list_val.extend(output[i][1])
else:
perf_list_val = [0]
hits_list_val = [0]
end_valid = timeit.default_timer()
# compute test mrr
if log_per_rel ==True:
num_processes = 1 #otherwise logging per rel does not work for our implementation
start_test = timeit.default_timer()
print('Computing test MRR')
start = timeit.default_timer()
num_queries = len(test_data) // num_processes
output = Parallel(n_jobs=num_processes)(
delayed(apply_rules)(i, num_queries,rules_dict, neg_sampler, test_data, window, learn_edges,
all_quads, args, split_mode='test', log_per_rel=log_per_rel, num_rels=num_rels) for i in range(num_processes))
end = timeit.default_timer()
perf_list_all = []
hits_list_all = []
for i in range(num_processes):
perf_list_all.extend(output[i][0])
hits_list_all.extend(output[i][1])
if log_per_rel == True:
perf_per_rel = output[0][2]
total_time = round(end - start, 6)
total_valid_time = round(end_valid - start_valid, 6)
print("Application finished in {} seconds.".format(total_time))
print(f"The valid MRR is {np.mean(perf_list_val)}")
print(f"The MRR is {np.mean(perf_list_all)}")
print(f"The Hits@10 is {np.mean(hits_list_all)}")
print(f"We have {len(perf_list_all)} predictions")
print(f"The test set has len {len(test_data)} ")
end_o = timeit.default_timer()
train_time_o = round(end_train- start_train, 6)
test_time_o = round(end_o- start_test, 6)
total_time_o = round(end_o- start_o, 6)
print("Running Training to find best configs finished in {} seconds.".format(train_time_o))
print("Running testing with best configs finished in {} seconds.".format(test_time_o))
print("Running all steps finished in {} seconds.".format(total_time_o))
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
if log_per_rel == True:
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results_per_rel.json'
with open(results_filename, 'w') as json_file:
json.dump(perf_per_rel, json_file)
results_filename = f'{results_path}/{MODEL_NAME}_NONE_{DATA}_results.json'
metric = dataset.eval_metric
save_results({'model': MODEL_NAME,
'train_flag': None,
'rule_len': rule_lengths,
'window': window,
'data': DATA,
'run': 1,
'seed': SEED,
metric: float(np.mean(perf_list_all)),
'hits10': float(np.mean(hits_list_all)),
'val_mrr': float(np.mean(perf_list_val)),
'test_time': test_time_o,
'tot_train_val_time': total_time_o,
'valid_time': total_valid_time
},
results_filename)
================================================
FILE: examples/linkproppred/tkgl-smallpedia/cen.py
================================================
"""
Complex Evolutional Pattern Learning for Temporal Knowledge Graph Reasoning
Reference:
- https://github.com/Lee-zix/CEN
Zixuan Li, Saiping Guan, Xiaolong Jin, Weihua Peng, Yajuan Lyu , Yong Zhu, Long Bai, Wei Li, Jiafeng Guo, Xueqi Cheng.
Complex Evolutional Pattern Learning for Temporal Knowledge Graph Reasoning. ACL 2022.
"""
import timeit
import os
import sys
import os.path as osp
from pathlib import Path
import numpy as np
import torch
import random
from tqdm import tqdm
import json
# internal imports
modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(modules_path)
from modules.rrgcn import RecurrentRGCNCEN
from tgb.utils.utils import set_random_seed, split_by_time, save_results
from modules.tkg_utils import get_args_cen, reformat_ts
from modules.tkg_utils_dgl import build_sub_graph
from tgb.linkproppred.evaluate import Evaluator
from tgb.linkproppred.dataset import LinkPropPredDataset
def test(model, history_len, history_list, test_list, num_rels, num_nodes, use_cuda, model_name, mode, split_mode):
"""
Test the model
:param model: model used to test
:param history_list: all input history snap shot list, not include output label train list or valid list
:param test_list: test triple snap shot list
:param num_rels: number of relations
:param num_nodes: number of nodes
:param use_cuda:
:param model_name:
:param mode:
:param split_mode: 'test' or 'val' to state which negative samples to load
:return mrr
"""
print("Testing for mode: ", split_mode)
if split_mode == 'test':
timesteps_to_eval = test_timestamps_orig
else:
timesteps_to_eval = val_timestamps_orig
idx = 0
if mode == "test":
# test mode: load parameter form file
if use_cuda:
checkpoint = torch.load(model_name, map_location=torch.device(args.gpu))
else:
checkpoint = torch.load(model_name, map_location=torch.device('cpu'))
# use best stat checkpoint:
print("Load Model name: {}. Using best epoch : {}".format(model_name, checkpoint['epoch']))
print("\n"+"-"*10+"start testing"+"-"*10+"\n")
model.load_state_dict(checkpoint['state_dict'])
model.eval()
input_list = [snap for snap in history_list[-history_len:]]
perf_list_all = []
hits_list_all = []
perf_per_rel = {}
for rel in range(num_rels):
perf_per_rel[rel] = []
for time_idx, test_snap in enumerate(tqdm(test_list)):
history_glist = [build_sub_graph(num_nodes, num_rels, g, use_cuda, args.gpu) for g in input_list]
test_triples_input = torch.LongTensor(test_snap).cuda() if use_cuda else torch.LongTensor(test_snap)
timesteps_batch =timesteps_to_eval[time_idx]*np.ones(len(test_triples_input[:,0]))
neg_samples_batch = neg_sampler.query_batch(test_triples_input[:,0], test_triples_input[:,2],
timesteps_batch, edge_type=test_triples_input[:,1], split_mode=split_mode)
pos_samples_batch = test_triples_input[:,2]
_, perf_list, hits_list = model.predict(history_glist, test_triples_input, use_cuda, neg_samples_batch, pos_samples_batch,
evaluator, METRIC)
perf_list_all.extend(perf_list)
hits_list_all.extend(hits_list)
if split_mode == "test":
if args.log_per_rel:
for score, rel in zip(perf_list, test_triples_input[:,1].tolist()):
perf_per_rel[rel].append(score)
# reconstruct history graph list
input_list.pop(0)
input_list.append(test_snap)
idx += 1
if split_mode == "test":
if args.log_per_rel:
for rel in range(num_rels):
if len(perf_per_rel[rel]) > 0:
perf_per_rel[rel] = float(np.mean(perf_per_rel[rel]))
else:
perf_per_rel.pop(rel)
mrr = np.mean(perf_list_all)
hits10 = np.mean(hits_list_all)
return mrr, perf_per_rel, hits10
def run_experiment(args, trainvalidtest_id=0, n_hidden=None, n_layers=None, dropout=None, n_bases=None):
'''
Run experiment for CEN model
:param args: arguments for the model
:param trainvalidtest_id: -1: pretrainig, 0: curriculum training (to find best test history len), 1: test on valid set, 2: test on test set
:param n_hidden: number of hidden units
:param n_layers: number of layers
:param dropout: dropout rate
:param n_bases: number of bases
return: mrr, perf_per_rel: mean reciprocal rank and performance per relation
'''
# 1) load configuration for grid search the best configuration
if n_hidden:
args.n_hidden = n_hidden
if n_layers:
args.n_layers = n_layers
if dropout:
args.dropout = dropout
if n_bases:
args.n_bases = n_bases
test_history_len = args.test_history_len
# 2) set save model path
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
if not osp.exists(save_model_dir):
os.mkdir(save_model_dir)
print('INFO: Create directory {}'.format(save_model_dir))
test_model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{args.test_history_len}'
test_state_file = save_model_dir+test_model_name
perf_per_rel ={}
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
# create stat
model = RecurrentRGCNCEN(args.decoder,
args.encoder,
num_nodes,
num_rels,
args.n_hidden,
args.opn,
sequence_len=args.train_history_len,
num_bases=args.n_bases,
num_basis=args.n_basis,
num_hidden_layers=args.n_layers,
dropout=args.dropout,
self_loop=args.self_loop,
skip_connect=args.skip_connect,
layer_norm=args.layer_norm,
input_dropout=args.input_dropout,
hidden_dropout=args.hidden_dropout,
feat_dropout=args.feat_dropout,
entity_prediction=args.entity_prediction,
relation_prediction=args.relation_prediction,
use_cuda=use_cuda,
gpu = args.gpu)
if use_cuda:
torch.cuda.set_device(args.gpu)
model.cuda()
# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
if trainvalidtest_id == 1: # normal test on validation set Note that mode=test
if os.path.exists(test_state_file):
mrr, _, hits10 = test(model, args.test_history_len, train_list, valid_list, num_rels, num_nodes, use_cuda,
test_state_file, "test", split_mode="val")
else:
print('Cannot do testing because model does not exist: ', test_state_file)
mrr = 0
hits10 = 0
elif trainvalidtest_id == 2: # normal test on test set
if os.path.exists(test_state_file):
mrr, perf_per_rel, hits10 = test(model, args.test_history_len, train_list+valid_list, test_list, num_rels, num_nodes, use_cuda,
test_state_file, "test", split_mode="test")
else:
print('Cannot do testing because model does not exist: ', test_state_file)
mrr = 0
hits10 = 0
elif trainvalidtest_id == -1:
print("-------------start pre training model with history length {}----------\n".format(args.start_history_len))
model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{args.start_history_len}'
model_state_file = save_model_dir + model_name
print("Sanity Check: stat name : {}".format(model_state_file))
print("Sanity Check: Is cuda available ? {}".format(torch.cuda.is_available()))
best_mrr = 0
best_epoch = 0
best_hits10= 0
## training loop
for epoch in range(args.n_epochs):
model.train()
losses = []
idx = [_ for _ in range(len(train_list))]
random.shuffle(idx)
for train_sample_num in idx:
if train_sample_num == 0 or train_sample_num == 1: continue
if train_sample_num - args.start_history_len<0:
input_list = train_list[0: train_sample_num]
output = train_list[1:train_sample_num+1]
else:
input_list = train_list[train_sample_num-args.start_history_len: train_sample_num]
output = train_list[train_sample_num-args.start_history_len+1:train_sample_num+1]
# generate history graph
history_glist = [build_sub_graph(num_nodes, num_rels, snap, use_cuda, args.gpu) for snap in input_list]
output = [torch.from_numpy(_).long().cuda() for _ in output] if use_cuda else [torch.from_numpy(_).long() for _ in output]
loss= model.get_loss(history_glist, output[-1], None, use_cuda)
losses.append(loss.item())
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm) # clip gradients
optimizer.step()
optimizer.zero_grad()
print("His {:04d}, Epoch {:04d} | Ave Loss: {:.4f} | Best MRR {:.4f} | Model {} "
.format(args.start_history_len, epoch, np.mean(losses), best_mrr, model_name))
#! checking GPU usage
free_mem, total_mem = torch.cuda.mem_get_info()
print ("--------------GPU memory usage-----------")
print ("there are ", free_mem, " free memory")
print ("there are ", total_mem, " total available memory")
print ("there are ", total_mem - free_mem, " used memory")
print ("--------------GPU memory usage-----------")
# validation
if epoch % args.evaluate_every == 0:
mrr, _, hits10 = test(model, args.start_history_len, train_list, valid_list, num_rels, num_nodes, use_cuda,
model_state_file, mode="train", split_mode= "val")
if mrr< best_mrr:
if epoch >= args.n_epochs or epoch - best_epoch > 5:
break
else:
best_mrr = mrr
best_epoch = epoch
best_hits10 = hits10
torch.save({'state_dict': model.state_dict(), 'epoch': epoch}, model_state_file)
elif trainvalidtest_id == 0: #curriculum training
model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{args.start_history_len}'
init_state_file = save_model_dir + model_name
init_checkpoint = torch.load(init_state_file, map_location=torch.device(args.gpu))
# use best stat checkpoint:
print("Load Previous Model name: {}. Using best epoch : {}".format(init_state_file, init_checkpoint['epoch']))
print("\n"+"-"*10+"Load model with history length {}".format(args.start_history_len)+"-"*10+"\n")
model.load_state_dict(init_checkpoint['state_dict'])
test_history_len = args.start_history_len
mrr, _, hits10 = test(model,
args.start_history_len,
train_list,
valid_list,
num_rels,
num_nodes,
use_cuda,
init_state_file,
mode="test", split_mode= "val")
best_mrr_list = [mrr.item()]
best_hits_list = [hits10.item()]
# start knowledge distillation
ks_idx = 0
for history_len in range(args.start_history_len+1, args.train_history_len+1, 1):
# current model
print("best mrr list :", best_mrr_list)
# lr = 0.1*args.lr - 0.002*args.lr*ks_idx
optimizer = torch.optim.Adam(model.parameters(), lr=0.1*args.lr, weight_decay=0.00001)
model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{history_len}'
model_state_file = save_model_dir + model_name
print("Sanity Check: stat name : {}".format(model_state_file))
# load model with the least history length
prev_model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{history_len-1}'
prev_state_file = save_model_dir + prev_model_name
checkpoint = torch.load(prev_state_file, map_location=torch.device(args.gpu))
model.load_state_dict(checkpoint['state_dict'])
print("\n"+"-"*10+"start knowledge distillation for history length at "+ str(history_len)+"-"*10+"\n")
best_mrr = 0
best_hits10 = 0
best_epoch = 0
for epoch in range(args.n_epochs):
model.train()
losses = []
idx = [_ for _ in range(len(train_list))]
random.shuffle(idx)
for train_sample_num in idx:
if train_sample_num == 0 or train_sample_num == 1: continue
if train_sample_num - history_len<0:
input_list = train_list[0: train_sample_num]
output = train_list[1:train_sample_num+1]
else:
input_list = train_list[train_sample_num - history_len: train_sample_num]
output = train_list[train_sample_num-history_len+1:train_sample_num+1]
# generate history graph
history_glist = [build_sub_graph(num_nodes, num_rels, snap, use_cuda, args.gpu) for snap in input_list]
output = [torch.from_numpy(_).long().cuda() for _ in output] if use_cuda else [torch.from_numpy(_).long() for _ in output]
loss= model.get_loss(history_glist, output[-1], None, use_cuda)
# print(loss)
losses.append(loss.item())
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm) # clip gradients
optimizer.step()
optimizer.zero_grad()
print("His {:04d}, Epoch {:04d} | Ave Loss: {:.4f} |Best MRR {:.4f} | Model {} "
.format(history_len, epoch, np.mean(losses), best_mrr, model_name))
#! checking GPU usage
free_mem, total_mem = torch.cuda.mem_get_info()
print ("--------------GPU memory usage-----------")
print ("there are ", free_mem, " free memory")
print ("there are ", total_mem, " total available memory")
print ("there are ", total_mem - free_mem, " used memory")
print ("--------------GPU memory usage-----------")
# validation
if epoch % args.evaluate_every == 0:
mrr, _, hits10 = test(model, history_len, train_list, valid_list, num_rels, num_nodes, use_cuda,
model_state_file, mode="train", split_mode= "val")
if mrr< best_mrr:
if epoch >= args.n_epochs or epoch-best_epoch>2:
break
else:
best_mrr = mrr
best_epoch = epoch
best_hits10 = hits10
torch.save({'state_dict': model.state_dict(), 'epoch': epoch}, model_state_file)
mrr, _, hits10 = test(model, history_len, train_list, valid_list, num_rels, num_nodes, use_cuda,
model_state_file, mode="test", split_mode= "val")
ks_idx += 1
if mrr.item() < max(best_mrr_list):
test_history_len = history_len-1
print("early stopping, best history length: ", test_history_len)
break
else:
best_mrr_list.append(mrr.item())
best_hits_list.append(hits10.item())
return mrr, test_history_len, perf_per_rel, hits10
# ==================
# ==================
# ==================
start_overall = timeit.default_timer()
# set hyperparameters
args, _ = get_args_cen()
args.dataset = 'tkgl-smallpedia'
SEED = args.seed # set the random seed for consistency
set_random_seed(SEED)
DATA=args.dataset
MODEL_NAME = 'CEN'
print("logging mrrs per relation: ", args.log_per_rel)
print("do test and valid? do only test no validation?: ", args.validtest, args.test_only)
# load data
dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
num_rels = dataset.num_rels
num_nodes = dataset.num_nodes
subjects = dataset.full_data["sources"]
objects= dataset.full_data["destinations"]
relations = dataset.edge_type
timestamps_orig = dataset.full_data["timestamps"]
timestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1
all_quads = np.stack((subjects, relations, objects, timestamps), axis=1)
train_data = all_quads[dataset.train_mask]
val_data = all_quads[dataset.val_mask]
test_data = all_quads[dataset.test_mask]
val_timestamps_orig = list(set(timestamps_orig[dataset.val_mask])) # needed for getting the negative samples
val_timestamps_orig.sort()
test_timestamps_orig = list(set(timestamps_orig[dataset.test_mask])) # needed for getting the negative samples
test_timestamps_orig.sort()
train_list = split_by_time(train_data)
valid_list = split_by_time(val_data)
test_list = split_by_time(test_data)
# evaluation metric
METRIC = dataset.eval_metric
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
#load the ns samples
dataset.load_val_ns()
dataset.load_test_ns()
if args.grid_search:
print("TODO: implement hyperparameter grid search")
# single run
else:
start_train = timeit.default_timer()
if args.validtest:
print('directly start testing')
if args.test_history_len_2 != args.test_history_len:
args.test_history_len = args.test_history_len_2 # hyperparameter value as given in original paper
else:
print('running pretrain and train')
# pretrain
mrr, _, _, hits10 = run_experiment(args, trainvalidtest_id=-1)
# train
mrr, args.test_history_len, _, hits10 = run_experiment(args, trainvalidtest_id=0) # overwrite test_history_len with
# the best history len (for valid mrr)
if args.test_only == False:
print("running test (on val and test dataset) with test_history_len of: ", args.test_history_len)
# test on val set
val_mrr, _, _, val_hits10 = run_experiment(args, trainvalidtest_id=1)
else:
val_mrr = 0
val_hits10 = 0
# test on test set
start_test = timeit.default_timer()
test_mrr, _, perf_per_rel, test_hits10 = run_experiment(args, trainvalidtest_id=2)
test_time = timeit.default_timer() - start_test
all_time = timeit.default_timer() - start_train
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
print(f"\Train and Test: Elapsed Time (s): {all_time: .4f}")
print(f"\tTest: {METRIC}: {test_mrr: .4f}")
print(f"\tValid: {METRIC}: {val_mrr: .4f}")
# saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
save_results({'model': MODEL_NAME,
'data': DATA,
'run': args.run_nr,
'seed': SEED,
'test_history_len': args.test_history_len,
f'val {METRIC}': float(val_mrr),
f'test {METRIC}': float(test_mrr),
'test_time': test_time,
'tot_train_val_time': all_time,
'test_hits10': float(test_hits10)
},
results_filename)
if args.log_per_rel:
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results_per_rel.json'
with open(results_filename, 'w') as json_file:
json.dump(perf_per_rel, json_file)
sys.exit()
================================================
FILE: examples/linkproppred/tkgl-smallpedia/edgebank.py
================================================
"""
Dynamic Link Prediction with EdgeBank
NOTE: This implementation works only based on `numpy`
Reference:
- https://github.com/fpour/DGB/tree/main
"""
import timeit
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch_geometric.loader import TemporalDataLoader
from tqdm import tqdm
import math
import os
import os.path as osp
from pathlib import Path
import sys
import argparse
# internal imports
tgb_modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(tgb_modules_path)
from tgb.linkproppred.evaluate import Evaluator
from modules.edgebank_predictor import EdgeBankPredictor
from tgb.utils.utils import set_random_seed
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.utils.utils import save_results
# ==================
# ==================
# ==================
def test(data, test_mask, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
data: a dataset object
test_mask: required masks to load the test set edges
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluation
"""
num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)
perf_list = []
hits_list = []
for batch_idx in tqdm(range(num_batches)):
start_idx = batch_idx * BATCH_SIZE
end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))
pos_src, pos_dst, pos_t, pos_edge = (
data['sources'][test_mask][start_idx: end_idx],
data['destinations'][test_mask][start_idx: end_idx],
data['timestamps'][test_mask][start_idx: end_idx],
data['edge_type'][test_mask][start_idx: end_idx],
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, pos_edge, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])
query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])
y_pred = edgebank.predict_link(query_src, query_dst)
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0]]),
"y_pred_neg": np.array(y_pred[1:]),
"eval_metric": [metric],
}
results = evaluator.eval(input_dict)
perf_list.append(results[metric])
hits_list.append(results['hits@10'])
# update edgebank memory after each positive batch
edgebank.update_memory(pos_src, pos_dst, pos_t)
perf_metrics = float(np.mean(perf_list))
perf_hits = float(np.mean(hits_list))
return perf_metrics, perf_hits
def get_args():
parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')
parser.add_argument('-d', '--data', type=str, help='Dataset name', default='tkgl-smallpedia')
parser.add_argument('--bs', type=int, help='Batch size', default=200)
parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)
parser.add_argument('--seed', type=int, help='Random seed', default=1)
parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])
parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)
try:
args = parser.parse_args()
except:
parser.print_help()
sys.exit(0)
return args, sys.argv
# ==================
# ==================
# ==================
start_overall = timeit.default_timer()
# set hyperparameters
args, _ = get_args()
SEED = args.seed # set the random seed for consistency
set_random_seed(SEED)
MEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`
BATCH_SIZE = args.bs
K_VALUE = args.k_value
TIME_WINDOW_RATIO = args.time_window_ratio
DATA = args.data
MODEL_NAME = 'EdgeBank'
# data loading with `numpy`
dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
data = dataset.full_data
metric = dataset.eval_metric
# get masks
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
#data for memory in edgebank
hist_src = np.concatenate([data['sources'][train_mask]])
hist_dst = np.concatenate([data['destinations'][train_mask]])
hist_ts = np.concatenate([data['timestamps'][train_mask]])
# Set EdgeBank with memory updater
edgebank = EdgeBankPredictor(
hist_src,
hist_dst,
hist_ts,
memory_mode=MEMORY_MODE,
time_window_ratio=TIME_WINDOW_RATIO)
print("==========================================================")
print(f"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'
# ==================================================== Test
# loading the validation negative samples
dataset.load_val_ns()
# testing ...
start_val = timeit.default_timer()
perf_metric_val, perf_hits_val = test(data, val_mask, neg_sampler, split_mode='val')
end_val = timeit.default_timer()
print(f"INFO: val: Evaluation Setting: >>> ONE-VS--ALL <<< ")
print(f"\tval: {metric}: {perf_metric_val: .4f}")
val_time = timeit.default_timer() - start_val
print(f"\tval: Elapsed Time (s): {val_time: .4f}")
# ==================================================== Test
# loading the test negative samples
dataset.load_test_ns()
# testing ...
start_test = timeit.default_timer()
perf_metric_test, perf_hits_test = test(data, test_mask, neg_sampler, split_mode='test')
end_test = timeit.default_timer()
print(f"INFO: Test: Evaluation Setting: >>> <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'memory_mode': MEMORY_MODE,
'data': DATA,
'run': 1,
'seed': SEED,
metric: perf_metric_test,
'val_mrr': perf_metric_val,
'test_time': test_time,
'tot_train_val_time': test_time+val_time,
'hits10': perf_hits_test},
results_filename)
================================================
FILE: examples/linkproppred/tkgl-smallpedia/recurrencybaseline.py
================================================
""" from paper: "History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting"
Julia Gastinger, Christian Meilicke, Federico Errica, Timo Sztyler, Anett Schuelke, Heiner Stuckenschmidt (IJCAI 2024)
@inproceedings{gastinger2024baselines,
title={History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting},
author={Gastinger, Julia and Meilicke, Christian and Errica, Federico and Sztyler, Timo and Schuelke, Anett and Stuckenschmidt, Heiner},
booktitle={33nd International Joint Conference on Artificial Intelligence (IJCAI 2024)},
year={2024},
organization={International Joint Conferences on Artificial Intelligence Organization}
}
"""
## imports
import timeit
import argparse
import numpy as np
from copy import copy
from pathlib import Path
import ray
import sys
import os
import os.path as osp
import json
#internal imports
modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(modules_path)
from modules.recurrencybaseline_predictor import baseline_predict, baseline_predict_remote
from tgb.linkproppred.evaluate import Evaluator
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.utils.utils import set_random_seed, save_results
from modules.tkg_utils import create_basis_dict, group_by, reformat_ts
def predict(num_processes, data_c_rel, all_data_c_rel, alpha, lmbda_psi,
perf_list_all, hits_list_all, window, neg_sampler, split_mode):
""" create predictions for each relation on test or valid set and compute mrr
:return perf_list_all: list of mrrs for each test query
:return hits_list_all: list of hits for each test query
"""
first_ts = data_c_rel[0][3]
## use this if you wanna use ray:
num_queries = len(data_c_rel) // num_processes
if num_queries < num_processes: # if we do not have enough queries for all the processes
num_processes_tmp = 1
num_queries = len(data_c_rel)
else:
num_processes_tmp = num_processes
if num_processes > 1:
object_references =[]
for i in range(num_processes_tmp):
num_test_queries = len(data_c_rel) - (i + 1) * num_queries
if num_test_queries >= num_queries:
test_queries_idx =[i * num_queries, (i + 1) * num_queries]
else:
test_queries_idx = [i * num_queries, len(test_data)]
valid_data_b = data_c_rel[test_queries_idx[0]:test_queries_idx[1]]
ob = baseline_predict_remote.remote(num_queries, valid_data_b, all_data_c_rel, window,
basis_dict,
num_nodes, num_rels, lmbda_psi,
alpha, evaluator,first_ts, neg_sampler, split_mode)
object_references.append(ob)
output = ray.get(object_references)
# updates the scores and logging dict for each process
for proc_loop in range(num_processes_tmp):
perf_list_all.extend(output[proc_loop][0])
hits_list_all.extend(output[proc_loop][1])
else:
perf_list, hits_list = baseline_predict(len(data_c_rel), data_c_rel, all_data_c_rel,
window, basis_dict,
num_nodes, num_rels, lmbda_psi,
alpha, evaluator, first_ts, neg_sampler, split_mode)
perf_list_all.extend(perf_list)
hits_list_all.extend(hits_list)
return perf_list_all, hits_list_all
## test
def test(best_config, all_relations,test_data_prel, all_data_prel, neg_sampler, num_processes, window, split_mode='test'):
""" create predictions by loopoing through all relations on test or valid set and compute mrr
:return perf_list_all: list of mrrs for each test query
:return hits_list_all: list of hits for each test query
"""
perf_list_all = []
hits_list_all =[]
csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'
## loop through relations and apply baselines
for rel in all_relations:
start = timeit.default_timer()
if rel in test_data_prel.keys():
lmbda_psi = best_config[str(rel)]['lmbda_psi'][0]
alpha = best_config[str(rel)]['alpha'][0]
# test data for this relation
test_data_c_rel = test_data_prel[rel]
timesteps_test = list(set(test_data_c_rel[:,3]))
timesteps_test.sort()
all_data_c_rel = all_data_prel[rel]
perf_list_rel = []
hits_list_rel = []
perf_list_rel, hits_list_rel = predict(num_processes, test_data_c_rel,
all_data_c_rel, alpha, lmbda_psi,perf_list_rel, hits_list_rel,
window, neg_sampler, split_mode)
perf_list_all.extend(perf_list_rel)
hits_list_all.extend(hits_list_rel)
else:
perf_list_rel =[]
end = timeit.default_timer()
total_time = round(end - start, 6)
print("Relation {} finished in {} seconds.".format(rel, total_time))
with open(csv_file, 'a') as f:
f.write("{},{}\n".format(rel, perf_list_rel))
return perf_list_all, hits_list_all
def read_dict_compute_mrr(split_mode='test'):
""" read the results per relation from a precreated file and compute mrr
:return mrr_per_rel: dictionary of mrrs for each relation
:return all_mrrs: list of mrrs for all relations
"""
csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'
# Initialize an empty dictionary to store the data
results_per_rel_dict = {}
mrr_per_rel = {}
all_mrrs = []
# Open the file for reading
with open(csv_file, 'r') as f:
# Read each line in the file
for line in f:
# Split the line at the comma
parts = line.strip().split(',')
# Extract the key (the first part)
key = int(parts[0])
# Extract the values (the rest of the parts), remove square brackets
values = [float(value.strip('[]')) for value in parts[1:]]
# Add the key-value pair to the dictionary
if key in results_per_rel_dict.keys():
print(f"Key {key} already exists in the dictionary!!! might have duplicate entries in results csv")
results_per_rel_dict[key] = values
all_mrrs.extend(values)
mrr_per_rel[key] = np.mean(values)
if len(list(results_per_rel_dict.keys())) != num_rels:
print("we do not have entries for each rel in the results csv file. only num enties: ", len(list(results_per_rel_dict.keys())))
print("Split mode: "+split_mode +" Mean MRR: ", np.mean(all_mrrs))
print("mrr per relation: ", mrr_per_rel)
## train
def train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampler, num_processes, window):
""" optional, find best values for lambda and alpha by looping through all relations and testing an a fixed set of params
based on validation mrr
:return best_config: dictionary of best params for each relation
"""
best_config= {}
best_mrr = 0
for rel in rels: # loop through relations. for each relation, apply rules with selected params, compute valid mrr
start = timeit.default_timer()
rel_key = int(rel)
best_config[str(rel_key)] = {}
best_config[str(rel_key)]['not_trained'] = 'True'
best_config[str(rel_key)]['lmbda_psi'] = [default_lmbda_psi,0] #default
best_config[str(rel_key)]['other_lmbda_mrrs'] = list(np.zeros(len(params_dict['lmbda_psi'])))
best_config[str(rel_key)]['alpha'] = [default_alpha,0] #default
best_config[str(rel_key)]['other_alpha_mrrs'] = list(np.zeros(len(params_dict['alpha'])))
if rel in val_data_prel.keys():
# valid data for this relation
val_data_c_rel = val_data_prel[rel]
timesteps_valid = list(set(val_data_c_rel[:,3]))
timesteps_valid.sort()
trainval_data_c_rel = trainval_data_prel[rel]
###### 1) select lambda ###############
lmbdas_psi = params_dict['lmbda_psi']
alpha = 1
best_lmbda_psi = 0.1
best_mrr_psi = 0
lmbda_mrrs = []
best_config[str(rel_key)]['num_app_valid'] = copy(len(val_data_c_rel))
best_config[str(rel_key)]['num_app_train_valid'] = copy(len(trainval_data_c_rel))
best_config[str(rel_key)]['not_trained'] = 'False'
for lmbda_psi in lmbdas_psi:
perf_list_r = []
hits_list_r = []
perf_list_r, hits_list_r = predict(num_processes, val_data_c_rel,
trainval_data_c_rel, alpha, lmbda_psi,perf_list_r, hits_list_r,
window, neg_sampler, split_mode='val')
# compute mrr
mrr = np.mean(perf_list_r)
# # is new mrr better than previous best? if yes: store lmbda
if mrr > best_mrr_psi:
best_mrr_psi = float(mrr)
best_lmbda_psi = lmbda_psi
lmbda_mrrs.append(float(mrr))
best_config[str(rel_key)]['lmbda_psi'] = [best_lmbda_psi, best_mrr_psi]
best_config[str(rel_key)]['other_lmbda_mrrs'] = lmbda_mrrs
best_mrr = best_mrr_psi
##### 2) select alpha ###############
best_config[str(rel_key)]['not_trained'] = 'False'
alphas = params_dict['alpha']
lmbda_psi = best_config[str(rel_key)]['lmbda_psi'][0] # use the best lmbda psi
alpha_mrrs = []
# perf_list_all = []
best_mrr_alpha = 0
best_alpha=0.99
for alpha in alphas:
perf_list_r = []
hits_list_r = []
perf_list_r, hits_list_r = predict(num_processes, val_data_c_rel,
trainval_data_c_rel, alpha, lmbda_psi,perf_list_r, hits_list_r,
window, neg_sampler, split_mode='val')
# compute mrr
mrr_alpha = np.mean(perf_list_r)
# is new mrr better than previous best? if yes: store alpha
if mrr_alpha > best_mrr_alpha:
best_mrr_alpha = float(mrr_alpha)
best_alpha = alpha
best_mrr = best_mrr_alpha
alpha_mrrs.append(float(mrr_alpha))
best_config[str(rel_key)]['alpha'] = [best_alpha, best_mrr_alpha]
best_config[str(rel_key)]['other_alpha_mrrs'] = alpha_mrrs
end = timeit.default_timer()
total_time = round(end - start, 6)
print("Relation {} finished in {} seconds.".format(rel, total_time))
return best_config
## args
def get_args():
"""parse all arguments for the script"""
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", "-d", default="tkgl-smallpedia", type=str)
parser.add_argument("--window", "-w", default=0, type=int) # set to e.g. 200 if only the most recent 200 timesteps should be considered. set to -2 if multistep
parser.add_argument("--num_processes", "-p", default=1, type=int)
parser.add_argument("--lmbda", "-l", default=0.1, type=float) # fix lambda. used if trainflag == false
parser.add_argument("--alpha", "-alpha", default=0.99, type=float) # fix alpha. used if trainflag == false
parser.add_argument("--train_flag", "-tr", default='False') # do we need training, ie selection of lambda and alpha
parser.add_argument("--load_flag", "-lo", default='False') # if train_flag set to True: do you want to load best_config?
parser.add_argument("--save_config", "-c", default='True') # do we need to save the selection of lambda and alpha in config file?
parser.add_argument('--seed', type=int, help='Random seed', default=1) # not needed
parsed = vars(parser.parse_args())
return parsed
start_o = timeit.default_timer()
parsed = get_args()
if parsed['num_processes']>1:
ray.init(num_cpus=parsed["num_processes"], num_gpus=0)
MODEL_NAME = 'RecurrencyBaseline'
SEED = parsed['seed'] # set the random seed for consistency
set_random_seed(SEED)
perrel_results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_models'
if not osp.exists(perrel_results_path):
os.mkdir(perrel_results_path)
print('INFO: Create directory {}'.format(perrel_results_path))
Path(perrel_results_path).mkdir(parents=True, exist_ok=True)
## load dataset and prepare it accordingly
name = parsed["dataset"]
dataset = LinkPropPredDataset(name=name, root="datasets", preprocess=True)
DATA = name
relations = dataset.edge_type
num_rels = dataset.num_rels
rels = np.arange(0,num_rels)
subjects = dataset.full_data["sources"]
objects= dataset.full_data["destinations"]
num_nodes = dataset.num_nodes
timestamps_orig = dataset.full_data["timestamps"]
timestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1
print("split train valid test data")
all_quads = np.stack((subjects, relations, objects, timestamps, timestamps_orig), axis=1)
train_data = all_quads[dataset.train_mask]
val_data = all_quads[dataset.val_mask]
test_data = all_quads[dataset.test_mask]
metric = dataset.eval_metric
evaluator = Evaluator(name=name)
neg_sampler = dataset.negative_sampler
train_val_data = np.concatenate([train_data, val_data])
all_data = np.concatenate([train_data, val_data, test_data])
# create dicts with key: relation id, values: triples for that relation id
print("grouping data by relation")
test_data_prel = group_by(test_data, 1)
all_data_prel = group_by(all_data, 1)
val_data_prel = group_by(val_data, 1)
trainval_data_prel = group_by(train_val_data, 1)
#load the ns samples
# if parsed['train_flag']:
print("loading negative samples")
dataset.load_val_ns()
dataset.load_test_ns()
# parameter options
if parsed['train_flag'] == 'True':
params_dict = {}
params_dict['lmbda_psi'] = [0, 0.0001, 0.0005, 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.5, 0.9, 1.0001]
params_dict['alpha'] = [0, 0.00001, 0.0001, 0.001, 0.01, 0.1, 0.5, 0.9, 0.99, 0.999, 0.9999, 0.99999, 1]
default_lmbda_psi = params_dict['lmbda_psi'][-1]
default_alpha = params_dict['alpha'][-2]
## load rules
print("creating rules")
basis_dict = create_basis_dict(train_val_data)
print("done with creating rules")
## init
# rb_predictor = RecurrencyBaselinePredictor(rels)
## train to find best lambda and alpha
start_train = timeit.default_timer()
if parsed['train_flag'] == 'True':
if parsed['load_flag'] == 'True':
with open('best_config.json', 'r') as infile:
best_config = json.load(infile)
else:
print('start training')
best_config = train(params_dict, rels, val_data_prel, trainval_data_prel, neg_sampler, parsed['num_processes'],
parsed['window'])
if parsed['save_config'] == 'True':
import json
with open('best_config.json', 'w') as outfile:
json.dump(best_config, outfile)
else: # use preset lmbda and alpha; same for all relations
best_config = {}
for rel in rels:
best_config[str(rel)] = {}
best_config[str(rel)]['lmbda_psi'] = [parsed['lmbda']]
best_config[str(rel)]['alpha'] = [parsed['alpha']]
end_train = timeit.default_timer()
# compute validation mrr
print("Computing validation MRR")
perf_list_all_val, hits_list_all_val = test(best_config,rels, val_data_prel,
trainval_data_prel, neg_sampler, parsed['num_processes'],
parsed['window'], split_mode='val')
val_mrr = float(np.mean(perf_list_all_val))
# compute test mrr
print("Computing test MRR")
start_test = timeit.default_timer()
perf_list_all, hits_list_all = test(best_config,rels, test_data_prel,
all_data_prel, neg_sampler, parsed['num_processes'],
parsed['window'])
end_o = timeit.default_timer()
train_time_o = round(end_train- start_train, 6)
test_time_o = round(end_o- start_test, 6)
total_time_o = round(end_o- start_o, 6)
print("Running Training to find best configs finished in {} seconds.".format(train_time_o))
print("Running testing with best configs finished in {} seconds.".format(test_time_o))
print("Running all steps finished in {} seconds.".format(total_time_o))
print(f"The test MRR is {np.mean(perf_list_all)}")
print(f"The valid MRR is {val_mrr}")
print(f"The Hits@10 is {np.mean(hits_list_all)}")
print(f"We have {len(perf_list_all)} predictions")
print(f"The test set has len {len(test_data)} ")
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_NONE_{DATA}_results.json'
metric = dataset.eval_metric
save_results({'model': MODEL_NAME,
'train_flag': parsed['train_flag'],
'data': DATA,
'run': 1,
'seed': SEED,
metric: float(np.mean(perf_list_all)),
'val_mrr': val_mrr,
'hits10': float(np.mean(hits_list_all)),
'test_time': test_time_o,
'tot_train_val_time': total_time_o
},
results_filename)
if parsed['num_processes']>1:
ray.shutdown()
================================================
FILE: examples/linkproppred/tkgl-smallpedia/regcn.py
================================================
"""
Temporal Knowledge Graph Reasoning Based on Evolutional Representation Learning
Reference:
- https://github.com/Lee-zix/RE-GCN
Zixuan Li, Xiaolong Jin, Wei Li, Saiping Guan, Jiafeng Guo, Huawei Shen, Yuanzhuo Wang and Xueqi Cheng. Temporal
Knowledge Graph Reasoning Based on Evolutional Representation Learning. SIGIR 2021.
"""
import sys
import timeit
import os
import sys
import os.path as osp
from pathlib import Path
import numpy as np
import torch
import random
from tqdm import tqdm
# internal imports
modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(modules_path)
from modules.rrgcn import RecurrentRGCNREGCN
from tgb.utils.utils import set_random_seed, split_by_time, save_results
from modules.tkg_utils import get_args_regcn, reformat_ts
from modules.tkg_utils_dgl import build_sub_graph
from tgb.linkproppred.evaluate import Evaluator
from tgb.linkproppred.dataset import LinkPropPredDataset
import json
def test(model, history_list, test_list, num_rels, num_nodes, use_cuda, model_name, static_graph, mode, split_mode):
"""
Test the model on either test or validation set
:param model: model used to test
:param history_list: all input history snap shot list, not include output label train list or valid list
:param test_list: test triple snap shot list
:param num_rels: number of relations
:param num_nodes: number of nodes
:param use_cuda:
:param model_name:
:param mode:
:param split_mode: 'test' or 'val' to state which negative samples to load
:return mrr
"""
print("Testing for mode: ", split_mode)
if split_mode == 'test':
timesteps_to_eval = test_timestamps_orig
else:
timesteps_to_eval = val_timestamps_orig
idx = 0
if mode == "test":
# test mode: load parameter form file
if use_cuda:
checkpoint = torch.load(model_name, map_location=torch.device(args.gpu))
else:
checkpoint = torch.load(model_name, map_location=torch.device('cpu'))
# use best stat checkpoint:
print("Load Model name: {}. Using best epoch : {}".format(model_name, checkpoint['epoch']))
print("\n"+"-"*10+"start testing"+"-"*10+"\n")
model.load_state_dict(checkpoint['state_dict'])
model.eval()
input_list = [snap for snap in history_list[-args.test_history_len:]]
perf_list_all = []
hits_list_all = []
perf_per_rel = {}
for rel in range(num_rels):
perf_per_rel[rel] = []
for time_idx, test_snap in enumerate(tqdm(test_list)):
history_glist = [build_sub_graph(num_nodes, num_rels, g, use_cuda, args.gpu) for g in input_list]
test_triples_input = torch.LongTensor(test_snap).cuda() if use_cuda else torch.LongTensor(test_snap)
timesteps_batch =timesteps_to_eval[time_idx]*np.ones(len(test_triples_input[:,0]))
neg_samples_batch = neg_sampler.query_batch(test_triples_input[:,0], test_triples_input[:,2],
timesteps_batch, edge_type=test_triples_input[:,1], split_mode=split_mode)
pos_samples_batch = test_triples_input[:,2]
_, perf_list, hits_list = model.predict(history_glist, num_rels, static_graph, test_triples_input, use_cuda, neg_samples_batch, pos_samples_batch,
evaluator, METRIC)
perf_list_all.extend(perf_list)
hits_list_all.extend(hits_list)
if split_mode == "test":
if args.log_per_rel:
for score, rel in zip(perf_list, test_triples_input[:,1].tolist()):
perf_per_rel[rel].append(score)
# reconstruct history graph list
input_list.pop(0)
input_list.append(test_snap)
idx += 1
if split_mode == "test":
if args.log_per_rel:
for rel in range(num_rels):
if len(perf_per_rel[rel]) > 0:
perf_per_rel[rel] = float(np.mean(perf_per_rel[rel]))
else:
perf_per_rel.pop(rel)
mrr = np.mean(perf_list_all)
hits10 = np.mean(hits_list_all)
return mrr, perf_per_rel, hits10
def run_experiment(args, n_hidden=None, n_layers=None, dropout=None, n_bases=None):
"""
Run the experiment with the given configuration
:param args: arguments
:param n_hidden: hidden dimension
:param n_layers: number of layers
:param dropout: dropout rate
:param n_bases: number of bases
:return: mrr, perf_per_rel (mean reciprocal rank, performance per relation)
"""
# load configuration for grid search the best configuration
if n_hidden:
args.n_hidden = n_hidden
if n_layers:
args.n_layers = n_layers
if dropout:
args.dropout = dropout
if n_bases:
args.n_bases = n_bases
mrr = 0
hits10=0
# 2) set save model path
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
if not osp.exists(save_model_dir):
os.mkdir(save_model_dir)
print('INFO: Create directory {}'.format(save_model_dir))
model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}'
model_state_file = save_model_dir+model_name
perf_per_rel = {}
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
num_static_rels, num_words, static_triples, static_graph = 0, 0, [], None
# create stat
model = RecurrentRGCNREGCN(args.decoder,
args.encoder,
num_nodes,
int(num_rels/2),
num_static_rels, # DIFFERENT
num_words, # DIFFERENT
args.n_hidden,
args.opn,
sequence_len=args.train_history_len,
num_bases=args.n_bases,
num_basis=args.n_basis,
num_hidden_layers=args.n_layers,
dropout=args.dropout,
self_loop=args.self_loop,
skip_connect=args.skip_connect,
layer_norm=args.layer_norm,
input_dropout=args.input_dropout,
hidden_dropout=args.hidden_dropout,
feat_dropout=args.feat_dropout,
aggregation=args.aggregation, # DIFFERENT
weight=args.weight, # DIFFERENT
discount=args.discount, # DIFFERENT
angle=args.angle, # DIFFERENT
use_static=args.add_static_graph, # DIFFERENT
entity_prediction=args.entity_prediction,
relation_prediction=args.relation_prediction,
use_cuda=use_cuda,
gpu = args.gpu,
analysis=args.run_analysis) # DIFFERENT
if use_cuda:
torch.cuda.set_device(args.gpu)
model.cuda()
# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
if args.test and os.path.exists(model_state_file):
mrr, perf_per_rel, hits10 = test(model,
train_list+valid_list,
test_list,
num_rels,
num_nodes,
use_cuda,
model_state_file,
static_graph,
"test",
"test")
return mrr, perf_per_rel, hits10
elif args.test and not os.path.exists(model_state_file):
print("--------------{} not exist, Change mode to train and generate stat for testing----------------\n".format(model_state_file))
return 0, 0
else:
print("----------------------------------------start training----------------------------------------\n")
best_mrr = 0
best_hits = 0
for epoch in range(args.n_epochs):
model.train()
losses = []
idx = [_ for _ in range(len(train_list))]
random.shuffle(idx)
for train_sample_num in tqdm(idx):
if train_sample_num == 0: continue
output = train_list[train_sample_num:train_sample_num+1]
if train_sample_num - args.train_history_len<0:
input_list = train_list[0: train_sample_num]
else:
input_list = train_list[train_sample_num - args.train_history_len:
train_sample_num]
# generate history graph
history_glist = [build_sub_graph(num_nodes, num_rels, snap, use_cuda, args.gpu) for snap in input_list]
output = [torch.from_numpy(_).long().cuda() for _ in output] if use_cuda else [torch.from_numpy(_).long() for _ in output]
loss_e, loss_r, loss_static = model.get_loss(history_glist, output[0], static_graph, use_cuda)
loss = args.task_weight*loss_e + (1-args.task_weight)*loss_r + loss_static
losses.append(loss.item())
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm) # clip gradients
optimizer.step()
optimizer.zero_grad()
print("Epoch {:04d} | Ave Loss: {:.4f} | Best MRR {:.4f} | Model {} "
.format(epoch, np.mean(losses), best_mrr, model_name))
#! checking GPU usage
free_mem, total_mem = torch.cuda.mem_get_info()
print ("--------------GPU memory usage-----------")
print ("there are ", free_mem, " free memory")
print ("there are ", total_mem, " total available memory")
print ("there are ", total_mem - free_mem, " used memory")
print ("--------------GPU memory usage-----------")
# validation
if epoch and epoch % args.evaluate_every == 0:
mrr,perf_per_rel, hits10 = test(model, train_list,
valid_list,
num_rels,
num_nodes,
use_cuda,
model_state_file,
static_graph,
mode="train", split_mode='val')
if mrr < best_mrr:
if epoch >= args.n_epochs:
break
else:
best_mrr = mrr
best_hits = hits10
torch.save({'state_dict': model.state_dict(), 'epoch': epoch}, model_state_file)
return best_mrr, perf_per_rel, hits10
# ==================
# ==================
# ==================
start_overall = timeit.default_timer()
# set hyperparameters
args, _ = get_args_regcn()
args.dataset = 'tkgl-smallpedia'
SEED = args.seed # set the random seed for consistency
set_random_seed(SEED)
DATA=args.dataset
MODEL_NAME = 'REGCN'
print("logging mrrs per relation: ", args.log_per_rel)
# load data
dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
num_rels = dataset.num_rels
num_nodes = dataset.num_nodes
subjects = dataset.full_data["sources"]
objects= dataset.full_data["destinations"]
relations = dataset.edge_type
timestamps_orig = dataset.full_data["timestamps"]
timestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1
all_quads = np.stack((subjects, relations, objects, timestamps), axis=1)
train_data = all_quads[dataset.train_mask]
val_data = all_quads[dataset.val_mask]
test_data = all_quads[dataset.test_mask]
val_timestamps_orig = list(set(timestamps_orig[dataset.val_mask])) # needed for getting the negative samples
val_timestamps_orig.sort()
test_timestamps_orig = list(set(timestamps_orig[dataset.test_mask])) # needed for getting the negative samples
test_timestamps_orig.sort()
train_list = split_by_time(train_data)
valid_list = split_by_time(val_data)
test_list = split_by_time(test_data)
# evaluation metric
METRIC = dataset.eval_metric
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
#load the ns samples
dataset.load_val_ns()
dataset.load_test_ns()
## run training and testing
val_mrr, test_mrr = 0, 0
test_hits10 = 0
if args.grid_search:
print("hyperparameter grid search not implemented. Exiting.")
# single run
else:
start_train = timeit.default_timer()
if args.test == False: #if they are true: directly test on a previously trained and stored model
print('start training')
val_mrr, perf_per_rel, val_hits10 = run_experiment(args) # do training
start_test = timeit.default_timer()
args.test = True
print('start testing')
test_mrr, perf_per_rel, test_hits10 = run_experiment(args) # do testing
test_time = timeit.default_timer() - start_test
all_time = timeit.default_timer() - start_train
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
print(f"\Train and Test: Elapsed Time (s): {all_time: .4f}")
print(f"\tTest: {METRIC}: {test_mrr: .4f}")
print(f"\tValid: {METRIC}: {val_mrr: .4f}")
# saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
save_results({'model': MODEL_NAME,
'data': DATA,
'run': args.run_nr,
'seed': SEED,
f'val {METRIC}': float(val_mrr),
f'test {METRIC}': float(test_mrr),
'test_time': test_time,
'tot_train_val_time': all_time,
'test_hits10': float(test_hits10)
},
results_filename)
if args.log_per_rel:
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results_per_rel.json'
with open(results_filename, 'w') as json_file:
json.dump(perf_per_rel, json_file)
sys.exit()
================================================
FILE: examples/linkproppred/tkgl-smallpedia/timetraveler.py
================================================
"""
TimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting
Reference:
- https://github.com/JHL-HUST/TITer
Haohai Sun, Jialun Zhong, Yunpu Ma, Zhen Han, Kun He.
TimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting EMNLP 2021
"""
import sys
import timeit
import torch
from torch.utils.data import Dataset,DataLoader
import logging
import numpy as np
import pickle
from tqdm import tqdm
import os.path as osp
from pathlib import Path
import os
modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(modules_path)
from modules.timetraveler_agent import Agent
from modules.timetraveler_environment import Env
from modules.timetraveler_dirichlet import Dirichlet, MLE_Dirchlet
from modules.timetraveler_episode import Episode
from modules.timetraveler_policygradient import PG
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.linkproppred.evaluate import Evaluator
from modules.timetraveler_trainertester import Trainer, Tester, getRelEntCooccurrence
from tgb.utils.utils import set_random_seed,save_results
from modules.tkg_utils import get_args_timetraveler, reformat_ts, get_model_config_timetraveler
class QuadruplesDataset(Dataset):
""" this is an internal way how Timetraveler represents the data
"""
def __init__(self, examples):
"""
examples: a list of quadruples.
num_r: number of relations
"""
self.quadruples = examples.copy()
def __len__(self):
return len(self.quadruples)
def __getitem__(self, item):
return self.quadruples[item][0], \
self.quadruples[item][1], \
self.quadruples[item][2], \
self.quadruples[item][3], \
self.quadruples[item][4]
def set_logger(save_path):
"""Write logs to checkpoint and console"""
if args.do_train:
log_file = os.path.join(save_path, 'train.log')
else:
log_file = os.path.join(save_path, 'test.log')
logging.basicConfig(
format='%(asctime)s %(levelname)-8s %(message)s',
level=logging.INFO,
datefmt='%Y-%m-%d %H:%M:%S',
filename=log_file,
filemode='w'
)
console = logging.StreamHandler()
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s')
console.setFormatter(formatter)
logging.getLogger('').addHandler(console)
def preprocess_data(args, config, timestamps, save_path, all_quads):
"""
Preprocess the data and save the state-action space (pickle dump)
"""
# parser = argparse.ArgumentParser(description='Data preprocess', usage='preprocess_data.py [] [-h | --help]')
# parser.add_argument('--data_dir', default='data/ICEWS14', type=str, help='Path to data.')
env = Env(all_quads, config)
state_actions_space = {}
with tqdm(total=len(all_quads)) as bar:
for (head, rel, tail, t, _) in all_quads:
if (head, t, True) not in state_actions_space.keys():
state_actions_space[(head, t, True)] = env.get_state_actions_space_complete(head, t, True, args.store_actions_num)
state_actions_space[(head, t, False)] = env.get_state_actions_space_complete(head, t, False, args.store_actions_num)
if (tail, t, True) not in state_actions_space.keys():
state_actions_space[(tail, t, True)] = env.get_state_actions_space_complete(tail, t, True, args.store_actions_num)
state_actions_space[(tail, t, False)] = env.get_state_actions_space_complete(tail, t, False, args.store_actions_num)
bar.update(1)
pickle.dump(state_actions_space, open(os.path.join(save_path, args.state_actions_path), 'wb'))
def log_metrics(mode, step, metrics):
"""Print the evaluation logs"""
for metric in metrics:
logging.info('%s %s at epoch %d: %f' % (mode, metric, step, metrics[metric]))
def main(args):
"""
Main function to train and test the TimeTraveler model"""
start_overall = timeit.default_timer()
#######################Set Logger#################################
save_path = f'{os.path.dirname(os.path.abspath(__file__))}/saved_models/'
if not os.path.exists(save_path):
os.makedirs(save_path)
if args.cuda and torch.cuda.is_available():
args.cuda = True
else:
args.cuda = False
set_logger(save_path)
#######################Create DataLoader#################################
# set hyperparameters
args.dataset = 'tkgl-smallpedia'
SEED = args.seed # set the random seed for consistency
set_random_seed(SEED)
DATA=args.dataset
MODEL_NAME = 'TIMETRAVELER'
# load data
dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
num_rels = dataset.num_rels
num_nodes = dataset.num_nodes
subjects = dataset.full_data["sources"]
objects= dataset.full_data["destinations"]
relations = dataset.edge_type
timestamps_orig = dataset.full_data["timestamps"]
timestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1
all_quads = np.stack((subjects, relations, objects, timestamps,timestamps_orig), axis=1)
train_data = all_quads[dataset.train_mask]
train_entities = np.unique(np.concatenate((train_data[:, 0], train_data[:, 2])))
RelEntCooccurrence = getRelEntCooccurrence(train_data, num_rels)
train_data =QuadruplesDataset(train_data)
val_data = QuadruplesDataset(all_quads[dataset.val_mask])
test_data = QuadruplesDataset(all_quads[dataset.test_mask])
METRIC = dataset.eval_metric
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
#load the ns samples
dataset.load_val_ns()
dataset.load_test_ns()
train_dataloader = DataLoader(
train_data,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
)
valid_dataloader = DataLoader(
val_data,
batch_size=args.test_batch_size,
shuffle=False,
num_workers=args.num_workers,
)
test_dataloader = DataLoader(
test_data,
batch_size=args.test_batch_size,
shuffle=False,
num_workers=args.num_workers,
)
######################Creat the agent and the environment###########################
config = get_model_config_timetraveler(args, num_nodes, num_rels)
logging.info(config)
logging.info(args)
# creat the agent
agent = Agent(config)
# creat the environment
state_actions_path = os.path.join(save_path, args.state_actions_path)
######################preprocessing###########################
if not os.path.exists(state_actions_path):
if args.preprocess:
print("preprocessing data...")
preprocess_data(args, config, timestamps, save_path, list(all_quads))
state_action_space = pickle.load(open(os.path.join(save_path, args.state_actions_path), 'rb'))
else:
state_action_space = None
else:
print("load preprocessed data...")
state_action_space = pickle.load(open(os.path.join(save_path, args.state_actions_path), 'rb'))
env = Env(list(all_quads), config, state_action_space)
# Create episode controller
episode = Episode(env, agent, config)
if args.cuda:
episode = episode.cuda()
pg = PG(config) # Policy Gradient
optimizer = torch.optim.Adam(episode.parameters(), lr=args.lr, weight_decay=0.00001)
######################Reward Shaping: MLE DIRICHLET alphas###########################
if args.reward_shaping:
try:
print("load alphas from pickle file")
alphas = pickle.load(open(os.path.join(save_path, args.alphas_pkl), 'rb'))
except:
print('running MLE dirichlet now')
mle_d = MLE_Dirchlet(all_quads, num_rels, args.k, args.time_span,
args.tol, args.method, args.maxiter)
pickle.dump(mle_d.alphas, open(os.path.join(save_path, args.alphas_pkl), 'wb'))
print('dumped alphas')
alphas = mle_d.alphas
distributions = Dirichlet(alphas, args.k)
else:
distributions = None
######################Training and Testing###########################
trainer = Trainer(episode, pg, optimizer, args, distributions)
tester = Tester(episode, args, train_entities, RelEntCooccurrence, dataset.metric)
test_metrics ={}
val_metrics = {}
test_metrics[METRIC] = None
val_metrics[METRIC] = None
if args.do_train:
start_train =timeit.default_timer()
logging.info('Start Training......')
for i in range(args.max_epochs):
loss, reward = trainer.train_epoch(train_dataloader, len(train_data))
logging.info('Epoch {}/{} Loss: {}, reward: {}'.format(i, args.max_epochs, loss, reward))
#! checking GPU usage
free_mem, total_mem = torch.cuda.mem_get_info()
print ("--------------GPU memory usage-----------")
print ("there are ", free_mem, " free memory")
print ("there are ", total_mem, " total available memory")
print ("there are ", total_mem - free_mem, " used memory")
print ("--------------GPU memory usage-----------")
if i % args.save_epoch == 0 and i != 0:
trainer.save_model(save_path, 'checkpoint_{}.pth'.format(i))
logging.info('Save Model in {}'.format(save_path))
if i % args.valid_epoch == 0 and i != 0:
logging.info('Start Val......')
val_metrics = tester.test(valid_dataloader,
len(val_data), num_nodes, neg_sampler, evaluator, split_mode='val')
for mode in val_metrics.keys():
logging.info('{} at epoch {}: {}'.format(mode, i, val_metrics[mode]))
trainer.save_model(save_path)
logging.info('Save Model in {}'.format(save_path))
else:
# # Load the model parameters
if os.path.isfile(save_path):
params = torch.load(save_path)
episode.load_state_dict(params['model_state_dict'])
optimizer.load_state_dict(params['optimizer_state_dict'])
logging.info('Load pretrain model: {}'.format(save_path))
if args.do_test:
logging.info('Start Testing......')
start_test = timeit.default_timer()
test_metrics = tester.test(test_dataloader,
len(test_data), num_nodes, neg_sampler, evaluator, split_mode='test')
for mode in test_metrics.keys():
logging.info('Test {} : {}'.format(mode, test_metrics[mode]))
# saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
test_time = timeit.default_timer() - start_test
all_time = timeit.default_timer() - start_train
all_time_preprocess = timeit.default_timer() - start_overall
save_results({'model': MODEL_NAME,
'data': DATA,
'seed': SEED,
f'val {METRIC}': float(val_metrics[METRIC]),
f'test {METRIC}': float(test_metrics[METRIC]),
'test_time': test_time,
'tot_train_val_time': all_time,
'tot_preprocess_train_val_time': all_time_preprocess
},
results_filename)
if __name__ == '__main__':
args = get_args_timetraveler()
main(args)
================================================
FILE: examples/linkproppred/tkgl-smallpedia/tkgl-smallpedia_example.py
================================================
import numpy as np
import timeit
from tqdm import tqdm
import sys
import os.path as osp
import os
from pathlib import Path
# internal imports
modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(modules_path)
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
from torch_geometric.loader import TemporalDataLoader
from tgb.linkproppred.evaluate import Evaluator
DATA = "tkgl-smallpedia"
# data loading
dataset = PyGLinkPropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
metric = dataset.eval_metric
print ("there are {} nodes and {} edges".format(dataset.num_nodes, dataset.num_edges))
print ("there are {} relation types".format(dataset.num_rels))
#! must run in this order
static_data = dataset.static_data
static_head = static_data["head"]
static_tail = static_data["tail"]
static_edge_type = static_data["edge_type"]
print ('static edges processed')
print ("static data has ", static_head.shape[0], " edges")
timestamp = data.t
head = data.src
tail = data.dst
edge_type = data.edge_type #relation
neg_sampler = dataset.negative_sampler
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
metric = dataset.eval_metric
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
BATCH_SIZE = 200
val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)
test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)
start_time = timeit.default_timer()
#load the ns samples first
dataset.load_val_ns()
for batch in tqdm(val_loader):
src, pos_dst, t, msg, rel = batch.src, batch.dst, batch.t, batch.msg, batch.edge_type
neg_batch_list = neg_sampler.query_batch(src.detach().cpu().numpy(), pos_dst.detach().cpu().numpy(), t.detach().cpu().numpy(), rel.detach().cpu().numpy(), split_mode='val')
print ("loading ns samples from validation", timeit.default_timer() - start_time)
start_time = timeit.default_timer()
dataset.load_test_ns()
for batch in test_loader:
src, pos_dst, t, msg, rel = batch.src, batch.dst, batch.t, batch.msg, batch.edge_type
neg_batch_list = neg_sampler.query_batch(src.detach().cpu().numpy(), pos_dst.detach().cpu().numpy(), t.detach().cpu().numpy(), rel.detach().cpu().numpy(), split_mode='test')
print ("loading ns samples from test", timeit.default_timer() - start_time)
print ("retrieved all negative samples")
# #* load numpy arrays instead
# from tgb.linkproppred.dataset import LinkPropPredDataset
# # data loading
# dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
# data = dataset.full_data
# metric = dataset.eval_metric
# sources = dataset.full_data['sources']
# print (sources.dtype)
================================================
FILE: examples/linkproppred/tkgl-smallpedia/tlogic.py
================================================
"""
https://github.com/liu-yushan/TLogic/tree/main/mycode
TLogic: Temporal Logical Rules for Explainable Link Forecasting on Temporal Knowledge Graphs.
Yushan Liu, Yunpu Ma, Marcel Hildebrandt, Mitchell Joblin, Volker Tresp
"""
# imports
import sys
import os
import os.path as osp
from pathlib import Path
modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(modules_path)
import timeit
import argparse
import numpy as np
import json
from joblib import Parallel, delayed
import itertools
from tgb.linkproppred.evaluate import Evaluator
from tgb.linkproppred.dataset import LinkPropPredDataset
from modules.tlogic_learn_modules import Temporal_Walk, Rule_Learner, store_edges
import modules.tlogic_apply_modules as ra
from tgb.utils.utils import set_random_seed, save_results
from modules.tkg_utils import reformat_ts, get_inv_relation_id, create_scores_array
def learn_rules(i, num_relations):
"""
Learn rules (multiprocessing possible).
Parameters:
i (int): process number
num_relations (int): minimum number of relations for each process
Returns:
rl.rules_dict (dict): rules dictionary
"""
# if seed:
# np.random.seed(seed)
num_rest_relations = len(all_relations) - (i + 1) * num_relations
if num_rest_relations >= num_relations:
relations_idx = range(i * num_relations, (i + 1) * num_relations)
else:
relations_idx = range(i * num_relations, len(all_relations))
num_rules = [0]
for k in relations_idx:
rel = all_relations[k]
for length in rule_lengths:
it_start = timeit.default_timer()
for _ in range(num_walks):
walk_successful, walk = temporal_walk.sample_walk(length + 1, rel)
if walk_successful:
rl.create_rule(walk)
it_end = timeit.default_timer()
it_time = round(it_end - it_start, 6)
num_rules.append(sum([len(v) for k, v in rl.rules_dict.items()]) // 2)
num_new_rules = num_rules[-1] - num_rules[-2]
print(
"Process {0}: relation {1}/{2}, length {3}: {4} sec, {5} rules".format(
i,
k - relations_idx[0] + 1,
len(relations_idx),
length,
it_time,
num_new_rules,
)
)
return rl.rules_dict
def apply_rules(i, num_queries, rules_dict, neg_sampler, data, window, learn_edges, all_quads, args, split_mode,
log_per_rel=False, num_rels=0):
"""
Apply rules (multiprocessing possible).
Parameters:
i (int): process number
num_queries (int): minimum number of queries for each process
Returns:
hits_list (list): hits list (hits@10 per sample)
perf_list (list): performance list (mrr per sample)
"""
perf_per_rel = {}
for rel in range(num_rels):
perf_per_rel[rel] = []
print("Start process", i, "...")
all_candidates = [dict() for _ in range(len(args))]
no_cands_counter = 0
num_rest_queries = len(data) - (i + 1) * num_queries
if num_rest_queries >= num_queries:
test_queries_idx = range(i * num_queries, (i + 1) * num_queries)
else:
test_queries_idx = range(i * num_queries, len(data))
cur_ts = data[test_queries_idx[0]][3]
edges = ra.get_window_edges(all_quads[:,0:4], cur_ts, learn_edges, window)
it_start = timeit.default_timer()
hits_list = [0] * len(test_queries_idx)
perf_list = [0] * len(test_queries_idx)
for index, j in enumerate(test_queries_idx):
neg_sample_el = neg_sampler.query_batch(np.expand_dims(np.array(data[j,0]), axis=0),
np.expand_dims(np.array(data[j,2]), axis=0),
np.expand_dims(np.array(data[j,4]), axis=0),
np.expand_dims(np.array(data[j,1]), axis=0),
split_mode=split_mode)[0]
# neg_samples_batch[j]
pos_sample_el = data[j,2]
test_query = data[j]
assert pos_sample_el == test_query[2]
cands_dict = [dict() for _ in range(len(args))]
if test_query[3] != cur_ts:
cur_ts = test_query[3]
edges = ra.get_window_edges(all_quads[:,0:4], cur_ts, learn_edges, window)
if test_query[1] in rules_dict:
dicts_idx = list(range(len(args)))
for rule in rules_dict[test_query[1]]:
walk_edges = ra.match_body_relations(rule, edges, test_query[0])
if 0 not in [len(x) for x in walk_edges]:
rule_walks = ra.get_walks(rule, walk_edges)
if rule["var_constraints"]:
rule_walks = ra.check_var_constraints(
rule["var_constraints"], rule_walks
)
if not rule_walks.empty:
cands_dict = ra.get_candidates(
rule,
rule_walks,
cur_ts,
cands_dict,
score_func,
args,
dicts_idx,
)
for s in dicts_idx:
cands_dict[s] = {
x: sorted(cands_dict[s][x], reverse=True)
for x in cands_dict[s].keys()
}
cands_dict[s] = dict(
sorted(
cands_dict[s].items(),
key=lambda item: item[1],
reverse=True,
)
)
top_k_scores = [v for _, v in cands_dict[s].items()][:top_k]
unique_scores = list(
scores for scores, _ in itertools.groupby(top_k_scores)
)
if len(unique_scores) >= top_k:
dicts_idx.remove(s)
if not dicts_idx:
break
if cands_dict[0]:
for s in range(len(args)):
# Calculate noisy-or scores
scores = list(
map(
lambda x: 1 - np.product(1 - np.array(x)),
cands_dict[s].values(),
)
)
cands_scores = dict(zip(cands_dict[s].keys(), scores))
noisy_or_cands = dict(
sorted(cands_scores.items(), key=lambda x: x[1], reverse=True)
)
all_candidates[s][j] = noisy_or_cands
else: # No candidates found by applying rules
no_cands_counter += 1
for s in range(len(args)):
all_candidates[s][j] = dict()
else: # No rules exist for this relation
no_cands_counter += 1
for s in range(len(args)):
all_candidates[s][j] = dict()
if not (j - test_queries_idx[0] + 1) % 100:
it_end = timeit.default_timer()
it_time = round(it_end - it_start, 6)
print(
"Process {0}: test samples finished: {1}/{2}, {3} sec".format(
i, j - test_queries_idx[0] + 1, len(test_queries_idx), it_time
)
)
it_start = timeit.default_timer()
predictions = create_scores_array(all_candidates[s][j], num_nodes)
predictions_of_interest_pos = np.array(predictions[pos_sample_el])
predictions_of_interest_neg = predictions[neg_sample_el]
input_dict = {
"y_pred_pos": predictions_of_interest_pos,
"y_pred_neg": predictions_of_interest_neg,
"eval_metric": ['mrr'],
}
predictions = evaluator.eval(input_dict)
perf_list[index] = predictions['mrr']
hits_list[index] = predictions['hits@10']
if split_mode == "test":
if log_per_rel:
perf_per_rel[test_query[1]].append(perf_list[index]) #test_query[1] is the relation index
if split_mode == "test":
if log_per_rel:
for rel in range(num_rels):
if len(perf_per_rel[rel]) > 0:
perf_per_rel[rel] = float(np.mean(perf_per_rel[rel]))
else:
perf_per_rel.pop(rel)
return perf_list, hits_list, perf_per_rel
## args
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", "-d", default="tkgl-smallpedia", type=str)
parser.add_argument("--rule_lengths", "-l", default="1", type=int, nargs="+")
parser.add_argument("--num_walks", "-n", default="100", type=int)
parser.add_argument("--transition_distr", default="exp", type=str)
parser.add_argument("--window", "-w", default=0, type=int) # set to e.g. 200 if only the most recent 200 timesteps should be considered. set to -2 if multistep
parser.add_argument("--top_k", default=20, type=int)
parser.add_argument("--num_processes", "-p", default=1, type=int)
parser.add_argument("--alpha", "-alpha", default=0.99, type=float) # fix alpha. used if trainflag == false
# parser.add_argument("--train_flag", "-tr", default=True) # do we need training, ie selection of lambda and alpha
parser.add_argument("--save_config", "-c", default=True) # do we need to save the selection of lambda and alpha in config file?
parser.add_argument('--seed', type=int, help='Random seed', default=1)
parser.add_argument('--run_nr', type=int, help='Run Number', default=1)
parser.add_argument('--learn_rules_flag', type=bool, help='Do we want to learn the rules', default=True)
parser.add_argument('--rule_filename', type=str, help='if rules not learned: where are they stored', default='0_r[3]_n100_exp_s1_rules.json')
parser.add_argument('--log_per_rel', type=bool, help='Do we want to log mrr per relation', default=False)
parser.add_argument('--compute_valid_mrr', type=bool, help='Do we want to compute mrr for valid set', default=True)
parsed = vars(parser.parse_args())
return parsed
start_o = timeit.default_timer()
## get args
parsed = get_args()
dataset = parsed["dataset"]
rule_lengths = parsed["rule_lengths"]
rule_lengths = [rule_lengths] if (type(rule_lengths) == int) else rule_lengths
print('rule_lengths', rule_lengths)
num_walks = parsed["num_walks"]
transition_distr = parsed["transition_distr"]
num_processes = parsed["num_processes"]
window = parsed["window"]
top_k = parsed["top_k"]
log_per_rel = parsed['log_per_rel']
MODEL_NAME = 'TLogic'
SEED = parsed['seed'] # set the random seed for consistency
set_random_seed(SEED)
## load dataset and prepare it accordingly
name = parsed["dataset"]
compute_valid_mrr = parsed["compute_valid_mrr"]
dataset = LinkPropPredDataset(name=name, root="datasets", preprocess=True)
DATA = name
relations = dataset.edge_type
num_rels = dataset.num_rels
subjects = dataset.full_data["sources"]
objects= dataset.full_data["destinations"]
num_nodes = dataset.num_nodes
timestamps_orig = dataset.full_data["timestamps"]
timestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1
all_quads = np.stack((subjects, relations, objects, timestamps, timestamps_orig), axis=1)
train_data = all_quads[dataset.train_mask,0:4] # we do not need the original timestamps
val_data = all_quads[dataset.val_mask,0:5]
test_data = all_quads[dataset.test_mask,0:5]
all_data = all_quads[:,0:4]
metric = dataset.eval_metric
evaluator = Evaluator(name=name)
neg_sampler = dataset.negative_sampler
inv_relation_id = get_inv_relation_id(num_rels)
#load the ns samples
dataset.load_val_ns()
dataset.load_test_ns()
output_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
learn_rules_flag = parsed['learn_rules_flag']
## 1. learn rules
start_train = timeit.default_timer()
if learn_rules_flag:
print("start learning rules")
# edges (dict): edges for each relation
# inv_relation_id (dict): mapping of relation to inverse relation
temporal_walk = Temporal_Walk(train_data, inv_relation_id, transition_distr)
rl = Rule_Learner(edges=temporal_walk.edges, id2relation=None, inv_relation_id=inv_relation_id,
output_dir=output_dir)
all_relations = sorted(temporal_walk.edges) # Learn for all relations
start = timeit.default_timer()
num_relations = len(all_relations) // num_processes
output = Parallel(n_jobs=num_processes)(
delayed(learn_rules)(i, num_relations) for i in range(num_processes)
)
end = timeit.default_timer()
all_rules = output[0]
for i in range(1, num_processes):
all_rules.update(output[i])
total_time = round(end - start, 6)
print("Learning finished in {} seconds.".format(total_time))
rl.rules_dict = all_rules
rl.sort_rules_dict()
rule_filename = rl.save_rules(0, rule_lengths, num_walks, transition_distr, SEED)
# rl.save_rules_verbalized(0, rule_lengths, num_walks, transition_distr, seed)
# rules_statistics(rl.rules_dict)
else:
rule_filename = parsed['rule_filename']
print("Loading rules from file {}".format(parsed['rule_filename']))
end_train = timeit.default_timer()
## 2. Apply rules
rules_dict = json.load(open(output_dir + rule_filename))
rules_dict = {int(k): v for k, v in rules_dict.items()}
rules_dict = ra.filter_rules(
rules_dict, min_conf=0.01, min_body_supp=2, rule_lengths=rule_lengths
) # filter rules for minimum confidence, body support and rule length
learn_edges = store_edges(train_data)
score_func = ra.score_12
# It is possible to specify a list of list of arguments for tuning
args = [[0.1, 0.5]]
# compute valid mrr
start_valid = timeit.default_timer()
if compute_valid_mrr:
print('Computing valid MRR')
num_queries = len(val_data) // num_processes
output = Parallel(n_jobs=num_processes)(
delayed(apply_rules)(i, num_queries,rules_dict, neg_sampler, val_data, window, learn_edges,
all_quads, args, split_mode='val') for i in range(num_processes))
end = timeit.default_timer()
perf_list_val = []
hits_list_val = []
for i in range(num_processes):
perf_list_val.extend(output[i][0])
hits_list_val.extend(output[i][1])
else:
perf_list_val = [0]
hits_list_val = [0]
end_valid = timeit.default_timer()
# compute test mrr
if log_per_rel ==True:
num_processes = 1 #otherwise logging per rel does not work for our implementation
start_test = timeit.default_timer()
print('Computing test MRR')
start = timeit.default_timer()
num_queries = len(test_data) // num_processes
output = Parallel(n_jobs=num_processes)(
delayed(apply_rules)(i, num_queries,rules_dict, neg_sampler, test_data, window, learn_edges,
all_quads, args, split_mode='test', log_per_rel=log_per_rel, num_rels=num_rels) for i in range(num_processes))
end = timeit.default_timer()
perf_list_all = []
hits_list_all = []
for i in range(num_processes):
perf_list_all.extend(output[i][0])
hits_list_all.extend(output[i][1])
if log_per_rel == True:
perf_per_rel = output[0][2]
total_time = round(end - start, 6)
total_valid_time = round(end_valid - start_valid, 6)
print("Application finished in {} seconds.".format(total_time))
print(f"The valid MRR is {np.mean(perf_list_val)}")
print(f"The MRR is {np.mean(perf_list_all)}")
print(f"The Hits@10 is {np.mean(hits_list_all)}")
print(f"We have {len(perf_list_all)} predictions")
print(f"The test set has len {len(test_data)} ")
end_o = timeit.default_timer()
train_time_o = round(end_train- start_train, 6)
test_time_o = round(end_o- start_test, 6)
total_time_o = round(end_o- start_o, 6)
print("Running Training to find best configs finished in {} seconds.".format(train_time_o))
print("Running testing with best configs finished in {} seconds.".format(test_time_o))
print("Running all steps finished in {} seconds.".format(total_time_o))
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
if log_per_rel == True:
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results_per_rel.json'
with open(results_filename, 'w') as json_file:
json.dump(perf_per_rel, json_file)
results_filename = f'{results_path}/{MODEL_NAME}_NONE_{DATA}_results.json'
metric = dataset.eval_metric
save_results({'model': MODEL_NAME,
'train_flag': None,
'rule_len': rule_lengths,
'window': window,
'data': DATA,
'run': 1,
'seed': SEED,
metric: float(np.mean(perf_list_all)),
'hits10': float(np.mean(hits_list_all)),
'val_mrr': float(np.mean(perf_list_val)),
'test_time': test_time_o,
'tot_train_val_time': total_time_o,
'valid_time': total_valid_time
},
results_filename)
================================================
FILE: examples/linkproppred/tkgl-wikidata/edgebank.py
================================================
"""
Dynamic Link Prediction with EdgeBank
NOTE: This implementation works only based on `numpy`
Reference:
- https://github.com/fpour/DGB/tree/main
"""
import timeit
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch_geometric.loader import TemporalDataLoader
from tqdm import tqdm
import math
import os
import os.path as osp
from pathlib import Path
import sys
import argparse
# internal imports
tgb_modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(tgb_modules_path)
from tgb.linkproppred.evaluate import Evaluator
from modules.edgebank_predictor import EdgeBankPredictor
from tgb.utils.utils import set_random_seed
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.utils.utils import save_results
# ==================
# ==================
# ==================
def test(data, test_mask, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
data: a dataset object
test_mask: required masks to load the test set edges
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluation
"""
num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)
perf_list = []
hits_list = []
for batch_idx in tqdm(range(num_batches)):
start_idx = batch_idx * BATCH_SIZE
end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))
pos_src, pos_dst, pos_t, pos_edge = (
data['sources'][test_mask][start_idx: end_idx],
data['destinations'][test_mask][start_idx: end_idx],
data['timestamps'][test_mask][start_idx: end_idx],
data['edge_type'][test_mask][start_idx: end_idx],
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, pos_edge, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])
query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])
y_pred = edgebank.predict_link(query_src, query_dst)
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0]]),
"y_pred_neg": np.array(y_pred[1:]),
"eval_metric": [metric],
}
results = evaluator.eval(input_dict)
perf_list.append(results[metric])
hits_list.append(results['hits@10'])
# update edgebank memory after each positive batch
edgebank.update_memory(pos_src, pos_dst, pos_t)
perf_metrics = float(np.mean(perf_list))
perf_hits = float(np.mean(hits_list))
return perf_metrics, perf_hits
def get_args():
parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')
parser.add_argument('-d', '--data', type=str, help='Dataset name', default='tkgl-wikidata')
parser.add_argument('--bs', type=int, help='Batch size', default=200)
parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)
parser.add_argument('--seed', type=int, help='Random seed', default=1)
parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])
parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)
try:
args = parser.parse_args()
except:
parser.print_help()
sys.exit(0)
return args, sys.argv
# ==================
# ==================
# ==================
start_overall = timeit.default_timer()
# set hyperparameters
args, _ = get_args()
SEED = args.seed # set the random seed for consistency
set_random_seed(SEED)
MEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`
BATCH_SIZE = args.bs
K_VALUE = args.k_value
TIME_WINDOW_RATIO = args.time_window_ratio
DATA = args.data
MODEL_NAME = 'EdgeBank'
# data loading with `numpy`
dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
data = dataset.full_data
metric = dataset.eval_metric
# get masks
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
#data for memory in edgebank
hist_src = np.concatenate([data['sources'][train_mask]])
hist_dst = np.concatenate([data['destinations'][train_mask]])
hist_ts = np.concatenate([data['timestamps'][train_mask]])
# Set EdgeBank with memory updater
edgebank = EdgeBankPredictor(
hist_src,
hist_dst,
hist_ts,
memory_mode=MEMORY_MODE,
time_window_ratio=TIME_WINDOW_RATIO)
print("==========================================================")
print(f"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'
# ==================================================== Test
# loading the validation negative samples
dataset.load_val_ns()
# testing ...
start_val = timeit.default_timer()
perf_metric_val, perf_hits_val = test(data, val_mask, neg_sampler, split_mode='val')
end_val = timeit.default_timer()
print(f"INFO: val: Evaluation Setting: >>> ONE-VS--ALL <<< ")
print(f"\tval: {metric}: {perf_metric_val: .4f}")
val_time = timeit.default_timer() - start_val
print(f"\tval: Elapsed Time (s): {val_time: .4f}")
# ==================================================== Test
# loading the test negative samples
dataset.load_test_ns()
# testing ...
start_test = timeit.default_timer()
perf_metric_test, perf_hits_test = test(data, test_mask, neg_sampler, split_mode='test')
end_test = timeit.default_timer()
print(f"INFO: Test: Evaluation Setting: >>> <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'memory_mode': MEMORY_MODE,
'data': DATA,
'run': 1,
'seed': SEED,
metric: perf_metric_test,
'val_mrr': perf_metric_val,
'test_time': test_time,
'tot_train_val_time': test_time+val_time,
'hits10': perf_hits_test},
results_filename)
================================================
FILE: examples/linkproppred/tkgl-wikidata/recurrencybaseline.py
================================================
""" from paper: "History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting"
Julia Gastinger, Christian Meilicke, Federico Errica, Timo Sztyler, Anett Schuelke, Heiner Stuckenschmidt (IJCAI 2024)
@inproceedings{gastinger2024baselines,
title={History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting},
author={Gastinger, Julia and Meilicke, Christian and Errica, Federico and Sztyler, Timo and Schuelke, Anett and Stuckenschmidt, Heiner},
booktitle={33nd International Joint Conference on Artificial Intelligence (IJCAI 2024)},
year={2024},
organization={International Joint Conferences on Artificial Intelligence Organization}
}
"""
## imports
import timeit
import argparse
import numpy as np
from copy import copy
from pathlib import Path
import ray
import sys
import os
import os.path as osp
import json
#internal imports
modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(modules_path)
from modules.recurrencybaseline_predictor import baseline_predict, baseline_predict_remote
from tgb.linkproppred.evaluate import Evaluator
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.utils.utils import set_random_seed, save_results
from modules.tkg_utils import create_basis_dict, group_by, reformat_ts
def predict(num_processes, data_c_rel, all_data_c_rel, alpha, lmbda_psi,
perf_list_all, hits_list_all, window, neg_sampler, split_mode):
""" create predictions for each relation on test or valid set and compute mrr
:return perf_list_all: list of mrrs for each test query
:return hits_list_all: list of hits for each test query
"""
first_ts = data_c_rel[0][3]
## use this if you wanna use ray:
num_queries = len(data_c_rel) // num_processes
if num_queries < num_processes: # if we do not have enough queries for all the processes
num_processes_tmp = 1
num_queries = len(data_c_rel)
else:
num_processes_tmp = num_processes
if num_processes > 1:
object_references =[]
for i in range(num_processes_tmp):
num_test_queries = len(data_c_rel) - (i + 1) * num_queries
if num_test_queries >= num_queries:
test_queries_idx =[i * num_queries, (i + 1) * num_queries]
else:
test_queries_idx = [i * num_queries, len(test_data)]
valid_data_b = data_c_rel[test_queries_idx[0]:test_queries_idx[1]]
ob = baseline_predict_remote.remote(num_queries, valid_data_b, all_data_c_rel, window,
basis_dict,
num_nodes, num_rels, lmbda_psi,
alpha, evaluator,first_ts, neg_sampler, split_mode)
object_references.append(ob)
output = ray.get(object_references)
# updates the scores and logging dict for each process
for proc_loop in range(num_processes_tmp):
perf_list_all.extend(output[proc_loop][0])
hits_list_all.extend(output[proc_loop][1])
else:
perf_list, hits_list = baseline_predict(len(data_c_rel), data_c_rel, all_data_c_rel,
window, basis_dict,
num_nodes, num_rels, lmbda_psi,
alpha, evaluator, first_ts, neg_sampler, split_mode)
perf_list_all.extend(perf_list)
hits_list_all.extend(hits_list)
return perf_list_all, hits_list_all
## test
def test(best_config, all_relations,test_data_prel, all_data_prel, neg_sampler, num_processes, window, split_mode='test'):
""" create predictions by loopoing through all relations on test or valid set and compute mrr
:return perf_list_all: list of mrrs for each test query
:return hits_list_all: list of hits for each test query
"""
perf_list_all = []
hits_list_all =[]
csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'
## loop through relations and apply baselines
for rel in all_relations:
start = timeit.default_timer()
if rel in test_data_prel.keys():
lmbda_psi = best_config[str(rel)]['lmbda_psi'][0]
alpha = best_config[str(rel)]['alpha'][0]
# test data for this relation
test_data_c_rel = test_data_prel[rel]
timesteps_test = list(set(test_data_c_rel[:,3]))
timesteps_test.sort()
all_data_c_rel = all_data_prel[rel]
perf_list_rel = []
hits_list_rel = []
perf_list_rel, hits_list_rel = predict(num_processes, test_data_c_rel,
all_data_c_rel, alpha, lmbda_psi,perf_list_rel, hits_list_rel,
window, neg_sampler, split_mode)
perf_list_all.extend(perf_list_rel)
hits_list_all.extend(hits_list_rel)
else:
perf_list_rel =[]
end = timeit.default_timer()
total_time = round(end - start, 6)
print("Relation {} finished in {} seconds.".format(rel, total_time))
with open(csv_file, 'a') as f:
f.write("{},{}\n".format(rel, perf_list_rel))
return perf_list_all, hits_list_all
def read_dict_compute_mrr(split_mode='test'):
""" read the results per relation from a precreated file and compute mrr
:return mrr_per_rel: dictionary of mrrs for each relation
:return all_mrrs: list of mrrs for all relations
"""
csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'
# Initialize an empty dictionary to store the data
results_per_rel_dict = {}
mrr_per_rel = {}
all_mrrs = []
# Open the file for reading
with open(csv_file, 'r') as f:
# Read each line in the file
for line in f:
# Split the line at the comma
parts = line.strip().split(',')
# Extract the key (the first part)
key = int(parts[0])
# Extract the values (the rest of the parts), remove square brackets
values = [float(value.strip('[]')) for value in parts[1:]]
# Add the key-value pair to the dictionary
if key in results_per_rel_dict.keys():
print(f"Key {key} already exists in the dictionary!!! might have duplicate entries in results csv")
results_per_rel_dict[key] = values
all_mrrs.extend(values)
mrr_per_rel[key] = np.mean(values)
if len(list(results_per_rel_dict.keys())) != num_rels:
print("we do not have entries for each rel in the results csv file. only num enties: ", len(list(results_per_rel_dict.keys())))
print("Split mode: "+split_mode +" Mean MRR: ", np.mean(all_mrrs))
print("mrr per relation: ", mrr_per_rel)
## train
def train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampler, num_processes, window):
""" optional, find best values for lambda and alpha by looping through all relations and testing an a fixed set of params
based on validation mrr
:return best_config: dictionary of best params for each relation
"""
best_config= {}
best_mrr = 0
for rel in rels: # loop through relations. for each relation, apply rules with selected params, compute valid mrr
start = timeit.default_timer()
rel_key = int(rel)
best_config[str(rel_key)] = {}
best_config[str(rel_key)]['not_trained'] = 'True'
best_config[str(rel_key)]['lmbda_psi'] = [default_lmbda_psi,0] #default
best_config[str(rel_key)]['other_lmbda_mrrs'] = list(np.zeros(len(params_dict['lmbda_psi'])))
best_config[str(rel_key)]['alpha'] = [default_alpha,0] #default
best_config[str(rel_key)]['other_alpha_mrrs'] = list(np.zeros(len(params_dict['alpha'])))
if rel in val_data_prel.keys():
# valid data for this relation
val_data_c_rel = val_data_prel[rel]
timesteps_valid = list(set(val_data_c_rel[:,3]))
timesteps_valid.sort()
trainval_data_c_rel = trainval_data_prel[rel]
###### 1) select lambda ###############
lmbdas_psi = params_dict['lmbda_psi']
alpha = 1
best_lmbda_psi = 0.1
best_mrr_psi = 0
lmbda_mrrs = []
best_config[str(rel_key)]['num_app_valid'] = copy(len(val_data_c_rel))
best_config[str(rel_key)]['num_app_train_valid'] = copy(len(trainval_data_c_rel))
best_config[str(rel_key)]['not_trained'] = 'False'
for lmbda_psi in lmbdas_psi:
perf_list_r = []
hits_list_r = []
perf_list_r, hits_list_r = predict(num_processes, val_data_c_rel,
trainval_data_c_rel, alpha, lmbda_psi,perf_list_r, hits_list_r,
window, neg_sampler, split_mode='val')
# compute mrr
mrr = np.mean(perf_list_r)
# # is new mrr better than previous best? if yes: store lmbda
if mrr > best_mrr_psi:
best_mrr_psi = float(mrr)
best_lmbda_psi = lmbda_psi
lmbda_mrrs.append(float(mrr))
best_config[str(rel_key)]['lmbda_psi'] = [best_lmbda_psi, best_mrr_psi]
best_config[str(rel_key)]['other_lmbda_mrrs'] = lmbda_mrrs
best_mrr = best_mrr_psi
##### 2) select alpha ###############
best_config[str(rel_key)]['not_trained'] = 'False'
alphas = params_dict['alpha']
lmbda_psi = best_config[str(rel_key)]['lmbda_psi'][0] # use the best lmbda psi
alpha_mrrs = []
# perf_list_all = []
best_mrr_alpha = 0
best_alpha=0.99
for alpha in alphas:
perf_list_r = []
hits_list_r = []
perf_list_r, hits_list_r = predict(num_processes, val_data_c_rel,
trainval_data_c_rel, alpha, lmbda_psi,perf_list_r, hits_list_r,
window, neg_sampler, split_mode='val')
# compute mrr
mrr_alpha = np.mean(perf_list_r)
# is new mrr better than previous best? if yes: store alpha
if mrr_alpha > best_mrr_alpha:
best_mrr_alpha = float(mrr_alpha)
best_alpha = alpha
best_mrr = best_mrr_alpha
alpha_mrrs.append(float(mrr_alpha))
best_config[str(rel_key)]['alpha'] = [best_alpha, best_mrr_alpha]
best_config[str(rel_key)]['other_alpha_mrrs'] = alpha_mrrs
end = timeit.default_timer()
total_time = round(end - start, 6)
print("Relation {} finished in {} seconds.".format(rel, total_time))
return best_config
## args
def get_args():
"""parse all arguments for the script"""
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", "-d", default="tkgl-wikidata", type=str)
parser.add_argument("--window", "-w", default=0, type=int) # set to e.g. 200 if only the most recent 200 timesteps should be considered. set to -2 if multistep
parser.add_argument("--num_processes", "-p", default=1, type=int)
parser.add_argument("--lmbda", "-l", default=0.1, type=float) # fix lambda. used if trainflag == false
parser.add_argument("--alpha", "-alpha", default=0.99, type=float) # fix alpha. used if trainflag == false
parser.add_argument("--train_flag", "-tr", default='False') # do we need training, ie selection of lambda and alpha
parser.add_argument("--load_flag", "-lo", default='False') # if train_flag set to True: do you want to load best_config?
parser.add_argument("--save_config", "-c", default='True') # do we need to save the selection of lambda and alpha in config file?
parser.add_argument('--seed', type=int, help='Random seed', default=1) # not needed
parsed = vars(parser.parse_args())
return parsed
start_o = timeit.default_timer()
parsed = get_args()
if parsed['num_processes']>1:
ray.init(num_cpus=parsed["num_processes"], num_gpus=0)
MODEL_NAME = 'RecurrencyBaseline'
SEED = parsed['seed'] # set the random seed for consistency
set_random_seed(SEED)
perrel_results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_models'
if not osp.exists(perrel_results_path):
os.mkdir(perrel_results_path)
print('INFO: Create directory {}'.format(perrel_results_path))
Path(perrel_results_path).mkdir(parents=True, exist_ok=True)
## load dataset and prepare it accordingly
name = parsed["dataset"]
dataset = LinkPropPredDataset(name=name, root="datasets", preprocess=True)
DATA = name
relations = dataset.edge_type
num_rels = dataset.num_rels
rels = np.arange(0,num_rels)
subjects = dataset.full_data["sources"]
objects= dataset.full_data["destinations"]
num_nodes = dataset.num_nodes
timestamps_orig = dataset.full_data["timestamps"]
timestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1
print("split train valid test data")
all_quads = np.stack((subjects, relations, objects, timestamps, timestamps_orig), axis=1)
train_data = all_quads[dataset.train_mask]
val_data = all_quads[dataset.val_mask]
test_data = all_quads[dataset.test_mask]
metric = dataset.eval_metric
evaluator = Evaluator(name=name)
neg_sampler = dataset.negative_sampler
train_val_data = np.concatenate([train_data, val_data])
all_data = np.concatenate([train_data, val_data, test_data])
# create dicts with key: relation id, values: triples for that relation id
print("grouping data by relation")
test_data_prel = group_by(test_data, 1)
all_data_prel = group_by(all_data, 1)
val_data_prel = group_by(val_data, 1)
trainval_data_prel = group_by(train_val_data, 1)
#load the ns samples
# if parsed['train_flag']:
print("loading negative samples")
dataset.load_val_ns()
dataset.load_test_ns()
# parameter options
if parsed['train_flag'] == 'True':
params_dict = {}
params_dict['lmbda_psi'] = [0, 0.0001, 0.0005, 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.5, 0.9, 1.0001]
params_dict['alpha'] = [0, 0.00001, 0.0001, 0.001, 0.01, 0.1, 0.5, 0.9, 0.99, 0.999, 0.9999, 0.99999, 1]
default_lmbda_psi = params_dict['lmbda_psi'][-1]
default_alpha = params_dict['alpha'][-2]
## load rules
print("creating rules")
basis_dict = create_basis_dict(train_val_data)
print("done with creating rules")
## init
# rb_predictor = RecurrencyBaselinePredictor(rels)
## train to find best lambda and alpha
start_train = timeit.default_timer()
if parsed['train_flag'] == 'True':
if parsed['load_flag'] == 'True':
with open('best_config.json', 'r') as infile:
best_config = json.load(infile)
else:
print('start training')
best_config = train(params_dict, rels, val_data_prel, trainval_data_prel, neg_sampler, parsed['num_processes'],
parsed['window'])
if parsed['save_config'] == 'True':
import json
with open('best_config.json', 'w') as outfile:
json.dump(best_config, outfile)
else: # use preset lmbda and alpha; same for all relations
best_config = {}
for rel in rels:
best_config[str(rel)] = {}
best_config[str(rel)]['lmbda_psi'] = [parsed['lmbda']]
best_config[str(rel)]['alpha'] = [parsed['alpha']]
end_train = timeit.default_timer()
# compute validation mrr
print("Computing validation MRR")
perf_list_all_val, hits_list_all_val = test(best_config,rels, val_data_prel,
trainval_data_prel, neg_sampler, parsed['num_processes'],
parsed['window'], split_mode='val')
val_mrr = float(np.mean(perf_list_all_val))
# compute test mrr
print("Computing test MRR")
start_test = timeit.default_timer()
perf_list_all, hits_list_all = test(best_config,rels, test_data_prel,
all_data_prel, neg_sampler, parsed['num_processes'],
parsed['window'])
end_o = timeit.default_timer()
train_time_o = round(end_train- start_train, 6)
test_time_o = round(end_o- start_test, 6)
total_time_o = round(end_o- start_o, 6)
print("Running Training to find best configs finished in {} seconds.".format(train_time_o))
print("Running testing with best configs finished in {} seconds.".format(test_time_o))
print("Running all steps finished in {} seconds.".format(total_time_o))
print(f"The test MRR is {np.mean(perf_list_all)}")
print(f"The valid MRR is {val_mrr}")
print(f"The Hits@10 is {np.mean(hits_list_all)}")
print(f"We have {len(perf_list_all)} predictions")
print(f"The test set has len {len(test_data)} ")
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_NONE_{DATA}_results.json'
metric = dataset.eval_metric
save_results({'model': MODEL_NAME,
'train_flag': parsed['train_flag'],
'data': DATA,
'run': 1,
'seed': SEED,
metric: float(np.mean(perf_list_all)),
'val_mrr': val_mrr,
'hits10': float(np.mean(hits_list_all)),
'test_time': test_time_o,
'tot_train_val_time': total_time_o
},
results_filename)
if parsed['num_processes']>1:
ray.shutdown()
================================================
FILE: examples/linkproppred/tkgl-wikidata/regcn.py
================================================
"""
Temporal Knowledge Graph Reasoning Based on Evolutional Representation Learning
Reference:
- https://github.com/Lee-zix/RE-GCN
Zixuan Li, Xiaolong Jin, Wei Li, Saiping Guan, Jiafeng Guo, Huawei Shen, Yuanzhuo Wang and Xueqi Cheng. Temporal
Knowledge Graph Reasoning Based on Evolutional Representation Learning. SIGIR 2021.
"""
import sys
import timeit
import os
import sys
import os.path as osp
from pathlib import Path
import numpy as np
import torch
import random
from tqdm import tqdm
# internal imports
modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(modules_path)
from modules.rrgcn import RecurrentRGCNREGCN
from tgb.utils.utils import set_random_seed, split_by_time, save_results
from modules.tkg_utils import get_args_regcn, reformat_ts
from modules.tkg_utils_dgl import build_sub_graph
from tgb.linkproppred.evaluate import Evaluator
from tgb.linkproppred.dataset import LinkPropPredDataset
import json
def test(model, history_list, test_list, num_rels, num_nodes, use_cuda, model_name, static_graph, mode, split_mode):
"""
Test the model on either test or validation set
:param model: model used to test
:param history_list: all input history snap shot list, not include output label train list or valid list
:param test_list: test triple snap shot list
:param num_rels: number of relations
:param num_nodes: number of nodes
:param use_cuda:
:param model_name:
:param mode:
:param split_mode: 'test' or 'val' to state which negative samples to load
:return mrr
"""
print("Testing for mode: ", split_mode)
if split_mode == 'test':
timesteps_to_eval = test_timestamps_orig
else:
timesteps_to_eval = val_timestamps_orig
idx = 0
if mode == "test":
# test mode: load parameter form file
if use_cuda:
checkpoint = torch.load(model_name, map_location=torch.device(args.gpu))
else:
checkpoint = torch.load(model_name, map_location=torch.device('cpu'))
# use best stat checkpoint:
print("Load Model name: {}. Using best epoch : {}".format(model_name, checkpoint['epoch']))
print("\n"+"-"*10+"start testing"+"-"*10+"\n")
model.load_state_dict(checkpoint['state_dict'])
model.eval()
input_list = [snap for snap in history_list[-args.test_history_len:]]
perf_list_all = []
hits_list_all = []
perf_per_rel = {}
for rel in range(num_rels):
perf_per_rel[rel] = []
for time_idx, test_snap in enumerate(tqdm(test_list)):
history_glist = [build_sub_graph(num_nodes, num_rels, g, use_cuda, args.gpu) for g in input_list]
test_triples_input = torch.LongTensor(test_snap).cuda() if use_cuda else torch.LongTensor(test_snap)
timesteps_batch =timesteps_to_eval[time_idx]*np.ones(len(test_triples_input[:,0]))
neg_samples_batch = neg_sampler.query_batch(test_triples_input[:,0], test_triples_input[:,2],
timesteps_batch, edge_type=test_triples_input[:,1], split_mode=split_mode)
pos_samples_batch = test_triples_input[:,2]
_, perf_list, hits_list = model.predict(history_glist, num_rels, static_graph, test_triples_input, use_cuda, neg_samples_batch, pos_samples_batch,
evaluator, METRIC)
perf_list_all.extend(perf_list)
hits_list_all.extend(hits_list)
if split_mode == "test":
if args.log_per_rel:
for score, rel in zip(perf_list, test_triples_input[:,1].tolist()):
perf_per_rel[rel].append(score)
# reconstruct history graph list
input_list.pop(0)
input_list.append(test_snap)
idx += 1
if split_mode == "test":
if args.log_per_rel:
for rel in range(num_rels):
if len(perf_per_rel[rel]) > 0:
perf_per_rel[rel] = float(np.mean(perf_per_rel[rel]))
else:
perf_per_rel.pop(rel)
mrr = np.mean(perf_list_all)
hits10 = np.mean(hits_list_all)
return mrr, perf_per_rel, hits10
def run_experiment(args, n_hidden=None, n_layers=None, dropout=None, n_bases=None):
"""
Run the experiment with the given configuration
:param args: arguments
:param n_hidden: hidden dimension
:param n_layers: number of layers
:param dropout: dropout rate
:param n_bases: number of bases
:return: mrr, perf_per_rel (mean reciprocal rank, performance per relation)
"""
# load configuration for grid search the best configuration
if n_hidden:
args.n_hidden = n_hidden
if n_layers:
args.n_layers = n_layers
if dropout:
args.dropout = dropout
if n_bases:
args.n_bases = n_bases
mrr = 0
hits10=0
# 2) set save model path
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
if not osp.exists(save_model_dir):
os.mkdir(save_model_dir)
print('INFO: Create directory {}'.format(save_model_dir))
model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}'
model_state_file = save_model_dir+model_name
perf_per_rel = {}
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
num_static_rels, num_words, static_triples, static_graph = 0, 0, [], None
# create stat
model = RecurrentRGCNREGCN(args.decoder,
args.encoder,
num_nodes,
int(num_rels/2),
num_static_rels, # DIFFERENT
num_words, # DIFFERENT
args.n_hidden,
args.opn,
sequence_len=args.train_history_len,
num_bases=args.n_bases,
num_basis=args.n_basis,
num_hidden_layers=args.n_layers,
dropout=args.dropout,
self_loop=args.self_loop,
skip_connect=args.skip_connect,
layer_norm=args.layer_norm,
input_dropout=args.input_dropout,
hidden_dropout=args.hidden_dropout,
feat_dropout=args.feat_dropout,
aggregation=args.aggregation, # DIFFERENT
weight=args.weight, # DIFFERENT
discount=args.discount, # DIFFERENT
angle=args.angle, # DIFFERENT
use_static=args.add_static_graph, # DIFFERENT
entity_prediction=args.entity_prediction,
relation_prediction=args.relation_prediction,
use_cuda=use_cuda,
gpu = args.gpu,
analysis=args.run_analysis) # DIFFERENT
if use_cuda:
torch.cuda.set_device(args.gpu)
model.cuda()
# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
if args.test and os.path.exists(model_state_file):
mrr, perf_per_rel, hits10 = test(model,
train_list+valid_list,
test_list,
num_rels,
num_nodes,
use_cuda,
model_state_file,
static_graph,
"test",
"test")
return mrr, perf_per_rel, hits10
elif args.test and not os.path.exists(model_state_file):
print("--------------{} not exist, Change mode to train and generate stat for testing----------------\n".format(model_state_file))
return 0, 0
else:
print("----------------------------------------start training----------------------------------------\n")
best_mrr = 0
best_hits = 0
for epoch in range(args.n_epochs):
model.train()
losses = []
idx = [_ for _ in range(len(train_list))]
random.shuffle(idx)
for train_sample_num in tqdm(idx):
if train_sample_num == 0: continue
output = train_list[train_sample_num:train_sample_num+1]
if train_sample_num - args.train_history_len<0:
input_list = train_list[0: train_sample_num]
else:
input_list = train_list[train_sample_num - args.train_history_len:
train_sample_num]
# generate history graph
history_glist = [build_sub_graph(num_nodes, num_rels, snap, use_cuda, args.gpu) for snap in input_list]
output = [torch.from_numpy(_).long().cuda() for _ in output] if use_cuda else [torch.from_numpy(_).long() for _ in output]
loss_e, loss_r, loss_static = model.get_loss(history_glist, output[0], static_graph, use_cuda)
loss = args.task_weight*loss_e + (1-args.task_weight)*loss_r + loss_static
losses.append(loss.item())
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm) # clip gradients
optimizer.step()
optimizer.zero_grad()
print("Epoch {:04d} | Ave Loss: {:.4f} | Best MRR {:.4f} | Model {} "
.format(epoch, np.mean(losses), best_mrr, model_name))
#! checking GPU usage
free_mem, total_mem = torch.cuda.mem_get_info()
print ("--------------GPU memory usage-----------")
print ("there are ", free_mem, " free memory")
print ("there are ", total_mem, " total available memory")
print ("there are ", total_mem - free_mem, " used memory")
print ("--------------GPU memory usage-----------")
# validation
if epoch and epoch % args.evaluate_every == 0:
mrr,perf_per_rel, hits10 = test(model, train_list,
valid_list,
num_rels,
num_nodes,
use_cuda,
model_state_file,
static_graph,
mode="train", split_mode='val')
if mrr < best_mrr:
if epoch >= args.n_epochs:
break
else:
best_mrr = mrr
best_hits = hits10
torch.save({'state_dict': model.state_dict(), 'epoch': epoch}, model_state_file)
return best_mrr, perf_per_rel, hits10
# ==================
# ==================
# ==================
start_overall = timeit.default_timer()
# set hyperparameters
args, _ = get_args_regcn()
args.dataset = 'tkgl-wikidata'
SEED = args.seed # set the random seed for consistency
set_random_seed(SEED)
DATA=args.dataset
MODEL_NAME = 'REGCN'
print("logging mrrs per relation: ", args.log_per_rel)
# load data
dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
num_rels = dataset.num_rels
num_nodes = dataset.num_nodes
subjects = dataset.full_data["sources"]
objects= dataset.full_data["destinations"]
relations = dataset.edge_type
timestamps_orig = dataset.full_data["timestamps"]
timestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1
all_quads = np.stack((subjects, relations, objects, timestamps), axis=1)
train_data = all_quads[dataset.train_mask]
val_data = all_quads[dataset.val_mask]
test_data = all_quads[dataset.test_mask]
val_timestamps_orig = list(set(timestamps_orig[dataset.val_mask])) # needed for getting the negative samples
val_timestamps_orig.sort()
test_timestamps_orig = list(set(timestamps_orig[dataset.test_mask])) # needed for getting the negative samples
test_timestamps_orig.sort()
train_list = split_by_time(train_data)
valid_list = split_by_time(val_data)
test_list = split_by_time(test_data)
# evaluation metric
METRIC = dataset.eval_metric
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
#load the ns samples
dataset.load_val_ns()
dataset.load_test_ns()
## run training and testing
val_mrr, test_mrr = 0, 0
test_hits10 = 0
if args.grid_search:
print("hyperparameter grid search not implemented. Exiting.")
# single run
else:
start_train = timeit.default_timer()
if args.test == False: #if they are true: directly test on a previously trained and stored model
print('start training')
val_mrr, perf_per_rel, val_hits10 = run_experiment(args) # do training
start_test = timeit.default_timer()
args.test = True
print('start testing')
test_mrr, perf_per_rel, test_hits10 = run_experiment(args) # do testing
test_time = timeit.default_timer() - start_test
all_time = timeit.default_timer() - start_train
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
print(f"\Train and Test: Elapsed Time (s): {all_time: .4f}")
print(f"\tTest: {METRIC}: {test_mrr: .4f}")
print(f"\tValid: {METRIC}: {val_mrr: .4f}")
# saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
save_results({'model': MODEL_NAME,
'data': DATA,
'run': args.run_nr,
'seed': SEED,
f'val {METRIC}': float(val_mrr),
f'test {METRIC}': float(test_mrr),
'test_time': test_time,
'tot_train_val_time': all_time,
'test_hits10': float(test_hits10)
},
results_filename)
if args.log_per_rel:
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results_per_rel.json'
with open(results_filename, 'w') as json_file:
json.dump(perf_per_rel, json_file)
sys.exit()
================================================
FILE: examples/linkproppred/tkgl-wikidata/tkgl-wikidata_example.py
================================================
import numpy as np
import timeit
from tqdm import tqdm
import os.path as osp
import sys
import os
modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(modules_path)
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
from torch_geometric.loader import TemporalDataLoader
from tgb.linkproppred.evaluate import Evaluator
DATA = "tkgl-wikidata"
# data loading
dataset = PyGLinkPropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
metric = dataset.eval_metric
print ("there are {} nodes and {} edges".format(dataset.num_nodes, dataset.num_edges))
print ("there are {} relation types".format(dataset.num_rels))
timestamp = data.t
head = data.src
tail = data.dst
edge_type = data.edge_type #relation
neg_sampler = dataset.negative_sampler
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
metric = dataset.eval_metric
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
BATCH_SIZE = 1 ## 200
val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)
test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)
start_time = timeit.default_timer()
#load the ns samples first
dataset.load_val_ns()
for batch in tqdm(val_loader):
src, pos_dst, t, msg, rel = batch.src, batch.dst, batch.t, batch.msg, batch.edge_type
neg_batch_list = neg_sampler.query_batch(src.detach().cpu().numpy(), pos_dst.detach().cpu().numpy(), t.detach().cpu().numpy(), rel.detach().cpu().numpy(), split_mode='val')
if len(neg_batch_list[0]) > 1500:
print(rel, len(neg_batch_list[0]))
print ("loading ns samples from validation", timeit.default_timer() - start_time)
start_time = timeit.default_timer()
dataset.load_test_ns()
for batch in test_loader:
src, pos_dst, t, msg, rel = batch.src, batch.dst, batch.t, batch.msg, batch.edge_type
neg_batch_list = neg_sampler.query_batch(src.detach().cpu().numpy(), pos_dst.detach().cpu().numpy(), t.detach().cpu().numpy(), rel.detach().cpu().numpy(), split_mode='test')
print ("loading ns samples from test", timeit.default_timer() - start_time)
print ("retrieved all negative samples")
# #* load numpy arrays instead
# from tgb.linkproppred.dataset import LinkPropPredDataset
# # data loading
# dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
# data = dataset.full_data
# metric = dataset.eval_metric
# sources = dataset.full_data['sources']
# print (sources.dtype)
================================================
FILE: examples/linkproppred/tkgl-wikidata/tlogic.py
================================================
"""
https://github.com/liu-yushan/TLogic/tree/main/mycode
TLogic: Temporal Logical Rules for Explainable Link Forecasting on Temporal Knowledge Graphs.
Yushan Liu, Yunpu Ma, Marcel Hildebrandt, Mitchell Joblin, Volker Tresp
"""
# imports
import sys
import os
import os.path as osp
from pathlib import Path
modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(modules_path)
import timeit
import argparse
import numpy as np
import json
from joblib import Parallel, delayed
import itertools
from tgb.linkproppred.evaluate import Evaluator
from tgb.linkproppred.dataset import LinkPropPredDataset
from modules.tlogic_learn_modules import Temporal_Walk, Rule_Learner, store_edges
import modules.tlogic_apply_modules as ra
from tgb.utils.utils import set_random_seed, save_results
from modules.tkg_utils import reformat_ts, get_inv_relation_id, create_scores_array
def learn_rules(i, num_relations):
"""
Learn rules (multiprocessing possible).
Parameters:
i (int): process number
num_relations (int): minimum number of relations for each process
Returns:
rl.rules_dict (dict): rules dictionary
"""
# if seed:
# np.random.seed(seed)
num_rest_relations = len(all_relations) - (i + 1) * num_relations
if num_rest_relations >= num_relations:
relations_idx = range(i * num_relations, (i + 1) * num_relations)
else:
relations_idx = range(i * num_relations, len(all_relations))
num_rules = [0]
for k in relations_idx:
rel = all_relations[k]
for length in rule_lengths:
it_start = timeit.default_timer()
for _ in range(num_walks):
walk_successful, walk = temporal_walk.sample_walk(length + 1, rel)
if walk_successful:
rl.create_rule(walk)
it_end = timeit.default_timer()
it_time = round(it_end - it_start, 6)
num_rules.append(sum([len(v) for k, v in rl.rules_dict.items()]) // 2)
num_new_rules = num_rules[-1] - num_rules[-2]
print(
"Process {0}: relation {1}/{2}, length {3}: {4} sec, {5} rules".format(
i,
k - relations_idx[0] + 1,
len(relations_idx),
length,
it_time,
num_new_rules,
)
)
return rl.rules_dict
def apply_rules(i, num_queries, rules_dict, neg_sampler, data, window, learn_edges, all_quads, args, split_mode,
log_per_rel=False, num_rels=0):
"""
Apply rules (multiprocessing possible).
Parameters:
i (int): process number
num_queries (int): minimum number of queries for each process
Returns:
hits_list (list): hits list (hits@10 per sample)
perf_list (list): performance list (mrr per sample)
"""
perf_per_rel = {}
for rel in range(num_rels):
perf_per_rel[rel] = []
print("Start process", i, "...")
all_candidates = [dict() for _ in range(len(args))]
no_cands_counter = 0
num_rest_queries = len(data) - (i + 1) * num_queries
if num_rest_queries >= num_queries:
test_queries_idx = range(i * num_queries, (i + 1) * num_queries)
else:
test_queries_idx = range(i * num_queries, len(data))
cur_ts = data[test_queries_idx[0]][3]
edges = ra.get_window_edges(all_quads[:,0:4], cur_ts, learn_edges, window)
it_start = timeit.default_timer()
hits_list = [0] * len(test_queries_idx)
perf_list = [0] * len(test_queries_idx)
for index, j in enumerate(test_queries_idx):
neg_sample_el = neg_sampler.query_batch(np.expand_dims(np.array(data[j,0]), axis=0),
np.expand_dims(np.array(data[j,2]), axis=0),
np.expand_dims(np.array(data[j,4]), axis=0),
np.expand_dims(np.array(data[j,1]), axis=0),
split_mode=split_mode)[0]
# neg_samples_batch[j]
pos_sample_el = data[j,2]
test_query = data[j]
assert pos_sample_el == test_query[2]
cands_dict = [dict() for _ in range(len(args))]
if test_query[3] != cur_ts:
cur_ts = test_query[3]
edges = ra.get_window_edges(all_quads[:,0:4], cur_ts, learn_edges, window)
if test_query[1] in rules_dict:
dicts_idx = list(range(len(args)))
for rule in rules_dict[test_query[1]]:
walk_edges = ra.match_body_relations(rule, edges, test_query[0])
if 0 not in [len(x) for x in walk_edges]:
rule_walks = ra.get_walks(rule, walk_edges)
if rule["var_constraints"]:
rule_walks = ra.check_var_constraints(
rule["var_constraints"], rule_walks
)
if not rule_walks.empty:
cands_dict = ra.get_candidates(
rule,
rule_walks,
cur_ts,
cands_dict,
score_func,
args,
dicts_idx,
)
for s in dicts_idx:
cands_dict[s] = {
x: sorted(cands_dict[s][x], reverse=True)
for x in cands_dict[s].keys()
}
cands_dict[s] = dict(
sorted(
cands_dict[s].items(),
key=lambda item: item[1],
reverse=True,
)
)
top_k_scores = [v for _, v in cands_dict[s].items()][:top_k]
unique_scores = list(
scores for scores, _ in itertools.groupby(top_k_scores)
)
if len(unique_scores) >= top_k:
dicts_idx.remove(s)
if not dicts_idx:
break
if cands_dict[0]:
for s in range(len(args)):
# Calculate noisy-or scores
scores = list(
map(
lambda x: 1 - np.product(1 - np.array(x)),
cands_dict[s].values(),
)
)
cands_scores = dict(zip(cands_dict[s].keys(), scores))
noisy_or_cands = dict(
sorted(cands_scores.items(), key=lambda x: x[1], reverse=True)
)
all_candidates[s][j] = noisy_or_cands
else: # No candidates found by applying rules
no_cands_counter += 1
for s in range(len(args)):
all_candidates[s][j] = dict()
else: # No rules exist for this relation
no_cands_counter += 1
for s in range(len(args)):
all_candidates[s][j] = dict()
if not (j - test_queries_idx[0] + 1) % 100:
it_end = timeit.default_timer()
it_time = round(it_end - it_start, 6)
print(
"Process {0}: test samples finished: {1}/{2}, {3} sec".format(
i, j - test_queries_idx[0] + 1, len(test_queries_idx), it_time
)
)
it_start = timeit.default_timer()
predictions = create_scores_array(all_candidates[s][j], num_nodes)
predictions_of_interest_pos = np.array(predictions[pos_sample_el])
predictions_of_interest_neg = predictions[neg_sample_el]
input_dict = {
"y_pred_pos": predictions_of_interest_pos,
"y_pred_neg": predictions_of_interest_neg,
"eval_metric": ['mrr'],
}
predictions = evaluator.eval(input_dict)
perf_list[index] = predictions['mrr']
hits_list[index] = predictions['hits@10']
if split_mode == "test":
if log_per_rel:
perf_per_rel[test_query[1]].append(perf_list[index]) #test_query[1] is the relation index
if split_mode == "test":
if log_per_rel:
for rel in range(num_rels):
if len(perf_per_rel[rel]) > 0:
perf_per_rel[rel] = float(np.mean(perf_per_rel[rel]))
else:
perf_per_rel.pop(rel)
return perf_list, hits_list, perf_per_rel
## args
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", "-d", default="tkgl-wikidata", type=str)
parser.add_argument("--rule_lengths", "-l", default="1", type=int, nargs="+")
parser.add_argument("--num_walks", "-n", default="100", type=int)
parser.add_argument("--transition_distr", default="exp", type=str)
parser.add_argument("--window", "-w", default=0, type=int) # set to e.g. 200 if only the most recent 200 timesteps should be considered. set to -2 if multistep
parser.add_argument("--top_k", default=20, type=int)
parser.add_argument("--num_processes", "-p", default=1, type=int)
parser.add_argument("--alpha", "-alpha", default=0.99, type=float) # fix alpha. used if trainflag == false
# parser.add_argument("--train_flag", "-tr", default=True) # do we need training, ie selection of lambda and alpha
parser.add_argument("--save_config", "-c", default=True) # do we need to save the selection of lambda and alpha in config file?
parser.add_argument('--seed', type=int, help='Random seed', default=1)
parser.add_argument('--run_nr', type=int, help='Run Number', default=1)
parser.add_argument('--learn_rules_flag', type=bool, help='Do we want to learn the rules', default=True)
parser.add_argument('--rule_filename', type=str, help='if rules not learned: where are they stored', default='0_r[3]_n100_exp_s1_rules.json')
parser.add_argument('--log_per_rel', type=bool, help='Do we want to log mrr per relation', default=False)
parser.add_argument('--compute_valid_mrr', type=bool, help='Do we want to compute mrr for valid set', default=True)
parsed = vars(parser.parse_args())
return parsed
start_o = timeit.default_timer()
## get args
parsed = get_args()
dataset = parsed["dataset"]
rule_lengths = parsed["rule_lengths"]
rule_lengths = [rule_lengths] if (type(rule_lengths) == int) else rule_lengths
num_walks = parsed["num_walks"]
transition_distr = parsed["transition_distr"]
num_processes = parsed["num_processes"]
window = parsed["window"]
top_k = parsed["top_k"]
log_per_rel = parsed['log_per_rel']
MODEL_NAME = 'TLogic'
SEED = parsed['seed'] # set the random seed for consistency
set_random_seed(SEED)
## load dataset and prepare it accordingly
name = parsed["dataset"]
compute_valid_mrr = parsed["compute_valid_mrr"]
dataset = LinkPropPredDataset(name=name, root="datasets", preprocess=True)
DATA = name
relations = dataset.edge_type
num_rels = dataset.num_rels
subjects = dataset.full_data["sources"]
objects= dataset.full_data["destinations"]
num_nodes = dataset.num_nodes
timestamps_orig = dataset.full_data["timestamps"]
timestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1
all_quads = np.stack((subjects, relations, objects, timestamps, timestamps_orig), axis=1)
train_data = all_quads[dataset.train_mask,0:4] # we do not need the original timestamps
val_data = all_quads[dataset.val_mask,0:5]
test_data = all_quads[dataset.test_mask,0:5]
all_data = all_quads[:,0:4]
metric = dataset.eval_metric
evaluator = Evaluator(name=name)
neg_sampler = dataset.negative_sampler
inv_relation_id = get_inv_relation_id(num_rels)
#load the ns samples
dataset.load_val_ns()
dataset.load_test_ns()
output_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
learn_rules_flag = parsed['learn_rules_flag']
## 1. learn rules
start_train = timeit.default_timer()
if learn_rules_flag:
print("start learning rules")
# edges (dict): edges for each relation
# inv_relation_id (dict): mapping of relation to inverse relation
temporal_walk = Temporal_Walk(train_data, inv_relation_id, transition_distr)
rl = Rule_Learner(edges=temporal_walk.edges, id2relation=None, inv_relation_id=inv_relation_id,
output_dir=output_dir)
all_relations = sorted(temporal_walk.edges) # Learn for all relations
start = timeit.default_timer()
num_relations = len(all_relations) // num_processes
output = Parallel(n_jobs=num_processes)(
delayed(learn_rules)(i, num_relations) for i in range(num_processes)
)
end = timeit.default_timer()
all_rules = output[0]
for i in range(1, num_processes):
all_rules.update(output[i])
total_time = round(end - start, 6)
print("Learning finished in {} seconds.".format(total_time))
rl.rules_dict = all_rules
rl.sort_rules_dict()
rule_filename = rl.save_rules(0, rule_lengths, num_walks, transition_distr, SEED)
# rl.save_rules_verbalized(0, rule_lengths, num_walks, transition_distr, seed)
# rules_statistics(rl.rules_dict)
else:
rule_filename = parsed['rule_filename']
print("Loading rules from file {}".format(parsed['rule_filename']))
end_train = timeit.default_timer()
## 2. Apply rules
rules_dict = json.load(open(output_dir + rule_filename))
rules_dict = {int(k): v for k, v in rules_dict.items()}
rules_dict = ra.filter_rules(
rules_dict, min_conf=0.01, min_body_supp=2, rule_lengths=rule_lengths
) # filter rules for minimum confidence, body support and rule length
learn_edges = store_edges(train_data)
score_func = ra.score_12
# It is possible to specify a list of list of arguments for tuning
args = [[0.1, 0.5]]
# compute valid mrr
start_valid = timeit.default_timer()
if compute_valid_mrr:
print('Computing valid MRR')
num_queries = len(val_data) // num_processes
output = Parallel(n_jobs=num_processes)(
delayed(apply_rules)(i, num_queries,rules_dict, neg_sampler, val_data, window, learn_edges,
all_quads, args, split_mode='val') for i in range(num_processes))
end = timeit.default_timer()
perf_list_val = []
hits_list_val = []
for i in range(num_processes):
perf_list_val.extend(output[i][0])
hits_list_val.extend(output[i][1])
else:
perf_list_val = [0]
hits_list_val = [0]
end_valid = timeit.default_timer()
# compute test mrr
if log_per_rel ==True:
num_processes = 1 #otherwise logging per rel does not work for our implementation
start_test = timeit.default_timer()
print('Computing test MRR')
start = timeit.default_timer()
num_queries = len(test_data) // num_processes
output = Parallel(n_jobs=num_processes)(
delayed(apply_rules)(i, num_queries,rules_dict, neg_sampler, test_data, window, learn_edges,
all_quads, args, split_mode='test', log_per_rel=log_per_rel, num_rels=num_rels) for i in range(num_processes))
end = timeit.default_timer()
perf_list_all = []
hits_list_all = []
for i in range(num_processes):
perf_list_all.extend(output[i][0])
hits_list_all.extend(output[i][1])
if log_per_rel == True:
perf_per_rel = output[0][2]
total_time = round(end - start, 6)
total_valid_time = round(end_valid - start_valid, 6)
print("Application finished in {} seconds.".format(total_time))
print(f"The valid MRR is {np.mean(perf_list_val)}")
print(f"The MRR is {np.mean(perf_list_all)}")
print(f"The Hits@10 is {np.mean(hits_list_all)}")
print(f"We have {len(perf_list_all)} predictions")
print(f"The test set has len {len(test_data)} ")
end_o = timeit.default_timer()
train_time_o = round(end_train- start_train, 6)
test_time_o = round(end_o- start_test, 6)
total_time_o = round(end_o- start_o, 6)
print("Running Training to find best configs finished in {} seconds.".format(train_time_o))
print("Running testing with best configs finished in {} seconds.".format(test_time_o))
print("Running all steps finished in {} seconds.".format(total_time_o))
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
if log_per_rel == True:
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results_per_rel.json'
with open(results_filename, 'w') as json_file:
json.dump(perf_per_rel, json_file)
results_filename = f'{results_path}/{MODEL_NAME}_NONE_{DATA}_results.json'
metric = dataset.eval_metric
save_results({'model': MODEL_NAME,
'train_flag': None,
'rule_len': rule_lengths,
'window': window,
'data': DATA,
'run': 1,
'seed': SEED,
metric: float(np.mean(perf_list_all)),
'hits10': float(np.mean(hits_list_all)),
'val_mrr': float(np.mean(perf_list_val)),
'test_time': test_time_o,
'tot_train_val_time': total_time_o,
'valid_time': total_valid_time
},
results_filename)
================================================
FILE: examples/linkproppred/tkgl-yago/cen.py
================================================
"""
Complex Evolutional Pattern Learning for Temporal Knowledge Graph Reasoning
Reference:
- https://github.com/Lee-zix/CEN
Zixuan Li, Saiping Guan, Xiaolong Jin, Weihua Peng, Yajuan Lyu , Yong Zhu, Long Bai, Wei Li, Jiafeng Guo, Xueqi Cheng.
Complex Evolutional Pattern Learning for Temporal Knowledge Graph Reasoning. ACL 2022.
"""
import timeit
import os
import sys
import os.path as osp
from pathlib import Path
import numpy as np
import torch
import random
from tqdm import tqdm
import json
# internal imports
modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(modules_path)
from modules.rrgcn import RecurrentRGCNCEN
from tgb.utils.utils import set_random_seed, split_by_time, save_results
from modules.tkg_utils import get_args_cen, reformat_ts
from modules.tkg_utils_dgl import build_sub_graph
from tgb.linkproppred.evaluate import Evaluator
from tgb.linkproppred.dataset import LinkPropPredDataset
def test(model, history_len, history_list, test_list, num_rels, num_nodes, use_cuda, model_name, mode, split_mode):
"""
Test the model
:param model: model used to test
:param history_list: all input history snap shot list, not include output label train list or valid list
:param test_list: test triple snap shot list
:param num_rels: number of relations
:param num_nodes: number of nodes
:param use_cuda:
:param model_name:
:param mode:
:param split_mode: 'test' or 'val' to state which negative samples to load
:return mrr
"""
print("Testing for mode: ", split_mode)
if split_mode == 'test':
timesteps_to_eval = test_timestamps_orig
else:
timesteps_to_eval = val_timestamps_orig
idx = 0
if mode == "test":
# test mode: load parameter form file
if use_cuda:
checkpoint = torch.load(model_name, map_location=torch.device(args.gpu))
else:
checkpoint = torch.load(model_name, map_location=torch.device('cpu'))
# use best stat checkpoint:
print("Load Model name: {}. Using best epoch : {}".format(model_name, checkpoint['epoch']))
print("\n"+"-"*10+"start testing"+"-"*10+"\n")
model.load_state_dict(checkpoint['state_dict'])
model.eval()
input_list = [snap for snap in history_list[-history_len:]]
perf_list_all = []
hits_list_all = []
perf_per_rel = {}
for rel in range(num_rels):
perf_per_rel[rel] = []
for time_idx, test_snap in enumerate(tqdm(test_list)):
history_glist = [build_sub_graph(num_nodes, num_rels, g, use_cuda, args.gpu) for g in input_list]
test_triples_input = torch.LongTensor(test_snap).cuda() if use_cuda else torch.LongTensor(test_snap)
timesteps_batch =timesteps_to_eval[time_idx]*np.ones(len(test_triples_input[:,0]))
neg_samples_batch = neg_sampler.query_batch(test_triples_input[:,0], test_triples_input[:,2],
timesteps_batch, edge_type=test_triples_input[:,1], split_mode=split_mode)
pos_samples_batch = test_triples_input[:,2]
_, perf_list, hits_list = model.predict(history_glist, test_triples_input, use_cuda, neg_samples_batch, pos_samples_batch,
evaluator, METRIC)
perf_list_all.extend(perf_list)
hits_list_all.extend(hits_list)
if split_mode == "test":
if args.log_per_rel:
for score, rel in zip(perf_list, test_triples_input[:,1].tolist()):
perf_per_rel[rel].append(score)
# reconstruct history graph list
input_list.pop(0)
input_list.append(test_snap)
idx += 1
if split_mode == "test":
if args.log_per_rel:
for rel in range(num_rels):
if len(perf_per_rel[rel]) > 0:
perf_per_rel[rel] = float(np.mean(perf_per_rel[rel]))
else:
perf_per_rel.pop(rel)
mrr = np.mean(perf_list_all)
hits10 = np.mean(hits_list_all)
return mrr, perf_per_rel, hits10
def run_experiment(args, trainvalidtest_id=0, n_hidden=None, n_layers=None, dropout=None, n_bases=None):
'''
Run experiment for CEN model
:param args: arguments for the model
:param trainvalidtest_id: -1: pretrainig, 0: curriculum training (to find best test history len), 1: test on valid set, 2: test on test set
:param n_hidden: number of hidden units
:param n_layers: number of layers
:param dropout: dropout rate
:param n_bases: number of bases
return: mrr, perf_per_rel: mean reciprocal rank and performance per relation
'''
# 1) load configuration for grid search the best configuration
if n_hidden:
args.n_hidden = n_hidden
if n_layers:
args.n_layers = n_layers
if dropout:
args.dropout = dropout
if n_bases:
args.n_bases = n_bases
test_history_len = args.test_history_len
# 2) set save model path
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
if not osp.exists(save_model_dir):
os.mkdir(save_model_dir)
print('INFO: Create directory {}'.format(save_model_dir))
test_model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{args.test_history_len}'
test_state_file = save_model_dir+test_model_name
perf_per_rel ={}
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
# create stat
model = RecurrentRGCNCEN(args.decoder,
args.encoder,
num_nodes,
num_rels,
args.n_hidden,
args.opn,
sequence_len=args.train_history_len,
num_bases=args.n_bases,
num_basis=args.n_basis,
num_hidden_layers=args.n_layers,
dropout=args.dropout,
self_loop=args.self_loop,
skip_connect=args.skip_connect,
layer_norm=args.layer_norm,
input_dropout=args.input_dropout,
hidden_dropout=args.hidden_dropout,
feat_dropout=args.feat_dropout,
entity_prediction=args.entity_prediction,
relation_prediction=args.relation_prediction,
use_cuda=use_cuda,
gpu = args.gpu)
if use_cuda:
torch.cuda.set_device(args.gpu)
model.cuda()
# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
if trainvalidtest_id == 1: # normal test on validation set Note that mode=test
if os.path.exists(test_state_file):
mrr, _, hits10 = test(model, args.test_history_len, train_list, valid_list, num_rels, num_nodes, use_cuda,
test_state_file, "test", split_mode="val")
else:
print('Cannot do testing because model does not exist: ', test_state_file)
mrr = 0
hits10 = 0
elif trainvalidtest_id == 2: # normal test on test set
if os.path.exists(test_state_file):
mrr, perf_per_rel, hits10 = test(model, args.test_history_len, train_list+valid_list, test_list, num_rels, num_nodes, use_cuda,
test_state_file, "test", split_mode="test")
else:
print('Cannot do testing because model does not exist: ', test_state_file)
mrr = 0
hits10 = 0
elif trainvalidtest_id == -1:
print("-------------start pre training model with history length {}----------\n".format(args.start_history_len))
model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{args.start_history_len}'
model_state_file = save_model_dir + model_name
print("Sanity Check: stat name : {}".format(model_state_file))
print("Sanity Check: Is cuda available ? {}".format(torch.cuda.is_available()))
best_mrr = 0
best_epoch = 0
best_hits10= 0
## training loop
for epoch in range(args.n_epochs):
model.train()
losses = []
idx = [_ for _ in range(len(train_list))]
random.shuffle(idx)
for train_sample_num in idx:
if train_sample_num == 0 or train_sample_num == 1: continue
if train_sample_num - args.start_history_len<0:
input_list = train_list[0: train_sample_num]
output = train_list[1:train_sample_num+1]
else:
input_list = train_list[train_sample_num-args.start_history_len: train_sample_num]
output = train_list[train_sample_num-args.start_history_len+1:train_sample_num+1]
# generate history graph
history_glist = [build_sub_graph(num_nodes, num_rels, snap, use_cuda, args.gpu) for snap in input_list]
output = [torch.from_numpy(_).long().cuda() for _ in output] if use_cuda else [torch.from_numpy(_).long() for _ in output]
loss= model.get_loss(history_glist, output[-1], None, use_cuda)
losses.append(loss.item())
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm) # clip gradients
optimizer.step()
optimizer.zero_grad()
print("His {:04d}, Epoch {:04d} | Ave Loss: {:.4f} | Best MRR {:.4f} | Model {} "
.format(args.start_history_len, epoch, np.mean(losses), best_mrr, model_name))
#! checking GPU usage
free_mem, total_mem = torch.cuda.mem_get_info()
print ("--------------GPU memory usage-----------")
print ("there are ", free_mem, " free memory")
print ("there are ", total_mem, " total available memory")
print ("there are ", total_mem - free_mem, " used memory")
print ("--------------GPU memory usage-----------")
# validation
if epoch % args.evaluate_every == 0:
mrr, _, hits10 = test(model, args.start_history_len, train_list, valid_list, num_rels, num_nodes, use_cuda,
model_state_file, mode="train", split_mode= "val")
if mrr< best_mrr:
if epoch >= args.n_epochs or epoch - best_epoch > 5:
break
else:
best_mrr = mrr
best_epoch = epoch
best_hits10 = hits10
torch.save({'state_dict': model.state_dict(), 'epoch': epoch}, model_state_file)
elif trainvalidtest_id == 0: #curriculum training
model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{args.start_history_len}'
init_state_file = save_model_dir + model_name
init_checkpoint = torch.load(init_state_file, map_location=torch.device(args.gpu))
# use best stat checkpoint:
print("Load Previous Model name: {}. Using best epoch : {}".format(init_state_file, init_checkpoint['epoch']))
print("\n"+"-"*10+"Load model with history length {}".format(args.start_history_len)+"-"*10+"\n")
model.load_state_dict(init_checkpoint['state_dict'])
test_history_len = args.start_history_len
mrr, _, hits10 = test(model,
args.start_history_len,
train_list,
valid_list,
num_rels,
num_nodes,
use_cuda,
init_state_file,
mode="test", split_mode= "val")
best_mrr_list = [mrr.item()]
best_hits_list = [hits10.item()]
# start knowledge distillation
ks_idx = 0
for history_len in range(args.start_history_len+1, args.train_history_len+1, 1):
# current model
print("best mrr list :", best_mrr_list)
# lr = 0.1*args.lr - 0.002*args.lr*ks_idx
optimizer = torch.optim.Adam(model.parameters(), lr=0.1*args.lr, weight_decay=0.00001)
model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{history_len}'
model_state_file = save_model_dir + model_name
print("Sanity Check: stat name : {}".format(model_state_file))
# load model with the least history length
prev_model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}_{history_len-1}'
prev_state_file = save_model_dir + prev_model_name
checkpoint = torch.load(prev_state_file, map_location=torch.device(args.gpu))
model.load_state_dict(checkpoint['state_dict'])
print("\n"+"-"*10+"start knowledge distillation for history length at "+ str(history_len)+"-"*10+"\n")
best_mrr = 0
best_hits10 = 0
best_epoch = 0
for epoch in range(args.n_epochs):
model.train()
losses = []
idx = [_ for _ in range(len(train_list))]
random.shuffle(idx)
for train_sample_num in idx:
if train_sample_num == 0 or train_sample_num == 1: continue
if train_sample_num - history_len<0:
input_list = train_list[0: train_sample_num]
output = train_list[1:train_sample_num+1]
else:
input_list = train_list[train_sample_num - history_len: train_sample_num]
output = train_list[train_sample_num-history_len+1:train_sample_num+1]
# generate history graph
history_glist = [build_sub_graph(num_nodes, num_rels, snap, use_cuda, args.gpu) for snap in input_list]
output = [torch.from_numpy(_).long().cuda() for _ in output] if use_cuda else [torch.from_numpy(_).long() for _ in output]
loss= model.get_loss(history_glist, output[-1], None, use_cuda)
# print(loss)
losses.append(loss.item())
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm) # clip gradients
optimizer.step()
optimizer.zero_grad()
print("His {:04d}, Epoch {:04d} | Ave Loss: {:.4f} |Best MRR {:.4f} | Model {} "
.format(history_len, epoch, np.mean(losses), best_mrr, model_name))
#! checking GPU usage
free_mem, total_mem = torch.cuda.mem_get_info()
print ("--------------GPU memory usage-----------")
print ("there are ", free_mem, " free memory")
print ("there are ", total_mem, " total available memory")
print ("there are ", total_mem - free_mem, " used memory")
print ("--------------GPU memory usage-----------")
# validation
if epoch % args.evaluate_every == 0:
mrr, _, hits10 = test(model, history_len, train_list, valid_list, num_rels, num_nodes, use_cuda,
model_state_file, mode="train", split_mode= "val")
if mrr< best_mrr:
if epoch >= args.n_epochs or epoch-best_epoch>2:
break
else:
best_mrr = mrr
best_epoch = epoch
best_hits10 = hits10
torch.save({'state_dict': model.state_dict(), 'epoch': epoch}, model_state_file)
mrr, _, hits10 = test(model, history_len, train_list, valid_list, num_rels, num_nodes, use_cuda,
model_state_file, mode="test", split_mode= "val")
ks_idx += 1
if mrr.item() < max(best_mrr_list):
test_history_len = history_len-1
print("early stopping, best history length: ", test_history_len)
break
else:
best_mrr_list.append(mrr.item())
best_hits_list.append(hits10.item())
return mrr, test_history_len, perf_per_rel, hits10
# ==================
# ==================
# ==================
start_overall = timeit.default_timer()
# set hyperparameters
args, _ = get_args_cen()
args.dataset = 'tkgl-yago'
SEED = args.seed # set the random seed for consistency
set_random_seed(SEED)
DATA=args.dataset
MODEL_NAME = 'CEN'
print("logging mrrs per relation: ", args.log_per_rel)
print("do test and valid? do only test no validation?: ", args.validtest, args.test_only)
# load data
dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
num_rels = dataset.num_rels
num_nodes = dataset.num_nodes
subjects = dataset.full_data["sources"]
objects= dataset.full_data["destinations"]
relations = dataset.edge_type
timestamps_orig = dataset.full_data["timestamps"]
timestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1
all_quads = np.stack((subjects, relations, objects, timestamps), axis=1)
train_data = all_quads[dataset.train_mask]
val_data = all_quads[dataset.val_mask]
test_data = all_quads[dataset.test_mask]
val_timestamps_orig = list(set(timestamps_orig[dataset.val_mask])) # needed for getting the negative samples
val_timestamps_orig.sort()
test_timestamps_orig = list(set(timestamps_orig[dataset.test_mask])) # needed for getting the negative samples
test_timestamps_orig.sort()
train_list = split_by_time(train_data)
valid_list = split_by_time(val_data)
test_list = split_by_time(test_data)
# evaluation metric
METRIC = dataset.eval_metric
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
#load the ns samples
dataset.load_val_ns()
dataset.load_test_ns()
if args.grid_search:
print("TODO: implement hyperparameter grid search")
# single run
else:
start_train = timeit.default_timer()
if args.validtest:
print('directly start testing')
if args.test_history_len_2 != args.test_history_len:
args.test_history_len = args.test_history_len_2 # hyperparameter value as given in original paper
else:
print('running pretrain and train')
# pretrain
mrr, _, _, hits10 = run_experiment(args, trainvalidtest_id=-1)
# train
mrr, args.test_history_len, _, hits10 = run_experiment(args, trainvalidtest_id=0) # overwrite test_history_len with
# the best history len (for valid mrr)
if args.test_only == False:
print("running test (on val and test dataset) with test_history_len of: ", args.test_history_len)
# test on val set
val_mrr, _, _, val_hits10 = run_experiment(args, trainvalidtest_id=1)
else:
val_mrr = 0
val_hits10 = 0
# test on test set
start_test = timeit.default_timer()
test_mrr, _, perf_per_rel, test_hits10 = run_experiment(args, trainvalidtest_id=2)
test_time = timeit.default_timer() - start_test
all_time = timeit.default_timer() - start_train
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
print(f"\Train and Test: Elapsed Time (s): {all_time: .4f}")
print(f"\tTest: {METRIC}: {test_mrr: .4f}")
print(f"\tValid: {METRIC}: {val_mrr: .4f}")
# saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
save_results({'model': MODEL_NAME,
'data': DATA,
'run': args.run_nr,
'seed': SEED,
'test_history_len': args.test_history_len,
f'val {METRIC}': float(val_mrr),
f'test {METRIC}': float(test_mrr),
'test_time': test_time,
'tot_train_val_time': all_time,
'test_hits10': float(test_hits10)
},
results_filename)
if args.log_per_rel:
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results_per_rel.json'
with open(results_filename, 'w') as json_file:
json.dump(perf_per_rel, json_file)
sys.exit()
================================================
FILE: examples/linkproppred/tkgl-yago/edgebank.py
================================================
"""
Dynamic Link Prediction with EdgeBank
NOTE: This implementation works only based on `numpy`
Reference:
- https://github.com/fpour/DGB/tree/main
"""
import timeit
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch_geometric.loader import TemporalDataLoader
from tqdm import tqdm
import math
import os
import os.path as osp
from pathlib import Path
import sys
import argparse
# internal imports
tgb_modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(tgb_modules_path)
from tgb.linkproppred.evaluate import Evaluator
from modules.edgebank_predictor import EdgeBankPredictor
from tgb.utils.utils import set_random_seed
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.utils.utils import save_results
# ==================
# ==================
# ==================
def test(data, test_mask, neg_sampler, split_mode):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
data: a dataset object
test_mask: required masks to load the test set edges
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluation
"""
num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)
perf_list = []
hits_list = []
for batch_idx in tqdm(range(num_batches)):
start_idx = batch_idx * BATCH_SIZE
end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))
pos_src, pos_dst, pos_t, pos_edge = (
data['sources'][test_mask][start_idx: end_idx],
data['destinations'][test_mask][start_idx: end_idx],
data['timestamps'][test_mask][start_idx: end_idx],
data['edge_type'][test_mask][start_idx: end_idx],
)
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, pos_edge, split_mode=split_mode)
for idx, neg_batch in enumerate(neg_batch_list):
query_src = np.array([int(pos_src[idx]) for _ in range(len(neg_batch) + 1)])
query_dst = np.concatenate([np.array([int(pos_dst[idx])]), neg_batch])
y_pred = edgebank.predict_link(query_src, query_dst)
# compute MRR
input_dict = {
"y_pred_pos": np.array([y_pred[0]]),
"y_pred_neg": np.array(y_pred[1:]),
"eval_metric": [metric],
}
results = evaluator.eval(input_dict)
perf_list.append(results[metric])
hits_list.append(results['hits@10'])
# update edgebank memory after each positive batch
edgebank.update_memory(pos_src, pos_dst, pos_t)
perf_metrics = float(np.mean(perf_list))
perf_hits = float(np.mean(hits_list))
return perf_metrics, perf_hits
def get_args():
parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')
parser.add_argument('-d', '--data', type=str, help='Dataset name', default='tkgl-yago')
parser.add_argument('--bs', type=int, help='Batch size', default=200)
parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)
parser.add_argument('--seed', type=int, help='Random seed', default=1)
parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])
parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)
try:
args = parser.parse_args()
except:
parser.print_help()
sys.exit(0)
return args, sys.argv
# ==================
# ==================
# ==================
start_overall = timeit.default_timer()
# set hyperparameters
args, _ = get_args()
SEED = args.seed # set the random seed for consistency
set_random_seed(SEED)
MEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`
BATCH_SIZE = args.bs
K_VALUE = args.k_value
TIME_WINDOW_RATIO = args.time_window_ratio
DATA = args.data
MODEL_NAME = 'EdgeBank'
# data loading with `numpy`
dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
data = dataset.full_data
metric = dataset.eval_metric
# get masks
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
#data for memory in edgebank
hist_src = np.concatenate([data['sources'][train_mask]])
hist_dst = np.concatenate([data['destinations'][train_mask]])
hist_ts = np.concatenate([data['timestamps'][train_mask]])
# Set EdgeBank with memory updater
edgebank = EdgeBankPredictor(
hist_src,
hist_dst,
hist_ts,
memory_mode=MEMORY_MODE,
time_window_ratio=TIME_WINDOW_RATIO)
print("==========================================================")
print(f"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{MEMORY_MODE}_{DATA}_results.json'
# ==================================================== Test
# loading the validation negative samples
dataset.load_val_ns()
# testing ...
start_val = timeit.default_timer()
perf_metric_val, perf_hits_val = test(data, val_mask, neg_sampler, split_mode='val')
end_val = timeit.default_timer()
print(f"INFO: val: Evaluation Setting: >>> ONE-VS--ALL <<< ")
print(f"\tval: {metric}: {perf_metric_val: .4f}")
val_time = timeit.default_timer() - start_val
print(f"\tval: Elapsed Time (s): {val_time: .4f}")
# ==================================================== Test
# loading the test negative samples
dataset.load_test_ns()
# testing ...
start_test = timeit.default_timer()
perf_metric_test, perf_hits_test = test(data, test_mask, neg_sampler, split_mode='test')
end_test = timeit.default_timer()
print(f"INFO: Test: Evaluation Setting: >>> <<< ")
print(f"\tTest: {metric}: {perf_metric_test: .4f}")
test_time = timeit.default_timer() - start_test
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
save_results({'model': MODEL_NAME,
'memory_mode': MEMORY_MODE,
'data': DATA,
'run': 1,
'seed': SEED,
metric: perf_metric_test,
'val_mrr': perf_metric_val,
'test_time': test_time,
'tot_train_val_time': test_time+val_time,
'hits10': perf_hits_test},
results_filename)
================================================
FILE: examples/linkproppred/tkgl-yago/recurrencybaseline.py
================================================
""" from paper: "History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting"
Julia Gastinger, Christian Meilicke, Federico Errica, Timo Sztyler, Anett Schuelke, Heiner Stuckenschmidt (IJCAI 2024)
@inproceedings{gastinger2024baselines,
title={History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting},
author={Gastinger, Julia and Meilicke, Christian and Errica, Federico and Sztyler, Timo and Schuelke, Anett and Stuckenschmidt, Heiner},
booktitle={33nd International Joint Conference on Artificial Intelligence (IJCAI 2024)},
year={2024},
organization={International Joint Conferences on Artificial Intelligence Organization}
}
"""
## imports
import timeit
import argparse
import numpy as np
from copy import copy
from pathlib import Path
import ray
import sys
import os
import os.path as osp
import json
#internal imports
modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(modules_path)
from modules.recurrencybaseline_predictor import baseline_predict, baseline_predict_remote
from tgb.linkproppred.evaluate import Evaluator
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.utils.utils import set_random_seed, save_results
from modules.tkg_utils import create_basis_dict, group_by, reformat_ts
def predict(num_processes, data_c_rel, all_data_c_rel, alpha, lmbda_psi,
perf_list_all, hits_list_all, window, neg_sampler, split_mode):
""" create predictions for each relation on test or valid set and compute mrr
:return perf_list_all: list of mrrs for each test query
:return hits_list_all: list of hits for each test query
"""
first_ts = data_c_rel[0][3]
## use this if you wanna use ray:
num_queries = len(data_c_rel) // num_processes
if num_queries < num_processes: # if we do not have enough queries for all the processes
num_processes_tmp = 1
num_queries = len(data_c_rel)
else:
num_processes_tmp = num_processes
if num_processes > 1:
object_references =[]
for i in range(num_processes_tmp):
num_test_queries = len(data_c_rel) - (i + 1) * num_queries
if num_test_queries >= num_queries:
test_queries_idx =[i * num_queries, (i + 1) * num_queries]
else:
test_queries_idx = [i * num_queries, len(test_data)]
valid_data_b = data_c_rel[test_queries_idx[0]:test_queries_idx[1]]
ob = baseline_predict_remote.remote(num_queries, valid_data_b, all_data_c_rel, window,
basis_dict,
num_nodes, num_rels, lmbda_psi,
alpha, evaluator,first_ts, neg_sampler, split_mode)
object_references.append(ob)
output = ray.get(object_references)
# updates the scores and logging dict for each process
for proc_loop in range(num_processes_tmp):
perf_list_all.extend(output[proc_loop][0])
hits_list_all.extend(output[proc_loop][1])
else:
perf_list, hits_list = baseline_predict(len(data_c_rel), data_c_rel, all_data_c_rel,
window, basis_dict,
num_nodes, num_rels, lmbda_psi,
alpha, evaluator, first_ts, neg_sampler, split_mode)
perf_list_all.extend(perf_list)
hits_list_all.extend(hits_list)
return perf_list_all, hits_list_all
## test
def test(best_config, all_relations,test_data_prel, all_data_prel, neg_sampler, num_processes, window, split_mode='test'):
""" create predictions by loopoing through all relations on test or valid set and compute mrr
:return perf_list_all: list of mrrs for each test query
:return hits_list_all: list of hits for each test query
"""
perf_list_all = []
hits_list_all =[]
csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'
## loop through relations and apply baselines
for rel in all_relations:
start = timeit.default_timer()
if rel in test_data_prel.keys():
lmbda_psi = best_config[str(rel)]['lmbda_psi'][0]
alpha = best_config[str(rel)]['alpha'][0]
# test data for this relation
test_data_c_rel = test_data_prel[rel]
timesteps_test = list(set(test_data_c_rel[:,3]))
timesteps_test.sort()
all_data_c_rel = all_data_prel[rel]
perf_list_rel = []
hits_list_rel = []
perf_list_rel, hits_list_rel = predict(num_processes, test_data_c_rel,
all_data_c_rel, alpha, lmbda_psi,perf_list_rel, hits_list_rel,
window, neg_sampler, split_mode)
perf_list_all.extend(perf_list_rel)
hits_list_all.extend(hits_list_rel)
else:
perf_list_rel =[]
end = timeit.default_timer()
total_time = round(end - start, 6)
print("Relation {} finished in {} seconds.".format(rel, total_time))
with open(csv_file, 'a') as f:
f.write("{},{}\n".format(rel, perf_list_rel))
return perf_list_all, hits_list_all
def read_dict_compute_mrr(split_mode='test'):
""" read the results per relation from a precreated file and compute mrr
:return mrr_per_rel: dictionary of mrrs for each relation
:return all_mrrs: list of mrrs for all relations
"""
csv_file = f'{perrel_results_path}/{MODEL_NAME}_NONE_{DATA}_results_{SEED}'+split_mode+'.csv'
# Initialize an empty dictionary to store the data
results_per_rel_dict = {}
mrr_per_rel = {}
all_mrrs = []
# Open the file for reading
with open(csv_file, 'r') as f:
# Read each line in the file
for line in f:
# Split the line at the comma
parts = line.strip().split(',')
# Extract the key (the first part)
key = int(parts[0])
# Extract the values (the rest of the parts), remove square brackets
values = [float(value.strip('[]')) for value in parts[1:]]
# Add the key-value pair to the dictionary
if key in results_per_rel_dict.keys():
print(f"Key {key} already exists in the dictionary!!! might have duplicate entries in results csv")
results_per_rel_dict[key] = values
all_mrrs.extend(values)
mrr_per_rel[key] = np.mean(values)
if len(list(results_per_rel_dict.keys())) != num_rels:
print("we do not have entries for each rel in the results csv file. only num enties: ", len(list(results_per_rel_dict.keys())))
print("Split mode: "+split_mode +" Mean MRR: ", np.mean(all_mrrs))
print("mrr per relation: ", mrr_per_rel)
## train
def train(params_dict, rels,val_data_prel, trainval_data_prel, neg_sampler, num_processes, window):
""" optional, find best values for lambda and alpha by looping through all relations and testing an a fixed set of params
based on validation mrr
:return best_config: dictionary of best params for each relation
"""
best_config= {}
best_mrr = 0
for rel in rels: # loop through relations. for each relation, apply rules with selected params, compute valid mrr
start = timeit.default_timer()
rel_key = int(rel)
best_config[str(rel_key)] = {}
best_config[str(rel_key)]['not_trained'] = 'True'
best_config[str(rel_key)]['lmbda_psi'] = [default_lmbda_psi,0] #default
best_config[str(rel_key)]['other_lmbda_mrrs'] = list(np.zeros(len(params_dict['lmbda_psi'])))
best_config[str(rel_key)]['alpha'] = [default_alpha,0] #default
best_config[str(rel_key)]['other_alpha_mrrs'] = list(np.zeros(len(params_dict['alpha'])))
if rel in val_data_prel.keys():
# valid data for this relation
val_data_c_rel = val_data_prel[rel]
timesteps_valid = list(set(val_data_c_rel[:,3]))
timesteps_valid.sort()
trainval_data_c_rel = trainval_data_prel[rel]
###### 1) select lambda ###############
lmbdas_psi = params_dict['lmbda_psi']
alpha = 1
best_lmbda_psi = 0.1
best_mrr_psi = 0
lmbda_mrrs = []
best_config[str(rel_key)]['num_app_valid'] = copy(len(val_data_c_rel))
best_config[str(rel_key)]['num_app_train_valid'] = copy(len(trainval_data_c_rel))
best_config[str(rel_key)]['not_trained'] = 'False'
for lmbda_psi in lmbdas_psi:
perf_list_r = []
hits_list_r = []
perf_list_r, hits_list_r = predict(num_processes, val_data_c_rel,
trainval_data_c_rel, alpha, lmbda_psi,perf_list_r, hits_list_r,
window, neg_sampler, split_mode='val')
# compute mrr
mrr = np.mean(perf_list_r)
# # is new mrr better than previous best? if yes: store lmbda
if mrr > best_mrr_psi:
best_mrr_psi = float(mrr)
best_lmbda_psi = lmbda_psi
lmbda_mrrs.append(float(mrr))
best_config[str(rel_key)]['lmbda_psi'] = [best_lmbda_psi, best_mrr_psi]
best_config[str(rel_key)]['other_lmbda_mrrs'] = lmbda_mrrs
best_mrr = best_mrr_psi
##### 2) select alpha ###############
best_config[str(rel_key)]['not_trained'] = 'False'
alphas = params_dict['alpha']
lmbda_psi = best_config[str(rel_key)]['lmbda_psi'][0] # use the best lmbda psi
alpha_mrrs = []
# perf_list_all = []
best_mrr_alpha = 0
best_alpha=0.99
for alpha in alphas:
perf_list_r = []
hits_list_r = []
perf_list_r, hits_list_r = predict(num_processes, val_data_c_rel,
trainval_data_c_rel, alpha, lmbda_psi,perf_list_r, hits_list_r,
window, neg_sampler, split_mode='val')
# compute mrr
mrr_alpha = np.mean(perf_list_r)
# is new mrr better than previous best? if yes: store alpha
if mrr_alpha > best_mrr_alpha:
best_mrr_alpha = float(mrr_alpha)
best_alpha = alpha
best_mrr = best_mrr_alpha
alpha_mrrs.append(float(mrr_alpha))
best_config[str(rel_key)]['alpha'] = [best_alpha, best_mrr_alpha]
best_config[str(rel_key)]['other_alpha_mrrs'] = alpha_mrrs
end = timeit.default_timer()
total_time = round(end - start, 6)
print("Relation {} finished in {} seconds.".format(rel, total_time))
return best_config
## args
def get_args():
"""parse all arguments for the script"""
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", "-d", default="tkgl-yago", type=str)
parser.add_argument("--window", "-w", default=0, type=int) # set to e.g. 200 if only the most recent 200 timesteps should be considered. set to -2 if multistep
parser.add_argument("--num_processes", "-p", default=1, type=int)
parser.add_argument("--lmbda", "-l", default=0.1, type=float) # fix lambda. used if trainflag == false
parser.add_argument("--alpha", "-alpha", default=0.99, type=float) # fix alpha. used if trainflag == false
parser.add_argument("--train_flag", "-tr", default='False') # do we need training, ie selection of lambda and alpha
parser.add_argument("--load_flag", "-lo", default='False') # if train_flag set to True: do you want to load best_config?
parser.add_argument("--save_config", "-c", default='True') # do we need to save the selection of lambda and alpha in config file?
parser.add_argument('--seed', type=int, help='Random seed', default=1) # not needed
parsed = vars(parser.parse_args())
return parsed
start_o = timeit.default_timer()
parsed = get_args()
if parsed['num_processes']>1:
ray.init(num_cpus=parsed["num_processes"], num_gpus=0)
MODEL_NAME = 'RecurrencyBaseline'
SEED = parsed['seed'] # set the random seed for consistency
set_random_seed(SEED)
perrel_results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_models'
if not osp.exists(perrel_results_path):
os.mkdir(perrel_results_path)
print('INFO: Create directory {}'.format(perrel_results_path))
Path(perrel_results_path).mkdir(parents=True, exist_ok=True)
## load dataset and prepare it accordingly
name = parsed["dataset"]
dataset = LinkPropPredDataset(name=name, root="datasets", preprocess=True)
DATA = name
relations = dataset.edge_type
num_rels = dataset.num_rels
rels = np.arange(0,num_rels)
subjects = dataset.full_data["sources"]
objects= dataset.full_data["destinations"]
num_nodes = dataset.num_nodes
timestamps_orig = dataset.full_data["timestamps"]
timestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1
print("split train valid test data")
all_quads = np.stack((subjects, relations, objects, timestamps, timestamps_orig), axis=1)
train_data = all_quads[dataset.train_mask]
val_data = all_quads[dataset.val_mask]
test_data = all_quads[dataset.test_mask]
metric = dataset.eval_metric
evaluator = Evaluator(name=name)
neg_sampler = dataset.negative_sampler
train_val_data = np.concatenate([train_data, val_data])
all_data = np.concatenate([train_data, val_data, test_data])
# create dicts with key: relation id, values: triples for that relation id
print("grouping data by relation")
test_data_prel = group_by(test_data, 1)
all_data_prel = group_by(all_data, 1)
val_data_prel = group_by(val_data, 1)
trainval_data_prel = group_by(train_val_data, 1)
#load the ns samples
# if parsed['train_flag']:
print("loading negative samples")
dataset.load_val_ns()
dataset.load_test_ns()
# parameter options
if parsed['train_flag'] == 'True':
params_dict = {}
params_dict['lmbda_psi'] = [0, 0.0001, 0.0005, 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.5, 0.9, 1.0001]
params_dict['alpha'] = [0, 0.00001, 0.0001, 0.001, 0.01, 0.1, 0.5, 0.9, 0.99, 0.999, 0.9999, 0.99999, 1]
default_lmbda_psi = params_dict['lmbda_psi'][-1]
default_alpha = params_dict['alpha'][-2]
## load rules
print("creating rules")
basis_dict = create_basis_dict(train_val_data)
print("done with creating rules")
## init
# rb_predictor = RecurrencyBaselinePredictor(rels)
## train to find best lambda and alpha
start_train = timeit.default_timer()
if parsed['train_flag'] == 'True':
if parsed['load_flag'] == 'True':
with open('best_config.json', 'r') as infile:
best_config = json.load(infile)
else:
print('start training')
best_config = train(params_dict, rels, val_data_prel, trainval_data_prel, neg_sampler, parsed['num_processes'],
parsed['window'])
if parsed['save_config'] == 'True':
import json
with open('best_config.json', 'w') as outfile:
json.dump(best_config, outfile)
else: # use preset lmbda and alpha; same for all relations
best_config = {}
for rel in rels:
best_config[str(rel)] = {}
best_config[str(rel)]['lmbda_psi'] = [parsed['lmbda']]
best_config[str(rel)]['alpha'] = [parsed['alpha']]
end_train = timeit.default_timer()
# compute validation mrr
print("Computing validation MRR")
perf_list_all_val, hits_list_all_val = test(best_config,rels, val_data_prel,
trainval_data_prel, neg_sampler, parsed['num_processes'],
parsed['window'], split_mode='val')
val_mrr = float(np.mean(perf_list_all_val))
# compute test mrr
print("Computing test MRR")
start_test = timeit.default_timer()
perf_list_all, hits_list_all = test(best_config,rels, test_data_prel,
all_data_prel, neg_sampler, parsed['num_processes'],
parsed['window'])
end_o = timeit.default_timer()
train_time_o = round(end_train- start_train, 6)
test_time_o = round(end_o- start_test, 6)
total_time_o = round(end_o- start_o, 6)
print("Running Training to find best configs finished in {} seconds.".format(train_time_o))
print("Running testing with best configs finished in {} seconds.".format(test_time_o))
print("Running all steps finished in {} seconds.".format(total_time_o))
print(f"The test MRR is {np.mean(perf_list_all)}")
print(f"The valid MRR is {val_mrr}")
print(f"The Hits@10 is {np.mean(hits_list_all)}")
print(f"We have {len(perf_list_all)} predictions")
print(f"The test set has len {len(test_data)} ")
# for saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_NONE_{DATA}_results.json'
metric = dataset.eval_metric
save_results({'model': MODEL_NAME,
'train_flag': parsed['train_flag'],
'data': DATA,
'run': 1,
'seed': SEED,
metric: float(np.mean(perf_list_all)),
'val_mrr': val_mrr,
'hits10': float(np.mean(hits_list_all)),
'test_time': test_time_o,
'tot_train_val_time': total_time_o
},
results_filename)
if parsed['num_processes']>1:
ray.shutdown()
================================================
FILE: examples/linkproppred/tkgl-yago/regcn.py
================================================
"""
Temporal Knowledge Graph Reasoning Based on Evolutional Representation Learning
Reference:
- https://github.com/Lee-zix/RE-GCN
Zixuan Li, Xiaolong Jin, Wei Li, Saiping Guan, Jiafeng Guo, Huawei Shen, Yuanzhuo Wang and Xueqi Cheng. Temporal
Knowledge Graph Reasoning Based on Evolutional Representation Learning. SIGIR 2021.
"""
import sys
import timeit
import os
import sys
import os.path as osp
from pathlib import Path
import numpy as np
import torch
import random
from tqdm import tqdm
# internal imports
modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(modules_path)
from modules.rrgcn import RecurrentRGCNREGCN
from tgb.utils.utils import set_random_seed, split_by_time, save_results
from modules.tkg_utils import get_args_regcn, reformat_ts
from modules.tkg_utils_dgl import build_sub_graph
from tgb.linkproppred.evaluate import Evaluator
from tgb.linkproppred.dataset import LinkPropPredDataset
import json
def test(model, history_list, test_list, num_rels, num_nodes, use_cuda, model_name, static_graph, mode, split_mode):
"""
Test the model on either test or validation set
:param model: model used to test
:param history_list: all input history snap shot list, not include output label train list or valid list
:param test_list: test triple snap shot list
:param num_rels: number of relations
:param num_nodes: number of nodes
:param use_cuda:
:param model_name:
:param mode:
:param split_mode: 'test' or 'val' to state which negative samples to load
:return mrr
"""
print("Testing for mode: ", split_mode)
if split_mode == 'test':
timesteps_to_eval = test_timestamps_orig
else:
timesteps_to_eval = val_timestamps_orig
idx = 0
if mode == "test":
# test mode: load parameter form file
if use_cuda:
checkpoint = torch.load(model_name, map_location=torch.device(args.gpu))
else:
checkpoint = torch.load(model_name, map_location=torch.device('cpu'))
# use best stat checkpoint:
print("Load Model name: {}. Using best epoch : {}".format(model_name, checkpoint['epoch']))
print("\n"+"-"*10+"start testing"+"-"*10+"\n")
model.load_state_dict(checkpoint['state_dict'])
model.eval()
input_list = [snap for snap in history_list[-args.test_history_len:]]
perf_list_all = []
hits_list_all = []
perf_per_rel = {}
for rel in range(num_rels):
perf_per_rel[rel] = []
for time_idx, test_snap in enumerate(tqdm(test_list)):
history_glist = [build_sub_graph(num_nodes, num_rels, g, use_cuda, args.gpu) for g in input_list]
test_triples_input = torch.LongTensor(test_snap).cuda() if use_cuda else torch.LongTensor(test_snap)
timesteps_batch =timesteps_to_eval[time_idx]*np.ones(len(test_triples_input[:,0]))
neg_samples_batch = neg_sampler.query_batch(test_triples_input[:,0], test_triples_input[:,2],
timesteps_batch, edge_type=test_triples_input[:,1], split_mode=split_mode)
pos_samples_batch = test_triples_input[:,2]
_, perf_list, hits_list = model.predict(history_glist, num_rels, static_graph, test_triples_input, use_cuda, neg_samples_batch, pos_samples_batch,
evaluator, METRIC)
perf_list_all.extend(perf_list)
hits_list_all.extend(hits_list)
if split_mode == "test":
if args.log_per_rel:
for score, rel in zip(perf_list, test_triples_input[:,1].tolist()):
perf_per_rel[rel].append(score)
# reconstruct history graph list
input_list.pop(0)
input_list.append(test_snap)
idx += 1
if split_mode == "test":
if args.log_per_rel:
for rel in range(num_rels):
if len(perf_per_rel[rel]) > 0:
perf_per_rel[rel] = float(np.mean(perf_per_rel[rel]))
else:
perf_per_rel.pop(rel)
mrr = np.mean(perf_list_all)
hits10 = np.mean(hits_list_all)
return mrr, perf_per_rel, hits10
def run_experiment(args, n_hidden=None, n_layers=None, dropout=None, n_bases=None):
"""
Run the experiment with the given configuration
:param args: arguments
:param n_hidden: hidden dimension
:param n_layers: number of layers
:param dropout: dropout rate
:param n_bases: number of bases
:return: mrr, perf_per_rel (mean reciprocal rank, performance per relation)
"""
# load configuration for grid search the best configuration
if n_hidden:
args.n_hidden = n_hidden
if n_layers:
args.n_layers = n_layers
if dropout:
args.dropout = dropout
if n_bases:
args.n_bases = n_bases
mrr = 0
hits10=0
# 2) set save model path
save_model_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
if not osp.exists(save_model_dir):
os.mkdir(save_model_dir)
print('INFO: Create directory {}'.format(save_model_dir))
model_name= f'{MODEL_NAME}_{DATA}_{SEED}_{args.run_nr}'
model_state_file = save_model_dir+model_name
perf_per_rel = {}
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
num_static_rels, num_words, static_triples, static_graph = 0, 0, [], None
# create stat
model = RecurrentRGCNREGCN(args.decoder,
args.encoder,
num_nodes,
int(num_rels/2),
num_static_rels, # DIFFERENT
num_words, # DIFFERENT
args.n_hidden,
args.opn,
sequence_len=args.train_history_len,
num_bases=args.n_bases,
num_basis=args.n_basis,
num_hidden_layers=args.n_layers,
dropout=args.dropout,
self_loop=args.self_loop,
skip_connect=args.skip_connect,
layer_norm=args.layer_norm,
input_dropout=args.input_dropout,
hidden_dropout=args.hidden_dropout,
feat_dropout=args.feat_dropout,
aggregation=args.aggregation, # DIFFERENT
weight=args.weight, # DIFFERENT
discount=args.discount, # DIFFERENT
angle=args.angle, # DIFFERENT
use_static=args.add_static_graph, # DIFFERENT
entity_prediction=args.entity_prediction,
relation_prediction=args.relation_prediction,
use_cuda=use_cuda,
gpu = args.gpu,
analysis=args.run_analysis) # DIFFERENT
if use_cuda:
torch.cuda.set_device(args.gpu)
model.cuda()
# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
if args.test and os.path.exists(model_state_file):
mrr, perf_per_rel, hits10 = test(model,
train_list+valid_list,
test_list,
num_rels,
num_nodes,
use_cuda,
model_state_file,
static_graph,
"test",
"test")
return mrr, perf_per_rel, hits10
elif args.test and not os.path.exists(model_state_file):
print("--------------{} not exist, Change mode to train and generate stat for testing----------------\n".format(model_state_file))
return 0, 0
else:
print("----------------------------------------start training----------------------------------------\n")
best_mrr = 0
best_hits = 0
for epoch in range(args.n_epochs):
model.train()
losses = []
idx = [_ for _ in range(len(train_list))]
random.shuffle(idx)
for train_sample_num in tqdm(idx):
if train_sample_num == 0: continue
output = train_list[train_sample_num:train_sample_num+1]
if train_sample_num - args.train_history_len<0:
input_list = train_list[0: train_sample_num]
else:
input_list = train_list[train_sample_num - args.train_history_len:
train_sample_num]
# generate history graph
history_glist = [build_sub_graph(num_nodes, num_rels, snap, use_cuda, args.gpu) for snap in input_list]
output = [torch.from_numpy(_).long().cuda() for _ in output] if use_cuda else [torch.from_numpy(_).long() for _ in output]
loss_e, loss_r, loss_static = model.get_loss(history_glist, output[0], static_graph, use_cuda)
loss = args.task_weight*loss_e + (1-args.task_weight)*loss_r + loss_static
losses.append(loss.item())
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm) # clip gradients
optimizer.step()
optimizer.zero_grad()
print("Epoch {:04d} | Ave Loss: {:.4f} | Best MRR {:.4f} | Model {} "
.format(epoch, np.mean(losses), best_mrr, model_name))
#! checking GPU usage
free_mem, total_mem = torch.cuda.mem_get_info()
print ("--------------GPU memory usage-----------")
print ("there are ", free_mem, " free memory")
print ("there are ", total_mem, " total available memory")
print ("there are ", total_mem - free_mem, " used memory")
print ("--------------GPU memory usage-----------")
# validation
if epoch and epoch % args.evaluate_every == 0:
mrr,perf_per_rel, hits10 = test(model, train_list,
valid_list,
num_rels,
num_nodes,
use_cuda,
model_state_file,
static_graph,
mode="train", split_mode='val')
if mrr < best_mrr:
if epoch >= args.n_epochs:
break
else:
best_mrr = mrr
best_hits = hits10
torch.save({'state_dict': model.state_dict(), 'epoch': epoch}, model_state_file)
return best_mrr, perf_per_rel, hits10
# ==================
# ==================
# ==================
start_overall = timeit.default_timer()
# set hyperparameters
args, _ = get_args_regcn()
args.dataset = 'tkgl-yago'
SEED = args.seed # set the random seed for consistency
set_random_seed(SEED)
DATA=args.dataset
MODEL_NAME = 'REGCN'
print("logging mrrs per relation: ", args.log_per_rel)
# load data
dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
num_rels = dataset.num_rels
num_nodes = dataset.num_nodes
subjects = dataset.full_data["sources"]
objects= dataset.full_data["destinations"]
relations = dataset.edge_type
timestamps_orig = dataset.full_data["timestamps"]
timestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1
all_quads = np.stack((subjects, relations, objects, timestamps), axis=1)
train_data = all_quads[dataset.train_mask]
val_data = all_quads[dataset.val_mask]
test_data = all_quads[dataset.test_mask]
val_timestamps_orig = list(set(timestamps_orig[dataset.val_mask])) # needed for getting the negative samples
val_timestamps_orig.sort()
test_timestamps_orig = list(set(timestamps_orig[dataset.test_mask])) # needed for getting the negative samples
test_timestamps_orig.sort()
train_list = split_by_time(train_data)
valid_list = split_by_time(val_data)
test_list = split_by_time(test_data)
# evaluation metric
METRIC = dataset.eval_metric
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
#load the ns samples
dataset.load_val_ns()
dataset.load_test_ns()
## run training and testing
val_mrr, test_mrr = 0, 0
test_hits10 = 0
if args.grid_search:
print("hyperparameter grid search not implemented. Exiting.")
# single run
else:
start_train = timeit.default_timer()
if args.test == False: #if they are true: directly test on a previously trained and stored model
print('start training')
val_mrr, perf_per_rel, val_hits10 = run_experiment(args) # do training
start_test = timeit.default_timer()
args.test = True
print('start testing')
test_mrr, perf_per_rel, test_hits10 = run_experiment(args) # do testing
test_time = timeit.default_timer() - start_test
all_time = timeit.default_timer() - start_train
print(f"\tTest: Elapsed Time (s): {test_time: .4f}")
print(f"\Train and Test: Elapsed Time (s): {all_time: .4f}")
print(f"\tTest: {METRIC}: {test_mrr: .4f}")
print(f"\tValid: {METRIC}: {val_mrr: .4f}")
# saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
save_results({'model': MODEL_NAME,
'data': DATA,
'run': args.run_nr,
'seed': SEED,
f'val {METRIC}': float(val_mrr),
f'test {METRIC}': float(test_mrr),
'test_time': test_time,
'tot_train_val_time': all_time,
'test_hits10': float(test_hits10)
},
results_filename)
if args.log_per_rel:
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results_per_rel.json'
with open(results_filename, 'w') as json_file:
json.dump(perf_per_rel, json_file)
sys.exit()
================================================
FILE: examples/linkproppred/tkgl-yago/timetraveler.py
================================================
"""
TimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting
Reference:
- https://github.com/JHL-HUST/TITer
Haohai Sun, Jialun Zhong, Yunpu Ma, Zhen Han, Kun He.
TimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting EMNLP 2021
"""
import sys
import timeit
import torch
from torch.utils.data import Dataset,DataLoader
import logging
import numpy as np
import pickle
from tqdm import tqdm
import os.path as osp
from pathlib import Path
import os
modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(modules_path)
from modules.timetraveler_agent import Agent
from modules.timetraveler_environment import Env
from modules.timetraveler_dirichlet import Dirichlet, MLE_Dirchlet
from modules.timetraveler_episode import Episode
from modules.timetraveler_policygradient import PG
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.linkproppred.evaluate import Evaluator
from modules.timetraveler_trainertester import Trainer, Tester, getRelEntCooccurrence
from tgb.utils.utils import set_random_seed,save_results
from modules.tkg_utils import get_args_timetraveler, reformat_ts, get_model_config_timetraveler
class QuadruplesDataset(Dataset):
""" this is an internal way how Timetraveler represents the data
"""
def __init__(self, examples):
"""
examples: a list of quadruples.
num_r: number of relations
"""
self.quadruples = examples.copy()
def __len__(self):
return len(self.quadruples)
def __getitem__(self, item):
return self.quadruples[item][0], \
self.quadruples[item][1], \
self.quadruples[item][2], \
self.quadruples[item][3], \
self.quadruples[item][4]
def set_logger(save_path):
"""Write logs to checkpoint and console"""
if args.do_train:
log_file = os.path.join(save_path, 'train.log')
else:
log_file = os.path.join(save_path, 'test.log')
logging.basicConfig(
format='%(asctime)s %(levelname)-8s %(message)s',
level=logging.INFO,
datefmt='%Y-%m-%d %H:%M:%S',
filename=log_file,
filemode='w'
)
console = logging.StreamHandler()
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s')
console.setFormatter(formatter)
logging.getLogger('').addHandler(console)
def preprocess_data(args, config, timestamps, save_path, all_quads):
"""
Preprocess the data and save the state-action space (pickle dump)
"""
# parser = argparse.ArgumentParser(description='Data preprocess', usage='preprocess_data.py [] [-h | --help]')
# parser.add_argument('--data_dir', default='data/ICEWS14', type=str, help='Path to data.')
env = Env(all_quads, config)
state_actions_space = {}
with tqdm(total=len(all_quads)) as bar:
for (head, rel, tail, t, _) in all_quads:
if (head, t, True) not in state_actions_space.keys():
state_actions_space[(head, t, True)] = env.get_state_actions_space_complete(head, t, True, args.store_actions_num)
state_actions_space[(head, t, False)] = env.get_state_actions_space_complete(head, t, False, args.store_actions_num)
if (tail, t, True) not in state_actions_space.keys():
state_actions_space[(tail, t, True)] = env.get_state_actions_space_complete(tail, t, True, args.store_actions_num)
state_actions_space[(tail, t, False)] = env.get_state_actions_space_complete(tail, t, False, args.store_actions_num)
bar.update(1)
pickle.dump(state_actions_space, open(os.path.join(save_path, args.state_actions_path), 'wb'))
def log_metrics(mode, step, metrics):
"""Print the evaluation logs"""
for metric in metrics:
logging.info('%s %s at epoch %d: %f' % (mode, metric, step, metrics[metric]))
def main(args):
"""
Main function to train and test the TimeTraveler model"""
start_overall = timeit.default_timer()
#######################Set Logger#################################
save_path = f'{os.path.dirname(os.path.abspath(__file__))}/saved_models/'
if not os.path.exists(save_path):
os.makedirs(save_path)
if args.cuda and torch.cuda.is_available():
args.cuda = True
else:
args.cuda = False
set_logger(save_path)
#######################Create DataLoader#################################
# set hyperparameters
args.dataset = 'tkgl-yago'
SEED = args.seed # set the random seed for consistency
set_random_seed(SEED)
DATA=args.dataset
MODEL_NAME = 'TIMETRAVELER'
# load data
dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
num_rels = dataset.num_rels
num_nodes = dataset.num_nodes
subjects = dataset.full_data["sources"]
objects= dataset.full_data["destinations"]
relations = dataset.edge_type
timestamps_orig = dataset.full_data["timestamps"]
timestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1
all_quads = np.stack((subjects, relations, objects, timestamps,timestamps_orig), axis=1)
train_data = all_quads[dataset.train_mask]
train_entities = np.unique(np.concatenate((train_data[:, 0], train_data[:, 2])))
RelEntCooccurrence = getRelEntCooccurrence(train_data, num_rels)
train_data =QuadruplesDataset(train_data)
val_data = QuadruplesDataset(all_quads[dataset.val_mask])
test_data = QuadruplesDataset(all_quads[dataset.test_mask])
METRIC = dataset.eval_metric
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
#load the ns samples
dataset.load_val_ns()
dataset.load_test_ns()
train_dataloader = DataLoader(
train_data,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
)
valid_dataloader = DataLoader(
val_data,
batch_size=args.test_batch_size,
shuffle=False,
num_workers=args.num_workers,
)
test_dataloader = DataLoader(
test_data,
batch_size=args.test_batch_size,
shuffle=False,
num_workers=args.num_workers,
)
######################Creat the agent and the environment###########################
config = get_model_config_timetraveler(args, num_nodes, num_rels)
logging.info(config)
logging.info(args)
# creat the agent
agent = Agent(config)
# creat the environment
state_actions_path = os.path.join(save_path, args.state_actions_path)
######################preprocessing###########################
if not os.path.exists(state_actions_path):
if args.preprocess:
print("preprocessing data...")
preprocess_data(args, config, timestamps, save_path, list(all_quads))
state_action_space = pickle.load(open(os.path.join(save_path, args.state_actions_path), 'rb'))
else:
state_action_space = None
else:
print("load preprocessed data...")
state_action_space = pickle.load(open(os.path.join(save_path, args.state_actions_path), 'rb'))
env = Env(list(all_quads), config, state_action_space)
# Create episode controller
episode = Episode(env, agent, config)
if args.cuda:
episode = episode.cuda()
pg = PG(config) # Policy Gradient
optimizer = torch.optim.Adam(episode.parameters(), lr=args.lr, weight_decay=0.00001)
######################Reward Shaping: MLE DIRICHLET alphas###########################
if args.reward_shaping:
try:
print("load alphas from pickle file")
alphas = pickle.load(open(os.path.join(save_path, args.alphas_pkl), 'rb'))
except:
print('running MLE dirichlet now')
mle_d = MLE_Dirchlet(all_quads, num_rels, args.k, args.time_span,
args.tol, args.method, args.maxiter)
pickle.dump(mle_d.alphas, open(os.path.join(save_path, args.alphas_pkl), 'wb'))
print('dumped alphas')
alphas = mle_d.alphas
distributions = Dirichlet(alphas, args.k)
else:
distributions = None
######################Training and Testing###########################
trainer = Trainer(episode, pg, optimizer, args, distributions)
tester = Tester(episode, args, train_entities, RelEntCooccurrence, dataset.metric)
test_metrics ={}
val_metrics = {}
test_metrics[METRIC] = None
val_metrics[METRIC] = None
if args.do_train:
start_train =timeit.default_timer()
logging.info('Start Training......')
for i in range(args.max_epochs):
loss, reward = trainer.train_epoch(train_dataloader, len(train_data))
logging.info('Epoch {}/{} Loss: {}, reward: {}'.format(i, args.max_epochs, loss, reward))
#! checking GPU usage
free_mem, total_mem = torch.cuda.mem_get_info()
print ("--------------GPU memory usage-----------")
print ("there are ", free_mem, " free memory")
print ("there are ", total_mem, " total available memory")
print ("there are ", total_mem - free_mem, " used memory")
print ("--------------GPU memory usage-----------")
if i % args.save_epoch == 0 and i != 0:
trainer.save_model(save_path, 'checkpoint_{}.pth'.format(i))
logging.info('Save Model in {}'.format(save_path))
if i % args.valid_epoch == 0 and i != 0:
logging.info('Start Val......')
val_metrics = tester.test(valid_dataloader,
len(val_data), num_nodes, neg_sampler, evaluator, split_mode='val')
for mode in val_metrics.keys():
logging.info('{} at epoch {}: {}'.format(mode, i, val_metrics[mode]))
trainer.save_model(save_path)
logging.info('Save Model in {}'.format(save_path))
else:
# # Load the model parameters
if os.path.isfile(save_path):
params = torch.load(save_path)
episode.load_state_dict(params['model_state_dict'])
optimizer.load_state_dict(params['optimizer_state_dict'])
logging.info('Load pretrain model: {}'.format(save_path))
if args.do_test:
logging.info('Start Testing......')
start_test = timeit.default_timer()
test_metrics = tester.test(test_dataloader,
len(test_data), num_nodes, neg_sampler, evaluator, split_mode='test')
for mode in test_metrics.keys():
logging.info('Test {} : {}'.format(mode, test_metrics[mode]))
# saving the results...
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results.json'
test_time = timeit.default_timer() - start_test
all_time = timeit.default_timer() - start_train
all_time_preprocess = timeit.default_timer() - start_overall
save_results({'model': MODEL_NAME,
'data': DATA,
'seed': SEED,
f'val {METRIC}': float(val_metrics[METRIC]),
f'test {METRIC}': float(test_metrics[METRIC]),
'test_time': test_time,
'tot_train_val_time': all_time,
'tot_preprocess_train_val_time': all_time_preprocess
},
results_filename)
if __name__ == '__main__':
args = get_args_timetraveler()
main(args)
================================================
FILE: examples/linkproppred/tkgl-yago/tkgl-yago_example.py
================================================
import numpy as np
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
from tgb.linkproppred.evaluate import Evaluator
DATA = "tkgl-yago"
# data loading
dataset = PyGLinkPropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
metric = dataset.eval_metric
print ("there are {} nodes and {} edges".format(dataset.num_nodes, dataset.num_edges))
print ("there are {} relation types".format(dataset.num_rels))
timestamp = data.t
head = data.src
tail = data.dst
edge_type = data.edge_type #relation
neg_sampler = dataset.negative_sampler
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
metric = dataset.eval_metric
evaluator = Evaluator(name=DATA)
neg_sampler = dataset.negative_sampler
#load the ns samples first
dataset.load_val_ns()
for i, (src, dst, t, rel) in enumerate(zip(val_data.src, val_data.dst, val_data.t, val_data.edge_type)):
#must use np array to query
neg_batch_list = neg_sampler.query_batch(np.array([src]), np.array([dst]), np.array([t]), edge_type=np.array([rel]), split_mode='val')
print ("retrieved all negative samples")
# #* load numpy arrays instead
# from tgb.linkproppred.dataset import LinkPropPredDataset
# # data loading
# dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
# data = dataset.full_data
# metric = dataset.eval_metric
# sources = dataset.full_data['sources']
# print (sources.dtype)
================================================
FILE: examples/linkproppred/tkgl-yago/tlogic.py
================================================
"""
https://github.com/liu-yushan/TLogic/tree/main/mycode
TLogic: Temporal Logical Rules for Explainable Link Forecasting on Temporal Knowledge Graphs.
Yushan Liu, Yunpu Ma, Marcel Hildebrandt, Mitchell Joblin, Volker Tresp
"""
# imports
import sys
import os
import os.path as osp
from pathlib import Path
modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(modules_path)
import timeit
import argparse
import numpy as np
import json
from joblib import Parallel, delayed
import itertools
from tgb.linkproppred.evaluate import Evaluator
from tgb.linkproppred.dataset import LinkPropPredDataset
from modules.tlogic_learn_modules import Temporal_Walk, Rule_Learner, store_edges
import modules.tlogic_apply_modules as ra
from tgb.utils.utils import set_random_seed, save_results
from modules.tkg_utils import reformat_ts, get_inv_relation_id, create_scores_array
def learn_rules(i, num_relations):
"""
Learn rules (multiprocessing possible).
Parameters:
i (int): process number
num_relations (int): minimum number of relations for each process
Returns:
rl.rules_dict (dict): rules dictionary
"""
# if seed:
# np.random.seed(seed)
num_rest_relations = len(all_relations) - (i + 1) * num_relations
if num_rest_relations >= num_relations:
relations_idx = range(i * num_relations, (i + 1) * num_relations)
else:
relations_idx = range(i * num_relations, len(all_relations))
num_rules = [0]
for k in relations_idx:
rel = all_relations[k]
for length in rule_lengths:
it_start = timeit.default_timer()
for _ in range(num_walks):
walk_successful, walk = temporal_walk.sample_walk(length + 1, rel)
if walk_successful:
rl.create_rule(walk)
it_end = timeit.default_timer()
it_time = round(it_end - it_start, 6)
num_rules.append(sum([len(v) for k, v in rl.rules_dict.items()]) // 2)
num_new_rules = num_rules[-1] - num_rules[-2]
print(
"Process {0}: relation {1}/{2}, length {3}: {4} sec, {5} rules".format(
i,
k - relations_idx[0] + 1,
len(relations_idx),
length,
it_time,
num_new_rules,
)
)
return rl.rules_dict
def apply_rules(i, num_queries, rules_dict, neg_sampler, data, window, learn_edges, all_quads, args, split_mode,
log_per_rel=False, num_rels=0):
"""
Apply rules (multiprocessing possible).
Parameters:
i (int): process number
num_queries (int): minimum number of queries for each process
Returns:
hits_list (list): hits list (hits@10 per sample)
perf_list (list): performance list (mrr per sample)
"""
perf_per_rel = {}
for rel in range(num_rels):
perf_per_rel[rel] = []
print("Start process", i, "...")
all_candidates = [dict() for _ in range(len(args))]
no_cands_counter = 0
num_rest_queries = len(data) - (i + 1) * num_queries
if num_rest_queries >= num_queries:
test_queries_idx = range(i * num_queries, (i + 1) * num_queries)
else:
test_queries_idx = range(i * num_queries, len(data))
cur_ts = data[test_queries_idx[0]][3]
edges = ra.get_window_edges(all_quads[:,0:4], cur_ts, learn_edges, window)
it_start = timeit.default_timer()
hits_list = [0] * len(test_queries_idx)
perf_list = [0] * len(test_queries_idx)
for index, j in enumerate(test_queries_idx):
neg_sample_el = neg_sampler.query_batch(np.expand_dims(np.array(data[j,0]), axis=0),
np.expand_dims(np.array(data[j,2]), axis=0),
np.expand_dims(np.array(data[j,4]), axis=0),
np.expand_dims(np.array(data[j,1]), axis=0),
split_mode=split_mode)[0]
# neg_samples_batch[j]
pos_sample_el = data[j,2]
test_query = data[j]
assert pos_sample_el == test_query[2]
cands_dict = [dict() for _ in range(len(args))]
if test_query[3] != cur_ts:
cur_ts = test_query[3]
edges = ra.get_window_edges(all_quads[:,0:4], cur_ts, learn_edges, window)
if test_query[1] in rules_dict:
dicts_idx = list(range(len(args)))
for rule in rules_dict[test_query[1]]:
walk_edges = ra.match_body_relations(rule, edges, test_query[0])
if 0 not in [len(x) for x in walk_edges]:
rule_walks = ra.get_walks(rule, walk_edges)
if rule["var_constraints"]:
rule_walks = ra.check_var_constraints(
rule["var_constraints"], rule_walks
)
if not rule_walks.empty:
cands_dict = ra.get_candidates(
rule,
rule_walks,
cur_ts,
cands_dict,
score_func,
args,
dicts_idx,
)
for s in dicts_idx:
cands_dict[s] = {
x: sorted(cands_dict[s][x], reverse=True)
for x in cands_dict[s].keys()
}
cands_dict[s] = dict(
sorted(
cands_dict[s].items(),
key=lambda item: item[1],
reverse=True,
)
)
top_k_scores = [v for _, v in cands_dict[s].items()][:top_k]
unique_scores = list(
scores for scores, _ in itertools.groupby(top_k_scores)
)
if len(unique_scores) >= top_k:
dicts_idx.remove(s)
if not dicts_idx:
break
if cands_dict[0]:
for s in range(len(args)):
# Calculate noisy-or scores
scores = list(
map(
lambda x: 1 - np.product(1 - np.array(x)),
cands_dict[s].values(),
)
)
cands_scores = dict(zip(cands_dict[s].keys(), scores))
noisy_or_cands = dict(
sorted(cands_scores.items(), key=lambda x: x[1], reverse=True)
)
all_candidates[s][j] = noisy_or_cands
else: # No candidates found by applying rules
no_cands_counter += 1
for s in range(len(args)):
all_candidates[s][j] = dict()
else: # No rules exist for this relation
no_cands_counter += 1
for s in range(len(args)):
all_candidates[s][j] = dict()
if not (j - test_queries_idx[0] + 1) % 100:
it_end = timeit.default_timer()
it_time = round(it_end - it_start, 6)
print(
"Process {0}: test samples finished: {1}/{2}, {3} sec".format(
i, j - test_queries_idx[0] + 1, len(test_queries_idx), it_time
)
)
it_start = timeit.default_timer()
predictions = create_scores_array(all_candidates[s][j], num_nodes)
predictions_of_interest_pos = np.array(predictions[pos_sample_el])
predictions_of_interest_neg = predictions[neg_sample_el]
input_dict = {
"y_pred_pos": predictions_of_interest_pos,
"y_pred_neg": predictions_of_interest_neg,
"eval_metric": ['mrr'],
}
predictions = evaluator.eval(input_dict)
perf_list[index] = predictions['mrr']
hits_list[index] = predictions['hits@10']
if split_mode == "test":
if log_per_rel:
perf_per_rel[test_query[1]].append(perf_list[index]) #test_query[1] is the relation index
if split_mode == "test":
if log_per_rel:
for rel in range(num_rels):
if len(perf_per_rel[rel]) > 0:
perf_per_rel[rel] = float(np.mean(perf_per_rel[rel]))
else:
perf_per_rel.pop(rel)
return perf_list, hits_list, perf_per_rel
## args
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", "-d", default="tkgl-yago", type=str)
parser.add_argument("--rule_lengths", "-l", default="1", type=int, nargs="+")
parser.add_argument("--num_walks", "-n", default="100", type=int)
parser.add_argument("--transition_distr", default="exp", type=str)
parser.add_argument("--window", "-w", default=0, type=int) # set to e.g. 200 if only the most recent 200 timesteps should be considered. set to -2 if multistep
parser.add_argument("--top_k", default=20, type=int)
parser.add_argument("--num_processes", "-p", default=1, type=int)
parser.add_argument("--alpha", "-alpha", default=0.99, type=float) # fix alpha. used if trainflag == false
# parser.add_argument("--train_flag", "-tr", default=True) # do we need training, ie selection of lambda and alpha
parser.add_argument("--save_config", "-c", default=True) # do we need to save the selection of lambda and alpha in config file?
parser.add_argument('--seed', type=int, help='Random seed', default=1)
parser.add_argument('--run_nr', type=int, help='Run Number', default=1)
parser.add_argument('--learn_rules_flag', type=bool, help='Do we want to learn the rules', default=True)
parser.add_argument('--rule_filename', type=str, help='if rules not learned: where are they stored', default='0_r[3]_n100_exp_s1_rules.json')
parser.add_argument('--log_per_rel', type=bool, help='Do we want to log mrr per relation', default=False)
parser.add_argument('--compute_valid_mrr', type=bool, help='Do we want to compute mrr for valid set', default=True)
parsed = vars(parser.parse_args())
return parsed
start_o = timeit.default_timer()
## get args
parsed = get_args()
dataset = parsed["dataset"]
rule_lengths = parsed["rule_lengths"]
rule_lengths = [rule_lengths] if (type(rule_lengths) == int) else rule_lengths
num_walks = parsed["num_walks"]
transition_distr = parsed["transition_distr"]
num_processes = parsed["num_processes"]
window = parsed["window"]
top_k = parsed["top_k"]
log_per_rel = parsed['log_per_rel']
MODEL_NAME = 'TLogic'
SEED = parsed['seed'] # set the random seed for consistency
set_random_seed(SEED)
## load dataset and prepare it accordingly
name = parsed["dataset"]
compute_valid_mrr = parsed["compute_valid_mrr"]
dataset = LinkPropPredDataset(name=name, root="datasets", preprocess=True)
DATA = name
relations = dataset.edge_type
num_rels = dataset.num_rels
subjects = dataset.full_data["sources"]
objects= dataset.full_data["destinations"]
num_nodes = dataset.num_nodes
timestamps_orig = dataset.full_data["timestamps"]
timestamps = reformat_ts(timestamps_orig, DATA) # stepsize:1
all_quads = np.stack((subjects, relations, objects, timestamps, timestamps_orig), axis=1)
train_data = all_quads[dataset.train_mask,0:4] # we do not need the original timestamps
val_data = all_quads[dataset.val_mask,0:5]
test_data = all_quads[dataset.test_mask,0:5]
all_data = all_quads[:,0:4]
metric = dataset.eval_metric
evaluator = Evaluator(name=name)
neg_sampler = dataset.negative_sampler
inv_relation_id = get_inv_relation_id(num_rels)
#load the ns samples
dataset.load_val_ns()
dataset.load_test_ns()
output_dir = f'{osp.dirname(osp.abspath(__file__))}/saved_models/'
learn_rules_flag = parsed['learn_rules_flag']
## 1. learn rules
start_train = timeit.default_timer()
if learn_rules_flag:
print("start learning rules")
# edges (dict): edges for each relation
# inv_relation_id (dict): mapping of relation to inverse relation
temporal_walk = Temporal_Walk(train_data, inv_relation_id, transition_distr)
rl = Rule_Learner(edges=temporal_walk.edges, id2relation=None, inv_relation_id=inv_relation_id,
output_dir=output_dir)
all_relations = sorted(temporal_walk.edges) # Learn for all relations
start = timeit.default_timer()
num_relations = len(all_relations) // num_processes
output = Parallel(n_jobs=num_processes)(
delayed(learn_rules)(i, num_relations) for i in range(num_processes)
)
end = timeit.default_timer()
all_rules = output[0]
for i in range(1, num_processes):
all_rules.update(output[i])
total_time = round(end - start, 6)
print("Learning finished in {} seconds.".format(total_time))
rl.rules_dict = all_rules
rl.sort_rules_dict()
rule_filename = rl.save_rules(0, rule_lengths, num_walks, transition_distr, SEED)
# rl.save_rules_verbalized(0, rule_lengths, num_walks, transition_distr, seed)
# rules_statistics(rl.rules_dict)
else:
rule_filename = parsed['rule_filename']
print("Loading rules from file {}".format(parsed['rule_filename']))
end_train = timeit.default_timer()
## 2. Apply rules
rules_dict = json.load(open(output_dir + rule_filename))
rules_dict = {int(k): v for k, v in rules_dict.items()}
rules_dict = ra.filter_rules(
rules_dict, min_conf=0.01, min_body_supp=2, rule_lengths=rule_lengths
) # filter rules for minimum confidence, body support and rule length
learn_edges = store_edges(train_data)
score_func = ra.score_12
# It is possible to specify a list of list of arguments for tuning
args = [[0.1, 0.5]]
# compute valid mrr
start_valid = timeit.default_timer()
if compute_valid_mrr:
print('Computing valid MRR')
num_queries = len(val_data) // num_processes
output = Parallel(n_jobs=num_processes)(
delayed(apply_rules)(i, num_queries,rules_dict, neg_sampler, val_data, window, learn_edges,
all_quads, args, split_mode='val') for i in range(num_processes))
end = timeit.default_timer()
perf_list_val = []
hits_list_val = []
for i in range(num_processes):
perf_list_val.extend(output[i][0])
hits_list_val.extend(output[i][1])
else:
perf_list_val = [0]
hits_list_val = [0]
end_valid = timeit.default_timer()
# compute test mrr
if log_per_rel ==True:
num_processes = 1 #otherwise logging per rel does not work for our implementation
start_test = timeit.default_timer()
print('Computing test MRR')
start = timeit.default_timer()
num_queries = len(test_data) // num_processes
output = Parallel(n_jobs=num_processes)(
delayed(apply_rules)(i, num_queries,rules_dict, neg_sampler, test_data, window, learn_edges,
all_quads, args, split_mode='test', log_per_rel=log_per_rel, num_rels=num_rels) for i in range(num_processes))
end = timeit.default_timer()
perf_list_all = []
hits_list_all = []
for i in range(num_processes):
perf_list_all.extend(output[i][0])
hits_list_all.extend(output[i][1])
if log_per_rel == True:
perf_per_rel = output[0][2]
total_time = round(end - start, 6)
total_valid_time = round(end_valid - start_valid, 6)
print("Application finished in {} seconds.".format(total_time))
print(f"The valid MRR is {np.mean(perf_list_val)}")
print(f"The MRR is {np.mean(perf_list_all)}")
print(f"The Hits@10 is {np.mean(hits_list_all)}")
print(f"We have {len(perf_list_all)} predictions")
print(f"The test set has len {len(test_data)} ")
end_o = timeit.default_timer()
train_time_o = round(end_train- start_train, 6)
test_time_o = round(end_o- start_test, 6)
total_time_o = round(end_o- start_o, 6)
print("Running Training to find best configs finished in {} seconds.".format(train_time_o))
print("Running testing with best configs finished in {} seconds.".format(test_time_o))
print("Running all steps finished in {} seconds.".format(total_time_o))
results_path = f'{osp.dirname(osp.abspath(__file__))}/saved_results'
if not osp.exists(results_path):
os.mkdir(results_path)
print('INFO: Create directory {}'.format(results_path))
Path(results_path).mkdir(parents=True, exist_ok=True)
if log_per_rel == True:
results_filename = f'{results_path}/{MODEL_NAME}_{DATA}_results_per_rel.json'
with open(results_filename, 'w') as json_file:
json.dump(perf_per_rel, json_file)
results_filename = f'{results_path}/{MODEL_NAME}_NONE_{DATA}_results.json'
metric = dataset.eval_metric
save_results({'model': MODEL_NAME,
'train_flag': None,
'rule_len': rule_lengths,
'window': window,
'data': DATA,
'run': 1,
'seed': SEED,
metric: float(np.mean(perf_list_all)),
'hits10': float(np.mean(hits_list_all)),
'val_mrr': float(np.mean(perf_list_val)),
'test_time': test_time_o,
'tot_train_val_time': total_time_o,
'valid_time': total_valid_time
},
results_filename)
================================================
FILE: examples/nodeproppred/tgbn-genre/dyrep.py
================================================
"""
DyRep
This has been implemented with intuitions from the following sources:
- https://github.com/twitter-research/tgn
- https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py
Spec.:
- Memory Updater: RNN
- Embedding Module: ID
- Message Function: ATTN
"""
import timeit
from tqdm import tqdm
import torch
from torch_geometric.loader import TemporalDataLoader
# internal imports
from tgb.utils.utils import get_args, set_random_seed
from tgb.nodeproppred.evaluate import Evaluator
from modules.decoder import NodePredictor
from modules.emb_module import GraphAttentionEmbedding
from modules.msg_func import IdentityMessage
from modules.msg_agg import LastAggregator
from modules.neighbor_loader import LastNeighborLoader
from modules.memory_module import DyRepMemory
from tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset
def process_edges(src, dst, t, msg):
if src.nelement() > 0:
model['memory'].update_state(src, dst, t, msg)
neighbor_loader.insert(src, dst)
# ==========
# ========== Define helper function...
# ==========
def train():
model['memory'].train()
model['gnn'].train()
model['node_pred'].train()
model['memory'].reset_state() # Start with a fresh memory.
neighbor_loader.reset_state() # Start with an empty graph.
total_loss = 0
label_t = dataset.get_label_time() # check when does the first label start
num_label_ts = 0
total_score = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
query_t = batch.t[-1]
# check if this batch moves to the next day
if query_t > label_t:
# find the node labels from the past day
label_tuple = dataset.get_node_label(query_t)
label_ts, label_srcs, labels = (
label_tuple[0],
label_tuple[1],
label_tuple[2],
)
label_t = dataset.get_label_time()
label_srcs = label_srcs.to(device)
# Process all edges that are still in the past day
previous_day_mask = batch.t < label_t
process_edges(
src[previous_day_mask],
dst[previous_day_mask],
t[previous_day_mask],
msg[previous_day_mask],
)
# Reset edges to be the edges from tomorrow so they can be used later
src, dst, t, msg = (
src[~previous_day_mask],
dst[~previous_day_mask],
t[~previous_day_mask],
msg[~previous_day_mask],
)
"""
modified for node property prediction
1. sample neighbors from the neighbor loader for all nodes to be predicted
2. extract memory from the sampled neighbors and the nodes
3. run gnn with the extracted memory embeddings and the corresponding time and message
"""
n_id = label_srcs
n_id_neighbors, mem_edge_index, e_id = neighbor_loader(n_id)
assoc[n_id_neighbors] = torch.arange(n_id_neighbors.size(0), device=device)
z, last_update = model['memory'](n_id_neighbors)
z = model['gnn'](
z,
last_update,
mem_edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
z = z[assoc[n_id]]
# loss and metric computation
pred = model['node_pred'](z)
loss = criterion(pred, labels.to(device))
np_pred = pred.cpu().detach().numpy()
np_true = labels.cpu().detach().numpy()
input_dict = {
"y_true": np_true,
"y_pred": np_pred,
"eval_metric": [metric],
}
result_dict = evaluator.eval(input_dict)
score = result_dict[metric]
total_score += score
num_label_ts += 1
loss.backward()
optimizer.step()
total_loss += float(loss.detach())
# Update memory and neighbor loader with ground-truth state.
process_edges(src, dst, t, msg)
model['memory'].detach()
metric_dict = {
"ce": total_loss / num_label_ts,
}
metric_dict[metric] = total_score / num_label_ts
return metric_dict
@torch.no_grad()
def test(loader):
model['memory'].eval()
model['gnn'].eval()
model['node_pred'].eval()
total_score = 0
label_t = dataset.get_label_time() # check when does the first label start
num_label_ts = 0
for batch in tqdm(loader):
batch = batch.to(device)
src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
query_t = batch.t[-1]
if query_t > label_t:
label_tuple = dataset.get_node_label(query_t)
if label_tuple is None:
break
label_ts, label_srcs, labels = (
label_tuple[0],
label_tuple[1],
label_tuple[2],
)
label_t = dataset.get_label_time()
label_srcs = label_srcs.to(device)
# Process all edges that are still in the past day
previous_day_mask = batch.t < label_t
process_edges(
src[previous_day_mask],
dst[previous_day_mask],
t[previous_day_mask],
msg[previous_day_mask],
)
# Reset edges to be the edges from tomorrow so they can be used later
src, dst, t, msg = (
src[~previous_day_mask],
dst[~previous_day_mask],
t[~previous_day_mask],
msg[~previous_day_mask],
)
"""
modified for node property prediction
1. sample neighbors from the neighbor loader for all nodes to be predicted
2. extract memory from the sampled neighbors and the nodes
3. run gnn with the extracted memory embeddings and the corresponding time and message
"""
n_id = label_srcs
n_id_neighbors, mem_edge_index, e_id = neighbor_loader(n_id)
assoc[n_id_neighbors] = torch.arange(n_id_neighbors.size(0), device=device)
z, last_update = model['memory'](n_id_neighbors)
z = model['gnn'](
z,
last_update,
mem_edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
z = z[assoc[n_id]]
# loss and metric computation
pred = model['node_pred'](z)
np_pred = pred.cpu().detach().numpy()
np_true = labels.cpu().detach().numpy()
input_dict = {
"y_true": np_true,
"y_pred": np_pred,
"eval_metric": [metric],
}
result_dict = evaluator.eval(input_dict)
score = result_dict[metric]
total_score += score
num_label_ts += 1
process_edges(src, dst, t, msg)
metric_dict = {}
metric_dict[metric] = total_score / num_label_ts
return metric_dict
# ==========
# ==========
# ==========
# Start...
start_overall = timeit.default_timer()
# ========== set parameters...
args, _ = get_args()
print("INFO: Arguments:", args)
DATA = "tgbn-genre"
LR = args.lr
BATCH_SIZE = args.bs
K_VALUE = args.k_value
NUM_EPOCH = args.num_epoch
SEED = args.seed
MEM_DIM = args.mem_dim
TIME_DIM = args.time_dim
EMB_DIM = args.emb_dim
TOLERANCE = args.tolerance
PATIENCE = args.patience
NUM_RUNS = args.num_run
NUM_NEIGHBORS = 10
# setting random seed
torch.manual_seed(SEED)
set_random_seed(SEED)
MODEL_NAME = 'DyRep'
USE_SRC_EMB_IN_MSG = False
USE_DST_EMB_IN_MSG = True
evaluator = Evaluator(name=DATA)
# ==========
# set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# data loading
dataset = PyGNodePropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data = data.to(device)
metric = dataset.eval_metric
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
num_classes = dataset.num_classes
train_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)
val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)
test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
# neighborhood sampler
neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)
# define the model end-to-end
memory = DyRepMemory(
data.num_nodes,
data.msg.size(-1),
MEM_DIM,
TIME_DIM,
message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),
aggregator_module=LastAggregator(),
memory_updater_type='rnn',
use_src_emb_in_msg=USE_SRC_EMB_IN_MSG,
use_dst_emb_in_msg=USE_DST_EMB_IN_MSG
).to(device)
gnn = GraphAttentionEmbedding(
in_channels=MEM_DIM,
out_channels=EMB_DIM,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
).to(device)
node_pred = NodePredictor(in_dim=EMB_DIM, out_dim=num_classes).to(device)
model = {'memory': memory,
'gnn': gnn,
'node_pred': node_pred}
optimizer = torch.optim.Adam(
set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['node_pred'].parameters()),
lr=LR,
)
criterion = torch.nn.CrossEntropyLoss()
# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)
train_curve = []
val_curve = []
test_curve = []
max_val_score = 0 #find the best test score based on validation score
best_test_idx = 0
for epoch in range(1, NUM_EPOCH + 1):
start_time = timeit.default_timer()
train_dict = train()
print("------------------------------------")
print(f"training Epoch: {epoch:02d}")
print(train_dict)
train_curve.append(train_dict[metric])
print("Training takes--- %s seconds ---" % (timeit.default_timer() - start_time))
start_time = timeit.default_timer()
val_dict = test(val_loader)
print(val_dict)
val_curve.append(val_dict[metric])
if (val_dict[metric] > max_val_score):
max_val_score = val_dict[metric]
best_test_idx = epoch - 1
print("Validation takes--- %s seconds ---" % (timeit.default_timer() - start_time))
start_time = timeit.default_timer()
test_dict = test(test_loader)
print(test_dict)
test_curve.append(test_dict[metric])
print("Test takes--- %s seconds ---" % (timeit.default_timer() - start_time))
print("------------------------------------")
dataset.reset_label_time()
max_test_score = test_curve[best_test_idx]
print("------------------------------------")
print("------------------------------------")
print ("best val score: ", max_val_score)
print ("best validation epoch : ", best_test_idx + 1)
print ("best test score: ", max_test_score)
================================================
FILE: examples/nodeproppred/tgbn-genre/moving_average.py
================================================
"""
implement persistant forecast as baseline for the node prop pred task
simply predict last seen label for the node
"""
import timeit
import numpy as np
from torch_geometric.loader import TemporalDataLoader
# local imports
from tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset
from modules.heuristics import MovingAverage
from tgb.nodeproppred.evaluate import Evaluator
window = 6
device = "cpu"
name = "tgbn-genre"
dataset = PyGNodePropPredDataset(name=name, root="datasets")
num_classes = dataset.num_classes
data = dataset.get_TemporalData()
data = data.to(device)
eval_metric = dataset.eval_metric
forecaster = MovingAverage(num_classes, window=window)
evaluator = Evaluator(name=name)
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
train_data, val_data, test_data = data.train_val_test_split(
val_ratio=0.15, test_ratio=0.15
)
batch_size = 200
train_loader = TemporalDataLoader(train_data, batch_size=batch_size)
val_loader = TemporalDataLoader(val_data, batch_size=batch_size)
test_loader = TemporalDataLoader(test_data, batch_size=batch_size)
def test_n_upate(loader):
label_t = dataset.get_label_time() # check when does the first label start
num_label_ts = 0
total_score = 0
for batch in loader:
batch = batch.to(device)
src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
query_t = batch.t[-1]
if query_t > label_t:
label_tuple = dataset.get_node_label(query_t)
if label_tuple is None:
break
label_ts, label_srcs, labels = (
label_tuple[0],
label_tuple[1],
label_tuple[2],
)
label_ts = label_ts.numpy()
label_srcs = label_srcs.numpy()
labels = labels.numpy()
label_t = dataset.get_label_time()
preds = []
for i in range(0, label_srcs.shape[0]):
node_id = label_srcs[i]
pred_vec = forecaster.query_dict(node_id)
preds.append(pred_vec)
forecaster.update_dict(node_id, labels[i])
np_pred = np.stack(preds, axis=0)
np_true = labels
input_dict = {
"y_true": np_true,
"y_pred": np_pred,
"eval_metric": [eval_metric],
}
result_dict = evaluator.eval(input_dict)
score = result_dict[eval_metric]
total_score += score
num_label_ts += 1
metric_dict = {}
metric_dict[eval_metric] = total_score / num_label_ts
return metric_dict
"""
train, val and test for one epoch only
"""
start_time = timeit.default_timer()
metric_dict = test_n_upate(train_loader)
print(metric_dict)
print(
"Persistant forecast on Training takes--- %s seconds ---"
% (timeit.default_timer() - start_time)
)
start_time = timeit.default_timer()
val_dict = test_n_upate(val_loader)
print(val_dict)
print(
"Persistant forecast on validation takes--- %s seconds ---"
% (timeit.default_timer() - start_time)
)
start_time = timeit.default_timer()
test_dict = test_n_upate(test_loader)
print(test_dict)
print(
"Persistant forecast on Test takes--- %s seconds ---" % (timeit.default_timer() - start_time)
)
dataset.reset_label_time()
================================================
FILE: examples/nodeproppred/tgbn-genre/persistant_forecast.py
================================================
"""
implement persistant forecast as baseline for the node prop pred task
simply predict last seen label for the node
"""
import timeit
import numpy as np
from torch_geometric.loader import TemporalDataLoader
from tqdm import tqdm
import torch
# local imports
from tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset
from modules.heuristics import PersistantForecaster
from tgb.nodeproppred.evaluate import Evaluator
device = "cpu"
name = "tgbn-genre"
dataset = PyGNodePropPredDataset(name=name, root="datasets")
num_classes = dataset.num_classes
data = dataset.get_TemporalData()
data = data.to(device)
all_nodes = torch.cat((data.src, data.dst), 0)
all_nodes = all_nodes.unique()
print (all_nodes.shape[0])
eval_metric = dataset.eval_metric
forecaster = PersistantForecaster(num_classes)
evaluator = Evaluator(name=name)
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
train_data, val_data, test_data = data.train_val_test_split(
val_ratio=0.15, test_ratio=0.15
)
batch_size = 200
train_loader = TemporalDataLoader(train_data, batch_size=batch_size)
val_loader = TemporalDataLoader(val_data, batch_size=batch_size)
test_loader = TemporalDataLoader(test_data, batch_size=batch_size)
"""
continue debug here
"""
def test_n_upate(loader):
label_t = dataset.get_label_time() # check when does the first label start
num_label_ts = 0
total_score = 0
for batch in tqdm(loader):
batch = batch.to(device)
src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
query_t = batch.t[-1]
if query_t > label_t:
label_tuple = dataset.get_node_label(query_t)
if label_tuple is None:
break
label_ts, label_srcs, labels = (
label_tuple[0],
label_tuple[1],
label_tuple[2],
)
label_ts = label_ts.numpy()
label_srcs = label_srcs.numpy()
labels = labels.numpy()
label_t = dataset.get_label_time()
preds = []
for i in range(0, label_srcs.shape[0]):
node_id = label_srcs[i]
pred_vec = forecaster.query_dict(node_id)
preds.append(pred_vec)
forecaster.update_dict(node_id, labels[i])
np_pred = np.stack(preds, axis=0)
np_true = labels
input_dict = {
"y_true": np_true,
"y_pred": np_pred,
"eval_metric": [eval_metric],
}
result_dict = evaluator.eval(input_dict)
score = result_dict[eval_metric]
total_score += score
num_label_ts += 1
metric_dict = {}
metric_dict[eval_metric] = total_score / num_label_ts
return metric_dict
"""
train, val and test for one epoch only
"""
start_time = timeit.default_timer()
metric_dict = test_n_upate(train_loader)
print(metric_dict)
print(
"Persistant forecast on Training takes--- %s seconds ---"
% (timeit.default_timer() - start_time)
)
start_time = timeit.default_timer()
val_dict = test_n_upate(val_loader)
print(val_dict)
print(
"Persistant forecast on validation takes--- %s seconds ---"
% (timeit.default_timer() - start_time)
)
start_time = timeit.default_timer()
test_dict = test_n_upate(test_loader)
print(test_dict)
print(
"Persistant forecast on Test takes--- %s seconds ---" % (timeit.default_timer() - start_time)
)
dataset.reset_label_time()
================================================
FILE: examples/nodeproppred/tgbn-genre/tgn.py
================================================
from tqdm import tqdm
import torch
import timeit
import argparse
import matplotlib.pyplot as plt
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TGNMemory
from torch_geometric.nn.models.tgn import (
IdentityMessage,
LastAggregator,
LastNeighborLoader,
)
from modules.decoder import NodePredictor
from modules.emb_module import GraphAttentionEmbedding
from tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset
from tgb.nodeproppred.evaluate import Evaluator
from tgb.utils.utils import set_random_seed
parser = argparse.ArgumentParser(description='parsing command line arguments as hyperparameters')
parser.add_argument('-s', '--seed', type=int, default=1,
help='random seed to use')
parser.parse_args()
args = parser.parse_args()
# setting random seed
seed = int(args.seed) #1,2,3,4,5
print ("setting random seed to be", seed)
torch.manual_seed(seed)
set_random_seed(seed)
# hyperparameters
lr = 0.0001
epochs = 50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
name = "tgbn-genre"
dataset = PyGNodePropPredDataset(name=name, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
eval_metric = dataset.eval_metric
num_classes = dataset.num_classes
data = dataset.get_TemporalData()
data = data.to(device)
evaluator = Evaluator(name=name)
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
batch_size = 200
train_loader = TemporalDataLoader(train_data, batch_size=batch_size)
val_loader = TemporalDataLoader(val_data, batch_size=batch_size)
test_loader = TemporalDataLoader(test_data, batch_size=batch_size)
neighbor_loader = LastNeighborLoader(data.num_nodes, size=10, device=device)
memory_dim = time_dim = embedding_dim = 100
memory = TGNMemory(
data.num_nodes,
data.msg.size(-1),
memory_dim,
time_dim,
message_module=IdentityMessage(data.msg.size(-1), memory_dim, time_dim),
aggregator_module=LastAggregator(),
).to(device)
gnn = (
GraphAttentionEmbedding(
in_channels=memory_dim,
out_channels=embedding_dim,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
)
.to(device)
.float()
)
node_pred = NodePredictor(in_dim=embedding_dim, out_dim=num_classes).to(device)
optimizer = torch.optim.Adam(
set(memory.parameters()) | set(gnn.parameters()) | set(node_pred.parameters()),
lr=lr,
)
criterion = torch.nn.CrossEntropyLoss()
# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)
def plot_curve(scores, out_name):
plt.plot(scores, color="#e34a33")
plt.ylabel("score")
plt.savefig(out_name + ".pdf")
plt.close()
def process_edges(src, dst, t, msg):
if src.nelement() > 0:
# msg = msg.to(torch.float32)
memory.update_state(src, dst, t, msg)
neighbor_loader.insert(src, dst)
def train():
memory.train()
gnn.train()
node_pred.train()
memory.reset_state() # Start with a fresh memory.
neighbor_loader.reset_state() # Start with an empty graph.
total_loss = 0
label_t = dataset.get_label_time() # check when does the first label start
total_score = 0
num_label_ts = 0
for batch in tqdm(train_loader):
batch = batch.to(device)
optimizer.zero_grad()
src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
query_t = batch.t[-1]
# check if this batch moves to the next day
if query_t > label_t:
# find the node labels from the past day
label_tuple = dataset.get_node_label(query_t)
label_ts, label_srcs, labels = (
label_tuple[0],
label_tuple[1],
label_tuple[2],
)
label_t = dataset.get_label_time()
label_srcs = label_srcs.to(device)
# Process all edges that are still in the past day
previous_day_mask = batch.t < label_t
process_edges(
src[previous_day_mask],
dst[previous_day_mask],
t[previous_day_mask],
msg[previous_day_mask],
)
# Reset edges to be the edges from tomorrow so they can be used later
src, dst, t, msg = (
src[~previous_day_mask],
dst[~previous_day_mask],
t[~previous_day_mask],
msg[~previous_day_mask],
)
"""
modified for node property prediction
1. sample neighbors from the neighbor loader for all nodes to be predicted
2. extract memory from the sampled neighbors and the nodes
3. run gnn with the extracted memory embeddings and the corresponding time and message
"""
n_id = label_srcs
n_id_neighbors, mem_edge_index, e_id = neighbor_loader(n_id)
assoc[n_id_neighbors] = torch.arange(n_id_neighbors.size(0), device=device)
z, last_update = memory(n_id_neighbors)
z = gnn(
z,
last_update,
mem_edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
z = z[assoc[n_id]]
# loss and metric computation
pred = node_pred(z)
loss = criterion(pred, labels.to(device))
np_pred = pred.cpu().detach().numpy()
np_true = labels.cpu().detach().numpy()
input_dict = {
"y_true": np_true,
"y_pred": np_pred,
"eval_metric": [eval_metric],
}
result_dict = evaluator.eval(input_dict)
score = result_dict[eval_metric]
total_score += score
num_label_ts += 1
loss.backward()
optimizer.step()
total_loss += float(loss.detach())
# Update memory and neighbor loader with ground-truth state.
process_edges(src, dst, t, msg)
memory.detach()
metric_dict = {
"ce": total_loss / num_label_ts,
}
metric_dict[eval_metric] = total_score / num_label_ts
return metric_dict
@torch.no_grad()
def test(loader):
memory.eval()
gnn.eval()
node_pred.eval()
label_t = dataset.get_label_time() # check when does the first label start
num_label_ts = 0
total_score = 0
for batch in tqdm(loader):
batch = batch.to(device)
src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
query_t = batch.t[-1]
if query_t > label_t:
label_tuple = dataset.get_node_label(query_t)
if label_tuple is None:
break
label_ts, label_srcs, labels = (
label_tuple[0],
label_tuple[1],
label_tuple[2],
)
label_t = dataset.get_label_time()
label_srcs = label_srcs.to(device)
# Process all edges that are still in the past day
previous_day_mask = batch.t < label_t
process_edges(
src[previous_day_mask],
dst[previous_day_mask],
t[previous_day_mask],
msg[previous_day_mask],
)
# Reset edges to be the edges from tomorrow so they can be used later
src, dst, t, msg = (
src[~previous_day_mask],
dst[~previous_day_mask],
t[~previous_day_mask],
msg[~previous_day_mask],
)
"""
modified for node property prediction
1. sample neighbors from the neighbor loader for all nodes to be predicted
2. extract memory from the sampled neighbors and the nodes
3. run gnn with the extracted memory embeddings and the corresponding time and message
"""
n_id = label_srcs
n_id_neighbors, mem_edge_index, e_id = neighbor_loader(n_id)
assoc[n_id_neighbors] = torch.arange(n_id_neighbors.size(0), device=device)
z, last_update = memory(n_id_neighbors)
z = gnn(
z,
last_update,
mem_edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
z = z[assoc[n_id]]
# loss and metric computation
pred = node_pred(z)
np_pred = pred.cpu().detach().numpy()
np_true = labels.cpu().detach().numpy()
input_dict = {
"y_true": np_true,
"y_pred": np_pred,
"eval_metric": [eval_metric],
}
result_dict = evaluator.eval(input_dict)
score = result_dict[eval_metric]
total_score += score
num_label_ts += 1
process_edges(src, dst, t, msg)
metric_dict = {}
metric_dict[eval_metric] = total_score / num_label_ts
return metric_dict
train_curve = []
val_curve = []
test_curve = []
max_val_score = 0 #find the best test score based on validation score
best_test_idx = 0
for epoch in range(1, epochs + 1):
start_time = timeit.default_timer()
train_dict = train()
print("------------------------------------")
print(f"training Epoch: {epoch:02d}")
print(train_dict)
train_curve.append(train_dict[eval_metric])
print("Training takes--- %s seconds ---" % (timeit.default_timer() - start_time))
start_time = timeit.default_timer()
val_dict = test(val_loader)
print(val_dict)
val_curve.append(val_dict[eval_metric])
if (val_dict[eval_metric] > max_val_score):
max_val_score = val_dict[eval_metric]
best_test_idx = epoch - 1
print("Validation takes--- %s seconds ---" % (timeit.default_timer() - start_time))
start_time = timeit.default_timer()
test_dict = test(test_loader)
print(test_dict)
test_curve.append(test_dict[eval_metric])
print("Test takes--- %s seconds ---" % (timeit.default_timer() - start_time))
print("------------------------------------")
dataset.reset_label_time()
# # code for plotting
# plot_curve(train_curve, "train_curve")
# plot_curve(val_curve, "val_curve")
# plot_curve(test_curve, "test_curve")
max_test_score = test_curve[best_test_idx]
print("------------------------------------")
print("------------------------------------")
print ("best val score: ", max_val_score)
print ("best validation epoch : ", best_test_idx + 1)
print ("best test score: ", max_test_score)
================================================
FILE: examples/nodeproppred/tgbn-reddit/dyrep.py
================================================
"""
DyRep
This has been implemented with intuitions from the following sources:
- https://github.com/twitter-research/tgn
- https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py
Spec.:
- Memory Updater: RNN
- Embedding Module: ID
- Message Function: ATTN
"""
import timeit
from tqdm import tqdm
import torch
from torch_geometric.loader import TemporalDataLoader
import numpy as np
# internal imports
from tgb.utils.utils import get_args, set_random_seed
from tgb.nodeproppred.evaluate import Evaluator
from modules.decoder import NodePredictor
from modules.emb_module import GraphAttentionEmbedding
from modules.msg_func import IdentityMessage
from modules.msg_agg import LastAggregator
from modules.neighbor_loader import LastNeighborLoader
from modules.memory_module import DyRepMemory
from modules.early_stopping import EarlyStopMonitor
from tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset
def process_edges(src, dst, t, msg):
if src.nelement() > 0:
model['memory'].update_state(src, dst, t, msg)
neighbor_loader.insert(src, dst)
# ==========
# ========== Define helper function...
# ==========
def train():
model['memory'].train()
model['gnn'].train()
model['node_pred'].train()
model['memory'].reset_state() # Start with a fresh memory.
neighbor_loader.reset_state() # Start with an empty graph.
total_loss = 0
label_t = dataset.get_label_time() # check when does the first label start
num_label_ts = 0
total_score = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
query_t = batch.t[-1]
# check if this batch moves to the next day
if query_t > label_t:
# find the node labels from the past day
label_tuple = dataset.get_node_label(query_t)
label_ts, label_srcs, labels = (
label_tuple[0],
label_tuple[1],
label_tuple[2],
)
label_t = dataset.get_label_time()
label_srcs = label_srcs.to(device)
# Process all edges that are still in the past day
previous_day_mask = batch.t < label_t
process_edges(
src[previous_day_mask],
dst[previous_day_mask],
t[previous_day_mask],
msg[previous_day_mask],
)
# Reset edges to be the edges from tomorrow so they can be used later
src, dst, t, msg = (
src[~previous_day_mask],
dst[~previous_day_mask],
t[~previous_day_mask],
msg[~previous_day_mask],
)
"""
modified for node property prediction
1. sample neighbors from the neighbor loader for all nodes to be predicted
2. extract memory from the sampled neighbors and the nodes
3. run gnn with the extracted memory embeddings and the corresponding time and message
"""
n_id = label_srcs
n_id_neighbors, mem_edge_index, e_id = neighbor_loader(n_id)
assoc[n_id_neighbors] = torch.arange(n_id_neighbors.size(0), device=device)
z, last_update = model['memory'](n_id_neighbors)
z = model['gnn'](
z,
last_update,
mem_edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
z = z[assoc[n_id]]
# loss and metric computation
pred = model['node_pred'](z)
loss = criterion(pred, labels.to(device))
np_pred = pred.cpu().detach().numpy()
np_true = labels.cpu().detach().numpy()
input_dict = {
"y_true": np_true,
"y_pred": np_pred,
"eval_metric": [metric],
}
result_dict = evaluator.eval(input_dict)
score = result_dict[metric]
total_score += score
num_label_ts += 1
loss.backward()
optimizer.step()
total_loss += float(loss.detach())
# Update memory and neighbor loader with ground-truth state.
process_edges(src, dst, t, msg)
model['memory'].detach()
metric_dict = {
"ce": total_loss / num_label_ts,
}
metric_dict[metric] = total_score / num_label_ts
return metric_dict
@torch.no_grad()
def test(loader):
model['memory'].eval()
model['gnn'].eval()
model['node_pred'].eval()
total_score = 0
label_t = dataset.get_label_time() # check when does the first label start
num_label_ts = 0
for batch in tqdm(loader):
batch = batch.to(device)
src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
query_t = batch.t[-1]
if query_t > label_t:
label_tuple = dataset.get_node_label(query_t)
if label_tuple is None:
break
label_ts, label_srcs, labels = (
label_tuple[0],
label_tuple[1],
label_tuple[2],
)
label_t = dataset.get_label_time()
label_srcs = label_srcs.to(device)
# Process all edges that are still in the past day
previous_day_mask = batch.t < label_t
process_edges(
src[previous_day_mask],
dst[previous_day_mask],
t[previous_day_mask],
msg[previous_day_mask],
)
# Reset edges to be the edges from tomorrow so they can be used later
src, dst, t, msg = (
src[~previous_day_mask],
dst[~previous_day_mask],
t[~previous_day_mask],
msg[~previous_day_mask],
)
"""
modified for node property prediction
1. sample neighbors from the neighbor loader for all nodes to be predicted
2. extract memory from the sampled neighbors and the nodes
3. run gnn with the extracted memory embeddings and the corresponding time and message
"""
n_id = label_srcs
n_id_neighbors, mem_edge_index, e_id = neighbor_loader(n_id)
assoc[n_id_neighbors] = torch.arange(n_id_neighbors.size(0), device=device)
z, last_update = model['memory'](n_id_neighbors)
z = model['gnn'](
z,
last_update,
mem_edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
z = z[assoc[n_id]]
# loss and metric computation
pred = model['node_pred'](z)
np_pred = pred.cpu().detach().numpy()
np_true = labels.cpu().detach().numpy()
input_dict = {
"y_true": np_true,
"y_pred": np_pred,
"eval_metric": [metric],
}
result_dict = evaluator.eval(input_dict)
score = result_dict[metric]
total_score += score
num_label_ts += 1
process_edges(src, dst, t, msg)
metric_dict = {}
metric_dict[metric] = total_score / num_label_ts
return metric_dict
# ==========
# ==========
# ==========
# Start...
start_overall = timeit.default_timer()
# ========== set parameters...
args, _ = get_args()
print("INFO: Arguments:", args)
DATA = "tgbn-reddit"
LR = args.lr
BATCH_SIZE = args.bs
K_VALUE = args.k_value
NUM_EPOCH = args.num_epoch
SEED = args.seed
MEM_DIM = args.mem_dim
TIME_DIM = args.time_dim
EMB_DIM = args.emb_dim
TOLERANCE = args.tolerance
PATIENCE = args.patience
NUM_RUNS = args.num_run
NUM_NEIGHBORS = 10
# setting random seed
torch.manual_seed(SEED)
set_random_seed(SEED)
MODEL_NAME = 'DyRep'
USE_SRC_EMB_IN_MSG = False
USE_DST_EMB_IN_MSG = True
evaluator = Evaluator(name=DATA)
# ==========
# set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# data loading
dataset = PyGNodePropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data = data.to(device)
metric = dataset.eval_metric
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
num_classes = dataset.num_classes
train_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)
val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)
test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
# neighborhood sampler
neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)
# define the model end-to-end
memory = DyRepMemory(
data.num_nodes,
data.msg.size(-1),
MEM_DIM,
TIME_DIM,
message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),
aggregator_module=LastAggregator(),
memory_updater_type='rnn',
use_src_emb_in_msg=USE_SRC_EMB_IN_MSG,
use_dst_emb_in_msg=USE_DST_EMB_IN_MSG
).to(device)
gnn = GraphAttentionEmbedding(
in_channels=MEM_DIM,
out_channels=EMB_DIM,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
).to(device)
node_pred = NodePredictor(in_dim=EMB_DIM, out_dim=num_classes).to(device)
model = {'memory': memory,
'gnn': gnn,
'node_pred': node_pred}
optimizer = torch.optim.Adam(
set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['node_pred'].parameters()),
lr=LR,
)
criterion = torch.nn.CrossEntropyLoss()
# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)
train_curve = []
val_curve = []
test_curve = []
max_val_score = 0 #find the best test score based on validation score
best_test_idx = 0
for epoch in range(1, NUM_EPOCH + 1):
start_time = timeit.default_timer()
train_dict = train()
print("------------------------------------")
print(f"training Epoch: {epoch:02d}")
print(train_dict)
train_curve.append(train_dict[metric])
print("Training takes--- %s seconds ---" % (timeit.default_timer() - start_time))
start_time = timeit.default_timer()
val_dict = test(val_loader)
print(val_dict)
val_curve.append(val_dict[metric])
if (val_dict[metric] > max_val_score):
max_val_score = val_dict[metric]
best_test_idx = epoch - 1
print("Validation takes--- %s seconds ---" % (timeit.default_timer() - start_time))
start_time = timeit.default_timer()
test_dict = test(test_loader)
print(test_dict)
test_curve.append(test_dict[metric])
print("Test takes--- %s seconds ---" % (timeit.default_timer() - start_time))
print("------------------------------------")
dataset.reset_label_time()
max_test_score = test_curve[best_test_idx]
print("------------------------------------")
print("------------------------------------")
print ("best val score: ", max_val_score)
print ("best validation epoch : ", best_test_idx + 1)
print ("best test score: ", max_test_score)
================================================
FILE: examples/nodeproppred/tgbn-reddit/moving_average.py
================================================
"""
implement persistant forecast as baseline for the node prop pred task
simply predict last seen label for the node
"""
import timeit
import numpy as np
from torch_geometric.loader import TemporalDataLoader
# local imports
from tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset
from modules.heuristics import MovingAverage
from tgb.nodeproppred.evaluate import Evaluator
window = 7
device = "cpu"
name = "tgbn-reddit"
dataset = PyGNodePropPredDataset(name=name, root="datasets")
num_classes = dataset.num_classes
data = dataset.get_TemporalData()
data = data.to(device)
eval_metric = dataset.eval_metric
forecaster = MovingAverage(num_classes, window=window)
evaluator = Evaluator(name=name)
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
train_data, val_data, test_data = data.train_val_test_split(
val_ratio=0.15, test_ratio=0.15
)
batch_size = 200
train_loader = TemporalDataLoader(train_data, batch_size=batch_size)
val_loader = TemporalDataLoader(val_data, batch_size=batch_size)
test_loader = TemporalDataLoader(test_data, batch_size=batch_size)
def test_n_upate(loader):
label_t = dataset.get_label_time() # check when does the first label start
num_label_ts = 0
total_score = 0
for batch in loader:
batch = batch.to(device)
src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
query_t = batch.t[-1]
if query_t > label_t:
label_tuple = dataset.get_node_label(query_t)
if label_tuple is None:
break
label_ts, label_srcs, labels = (
label_tuple[0],
label_tuple[1],
label_tuple[2],
)
label_ts = label_ts.numpy()
label_srcs = label_srcs.numpy()
labels = labels.numpy()
label_t = dataset.get_label_time()
preds = []
for i in range(0, label_srcs.shape[0]):
node_id = label_srcs[i]
pred_vec = forecaster.query_dict(node_id)
preds.append(pred_vec)
forecaster.update_dict(node_id, labels[i])
np_pred = np.stack(preds, axis=0)
np_true = labels
input_dict = {
"y_true": np_true,
"y_pred": np_pred,
"eval_metric": [eval_metric],
}
result_dict = evaluator.eval(input_dict)
score = result_dict[eval_metric]
total_score += score
num_label_ts += 1
metric_dict = {}
metric_dict[eval_metric] = total_score / num_label_ts
return metric_dict
"""
train, val and test for one epoch only
"""
start_time = timeit.default_timer()
metric_dict = test_n_upate(train_loader)
print(metric_dict)
print(
"Persistant forecast on Training takes--- %s seconds ---"
% (timeit.default_timer() - start_time)
)
start_time = timeit.default_timer()
val_dict = test_n_upate(val_loader)
print(val_dict)
print(
"Persistant forecast on validation takes--- %s seconds ---"
% (timeit.default_timer() - start_time)
)
start_time = timeit.default_timer()
test_dict = test_n_upate(test_loader)
print(test_dict)
print(
"Persistant forecast on Test takes--- %s seconds ---" % (timeit.default_timer() - start_time)
)
dataset.reset_label_time()
================================================
FILE: examples/nodeproppred/tgbn-reddit/persistant_forecast.py
================================================
"""
implement persistant forecast as baseline for the node prop pred task
simply predict last seen label for the node
"""
import timeit
import numpy as np
from torch_geometric.loader import TemporalDataLoader
from tqdm import tqdm
import torch
# local imports
from tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset
from modules.heuristics import PersistantForecaster
from tgb.nodeproppred.evaluate import Evaluator
device = "cpu"
name = "tgbn-reddit"
dataset = PyGNodePropPredDataset(name=name, root="datasets")
num_classes = dataset.num_classes
data = dataset.get_TemporalData()
data = data.to(device)
all_nodes = torch.cat((data.src, data.dst), 0)
all_nodes = all_nodes.unique()
print (all_nodes.shape[0])
eval_metric = dataset.eval_metric
forecaster = PersistantForecaster(num_classes)
evaluator = Evaluator(name=name)
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
train_data, val_data, test_data = data.train_val_test_split(
val_ratio=0.15, test_ratio=0.15
)
batch_size = 200
train_loader = TemporalDataLoader(train_data, batch_size=batch_size)
val_loader = TemporalDataLoader(val_data, batch_size=batch_size)
test_loader = TemporalDataLoader(test_data, batch_size=batch_size)
"""
continue debug here
"""
def test_n_upate(loader):
label_t = dataset.get_label_time() # check when does the first label start
num_label_ts = 0
total_score = 0
for batch in tqdm(loader):
batch = batch.to(device)
src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
query_t = batch.t[-1]
if query_t > label_t:
label_tuple = dataset.get_node_label(query_t)
if label_tuple is None:
break
label_ts, label_srcs, labels = (
label_tuple[0],
label_tuple[1],
label_tuple[2],
)
label_ts = label_ts.numpy()
label_srcs = label_srcs.numpy()
labels = labels.numpy()
label_t = dataset.get_label_time()
preds = []
for i in range(0, label_srcs.shape[0]):
node_id = label_srcs[i]
pred_vec = forecaster.query_dict(node_id)
preds.append(pred_vec)
forecaster.update_dict(node_id, labels[i])
np_pred = np.stack(preds, axis=0)
np_true = labels
input_dict = {
"y_true": np_true,
"y_pred": np_pred,
"eval_metric": [eval_metric],
}
result_dict = evaluator.eval(input_dict)
score = result_dict[eval_metric]
total_score += score
num_label_ts += 1
metric_dict = {}
metric_dict[eval_metric] = total_score / num_label_ts
return metric_dict
"""
train, val and test for one epoch only
"""
start_time = timeit.default_timer()
metric_dict = test_n_upate(train_loader)
print(metric_dict)
print(
"Persistant forecast on Training takes--- %s seconds ---"
% (timeit.default_timer() - start_time)
)
start_time = timeit.default_timer()
val_dict = test_n_upate(val_loader)
print(val_dict)
print(
"Persistant forecast on validation takes--- %s seconds ---"
% (timeit.default_timer() - start_time)
)
start_time = timeit.default_timer()
test_dict = test_n_upate(test_loader)
print(test_dict)
print(
"Persistant forecast on Test takes--- %s seconds ---" % (timeit.default_timer() - start_time)
)
dataset.reset_label_time()
================================================
FILE: examples/nodeproppred/tgbn-reddit/tgn.py
================================================
import timeit
import argparse
from tqdm import tqdm
import torch
import matplotlib.pyplot as plt
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TGNMemory
from torch_geometric.nn.models.tgn import (
IdentityMessage,
LastAggregator,
LastNeighborLoader,
)
from modules.decoder import NodePredictor
from modules.emb_module import GraphAttentionEmbedding
from tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset
from tgb.nodeproppred.evaluate import Evaluator
from tgb.utils.utils import set_random_seed
from tgb.utils.stats import plot_curve
parser = argparse.ArgumentParser(description='parsing command line arguments as hyperparameters')
parser.add_argument('-s', '--seed', type=int, default=1,
help='random seed to use')
parser.parse_args()
args = parser.parse_args()
# setting random seed
seed = int(args.seed) #1,2,3,4,5
print ("setting random seed to be", seed)
torch.manual_seed(seed)
set_random_seed(seed)
# hyperparameters
lr = 0.0001
epochs = 50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
name = "tgbn-reddit"
dataset = PyGNodePropPredDataset(name=name, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
eval_metric = dataset.eval_metric
num_classes = dataset.num_classes
data = dataset.get_TemporalData()
data = data.to(device)
evaluator = Evaluator(name=name)
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
batch_size = 200
train_loader = TemporalDataLoader(train_data, batch_size=batch_size)
val_loader = TemporalDataLoader(val_data, batch_size=batch_size)
test_loader = TemporalDataLoader(test_data, batch_size=batch_size)
neighbor_loader = LastNeighborLoader(data.num_nodes, size=10, device=device)
memory_dim = time_dim = embedding_dim = 100
memory = TGNMemory(
data.num_nodes,
data.msg.size(-1),
memory_dim,
time_dim,
message_module=IdentityMessage(data.msg.size(-1), memory_dim, time_dim),
aggregator_module=LastAggregator(),
).to(device)
gnn = (
GraphAttentionEmbedding(
in_channels=memory_dim,
out_channels=embedding_dim,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
)
.to(device)
.float()
)
node_pred = NodePredictor(in_dim=embedding_dim, out_dim=num_classes).to(device)
optimizer = torch.optim.Adam(
set(memory.parameters()) | set(gnn.parameters()) | set(node_pred.parameters()),
lr=lr,
)
criterion = torch.nn.CrossEntropyLoss()
# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)
def plot_curve(scores, out_name):
plt.plot(scores, color="#e34a33")
plt.ylabel("score")
plt.savefig(out_name + ".pdf")
plt.close()
def process_edges(src, dst, t, msg):
if src.nelement() > 0:
# msg = msg.to(torch.float32)
memory.update_state(src, dst, t, msg)
neighbor_loader.insert(src, dst)
def train():
memory.train()
gnn.train()
node_pred.train()
memory.reset_state() # Start with a fresh memory.
neighbor_loader.reset_state() # Start with an empty graph.
total_loss = 0
label_t = dataset.get_label_time() # check when does the first label start
total_score = 0
num_label_ts = 0
for batch in tqdm(train_loader):
batch = batch.to(device)
optimizer.zero_grad()
src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
query_t = batch.t[-1]
# check if this batch moves to the next day
if query_t > label_t:
# find the node labels from the past day
label_tuple = dataset.get_node_label(query_t)
label_ts, label_srcs, labels = (
label_tuple[0],
label_tuple[1],
label_tuple[2],
)
label_t = dataset.get_label_time()
label_srcs = label_srcs.to(device)
# Process all edges that are still in the past day
previous_day_mask = batch.t < label_t
process_edges(
src[previous_day_mask],
dst[previous_day_mask],
t[previous_day_mask],
msg[previous_day_mask],
)
# Reset edges to be the edges from tomorrow so they can be used later
src, dst, t, msg = (
src[~previous_day_mask],
dst[~previous_day_mask],
t[~previous_day_mask],
msg[~previous_day_mask],
)
"""
modified for node property prediction
1. sample neighbors from the neighbor loader for all nodes to be predicted
2. extract memory from the sampled neighbors and the nodes
3. run gnn with the extracted memory embeddings and the corresponding time and message
"""
n_id = label_srcs
n_id_neighbors, mem_edge_index, e_id = neighbor_loader(n_id)
assoc[n_id_neighbors] = torch.arange(n_id_neighbors.size(0), device=device)
z, last_update = memory(n_id_neighbors)
z = gnn(
z,
last_update,
mem_edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
z = z[assoc[n_id]]
# loss and metric computation
pred = node_pred(z)
loss = criterion(pred, labels.to(device))
np_pred = pred.cpu().detach().numpy()
np_true = labels.cpu().detach().numpy()
input_dict = {
"y_true": np_true,
"y_pred": np_pred,
"eval_metric": [eval_metric],
}
result_dict = evaluator.eval(input_dict)
score = result_dict[eval_metric]
total_score += score
num_label_ts += 1
loss.backward()
optimizer.step()
total_loss += float(loss.detach())
# Update memory and neighbor loader with ground-truth state.
process_edges(src, dst, t, msg)
memory.detach()
metric_dict = {
"ce": total_loss / num_label_ts,
}
metric_dict[eval_metric] = total_score / num_label_ts
return metric_dict
@torch.no_grad()
def test(loader):
memory.eval()
gnn.eval()
node_pred.eval()
label_t = dataset.get_label_time() # check when does the first label start
num_label_ts = 0
total_score = 0
for batch in tqdm(loader):
batch = batch.to(device)
src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
query_t = batch.t[-1]
if query_t > label_t:
label_tuple = dataset.get_node_label(query_t)
if label_tuple is None:
break
label_ts, label_srcs, labels = (
label_tuple[0],
label_tuple[1],
label_tuple[2],
)
label_t = dataset.get_label_time()
label_srcs = label_srcs.to(device)
# Process all edges that are still in the past day
previous_day_mask = batch.t < label_t
process_edges(
src[previous_day_mask],
dst[previous_day_mask],
t[previous_day_mask],
msg[previous_day_mask],
)
# Reset edges to be the edges from tomorrow so they can be used later
src, dst, t, msg = (
src[~previous_day_mask],
dst[~previous_day_mask],
t[~previous_day_mask],
msg[~previous_day_mask],
)
"""
modified for node property prediction
1. sample neighbors from the neighbor loader for all nodes to be predicted
2. extract memory from the sampled neighbors and the nodes
3. run gnn with the extracted memory embeddings and the corresponding time and message
"""
n_id = label_srcs
n_id_neighbors, mem_edge_index, e_id = neighbor_loader(n_id)
assoc[n_id_neighbors] = torch.arange(n_id_neighbors.size(0), device=device)
z, last_update = memory(n_id_neighbors)
z = gnn(
z,
last_update,
mem_edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
z = z[assoc[n_id]]
# loss and metric computation
pred = node_pred(z)
np_pred = pred.cpu().detach().numpy()
np_true = labels.cpu().detach().numpy()
input_dict = {
"y_true": np_true,
"y_pred": np_pred,
"eval_metric": [eval_metric],
}
result_dict = evaluator.eval(input_dict)
score = result_dict[eval_metric]
total_score += score
num_label_ts += 1
process_edges(src, dst, t, msg)
metric_dict = {}
metric_dict[eval_metric] = total_score / num_label_ts
return metric_dict
train_curve = []
val_curve = []
test_curve = []
max_val_score = 0 #find the best test score based on validation score
best_test_idx = 0
for epoch in range(1, epochs + 1):
start_time = timeit.default_timer()
train_dict = train()
print("------------------------------------")
print(f"training Epoch: {epoch:02d}")
print(train_dict)
train_curve.append(train_dict[eval_metric])
print("Training takes--- %s seconds ---" % (timeit.default_timer() - start_time))
start_time = timeit.default_timer()
val_dict = test(val_loader)
print(val_dict)
val_curve.append(val_dict[eval_metric])
if (val_dict[eval_metric] > max_val_score):
max_val_score = val_dict[eval_metric]
best_test_idx = epoch - 1
print("Validation takes--- %s seconds ---" % (timeit.default_timer() - start_time))
start_time = timeit.default_timer()
test_dict = test(test_loader)
print(test_dict)
test_curve.append(test_dict[eval_metric])
print("Test takes--- %s seconds ---" % (timeit.default_timer() - start_time))
print("------------------------------------")
dataset.reset_label_time()
# code for plotting
plot_curve(train_curve, "train_curve")
plot_curve(val_curve, "val_curve")
plot_curve(test_curve, "test_curve")
max_test_score = test_curve[best_test_idx]
print("------------------------------------")
print("------------------------------------")
print ("best val score: ", max_val_score)
print ("best validation epoch : ", best_test_idx + 1)
print ("best test score: ", max_test_score)
================================================
FILE: examples/nodeproppred/tgbn-token/dyrep.py
================================================
"""
DyRep
This has been implemented with intuitions from the following sources:
- https://github.com/twitter-research/tgn
- https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py
Spec.:
- Memory Updater: RNN
- Embedding Module: ID
- Message Function: ATTN
"""
import timeit
from tqdm import tqdm
import torch
from torch_geometric.loader import TemporalDataLoader
import numpy as np
# internal imports
from tgb.utils.utils import get_args, set_random_seed
from tgb.nodeproppred.evaluate import Evaluator
from modules.decoder import NodePredictor
from modules.emb_module import GraphAttentionEmbedding
from modules.msg_func import IdentityMessage
from modules.msg_agg import LastAggregator
from modules.neighbor_loader import LastNeighborLoader
from modules.memory_module import DyRepMemory
from modules.early_stopping import EarlyStopMonitor
from tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset
def process_edges(src, dst, t, msg):
if src.nelement() > 0:
model['memory'].update_state(src, dst, t, msg)
neighbor_loader.insert(src, dst)
# ==========
# ========== Define helper function...
# ==========
def train():
model['memory'].train()
model['gnn'].train()
model['node_pred'].train()
model['memory'].reset_state() # Start with a fresh memory.
neighbor_loader.reset_state() # Start with an empty graph.
total_loss = 0
label_t = dataset.get_label_time() # check when does the first label start
num_label_ts = 0
total_score = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
query_t = batch.t[-1]
# check if this batch moves to the next day
if query_t > label_t:
# find the node labels from the past day
label_tuple = dataset.get_node_label(query_t)
label_ts, label_srcs, labels = (
label_tuple[0],
label_tuple[1],
label_tuple[2],
)
label_t = dataset.get_label_time()
label_srcs = label_srcs.to(device)
# Process all edges that are still in the past day
previous_day_mask = batch.t < label_t
process_edges(
src[previous_day_mask],
dst[previous_day_mask],
t[previous_day_mask],
msg[previous_day_mask],
)
# Reset edges to be the edges from tomorrow so they can be used later
src, dst, t, msg = (
src[~previous_day_mask],
dst[~previous_day_mask],
t[~previous_day_mask],
msg[~previous_day_mask],
)
"""
modified for node property prediction
1. sample neighbors from the neighbor loader for all nodes to be predicted
2. extract memory from the sampled neighbors and the nodes
3. run gnn with the extracted memory embeddings and the corresponding time and message
"""
n_id = label_srcs
n_id_neighbors, mem_edge_index, e_id = neighbor_loader(n_id)
assoc[n_id_neighbors] = torch.arange(n_id_neighbors.size(0), device=device)
z, last_update = model['memory'](n_id_neighbors)
z = model['gnn'](
z,
last_update,
mem_edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
z = z[assoc[n_id]]
# loss and metric computation
pred = model['node_pred'](z)
loss = criterion(pred, labels.to(device))
np_pred = pred.cpu().detach().numpy()
np_true = labels.cpu().detach().numpy()
input_dict = {
"y_true": np_true,
"y_pred": np_pred,
"eval_metric": [metric],
}
result_dict = evaluator.eval(input_dict)
score = result_dict[metric]
total_score += score
num_label_ts += 1
loss.backward()
optimizer.step()
total_loss += float(loss.detach())
# Update memory and neighbor loader with ground-truth state.
process_edges(src, dst, t, msg)
model['memory'].detach()
metric_dict = {
"ce": total_loss / num_label_ts,
}
metric_dict[metric] = total_score / num_label_ts
return metric_dict
@torch.no_grad()
def test(loader):
model['memory'].eval()
model['gnn'].eval()
model['node_pred'].eval()
total_score = 0
label_t = dataset.get_label_time() # check when does the first label start
num_label_ts = 0
for batch in tqdm(loader):
batch = batch.to(device)
src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
query_t = batch.t[-1]
if query_t > label_t:
label_tuple = dataset.get_node_label(query_t)
if label_tuple is None:
break
label_ts, label_srcs, labels = (
label_tuple[0],
label_tuple[1],
label_tuple[2],
)
label_t = dataset.get_label_time()
label_srcs = label_srcs.to(device)
# Process all edges that are still in the past day
previous_day_mask = batch.t < label_t
process_edges(
src[previous_day_mask],
dst[previous_day_mask],
t[previous_day_mask],
msg[previous_day_mask],
)
# Reset edges to be the edges from tomorrow so they can be used later
src, dst, t, msg = (
src[~previous_day_mask],
dst[~previous_day_mask],
t[~previous_day_mask],
msg[~previous_day_mask],
)
"""
modified for node property prediction
1. sample neighbors from the neighbor loader for all nodes to be predicted
2. extract memory from the sampled neighbors and the nodes
3. run gnn with the extracted memory embeddings and the corresponding time and message
"""
n_id = label_srcs
n_id_neighbors, mem_edge_index, e_id = neighbor_loader(n_id)
assoc[n_id_neighbors] = torch.arange(n_id_neighbors.size(0), device=device)
z, last_update = model['memory'](n_id_neighbors)
z = model['gnn'](
z,
last_update,
mem_edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
z = z[assoc[n_id]]
# loss and metric computation
pred = model['node_pred'](z)
np_pred = pred.cpu().detach().numpy()
np_true = labels.cpu().detach().numpy()
input_dict = {
"y_true": np_true,
"y_pred": np_pred,
"eval_metric": [metric],
}
result_dict = evaluator.eval(input_dict)
score = result_dict[metric]
total_score += score
num_label_ts += 1
process_edges(src, dst, t, msg)
metric_dict = {}
metric_dict[metric] = total_score / num_label_ts
return metric_dict
# ==========
# ==========
# ==========
# Start...
start_overall = timeit.default_timer()
# ========== set parameters...
args, _ = get_args()
print("INFO: Arguments:", args)
DATA = "tgbn-token"
LR = args.lr
BATCH_SIZE = args.bs
K_VALUE = args.k_value
NUM_EPOCH = args.num_epoch
SEED = args.seed
MEM_DIM = args.mem_dim
TIME_DIM = args.time_dim
EMB_DIM = args.emb_dim
TOLERANCE = args.tolerance
PATIENCE = args.patience
NUM_RUNS = args.num_run
NUM_NEIGHBORS = 10
# setting random seed
torch.manual_seed(SEED)
set_random_seed(SEED)
MODEL_NAME = 'DyRep'
USE_SRC_EMB_IN_MSG = False
USE_DST_EMB_IN_MSG = True
evaluator = Evaluator(name=DATA)
# ==========
# set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# data loading
dataset = PyGNodePropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data = data.to(device)
metric = dataset.eval_metric
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
num_classes = dataset.num_classes
train_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)
val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)
test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
# neighborhood sampler
neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)
# define the model end-to-end
memory = DyRepMemory(
data.num_nodes,
data.msg.size(-1),
MEM_DIM,
TIME_DIM,
message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),
aggregator_module=LastAggregator(),
memory_updater_type='rnn',
use_src_emb_in_msg=USE_SRC_EMB_IN_MSG,
use_dst_emb_in_msg=USE_DST_EMB_IN_MSG
).to(device)
gnn = GraphAttentionEmbedding(
in_channels=MEM_DIM,
out_channels=EMB_DIM,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
).to(device)
node_pred = NodePredictor(in_dim=EMB_DIM, out_dim=num_classes).to(device)
model = {'memory': memory,
'gnn': gnn,
'node_pred': node_pred}
optimizer = torch.optim.Adam(
set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['node_pred'].parameters()),
lr=LR,
)
criterion = torch.nn.CrossEntropyLoss()
# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)
train_curve = []
val_curve = []
test_curve = []
max_val_score = 0 #find the best test score based on validation score
best_test_idx = 0
for epoch in range(1, NUM_EPOCH + 1):
start_time = timeit.default_timer()
train_dict = train()
print("------------------------------------")
print(f"training Epoch: {epoch:02d}")
print(train_dict)
train_curve.append(train_dict[metric])
print("Training takes--- %s seconds ---" % (timeit.default_timer() - start_time))
start_time = timeit.default_timer()
val_dict = test(val_loader)
print(val_dict)
val_curve.append(val_dict[metric])
if (val_dict[metric] > max_val_score):
max_val_score = val_dict[metric]
best_test_idx = epoch - 1
print("Validation takes--- %s seconds ---" % (timeit.default_timer() - start_time))
start_time = timeit.default_timer()
test_dict = test(test_loader)
print(test_dict)
test_curve.append(test_dict[metric])
print("Test takes--- %s seconds ---" % (timeit.default_timer() - start_time))
print("------------------------------------")
dataset.reset_label_time()
max_test_score = test_curve[best_test_idx]
print("------------------------------------")
print("------------------------------------")
print ("best val score: ", max_val_score)
print ("best validation epoch : ", best_test_idx + 1)
print ("best test score: ", max_test_score)
================================================
FILE: examples/nodeproppred/tgbn-token/moving_average.py
================================================
"""
implement persistant forecast as baseline for the node prop pred task
simply predict last seen label for the node
"""
import timeit
import numpy as np
from torch_geometric.loader import TemporalDataLoader
from tqdm import tqdm
# local imports
from tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset
from modules.heuristics import MovingAverage
from tgb.nodeproppred.evaluate import Evaluator
window = 7
device = "cpu"
name = "tgbn-token"
dataset = PyGNodePropPredDataset(name=name, root="datasets")
num_classes = dataset.num_classes
data = dataset.get_TemporalData()
data = data.to(device)
eval_metric = dataset.eval_metric
forecaster = MovingAverage(num_classes, window=window)
evaluator = Evaluator(name=name)
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
train_data, val_data, test_data = data.train_val_test_split(
val_ratio=0.15, test_ratio=0.15
)
batch_size = 200
train_loader = TemporalDataLoader(train_data, batch_size=batch_size)
val_loader = TemporalDataLoader(val_data, batch_size=batch_size)
test_loader = TemporalDataLoader(test_data, batch_size=batch_size)
def test_n_upate(loader):
label_t = dataset.get_label_time() # check when does the first label start
num_label_ts = 0
total_score = 0
for batch in tqdm(loader):
batch = batch.to(device)
src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
query_t = batch.t[-1]
if query_t > label_t:
label_tuple = dataset.get_node_label(query_t)
if label_tuple is None:
break
label_ts, label_srcs, labels = (
label_tuple[0],
label_tuple[1],
label_tuple[2],
)
label_ts = label_ts.numpy()
label_srcs = label_srcs.numpy()
labels = labels.numpy()
label_t = dataset.get_label_time()
preds = []
for i in range(0, label_srcs.shape[0]):
node_id = label_srcs[i]
pred_vec = forecaster.query_dict(node_id)
preds.append(pred_vec)
forecaster.update_dict(node_id, labels[i])
np_pred = np.stack(preds, axis=0)
np_true = labels
input_dict = {
"y_true": np_true,
"y_pred": np_pred,
"eval_metric": [eval_metric],
}
result_dict = evaluator.eval(input_dict)
score = result_dict[eval_metric]
total_score += score
num_label_ts += 1
metric_dict = {}
metric_dict[eval_metric] = total_score / num_label_ts
return metric_dict
"""
train, val and test for one epoch only
"""
start_time = timeit.default_timer()
metric_dict = test_n_upate(train_loader)
print(metric_dict)
print(
"Persistant forecast on Training takes--- %s seconds ---"
% (timeit.default_timer() - start_time)
)
start_time = timeit.default_timer()
val_dict = test_n_upate(val_loader)
print(val_dict)
print(
"Persistant forecast on validation takes--- %s seconds ---"
% (timeit.default_timer() - start_time)
)
start_time = timeit.default_timer()
test_dict = test_n_upate(test_loader)
print(test_dict)
print(
"Persistant forecast on Test takes--- %s seconds ---" % (timeit.default_timer() - start_time)
)
dataset.reset_label_time()
================================================
FILE: examples/nodeproppred/tgbn-token/persistant_forecast.py
================================================
"""
implement persistant forecast as baseline for the node prop pred task
simply predict last seen label for the node
"""
import timeit
import numpy as np
from torch_geometric.loader import TemporalDataLoader
from tqdm import tqdm
import torch
# local imports
from tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset
from modules.heuristics import PersistantForecaster
from tgb.nodeproppred.evaluate import Evaluator
device = "cpu"
name = "tgbn-token"
dataset = PyGNodePropPredDataset(name=name, root="datasets")
num_classes = dataset.num_classes
data = dataset.get_TemporalData()
data = data.to(device)
all_nodes = torch.cat((data.src, data.dst), 0)
all_nodes = all_nodes.unique()
print (all_nodes.shape[0])
eval_metric = dataset.eval_metric
forecaster = PersistantForecaster(num_classes)
evaluator = Evaluator(name=name)
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
train_data, val_data, test_data = data.train_val_test_split(
val_ratio=0.15, test_ratio=0.15
)
batch_size = 200
train_loader = TemporalDataLoader(train_data, batch_size=batch_size)
val_loader = TemporalDataLoader(val_data, batch_size=batch_size)
test_loader = TemporalDataLoader(test_data, batch_size=batch_size)
"""
continue debug here
"""
def test_n_upate(loader):
label_t = dataset.get_label_time() # check when does the first label start
num_label_ts = 0
total_score = 0
for batch in tqdm(loader):
batch = batch.to(device)
src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
query_t = batch.t[-1]
if query_t > label_t:
label_tuple = dataset.get_node_label(query_t)
if label_tuple is None:
break
label_ts, label_srcs, labels = (
label_tuple[0],
label_tuple[1],
label_tuple[2],
)
label_ts = label_ts.numpy()
label_srcs = label_srcs.numpy()
labels = labels.numpy()
label_t = dataset.get_label_time()
preds = []
for i in range(0, label_srcs.shape[0]):
node_id = label_srcs[i]
pred_vec = forecaster.query_dict(node_id)
preds.append(pred_vec)
forecaster.update_dict(node_id, labels[i])
np_pred = np.stack(preds, axis=0)
np_true = labels
input_dict = {
"y_true": np_true,
"y_pred": np_pred,
"eval_metric": [eval_metric],
}
result_dict = evaluator.eval(input_dict)
score = result_dict[eval_metric]
total_score += score
num_label_ts += 1
metric_dict = {}
metric_dict[eval_metric] = total_score / num_label_ts
return metric_dict
"""
train, val and test for one epoch only
"""
start_time = timeit.default_timer()
metric_dict = test_n_upate(train_loader)
print(metric_dict)
print(
"Persistant forecast on Training takes--- %s seconds ---"
% (timeit.default_timer() - start_time)
)
start_time = timeit.default_timer()
val_dict = test_n_upate(val_loader)
print(val_dict)
print(
"Persistant forecast on validation takes--- %s seconds ---"
% (timeit.default_timer() - start_time)
)
start_time = timeit.default_timer()
test_dict = test_n_upate(test_loader)
print(test_dict)
print(
"Persistant forecast on Test takes--- %s seconds ---" % (timeit.default_timer() - start_time)
)
dataset.reset_label_time()
================================================
FILE: examples/nodeproppred/tgbn-token/tgn.py
================================================
import timeit
import argparse
from tqdm import tqdm
import torch
import matplotlib.pyplot as plt
import numpy as np
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TGNMemory
from torch_geometric.nn.models.tgn import (
IdentityMessage,
LastAggregator,
LastNeighborLoader,
)
from modules.decoder import NodePredictor
from modules.emb_module import GraphAttentionEmbedding
from tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset
from tgb.nodeproppred.evaluate import Evaluator
from tgb.utils.utils import set_random_seed
from tgb.utils.stats import plot_curve
parser = argparse.ArgumentParser(description='parsing command line arguments as hyperparameters')
parser.add_argument('-s', '--seed', type=int, default=1,
help='random seed to use')
parser.parse_args()
args = parser.parse_args()
# setting random seed
seed = int(args.seed) #1,2,3,4,5
print ("setting random seed to be", seed)
torch.manual_seed(seed)
set_random_seed(seed)
# hyperparameters
lr = 0.0001
epochs = 50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
name = "tgbn-token"
dataset = PyGNodePropPredDataset(name=name, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
eval_metric = dataset.eval_metric
num_classes = dataset.num_classes
data = dataset.get_TemporalData()
data = data.to(device)
evaluator = Evaluator(name=name)
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
batch_size = 200
train_loader = TemporalDataLoader(train_data, batch_size=batch_size)
val_loader = TemporalDataLoader(val_data, batch_size=batch_size)
test_loader = TemporalDataLoader(test_data, batch_size=batch_size)
neighbor_loader = LastNeighborLoader(data.num_nodes, size=10, device=device)
memory_dim = time_dim = embedding_dim = 100
memory = TGNMemory(
data.num_nodes,
data.msg.size(-1),
memory_dim,
time_dim,
message_module=IdentityMessage(data.msg.size(-1), memory_dim, time_dim),
aggregator_module=LastAggregator(),
).to(device)
gnn = (
GraphAttentionEmbedding(
in_channels=memory_dim,
out_channels=embedding_dim,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
)
.to(device)
.float()
)
node_pred = NodePredictor(in_dim=embedding_dim, out_dim=num_classes).to(device)
optimizer = torch.optim.Adam(
set(memory.parameters()) | set(gnn.parameters()) | set(node_pred.parameters()),
lr=lr,
)
criterion = torch.nn.CrossEntropyLoss()
# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)
def plot_curve(scores, out_name):
plt.plot(scores, color="#e34a33")
plt.ylabel("score")
plt.savefig(out_name + ".pdf")
plt.close()
def process_edges(src, dst, t, msg):
if src.nelement() > 0:
# msg = msg.to(torch.float32)
memory.update_state(src, dst, t, msg)
neighbor_loader.insert(src, dst)
def train():
memory.train()
gnn.train()
node_pred.train()
memory.reset_state() # Start with a fresh memory.
neighbor_loader.reset_state() # Start with an empty graph.
total_loss = 0
label_t = dataset.get_label_time() # check when does the first label start
total_score = 0
num_label_ts = 0
for batch in tqdm(train_loader):
batch = batch.to(device)
optimizer.zero_grad()
src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
query_t = batch.t[-1]
# check if this batch moves to the next day
if query_t > label_t:
# find the node labels from the past day
label_tuple = dataset.get_node_label(query_t)
label_ts, label_srcs, labels = (
label_tuple[0],
label_tuple[1],
label_tuple[2],
)
label_t = dataset.get_label_time()
label_srcs = label_srcs.to(device)
# Process all edges that are still in the past day
previous_day_mask = batch.t < label_t
process_edges(
src[previous_day_mask],
dst[previous_day_mask],
t[previous_day_mask],
msg[previous_day_mask],
)
# Reset edges to be the edges from tomorrow so they can be used later
src, dst, t, msg = (
src[~previous_day_mask],
dst[~previous_day_mask],
t[~previous_day_mask],
msg[~previous_day_mask],
)
"""
modified for node property prediction
1. sample neighbors from the neighbor loader for all nodes to be predicted
2. extract memory from the sampled neighbors and the nodes
3. run gnn with the extracted memory embeddings and the corresponding time and message
"""
n_id = label_srcs
n_id_neighbors, mem_edge_index, e_id = neighbor_loader(n_id)
assoc[n_id_neighbors] = torch.arange(n_id_neighbors.size(0), device=device)
z, last_update = memory(n_id_neighbors)
z = gnn(
z,
last_update,
mem_edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
z = z[assoc[n_id]]
# loss and metric computation
pred = node_pred(z)
loss = criterion(pred, labels.to(device))
np_pred = pred.cpu().detach().numpy()
np_true = labels.cpu().detach().numpy()
input_dict = {
"y_true": np_true,
"y_pred": np_pred,
"eval_metric": [eval_metric],
}
result_dict = evaluator.eval(input_dict)
score = result_dict[eval_metric]
total_score += score
num_label_ts += 1
loss.backward()
optimizer.step()
total_loss += float(loss.detach())
# Update memory and neighbor loader with ground-truth state.
process_edges(src, dst, t, msg)
memory.detach()
metric_dict = {
"ce": total_loss / num_label_ts,
}
metric_dict[eval_metric] = total_score / num_label_ts
return metric_dict
@torch.no_grad()
def test(loader):
memory.eval()
gnn.eval()
node_pred.eval()
label_t = dataset.get_label_time() # check when does the first label start
num_label_ts = 0
total_score = 0
for batch in tqdm(loader):
batch = batch.to(device)
src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
query_t = batch.t[-1]
if query_t > label_t:
label_tuple = dataset.get_node_label(query_t)
if label_tuple is None:
break
label_ts, label_srcs, labels = (
label_tuple[0],
label_tuple[1],
label_tuple[2],
)
label_t = dataset.get_label_time()
label_srcs = label_srcs.to(device)
# Process all edges that are still in the past day
previous_day_mask = batch.t < label_t
process_edges(
src[previous_day_mask],
dst[previous_day_mask],
t[previous_day_mask],
msg[previous_day_mask],
)
# Reset edges to be the edges from tomorrow so they can be used later
src, dst, t, msg = (
src[~previous_day_mask],
dst[~previous_day_mask],
t[~previous_day_mask],
msg[~previous_day_mask],
)
"""
modified for node property prediction
1. sample neighbors from the neighbor loader for all nodes to be predicted
2. extract memory from the sampled neighbors and the nodes
3. run gnn with the extracted memory embeddings and the corresponding time and message
"""
n_id = label_srcs
n_id_neighbors, mem_edge_index, e_id = neighbor_loader(n_id)
assoc[n_id_neighbors] = torch.arange(n_id_neighbors.size(0), device=device)
z, last_update = memory(n_id_neighbors)
z = gnn(
z,
last_update,
mem_edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
z = z[assoc[n_id]]
# loss and metric computation
pred = node_pred(z)
np_pred = pred.cpu().detach().numpy()
np_true = labels.cpu().detach().numpy()
input_dict = {
"y_true": np_true,
"y_pred": np_pred,
"eval_metric": [eval_metric],
}
result_dict = evaluator.eval(input_dict)
score = result_dict[eval_metric]
total_score += score
num_label_ts += 1
process_edges(src, dst, t, msg)
metric_dict = {}
metric_dict[eval_metric] = total_score / num_label_ts
return metric_dict
train_curve = []
val_curve = []
test_curve = []
max_val_score = 0 #find the best test score based on validation score
best_test_idx = 0
for epoch in range(1, epochs + 1):
start_time = timeit.default_timer()
train_dict = train()
print("------------------------------------")
print(f"training Epoch: {epoch:02d}")
print(train_dict)
train_curve.append(train_dict[eval_metric])
print("Training takes--- %s seconds ---" % (timeit.default_timer() - start_time))
start_time = timeit.default_timer()
val_dict = test(val_loader)
print(val_dict)
val_curve.append(val_dict[eval_metric])
if (val_dict[eval_metric] > max_val_score):
max_val_score = val_dict[eval_metric]
best_test_idx = epoch - 1
print("Validation takes--- %s seconds ---" % (timeit.default_timer() - start_time))
start_time = timeit.default_timer()
test_dict = test(test_loader)
print(test_dict)
test_curve.append(test_dict[eval_metric])
print("Test takes--- %s seconds ---" % (timeit.default_timer() - start_time))
print("------------------------------------")
dataset.reset_label_time()
# code for plotting
plot_curve(train_curve, "train_curve")
plot_curve(val_curve, "val_curve")
plot_curve(test_curve, "test_curve")
max_test_score = test_curve[best_test_idx]
print("------------------------------------")
print("------------------------------------")
print ("best val score: ", max_val_score)
print ("best validation epoch : ", best_test_idx + 1)
print ("best test score: ", max_test_score)
================================================
FILE: examples/nodeproppred/tgbn-trade/count_new_nodes.py
================================================
import timeit
import numpy as np
from tqdm import tqdm
import math
import os
import os.path as osp
from pathlib import Path
import sys
import argparse
# internal imports
from modules.nodebank import NodeBank
from tgb.linkproppred.evaluate import Evaluator
from modules.edgebank_predictor import EdgeBankPredictor
from tgb.utils.utils import set_random_seed
from tgb.nodeproppred.dataset import NodePropPredDataset
# ==================
# ==================
# ==================
def count_nodes(data, test_mask, nodebank):
r"""
Evaluated the dynamic link prediction
Evaluation happens as 'one vs. many', meaning that each positive edge is evaluated against many negative edges
Parameters:
data: a dataset object
test_mask: required masks to load the test set edges
neg_sampler: an object that gives the negative edges corresponding to each positive edge
split_mode: specifies whether it is the 'validation' or 'test' set to correctly load the negatives
Returns:
perf_metric: the result of the performance evaluation
"""
node_dict_new = {}
node_dict = {}
num_batches = math.ceil(len(data['sources'][test_mask]) / BATCH_SIZE)
for batch_idx in tqdm(range(num_batches)):
start_idx = batch_idx * BATCH_SIZE
end_idx = min(start_idx + BATCH_SIZE, len(data['sources'][test_mask]))
pos_src, pos_dst, pos_t = (
data['sources'][test_mask][start_idx: end_idx],
data['destinations'][test_mask][start_idx: end_idx],
data['timestamps'][test_mask][start_idx: end_idx],
)
for node in pos_src:
if (not nodebank.query_node(node)):
if (node not in node_dict_new):
node_dict_new[node] = 1
if (node not in node_dict):
node_dict[node] = 1
for node in pos_dst:
if (not nodebank.query_node(node)):
if (node not in node_dict_new):
node_dict_new[node] = 1
if (node not in node_dict):
node_dict[node] = 1
return len(node_dict_new), len(node_dict)
def get_args():
parser = argparse.ArgumentParser('*** TGB: EdgeBank ***')
parser.add_argument('-d', '--data', type=str, help='Dataset name', default='tgbl-wiki')
parser.add_argument('--bs', type=int, help='Batch size', default=200)
parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)
parser.add_argument('--seed', type=int, help='Random seed', default=1)
parser.add_argument('--mem_mode', type=str, help='Memory mode', default='unlimited', choices=['unlimited', 'fixed_time_window'])
parser.add_argument('--time_window_ratio', type=float, help='Test window ratio', default=0.15)
try:
args = parser.parse_args()
except:
parser.print_help()
sys.exit(0)
return args, sys.argv
# ==================
# ==================
# ==================
start_overall = timeit.default_timer()
# set hyperparameters
args, _ = get_args()
SEED = args.seed # set the random seed for consistency
set_random_seed(SEED)
MEMORY_MODE = args.mem_mode # `unlimited` or `fixed_time_window`
BATCH_SIZE = 10000
K_VALUE = args.k_value
TIME_WINDOW_RATIO = args.time_window_ratio
DATA = "tgbn-token" #"tgbl-wiki"
MODEL_NAME = 'EdgeBank'
# data loading with `numpy`
dataset = NodePropPredDataset(name=DATA, root="datasets", preprocess=True)
data = dataset.full_data
metric = dataset.eval_metric
# get masks
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
train_src = data['sources'][train_mask]
train_dst = data['destinations'][train_mask]
#data for memory in edgebank
hist_src = np.concatenate([data['sources'][train_mask]])
hist_dst = np.concatenate([data['destinations'][train_mask]])
hist_ts = np.concatenate([data['timestamps'][train_mask]])
# Set EdgeBank with memory updater
edgebank = EdgeBankPredictor(
hist_src,
hist_dst,
hist_ts,
memory_mode=MEMORY_MODE,
time_window_ratio=TIME_WINDOW_RATIO)
print("==========================================================")
print(f"============*** {MODEL_NAME}: {MEMORY_MODE}: {DATA} ***==============")
print("==========================================================")
evaluator = Evaluator(name=DATA)
nodebank = NodeBank(train_src, train_dst)
new_val_num, val_total = count_nodes(data, val_mask, nodebank)
print ()
print ("-------------------------------------------------------")
print ("there are ", new_val_num, " new nodes in the validation set")
print ("there are ", val_total, " total nodes in the validation set")
print (" the percentage of new nodes in the validation set is ", (new_val_num/val_total))
new_test_num, test_total = count_nodes(data, test_mask, nodebank)
print ()
print ("-------------------------------------------------------")
print ("there are ", new_test_num, " new nodes in the test set")
print ("there are ", test_total, " total nodes in the test set")
print (" the percentage of new nodes in the test set is ", (new_test_num/test_total))
================================================
FILE: examples/nodeproppred/tgbn-trade/dyrep.py
================================================
"""
DyRep
This has been implemented with intuitions from the following sources:
- https://github.com/twitter-research/tgn
- https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py
Spec.:
- Memory Updater: RNN
- Embedding Module: ID
- Message Function: ATTN
"""
import timeit
from tqdm import tqdm
import torch
from torch_geometric.loader import TemporalDataLoader
# internal imports
from tgb.utils.utils import get_args, set_random_seed
from tgb.nodeproppred.evaluate import Evaluator
from modules.decoder import NodePredictor
from modules.emb_module import GraphAttentionEmbedding
from modules.msg_func import IdentityMessage
from modules.msg_agg import LastAggregator
from modules.neighbor_loader import LastNeighborLoader
from modules.memory_module import DyRepMemory
from tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset
def process_edges(src, dst, t, msg):
if src.nelement() > 0:
model['memory'].update_state(src, dst, t, msg)
neighbor_loader.insert(src, dst)
# ==========
# ========== Define helper function...
# ==========
def train():
model['memory'].train()
model['gnn'].train()
model['node_pred'].train()
model['memory'].reset_state() # Start with a fresh memory.
neighbor_loader.reset_state() # Start with an empty graph.
total_loss = 0
label_t = dataset.get_label_time() # check when does the first label start
num_label_ts = 0
total_score = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
query_t = batch.t[-1]
# check if this batch moves to the next day
if query_t > label_t:
# find the node labels from the past day
label_tuple = dataset.get_node_label(query_t)
label_ts, label_srcs, labels = (
label_tuple[0],
label_tuple[1],
label_tuple[2],
)
label_t = dataset.get_label_time()
label_srcs = label_srcs.to(device)
# Process all edges that are still in the past day
previous_day_mask = batch.t < label_t
process_edges(
src[previous_day_mask],
dst[previous_day_mask],
t[previous_day_mask],
msg[previous_day_mask],
)
# Reset edges to be the edges from tomorrow so they can be used later
src, dst, t, msg = (
src[~previous_day_mask],
dst[~previous_day_mask],
t[~previous_day_mask],
msg[~previous_day_mask],
)
"""
modified for node property prediction
1. sample neighbors from the neighbor loader for all nodes to be predicted
2. extract memory from the sampled neighbors and the nodes
3. run gnn with the extracted memory embeddings and the corresponding time and message
"""
n_id = label_srcs
n_id_neighbors, mem_edge_index, e_id = neighbor_loader(n_id)
assoc[n_id_neighbors] = torch.arange(n_id_neighbors.size(0), device=device)
z, last_update = model['memory'](n_id_neighbors)
z = model['gnn'](
z,
last_update,
mem_edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
z = z[assoc[n_id]]
# loss and metric computation
pred = model['node_pred'](z)
loss = criterion(pred, labels.to(device))
np_pred = pred.cpu().detach().numpy()
np_true = labels.cpu().detach().numpy()
input_dict = {
"y_true": np_true,
"y_pred": np_pred,
"eval_metric": [metric],
}
result_dict = evaluator.eval(input_dict)
score = result_dict[metric]
total_score += score
num_label_ts += 1
loss.backward()
optimizer.step()
total_loss += float(loss.detach())
# Update memory and neighbor loader with ground-truth state.
process_edges(src, dst, t, msg)
model['memory'].detach()
metric_dict = {
"ce": total_loss / num_label_ts,
}
metric_dict[metric] = total_score / num_label_ts
return metric_dict
@torch.no_grad()
def test(loader):
model['memory'].eval()
model['gnn'].eval()
model['node_pred'].eval()
total_score = 0
label_t = dataset.get_label_time() # check when does the first label start
num_label_ts = 0
for batch in loader:
batch = batch.to(device)
src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
query_t = batch.t[-1]
if query_t > label_t:
label_tuple = dataset.get_node_label(query_t)
if label_tuple is None:
break
label_ts, label_srcs, labels = (
label_tuple[0],
label_tuple[1],
label_tuple[2],
)
label_t = dataset.get_label_time()
label_srcs = label_srcs.to(device)
# Process all edges that are still in the past day
previous_day_mask = batch.t < label_t
process_edges(
src[previous_day_mask],
dst[previous_day_mask],
t[previous_day_mask],
msg[previous_day_mask],
)
# Reset edges to be the edges from tomorrow so they can be used later
src, dst, t, msg = (
src[~previous_day_mask],
dst[~previous_day_mask],
t[~previous_day_mask],
msg[~previous_day_mask],
)
"""
modified for node property prediction
1. sample neighbors from the neighbor loader for all nodes to be predicted
2. extract memory from the sampled neighbors and the nodes
3. run gnn with the extracted memory embeddings and the corresponding time and message
"""
n_id = label_srcs
n_id_neighbors, mem_edge_index, e_id = neighbor_loader(n_id)
assoc[n_id_neighbors] = torch.arange(n_id_neighbors.size(0), device=device)
z, last_update = model['memory'](n_id_neighbors)
z = model['gnn'](
z,
last_update,
mem_edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
z = z[assoc[n_id]]
# loss and metric computation
pred = model['node_pred'](z)
np_pred = pred.cpu().detach().numpy()
np_true = labels.cpu().detach().numpy()
input_dict = {
"y_true": np_true,
"y_pred": np_pred,
"eval_metric": [metric],
}
result_dict = evaluator.eval(input_dict)
score = result_dict[metric]
total_score += score
num_label_ts += 1
process_edges(src, dst, t, msg)
metric_dict = {}
metric_dict[metric] = total_score / num_label_ts
return metric_dict
# ==========
# ==========
# ==========
# Start...
start_overall = timeit.default_timer()
# ========== set parameters...
args, _ = get_args()
print("INFO: Arguments:", args)
DATA = "tgbn-trade"
LR = args.lr
BATCH_SIZE = args.bs
K_VALUE = args.k_value
NUM_EPOCH = args.num_epoch
SEED = args.seed
MEM_DIM = args.mem_dim
TIME_DIM = args.time_dim
EMB_DIM = args.emb_dim
TOLERANCE = args.tolerance
PATIENCE = args.patience
NUM_RUNS = args.num_run
NUM_NEIGHBORS = 10
# setting random seed
torch.manual_seed(SEED)
set_random_seed(SEED)
MODEL_NAME = 'DyRep'
USE_SRC_EMB_IN_MSG = False
USE_DST_EMB_IN_MSG = True
evaluator = Evaluator(name=DATA)
# ==========
# set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# data loading
dataset = PyGNodePropPredDataset(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data = data.to(device)
metric = dataset.eval_metric
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
num_classes = dataset.num_classes
train_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)
val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)
test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
# neighborhood sampler
neighbor_loader = LastNeighborLoader(data.num_nodes, size=NUM_NEIGHBORS, device=device)
# define the model end-to-end
memory = DyRepMemory(
data.num_nodes,
data.msg.size(-1),
MEM_DIM,
TIME_DIM,
message_module=IdentityMessage(data.msg.size(-1), MEM_DIM, TIME_DIM),
aggregator_module=LastAggregator(),
memory_updater_type='rnn',
use_src_emb_in_msg=USE_SRC_EMB_IN_MSG,
use_dst_emb_in_msg=USE_DST_EMB_IN_MSG
).to(device)
gnn = GraphAttentionEmbedding(
in_channels=MEM_DIM,
out_channels=EMB_DIM,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
).to(device)
node_pred = NodePredictor(in_dim=EMB_DIM, out_dim=num_classes).to(device)
model = {'memory': memory,
'gnn': gnn,
'node_pred': node_pred}
optimizer = torch.optim.Adam(
set(model['memory'].parameters()) | set(model['gnn'].parameters()) | set(model['node_pred'].parameters()),
lr=LR,
)
criterion = torch.nn.CrossEntropyLoss()
# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)
train_curve = []
val_curve = []
test_curve = []
max_val_score = 0 #find the best test score based on validation score
best_test_idx = 0
for epoch in range(1, NUM_EPOCH + 1):
start_time = timeit.default_timer()
train_dict = train()
print("------------------------------------")
print(f"training Epoch: {epoch:02d}")
print(train_dict)
train_curve.append(train_dict[metric])
print("Training takes--- %s seconds ---" % (timeit.default_timer() - start_time))
start_time = timeit.default_timer()
val_dict = test(val_loader)
print(val_dict)
val_curve.append(val_dict[metric])
if (val_dict[metric] > max_val_score):
max_val_score = val_dict[metric]
best_test_idx = epoch - 1
print("Validation takes--- %s seconds ---" % (timeit.default_timer() - start_time))
start_time = timeit.default_timer()
test_dict = test(test_loader)
print(test_dict)
test_curve.append(test_dict[metric])
print("Test takes--- %s seconds ---" % (timeit.default_timer() - start_time))
print("------------------------------------")
dataset.reset_label_time()
max_test_score = test_curve[best_test_idx]
print("------------------------------------")
print("------------------------------------")
print ("best val score: ", max_val_score)
print ("best validation epoch : ", best_test_idx + 1)
print ("best test score: ", max_test_score)
================================================
FILE: examples/nodeproppred/tgbn-trade/moving_average.py
================================================
"""
implement persistant forecast as baseline for the node prop pred task
simply predict last seen label for the node
"""
import timeit
import numpy as np
from torch_geometric.loader import TemporalDataLoader
# local imports
from tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset
from modules.heuristics import MovingAverage
from tgb.nodeproppred.evaluate import Evaluator
device = "cpu"
window = 7
name = "tgbn-trade"
dataset = PyGNodePropPredDataset(name=name, root="datasets")
num_classes = dataset.num_classes
data = dataset.get_TemporalData()
data = data.to(device)
eval_metric = dataset.eval_metric
forecaster = MovingAverage(num_classes, window=window)
evaluator = Evaluator(name=name)
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
train_data, val_data, test_data = data.train_val_test_split(
val_ratio=0.15, test_ratio=0.15
)
batch_size = 200
train_loader = TemporalDataLoader(train_data, batch_size=batch_size)
val_loader = TemporalDataLoader(val_data, batch_size=batch_size)
test_loader = TemporalDataLoader(test_data, batch_size=batch_size)
def test_n_upate(loader):
label_t = dataset.get_label_time() # check when does the first label start
num_label_ts = 0
total_score = 0
for batch in loader:
batch = batch.to(device)
src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
query_t = batch.t[-1]
if query_t > label_t:
label_tuple = dataset.get_node_label(query_t)
if label_tuple is None:
break
label_ts, label_srcs, labels = (
label_tuple[0],
label_tuple[1],
label_tuple[2],
)
label_ts = label_ts.numpy()
label_srcs = label_srcs.numpy()
labels = labels.numpy()
label_t = dataset.get_label_time()
preds = []
for i in range(0, label_srcs.shape[0]):
node_id = label_srcs[i]
pred_vec = forecaster.query_dict(node_id)
preds.append(pred_vec)
forecaster.update_dict(node_id, labels[i])
np_pred = np.stack(preds, axis=0)
np_true = labels
input_dict = {
"y_true": np_true,
"y_pred": np_pred,
"eval_metric": [eval_metric],
}
result_dict = evaluator.eval(input_dict)
score = result_dict[eval_metric]
total_score += score
num_label_ts += 1
metric_dict = {}
metric_dict[eval_metric] = total_score / num_label_ts
return metric_dict
"""
train, val and test for one epoch only
"""
start_time = timeit.default_timer()
metric_dict = test_n_upate(train_loader)
print(metric_dict)
print(
"Persistant forecast on Training takes--- %s seconds ---"
% (timeit.default_timer() - start_time)
)
start_time = timeit.default_timer()
val_dict = test_n_upate(val_loader)
print(val_dict)
print(
"Persistant forecast on validation takes--- %s seconds ---"
% (timeit.default_timer() - start_time)
)
start_time = timeit.default_timer()
test_dict = test_n_upate(test_loader)
print(test_dict)
print(
"Persistant forecast on Test takes--- %s seconds ---" % (timeit.default_timer() - start_time)
)
dataset.reset_label_time()
================================================
FILE: examples/nodeproppred/tgbn-trade/persistant_forecast.py
================================================
"""
implement persistant forecast as baseline for the node prop pred task
simply predict last seen label for the node
"""
import timeit
import numpy as np
from torch_geometric.loader import TemporalDataLoader
# local imports
from tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset
from modules.heuristics import PersistantForecaster
from tgb.nodeproppred.evaluate import Evaluator
device = "cpu"
name = "tgbn-trade"
dataset = PyGNodePropPredDataset(name=name, root="datasets")
num_classes = dataset.num_classes
data = dataset.get_TemporalData()
data = data.to(device)
eval_metric = dataset.eval_metric
forecaster = PersistantForecaster(num_classes)
evaluator = Evaluator(name=name)
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
train_data, val_data, test_data = data.train_val_test_split(
val_ratio=0.15, test_ratio=0.15
)
batch_size = 200
train_loader = TemporalDataLoader(train_data, batch_size=batch_size)
val_loader = TemporalDataLoader(val_data, batch_size=batch_size)
test_loader = TemporalDataLoader(test_data, batch_size=batch_size)
def test_n_upate(loader):
label_t = dataset.get_label_time() # check when does the first label start
num_label_ts = 0
total_score = 0
for batch in loader:
batch = batch.to(device)
src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
query_t = batch.t[-1]
if query_t > label_t:
label_tuple = dataset.get_node_label(query_t)
if label_tuple is None:
break
label_ts, label_srcs, labels = (
label_tuple[0],
label_tuple[1],
label_tuple[2],
)
label_ts = label_ts.numpy()
label_srcs = label_srcs.numpy()
labels = labels.numpy()
label_t = dataset.get_label_time()
preds = []
for i in range(0, label_srcs.shape[0]):
node_id = label_srcs[i]
pred_vec = forecaster.query_dict(node_id)
preds.append(pred_vec)
forecaster.update_dict(node_id, labels[i])
np_pred = np.stack(preds, axis=0)
np_true = labels
input_dict = {
"y_true": np_true,
"y_pred": np_pred,
"eval_metric": [eval_metric],
}
result_dict = evaluator.eval(input_dict)
score = result_dict[eval_metric]
total_score += score
num_label_ts += 1
metric_dict = {}
metric_dict[eval_metric] = total_score / num_label_ts
return metric_dict
"""
train, val and test for one epoch only
"""
start_time = timeit.default_timer()
metric_dict = test_n_upate(train_loader)
print(metric_dict)
print(
"Persistant forecast on Training takes--- %s seconds ---"
% (timeit.default_timer() - start_time)
)
start_time = timeit.default_timer()
val_dict = test_n_upate(val_loader)
print(val_dict)
print(
"Persistant forecast on validation takes--- %s seconds ---"
% (timeit.default_timer() - start_time)
)
start_time = timeit.default_timer()
test_dict = test_n_upate(test_loader)
print(test_dict)
print(
"Persistant forecast on Test takes--- %s seconds ---" % (timeit.default_timer() - start_time)
)
dataset.reset_label_time()
================================================
FILE: examples/nodeproppred/tgbn-trade/tgn.py
================================================
import timeit
import argparse
from tqdm import tqdm
import torch
import matplotlib.pyplot as plt
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TGNMemory
from torch_geometric.nn.models.tgn import (
IdentityMessage,
LastAggregator,
LastNeighborLoader,
)
from modules.decoder import NodePredictor
from modules.emb_module import GraphAttentionEmbedding
from tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset
from tgb.nodeproppred.evaluate import Evaluator
from tgb.utils.utils import set_random_seed
from tgb.utils.stats import plot_curve
parser = argparse.ArgumentParser(description='parsing command line arguments as hyperparameters')
parser.add_argument('-s', '--seed', type=int, default=1,
help='random seed to use')
parser.parse_args()
args = parser.parse_args()
# setting random seed
seed = int(args.seed) #1,2,3,4,5
torch.manual_seed(seed)
set_random_seed(seed)
# hyperparameters
lr = 0.0001
epochs = 50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
name = "tgbn-trade"
dataset = PyGNodePropPredDataset(name=name, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
eval_metric = dataset.eval_metric
num_classes = dataset.num_classes
data = dataset.get_TemporalData()
data = data.to(device)
evaluator = Evaluator(name=name)
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
batch_size = 200
train_loader = TemporalDataLoader(train_data, batch_size=batch_size)
val_loader = TemporalDataLoader(val_data, batch_size=batch_size)
test_loader = TemporalDataLoader(test_data, batch_size=batch_size)
neighbor_loader = LastNeighborLoader(data.num_nodes, size=10, device=device)
memory_dim = time_dim = embedding_dim = 100
memory = TGNMemory(
data.num_nodes,
data.msg.size(-1),
memory_dim,
time_dim,
message_module=IdentityMessage(data.msg.size(-1), memory_dim, time_dim),
aggregator_module=LastAggregator(),
).to(device)
gnn = (
GraphAttentionEmbedding(
in_channels=memory_dim,
out_channels=embedding_dim,
msg_dim=data.msg.size(-1),
time_enc=memory.time_enc,
)
.to(device)
.float()
)
node_pred = NodePredictor(in_dim=embedding_dim, out_dim=num_classes).to(device)
optimizer = torch.optim.Adam(
set(memory.parameters()) | set(gnn.parameters()) | set(node_pred.parameters()),
lr=lr,
)
criterion = torch.nn.CrossEntropyLoss()
# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)
def plot_curve(scores, out_name):
plt.plot(scores, color="#e34a33")
plt.ylabel("score")
plt.savefig(out_name + ".pdf")
plt.close()
def process_edges(src, dst, t, msg):
if src.nelement() > 0:
# msg = msg.to(torch.float32)
memory.update_state(src, dst, t, msg)
neighbor_loader.insert(src, dst)
def train():
memory.train()
gnn.train()
node_pred.train()
memory.reset_state() # Start with a fresh memory.
neighbor_loader.reset_state() # Start with an empty graph.
total_loss = 0
label_t = dataset.get_label_time() # check when does the first label start
num_label_ts = 0
total_score = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
query_t = batch.t[-1]
# check if this batch moves to the next day
if query_t > label_t:
# find the node labels from the past day
label_tuple = dataset.get_node_label(query_t)
label_ts, label_srcs, labels = (
label_tuple[0],
label_tuple[1],
label_tuple[2],
)
label_t = dataset.get_label_time()
label_srcs = label_srcs.to(device)
# Process all edges that are still in the past day
previous_day_mask = batch.t < label_t
process_edges(
src[previous_day_mask],
dst[previous_day_mask],
t[previous_day_mask],
msg[previous_day_mask],
)
# Reset edges to be the edges from tomorrow so they can be used later
src, dst, t, msg = (
src[~previous_day_mask],
dst[~previous_day_mask],
t[~previous_day_mask],
msg[~previous_day_mask],
)
"""
modified for node property prediction
1. sample neighbors from the neighbor loader for all nodes to be predicted
2. extract memory from the sampled neighbors and the nodes
3. run gnn with the extracted memory embeddings and the corresponding time and message
"""
n_id = label_srcs
n_id_neighbors, mem_edge_index, e_id = neighbor_loader(n_id)
assoc[n_id_neighbors] = torch.arange(n_id_neighbors.size(0), device=device)
z, last_update = memory(n_id_neighbors)
z = gnn(
z,
last_update,
mem_edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
z = z[assoc[n_id]]
# loss and metric computation
pred = node_pred(z)
loss = criterion(pred, labels.to(device))
np_pred = pred.cpu().detach().numpy()
np_true = labels.cpu().detach().numpy()
input_dict = {
"y_true": np_true,
"y_pred": np_pred,
"eval_metric": [eval_metric],
}
result_dict = evaluator.eval(input_dict)
score = result_dict[eval_metric]
total_score += score
num_label_ts += 1
loss.backward()
optimizer.step()
total_loss += float(loss.detach())
# Update memory and neighbor loader with ground-truth state.
process_edges(src, dst, t, msg)
memory.detach()
metric_dict = {
"ce": total_loss / num_label_ts,
}
metric_dict[eval_metric] = total_score / num_label_ts
return metric_dict
@torch.no_grad()
def test(loader):
memory.eval()
gnn.eval()
node_pred.eval()
total_score = 0
label_t = dataset.get_label_time() # check when does the first label start
num_label_ts = 0
for batch in loader:
batch = batch.to(device)
src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
query_t = batch.t[-1]
if query_t > label_t:
label_tuple = dataset.get_node_label(query_t)
if label_tuple is None:
break
label_ts, label_srcs, labels = (
label_tuple[0],
label_tuple[1],
label_tuple[2],
)
label_t = dataset.get_label_time()
label_srcs = label_srcs.to(device)
# Process all edges that are still in the past day
previous_day_mask = batch.t < label_t
process_edges(
src[previous_day_mask],
dst[previous_day_mask],
t[previous_day_mask],
msg[previous_day_mask],
)
# Reset edges to be the edges from tomorrow so they can be used later
src, dst, t, msg = (
src[~previous_day_mask],
dst[~previous_day_mask],
t[~previous_day_mask],
msg[~previous_day_mask],
)
"""
modified for node property prediction
1. sample neighbors from the neighbor loader for all nodes to be predicted
2. extract memory from the sampled neighbors and the nodes
3. run gnn with the extracted memory embeddings and the corresponding time and message
"""
n_id = label_srcs
n_id_neighbors, mem_edge_index, e_id = neighbor_loader(n_id)
assoc[n_id_neighbors] = torch.arange(n_id_neighbors.size(0), device=device)
z, last_update = memory(n_id_neighbors)
z = gnn(
z,
last_update,
mem_edge_index,
data.t[e_id].to(device),
data.msg[e_id].to(device),
)
z = z[assoc[n_id]]
# loss and metric computation
pred = node_pred(z)
np_pred = pred.cpu().detach().numpy()
np_true = labels.cpu().detach().numpy()
input_dict = {
"y_true": np_true,
"y_pred": np_pred,
"eval_metric": [eval_metric],
}
result_dict = evaluator.eval(input_dict)
score = result_dict[eval_metric]
total_score += score
num_label_ts += 1
process_edges(src, dst, t, msg)
metric_dict = {}
metric_dict[eval_metric] = total_score / num_label_ts
return metric_dict
train_curve = []
val_curve = []
test_curve = []
max_val_score = 0 #find the best test score based on validation score
best_test_idx = 0
for epoch in range(1, epochs + 1):
start_time = timeit.default_timer()
train_dict = train()
print("------------------------------------")
print(f"training Epoch: {epoch:02d}")
print(train_dict)
train_curve.append(train_dict[eval_metric])
print("Training takes--- %s seconds ---" % (timeit.default_timer() - start_time))
start_time = timeit.default_timer()
val_dict = test(val_loader)
print(val_dict)
val_curve.append(val_dict[eval_metric])
if (val_dict[eval_metric] > max_val_score):
max_val_score = val_dict[eval_metric]
best_test_idx = epoch - 1
print("Validation takes--- %s seconds ---" % (timeit.default_timer() - start_time))
start_time = timeit.default_timer()
test_dict = test(test_loader)
print(test_dict)
test_curve.append(test_dict[eval_metric])
print("Test takes--- %s seconds ---" % (timeit.default_timer() - start_time))
print("------------------------------------")
dataset.reset_label_time()
# code for plotting
plot_curve(train_curve, "train_curve")
plot_curve(val_curve, "val_curve")
plot_curve(test_curve, "test_curve")
max_test_score = test_curve[best_test_idx]
print("------------------------------------")
print("------------------------------------")
print ("best val score: ", max_val_score)
print ("best validation epoch : ", best_test_idx + 1)
print ("best test score: ", max_test_score)
================================================
FILE: mkdocs.yml
================================================
site_name: Temporal Graph Benchmark
nav:
- Overview: index.md
- About: about.md
- API:
- tgb.linkproppred: api/tgb.linkproppred.md
- tgb.nodeproppred: api/tgb.nodeproppred.md
- tgb.utils: api/tgb.utils.md
- Tutorials:
- Access Edge Data in PyG: tutorials/Edge_data_pyg.ipynb
- Access Edge Data in Numpy: tutorials/Edge_data_numpy.ipynb
theme:
logo: assets/logo.png
name: material
features:
- navigation.tabs
- navigation.sections
- toc.integrate
- navigation.top
- search.suggest
- search.highlight
- content.tabs.link
- content.code.annotation
- content.code.copy
language: en
palette:
- scheme: default
toggle:
icon: material/toggle-switch-off-outline
name: Switch to dark mode
primary: purple
accent: orange
- scheme: slate
toggle:
icon: material/toggle-switch
name: Switch to light mode
primary: orange
accent: lime
extra:
social:
- icon: fontawesome/brands/github-alt
link: https://github.com/shenyangHuang/TGB
- icon: fontawesome/solid/envelope
link: shenyang.huang@mail.mcgill.ca
- icon: fontawesome/brands/twitter
link: https://twitter.com/shenyangHuang
- icon: fontawesome/brands/linkedin
link: https://www.linkedin.com/in/shenyang-huang/
markdown_extensions:
- pymdownx.highlight:
anchor_linenums: true
- pymdownx.inlinehilite
- pymdownx.snippets
- admonition
- pymdownx.arithmatex:
generic: true
- footnotes
- pymdownx.details
- pymdownx.superfences
- pymdownx.mark
- attr_list
- pymdownx.emoji:
emoji_index: !!python/name:materialx.emoji.twemoji
emoji_generator: !!python/name:materialx.emoji.to_svg
plugins:
- search
- mkdocstrings:
watch:
- tgb/
handlers:
python:
setup_commands:
- import sys
- sys.path.append("docs")
- sys.path.append("tgb")
selection:
new_path_syntax: true
rendering:
show_root_heading: false
heading_level: 3
show_root_full_path: false
- mkdocs-jupyter:
execute: false
================================================
FILE: modules/decoder.py
================================================
"""
Decoder modules for dynamic link prediction
"""
import torch
from torch.nn import Linear
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import math
class LinkPredictor(torch.nn.Module):
"""
Reference:
- https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py
"""
def __init__(self, in_channels):
super().__init__()
self.lin_src = Linear(in_channels, in_channels)
self.lin_dst = Linear(in_channels, in_channels)
self.lin_final = Linear(in_channels, 1)
def forward(self, z_src, z_dst):
h = self.lin_src(z_src) + self.lin_dst(z_dst)
h = h.relu()
return self.lin_final(h).sigmoid()
class NodePredictor(torch.nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.lin_node = Linear(in_dim, in_dim)
self.out = Linear(in_dim, out_dim)
def forward(self, node_embed):
h = self.lin_node(node_embed)
h = h.relu()
h = self.out(h)
# h = F.log_softmax(h, dim=-1)
return h
### for TKG:
class ConvTransE(torch.nn.Module):
"""
https://github.com/Lee-zix/CEN/blob/main/src/decoder.py
"""
def __init__(self, num_entities, embedding_dim, input_dropout=0, hidden_dropout=0,
feature_map_dropout=0, channels=50, kernel_size=3, sequence_len = 1, use_bias=True, model_name='REGCN'):
super(ConvTransE, self).__init__()
self.model_name = model_name #'REGCN' or 'CEN'
self.inp_drop = torch.nn.Dropout(input_dropout)
self.hidden_drop = torch.nn.Dropout(hidden_dropout)
self.feature_map_drop = torch.nn.Dropout(feature_map_dropout)
self.embedding_dim = embedding_dim
# self.sequence_len = sequence_len
self.conv_list = torch.nn.ModuleList()
self.bn0_list = torch.nn.ModuleList()
self.bn1_list = torch.nn.ModuleList()
self.bn2_list = torch.nn.ModuleList()
for _ in range(sequence_len):
self.conv_list.append(torch.nn.Conv1d(2, channels, kernel_size, stride=1,
padding=int(math.floor(kernel_size / 2))) ) # kernel size is odd, then padding = math.floor(kernel_size/2))
self.bn0_list.append(torch.nn.BatchNorm1d(2))
self.bn1_list.append( torch.nn.BatchNorm1d(channels))
self.bn2_list.append(torch.nn.BatchNorm1d(embedding_dim))
self.fc = torch.nn.Linear(embedding_dim * channels, embedding_dim)
def forward(self, embedding, emb_rel, triplets, partial_embeding=None, samples_of_interest_emb=None):
""" forward for ConvsTransE decoder that computes scores for given triples of question
return: score_list: list of scores for each triple in the batch
"""
score_list = []
batch_size = len(triplets)
if self.model_name == 'CEN': #CEN
for idx in range(len(embedding)): # leng of test_graph
if samples_of_interest_emb != None:
x= self.forward_inner(embedding[idx], emb_rel, triplets, idx, partial_embeding, samples_of_interest_emb[idx])
else:
x= self.forward_inner(embedding[idx], emb_rel, triplets, idx, partial_embeding, samples_of_interest_emb)
score_list.append(x)
return score_list
else: #RE-GCN
scores = self.forward_inner(embedding, emb_rel, triplets, 0, partial_embeding, samples_of_interest_emb)
return scores
def forward_inner(self, embedding, emb_rel, triplets, idx=0, partial_embeding=None, samples_of_interest_emb=None):
""" forward for ConvsTransE decoder that computes scores for given triples of question for each graph in the history of test graphs
return: x: list of scores for each triple in the batch
"""
batch_size = len(triplets)
e1_embedded_all = F.tanh(embedding)
e1_embedded = e1_embedded_all[triplets[:, 0]].unsqueeze(1)
rel_embedded = emb_rel[triplets[:, 1]].unsqueeze(1)
stacked_inputs = torch.cat([e1_embedded, rel_embedded], 1)
stacked_inputs = self.bn0_list[idx](stacked_inputs)
x = self.inp_drop(stacked_inputs)
x = self.conv_list[idx](x)
x = self.bn1_list[idx](x)
x = F.relu(x)
x = self.feature_map_drop(x)
x = x.view(batch_size, -1)
x = self.fc(x)
x = self.hidden_drop(x)
if batch_size > 1:
x = self.bn2_list[idx](x)
x = F.relu(x)
if partial_embeding !=None:
x = torch.mm(x, partial_embeding.transpose(1, 0))
elif samples_of_interest_emb !=None: # added tgb team: predict only for nodes of interest
x = torch.mm(x, F.tanh(samples_of_interest_emb).transpose(1, 0))
else: #predict for all nodes
x = torch.mm(x, e1_embedded_all.transpose(1, 0))
return x
================================================
FILE: modules/early_stopping.py
================================================
"""
An Early Stopping Module
"""
import os
from pathlib import Path
import torch
import torch.nn as nn
import numpy as np
class EarlyStopMonitor(object):
def __init__(self, save_model_dir: str, save_model_id: str,
tolerance: float=1e-10, patience: int=5,
higher_better: bool=True):
r"""
Early Stopping Monitor
:param: save_model_path: strc, where to save the model
:param: save_model_id: str, an id to save the model with
:param: tolerance: float, the amount of tolerance of the early stopper
:param: patience: int, how many round to wait
:param: higher_better: whether higher_value of the a metric is better
"""
self.tolerance = tolerance
self.patience = patience
self.higher_better = higher_better
self.counter = 0
self.best_sofar = None
self.best_epoch = 0
self.epoch_idx = 1
self.save_model_dir = save_model_dir
if not os.path.exists(self.save_model_dir):
os.mkdir(self.save_model_dir)
print('INFO: Create directory {}'.format(save_model_dir))
Path(self.save_model_dir).mkdir(parents=True, exist_ok=True)
self.save_model_id = save_model_id
def get_best_model_path(self):
r"""
return the path of the best model
"""
return self.save_model_dir + '/{}.pth'.format(self.save_model_id)
def step_check(self, curr_metric: float, models_dict: dict):
r"""
execute the early stop strategy
:param: metric: a metric to evaluate the early stopping on
:param: models_dict: a dictionary containing all models to be saved
"""
if not self.higher_better:
curr_metric *= -1
if (self.best_sofar is None) or ((curr_metric - self.best_sofar) / np.abs(self.best_sofar) > self.tolerance):
# first iteration or observing an improvement
self.best_sofar = curr_metric
print("INFO: save a checkpoint...")
self.save_checkpoint(models_dict)
self.counter = 0
self.best_epoch = self.epoch_idx
else:
# no improvement observed
self.counter += 1
self.epoch_idx += 1
return self.counter >= self.patience
def save_checkpoint(self, models_dict: dict):
r"""
save models as a checkpoint
:param: models_dict: a dictionary containing all models to be saved
"""
model_path = self.get_best_model_path()
print("INFO: save the model to {}".format(model_path))
model_names = list(models_dict.keys())
model_components = list(models_dict.values())
torch.save({model_names[i]: model_components[i].state_dict() for i in range(len(model_names))},
model_path)
def load_checkpoint(self, models_dict: dict):
r"""
save models from the checkpoint
:param: models_dict: a dictionary containing all models
"""
model_path = self.get_best_model_path()
print("INFO: load the model of epoch {} from {}".format(self.best_epoch, model_path))
checkpoint = torch.load(model_path)
for model_name, model in models_dict.items():
model.load_state_dict(checkpoint[model_name])
================================================
FILE: modules/edgebank_predictor.py
================================================
"""
EdgeBank is a simple strong baseline for dynamic link prediction
it predicts the existence of edges based on their history of occurrence
Reference:
- https://github.com/fpour/DGB/tree/main
"""
import numpy as np
import warnings
class EdgeBankPredictor(object):
def __init__(
self,
src: np.ndarray,
dst: np.ndarray,
ts: np.ndarray,
memory_mode: str = 'unlimited', # could be `unlimited` or `fixed_time_window`
time_window_ratio: float = 0.15,
pos_prob: float = 1.0,
):
r"""
intialize edgebank and specify the memory mode
Parameters:
src: source node id of the edges for initialization
dst: destination node id of the edges for initialization
ts: timestamp of the edges for initialization
memory_mode: 'unlimited' or 'fixed_time_window'
time_window_ratio: the ratio of the time window length to the total time length
pos_prob: the probability of the link existence for the edges in memory
"""
assert memory_mode in ['unlimited', 'fixed_time_window'], "Invalide memory mode for EdgeBank!"
self.memory_mode = memory_mode
if self.memory_mode == 'fixed_time_window':
self.time_window_ratio = time_window_ratio
#determine the time window size based on ratio from the given src, dst, and ts for initialization
duration = ts.max() - ts.min()
self.prev_t = ts.min() + duration * (1-time_window_ratio) #the time windows starts from the last ratio% of time
self.cur_t = ts.max()
self.duration = self.cur_t - self.prev_t
else:
self.time_window_ratio = -1
self.prev_t = -1
self.cur_t = -1
self.duration = -1
self.memory = {} #{(u,v):1}
self.pos_prob = pos_prob
self.update_memory(src, dst, ts)
def update_memory(self,
src: np.ndarray,
dst: np.ndarray,
ts: np.ndarray):
r"""
generate the current and correct state of the memory with the observed edges so far
note that historical edges may include training, validation, and already observed test edges
Parameters:
src: source node id of the edges
dst: destination node id of the edges
ts: timestamp of the edges
"""
if self.memory_mode == 'unlimited':
self._update_unlimited_memory(src, dst) #ignores time
elif self.memory_mode == 'fixed_time_window':
self._update_time_window_memory(src, dst, ts)
else:
raise ValueError("Invalide memory mode!")
@property
def start_time(self) -> int:
"""
return the start of time window for edgebank `fixed_time_window` only
Returns:
start of time window
"""
if (self.memory_mode == "unlimited"):
warnings.warn("start_time is not defined for unlimited memory mode, returns -1")
return self.prev_t
@property
def end_time(self) -> int:
"""
return the end of time window for edgebank `fixed_time_window` only
Returns:
end of time window
"""
if (self.memory_mode == "unlimited"):
warnings.warn("end_time is not defined for unlimited memory mode, returns -1")
return self.cur_t
def _update_unlimited_memory(self,
update_src: np.ndarray,
update_dst: np.ndarray):
r"""
update self.memory with newly arrived src and dst
Parameters:
src: source node id of the edges
dst: destination node id of the edges
"""
for src, dst in zip(update_src, update_dst):
if (src, dst) not in self.memory:
self.memory[(src, dst)] = 1
def _update_time_window_memory(self,
update_src: np.ndarray,
update_dst: np.ndarray,
update_ts: np.ndarray) -> None:
r"""
move the time window forward until end of dst timestamp here
also need to remove earlier edges from memory which is not in the time window
Parameters:
update_src: source node id of the edges
update_dst: destination node id of the edges
update_ts: timestamp of the edges
"""
#* initialize the memory if it is empty
if (len(self.memory) == 0):
for src, dst, ts in zip(update_src, update_dst, update_ts):
self.memory[(src, dst)] = ts
return None
#* update the memory if it is not empty
if (update_ts.max() > self.cur_t):
self.cur_t = update_ts.max()
self.prev_t = self.cur_t - self.duration
#* add new edges to the time window
for src, dst, ts in zip(update_src, update_dst, update_ts):
self.memory[(src, dst)] = ts
def predict_link(self,
query_src: np.ndarray,
query_dst: np.ndarray) -> np.ndarray:
r"""
predict the probability from query src,dst pair given the current memory,
all edges not in memory will return 0.0 while all observed edges in memory will return self.pos_prob
Parameters:
query_src: source node id of the query edges
query_dst: destination node id of the query edges
Returns:
pred: the prediction for all query edges
"""
pred = np.zeros(len(query_src))
idx = 0
for src, dst in zip(query_src, query_dst):
if (src, dst) in self.memory:
if (self.memory_mode == 'fixed_time_window'):
if (self.memory[(src,dst)] >= self.prev_t):
pred[idx] = self.pos_prob
else:
pred[idx] = self.pos_prob
idx += 1
return pred
================================================
FILE: modules/emb_module.py
================================================
"""
GNN-based modules used in the architecture of MP-TG models
"""
import math
from torch_geometric.nn import TransformerConv
import torch
class GraphAttentionEmbedding(torch.nn.Module):
"""
Reference:
- https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py
"""
def __init__(self, in_channels, out_channels, msg_dim, time_enc):
super().__init__()
self.time_enc = time_enc
edge_dim = msg_dim + time_enc.out_channels
self.conv = TransformerConv(
in_channels, out_channels // 2, heads=2, dropout=0.1, edge_dim=edge_dim
)
def forward(self, x, last_update, edge_index, t, msg):
rel_t = last_update[edge_index[0]] - t
rel_t_enc = self.time_enc(rel_t.to(x.dtype))
edge_attr = torch.cat([rel_t_enc, msg], dim=-1)
return self.conv(x, edge_index, edge_attr)
class TimeEmbedding(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
class NormalLinear(torch.nn.Linear):
# From TGN code: From JODIE code
def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.weight.size(1))
self.weight.data.normal_(0, stdv)
if self.bias is not None:
self.bias.data.normal_(0, stdv)
self.embedding_layer = NormalLinear(1, self.out_channels)
def forward(self, x, last_update, t):
rel_t = last_update - t
embeddings = x * (1 + self.embedding_layer(rel_t.to(x.dtype).unsqueeze(1)))
return embeddings
================================================
FILE: modules/heuristics.py
================================================
import numpy as np
class PersistantForecaster:
def __init__(self, num_class):
self.dict = {}
self.num_class = num_class
def update_dict(self, node_id, label):
self.dict[node_id] = label
def query_dict(self, node_id):
r"""
Parameters:
node_id: the node to query
Returns:
returns the last seen label of the node if it exists, if not return zero vector
"""
if node_id in self.dict:
return self.dict[node_id]
else:
return np.zeros(self.num_class)
class MovingAverage:
def __init__(self, num_class, window=7):
self.dict = {}
self.num_class = num_class
self.window = window
def update_dict(self, node_id, label):
if node_id in self.dict:
total = self.dict[node_id] * (self.window - 1) + label
self.dict[node_id] = total / self.window
else:
self.dict[node_id] = label
def query_dict(self, node_id):
r"""
Parameters:
node_id: the node to query
Returns:
returns the last seen label of the node if it exists, if not return zero vector
"""
if node_id in self.dict:
return self.dict[node_id]
else:
return np.zeros(self.num_class)
================================================
FILE: modules/memory_module.py
================================================
"""
Memory Module
Reference:
- https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/models/tgn.html
"""
import copy
from typing import Callable, Dict, Tuple
import torch
from torch import Tensor
from torch.nn import GRUCell, RNNCell, Linear
from torch_geometric.nn.inits import zeros
from torch_geometric.utils import scatter
from modules.time_enc import TimeEncoder
TGNMessageStoreType = Dict[int, Tuple[Tensor, Tensor, Tensor, Tensor]]
class TGNMemory(torch.nn.Module):
r"""The Temporal Graph Network (TGN) memory model from the
`"Temporal Graph Networks for Deep Learning on Dynamic Graphs"
`_ paper.
.. note::
For an example of using TGN, see `examples/tgn.py
`_.
Args:
num_nodes (int): The number of nodes to save memories for.
raw_msg_dim (int): The raw message dimensionality.
memory_dim (int): The hidden memory dimensionality.
time_dim (int): The time encoding dimensionality.
message_module (torch.nn.Module): The message function which
combines source and destination node memory embeddings, the raw
message and the time encoding.
aggregator_module (torch.nn.Module): The message aggregator function
which aggregates messages to the same destination into a single
representation.
"""
def __init__(
self,
num_nodes: int,
raw_msg_dim: int,
memory_dim: int,
time_dim: int,
message_module: Callable,
aggregator_module: Callable,
memory_updater_cell: str = "gru",
):
super().__init__()
self.num_nodes = num_nodes
self.raw_msg_dim = raw_msg_dim
self.memory_dim = memory_dim
self.time_dim = time_dim
self.msg_s_module = message_module
self.msg_d_module = copy.deepcopy(message_module)
self.aggr_module = aggregator_module
self.time_enc = TimeEncoder(time_dim)
# self.gru = GRUCell(message_module.out_channels, memory_dim)
if memory_updater_cell == "gru": # for TGN
self.memory_updater = GRUCell(message_module.out_channels, memory_dim)
elif memory_updater_cell == "rnn": # for JODIE & DyRep
self.memory_updater = RNNCell(message_module.out_channels, memory_dim)
else:
raise ValueError(
"Undefined memory updater!!! Memory updater can be either 'gru' or 'rnn'."
)
self.register_buffer("memory", torch.empty(num_nodes, memory_dim))
last_update = torch.empty(self.num_nodes, dtype=torch.long)
self.register_buffer("last_update", last_update)
self.register_buffer("_assoc", torch.empty(num_nodes, dtype=torch.long))
self.msg_s_store = {}
self.msg_d_store = {}
self.reset_parameters()
@property
def device(self) -> torch.device:
return self.time_enc.lin.weight.device
def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
if hasattr(self.msg_s_module, "reset_parameters"):
self.msg_s_module.reset_parameters()
if hasattr(self.msg_d_module, "reset_parameters"):
self.msg_d_module.reset_parameters()
if hasattr(self.aggr_module, "reset_parameters"):
self.aggr_module.reset_parameters()
self.time_enc.reset_parameters()
self.memory_updater.reset_parameters()
self.reset_state()
def reset_state(self):
"""Resets the memory to its initial state."""
zeros(self.memory)
zeros(self.last_update)
self._reset_message_store()
def detach(self):
"""Detaches the memory from gradient computation."""
self.memory.detach_()
def forward(self, n_id: Tensor) -> Tuple[Tensor, Tensor]:
"""Returns, for all nodes :obj:`n_id`, their current memory and their
last updated timestamp."""
if self.training:
memory, last_update = self._get_updated_memory(n_id)
else:
memory, last_update = self.memory[n_id], self.last_update[n_id]
return memory, last_update
def update_state(self, src: Tensor, dst: Tensor, t: Tensor, raw_msg: Tensor):
"""Updates the memory with newly encountered interactions
:obj:`(src, dst, t, raw_msg)`."""
n_id = torch.cat([src, dst]).unique()
if self.training:
self._update_memory(n_id)
self._update_msg_store(src, dst, t, raw_msg, self.msg_s_store)
self._update_msg_store(dst, src, t, raw_msg, self.msg_d_store)
else:
self._update_msg_store(src, dst, t, raw_msg, self.msg_s_store)
self._update_msg_store(dst, src, t, raw_msg, self.msg_d_store)
self._update_memory(n_id)
def _reset_message_store(self):
i = self.memory.new_empty((0,), device=self.device, dtype=torch.long)
msg = self.memory.new_empty((0, self.raw_msg_dim), device=self.device)
# Message store format: (src, dst, t, msg)
self.msg_s_store = {j: (i, i, i, msg) for j in range(self.num_nodes)}
self.msg_d_store = {j: (i, i, i, msg) for j in range(self.num_nodes)}
def _update_memory(self, n_id: Tensor):
memory, last_update = self._get_updated_memory(n_id)
self.memory[n_id] = memory
self.last_update[n_id] = last_update
def _get_updated_memory(self, n_id: Tensor) -> Tuple[Tensor, Tensor]:
self._assoc[n_id] = torch.arange(n_id.size(0), device=n_id.device)
# Compute messages (src -> dst).
msg_s, t_s, src_s, dst_s = self._compute_msg(
n_id, self.msg_s_store, self.msg_s_module
)
# Compute messages (dst -> src).
msg_d, t_d, src_d, dst_d = self._compute_msg(
n_id, self.msg_d_store, self.msg_d_module
)
# Aggregate messages.
idx = torch.cat([src_s, src_d], dim=0)
msg = torch.cat([msg_s, msg_d], dim=0)
t = torch.cat([t_s, t_d], dim=0)
aggr = self.aggr_module(msg, self._assoc[idx], t, n_id.size(0))
# Get local copy of updated memory.
memory = self.memory_updater(aggr, self.memory[n_id])
# Get local copy of updated `last_update`.
dim_size = self.last_update.size(0)
last_update = scatter(t, idx, 0, dim_size, reduce="max")[n_id]
return memory, last_update
def _update_msg_store(
self,
src: Tensor,
dst: Tensor,
t: Tensor,
raw_msg: Tensor,
msg_store: TGNMessageStoreType,
):
n_id, perm = src.sort()
n_id, count = n_id.unique_consecutive(return_counts=True)
for i, idx in zip(n_id.tolist(), perm.split(count.tolist())):
msg_store[i] = (src[idx], dst[idx], t[idx], raw_msg[idx])
def _compute_msg(
self, n_id: Tensor, msg_store: TGNMessageStoreType, msg_module: Callable
):
data = [msg_store[i] for i in n_id.tolist()]
src, dst, t, raw_msg = list(zip(*data))
src = torch.cat(src, dim=0)
dst = torch.cat(dst, dim=0)
t = torch.cat(t, dim=0)
raw_msg = torch.cat(raw_msg, dim=0)
t_rel = t - self.last_update[src]
t_enc = self.time_enc(t_rel.to(raw_msg.dtype))
msg = msg_module(self.memory[src], self.memory[dst], raw_msg, t_enc)
return msg, t, src, dst
def train(self, mode: bool = True):
"""Sets the module in training mode."""
if self.training and not mode:
# Flush message store to memory in case we just entered eval mode.
self._update_memory(torch.arange(self.num_nodes, device=self.memory.device))
self._reset_message_store()
super().train(mode)
class DyRepMemory(torch.nn.Module):
r"""
Based on intuitions from TGN Memory...
Differences with the original TGN Memory:
- can use source or destination embeddings in message generation
- can use a RNN or GRU module as the memory updater
Args:
num_nodes (int): The number of nodes to save memories for.
raw_msg_dim (int): The raw message dimensionality.
memory_dim (int): The hidden memory dimensionality.
time_dim (int): The time encoding dimensionality.
message_module (torch.nn.Module): The message function which
combines source and destination node memory embeddings, the raw
message and the time encoding.
aggregator_module (torch.nn.Module): The message aggregator function
which aggregates messages to the same destination into a single
representation.
memory_updater_type (str): specifies whether the memory updater is GRU or RNN
use_src_emb_in_msg (bool): whether to use the source embeddings
in generation of messages
use_dst_emb_in_msg (bool): whether to use the destination embeddings
in generation of messages
"""
def __init__(self, num_nodes: int, raw_msg_dim: int, memory_dim: int,
time_dim: int, message_module: Callable,
aggregator_module: Callable, memory_updater_type: str,
use_src_emb_in_msg: bool = False, use_dst_emb_in_msg: bool = False):
super().__init__()
self.num_nodes = num_nodes
self.raw_msg_dim = raw_msg_dim
self.memory_dim = memory_dim
self.time_dim = time_dim
self.msg_s_module = message_module
self.msg_d_module = copy.deepcopy(message_module)
self.aggr_module = aggregator_module
self.time_enc = TimeEncoder(time_dim)
assert memory_updater_type in ['gru', 'rnn'], "Memor updater can be either `rnn` or `gru`."
if memory_updater_type == 'gru': # for TGN
self.memory_updater = GRUCell(message_module.out_channels, memory_dim)
elif memory_updater_type == 'rnn': # for JODIE & DyRep
self.memory_updater = RNNCell(message_module.out_channels, memory_dim)
else:
raise ValueError("Undefined memory updater!!! Memory updater can be either 'gru' or 'rnn'.")
self.use_src_emb_in_msg = use_src_emb_in_msg
self.use_dst_emb_in_msg = use_dst_emb_in_msg
self.register_buffer('memory', torch.empty(num_nodes, memory_dim))
last_update = torch.empty(self.num_nodes, dtype=torch.long)
self.register_buffer('last_update', last_update)
self.register_buffer('_assoc', torch.empty(num_nodes,
dtype=torch.long))
self.msg_s_store = {}
self.msg_d_store = {}
self.reset_parameters()
@property
def device(self) -> torch.device:
return self.time_enc.lin.weight.device
def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
if hasattr(self.msg_s_module, 'reset_parameters'):
self.msg_s_module.reset_parameters()
if hasattr(self.msg_d_module, 'reset_parameters'):
self.msg_d_module.reset_parameters()
if hasattr(self.aggr_module, 'reset_parameters'):
self.aggr_module.reset_parameters()
self.time_enc.reset_parameters()
self.memory_updater.reset_parameters()
self.reset_state()
def reset_state(self):
"""Resets the memory to its initial state."""
zeros(self.memory)
zeros(self.last_update)
self._reset_message_store()
def detach(self):
"""Detaches the memory from gradient computation."""
self.memory.detach_()
def forward(self, n_id: Tensor) -> Tuple[Tensor, Tensor]:
"""Returns, for all nodes :obj:`n_id`, their current memory and their
last updated timestamp."""
if self.training:
memory, last_update = self._get_updated_memory(n_id)
else:
memory, last_update = self.memory[n_id], self.last_update[n_id]
return memory, last_update
def update_state(self, src: Tensor, dst: Tensor, t: Tensor, raw_msg: Tensor,
embeddings: Tensor = None, assoc: Tensor = None):
"""Updates the memory with newly encountered interactions
:obj:`(src, dst, t, raw_msg)`."""
n_id = torch.cat([src, dst]).unique()
if self.training:
self._update_memory(n_id, embeddings, assoc)
self._update_msg_store(src, dst, t, raw_msg, self.msg_s_store)
self._update_msg_store(dst, src, t, raw_msg, self.msg_d_store)
else:
self._update_msg_store(src, dst, t, raw_msg, self.msg_s_store)
self._update_msg_store(dst, src, t, raw_msg, self.msg_d_store)
self._update_memory(n_id, embeddings, assoc)
def _reset_message_store(self):
i = self.memory.new_empty((0, ), device=self.device, dtype=torch.long)
msg = self.memory.new_empty((0, self.raw_msg_dim), device=self.device)
# Message store format: (src, dst, t, msg)
self.msg_s_store = {j: (i, i, i, msg) for j in range(self.num_nodes)}
self.msg_d_store = {j: (i, i, i, msg) for j in range(self.num_nodes)}
def _update_memory(self, n_id: Tensor, embeddings: Tensor = None, assoc: Tensor = None):
memory, last_update = self._get_updated_memory(n_id, embeddings, assoc)
self.memory[n_id] = memory
self.last_update[n_id] = last_update
def _get_updated_memory(self, n_id: Tensor, embeddings: Tensor = None, assoc: Tensor = None) -> Tuple[Tensor, Tensor]:
self._assoc[n_id] = torch.arange(n_id.size(0), device=n_id.device)
# Compute messages (src -> dst).
msg_s, t_s, src_s, dst_s = self._compute_msg(n_id, self.msg_s_store,
self.msg_s_module, embeddings, assoc)
# Compute messages (dst -> src).
msg_d, t_d, src_d, dst_d = self._compute_msg(n_id, self.msg_d_store,
self.msg_d_module, embeddings, assoc)
# Aggregate messages.
idx = torch.cat([src_s, src_d], dim=0)
msg = torch.cat([msg_s, msg_d], dim=0)
t = torch.cat([t_s, t_d], dim=0)
aggr = self.aggr_module(msg, self._assoc[idx], t, n_id.size(0))
# Get local copy of updated memory.
memory = self.memory_updater(aggr, self.memory[n_id])
# Get local copy of updated `last_update`.
dim_size = self.last_update.size(0)
last_update = scatter(t, idx, 0, dim_size, reduce='max')[n_id]
return memory, last_update
def _update_msg_store(self, src: Tensor, dst: Tensor, t: Tensor,
raw_msg: Tensor, msg_store: TGNMessageStoreType):
n_id, perm = src.sort()
n_id, count = n_id.unique_consecutive(return_counts=True)
for i, idx in zip(n_id.tolist(), perm.split(count.tolist())):
msg_store[i] = (src[idx], dst[idx], t[idx], raw_msg[idx])
def _compute_msg(self, n_id: Tensor, msg_store: TGNMessageStoreType, msg_module: Callable,
embeddings: Tensor = None, assoc: Tensor = None):
data = [msg_store[i] for i in n_id.tolist()]
src, dst, t, raw_msg = list(zip(*data))
src = torch.cat(src, dim=0)
dst = torch.cat(dst, dim=0)
t = torch.cat(t, dim=0)
raw_msg = torch.cat(raw_msg, dim=0)
t_rel = t - self.last_update[src]
t_enc = self.time_enc(t_rel.to(raw_msg.dtype))
# source nodes: retrieve embeddings
source_memory = self.memory[src]
if self.use_src_emb_in_msg and embeddings != None:
if src.size(0) > 0:
curr_src, curr_src_idx = [], []
for s_idx, s in enumerate(src):
if s in n_id:
curr_src.append(s.item())
curr_src_idx.append(s_idx)
source_memory[curr_src_idx] = embeddings[assoc[curr_src]]
# destination nodes: retrieve embeddings
destination_memory = self.memory[dst]
if self.use_dst_emb_in_msg and embeddings != None:
if dst.size(0) > 0:
curr_dst, curr_dst_idx = [], []
for d_idx, d in enumerate(dst):
if d in n_id:
curr_dst.append(d.item())
curr_dst_idx.append(d_idx)
destination_memory[curr_dst_idx] = embeddings[assoc[curr_dst]]
msg = msg_module(source_memory, destination_memory, raw_msg, t_enc)
return msg, t, src, dst
def train(self, mode: bool = True):
"""Sets the module in training mode."""
if self.training and not mode:
# Flush message store to memory in case we just entered eval mode.
self._update_memory(
torch.arange(self.num_nodes, device=self.memory.device))
self._reset_message_store()
super().train(mode)
================================================
FILE: modules/msg_agg.py
================================================
"""
Message Aggregator Module
Reference:
- https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/models/tgn.html
"""
import torch
from torch import Tensor
from torch_geometric.utils import scatter
from torch_scatter import scatter_max
class LastAggregator(torch.nn.Module):
def forward(self, msg: Tensor, index: Tensor, t: Tensor, dim_size: int):
_, argmax = scatter_max(t, index, dim=0, dim_size=dim_size)
out = msg.new_zeros((dim_size, msg.size(-1)))
mask = argmax < msg.size(0) # Filter items with at least one entry.
out[mask] = msg[argmax[mask]]
return out
class MeanAggregator(torch.nn.Module):
def forward(self, msg: Tensor, index: Tensor, t: Tensor, dim_size: int):
return scatter(msg, index, dim=0, dim_size=dim_size, reduce="mean")
================================================
FILE: modules/msg_func.py
================================================
"""
Message Function Module
Reference:
- https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/models/tgn.html
"""
import torch
from torch import Tensor
class IdentityMessage(torch.nn.Module):
def __init__(self, raw_msg_dim: int, memory_dim: int, time_dim: int):
super().__init__()
self.out_channels = raw_msg_dim + 2 * memory_dim + time_dim
def forward(self, z_src: Tensor, z_dst: Tensor, raw_msg: Tensor, t_enc: Tensor):
return torch.cat([z_src, z_dst, raw_msg, t_enc], dim=-1)
================================================
FILE: modules/neighbor_loader.py
================================================
"""
Neighbor Loader
Reference:
- https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/models/tgn.html
"""
import copy
from typing import Callable, Dict, Tuple
import torch
from torch import Tensor
class LastNeighborLoader:
def __init__(self, num_nodes: int, size: int, device=None):
self.size = size
self.neighbors = torch.empty((num_nodes, size), dtype=torch.long, device=device)
self.e_id = torch.empty((num_nodes, size), dtype=torch.long, device=device)
self._assoc = torch.empty(num_nodes, dtype=torch.long, device=device)
self.reset_state()
def __call__(self, n_id: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
neighbors = self.neighbors[n_id]
nodes = n_id.view(-1, 1).repeat(1, self.size)
e_id = self.e_id[n_id]
# Filter invalid neighbors (identified by `e_id < 0`).
mask = e_id >= 0
neighbors, nodes, e_id = neighbors[mask], nodes[mask], e_id[mask]
# Relabel node indices.
n_id = torch.cat([n_id, neighbors]).unique()
self._assoc[n_id] = torch.arange(n_id.size(0), device=n_id.device)
neighbors, nodes = self._assoc[neighbors], self._assoc[nodes]
return n_id, torch.stack([neighbors, nodes]), e_id
def insert(self, src: Tensor, dst: Tensor):
# Inserts newly encountered interactions into an ever-growing
# (undirected) temporal graph.
# Collect central nodes, their neighbors and the current event ids.
neighbors = torch.cat([src, dst], dim=0)
nodes = torch.cat([dst, src], dim=0)
e_id = torch.arange(
self.cur_e_id, self.cur_e_id + src.size(0), device=src.device
).repeat(2)
self.cur_e_id += src.numel()
# Convert newly encountered interaction ids so that they point to
# locations of a "dense" format of shape [num_nodes, size].
nodes, perm = nodes.sort()
neighbors, e_id = neighbors[perm], e_id[perm]
n_id = nodes.unique()
self._assoc[n_id] = torch.arange(n_id.numel(), device=n_id.device)
dense_id = torch.arange(nodes.size(0), device=nodes.device) % self.size
dense_id += self._assoc[nodes].mul_(self.size)
dense_e_id = e_id.new_full((n_id.numel() * self.size,), -1)
dense_e_id[dense_id] = e_id
dense_e_id = dense_e_id.view(-1, self.size)
dense_neighbors = e_id.new_empty(n_id.numel() * self.size)
dense_neighbors[dense_id] = neighbors
dense_neighbors = dense_neighbors.view(-1, self.size)
# Collect new and old interactions...
e_id = torch.cat([self.e_id[n_id, : self.size], dense_e_id], dim=-1)
neighbors = torch.cat(
[self.neighbors[n_id, : self.size], dense_neighbors], dim=-1
)
# And sort them based on `e_id`.
e_id, perm = e_id.topk(self.size, dim=-1)
self.e_id[n_id] = e_id
self.neighbors[n_id] = torch.gather(neighbors, 1, perm)
def reset_state(self):
self.cur_e_id = 0
self.e_id.fill_(-1)
================================================
FILE: modules/nodebank.py
================================================
import numpy as np
class NodeBank(object):
def __init__(
self,
src: np.ndarray,
dst: np.ndarray,
):
r"""
maintains a dictionary of all nodes seen so far (specified by the input src and dst)
Parameters:
src: source node id of the edges
dst: destination node id of the edges
ts: timestamp of the edges
"""
self.nodebank = {}
self.update_memory(src, dst)
def update_memory(self,
update_src: np.ndarray,
update_dst: np.ndarray) -> None:
r"""
update self.memory with newly arrived src and dst
Parameters:
src: source node id of the edges
dst: destination node id of the edges
"""
for src, dst in zip(update_src, update_dst):
if src not in self.nodebank:
self.nodebank[src] = 1
if dst not in self.nodebank:
self.nodebank[dst] = 1
def query_node(self, node: int) -> bool:
r"""
query if node is in the memory
Parameters:
node: node id to query
Returns:
True if node is in the memory, False otherwise
"""
return node in self.nodebank
================================================
FILE: modules/recurrencybaseline_predictor.py
================================================
"""
from paper: "History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting"
Julia Gastinger, Christian Meilicke, Federico Errica, Timo Sztyler, Anett Schuelke, Heiner Stuckenschmidt (IJCAI 2024)
@inproceedings{gastinger2024baselines,
title={History repeats itself: A Baseline for Temporal Knowledge Graph Forecasting},
author={Gastinger, Julia and Meilicke, Christian and Errica, Federico and Sztyler, Timo and Schuelke, Anett and Stuckenschmidt, Heiner},
booktitle={33nd International Joint Conference on Artificial Intelligence (IJCAI 2024)},
year={2024},
organization={International Joint Conferences on Artificial Intelligence Organization}
}
"""
import numpy as np
from collections import Counter
import ray
from modules.tkg_utils import create_scores_array
@ray.remote
def baseline_predict_remote(num_queries, test_data, all_data, window, basis_dict, num_nodes,
num_rels, lmbda_psi, alpha, evaluator,first_test_ts, neg_sampler, split_mode='test'):
"""
Apply baselines psi and xi (multiprocessing possible). See baseline_predict for more details."""
return baseline_predict(num_queries, test_data, all_data, window, basis_dict, num_nodes,
num_rels, lmbda_psi, alpha, evaluator,first_test_ts, neg_sampler, split_mode)
def baseline_predict(num_queries, test_data, all_data, window, basis_dict, num_nodes,
num_rels, lmbda_psi, alpha, evaluator,first_test_ts, neg_sampler, split_mode='test'):
"""
Apply baselines psi and xi and compute scores and mrr per test or valid query (multiprocessing possible).
Parameters:
num_queries (int): minimum number of queries for each process
test_data (np.array): test quadruples (only used in single-step prediction, depending on window specified);
including inverse quadruples for subject prediction
all_data (np.array): train valid and test quadruples (test only used in single-step prediction, depending
on window specified); including inverse quadruples for subject prediction
window: int, specifying which values from the past can be used for prediction. 0: all edges before the test
query timestamp are included. -2: multistep. all edges from train and validation set used. as long as they are
< first_test_query_ts. Int n > 0, all edges within n timestamps before the test query timestamp are included.
basis_dict (dict): keys: rel_ids; specifies the predefined rules for each relation.
in our case: head rel = tail rel, confidence =1 for all rels in train/valid set
score_func_psi (method): method to use for computing time decay for psi
num_nodes (int): number of nodes in the dataset
num_rels (int): number of relations in the dataset
lambda_psi (float): parameter for time decay function for baselinepsi. 0: no decay, >1 very steep decay
alpha (float): parameter, weight to combine the scores from psi and xi. alpha*scores_psi + (1-alpha)*scores_xi
evaluator (method): method to compute mrr and hits
first_test_ts (int): timestamp of the first test query
neg_sampler (NegSampler): negative sampler
split_mode (str): 'test' or 'valid'
Returns:
performance_list and hits_list (one entry per query)
"""
num_this_queries = len(test_data)
cur_ts = test_data[0][3]
first_test_query_ts = first_test_ts #test_data[0][3]
edges, all_data_ts = get_window_edges(all_data, cur_ts, window, first_test_query_ts) # get for the current
# timestep all previous quadruples per relation that fullfill time constraints
rel_obj_dist_cur_ts = update_distributions(edges, num_rels)
if len(all_data_ts) >0:
sum_delta_t = update_delta_t(np.min(all_data_ts[:,3]), np.max(all_data_ts[:,3]), cur_ts, lmbda_psi)
predictions_xi=np.zeros(num_nodes)
predictions_psi=np.zeros(num_nodes)
# if num_queries != len(test_queries_idx):
# print('num_queries not equal to len(test_queries_idx)')
hits_list = [0] * num_this_queries #len(test_queries_idx)
perf_list = [0] * num_this_queries #* len(test_queries_idx)
for j in range(num_this_queries):
neg_sample_el = neg_sampler.query_batch(np.expand_dims(np.array(test_data[j,0]), axis=0),
np.expand_dims(np.array(test_data[j,2]), axis=0),
np.expand_dims(np.array(test_data[j,4]), axis=0),
np.expand_dims(np.array(test_data[j,1]), axis=0),
split_mode)[0]
pos_sample_el = test_data[j,2]
test_query = test_data[j]
assert(pos_sample_el == test_query[2])
cands_dict = dict()
cands_dict_psi = dict()
# 1) update timestep and known triples
if test_query[3] != cur_ts: # if we have a new timestep
cur_ts = test_query[3]
edges, all_data_ts = get_window_edges(all_data, cur_ts, window, first_test_query_ts) # get for the current
# timestep all previous quadruples per relation that fullfill time constraints
# update the object and rel-object distritbutions to take into account what timesteps to use
if window > -1: #otherwise: multistep, we do not need to update
rel_obj_dist_cur_ts = update_distributions( edges, num_rels)
if len(all_data_ts) >0:
if window > -1: #otherwise: multistep, we do not need to update
sum_delta_t = update_delta_t(np.min(all_data_ts[:,3]), np.max(all_data_ts[:,3]), cur_ts, lmbda_psi)
#### BASELINE PSI
# 2) apply rules for relation of interest, if we have any
if str(test_query[1]) in basis_dict: # do we have rules for the given relation?
walk_edges = match_body_relations(basis_dict[str(test_query[1])][0], edges, test_query[0])
# Find quadruples that match the rule (starting from the test query subject)
# Find edges whose subject match the query subject and the relation matches
# the relation in the rule body. np array with [[sub, obj, ts]]
if 0 not in [len(x) for x in walk_edges]: # if we found at least one potential rule
cands_dict_psi = get_candidates_psi(walk_edges[0][:,1:3], cur_ts, cands_dict, lmbda_psi, sum_delta_t)
if len(cands_dict_psi)>0:
# predictions_psi = create_scores_tensor(cands_dict_psi, num_nodes)
predictions_psi = create_scores_array(cands_dict_psi, num_nodes)
#### BASELINE XI
predictions_xi = create_scores_array(rel_obj_dist_cur_ts[test_query[1]], num_nodes)
# predictions_xi = create_scores_tensor(rel_obj_dist_cur_ts[test_query[1]], num_nodes)
#### Combine Both
predictions_all = 1000*alpha*predictions_psi + 1000*(1-alpha)*predictions_xi
# predictions_of_interest_pos = predimctions_all[pos_sample_el].unsqueeze(0)
predictions_of_interest_pos = np.array(predictions_all[pos_sample_el])
predictions_of_interest_neg = predictions_all[neg_sample_el]
input_dict = {
"y_pred_pos": predictions_of_interest_pos,
"y_pred_neg": predictions_of_interest_neg,
"eval_metric": ['mrr'],
}
predictions = evaluator.eval(input_dict)
perf_list[j] = float(predictions['mrr'])
hits_list[j] = float(predictions['hits@10'])
return perf_list, hits_list
def match_body_relations(rule, edges, test_query_sub):
"""
for rules of length 1
Find quadruples that match the rule (starting from the test query subject)
Find edges whose subject match the query subject and the relation matches
the relation in the rule body.
Memory-efficient implementation.
modified from Tlogic rule_application.py https://github.com/liu-yushan/TLogic/blob/main/mycode/rule_application.py
shortened because we only have rules of length one
Parameters:
rule (dict): rule from rules_dict
edges (dict): edges for rule application
test_query_sub (int): test query subject
Returns:
walk_edges (list of np.ndarrays): edges that could constitute rule walks
"""
rels = rule["body_rels"]
# Match query subject and first body relation
try:
rel_edges = edges[rels[0]]
mask = rel_edges[:, 0] == test_query_sub
new_edges = rel_edges[mask]
walk_edges = [np.hstack((new_edges[:, 0:1], new_edges[:, 2:4]))] # [sub, obj, ts]
except KeyError:
walk_edges = [[]]
return walk_edges #subject object timestamp
def score_delta(cands_ts, test_query_ts, lmbda):
""" deta function to score a given candidate based on its distance to current timestep and based on param lambda
Parameters:
cands_ts (int): timestep of candidate(s)
test_query_ts (int): timestep of current test quadruple
lmbda (float): param to specify how steep decay is
Returns:
score (float): score for a given candicate
"""
score = pow(2, lmbda * (cands_ts - test_query_ts))
return score
def get_window_edges(all_data, test_query_ts, window=-2, first_test_query_ts=0):
"""
modified from Tlogic rule_application.py https://github.com/liu-yushan/TLogic/blob/main/mycode/rule_application.py
introduce window -2
Get the edges in the data (for rule application) that occur in the specified time window.
If window is 0, all edges before the test query timestamp are included.
If window is -2, all edges from train and validation set are used. as long as they are < first_test_query_ts
If window is an integer n > 0, all edges within n timestamps before the test query
timestamp are included.
Parameters:
all_data (np.ndarray): complete dataset (train/valid/test)
test_query_ts (np.ndarray): test query timestamp
window (int): time window used for rule application
first_test_query_ts (int): smallest timestamp from test set (eval_paper_authors)
Returns:
window_edges (dict): edges in the window for rule application
"""
if window > 0:
mask = (all_data[:, 3] < test_query_ts) * (
all_data[:, 3] >= test_query_ts - window
)
window_edges = quads_per_rel(all_data[mask]) # quadruples per relation that fullfill the time constraints
elif window == 0:
mask = all_data[:, 3] < test_query_ts #!!!
window_edges = quads_per_rel(all_data[mask])
elif window == -2: #modified eval_paper_authors: added this option
mask = all_data[:, 3] < first_test_query_ts # all edges at timestep smaller then the test queries. meaning all from train and valid set
window_edges = quads_per_rel(all_data[mask])
elif window == -200: #modified eval_paper_authors: added this option
abswindow = 200
mask = (all_data[:, 3] < first_test_query_ts) * (
all_data[:, 3] >= first_test_query_ts - abswindow # all edges at timestep smaller than the test queries - 200
)
window_edges = quads_per_rel(all_data[mask])
all_data_ts = all_data[mask]
return window_edges, all_data_ts
def quads_per_rel(quads):
"""
modified from Tlogic rule_application.py https://github.com/liu-yushan/TLogic/blob/main/mycode/rule_application.py
Store all edges for each relation.
Parameters:
quads (np.ndarray): indices of quadruples
Returns:
edges (dict): edges for each relation
"""
edges = dict()
relations = list(set(quads[:, 1]))
for rel in relations:
edges[rel] = quads[quads[:, 1] == rel]
return edges
def get_candidates_psi(rule_walks, test_query_ts, cands_dict,lmbda, sum_delta_t):
"""
Get answer candidates from the walks that follow the rule.
Add the confidence of the rule that leads to these candidates.
originally from TLogic https://github.com/liu-yushan/TLogic/blob/main/mycode/apply.py but heavily modified
Parameters:
rule_walks (np.array): rule walks np array with [[sub, obj]]
test_query_ts (int): test query timestamp
cands_dict (dict): candidates along with the confidences of the rules that generated these candidates
score_func (function): function for calculating the candidate score
lmbda (float): parameter to describe decay of the scoring function
sum_delta_t: to be used in denominator of scoring fct
Returns:
cands_dict (dict): keys: candidates, values: score for the candidates """
cands = set(rule_walks[:,0])
for cand in cands:
cands_walks = rule_walks[rule_walks[:,0] == cand]
score = score_psi(cands_walks, test_query_ts, lmbda, sum_delta_t).astype(np.float64)
cands_dict[cand] = score
return cands_dict
def update_delta_t(min_ts, max_ts, cur_ts, lmbda):
""" compute denominator for scoring function psi_delta
Patameters:
min_ts (int): minimum available timestep
max_ts (int): maximum available timestep
cur_ts (int): current timestep
lmbda (float): time decay parameter
Returns:
delta_all (float): sum(delta_t for all available timesteps between min_ts and max_ts)
"""
timesteps = np.arange(min_ts, max_ts)
now = np.ones(len(timesteps))*cur_ts
delta_all = score_delta(timesteps, now, lmbda)
delta_all = np.sum(delta_all)
return delta_all
def score_psi(cands_walks, test_query_ts, lmbda, sum_delta_t):
"""
Calculate candidate score depending on the time difference.
Parameters:
cands_walks (np.array): rule walks np array with [[sub, obj]]
test_query_ts (int): test query timestamp
lmbda (float): rate of exponential distribution
Returns:
score (float): candidate score
"""
all_cands_ts = cands_walks[:,1] #cands_walks["timestamp_0"].reset_index()["timestamp_0"]
ts_series = np.ones(len(all_cands_ts))*test_query_ts
scores = score_delta(all_cands_ts, ts_series, lmbda) # Score depending on time difference
if sum_delta_t == 0:
print(scores, "sum_delta_t is zero")
print(all_cands_ts)
score = np.sum(scores)
# print(score)
else:
score = np.sum(scores)/sum_delta_t
return score
def update_distributions(ts_edges,num_rels):
""" update the distributions with more recent infos, if there is a more recent timestep available, depending on window parameter
take into account scaling factor
"""
rel_obj_dist_cur_ts= calculate_obj_distribution(ts_edges, num_rels) #, lmbda, cur_ts)
return rel_obj_dist_cur_ts
def calculate_obj_distribution(edges, num_rels):
"""
Calculate the overall object distribution and the object distribution for each relation in the data.
Parameters:
edges (dict): edges from the data on which the rules should be learned
Returns:
rel_obj_dist (dict): object distribution for each relation
"""
rel_obj_dist_scaled = dict()
for rel in range(num_rels):
rel_obj_dist_scaled[rel] = {}
for rel in edges:
objects = edges[rel][:, 2]
dist = Counter(objects)
for obj in dist:
dist[obj] /= len(objects)
rel_obj_dist_scaled[rel] = {k: v for k, v in dist.items()}
return rel_obj_dist_scaled
def update_delta_t(min_ts, max_ts, cur_ts, lmbda):
""" compute denominator for scoring function psi_delta
Patameters:
min_ts (int): minimum available timestep
max_ts (int): maximum available timestep
cur_ts (int): current timestep
lmbda (float): time decay parameter
Returns:
delta_all (float): sum(delta_t for all available timesteps between min_ts and max_ts)
"""
timesteps = np.arange(min_ts, max_ts)
now = np.ones(len(timesteps))*cur_ts
delta_all = score_delta(timesteps, now, lmbda)
delta_all = np.sum(delta_all)
return delta_all
================================================
FILE: modules/rgcn_layers.py
================================================
"""
https://github.com/Lee-zix/CEN/blob/main/rgcn/layers.py
"""
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
class RGCNLayer(nn.Module):
def __init__(self, in_feat, out_feat, bias=None, activation=None,
self_loop=False, skip_connect=False, dropout=0.0, layer_norm=False):
""" init of the RGCN layer class
from https://github.com/Lee-zix/CEN/blob/main/rgcn/layers.py
"""
super(RGCNLayer, self).__init__()
self.bias = bias
self.activation = activation
self.self_loop = self_loop
self.skip_connect = skip_connect
self.layer_norm = layer_norm
if self.bias:
self.bias = nn.Parameter(torch.Tensor(out_feat))
nn.init.xavier_uniform_(self.bias,
gain=nn.init.calculate_gain('relu'))
# weight for self loop
if self.self_loop:
self.loop_weight = nn.Parameter(torch.Tensor(in_feat, out_feat))
nn.init.xavier_uniform_(self.loop_weight, gain=nn.init.calculate_gain('relu'))
# self.loop_weight = nn.Parameter(torch.eye(out_feat), requires_grad=False)
if self.skip_connect:
self.skip_connect_weight = nn.Parameter(torch.Tensor(out_feat, out_feat)) # 和self-loop不一样,是跨层的计算
nn.init.xavier_uniform_(self.skip_connect_weight,
gain=nn.init.calculate_gain('relu'))
self.skip_connect_bias = nn.Parameter(torch.Tensor(out_feat))
nn.init.zeros_(self.skip_connect_bias) # 初始化设置为0
if dropout:
self.dropout = nn.Dropout(dropout)
else:
self.dropout = None
if self.layer_norm:
self.normalization_layer = nn.LayerNorm(out_feat, elementwise_affine=False)
# define how propagation is done in subclass
def propagate(self, g):
raise NotImplementedError
def forward(self, g, prev_h=[]):
if self.self_loop:
#print(self.loop_weight)
loop_message = torch.mm(g.ndata['h'], self.loop_weight)
if self.dropout is not None:
loop_message = self.dropout(loop_message)
# self.skip_connect_weight.register_hook(lambda g: print("grad of skip connect weight: {}".format(g)))
if len(prev_h) != 0 and self.skip_connect:
skip_weight = F.sigmoid(torch.mm(prev_h, self.skip_connect_weight) + self.skip_connect_bias) # 使用sigmoid,让值在0~1
# print("skip_ weight")
# print(skip_weight)
# print("skip connect weight")
# print(self.skip_connect_weight)
# print(torch.mm(prev_h, self.skip_connect_weight))
self.propagate(g) # 这里是在计算从周围节点传来的信息
# apply bias and activation
node_repr = g.ndata['h']
if self.bias:
node_repr = node_repr + self.bias
# print(len(prev_h))
if len(prev_h) != 0 and self.skip_connect: # 两次计算loop_message的方式不一样,前者激活后再加权
previous_node_repr = (1 - skip_weight) * prev_h
if self.activation:
node_repr = self.activation(node_repr)
if self.self_loop:
if self.activation:
loop_message = skip_weight * self.activation(loop_message)
else:
loop_message = skip_weight * loop_message
node_repr = node_repr + loop_message
node_repr = node_repr + previous_node_repr
else:
if self.self_loop:
node_repr = node_repr + loop_message
if self.layer_norm:
node_repr = self.normalization_layer(node_repr)
if self.activation:
node_repr = self.activation(node_repr)
# print("node_repr")
# print(node_repr)
g.ndata['h'] = node_repr
return node_repr
class RGCNBasisLayer(RGCNLayer):
def __init__(self, in_feat, out_feat, num_rels, num_bases=-1, bias=None,
activation=None, is_input_layer=False):
super(RGCNBasisLayer, self).__init__(in_feat, out_feat, bias, activation)
self.in_feat = in_feat
self.out_feat = out_feat
self.num_rels = num_rels
self.num_bases = num_bases
self.is_input_layer = is_input_layer
if self.num_bases <= 0 or self.num_bases > self.num_rels:
self.num_bases = self.num_rels
# add basis weights
self.weight = nn.Parameter(torch.Tensor(self.num_bases, self.in_feat,
self.out_feat))
if self.num_bases < self.num_rels:
# linear combination coefficients
self.w_comp = nn.Parameter(torch.Tensor(self.num_rels,
self.num_bases))
nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))
if self.num_bases < self.num_rels:
nn.init.xavier_uniform_(self.w_comp,
gain=nn.init.calculate_gain('relu'))
def propagate(self, g):
if self.num_bases < self.num_rels:
# generate all weights from bases
weight = self.weight.view(self.num_bases,
self.in_feat * self.out_feat)
weight = torch.matmul(self.w_comp, weight).view(
self.num_rels, self.in_feat, self.out_feat)
else:
weight = self.weight
if self.is_input_layer:
def msg_func(edges):
# for input layer, matrix multiply can be converted to be
# an embedding lookup using source node id
embed = weight.view(-1, self.out_feat)
index = edges.data['type'] * self.in_feat + edges.src['id']
return {'msg': embed.index_select(0, index)}
else:
def msg_func(edges):
w = weight.index_select(0, edges.data['type'])
msg = torch.bmm(edges.src['h'].unsqueeze(1), w).squeeze()
return {'msg': msg}
def apply_func(nodes):
return {'h': nodes.data['h'] * nodes.data['norm']}
g.update_all(msg_func, fn.sum(msg='msg', out='h'), apply_func)
class RGCNBlockLayer(RGCNLayer):
def __init__(self, in_feat, out_feat, num_rels, num_bases, bias=None,
activation=None, self_loop=False, dropout=0.0, skip_connect=False, layer_norm=False):
super(RGCNBlockLayer, self).__init__(in_feat, out_feat, bias,
activation, self_loop=self_loop, skip_connect=skip_connect,
dropout=dropout)
self.num_rels = num_rels
self.num_bases = num_bases
assert self.num_bases > 0
self.out_feat = out_feat
self.submat_in = in_feat // self.num_bases
self.submat_out = out_feat // self.num_bases
# assuming in_feat and out_feat are both divisible by num_bases
self.weight = nn.Parameter(torch.Tensor(
self.num_rels, self.num_bases * self.submat_in * self.submat_out))
nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))
def msg_func(self, edges):
weight = self.weight.index_select(0, edges.data['type']).view(
-1, self.submat_in, self.submat_out) # [edge_num, submat_in, submat_out]
node = edges.src['h'].view(-1, 1, self.submat_in) # [edge_num * num_bases, 1, submat_in]->
msg = torch.bmm(node, weight).view(-1, self.out_feat) # [edge_num, out_feat]
return {'msg': msg}
def propagate(self, g):
g.update_all(self.msg_func, fn.sum(msg='msg', out='h'), self.apply_func)
# g.updata_all ({'msg': msg} , fn.sum(msg='msg', out='h'), {'h': nodes.data['h'] * nodes.data[''norm]})
def apply_func(self, nodes):
return {'h': nodes.data['h'] * nodes.data['norm']}
class UnionRGCNLayer(nn.Module):
def __init__(self, in_feat, out_feat, num_rels, num_bases=-1, bias=None,
activation=None, self_loop=False, dropout=0.0, skip_connect=False, rel_emb=None):
super(UnionRGCNLayer, self).__init__()
self.in_feat = in_feat
self.out_feat = out_feat
self.bias = bias
self.activation = activation
self.self_loop = self_loop
self.num_rels = num_rels
self.skip_connect = skip_connect
self.emb_rel = rel_emb
self.ob = None
self.sub = None
# WL
self.weight_neighbor = nn.Parameter(torch.Tensor(self.in_feat, self.out_feat))
nn.init.xavier_uniform_(self.weight_neighbor, gain=nn.init.calculate_gain('relu'))
if self.self_loop:
self.loop_weight = nn.Parameter(torch.Tensor(in_feat, out_feat))
nn.init.xavier_uniform_(self.loop_weight, gain=nn.init.calculate_gain('relu'))
self.evolve_loop_weight = nn.Parameter(torch.Tensor(in_feat, out_feat))
nn.init.xavier_uniform_(self.evolve_loop_weight, gain=nn.init.calculate_gain('relu'))
if self.skip_connect:
self.skip_connect_weight = nn.Parameter(torch.Tensor(out_feat, out_feat)) # 和self-loop不一样,是跨层的计算
nn.init.xavier_uniform_(self.skip_connect_weight,gain=nn.init.calculate_gain('relu'))
self.skip_connect_bias = nn.Parameter(torch.Tensor(out_feat))
nn.init.zeros_(self.skip_connect_bias) # 初始化设置为0
if dropout:
self.dropout = nn.Dropout(dropout)
else:
self.dropout = None
def propagate(self, g):
g.update_all(lambda x: self.msg_func(x), fn.sum(msg='msg', out='h'), self.apply_func)
def forward(self, g, prev_h):
# self.sub = sub
# self.ob = ob
if self.self_loop:
#loop_message = torch.mm(g.ndata['h'], self.loop_weight)
# masked_index = torch.masked_select(torch.arange(0, g.number_of_nodes(), dtype=torch.long), (g.in_degrees(range(g.number_of_nodes())) > 0))
masked_index = torch.masked_select(
torch.arange(0, g.number_of_nodes(), dtype=torch.long).cuda(),
(g.in_degrees(range(g.number_of_nodes())) > 0))
loop_message = torch.mm(g.ndata['h'], self.evolve_loop_weight)
loop_message[masked_index, :] = torch.mm(g.ndata['h'], self.loop_weight)[masked_index, :]
if len(prev_h) != 0 and self.skip_connect:
skip_weight = F.sigmoid(torch.mm(prev_h, self.skip_connect_weight) + self.skip_connect_bias) # 使用sigmoid,让值在0~1
# calculate the neighbor message with weight_neighbor
self.propagate(g)
node_repr = g.ndata['h']
# print(len(prev_h))
if len(prev_h) != 0 and self.skip_connect: # 两次计算loop_message的方式不一样,前者激活后再加权
if self.self_loop:
node_repr = node_repr + loop_message
node_repr = skip_weight * node_repr + (1 - skip_weight) * prev_h
else:
if self.self_loop:
node_repr = node_repr + loop_message
if self.activation:
node_repr = self.activation(node_repr)
if self.dropout is not None:
node_repr = self.dropout(node_repr)
g.ndata['h'] = node_repr
return node_repr
def msg_func(self, edges):
# if reverse:
# relation = self.rel_emb.index_select(0, edges.data['type_o']).view(-1, self.out_feat)
# else:
# relation = self.rel_emb.index_select(0, edges.data['type_s']).view(-1, self.out_feat)
relation = self.emb_rel.index_select(0, edges.data['type']).view(-1, self.out_feat)
edge_type = edges.data['type']
edge_num = edge_type.shape[0]
node = edges.src['h'].view(-1, self.out_feat)
# node = torch.cat([torch.matmul(node[:edge_num // 2, :], self.sub),
# torch.matmul(node[edge_num // 2:, :], self.ob)])
# node = torch.matmul(node, self.sub)
# after add inverse edges, we only use message pass when h as tail entity
# 这里计算的是每个节点发出的消息,节点发出消息时其作为头实体
# msg = torch.cat((node, relation), dim=1)
msg = node + relation
# calculate the neighbor message with weight_neighbor
msg = torch.mm(msg, self.weight_neighbor)
return {'msg': msg}
def apply_func(self, nodes):
return {'h': nodes.data['h'] * nodes.data['norm']}
================================================
FILE: modules/rgcn_model.py
================================================
"""
https://github.com/nec-research/CEN/blob/main/src/model.py
"""
import torch.nn as nn
class BaseRGCN(nn.Module):
def __init__(self, num_nodes, h_dim, out_dim, num_rels, num_bases=-1, num_basis=-1,
num_hidden_layers=1, dropout=0, self_loop=False, skip_connect=False, encoder_name="", opn="sub",
rel_emb=None, use_cuda=False, analysis=False):
super(BaseRGCN, self).__init__()
self.num_nodes = num_nodes
self.h_dim = h_dim
self.out_dim = out_dim
self.num_rels = num_rels
self.num_bases = num_bases
self.num_basis = num_basis
self.num_hidden_layers = num_hidden_layers
self.dropout = dropout
self.skip_connect = skip_connect
self.self_loop = self_loop
self.encoder_name = encoder_name
self.use_cuda = use_cuda
self.run_analysis = analysis
self.skip_connect = skip_connect
print("use layer :{}".format(encoder_name))
self.rel_emb = rel_emb
self.opn = opn
# create rgcn layers
self.build_model()
# create initial features
self.features = self.create_features()
def build_model(self):
self.layers = nn.ModuleList()
# i2h
i2h = self.build_input_layer()
if i2h is not None:
self.layers.append(i2h)
# h2h
for idx in range(self.num_hidden_layers):
h2h = self.build_hidden_layer(idx)
self.layers.append(h2h)
# h2o
h2o = self.build_output_layer()
if h2o is not None:
self.layers.append(h2o)
# initialize feature for each node
def create_features(self):
return None
def build_input_layer(self):
return None
def build_hidden_layer(self, idx):
raise NotImplementedError
def build_output_layer(self):
return None
def forward(self, g):
if self.features is not None:
g.ndata['id'] = self.features
print("h before GCN message passing")
print(g.ndata['h'])
print("h behind GCN message passing")
for layer in self.layers:
layer(g)
print(g.ndata['h'])
return g.ndata.pop('h')
================================================
FILE: modules/rrgcn.py
================================================
"""
https://github.com/Lee-zix/CEN/blob/main/src/rrgcn.py
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from modules.rgcn_layers import UnionRGCNLayer, RGCNBlockLayer
from modules.rgcn_model import BaseRGCN
from modules.decoder import ConvTransE
import numpy as np
class RGCNCell(BaseRGCN):
def build_hidden_layer(self, idx):
act = F.rrelu
if idx:
self.num_basis = 0
print("activate function: {}".format(act))
if self.skip_connect:
sc = False if idx == 0 else True
else:
sc = False
if self.encoder_name == "uvrgcn":
return UnionRGCNLayer(self.h_dim, self.h_dim, self.num_rels, self.num_bases,
activation=act, dropout=self.dropout, self_loop=self.self_loop, skip_connect=sc,
rel_emb=self.rel_emb)
else:
raise NotImplementedError
def forward(self, g, init_ent_emb):
if self.encoder_name == "uvrgcn":
node_id = g.ndata['id'].squeeze()
g.ndata['h'] = init_ent_emb[node_id]
for i, layer in enumerate(self.layers):
layer(g, [])
return g.ndata.pop('h')
else:
if self.features is not None:
print("----------------Feature is not None, Attention ------------")
g.ndata['id'] = self.features
node_id = g.ndata['id'].squeeze()
g.ndata['h'] = init_ent_emb[node_id]
if self.skip_connect:
prev_h = []
for layer in self.layers:
prev_h = layer(g, prev_h)
else:
for layer in self.layers:
layer(g, [])
return g.ndata.pop('h')
class RecurrentRGCNCEN(nn.Module):
def __init__(self, decoder_name, encoder_name, num_ents, num_rels, h_dim, opn, sequence_len, num_bases=-1, num_basis=-1,
num_hidden_layers=1, dropout=0, self_loop=False, skip_connect=False, layer_norm=False, input_dropout=0,
hidden_dropout=0, feat_dropout=0, entity_prediction=False, relation_prediction=False, use_cuda=False,
gpu = 0):
super(RecurrentRGCNCEN, self).__init__()
self.decoder_name = decoder_name
self.encoder_name = encoder_name
self.num_rels = num_rels
self.num_ents = num_ents
self.opn = opn
self.sequence_len = sequence_len
self.h_dim = h_dim
self.layer_norm = layer_norm
self.h = None
self.relation_prediction = relation_prediction
self.entity_prediction = entity_prediction
self.gpu = gpu
self.emb_rel = torch.nn.Parameter(torch.Tensor(self.num_rels * 2, self.h_dim), requires_grad=True).float() #TODO: correct number?
torch.nn.init.xavier_normal_(self.emb_rel)
self.dynamic_emb = torch.nn.Parameter(torch.Tensor(num_ents, h_dim), requires_grad=True).float()
torch.nn.init.normal_(self.dynamic_emb)
self.loss_e = torch.nn.CrossEntropyLoss()
self.rgcn = RGCNCell(num_ents,
h_dim,
h_dim,
num_rels * 2,
num_bases,
num_basis,
num_hidden_layers,
dropout,
self_loop,
skip_connect,
encoder_name,
self.opn,
self.emb_rel,
use_cuda)
self.time_gate_weight = nn.Parameter(torch.Tensor(h_dim, h_dim))
nn.init.xavier_uniform_(self.time_gate_weight, gain=nn.init.calculate_gain('relu'))
self.time_gate_bias = nn.Parameter(torch.Tensor(h_dim))
nn.init.zeros_(self.time_gate_bias)
if decoder_name == "convtranse":
self.decoder_ob = ConvTransE(num_ents, h_dim, input_dropout, hidden_dropout, feat_dropout,
sequence_len=self.sequence_len, model_name='CEN')
else:
raise NotImplementedError
def forward(self, g_list, use_cuda):
evolve_embs = []
self.h = F.normalize(self.dynamic_emb) if self.layer_norm else self.dynamic_emb
for i, g in enumerate(g_list):
g = g.to(self.gpu)
current_h = self.rgcn.forward(g, self.h)
current_h = F.normalize(current_h) if self.layer_norm else current_h
time_weight = F.sigmoid(torch.mm(self.h, self.time_gate_weight) + self.time_gate_bias)
self.h = time_weight * current_h + (1-time_weight) * self.h
self.h = F.normalize(self.h)
evolve_embs.append(self.h)
return evolve_embs, self.emb_rel
def predict(self, test_graph, test_triplets, use_cuda, neg_samples_batch=None, pos_samples_batch=None,
evaluator=None, metric=None):
with torch.no_grad():
scores = torch.zeros(len(test_triplets), self.num_ents).cuda()
evolve_embeddings = []
for idx in range(len(test_graph)):
evolve_embs, r_emb = self.forward(test_graph[idx:], use_cuda)
evolve_embeddings.append(evolve_embs[-1])
evolve_embeddings.reverse()
if neg_samples_batch != None: # added by tgb team
perf_list = []
hits_list = []
for query_id, query in enumerate(neg_samples_batch): # for each sample separately
pos = pos_samples_batch[query_id]
neg = torch.tensor(query).to(pos.device)
all =torch.cat((pos.unsqueeze(0), neg), dim=0)
score_list = self.decoder_ob.forward(evolve_embeddings, r_emb, test_triplets[query_id].unsqueeze(0),
samples_of_interest_emb= [evolve_embeddings[i][all] for i in range(len(evolve_embeddings))])
score_list = [_.unsqueeze(2) for _ in score_list]
scores_b = torch.cat(score_list, dim=2)
scores_b = torch.softmax(scores_b, dim=1)
scores_b = torch.sum(scores_b, dim=-1)
# compute MRR
input_dict = {
"y_pred_pos": np.array([scores_b[0,0].cpu()]),
"y_pred_neg": np.array(scores_b[0,1:].cpu()),
"eval_metric": [metric],
}
prediction_perf = evaluator.eval(input_dict)
perf_list.append(prediction_perf[metric])
hits_list.append(prediction_perf['hits@10'])
else:
score_list = self.decoder_ob.forward(evolve_embeddings, r_emb, test_triplets, mode="test")
score_list = [_.unsqueeze(2) for _ in score_list]
scores = torch.cat(score_list, dim=2)
scores = torch.softmax(scores, dim=1)
scores = torch.sum(scores, dim=-1)
return scores, perf_list, hits_list
def get_ft_loss(self, glist, triple_list, use_cuda):
#"""
#:param glist:
#:param triplets:
#:param use_cuda:
#:return:
#"""
glist = [g.to(self.gpu) for g in glist]
loss_ent = torch.zeros(1).cuda().to(self.gpu) if use_cuda else torch.zeros(1)
# for step, triples in enumerate(triple_list):
evolve_embeddings = []
for idx in range(len(glist)):
evolve_embs, r_emb = self.forward(glist[idx:], use_cuda)
evolve_embeddings.append(evolve_embs[-1])
evolve_embeddings.reverse()
scores_ob = self.decoder_ob.forward(evolve_embeddings, r_emb, triple_list[-1])#.view(-1, self.num_ents)
for idx in range(len(glist)):
loss_ent += self.loss_e(scores_ob[idx], triple_list[-1][:, 2])
return loss_ent
def get_loss(self, glist, triples, prev_model, use_cuda):
"""
:param glist:
:param triplets:
:param use_cuda:
:return:
"""
loss_ent = torch.zeros(1).cuda().to(self.gpu) if use_cuda else torch.zeros(1)
evolve_embeddings = []
for idx in range(len(glist)):
evolve_embs, r_emb = self.forward(glist[idx:], use_cuda)
evolve_embeddings.append(evolve_embs[-1])
evolve_embeddings.reverse()
if self.entity_prediction:
scores_ob = self.decoder_ob.forward(evolve_embeddings, r_emb, triples)#.view(-1, self.num_ents)
for idx in range(len(glist)):
loss_ent += self.loss_e(scores_ob[idx], triples[:, 2])
return loss_ent
class RecurrentRGCNREGCN(nn.Module):
def __init__(self, decoder_name, encoder_name, num_ents, num_rels, num_static_rels, num_words, h_dim, opn, sequence_len, num_bases=-1, num_basis=-1,
num_hidden_layers=1, dropout=0, self_loop=False, skip_connect=False, layer_norm=False, input_dropout=0,
hidden_dropout=0, feat_dropout=0, aggregation='cat', weight=1, discount=0, angle=0, use_static=False,
entity_prediction=False, relation_prediction=False, use_cuda=False,
gpu = 0, analysis=False):
super(RecurrentRGCNREGCN, self).__init__()
self.decoder_name = decoder_name
self.encoder_name = encoder_name
self.num_rels = num_rels
self.num_ents = num_ents
self.opn = opn
self.num_words = num_words
self.num_static_rels = num_static_rels
self.sequence_len = sequence_len
self.h_dim = h_dim
self.layer_norm = layer_norm
self.h = None
self.run_analysis = analysis
self.aggregation = aggregation
self.relation_evolve = False
self.weight = weight
self.discount = discount
self.use_static = use_static
self.angle = angle
self.relation_prediction = relation_prediction
self.entity_prediction = entity_prediction
self.emb_rel = None
self.gpu = gpu
self.w1 = torch.nn.Parameter(torch.Tensor(self.h_dim, self.h_dim), requires_grad=True).float()
torch.nn.init.xavier_normal_(self.w1)
self.w2 = torch.nn.Parameter(torch.Tensor(self.h_dim, self.h_dim), requires_grad=True).float()
torch.nn.init.xavier_normal_(self.w2)
self.emb_rel = torch.nn.Parameter(torch.Tensor(self.num_rels * 2, self.h_dim), requires_grad=True).float()
torch.nn.init.xavier_normal_(self.emb_rel)
self.dynamic_emb = torch.nn.Parameter(torch.Tensor(num_ents, h_dim), requires_grad=True).float()
torch.nn.init.normal_(self.dynamic_emb)
if self.use_static:
self.words_emb = torch.nn.Parameter(torch.Tensor(self.num_words, h_dim), requires_grad=True).float()
torch.nn.init.xavier_normal_(self.words_emb)
self.statci_rgcn_layer = RGCNBlockLayer(self.h_dim, self.h_dim, self.num_static_rels*2, num_bases,
activation=F.rrelu, dropout=dropout, self_loop=False, skip_connect=False)
self.static_loss = torch.nn.MSELoss()
self.loss_r = torch.nn.CrossEntropyLoss()
self.loss_e = torch.nn.CrossEntropyLoss()
self.rgcn = RGCNCell(num_ents,
h_dim,
h_dim,
num_rels * 2,
num_bases,
num_basis,
num_hidden_layers,
dropout,
self_loop,
skip_connect,
encoder_name,
self.opn,
self.emb_rel,
use_cuda,
analysis)
self.time_gate_weight = nn.Parameter(torch.Tensor(h_dim, h_dim))
nn.init.xavier_uniform_(self.time_gate_weight, gain=nn.init.calculate_gain('relu'))
self.time_gate_bias = nn.Parameter(torch.Tensor(h_dim))
nn.init.zeros_(self.time_gate_bias)
# GRU cell for relation evolving
self.relation_cell_1 = nn.GRUCell(self.h_dim*2, self.h_dim)
# decoder
if decoder_name == "convtranse":
self.decoder_ob = ConvTransE(num_ents, h_dim, input_dropout, hidden_dropout, feat_dropout)
# self.rdecoder = ConvTransR(num_rels, h_dim, input_dropout, hidden_dropout, feat_dropout)
else:
raise NotImplementedError
def forward(self, g_list, static_graph, use_cuda):
gate_list = []
degree_list = []
# a = True
if self.use_static:
static_graph = static_graph.to(self.gpu)
static_graph.ndata['h'] = torch.cat((self.dynamic_emb, self.words_emb), dim=0) # 演化得到的表示,和wordemb满足静态图约束
self.statci_rgcn_layer(static_graph, [])
static_emb = static_graph.ndata.pop('h')[:self.num_ents, :]
static_emb = F.normalize(static_emb) if self.layer_norm else static_emb
self.h = static_emb
a = torch.isnan(F.normalize(static_emb)).any() or torch.isinf(static_emb).any()
if a ==True:
print("static_emb is nan")
else:
self.h = F.normalize(self.dynamic_emb) if self.layer_norm else self.dynamic_emb[:, :]
static_emb = None
history_embs = []
for i, g in enumerate(g_list):
g = g.to(self.gpu)
temp_e = self.h[g.r_to_e]
x_input = torch.zeros(self.num_rels * 2, self.h_dim).float().cuda() if use_cuda else torch.zeros(self.num_rels * 2, self.h_dim).float()
for span, r_idx in zip(g.r_len, g.uniq_r):
x = temp_e[span[0]:span[1],:]
x_mean = torch.mean(x, dim=0, keepdim=True)
x_input[r_idx] = x_mean
if i == 0:
x_input = torch.cat((self.emb_rel, x_input), dim=1)
self.h_0 = self.relation_cell_1(x_input, self.emb_rel) # 第1层输入
self.h_0 = F.normalize(self.h_0) if self.layer_norm else self.h_0
else:
x_input = torch.cat((self.emb_rel, x_input), dim=1)
self.h_0 = self.relation_cell_1(x_input, self.h_0) # 第2层输出==下一时刻第一层输入
self.h_0 = F.normalize(self.h_0) if self.layer_norm else self.h_0
current_h = self.rgcn.forward(g, self.h) #, [self.h_0, self.h_0])
current_h = F.normalize(current_h) if self.layer_norm else current_h
time_weight = F.sigmoid(torch.mm(self.h, self.time_gate_weight) + self.time_gate_bias)
self.h = time_weight * current_h + (1-time_weight) * self.h
history_embs.append(self.h)
return history_embs, static_emb, self.h_0, gate_list, degree_list
def predict(self, test_graph, num_rels, static_graph, test_triplets, use_cuda, neg_samples_batch=None,
pos_samples_batch=None, evaluator=None, metric=None):
perf_list = [None]*len(neg_samples_batch)
hits_list = [None]*len(neg_samples_batch)
with torch.no_grad():
# inverse_test_triplets = test_triplets[:, [2, 1, 0]]
# inverse_test_triplets[:, 1] = inverse_test_triplets[:, 1] + num_rels # 将逆关系换成逆关系的id
all_triples =test_triplets # torch.cat((test_triplets, inverse_test_triplets))
evolve_embs, _, r_emb, _, _ = self.forward(test_graph, static_graph, use_cuda)
embedding = F.normalize(evolve_embs[-1]) if self.layer_norm else evolve_embs[-1]
if neg_samples_batch != None: # added by tgb team
perf_list = []
hits_list = []
for query_id, query in enumerate(neg_samples_batch): # for each sample separately
pos = pos_samples_batch[query_id]
neg = torch.tensor(query).to(pos.device)
all =torch.cat((pos.unsqueeze(0), neg), dim=0)
score = self.decoder_ob.forward(embedding, r_emb, test_triplets[query_id].unsqueeze(0),
samples_of_interest_emb=embedding[all] )
# compute MRR
input_dict = {
"y_pred_pos": np.array([score[0,0].cpu()]),
"y_pred_neg": np.array(score[0,1:].cpu()),
"eval_metric": [metric],
}
prediction_perf = evaluator.eval(input_dict)
perf_list.append(prediction_perf[metric])
hits_list.append(prediction_perf['hits@10'])
else:
score = self.decoder_ob.forward(embedding, r_emb, all_triples, mode="test")
# score_rel = self.rdecoder.forward(embedding, r_emb, all_triples, mode="test")
return score, perf_list, hits_list
def get_mask_nonzero(self, static_embedding):
""" Each element of this resulting tensor will be True if the sum of the corresponding row in
static_emb is not zero, and False otherwise
"""
mask = torch.sum(static_embedding, dim=1) != 0
return mask
def get_loss(self, glist, triples, static_graph, use_cuda):
"""
:param glist:
:param triplets:
:param static_graph:
:param use_cuda:
:return:
"""
loss_ent = torch.zeros(1).cuda().to(self.gpu) if use_cuda else torch.zeros(1)
loss_rel = torch.zeros(1).cuda().to(self.gpu) if use_cuda else torch.zeros(1)
loss_static = torch.zeros(1).cuda().to(self.gpu) if use_cuda else torch.zeros(1)
# inverse_triples = triples[:, [2, 1, 0]]
# inverse_triples[:, 1] = inverse_triples[:, 1] + self.num_rels
all_triples = triples #torch.cat([triples, inverse_triples])
all_triples = all_triples.to(self.gpu)
evolve_embs, static_emb, r_emb, _, _ = self.forward(glist, static_graph, use_cuda)
pre_emb = F.normalize(evolve_embs[-1]) if self.layer_norm else evolve_embs[-1]
if self.entity_prediction:
scores_ob = self.decoder_ob.forward(pre_emb, r_emb, all_triples).view(-1, self.num_ents)
loss_ent += self.loss_e(scores_ob, all_triples[:, 2])
if self.use_static:
if self.discount == 1:
for time_step, evolve_emb in enumerate(evolve_embs):
step = (self.angle * math.pi / 180) * (time_step + 1)
if self.layer_norm:
a= torch.isnan(F.normalize(evolve_emb)).any() or torch.isinf(evolve_emb).any()
if a ==True:
print("evolve_emb is nan")
sim_matrix = torch.sum(static_emb * F.normalize(evolve_emb), dim=1)
a = torch.isnan(sim_matrix).any() or torch.isinf(sim_matrix).any()
if a ==True:
print("sim_matrix is nan")
else:
sim_matrix = torch.sum(static_emb * evolve_emb, dim=1)
c = torch.norm(static_emb, p=2, dim=1) * torch.norm(evolve_emb, p=2, dim=1)
non_zero_mask = c != 0
# Initialize b_sim_matrix with zeros (or another appropriate value)
sim_matrix = torch.zeros_like(sim_matrix)
# Perform division only where c is not zero
sim_matrix[non_zero_mask] = sim_matrix[non_zero_mask] / c[non_zero_mask]
# sim_matrix = sim_matrix / c
mask = (math.cos(step) - sim_matrix) > 0
# mask = self.get_mask_nonzero(static_emb) #modified! to only consider non-zero rows
loss_static += self.weight * torch.sum(torch.masked_select(math.cos(step) - sim_matrix, mask))
elif self.discount == 0:
for time_step, evolve_emb in enumerate(evolve_embs):
step = (self.angle * math.pi / 180)
if self.layer_norm:
sim_matrix = torch.sum(static_emb * F.normalize(evolve_emb), dim=1)
else:
sim_matrix = torch.sum(static_emb * evolve_emb, dim=1)
c = torch.norm(static_emb, p=2, dim=1) * torch.norm(evolve_emb, p=2, dim=1)
sim_matrix = sim_matrix / c
mask = (math.cos(step) - sim_matrix) > 0
loss_static += self.weight * torch.sum(torch.masked_select(math.cos(step) - sim_matrix, mask))
return loss_ent, loss_rel, loss_static
================================================
FILE: modules/sampler_core.cpp
================================================
#include
#include
#include
#include
#include
#include
#include
#include
#include
namespace py = pybind11;
typedef int NodeIDType;
typedef int EdgeIDType;
typedef float TimeStampType;
class TemporalGraphBlock
{
public:
std::vector row;
std::vector col;
std::vector eid;
std::vector ts;
std::vector dts;
std::vector nodes;
NodeIDType dim_in, dim_out;
double ptr_time = 0;
double search_time = 0;
double sample_time = 0;
double tot_time = 0;
double coo_time = 0;
TemporalGraphBlock() {}
TemporalGraphBlock(std::vector &_row, std::vector &_col,
std::vector &_eid, std::vector &_ts,
std::vector &_dts, std::vector &_nodes,
NodeIDType _dim_in, NodeIDType _dim_out) : row(_row), col(_col), eid(_eid), ts(_ts), dts(_dts),
nodes(_nodes), dim_in(_dim_in), dim_out(_dim_out) {}
};
class ParallelSampler
{
public:
std::vector indptr;
std::vector indices;
std::vector eid;
std::vector ts;
NodeIDType num_nodes;
EdgeIDType num_edges;
int num_thread_per_worker;
int num_workers;
int num_threads;
int num_layers;
std::vector num_neighbors;
bool recent;
bool prop_time;
int num_history;
TimeStampType window_duration;
std::vector::size_type>> ts_ptr;
omp_lock_t *ts_ptr_lock;
std::vector ret;
ParallelSampler(std::vector &_indptr, std::vector &_indices,
std::vector &_eid, std::vector &_ts,
int _num_thread_per_worker, int _num_workers, int _num_layers,
std::vector &_num_neighbors, bool _recent, bool _prop_time,
int _num_history, TimeStampType _window_duration) : indptr(_indptr), indices(_indices), eid(_eid), ts(_ts), prop_time(_prop_time),
num_thread_per_worker(_num_thread_per_worker), num_workers(_num_workers),
num_layers(_num_layers), num_neighbors(_num_neighbors), recent(_recent),
num_history(_num_history), window_duration(_window_duration)
{
omp_set_num_threads(num_thread_per_worker * num_workers);
num_threads = num_thread_per_worker * num_workers;
num_nodes = indptr.size() - 1;
num_edges = indices.size();
ts_ptr_lock = (omp_lock_t *)malloc(num_nodes * sizeof(omp_lock_t));
for (int i = 0; i < num_nodes; i++)
omp_init_lock(&ts_ptr_lock[i]);
ts_ptr.resize(num_history + 1);
for (auto it = ts_ptr.begin(); it != ts_ptr.end(); it++)
{
it->resize(indptr.size() - 1);
#pragma omp parallel for
for (auto itt = indptr.begin(); itt < indptr.end() - 1; itt++)
(*it)[itt - indptr.begin()] = *itt;
}
}
void reset()
{
for (auto it = ts_ptr.begin(); it != ts_ptr.end(); it++)
{
it->resize(indptr.size() - 1);
#pragma omp parallel for
for (auto itt = indptr.begin(); itt < indptr.end() - 1; itt++)
(*it)[itt - indptr.begin()] = *itt;
}
}
void update_ts_ptr(int slc, std::vector &root_nodes,
std::vector &root_ts, float offset)
{
#pragma omp parallel for schedule(static, int(ceil(static_cast (root_nodes.size()) / num_threads)))
for (std::vector::size_type i = 0; i < root_nodes.size(); i++)
{
NodeIDType n = root_nodes[i];
omp_set_lock(&(ts_ptr_lock[n]));
for (std::vector::size_type j = ts_ptr[slc][n]; j < indptr[n + 1]; j++)
{
// std::cout << "comparing " << ts[j] << " with " << root_ts[i] << std::endl;
if (ts[j] > (root_ts[i] + offset - 1e-7f))
{
if (j != ts_ptr[slc][n])
ts_ptr[slc][n] = j - 1;
break;
}
if (j == indptr[n + 1] - 1)
{
ts_ptr[slc][n] = j;
}
}
omp_unset_lock(&(ts_ptr_lock[n]));
}
}
inline void add_neighbor(std::vector *_row, std::vector *_col,
std::vector *_eid, std::vector *_ts,
std::vector *_dts, std::vector *_nodes,
EdgeIDType &k, TimeStampType &src_ts, int &row_id)
{
_row->push_back(row_id);
_col->push_back(_nodes->size());
_eid->push_back(eid[k]);
if (prop_time)
_ts->push_back(src_ts);
else
_ts->push_back(ts[k]);
_dts->push_back(src_ts - ts[k]);
_nodes->push_back(indices[k]);
// _row.push_back(0);
// _col.push_back(0);
// _eid.push_back(0);
// if (prop_time)
// _ts.push_back(src_ts);
// else
// _ts.push_back(10000);
// _nodes.push_back(100);
}
inline void combine_coo(TemporalGraphBlock &_ret, std::vector **_row,
std::vector **_col,
std::vector **_eid,
std::vector **_ts,
std::vector **_dts,
std::vector **_nodes,
std::vector &_out_nodes)
{
std::vector cum_row, cum_col;
cum_row.push_back(0);
cum_col.push_back(0);
for (int tid = 0; tid < num_threads; tid++)
{
// std::cout<size());
}
int num_root_nodes = _ret.nodes.size();
_ret.row.resize(cum_col.back());
_ret.col.resize(cum_col.back());
_ret.eid.resize(cum_col.back());
_ret.ts.resize(cum_col.back() + num_root_nodes);
_ret.dts.resize(cum_col.back() + num_root_nodes);
_ret.nodes.resize(cum_col.back() + num_root_nodes);
#pragma omp parallel for schedule(static, 1)
for (int tid = 0; tid < num_threads; tid++)
{
std::transform(_row[tid]->begin(), _row[tid]->end(), _row[tid]->begin(),
[&](auto &v)
{ return v + cum_row[tid]; });
std::transform(_col[tid]->begin(), _col[tid]->end(), _col[tid]->begin(),
[&](auto &v)
{ return v + cum_col[tid] + num_root_nodes; });
std::copy(_row[tid]->begin(), _row[tid]->end(), _ret.row.begin() + cum_col[tid]);
std::copy(_col[tid]->begin(), _col[tid]->end(), _ret.col.begin() + cum_col[tid]);
std::copy(_eid[tid]->begin(), _eid[tid]->end(), _ret.eid.begin() + cum_col[tid]);
std::copy(_ts[tid]->begin(), _ts[tid]->end(), _ret.ts.begin() + cum_col[tid] + num_root_nodes);
std::copy(_dts[tid]->begin(), _dts[tid]->end(), _ret.dts.begin() + cum_col[tid] + num_root_nodes);
std::copy(_nodes[tid]->begin(), _nodes[tid]->end(), _ret.nodes.begin() + cum_col[tid] + num_root_nodes);
delete _row[tid];
delete _col[tid];
delete _eid[tid];
delete _ts[tid];
delete _dts[tid];
delete _nodes[tid];
}
_ret.dim_in = _ret.nodes.size();
_ret.dim_out = cum_row.back();
}
void sample_layer(std::vector &_root_nodes, std::vector &_root_ts,
int neighs, bool use_ptr, bool from_root)
{
double t_s = omp_get_wtime();
std::vector *root_nodes;
std::vector *root_ts;
if (from_root)
{
root_nodes = &_root_nodes;
root_ts = &_root_ts;
}
double t_ptr_s = omp_get_wtime();
if (use_ptr)
update_ts_ptr(num_history, *root_nodes, *root_ts, 0);
ret[0].ptr_time += omp_get_wtime() - t_ptr_s;
for (int i = 0; i < num_history; i++)
{
if (!from_root)
{
root_nodes = &(ret[ret.size() - 1 - i - num_history].nodes);
root_ts = &(ret[ret.size() - 1 - i - num_history].ts);
}
TimeStampType offset = -i * window_duration;
t_ptr_s = omp_get_wtime();
if ((use_ptr) && (std::abs(window_duration) > 1e-7f))
update_ts_ptr(num_history - 1 - i, *root_nodes, *root_ts, offset - window_duration);
ret[0].ptr_time += omp_get_wtime() - t_ptr_s;
std::vector *_row[num_threads];
std::vector *_col[num_threads];
std::vector *_eid[num_threads];
std::vector *_ts[num_threads];
std::vector *_dts[num_threads];
std::vector *_nodes[num_threads];
std::vector _out_node(num_threads, 0);
int reserve_capacity = int(ceil((*root_nodes).size() / num_threads)) * neighs;
#pragma omp parallel
{
int tid = omp_get_thread_num();
unsigned int loc_seed = tid;
_row[tid] = new std::vector;
_col[tid] = new std::vector;
_eid[tid] = new std::vector;
_ts[tid] = new std::vector;
_dts[tid] = new std::vector;
_nodes[tid] = new std::vector;
_row[tid]->reserve(reserve_capacity);
_col[tid]->reserve(reserve_capacity);
_eid[tid]->reserve(reserve_capacity);
_ts[tid]->reserve(reserve_capacity);
_dts[tid]->reserve(reserve_capacity);
_nodes[tid]->reserve(reserve_capacity);
// #pragma omp critical
// std::cout<size()<<" "<((*root_nodes).size()) / num_threads)))
for (std::vector::size_type j = 0; j < (*root_nodes).size(); j++)
{
NodeIDType n = (*root_nodes)[j];
// if (tid == 16)
// std::cout << _out_node[tid] << " " < std::max(s_search, e_search - neighs); k--)
{
if (ts[k] < nts + offset - 1e-7f)
{
add_neighbor(_row[tid], _col[tid], _eid[tid], _ts[tid],
_dts[tid], _nodes[tid], k, nts, _out_node[tid]);
}
}
}
else
{
// random sampling within ptr
for (int _i = 0; _i < neighs; _i++)
{
EdgeIDType picked = s_search + rand_r(&loc_seed) % (e_search - s_search + 1);
if (ts[picked] < nts + offset - 1e-7f)
{
add_neighbor(_row[tid], _col[tid], _eid[tid], _ts[tid],
_dts[tid], _nodes[tid], picked, nts, _out_node[tid]);
}
}
}
_out_node[tid] += 1;
if (tid == 0)
ret[0].sample_time += omp_get_wtime() - t_sample_s;
}
}
double t_coo_s = omp_get_wtime();
ret[ret.size() - 1 - i].ts.insert(ret[ret.size() - 1 - i].ts.end(),
root_ts->begin(), root_ts->end());
ret[ret.size() - 1 - i].nodes.insert(ret[ret.size() - 1 - i].nodes.end(),
root_nodes->begin(), root_nodes->end());
ret[ret.size() - 1 - i].dts.resize(root_nodes->size());
combine_coo(ret[ret.size() - 1 - i], _row, _col, _eid, _ts, _dts, _nodes, _out_node);
ret[0].coo_time += omp_get_wtime() - t_coo_s;
}
ret[0].tot_time += omp_get_wtime() - t_s;
}
void sample(std::vector &root_nodes, std::vector &root_ts)
{
// a weird bug, dgl library seems to modify the total number of threads
omp_set_num_threads(num_threads);
ret.resize(0);
bool first_layer = true;
bool use_ptr = false;
for (int i = 0; i < num_layers; i++)
{
ret.resize(ret.size() + num_history);
if ((first_layer) || ((prop_time) && num_history == 1) || (recent))
{
first_layer = false;
use_ptr = true;
}
else
use_ptr = false;
if (i == 0)
sample_layer(root_nodes, root_ts, num_neighbors[i], use_ptr, true);
else
sample_layer(root_nodes, root_ts, num_neighbors[i], use_ptr, false);
}
}
};
template
inline py::array vec2npy(const std::vector &vec)
{
// need to let python garbage collector handle C++ vector memory
// see https://github.com/pybind/pybind11/issues/1042
auto v = new std::vector(vec);
auto capsule = py::capsule(v, [](void *v)
{ delete reinterpret_cast *>(v); });
return py::array(v->size(), v->data(), capsule);
// return py::array(vec.size(), vec.data());
}
PYBIND11_MODULE(sampler_core, m)
{
py::class_(m, "TemporalGraphBlock")
.def(py::init &, std::vector &,
std::vector &, std::vector &,
std::vector &, std::vector &,
NodeIDType, NodeIDType>())
.def("row", [](const TemporalGraphBlock &tgb)
{ return vec2npy(tgb.row); })
.def("col", [](const TemporalGraphBlock &tgb)
{ return vec2npy(tgb.col); })
.def("eid", [](const TemporalGraphBlock &tgb)
{ return vec2npy(tgb.eid); })
.def("ts", [](const TemporalGraphBlock &tgb)
{ return vec2npy(tgb.ts); })
.def("dts", [](const TemporalGraphBlock &tgb)
{ return vec2npy(tgb.dts); })
.def("nodes", [](const TemporalGraphBlock &tgb)
{ return vec2npy(tgb.nodes); })
.def("dim_in", [](const TemporalGraphBlock &tgb)
{ return tgb.dim_in; })
.def("dim_out", [](const TemporalGraphBlock &tgb)
{ return tgb.dim_out; })
.def("tot_time", [](const TemporalGraphBlock &tgb)
{ return tgb.tot_time; })
.def("ptr_time", [](const TemporalGraphBlock &tgb)
{ return tgb.ptr_time; })
.def("search_time", [](const TemporalGraphBlock &tgb)
{ return tgb.search_time; })
.def("sample_time", [](const TemporalGraphBlock &tgb)
{ return tgb.sample_time; })
.def("coo_time", [](const TemporalGraphBlock &tgb)
{ return tgb.coo_time; });
py::class_(m, "ParallelSampler")
.def(py::init &, std::vector &,
std::vector &, std::vector &,
int, int, int, std::vector &, bool, bool,
int, TimeStampType>())
.def("sample", &ParallelSampler::sample)
.def("reset", &ParallelSampler::reset)
.def("get_ret", [](const ParallelSampler &ps)
{ return ps.ret; });
}
================================================
FILE: modules/sthn.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
import numpy as np
from torch import Tensor
from tqdm import tqdm
from sampler_core import ParallelSampler
import torch_sparse
import time
import copy
import random
from torch_sparse import SparseTensor
from torchmetrics.classification import MulticlassAUROC, MulticlassAveragePrecision
from torchmetrics.classification import BinaryAUROC, BinaryAveragePrecision
from sklearn.preprocessing import MinMaxScaler
import os
import pickle
"""
Source: STHN: utils.py
URL: https://github.com/celi52/STHN/blob/main/utils.py
"""
# utility function
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def row_norm(adj_t):
if isinstance(adj_t, torch_sparse.SparseTensor):
# adj_t = torch_sparse.fill_diag(adj, 1)
deg = torch_sparse.sum(adj_t, dim=1)
deg_inv = 1. / deg
deg_inv.masked_fill_(deg_inv == float('inf'), 0.)
adj_t = torch_sparse.mul(adj_t, deg_inv.view(-1, 1))
return adj_t
"""
Source: STHN: construct_subgraph.py
URL: https://github.com/celi52/STHN/blob/main/construct_subgraph.py
Notes: The NegLinkSampler is only used for STHN internal sampling and not for TGB
"""
##############################################################################
##############################################################################
##############################################################################
# get sampler
class NegLinkSampler:
"""
From https://github.com/amazon-research/tgl/blob/main/sampler.py
"""
def __init__(self, num_nodes):
self.num_nodes = num_nodes
def sample(self, n):
return np.random.randint(self.num_nodes, size=n)
def get_parallel_sampler(g, num_neighbors=10):
"""
Function wrapper of the C++ sampler (https://github.com/amazon-research/tgl/blob/main/sampler_core.cpp)
Sample the 1-hop most recent neighbors of each node
"""
configs = [
g['indptr'], # indptr --> fixed: data info
g['indices'], # indices --> fixed: data info
g['eid'], # eid --> fixed: data info
g['ts'], # ts --> fixed: data info
32, # num_thread_per_worker --> change this based on machine's setup
1, # num_workers --> change this based on machine's setup
1, # num_layers --> change this based on machine's setup
[num_neighbors], # num_neighbors --> hyper-parameters. Reddit 10, WIKI 30
True, # recent --> fixed: never touch
False, # prop_time --> never touch
1, # num_history --> fixed: never touch
0 # window_duration --> fixed: never touch
]
sampler = ParallelSampler(*configs)
neg_link_sampler = NegLinkSampler(g['indptr'].shape[0] - 1)
return sampler, neg_link_sampler
##############################################################################
##############################################################################
##############################################################################
# sampling
def get_mini_batch(sampler, root_nodes, ts, num_hops): # neg_samples is not used
"""
Call function fetch_subgraph()
Return: Subgraph of each node.
"""
all_graphs = []
for root_node, root_time in zip(root_nodes, ts):
all_graphs.append(fetch_subgraph(sampler, root_node, root_time, num_hops))
return all_graphs
def fetch_subgraph(sampler, root_node, root_time, num_hops):
"""
Sample a subgraph for each node or node pair
"""
all_row_col_times_nodes_eid = []
# suppose sampling for both a single node and a node pair (two side of a link)
if isinstance(root_node, list):
nodes, ts = [i for i in root_node], [root_time for i in root_node]
else:
nodes, ts = [root_node], [root_time]
# fetch all nodes+edges
for _ in range(num_hops):
sampler.sample(nodes, ts)
ret = sampler.get_ret() # 1-hop recent neighbors
row, col, eid = ret[0].row(), ret[0].col(), ret[0].eid()
nodes, ts = ret[0].nodes(), ret[0].ts().astype(np.float32)
row_col_times_nodes_eid = np.stack([ts[row], nodes[row], ts[col], nodes[col], eid]).T
all_row_col_times_nodes_eid.append(row_col_times_nodes_eid)
all_row_col_times_nodes_eid = np.concatenate(all_row_col_times_nodes_eid, axis=0)
# remove duplicate edges and sort according to the root node time (descending)
all_row_col_times_nodes_eid = np.unique(all_row_col_times_nodes_eid, axis=0)[::-1]
all_row_col_times_nodes = all_row_col_times_nodes_eid[:, :-1]
eid = all_row_col_times_nodes_eid[:, -1]
# remove duplicate (node+time) and sorted by time decending order
all_row_col_times_nodes = np.array_split(all_row_col_times_nodes, 2, axis=1)
times_nodes = np.concatenate(all_row_col_times_nodes, axis=0)
times_nodes = np.unique(times_nodes, axis=0)[::-1]
# each (node, time) pair identifies a node
node_2_ind = dict()
for ind, (time, node) in enumerate(times_nodes):
node_2_ind[(time, node)] = ind
# translate the nodes into new index
row = np.zeros(len(eid), dtype=np.int32)
col = np.zeros(len(eid), dtype=np.int32)
for i, ((t1, n1), (t2, n2)) in enumerate(zip(*all_row_col_times_nodes)):
row[i] = node_2_ind[(t1, n1)]
col[i] = node_2_ind[(t2, n2)]
# fetch get time + node information
eid = eid.astype(np.int32)
ts = times_nodes[:,0].astype(np.float32)
nodes = times_nodes[:,1].astype(np.int32)
dts = root_time - ts # make sure the root node time is 0
return {
# edge info: sorted with descending row (src) node temporal order
'row': row, # src
'col': col, # dst
'eid': eid,
# node info
'nodes': nodes , # sorted by the ascending order of node's dts (root_node's dts = 0)
'dts': dts,
# graph info
'num_nodes': len(nodes),
'num_edges': len(eid),
# root info
'root_node': root_node,
'root_time': root_time,
}
def construct_mini_batch_giant_graph(all_graphs, max_num_edges):
"""
Take the subgraph computed by fetch_subgraph() and combine it into a giant graph
Return: the new indices of the graph
"""
all_rows, all_cols, all_eids, all_nodes, all_dts = [], [], [], [], []
cumsum_edges = 0
all_edge_indptr = [0]
cumsum_nodes = 0
all_node_indptr = [0]
all_root_nodes = []
all_root_times = []
for all_graph in all_graphs:
# record inds
num_nodes = all_graph['num_nodes']
num_edges = min(all_graph['num_edges'], max_num_edges)
# add graph information
all_rows.append(all_graph['row'][:num_edges] + cumsum_nodes)
all_cols.append(all_graph['col'][:num_edges] + cumsum_nodes)
all_eids.append(all_graph['eid'][:num_edges])
all_nodes.append(all_graph['nodes'])
all_dts.append(all_graph['dts'])
# update cumsum
cumsum_nodes += num_nodes
all_node_indptr.append(cumsum_nodes)
cumsum_edges += num_edges
all_edge_indptr.append(cumsum_edges)
# add root nodes
all_root_nodes.append(all_graph['root_node'])
all_root_times.append(all_graph['root_time'])
# for each edges
all_rows = np.concatenate(all_rows).astype(np.int32)
all_cols = np.concatenate(all_cols).astype(np.int32)
all_eids = np.concatenate(all_eids).astype(np.int32)
all_edge_indptr = np.array(all_edge_indptr).astype(np.int32)
# for each nodes
all_nodes = np.concatenate(all_nodes).astype(np.int32)
all_dts = np.concatenate(all_dts).astype(np.float32)
all_node_indptr = np.array(all_node_indptr).astype(np.int32)
return {
# for edges
'row': all_rows,
'col': all_cols,
'eid': all_eids,
'edts': all_dts[all_cols] - all_dts[all_rows],
# number of subgraphs + 1
'all_node_indptr': all_node_indptr,
'all_edge_indptr': all_edge_indptr,
# for nodes
'nodes': all_nodes,
'dts': all_dts,
# general information
'all_num_nodes': cumsum_nodes,
'all_num_edges': cumsum_edges,
# root nodes
'root_nodes': np.array(all_root_nodes, dtype=np.int32),
'root_times': np.array(all_root_times, dtype=np.float32),
}
##############################################################################
##############################################################################
##############################################################################
def print_subgraph_data(subgraph_data):
"""
Used to double check see if the sampled graph is as expected
"""
for key, vals in subgraph_data.items():
if isinstance(vals, np.ndarray):
print(key, vals.shape)
else:
print(key, vals)
"""
Source: STHN data_process_utils.py
URL: https://github.com/celi52/STHN/blob/main/data_process_utils.py
Note:
Currently only using pre_compute_subgraphs because use_cached_subgraph is True
get_subgraph_sampler needs to be modified if use_cached_subgraph is False
The function get_all_inds is new to handle TGB evaluation
"""
class SubgraphSampler:
def __init__(self, all_root_nodes, all_ts, sampler, args):
self.all_root_nodes = all_root_nodes
self.all_ts = all_ts
self.sampler = sampler
self.sampled_num_hops = args.sampled_num_hops
def mini_batch(self, ind, mini_batch_inds):
root_nodes = self.all_root_nodes[ind][mini_batch_inds]
ts = self.all_ts[ind][mini_batch_inds]
return get_mini_batch(self.sampler, root_nodes, ts, self.sampled_num_hops)
def get_subgraph_sampler(args, g, df, mode):
###################################################
# get cached file_name
if mode == 'train':
extra_neg_samples = args.extra_neg_samples
else:
extra_neg_samples = 1
###################################################
# for each node, sample its neighbors with the most recent neighbors (sorted)
print('Sample subgraphs ... for %s mode'%mode)
sampler, neg_link_sampler = get_parallel_sampler(g, args.num_neighbors)
###################################################
# setup modes
if mode == 'train':
cur_df = df[args.train_mask]
elif mode == 'valid':
cur_df = df[args.val_mask]
elif mode == 'test':
cur_df = df[args.test_mask]
loader = cur_df.groupby(cur_df.index // args.batch_size)
print(cur_df.index, cur_df.index // args.batch_size)
pbar = tqdm(total=len(loader))
pbar.set_description('Pre-sampling: %s mode with negative sampleds %s ...'%(mode, extra_neg_samples))
all_root_nodes = []
all_ts = []
for _, rows in loader:
root_nodes = np.concatenate(
[rows.src.values,
rows.dst.values,
neg_link_sampler.sample(len(rows) * extra_neg_samples)]
).astype(np.int32)
all_root_nodes.append(root_nodes)
# time-stamp for node = edge time-stamp
ts = np.tile(rows.time.values, extra_neg_samples + 2).astype(np.float32)
all_ts.append(ts)
pbar.update(1)
pbar.close()
return SubgraphSampler(all_root_nodes, all_ts, sampler, args)
######################################################################################################
######################################################################################################
######################################################################################################
# for small dataset, we can cache each graph
def pre_compute_subgraphs(args, g, df, mode, negative_sampler=None, split_mode='test', cache=False):
###################################################
# get cached file_name
if mode == 'train':
extra_neg_samples = args.extra_neg_samples
else:
extra_neg_samples = 1
fn = os.path.join(os.getcwd(), 'DATA', args.data,
'%s_neg_sample_neg%d_bs%d_hops%d_neighbors%d.pickle'%(mode,
extra_neg_samples,
args.batch_size,
args.sampled_num_hops,
args.num_neighbors))
###################################################
# # try:
if os.path.exists(fn):
subgraph_elabel = pickle.load(open(fn, 'rb'))
# print('load ', fn)
else:
##################################################
# for each node, sample its neighbors with the most recent neighbors (sorted)
print('Sample subgraphs ... for %s mode'%mode)
sampler, neg_link_sampler = get_parallel_sampler(g, args.num_neighbors)
###################################################
# setup modes
if mode == 'train':
cur_df = df[args.train_mask]
elif mode == 'valid':
cur_df = df[args.val_mask]
elif mode == 'test':
cur_df = df[args.test_mask]
loader = cur_df.groupby(cur_df.index // args.batch_size)
pbar = tqdm(total=len(loader))
pbar.set_description('Pre-sampling: %s mode'%(mode,))
###################################################
all_subgraphs = []
all_elabel = []
sampler.reset()
for _, rows in loader:
if negative_sampler is not None:
neg_batch_list = negative_sampler.query_batch(
rows.src.values,
rows.dst.values,
rows.time.values,
rows.label.values,
split_mode=split_mode
)
neg_batch_list = np.concatenate(neg_batch_list)
extra_neg_samples = neg_batch_list.shape[0] // len(rows)
else:
neg_batch_list = neg_link_sampler.sample(len(rows) * extra_neg_samples)
root_nodes = np.concatenate(
[rows.src.values,
rows.dst.values,
neg_batch_list]
).astype(np.int32)
# time-stamp for node = edge time-stamp
ts = np.tile(rows.time.values, extra_neg_samples + 2).astype(np.float32)
all_elabel.append(rows.label.values)
all_subgraphs.append(get_mini_batch(sampler, root_nodes, ts, args.sampled_num_hops))
pbar.update(1)
pbar.close()
subgraph_elabel = (all_subgraphs, all_elabel)
if cache:
try:
pickle.dump(subgraph_elabel, open(fn, 'wb'), protocol=pickle.HIGHEST_PROTOCOL)
except:
print('For some shit reason pickle cannot save ... but anyway ...')
##################################################
return subgraph_elabel
def get_random_inds(num_subgraph, cached_neg_samples, neg_samples):
###################################################
batch_size = num_subgraph // (2+cached_neg_samples)
pos_src_inds = np.arange(batch_size)
pos_dst_inds = np.arange(batch_size) + batch_size
neg_dst_inds = np.random.randint(low=2, high=2+cached_neg_samples, size=batch_size*neg_samples)
neg_dst_inds = batch_size * neg_dst_inds + np.arange(batch_size)
mini_batch_inds = np.concatenate([pos_src_inds, pos_dst_inds, neg_dst_inds]).astype(np.int32)
###################################################
return mini_batch_inds
def get_all_inds(num_subgraph, neg_samples):
###################################################
batch_size = num_subgraph // (2+neg_samples)
pos_src_inds = np.arange(batch_size)
pos_dst_inds = np.arange(batch_size) + batch_size
neg_dst_inds = batch_size * 2 + np.arange(batch_size * neg_samples)
mini_batch_inds = np.concatenate([pos_src_inds, pos_dst_inds, neg_dst_inds]).astype(np.int32)
###################################################
return mini_batch_inds
def check_data_leakage(args, g, df):
"""
This is a function to double if the sampled graph has eid greater than the positive node pairs eid (if no then no data leakage)
"""
for mode in ['train', 'valid', 'test']:
if mode == 'train':
cur_df = df[:args.train_edge_end]
elif mode == 'valid':
cur_df = df[args.train_edge_end:args.val_edge_end]
elif mode == 'test':
cur_df = df[args.val_edge_end:]
loader = cur_df.groupby(cur_df.index // args.batch_size)
subgraphs = pre_compute_subgraphs(args, g, df, mode)
for i, (_, rows) in enumerate(loader):
root_nodes = np.concatenate([rows.src.values, rows.dst.values]).astype(np.int32)
eids = np.tile(rows.index.values, 2)
cur_subgraphs = subgraphs[i][:args.batch_size*2]
for eid, cur_subgraph in zip(eids, cur_subgraphs):
all_eids_in_subgraph = cur_subgraph['eid']
if len(all_eids_in_subgraph) == 0:
continue
# all edges in the sampled graph has eid smaller than the target edge's eid, i.e,. sampled links never seen before
assert sum(all_eids_in_subgraph < eid) == len(all_eids_in_subgraph)
print('Does not detect information leakage ...')
"""
Source: STHN link_pred_train_utils.py
URL: https://github.com/celi52/STHN/blob/main/link_pred_train_utils.py
Notes: I created a separate function for get_inputs_for_ind so that we can use it for TGB evaluation as well
"""
def get_inputs_for_ind(subgraphs, mode, cached_neg_samples, neg_samples, node_feats, edge_feats, cur_df, cur_inds, ind, args):
subgraphs, elabel = subgraphs
scaler = MinMaxScaler()
if args.use_cached_subgraph == False and mode == 'train':
subgraph_data_list = subgraphs.all_root_nodes[ind]
mini_batch_inds = get_random_inds(len(subgraph_data_list), cached_neg_samples, neg_samples)
subgraph_data = subgraphs.mini_batch(ind, mini_batch_inds)
elif mode in ['test', 'tgb-val']:
assert cached_neg_samples == neg_samples
subgraph_data_list = subgraphs[ind]
mini_batch_inds = get_all_inds(len(subgraph_data_list), cached_neg_samples)
subgraph_data = [subgraph_data_list[i] for i in mini_batch_inds]
else: # sthn valid
subgraph_data_list = subgraphs[ind]
mini_batch_inds = get_random_inds(len(subgraph_data_list), cached_neg_samples, neg_samples)
subgraph_data = [subgraph_data_list[i] for i in mini_batch_inds]
subgraph_data = construct_mini_batch_giant_graph(subgraph_data, args.max_edges)
# raw edge feats
subgraph_edge_feats = edge_feats[subgraph_data['eid']]
subgraph_edts = torch.from_numpy(subgraph_data['edts']).float()
if args.use_graph_structure and node_feats:
num_of_df_links = len(subgraph_data_list) // (cached_neg_samples+2)
# subgraph_node_feats = compute_sign_feats(node_feats, df, cur_inds, num_of_df_links, subgraph_data['root_nodes'], args)
# Erfan: change this part to use masked version
subgraph_node_feats = compute_sign_feats(node_feats, cur_df, cur_inds, num_of_df_links, subgraph_data['root_nodes'], args)
cur_inds += num_of_df_links
else:
subgraph_node_feats = None
# scale
scaler.fit(subgraph_edts.reshape(-1,1))
subgraph_edts = scaler.transform(subgraph_edts.reshape(-1,1)).ravel().astype(np.float32) * 1000
subgraph_edts = torch.from_numpy(subgraph_edts)
# get mini-batch inds
all_inds, has_temporal_neighbors = [], []
# ignore an edge pair if (src_node, dst_node) does not have temporal neighbors
all_edge_indptr = subgraph_data['all_edge_indptr']
for i in range(len(all_edge_indptr)-1):
num_edges = all_edge_indptr[i+1] - all_edge_indptr[i]
all_inds.extend([(args.max_edges * i + j) for j in range(num_edges)])
has_temporal_neighbors.append(num_edges>0)
if not args.predict_class:
inputs = [
subgraph_edge_feats.to(args.device),
subgraph_edts.to(args.device),
len(has_temporal_neighbors),
torch.tensor(all_inds).long()
]
else:
subgraph_edge_type = elabel[ind]
inputs = [
subgraph_edge_feats.to(args.device),
subgraph_edts.to(args.device),
len(has_temporal_neighbors),
torch.tensor(all_inds).long(),
torch.from_numpy(subgraph_edge_type).to(args.device)
]
return inputs, subgraph_node_feats, cur_inds
def run(model, optimizer, args, subgraphs, df, node_feats, edge_feats, MLAUROC, MLAUPRC, mode):
time_epoch = 0
###################################################
# setup modes
cur_inds = 0
if mode == 'train':
model.train()
cur_df = df[args.train_mask]
neg_samples = args.neg_samples
cached_neg_samples = args.extra_neg_samples
elif mode == 'valid':
model.eval()
cur_df = df[args.val_mask]
neg_samples = 1
cached_neg_samples = 1
elif mode == 'test':
## Erfan: remove this part use TGB evaluation
raise('Use TGB evaluation')
# model.eval()
# cur_df = df[args.test_mask]
# neg_samples = 1
# cached_neg_samples = 1
# cur_inds = args.val_edge_end
train_loader = cur_df.groupby(cur_df.index // args.batch_size)
pbar = tqdm(total=len(train_loader))
pbar.set_description('%s mode with negative samples %d ...'%(mode, neg_samples))
###################################################
# compute + training + fetch all scores
loss_lst = []
MLAUROC.reset()
MLAUPRC.reset()
for ind in range(len(train_loader)):
###################################################
inputs, subgraph_node_feats, cur_inds = get_inputs_for_ind(subgraphs, mode, cached_neg_samples, neg_samples, node_feats, edge_feats, cur_df, cur_inds, ind, args)
start_time = time.time()
loss, pred, edge_label = model(inputs, neg_samples, subgraph_node_feats)
if mode == 'train' and optimizer != None:
optimizer.zero_grad()
loss.backward()
optimizer.step()
time_epoch += (time.time() - start_time)
batch_auroc = MLAUROC.update(pred, edge_label)
batch_auprc = MLAUPRC.update(pred, edge_label)
loss_lst.append(float(loss.detach()))
pbar.update(1)
pbar.close()
total_auroc = MLAUROC.compute()
total_auprc = MLAUPRC.compute()
print('%s mode with time %.4f, AUROC %.4f, AUPRC %.4f, loss %.4f'%(mode, time_epoch, total_auroc, total_auprc, loss.item()))
return_loss = np.mean(loss_lst)
return total_auroc, total_auprc, return_loss, time_epoch
def link_pred_train(model, args, g, df, node_feats, edge_feats):
optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
###################################################
# get cached data
if args.use_cached_subgraph:
train_subgraphs = pre_compute_subgraphs(args, g, df, mode='train')
else:
train_subgraphs = get_subgraph_sampler(args, g, df, mode='train')
valid_subgraphs = pre_compute_subgraphs(args, g, df, mode='valid')
# test_subgraphs = pre_compute_subgraphs(args, g, df, mode='test' )
###################################################
all_results = {
'train_ap': [],
'valid_ap': [],
# 'test_ap' : [],
'train_auc': [],
'valid_auc': [],
# 'test_auc' : [],
'train_loss': [],
'valid_loss': [],
# 'test_loss': [],
}
low_loss = 100000
user_train_total_time = 0
user_epoch_num = 0
if args.predict_class:
num_classes = args.num_edgeType+1
train_AUROC = MulticlassAUROC(num_classes, average="macro", thresholds=None)
valid_AUROC = MulticlassAUROC(num_classes, average="macro", thresholds=None)
train_AUPRC = MulticlassAveragePrecision(num_classes, average="macro", thresholds=None)
valid_AUPRC = MulticlassAveragePrecision(num_classes, average="macro", thresholds=None)
else:
train_AUROC = BinaryAUROC(thresholds=None)
valid_AUROC = BinaryAUROC(thresholds=None)
train_AUPRC = BinaryAveragePrecision(thresholds=None)
valid_AUPRC = BinaryAveragePrecision(thresholds=None)
for epoch in range(args.epochs):
print('>>> Epoch ', epoch+1)
train_auc, train_ap, train_loss, time_train = run(model, optimizer, args, train_subgraphs, df,
node_feats, edge_feats, train_AUROC, train_AUPRC, mode='train')
with torch.no_grad():
# second variable (optimizer) is only required for training
valid_auc, valid_ap, valid_loss, time_valid = run(copy.deepcopy(model), None, args, valid_subgraphs, df,
node_feats, edge_feats, valid_AUROC, valid_AUPRC, mode='valid')
# # second variable (optimizer) is only required for training
# test_auc, test_ap, test_loss, time_test = run(copy.deepcopy(model), None, args, test_subgraphs, df,
# node_feats, edge_feats, test_AUROC, test_AUPRC, mode='test')
if valid_loss < low_loss:
best_auc_model = copy.deepcopy(model).cpu()
best_auc = valid_auc
low_loss = valid_loss
best_epoch = epoch
user_train_total_time += time_train + time_valid
user_epoch_num += 1
if epoch > best_epoch + 20:
break
all_results['train_ap'].append(train_ap)
all_results['valid_ap'].append(valid_ap)
# all_results['test_ap'].append(test_ap)
all_results['valid_auc'].append(valid_auc)
all_results['train_auc'].append(train_auc)
# all_results['test_auc'].append(test_auc)
all_results['train_loss'].append(train_loss)
all_results['valid_loss'].append(valid_loss)
# all_results['test_loss'].append(test_loss)
print('best epoch %d, auc score %.4f'%(best_epoch, best_auc))
return best_auc_model
def compute_sign_feats(node_feats, df, start_i, num_links, root_nodes, args):
num_duplicate = len(root_nodes) // num_links
num_nodes = args.num_nodes
root_inds = torch.arange(len(root_nodes)).view(num_duplicate, -1)
root_inds = [arr.flatten() for arr in root_inds.chunk(1, dim=1)]
output_feats = torch.zeros((len(root_nodes), node_feats.size(1))).to(args.device)
i = start_i
for _root_ind in root_inds:
if i == 0 or args.structure_hops == 0:
sign_feats = node_feats.clone()
else:
prev_i = max(0, i - args.structure_time_gap)
cur_df = df[prev_i: i] # get adj's row, col indices (as undirected)
src = torch.from_numpy(cur_df.src.values)
dst = torch.from_numpy(cur_df.dst.values)
edge_index = torch.stack([
torch.cat([src, dst]),
torch.cat([dst, src])
])
edge_index, edge_cnt = torch.unique(edge_index, dim=1, return_counts=True)
mask = edge_index[0]!=edge_index[1] # ignore self-loops
adj = SparseTensor(
value = torch.ones_like(edge_cnt[mask]).float(),
row = edge_index[0][mask].long(),
col = edge_index[1][mask].long(),
sparse_sizes=(num_nodes, num_nodes)
)
adj_norm = row_norm(adj).to(args.device)
sign_feats = [node_feats]
for _ in range(args.structure_hops):
sign_feats.append(adj_norm@sign_feats[-1])
sign_feats = torch.sum(torch.stack(sign_feats), dim=0)
output_feats[_root_ind] = sign_feats[root_nodes[_root_ind]]
i += len(_root_ind) // num_duplicate
return output_feats
################################################################################################
################################################################################################
################################################################################################
"""
Source: STHN torch_encodings
URL: https://github.com/celi52/STHN/blob/main/torch_encodings.py
"""
def get_emb(sin_inp):
"""
Gets a base embedding for one dimension with sin and cos intertwined
"""
emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1)
return torch.flatten(emb, -2, -1)
class PositionalEncoding1D(nn.Module):
def __init__(self, channels):
"""
:param channels: The last dimension of the tensor you want to apply pos emb to.
"""
super(PositionalEncoding1D, self).__init__()
self.org_channels = channels
channels = int(np.ceil(channels / 2) * 2)
self.channels = channels
inv_freq = 1.0 / (1000 ** (torch.arange(0, channels, 2).float() / channels))
self.register_buffer("inv_freq", inv_freq)
self.cached_penc = None
def forward(self, tensor):
"""
:param tensor: A 3d tensor of size (batch_size, x, ch)
:return: Positional Encoding Matrix of size (batch_size, x, ch)
"""
if len(tensor.shape) != 3:
raise RuntimeError("The input tensor has to be 3d!")
if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:
return self.cached_penc
self.cached_penc = None
batch_size, x, orig_ch = tensor.shape
pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type())
sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
emb_x = get_emb(sin_inp_x)
emb = torch.zeros((x, self.channels), device=tensor.device).type(tensor.type())
emb[:, : self.channels] = emb_x
self.cached_penc = emb[None, :, :orig_ch].repeat(batch_size, 1, 1)
return self.cached_penc
class PositionalEncodingPermute1D(nn.Module):
def __init__(self, channels):
"""
Accepts (batchsize, ch, x) instead of (batchsize, x, ch)
"""
super(PositionalEncodingPermute1D, self).__init__()
self.penc = PositionalEncoding1D(channels)
def forward(self, tensor):
tensor = tensor.permute(0, 2, 1)
enc = self.penc(tensor)
return enc.permute(0, 2, 1)
@property
def org_channels(self):
return self.penc.org_channels
class PositionalEncoding2D(nn.Module):
def __init__(self, channels):
"""
:param channels: The last dimension of the tensor you want to apply pos emb to.
"""
super(PositionalEncoding2D, self).__init__()
self.org_channels = channels
channels = int(np.ceil(channels / 4) * 2)
self.channels = channels
inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels))
self.register_buffer("inv_freq", inv_freq)
self.cached_penc = None
def forward(self, tensor):
"""
:param tensor: A 4d tensor of size (batch_size, x, y, ch)
:return: Positional Encoding Matrix of size (batch_size, x, y, ch)
"""
if len(tensor.shape) != 4:
raise RuntimeError("The input tensor has to be 4d!")
if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:
return self.cached_penc
self.cached_penc = None
batch_size, x, y, orig_ch = tensor.shape
pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type())
pos_y = torch.arange(y, device=tensor.device).type(self.inv_freq.type())
sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq)
emb_x = get_emb(sin_inp_x).unsqueeze(1)
emb_y = get_emb(sin_inp_y)
emb = torch.zeros((x, y, self.channels * 2), device=tensor.device).type(
tensor.type()
)
emb[:, :, : self.channels] = emb_x
emb[:, :, self.channels : 2 * self.channels] = emb_y
self.cached_penc = emb[None, :, :, :orig_ch].repeat(tensor.shape[0], 1, 1, 1)
return self.cached_penc
class PositionalEncodingPermute2D(nn.Module):
def __init__(self, channels):
"""
Accepts (batchsize, ch, x, y) instead of (batchsize, x, y, ch)
"""
super(PositionalEncodingPermute2D, self).__init__()
self.penc = PositionalEncoding2D(channels)
def forward(self, tensor):
tensor = tensor.permute(0, 2, 3, 1)
enc = self.penc(tensor)
return enc.permute(0, 3, 1, 2)
@property
def org_channels(self):
return self.penc.org_channels
class PositionalEncoding3D(nn.Module):
def __init__(self, channels):
"""
:param channels: The last dimension of the tensor you want to apply pos emb to.
"""
super(PositionalEncoding3D, self).__init__()
self.org_channels = channels
channels = int(np.ceil(channels / 6) * 2)
if channels % 2:
channels += 1
self.channels = channels
inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels))
self.register_buffer("inv_freq", inv_freq)
self.cached_penc = None
def forward(self, tensor):
"""
:param tensor: A 5d tensor of size (batch_size, x, y, z, ch)
:return: Positional Encoding Matrix of size (batch_size, x, y, z, ch)
"""
if len(tensor.shape) != 5:
raise RuntimeError("The input tensor has to be 5d!")
if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:
return self.cached_penc
self.cached_penc = None
batch_size, x, y, z, orig_ch = tensor.shape
pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type())
pos_y = torch.arange(y, device=tensor.device).type(self.inv_freq.type())
pos_z = torch.arange(z, device=tensor.device).type(self.inv_freq.type())
sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq)
sin_inp_z = torch.einsum("i,j->ij", pos_z, self.inv_freq)
emb_x = get_emb(sin_inp_x).unsqueeze(1).unsqueeze(1)
emb_y = get_emb(sin_inp_y).unsqueeze(1)
emb_z = get_emb(sin_inp_z)
emb = torch.zeros((x, y, z, self.channels * 3), device=tensor.device).type(
tensor.type()
)
emb[:, :, :, : self.channels] = emb_x
emb[:, :, :, self.channels : 2 * self.channels] = emb_y
emb[:, :, :, 2 * self.channels :] = emb_z
self.cached_penc = emb[None, :, :, :, :orig_ch].repeat(batch_size, 1, 1, 1, 1)
return self.cached_penc
class PositionalEncodingPermute3D(nn.Module):
def __init__(self, channels):
"""
Accepts (batchsize, ch, x, y, z) instead of (batchsize, x, y, z, ch)
"""
super(PositionalEncodingPermute3D, self).__init__()
self.penc = PositionalEncoding3D(channels)
def forward(self, tensor):
tensor = tensor.permute(0, 2, 3, 4, 1)
enc = self.penc(tensor)
return enc.permute(0, 4, 1, 2, 3)
@property
def org_channels(self):
return self.penc.org_channels
class Summer(nn.Module):
def __init__(self, penc):
"""
:param model: The type of positional encoding to run the summer on.
"""
super(Summer, self).__init__()
self.penc = penc
def forward(self, tensor):
"""
:param tensor: A 3, 4 or 5d tensor that matches the model output size
:return: Positional Encoding Matrix summed to the original tensor
"""
penc = self.penc(tensor)
assert (
tensor.size() == penc.size()
), "The original tensor size {} and the positional encoding tensor size {} must match!".format(
tensor.size(), penc.size()
)
return tensor + penc
"""
Source: STHN model.py
URL: https://github.com/celi52/STHN/blob/main/model.py
"""
"""
Module: Time-encoder
"""
class TimeEncode(nn.Module):
"""
out = linear(time_scatter): 1-->time_dims
out = cos(out)
"""
def __init__(self, dim):
super(TimeEncode, self).__init__()
self.dim = dim
self.w = nn.Linear(1, dim)
self.reset_parameters()
def reset_parameters(self, ):
self.w.weight = nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, self.dim, dtype=np.float32))).reshape(self.dim, -1))
self.w.bias = nn.Parameter(torch.zeros(self.dim))
self.w.weight.requires_grad = False
self.w.bias.requires_grad = False
@torch.no_grad()
def forward(self, t):
output = torch.cos(self.w(t.reshape((-1, 1))))
return output
################################################################################################
################################################################################################
################################################################################################
"""
Module: STHN
"""
class FeedForward(nn.Module):
"""
2-layer MLP with GeLU (fancy version of ReLU) as activation
"""
def __init__(self, dims, expansion_factor, dropout=0, use_single_layer=False):
super().__init__()
self.dims = dims
self.use_single_layer = use_single_layer
self.expansion_factor = expansion_factor
self.dropout = dropout
if use_single_layer:
self.linear_0 = nn.Linear(dims, dims)
else:
self.linear_0 = nn.Linear(dims, int(expansion_factor * dims))
self.linear_1 = nn.Linear(int(expansion_factor * dims), dims)
self.reset_parameters()
def reset_parameters(self):
self.linear_0.reset_parameters()
if self.use_single_layer==False:
self.linear_1.reset_parameters()
def forward(self, x):
x = self.linear_0(x)
x = F.gelu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
if self.use_single_layer==False:
x = self.linear_1(x)
x = F.dropout(x, p=self.dropout, training=self.training)
return x
class TransformerBlock(nn.Module):
"""
out = X.T + MLP_Layernorm(X.T) # apply token mixing
out = out.T + MLP_Layernorm(out.T) # apply channel mixing
"""
def __init__(self, dims,
channel_expansion_factor=4,
dropout=0.2,
module_spec=None, use_single_layer=False):
super().__init__()
if module_spec == None:
self.module_spec = ['token', 'channel']
else:
self.module_spec = module_spec.split('+')
self.dims = dims
if 'token' in self.module_spec:
self.transformer_encoder = _MultiheadAttention(d_model=dims,
n_heads=2,
d_k=None,
d_v=None,
attn_dropout=dropout)
if 'channel' in self.module_spec:
self.channel_layernorm = nn.LayerNorm(dims)
self.channel_forward = FeedForward(dims, channel_expansion_factor, dropout, use_single_layer)
def reset_parameters(self):
if 'token' in self.module_spec:
self.transformer_encoder.reset_parameters()
if 'channel' in self.module_spec:
self.channel_layernorm.reset_parameters()
self.channel_forward.reset_parameters()
def token_mixer(self, x):
x = self.transformer_encoder(x, x, x)
return x
def channel_mixer(self, x):
x = self.channel_layernorm(x)
x = self.channel_forward(x)
return x
def forward(self, x):
if 'token' in self.module_spec:
x = x + self.token_mixer(x)
if 'channel' in self.module_spec:
x = x + self.channel_mixer(x)
return x
class _MultiheadAttention(nn.Module):
def __init__(self, d_model, n_heads, d_k=None, d_v=None, res_attention=False, attn_dropout=0., proj_dropout=0., qkv_bias=True, lsa=False):
"""Multi Head Attention Layer
Input shape:
Q: [batch_size (bs) x max_q_len x d_model]
K, V: [batch_size (bs) x q_len x d_model]
mask: [q_len x q_len]
"""
super().__init__()
d_k = d_model // n_heads if d_k is None else d_k
d_v = d_model // n_heads if d_v is None else d_v
self.n_heads, self.d_k, self.d_v = n_heads, d_k, d_v
self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)
self.W_K = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)
self.W_V = nn.Linear(d_model, d_v * n_heads, bias=qkv_bias)
# Scaled Dot-Product Attention (multiple heads)
self.res_attention = res_attention
self.sdp_attn = _ScaledDotProductAttention(d_model, n_heads, attn_dropout=attn_dropout, res_attention=self.res_attention, lsa=lsa)
# Poject output
self.to_out = nn.Sequential(nn.Linear(n_heads * d_v, d_model), nn.Dropout(attn_dropout))
def reset_parameters(self):
self.to_out[0].reset_parameters()
self.W_Q.reset_parameters()
self.W_K.reset_parameters()
self.W_V.reset_parameters()
def forward(self, Q:Tensor, K:Optional[Tensor]=None, V:Optional[Tensor]=None, prev:Optional[Tensor]=None,
key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None):
bs = Q.size(0)
if K is None: K = Q
if V is None: V = Q
# Linear (+ split in multiple heads)
q_s = self.W_Q(Q).view(bs, -1, self.n_heads, self.d_k).transpose(1,2) # q_s : [bs x n_heads x max_q_len x d_k]
k_s = self.W_K(K).view(bs, -1, self.n_heads, self.d_k).permute(0,2,3,1) # k_s : [bs x n_heads x d_k x q_len] - transpose(1,2) + transpose(2,3)
v_s = self.W_V(V).view(bs, -1, self.n_heads, self.d_v).transpose(1,2) # v_s : [bs x n_heads x q_len x d_v]
# Apply Scaled Dot-Product Attention (multiple heads)
output, attn_weights = self.sdp_attn(q_s, k_s, v_s, prev=prev, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
# back to the original inputs dimensions
output = output.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.d_v) # output: [bs x q_len x n_heads * d_v]
output = self.to_out(output)
return output
class _ScaledDotProductAttention(nn.Module):
r"""Scaled Dot-Product Attention module (Attention is all you need by Vaswani et al., 2017) with optional residual attention from previous layer
(Realformer: Transformer likes residual attention by He et al, 2020) and locality self sttention (Vision Transformer for Small-Size Datasets
by Lee et al, 2021)"""
def __init__(self, d_model, n_heads, attn_dropout=0., res_attention=False, lsa=False):
super().__init__()
self.attn_dropout = nn.Dropout(attn_dropout)
self.res_attention = res_attention
head_dim = d_model // n_heads
self.scale = nn.Parameter(torch.tensor(head_dim ** -0.5), requires_grad=lsa)
self.lsa = lsa
def forward(self, q:Tensor, k:Tensor, v:Tensor, prev:Optional[Tensor]=None, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None):
'''
Input shape:
q : [bs x n_heads x max_q_len x d_k]
k : [bs x n_heads x d_k x seq_len]
v : [bs x n_heads x seq_len x d_v]
prev : [bs x n_heads x q_len x seq_len]
key_padding_mask: [bs x seq_len]
attn_mask : [1 x seq_len x seq_len]
Output shape:
output: [bs x n_heads x q_len x d_v]
attn : [bs x n_heads x q_len x seq_len]
scores : [bs x n_heads x q_len x seq_len]
'''
# Scaled MatMul (q, k) - similarity scores for all pairs of positions in an input sequence
attn_scores = torch.matmul(q, k) * self.scale # attn_scores : [bs x n_heads x max_q_len x q_len]
# Add pre-softmax attention scores from the previous layer (optional)
if prev is not None: attn_scores = attn_scores + prev
# Attention mask (optional)
if attn_mask is not None: # attn_mask with shape [q_len x seq_len] - only used when q_len == seq_len
if attn_mask.dtype == torch.bool:
attn_scores.masked_fill_(attn_mask, -np.inf)
else:
attn_scores += attn_mask
# Key padding mask (optional)
if key_padding_mask is not None: # mask with shape [bs x q_len] (only when max_w_len == q_len)
attn_scores.masked_fill_(key_padding_mask.unsqueeze(1).unsqueeze(2), -np.inf)
# normalize the attention weights
attn_weights = F.softmax(attn_scores, dim=-1) # attn_weights : [bs x n_heads x max_q_len x q_len]
attn_weights = self.attn_dropout(attn_weights)
# compute the new values given the attention weights
output = torch.matmul(attn_weights, v) # output: [bs x n_heads x max_q_len x d_v]
if self.res_attention: return output, attn_weights, attn_scores
else: return output, attn_weights
class FeatEncode(nn.Module):
"""
Return [raw_edge_feat | TimeEncode(edge_time_stamp)]
"""
def __init__(self, time_dims, feat_dims, out_dims):
super().__init__()
self.time_encoder = TimeEncode(time_dims)
self.feat_encoder = nn.Linear(time_dims + feat_dims, out_dims)
self.reset_parameters()
def reset_parameters(self):
self.time_encoder.reset_parameters()
self.feat_encoder.reset_parameters()
def forward(self, edge_feats, edge_ts):
edge_time_feats = self.time_encoder(edge_ts)
x = torch.cat([edge_feats, edge_time_feats], dim=1)
return self.feat_encoder(x)
class Patch_Encoding(nn.Module):
"""
Input : [ batch_size, graph_size, edge_dims+time_dims]
Output: [ batch_size, graph_size, output_dims]
"""
def __init__(self, per_graph_size, time_channels,
input_channels, hidden_channels, out_channels,
num_layers, dropout,
channel_expansion_factor,
window_size,
module_spec=None,
use_single_layer=False
):
super().__init__()
self.per_graph_size = per_graph_size
self.dropout = nn.Dropout(dropout)
self.num_layers = num_layers
# input & output classifer
self.feat_encoder = FeatEncode(time_channels, input_channels, hidden_channels)
self.layernorm = nn.LayerNorm(hidden_channels)
self.mlp_head = nn.Linear(hidden_channels, out_channels)
# inner layers
self.mixer_blocks = torch.nn.ModuleList()
for _ in range(num_layers):
self.mixer_blocks.append(
TransformerBlock(hidden_channels,
channel_expansion_factor,
dropout,
module_spec=None,
use_single_layer=use_single_layer)
)
# padding
self.stride = window_size
self.window_size = window_size
self.pad_projector = nn.Linear(window_size*hidden_channels, hidden_channels)
self.p_enc_1d_model_sum = Summer(PositionalEncoding1D(hidden_channels))
self.reset_parameters()
def reset_parameters(self):
for layer in self.mixer_blocks:
layer.reset_parameters()
self.feat_encoder.reset_parameters()
self.layernorm.reset_parameters()
self.mlp_head.reset_parameters()
def forward(self, edge_feats, edge_ts, batch_size, inds):
# x : [ batch_size, graph_size, edge_dims+time_dims]
edge_time_feats = self.feat_encoder(edge_feats, edge_ts)
x = torch.zeros((batch_size * self.per_graph_size,
edge_time_feats.size(1)), device=edge_feats.device)
x[inds] = x[inds] + edge_time_feats
x = x. view(-1, self.per_graph_size//self.window_size, self.window_size*x.shape[-1])
x = self.pad_projector(x)
x = self.p_enc_1d_model_sum(x)
for i in range(self.num_layers):
# apply to channel + feat dim
x = self.mixer_blocks[i](x)
x = self.layernorm(x)
x = torch.mean(x, dim=1)
x = self.mlp_head(x)
return x
################################################################################################
################################################################################################
################################################################################################
"""
Edge predictor
"""
class EdgePredictor_per_node(torch.nn.Module):
"""
out = linear(src_node_feats) + linear(dst_node_feats)
out = ReLU(out)
"""
def __init__(self, dim_in_time, dim_in_node, predict_class):
super().__init__()
self.dim_in_time = dim_in_time
self.dim_in_node = dim_in_node
# dim_in_time + dim_in_node
self.src_fc = torch.nn.Linear(dim_in_time+dim_in_node, 100)
self.dst_fc = torch.nn.Linear(dim_in_time+dim_in_node, 100)
self.out_fc = torch.nn.Linear(100, predict_class)
self.reset_parameters()
def reset_parameters(self, ):
self.src_fc.reset_parameters()
self.dst_fc.reset_parameters()
self.out_fc.reset_parameters()
def forward(self, h, neg_samples=1):
num_edge = h.shape[0]//(neg_samples + 2)
h_src = self.src_fc(h[:num_edge])
h_pos_dst = self.dst_fc(h[num_edge:2 * num_edge])
h_neg_dst = self.dst_fc(h[2 * num_edge:])
h_pos_edge = torch.nn.functional.relu(h_src + h_pos_dst)
h_neg_edge = torch.nn.functional.relu(h_src.tile(neg_samples, 1) + h_neg_dst)
return self.out_fc(h_pos_edge), self.out_fc(h_neg_edge)
class STHN_Interface(nn.Module):
def __init__(self, mlp_mixer_configs, edge_predictor_configs):
super(STHN_Interface, self).__init__()
self.time_feats_dim = edge_predictor_configs['dim_in_time']
self.node_feats_dim = edge_predictor_configs['dim_in_node']
if self.time_feats_dim > 0:
self.base_model = Patch_Encoding(**mlp_mixer_configs)
self.edge_predictor = EdgePredictor_per_node(**edge_predictor_configs)
self.creterion = nn.BCEWithLogitsLoss(reduction='none')
self.reset_parameters()
def reset_parameters(self):
if self.time_feats_dim > 0:
self.base_model.reset_parameters()
self.edge_predictor.reset_parameters()
def forward(self, model_inputs, neg_samples, node_feats):
pred_pos, pred_neg = self.predict(model_inputs, neg_samples, node_feats)
all_pred = torch.cat((pred_pos, pred_neg), dim=0)
all_edge_label = torch.cat((torch.ones_like(pred_pos),
torch.zeros_like(pred_neg)), dim=0)
loss = self.creterion(all_pred, all_edge_label).mean()
return loss, all_pred, all_edge_label
def predict(self, model_inputs, neg_samples, node_feats):
if self.time_feats_dim > 0 and self.node_feats_dim == 0:
x = self.base_model(*model_inputs)
elif self.time_feats_dim > 0 and self.node_feats_dim > 0:
x = self.base_model(*model_inputs)
x = torch.cat([x, node_feats], dim=1)
elif self.time_feats_dim == 0 and self.node_feats_dim > 0:
x = node_feats
else:
print('Either time_feats_dim or node_feats_dim must larger than 0!')
pred_pos, pred_neg = self.edge_predictor(x, neg_samples=neg_samples)
return pred_pos, pred_neg
class Multiclass_Interface(nn.Module):
def __init__(self, mlp_mixer_configs, edge_predictor_configs):
super(Multiclass_Interface, self).__init__()
self.time_feats_dim = edge_predictor_configs['dim_in_time']
self.node_feats_dim = edge_predictor_configs['dim_in_node']
if self.time_feats_dim > 0:
self.base_model = Patch_Encoding(**mlp_mixer_configs)
self.edge_predictor = EdgePredictor_per_node(**edge_predictor_configs)
self.creterion = nn.CrossEntropyLoss(reduction='none')
self.reset_parameters()
def reset_parameters(self):
if self.time_feats_dim > 0:
self.base_model.reset_parameters()
self.edge_predictor.reset_parameters()
def forward(self, model_inputs, neg_samples, node_feats):
pos_edge_label = model_inputs[-1].view(-1,1)
model_inputs = model_inputs[:-1]
pred_pos, pred_neg = self.predict(model_inputs, neg_samples, node_feats)
all_pred = torch.cat((pred_pos, pred_neg), dim=0)
all_edge_label = torch.squeeze(torch.cat((pos_edge_label, torch.zeros_like(pos_edge_label)), dim=0))
loss = self.creterion(all_pred, all_edge_label).mean()
return loss, all_pred, all_edge_label
def predict(self, model_inputs, neg_samples, node_feats):
if self.time_feats_dim > 0 and self.node_feats_dim == 0:
x = self.base_model(*model_inputs)
elif self.time_feats_dim > 0 and self.node_feats_dim > 0:
x = self.base_model(*model_inputs)
x = torch.cat([x, node_feats], dim=1)
elif self.time_feats_dim == 0 and self.node_feats_dim > 0:
x = node_feats
else:
print('Either time_feats_dim or node_feats_dim must larger than 0!')
pred_pos, pred_neg = self.edge_predictor(x, neg_samples=neg_samples)
return pred_pos, pred_neg
================================================
FILE: modules/sthn_sampler_setup.py
================================================
from glob import glob
from setuptools import setup
from pybind11.setup_helpers import Pybind11Extension
ext_modules = [
Pybind11Extension("sampler_core",
['sampler_core.cpp'],
extra_compile_args = ['-fopenmp'],
extra_link_args = ['-fopenmp'],),
]
setup(
name = "sampler_core",
version = "0.0.1",
author = "XXXX-2",
author_email = "XXXX-3",
url = "XXXX-4",
description = "Parallel Sampling for Temporal Graphs",
ext_modules = ext_modules,
)
================================================
FILE: modules/time_enc.py
================================================
"""
Time Encoding Module
Reference:
- https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/models/tgn.html
"""
import torch
from torch import Tensor
from torch.nn import Linear
class TimeEncoder(torch.nn.Module):
def __init__(self, out_channels: int):
super().__init__()
self.out_channels = out_channels
self.lin = Linear(1, out_channels)
def reset_parameters(self):
self.lin.reset_parameters()
def forward(self, t: Tensor) -> Tensor:
return self.lin(t.view(-1, 1)).cos()
================================================
FILE: modules/timetraveler_agent.py
================================================
"""
TimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting
Reference:
- https://github.com/JHL-HUST/TITer/blob/master/model/agent.py
Haohai Sun, Jialun Zhong, Yunpu Ma, Zhen Han, Kun He.
TimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting EMNLP 2021
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class HistoryEncoder(nn.Module):
def __init__(self, config):
super(HistoryEncoder, self).__init__()
self.config = config
self.lstm_cell = torch.nn.LSTMCell(input_size=config['action_dim'],
hidden_size=config['state_dim'])
def set_hiddenx(self, batch_size):
"""Set hidden layer parameters. Initialize to 0"""
if self.config['cuda']:
self.hx = torch.zeros(batch_size, self.config['state_dim'], device='cuda')
self.cx = torch.zeros(batch_size, self.config['state_dim'], device='cuda')
else:
self.hx = torch.zeros(batch_size, self.config['state_dim'])
self.cx = torch.zeros(batch_size, self.config['state_dim'])
def forward(self, prev_action, mask):
"""mask: True if NO_OP. ON_OP does not affect history coding results"""
self.hx_, self.cx_ = self.lstm_cell(prev_action, (self.hx, self.cx))
self.hx = torch.where(mask, self.hx, self.hx_)
self.cx = torch.where(mask, self.cx, self.cx_)
return self.hx
class PolicyMLP(nn.Module):
def __init__(self, config):
super(PolicyMLP, self).__init__()
self.mlp_l1= nn.Linear(config['mlp_input_dim'], config['mlp_hidden_dim'], bias=True)
self.mlp_l2 = nn.Linear(config['mlp_hidden_dim'], config['action_dim'], bias=True)
def forward(self, state_query):
hidden = torch.relu(self.mlp_l1(state_query))
output = self.mlp_l2(hidden).unsqueeze(1)
return output
class DynamicEmbedding(nn.Module):
def __init__(self, n_ent, dim_ent, dim_t):
super(DynamicEmbedding, self).__init__()
self.ent_embs = nn.Embedding(n_ent, dim_ent - dim_t)
self.w = torch.nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, dim_t))).float())
self.b = torch.nn.Parameter(torch.zeros(dim_t).float())
def forward(self, entities, dt):
dt = dt.unsqueeze(-1)
batch_size = dt.size(0)
seq_len = dt.size(1)
dt = dt.view(batch_size, seq_len, 1)
t = torch.cos(self.w.view(1, 1, -1) * dt + self.b.view(1, 1, -1))
t = t.squeeze(1) # [batch_size, time_dim]
e = self.ent_embs(entities)
return torch.cat((e, t), -1)
class StaticEmbedding(nn.Module):
def __init__(self, n_ent, dim_ent):
super(StaticEmbedding, self).__init__()
self.ent_embs = nn.Embedding(n_ent, dim_ent)
def forward(self, entities, timestamps=None):
return self.ent_embs(entities)
class Agent(nn.Module):
def __init__(self, config):
super(Agent, self).__init__()
self.num_rel = config['num_rel'] * 2 + 2
self.config = config
# [0, num_rel) -> normal relations; num_rel -> stay in place,(num_rel, num_rel * 2] reversed relations.
self.NO_OP = self.num_rel # Stay in place; No Operation
self.ePAD = config['num_ent'] # Padding entity
self.rPAD = config['num_rel'] * 2 + 1 # Padding relation
self.tPAD = 0 # Padding time
if self.config['entities_embeds_method'] == 'dynamic':
self.ent_embs = DynamicEmbedding(config['num_ent']+1, config['ent_dim'], config['time_dim'])
else:
self.ent_embs = StaticEmbedding(config['num_ent']+1, config['ent_dim'])
self.rel_embs = nn.Embedding(config['num_ent'], config['rel_dim'])
self.policy_step = HistoryEncoder(config)
self.policy_mlp = PolicyMLP(config)
self.score_weighted_fc = nn.Linear(
self.config['ent_dim'] * 2 + self.config['rel_dim'] * 2 + self.config['state_dim'],
1, bias=True)
def forward(self, prev_relation, current_entities, current_timestamps,
query_relation, query_entity, query_timestamps, action_space):
"""
Args:
prev_relation: [batch_size]
current_entities: [batch_size]
current_timestamps: [batch_size]
query_relation: embeddings of query relation,[batch_size, rel_dim]
query_entity: embeddings of query entity, [batch_size, ent_dim]
query_timestamps: [batch_size]
action_space: [batch_size, max_actions_num, 3] (relations, entities, timestamps)
"""
# embeddings
current_delta_time = query_timestamps - current_timestamps
current_embds = self.ent_embs(current_entities, current_delta_time) # [batch_size, ent_dim] #dynamic embedding
prev_relation_embds = self.rel_embs(prev_relation) # [batch_size, rel_dim]
# Pad Mask
pad_mask = torch.ones_like(action_space[:, :, 0]) * self.rPAD # [batch_size, action_number]
pad_mask = torch.eq(action_space[:, :, 0], pad_mask) # [batch_size, action_number]
# History Encode
NO_OP_mask = torch.eq(prev_relation, torch.ones_like(prev_relation) * self.NO_OP) # [batch_size]
NO_OP_mask = NO_OP_mask.repeat(self.config['state_dim'], 1).transpose(1, 0) # [batch_size, state_dim]
prev_action_embedding = torch.cat([prev_relation_embds, current_embds], dim=-1) # [batch_size, rel_dim + ent_dim]
lstm_output = self.policy_step(prev_action_embedding, NO_OP_mask) # [batch_size, state_dim] (5) Path encoding
# Neighbor/condidate_actions embeddings
action_num = action_space.size(1)
neighbors_delta_time = query_timestamps.unsqueeze(-1).repeat(1, action_num) - action_space[:, :, 2]
neighbors_entities = self.ent_embs(action_space[:, :, 1], neighbors_delta_time) # [batch_size, action_num, ent_dim]
neighbors_relations = self.rel_embs(action_space[:, :, 0]) # [batch_size, action_num, rel_dim]
# agent state representation
agent_state = torch.cat([lstm_output, query_entity, query_relation], dim=-1) # [batch_size, state_dim + ent_dim + rel_dim]
output = self.policy_mlp(agent_state) # [batch_size, 1, action_dim] action_dim == rel_dim + ent_dim
# scoring
entitis_output = output[:, :, self.config['rel_dim']:]
relation_ouput = output[:, :, :self.config['rel_dim']]
relation_score = torch.sum(torch.mul(neighbors_relations, relation_ouput), dim=2)
entities_score = torch.sum(torch.mul(neighbors_entities, entitis_output), dim=2) # [batch_size, action_number]
actions = torch.cat([neighbors_relations, neighbors_entities], dim=-1) # [batch_size, action_number, action_dim]
agent_state_repeats = agent_state.unsqueeze(1).repeat(1, actions.shape[1], 1)
score_attention_input = torch.cat([actions, agent_state_repeats], dim=-1)
a = self.score_weighted_fc(score_attention_input) # (8)
a = torch.sigmoid(a).squeeze() # [batch_size, action_number] # (8)
scores = (1 - a) * relation_score + a * entities_score # (6) a= beta
# Padding mask
scores = scores.masked_fill(pad_mask, -1e10) # [batch_size ,action_number]
action_prob = torch.softmax(scores, dim=1)
action_id = torch.multinomial(action_prob, 1) # Randomly select an action. [batch_size, 1] # ACTION SELECTION
logits = torch.nn.functional.log_softmax(scores, dim=1) # [batch_size, action_number]
one_hot = torch.zeros_like(logits).scatter(1, action_id, 1)
loss = - torch.sum(torch.mul(logits, one_hot), dim=1)
return loss, logits, action_id
def get_im_embedding(self, cooccurrence_entities):
"""Get the inductive mean representation of the co-occurrence relation.
cooccurrence_entities: a list that contains the trained entities with the co-occurrence relation.
return: torch.tensor, representation of the co-occurrence entities.
"""
entities = self.ent_embs.ent_embs.weight.data[cooccurrence_entities]
im = torch.mean(entities, dim=0)
return im
def update_entity_embedding(self, entity, ims, mu):
"""Update the entity representation with the co-occurrence relations in the last timestamp.
entity: int, the entity that needs to be updated.
ims: torch.tensor, [number of co-occurrence, -1], the IM representations of the co-occurrence relations
mu: update ratio, the hyperparam.
"""
self.source_entity = self.ent_embs.ent_embs.weight.data[entity]
self.ent_embs.ent_embs.weight.data[entity] = mu * self.source_entity + (1 - mu) * torch.mean(ims, dim=0)
def entities_embedding_shift(self, entity, im, mu):
"""Prediction shift."""
self.source_entity = self.ent_embs.ent_embs.weight.data[entity]
self.ent_embs.ent_embs.weight.data[entity] = mu * self.source_entity + (1 - mu) * im
def back_entities_embedding(self, entity):
"""Go back after shift ends."""
self.ent_embs.ent_embs.weight.data[entity] = self.source_entity
================================================
FILE: modules/timetraveler_dirichlet.py
================================================
"""
TimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting
Reference:
- https://github.com/JHL-HUST/TITer/blob/master/model/dirichlet.py
Haohai Sun, Jialun Zhong, Yunpu Ma, Zhen Han, Kun He.
TimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting EMNLP 2021
"""
"""Dirichlet.py
Maximum likelihood estimation and likelihood ratio tests of Dirichlet
distribution models of data.
Most of this package is a port of Thomas P. Minka's wonderful Fastfit MATLAB
code. Much thanks to him for that and his clear paper "Estimating a Dirichlet
distribution". See the following URL for more information:
http://research.microsoft.com/en-us/um/people/minka/"""
import sys
import numpy as np
import scipy as sp
import scipy.stats as stats
from scipy.stats import dirichlet
from tqdm import tqdm
from numpy import (
arange,
array,
asanyarray,
asarray,
diag,
exp,
isscalar,
log,
ndarray,
ones,
vstack,
zeros,
)
from numpy.linalg import norm
from scipy.special import gammaln, polygamma, psi
MAXINT = sys.maxsize
__all__ = [
"loglikelihood",
"meanprecision",
"mle",
"pdf",
"test",
]
euler = -1 * psi(1) # Euler-Mascheroni constant
class NotConvergingError(Exception):
"""Error when a successive approximation method doesn't converge
"""
pass
def test(D1, D2, method="meanprecision", maxiter=None):
"""Test for statistical difference between observed proportions.
Parameters
----------
D1 : (N1, K) shape array
D2 : (N2, K) shape array
Input observations. ``N1`` and ``N2`` are the number of observations,
and ``K`` is the number of parameters for the Dirichlet distribution
(i.e. the number of levels or categorical possibilities).
Each cell is the proportion seen in that category for a particular
observation. Rows of the matrices must add up to 1.
method : string
One of ``'fixedpoint'`` and ``'meanprecision'``, designates method by
which to find MLE Dirichlet distribution. Default is
``'meanprecision'``, which is faster.
maxiter : int
Maximum number of iterations to take calculations. Default is
``sys.maxint``.
Returns
-------
D : float
Test statistic, which is ``-2 * log`` of likelihood ratios.
p : float
p-value of test.
a0 : (K,) shape array
a1 : (K,) shape array
a2 : (K,) shape array
MLE parameters for the Dirichlet distributions fit to
``D1`` and ``D2`` together, ``D1``, and ``D2``, respectively."""
N1, K1 = D1.shape
N2, K2 = D2.shape
if K1 != K2:
raise ValueError("D1 and D2 must have the same number of columns")
D0 = vstack((D1, D2))
a0 = mle(D0, method=method, maxiter=maxiter)
a1 = mle(D1, method=method, maxiter=maxiter)
a2 = mle(D2, method=method, maxiter=maxiter)
D = 2 * (loglikelihood(D1, a1) + loglikelihood(D2, a2) - loglikelihood(D0, a0))
return (D, stats.chi2.sf(D, K1), a0, a1, a2)
def pdf(alphas):
"""Returns a Dirichlet PDF function
Parameters
----------
alphas : (K,) shape array
The parameters for the distribution of shape ``(K,)``.
Returns
-------
function
The PDF function, takes an ``(N, K)`` shape input and gives an
``(N,)`` output.
"""
alphap = alphas - 1
c = np.exp(gammaln(alphas.sum()) - gammaln(alphas).sum())
def dirichlet(xs):
"""Dirichlet PDF
Parameters
----------
xs : (N, K) shape array
The ``(N, K)`` shape input matrix
Returns
-------
(N,) shape array
Point value for PDF
"""
return c * (xs ** alphap).prod(axis=1)
return dirichlet
def meanprecision(a):
"""Mean and precision of a Dirichlet distribution.
Parameters
----------
a : (K,) shape array
Parameters of a Dirichlet distribution.
Returns
-------
mean : (K,) shape array
Means of the Dirichlet distribution. Values are in [0,1].
precision : float
Precision or concentration parameter of the Dirichlet distribution."""
s = a.sum()
m = a / s
return (m, s)
def loglikelihood(D, a):
"""Compute log likelihood of Dirichlet distribution, i.e. log p(D|a).
Parameters
----------
D : (N, K) shape array
``N`` is the number of observations, ``K`` is the number of
parameters for the Dirichlet distribution.
a : (K,) shape array
Parameters for the Dirichlet distribution.
Returns
-------
logl : float
The log likelihood of the Dirichlet distribution"""
N, K = D.shape
logp = log(D).mean(axis=0)
return N * (gammaln(a.sum()) - gammaln(a).sum() + ((a - 1) * logp).sum())
def mle(D, tol=1e-7, method="meanprecision", maxiter=None):
"""Iteratively computes maximum likelihood Dirichlet distribution
for an observed data set, i.e. a for which log p(D|a) is maximum.
Parameters
----------
D : (N, K) shape array
``N`` is the number of observations, ``K`` is the number of
parameters for the Dirichlet distribution.
tol : float
If Euclidean distance between successive parameter arrays is less than
``tol``, calculation is taken to have converged.
method : string
One of ``'fixedpoint'`` and ``'meanprecision'``, designates method by
which to find MLE Dirichlet distribution. Default is
``'meanprecision'``, which is faster.
maxiter : int
Maximum number of iterations to take calculations. Default is
``sys.maxint``.
Returns
-------
a : (K,) shape array
Maximum likelihood parameters for Dirichlet distribution."""
if method == "meanprecision":
return _meanprecision(D, tol=tol, maxiter=maxiter)
else:
return _fixedpoint(D, tol=tol, maxiter=maxiter)
def _fixedpoint(D, tol=1e-7, maxiter=None):
"""Simple fixed point iteration method for MLE of Dirichlet distribution
Parameters
----------
D : (N, K) shape array
``N`` is the number of observations, ``K`` is the number of
parameters for the Dirichlet distribution.
tol : float
If Euclidean distance between successive parameter arrays is less than
``tol``, calculation is taken to have converged.
maxiter : int
Maximum number of iterations to take calculations. Default is
``sys.maxint``.
Returns
-------
a : (K,) shape array
Fixed-point estimated parameters for Dirichlet distribution."""
logp = log(D).mean(axis=0)
a0 = _init_a(D)
# Start updating
if maxiter is None:
maxiter = MAXINT
for i in range(maxiter):
a1 = _ipsi(psi(a0.sum()) + logp)
# Much faster convergence than with the more obvious condition
# `norm(a1-a0) < tol`
if abs(loglikelihood(D, a1) - loglikelihood(D, a0)) < tol:
return a1
a0 = a1
raise NotConvergingError(
"Failed to converge after {} iterations, values are {}.".format(maxiter, a1)
)
def _meanprecision(D, tol=1e-7, maxiter=None):
"""Mean/precision method for MLE of Dirichlet distribution
Uses alternating estimations of mean and precision.
Parameters
----------
D : (N, K) shape array
``N`` is the number of observations, ``K`` is the number of
parameters for the Dirichlet distribution.
tol : float
If Euclidean distance between successive parameter arrays is less than
``tol``, calculation is taken to have converged.
maxiter : int
Maximum number of iterations to take calculations. Default is
``sys.maxint``.
Returns
-------
a : (K,) shape array
Estimated parameters for Dirichlet distribution."""
D = D + 1e-9
logp = log(D).mean(axis=0)
a0 = _init_a(D)
s0 = a0.sum()
if s0 < 0:
a0 = a0 / s0
s0 = 1
elif s0 == 0:
a0 = ones(a0.shape) / len(a0)
s0 = 1
m0 = a0 / s0
# Start updating
if maxiter is None:
maxiter = MAXINT
for i in range(maxiter):
a1 = _fit_s(D, a0, logp, tol=tol)
s1 = sum(a1)
a1 = _fit_m(D, a1, logp, tol=tol)
m = a1 / s1
# Much faster convergence than with the more obvious condition
# `norm(a1-a0) < tol`
if abs(loglikelihood(D, a1) - loglikelihood(D, a0)) < tol:
return a1
a0 = a1
return a1
# raise NotConvergingError(
# f"Failed to converge after {maxiter} iterations, " f"values are {a1}."
# )
def _fit_s(D, a0, logp, tol=1e-7, maxiter=1000):
"""Update parameters via MLE of precision with fixed mean
Parameters
----------
D : (N, K) shape array
``N`` is the number of observations, ``K`` is the number of
parameters for the Dirichlet distribution.
a0 : (K,) shape array
Current parameters for Dirichlet distribution
logp : (K,) shape array
Mean of log-transformed D across N observations
tol : float
If Euclidean distance between successive parameter arrays is less than
``tol``, calculation is taken to have converged.
maxiter : int
Maximum number of iterations to take calculations. Default is 1000.
Returns
-------
(K,) shape array
Updated parameters for Dirichlet distribution."""
s1 = a0.sum()
m = a0 / s1
mlogp = (m * logp).sum()
for i in range(maxiter):
s0 = s1
g = psi(s1) - (m * psi(s1 * m)).sum() + mlogp
h = _trigamma(s1) - ((m ** 2) * _trigamma(s1 * m)).sum()
if g + s1 * h < 0:
s1 = 1 / (1 / s0 + g / h / (s0 ** 2))
if s1 <= 0:
s1 = s0 * exp(-g / (s0 * h + g)) # Newton on log s
if s1 <= 0:
s1 = 1 / (1 / s0 + g / ((s0 ** 2) * h + 2 * s0 * g)) # Newton on 1/s
if s1 <= 0:
s1 = s0 - g / h # Newton
if s1 <= 0:
raise NotConvergingError(f"Unable to update s from {s0}")
a = s1 * m
if abs(s1 - s0) < tol:
return a
return a
# raise NotConvergingError(f"Failed to converge after {maxiter} iterations, " f"s is {s1}")
def _fit_m(D, a0, logp, tol=1e-7, maxiter=1000):
"""Update parameters via MLE of mean with fixed precision s
Parameters
----------
D : (N, K) shape array
``N`` is the number of observations, ``K`` is the number of
parameters for the Dirichlet distribution.
a0 : (K,) shape array
Current parameters for Dirichlet distribution
logp : (K,) shape array
Mean of log-transformed D across N observations
tol : float
If Euclidean distance between successive parameter arrays is less than
``tol``, calculation is taken to have converged.
maxiter : int
Maximum number of iterations to take calculations. Default is 1000.
Returns
-------
(K,) shape array
Updated parameters for Dirichlet distribution."""
s = a0.sum()
for i in range(maxiter):
m = a0 / s
a1 = _ipsi(logp + (m * (psi(a0) - logp)).sum())
a1 = a1 / a1.sum() * s
if norm(a1 - a0) < tol:
return a1
a0 = a1
return a1
# raise NotConvergingError(f"Failed to converge after {maxiter} iterations, " f"s is {s}")
def _init_a(D):
"""Initial guess for Dirichlet alpha parameters given data D
Parameters
----------
D : (N, K) shape array
``N`` is the number of observations, ``K`` is the number of
parameters for the Dirichlet distribution.
Returns
-------
(K,) shape array
Crude guess for parameters of Dirichlet distribution."""
E = D.mean(axis=0)
E2 = (D ** 2).mean(axis=0)
return ((E[0] - E2[0]) / ((E2[0] - E[0] ** 2) + 1e-9 ) * E)
def _ipsi(y, tol=1.48e-9, maxiter=10):
"""Inverse of psi (digamma) using Newton's method. For the purposes
of Dirichlet MLE, since the parameters a[i] must always
satisfy a > 0, we define ipsi :: R -> (0,inf).
Parameters
----------
y : (K,) shape array
y-values of psi(x)
tol : float
If Euclidean distance between successive parameter arrays is less than
``tol``, calculation is taken to have converged.
maxiter : int
Maximum number of iterations to take calculations. Default is 10.
Returns
-------
(K,) shape array
Approximate x for psi(x)."""
y = asanyarray(y, dtype="float")
x0 = np.piecewise(
y,
[y >= -2.22, y < -2.22],
[(lambda x: exp(x) + 0.5), (lambda x: -1 / (x + euler))],
)
for i in range(maxiter):
x1 = x0 - (psi(x0) - y) / _trigamma(x0)
if norm(x1 - x0) < tol:
return x1
x0 = x1
return x1
# raise NotConvergingError(f"Failed to converge after {maxiter} iterations, " f"value is {x1}")
def _trigamma(x):
return polygamma(1, x)
class MLE_Dirchlet(object):
def __init__(self, trainQuads, num_r, k, timespan,
tol=1e-7, method="meanprecision", maxiter=10000):
"""
num_r:int, number of relations.
k:int, statistics recent K historical snapshots.
timespan:int, 24 for ICEWS, 1 for WIKI and YAGO
tol : float, If Euclidean distance between successive parameter arrays is less than
``tol``, calculation is taken to have converged.
method : string, One of ``'fixedpoint'`` and ``'meanprecision'``, designates method by
which to find MLE Dirichlet distribution. Default is ``'meanprecision'``, which is faster.
maxiter : int, Maximum number of iterations to take calculations. Default is ``sys.maxint``.
"""
self.num_r = num_r
self.k = k
self.timespan = timespan
self.tol = tol
self.method = method
self.maxiter = maxiter
self.entity_occ_times = self.get_entity_occ_times(trainQuads) # The number of occurrences of the entity at each time in the training set
self.relations_observed_data = self.get_relations_observed_data(trainQuads)
self.alphas = self.mle_dirchlet()
def get_entity_occ_times(self, trainQuads):
entity_occ_times = {} # key -> entity, value -> dict [key: time, value: times]
for quad in trainQuads:
for entity in [quad[0], quad[2]]:
if entity in entity_occ_times.keys():
if quad[3] in entity_occ_times[entity].keys():
entity_occ_times[entity][quad[3]] += 1
else:
entity_occ_times[entity][quad[3]] = 1
else:
entity_occ_times[entity] = {quad[3]: 1, }
return entity_occ_times
def get_relations_observed_data(self, trainQuads):
relations_observed_data = {} # key: relation, value: list of observed data
for quad in trainQuads:
if quad[1] not in relations_observed_data.keys():
relations_observed_data[quad[1]] = []
observed = np.zeros([self.k+1])
occ_times = self.entity_occ_times[quad[2]]
for time in occ_times.keys():
if time >= quad[3]:
continue
observed[(quad[3] - time) // self.timespan] = occ_times[time]
relations_observed_data[quad[1]].append(observed)
# reversed_r = quad[1] + 1 + self.num_r
# if reversed_r not in relations_observed_data.keys():
# relations_observed_data[reversed_r] = []
# reversed_r_observed = np.zeros([self.k+1])
# occ_times = self.entity_occ_times[quad[0]]
# for time in occ_times.keys():
# if time >= quad[3]:
# continue
# reversed_r_observed[(quad[3] - time) // self.timespan] = occ_times[time]
# relations_observed_data[reversed_r].append(reversed_r_observed)
return relations_observed_data
def mle_dirchlet(self):
alphas = {} # key: relation, value: alpha array
with tqdm(total=len(self.relations_observed_data)) as bar:
for r, observed in self.relations_observed_data.items():
alphas[r] = mle(np.array(observed), tol=self.tol, method=self.method, maxiter=self.maxiter)
bar.update(1)
return alphas
class Dirichlet(object):
def __init__(self, alphas, k):
"""alphas: Get from MLE_Dirchlet
k: int, statistics recent K historical snapshots.
"""
self.k = k
self.distributions = {}
for rel, alpha in alphas.items():
self.distributions[rel] = dirichlet(alpha)
def __call__(self, rel, dt):
if dt >= self.k:
return 0.0
p_dt = self.distributions[rel].rvs(1)[0][dt]
return p_dt
================================================
FILE: modules/timetraveler_environment.py
================================================
"""
TimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting
Reference:
- https://github.com/JHL-HUST/TITer/blob/master/model/environment.py
Haohai Sun, Jialun Zhong, Yunpu Ma, Zhen Han, Kun He.
TimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting EMNLP 2021
"""
import networkx as nx
from collections import defaultdict
import numpy as np
import torch
class Env(object):
def __init__(self, examples, config, state_action_space=None):
"""Temporal Knowledge Graph Environment.
examples: quadruples (subject, relation, object, timestamps);
config: config dict;
state_action_space: Pre-processed action space;
"""
self.config = config
self.num_rel = config['num_rel']
self.graph, self.label2nodes = self.build_graph(examples)
# [0, num_rel) -> normal relations; num_rel -> stay in place,(num_rel, num_rel * 2] reversed relations.
self.NO_OP = self.num_rel # Stay in place; No Operation
self.ePAD = config['num_ent'] # Padding entity
self.rPAD = config['num_rel'] * 2 # + 1 # Padding relation.
self.tPAD = 0 # Padding time
self.state_action_space = state_action_space # Pre-processed action space
if state_action_space:
self.state_action_space_key = self.state_action_space.keys()
def build_graph(self, examples):
"""The graph node is represented as (entity, time), and the edges are directed and labeled relation.
return:
graph: nx.MultiDiGraph;
label2nodes: a dict [keys -> entities, value-> nodes in the graph (entity, time)]
"""
graph = nx.MultiDiGraph()
label2nodes = defaultdict(set)
examples.sort(key=lambda x: x[3], reverse=True) # Reverse chronological order
for example in examples:
src = example[0]
rel = example[1]
dst = example[2]
time = example[3]
# Add the nodes and edges of the current quadruple
src_node = (src, time)
dst_node = (dst, time)
if src_node not in label2nodes[src]:
graph.add_node(src_node, label=src)
if dst_node not in label2nodes[dst]:
graph.add_node(dst_node, label=dst)
graph.add_edge(src_node, dst_node, relation=rel)
# graph.add_edge(dst_node, src_node, relation=rel+self.num_rel+1) #REMOVED by JULIA
label2nodes[src].add(src_node)
label2nodes[dst].add(dst_node)
return graph, label2nodes
def get_state_actions_space_complete(self, entity, time, current_=False, max_action_num=None):
"""Get the action space of the current state.
Args:
entity: The entity of the current state;
time: Maximum timestamp for candidate actions;
current_: Can the current time of the event be used;
max_action_num: Maximum number of events stored;
Return:
numpy array,shape: [number of events,3], (relation, dst, time)
"""
if self.state_action_space:
if (entity, time, current_) in self.state_action_space_key:
return self.state_action_space[(entity, time, current_)]
nodes = self.label2nodes[entity].copy()
if current_:
# Delete future events, you can see current events, before query time
nodes = list(filter((lambda x: x[1] <= time), nodes))
else:
# No future events, no current events
nodes = list(filter((lambda x: x[1] < time), nodes))
nodes.sort(key=lambda x: x[1], reverse=True)
actions_space = []
i = 0
for node in nodes:
for src, dst, rel in self.graph.out_edges(node, data=True):
actions_space.append((rel['relation'], dst[0], dst[1]))
i += 1
if max_action_num and i >= max_action_num:
break
if max_action_num and i >= max_action_num:
break
return np.array(list(actions_space), dtype=np.dtype('int32'))
def next_actions(self, entites, times, query_times, max_action_num=200, first_step=False):
"""Get the current action space. There must be an action that stays at the current position in the action space.
Args:
entites: torch.tensor, shape: [batch_size], the entity where the agent is currently located;
times: torch.tensor, shape: [batch_size], the timestamp of the current entity;
query_times: torch.tensor, shape: [batch_size], the timestamp of query;
max_action_num: The size of the action space;
first_step: Is it the first step for the agent.
Return: torch.tensor, shape: [batch_size, max_action_num, 3], (relation, entity, time)
"""
if self.config['cuda']:
entites = entites.cpu()
times = times.cpu()
query_times = times.cpu()
entites = entites.numpy()
times = times.numpy()
query_times = query_times.numpy()
actions = self.get_padd_actions(entites, times, query_times, max_action_num, first_step)
if self.config['cuda']:
actions = torch.tensor(actions, dtype=torch.long, device='cuda')
else:
actions = torch.tensor(actions, dtype=torch.long)
return actions
def get_padd_actions(self, entites, times, query_times, max_action_num=200, first_step=False):
"""Construct the model input array.
If the optional actions are greater than the maximum number of actions, then sample,
otherwise all are selected, and the insufficient part is pad.
"""
actions = np.ones((entites.shape[0], max_action_num, 3), dtype=np.dtype('int32'))
actions[:, :, 0] *= self.rPAD
actions[:, :, 1] *= self.ePAD
actions[:, :, 2] *= self.tPAD
for i in range(entites.shape[0]):
# NO OPERATION
actions[i, 0, 0] = self.NO_OP
actions[i, 0, 1] = entites[i]
actions[i, 0, 2] = times[i]
if times[i] == query_times[i]:
action_array = self.get_state_actions_space_complete(entites[i], times[i], False)
else:
action_array = self.get_state_actions_space_complete(entites[i], times[i], True)
if action_array.shape[0] == 0:
continue
# Whether to keep the action NO_OPERATION
start_idx = 1
if first_step:
# The first step cannot stay in place
start_idx = 0
if action_array.shape[0] > (max_action_num - start_idx):
# Sample. Take the latest events.
actions[i, start_idx:, ] = action_array[:max_action_num-start_idx]
else:
actions[i, start_idx:action_array.shape[0]+start_idx, ] = action_array
return actions
================================================
FILE: modules/timetraveler_episode.py
================================================
"""
TimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting
Reference:
- https://github.com/JHL-HUST/TITer/blob/master/model/episode.py
Haohai Sun, Jialun Zhong, Yunpu Ma, Zhen Han, Kun He.
TimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting EMNLP 2021
"""
import torch
import torch.nn as nn
class Episode(nn.Module):
def __init__(self, env, agent, config):
super(Episode, self).__init__()
self.config = config
self.env = env
self.agent = agent
self.path_length = config['path_length']
self.num_rel = config['num_rel']
self.max_action_num = config['max_action_num']
def forward(self, query_entities, query_timestamps, query_relations):
"""
Args:
query_entities: [batch_size]
query_timestamps: [batch_size]
query_relations: [batch_size]
Return:
all_loss: list
all_logits: list
all_actions_idx: list
current_entities: torch.tensor, [batch_size]
current_timestamps: torch.tensor, [batch_size]
"""
query_entities_embeds = self.agent.ent_embs(query_entities, torch.zeros_like(query_timestamps))
query_relations_embeds = self.agent.rel_embs(query_relations)
current_entites = query_entities
current_timestamps = query_timestamps
prev_relations = torch.ones_like(query_relations) * self.num_rel # NO_OP
all_loss = []
all_logits = []
all_actions_idx = []
self.agent.policy_step.set_hiddenx(query_relations.shape[0])
for t in range(self.path_length):
if t == 0:
first_step = True
else:
first_step = False
action_space = self.env.next_actions(
current_entites,
current_timestamps,
query_timestamps,
self.max_action_num,
first_step
)
loss, logits, action_id = self.agent(
prev_relations,
current_entites,
current_timestamps,
query_relations_embeds,
query_entities_embeds,
query_timestamps,
action_space,
)
chosen_relation = torch.gather(action_space[:, :, 0], dim=1, index=action_id).reshape(action_space.shape[0])
chosen_entity = torch.gather(action_space[:, :, 1], dim=1, index=action_id).reshape(action_space.shape[0])
chosen_entity_timestamps = torch.gather(action_space[:, :, 2], dim=1, index=action_id).reshape(action_space.shape[0])
all_loss.append(loss)
all_logits.append(logits)
all_actions_idx.append(action_id)
current_entites = chosen_entity
current_timestamps = chosen_entity_timestamps
prev_relations = chosen_relation
return all_loss, all_logits, all_actions_idx, current_entites, current_timestamps
def beam_search(self, query_entities, query_timestamps, query_relations):
"""
Args:
query_entities: [batch_size]
query_timestamps: [batch_size]
query_relations: [batch_size]
Return:
current_entites: [batch_size, test_rollouts_num]
beam_prob: [batch_size, test_rollouts_num]
"""
batch_size = query_entities.shape[0]
query_entities_embeds = self.agent.ent_embs(query_entities, torch.zeros_like(query_timestamps))
query_relations_embeds = self.agent.rel_embs(query_relations)
self.agent.policy_step.set_hiddenx(batch_size)
# In the first step, if rollouts_num is greater than the maximum number of actions, select all actions
current_entites = query_entities
current_timestamps = query_timestamps
prev_relations = torch.ones_like(query_relations) * self.num_rel # NO_OP
action_space = self.env.next_actions(current_entites, current_timestamps,
query_timestamps, self.max_action_num, True)
loss, logits, action_id = self.agent(
prev_relations,
current_entites,
current_timestamps,
query_relations_embeds,
query_entities_embeds,
query_timestamps,
action_space
) # logits.shape: [batch_size, max_action_num]
action_space_size = action_space.shape[1]
if self.config['beam_size'] > action_space_size:
beam_size = action_space_size
else:
beam_size = self.config['beam_size']
beam_log_prob, top_k_action_id = torch.topk(logits, beam_size, dim=1) # beam_log_prob.shape [batch_size, beam_size]
beam_log_prob = beam_log_prob.reshape(-1) # [batch_size * beam_size]
current_entites = torch.gather(action_space[:, :, 1], dim=1, index=top_k_action_id).reshape(-1) # [batch_size * beam_size]
current_timestamps = torch.gather(action_space[:, :, 2], dim=1, index=top_k_action_id).reshape(-1) # [batch_size * beam_size]
prev_relations = torch.gather(action_space[:, :, 0], dim=1, index=top_k_action_id).reshape(-1) # [batch_size * beam_size]
self.agent.policy_step.hx = self.agent.policy_step.hx.repeat(1, 1, beam_size).reshape([batch_size * beam_size, -1]) # [batch_size * beam_size, state_dim]
self.agent.policy_step.cx = self.agent.policy_step.cx.repeat(1, 1, beam_size).reshape([batch_size * beam_size, -1]) # [batch_size * beam_size, state_dim]
beam_tmp = beam_log_prob.repeat([action_space_size, 1]).transpose(1, 0) # [batch_size * beam_size, max_action_num]
for t in range(1, self.path_length):
query_timestamps_roll = query_timestamps.repeat(beam_size, 1).permute(1, 0).reshape(-1)
query_entities_embeds_roll = query_entities_embeds.repeat(1, 1, beam_size)
query_entities_embeds_roll = query_entities_embeds_roll.reshape([batch_size * beam_size, -1]) # [batch_size * beam_size, ent_dim]
query_relations_embeds_roll = query_relations_embeds.repeat(1, 1, beam_size)
query_relations_embeds_roll = query_relations_embeds_roll.reshape([batch_size * beam_size, -1]) # [batch_size * beam_size, rel_dim]
action_space = self.env.next_actions(current_entites, current_timestamps,
query_timestamps_roll, self.max_action_num)
loss, logits, action_id = self.agent(
prev_relations,
current_entites,
current_timestamps,
query_relations_embeds_roll,
query_entities_embeds_roll,
query_timestamps_roll,
action_space
) # logits.shape [bs * rollouts_num, max_action_num]
hx_tmp = self.agent.policy_step.hx.reshape(batch_size, beam_size, -1)
cx_tmp = self.agent.policy_step.cx.reshape(batch_size, beam_size, -1)
beam_tmp = beam_log_prob.repeat([action_space_size, 1]).transpose(1, 0) # [batch_size * beam_size, max_action_num]
beam_tmp += logits
beam_tmp = beam_tmp.reshape(batch_size, -1) # [batch_size, beam_size * max_actions_num]
if action_space_size * beam_size >= self.config['beam_size']:
beam_size = self.config['beam_size']
else:
beam_size = action_space_size * beam_size
top_k_log_prob, top_k_action_id = torch.topk(beam_tmp, beam_size, dim=1) # [batch_size, beam_size]
offset = top_k_action_id // action_space_size # [batch_size, beam_size]
offset = offset.unsqueeze(-1).repeat(1, 1, self.config['state_dim']) # [batch_size, beam_size]
self.agent.policy_step.hx = torch.gather(hx_tmp, dim=1, index=offset)
self.agent.policy_step.hx = self.agent.policy_step.hx.reshape([batch_size * beam_size, -1])
self.agent.policy_step.cx = torch.gather(cx_tmp, dim=1, index=offset)
self.agent.policy_step.cx = self.agent.policy_step.cx.reshape([batch_size * beam_size, -1])
current_entites = torch.gather(action_space[:, :, 1].reshape(batch_size, -1), dim=1, index=top_k_action_id).reshape(-1)
current_timestamps = torch.gather(action_space[:, :, 2].reshape(batch_size, -1), dim=1, index=top_k_action_id).reshape(-1)
prev_relations = torch.gather(action_space[:, :, 0].reshape(batch_size, -1), dim=1, index=top_k_action_id).reshape(-1)
beam_log_prob = top_k_log_prob.reshape(-1) # [batch_size * beam_size]
return action_space[:, :, 1].reshape(batch_size, -1), beam_tmp
================================================
FILE: modules/timetraveler_policygradient.py
================================================
"""
TimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting
Reference:
- https://github.com/JHL-HUST/TITer/blob/master/model/policyGradient.py and
https://github.com/JHL-HUST/TITer/blob/master/model/baseline.py
Haohai Sun, Jialun Zhong, Yunpu Ma, Zhen Han, Kun He.
TimeTraveler: Reinforcement Learning for Temporal Knowledge Graph Forecasting EMNLP 2021
"""
import torch
import numpy as np
import math
class ReactiveBaseline(object):
def __init__(self, config, update_rate):
self.update_rate = update_rate
self.value = torch.zeros(1)
if config['cuda']:
self.value = self.value.cuda()
def get_baseline_value(self):
return self.value
def update(self, target):
self.value = torch.add((1 - self.update_rate) * self.value, self.update_rate * target)
class PG(object):
def __init__(self, config):
self.config = config
self.positive_reward = 1.0
self.negative_reward = 0.0
self.baseline = ReactiveBaseline(config, config['lambda'])
self.now_epoch = 0
def get_reward(self, current_entites, answers):
positive = torch.ones_like(current_entites, dtype=torch.float32) * self.positive_reward
negative = torch.ones_like(current_entites, dtype=torch.float32) * self.negative_reward
reward = torch.where(current_entites == answers, positive, negative)
return reward
def calc_cum_discounted_reward(self, rewards):
running_add = torch.zeros([rewards.shape[0]])
cum_disc_reward = torch.zeros([rewards.shape[0], self.config['path_length']])
if self.config['cuda']:
running_add = running_add.cuda()
cum_disc_reward = cum_disc_reward.cuda()
cum_disc_reward[:, self.config['path_length'] - 1] = rewards
for t in reversed(range(self.config['path_length'])):
running_add = self.config['gamma'] * running_add + cum_disc_reward[:, t]
cum_disc_reward[:, t] = running_add
return cum_disc_reward
def entropy_reg_loss(self, all_logits):
all_logits = torch.stack(all_logits, dim=2)
entropy_loss = - torch.mean(torch.sum(torch.mul(torch.exp(all_logits), all_logits), dim=1))
return entropy_loss
def calc_reinforce_loss(self, all_loss, all_logits, cum_discounted_reward):
loss = torch.stack(all_loss, dim=1)
base_value = self.baseline.get_baseline_value()
final_reward = cum_discounted_reward - base_value
reward_mean = torch.mean(final_reward)
reward_std = torch.std(final_reward) + 1e-6
final_reward = torch.div(final_reward - reward_mean, reward_std)
loss = torch.mul(loss, final_reward)
entropy_loss = self.config['ita'] * math.pow(self.config['zita'], self.now_epoch) * self.entropy_reg_loss(all_logits)
total_loss = torch.mean(loss) - entropy_loss
return total_loss
================================================
FILE: modules/timetraveler_trainertester.py
================================================
import torch
import json
import os
import tqdm
import numpy as np
class Trainer(object):
def __init__(self, model, pg, optimizer, args, distribution=None):
self.model = model
self.pg = pg
self.optimizer = optimizer
self.args = args
self.distribution = distribution
def train_epoch(self, dataloader, ntriple):
self.model.train()
total_loss = 0.0
total_reward = 0.0
counter = 0
with tqdm.tqdm(total=ntriple, unit='ex') as bar:
bar.set_description('Train')
for src_batch, rel_batch, dst_batch, time_batch, time_orig_batch in dataloader:
if self.args.cuda:
src_batch = src_batch.cuda()
rel_batch = rel_batch.cuda()
dst_batch = dst_batch.cuda()
time_batch = time_batch.cuda()
all_loss, all_logits, _, current_entities, current_time = self.model(src_batch, time_batch, rel_batch)
reward = self.pg.get_reward(current_entities, dst_batch)
if self.args.reward_shaping:
# reward shaping
delta_time = time_batch - current_time
p_dt = []
for i in range(rel_batch.shape[0]):
rel = rel_batch[i].item()
dt = delta_time[i].item()
p_dt.append(self.distribution(rel, dt // self.args.time_span))
p_dt = torch.tensor(p_dt)
if self.args.cuda:
p_dt = p_dt.cuda()
shaped_reward = (1 + p_dt) * reward
cum_discounted_reward = self.pg.calc_cum_discounted_reward(shaped_reward)
else:
cum_discounted_reward = self.pg.calc_cum_discounted_reward(reward)
reinfore_loss = self.pg.calc_reinforce_loss(all_loss, all_logits, cum_discounted_reward)
self.pg.baseline.update(torch.mean(cum_discounted_reward))
self.pg.now_epoch += 1
self.optimizer.zero_grad()
reinfore_loss.backward()
if self.args.clip_gradient:
total_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip_gradient)
self.optimizer.step()
total_loss += reinfore_loss
total_reward += torch.mean(reward)
counter += 1
bar.update(self.args.batch_size)
bar.set_postfix(loss='%.4f' % reinfore_loss, reward='%.4f' % torch.mean(reward).item())
return total_loss / counter, total_reward / counter
def save_model(self, save_path, checkpoint_path='checkpoint.pth'):
"""Save the parameters of the model and the optimizer,"""
argparse_dict = vars(self.args)
with open(os.path.join(save_path, 'config.json'), 'w') as fjson:
json.dump(argparse_dict, fjson)
torch.save({
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict()},
os.path.join(save_path, checkpoint_path)
)
class Tester(object):
def __init__(self, model, args, train_entities, RelEntCooccurrence, metric='mrr'):
self.model = model
self.args = args
self.train_entities = train_entities
self.RelEntCooccurrence = RelEntCooccurrence
self.metric = metric
def get_rank(self, score, answer, entities_space, num_ent):
"""Get the location of the answer, if the answer is not in the array,
the ranking will be the total number of entities.
Args:
score: list, entity score
answer: int, the ground truth entity
entities_space: corresponding entity with the score
num_ent: the total number of entities
Return: the rank of the ground truth.
"""
if answer not in entities_space:
rank = num_ent
else:
answer_prob = score[entities_space.index(answer)]
score.sort(reverse=True)
rank = score.index(answer_prob) + 1
return rank
def test(self, dataloader, ntriple, num_nodes, neg_sampler, evaluator, split_mode='test'):
"""Get time-aware filtered metrics(MRR, Hits@1/3/10).
Args:
ntriple: number of the test examples.
skip_dict: time-aware filter. Get from baseDataset
num_ent: number of the entities.
Return: a dict (key -> MRR/HITS@1/HITS@3/HITS@10, values -> float)
"""
self.model.eval()
logs = []
perf_list =[]
with torch.no_grad():
with tqdm.tqdm(total=ntriple, unit='ex') as bar:
current_time = 0
cache_IM = {} # key -> entity, values: list, IM representations of the co-o relations.
for src_batch, rel_batch, dst_batch, time_batch,time_orig_batch in dataloader:
batch_size = dst_batch.size(0)
if self.args.IM:
src = src_batch[0].item()
rel = rel_batch[0].item()
dst = dst_batch[0].item()
time = time_batch[0].item()
# representation update
if current_time != time:
current_time = time
for k, v in cache_IM.items():
ims = torch.stack(v, dim=0)
self.model.agent.update_entity_embedding(k, ims, self.args.mu)
cache_IM = {}
if src not in self.train_entities and rel in self.RelEntCooccurrence['subject'].keys():
im = self.model.agent.get_im_embedding(list(self.RelEntCooccurrence['subject'][rel]))
if src in cache_IM.keys():
cache_IM[src].append(im)
else:
cache_IM[src] = [im]
# prediction shift
self.model.agent.entities_embedding_shift(src, im, self.args.mu)
if self.args.cuda:
src_batch = src_batch.cuda()
rel_batch = rel_batch.cuda()
dst_batch = dst_batch.cuda()
time_batch = time_batch.cuda()
current_entities, beam_prob = \
self.model.beam_search(src_batch, time_batch, rel_batch)
if self.args.IM and src not in self.train_entities:
# We do this
# because events that happen at the same time in the future cannot see each other.
self.model.agent.back_entities_embedding(src)
if self.args.cuda:
current_entities = current_entities.cpu()
beam_prob = beam_prob.cpu()
current_entities = current_entities.numpy()
beam_prob = beam_prob.numpy()
MRR = 0
for i in range(batch_size):
candidate_answers = current_entities[i]
candidate_score = beam_prob[i]
scores_eval_paper_authors = -10000000000.0*np.ones(num_nodes, dtype=np.float32)
# sort by score from largest to smallest
idx = np.argsort(-candidate_score)
candidate_answers = candidate_answers[idx]
candidate_score = candidate_score[idx]
# remove duplicate entities
candidate_answers, idx = np.unique(candidate_answers, return_index=True)
candidate_answers = list(candidate_answers)
candidate_score = list(candidate_score[idx])
src = src_batch[i].item()
rel = rel_batch[i].item()
dst = dst_batch[i].item()
time = time_batch[i].item()
time_orig = time_orig_batch[i].item()
if np.max(candidate_answers) >= num_nodes:
if candidate_answers[-1] == num_nodes:
logging_score_answers = candidate_answers[0:-1]
logging_score = candidate_score[0:-1]
else:
print("Problem with the score ids", np.max(candidate_answers))
else:
logging_score_answers = candidate_answers
logging_score = candidate_score
neg_samples_batch = neg_sampler.query_batch(np.expand_dims(np.array(src), axis=0),
np.expand_dims(np.array(dst), axis=0),
np.expand_dims(np.array(time_orig), axis=0),
edge_type=np.expand_dims(np.array(rel), axis=0),
split_mode=split_mode)
pos_samples_batch = dst
# get inductive inference performance.
# Only count the results of the example containing new entities.
if self.args.test_inductive and src in self.train_entities and dst in self.train_entities:
continue
# filter = skip_dict[(src, rel, time)] # a set of ground truth entities
# tmp_entities = candidate_answers.copy()
# tmp_prob = candidate_score.copy()
# # time-aware filter
# for j in range(len(tmp_entities)):
# if tmp_entities[j] in filter and tmp_entities[j] != dst:
# candidate_answers.remove(tmp_entities[j])
# candidate_score.remove(tmp_prob[j])
# ranking_raw = self.get_rank(candidate_score, dst, candidate_answers, num_ent)
scores_eval_paper_authors[logging_score_answers] = logging_score
# logs.append({
# 'MRR': 1.0 / ranking_raw,
# 'HITS@1': 1.0 if ranking_raw <= 1 else 0.0,
# 'HITS@3': 1.0 if ranking_raw <= 3 else 0.0,
# 'HITS@10': 1.0 if ranking_raw <= 10 else 0.0,
# })
neg_scores = scores_eval_paper_authors[neg_samples_batch]
pos_scores = scores_eval_paper_authors[pos_samples_batch]
input_dict = {
"y_pred_pos": np.array([pos_scores]),
"y_pred_neg": np.array(neg_scores),
"eval_metric": [self.metric],
}
perf_list.append(evaluator.eval(input_dict)[self.metric])
bar.update(batch_size)
bar.set_postfix(MRR='{}'.format(perf_list[-1] / batch_size))
metrics = {}
metrics[self.metric] = np.mean(perf_list)
# for metric in logs[0].keys():
# metrics[metric] = sum([log[metric] for log in logs]) / len(logs)
return metrics
def getRelEntCooccurrence(quadruples, num_rels):
"""Used for Inductive-Mean. Get co-occurrence in the training set.
https://github.com/JHL-HUST/TITer/blob/master/dataset/baseDataset.py
from Timetraveler
return:
{'subject': a dict[key -> relation, values -> a set of co-occurrence subject entities],
'object': a dict[key -> relation, values -> a set of co-occurrence object entities],}
"""
relation_entities_s = {}
relation_entities_o = {}
for ex in quadruples:
s, r, o = ex[0], ex[1], ex[2]
reversed_r = r + num_rels + 1
if r not in relation_entities_s.keys():
relation_entities_s[r] = set()
relation_entities_s[r].add(s)
if r not in relation_entities_o.keys():
relation_entities_o[r] = set()
relation_entities_o[r].add(o)
if reversed_r not in relation_entities_s.keys():
relation_entities_s[reversed_r] = set()
relation_entities_s[reversed_r].add(o)
if reversed_r not in relation_entities_o.keys():
relation_entities_o[reversed_r] = set()
relation_entities_o[reversed_r].add(s)
return {'subject': relation_entities_s, 'object': relation_entities_o}
================================================
FILE: modules/tkg_utils.py
================================================
from itertools import groupby
from operator import itemgetter
from collections import defaultdict
import sys
import argparse
import numpy as np
def get_args_timetraveler(args=None):
""" Parse the arguments for "timetraveler" model
"""
parser = argparse.ArgumentParser(
description='Timetraveler',
usage='main.py [] [-h | --help]'
)
parser.add_argument('--seed', type=int, help='Random seed', default=1)
parser.add_argument('--cuda', action='store_true', help='whether to use GPU or not.')
parser.add_argument('--do_train', default=True, action='store_true', help='whether to train.')
parser.add_argument('--do_test', default=True, action='store_true', help='whether to test.')
# Train Params
parser.add_argument('--batch_size', default=512, type=int, help='training batch size.')
parser.add_argument('--max_epochs', default=400, type=int, help='max training epochs.') #400
parser.add_argument('--num_workers', default=8, type=int, help='workers number used for dataloader.')
parser.add_argument('--valid_epoch', default=30, type=int, help='validation frequency.') # 30
parser.add_argument('--lr', default=0.001, type=float, help='learning rate.')
parser.add_argument('--save_epoch', default=30, type=int, help='model saving frequency.')
parser.add_argument('--clip_gradient', default=10.0, type=float, help='for gradient crop.')
# Test Params
parser.add_argument('--test_batch_size', default=1, type=int,
help='test batch size, it needs to be set to 1 when using IM module.')
parser.add_argument('--beam_size', default=100, type=int, help='the beam number of the beam search.')
parser.add_argument('--test_inductive', action='store_true', help='whether to verify inductive inference performance.')
parser.add_argument('--IM', default=True, action='store_true', help='whether to use IM module.')
parser.add_argument('--mu', default=0.1, type=float, help='the hyperparameter of IM module.')
# Agent Params
parser.add_argument('--ent_dim', default=80, type=int, help='Embedding dimension of the entities')
parser.add_argument('--rel_dim', default=100, type=int, help='Embedding dimension of the relations')
parser.add_argument('--state_dim', default=100, type=int, help='dimension of the LSTM hidden state')
parser.add_argument('--hidden_dim', default=100, type=int, help='dimension of the MLP hidden layer')
parser.add_argument('--time_dim', default=20, type=int, help='Embedding dimension of the timestamps')
parser.add_argument('--entities_embeds_method', default='dynamic', type=str,
help='representation method of the entities, dynamic or static')
# Environment Params
parser.add_argument('--state_actions_path', default='state_actions_space.pkl', type=str,
help='the file stores preprocessed candidate action array.')
# Episode Params
parser.add_argument('--path_length', default=3, type=int, help='the agent search path length.')
parser.add_argument('--max_action_num', default=30, type=int, help='the max candidate actions number.')
# Policy Gradient Params
parser.add_argument('--Lambda', default=0.0, type=float, help='update rate of baseline.')
parser.add_argument('--Gamma', default=0.95, type=float, help='discount factor of Bellman Eq.')
parser.add_argument('--Ita', default=0.01, type=float, help='regular proportionality constant.')
parser.add_argument('--Zita', default=0.9, type=float, help='attenuation factor of entropy regular term.')
# reward shaping params
parser.add_argument('--reward_shaping', default=False, help='whether to use reward shaping.')
parser.add_argument('--time_span', default=1, type=int, help='24 for ICEWS, 1 for WIKI and YAGO')
parser.add_argument('--alphas_pkl', default='dirchlet_alphas.pkl', type=str,
help='the file storing the alpha parameters of the Dirichlet distribution.')
parser.add_argument('--k', default=12000, type=int, help='statistics recent K historical snapshots.')
# configuration for preprocessor
parser.add_argument('--store_actions_num', default=0, type=int,
help='maximum number of stored neighbors, 0 means store all.')
parser.add_argument('--preprocess', default=True,
help="Do we want preprocessing for the actionspace")
# configuration for dirichlet
parser.add_argument('--tol', default=1e-7, type=float)
parser.add_argument('--method', default='meanprecision', type=str)
parser.add_argument('--maxiter', default=100, type=int)
return parser.parse_args(args)
def get_model_config_timetraveler(args, num_ent, num_rel):
""" Get the model configuration for "timetraveler" model"""
config = {
'cuda': args.cuda, # whether to use GPU or not.
'batch_size': args.batch_size, # training batch size.
'num_ent': num_ent, # number of entities
'num_rel': num_rel, # number of relations
'ent_dim': args.ent_dim, # Embedding dimension of the entities
'rel_dim': args.rel_dim, # Embedding dimension of the relations
'time_dim': args.time_dim, # Embedding dimension of the timestamps
'state_dim': args.state_dim, # dimension of the LSTM hidden state
'action_dim': args.ent_dim + args.rel_dim, # dimension of the actions
'mlp_input_dim': args.ent_dim + args.rel_dim + args.state_dim, # dimension of the input of the MLP
'mlp_hidden_dim': args.hidden_dim, # dimension of the MLP hidden layer
'path_length': args.path_length, # agent search path length
'max_action_num': args.max_action_num, # max candidate action number
'lambda': args.Lambda, # update rate of baseline
'gamma': args.Gamma, # discount factor of Bellman Eq.
'ita': args.Ita, # regular proportionality constant
'zita': args.Zita, # attenuation factor of entropy regular term
'beam_size': args.beam_size, # beam size for beam search
'entities_embeds_method': args.entities_embeds_method, # default: 'dynamic', otherwise static encoder will be used
}
return config
def get_args_cen():
""" Get the arguments for "CEN" model"""
parser = argparse.ArgumentParser(description='CEN')
parser.add_argument("--gpu", type=int, default=0,
help="gpu")
parser.add_argument("--batch-size", type=int, default=1,
help="batch-size")
parser.add_argument("-d", "--dataset", type=str, default='tkgl-yago',
help="dataset to use")
parser.add_argument("--test", type=int, default=0,
help="1: formal test 2: continual test")
parser.add_argument("--validtest", default=False,
help="load stat from dir and directly valid and test")
parser.add_argument("--test-only", type=bool, default=False,
help="do we want to compute valid mrr or only test")
parser.add_argument("--run-statistic", action='store_true', default=False,
help="statistic the result")
parser.add_argument("--relation-evaluation", action='store_true', default=False,
help="save model accordding to the relation evalution")
parser.add_argument("--log-per-rel", action='store_true', default=False,
help="log mrr per relation in json")
# configuration for encoder RGCN stat
parser.add_argument("--weight", type=float, default=1,
help="weight of static constraint")
parser.add_argument("--task-weight", type=float, default=1,
help="weight of entity prediction task")
parser.add_argument("--kl-weight", type=float, default=0.7,
help="weight of entity prediction task")
parser.add_argument("--encoder", type=str, default="uvrgcn",
help="method of encoder")
parser.add_argument("--dropout", type=float, default=0.2,
help="dropout probability")
parser.add_argument("--skip-connect", action='store_true', default=False,
help="whether to use skip connect in a RGCN Unit")
parser.add_argument("--n-hidden", type=int, default=200,
help="number of hidden units")
parser.add_argument("--opn", type=str, default="sub",
help="opn of compgcn")
parser.add_argument("--n-bases", type=int, default=100,
help="number of weight blocks for each relation")
parser.add_argument("--n-basis", type=int, default=100,
help="number of basis vector for compgcn")
parser.add_argument("--n-layers", type=int, default=2,
help="number of propagation rounds")
parser.add_argument("--self-loop", action='store_true', default=True,
help="perform layer normalization in every layer of gcn ")
parser.add_argument("--layer-norm", action='store_true', default=True,
help="perform layer normalization in every layer of gcn ")
parser.add_argument("--relation-prediction", action='store_true', default=False,
help="add relation prediction loss")
parser.add_argument("--entity-prediction", action='store_true', default=True,
help="add entity prediction loss")
# configuration for stat training
parser.add_argument("--n-epochs", type=int, default=30,
help="number of minimum training epochs on each time step")
parser.add_argument("--lr", type=float, default=0.001,
help="learning rate")
parser.add_argument("--ft_epochs", type=int, default=30,
help="number of minimum fine-tuning epoch")
parser.add_argument("--ft_lr", type=float, default=0.001,
help="learning rate")
parser.add_argument("--norm_weight", type=float, default=1,
help="learning rate")
parser.add_argument("--grad-norm", type=float, default=1.0,
help="norm to clip gradient to")
# configuration for evaluating
parser.add_argument("--evaluate-every", type=int, default=1,
help="perform evaluation every n epochs")
# configuration for decoder
parser.add_argument("--decoder", type=str, default="convtranse",
help="method of decoder")
parser.add_argument("--input-dropout", type=float, default=0.2,
help="input dropout for decoder ")
parser.add_argument("--hidden-dropout", type=float, default=0.2,
help="hidden dropout for decoder")
parser.add_argument("--feat-dropout", type=float, default=0.2,
help="feat dropout for decoder")
# configuration for sequences stat
parser.add_argument("--train-history-len", type=int, default=3,
help="history length")
parser.add_argument("--test-history-len", type=int, default=10,
help="history length for test")
parser.add_argument("--test-history-len-2", type=int, default=2,
help="history length for test")
parser.add_argument("--start-history-len", type=int, default=3,
help="start history length")
parser.add_argument("--dilate-len", type=int, default=1,
help="dilate history graph")
# configuration for optimal parameters
parser.add_argument("--grid-search", action='store_true', default=False,
help="perform grid search for best configuration")
parser.add_argument("-tune", "--tune", type=str, default="n_hidden,n_layers,dropout,n_bases",
help="stat to use")
parser.add_argument("--num-k", type=int, default=500,
help="number of triples generated")
parser.add_argument('--seed', type=int, help='Random seed', default=1)
parser.add_argument('--run-nr', type=int, help='Run Number', default=1)
try:
args = parser.parse_args()
except:
parser.print_help()
sys.exit(0)
return args, sys.argv
def get_args_regcn():
"""Parses the arguments for REGCN model"""
parser = argparse.ArgumentParser(description='REGCN')
parser.add_argument("--gpu", type=int, default=0,
help="gpu")
parser.add_argument("--batch-size", type=int, default=1,
help="batch-size")
parser.add_argument("-d", "--dataset", type=str, default='tkgl-yago',
help="dataset to use")
parser.add_argument("--test", default=False,
help="load stat from dir and directly test")
parser.add_argument("--run-analysis", action='store_true', default=False,
help="print log info")
parser.add_argument("--run-statistic", action='store_true', default=False,
help="statistic the result")
parser.add_argument("--multi-step", action='store_true', default=False,
help="do multi-steps inference without ground truth")
parser.add_argument("--topk", type=int, default=10,
help="choose top k entities as results when do multi-steps without ground truth")
parser.add_argument("--add-static-graph", action='store_true', default=False,
help="use the info of static graph")
parser.add_argument("--add-rel-word", action='store_true', default=False,
help="use words in relaitons")
parser.add_argument("--relation-evaluation", action='store_true', default=False,
help="save model accordding to the relation evalution")
# configuration for encoder RGCN stat
parser.add_argument("--weight", type=float, default=0.5,
help="weight of static constraint")
parser.add_argument("--task-weight", type=float, default=0.7,
help="weight of entity prediction task")
parser.add_argument("--discount", type=float, default=1.0,
help="discount of weight of static constraint")
parser.add_argument("--angle", type=int, default=10,
help="evolution speed")
parser.add_argument("--encoder", type=str, default="uvrgcn",
help="method of encoder")
parser.add_argument("--aggregation", type=str, default="none",
help="method of aggregation")
parser.add_argument("--dropout", type=float, default=0.2,
help="dropout probability")
parser.add_argument("--skip-connect", action='store_true', default=False,
help="whether to use skip connect in a RGCN Unit")
parser.add_argument("--n-hidden", type=int, default=200,
help="number of hidden units")
parser.add_argument("--opn", type=str, default="sub",
help="opn of compgcn")
parser.add_argument("--n-bases", type=int, default=100,
help="number of weight blocks for each relation")
parser.add_argument("--n-basis", type=int, default=100,
help="number of basis vector for compgcn")
parser.add_argument("--n-layers", type=int, default=2,
help="number of propagation rounds")
parser.add_argument("--self-loop", action='store_true', default=True,
help="perform layer normalization in every layer of gcn ")
parser.add_argument("--layer-norm", action='store_true', default=True,
help="perform layer normalization in every layer of gcn ")
parser.add_argument("--relation-prediction", action='store_true', default=False,
help="add relation prediction loss")
parser.add_argument("--entity-prediction", action='store_true', default=True,
help="add entity prediction loss")
parser.add_argument("--split_by_relation", action='store_true', default=False,
help="do relation prediction")
# configuration for stat training
parser.add_argument("--n-epochs", type=int, default=10,
help="number of minimum training epochs on each time step") #100
parser.add_argument("--lr", type=float, default=0.001,
help="learning rate")
parser.add_argument("--grad-norm", type=float, default=1.0,
help="norm to clip gradient to")
# configuration for evaluating
parser.add_argument("--evaluate-every", type=int, default=1,
help="perform evaluation every n epochs")
parser.add_argument("--log-per-rel", action='store_true', default=False,
help="log mrr per relation in json")
# configuration for decoder
parser.add_argument("--decoder", type=str, default="convtranse",
help="method of decoder")
parser.add_argument("--input-dropout", type=float, default=0.2,
help="input dropout for decoder ")
parser.add_argument("--hidden-dropout", type=float, default=0.2,
help="hidden dropout for decoder")
parser.add_argument("--feat-dropout", type=float, default=0.2,
help="feat dropout for decoder")
# configuration for sequences stat
parser.add_argument("--train-history-len", type=int, default=3,
help="history length")
parser.add_argument("--test-history-len", type=int, default=3,
help="history length for test")
parser.add_argument("--dilate-len", type=int, default=1,
help="dilate history graph")
# configuration for optimal parameters
parser.add_argument("--grid-search", action='store_true', default=False,
help="perform grid search for best configuration")
parser.add_argument("-tune", "--tune", type=str, default="n_hidden,n_layers,dropout,n_bases",
help="stat to use")
parser.add_argument("--num-k", type=int, default=500,
help="number of triples generated")
parser.add_argument('--seed', type=int, help='Random seed', default=1)
parser.add_argument('--run-nr', type=int, help='Run Number', default=1)
try:
args = parser.parse_args()
except:
parser.print_help()
sys.exit(0)
return args, sys.argv
def compute_min_distance(unique_sorted_timestamps):
""" compute the minimum distance between timestamps, where the timestamps are in a sorted list
"""
min_distance = np.inf
for i in range(1, len(unique_sorted_timestamps)):
min_distance = min(min_distance, unique_sorted_timestamps[i] - unique_sorted_timestamps[i-1])
return min_distance
def compute_maxminmean_distances(unique_sorted_timestamps):
""" compute the maximum, minimum and mean distances between timestamps, where the timestamps are in a sorted list"""
differences = []
# Iterate over the list and compute the differences between successive elements
for i in range(len(unique_sorted_timestamps) - 1):
diff = unique_sorted_timestamps[i+1] - unique_sorted_timestamps[i]
differences.append(diff)
# Calculate the mean of the differences
mean_diff = sum(differences) / len(differences)
return np.max(differences), np.min(differences), np.mean(differences)
def group_by(data: np.array, key_idx: int) -> dict:
"""
group data in an np array to dict; where key is specified by key_idx. for example groups elements of array by relations
:param data: [np.array] data to be grouped
:param key_idx: [int] index for element of interest
returns data_dict: dict with key: values of element at index key_idx, values: all elements in data that have that value
"""
data_dict = {}
data_sorted = sorted(data, key=itemgetter(key_idx))
for key, group in groupby(data_sorted, key=itemgetter(key_idx)):
data_dict[key] = np.array(list(group))
return data_dict
def tkg_granularity_lookup(dataset_name, ts_distmean):
""" lookup the granularity of the dataset, and return the corresponding granularity
"""
if 'icews' in dataset_name or 'polecat' in dataset_name:
return 86400
elif 'wiki' in dataset_name or 'yago' in dataset_name:
return 31536000
else:
return ts_distmean
def reformat_ts(timestamps, dataset_name='tkgl'):
""" reformat timestamps s.t. they start with 0, and have stepsize 1.
:param timestamps: np.array() with timestamps
returns: np.array(ts_new)
"""
all_ts = list(set(timestamps))
all_ts.sort()
ts_min = np.min(all_ts)
if 'tkgl' in dataset_name:
ts_distmax, ts_distmin, ts_distmean = compute_maxminmean_distances(all_ts)
if ts_distmean != ts_distmin:
ts_dist = tkg_granularity_lookup(dataset_name, ts_distmean)
if ts_dist - ts_distmean > 0.1*ts_distmean:
print('PROBLEM: the distances are somehwat off from the granularity of the dataset. using original mean distance')
ts_dist = ts_distmean
else:
ts_dist = ts_distmean
else:
ts_dist = compute_min_distance(all_ts) # all_ts[1] - all_ts[0]
ts_new = []
timestamps2 = timestamps - ts_min
ts_new = np.ceil(timestamps2/ts_dist).astype(int)
return np.array(ts_new)
def get_original_ts(reformatted_ts, ts_dist, min_ts):
""" get original timestamps from reformatted timestamps
:param reformatted_ts: np.array() with reformatted timestamps
returns: np.array(ts_new)
"""
reformatted_ts = list(set(reformatted_ts))
reformatted_ts.sort()
ts_new = []
for ts in reformatted_ts:
ts_new.append((ts * ts_dist)+min_ts)
return np.array(ts_new)
def create_basis_dict(data):
"""
Create basis dictionary for the recurrency baseline model with rules of confidence 1
data: concatenated train and vali data, INCLUDING INVERSE QUADRUPLES. we need it for the relation ids.
"""
rels = list(set(data[:,1]))
basis_dict = {}
for rel in rels:
basis_id_new = []
rule_dict = {}
rule_dict["head_rel"] = int(rel)
rule_dict["body_rels"] = [int(rel)] #same body and head relation -> what happened before happens again
rule_dict["conf"] = 1 #same confidence for every rule
rule_new = rule_dict
basis_id_new.append(rule_new)
basis_dict[str(rel)] = basis_id_new
return basis_dict
def get_inv_relation_id(num_rels):
"""
Get inverse relation id.
parameters:
num_rels (int): number of relations
returns:
inv_relation_id (dict): mapping of relation to inverse relation
"""
inv_relation_id = dict()
for i in range(int(num_rels / 2)):
inv_relation_id[i] = i + int(num_rels / 2)
for i in range(int(num_rels / 2), num_rels):
inv_relation_id[i] = i % int(num_rels / 2)
return inv_relation_id
def create_scores_array(predictions_dict, num_nodes):
"""
Create an array of scores from a dictionary of predictions.
predictions_dict: a dictionary mapping indices to values
num_nodes: the size of the array
returns: an array of scores
"""
# predictions_dict is a dictionary mapping indices to values
# num_nodes is the size of the array
# Convert keys and values of the predictions_dict into NumPy arrays
keys_array = np.array(list(predictions_dict.keys()))
values_array = np.array(list(predictions_dict.values()))
# Create an array of zeros with the desired shape
predictions = np.zeros(num_nodes)
# Use advanced indexing to scatter values into predictions array
predictions[keys_array.astype(int)] = values_array.astype(float)
return predictions
================================================
FILE: modules/tkg_utils_dgl.py
================================================
import dgl
import torch
import numpy as np
from collections import defaultdict
def build_sub_graph(num_nodes, num_rels, triples, use_cuda, gpu, mode='dyn'):
"""
https://github.com/Lee-zix/CEN/blob/main/rgcn/utils.py
:param node_id: node id in the large graph
:param num_rels: number of relation
:param src: relabeled src id
:param rel: original rel id
:param dst: relabeled dst id
:param use_cuda:
:return:
"""
def comp_deg_norm(g):
in_deg = g.in_degrees(range(g.number_of_nodes())).float()
in_deg[torch.nonzero(in_deg == 0).view(-1)] = 1
norm = 1.0 / in_deg
return norm
src, rel, dst = triples.transpose()
if mode =='static':
src, dst = np.concatenate((src, dst)), np.concatenate((dst, src))
rel = np.concatenate((rel, rel + num_rels))
g = dgl.DGLGraph()
g.add_nodes(num_nodes)
#g.ndata['original_id'] = np.unique(np.concatenate((np.unique(triples[:,0]), np.unique(triples[:,2]))))
g.add_edges(src, dst)
norm = comp_deg_norm(g)
#node_id =torch.arange(0, g.num_nodes(), dtype=torch.long).view(-1, 1) #updated to deal with the fact that ot only the first k nodes of our graph have static infos
node_id = torch.arange(0, num_nodes, dtype=torch.long).view(-1, 1)
g.ndata.update({'id': node_id, 'norm': norm.view(-1, 1)})
g.apply_edges(lambda edges: {'norm': edges.dst['norm'] * edges.src['norm']})
g.edata['type'] = torch.LongTensor(rel)
uniq_r, r_len, r_to_e = r2e(triples, num_rels)
g.uniq_r = uniq_r
g.r_to_e = r_to_e
g.r_len = r_len
if use_cuda:
g = g.to(gpu)
g.r_to_e = torch.from_numpy(np.array(r_to_e))
return g
def r2e(triplets, num_rels):
""" get the mapping from relation to entities helper function for build_sub_graph()
returns:
uniq_r: set of unique relations
r_len: list of tuples, where each tuple is the start and end index of entities for a relation
e_idx: indices of entities"""
src, rel, dst = triplets.transpose()
# get all relations
uniq_r = np.unique(rel)
# uniq_r = np.concatenate((uniq_r, uniq_r+num_rels)) #we already have the inverse triples
# generate r2e
r_to_e = defaultdict(set)
for j, (src, rel, dst) in enumerate(triplets):
r_to_e[rel].add(src)
r_to_e[rel].add(dst)
r_to_e[rel+num_rels].add(src)
r_to_e[rel+num_rels].add(dst)
r_len = []
e_idx = []
idx = 0
for r in uniq_r:
r_len.append((idx,idx+len(r_to_e[r])))
e_idx.extend(list(r_to_e[r]))
idx += len(r_to_e[r])
return uniq_r, r_len, e_idx
================================================
FILE: modules/tlogic_apply_modules.py
================================================
"""
https://github.com/liu-yushan/TLogic/blob/main/mycode/rule_application.py
TLogic: Temporal Logical Rules for Explainable Link Forecasting on Temporal Knowledge Graphs.
Yushan Liu, Yunpu Ma, Marcel Hildebrandt, Mitchell Joblin, Volker Tresp
"""
import json
import numpy as np
import pandas as pd
from modules.tlogic_learn_modules import store_edges
def filter_rules(rules_dict, min_conf, min_body_supp, rule_lengths):
"""
Filter for rules with a minimum confidence, minimum body support, and
specified rule lengths.
Parameters.
rules_dict (dict): rules
min_conf (float): minimum confidence value
min_body_supp (int): minimum body support value
rule_lengths (list): rule lengths
Returns:
new_rules_dict (dict): filtered rules
"""
new_rules_dict = dict()
for k in rules_dict:
new_rules_dict[k] = []
for rule in rules_dict[k]:
cond = (
(rule["conf"] >= min_conf)
and (rule["body_supp"] >= min_body_supp)
and (len(rule["body_rels"]) in rule_lengths)
)
if cond:
new_rules_dict[k].append(rule)
return new_rules_dict
def get_window_edges(all_data, test_query_ts, learn_edges, window=-1, first_test_query_ts=0): #modified eval_paper_authors: added first_test_query_ts for validation set usage
"""
Get the edges in the data (for rule application) that occur in the specified time window.
If window is 0, all edges before the test query timestamp are included.
If window is -1, the edges on which the rules are learned are used.
If window is -2, all edges from train and validation set are used. modified by eval_paper_authors.
If window is an integer n > 0, all edges within n timestamps before the test query
timestamp are included.
Note: modified according to Julia Gastinger, Timo Sztyler, Lokesh Sharma, Anett Schuelke, Heiner Stuckenschmidt.
Comparing Apples and Oranges? On the Evaluation of Methods for Temporal Knowledge Graphs. In ECML PKDD, 2023.
https://github.com/nec-research/TLogic/blob/374c7e34f5949f98b2eccc9628f98125a63763f1/mycode/rule_application.py
Parameters:
all_data (np.ndarray): complete dataset (train/valid/test)
test_query_ts (np.ndarray): test query timestamp
learn_edges (dict): edges on which the rules are learned
window (int): time window used for rule application
first_test_query_ts (int): smallest timestamp from test set (eval_paper_authors)
Returns:
window_edges (dict): edges in the window for rule application
"""
if window > 0:
mask = (all_data[:, 3] < test_query_ts) * (
all_data[:, 3] >= test_query_ts - window
)
window_edges = store_edges(all_data[mask])
elif window == 0:
mask = all_data[:, 3] < test_query_ts #!!!
window_edges = store_edges(all_data[mask])
elif window == -1:
window_edges = learn_edges
elif window == -2: #modified eval_paper_authors: added this option
mask = all_data[:, 3] < first_test_query_ts # all edges at timestep smaller then the test queries.
# meaning all from train and valid set
window_edges = store_edges(all_data[mask])
elif window == -200: #modified eval_paper_authors: added this option
abswindow = 200
mask = (all_data[:, 3] < first_test_query_ts) * (
all_data[:, 3] >= first_test_query_ts - abswindow # all edges at timestep smaller than the test queries - 200
)
window_edges = store_edges(all_data[mask])
return window_edges
def match_body_relations(rule, edges, test_query_sub):
"""
Find edges that could constitute walks (starting from the test query subject)
that match the rule.
First, find edges whose subject match the query subject and the relation matches
the first relation in the rule body. Then, find edges whose subjects match the
current targets and the relation the next relation in the rule body.
Memory-efficient implementation.
Parameters:
rule (dict): rule from rules_dict
edges (dict): edges for rule application
test_query_sub (int): test query subject
Returns:
walk_edges (list of np.ndarrays): edges that could constitute rule walks
"""
rels = rule["body_rels"]
# Match query subject and first body relation
try:
rel_edges = edges[rels[0]]
mask = rel_edges[:, 0] == test_query_sub
new_edges = rel_edges[mask]
walk_edges = [
np.hstack((new_edges[:, 0:1], new_edges[:, 2:4]))
] # [sub, obj, ts]
cur_targets = np.array(list(set(walk_edges[0][:, 1])))
for i in range(1, len(rels)):
# Match current targets and next body relation
try:
rel_edges = edges[rels[i]]
mask = np.any(rel_edges[:, 0] == cur_targets[:, None], axis=0)
new_edges = rel_edges[mask]
walk_edges.append(
np.hstack((new_edges[:, 0:1], new_edges[:, 2:4]))
) # [sub, obj, ts]
cur_targets = np.array(list(set(walk_edges[i][:, 1])))
except KeyError:
walk_edges.append([])
break
except KeyError:
walk_edges = [[]]
return walk_edges
def match_body_relations_complete(rule, edges, test_query_sub):
"""
Find edges that could constitute walks (starting from the test query subject)
that match the rule.
First, find edges whose subject match the query subject and the relation matches
the first relation in the rule body. Then, find edges whose subjects match the
current targets and the relation the next relation in the rule body.
Parameters:
rule (dict): rule from rules_dict
edges (dict): edges for rule application
test_query_sub (int): test query subject
Returns:
walk_edges (list of np.ndarrays): edges that could constitute rule walks
"""
rels = rule["body_rels"]
# Match query subject and first body relation
try:
rel_edges = edges[rels[0]]
mask = rel_edges[:, 0] == test_query_sub
new_edges = rel_edges[mask]
walk_edges = [new_edges]
cur_targets = np.array(list(set(walk_edges[0][:, 2])))
for i in range(1, len(rels)):
# Match current targets and next body relation
try:
rel_edges = edges[rels[i]]
mask = np.any(rel_edges[:, 0] == cur_targets[:, None], axis=0)
new_edges = rel_edges[mask]
walk_edges.append(new_edges)
cur_targets = np.array(list(set(walk_edges[i][:, 2])))
except KeyError:
walk_edges.append([])
break
except KeyError:
walk_edges = [[]]
return walk_edges
def get_walks(rule, walk_edges):
"""
Get walks for a given rule. Take the time constraints into account.
Memory-efficient implementation.
Parameters:
rule (dict): rule from rules_dict
walk_edges (list of np.ndarrays): edges from match_body_relations
Returns:
rule_walks (pd.DataFrame): all walks matching the rule
"""
df_edges = []
#pd.Series(values).astype(uint16)
df = pd.DataFrame(
walk_edges[0],
columns=["entity_" + str(0), "entity_" + str(1), "timestamp_" + str(0)]#,
# dtype=np.uint16,
) # Change type if necessary for better memory efficiency
if not rule["var_constraints"]:
del df["entity_" + str(0)]
df_edges.append(df)
df = df[0:0] # Memory efficiency
for i in range(1, len(walk_edges)):
df = pd.DataFrame(
walk_edges[i],
columns=["entity_" + str(i), "entity_" + str(i + 1), "timestamp_" + str(i)],
dtype=np.uint16,
) # Change type if necessary
df_edges.append(df)
df = df[0:0]
rule_walks = df_edges[0]
df_edges[0] = df_edges[0][0:0]
for i in range(1, len(df_edges)):
rule_walks = pd.merge(rule_walks, df_edges[i], on=["entity_" + str(i)])
rule_walks = rule_walks[
rule_walks["timestamp_" + str(i - 1)] <= rule_walks["timestamp_" + str(i)]
]
if not rule["var_constraints"]:
del rule_walks["entity_" + str(i)]
df_edges[i] = df_edges[i][0:0]
for i in range(1, len(rule["body_rels"])):
del rule_walks["timestamp_" + str(i)]
return rule_walks
def get_walks_complete(rule, walk_edges):
"""
Get complete walks for a given rule. Take the time constraints into account.
Parameters:
rule (dict): rule from rules_dict
walk_edges (list of np.ndarrays): edges from match_body_relations
Returns:
rule_walks (pd.DataFrame): all walks matching the rule
"""
df_edges = []
df = pd.DataFrame(
walk_edges[0],
columns=[
"entity_" + str(0),
"relation_" + str(0),
"entity_" + str(1),
"timestamp_" + str(0),
],
dtype=np.uint16,
) # Change type if necessary for better memory efficiency
df_edges.append(df)
for i in range(1, len(walk_edges)):
df = pd.DataFrame(
walk_edges[i],
columns=[
"entity_" + str(i),
"relation_" + str(i),
"entity_" + str(i + 1),
"timestamp_" + str(i),
],
dtype=np.uint16,
) # Change type if necessary
df_edges.append(df)
rule_walks = df_edges[0]
for i in range(1, len(df_edges)):
rule_walks = pd.merge(rule_walks, df_edges[i], on=["entity_" + str(i)])
rule_walks = rule_walks[
rule_walks["timestamp_" + str(i - 1)] <= rule_walks["timestamp_" + str(i)]
]
return rule_walks
def check_var_constraints(var_constraints, rule_walks):
"""
Check variable constraints of the rule.
Parameters:
var_constraints (list): variable constraints from the rule
rule_walks (pd.DataFrame): all walks matching the rule
Returns:
rule_walks (pd.DataFrame): all walks matching the rule including the variable constraints
"""
for const in var_constraints:
for i in range(len(const) - 1):
rule_walks = rule_walks[
rule_walks["entity_" + str(const[i])]
== rule_walks["entity_" + str(const[i + 1])]
]
return rule_walks
def get_candidates(
rule, rule_walks, test_query_ts, cands_dict, score_func, args, dicts_idx
):
"""
Get from the walks that follow the rule the answer candidates.
Add the confidence of the rule that leads to these candidates.
Parameters:
rule (dict): rule from rules_dict
rule_walks (pd.DataFrame): rule walks (satisfying all constraints from the rule)
test_query_ts (int): test query timestamp
cands_dict (dict): candidates along with the confidences of the rules that generated these candidates
score_func (function): function for calculating the candidate score
args (list): arguments for the scoring function
dicts_idx (list): indices for candidate dictionaries
Returns:
cands_dict (dict): updated candidates
"""
max_entity = "entity_" + str(len(rule["body_rels"]))
cands = set(rule_walks[max_entity])
for cand in cands:
cands_walks = rule_walks[rule_walks[max_entity] == cand]
for s in dicts_idx:
score = score_func(rule, cands_walks, test_query_ts, *args[s]).astype(
np.float32
)
try:
cands_dict[s][cand].append(score)
except KeyError:
cands_dict[s][cand] = [score]
return cands_dict
def save_candidates(
rules_file, dir_path, all_candidates, rule_lengths, window, score_func_str
):
"""
Save the candidates.
Parameters:
rules_file (str): name of rules file
dir_path (str): path to output directory
all_candidates (dict): candidates for all test queries
rule_lengths (list): rule lengths
window (int): time window used for rule application
score_func_str (str): scoring function
Returns:
None
"""
all_candidates = {int(k): v for k, v in all_candidates.items()}
for k in all_candidates:
all_candidates[k] = {int(cand): v for cand, v in all_candidates[k].items()}
filename = "{0}_cands_r{1}_w{2}_{3}.json".format(
rules_file[:-11], rule_lengths, window, score_func_str
)
filename = filename.replace(" ", "")
with open(dir_path + filename, "w", encoding="utf-8") as fout:
json.dump(all_candidates, fout)
def verbalize_walk(walk, data):
"""
Verbalize walk from rule application.
Parameters:
walk (pandas.core.series.Series): walk that matches the rule body from get_walks
data (grapher.Grapher): graph data
Returns:
walk_str (str): verbalized walk
"""
l = len(walk) // 3
walk = walk.values.tolist()
walk_str = data.id2entity[walk[0]] + "\t"
for j in range(l):
walk_str += data.id2relation[walk[3 * j + 1]] + "\t"
walk_str += data.id2entity[walk[3 * j + 2]] + "\t"
walk_str += data.id2ts[walk[3 * j + 3]] + "\t"
return walk_str[:-1]
def score1(rule, c=0):
"""
Calculate candidate score depending on the rule's confidence.
Parameters:
rule (dict): rule from rules_dict
c (int): constant for smoothing
Returns:
score (float): candidate score
"""
score = rule["rule_supp"] / (rule["body_supp"] + c)
return score
def score2(cands_walks, test_query_ts, lmbda):
"""
Calculate candidate score depending on the time difference.
Parameters:
cands_walks (pd.DataFrame): walks leading to the candidate
test_query_ts (int): test query timestamp
lmbda (float): rate of exponential distribution
Returns:
score (float): candidate score
"""
max_cands_ts = max(cands_walks["timestamp_0"])
score = np.exp(
lmbda * (max_cands_ts - test_query_ts)
) # Score depending on time difference
return score
def score_12(rule, cands_walks, test_query_ts, lmbda, a):
"""
Combined score function.
Parameters:
rule (dict): rule from rules_dict
cands_walks (pd.DataFrame): walks leading to the candidate
test_query_ts (int): test query timestamp
lmbda (float): rate of exponential distribution
a (float): value between 0 and 1
Returns:
score (float): candidate score
"""
score = a * score1(rule) + (1 - a) * score2(cands_walks, test_query_ts, lmbda)
return score
================================================
FILE: modules/tlogic_learn_modules.py
================================================
"""
https://github.com/liu-yushan/TLogic/blob/main/mycode/temporal_walk.py
AND
https://github.com/liu-yushan/TLogic/blob/main/mycode/rule_learning.py
TLogic: Temporal Logical Rules for Explainable Link Forecasting on Temporal Knowledge Graphs.
Yushan Liu, Yunpu Ma, Marcel Hildebrandt, Mitchell Joblin, Volker Tresp
"""
import os
import json
import itertools
import numpy as np
class Temporal_Walk(object):
def __init__(self, learn_data, inv_relation_id, transition_distr):
"""
Initialize temporal random walk object.
Parameters:
learn_data (np.ndarray): data on which the rules should be learned
inv_relation_id (dict): mapping of relation to inverse relation
transition_distr (str): transition distribution
"unif" - uniform distribution
"exp" - exponential distribution
Returns:
None
"""
self.learn_data = learn_data
self.inv_relation_id = inv_relation_id
self.transition_distr = transition_distr
self.neighbors = store_neighbors(learn_data)
self.edges = store_edges(learn_data)
def sample_start_edge(self, rel_idx):
"""
Define start edge distribution.
Parameters:
rel_idx (int): relation index
Returns:
start_edge (np.ndarray): start edge
"""
rel_edges = self.edges[rel_idx]
start_edge = rel_edges[np.random.choice(len(rel_edges))]
return start_edge
def sample_next_edge(self, filtered_edges, cur_ts):
"""
Define next edge distribution.
Parameters:
filtered_edges (np.ndarray): filtered (according to time) edges
cur_ts (int): current timestamp
Returns:
next_edge (np.ndarray): next edge
"""
if self.transition_distr == "unif":
next_edge = filtered_edges[np.random.choice(len(filtered_edges))]
elif self.transition_distr == "exp":
tss = filtered_edges[:, 3]
prob = np.exp(tss - cur_ts)
try:
prob = prob / np.sum(prob)
next_edge = filtered_edges[
np.random.choice(range(len(filtered_edges)), p=prob)
]
except ValueError: # All timestamps are far away
next_edge = filtered_edges[np.random.choice(len(filtered_edges))]
return next_edge
def transition_step(self, cur_node, cur_ts, prev_edge, start_node, step, L):
"""
Sample a neighboring edge given the current node and timestamp.
In the second step (step == 1), the next timestamp should be smaller than the current timestamp.
In the other steps, the next timestamp should be smaller than or equal to the current timestamp.
In the last step (step == L-1), the edge should connect to the source of the walk (cyclic walk).
It is not allowed to go back using the inverse edge.
Parameters:
cur_node (int): current node
cur_ts (int): current timestamp
prev_edge (np.ndarray): previous edge
start_node (int): start node
step (int): number of current step
L (int): length of random walk
Returns:
next_edge (np.ndarray): next edge
"""
next_edges = self.neighbors[cur_node]
if step == 1: # The next timestamp should be smaller than the current timestamp
filtered_edges = next_edges[next_edges[:, 3] < cur_ts]
else: # The next timestamp should be smaller than or equal to the current timestamp
filtered_edges = next_edges[next_edges[:, 3] <= cur_ts]
# Delete inverse edge
inv_edge = [
cur_node,
self.inv_relation_id[prev_edge[1]],
prev_edge[0],
cur_ts,
]
row_idx = np.where(np.all(filtered_edges == inv_edge, axis=1))
filtered_edges = np.delete(filtered_edges, row_idx, axis=0)
if step == L - 1: # Find an edge that connects to the source of the walk
filtered_edges = filtered_edges[filtered_edges[:, 2] == start_node]
if len(filtered_edges):
next_edge = self.sample_next_edge(filtered_edges, cur_ts)
else:
next_edge = []
return next_edge
def sample_walk(self, L, rel_idx):
"""
Try to sample a cyclic temporal random walk of length L (for a rule of length L-1).
Parameters:
L (int): length of random walk
rel_idx (int): relation index
Returns:
walk_successful (bool): if a cyclic temporal random walk has been successfully sampled
walk (dict): information about the walk (entities, relations, timestamps)
"""
walk_successful = True
walk = dict()
prev_edge = self.sample_start_edge(rel_idx)
start_node = prev_edge[0]
cur_node = prev_edge[2]
cur_ts = prev_edge[3]
walk["entities"] = [start_node, cur_node]
walk["relations"] = [prev_edge[1]]
walk["timestamps"] = [cur_ts]
for step in range(1, L):
next_edge = self.transition_step(
cur_node, cur_ts, prev_edge, start_node, step, L
)
if len(next_edge):
cur_node = next_edge[2]
cur_ts = next_edge[3]
walk["relations"].append(next_edge[1])
walk["entities"].append(cur_node)
walk["timestamps"].append(cur_ts)
prev_edge = next_edge
else: # No valid neighbors (due to temporal or cyclic constraints)
walk_successful = False
break
return walk_successful, walk
def store_neighbors(quads):
"""
Store all neighbors (outgoing edges) for each node.
Parameters:
quads (np.ndarray): indices of quadruples
Returns:
neighbors (dict): neighbors for each node
"""
neighbors = dict()
nodes = list(set(quads[:, 0]))
for node in nodes:
neighbors[node] = quads[quads[:, 0] == node]
return neighbors
def store_edges(quads):
"""
Store all edges for each relation.
Parameters:
quads (np.ndarray): indices of quadruples
Returns:
edges (dict): edges for each relation
"""
edges = dict()
relations = list(set(quads[:, 1]))
for rel in relations:
edges[rel] = quads[quads[:, 1] == rel]
return edges
class Rule_Learner(object):
def __init__(self, edges, id2relation, inv_relation_id, output_dir):
"""
Initialize rule learner object.
Parameters:
edges (dict): edges for each relation
id2relation (dict): mapping of index to relation
inv_relation_id (dict): mapping of relation to inverse relation
output_dir (str): directory name where to store learned rules
Returns:
None
"""
self.edges = edges
self.id2relation = id2relation
self.inv_relation_id = inv_relation_id
self.found_rules = []
self.rules_dict = dict()
self.output_dir = output_dir
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
def create_rule(self, walk):
"""
Create a rule given a cyclic temporal random walk.
The rule contains information about head relation, body relations,
variable constraints, confidence, rule support, and body support.
A rule is a dictionary with the content
{"head_rel": int, "body_rels": list, "var_constraints": list,
"conf": float, "rule_supp": int, "body_supp": int}
Parameters:
walk (dict): cyclic temporal random walk
{"entities": list, "relations": list, "timestamps": list}
Returns:
rule (dict): created rule
"""
rule = dict()
rule["head_rel"] = int(walk["relations"][0])
rule["body_rels"] = [
self.inv_relation_id[x] for x in walk["relations"][1:][::-1]
]
rule["var_constraints"] = self.define_var_constraints(
walk["entities"][1:][::-1]
)
if rule not in self.found_rules:
self.found_rules.append(rule.copy())
(
rule["conf"],
rule["rule_supp"],
rule["body_supp"],
) = self.estimate_confidence(rule)
if rule["conf"]:
self.update_rules_dict(rule)
def define_var_constraints(self, entities):
"""
Define variable constraints, i.e., state the indices of reoccurring entities in a walk.
Parameters:
entities (list): entities in the temporal walk
Returns:
var_constraints (list): list of indices for reoccurring entities
"""
var_constraints = []
for ent in set(entities):
all_idx = [idx for idx, x in enumerate(entities) if x == ent]
var_constraints.append(all_idx)
var_constraints = [x for x in var_constraints if len(x) > 1]
return sorted(var_constraints)
def estimate_confidence(self, rule, num_samples=500):
"""
Estimate the confidence of the rule by sampling bodies and checking the rule support.
Parameters:
rule (dict): rule
{"head_rel": int, "body_rels": list, "var_constraints": list}
num_samples (int): number of samples
Returns:
confidence (float): confidence of the rule, rule_support/body_support
rule_support (int): rule support
body_support (int): body support
"""
all_bodies = []
for _ in range(num_samples):
sample_successful, body_ents_tss = self.sample_body(
rule["body_rels"], rule["var_constraints"]
)
if sample_successful:
all_bodies.append(body_ents_tss)
all_bodies.sort()
unique_bodies = list(x for x, _ in itertools.groupby(all_bodies))
body_support = len(unique_bodies)
confidence, rule_support = 0, 0
if body_support:
rule_support = self.calculate_rule_support(unique_bodies, rule["head_rel"])
confidence = round(rule_support / body_support, 6)
return confidence, rule_support, body_support
def sample_body(self, body_rels, var_constraints):
"""
Sample a walk according to the rule body.
The sequence of timesteps should be non-decreasing.
Parameters:
body_rels (list): relations in the rule body
var_constraints (list): variable constraints for the entities
Returns:
sample_successful (bool): if a body has been successfully sampled
body_ents_tss (list): entities and timestamps (alternately entity and timestamp)
of the sampled body
"""
sample_successful = True
body_ents_tss = []
cur_rel = body_rels[0]
rel_edges = self.edges[cur_rel]
next_edge = rel_edges[np.random.choice(len(rel_edges))]
cur_ts = next_edge[3]
cur_node = next_edge[2]
body_ents_tss.append(next_edge[0])
body_ents_tss.append(cur_ts)
body_ents_tss.append(cur_node)
for cur_rel in body_rels[1:]:
next_edges = self.edges[cur_rel]
mask = (next_edges[:, 0] == cur_node) * (next_edges[:, 3] >= cur_ts)
filtered_edges = next_edges[mask]
if len(filtered_edges):
next_edge = filtered_edges[np.random.choice(len(filtered_edges))]
cur_ts = next_edge[3]
cur_node = next_edge[2]
body_ents_tss.append(cur_ts)
body_ents_tss.append(cur_node)
else:
sample_successful = False
break
if sample_successful and var_constraints:
# Check variable constraints
body_var_constraints = self.define_var_constraints(body_ents_tss[::2])
if body_var_constraints != var_constraints:
sample_successful = False
return sample_successful, body_ents_tss
def calculate_rule_support(self, unique_bodies, head_rel):
"""
Calculate the rule support. Check for each body if there is a timestamp
(larger than the timestamps in the rule body) for which the rule head holds.
Parameters:
unique_bodies (list): bodies from self.sample_body
head_rel (int): head relation
Returns:
rule_support (int): rule support
"""
rule_support = 0
head_rel_edges = self.edges[head_rel]
for body in unique_bodies:
mask = (
(head_rel_edges[:, 0] == body[0])
* (head_rel_edges[:, 2] == body[-1])
* (head_rel_edges[:, 3] > body[-2])
)
if True in mask:
rule_support += 1
return rule_support
def update_rules_dict(self, rule):
"""
Update the rules if a new rule has been found.
Parameters:
rule (dict): generated rule from self.create_rule
Returns:
None
"""
try:
self.rules_dict[rule["head_rel"]].append(rule)
except KeyError:
self.rules_dict[rule["head_rel"]] = [rule]
def sort_rules_dict(self):
"""
Sort the found rules for each head relation by decreasing confidence.
Parameters:
None
Returns:
None
"""
for rel in self.rules_dict:
self.rules_dict[rel] = sorted(
self.rules_dict[rel], key=lambda x: x["conf"], reverse=True
)
def save_rules(self, dt, rule_lengths, num_walks, transition_distr, seed):
"""
Save all rules.
Parameters:
dt (str): time now
rule_lengths (list): rule lengths
num_walks (int): number of walks
transition_distr (str): transition distribution
seed (int): random seed
Returns:
None
"""
rules_dict = {int(k): v for k, v in self.rules_dict.items()}
filename = "{0}_r{1}_n{2}_{3}_s{4}_rules.json".format(
dt, rule_lengths, num_walks, transition_distr, seed
)
filename = filename.replace(" ", "")
with open(self.output_dir + filename, "w", encoding="utf-8") as fout:
json.dump(rules_dict, fout)
return filename
def save_rules_verbalized(
self, dt, rule_lengths, num_walks, transition_distr, seed
):
"""
Save all rules in a human-readable format.
Parameters:
dt (str): time now
rule_lengths (list): rule lengths
num_walks (int): number of walks
transition_distr (str): transition distribution
seed (int): random seed
Returns:
None
"""
rules_str = ""
for rel in self.rules_dict:
for rule in self.rules_dict[rel]:
rules_str += verbalize_rule(rule, self.id2relation) + "\n"
filename = "{0}_r{1}_n{2}_{3}_s{4}_rules.txt".format(
dt, rule_lengths, num_walks, transition_distr, seed
)
filename = filename.replace(" ", "")
with open(self.output_dir + filename, "w", encoding="utf-8") as fout:
fout.write(rules_str)
def verbalize_rule(rule, id2relation):
"""
Verbalize the rule to be in a human-readable format.
Parameters:
rule (dict): rule from Rule_Learner.create_rule
id2relation (dict): mapping of index to relation
Returns:
rule_str (str): human-readable rule
"""
if rule["var_constraints"]:
var_constraints = rule["var_constraints"]
constraints = [x for sublist in var_constraints for x in sublist]
for i in range(len(rule["body_rels"]) + 1):
if i not in constraints:
var_constraints.append([i])
var_constraints = sorted(var_constraints)
else:
var_constraints = [[x] for x in range(len(rule["body_rels"]) + 1)]
rule_str = "{0:8.6f} {1:4} {2:4} {3}(X0,X{4},T{5}) <- "
obj_idx = [
idx
for idx in range(len(var_constraints))
if len(rule["body_rels"]) in var_constraints[idx]
][0]
rule_str = rule_str.format(
rule["conf"],
rule["rule_supp"],
rule["body_supp"],
id2relation[rule["head_rel"]],
obj_idx,
len(rule["body_rels"]),
)
for i in range(len(rule["body_rels"])):
sub_idx = [
idx for idx in range(len(var_constraints)) if i in var_constraints[idx]
][0]
obj_idx = [
idx for idx in range(len(var_constraints)) if i + 1 in var_constraints[idx]
][0]
rule_str += "{0}(X{1},X{2},T{3}), ".format(
id2relation[rule["body_rels"][i]], sub_idx, obj_idx, i
)
return rule_str[:-2]
================================================
FILE: pyproject.toml
================================================
[tool.poetry]
name = "py-tgb"
version = "2.2.0"
description = "Temporal Graph Benchmark project repo"
authors = ["shenyang Huang ", "Julia Gastinger", "Farimah Poursafaei", "Emanuele Rossi ", "Jacob Danovitch "]
readme = "README.md"
packages = [{include = "tgb"}]
[tool.poetry.dependencies]
python = "^3.9"
torch-geometric = "^2.3.0"
tqdm = "^4.65.0"
numpy = "^2.0.2"
clint = "^0.5.1"
requests = "^2.28.2"
pandas = ">=2.2.3"
scikit-learn = "^1.2.2"
[tool.poetry.group.dev.dependencies]
mkdocs = "^1.4.3"
mkdocs-material = "^9.1.15"
mkdocstrings-python = "^1.1.2"
mkdocs-jupyter = "^0.24.1"
poetry = "^1.5.1"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
================================================
FILE: run.sh
================================================
#!/bin/bash
#SBATCH --partition=long #unkillable #main #long
#SBATCH --output=tgnlog_genre_s5.txt #tgn_lastfmgenre_s5.txt
#SBATCH --error=tgnlog_genre_s5error.txt #tgn_lastfmgenre_s5_error.txt
#SBATCH --cpus-per-task=4 # Ask for 4 CPUs
#SBATCH --gres=gpu:rtx8000:1 # Ask for 1 titan xp
#SBATCH --mem=32G # Ask for 32 GB of RAM
#SBATCH --time=48:00:00 # The job will run for 1 day
export HOME="/home/mila/h/huangshe"
module load python/3.9
source $HOME/tgbenv/bin/activate
pwd
CUDA_VISIBLE_DEVICES=0 python examples/nodeproppred/tgbn-genre/tgn.py --seed 5
# CUDA_VISIBLE_DEVICES=0 python examples/nodeproppred/lastfmgenre/dyrep.py --seed 5
# CUDA_VISIBLE_DEVICES=0 python examples/nodeproppred/un_trade/tgn.py -s 5
# CUDA_VISIBLE_DEVICES=0 python examples/linkproppred/amazonreview/tgn.py -s 1
================================================
FILE: scripts/env.sh
================================================
module load python/3.9
source $HOME/tgbenv/bin/activate
================================================
FILE: scripts/mila.sh
================================================
salloc --partition=unkillable --cpus-per-task=4 --gres=gpu:1 --mem=32G
================================================
FILE: scripts/mila_install.sh
================================================
module load python/3.9
python -m venv $HOME/tgbenv
source $HOME/tgbenv/bin/activate
pip3 install -r requirements.txt
pip3 install -e .
================================================
FILE: scripts/run.sh
================================================
#!/bin/bash
#SBATCH --partition=long #unkillable #main #long
#SBATCH --output=dyrep_trade_s5.txt #tgn_lastfmgenre_s5.txt
#SBATCH --error=dyrep_trade_s5error.txt #tgn_lastfmgenre_s5_error.txt
#SBATCH --cpus-per-task=4 # Ask for 4 CPUs
#SBATCH --gres=gpu:rtx8000:1 # Ask for 1 titan xp
#SBATCH --mem=32G # Ask for 32 GB of RAM
#SBATCH --time=48:00:00 # The job will run for 1 day
export HOME="/home/mila/h/huangshe"
module load python/3.9
source $HOME/tgbenv/bin/activate
pwd
CUDA_VISIBLE_DEVICES=0 python examples/nodeproppred/un_trade/dyrep.py --seed 5
# CUDA_VISIBLE_DEVICES=0 python examples/nodeproppred/lastfmgenre/dyrep.py --seed 5
# CUDA_VISIBLE_DEVICES=0 python examples/nodeproppred/un_trade/tgn.py -s 5
# CUDA_VISIBLE_DEVICES=0 python examples/linkproppred/amazonreview/tgn.py -s 1
# CUDA_VISIBLE_DEVICES=0 python examples/nodeproppred/lastfmgenre/tgn.py -s 5
================================================
FILE: setup.py
================================================
from setuptools import setup, find_packages
setup(name="py-tgb", version="2.2.0", packages=find_packages())
================================================
FILE: tgb/datasets/ICEWS14/ent2word.py
================================================
# -*- coding: utf-8 -*-
# @Time : 2019/12/5 4:20 下午
# @Author : Lee_zix
# @Email : Lee_zix@163.com
# @File : ent2word.py.py
# @Software: PyCharm
import os
def load_index(input_path):
index, rev_index = {}, {}
with open(input_path) as f:
for i, line in enumerate(f.readlines()): # relaions.dict和entities.dict中的id都是按顺序排列的
rel, id = line.strip().split("\t")
index[rel] = id
rev_index[id] = rel
return index, rev_index
entity2id, id2entity = load_index(os.path.join('entity2id.txt'))
relation2id, id2relation = load_index(os.path.join('relation2id.txt'))
count = 0
count1 = 0
word_list = set()
for entity_str in entity2id.keys():
if "(" in entity_str and ")" in entity_str:
count += 1
begin = entity_str.find('(')
end = entity_str.find(')')
w1 = entity_str[:begin].strip()
w2 = entity_str[begin+1: end]
if w2 not in entity2id.keys():
print(w2)
count1 += 1
word_list.add(w1)
word_list.add(w2)
else:
word_list.add(entity_str)
num_word = len(word_list)
word2id = {word: id for id, word in enumerate(word_list)}
id2word = {id: word for id, word in enumerate(word_list)}
# print(word2id)
# print(id2word)
print("words num: {}, enity_num: {}".format(num_word, len(entity2id.keys())))
print(float(count)/len(entity2id.keys()))
print(float(count1)/float(count))
with open("word2id.txt", "w") as f:
for word in word2id.keys():
f.write(word + "\t" + str(word2id[word])+'\n')
eid2wid = []
for id in range(len(id2entity.keys())):
entity_str = id2entity[str(id)]
if "(" in entity_str and ")" in entity_str:
count += 1
begin = entity_str.find('(')
end = entity_str.find(')')
w1 = entity_str[:begin].strip()
w2 = entity_str[begin+1: end]
eid2wid.append([str(entity2id[entity_str]), "0", str(word2id[w1])]) # isA关系
eid2wid.append([str(entity2id[entity_str]), "1", str(word2id[w2])]) # 隶属关系
else:
eid2wid.append([str(entity2id[entity_str]), "2", str(word2id[entity_str])])
with open("e-w-graph.txt", "w") as f:
for line in eid2wid:
f.write("\t".join(line)+'\n')
================================================
FILE: tgb/datasets/ICEWS14/icews14.py
================================================
import csv
def load_index(input_path):
index, rev_index = {}, {}
with open(input_path) as f:
for i, line in enumerate(f.readlines()): # relaions.dict和entities.dict中的id都是按顺序排列的
rel, id = line.strip().split("\t")
index[rel] = id
rev_index[id] = rel
return index, rev_index
def load_tab_list(input_path):
rows = []
with open(input_path) as f:
for i, line in enumerate(f.readlines()):
head,relation,tail,t, = line.strip().split("\t")
rows.append([t,head,tail,relation])
return rows
def write2csv(rows, output_path):
with open(output_path, "w") as f:
writer = csv.writer(f)
writer.writerow(["timestamp", "head", "tail", "relation_type"])
writer.writerows(rows)
def main():
"""
concatenate and merge the edgelists into one
change tab to ,
"""
train_name = "train.txt"
train_rows = load_tab_list(train_name)
val_name = "valid.txt"
val_rows = load_tab_list(val_name)
test_name = "test.txt"
test_rows = load_tab_list(test_name)
all_rows = train_rows + val_rows + test_rows
output_path = "icews14.csv"
write2csv(all_rows, output_path)
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/dataset_scripts/MAG/mag.py
================================================
import pandas as pd
if __name__ == "__main__":
df = pd.read_parquet("nodes.parquet/nodes.parquet", engine="pyarrow")
data_top = df.head()
print(data_top)
================================================
FILE: tgb/datasets/dataset_scripts/MAG/old/plot_stats.py
================================================
import networkx as nx
import matplotlib.pyplot as plt
def load_csv(fname: str):
"""
plot the number of citations in each year for the MAG dataset
"""
f = open(fname, "r")
lines = list(f.readlines())
f.close()
years = []
cites = []
for i in range(len(lines)):
if i == 0:
continue
line = lines[i]
line = line.split(",")
try:
year = int(line[0])
except:
continue
num_citations = int(line[1])
years.append(year)
cites.append(num_citations)
plt.plot(years, cites, color="#e34a33")
plt.xlabel("Year")
plt.ylabel("Paper Count")
plt.savefig("paper_count.pdf")
plt.close()
if __name__ == "__main__":
load_csv("paper_year.txt")
================================================
FILE: tgb/datasets/dataset_scripts/dgraph.py
================================================
import dateutil.parser as dparser
import csv
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from os import listdir
from datetime import datetime
"""
# Description of DGraphFin datafile.
#! File **dgraphfin.npz** including below keys:
#* **x**: 17-dimensional node features.
#* **y**: node label.
There four classes. Below are the nodes counts of each class.
0: 1210092
1: 15509
2: 1620851
3: 854098
Nodes of Class 1 are fraud users and nodes of 0 are normal users, and they the two classes to be predicted.
Nodes of Class 2 and Class 3 are background users.
#* **edge_index**: shape (4300999, 2).
Each edge is in the form (id_a, id_b), where ids are the indices in x.
#* **edge_type**: 11 types of edges.
#* **edge_timestamp**: the desensitized timestamp of each edge.
#* **train_mask, valid_mask, test_mask**:
Nodes of Class 0 and Class 1 are randomly splitted by 70/15/15.
"""
def main():
#* load the raw data from numpy
with np.load('dgraphfin.npz') as data:
x = data['x']
print ("shape of the node feature vectors are")
print (x.shape)
y = data['y']
print ("shape of the node labels are")
print (y.shape)
edge_index = data['edge_index']
print ("shape of the edge index are")
print (edge_index.shape)
edge_type = data['edge_type']
print ("shape of the edge type are")
print (edge_type.shape)
edge_timestamp = data['edge_timestamp']
print ("shape of the edge timestamp are")
print (edge_timestamp.shape)
print ("check if the timestamps are sorted")
print(np.all(edge_timestamp[:-1] <= edge_timestamp[1:]))
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/dataset_scripts/dgraph_Readme.md
================================================
# Description of DGraphFin datafile.
File **dgraphfin.npz** including below keys:
- **x**: 17-dimensional node features.
- **y**: node label.
There four classes. Below are the nodes counts of each class.
0: 1210092
1: 15509
2: 1620851
3: 854098
Nodes of Class 1 are fraud users and nodes of 0 are normal users, and they the two classes to be predicted.
Nodes of Class 2 and Class 3 are background users.
- **edge_index**: shape (4300999, 2).
Each edge is in the form (id_a, id_b), where ids are the indices in x.
- **edge_type**: 11 types of edges.
- **edge_timestamp**: the desensitized timestamp of each edge.
- **train_mask, valid_mask, test_mask**:
Nodes of Class 0 and Class 1 are randomly splitted by 70/15/15.
================================================
FILE: tgb/datasets/dataset_scripts/process_arxiv.py
================================================
import json
import networkx as nx
import numpy as np
import csv
from datetime import date
def load_full_json(fname):
json_str = ""
ctr = 0
with open(fname, "r", encoding='utf-8') as f:
#TODO need to determine how many lines form a json object
for line in f:
data = json.loads(line)
print (data)
quit() #remove this when you write the code
def main():
fname = "nodes.json"
load_full_json(fname)
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/dataset_scripts/process_github.py
================================================
import json
from datetime import datetime
rels = {
"IC_Created_IC_I": "IC_AO_C_I",
"IC_Created_U_IC": "U_SO_C_IC",
"I_Opened_U_I": "U_SE_O_I",
"I_Opened_I_R": "I_AO_O_R",
"I_Closed_U_I": "U_SE_C_I",
"I_Closed_I_R": "I_AO_C_R",
"I_Reopened_U_I": "U_SE_RO_I",
"I_Reopened_I_R": "I_AO_RO_R",
"PR_Opened_U_PR": "U_SO_O_P",
"PR_Opened_PR_R": "P_AO_O_R",
"PR_Closed_U_PR": "U_SO_C_P",
"PR_Closed_PR_R": "P_AO_C_R",
"PR_Reopened_U_PR": "U_SO_R_P",
"PR_Reopened_PR_R": "P_AO_R_R",
"PRRC_Created_U_PRC": "U_SO_C_PRC",
"PRRC_Created_PRC_PR": "PRC_AO_C_P",
"Forked_R_R": "R_FO_R",
"AddMember_U_R": "U_CO_A_R",
}
issue_comment_format = "/issue_comment/{}"
issue_format = "/issue/{}"
user_format = "/user/{}"
repo_format = "/repo/{}"
pull_request_format = "/pr/{}"
pull_request_review_comment_format = "/pr_review_comment/{}"
def str_to_timestamp(time_str):
dt = datetime.strptime(time_str, "%Y-%m-%dT%H:%M:%SZ")
return int(dt.timestamp())
def parse_issue_comment_events(event):
if event["payload"]["action"] == "created":
issue_comment_id = event["payload"]["comment"]["id"]
issue_id = event["payload"]["issue"]["id"]
user_id = event["actor"]["id"]
created_at = str_to_timestamp(event["created_at"])
ici_event = [
issue_comment_format.format(issue_comment_id),
rels["IC_Created_IC_I"],
issue_format.format(issue_id),
created_at,
]
uic_event = [
user_format.format(user_id),
rels["IC_Created_U_IC"],
issue_comment_format.format(issue_comment_id),
created_at,
]
return [ici_event, uic_event]
return []
def parse_issue_event(event):
issue_id = event["payload"]["issue"]["id"]
user_id = event["actor"]["id"]
repo_id = event["repo"]["id"]
created_at = str_to_timestamp(event["created_at"])
action_map = {
"opened": ("I_Opened_U_I", "I_Opened_I_R"),
"closed": ("I_Closed_U_I", "I_Closed_I_R"),
"reopened": ("I_Reopened_U_I", "I_Reopened_I_R"),
}
for action, event_rels in action_map.items():
if event["payload"]["action"] == action:
ui_event = [
user_format.format(user_id),
rels[event_rels[0]],
issue_format.format(issue_id),
created_at,
]
ir_event = [
issue_format.format(issue_id),
rels[event_rels[1]],
repo_format.format(repo_id),
created_at,
]
return [ui_event, ir_event]
return []
def parse_pull_request_event(event):
pull_request_id = event["payload"]["pull_request"]["id"]
user_id = event["actor"]["id"]
repo_id = event["repo"]["id"]
created_at = str_to_timestamp(event["created_at"])
action_map = {
"opened": ("PR_Opened_U_PR", "PR_Opened_PR_R"),
"closed": ("PR_Closed_U_PR", "PR_Closed_PR_R"),
"reopened": ("PR_Reopened_U_PR", "PR_Reopened_PR_R"),
}
for action, event_rels in action_map.items():
if event["payload"]["action"] == action:
upr_event = [
user_format.format(user_id),
rels[event_rels[0]],
pull_request_format.format(pull_request_id),
created_at,
]
prr_event = [
pull_request_format.format(pull_request_id),
rels[event_rels[1]],
repo_format.format(repo_id),
created_at,
]
return [upr_event, prr_event]
return []
def parse_pull_request_review_comment_event(event):
pull_request_review_comment_id = event["payload"]["comment"]["id"]
pull_request_id = event["payload"]["pull_request"]["id"]
user_id = event["actor"]["id"]
created_at = str_to_timestamp(event["created_at"])
if event["payload"]["action"] == "created":
uprc_event = [
user_format.format(user_id),
rels["PRRC_Created_U_PRC"],
pull_request_review_comment_format.format(pull_request_review_comment_id),
created_at,
]
prcpr_event = [
pull_request_review_comment_format.format(pull_request_review_comment_id),
rels["PRRC_Created_PRC_PR"],
pull_request_format.format(pull_request_id),
created_at,
]
return [uprc_event, prcpr_event]
return []
def parse_fork_event(event):
forkee_repo_id = event["payload"]["forkee"]["id"]
forked_repo_id = event["repo"]["id"]
created_at = str_to_timestamp(event["created_at"])
return [
[
repo_format.format(forkee_repo_id),
rels["Forked_R_R"],
repo_format.format(forked_repo_id),
created_at,
]
]
def parse_member_event(event):
user_id = event["payload"]["member"]["id"]
repo_id = event["repo"]["id"]
created_at = str_to_timestamp(event["created_at"])
return [
[
user_format.format(user_id),
rels["AddMember_U_R"],
repo_format.format(repo_id),
created_at,
]
]
event_handler_dict = {
"IssueCommentEvent": parse_issue_comment_events,
"IssuesEvent": parse_issue_event,
"PullRequestEvent": parse_pull_request_event,
"PullRequestReviewCommentEvent": parse_pull_request_review_comment_event,
"ForkEvent": parse_fork_event,
"MemberEvent": parse_member_event,
}
def parse_event(event):
event_type = event["type"]
if event_type in event_handler_dict:
output_list = event_handler_dict[event_type](event)
# print("Got {} outputs for event type {}".format(len(output_list), event_type))
else:
# print("Unknown event type: {}".format(event_type))
output_list = []
return output_list
def parse_file(filename):
events = []
with open(filename) as f:
for i, line in enumerate(f):
event = json.loads(line)
parsed_events = parse_event(event)
events.append(parsed_events)
events = [event for sublist in events for event in sublist]
print("Parsed {} events".format(len(events)))
return events
filename = "2015-01-01-15.json"
parse_file(filename)
================================================
FILE: tgb/datasets/dataset_scripts/tgbl-coin.py
================================================
import csv
"""
#! analyze statistics from the dataset
#* 1). # of unique nodes, 2). # of edges. 3). # of unique edges, 4). # of timestamps 5). min & max of edge weights, 6). recurrence of nodes
"""
def analyze_csv(fname):
node_dict = {}
edge_dict = {}
num_edges = 0
num_time = 0
prev_t = "none"
min_w = 100000
max_w = 0
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
line_count = 0
for row in csv_reader:
if line_count == 0:
line_count += 1
else:
# t,u,v,w
t = row[0]
u = row[1]
v = row[2]
w = float(row[3].strip())
# min & max edge weights
if w > max_w:
max_w = w
if w < min_w:
min_w = w
# count unique time
if t != prev_t:
num_time += 1
prev_t = t
# unique nodes
if u not in node_dict:
node_dict[u] = 1
else:
node_dict[u] += 1
if v not in node_dict:
node_dict[v] = 1
else:
node_dict[v] += 1
# unique edges
num_edges += 1
if (u, v) not in edge_dict:
edge_dict[(u, v)] = 1
else:
edge_dict[(u, v)] += 1
print("----------------------high level statistics-------------------------")
print("number of total edges are ", num_edges)
print("number of nodes are ", len(node_dict))
print("number of unique edges are ", len(edge_dict))
print("number of unique timestamps are ", num_time)
print("maximum edge weight is ", max_w)
print("minimum edge weight is ", min_w)
num_10 = 0
num_100 = 0
num_1000 = 0
for node in node_dict:
if node_dict[node] >= 10:
num_10 += 1
if node_dict[node] >= 100:
num_100 += 1
if node_dict[node] >= 1000:
num_1000 += 1
print("number of nodes with # edges >= 10 is ", num_10)
print("number of nodes with # edges >= 100 is ", num_100)
print("number of nodes with # edges >= 1000 is ", num_1000)
print("----------------------high level statistics-------------------------")
"""
return a node dict only keeping nodes with > 10 edges
"""
def extract_node_dict(fname, freq=10):
node_dict = {}
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
line_count = 0
for row in csv_reader:
if line_count == 0:
line_count += 1
else:
# t,u,v,w
t = row[0]
u = row[1]
v = row[2]
w = float(row[3].strip())
if u not in node_dict:
node_dict[u] = 1
else:
node_dict[u] += 1
if v not in node_dict:
node_dict[v] = 1
else:
node_dict[v] += 1
out_dict = {}
for node in node_dict:
if node_dict[node] >= freq:
out_dict[node] = node_dict[node]
return out_dict
"""
remove any edges do not contain either src or dst not in the node dict
"""
def clean_edgelist(fname, outname, node_dict):
with open(outname, "w") as outf:
write = csv.writer(outf)
fields = ["time", "src", "dst", "weight"]
write.writerow(fields)
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
line_count = 0
for row in csv_reader:
if line_count == 0:
line_count += 1
else:
# t,u,v,w
t = row[0]
u = row[1]
v = row[2]
w = float(row[3].strip())
if u in node_dict and v in node_dict:
write.writerow([t, u, v, w])
def sort_edgelist(in_file, outname):
"""
sort the edges by timestamp
"""
row_dict = {} #{day: {row: row}}
line_idx = 0
with open(outname, "w") as outf:
write = csv.writer(outf)
fields = ["day", "src", "dst", "callsign", "typecode"]
write.writerow(fields)
with open(in_file, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
for row in csv_reader:
if line_idx == 0: # header
line_idx += 1
continue
ts = int(row[0])
if ts not in row_dict:
row_dict[ts] = {}
row_dict[ts][line_idx] = row
else:
row_dict[ts][line_idx] = row
line_idx += 1
for ts in sorted(row_dict.keys()):
for idx in row_dict[ts].keys():
row = row_dict[ts][idx]
write.writerow(row)
def main():
"""
keeping subgraph of most active nodes
"""
# freq = 10
# fname = "stablecoin_edgelist.csv"
# node_dict = extract_node_dict(fname, freq=freq)
# outname = "stablecoin_freq10.csv"
# clean_edgelist(fname, outname, node_dict)
# fname = "stablecoin_freq10.csv"
# analyze_csv(fname)
"""
sort edgelist by time
"""
in_file = "tgbl-coin_edgelist.csv"
outname = "tgbl-coin_edgelist_sorted.csv"
sort_edgelist(in_file, outname)
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/dataset_scripts/tgbl-coin_neg_generator.py
================================================
import time
from tgb.linkproppred.negative_generator import NegativeEdgeGenerator
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
def main():
r"""
Generate negative edges for the validation or test phase
"""
print("*** Negative Sample Generation ***")
# setting the required parameters
num_neg_e_per_pos = 20 #100
neg_sample_strategy = "hist_rnd" #"rnd"
rnd_seed = 42
name = "tgbl-coin"
dataset = PyGLinkPropPredDataset(name=name, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data_splits = {}
data_splits['train'] = data[train_mask]
data_splits['val'] = data[val_mask]
data_splits['test'] = data[test_mask]
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
# After successfully loading the dataset...
if neg_sample_strategy == "hist_rnd":
historical_data = data_splits["train"]
else:
historical_data = None
neg_sampler = NegativeEdgeGenerator(
dataset_name=name,
first_dst_id=min_dst_idx,
last_dst_id=max_dst_idx,
num_neg_e=num_neg_e_per_pos,
strategy=neg_sample_strategy,
rnd_seed=rnd_seed,
historical_data=historical_data,
)
# generate evaluation set
partial_path = "."
# generate validation negative edge set
start_time = time.time()
split_mode = "val"
print(
f"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}"
)
neg_sampler.generate_negative_samples(
data=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path
)
print(
f"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}"
)
# generate test negative edge set
start_time = time.time()
split_mode = "test"
print(
f"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}"
)
neg_sampler.generate_negative_samples(
data=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path
)
print(
f"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}"
)
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/dataset_scripts/tgbl-comment.py
================================================
import csv
from tqdm import tqdm
from os import listdir
from tgb.utils.stats import analyze_csv
def find_filenames(path_to_dir):
r"""
find all files in a folder
Parameters:
path_to_dir (str): path to the directory
"""
filenames = listdir(path_to_dir)
return filenames
def read_edgelist(fname, outfname, write_header=False):
"""
read a space separated edgelist
comment’s author, author of the parent (the post that the comment is replied to), comment’s creation time, comment’s edge id
u,v,t,edge_id
3746738 1637382 1551398391 31534079835
Parameters:
fname (str): path to the edgelist
outfname (str): path to the output file
"""
edgelist = open(fname, "r")
lines = list(edgelist.readlines())
edgelist.close()
with open(outfname, "a") as outf:
write = csv.writer(outf)
if write_header:
fields = ["ts", "src", "dst", "edge_id"]
write.writerow(fields)
for line in lines:
line = line.split()
if len(line) < 4:
continue
src = line[0]
dst = line[1]
ts = line[2]
edge_id = line[3]
write.writerow([ts, src, dst, edge_id])
def read_nodeattr(fname, outfname, write_header=False):
"""
read a space separated edgelist
comment’s edge id, Reddit’s identifier of the comment, Reddit’s identifier of the parent (the post that the comment is replied to)
Reddit’s identifier of the submission that the comment is in, name of the subreddit that the comment is in, number of characters in the comment’s body
number of words in the comment’s body, score of the comment, a flag indicating if the comment has been edited
edge_id, subreddit, num_characters, num_words, score, 'edited_flag'
Parameters:
fname (str): path to the edgelist
outfname (str): path to the output file
"""
edgelist = open(fname, "r")
lines = list(edgelist.readlines())
edgelist.close()
with open(outfname, "a") as outf:
write = csv.writer(outf)
if write_header:
fields = [
"edge_id",
"subreddit",
"num_characters",
"num_words",
"score",
"edited_flag",
]
write.writerow(fields)
for line in lines:
line = line.split()
if len(line) < 4:
continue
edge_id = line[0]
subreddit = line[4]
num_characters = line[5]
num_words = line[6]
score = line[7]
edited_flag = line[8].strip("/n")
write.writerow(
[edge_id, subreddit, num_characters, num_words, score, edited_flag]
)
def combine_edgelist_edgefeat(edgefname, featfname, outname):
"""
combine edgelist and edge features
#! remove subreddit from feature
"""
total_lines = sum(1 for line in open(edgefname))
subreddit_ids = {}
missing_ts = 0
missing_src = 0
missing_dst = 0
line_idx = 0
with open(outname, "w") as outf:
write = csv.writer(outf)
fields = ["ts", "src", "dst", "subreddit", "num_words", "score"]
write.writerow(fields)
sub_id = 0
edgelist = open(edgefname, "r")
edgefeat = open(featfname, "r")
edgelist.readline()
edgefeat.readline()
while True:
#'ts', 'src', 'dst', 'edge_id'
edge_line = edgelist.readline()
edge_line = edge_line.split(",")
if len(edge_line) < 4:
break
edge_id = int(edge_line[3])
ts = int(edge_line[0])
src = int(edge_line[1])
dst = int(edge_line[2])
#'edge_id', 'subreddit', 'num_characters', 'num_words', 'score', 'edited_flag'
feat_line = edgefeat.readline()
feat_line = feat_line.split(",")
edge_id_feat = int(feat_line[0])
subreddit = feat_line[1]
if subreddit not in subreddit_ids:
subreddit_ids[subreddit] = sub_id
sub_id += 1
subreddit = subreddit_ids[subreddit]
num_characters = int(feat_line[2])
num_words = int(feat_line[3])
score = int(feat_line[4])
edited_flag = bool(feat_line[5])
#! check if ts, src, dst is -1
if ts == -1:
missing_ts += 1
continue
if src == -1:
missing_src += 1
continue
if dst == -1:
missing_dst += 1
continue
if edge_id != edge_id_feat:
print("edge_id != edge_id_feat")
print(edge_id)
print(edge_id_feat)
break
# write.writerow([ts, src, dst, subreddit, num_words, score])
write.writerow([ts, src, dst, num_words, score])
line_idx += 1
print("processed", line_idx, "lines")
# print ("there are lines", missing_ts, " missing timestamps")
# print ("there are lines", missing_src, " missing src")
# print ("there are lines", missing_dst, " missing dst")
def main():
# #! unzip all xz files by $ unxz *.xz
# f_dir = "raw/raw_2008_2010/" #"raw/raw_2005_2010/" #"raw/raw_2013_2014/"
# fnames = find_filenames(f_dir)
# outname = "redditcomments_edgelist_2008_2010.csv" #"redditcomments_edgelist_2013_2014.csv"
# idx = 0
# for fname in tqdm(fnames):
# if (idx == 0):
# read_edgelist(f_dir+fname, outname, write_header=True)
# else:
# read_edgelist(f_dir+fname, outname, write_header=False)
# idx += 1
# # #! extract the node attributes
f_dir = "raw/node_2008_2010/"#"raw/node_2005_2010/"
fnames = find_filenames(f_dir)
outname = "redditcomments_edgefeat_2008_2010.csv"
idx = 0
for fname in tqdm(fnames):
if (idx == 0):
read_nodeattr(f_dir+fname, outname, write_header=True)
else:
read_nodeattr(f_dir+fname, outname, write_header=False)
idx += 1
#! combine edgelist and edge feat file check if the edge_id matches
# edgefname = "redditcomments_edgelist_2005_2010.csv"
# featfname = "redditcomments_edgefeat_2005_2010.csv"
# outname = "redditcomments_edgelist.csv"
# combine_edgelist_edgefeat(edgefname, featfname, outname)
# #! analyze the extracted csv
# fname = "redditcomments_edgelist_2005_2010.csv"
# analyze_csv(fname)
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/dataset_scripts/tgbl-comment_neg_generator.py
================================================
import time
from tgb.linkproppred.negative_generator import NegativeEdgeGenerator
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
def main():
r"""
Generate negative edges for the validation or test phase
"""
print("*** Negative Sample Generation ***")
# setting the required parameters
num_neg_e_per_pos = 20 #100
neg_sample_strategy = "hist_rnd" #"rnd"
rnd_seed = 42
name = "tgbl-comment"
dataset = PyGLinkPropPredDataset(name=name, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data_splits = {}
data_splits['train'] = data[train_mask]
data_splits['val'] = data[val_mask]
data_splits['test'] = data[test_mask]
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
# After successfully loading the dataset...
if neg_sample_strategy == "hist_rnd":
historical_data = data_splits["train"]
else:
historical_data = None
neg_sampler = NegativeEdgeGenerator(
dataset_name=name,
first_dst_id=min_dst_idx,
last_dst_id=max_dst_idx,
num_neg_e=num_neg_e_per_pos,
strategy=neg_sample_strategy,
rnd_seed=rnd_seed,
historical_data=historical_data,
)
# generate evaluation set
partial_path = "."
# generate validation negative edge set
start_time = time.time()
split_mode = "val"
print(
f"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}"
)
neg_sampler.generate_negative_samples(
data=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path
)
print(
f"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}"
)
# generate test negative edge set
start_time = time.time()
split_mode = "test"
print(
f"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}"
)
neg_sampler.generate_negative_samples(
data=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path
)
print(
f"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}"
)
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/dataset_scripts/tgbl-flight.py
================================================
import dateutil.parser as dparser
import csv
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from os import listdir
from datetime import datetime
def find_csv_filenames(path_to_dir, suffix=".csv"):
r"""
find all csv files in a directory
Parameters:
path_to_dir (str): path to the directory
suffix (str): suffix of the file
"""
filenames = listdir(path_to_dir)
return [filename for filename in filenames if filename.endswith(suffix)]
def flight2edgelist(
fname,
outname,
node_dict=None,
):
"""
process all rows into
Day, src, dst, callsign, number, icao24, registration, typecode
and save it as an edgelist file
"""
miss_node_lines = 0
skip_lines = 0
print("processing ", outname)
with open(outname, "w") as outf:
write = csv.writer(outf)
fields = [
"day",
"src",
"dst",
"callsign",
"number",
"icao24",
"registration",
"typecode",
]
write.writerow(fields)
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
line_count = 0
# callsign,number,icao24,registration,typecode,origin,destination,firstseen,lastseen,day,latitude_1,longitude_1,altitude_1,latitude_2,longitude_2,altitude_2
for row in csv_reader:
if line_count == 0:
line_count += 1
else:
out = []
callsign = row[0]
number = row[1]
icao24 = row[2]
registration = row[3]
typecode = row[4]
src = row[5]
if src == "":
skip_lines += 1
continue
dst = row[6]
if dst == "":
skip_lines += 1
continue
if node_dict is not None:
if src not in node_dict:
miss_node_lines += 1
continue
if dst not in node_dict:
miss_node_lines += 1
continue
day = row[9]
day = day[0:10]
out.append(day)
out.append(src)
out.append(dst)
out.append(callsign)
out.append(number)
out.append(icao24)
out.append(registration)
out.append(typecode)
write.writerow(out)
line_count += 1
print(f"Processed {line_count} lines.")
print(f"Skipped {skip_lines} lines.")
print(f"missing node {miss_node_lines} lines.")
return line_count, skip_lines, miss_node_lines
def load_icao_airports(fname="airport_codes.csv"):
airports_continent = {}
airports_country = {}
edgelist = open(fname, "r")
lines = list(edgelist.readlines())
edgelist.close()
# date u v w
# find how many timestamps there are
for i in range(0, len(lines)):
line = lines[i]
values = line.split(",")
icao = values[0]
continent = values[4]
country = values[5]
airports_continent[icao] = continent
airports_country[icao] = country
return airports_continent, airports_country
def merge_edgelist(input_names: str, in_dir: str, outname: str):
"""
merge a list of edgefiles into one file
"""
line_count = 0
total = 0
with open(outname, "w") as outf:
write = csv.writer(outf)
fields = ["day", "src", "dst", "callsign", "typecode"]
write.writerow(fields)
for csv_name in tqdm(input_names):
in_name = in_dir + csv_name
line_count = 0
with open(in_name, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
for row in csv_reader:
if line_count == 0: # header
line_count += 1
else:
# Day, src, dst, callsign, number, icao24, registration, typecode
day = row[0]
src = row[1]
dst = row[2]
callsign = row[3]
typecode = row[-1]
out = [day, src, dst, callsign, typecode]
write.writerow(out)
total += 1
def clean_node_feat(in_file, outname):
with open(outname, "w") as outf:
write = csv.writer(outf)
fields = [
"airport_code",
"type",
"continent",
"iso_region",
"longitude",
"latitude",
]
write.writerow(fields)
idx = 0
with open(in_file, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
for row in csv_reader:
if idx == 0:
idx += 1
continue
else:
# ident,type,name,elevation_ft,continent,iso_country,iso_region,municipality,gps_code,iata_code,local_code,coordinates
airport_code = row[0]
type = row[1]
continent = row[4]
iso_region = row[6]
longitude = float(row[-1].split(",")[0])
latitude = float(row[-1].split(",")[1])
out = [
airport_code,
type,
continent,
iso_region,
longitude,
latitude,
]
idx += 1
write.writerow(out)
def sort_edgelist(in_file, outname):
"""
sort the edges by day
"""
TIME_FORMAT = "%Y-%m-%d"
row_dict = {} #{day: {row: row}}
line_idx = 0
with open(outname, "w") as outf:
write = csv.writer(outf)
fields = ["day", "src", "dst", "callsign", "typecode"]
write.writerow(fields)
with open(in_file, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
for row in csv_reader:
if line_idx == 0: # header
line_idx += 1
continue
day = row[0]
ts = datetime.strptime(day, TIME_FORMAT)
ts = ts.timestamp()
if ts not in row_dict:
row_dict[ts] = {}
row_dict[ts][line_idx] = row
else:
row_dict[ts][line_idx] = row
line_idx += 1
for ts in sorted(row_dict.keys()):
for idx in row_dict[ts].keys():
row = row_dict[ts][idx]
write.writerow(row)
def date2ts(date_str: str) -> float:
r"""
convert date string to timestamp
"""
TIME_FORMAT = "%Y-%m-%d-%z"
date_cur = datetime.strptime(date_str, TIME_FORMAT)
return float(date_cur.timestamp())
def main():
"""
instructions for recompiling the dataset from
https://zenodo.org/record/7323875#.ZD1-43ZKguX
1. download all datasets into a folder specified by in_dir (such as full_dataset)
2. run the following code to extract the needed information
"""
# _, airports_country = load_icao_airports(fname="airport_codes.csv")
# in_dir = "full_dataset/"
# out_dir = "edgelists/"
# csv_name = "flightlist_20190101_20190131.csv"
# csv_names = find_csv_filenames(in_dir)
# processed_lines = 0
# skipped_lines = 0
# miss_node_lines = 0
# for csv_name in tqdm(csv_names):
# fname = in_dir + csv_name
# outname = out_dir + csv_name[11:-4] + "edgelist"+".csv"
# line_count, skip_lines, miss_node = flight2edgelist(fname, outname, node_dict=airports_country)
# processed_lines += line_count
# skipped_lines += skip_lines
# miss_node_lines += miss_node
# print(f'Processed {processed_lines} lines.')
# print(f'Skipped {skipped_lines} lines.')
# print(f'missing node {miss_node_lines} lines.')
"""
merge all edgelists into one file
"""
# in_dir = "edgelists/"
# outname = "opensky_edgelist.csv"
# csv_names = find_csv_filenames(in_dir)
# merge_edgelist(csv_names, in_dir, outname)
"""
clean the node features
"""
# in_file = "edgelists/airport_codes.csv"
# outname = "airport_node_feat.csv"
# clean_node_feat(in_file, outname)
"""
sort the edgelist by day
"""
# in_file = "tgbl-flight_edgelist.csv"
# outname = "tgbl-flight_edgelist_sorted.csv"
# sort_edgelist(in_file, outname)
"""
fixing time zone different for strip time
"""
tz_offset = "-0500"
ts = "2021-11-29" + "-" + tz_offset
print (date2ts(ts))
tz_offset = "+0000"
ts_utc = "2021-11-29" + "-" + tz_offset
print (date2ts(ts_utc))
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/dataset_scripts/tgbl-flight_neg_generator.py
================================================
import time
from tgb.linkproppred.negative_generator import NegativeEdgeGenerator
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
def main():
r"""
Generate negative edges for the validation or test phase
"""
print("*** Negative Sample Generation ***")
# setting the required parameters
num_neg_e_per_pos = 20
neg_sample_strategy = "hist_rnd" #"rnd"
rnd_seed = 42
name = "tgbl-flight"
dataset = PyGLinkPropPredDataset(name=name, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data_splits = {}
data_splits['train'] = data[train_mask]
data_splits['val'] = data[val_mask]
data_splits['test'] = data[test_mask]
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
# After successfully loading the dataset...
if neg_sample_strategy == "hist_rnd":
historical_data = data_splits["train"]
else:
historical_data = None
neg_sampler = NegativeEdgeGenerator(
dataset_name=name,
first_dst_id=min_dst_idx,
last_dst_id=max_dst_idx,
num_neg_e=num_neg_e_per_pos,
strategy=neg_sample_strategy,
rnd_seed=rnd_seed,
historical_data=historical_data,
)
# generate evaluation set
partial_path = "."
# generate validation negative edge set
start_time = time.time()
split_mode = "val"
print(
f"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}"
)
neg_sampler.generate_negative_samples(
data=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path
)
print(
f"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}"
)
# generate test negative edge set
start_time = time.time()
split_mode = "test"
print(
f"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}"
)
neg_sampler.generate_negative_samples(
data=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path
)
print(
f"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}"
)
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/dataset_scripts/tgbl-review.py
================================================
import pyarrow.dataset as ds
import csv
import numpy as np
from tgb.utils.stats import analyze_csv
import pandas as pd
from tqdm import tqdm
def collect_csv(dir_name="software"):
dataset = ds.dataset(dir_name, format="csv")
df = dataset.to_table().to_pandas()
df.to_csv(dir_name + ".csv", index=True)
def reorder_column(fname: str, outname: str):
with open(outname, "w") as outf:
write = csv.writer(outf)
fields = ["ts", "source", "target", "weight"]
write.writerow(fields)
line_count = 0
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
for row in csv_reader:
if line_count == 0: # header
line_count += 1
else:
# edgeid, SourceId,TargetId,Weight,Timestamp
src = row[1]
dst = row[2]
w = row[3]
ts = row[4]
write.writerow([ts, src, dst, w])
line_count += 1
def sort_edgelist(fname: str, outname: str):
with open(outname, "w") as outf:
write = csv.writer(outf)
fields = ["ts", "source", "target", "weight"]
write.writerow(fields)
line_count = 0
ts_list = []
line_list = []
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
for row in csv_reader:
if line_count == 0: # header
line_count += 1
else:
ts = int(row[0])
src = row[1]
dst = row[2]
w = row[3]
ts_list.append(ts)
line_list.append([ts, src, dst, w])
# write.writerow([ts, src, dst, w])
line_count += 1
ts_list = np.array(ts_list)
idx = np.argsort(ts_list)
idx = idx.tolist()
line_list_out = []
for i in idx:
line_list_out.append(line_list[i])
for line in line_list_out:
write.writerow(line)
def count_degree(fname: str):
node_counts = {}
line_count = 0
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
for row in csv_reader:
if line_count == 0: # header
line_count += 1
else:
ts = int(row[0])
src = row[1]
dst = row[2]
w = row[3]
if src not in node_counts:
node_counts[src] = 1
else:
node_counts[src] += 1
if dst not in node_counts:
node_counts[dst] = 1
else:
node_counts[dst] += 1
line_count += 1
return node_counts
def reduce_edgelist(fname: str, outname: str, node10_id: dict):
with open(outname, "w") as outf:
write = csv.writer(outf)
fields = ["ts", "source", "target", "weight"]
write.writerow(fields)
line_count = 0
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
for row in csv_reader:
if line_count == 0: # header
line_count += 1
else:
ts = int(row[0])
src = row[1]
dst = row[2]
w = row[3]
if (src in node10_id) and (dst in node10_id):
write.writerow([ts, src, dst, w])
line_count += 1
"""
function for review
"""
def csv_process_review(
fname: str,
outname: str = "review.csv",
) -> pd.DataFrame:
r"""
used for processing review dataset, helper function, not used in actual dataloading
input .csv file format should be: timestamp, node u, node v, attributes
Parameters:
fname: the path to the raw data
Returns:
df: a pandas dataframe containing the edgelist data
feat_l: a numpy array containing the node features
node_ids: a dictionary mapping node id to integer
"""
src_ids = {}
dst_ids = {}
src_ctr = 0
dst_ctr = 0
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
idx = 0
"""
ts,source,target,weight
929232000,137139,30122,5.0
930787200,129185,175070,2.0
931824000,246213,30122,2.0
"""
for row in tqdm(csv_reader):
if idx == 0:
idx += 1
continue
else:
ts = int(row[0])
src = row[1]
dst = row[2]
if src not in src_ids:
src_ids[src] = src_ctr
src_ctr += 1
if dst not in dst_ids:
dst_ids[dst] = dst_ctr
dst_ctr += 1
w = float(row[3])
#! ensure that source and destination nodes are unique and non-overlapping
src_ctr += 1
dst_ids = {k:v+src_ctr for k,v in dst_ids.items()}
with open(outname, "w") as outf:
write = csv.writer(outf)
fields = ["ts","source","target","weight"]
write.writerow(fields)
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
idx = 0
"""
ts,source,target,weight
929232000,137139,30122,5.0
930787200,129185,175070,2.0
931824000,246213,30122,2.0
"""
for row in tqdm(csv_reader):
if idx == 0:
idx += 1
continue
else:
ts = int(row[0])
src = src_ids[row[1]]
dst = dst_ids[row[2]]
w = float(row[3])
write.writerow([ts,src,dst,w])
def main():
# # collect csv
# # collect_csv(dir_name = "software")
# collect_csv(dir_name="books")
# # collect_csv(dir_name = "electronics")
# # #* reorder column
# # fname = "electronics.csv"
# # outname = "amazonreview_edgelist.csv"
# # reorder_column(fname,
# # outname)
# # #* sort edgelist
# # fname = "amazonreview_edgelist.csv"
# # outname = "amazonreview_edgelist_sort.csv"
# # sort_edgelist(fname,
# # outname)
# fname = "amazonreview_edgelist_reduce.csv"
# analyze_csv(fname)
# # fname = "amazonreview_edgelist.csv"
# # node_counts = count_degree(fname)
# # node10_id = {}
# # for node in node_counts:
# # if node_counts[node] > 10:
# # node10_id[node] = node_counts[node]
# # outname = "amazonreview_edgelist_reduce.csv"
# # reduce_edgelist(fname,
# # outname,
# # node10_id)
csv_process_review("tgbl-review_edgelist_v2.csv", "review.csv")
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/dataset_scripts/tgbl-review_neg_generator.py
================================================
import time
from tgb.linkproppred.negative_generator import NegativeEdgeGenerator
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
def main():
r"""
Generate negative edges for the validation or test phase
"""
print("*** Negative Sample Generation ***")
# setting the required parameters
num_neg_e_per_pos = 100 #20 #100
neg_sample_strategy = "hist_rnd" #"rnd"
rnd_seed = 42
name = "tgbl-review"
dataset = PyGLinkPropPredDataset(name=name, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data_splits = {}
data_splits['train'] = data[train_mask]
data_splits['val'] = data[val_mask]
data_splits['test'] = data[test_mask]
min_src_idx = int(data.src.min())
print (f"min_src_idx: {min_src_idx}")
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
# After successfully loading the dataset...
if neg_sample_strategy == "hist_rnd":
historical_data = data_splits["train"]
else:
historical_data = None
neg_sampler = NegativeEdgeGenerator(
dataset_name=name,
first_dst_id=min_dst_idx,
last_dst_id=max_dst_idx,
num_neg_e=num_neg_e_per_pos,
strategy=neg_sample_strategy,
rnd_seed=rnd_seed,
historical_data=historical_data,
)
# generate evaluation set
partial_path = "."
# generate validation negative edge set
start_time = time.time()
split_mode = "val"
print(
f"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}"
)
neg_sampler.generate_negative_samples(
data=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path
)
print(
f"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}"
)
# generate test negative edge set
start_time = time.time()
split_mode = "test"
print(
f"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}"
)
neg_sampler.generate_negative_samples(
data=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path
)
print(
f"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}"
)
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/dataset_scripts/tgbl-wiki_neg_generator.py
================================================
import timeit
from tgb.linkproppred.negative_generator import NegativeEdgeGenerator
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
def main():
r"""
Generate negative edges for the validation or test phase
"""
print("*** Negative Sample Generation ***")
# setting the required parameters
num_neg_e_per_pos = 12000 #11000 #10000 #20 #100
neg_sample_strategy = "hist_rnd" #"rnd"
rnd_seed = 42
name = "tgbl-wiki"
dataset = PyGLinkPropPredDataset(name=name, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data_splits = {}
data_splits['train'] = data[train_mask]
data_splits['val'] = data[val_mask]
data_splits['test'] = data[test_mask]
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
# After successfully loading the dataset...
if neg_sample_strategy == "hist_rnd":
historical_data = data_splits["train"]
else:
historical_data = None
neg_sampler = NegativeEdgeGenerator(
dataset_name=name,
first_dst_id=min_dst_idx,
last_dst_id=max_dst_idx,
num_neg_e=num_neg_e_per_pos,
strategy=neg_sample_strategy,
rnd_seed=rnd_seed,
historical_data=historical_data,
)
# generate evaluation set
partial_path = "./"
# generate validation negative edge set
start_time = time.time()
split_mode = "val"
print(
f"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}"
)
neg_sampler.generate_negative_samples(
data=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path
)
print(
f"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}"
)
# generate test negative edge set
start_time = timeit.default_timer()
split_mode = "test"
print(
f"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}"
)
neg_sampler.generate_negative_samples(
data=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path
)
print(
f"INFO: End of negative samples generation. Elapsed Time (s): {timeit.default_timer()- start_time: .4f}"
)
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/dataset_scripts/tgbn-genre.py
================================================
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
import csv
from typing import Optional, Dict, Any, Tuple
import datetime
from datetime import date, timedelta
from difflib import SequenceMatcher
# similarity_dict = {('electronic', 'electronica'): 0.9523809523809523, ('electronic', 'electro'): 0.8235294117647058, ('alternative', 'alternative rock'): 0.8148148148148148, ('nu jazz', 'nu-jazz'): 0.8571428571428571,
# ('funky', 'funk'): 0.8888888888888888, ('funky', 'funny'): 0.8, ('post rock', 'pop rock'): 0.8235294117647058, ('post rock', 'post-rock'): 0.8888888888888888,
# ('instrumental', 'instrumental rock'): 0.8275862068965517, ('chill', 'chile'): 0.8, ('Drum and bass', 'Drum n Bass'): 0.8333333333333334, ('female vocalists', 'female vocalist'): 0.967741935483871,
# ('female vocalists', 'male vocalists'): 0.9333333333333333, ('female vocalists', 'male vocalist'): 0.896551724137931, ('electro', 'electropop'): 0.8235294117647058, ('funk', 'fun'): 0.8571428571428571,
# ('hip hop', 'trip hop'): 0.8, ('hip hop', 'hiphop'): 0.9230769230769231, ('trip-hop', 'trip hop'): 0.875, ('indie rock', 'indie folk'): 0.8, ('new age', 'new wave'): 0.8, ('new age', 'new rave'): 0.8,
# ('synthpop', 'synth pop'): 0.9411764705882353, ('industrial', 'industrial rock'): 0.8, ('cover', 'covers'): 0.9090909090909091, ('post hardcore', 'post-hardcore'): 0.9230769230769231, ('mathcore', 'deathcore'): 0.8235294117647058,
# ('deutsch', 'dutch'): 0.8333333333333334, ('swing', 'sting'): 0.8, ('female vocalist', 'male vocalists'): 0.896551724137931, ('female vocalist', 'male vocalist'): 0.9285714285714286, ('new wave', 'new rave'): 0.875,
# ('male vocalists', 'male vocalist'): 0.9629629629629629, ('Progressive rock', 'Progressive'): 0.8148148148148148, ('Alt-country', 'alt country'): 0.8181818181818182, ('favorites', 'Favourites'): 0.8421052631578947,
# ('favorites', 'favourite'): 0.8888888888888888, ('favorites', 'Favorite'): 0.8235294117647058, ('1970s', '1980s'): 0.8, ('1970s', '1990s'): 0.8, ('proto-punk', 'post-punk'): 0.8421052631578947,
# ('folk rock', 'folk-rock'): 0.8888888888888888, ('1980s', '1990s'): 0.8, ('favorite songs', 'Favourite Songs'): 0.8275862068965517, ('melancholic', 'melancholy'): 0.8571428571428571,
# ('Favourites', 'favourite'): 0.8421052631578947, ('Favourites', 'Favorite'): 0.8888888888888888, ('Favourites', 'Favourite Songs'): 0.8, ('favourite', 'Favorite'): 0.8235294117647058,
# ('american', 'americana'): 0.9411764705882353, ('american', 'african'): 0.8, ('american', 'mexican'): 0.8, ('rock en español', 'Rock en Espanol'): 0.8, ('trance', 'psytrance'): 0.8,
# ('power pop', 'powerpop'): 0.9411764705882353, ('psychill', 'psychobilly'): 0.8421052631578947, ('Progressive metal', 'progressive death metal'): 0.8, ('Progressive metal', 'progressive black metal'): 0.8,
# ('progressive death metal', 'progressive black metal'): 0.8260869565217391, ('romantic', 'new romantic'): 0.8, ('hair metal', 'Dark metal'): 0.8, ('melodic metal', 'melodic black metal'): 0.8125,
# ('funk metal', 'folk metal'): 0.8, ('death metal', 'math metal'): 0.8571428571428571, ('Technical Metal', 'Technical Death Metal'): 0.8333333333333334, ('speed metal', 'sid metal'): 0.8}
#! map diferent spelling and similar ones to the same one, use space if possible
# ? key = to replace, value = to keep
similarity_dict = {
"nu-jazz": "nu jazz",
"funky": "funk",
"post-rock": "post rock",
"Drum n Bass": "Drum and bass",
"female vocalists": "female vocalist",
"male vocalists": "male vocalist",
"hiphop": "hip hop",
"trip-hop": "trip hop",
"synthpop": "synth pop",
"covers": "cover",
"post-hardcore": "post hardcore",
"Favourites": "favorites",
"favourite": "favorites",
"Favorite": "favorites",
"folk-rock": "folk rock",
"favorite songs": "favorites",
"Favourite Songs": "favorites",
"americana": "american",
"Rock en Espanol": "rock en español",
"melancholy": "melancholic",
"powerpop": "power pop",
}
def filter_genre_edgelist(fname, genres_dict):
"""
rewrite the edgelist but only keeping the genres with high frequency, also uses similarity_dict
"""
edgelist = open(fname, "r")
lines = list(edgelist.readlines())
edgelist.close()
with open("lastfm_edgelist_clean.csv", "w") as f:
write = csv.writer(f)
fields = ["user_id", "timestamp", "tags", "weight"]
write.writerow(fields)
for i in range(1, len(lines)):
vals = lines[i].split(",")
user_id = vals[1]
time = vals[2]
genre = vals[3].strip('"').strip("['")
w = vals[4][:-3]
if genre in genres_dict:
if genre in similarity_dict:
genre = similarity_dict[genre]
write.writerow([user_id, time, genre, w])
def get_genre_list(fname):
"""
edge_id, user_id, timestamp, tags
0,user_000001,2006-08-13 14:59:59+00:00,"['electronic', 0.5319148936170213]"
0,user_000001,2006-08-13 14:59:59+00:00,"['alternative', 0.46808510638297873]"
1,user_000001,2006-08-13 15:36:22+00:00,"['electronic', 0.6410256410256411]"
1,user_000001,2006-08-13 15:36:22+00:00,"['chillout', 0.358974358974359]"
2,user_000001,2006-08-13 15:40:13+00:00,"['math rock', 1.0]"
3,user_000001,2006-08-15 13:41:18+00:00,"['electronica', 1.0]"
"""
edgelist = open(fname, "r")
lines = list(edgelist.readlines())
edgelist.close()
genre_dict = {}
for i in range(1, len(lines)):
vals = lines[i].split(",")
user_id = vals[1]
time = vals[2]
genre = vals[3].strip('"').strip("['")
# genre = vals[3]
w = float(vals[4][:-3])
if genre not in genre_dict:
genre_dict[genre] = 1
else:
genre_dict[genre] += 1
# TODO check the frequency of genres and threshold
genre_list_10 = []
genre_list_100 = []
genre_list_1000 = []
genre_list_2000 = []
for key, freq in genre_dict.items():
if freq > 10:
genre_list_10.append([key])
if freq > 100:
genre_list_100.append([key])
if freq > 1000:
genre_list_1000.append([key])
if freq > 2000:
genre_list_2000.append([key])
print("number of genres with frequency > 10: " + str(len(genre_list_10)))
print("number of genres with frequency > 100: " + str(len(genre_list_100)))
print("number of genres with frequency > 1000: " + str(len(genre_list_1000)))
print("number of genres with frequency > 2000: " + str(len(genre_list_2000)))
fields = ["genre"]
with open("genre_list_1000.csv", "w") as f:
write = csv.writer(f)
write.writerow(fields)
write.writerows(genre_list_1000)
def find_unique_genres(fname: str, threshold: float = 0.8):
"""
identify fuzzy strings which are actually the same genre, differences can be spacing, typo etc.
"""
# load all genre names into a list
edgelist = open(fname, "r")
lines = list(edgelist.readlines())
edgelist.close()
genres = []
sim_genres = {}
for i in range(1, len(lines)):
line = lines[i]
genre = line.strip("\n")
genres.append(genre)
for i in range(len(genres)):
for j in range(i + 1, len(genres)):
text = genres[i]
search_key = genres[j]
sim = SequenceMatcher(None, text, search_key)
sim = sim.ratio()
if sim >= threshold:
sim_genres[(text, search_key)] = sim
print("there are " + str(len(sim_genres)) + " similar genres")
print(sim_genres)
def load_genre_dict(
fname: str,
) -> Dict[str, Any]:
"""
reading the list of genres from genre_list.csv
parameters:
fname: file name of the genre list
Returns:
genre_dict: a dictionary of genres
"""
genre_dict = {}
with open(fname, "r") as f:
reader = csv.reader(f)
for row in reader:
genre_dict[row[0]] = 1
return genre_dict
def generate_daily_node_labels(fname: str):
r"""
read a temporal edgelist
node label = fav genre in this day
generate the node label for each day for each user
Note: only genres from the genre_list are considered
user_000001,2006-08-13 14:59:59+00:00,"['electronic', 0.5319148936170213]"
user_000001,2006-08-13 14:59:59+00:00,"['alternative', 0.46808510638297873]"
user_000001,2006-08-13 15:36:22+00:00,"['electronic', 0.6410256410256411]"
user_000001,2006-08-13 15:36:22+00:00,"['chillout', 0.358974358974359]"
user_000001,2006-08-13 15:40:13+00:00,"['math rock', 1.0]"
user_000001,2006-08-15 13:41:18+00:00,"['electronica', 1.0]"
user_000001,2006-08-15 13:59:27+00:00,"['acid jazz', 0.3546099290780142]"
user_000001,2006-08-15 13:59:27+00:00,"['nu jazz', 0.3333333333333333]"
user_000001,2006-08-15 13:59:27+00:00,"['chillout', 0.3120567375886525]"
"""
edgelist = open(fname, "r")
lines = list(edgelist.readlines())
edgelist.close()
format = "%Y-%m-%d %H:%M:%S"
day_dict = {} # store the weights of genres on this day
cur_day = -1
with open("daily_labels.csv", "w") as outf:
write = csv.writer(outf)
fields = ["user_id", "year", "month", "day", "genre", "weight"]
write.writerow(fields)
# generate daily labels for users
for i in range(1, len(lines)):
vals = lines[i].split(",")
user_id = vals[0]
time = vals[1][:-7]
date_object = datetime.datetime.strptime(time, format)
if i == 1:
cur_day = date_object.day
genre = vals[2]
w = float(vals[3].strip())
if date_object.day != cur_day:
#! normalize the weights in the day_dict to sum 1
# * remove normalization for future aggregation
# total = sum(day_dict.values())
# day_dict = {k: v / total for k, v in day_dict.items()}
#! user,time,genre,weight # genres = # of weights
out = [
user_id,
str(date_object.year),
str(date_object.month),
str(date_object.day),
]
for genre, w in day_dict.items():
write.writerow(out + [genre] + [w])
cur_day = date_object.day
day_dict = {}
else:
if genre not in day_dict:
day_dict[genre] = w
else:
day_dict[genre] += w
def generate_aggregate_labels(fname: str, days: int = 7):
"""
aggregate the genres over a number of days, as specified by days
#! current generation includes edges from the day of the label, thus the label should be set to be beginning of the day
prediction should always be at the first second of the day
"""
edgelist = open(fname, "r")
lines = list(edgelist.readlines())
edgelist.close()
date_prev = 0
genre_dict = {}
user_prev = 0
# "user_id", "year", "month", "day", "genre", "weight"
with open(str(days) + "days_labels.csv", "w") as outf:
write = csv.writer(outf)
fields = ["user_id", "year", "month", "day", "genre", "weight"]
write.writerow(fields)
for i in range(1, len(lines)):
vals = lines[i].split(",")
user_id = vals[0]
year = int(vals[1])
month = int(vals[2])
day = int(vals[3])
genre = vals[4]
w = float(vals[5])
if i == 1:
date_prev = date(year, month, day)
user_prev = user_id
date_cur = date(year, month, day)
if user_id != user_prev:
date_prev = date(year, month, day)
user_prev = user_id
if (
date_cur - date_prev
).days <= days: #! this means that the date = [0,7] which includes the current day
if genre not in genre_dict:
genre_dict[genre] = w
else:
genre_dict[genre] += w
else:
# start a new week
# normalize the weight to sum 1
total = sum(genre_dict.values())
genre_dict = {k: v / total for k, v in genre_dict.items()}
out = [
user_id,
str(date_prev.year),
str(date_prev.month),
str(date_prev.day),
]
for genre, w in genre_dict.items():
write.writerow(out + [genre] + [w])
date_prev = date_prev + datetime.timedelta(days=1)
genre_dict = {}
def most_frequent(List):
"""
helper function to find the most frequent element in a list
the ties are broken by choosing the earlier element
"""
counter = 0
out = List[0]
for item in List:
curr_frequency = List.count(item)
if curr_frequency > counter: # update on most frequent item is found
counter = curr_frequency
out = item
return out
def convert_ts_unix(fname: str, outname: str):
"""
convert all time from datetime to unix time
"""
TIME_FORMAT = "%Y-%m-%d"
with open(outname, "w") as outf:
write = csv.writer(outf)
fields = ["ts", "user_id", "genre", "weight"]
write.writerow(fields)
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
line_count = 0
# time,user_id,genre,weight
for row in csv_reader:
if line_count == 0:
line_count += 1
else:
ts = datetime.datetime.strptime(row[0], TIME_FORMAT)
ts += timedelta(days=1)
ts = int(ts.timestamp())
user_id = row[1]
genre = row[2]
weight = float(row[3])
write.writerow([ts, user_id, genre, weight])
def convert_ts_edgelist(fname: str, outname: str):
"""
convert all time from datetime to unix time
"""
TIME_FORMAT = "%Y-%m-%d %H:%M:%S"
with open(outname, "w") as outf:
write = csv.writer(outf)
fields = ["ts", "user_id", "genre", "weight"]
write.writerow(fields)
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
line_count = 0
# time,user_id,genre,weight
for row in csv_reader:
if line_count == 0:
line_count += 1
else:
ts = datetime.datetime.strptime(row[0], TIME_FORMAT)
ts = int(ts.timestamp())
user_id = row[1]
genre = row[2]
weight = float(row[3])
write.writerow([ts, user_id, genre, weight])
def sort_node_labels(fname, outname):
r"""
sort the node labels by time
"""
edgelist = open(fname, "r")
lines = list(edgelist.readlines())
edgelist.close()
with open(outname, "w") as outf:
write = csv.writer(outf)
fields = ["time", "user_id", "genre", "weight"]
write.writerow(fields)
rows_dict = {}
for i in range(1, len(lines)):
vals = lines[i].split(",")
user_id = vals[0]
year = int(vals[1])
month = int(vals[2])
day = int(vals[3])
genre = vals[4]
w = float(vals[5])
date_cur = datetime(year, month, day)
time_ts = date_cur.strftime("%Y-%m-%d")
if time_ts not in rows_dict:
rows_dict[time_ts] = [(user_id, genre, w)]
else:
rows_dict[time_ts].append((user_id, genre, w))
time_keys = list(rows_dict.keys())
time_keys.sort()
for ts in time_keys:
rows = rows_dict[ts]
for user_id, genre, w in rows:
write.writerow([ts, user_id, genre, w])
def sort_edgelist(fname, outname="sorted_lastfm_edgelist.csv"):
r"""
sort the edgelist by time
"""
edgelist = open(fname, "r")
lines = list(edgelist.readlines())
edgelist.close()
with open(outname, "w") as outf:
write = csv.writer(outf)
fields = ["time", "user_id", "genre", "weight"]
write.writerow(fields)
rows_dict = {}
for idx in range(1, len(lines)):
vals = lines[idx].split(",")
user_id = vals[0]
time_ts = vals[1][:-7]
genre = vals[2]
w = float(vals[3].strip())
if time_ts not in rows_dict:
rows_dict[time_ts] = [(user_id, genre, w)]
else:
rows_dict[time_ts].append((user_id, genre, w))
time_keys = list(rows_dict.keys())
time_keys.sort()
for ts in time_keys:
rows = rows_dict[ts]
for user_id, genre, w in rows:
write.writerow([ts, user_id, genre, w])
if __name__ == "__main__":
#! generate the list of genres by frequency
# get_genre_list("/mnt/c/Users/sheny/Desktop/TGB/tgb/datasets/lastfmGenre/dataset.csv")
# genre_dict = load_genre_dict("/mnt/c/Users/sheny/Desktop/TGB/tgb/datasets/lastfmGenre/genre_list.csv")
#! find similar genres
# find_unique_genres("genre_list_1000.csv",threshold= 0.8)
#! filter edgelist with genres to keep
# genres_dict = load_genre_dict("genre_list_1000.csv")
# filter_genre_edgelist("dataset.csv", genres_dict)
#! generate the daily node labels
# generate_daily_node_labels("lastfm_edgelist_clean.csv")
# generate_daily_node_labels("/mnt/c/Users/sheny/Desktop/TGB/tgb/datasets/lastfmGenre/dataset.csv")
# load_node_labels("/mnt/c/Users/sheny/Desktop/TGB/tgb/datasets/lastfmGenre/daily_labels.csv")
# #! generate normalized weekly node labels
# generate_aggregate_labels("daily_labels.csv", days=7)\
# """
# sort edgelist by time for lastfm dataset
# """
# fname = "../datasets/lastfmGenre/lastfm_edgelist_clean.csv"
# outname = '../datasets/lastfmGenre/sorted_lastfm_edgelist.csv'
# sort_edgelist(fname,
# outname = outname)
# """
# sort node labels by time for lastfm dataset
# """
# fname = "../datasets/lastfmGenre/7days_labels.csv"
# outname = '../datasets/lastfmGenre/sorted_7days_node_labels.csv'
# sort_node_labels(fname,
# outname)
# #! convert from date to ts
# convert_ts_unix("lastfmgenre_node_labels_datetime.csv",
# "lastfmgenre_node_labels.csv")
convert_ts_edgelist("lastfmgenre_edgelist.csv", "lastfmgenre_edgelist_ts.csv")
================================================
FILE: tgb/datasets/dataset_scripts/tgbn-reddit.py
================================================
import csv
from tqdm import tqdm
from os import listdir
from tgb.utils.stats import analyze_csv
def find_filenames(path_to_dir):
r"""
find all files in a folder
Parameters:
path_to_dir (str): path to the directory
"""
filenames = listdir(path_to_dir)
return filenames
def combine_edgelist_edgefeat2subreddits(edgefname, featfname, outname):
"""
combine edgelist and edge features
'ts', 'src', 'subreddit', 'num_words', 'score'
"""
line_idx = 0
with open(outname, "w") as outf:
write = csv.writer(outf)
fields = ["ts", "src", "subreddit", "num_words", "score"]
write.writerow(fields)
sub_id = 0
edgelist = open(edgefname, "r")
edgefeat = open(featfname, "r")
edgelist.readline()
edgefeat.readline()
while True:
#'ts', 'src', 'dst', 'edge_id'
edge_line = edgelist.readline()
edge_line = edge_line.split(",")
if len(edge_line) < 4:
break
edge_id = int(edge_line[3])
ts = int(edge_line[0])
src = int(edge_line[1])
#'edge_id', 'subreddit', 'num_characters', 'num_words', 'score', 'edited_flag'
feat_line = edgefeat.readline()
feat_line = feat_line.split(",")
edge_id_feat = int(feat_line[0])
subreddit = feat_line[1]
num_words = int(feat_line[3])
score = int(feat_line[4])
if edge_id != edge_id_feat:
print("edge_id != edge_id_feat")
print(edge_id)
print(edge_id_feat)
break
write.writerow([ts, src, subreddit, num_words, score])
line_idx += 1
print("processed", line_idx, "lines")
def filter_subreddits(fname):
"""
check the frequency of subreddits
"""
subreddit_count = {}
node_count = {}
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
line_count = 0
# ts, src, subreddit, num_words, score
for row in csv_reader:
if line_count == 0:
line_count += 1
else:
ts = row[0]
src = row[1]
subreddit = row[2]
if subreddit not in subreddit_count:
subreddit_count[subreddit] = 1
else:
subreddit_count[subreddit] += 1
if src not in node_count:
node_count[src] = 1
else:
node_count[src] += 1
return subreddit_count, node_count
def clean_edgelist(fname, node_counts, outname, threshold=1000):
"""
helper function for filtering out low frequency nodes
"""
node_dict = {}
for node in node_counts:
if node_counts[node] >= threshold:
node_dict[node] = 1
with open(outname, "w") as outf:
write = csv.writer(outf)
fields = ["ts", "user", "subreddit", "num_words", "score"]
write.writerow(fields)
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
line_count = 0
# ts, src, subreddit, num_words, score
for row in csv_reader:
if line_count == 0:
line_count += 1
else:
ts = row[0]
src = row[1]
subreddit = row[2]
num_words = int(row[3])
score = int(row[4])
if src in node_dict:
write.writerow([ts, src, subreddit, num_words, score])
def clean_edgelist_reddits(fname, reddit_counts, outname, threshold=50):
"""
helper function for filtering out low frequency subreddits
"""
reddit_dict = {}
for reddit in reddit_counts:
if reddit_counts[reddit] >= threshold:
reddit_dict[reddit] = 1
print ("there remains, ", len(reddit_dict), " subreddits")
with open(outname, "w") as outf:
write = csv.writer(outf)
fields = ["ts", "user", "subreddit", "num_words", "score"]
write.writerow(fields)
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
line_count = 0
# ts, src, subreddit, num_words, score
for row in csv_reader:
if line_count == 0:
line_count += 1
else:
ts = row[0]
src = row[1]
subreddit = row[2]
num_words = int(row[3])
score = int(row[4])
if subreddit in reddit_dict:
write.writerow([ts, src, subreddit, num_words, score])
def remove_missing_user(fname, outname):
"""
remove all lines that are missing the user
"""
with open(outname, "w") as outf:
write = csv.writer(outf)
fields = ["ts", "user", "subreddit", "num_words", "score"]
write.writerow(fields)
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
line_count = 0
# ts, src, subreddit, num_words, score
for row in csv_reader:
if line_count == 0:
line_count += 1
else:
ts = row[0]
src = int(row[1])
subreddit = row[2]
num_words = int(row[3])
score = int(row[4])
if src != -1:
write.writerow([ts, src, subreddit, num_words, score])
def generate_daily_node_labels(
fname: str,
outname: str,
):
r"""
function for generating daily node labels then can be used for aggregation
"""
day_dict = {} # store the weights of genres on this day
prev_t = -1
DAY_IN_SEC = 86400
# WEEK_IN_SEC = 604800
with open(outname, "w") as outf:
write = csv.writer(outf)
fields = ["ts", "user", "subreddit", "weight"]
write.writerow(fields)
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
line_count = 0
# ts, src, subreddit, num_words, score
for row in csv_reader:
if line_count == 0:
line_count += 1
else:
ts = int(row[0])
user_id = row[1]
subreddit = row[2]
if line_count == 1:
prev_t = ts
if (prev_t + DAY_IN_SEC) < ts:
#! user,time,genre,weight # genres = # of weights
out = [user_id, ts]
for subreddit, w in day_dict.items():
write.writerow(out + [subreddit] + [w])
prev_t = ts
day_dict = {}
else:
if subreddit not in day_dict:
day_dict[subreddit] = 1
else:
day_dict[subreddit] += 1
line_count += 1
#! note that the edgelist are not sorted by users then by time, should keep multiple users when aggregating
def generate_aggregate_labels(fname: str, outname: str, days: int = 7):
"""
aggregate the genres over a number of days, as specified by days
prediction should always be at the first second of the day
#! daily labels are always shifted by 1 day
"""
ts_prev = 0
DAY_IN_SEC = 86400
timespan = days * DAY_IN_SEC
user_dict = {}
# ts, src, subreddit, num_words, score
with open(outname, "w") as outf:
write = csv.writer(outf)
fields = ["ts", "user", "subreddit", "weight"]
write.writerow(fields)
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
line_count = 0
# ts, src, subreddit, num_words, score
for row in csv_reader:
if line_count == 0:
line_count += 1
else:
ts = int(row[0])
user = row[1]
subreddit = row[2]
w = int(row[3])
if line_count == 1:
ts_prev = ts
if (ts - ts_prev) > timespan:
for user in user_dict:
total = sum(user_dict[user].values())
subreddit_dict = {
k: v / total for k, v in user_dict[user].items()
}
for subreddit, w in subreddit_dict.items():
write.writerow(
[ts_prev + DAY_IN_SEC, user, subreddit, w]
)
user_dict = {}
ts_prev = ts_prev + DAY_IN_SEC #! move label to the next day
else:
if user in user_dict:
if subreddit in user_dict[user]:
user_dict[user][subreddit] += w
else:
user_dict[user][subreddit] = w
else:
user_dict[user] = {}
user_dict[user][subreddit] = w
line_count += 1
def main():
# #? see redditcomments.py for the extraction from the raw files
#! combine edgelist and edge feat file check if the edge_id matches
# edgefname = "redditcomments_edgelist_2008_2010.csv"
# featfname = "redditcomments_edgefeat_2008_2010.csv"
# outname = "subreddits_edgelist.csv"
# combine_edgelist_edgefeat2subreddits(edgefname, featfname, outname)
#! remove all edges missing user
# fname = "subreddits_edgelist.csv"
# outname = "subreddits_edgelist_filtered.csv"
# remove_missing_user(fname,
# outname)
#! should clean subreddits first, frequency count of reddits
# fname = "subreddits_edgelist.csv"
# outname = "subreddits_edgelist_filter.csv"
# subreddit_count, node_count = filter_subreddits(fname)
# threshold = 1000 #200 #100
# clean_edgelist_reddits(fname, subreddit_count, outname, threshold=threshold)
#! filter out nodes with low frequency frequency count of nodes
# fname = "subreddits_edgelist.csv"
# outname = "subreddits_edgelist_clean.csv"
# subreddit_count, node_count = filter_subreddits(fname)
# threshold = 1000
# clean_edgelist(fname, node_count, outname, threshold=threshold)
# print ("finish cleaning")
#! generate aggregate labels, the label for each day is shifted by 1 day as it uses the edges from today
# fname = "subreddits_edgelist.csv"
# outname = "subreddits_node_labels.csv"
# generate_aggregate_labels(fname, outname, days=7)
#! analyze the extracted csv
fname = "subreddits_edgelist.csv"
analyze_csv(fname)
# #! generate daily node labels
# outname = 'subreddits_daily_labels.csv'
# fname = "subreddits_edgelist.csv"
# generate_daily_node_labels(fname,outname)
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/dataset_scripts/tgbn-token.py
================================================
import csv
import datetime
def count_node_freq(fname, filter_size=100):
node_dict = {}
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
ctr = 0
for row in csv_reader:
if ctr == 0:
ctr += 1
continue
else:
token_type = row[0]
src = row[1]
if (src not in node_dict):
node_dict[src] = 1
else:
node_dict[src] += 1
dst = row[2]
if (dst not in node_dict):
node_dict[dst] = 1
else:
node_dict[dst] += 1
ctr += 1
num_10 = 0
num_100 = 0
num_1000 = 0
num_2000 = 0
num_5000 = 0
for node in node_dict:
if node_dict[node] >= 10:
num_10 += 1
if node_dict[node] >= 100:
num_100 += 1
if node_dict[node] >= 1000:
num_1000 += 1
if node_dict[node] >= 2000:
num_2000 += 1
if node_dict[node] >= 5000:
num_5000 += 1
print("number of nodes with # edges >= 10 is ", num_10)
print("number of nodes with # edges >= 100 is ", num_100)
print("number of nodes with # edges >= 1000 is ", num_1000)
print("number of nodes with # edges >= 2000 is ", num_2000)
print("number of nodes with # edges >= 5000 is ", num_5000)
print("----------------------high level statistics-------------------------")
#! keep nodes with at least 100 edges
node_dict_filtered = {}
for node in node_dict:
if node_dict[node] >= filter_size:
node_dict_filtered[node] = node_dict[node]
return node_dict_filtered
def filter_edgelist(token_fname, edgefile, outname):
"""
preserve only the tokens in the token file
Parameters:
token_fname: the file of the token file
edgefile: the edgelist file name
outname: the output filtered edgelistname
"""
#* read tokens from the file
token_dict = {}
with open(token_fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
ctr = 0
for row in csv_reader:
if ctr == 0:
ctr += 1
continue
else:
token_type = row[0]
token_dict[token_type] = 1
with open(edgefile, "r") as in_file:
with open(outname, "w") as out_file:
csv_reader = csv.reader(in_file, delimiter=",")
csv_writer = csv.writer(out_file, delimiter=",")
csv_writer.writerow(["token_address", "from_address", "to_address", "value", "block_timestamp"])
ctr = 0
for row in csv_reader:
if ctr == 0:
ctr += 1
continue
else:
token_type = row[0]
if token_type in token_dict:
csv_writer.writerow(row)
ctr += 1
def filter_by_node(node_dict, edgefile, outname):
with open(edgefile, "r") as in_file:
with open(outname, "w") as out_file:
csv_reader = csv.reader(in_file, delimiter=",")
csv_writer = csv.writer(out_file, delimiter=",")
csv_writer.writerow(["token_address", "from_address", "to_address", "value", "block_timestamp"])
ctr = 0
for row in csv_reader:
if ctr == 0:
ctr += 1
continue
else:
token_type = row[0]
src = row[1]
dst = row[2]
if (src in node_dict) or (dst in node_dict):
csv_writer.writerow(row)
ctr += 1
def store_node_list(node_dict, outname):
"""
Parameters:
outname: name of the output csv file
Output:
output csv file with node list
"""
with open(outname, "w") as csv_file:
csv_writer = csv.writer(csv_file, delimiter=",")
csv_writer.writerow(["node_list", "frequency"])
for key, value in node_dict.items():
csv_writer.writerow([key, value])
def load_node_dict(fname):
node_dict = {}
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
ctr = 0
for row in csv_reader:
if ctr == 0:
ctr += 1
continue
else:
node = row[0]
freq = int(row[1])
node_dict[node] = freq
ctr += 1
return node_dict
def store_token_address(token_dict, outname, topk=1000):
"""
Parameters:
outname: name of the output csv file
Output:
output csv file with topk token addresses
"""
sorted_tokens = {k: v for k, v in sorted(token_dict.items(), key=lambda item: item[1], reverse=True)}
ctr = 0
with open(outname, "w") as csv_file:
csv_writer = csv.writer(csv_file, delimiter=",")
csv_writer.writerow(["token_address", "frequency"])
for key, value in sorted_tokens.items():
if (ctr <= topk):
csv_writer.writerow([key, value])
else:
break
ctr += 1
def analyze_token_frequency(fname):
# ['token_address', 'from_address', 'to_address', 'value', 'block_timestamp']
token_dict = {}
node_dict = {}
time_dict = {}
max_w = 0
min_w = 100000
num_edges = 0
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
ctr = 0
for row in csv_reader:
if ctr == 0:
ctr += 1
continue
else:
token_type = row[0]
if (token_type not in token_dict):
token_dict[token_type] = 1
else:
token_dict[token_type] += 1
src = row[1]
if (src not in node_dict):
node_dict[src] = 1
else:
node_dict[src] += 1
dst = row[2]
if (dst not in node_dict):
node_dict[dst] = 1
else:
node_dict[dst] += 1
w = float(row[3])
if (w > max_w):
max_w = w
elif (w < min_w):
min_w = w
timestamp = row[4]
if (timestamp not in time_dict):
time_dict[timestamp] = 1
ctr += 1
num_edges += 1
print ( "number of edges are ", num_edges)
print (" number of unique tokens are ", len(token_dict))
print (" number of unique nodes are ", len(node_dict))
print (" number of unique timestamps are ", len(time_dict))
print (" max weight is ", max_w)
print (" min weight is ", min_w)
# topk = 1000
# store_token_address(token_dict, "token_list.csv", topk=topk)
def to_bipartite(in_name, out_name, node_dict):
"""
load and convert a user-user graph into a user-token bipartite graph
"""
with open(in_name, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
with open(out_name, "w") as out_file:
csv_writer = csv.writer(out_file, delimiter=",")
csv_writer.writerow(["timestamp", "user_address", "token_address", "value", "IsSender"])
ctr = 0
for row in csv_reader:
if ctr == 0:
ctr += 1
continue
else:
token_type = row[0]
src = row[1]
dst = row[2]
w = float(row[3])
timestamp = row[4]
if (src in node_dict):
csv_writer.writerow([timestamp, src, token_type, w, 1])
if (dst in node_dict):
csv_writer.writerow([timestamp, dst, token_type, w, 0])
def analyze_csv(fname):
node_dict = {}
edge_dict = {}
num_edges = 0
num_time = 0
time_dict = {}
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
line_count = 0
for row in csv_reader:
if line_count == 0:
line_count += 1
else:
# t,u,v,w
t = row[0]
u = row[1]
v = row[2]
# count unique time
if t not in time_dict:
time_dict[t] = 1
num_time += 1
# unique nodes
if u not in node_dict:
node_dict[u] = 1
else:
node_dict[u] += 1
if v not in node_dict:
node_dict[v] = 1
else:
node_dict[v] += 1
# unique edges
num_edges += 1
if (u, v) not in edge_dict:
edge_dict[(u, v)] = 1
else:
edge_dict[(u, v)] += 1
print("----------------------high level statistics-------------------------")
print("number of total edges are ", num_edges)
print("number of nodes are ", len(node_dict))
print("number of unique edges are ", len(edge_dict))
print("number of unique timestamps are ", num_time)
num_10 = 0
num_100 = 0
num_1000 = 0
for node in node_dict:
if node_dict[node] >= 10:
num_10 += 1
if node_dict[node] >= 100:
num_100 += 1
if node_dict[node] >= 1000:
num_1000 += 1
print("number of nodes with # edges >= 10 is ", num_10)
print("number of nodes with # edges >= 100 is ", num_100)
print("number of nodes with # edges >= 1000 is ", num_1000)
print("----------------------high level statistics-------------------------")
def convert_2_sec(fname, outname):
"""
convert datetime object format = "%Y-%m-%d %H:%M:%S" to seconds
#2017-07-24 17:48:15+00:00
"""
format = "%Y-%m-%d %H:%M:%S"
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
with open(outname, "w") as out_file:
csv_writer = csv.writer(out_file, delimiter=",")
csv_writer.writerow(["timestamp", "user_address", "token_address", "value", "IsSender"])
ctr = 0
for row in csv_reader:
if ctr == 0:
ctr += 1
continue
else:
timestamp = row[0][:19]
date_object = datetime.datetime.strptime(timestamp, format)
timestamp_sec = int(date_object.timestamp())
src = row[1]
dst = row[2]
w = float(row[3])
IsSender = int(row[4])
if (w != 0):
csv_writer.writerow([timestamp_sec, src, dst, w, IsSender])
def print_csv(fname):
# ['token_address', 'from_address', 'to_address', 'value', 'block_timestamp']
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
ctr = 0
for row in csv_reader:
ctr += 1
print ("there are ", ctr, " rows in the csv file")
def sort_edgelist_by_time(fname, outname):
row_dict = {}
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
with open(outname, "w") as out_file:
csv_writer = csv.writer(out_file, delimiter=",")
csv_writer.writerow(["timestamp", "user_address", "token_address", "value", "IsSender"])
ctr = 0
for row in csv_reader:
if ctr == 0:
ctr += 1
continue
else:
timestamp =int(row[0])
if (timestamp not in row_dict):
row_dict[timestamp] = [row]
else:
row_dict[timestamp].append(row)
for i in sorted(row_dict.keys()):
rows = row_dict[i]
for row in rows:
csv_writer.writerow(row)
#! aggregate node labels
def generate_aggregate_labels(fname: str, outname: str, days: int = 7):
"""
aggregate the genres over a number of days, as specified by days
prediction should always be at the first second of the day
#! daily labels are always shifted by 1 day
"""
ts_prev = 0
DAY_IN_SEC = 86400
timespan = days * DAY_IN_SEC
user_dict = {}
# ts, src, subreddit, num_words, score
with open(outname, "w") as outf:
write = csv.writer(outf)
fields = ["ts", "user_address", "token_address", "weight"] #["ts", "user", "subreddit", "weight"]
write.writerow(fields)
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
line_count = 0
# ts, src, subreddit, num_words, score
for row in csv_reader:
if line_count == 0:
line_count += 1
else:
ts = float(row[0])
ts = int(ts)
user = row[1]
item = row[2]
w = float(row[3])
if (w == 0):
print (row)
if line_count == 1:
ts_prev = ts
if (ts - ts_prev) > timespan:
for user in user_dict:
total = sum(user_dict[user].values())
item_dict = {
k: v / total for k, v in user_dict[user].items()
}
for item, w in item_dict.items():
write.writerow(
[ts_prev + DAY_IN_SEC, user, item, w]
)
user_dict = {}
ts_prev = ts_prev + DAY_IN_SEC #! move label to the next day
else:
if user in user_dict:
if item in user_dict[user]:
user_dict[user][item] += w
else:
user_dict[user][item] = w
else:
user_dict[user] = {}
user_dict[user][item] = w
line_count += 1
def main():
"""
processing token types
"""
# fname = "ERC20_token_network.csv"
# #analyze_token_frequency(fname)
# token_file = "token_list.csv"
# outname = "filtered_token_edgelist.csv"
#! filter by token frequency
# filter_edgelist(token_file, fname, outname)
# #print_csv(fname)
# #analyze_csv(fname)
"""
processing node dict
"""
# fname = "filtered_token_edgelist.csv"
# #! filter by node frequency
# node_dict = count_node_freq(fname, filter_size=100)
# store_node_list(node_dict, "node_list.csv")
# #store_token_address(node_dict, "node_list.csv", topk=0)
# outname = "tgbl-token-edgelist_100.csv"
# filter_by_node(node_dict, fname, outname)
# analyze_token_frequency('tgbl-token-edgelist_100.csv')
#! converting user-user graph to user-token bipartite graph
# out_name = "tgbl-token_edgelist.csv"
# node_dict = load_node_dict("node_list.csv")
# to_bipartite('tgbl-token-edgelist_100.csv', out_name, node_dict)
# analyze_csv(out_name)
#! convert datetime to seconds
#convert_2_sec("tgbl-token_edgelist_old.csv", "tgbn-token_edgelist.csv")
#! sort the timestamps in the edgelist
# fname = "tgbn-token_edgelist.csv"
# outname = "tgbn-token_edgelist_sorted.csv"
# sort_edgelist_by_time(fname, outname)
#! generate node labels
edgefile = "tgbn-token_edgelist.csv"
outfile = "tgbn-token_node_labels.csv"
days = 7
generate_aggregate_labels(edgefile, outfile, days=days)
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/dataset_scripts/tgbn-trade.py
================================================
import dateutil.parser as dparser
import csv
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from os import listdir
def count_unique_countries(fname):
node_dict = {}
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
line_count = 0
# year,u,v,w
for row in csv_reader:
if line_count == 0:
line_count += 1
else:
year = int(row[0])
u = row[1]
v = row[2]
w = float(row[3])
if u not in node_dict:
node_dict[u] = 1
if v not in node_dict:
node_dict[v] = 1
print("there are {} unique countries".format(len(node_dict)))
#! incorrect, do not use
def normalize_edgelist(fname: str, outname: str):
"""
need to track id for nodes
normalize the edgelist by row for each year
"""
prev_t = 0
uid = 0
node_dict = {}
year_dict = {}
with open(outname, "w") as outf:
write = csv.writer(outf)
fields = ["year", "nation", "trading nation", "weight"]
write.writerow(fields)
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
line_count = 0
for row in csv_reader:
if line_count == 0:
line_count += 1
else:
ts = int(row[0])
u = row[1]
v = row[2]
w = float(row[3])
if line_count == 1:
prev_t = ts
if u not in node_dict:
node_dict[u] = uid
uid += 1
if v not in node_dict:
node_dict[v] = uid
uid += 1
if w == 0:
line_count += 1
continue
if ts != prev_t: # a new year now, write everything
# normalize the counts
for u in year_dict:
if np.sum(year_dict[u]) == 0:
continue
year_dict[u] = year_dict[u] / np.sum(year_dict[u])
invert_dict = {v: k for k, v in node_dict.items()}
for v in range(len(year_dict[u])):
if year_dict[u][v] > 0:
write.writerow(
[prev_t, u, invert_dict[v], year_dict[u][v]]
)
year_dict = {}
prev_t = ts
else:
if u not in year_dict:
year_dict[u] = np.zeros(255)
year_dict[u][node_dict[v]] = w
else:
year_dict[u][node_dict[v]] = w
line_count += 1
def generate_aggregate_labels(fname: str, outname: str):
"""
aggregate the node label for next year
"""
ts_init = 1986
# ts, src, subreddit, num_words, score
with open(outname, "w") as outf:
write = csv.writer(outf)
fields = ["year", "nation", "trading nation", "weight"]
write.writerow(fields)
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
line_count = 0
# ts, src, subreddit, num_words, score
for row in csv_reader:
if line_count == 0:
line_count += 1
else:
ts = int(row[0])
u = row[1]
v = row[2]
w = float(row[3])
if (ts > ts_init):
write.writerow([ts, u, v, w])
line_count += 1
def check_sum_to_one(fname: str):
"""
just to check if weights sum to 1 in a year
"""
u_dict = {}
ts_prev = 0
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
line_count = 0
# ts, src, subreddit, num_words, score
for row in csv_reader:
if line_count == 0:
line_count += 1
else:
ts = int(row[0])
if (line_count == 1):
ts_prev = ts
if (ts != ts_prev):
ts_prev = ts
for u in u_dict:
print (u_dict[u])
u_dict = {}
u = row[1]
v = row[2]
w = float(row[3])
if (u not in u_dict):
u_dict[u] = w
else:
u_dict[u] += w
line_count += 1
def main():
#! should have the normalized version on the edgelist
# #find the number of unique countries
# fname = "un_trade_edgelist.csv"
# count_unique_countries(fname)
#! normalize edgelist by row for each year
# fname = "un_trade_edgelist.csv"
# outname = "un_trade_edgelist_normalized.csv"
# normalize_edgelist(fname, outname)
#! find the node label for next year
# * the node labels are simply the edgelist in this case
# fname = "un_trade_edgelist.csv"
# outname = "un_trade_node_labels.csv"
# generate_aggregate_labels(fname, outname)
# #! check if all sums are correct
# fname = "un_trade_node_labels.csv"
# check_sum_to_one(fname)
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/tgbl_enron/tgbl-enron_neg_generator.py
================================================
import timeit
from tgb.linkproppred.negative_generator import NegativeEdgeGenerator
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
def main():
r"""
Generate negative edges for the validation or test phase
"""
print("*** Negative Sample Generation ***")
# setting the required parameters
num_neg_e_per_pos = 1000 # more than half the nodes in the graph
neg_sample_strategy = "hist_rnd" #"rnd"
rnd_seed = 42
name = "tgbl-enron"
dataset = PyGLinkPropPredDataset(name=name, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data_splits = {}
data_splits['train'] = data[train_mask]
data_splits['val'] = data[val_mask]
data_splits['test'] = data[test_mask]
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
# After successfully loading the dataset...
if neg_sample_strategy == "hist_rnd":
historical_data = data_splits["train"]
else:
historical_data = None
neg_sampler = NegativeEdgeGenerator(
dataset_name=name,
first_dst_id=min_dst_idx,
last_dst_id=max_dst_idx,
num_neg_e=num_neg_e_per_pos,
strategy=neg_sample_strategy,
rnd_seed=rnd_seed,
historical_data=historical_data,
)
# generate evaluation set
partial_path = "./"
# generate validation negative edge set
start_time = timeit.default_timer()
split_mode = "val"
print(
f"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}"
)
neg_sampler.generate_negative_samples(
data=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path
)
print(
f"INFO: End of negative samples generation. Elapsed Time (s): {timeit.default_timer() - start_time: .4f}"
)
# generate test negative edge set
start_time = timeit.default_timer()
split_mode = "test"
print(
f"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}"
)
neg_sampler.generate_negative_samples(
data=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path
)
print(
f"INFO: End of negative samples generation. Elapsed Time (s): {timeit.default_timer()- start_time: .4f}"
)
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/tgbl_enron/tgbl_enron.py
================================================
import csv
with open('ml_enron.csv', 'r', newline='\n') as infile, open('tgbl-enron_edgelist.csv', 'w', newline='\n') as outfile:
reader = csv.reader(infile)
writer = csv.writer(outfile)
for row in reader:
writer.writerow(row[1:])
================================================
FILE: tgb/datasets/tgbl_lastfm/tgbl-lastfm_neg_generator.py
================================================
import timeit
from tgb.linkproppred.negative_generator import NegativeEdgeGenerator
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
def main():
r"""
Generate negative edges for the validation or test phase
"""
print("*** Negative Sample Generation ***")
# setting the required parameters
num_neg_e_per_pos = 2000 #this is all nodes in the dataset
neg_sample_strategy = "hist_rnd" #"rnd"
rnd_seed = 42
name = "tgbl-lastfm"
dataset = PyGLinkPropPredDataset(name=name, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data_splits = {}
data_splits['train'] = data[train_mask]
data_splits['val'] = data[val_mask]
data_splits['test'] = data[test_mask]
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
# After successfully loading the dataset...
if neg_sample_strategy == "hist_rnd":
historical_data = data_splits["train"]
else:
historical_data = None
neg_sampler = NegativeEdgeGenerator(
dataset_name=name,
first_dst_id=min_dst_idx,
last_dst_id=max_dst_idx,
num_neg_e=num_neg_e_per_pos,
strategy=neg_sample_strategy,
rnd_seed=rnd_seed,
historical_data=historical_data,
)
# generate evaluation set
partial_path = "./"
# generate validation negative edge set
start_time = timeit.default_timer()
split_mode = "val"
print(
f"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}"
)
neg_sampler.generate_negative_samples(
data=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path
)
print(
f"INFO: End of negative samples generation. Elapsed Time (s): {timeit.default_timer() - start_time: .4f}"
)
# generate test negative edge set
start_time = timeit.default_timer()
split_mode = "test"
print(
f"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}"
)
neg_sampler.generate_negative_samples(
data=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path
)
print(
f"INFO: End of negative samples generation. Elapsed Time (s): {timeit.default_timer()- start_time: .4f}"
)
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/tgbl_subreddit/tgbl-subreddit_neg_generator.py
================================================
import timeit
from tgb.linkproppred.negative_generator import NegativeEdgeGenerator
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
def main():
r"""
Generate negative edges for the validation or test phase
"""
print("*** Negative Sample Generation ***")
# setting the required parameters
num_neg_e_per_pos = 1000
neg_sample_strategy = "hist_rnd" #"rnd"
rnd_seed = 42
name = "tgbl-subreddit"
dataset = PyGLinkPropPredDataset(name=name, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data_splits = {}
data_splits['train'] = data[train_mask]
data_splits['val'] = data[val_mask]
data_splits['test'] = data[test_mask]
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
# After successfully loading the dataset...
if neg_sample_strategy == "hist_rnd":
historical_data = data_splits["train"]
else:
historical_data = None
neg_sampler = NegativeEdgeGenerator(
dataset_name=name,
first_dst_id=min_dst_idx,
last_dst_id=max_dst_idx,
num_neg_e=num_neg_e_per_pos,
strategy=neg_sample_strategy,
rnd_seed=rnd_seed,
historical_data=historical_data,
)
# generate evaluation set
partial_path = "./"
# generate validation negative edge set
start_time = timeit.default_timer()
split_mode = "val"
print(
f"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}"
)
neg_sampler.generate_negative_samples(
data=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path
)
print(
f"INFO: End of negative samples generation. Elapsed Time (s): {timeit.default_timer() - start_time: .4f}"
)
# generate test negative edge set
start_time = timeit.default_timer()
split_mode = "test"
print(
f"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}"
)
neg_sampler.generate_negative_samples(
data=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path
)
print(
f"INFO: End of negative samples generation. Elapsed Time (s): {timeit.default_timer()- start_time: .4f}"
)
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/tgbl_uci/tgbl-uci_neg_generator.py
================================================
import timeit
from tgb.linkproppred.negative_generator import NegativeEdgeGenerator
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
def main():
r"""
Generate negative edges for the validation or test phase
"""
print("*** Negative Sample Generation ***")
# setting the required parameters
num_neg_e_per_pos = 1000 # more than half the nodes in the graph
neg_sample_strategy = "hist_rnd" #"rnd"
rnd_seed = 42
name = "tgbl-uci"
dataset = PyGLinkPropPredDataset(name=name, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data_splits = {}
data_splits['train'] = data[train_mask]
data_splits['val'] = data[val_mask]
data_splits['test'] = data[test_mask]
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
# After successfully loading the dataset...
if neg_sample_strategy == "hist_rnd":
historical_data = data_splits["train"]
else:
historical_data = None
neg_sampler = NegativeEdgeGenerator(
dataset_name=name,
first_dst_id=min_dst_idx,
last_dst_id=max_dst_idx,
num_neg_e=num_neg_e_per_pos,
strategy=neg_sample_strategy,
rnd_seed=rnd_seed,
historical_data=historical_data,
)
# generate evaluation set
partial_path = "./"
# generate validation negative edge set
start_time = timeit.default_timer()
split_mode = "val"
print(
f"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}"
)
neg_sampler.generate_negative_samples(
data=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path
)
print(
f"INFO: End of negative samples generation. Elapsed Time (s): {timeit.default_timer() - start_time: .4f}"
)
# generate test negative edge set
start_time = timeit.default_timer()
split_mode = "test"
print(
f"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}"
)
neg_sampler.generate_negative_samples(
data=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path
)
print(
f"INFO: End of negative samples generation. Elapsed Time (s): {timeit.default_timer()- start_time: .4f}"
)
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/tgbl_uci/tgbl_uci.py
================================================
import csv
with open('ml_uci.csv', 'r', newline='\n') as infile, open('tgbl-uci_edgelist.csv', 'w', newline='\n') as outfile:
reader = csv.reader(infile)
writer = csv.writer(outfile)
for row in reader:
writer.writerow(row[1:])
================================================
FILE: tgb/datasets/thgl_forum/merge_files.py
================================================
import csv
from tqdm import tqdm
from os import listdir
import glob
def find_filenames(path_to_dir):
r"""
find all files in a folder
Parameters:
path_to_dir (str): path to the directory
"""
# filenames = glob.glob(path_to_dir)
filenames = listdir(path_to_dir)
return filenames
def read_edgelist(fname, outfname, write_header=False):
"""
read a space separated edgelist
comment’s author, author of the parent (the post that the comment is replied to), comment’s creation time, comment’s edge id
u,v,t,edge_id
3746738 1637382 1551398391 31534079835
Parameters:
fname (str): path to the edgelist
outfname (str): path to the output file
"""
edgelist = open(fname, "r")
lines = list(edgelist.readlines())
edgelist.close()
with open(outfname, "a") as outf:
write = csv.writer(outf)
if write_header:
fields = ["ts", "src", "dst", "edge_id"]
write.writerow(fields)
for line in lines:
line = line.split()
if len(line) < 4:
continue
src = line[0]
dst = line[1]
ts = line[2]
edge_id = line[3]
write.writerow([ts, src, dst, edge_id])
def read_nodeattr(fname, outfname, write_header=False):
"""
read a space separated edgelist
comment’s edge id, Reddit’s identifier of the comment, Reddit’s identifier of the parent (the post that the comment is replied to)
Reddit’s identifier of the submission that the comment is in, name of the subreddit that the comment is in, number of characters in the comment’s body
number of words in the comment’s body, score of the comment, a flag indicating if the comment has been edited
- comment’s edge id
- Reddit’s identifier of the comment
- Reddit’s identifier of the parent (the post that the comment is replied to)
- Reddit’s identifier of the submission that the comment is in
- name of the subreddit that the comment is in
- number of characters in the comment’s body
- number of words in the comment’s body
- score of the comment
- a flag indicating if the comment has been edited
edge_id, subreddit, num_characters, num_words, score, 'edited_flag'
Parameters:
fname (str): path to the edgelist
outfname (str): path to the output file
"""
edgelist = open(fname, "r")
lines = list(edgelist.readlines())
edgelist.close()
max_words = 0
max_score = 0
min_score = 100000000
with open(outfname, "a") as outf:
write = csv.writer(outf)
if write_header:
fields = [
"edge_id",
"reddit_id",
"reddit_parent_id",
"subreddit",
"num_characters",
"num_words",
"score",
"edited_flag",
]
write.writerow(fields)
for line in lines:
line = line.split()
if len(line) < 4:
continue
edge_id = line[0]
reddit_id = line[1]
reddit_parent_id = line[2]
subreddit = line[4]
num_characters = line[5]
num_words = line[6]
if (int(num_words) > max_words):
max_words = int(num_words)
score = line[7]
if (int(score) < min_score):
min_score = int(score)
edited_flag = line[8].strip("/n")
write.writerow(
[edge_id, reddit_id, reddit_parent_id, subreddit, num_characters, num_words, score, edited_flag]
)
print("max # words", max_words)
print("min score", min_score)
print("max score", max_score)
def combine_edgelist_edgefeat(edgefname, featfname, outname):
"""
combine edgelist and edge features
"""
total_lines = sum(1 for line in open(edgefname))
subreddit_ids = {}
missing_ts = 0
missing_src = 0
missing_dst = 0
line_idx = 0
with open(outname, "w") as outf:
write = csv.writer(outf)
#fields = ["ts", "src", "dst", "subreddit", "num_words", "score"]
fields = ["ts", "src", "dst", "reddit_id", "reddit_parent_id", "subreddit", "num_words", "score"]
write.writerow(fields)
sub_id = 0
edgelist = open(edgefname, "r")
edgefeat = open(featfname, "r")
edgelist.readline()
edgefeat.readline()
while True:
#'ts', 'src', 'dst', 'edge_id'
edge_line = edgelist.readline()
edge_line = edge_line.split(",")
if len(edge_line) < 4:
break
edge_id = int(edge_line[3])
ts = int(edge_line[0])
src = int(edge_line[1])
dst = int(edge_line[2])
# "edge_id", "reddit_id", "reddit_parent_id", "subreddit", "num_characters", "num_words", "score", "edited_flag",
feat_line = edgefeat.readline()
feat_line = feat_line.split(",")
edge_id_feat = int(feat_line[0])
reddit_id = feat_line[1]
reddit_parent_id = feat_line[2]
subreddit = feat_line[3]
num_characters = int(feat_line[4])
num_words = int(feat_line[5])
score = int(feat_line[6])
edited_flag = bool(feat_line[7])
#! check if ts, src, dst is -1
if ts == -1:
missing_ts += 1
continue
if src == -1:
missing_src += 1
continue
if dst == -1:
missing_dst += 1
continue
if edge_id != edge_id_feat:
print("edge_id != edge_id_feat")
print(edge_id)
print(edge_id_feat)
break
# write.writerow([ts, src, dst, subreddit, num_words, score])
#write.writerow([ts, src, dst, num_words, score])
#? ts: int, src: int (user_id), dst: int (user_id), subreddit: str, reddit_id: str (comment_id), reddit_parent_id: str (post_id), num_words: int, score: int
write.writerow([ts, src, dst, subreddit, reddit_id, reddit_parent_id, num_words, score])
line_idx += 1
print("processed", line_idx, "lines")
print ("there are lines", missing_ts, " missing timestamps")
print ("there are lines", missing_src, " missing src")
print ("there are lines", missing_dst, " missing dst")
def main():
# #! unzip all xz files by $ unxz *.xz
f_dir = "edge_files/" #"raw/raw_2005_2010/" #"raw/raw_2013_2014/"
fnames = find_filenames(f_dir)
edgefname = "reddit_edgefile_2014_01.csv" #"redditcomments_edgelist_2013_2014.csv"
idx = 0
for fname in tqdm(fnames):
print ("processing, ", fname)
if (idx == 0):
read_edgelist(f_dir+fname, edgefname, write_header=True)
else:
read_edgelist(f_dir+fname, edgefname, write_header=False)
idx += 1
# # #! extract the node attributes
f_dir = "attribute_files/"
fnames = find_filenames(f_dir)
featfname = "reddit_attribute_2014_01.csv"
idx = 0
for fname in tqdm(fnames):
print ("processing, ", fname)
if (idx == 0):
read_nodeattr(f_dir+fname, featfname, write_header=True)
else:
read_nodeattr(f_dir+fname, featfname, write_header=False)
idx += 1
#! combine edgelist and edge feat file check if the edge_id matches
# edgefname = "reddit_edgefile_2019_01_03.csv"
# featfname = "reddit_attribute_2019_01_03.csv"
outname = "reddit_edgelist.csv"
combine_edgelist_edgefeat(edgefname, featfname, outname)
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/thgl_forum/thgl-forum.py
================================================
"""
Data source:
# https://surfdrive.surf.nl/files/index.php/s/M09RDerAMZrQy8q#editor
# https://dl.acm.org/doi/abs/10.1145/3487553.3524699
also see here: https://arxiv.org/pdf/1803.03697
# Temporal Social Network Dataset of Reddit
### Dataset to accompany “A large-scale temporal analysis of user lifespan durability on the Reddit social media platform” (WWW 2022).
## Overview
This dataset consists of more than 6.7 billion Reddit comment interactions made from the beginning of Reddit in 2005 until the end of 2019.
### Nodes
Nodes in the network represent users who posted at least one comment or one submission until the end of 2019 and have not deleted their accounts by the time of the data ingestion.
Each user is assigned a unique identifier starting from 0, and -1 is the identifier of the node representing deleted users. The `nodes` file maintains the node-identifier assignment to each user’s username.
### Edges
For each month of our data, we maintain two separate files, an edge file that consists of temporal edges data and an attribute file that consists of attributes of each interaction. All these files are in a tab-separated format. The compressed edge files and compressed attribute files are available in the `edges` and the `attributes` directory, respectively. The name of files indicates the timeframe they belong to.
Each line in an edge file corresponds to a comment and includes:
- comment’s author
- author of the parent (the post that the comment is replied to)
- comment’s creation time
- comment’s edge id
Each line in an attribute file corresponds to the line with the same line number in the corresponding edge file and includes:
- comment’s edge id
- Reddit’s identifier of the comment
- Reddit’s identifier of the parent (the post that the comment is replied to)
- Reddit’s identifier of the submission that the comment is in
- name of the subreddit that the comment is in
- number of characters in the comment’s body
- number of words in the comment’s body
- score of the comment
- a flag indicating if the comment has been edited
### Stats
Size (compressed): 125GB
Size (uncompressed): 652GB
Number of nodes: 62,402,844
Number of edges: 6,728,759,080
### Notes
Reddit banned the subreddit `/r/Incels` in November of 2017, and its data is no longer available via the Reddit API. This has resulted in the loss of score data for 119,111 comments made in October and November of 2017 in this subreddit. The affected entries have a null value as their score.
## Citation
If you want to reuse this dataset, you can reference it as follows:
A. Nadiri and F.W. Takes, A large-scale temporal analysis of user lifespan durability on the Reddit social media platform, in Proceedings of the 28th ACM International Web Conference (TheWebConf) Workshops, 2022.
## Online repository
The dataset is available for download at [**LINK**](https://surfdrive.surf.nl/files/index.php/s/M09RDerAMZrQy8q)
## Acknowledgments
The dataset is constructed using data provided by [The Pushshift Reddit Dataset](https://ojs.aaai.org/index.php/ICWSM/article/view/7347)
"""
"""
ideas for temporal heterogenous graph in reddit data:
node types:
1. user
2. subreddit
edge types
1. user post in subreddit (top level)
2. user replies to another user
3. user replies in subreddit
# node types:
# 1. user
# 2. subreddit
# 3. comment
# edge types
# 1. user makes comment in subreddit (top level comment)
# 2. user replies to comment in subreddit (comments that has a parent)
# 2. comment is child of comment (comments that has a parent)
# 3. comment belongs to subreddit
"""
import csv
from tgb.utils.utils import save_pkl, load_pkl
def load_csv_raw(fname):
"""
load the raw csv file and merge them into one
ts, src, dst, subreddit, reddit_id, reddit_parent_id, num_words, score
"""
out_dict = {}
num_lines = 0
max_words = 0
min_words = 10000000
max_score = 0
min_score = 1000000
"""
relation types:
0: user replies to user
1: user replies to subreddit
node types:
0: user
1: subreddit
"""
node_dict = {}
node_type_dict = {}
reddit_deg_dict = {}
node_deg_dict = {}
header = True
with open(fname, 'r') as f:
reader = csv.reader(f, delimiter =',')
#* ts, src, dst, subreddit, reddit_id, reddit_parent_id, num_words, score
#? 1388534400,32183137,51851117,AskReddit,t1_ceefsvy,t3_1u4kbf,32,1
for row in reader:
if header:
header = False
continue
ts = int(row[0])
src = row[1]
if (src not in node_dict):
node_dict[src] = len(node_dict)
node_type_dict[node_dict[src]] = 0
dst = row[2]
if (dst not in node_dict):
node_dict[dst] = len(node_dict)
node_type_dict[node_dict[dst]] = 0
if (src not in node_deg_dict):
node_deg_dict[src] = 1
else:
node_deg_dict[src] += 1
if (dst not in node_deg_dict):
node_deg_dict[dst] = 1
else:
node_deg_dict[dst] += 1
subreddit = row[3]
if (subreddit not in node_dict):
node_dict[subreddit] = len(node_dict)
node_type_dict[node_dict[subreddit]] = 1
if (subreddit not in reddit_deg_dict):
reddit_deg_dict[subreddit] = 1
else:
reddit_deg_dict[subreddit] += 1
num_words = int(row[6])
if (num_words > max_words):
max_words = num_words
if (num_words < min_words):
min_words = num_words
score = int(row[7])
if (score > max_score):
max_score = score
if (score < min_score):
min_score = score
if (ts in out_dict):
out_dict[ts][(node_dict[src], node_dict[dst], 0)] = (num_words, score)
out_dict[ts][(node_dict[src], node_dict[subreddit], 1)] = (num_words, score)
else:
out_dict[ts] = {}
out_dict[ts][(node_dict[src], node_dict[dst], 0)] = (num_words, score)
out_dict[ts][(node_dict[src], node_dict[subreddit], 1)] = (num_words, score)
num_lines += 1
print ("max words: ", max_words)
print ("min words: ", min_words)
print ("max score: ", max_score)
print ("min score: ", min_score)
return out_dict, num_lines, node_dict, node_type_dict, reddit_deg_dict, node_deg_dict
def load_csv_filtered_node(fname, low_deg_dict):
"""
load the raw csv file, remove edges with low degree nodes
ts, src, dst, subreddit, reddit_id, reddit_parent_id, num_words, score
"""
out_dict = {}
num_lines = 0
"""
relation types:
0: user replies to user
1: user replies to subreddit
node types:
0: user
1: subreddit
"""
node_dict = {}
node_type_dict = {}
header = True
with open(fname, 'r') as f:
reader = csv.reader(f, delimiter =',')
#* ts, src, dst, subreddit, reddit_id, reddit_parent_id, num_words, score
#? 1388534400,32183137,51851117,AskReddit,t1_ceefsvy,t3_1u4kbf,32,1
for row in reader:
if header:
header = False
continue
ts = int(row[0])
src = row[1]
dst = row[2]
#* filter low degree nodes
if (src in low_deg_dict or dst in low_deg_dict):
continue
if (src not in node_dict):
node_dict[src] = len(node_dict)
node_type_dict[node_dict[src]] = 0
if (dst not in node_dict):
node_dict[dst] = len(node_dict)
node_type_dict[node_dict[dst]] = 0
subreddit = row[3]
if (subreddit not in node_dict):
node_dict[subreddit] = len(node_dict)
node_type_dict[node_dict[subreddit]] = 1
num_words = int(row[6])
score = int(row[7])
if (ts in out_dict):
out_dict[ts][(node_dict[src], node_dict[dst], 0)] = (num_words, score)
out_dict[ts][(node_dict[src], node_dict[subreddit], 1)] = (num_words, score)
else:
out_dict[ts] = {}
out_dict[ts][(node_dict[src], node_dict[dst], 0)] = (num_words, score)
out_dict[ts][(node_dict[src], node_dict[subreddit], 1)] = (num_words, score)
num_lines += 1
return out_dict, num_lines, node_dict, node_type_dict
def writeNodeType(node_type_dict, outname):
r"""
write the node type mapping to a file
"""
with open(outname, 'w') as f:
writer = csv.writer(f, delimiter =',')
writer.writerow(['node_id', 'type'])
for key in node_type_dict:
writer.writerow([key, node_type_dict[key]])
def write2csv(outname, out_dict):
with open(outname, 'w') as f:
writer = csv.writer(f, delimiter =',')
writer.writerow(['ts', 'src', 'dst', 'relation_type', 'num_words', 'score'])
ts_list = list(out_dict.keys())
ts_list.sort()
for ts in ts_list:
for edge in out_dict[ts]:
head = edge[0]
tail = edge[1]
relation_type = edge[2]
num_words, score = out_dict[ts][edge]
row = [ts, head, tail, relation_type, num_words, score]
writer.writerow(row)
def writeNodeIDMapping(node_dict, outname):
r"""
write the node id mapping to a file
"""
with open(outname, 'w') as f:
writer = csv.writer(f, delimiter =',')
writer.writerow(['node_name', 'node_id'])
for key in node_dict:
writer.writerow([key, node_dict[key]])
def node_deg_filter(node_deg_dict):
"""
filter out nodes with degree less than threshold
"""
deg10_nodes = 0
deg100_nodes = 0
deg1000_nodes = 0
for key in node_deg_dict:
if (node_deg_dict[key] < 10):
deg10_nodes += 1
if (node_deg_dict[key] < 100):
deg100_nodes += 1
if (node_deg_dict[key] < 1000):
deg1000_nodes += 1
print ("nodes with degree less than 10: ", deg10_nodes)
print ("nodes with degree less than 100: ", deg100_nodes)
print ("nodes with degree less than 1000: ", deg1000_nodes)
def find_low_degree_nodes(node_deg_dict, threshold=10):
"""
find nodes with degree less than threshold
"""
low_degree_nodes = {}
for key in node_deg_dict:
if (node_deg_dict[key] < threshold):
low_degree_nodes[key] = 1
return low_degree_nodes
def main():
fname = "reddit_edgelist.csv"
_, _, _, _, _, node_deg_dict = load_csv_raw(fname)
# print ("checking node degree")
# node_deg_filter(node_deg_dict)
# print ("checking reddit degree")
# node_deg_filter(reddit_deg_dict)
# low_degree_nodes = find_low_degree_nodes(node_deg_dict, threshold=100)
# save_pkl(low_degree_nodes, 'low_degree_nodes.pkl')
low_degree_nodes = load_pkl('low_degree_nodes.pkl')
out_dict, num_lines, node_dict, node_type_dict = load_csv_filtered_node(fname, low_degree_nodes)
writeNodeType(node_type_dict, 'thgl-forum_nodetype.csv')
writeNodeIDMapping(node_dict, 'thgl-forum_nodeIDmapping.csv')
write2csv('thgl-forum_edgelist.csv', out_dict)
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/thgl_forum/thgl_forum_ns_gen.py
================================================
import time
from tgb.linkproppred.thg_negative_generator import THGNegativeEdgeGenerator
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
def main():
r"""
Generate negative edges for the validation or test phase
"""
print("*** Negative Sample Generation ***")
# setting the required parameters
num_neg_e_per_pos = 100 #-1
neg_sample_strategy = "node-type-filtered"
rnd_seed = 42
name = "thgl-forum"
dataset = PyGLinkPropPredDataset(name=name, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data_splits = {}
data_splits['train'] = data[train_mask]
data_splits['val'] = data[val_mask]
data_splits['test'] = data[test_mask]
min_node_idx = min(int(data.src.min()), int(data.dst.min()))
max_node_idx = max(int(data.src.max()), int(data.dst.max()))
neg_sampler = THGNegativeEdgeGenerator(
dataset_name=name,
first_node_id=min_node_idx,
last_node_id=max_node_idx,
node_type=dataset.node_type,
num_neg_e=num_neg_e_per_pos,
strategy=neg_sample_strategy,
rnd_seed=rnd_seed,
edge_data=data,
)
# generate evaluation set
partial_path = "."
# generate validation negative edge set
start_time = time.time()
split_mode = "val"
print(
f"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}"
)
neg_sampler.generate_negative_samples(
pos_edges=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path
)
print(
f"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}"
)
# generate test negative edge set
start_time = time.time()
split_mode = "test"
print(
f"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}"
)
neg_sampler.generate_negative_samples(
pos_edges=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path
)
print(
f"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}"
)
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/thgl_github/2024_01/github_extract.py
================================================
import json
from datetime import datetime
import glob
import gzip
import csv
"""
go to https://www.gharchive.org/
wget https://data.gharchive.org/2024-01-{01..31}-{0..23}.json.gz
Creates (src, edge_type, dst, time) edges from the GitHub archive JSON file.
Using the rules from https://arxiv.org/pdf/2007.01231 (page 11)
The parser creates 18 rules that are in the GITHUB-SE-1Y-Repo dataset. I wrote the meaning of the rules and sources and destination types here.
"""
rels = {
"IC_Created_IC_I": "IC_AO_C_I",
"IC_Created_U_IC": "U_SO_C_IC",
"I_Opened_U_I": "U_SE_O_I",
"I_Opened_I_R": "I_AO_O_R",
"I_Closed_U_I": "U_SE_C_I",
"I_Closed_I_R": "I_AO_C_R",
"I_Reopened_U_I": "U_SE_RO_I",
"I_Reopened_I_R": "I_AO_RO_R",
"PR_Opened_U_PR": "U_SO_O_P",
"PR_Opened_PR_R": "P_AO_O_R",
"PR_Closed_U_PR": "U_SO_C_P",
"PR_Closed_PR_R": "P_AO_C_R",
"PR_Reopened_U_PR": "U_SO_R_P",
"PR_Reopened_PR_R": "P_AO_R_R",
"PRRC_Created_U_PRC": "U_SO_C_PRC",
"PRRC_Created_PRC_PR": "PRC_AO_C_P",
"Forked_R_R": "R_FO_R",
"AddMember_U_R": "U_CO_A_R",
}
issue_comment_format = "/issue_comment/{}"
issue_format = "/issue/{}"
user_format = "/user/{}"
repo_format = "/repo/{}"
pull_request_format = "/pr/{}"
pull_request_review_comment_format = "/pr_review_comment/{}"
def str_to_timestamp(time_str):
dt = datetime.strptime(time_str, "%Y-%m-%dT%H:%M:%SZ")
return int(dt.timestamp())
def parse_issue_comment_events(event):
try:
if "action" not in event["payload"]:
return []
if event["payload"]["action"] == "created":
issue_comment_id = event["payload"]["comment"]["id"]
issue_id = event["payload"]["issue"]["id"]
user_id = event["actor"]["id"]
created_at = str_to_timestamp(event["created_at"])
ici_event = [
issue_comment_format.format(issue_comment_id),
rels["IC_Created_IC_I"],
issue_format.format(issue_id),
created_at,
]
uic_event = [
user_format.format(user_id),
rels["IC_Created_U_IC"],
issue_comment_format.format(issue_comment_id),
created_at,
]
return [ici_event, uic_event]
return []
except:
return []
def parse_issue_event(event):
try:
issue_id = event["payload"]["issue"]["id"]
user_id = event["actor"]["id"]
repo_id = event["repo"]["id"]
created_at = str_to_timestamp(event["created_at"])
action_map = {
"opened": ("I_Opened_U_I", "I_Opened_I_R"),
"closed": ("I_Closed_U_I", "I_Closed_I_R"),
"reopened": ("I_Reopened_U_I", "I_Reopened_I_R"),
}
for action, event_rels in action_map.items():
if event["payload"]["action"] == action:
ui_event = [
user_format.format(user_id),
rels[event_rels[0]],
issue_format.format(issue_id),
created_at,
]
ir_event = [
issue_format.format(issue_id),
rels[event_rels[1]],
repo_format.format(repo_id),
created_at,
]
return [ui_event, ir_event]
return []
except:
return []
def parse_pull_request_event(event):
try:
pull_request_id = event["payload"]["pull_request"]["id"]
user_id = event["actor"]["id"]
repo_id = event["repo"]["id"]
created_at = str_to_timestamp(event["created_at"])
action_map = {
"opened": ("PR_Opened_U_PR", "PR_Opened_PR_R"),
"closed": ("PR_Closed_U_PR", "PR_Closed_PR_R"),
"reopened": ("PR_Reopened_U_PR", "PR_Reopened_PR_R"),
}
for action, event_rels in action_map.items():
if event["payload"]["action"] == action:
upr_event = [
user_format.format(user_id),
rels[event_rels[0]],
pull_request_format.format(pull_request_id),
created_at,
]
prr_event = [
pull_request_format.format(pull_request_id),
rels[event_rels[1]],
repo_format.format(repo_id),
created_at,
]
return [upr_event, prr_event]
return []
except:
return []
def parse_pull_request_review_comment_event(event):
try:
pull_request_review_comment_id = event["payload"]["comment"]["id"]
pull_request_id = event["payload"]["pull_request"]["id"]
user_id = event["actor"]["id"]
created_at = str_to_timestamp(event["created_at"])
if event["payload"]["action"] == "created":
uprc_event = [
user_format.format(user_id),
rels["PRRC_Created_U_PRC"],
pull_request_review_comment_format.format(pull_request_review_comment_id),
created_at,
]
prcpr_event = [
pull_request_review_comment_format.format(pull_request_review_comment_id),
rels["PRRC_Created_PRC_PR"],
pull_request_format.format(pull_request_id),
created_at,
]
return [uprc_event, prcpr_event]
return []
except:
return []
def parse_fork_event(event):
try:
forkee_repo_id = event["payload"]["forkee"]["id"]
forked_repo_id = event["repo"]["id"]
created_at = str_to_timestamp(event["created_at"])
return [
[
repo_format.format(forkee_repo_id),
rels["Forked_R_R"],
repo_format.format(forked_repo_id),
created_at,
]
]
except:
return []
def parse_member_event(event):
try:
user_id = event["payload"]["member"]["id"]
repo_id = event["repo"]["id"]
created_at = str_to_timestamp(event["created_at"])
return [
[
user_format.format(user_id),
rels["AddMember_U_R"],
repo_format.format(repo_id),
created_at,
]
]
except:
return []
event_handler_dict = {
"IssueCommentEvent": parse_issue_comment_events,
"IssuesEvent": parse_issue_event,
"PullRequestEvent": parse_pull_request_event,
"PullRequestReviewCommentEvent": parse_pull_request_review_comment_event,
"ForkEvent": parse_fork_event,
"MemberEvent": parse_member_event,
}
def parse_event(event):
event_type = event["type"]
if event_type in event_handler_dict:
output_list = event_handler_dict[event_type](event)
# print("Got {} outputs for event type {}".format(len(output_list), event_type))
else:
# print("Unknown event type: {}".format(event_type))
output_list = []
return output_list
def parse_file(filename):
# events = []
output_dict = {}
num_edge = 1
#with open(filename) as f:
with gzip.open(filename, 'r') as f:
for i, line in enumerate(f):
djson = json.loads(line)
parsed_events = parse_event(djson)
if (len(parsed_events) > 0):
for edge in parsed_events:
#? ['/user/41898282', 'U_SE_O_I', '/issue/2061196208', 1704085558]
ts = int(edge[3])
head = edge[0]
rel = edge[1]
tail = edge[2]
if ts not in output_dict:
output_dict[ts] = {}
output_dict[ts][(head,tail,rel)] = 1
num_edge += 1
else:
if (head,tail,rel) in output_dict[ts]:
output_dict[ts][(head,tail,rel)] += 1
else:
output_dict[ts][(head,tail,rel)] = 1
num_edge += 1
print("Parsed {} events".format(num_edge))
return output_dict
def write2csv(outname, out_dict):
with open(outname, 'a') as f:
writer = csv.writer(f, delimiter =',')
# writer.writerow(['ts', 'head', 'tail', 'relation_type'])
ts_list = list(out_dict.keys())
ts_list.sort()
for ts in ts_list:
for edge in out_dict[ts]:
head = edge[0]
tail = edge[1]
relation_type = edge[2]
row = [ts, head, tail, relation_type]
writer.writerow(row)
def main():
total_edge_dict = {}
for file in glob.glob("*.json.gz"):
print ("processing,", file)
edge_dict = parse_file(file)
# print ('check for edge overlap')
# print(edge_dict.keys() & total_edge_dict.keys())
# print ("-------------------------")
#! write to csv after each file is processed.
# total_edge_dict.update(edge_dict)
outname = "github_01_2024.csv"
write2csv(outname, edge_dict)
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/thgl_github/2024_02/github_extract.py
================================================
import json
from datetime import datetime
import glob
import gzip
import csv
"""
go to https://www.gharchive.org/
wget https://data.gharchive.org/2024-01-{01..31}-{0..23}.json.gz
Creates (src, edge_type, dst, time) edges from the GitHub archive JSON file.
Using the rules from https://arxiv.org/pdf/2007.01231 (page 11)
The parser creates 18 rules that are in the GITHUB-SE-1Y-Repo dataset. I wrote the meaning of the rules and sources and destination types here.
"""
rels = {
"IC_Created_IC_I": "IC_AO_C_I",
"IC_Created_U_IC": "U_SO_C_IC",
"I_Opened_U_I": "U_SE_O_I",
"I_Opened_I_R": "I_AO_O_R",
"I_Closed_U_I": "U_SE_C_I",
"I_Closed_I_R": "I_AO_C_R",
"I_Reopened_U_I": "U_SE_RO_I",
"I_Reopened_I_R": "I_AO_RO_R",
"PR_Opened_U_PR": "U_SO_O_P",
"PR_Opened_PR_R": "P_AO_O_R",
"PR_Closed_U_PR": "U_SO_C_P",
"PR_Closed_PR_R": "P_AO_C_R",
"PR_Reopened_U_PR": "U_SO_R_P",
"PR_Reopened_PR_R": "P_AO_R_R",
"PRRC_Created_U_PRC": "U_SO_C_PRC",
"PRRC_Created_PRC_PR": "PRC_AO_C_P",
"Forked_R_R": "R_FO_R",
"AddMember_U_R": "U_CO_A_R",
}
issue_comment_format = "/issue_comment/{}"
issue_format = "/issue/{}"
user_format = "/user/{}"
repo_format = "/repo/{}"
pull_request_format = "/pr/{}"
pull_request_review_comment_format = "/pr_review_comment/{}"
def str_to_timestamp(time_str):
dt = datetime.strptime(time_str, "%Y-%m-%dT%H:%M:%SZ")
return int(dt.timestamp())
def parse_issue_comment_events(event):
try:
if "action" not in event["payload"]:
return []
if event["payload"]["action"] == "created":
issue_comment_id = event["payload"]["comment"]["id"]
issue_id = event["payload"]["issue"]["id"]
user_id = event["actor"]["id"]
created_at = str_to_timestamp(event["created_at"])
ici_event = [
issue_comment_format.format(issue_comment_id),
rels["IC_Created_IC_I"],
issue_format.format(issue_id),
created_at,
]
uic_event = [
user_format.format(user_id),
rels["IC_Created_U_IC"],
issue_comment_format.format(issue_comment_id),
created_at,
]
return [ici_event, uic_event]
return []
except:
return []
def parse_issue_event(event):
try:
issue_id = event["payload"]["issue"]["id"]
user_id = event["actor"]["id"]
repo_id = event["repo"]["id"]
created_at = str_to_timestamp(event["created_at"])
action_map = {
"opened": ("I_Opened_U_I", "I_Opened_I_R"),
"closed": ("I_Closed_U_I", "I_Closed_I_R"),
"reopened": ("I_Reopened_U_I", "I_Reopened_I_R"),
}
for action, event_rels in action_map.items():
if event["payload"]["action"] == action:
ui_event = [
user_format.format(user_id),
rels[event_rels[0]],
issue_format.format(issue_id),
created_at,
]
ir_event = [
issue_format.format(issue_id),
rels[event_rels[1]],
repo_format.format(repo_id),
created_at,
]
return [ui_event, ir_event]
return []
except:
return []
def parse_pull_request_event(event):
try:
pull_request_id = event["payload"]["pull_request"]["id"]
user_id = event["actor"]["id"]
repo_id = event["repo"]["id"]
created_at = str_to_timestamp(event["created_at"])
action_map = {
"opened": ("PR_Opened_U_PR", "PR_Opened_PR_R"),
"closed": ("PR_Closed_U_PR", "PR_Closed_PR_R"),
"reopened": ("PR_Reopened_U_PR", "PR_Reopened_PR_R"),
}
for action, event_rels in action_map.items():
if event["payload"]["action"] == action:
upr_event = [
user_format.format(user_id),
rels[event_rels[0]],
pull_request_format.format(pull_request_id),
created_at,
]
prr_event = [
pull_request_format.format(pull_request_id),
rels[event_rels[1]],
repo_format.format(repo_id),
created_at,
]
return [upr_event, prr_event]
return []
except:
return []
def parse_pull_request_review_comment_event(event):
try:
pull_request_review_comment_id = event["payload"]["comment"]["id"]
pull_request_id = event["payload"]["pull_request"]["id"]
user_id = event["actor"]["id"]
created_at = str_to_timestamp(event["created_at"])
if event["payload"]["action"] == "created":
uprc_event = [
user_format.format(user_id),
rels["PRRC_Created_U_PRC"],
pull_request_review_comment_format.format(pull_request_review_comment_id),
created_at,
]
prcpr_event = [
pull_request_review_comment_format.format(pull_request_review_comment_id),
rels["PRRC_Created_PRC_PR"],
pull_request_format.format(pull_request_id),
created_at,
]
return [uprc_event, prcpr_event]
return []
except:
return []
def parse_fork_event(event):
try:
forkee_repo_id = event["payload"]["forkee"]["id"]
forked_repo_id = event["repo"]["id"]
created_at = str_to_timestamp(event["created_at"])
return [
[
repo_format.format(forkee_repo_id),
rels["Forked_R_R"],
repo_format.format(forked_repo_id),
created_at,
]
]
except:
return []
def parse_member_event(event):
try:
user_id = event["payload"]["member"]["id"]
repo_id = event["repo"]["id"]
created_at = str_to_timestamp(event["created_at"])
return [
[
user_format.format(user_id),
rels["AddMember_U_R"],
repo_format.format(repo_id),
created_at,
]
]
except:
return []
event_handler_dict = {
"IssueCommentEvent": parse_issue_comment_events,
"IssuesEvent": parse_issue_event,
"PullRequestEvent": parse_pull_request_event,
"PullRequestReviewCommentEvent": parse_pull_request_review_comment_event,
"ForkEvent": parse_fork_event,
"MemberEvent": parse_member_event,
}
def parse_event(event):
event_type = event["type"]
if event_type in event_handler_dict:
output_list = event_handler_dict[event_type](event)
# print("Got {} outputs for event type {}".format(len(output_list), event_type))
else:
# print("Unknown event type: {}".format(event_type))
output_list = []
return output_list
def parse_file(filename):
# events = []
output_dict = {}
num_edge = 1
#with open(filename) as f:
with gzip.open(filename, 'r') as f:
for i, line in enumerate(f):
djson = json.loads(line)
parsed_events = parse_event(djson)
if (len(parsed_events) > 0):
for edge in parsed_events:
#? ['/user/41898282', 'U_SE_O_I', '/issue/2061196208', 1704085558]
ts = int(edge[3])
head = edge[0]
rel = edge[1]
tail = edge[2]
if ts not in output_dict:
output_dict[ts] = {}
output_dict[ts][(head,tail,rel)] = 1
num_edge += 1
else:
if (head,tail,rel) in output_dict[ts]:
output_dict[ts][(head,tail,rel)] += 1
else:
output_dict[ts][(head,tail,rel)] = 1
num_edge += 1
print("Parsed {} events".format(num_edge))
return output_dict
def write2csv(outname, out_dict):
with open(outname, 'a') as f:
writer = csv.writer(f, delimiter =',')
# writer.writerow(['ts', 'head', 'tail', 'relation_type'])
ts_list = list(out_dict.keys())
ts_list.sort()
for ts in ts_list:
for edge in out_dict[ts]:
head = edge[0]
tail = edge[1]
relation_type = edge[2]
row = [ts, head, tail, relation_type]
writer.writerow(row)
def main():
total_edge_dict = {}
for file in glob.glob("*.json.gz"):
print ("processing,", file)
edge_dict = parse_file(file)
# print ('check for edge overlap')
# print(edge_dict.keys() & total_edge_dict.keys())
# print ("-------------------------")
#! write to csv after each file is processed.
# total_edge_dict.update(edge_dict)
outname = "github_02_2024.csv"
write2csv(outname, edge_dict)
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/thgl_github/2024_03/github_extract.py
================================================
import json
from datetime import datetime
import glob
import gzip
import csv
"""
go to https://www.gharchive.org/
wget https://data.gharchive.org/2024-01-{01..31}-{0..23}.json.gz
Creates (src, edge_type, dst, time) edges from the GitHub archive JSON file.
Using the rules from https://arxiv.org/pdf/2007.01231 (page 11)
The parser creates 18 rules that are in the GITHUB-SE-1Y-Repo dataset. I wrote the meaning of the rules and sources and destination types here.
"""
rels = {
"IC_Created_IC_I": "IC_AO_C_I",
"IC_Created_U_IC": "U_SO_C_IC",
"I_Opened_U_I": "U_SE_O_I",
"I_Opened_I_R": "I_AO_O_R",
"I_Closed_U_I": "U_SE_C_I",
"I_Closed_I_R": "I_AO_C_R",
"I_Reopened_U_I": "U_SE_RO_I",
"I_Reopened_I_R": "I_AO_RO_R",
"PR_Opened_U_PR": "U_SO_O_P",
"PR_Opened_PR_R": "P_AO_O_R",
"PR_Closed_U_PR": "U_SO_C_P",
"PR_Closed_PR_R": "P_AO_C_R",
"PR_Reopened_U_PR": "U_SO_R_P",
"PR_Reopened_PR_R": "P_AO_R_R",
"PRRC_Created_U_PRC": "U_SO_C_PRC",
"PRRC_Created_PRC_PR": "PRC_AO_C_P",
"Forked_R_R": "R_FO_R",
"AddMember_U_R": "U_CO_A_R",
}
issue_comment_format = "/issue_comment/{}"
issue_format = "/issue/{}"
user_format = "/user/{}"
repo_format = "/repo/{}"
pull_request_format = "/pr/{}"
pull_request_review_comment_format = "/pr_review_comment/{}"
def str_to_timestamp(time_str):
dt = datetime.strptime(time_str, "%Y-%m-%dT%H:%M:%SZ")
return int(dt.timestamp())
def parse_issue_comment_events(event):
try:
if "action" not in event["payload"]:
return []
if event["payload"]["action"] == "created":
issue_comment_id = event["payload"]["comment"]["id"]
issue_id = event["payload"]["issue"]["id"]
user_id = event["actor"]["id"]
created_at = str_to_timestamp(event["created_at"])
ici_event = [
issue_comment_format.format(issue_comment_id),
rels["IC_Created_IC_I"],
issue_format.format(issue_id),
created_at,
]
uic_event = [
user_format.format(user_id),
rels["IC_Created_U_IC"],
issue_comment_format.format(issue_comment_id),
created_at,
]
return [ici_event, uic_event]
return []
except:
return []
def parse_issue_event(event):
try:
issue_id = event["payload"]["issue"]["id"]
user_id = event["actor"]["id"]
repo_id = event["repo"]["id"]
created_at = str_to_timestamp(event["created_at"])
action_map = {
"opened": ("I_Opened_U_I", "I_Opened_I_R"),
"closed": ("I_Closed_U_I", "I_Closed_I_R"),
"reopened": ("I_Reopened_U_I", "I_Reopened_I_R"),
}
for action, event_rels in action_map.items():
if event["payload"]["action"] == action:
ui_event = [
user_format.format(user_id),
rels[event_rels[0]],
issue_format.format(issue_id),
created_at,
]
ir_event = [
issue_format.format(issue_id),
rels[event_rels[1]],
repo_format.format(repo_id),
created_at,
]
return [ui_event, ir_event]
return []
except:
return []
def parse_pull_request_event(event):
try:
pull_request_id = event["payload"]["pull_request"]["id"]
user_id = event["actor"]["id"]
repo_id = event["repo"]["id"]
created_at = str_to_timestamp(event["created_at"])
action_map = {
"opened": ("PR_Opened_U_PR", "PR_Opened_PR_R"),
"closed": ("PR_Closed_U_PR", "PR_Closed_PR_R"),
"reopened": ("PR_Reopened_U_PR", "PR_Reopened_PR_R"),
}
for action, event_rels in action_map.items():
if event["payload"]["action"] == action:
upr_event = [
user_format.format(user_id),
rels[event_rels[0]],
pull_request_format.format(pull_request_id),
created_at,
]
prr_event = [
pull_request_format.format(pull_request_id),
rels[event_rels[1]],
repo_format.format(repo_id),
created_at,
]
return [upr_event, prr_event]
return []
except:
return []
def parse_pull_request_review_comment_event(event):
try:
pull_request_review_comment_id = event["payload"]["comment"]["id"]
pull_request_id = event["payload"]["pull_request"]["id"]
user_id = event["actor"]["id"]
created_at = str_to_timestamp(event["created_at"])
if event["payload"]["action"] == "created":
uprc_event = [
user_format.format(user_id),
rels["PRRC_Created_U_PRC"],
pull_request_review_comment_format.format(pull_request_review_comment_id),
created_at,
]
prcpr_event = [
pull_request_review_comment_format.format(pull_request_review_comment_id),
rels["PRRC_Created_PRC_PR"],
pull_request_format.format(pull_request_id),
created_at,
]
return [uprc_event, prcpr_event]
return []
except:
return []
def parse_fork_event(event):
try:
forkee_repo_id = event["payload"]["forkee"]["id"]
forked_repo_id = event["repo"]["id"]
created_at = str_to_timestamp(event["created_at"])
return [
[
repo_format.format(forkee_repo_id),
rels["Forked_R_R"],
repo_format.format(forked_repo_id),
created_at,
]
]
except:
return []
def parse_member_event(event):
try:
user_id = event["payload"]["member"]["id"]
repo_id = event["repo"]["id"]
created_at = str_to_timestamp(event["created_at"])
return [
[
user_format.format(user_id),
rels["AddMember_U_R"],
repo_format.format(repo_id),
created_at,
]
]
except:
return []
event_handler_dict = {
"IssueCommentEvent": parse_issue_comment_events,
"IssuesEvent": parse_issue_event,
"PullRequestEvent": parse_pull_request_event,
"PullRequestReviewCommentEvent": parse_pull_request_review_comment_event,
"ForkEvent": parse_fork_event,
"MemberEvent": parse_member_event,
}
def parse_event(event):
event_type = event["type"]
if event_type in event_handler_dict:
output_list = event_handler_dict[event_type](event)
# print("Got {} outputs for event type {}".format(len(output_list), event_type))
else:
# print("Unknown event type: {}".format(event_type))
output_list = []
return output_list
def parse_file(filename):
# events = []
output_dict = {}
num_edge = 1
#with open(filename) as f:
with gzip.open(filename, 'r') as f:
for i, line in enumerate(f):
djson = json.loads(line)
parsed_events = parse_event(djson)
if (len(parsed_events) > 0):
for edge in parsed_events:
#? ['/user/41898282', 'U_SE_O_I', '/issue/2061196208', 1704085558]
ts = int(edge[3])
head = edge[0]
rel = edge[1]
tail = edge[2]
if ts not in output_dict:
output_dict[ts] = {}
output_dict[ts][(head,tail,rel)] = 1
num_edge += 1
else:
if (head,tail,rel) in output_dict[ts]:
output_dict[ts][(head,tail,rel)] += 1
else:
output_dict[ts][(head,tail,rel)] = 1
num_edge += 1
print("Parsed {} events".format(num_edge))
return output_dict
def write2csv(outname, out_dict):
with open(outname, 'a') as f:
writer = csv.writer(f, delimiter =',')
# writer.writerow(['ts', 'head', 'tail', 'relation_type'])
ts_list = list(out_dict.keys())
ts_list.sort()
for ts in ts_list:
for edge in out_dict[ts]:
head = edge[0]
tail = edge[1]
relation_type = edge[2]
row = [ts, head, tail, relation_type]
writer.writerow(row)
def main():
total_edge_dict = {}
for file in glob.glob("*.json.gz"):
print ("processing,", file)
edge_dict = parse_file(file)
# print ('check for edge overlap')
# print(edge_dict.keys() & total_edge_dict.keys())
# print ("-------------------------")
#! write to csv after each file is processed.
# total_edge_dict.update(edge_dict)
outname = "github_03_2024.csv"
write2csv(outname, edge_dict)
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/thgl_github/extract_subset.py
================================================
import csv
def load_edgelist(file_path, freq_threshold=5):
"""
ts, head, tail, relation_type
1704085200,/user/34452971,/pr/1660752740,U_SO_C_P
"""
first_row = True
edge_dict = {}
num_nodes = 0
num_edges = 0
num_rels = 0
node_dict = {}
edge_freq_dict = {}
num_lines = 0
#! identify node type with least amount of edges
node_type_freq = {}
with open(file_path, 'r') as f:
reader = csv.reader(f, delimiter =',')
for row in reader:
if first_row:
first_row = False
continue
ts = int(row[0])
head = row[1]
head_type = head.split("/")[1]
if (head_type not in node_type_freq):
node_type_freq[head_type] = 1
else:
node_type_freq[head_type] += 1
tail = row[2]
tail_type = tail.split("/")[1]
if (tail_type not in node_type_freq):
node_type_freq[tail_type] = 1
else:
node_type_freq[tail_type] += 1
relation_type = row[3]
if head not in node_dict:
node_dict[head] = 1
num_nodes += 1
else:
node_dict[head] += 1
if tail not in node_dict:
node_dict[tail] = 1
num_nodes += 1
else:
node_dict[tail] += 1
if relation_type not in edge_freq_dict:
edge_freq_dict[relation_type] = 1
num_rels += 1
else:
edge_freq_dict[relation_type] += 1
num_lines += 1
print ("there are ", num_lines, " edges")
print ("there are ", num_nodes, " nodes")
print ("there are ", num_rels, " relations")
node_freq5 = 0
node_freq10 = 0
node_freq100 = 0
node_freq1000 = 0
low_freq_dict = {}
for k, v in node_dict.items():
if v <= freq_threshold:
low_freq_dict[k] = 1
node_freq5 += 1
if v >= 10:
node_freq10 += 1
if v >= 100:
node_freq100 += 1
if v >= 1000:
node_freq1000 += 1
print ("there are ", node_freq5, " nodes with frequency <= ", freq_threshold, " (inclusive)")
print ("there are ", node_freq10, " nodes with frequency >= 10")
print ("there are ", node_freq100, " nodes with frequency >= 100")
print ("there are ", node_freq1000, " nodes with frequency >= 1000")
# return node_freq10_dict
return low_freq_dict, node_type_freq
def subset_by_node(file_path, low_freq_dict):
first_row = True
edge_dict = {}
with open(file_path, 'r') as f:
reader = csv.reader(f, delimiter =',')
for row in reader:
if first_row:
first_row = False
continue
ts = int(row[0])
head = row[1]
tail = row[2]
relation_type = row[3]
#! remove any edges that belongs any node with degree one
if (head in low_freq_dict) or (tail in low_freq_dict):
continue
# if (head in node_dict) or (tail in node_dict):
if ts not in edge_dict:
edge_dict[ts] = {}
if (head,tail,relation_type) not in edge_dict[ts]:
edge_dict[ts][(head,tail,relation_type)] = 1
else:
edge_dict[ts][(head,tail,relation_type)] += 1
return edge_dict
def subset_by_node_type(file_path, remove_node_type_dict, low_freq_dict=None):
first_row = True
edge_dict = {}
node_dict = {}
num_edges = 0
if (low_freq_dict is not None):
check_low_freq = True
with open(file_path, 'r') as f:
reader = csv.reader(f, delimiter =',')
for row in reader:
if first_row:
first_row = False
continue
ts = int(row[0])
head = row[1]
tail = row[2]
if (head in low_freq_dict) or (tail in low_freq_dict):
continue
head_type = head.split("/")[1]
tail_type = tail.split("/")[1]
relation_type = row[3]
if (head_type in remove_node_type_dict) or (tail_type in remove_node_type_dict):
continue
if (head not in node_dict):
node_dict[head] = 1
if (tail not in node_dict):
node_dict[tail] = 1
num_edges += 1
if ts not in edge_dict:
edge_dict[ts] = {}
if (head,tail,relation_type) not in edge_dict[ts]:
edge_dict[ts][(head,tail,relation_type)] = 1
else:
edge_dict[ts][(head,tail,relation_type)] += 1
print ("there are ", num_edges, " edges in the output file")
print ("there are ", len(node_dict), " nodes in the output file")
return edge_dict
def write2csv(outname, out_dict):
num_edges = 0
with open(outname, 'w') as f:
writer = csv.writer(f, delimiter =',')
writer.writerow(['ts', 'head', 'tail', 'relation_type'])
ts_list = list(out_dict.keys())
ts_list.sort()
for ts in ts_list:
for edge in out_dict[ts]:
head = edge[0]
tail = edge[1]
relation_type = edge[2]
row = [ts, head, tail, relation_type]
writer.writerow(row)
num_edges += 1
print ("there are ", num_edges, " edges in the output file")
def combine_edgelist(file_paths, outname):
with open(outname, 'w') as f:
writer = csv.writer(f, delimiter =',')
writer.writerow(['ts', 'head', 'tail', 'relation_type'])
for file_path in file_paths:
first_row = True
with open(file_path, 'r') as f:
reader = csv.reader(f, delimiter =',')
for row in reader:
if first_row:
first_row = False
continue
ts = int(row[0])
head = row[1]
tail = row[2]
relation_type = row[3]
writer.writerow([ts, head, tail, relation_type])
def main():
file_path = "github_03_2024.csv"
freq_threshold = 2
low_freq_dict, node_type_dict = load_edgelist(file_path, freq_threshold=freq_threshold)
remove_node_type_dict = {'issue_comment':1, 'pr_review_comment':1} #{'issue_comment':1, 'pr_review_comment':1, 'issue':1}
edge_dict = subset_by_node_type(file_path, remove_node_type_dict, low_freq_dict=low_freq_dict)
# edge_dict = subset_by_node(file_path, low_freq_dict=low_freq_dict)
outname = "github_03_2024_subset.csv"
write2csv(outname, edge_dict)
# file_paths = ["github_01_2024_subset.csv", "github_02_2024_subset.csv", "github_03_2024_subset.csv"]
# outname = "thgl-github_edges.csv"
# combine_edgelist(file_paths, outname)
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/thgl_github/thgl_github.py
================================================
import csv
import datetime
import glob, os
def load_csv_raw(fname):
"""
load the raw csv file and merge them into one
"""
out_dict = {}
num_lines = 0
with open(fname, 'r') as f:
reader = csv.reader(f, delimiter ='\t')
#* /user/10746682 U_SO_C_IC /issue_comment/455195715 1547754198
for row in reader:
head = row[0]
relation_type = row[1]
tail = row[2]
ts = int(row[3])
if (ts in out_dict):
if (head, tail, relation_type) in out_dict[ts]:
out_dict[ts][(head, tail, relation_type)] += 1
else:
out_dict[ts][(head, tail, relation_type)] = 1
else:
out_dict[ts] = {}
out_dict[ts][(head, tail, relation_type)] = 1
num_lines += 1
return out_dict, num_lines
def write2csv(outname, out_dict):
with open(outname, 'w') as f:
writer = csv.writer(f, delimiter =',')
writer.writerow(['ts', 'head', 'tail', 'relation_type'])
ts_list = list(out_dict.keys())
ts_list.sort()
for ts in ts_list:
for edge in out_dict[ts]:
head = edge[0]
tail = edge[1]
relation_type = edge[2]
row = [ts, head, tail, relation_type]
writer.writerow(row)
def load_edgelist(fname):
"""
load the edgelist
"""
node_dict = {} # {node_name: node_id}
node_type_dict = {} # {node_id: node_type}
rel_type_dict = {}
edge_dict = {} # {edge: edge_type}
node_type_mapping = {}
num_edges = 0
with open(fname, 'r') as f:
reader = csv.reader(f, delimiter =',')
first_row = True
for row in reader:
if first_row:
first_row = False
continue
ts = int(row[0])
head = row[1]
tail = row[2]
relation_type = row[3]
head_strs = head.split('/')
tail_strs = tail.split('/')
head_type = head_strs[1]
tail_type = tail_strs[1]
if head_type not in node_type_mapping:
node_type_mapping[head_type] = len(node_type_mapping)
if tail_type not in node_type_mapping:
node_type_mapping[tail_type] = len(node_type_mapping)
if head not in node_dict:
node_dict[head] = len(node_dict)
node_type_dict[node_dict[head]] = node_type_mapping[head_type]
if tail not in node_dict:
node_dict[tail] = len(node_dict)
node_type_dict[node_dict[tail]] = node_type_mapping[tail_type]
if relation_type not in rel_type_dict:
rel_type_dict[relation_type] = len(rel_type_dict)
if ts not in edge_dict:
edge_dict[ts] = {}
edge_dict[ts][(node_dict[head], node_dict[tail], rel_type_dict[relation_type])] = 1
num_edges += 1
print ("there are {} nodes".format(len(node_dict)))
print ("there are {} edges".format(num_edges))
return node_dict, node_type_dict, edge_dict, rel_type_dict, node_type_mapping
def writeNodeType(node_type_dict, outname):
r"""
write the node type mapping to a file
"""
with open(outname, 'w') as f:
writer = csv.writer(f, delimiter =',')
writer.writerow(['node_id', 'type'])
for key in node_type_dict:
writer.writerow([key, node_type_dict[key]])
def writeEdgeTypeMapping(edge_type_dict, outname):
r"""
write the edge type mapping to a file
"""
with open(outname, 'w') as f:
writer = csv.writer(f, delimiter =',')
writer.writerow(['edge_id', 'type'])
for key in edge_type_dict:
writer.writerow([key, edge_type_dict[key]])
def writeNodeTypeMapping(node_type_dict, outname):
r"""
write the edge type mapping to a file
"""
with open(outname, 'w') as f:
writer = csv.writer(f, delimiter =',')
writer.writerow(['node_type_id', 'type'])
for key in node_type_dict:
writer.writerow([key, node_type_dict[key]])
def write2edgelist(out_dict, outname):
r"""
Write the dictionary to a csv file
"""
num_lines = 0
with open(outname, 'w') as f:
writer = csv.writer(f, delimiter =',')
writer.writerow(['timestamp', 'head', 'tail', 'relation_type'])
dates = list(out_dict.keys())
dates.sort()
for date in dates:
for edge in out_dict[date]:
head = edge[0]
tail = edge[1]
relation_type = int(edge[2])
row = [date, head, tail, relation_type]
writer.writerow(row)
num_lines += 1
print ("there are {} lines in the file".format(num_lines))
def main():
# """
# concatenate edgelists
# """
# total_lines = 0
# total_edge_dict = {}
# #1. find all files with .txt in the folder
# for file in glob.glob("*.txt"):
# # outname = file[7:11] + "_edgelist.csv"
# print ("processing", file)
# edge_dict, num_lines = load_csv_raw(file)
# total_lines += num_lines
# print ("-----------------------------------")
# print ("file, ", file)
# print ("number of lines, ", num_lines)
# print ("number of ts, ", len(edge_dict))
# print ("-----------------------------------")
# total_edge_dict.update(edge_dict)
# outname = "all_edgelist.csv"
# write2csv(outname, total_edge_dict)
fname ="github_03_2024_subset.csv"#"github_01_2024_subset.csv" #"thgl-github_edges.csv" #"all_edgelist.csv"
node_dict, node_type_dict, edge_dict, edge_type_dict, node_type_mapping = load_edgelist(fname)
write2edgelist (edge_dict, "thgl-github_edgelist.csv")
writeNodeType(node_type_dict, "thgl-github_nodetype.csv")
writeEdgeTypeMapping(edge_type_dict, "thgl-github_edgemapping.csv")
writeNodeTypeMapping(node_type_mapping, "thgl-github_nodemapping.csv")
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/thgl_github/thgl_github_ns_gen.py
================================================
import time
from tgb.linkproppred.thg_negative_generator import THGNegativeEdgeGenerator
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
def main():
r"""
Generate negative edges for the validation or test phase
"""
print("*** Negative Sample Generation ***")
# setting the required parameters
num_neg_e_per_pos = 20 #1000
neg_sample_strategy = "node-type-filtered"
rnd_seed = 42
name = "thgl-github"
dataset = PyGLinkPropPredDataset(name=name, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data_splits = {}
data_splits['train'] = data[train_mask]
data_splits['val'] = data[val_mask]
data_splits['test'] = data[test_mask]
min_node_idx = min(int(data.src.min()), int(data.dst.min()))
max_node_idx = max(int(data.src.max()), int(data.dst.max()))
neg_sampler = THGNegativeEdgeGenerator(
dataset_name=name,
first_node_id=min_node_idx,
last_node_id=max_node_idx,
node_type=dataset.node_type,
num_neg_e=num_neg_e_per_pos,
strategy=neg_sample_strategy,
rnd_seed=rnd_seed,
edge_data=data,
)
# generate evaluation set
partial_path = "."
# generate validation negative edge set
start_time = time.time()
split_mode = "val"
print(
f"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}"
)
neg_sampler.generate_negative_samples(
pos_edges=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path
)
print(
f"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}"
)
# generate test negative edge set
start_time = time.time()
split_mode = "test"
print(
f"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}"
)
neg_sampler.generate_negative_samples(
pos_edges=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path
)
print(
f"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}"
)
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/thgl_myket/thgl_myket.py
================================================
import dateutil.parser as dparser
import csv
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from os import listdir
from datetime import datetime
def date2ts(date_str: str) -> float:
r"""
convert date string to timestamp
"""
TIME_FORMAT = "%Y-%m-%d %H:%M:%S.%f"
date_cur = datetime.strptime(date_str, TIME_FORMAT)
return int(date_cur.timestamp())
"""
app_name user_id datetime is_update
com.cocoplay.erpetvet 392863962 2020-06-17 23:55:17.460 0
com.titan.royal 790103760 2020-06-17 23:55:19.583 0
com.tencent.ig -1651723014 2020-06-17 23:55:20.647 0
com.cyberlink.youperfect -2116095669 2020-06-17 23:55:20.723 1
com.whatsapp 1591275459 2020-06-17 23:55:20.820 0
com.nexttechgamesstudio.house.paint.craft.coloring.book.pages -984956295 2020-06-17 23:55:21.840 0
com.lenovo.anyshare.gps 1643649087 2020-06-17 23:55:21.853 1
com.kurankarim.mp3 1316745267 2020-06-17 23:55:22.537 0
com.google.android.dialer 239675079 2020-06-17 23:55:22.950 1
com.ma.textgraphy -951808761 2020-06-17 23:55:22.977 0
ir.shahbaz.SHZToolBox 1643649087 2020-06-17 23:55:22.987 1
picture.instagram.makers -1898448882 2020-06-17 23:55:23.010 0
ir.shahbaz.SHZToolBox 780669111 2020-06-17 23:55:23.600 1
fantasy.survival.game.rpg 1849120437 2020-06-17 23:55:23.980 0
com.ags.flying.muscle.car.transform.robot.war.robot.games 1751574033 2020-06-17 23:55:24.680 0
"""
def read_csv2dict(fname):
r"""
load from the raw data and retrieve, timestamp, head, tail, relation
also return a mapping from node text to node id
convert all dates into unix timestamps
"""
out_dict = {}
first_row = True
num_lines = 0
with open(fname, 'r') as f:
reader = csv.reader(f, delimiter ='\t')
for row in reader:
if first_row:
first_row = False
continue
app = row[0]
user = row[1]
date = row[2]
is_update = int(row[3])
if (len(date) == 0 or date is None):
continue
else:
ts = date2ts(date)
head = user
tail = app
if (ts not in out_dict):
out_dict[ts] = {(head,tail,is_update): 1}
else:
out_dict[ts][(head,tail,is_update)] = 1
num_lines += 1
print ("there are {} lines in the file".format(num_lines))
return out_dict
# def writeIDmapping(id_dict, outname):
# r"""
# write the id mapping to a file
# """
# with open(outname, 'w') as f:
# writer = csv.writer(f, delimiter =',')
# writer.writerow(['ID', 'name'])
# for key in id_dict:
# writer.writerow([key, id_dict[key]])
def edge2nodetype(out_dict):
r"""
1. remap node id of nodes
2. output the node_type file
"""
node_dict = {} # {node_name: node_id}
node_type_dict = {} # {node_id: node_type}
edge_dict = {} # {edge: edge_type}
dates = list(out_dict.keys())
dates.sort()
for date in dates:
for edge in out_dict[date]:
head = edge[0] # user node
tail = edge[1] # app node
relation_type = int(edge[2])
if head not in node_dict:
node_dict[head] = len(node_dict)
node_type_dict[node_dict[head]] = 0 #user
if tail not in node_dict:
node_dict[tail] = len(node_dict)
node_type_dict[node_dict[tail]] = 1 #app
if date not in edge_dict:
edge_dict[date] = {}
edge_dict[date][(node_dict[head], node_dict[tail], relation_type)] = 1
return node_dict, node_type_dict, edge_dict
def writeNodeType(node_type_dict, outname):
r"""
write the node type mapping to a file
"""
with open(outname, 'w') as f:
writer = csv.writer(f, delimiter =',')
writer.writerow(['node_id', 'type'])
for key in node_type_dict:
writer.writerow([key, node_type_dict[key]])
def write2edgelist(out_dict, outname):
r"""
Write the dictionary to a csv file
"""
num_lines = 0
with open(outname, 'w') as f:
writer = csv.writer(f, delimiter =',')
writer.writerow(['timestamp', 'head', 'tail', 'relation_type'])
dates = list(out_dict.keys())
dates.sort()
for date in dates:
for edge in out_dict[date]:
head = edge[0]
tail = edge[1]
relation_type = int(edge[2])
row = [date, head, tail, relation_type]
writer.writerow(row)
num_lines += 1
print ("there are {} lines in the file".format(num_lines))
"""
need to have edgelist with n_ids
need to have a node_type file to document which nodes are which type
"""
def main():
fname = "raw_myket_input-001.csv"
out_dict = read_csv2dict(fname)
# write2edgelist (out_dict, "thgl-myket_edgelist.csv")
node_dict, node_type_dict, edge_dict = edge2nodetype(out_dict)
write2edgelist (edge_dict, "thgl-myket_edgelist.csv")
writeNodeType(node_type_dict, "thgl-myket_nodetype.csv")
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/thgl_myket/thgl_myket_ns_gen.py
================================================
import time
from tgb.linkproppred.thg_negative_generator import THGNegativeEdgeGenerator
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
def main():
r"""
Generate negative edges for the validation or test phase
"""
print("*** Negative Sample Generation ***")
# setting the required parameters
num_neg_e_per_pos = 20 #-1
neg_sample_strategy = "node-type-filtered"
rnd_seed = 42
name = "thgl-myket"
dataset = PyGLinkPropPredDataset(name=name, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data_splits = {}
data_splits['train'] = data[train_mask]
data_splits['val'] = data[val_mask]
data_splits['test'] = data[test_mask]
min_node_idx = min(int(data.src.min()), int(data.dst.min()))
max_node_idx = max(int(data.src.max()), int(data.dst.max()))
neg_sampler = THGNegativeEdgeGenerator(
dataset_name=name,
first_node_id=min_node_idx,
last_node_id=max_node_idx,
node_type=dataset.node_type,
num_neg_e=num_neg_e_per_pos,
strategy=neg_sample_strategy,
rnd_seed=rnd_seed,
edge_data=data,
)
# generate evaluation set
partial_path = "."
# generate validation negative edge set
start_time = time.time()
split_mode = "val"
print(
f"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}"
)
neg_sampler.generate_negative_samples(
pos_edges=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path
)
print(
f"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}"
)
# generate test negative edge set
start_time = time.time()
split_mode = "test"
print(
f"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}"
)
neg_sampler.generate_negative_samples(
pos_edges=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path
)
print(
f"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}"
)
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/thgl_software/thgl_software.py
================================================
import csv
import datetime
import glob, os
def load_csv_raw(fname):
"""
load the raw csv file and merge them into one
"""
out_dict = {}
num_lines = 0
with open(fname, 'r') as f:
reader = csv.reader(f, delimiter ='\t')
#* /user/10746682 U_SO_C_IC /issue_comment/455195715 1547754198
for row in reader:
head = row[0]
relation_type = row[1]
tail = row[2]
ts = int(row[3])
if (ts in out_dict):
if (head, tail, relation_type) in out_dict[ts]:
out_dict[ts][(head, tail, relation_type)] += 1
else:
out_dict[ts][(head, tail, relation_type)] = 1
else:
out_dict[ts] = {}
out_dict[ts][(head, tail, relation_type)] = 1
num_lines += 1
return out_dict, num_lines
def write2csv(outname, out_dict):
with open(outname, 'w') as f:
writer = csv.writer(f, delimiter =',')
writer.writerow(['ts', 'head', 'tail', 'relation_type'])
ts_list = list(out_dict.keys())
ts_list.sort()
for ts in ts_list:
for edge in out_dict[ts]:
head = edge[0]
tail = edge[1]
relation_type = edge[2]
row = [ts, head, tail, relation_type]
writer.writerow(row)
def load_edgelist(fname):
"""
load the edgelist
"""
node_dict = {} # {node_name: node_id}
node_type_dict = {} # {node_id: node_type}
rel_type_dict = {}
edge_dict = {} # {edge: edge_type}
node_type_mapping = {}
num_edges = 0
with open(fname, 'r') as f:
reader = csv.reader(f, delimiter =',')
first_row = True
for row in reader:
if first_row:
first_row = False
continue
ts = int(row[0])
head = row[1]
tail = row[2]
relation_type = row[3]
head_strs = head.split('/')
tail_strs = tail.split('/')
head_type = head_strs[1]
tail_type = tail_strs[1]
if head_type not in node_type_mapping:
node_type_mapping[head_type] = len(node_type_mapping)
if tail_type not in node_type_mapping:
node_type_mapping[tail_type] = len(node_type_mapping)
if head not in node_dict:
node_dict[head] = len(node_dict)
node_type_dict[node_dict[head]] = node_type_mapping[head_type]
if tail not in node_dict:
node_dict[tail] = len(node_dict)
node_type_dict[node_dict[tail]] = node_type_mapping[tail_type]
if relation_type not in rel_type_dict:
rel_type_dict[relation_type] = len(rel_type_dict)
if ts not in edge_dict:
edge_dict[ts] = {}
edge_dict[ts][(node_dict[head], node_dict[tail], rel_type_dict[relation_type])] = 1
num_edges += 1
print ("there are {} nodes".format(len(node_dict)))
print ("there are {} edges".format(num_edges))
return node_dict, node_type_dict, edge_dict, rel_type_dict, node_type_mapping
def writeNodeType(node_type_dict, outname):
r"""
write the node type mapping to a file
"""
with open(outname, 'w') as f:
writer = csv.writer(f, delimiter =',')
writer.writerow(['node_id', 'type'])
for key in node_type_dict:
writer.writerow([key, node_type_dict[key]])
def writeEdgeTypeMapping(edge_type_dict, outname):
r"""
write the edge type mapping to a file
"""
with open(outname, 'w') as f:
writer = csv.writer(f, delimiter =',')
writer.writerow(['edge_id', 'type'])
for key in edge_type_dict:
writer.writerow([key, edge_type_dict[key]])
def writeNodeTypeMapping(node_type_dict, outname):
r"""
write the edge type mapping to a file
"""
with open(outname, 'w') as f:
writer = csv.writer(f, delimiter =',')
writer.writerow(['node_type_id', 'type'])
for key in node_type_dict:
writer.writerow([key, node_type_dict[key]])
def write2edgelist(out_dict, outname):
r"""
Write the dictionary to a csv file
"""
num_lines = 0
with open(outname, 'w') as f:
writer = csv.writer(f, delimiter =',')
writer.writerow(['timestamp', 'head', 'tail', 'relation_type'])
dates = list(out_dict.keys())
dates.sort()
for date in dates:
for edge in out_dict[date]:
head = edge[0]
tail = edge[1]
relation_type = int(edge[2])
row = [date, head, tail, relation_type]
writer.writerow(row)
num_lines += 1
print ("there are {} lines in the file".format(num_lines))
def main():
fname = "software_edgelist.csv" #"all_edgelist.csv"
node_dict, node_type_dict, edge_dict, edge_type_dict, node_type_mapping = load_edgelist(fname)
write2edgelist (edge_dict, "thgl-software_edgelist.csv")
writeNodeType(node_type_dict, "thgl-software_nodetype.csv")
writeEdgeTypeMapping(edge_type_dict, "thgl-software_edgemapping.csv")
writeNodeTypeMapping(node_type_mapping, "thgl-software_nodemapping.csv")
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/thgl_software/thgl_software_ns_gen.py
================================================
import time
from tgb.linkproppred.thg_negative_generator import THGNegativeEdgeGenerator
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
def main():
r"""
Generate negative edges for the validation or test phase
"""
print("*** Negative Sample Generation ***")
# setting the required parameters
num_neg_e_per_pos = 1000
neg_sample_strategy = "node-type-filtered" #"random"
rnd_seed = 42
name = "thgl-software"
dataset = PyGLinkPropPredDataset(name=name, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data_splits = {}
data_splits['train'] = data[train_mask]
data_splits['val'] = data[val_mask]
data_splits['test'] = data[test_mask]
min_node_idx = min(int(data.src.min()), int(data.dst.min()))
max_node_idx = max(int(data.src.max()), int(data.dst.max()))
neg_sampler = THGNegativeEdgeGenerator(
dataset_name=name,
first_node_id=min_node_idx,
last_node_id=max_node_idx,
node_type=dataset.node_type,
num_neg_e=num_neg_e_per_pos,
strategy=neg_sample_strategy,
rnd_seed=rnd_seed,
edge_data=data,
)
# generate evaluation set
partial_path = "."
# generate validation negative edge set
start_time = time.time()
split_mode = "val"
print(
f"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}"
)
neg_sampler.generate_negative_samples(
pos_edges=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path
)
print(
f"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}"
)
# generate test negative edge set
start_time = time.time()
split_mode = "test"
print(
f"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}"
)
neg_sampler.generate_negative_samples(
pos_edges=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path
)
print(
f"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}"
)
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/tkgl_icews/tkgl_icews.py
================================================
import csv
import datetime
import glob, os
def load_csv_raw(fname):
r"""
load from the raw data and retrieve, timestamp, head, tail, relation
convert all dates into unix timestamps
#! Event ID Event Date Source Name Source Sectors Source Country Event Text CAMEO Code Intensity Target Name Target Sectors Target Country Story ID Sentence Number Publisher City District Province Country Latitude Longitude
"""
out_dict = {}
first_row = True
num_lines = 0
with open(fname, 'r', encoding='ISO-8859-1') as f:
reader = csv.reader(f, delimiter ='\t')
for row in reader:
if first_row:
first_row = False
continue
date = row[1] #1995-01-01
head = row[2]
tail = row[8]
relation_type = row[6] #CAMEO code #! not always integer in 2017 for some reason there is 13y
if (len(date) == 0):
continue
if ("None" in date or "None" in head or "None" in tail or "None" in relation_type):
continue
else:
#! remove redundant edges with same timestamps
TIME_FORMAT = "%Y-%m-%d" #2018-01-01
date_cur = datetime.datetime.strptime(date, TIME_FORMAT)
ts = int(date_cur.timestamp())
num_lines += 1
if (ts in out_dict):
if (head, tail, relation_type) in out_dict[ts]:
out_dict[ts][(head, tail, relation_type)] += 1
else:
out_dict[ts][(head, tail, relation_type)] = 1
else:
out_dict[ts] = {}
out_dict[ts][(head, tail, relation_type)] = 1
return out_dict, num_lines
def write2csv(outname, out_dict):
node_dict = {}
max_node_id = 0
edge_type_dict = {}
max_edge_type_id = 0
with open(outname, 'w') as f:
writer = csv.writer(f, delimiter =',')
writer.writerow(['date', 'head', 'tail', 'relation_type'])
for date in out_dict:
for edge in out_dict[date]:
head = edge[0]
tail = edge[1]
relation_type = edge[2]
if head not in node_dict:
node_dict[head] = max_node_id
max_node_id += 1
if tail not in node_dict:
node_dict[tail] = max_node_id
max_node_id += 1
if relation_type not in edge_type_dict:
edge_type_dict[relation_type] = max_edge_type_id
max_edge_type_id += 1
row = [date, node_dict[head], node_dict[tail], edge_type_dict[relation_type]]
writer.writerow(row)
return node_dict, edge_type_dict
def writeEdgeTypeMapping(edge_type_dict, outname):
r"""
write the edge type mapping to a file
"""
with open(outname, 'w') as f:
writer = csv.writer(f, delimiter =',')
writer.writerow(['edge_id', 'type'])
for key in edge_type_dict:
writer.writerow([key, edge_type_dict[key]])
def main():
total_lines = 0
total_edge_dict = {}
#1. find all files with .txt in the folder
for file in glob.glob("*.tab"):
# outname = file[7:11] + "_edgelist.csv"
print ("processing", file)
edge_dict, num_lines = load_csv_raw(file)
total_lines += num_lines
print ("-----------------------------------")
print ("file, ", file)
print ("number of lines, ", num_lines)
print ("number of days, ", len(edge_dict))
print ("-----------------------------------")
total_edge_dict.update(edge_dict)
outname = "tkgl-icews_edgelist_tiny.csv"
print ("total number of lines", total_lines)
print ("total number of days", len(total_edge_dict))
node_dict, edge_type_dict = write2csv(outname, total_edge_dict)
writeEdgeTypeMapping(edge_type_dict, "tkgl-icews_edgemapping.csv")
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/tkgl_icews/tkgl_icews_ns_gen.py
================================================
import time
from tgb.linkproppred.tkg_negative_generator import TKGNegativeEdgeGenerator
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
def main():
r"""
Generate negative edges for the validation or test phase
"""
print("*** Negative Sample Generation ***")
# setting the required parameters
num_neg_e_per_pos = -1
neg_sample_strategy = "time-filtered"
rnd_seed = 42
name = "tkgl-icews"
dataset = PyGLinkPropPredDataset(name=name, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data_splits = {}
data_splits['train'] = data[train_mask]
data_splits['val'] = data[val_mask]
data_splits['test'] = data[test_mask]
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
neg_sampler = TKGNegativeEdgeGenerator(
dataset_name=name,
first_dst_id=min_dst_idx,
last_dst_id=max_dst_idx,
num_neg_e=num_neg_e_per_pos,
strategy=neg_sample_strategy,
rnd_seed=rnd_seed,
edge_data=data,
)
# generate evaluation set
partial_path = "."
# generate validation negative edge set
start_time = time.time()
split_mode = "val"
print(
f"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}"
)
neg_sampler.generate_negative_samples(
pos_edges=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path
)
print(
f"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}"
)
# generate test negative edge set
start_time = time.time()
split_mode = "test"
print(
f"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}"
)
neg_sampler.generate_negative_samples(
pos_edges=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path
)
print(
f"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}"
)
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/tkgl_polecat/tkgl_polecat.py
================================================
import csv
import datetime
import glob, os
def load_csv_raw(fname):
r"""
load from the raw data and retrieve, timestamp, head, tail, relation
convert all dates into unix timestamps
#!Event ID Event Date Event Type Event Mode Intensity Quad Code Contexts Actor Name Actor Country Actor COW Primary Actor Sector Actor Sectors Actor Title Actor Name Raw Wikipedia Actor ID Recipient Name Recipient Country Recipient COW Primary Recipient Sector Recipient Sectors Recipient Title Recipient Name Raw Wikipedia Recipient ID Placename City District Province Country Latitude Longitude GeoNames ID Raw Placename Feature Type Source Publication Date Story People Story Organizations Story Locations Language Version
"""
out_dict = {}
first_row = True
num_lines = 0
with open(fname, 'r', encoding='ISO-8859-1') as f:
reader = csv.reader(f, delimiter ='\t')
for row in reader:
if first_row:
first_row = False
continue
date = row[1]
relation_type = row[2]
head = row[7]
tail = row[15]
if (len(date) == 0):
continue
if ("None" in date or "None" in head or "None" in tail or "None" in relation_type):
continue
else:
#! remove redundant edges with same timestamps
TIME_FORMAT = "%Y-%m-%d" #2018-01-01
date_cur = datetime.datetime.strptime(date, TIME_FORMAT)
ts = int(date_cur.timestamp())
num_lines += 1
if (ts in out_dict):
if (head, tail, relation_type) in out_dict[ts]:
out_dict[ts][(head, tail, relation_type)] += 1
else:
out_dict[ts][(head, tail, relation_type)] = 1
else:
out_dict[ts] = {}
out_dict[ts][(head, tail, relation_type)] = 1
return out_dict, num_lines
#! fill in node and edge type dictionaries
def write2csv(outname: str,
out_dict: dict,
edge_type_dict: dict = None,
node_dict: dict = None,):
r"""
Write the dictionary to a csv file
also keep track of edge_type or node_dict, update the provided one too
"""
if (edge_type_dict is None):
edge_type_dict = {}
if (node_dict is None):
node_dict = {}
max_edge_type_id = len(edge_type_dict)
max_node_id = len(node_dict)
with open(outname, 'w') as f:
writer = csv.writer(f, delimiter =',')
writer.writerow(['date', 'head', 'tail', 'relation_type'])
dates = list(out_dict.keys())
dates.sort()
for date in dates:
for edge in out_dict[date]:
head = edge[0]
tail = edge[1]
relation_type = edge[2]
if head not in node_dict:
node_dict[head] = max_node_id
max_node_id += 1
if tail not in node_dict:
node_dict[tail] = max_node_id
max_node_id += 1
if relation_type not in edge_type_dict:
edge_type_dict[relation_type] = max_edge_type_id
max_edge_type_id += 1
row = [date, node_dict[head], node_dict[tail], edge_type_dict[relation_type]]
writer.writerow(row)
return edge_type_dict, node_dict
# def write2csv(outname, out_dict):
# node_dict = {}
# max_node_id = 0
# edge_type_dict = {}
# max_edge_type_id = 0
# with open(outname, 'w') as f:
# writer = csv.writer(f, delimiter =',')
# writer.writerow(['date', 'head', 'tail', 'relation_type'])
# for date in out_dict:
# for edge in out_dict[date]:
# head = edge[0]
# tail = edge[1]
# relation_type = edge[2]
# if head not in node_dict:
# node_dict[head] = max_node_id
# max_node_id += 1
# if tail not in node_dict:
# node_dict[tail] = max_node_id
# max_node_id += 1
# if relation_type not in edge_type_dict:
# edge_type_dict[relation_type] = max_edge_type_id
# max_edge_type_id += 1
# row = [date, node_dict[head], node_dict[tail], edge_type_dict[relation_type]]
# writer.writerow(row)
def writeEdgeTypeMapping(edge_type_dict, outname):
r"""
write the edge type mapping to a file
"""
with open(outname, 'w') as f:
writer = csv.writer(f, delimiter =',')
writer.writerow(['edge_id', 'type'])
for key in edge_type_dict:
writer.writerow([key, edge_type_dict[key]])
def main():
#example
# fname = "2018-Jan.txt"
# print ("hi")
# lines = load_csv_raw(fname)
# outname = "tkgl-polecat_edgelist.csv"
# write2csv(outname, lines)
total_lines = 0
num_days = 0
total_edge_dict = {}
#1. find all files with .txt in the folder
for file in glob.glob("*.csv"):
outname = file[0:7] + "_edgelist.csv"
print ("processing", file, "to", outname)
edge_dict, num_lines = load_csv_raw(file)
total_lines += num_lines
num_days += len(edge_dict)
total_edge_dict.update(edge_dict)
edge_type_dict, node_dict = write2csv("tkgl-polecat_edgelist.csv", total_edge_dict)
print ("-----------------------------------")
print ("total number of lines", total_lines)
print ("total number of days", num_days)
print ("there are", len(edge_type_dict), "unique edge types")
print ("there are", len(node_dict), "unique nodes")
writeEdgeTypeMapping(edge_type_dict, "tkgl-polecat_edgemapping.csv")
if __name__ == "__main__":
main()
#* rename functions
# renames = []
# for file in glob.glob("*.txt"):
# outname = file[-12:-4] + "_edgelist.csv"
# file_rename = file[-12:-4] + "_raw.csv"
# if ("Jan" in outname):
# outname = outname.replace("Jan", "01")
# renames.append((file, file_rename.replace("Jan", "01")))
# elif ("Feb" in outname):
# outname = outname.replace("Feb", "02")
# renames.append((file, file_rename.replace("Feb", "02")))
# elif ("Mar" in outname):
# outname = outname.replace("Mar", "03")
# renames.append((file, file_rename.replace("Mar", "03")))
# elif ("Apr" in outname):
# outname = outname.replace("Apr", "04")
# renames.append((file, file_rename.replace("Apr", "04")))
# elif ("May" in outname):
# outname = outname.replace("May", "05")
# renames.append((file, file_rename.replace("May", "05")))
# elif ("Jun" in outname):
# outname = outname.replace("Jun", "06")
# renames.append((file, file_rename.replace("Jun", "06")))
# elif ("Jul" in outname):
# outname = outname.replace("Jul", "07")
# renames.append((file, file_rename.replace("Jul", "07")))
# elif ("Aug" in outname):
# outname = outname.replace("Aug", "08")
# renames.append((file, file_rename.replace("Aug", "08")))
# elif ("Sep" in outname):
# outname = outname.replace("Sep", "09")
# renames.append((file, file_rename.replace("Sep", "09")))
# elif ("Oct" in outname):
# outname = outname.replace("Oct", "10")
# renames.append((file, file_rename.replace("Oct", "10")))
# elif ("Nov" in outname):
# outname = outname.replace("Nov", "11")
# renames.append((file, file_rename.replace("Nov", "11")))
# elif ("Dec" in outname):
# outname = outname.replace("Dec", "12")
# renames.append((file, file_rename.replace("Dec", "12")))
# for file, file_rename in renames:
# os.rename(file, file_rename)
================================================
FILE: tgb/datasets/tkgl_polecat/tkgl_polecat_ns_gen.py
================================================
import time
from tgb.linkproppred.tkg_negative_generator import TKGNegativeEdgeGenerator
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
def main():
r"""
Generate negative edges for the validation or test phase
"""
print("*** Negative Sample Generation ***")
# setting the required parameters
num_neg_e_per_pos = -1
neg_sample_strategy = "time-filtered"
rnd_seed = 42
name = "tkgl-polecat"
dataset = PyGLinkPropPredDataset(name=name, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data_splits = {}
data_splits['train'] = data[train_mask]
data_splits['val'] = data[val_mask]
data_splits['test'] = data[test_mask]
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
neg_sampler = TKGNegativeEdgeGenerator(
dataset_name=name,
first_dst_id=min_dst_idx,
last_dst_id=max_dst_idx,
num_neg_e=num_neg_e_per_pos,
strategy=neg_sample_strategy,
rnd_seed=rnd_seed,
edge_data=data,
)
# generate evaluation set
partial_path = "."
# generate validation negative edge set
start_time = time.time()
split_mode = "val"
print(
f"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}"
)
neg_sampler.generate_negative_samples(
pos_edges=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path
)
print(
f"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}"
)
# generate test negative edge set
start_time = time.time()
split_mode = "test"
print(
f"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}"
)
neg_sampler.generate_negative_samples(
pos_edges=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path
)
print(
f"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}"
)
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/tkgl_smallpedia/smallpedia_remove_conflict.py
================================================
import csv
def load_static_edgelist(file_path):
r"""
Load the static edgelist from the file_path
Args:
file_path: str, The path to the file
"""
static_dict = {}
first_row = True
with open(file_path, 'r') as f:
reader = csv.reader(f, delimiter =',')
for row in reader:
if first_row:
first_row = False
continue
head = row[0]
tail = row[1]
relation_type = row[2]
static_dict[(head, tail, relation_type)] = 1
return static_dict
def load_temporal_edgelist(file_path):
r"""
Load the temporal edgelist from the file_path
Args:
file_path: str, The path to the file
"""
temporal_dict = {}
first_row = True
with open(file_path, 'r') as f:
"""
ts,head,tail,relation_type
0,Q331755,Q1294765,P39
"""
reader = csv.reader(f, delimiter =',')
for row in reader:
if first_row:
first_row = False
continue
ts = int(row[0])
head = row[1]
tail = row[2]
relation_type = row[3]
if ts not in temporal_dict:
temporal_dict[ts] = {}
temporal_dict[ts][(head, tail, relation_type)] = 1
else:
if (head, tail, relation_type) in temporal_dict[ts]:
temporal_dict[ts][(head, tail, relation_type)] += 1
else:
temporal_dict[ts][(head, tail, relation_type)] = 1
return temporal_dict
def remove_conflict(static_dict, temporal_dict):
r"""
Remove the conflict between the static and temporal edgelist
Args:
static_dict: dict, The static edgelist
temporal_dict: dict, The temporal edgelist
"""
num_conflicts = 0
for ts in temporal_dict:
for edge in temporal_dict[ts]:
head = edge[0]
tail = edge[1]
relation_type = edge[2]
if (head, tail, relation_type) in static_dict:
num_conflicts += 1
static_dict.pop((head, tail, relation_type))
print("Removed {} conflicts".format(num_conflicts))
return static_dict
def write2csv(outname: str,
out_dict: dict,):
r"""
Write the dictionary to a csv file
"""
with open(outname, 'w') as f:
writer = csv.writer(f, delimiter =',')
#head,tail,relation_type
writer.writerow(['head', 'tail', 'relation_type'])
for edge in out_dict:
head = edge[0]
tail = edge[1]
relation_type = edge[2]
row = [head, tail, relation_type]
writer.writerow(row)
def main():
#! remove conflict: remove all edges with the same head, tail, relation_type from the static edgelist
static_file = "tkgl-smallpedia_static_edgelist.csv"
temporal_file = "tkgl-smallpedia_edgelist.csv"
static_dict = load_static_edgelist(static_file)
print("constructed static dictionary")
temporal_dict = load_temporal_edgelist(temporal_file)
print("constructed temporal dictionary")
static_dict = remove_conflict(static_dict, temporal_dict)
out_name = "tkgl-smallpedia_static_edgelist_no_conflict.csv"
write2csv(out_name, static_dict)
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/tkgl_smallpedia/tkgl_smallpedia_ns_gen.py
================================================
import time
from tgb.linkproppred.tkg_negative_generator import TKGNegativeEdgeGenerator
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
def main():
r"""
Generate negative edges for the validation or test phase
"""
print("*** Negative Sample Generation ***")
# setting the required parameters
num_neg_e_per_pos = -1 #10000
neg_sample_strategy = "time-filtered" #"dst-time-filtered"
rnd_seed = 42
name = "tkgl-smallpedia"
dataset = PyGLinkPropPredDataset(name=name, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data_splits = {}
data_splits['train'] = data[train_mask]
data_splits['val'] = data[val_mask]
data_splits['test'] = data[test_mask]
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
neg_sampler = TKGNegativeEdgeGenerator(
dataset_name=name,
first_dst_id=min_dst_idx,
last_dst_id=max_dst_idx,
num_neg_e=num_neg_e_per_pos,
strategy=neg_sample_strategy,
rnd_seed=rnd_seed,
partial_path=".",
edge_data=data,
)
# generate evaluation set
partial_path = "."
# generate validation negative edge set
start_time = time.time()
split_mode = "val"
print(
f"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}"
)
neg_sampler.generate_negative_samples(
pos_edges=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path
)
print(
f"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}"
)
# generate test negative edge set
start_time = time.time()
split_mode = "test"
print(
f"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}"
)
neg_sampler.generate_negative_samples(
pos_edges=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path
)
print(
f"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}"
)
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/tkgl_wikidata/extract.sh
================================================
for chunk in 5; do
num_chunk=25
while [ $chunk -le $num_chunk ]; do
cmd="tkgl_wikidata.py \
--chunk ${chunk} \
--num_chunks ${num_chunk} \
"
python $cmd
chunk=$(( $chunk + 1 ))
done
done
================================================
FILE: tgb/datasets/tkgl_wikidata/time_edges/tkgl-wikidata_extract.py
================================================
import csv
import datetime
import glob, os
def load_time_csv_raw(fname):
r"""
load from the raw data and retrieve, timestamp, head, tail, relation, time_rel
convert all dates into unix timestamps
"""
out_dict = {}
first_row = True
num_lines = 0
#? timestamp,head,tail,relation_type,time_rel_type
#* +1999-01-01T00:00:00Z,Q31,Q4916,P38,P580
error_ctr = 0
with open(fname, 'r') as f:
reader = csv.reader(f, delimiter =',')
for row in reader:
if first_row:
first_row = False
continue
date = row[0][0:11]
head = row[1]
tail = row[2]
relation_type = row[3]
time_rel = row[4]
if (len(date) == 0):
continue
if ("None" in date or "None" in head or "None" in tail or "None" in relation_type):
continue
else:
TIME_FORMAT = "%Y"
#* only keep track of year in positive BC
if (date[0] == "+"):
ts = int(date[1:5])
else:
continue
#* no scifi for knowledge graphs
if (ts > 2024):
continue
num_lines += 1
if (ts in out_dict):
if (head, tail, relation_type, time_rel) in out_dict[ts]:
out_dict[ts][(head, tail, relation_type, time_rel)] += 1
else:
out_dict[ts][(head, tail, relation_type, time_rel)] = 1
else:
out_dict[ts] = {}
out_dict[ts][(head, tail, relation_type, time_rel)] = 1
return out_dict, num_lines
def write2csv(outname: str,
out_dict: dict,):
r"""
Write the dictionary to a csv file
"""
with open(outname, 'w') as f:
writer = csv.writer(f, delimiter =',')
writer.writerow(['ts', 'head', 'tail', 'relation_type','time_rel_type'])
dates = list(out_dict.keys())
dates.sort()
for date in dates:
for edge in out_dict[date]:
head = edge[0]
tail = edge[1]
relation_type = edge[2]
time_rel = edge[3]
row = [date, head, tail, relation_type, time_rel]
writer.writerow(row)
def update_dict(total_dict, new_dict):
r"""
Update the total_dict with new_dict
"""
for key in new_dict:
if key in total_dict:
for edge in new_dict[key]:
if edge in total_dict[key]:
total_dict[key][edge] += new_dict[key][edge]
else:
total_dict[key][edge] = new_dict[key][edge]
else:
total_dict[key] = new_dict[key]
return total_dict
def retrieve_all_entities(total_dict):
r"""
retrieve the entities from all edges of the total dictionary
Parameters:
total_dict: dictionary of all edges, {ts: {edge: count}}
"""
node_dict = {}
for key in total_dict:
for edge in total_dict[key]:
head = edge[0]
tail = edge[1]
if head not in node_dict:
node_dict[head] = 1
else:
node_dict[head] += 1
if tail not in node_dict:
node_dict[tail] = 1
else:
node_dict[tail] += 1
return node_dict
def writenode2csv(outname: str,
out_dict: dict,):
r"""
Write the dictionary to a csv file
"""
with open(outname, 'w') as f:
writer = csv.writer(f, delimiter =',')
writer.writerow(['entity', 'occurrences'])
for node in out_dict:
row = [node, out_dict[node]]
writer.writerow(row)
def main():
#! when timestamps overlap can't update dictionary
total_lines = 0
total_edge_dict = {}
#1. find all files with .txt in the folder
total_edge_file = "tkgl-wikidata_edgelist.csv"
for file in glob.glob("*.csv"):
print (file)
edge_dict, num_lines = load_time_csv_raw(file)
print ("processed ", num_lines, " lines")
total_lines += num_lines
update_dict(total_edge_dict, edge_dict)
print ("processed a total of ", total_lines, " lines")
node_dict = retrieve_all_entities(total_edge_dict)
writenode2csv("wiki_entities.csv", node_dict)
write2csv(total_edge_file, total_edge_dict)
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/tkgl_wikidata/tkgl-wikidata.py
================================================
import csv
import datetime
import glob, os
def load_time_csv(fname):
r"""
load from data and retrieve, ts,head,tail,relation_type,time_rel_type
"""
out_dict = {} #only contain edges {ts: {(head, tail, rel_type):count}}
start_end_dict = {} #{(head, tail, rel_type): {start:year, end:year}}
first_row = True
point_in_time_lines = 0
start_end_lines = 0
#? ts,head,tail,relation_type,time_rel_type
#* 0,Q331755,Q1294765,P39,P580
with open(fname, 'r') as f:
reader = csv.reader(f, delimiter =',')
for row in reader:
if first_row:
first_row = False
continue
ts = int(row[0])
head = row[1]
tail = row[2]
relation_type = row[3]
time_rel = row[4]
if (time_rel in ['P585', 'P577', 'P574']):
if (ts in out_dict):
if (head, tail, relation_type) in out_dict[ts]:
out_dict[ts][(head, tail, relation_type)] += 1
else:
out_dict[ts][(head, tail, relation_type)] = 1
else:
out_dict[ts] = {}
out_dict[ts][(head, tail, relation_type)] = 1
point_in_time_lines += 1
else: # time_rel in ['P580', 'P582']
if (head, tail, relation_type) in start_end_dict:
if (time_rel in ['P580']):
start_end_dict[(head, tail, relation_type)]['start'] = ts
elif (time_rel in ['P582']):
start_end_dict[(head, tail, relation_type)]['end'] = ts
else:
raise ValueError(f"Unknown time_rel: {time_rel}")
else:
start_end_dict[(head, tail, relation_type)] = {}
if (time_rel in ['P580']):
start_end_dict[(head, tail, relation_type)]['start'] = ts
else:
start_end_dict[(head, tail, relation_type)]['end'] = ts
start_end_lines += 1
print ("-----------------------------------")
print ("for this edgelist:")
print (f"point_in_time_lines: {point_in_time_lines}")
print (f"start_end_lines: {start_end_lines}")
print ("-----------------------------------")
repeated_lines = 0
no_duration_lines = 0
#* now, repeat edges from start_end_dict
for edge in start_end_dict.keys():
if 'start' not in start_end_dict[edge]:
#start_end_dict[edge]['start'] = 0 #start at year 0
#start_end_dict[edge]['start'] = start_end_dict[edge]['end']
no_duration_lines += 1
continue
if 'end' not in start_end_dict[edge]:
# start_end_dict[edge]['end'] = 2024 #end at year 2024
start_end_dict[edge]['end'] = start_end_dict[edge]['start'] #end at year 2024
no_duration_lines += 1
continue
for year in range(start_end_dict[edge]['start'], start_end_dict[edge]['end']+1):
if year not in out_dict:
out_dict[year] = {}
out_dict[year][edge] = 1
repeated_lines += 1
print ("-----------------------------------")
print ("for this edgelist:")
print (f"point_in_time_lines: {point_in_time_lines}")
print (f"start_end_lines: {start_end_lines} resulting in")
print (f"repeated_lines: {repeated_lines}")
print (f"no_duration_lines: {no_duration_lines}")
print ("-----------------------------------")
print ("total lines: ", point_in_time_lines + repeated_lines)
num_lines = point_in_time_lines + repeated_lines
return out_dict, num_lines
def write2csv(outname: str,
out_dict: dict,):
r"""
Write the dictionary to a csv file
"""
with open(outname, 'w') as f:
writer = csv.writer(f, delimiter =',')
writer.writerow(['ts', 'head', 'tail', 'relation_type'])
dates = list(out_dict.keys())
dates.sort()
for date in dates:
for edge in out_dict[date]:
head = edge[0]
tail = edge[1]
relation_type = edge[2]
row = [date, head, tail, relation_type]
writer.writerow(row)
def extract_subset(fname, outname, start_year=2000, end_year=2024):
node_dict = {}
first_row = True
rel_type = {}
r"""
ts,head,tail,relation_type
0,Q331755,Q1294765,P39
0,Q116233388,Q2566630,P2348
"""
with open(outname, 'w') as f:
writer = csv.writer(f, delimiter =',')
writer.writerow(['ts', 'head', 'tail', 'relation_type'])
with open(fname, 'r') as f:
reader = csv.reader(f, delimiter =',')
for row in reader:
if first_row:
first_row = False
continue
ts = int(row[0])
head = row[1]
tail = row[2]
relation_type = row[3]
if (ts >= start_year and ts <= end_year):
if head not in node_dict:
node_dict[head] = 1
if tail not in node_dict:
node_dict[tail] = 1
row = [ts, head, tail, relation_type]
if (relation_type not in rel_type):
rel_type[relation_type] = 1
writer.writerow(row)
print ("there are ",len(rel_type), " relation types")
return node_dict
def extract_subset_nodeid(fname, outname, start_year=2000, end_year=2024, max_id=1000000):
node_dict = {}
first_row = True
rel_type = {}
r"""
ts,head,tail,relation_type
0,Q331755,Q1294765,P39
0,Q116233388,Q2566630,P2348
"""
with open(outname, 'w') as f:
writer = csv.writer(f, delimiter =',')
writer.writerow(['ts', 'head', 'tail', 'relation_type'])
with open(fname, 'r') as f:
reader = csv.reader(f, delimiter =',')
for row in reader:
if first_row:
first_row = False
continue
ts = int(row[0])
head = row[1]
head_id = int(head[1:])
tail = row[2]
tail_id = int(tail[1:])
if (head_id > max_id or tail_id > max_id):
continue
relation_type = row[3]
if (ts >= start_year and ts <= end_year):
if head not in node_dict:
node_dict[head] = 1
if tail not in node_dict:
node_dict[tail] = 1
row = [ts, head, tail, relation_type]
if (relation_type not in rel_type):
rel_type[relation_type] = 1
writer.writerow(row)
print ("there are ",len(rel_type), " relation types")
return node_dict
def extract_static_subset(fname, outname, node_dict, max_id=1000000):
r"""
extract static edges based a given node dict
"""
first_row = True
r"""
head,tail,relation_type
Q31,Q1088364,P1344
Q31,Q3247091,P1151
"""
rel_type = {}
full_node = {}
with open(outname, 'w') as f:
writer = csv.writer(f, delimiter =',')
writer.writerow(['head', 'tail', 'relation_type'])
with open(fname, 'r') as f:
reader = csv.reader(f, delimiter =',')
for row in reader:
if first_row:
first_row = False
continue
head = row[0]
head_id = int(head[1:])
tail = row[1]
tail_id = int(tail[1:])
relation_type = row[2]
if (head_id > max_id or tail_id > max_id):
continue
if (head in node_dict) or (tail in node_dict): #need to check
row = [head, tail, relation_type]
writer.writerow(row)
if (relation_type not in rel_type):
rel_type[relation_type] = 1
else:
rel_type[relation_type] += 1
if (head not in full_node):
full_node[head] = 1
if (tail not in full_node):
full_node[tail] = 1
print ("there are ",len(rel_type), " relation types")
print ("there are ",len(full_node), " nodes in static edgelist")
return rel_type
#! not used, filter by top edgetypes
def subset_static_edges(fname, outname, rel_type, topk=10):
#* select edges based on frequency
import operator
sorted_x = sorted(rel_type.items(), key=operator.itemgetter(1))
sorted_x = sorted_x[-topk:]
rel_kept = {}
for (u,v) in sorted_x:
rel_kept[u] = 1
print (u,v)
kept_nodes = {}
first_row = True
# rel_kept = {"P17":1, "P27":1, "P495":1, "P19": 1}
with open(outname, 'w') as f:
writer = csv.writer(f, delimiter =',')
writer.writerow(['head', 'tail', 'relation_type'])
with open(fname, 'r') as f:
reader = csv.reader(f, delimiter =',')
for row in reader:
if first_row:
first_row = False
continue
head = row[0]
tail = row[1]
relation_type = row[2]
if (relation_type in rel_kept):
row = [head, tail, relation_type]
if (head not in kept_nodes):
kept_nodes[head] = 1
if (tail not in kept_nodes):
kept_nodes[tail] = 1
writer.writerow(row)
print ("there are ",len(kept_nodes), " nodes in static edgelist")
def main():
# #* repeat the edges of start and end dates
# """
# P580: start time
# P582: end time
# P585: point in time
# P577: publication date
# P574: year of publication of scientific name for taxon
# we need to:
# 1. get all edges with P585, P577 and P574
# 2. find out which edges has both start and end time
# 3. for those without start time, start at year 0, without end time, end at year 2024
# """
# fname = "tkgl-wikidata_edgelist_raw.csv"
# out_dict, num_lines = load_time_csv(fname)
# outname = "tkgl-wikidata_edgelist.csv"
# write2csv(outname, out_dict)
inputfile = "tkgl-wikidata_edgelist.csv"
outname = "tkgl-smallpedia_edgelist.csv"
# start_year = 2015
start_year=1900#1700
end_year=2024#1800
max_id=1000000
# node_dict = extract_subset(inputfile, outname, start_year=start_year, end_year=end_year)
node_dict = extract_subset_nodeid(inputfile, outname, start_year=start_year, end_year=end_year, max_id=max_id)
print ("there are ",len(node_dict), " nodes")
inputfile = "tkgl-wikidata_static_edgelist.csv"
outname = "tkgl-smallpedia_static_edgelist.csv"
rel_type = extract_static_subset(inputfile, outname, node_dict, max_id=max_id)
#! not used
# inputfile = "tkgl-smallpedia_static_edgelist.csv"
# outname = "tkgl-smallpedia_static_edgelist_top10.csv"
# topk=10
# subset_static_edges(inputfile, outname, rel_type, topk=topk)
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/tkgl_wikidata/tkgl_wikidata_mining.py
================================================
r"""
How to use
python tkgl_wikidata.py --chunk 0 --num_chunks 25
# python tkgl_wikidata.py --chunk 1 --num_chunks 25
"""
from qwikidata.entity import WikidataItem
from qwikidata.json_dump import WikidataJsonDump
from qwikidata.datavalue import get_datavalue_from_snak_dict, WikibaseEntityId
from tqdm import tqdm
from collections import defaultdict
import os.path as osp
import os
import pickle
import argparse
import numpy as np
import csv
def timeEdgeWrite2csv(outname, out_dict):
with open(outname, 'w') as f:
writer = csv.writer(f, delimiter =',')
writer.writerow(['timestamp', 'head', 'tail', 'relation_type', 'time_rel_type'])
for edge in out_dict.keys():
ts = edge[0]
src = edge[1]
dst = edge[2]
rel_type = edge[3]
time_rel_type = edge[4]
row = [ts, src, dst, rel_type, time_rel_type]
writer.writerow(row)
def EdgeWrite2csv(outname, out_dict):
with open(outname, 'w') as f:
writer = csv.writer(f, delimiter =',')
writer.writerow(['head', 'tail', 'relation_type'])
for edge in out_dict.keys():
src = edge[0]
dst = edge[1]
rel_type = edge[2]
row = [src, dst, rel_type]
writer.writerow(row)
def main():
parser = argparse.ArgumentParser(description='Process some integers.')
# parser.add_argument('--split', type=str, default = 'train',
# help='an integer for the accumulator')
parser.add_argument('--chunk', type=int, default = 0,
help='an integer for the accumulator')
parser.add_argument('--num_chunks', type=int, default = 10,
help='an integer for the accumulator')
args = parser.parse_args()
print(args)
assert args.chunk < args.num_chunks
# # create an instance of WikidataJsonDump
# if args.split == 'train':
# wjd_dump_path_original = "wikidata-20210517-all.json.gz"
# elif args.split == 'val':
# wjd_dump_path_original = "wikidata-20210607-all.json.gz"
# elif args.split == 'test':
# wjd_dump_path_original = 'wikidata-20210628-all.json.gz'
# else:
# raise ValueError('Unknown split')
#* download here
#? https://dumps.wikimedia.org/wikidatawiki/entities/
wjd_dump_path_original = "wikidata-20240220-all.json.gz" #"latest-all_03_Apr_2024_12_49.json"
wjd_dump_path = osp.join('dump', wjd_dump_path_original)
wjd = WikidataJsonDump(wjd_dump_path)
print(wjd_dump_path)
"""
# head = entity_dict['id']
# type = entity_dict['type']
# labels = entity_dict['labels']
# descriptions = entity_dict['descriptions']
# aliases = entity_dict['aliases']
# if ('claims' in entity_dict):
# claims = entity_dict['claims']
# for key in claims.keys():
# print (key)
# print (claims[key])
# sitelinks = entity_dict['sitelinks']
"""
time_edge_dict = {} #{()}
time_rel_dict = {}
static_edge_dict = {}
dummy_rel_set = ['P31','P279'] #filter out instance of and subclass of
time_rel_set = ['P585','P580', 'P582', 'P577', 'P574'] #point in time, start time, end time, publication date,year of publication of scientific name for taxon
num_totals = 100000000 #4000000 #10000000 #110000000
tmp = np.linspace(0, num_totals, args.num_chunks + 1).astype(np.int64)
start_idx = tmp[args.chunk]
end_idx = tmp[args.chunk + 1]
print('Start: ', start_idx)
print('End: ', end_idx)
#? output format is (timestamp, head, tail, relation_type, time_rel_type)
for i, entity_dict in enumerate(tqdm(wjd, total=(end_idx))):
#! entity_dict keys(['type', 'id', 'labels', 'descriptions', 'aliases', 'claims', 'sitelinks', 'pageid', 'ns', 'title', 'lastrevid', 'modified'])
if i > end_idx:
break
if not (start_idx <= i and i < end_idx):
continue
head = entity_dict['id']
# head needs to start from 'Q'
if head[0] == 'Q':
head_id = head
if 'claims' in entity_dict:
claim_dict = entity_dict['claims']
rel_list = list(claim_dict.keys())
for rel in rel_list:
if (rel in dummy_rel_set):
continue
tail_list = claim_dict[rel]
for tail in tail_list:
tail_id = None
#* first check if there is a valid tail
if (tail['mainsnak']['datatype'] == 'wikibase-item'):
if ('rank' in tail) and (tail['rank'] != 'deprecated') and ('datavalue' in tail['mainsnak']):
if 'id' in tail['mainsnak']['datavalue']['value']:
tail_id = tail['mainsnak']['datavalue']['value']['id']
else:
tail_id = 'Q' + str(tail['mainsnak']['datavalue']['value']['numeric-id'])
#* check if there is a qualifier and if it is a time qualifier
if (tail_id is not None):
if ("qualifiers" in tail):
time_logged = False
for q in tail["qualifiers"]:
for item in tail["qualifiers"][q]:
if (item['datatype'] == 'time') and ('datavalue' in item):
timestr = item['datavalue']['value']['time']
time_rel_type = q
if (time_rel_type in time_rel_set):
time_edge_dict[(timestr, head_id, tail_id, rel, time_rel_type)] = 1
time_logged = True
else:
time_logged = False
if not time_logged:
static_edge_dict[(head_id, tail_id, rel)] = 1
else:
static_edge_dict[(head_id, tail_id, rel)] = 1
#! write edges to file
print ("there are ", len(time_edge_dict), " temporal edges in the dataset")
outname = "time_edgelist_" + str(args.chunk) + ".csv" #"tkgl-wikidata_time_edgelist.csv"
timeEdgeWrite2csv(outname, time_edge_dict)
print ("there are ", len(static_edge_dict), " static edges in the dataset")
outname = "static_edgelist_" + str(args.chunk) + ".csv" #"tkgl-wikidata_static_edgelist.csv"
EdgeWrite2csv(outname, static_edge_dict)
if __name__ == '__main__':
main()
================================================
FILE: tgb/datasets/tkgl_wikidata/tkgl_wikidata_ns_gen.py
================================================
import time
import sys
import os
import os.path as osp
from pathlib import Path
modules_path = osp.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
sys.path.append(modules_path)
from tgb.linkproppred.tkg_negative_generator import TKGNegativeEdgeGenerator
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
def main():
r"""
Generate negative edges for the validation or test phase
"""
print("*** Negative Sample Generation ***")
# setting the required parameters
num_neg_e_per_pos = 1000 #10000
neg_sample_strategy = "dst-time-filtered" #"time-filtered"
rnd_seed = 42
name = "tkgl-wikidata"
dataset = PyGLinkPropPredDataset(name=name, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data_splits = {}
data_splits['train'] = data[train_mask]
data_splits['val'] = data[val_mask]
data_splits['test'] = data[test_mask]
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
neg_sampler = TKGNegativeEdgeGenerator(
dataset_name=name,
first_dst_id=min_dst_idx,
last_dst_id=max_dst_idx,
num_neg_e=num_neg_e_per_pos,
strategy=neg_sample_strategy,
rnd_seed=rnd_seed,
partial_path=".",
edge_data=data,
)
# generate evaluation set
partial_path = "."
# generate validation negative edge set
start_time = time.time()
split_mode = "val"
print(
f"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}"
)
neg_sampler.generate_negative_samples(
pos_edges=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path
)
print(
f"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}"
)
# generate test negative edge set
start_time = time.time()
split_mode = "test"
print(
f"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}"
)
neg_sampler.generate_negative_samples(
pos_edges=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path
)
print(
f"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}"
)
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/tkgl_wikidata/wikidata_remove_conflict.py
================================================
import csv
def load_static_edgelist(file_path):
r"""
Load the static edgelist from the file_path
Args:
file_path: str, The path to the file
"""
static_dict = {}
first_row = True
with open(file_path, 'r') as f:
reader = csv.reader(f, delimiter =',')
for row in reader:
if first_row:
first_row = False
continue
head = row[0]
tail = row[1]
relation_type = row[2]
static_dict[(head, tail, relation_type)] = 1
return static_dict
def load_temporal_edgelist(file_path):
r"""
Load the temporal edgelist from the file_path
Args:
file_path: str, The path to the file
"""
temporal_dict = {}
first_row = True
with open(file_path, 'r') as f:
"""
ts,head,tail,relation_type
0,Q331755,Q1294765,P39
"""
reader = csv.reader(f, delimiter =',')
for row in reader:
if first_row:
first_row = False
continue
ts = int(row[0])
head = row[1]
tail = row[2]
relation_type = row[3]
if ts not in temporal_dict:
temporal_dict[ts] = {}
temporal_dict[ts][(head, tail, relation_type)] = 1
else:
if (head, tail, relation_type) in temporal_dict[ts]:
temporal_dict[ts][(head, tail, relation_type)] += 1
else:
temporal_dict[ts][(head, tail, relation_type)] = 1
return temporal_dict
def remove_conflict(static_dict, temporal_dict):
r"""
Remove the conflict between the static and temporal edgelist
Args:
static_dict: dict, The static edgelist
temporal_dict: dict, The temporal edgelist
"""
num_conflicts = 0
for ts in temporal_dict:
for edge in temporal_dict[ts]:
head = edge[0]
tail = edge[1]
relation_type = edge[2]
if (head, tail, relation_type) in static_dict:
num_conflicts += 1
static_dict.pop((head, tail, relation_type))
print("Removed {} conflicts".format(num_conflicts))
return static_dict
def write2csv(outname: str,
out_dict: dict,):
r"""
Write the dictionary to a csv file
"""
with open(outname, 'w') as f:
writer = csv.writer(f, delimiter =',')
#head,tail,relation_type
writer.writerow(['head', 'tail', 'relation_type'])
for edge in out_dict:
head = edge[0]
tail = edge[1]
relation_type = edge[2]
row = [head, tail, relation_type]
writer.writerow(row)
def main():
#! remove conflict: remove all edges with the same head, tail, relation_type from the static edgelist
static_file = "tkgl-wikidata_static_edgelist.csv"
temporal_file = "tkgl-wikidata_edgelist.csv"
static_dict = load_static_edgelist(static_file)
print("constructed static dictionary")
temporal_dict = load_temporal_edgelist(temporal_file)
print("constructed temporal dictionary")
static_dict = remove_conflict(static_dict, temporal_dict)
out_name = "tkgl-wikidata_static_edgelist_no_conflict.csv"
write2csv(out_name, static_dict)
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/tkgl_yago/tkgl_yago.py
================================================
import csv
import datetime
import glob, os
def main():
train_fname = "train.txt"
val_fname = "valid.txt"
test_fname = "test.txt"
train_dict, num_lines = load_csv(train_fname)
print ("there are ", num_lines, " lines in the train file")
print ("there are ", len(train_dict), " timestamps in the train file")
val_dict, num_lines = load_csv(val_fname)
print ("there are ", num_lines, " lines in the val file")
print ("there are ", len(val_dict), " timestamps in the val file")
test_dict, num_lines = load_csv(test_fname)
print ("there are ", num_lines, " lines in the test file")
print ("there are ", len(test_dict), " timestamps in the test file")
train_dict.update(val_dict)
train_dict.update(test_dict)
print ("there are ", len(train_dict), " timestamps in the combined file")
outname = "tkgl-yago_edgelist.csv"
write_csv(outname, train_dict)
def write_csv(outname, out_dict):
with open(outname, 'w') as f:
writer = csv.writer(f, delimiter =',')
writer.writerow(['timestamp', 'head', 'tail', 'relation_type'])
for ts in out_dict:
for edge in out_dict[ts]:
src = edge[0]
rel_type = edge[1]
dst = edge[2]
row = [ts, src, dst, rel_type]
writer.writerow(row)
def load_csv(fname):
out_dict = {}
num_lines = 0
with open(fname, 'r') as f:
reader = csv.reader(f, delimiter ='\t')
#! src rel_type dst ts
# 10289 9 10290 0 0
for row in reader:
src = int (row[0])
rel_type = int (row[1])
dst = int (row[2])
ts = int (row[3])
if ts not in out_dict:
out_dict[ts] = {(src,rel_type,dst):1}
else:
out_dict[ts][(src,rel_type,dst)] = 1
num_lines += 1
return out_dict, num_lines
if __name__ == "__main__":
main()
================================================
FILE: tgb/datasets/tkgl_yago/tkgl_yago_ns_gen.py
================================================
import time
import sys
sys.path.insert(0,'/../../../')
from tgb.linkproppred.tkg_negative_generator import TKGNegativeEdgeGenerator
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
def main():
r"""
Generate negative edges for the validation or test phase
"""
print("*** Negative Sample Generation ***")
# setting the required parameters
num_neg_e_per_pos = -1
neg_sample_strategy = "time-filtered"
rnd_seed = 42
name = "tkgl-yago"
dataset = PyGLinkPropPredDataset(name=name, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data_splits = {}
data_splits['train'] = data[train_mask]
data_splits['val'] = data[val_mask]
data_splits['test'] = data[test_mask]
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
neg_sampler = TKGNegativeEdgeGenerator(
dataset_name=name,
first_dst_id=min_dst_idx,
last_dst_id=max_dst_idx,
num_neg_e=num_neg_e_per_pos,
strategy=neg_sample_strategy,
rnd_seed=rnd_seed,
edge_data=data,
)
# generate evaluation set
partial_path = "."
# generate validation negative edge set
start_time = time.time()
split_mode = "val"
print(
f"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}"
)
neg_sampler.generate_negative_samples(
pos_edges=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path
)
print(
f"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}"
)
# generate test negative edge set
start_time = time.time()
split_mode = "test"
print(
f"INFO: Start generating negative samples: {split_mode} --- {neg_sample_strategy}"
)
neg_sampler.generate_negative_samples(
pos_edges=data_splits[split_mode], split_mode=split_mode, partial_path=partial_path
)
print(
f"INFO: End of negative samples generation. Elapsed Time (s): {time.time() - start_time: .4f}"
)
if __name__ == "__main__":
main()
================================================
FILE: tgb/linkproppred/dataset.py
================================================
import sys
from typing import Optional, Dict, Any, Tuple
import os
import os.path as osp
import numpy as np
import pandas as pd
import zipfile
import requests
from clint.textui import progress
from tgb.linkproppred.negative_sampler import NegativeEdgeSampler
from tgb.linkproppred.tkg_negative_sampler import TKGNegativeEdgeSampler
from tgb.linkproppred.thg_negative_sampler import THGNegativeEdgeSampler
from tgb.utils.info import (
PROJ_DIR,
DATA_URL_DICT,
DATA_VERSION_DICT,
DATA_EVAL_METRIC_DICT,
DATA_NS_STRATEGY_DICT,
BColors
)
from tgb.utils.pre_process import (
csv_to_pd_data,
process_node_feat,
process_node_type,
csv_to_pd_data_sc,
csv_to_pd_data_rc,
load_edgelist_wiki,
csv_to_tkg_data,
csv_to_thg_data,
csv_to_forum_data,
csv_to_wikidata,
csv_to_staticdata,
)
from tgb.utils.utils import save_pkl, load_pkl
from tgb.utils.utils import add_inverse_quadruples, vprint
class LinkPropPredDataset(object):
def __init__(
self,
name: str,
root: str = "datasets",
meta_dict: Optional[dict] = None,
preprocess: Optional[bool] = True,
download: Optional[bool] = True,
):
r"""Dataset class for link prediction dataset. Stores meta information about each dataset such as evaluation metrics etc.
also automatically pre-processes the dataset.
Args:
name: name of the dataset
root: root directory to store the dataset folder
meta_dict: dictionary containing meta information about the dataset, should contain key 'dir_name' which is the name of the dataset folder
preprocess: whether to pre-process the dataset
download: whether to download the dataset (default: true)
"""
self.name = name ## original name
# check if dataset url exist
if self.name in DATA_URL_DICT:
self.url = DATA_URL_DICT[self.name]
else:
self.url = None
# check if the evaluatioin metric are specified
if self.name in DATA_EVAL_METRIC_DICT:
self.metric = DATA_EVAL_METRIC_DICT[self.name]
else:
self.metric = None
raise ValueError(f"Dataset {self.name} default evaluation metric not found, it is not supported yet.")
root = PROJ_DIR + root
if meta_dict is None:
self.dir_name = "_".join(name.split("-")) ## replace hyphen with underline
meta_dict = {"dir_name": self.dir_name}
else:
self.dir_name = meta_dict["dir_name"]
self.root = osp.join(root, self.dir_name)
self.meta_dict = meta_dict
if "fname" not in self.meta_dict:
self.meta_dict["fname"] = self.root + "/" + self.name + "_edgelist.csv"
self.meta_dict["nodefile"] = None
if name == "tgbl-flight":
self.meta_dict["nodefile"] = self.root + "/" + "airport_node_feat.csv"
if name == "tkgl-wikidata" or name == "tkgl-smallpedia":
self.meta_dict["staticfile"] = self.root + "/" + self.name + "_static_edgelist.csv"
if "thg" in name:
self.meta_dict["nodeTypeFile"] = self.root + "/" + self.name + "_nodetype.csv"
else:
self.meta_dict["nodeTypeFile"] = None
self.meta_dict["val_ns"] = self.root + "/" + self.name + "_val_ns.pkl"
self.meta_dict["test_ns"] = self.root + "/" + self.name + "_test_ns.pkl"
#! version check
self.version_passed = True
self._version_check()
# initialize
self._node_feat = None
self._edge_feat = None
self._full_data = None
self._train_data = None
self._val_data = None
self._test_data = None
# for tkg and thg
self._edge_type = None
#tkgl-wikidata and tkgl-smallpedia only
self._static_data = None
# for thg only
self._node_type = None
self._node_id = None
if download:
self.download()
else:
if osp.exists(self.meta_dict["fname"]):
dir_name = self.meta_dict["fname"]
vprint(f"files found in {dir_name}")
else:
dir_name = self.meta_dict["fname"]
raise FileNotFoundError(f"Directory not found at {dir_name}, please download the dataset")
# check if the root directory exists, if not create it
if osp.isdir(self.root):
vprint("Dataset directory is ", self.root)
else:
raise FileNotFoundError(f"Directory not found at {self.root}")
if preprocess:
self.pre_process()
self.min_dst_idx, self.max_dst_idx = int(self._full_data["destinations"].min()), int(self._full_data["destinations"].max())
if ('tkg' in self.name):
if self.name in DATA_NS_STRATEGY_DICT:
self.ns_sampler = TKGNegativeEdgeSampler(
dataset_name=self.name,
first_dst_id=self.min_dst_idx,
last_dst_id=self.max_dst_idx,
strategy=DATA_NS_STRATEGY_DICT[self.name],
partial_path=self.root + "/" + self.name,
)
else:
raise ValueError(f"Dataset {self.name} negative sampling strategy not found.")
elif ('thg' in self.name):
#* need to find the smallest node id of all nodes (regardless of types)
min_node_idx = min(int(self._full_data["sources"].min()), int(self._full_data["destinations"].min()))
max_node_idx = max(int(self._full_data["sources"].max()), int(self._full_data["destinations"].max()))
self.ns_sampler = THGNegativeEdgeSampler(
dataset_name=self.name,
first_node_id=min_node_idx,
last_node_id=max_node_idx,
node_type=self._node_type,
)
else:
self.ns_sampler = NegativeEdgeSampler(
dataset_name=self.name,
first_dst_id=self.min_dst_idx,
last_dst_id=self.max_dst_idx,
)
def _version_check(self) -> None:
r"""Implement Version checks for dataset files
updates the file names based on the current version number
prompt the user to download the new version via self.version_passed variable
"""
if (self.name in DATA_VERSION_DICT):
version = DATA_VERSION_DICT[self.name]
else:
raise ValueError(f"Dataset {self.name} version number not found.")
if (version > 1):
#* check if current version is outdated
self.meta_dict["fname"] = self.root + "/" + self.name + "_edgelist_v" + str(int(version)) + ".csv"
self.meta_dict["nodefile"] = None
if self.name == "tgbl-flight":
self.meta_dict["nodefile"] = self.root + "/" + "airport_node_feat_v" + str(int(version)) + ".csv"
self.meta_dict["val_ns"] = self.root + "/" + self.name + "_val_ns_v" + str(int(version)) + ".pkl"
self.meta_dict["test_ns"] = self.root + "/" + self.name + "_test_ns_v" + str(int(version)) + ".pkl"
if (not osp.exists(self.meta_dict["fname"])):
vprint(f"Dataset {self.name} version {int(version)} not found, Please download the latest version of the dataset.")
self.version_passed = False
return None
def download(self) -> None:
"""
downloads this dataset from url
check if files are already downloaded
"""
# check if the file already exists
if osp.exists(self.meta_dict["fname"]):
dir_name = self.meta_dict["fname"]
vprint(f"files found in {dir_name}")
return None
vprint(
f"{BColors.WARNING}Download started, this might take a while . . . {BColors.ENDC}"
)
vprint(f"Dataset title: {self.name}")
if self.url is None:
raise ValueError(f"Dataset {self.name} url not found, download not supported yet.")
else:
r = requests.get(self.url, stream=True)
if osp.isdir(self.root):
vprint("Dataset directory is ", self.root)
else:
os.makedirs(self.root)
path_download = self.root + "/" + self.name + ".zip"
print(f"downloading Dataset: {self.name} to {path_download}")
with open(path_download, "wb") as f:
total_length = int(r.headers.get("content-length"))
for chunk in progress.bar(
r.iter_content(chunk_size=1024),
expected_size=(total_length / 1024) + 1,
):
if chunk:
f.write(chunk)
f.flush()
# for unzipping the file
with zipfile.ZipFile(path_download, "r") as zip_ref:
zip_ref.extractall(self.root)
vprint(f"{BColors.OKGREEN}Download completed {BColors.ENDC}")
self.version_passed = True
def generate_processed_files(self) -> pd.DataFrame:
r"""
turns raw data .csv file into a pandas data frame, stored on disc if not already
Returns:
df: pandas data frame
"""
node_feat = None
if not osp.exists(self.meta_dict["fname"]):
raise FileNotFoundError(f"File not found at {self.meta_dict['fname']}")
if self.meta_dict["nodefile"] is not None:
if not osp.exists(self.meta_dict["nodefile"]):
raise FileNotFoundError(
f"File not found at {self.meta_dict['nodefile']}"
)
#* for thg must have nodetypes
if self.meta_dict["nodeTypeFile"] is not None:
if not osp.exists(self.meta_dict["nodeTypeFile"]):
raise FileNotFoundError(
f"File not found at {self.meta_dict['nodeTypeFile']}"
)
OUT_DF = self.root + "/" + "ml_{}.pkl".format(self.name)
OUT_EDGE_FEAT = self.root + "/" + "ml_{}.pkl".format(self.name + "_edge")
OUT_NODE_ID = self.root + "/" + "ml_{}.pkl".format(self.name + "_nodeid")
if self.meta_dict["nodefile"] is not None:
OUT_NODE_FEAT = self.root + "/" + "ml_{}.pkl".format(self.name + "_node")
if self.meta_dict["nodeTypeFile"] is not None:
OUT_NODE_TYPE = self.root + "/" + "ml_{}.pkl".format(self.name + "_nodeType")
if osp.exists(OUT_DF) and self.version_passed is True:
vprint(f"loading processed file from {OUT_DF}.")
df = pd.read_pickle(OUT_DF)
edge_feat = load_pkl(OUT_EDGE_FEAT)
if (self.name == "tkgl-wikidata") or (self.name == "tkgl-smallpedia"):
node_id = load_pkl(OUT_NODE_ID)
self._node_id = node_id
if self.meta_dict["nodefile"] is not None:
node_feat = load_pkl(OUT_NODE_FEAT)
if self.meta_dict["nodeTypeFile"] is not None:
node_type = load_pkl(OUT_NODE_TYPE)
self._node_type = node_type
else:
vprint("file not processed, generating processed file")
if self.name == "tgbl-flight":
df, edge_feat, node_ids = csv_to_pd_data(self.meta_dict["fname"])
elif self.name == "tgbl-coin":
df, edge_feat, node_ids = csv_to_pd_data_sc(self.meta_dict["fname"])
elif self.name == "tgbl-comment":
df, edge_feat, node_ids = csv_to_pd_data_rc(self.meta_dict["fname"])
elif self.name == "tgbl-review":
df, edge_feat, node_ids = csv_to_pd_data_sc(self.meta_dict["fname"])
elif self.name == "tgbl-wiki":
df, edge_feat, node_ids = load_edgelist_wiki(self.meta_dict["fname"])
elif self.name == "tgbl-subreddit":
df, edge_feat, node_ids = load_edgelist_wiki(self.meta_dict["fname"])
elif self.name == "tgbl-uci":
df, edge_feat, node_ids = load_edgelist_wiki(self.meta_dict["fname"])
elif self.name == "tgbl-enron":
df, edge_feat, node_ids = load_edgelist_wiki(self.meta_dict["fname"])
elif self.name == "tgbl-lastfm":
df, edge_feat, node_ids = load_edgelist_wiki(self.meta_dict["fname"])
elif self.name == "tkgl-polecat":
df, edge_feat, node_ids = csv_to_tkg_data(self.meta_dict["fname"])
elif self.name == "tkgl-icews":
df, edge_feat, node_ids = csv_to_tkg_data(self.meta_dict["fname"])
elif self.name == "tkgl-yago":
df, edge_feat, node_ids = csv_to_tkg_data(self.meta_dict["fname"])
elif self.name == "tkgl-wikidata":
df, edge_feat, node_ids = csv_to_wikidata(self.meta_dict["fname"])
save_pkl(node_ids, OUT_NODE_ID)
self._node_id = node_ids
elif self.name == "tkgl-smallpedia":
df, edge_feat, node_ids = csv_to_wikidata(self.meta_dict["fname"])
save_pkl(node_ids, OUT_NODE_ID)
self._node_id = node_ids
elif self.name == "thgl-myket":
df, edge_feat, node_ids = csv_to_thg_data(self.meta_dict["fname"])
elif self.name == "thgl-github":
df, edge_feat, node_ids = csv_to_thg_data(self.meta_dict["fname"])
elif self.name == "thgl-forum":
df, edge_feat, node_ids = csv_to_forum_data(self.meta_dict["fname"])
elif self.name == "thgl-software":
df, edge_feat, node_ids = csv_to_thg_data(self.meta_dict["fname"])
else:
raise ValueError(f"Dataset {self.name} not found.")
save_pkl(edge_feat, OUT_EDGE_FEAT)
df.to_pickle(OUT_DF)
if self.meta_dict["nodefile"] is not None:
node_feat = process_node_feat(self.meta_dict["nodefile"], node_ids)
save_pkl(node_feat, OUT_NODE_FEAT)
if self.meta_dict["nodeTypeFile"] is not None:
node_type = process_node_type(self.meta_dict["nodeTypeFile"], node_ids)
save_pkl(node_type, OUT_NODE_TYPE)
#? do not return node_type, simply set it
self._node_type = node_type
return df, edge_feat, node_feat
def pre_process(self):
"""
Pre-process the dataset and generates the splits, must be run before dataset properties can be accessed
generates the edge data and different train, val, test splits
"""
# check if path to file is valid
df, edge_feat, node_feat = self.generate_processed_files()
#* design choice, only stores the original edges not the inverse relations on disc
if ("tkgl" in self.name):
df = add_inverse_quadruples(df)
sources = np.array(df["u"])
destinations = np.array(df["i"])
timestamps = np.array(df["ts"])
edge_idxs = np.array(df["idx"])
weights = np.array(df["w"])
edge_label = np.ones(len(df)) # should be 1 for all pos edges
if (self.name == "tgbl-coin") or (self.name == "tgbl-review"):
self._edge_feat = weights.reshape(-1,1)
elif (self.name == "tgbl-comment"):
self._edge_feat = np.concatenate((edge_feat, weights.reshape(-1,1)), axis=1)
else:
self._edge_feat = edge_feat
self._node_feat = node_feat
full_data = {
"sources": sources.astype(int),
"destinations": destinations.astype(int),
"timestamps": timestamps.astype(int),
"edge_idxs": edge_idxs,
"edge_feat": self._edge_feat,
"w": weights,
"edge_label": edge_label,
}
#* for tkg and thg
if ("edge_type" in df):
edge_type = np.array(df["edge_type"]).astype(int)
self._edge_type = edge_type
full_data["edge_type"] = edge_type
self._full_data = full_data
if ("yago" in self.name):
_train_mask, _val_mask, _test_mask = self.generate_splits(full_data, val_ratio=0.1, test_ratio=0.10) #99) #val_ratio=0.097, test_ratio=0.099)
else:
_train_mask, _val_mask, _test_mask = self.generate_splits(full_data, val_ratio=0.15, test_ratio=0.15)
self._train_mask = _train_mask
self._val_mask = _val_mask
self._test_mask = _test_mask
def generate_splits(
self,
full_data: Dict[str, Any],
val_ratio: float = 0.15,
test_ratio: float = 0.15,
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
r"""Generates train, validation, and test splits from the full dataset
Args:
full_data: dictionary containing the full dataset
val_ratio: ratio of validation data
test_ratio: ratio of test data
Returns:
train_data: dictionary containing the training dataset
val_data: dictionary containing the validation dataset
test_data: dictionary containing the test dataset
"""
val_time, test_time = list(
np.quantile(
full_data["timestamps"],
[(1 - val_ratio - test_ratio), (1 - test_ratio)],
)
)
timestamps = full_data["timestamps"]
train_mask = timestamps <= val_time
val_mask = np.logical_and(timestamps <= test_time, timestamps > val_time)
test_mask = timestamps > test_time
return train_mask, val_mask, test_mask
def preprocess_static_edges(self):
"""
Pre-process the static edges of the dataset
"""
if ("staticfile" in self.meta_dict):
OUT_DF = self.root + "/" + "ml_{}.pkl".format(self.name + "_static")
if osp.exists(OUT_DF) and self.version_passed is True:
vprint(f"loading processed file from {OUT_DF}.")
static_dict = load_pkl(OUT_DF)
self._static_data = static_dict
else:
vprint("file not processed, generating processed file")
static_dict, node_ids = csv_to_staticdata(self.meta_dict["staticfile"], self._node_id)
save_pkl(static_dict, OUT_DF)
self._static_data = static_dict
else:
vprint ("static edges are only for tkgl-wikidata and tkgl-smallpedia datasets")
@property
def eval_metric(self) -> str:
"""
the official evaluation metric for the dataset, loaded from info.py
Returns:
eval_metric: str, the evaluation metric
"""
return self.metric
@property
def negative_sampler(self) -> NegativeEdgeSampler:
r"""
Returns the negative sampler of the dataset, will load negative samples from disc
Returns:
negative_sampler: NegativeEdgeSampler
"""
return self.ns_sampler
def load_val_ns(self) -> None:
r"""
load the negative samples for the validation set
"""
self.ns_sampler.load_eval_set(
fname=self.meta_dict["val_ns"], split_mode="val"
)
def load_test_ns(self) -> None:
r"""
load the negative samples for the test set
"""
self.ns_sampler.load_eval_set(
fname=self.meta_dict["test_ns"], split_mode="test"
)
@property
def num_nodes(self) -> int:
r"""
Returns the total number of unique nodes in the dataset
Returns:
num_nodes: int, the number of unique nodes
"""
src = self._full_data["sources"]
dst = self._full_data["destinations"]
all_nodes = np.concatenate((src, dst), axis=0)
uniq_nodes = np.unique(all_nodes, axis=0)
return uniq_nodes.shape[0]
@property
def num_edges(self) -> int:
r"""
Returns the total number of edges in the dataset
Returns:
num_edges: int, the number of edges
"""
src = self._full_data["sources"]
return src.shape[0]
@property
def num_rels(self) -> int:
r"""
Returns the number of relation types in the dataset
Returns:
num_rels: int, the number of relation types
"""
#* if it is a homogenous graph
if ("edge_type" not in self._full_data):
return 1
else:
return np.unique(self._full_data["edge_type"]).shape[0]
@property
def node_feat(self) -> Optional[np.ndarray]:
r"""
Returns the node features of the dataset with dim [N, feat_dim]
Returns:
node_feat: np.ndarray, [N, feat_dim] or None if there is no node feature
"""
return self._node_feat
@property
def node_type(self) -> Optional[np.ndarray]:
r"""
Returns the node types of the dataset with dim [N], only for temporal heterogeneous graphs
Returns:
node_feat: np.ndarray, [N] or None if there is no node feature
"""
return self._node_type
@property
def edge_feat(self) -> Optional[np.ndarray]:
r"""
Returns the edge features of the dataset with dim [E, feat_dim]
Returns:
edge_feat: np.ndarray, [E, feat_dim] or None if there is no edge feature
"""
return self._edge_feat
@property
def edge_type(self) -> Optional[np.ndarray]:
r"""
Returns the edge types of the dataset with dim [E, 1], only for temporal knowledge graph and temporal heterogeneous graph
Returns:
edge_type: np.ndarray, [E, 1] or None if it is not a TKG or THG
"""
return self._edge_type
@property
def static_data(self) -> Optional[np.ndarray]:
r"""
Returns the static edges related to this dataset, applies for tkgl-wikidata and tkgl-smallpedia, edges are (src, dst, rel_type)
Returns:
df: pd.DataFrame {"head": np.ndarray, "tail": np.ndarray, "rel_type": np.ndarray}
"""
if (self.name == "tkgl-wikidata") or (self.name == "tkgl-smallpedia"):
self.preprocess_static_edges()
return self._static_data
@property
def full_data(self) -> Dict[str, Any]:
r"""
the full data of the dataset as a dictionary with keys: 'sources', 'destinations', 'timestamps', 'edge_idxs', 'edge_feat', 'w', 'edge_label',
Returns:
full_data: Dict[str, Any]
"""
if self._full_data is None:
raise ValueError(
"dataset has not been processed yet, please call pre_process() first"
)
return self._full_data
@property
def train_mask(self) -> np.ndarray:
r"""
Returns the train mask of the dataset
Returns:
train_mask: training masks
"""
if self._train_mask is None:
raise ValueError("training split hasn't been loaded")
return self._train_mask
@property
def val_mask(self) -> np.ndarray:
r"""
Returns the validation mask of the dataset
Returns:
val_mask: Dict[str, Any]
"""
if self._val_mask is None:
raise ValueError("validation split hasn't been loaded")
return self._val_mask
@property
def test_mask(self) -> np.ndarray:
r"""
Returns the test mask of the dataset:
Returns:
test_mask: Dict[str, Any]
"""
if self._test_mask is None:
raise ValueError("test split hasn't been loaded")
return self._test_mask
def main():
name = "tkgl-polecat"
dataset = LinkPropPredDataset(name=name, root="datasets", preprocess=True)
dataset.edge_type
# name = "tgbl-comment"
# dataset = LinkPropPredDataset(name=name, root="datasets", preprocess=True)
# dataset.node_feat
# dataset.edge_feat # not the edge weights
# dataset.full_data
# dataset.full_data["edge_idxs"]
# dataset.full_data["sources"]
# dataset.full_data["destinations"]
# dataset.full_data["timestamps"]
# dataset.full_data["edge_label"]
if __name__ == "__main__":
main()
================================================
FILE: tgb/linkproppred/dataset_pyg.py
================================================
import torch
from typing import Optional, Optional, Callable
from torch_geometric.data import Dataset, TemporalData
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.linkproppred.negative_sampler import NegativeEdgeSampler
class PyGLinkPropPredDataset(Dataset):
def __init__(
self,
name: str,
root: str = "datasets",
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
download: Optional[bool] = True,
):
r"""
PyG wrapper for the LinkPropPredDataset
can return pytorch tensors for src,dst,t,msg,label
can return Temporal Data object
Parameters:
name: name of the dataset, passed to `LinkPropPredDataset`
root (string): Root directory where the dataset should be saved, passed to `LinkPropPredDataset`
transform (callable, optional): A function/transform that takes in an, not used in this case
pre_transform (callable, optional): A function/transform that takes in, not used in this case
download (optional, bool): download or not (default True)
"""
self.name = name
self.root = root
self.dataset = LinkPropPredDataset(name=name, root=root, download=download)
self._train_mask = torch.from_numpy(self.dataset.train_mask)
self._val_mask = torch.from_numpy(self.dataset.val_mask)
self._test_mask = torch.from_numpy(self.dataset.test_mask)
super().__init__(root, transform, pre_transform)
self._node_feat = self.dataset.node_feat
self._edge_type = None
self._static_data = None
if self._node_feat is None:
self._node_feat = None
else:
self._node_feat = torch.from_numpy(self._node_feat).float()
self._node_type = self.dataset.node_type
if self.node_type is not None:
self._node_type = torch.from_numpy(self.dataset.node_type).long()
self.process_data()
self._ns_sampler = self.dataset.negative_sampler
@property
def eval_metric(self) -> str:
"""
the official evaluation metric for the dataset, loaded from info.py
Returns:
eval_metric: str, the evaluation metric
"""
return self.dataset.eval_metric
@property
def negative_sampler(self) -> NegativeEdgeSampler:
r"""
Returns the negative sampler of the dataset, will load negative samples from disc
Returns:
negative_sampler: NegativeEdgeSampler
"""
return self._ns_sampler
@property
def num_nodes(self) -> int:
r"""
Returns the total number of unique nodes in the dataset
Returns:
num_nodes: int, the number of unique nodes
"""
return self.dataset.num_nodes
@property
def num_rels(self) -> int:
r"""
Returns the total number of unique relations in the dataset
Returns:
num_rels: int, the number of unique relations
"""
return self.dataset.num_rels
@property
def num_edges(self) -> int:
r"""
Returns the total number of edges in the dataset
Returns:
num_edges: int, the number of edges
"""
return self.dataset.num_edges
def load_val_ns(self) -> None:
r"""
load the negative samples for the validation set
"""
self.dataset.load_val_ns()
def load_test_ns(self) -> None:
r"""
load the negative samples for the test set
"""
self.dataset.load_test_ns()
@property
def train_mask(self) -> torch.Tensor:
r"""
Returns the train mask of the dataset
Returns:
train_mask: the mask for edges in the training set
"""
if self._train_mask is None:
raise ValueError("training split hasn't been loaded")
return self._train_mask
@property
def val_mask(self) -> torch.Tensor:
r"""
Returns the validation mask of the dataset
Returns:
val_mask: the mask for edges in the validation set
"""
if self._val_mask is None:
raise ValueError("validation split hasn't been loaded")
return self._val_mask
@property
def test_mask(self) -> torch.Tensor:
r"""
Returns the test mask of the dataset:
Returns:
test_mask: the mask for edges in the test set
"""
if self._test_mask is None:
raise ValueError("test split hasn't been loaded")
return self._test_mask
@property
def node_feat(self) -> torch.Tensor:
r"""
Returns the node features of the dataset
Returns:
node_feat: the node features
"""
return self._node_feat
@property
def node_type(self) -> torch.Tensor:
r"""
Returns the node types of the dataset
Returns:
node_type: the node types [N]
"""
return self._node_type
@property
def src(self) -> torch.Tensor:
r"""
Returns the source nodes of the dataset
Returns:
src: the idx of the source nodes
"""
return self._src
@property
def dst(self) -> torch.Tensor:
r"""
Returns the destination nodes of the dataset
Returns:
dst: the idx of the destination nodes
"""
return self._dst
@property
def ts(self) -> torch.Tensor:
r"""
Returns the timestamps of the dataset
Returns:
ts: the timestamps of the edges
"""
return self._ts
@property
def static_data(self) -> torch.Tensor:
r"""
Returns the static data of the dataset for tkgl-wikidata and tkgl-smallpedia
Returns:
static_data: the static data of the dataset
"""
if (self._static_data is None):
static_dict = {}
static_dict["head"] = torch.from_numpy(self.dataset.static_data["head"]).long()
static_dict["tail"] = torch.from_numpy(self.dataset.static_data["tail"]).long()
static_dict["edge_type"] = torch.from_numpy(self.dataset.static_data["edge_type"]).long()
self._static_data = static_dict
return self._static_data
else:
return self._static_data
@property
def edge_type(self) -> torch.Tensor:
r"""
Returns the edge types for each edge
Returns:
edge_type: edge type tensor (int)
"""
return self._edge_type
@property
def edge_feat(self) -> torch.Tensor:
r"""
Returns the edge features of the dataset
Returns:
edge_feat: the edge features
"""
return self._edge_feat
@property
def edge_label(self) -> torch.Tensor:
r"""
Returns the edge labels of the dataset
Returns:
edge_label: the labels of the edges
"""
return self._edge_label
def process_data(self) -> None:
r"""
convert the numpy arrays from dataset to pytorch tensors
"""
src = torch.from_numpy(self.dataset.full_data["sources"])
dst = torch.from_numpy(self.dataset.full_data["destinations"])
ts = torch.from_numpy(self.dataset.full_data["timestamps"])
msg = torch.from_numpy(
self.dataset.full_data["edge_feat"]
) # use edge features here if available
edge_label = torch.from_numpy(
self.dataset.full_data["edge_label"]
) # this is the label indicating if an edge is a true edge, always 1 for true edges
w = torch.from_numpy(
self.dataset.full_data["w"]
)
# * first check typing for all tensors
# source tensor must be of type int64
# warnings.warn("sources tensor is not of type int64 or int32, forcing conversion")
if src.dtype != torch.int64:
src = src.long()
# destination tensor must be of type int64
if dst.dtype != torch.int64:
dst = dst.long()
# timestamp tensor must be of type int64
if ts.dtype != torch.int64:
ts = ts.long()
# message tensor must be of type float32
if msg.dtype != torch.float32:
msg = msg.float()
# weight tensor must be of type float32
if w.dtype != torch.float32:
w = w.float()
#* for tkg
if ("edge_type" in self.dataset.full_data):
edge_type = torch.from_numpy(self.dataset.full_data["edge_type"])
if edge_type.dtype != torch.int64:
edge_type = edge_type.long()
self._edge_type = edge_type
self._src = src
self._dst = dst
self._ts = ts
self._edge_label = edge_label
self._edge_feat = msg
self._w = w
def get_TemporalData(self) -> TemporalData:
"""
return the TemporalData object for the entire dataset
"""
if (self._edge_type is not None):
data = TemporalData(
src=self._src,
dst=self._dst,
t=self._ts,
msg=self._edge_feat,
y=self._edge_label,
edge_type=self._edge_type,
w=self._w,
)
else:
data = TemporalData(
src=self._src,
dst=self._dst,
t=self._ts,
msg=self._edge_feat,
y=self._edge_label,
w=self._w,
)
return data
def len(self) -> int:
"""
size of the dataset
Returns:
size: int
"""
return self._src.shape[0]
def get(self, idx: int) -> TemporalData:
"""
construct temporal data object for a single edge
Parameters:
idx: index of the edge
Returns:
data: TemporalData object
"""
if (self._edge_type is not None):
data = TemporalData(
src=self._src[idx],
dst=self._dst[idx],
t=self._ts[idx],
msg=self._edge_feat[idx],
y=self._edge_label[idx],
edge_type=self._edge_type[idx]
)
else:
data = TemporalData(
src=self._src[idx],
dst=self._dst[idx],
t=self._ts[idx],
msg=self._edge_feat[idx],
y=self._edge_label[idx],
)
return data
def __repr__(self) -> str:
return f"{self.name.capitalize()}()"
================================================
FILE: tgb/linkproppred/evaluate.py
================================================
"""
Evaluator Module for Dynamic Link Prediction
"""
import numpy as np
from sklearn.metrics import *
from tgb.utils.info import DATA_EVAL_METRIC_DICT
from tgb.utils.utils import vprint
try:
import torch
except ImportError:
torch = None
class Evaluator(object):
r"""Evaluator for Link Property Prediction """
def __init__(self, name: str, k_value: int = 10):
r"""
Parameters:
name: name of the dataset
k_value: the desired 'k' value for calculating metric@k
"""
self.name = name
self.k_value = k_value # for computing `hits@k`
self.valid_metric_list = ['hits@', 'mrr']
if self.name not in DATA_EVAL_METRIC_DICT:
raise NotImplementedError("Dataset not supported")
def _parse_and_check_input(self, input_dict):
r"""
Check whether the input has the appropriate format
Parametrers:
input_dict: a dictionary containing "y_pred_pos", "y_pred_neg", and "eval_metric"
note: "eval_metric" should be a list including one or more of the followin metrics: ["hits@", "mrr"]
Returns:
y_pred_pos: positive predicted scores
y_pred_neg: negative predicted scores
"""
if "eval_metric" not in input_dict:
raise RuntimeError("Missing key of eval_metric!")
for eval_metric in input_dict["eval_metric"]:
if eval_metric in self.valid_metric_list:
if "y_pred_pos" not in input_dict:
raise RuntimeError("Missing key of y_true")
if "y_pred_neg" not in input_dict:
raise RuntimeError("Missing key of y_pred")
y_pred_pos, y_pred_neg = input_dict["y_pred_pos"], input_dict["y_pred_neg"]
# converting to numpy on cpu
if torch is not None and isinstance(y_pred_pos, torch.Tensor):
y_pred_pos = y_pred_pos.detach().cpu().numpy()
if torch is not None and isinstance(y_pred_neg, torch.Tensor):
y_pred_neg = y_pred_neg.detach().cpu().numpy()
# check type and shape
if not isinstance(y_pred_pos, np.ndarray) or not isinstance(y_pred_neg, np.ndarray):
raise RuntimeError(
"Arguments to Evaluator need to be either numpy ndarray or torch tensor!"
)
else:
raise ValueError(f"Unsupported eval metric: {eval_metric}, not found in {self.valid_metric_list}")
self.eval_metric = input_dict["eval_metric"]
return y_pred_pos, y_pred_neg
def _eval_hits_and_mrr(self, y_pred_pos, y_pred_neg, type_info, k_value):
r"""
compute hist@k and mrr
reference:
- https://github.com/snap-stanford/ogb/blob/d5c11d91c9e1c22ed090a2e0bbda3fe357de66e7/ogb/linkproppred/evaluate.py#L214
Parameters:
y_pred_pos: positive predicted scores
y_pred_neg: negative predicted scores
type_info: type of the predicted scores; could be 'torch' or 'numpy'
k_value: the desired 'k' value for calculating metric@k
Returns:
a dictionary containing the computed performance metrics
"""
if type_info == 'torch':
# calculate ranks
y_pred_pos = y_pred_pos.view(-1, 1)
# optimistic rank: "how many negatives have a larger score than the positive?"
# ~> the positive is ranked first among those with equal score
optimistic_rank = (y_pred_neg > y_pred_pos).sum(dim=1)
# pessimistic rank: "how many negatives have at least the positive score?"
# ~> the positive is ranked last among those with equal score
pessimistic_rank = (y_pred_neg >= y_pred_pos).sum(dim=1)
ranking_list = 0.5 * (optimistic_rank + pessimistic_rank) + 1
hitsK_list = (ranking_list <= k_value).to(torch.float)
mrr_list = 1./ranking_list.to(torch.float)
return {
f'hits@{k_value}': hitsK_list.mean(),
'mrr': mrr_list.mean()
}
else:
y_pred_pos = y_pred_pos.reshape(-1, 1)
optimistic_rank = (y_pred_neg > y_pred_pos).sum(axis=1)
pessimistic_rank = (y_pred_neg >= y_pred_pos).sum(axis=1)
ranking_list = 0.5 * (optimistic_rank + pessimistic_rank) + 1
hitsK_list = (ranking_list <= k_value).astype(np.float32)
mrr_list = 1./ranking_list.astype(np.float32)
return {
f'hits@{k_value}': hitsK_list.mean(),
'mrr': mrr_list.mean()
}
def eval(self,
input_dict: dict,
verbose: bool = False) -> dict:
r"""
evaluate the link prediction task
this method is callable through an instance of this object to compute the metric
Parameters:
input_dict: a dictionary containing "y_pred_pos", "y_pred_neg", and "eval_metric"
the performance metric is calculated for the provided scores
verbose: whether to print out the computed metric
Returns:
perf_dict: a dictionary containing the computed performance metric
"""
y_pred_pos, y_pred_neg = self._parse_and_check_input(input_dict) # convert the predictions to numpy
perf_dict = self._eval_hits_and_mrr(y_pred_pos, y_pred_neg, type_info='numpy', k_value=self.k_value)
return perf_dict
================================================
FILE: tgb/linkproppred/negative_generator.py
================================================
"""
Sample and Generate negative edges that are going to be used for evaluation of a dynamic graph learning model
Negative samples are generated and saved to files ONLY once;
other times, they should be loaded from file with instances of the `negative_sampler.py`.
"""
import torch
import numpy as np
from torch_geometric.data import TemporalData
from tgb.utils.utils import save_pkl
import os
from tqdm import tqdm
from tgb.utils.utils import vprint
class NegativeEdgeGenerator(object):
def __init__(
self,
dataset_name: str,
first_dst_id: int,
last_dst_id: int,
num_neg_e: int = 100, # number of negative edges sampled per positive edges --> make it constant => 1000
strategy: str = "rnd",
rnd_seed: int = 123,
hist_ratio: float = 0.5,
historical_data: TemporalData = None,
) -> None:
r"""
Negative Edge Sampler class
this is a class for generating negative samples for a specific datasets
the set of the positive samples are provided, the negative samples are generated with specific strategies
and are saved for consistent evaluation across different methods
negative edges are sampled with 'oen_vs_many' strategy.
it is assumed that the destination nodes are indexed sequentially with 'first_dst_id'
and 'last_dst_id' being the first and last index, respectively.
Parameters:
dataset_name: name of the dataset
first_dst_id: identity of the first destination node
last_dst_id: indentity of the last destination node
num_neg_e: number of negative edges being generated per each positive edge
strategy: how to generate negative edges; can be 'rnd' or 'hist_rnd'
rnd_seed: random seed for consistency
hist_ratio: if the startegy is 'hist_rnd', how much of the negatives are historical
historical_data: previous records of the positive edges
Returns:
None
"""
self.rnd_seed = rnd_seed
np.random.seed(self.rnd_seed)
self.dataset_name = dataset_name
self.first_dst_id = first_dst_id
self.last_dst_id = last_dst_id
self.num_neg_e = num_neg_e
assert strategy in [
"rnd",
"hist_rnd",
], "The supported strategies are `rnd` or `hist_rnd`!"
self.strategy = strategy
if self.strategy == "hist_rnd":
assert (
historical_data != None
), "Train data should be passed when `hist_rnd` strategy is selected."
self.hist_ratio = hist_ratio
self.historical_data = historical_data
def generate_negative_samples(self,
data: TemporalData,
split_mode: str,
partial_path: str,
) -> None:
r"""
Generate negative samples
Parameters:
data: an object containing positive edges information
split_mode: specifies whether to generate negative edges for 'validation' or 'test' splits
partial_path: in which directory save the generated negatives
"""
# file name for saving or loading...
filename = (
partial_path
+ "/"
+ self.dataset_name
+ "_"
+ split_mode
+ "_"
+ "ns"
+ ".pkl"
)
if self.strategy == "rnd":
self.generate_negative_samples_rnd(data, split_mode, filename)
elif self.strategy == "hist_rnd":
self.generate_negative_samples_hist_rnd(
self.historical_data, data, split_mode, filename
)
else:
raise ValueError("Unsupported negative sample generation strategy!")
def generate_negative_samples_rnd(self,
data: TemporalData,
split_mode: str,
filename: str,
) -> None:
r"""
Generate negative samples based on the `HIST-RND` strategy:
- for each positive edge, sample a batch of negative edges from all possible edges with the same source node
- filter actual positive edges
Parameters:
data: an object containing positive edges information
split_mode: specifies whether to generate negative edges for 'validation' or 'test' splits
filename: name of the file containing the generated negative edges
"""
vprint(
f"INFO: Negative Sampling Strategy: {self.strategy}, Data Split: {split_mode}"
)
assert split_mode in [
"val",
"test",
], "Invalid split-mode! It should be `val` or `test`!"
if os.path.exists(filename):
vprint(
f"INFO: negative samples for '{split_mode}' evaluation are already generated!"
)
else:
vprint(f"INFO: Generating negative samples for '{split_mode}' evaluation!")
# retrieve the information from the batch
pos_src, pos_dst, pos_timestamp = (
data.src.cpu().numpy(),
data.dst.cpu().numpy(),
data.t.cpu().numpy(),
)
# all possible destinations
all_dst = np.arange(self.first_dst_id, self.last_dst_id + 1)
evaluation_set = {}
# generate a list of negative destinations for each positive edge
pos_edge_tqdm = tqdm(
zip(pos_src, pos_dst, pos_timestamp), total=len(pos_src)
)
for (
pos_s,
pos_d,
pos_t,
) in pos_edge_tqdm:
t_mask = pos_timestamp == pos_t
src_mask = pos_src == pos_s
fn_mask = np.logical_and(t_mask, src_mask)
pos_e_dst_same_src = pos_dst[fn_mask]
filtered_all_dst = np.setdiff1d(all_dst, pos_e_dst_same_src)
'''
when num_neg_e is larger than all possible destinations simple return all possible destinations
'''
if (self.num_neg_e > len(filtered_all_dst)):
neg_d_arr = filtered_all_dst
else:
neg_d_arr = np.random.choice(
filtered_all_dst, self.num_neg_e, replace=False) #never replace negatives
evaluation_set[(pos_s, pos_d, pos_t)] = neg_d_arr
# save the generated evaluation set to disk
save_pkl(evaluation_set, filename)
def generate_historical_edge_set(self,
historical_data: TemporalData,
) -> tuple:
r"""
Generate the set of edges seen durign training or validation
ONLY `train_data` should be passed as historical data; i.e., the HISTORICAL negative edges should be selected from training data only.
Parameters:
historical_data: contains the positive edges observed previously
Returns:
historical_edges: distict historical positive edges
hist_edge_set_per_node: historical edges observed for each node
"""
sources = historical_data.src.cpu().numpy()
destinations = historical_data.dst.cpu().numpy()
historical_edges = {}
hist_e_per_node = {}
for src, dst in zip(sources, destinations):
# edge-centric
if (src, dst) not in historical_edges:
historical_edges[(src, dst)] = 1
# node-centric
if src not in hist_e_per_node:
hist_e_per_node[src] = [dst]
else:
hist_e_per_node[src].append(dst)
hist_edge_set_per_node = {}
for src, dst_list in hist_e_per_node.items():
hist_edge_set_per_node[src] = np.array(list(set(dst_list)))
return historical_edges, hist_edge_set_per_node
def generate_negative_samples_hist_rnd(
self,
historical_data : TemporalData,
data: TemporalData,
split_mode: str,
filename: str,
) -> None:
r"""
Generate negative samples based on the `HIST-RND` strategy:
- up to 50% of the negative samples are selected from the set of edges seen during the training with the same source node.
- the rest of the negative edges are randomly sampled with the fixed source node.
Parameters:
historical_data: contains the history of the observed positive edges including
distinct positive edges and edges observed for each positive node
data: an object containing positive edges information
split_mode: specifies whether to generate negative edges for 'validation' or 'test' splits
filename: name of the file to save generated negative edges
Returns:
None
"""
vprint(
f"INFO: Negative Sampling Strategy: {self.strategy}, Data Split: {split_mode}"
)
assert split_mode in [
"val",
"test",
], "Invalid split-mode! It should be `val` or `test`!"
if os.path.exists(filename):
vprint(
f"INFO: negative samples for '{split_mode}' evaluation are already generated!"
)
else:
vprint(f"INFO: Generating negative samples for '{split_mode}' evaluation!")
# retrieve the information from the batch
pos_src, pos_dst, pos_timestamp = (
data.src.cpu().numpy(),
data.dst.cpu().numpy(),
data.t.cpu().numpy(),
)
pos_ts_edge_dict = {} #{ts: {src: [dsts]}}
pos_edge_tqdm = tqdm(
zip(pos_src, pos_dst, pos_timestamp), total=len(pos_src)
)
for (
pos_s,
pos_d,
pos_t,
) in pos_edge_tqdm:
if (pos_t not in pos_ts_edge_dict):
pos_ts_edge_dict[pos_t] = {pos_s: [pos_d]}
else:
if (pos_s not in pos_ts_edge_dict[pos_t]):
pos_ts_edge_dict[pos_t][pos_s] = [pos_d]
else:
pos_ts_edge_dict[pos_t][pos_s].append(pos_d)
# all possible destinations
all_dst = np.arange(self.first_dst_id, self.last_dst_id + 1)
# get seen edge history
(
historical_edges,
hist_edge_set_per_node,
) = self.generate_historical_edge_set(historical_data)
# sample historical edges
max_num_hist_neg_e = int(self.num_neg_e * self.hist_ratio)
evaluation_set = {}
# generate a list of negative destinations for each positive edge
pos_edge_tqdm = tqdm(
zip(pos_src, pos_dst, pos_timestamp), total=len(pos_src)
)
for (
pos_s,
pos_d,
pos_t,
) in pos_edge_tqdm:
pos_e_dst_same_src = np.array(pos_ts_edge_dict[pos_t][pos_s])
# sample historical edges
num_hist_neg_e = 0
neg_hist_dsts = np.array([])
seen_dst = []
if pos_s in hist_edge_set_per_node:
seen_dst = hist_edge_set_per_node[pos_s]
if len(seen_dst) >= 1:
filtered_all_seen_dst = np.setdiff1d(seen_dst, pos_e_dst_same_src)
#filtered_all_seen_dst = seen_dst #! no collision check
num_hist_neg_e = (
max_num_hist_neg_e
if max_num_hist_neg_e <= len(filtered_all_seen_dst)
else len(filtered_all_seen_dst)
)
neg_hist_dsts = np.random.choice(
filtered_all_seen_dst, num_hist_neg_e, replace=False
)
# sample random edges
if (len(seen_dst) >= 1):
invalid_dst = np.concatenate((np.array(pos_e_dst_same_src), seen_dst))
else:
invalid_dst = np.array(pos_e_dst_same_src)
filtered_all_rnd_dst = np.setdiff1d(all_dst, invalid_dst)
num_rnd_neg_e = self.num_neg_e - num_hist_neg_e
'''
when num_neg_e is larger than all possible destinations simple return all possible destinations
'''
if (num_rnd_neg_e > len(filtered_all_rnd_dst)):
neg_rnd_dsts = filtered_all_rnd_dst
else:
neg_rnd_dsts = np.random.choice(
filtered_all_rnd_dst, num_rnd_neg_e, replace=False
)
# concatenate the two sets: historical and random
neg_dst_arr = np.concatenate((neg_hist_dsts, neg_rnd_dsts))
evaluation_set[(pos_s, pos_d, pos_t)] = neg_dst_arr
# save the generated evaluation set to disk
save_pkl(evaluation_set, filename)
================================================
FILE: tgb/linkproppred/negative_sampler.py
================================================
"""
Sample negative edges for evaluation of dynamic link prediction
Load already generated negative edges from file, batch them based on the positive edge, and return the evaluation set
"""
import torch
from torch import Tensor
import numpy as np
from tgb.utils.utils import save_pkl, load_pkl
from tgb.utils.info import PROJ_DIR
import os
import time
class NegativeEdgeSampler(object):
def __init__(
self,
dataset_name: str,
first_dst_id: int = 0,
last_dst_id: int = 0,
strategy: str = "hist_rnd",
) -> None:
r"""
Negative Edge Sampler
Loads and query the negative batches based on the positive batches provided.
constructor for the negative edge sampler class
Parameters:
dataset_name: name of the dataset
first_dst_id: identity of the first destination node
last_dst_id: indentity of the last destination node
strategy: will always load the pre-generated negatives
Returns:
None
"""
self.dataset_name = dataset_name
assert strategy in [
"rnd",
"hist_rnd",
], "The supported strategies are `rnd` or `hist_rnd`!"
self.strategy = strategy
self.eval_set = {}
def load_eval_set(
self,
fname: str,
split_mode: str = "val",
) -> None:
r"""
Load the evaluation set from disk, can be either val or test set ns samples
Parameters:
fname: the file name of the evaluation ns on disk
split_mode: the split mode of the evaluation set, can be either `val` or `test`
Returns:
None
"""
assert split_mode in [
"val",
"test",
], "Invalid split-mode! It should be `val`, `test`"
if not os.path.exists(fname):
raise FileNotFoundError(f"File not found at {fname}")
self.eval_set[split_mode] = load_pkl(fname)
def reset_eval_set(self,
split_mode: str = "test",
) -> None:
r"""
Reset evaluation set
Parameters:
split_mode: specifies whether to generate negative edges for 'validation' or 'test' splits
Returns:
None
"""
assert split_mode in [
"val",
"test",
], "Invalid split-mode! It should be `val`, `test`!"
self.eval_set[split_mode] = None
def query_batch(self,
pos_src: Tensor,
pos_dst: Tensor,
pos_timestamp: Tensor,
edge_type: Tensor = None,
split_mode: str = "test") -> list:
r"""
For each positive edge in the `pos_batch`, return a list of negative edges
`split_mode` specifies whether the valiation or test evaluation set should be retrieved.
modify now to include edge type argument
Parameters:
pos_src: list of positive source nodes
pos_dst: list of positive destination nodes
pos_timestamp: list of timestamps of the positive edges
split_mode: specifies whether to generate negative edges for 'validation' or 'test' splits
Returns:
neg_samples: a list of list; each internal list contains the set of negative edges that
should be evaluated against each positive edge.
"""
assert split_mode in [
"val",
"test",
], "Invalid split-mode! It should be `val`, `test`!"
if self.eval_set[split_mode] == None:
raise ValueError(
f"Evaluation set is None! You should load the {split_mode} evaluation set first!"
)
# check the argument types...
if torch is not None and isinstance(pos_src, torch.Tensor):
pos_src = pos_src.detach().cpu().numpy()
if torch is not None and isinstance(pos_dst, torch.Tensor):
pos_dst = pos_dst.detach().cpu().numpy()
if torch is not None and isinstance(pos_timestamp, torch.Tensor):
pos_timestamp = pos_timestamp.detach().cpu().numpy()
if torch is not None and isinstance(edge_type, torch.Tensor):
edge_type = edge_type.detach().cpu().numpy()
if not isinstance(pos_src, np.ndarray) or not isinstance(pos_dst, np.ndarray) or not(pos_timestamp, np.ndarray):
raise RuntimeError(
"pos_src, pos_dst, and pos_timestamp need to be either numpy ndarray or torch tensor!"
)
neg_samples = []
if (edge_type is None):
for pos_s, pos_d, pos_t in zip(pos_src, pos_dst, pos_timestamp):
if (pos_s, pos_d, pos_t) not in self.eval_set[split_mode]:
raise ValueError(
f"The edge ({pos_s}, {pos_d}, {pos_t}) is not in the '{split_mode}' evaluation set! Please check the implementation."
)
else:
neg_samples.append(
[
int(neg_dst)
for neg_dst in self.eval_set[split_mode][(pos_s, pos_d, pos_t)]
]
)
else:
for pos_s, pos_d, pos_t, e_type in zip(pos_src, pos_dst, pos_timestamp, edge_type):
if (pos_s, pos_d, pos_t, e_type) not in self.eval_set[split_mode]:
raise ValueError(
f"The edge ({pos_s}, {pos_d}, {pos_t}, {e_type}) is not in the '{split_mode}' evaluation set! Please check the implementation."
)
else:
neg_samples.append(
[
int(neg_dst)
for neg_dst in self.eval_set[split_mode][(pos_s, pos_d, pos_t, e_type)]
]
)
return neg_samples
================================================
FILE: tgb/linkproppred/thg_negative_generator.py
================================================
"""
Sample and Generate negative edges that are going to be used for evaluation of a dynamic graph learning model
Negative samples are generated and saved to files ONLY once;
other times, they should be loaded from file with instances of the `negative_sampler.py`.
"""
import os
import torch
import numpy as np
from tqdm import tqdm
from torch_geometric.data import TemporalData
from tgb.utils.utils import save_pkl
from typing import Union
from tgb.utils.utils import vprint
"""
negative sample generator for tkg datasets
temporal filterted MRR
"""
class THGNegativeEdgeGenerator(object):
def __init__(
self,
dataset_name: str,
first_node_id: int,
last_node_id: int,
node_type: Union[np.ndarray, torch.Tensor],
strategy: str = "node-type-filtered",
num_neg_e: int = -1, # -1 means generate all possible negatives
rnd_seed: int = 1,
edge_data: TemporalData = None,
) -> None:
r"""
Negative Edge Generator class for Temporal Heterogeneous Graphs
this is a class for generating negative samples for a specific datasets
the set of the positive samples are provided, the negative samples are generated with specific strategies
and are saved for consistent evaluation across different methods
Parameters:
dataset_name: name of the dataset
first_node_id: the first node id
last_node_id: the last node id
node_type: the node type of each node
strategy: the strategy to generate negative samples
num_neg_e: number of negative samples to generate
rnd_seed: random seed
edge_data: the edge data object containing the positive edges
Returns:
None
"""
self.rnd_seed = rnd_seed
np.random.seed(self.rnd_seed)
self.dataset_name = dataset_name
self.first_node_id = first_node_id
self.last_node_id = last_node_id
if isinstance(node_type, torch.Tensor):
node_type = node_type.cpu().numpy()
self.node_type = node_type
self.node_type_dict = self.get_destinations_based_on_node_type(first_node_id, last_node_id, self.node_type) # {node_type: {nid:1}}
assert isinstance(self.node_type, np.ndarray), "node_type should be a numpy array"
self.num_neg_e = num_neg_e #-1 means generate all
assert strategy in [
"node-type-filtered",
"random",
], "The supported strategies are `node-type-filtered`"
self.strategy = strategy
self.edge_data = edge_data
def get_destinations_based_on_node_type(self,
first_node_id: int,
last_node_id: int,
node_type: np.ndarray) -> dict:
r"""
get the destination node id arrays based on the node type
Parameters:
first_node_id: the first node id
last_node_id: the last node id
node_type: the node type of each node
Returns:
node_type_dict: a dictionary containing the destination node ids for each node type
"""
node_type_store = {}
assert first_node_id <= last_node_id, "Invalid destination node ids!"
assert len(node_type) == (last_node_id - first_node_id + 1), "node type array must match the indices"
for k in range(len(node_type)):
nt = int(node_type[k]) #node type must be ints
nid = k + first_node_id
if nt not in node_type_store:
node_type_store[nt] = {nid:1}
else:
node_type_store[nt][nid] = 1
node_type_dict = {}
for ntype in node_type_store:
node_type_dict[ntype] = np.array(list(node_type_store[ntype].keys()))
assert np.all(np.diff(node_type_dict[ntype]) >= 0), "Destination node ids for a given type must be sorted"
assert np.all(node_type_dict[ntype] <= last_node_id), "Destination node ids must be less than or equal to the last destination id"
return node_type_dict
def generate_negative_samples(self,
pos_edges: TemporalData,
split_mode: str,
partial_path: str,
) -> None:
r"""
Generate negative samples
Parameters:
pos_edges: positive edges to generate the negatives for
split_mode: specifies whether to generate negative edges for 'validation' or 'test' splits
partial_path: in which directory save the generated negatives
"""
# file name for saving or loading...
filename = (
partial_path
+ "/"
+ self.dataset_name
+ "_"
+ split_mode
+ "_"
+ "ns"
+ ".pkl"
)
if self.strategy == "node-type-filtered":
self.generate_negative_samples_nt(pos_edges, split_mode, filename)
elif self.strategy == "random":
self.generate_negative_samples_random(pos_edges, split_mode, filename)
else:
raise ValueError("Unsupported negative sample generation strategy!")
def generate_negative_samples_nt(self,
data: TemporalData,
split_mode: str,
filename: str,
) -> None:
r"""
now we consider (s, d, t, edge_type) as a unique edge, also adding the node type info for the destination node for convenience so (s, d, t, edge_type): (conflict_set, d_node_type)
Generate negative samples based on the random strategy:
- for each positive edge, retrieve all possible destinations based on the node type of the destination node
- filter actual positive edges at the same timestamp with the same edge type
Parameters:
data: an object containing positive edges information
split_mode: specifies whether to generate negative edges for 'validation' or 'test' splits
filename: name of the file containing the generated negative edges
"""
vprint(
f"INFO: Negative Sampling Strategy: {self.strategy}, Data Split: {split_mode}"
)
assert split_mode in [
"val",
"test",
], "Invalid split-mode! It should be `val` or `test`!"
if os.path.exists(filename):
vprint(
f"INFO: negative samples for '{split_mode}' evaluation are already generated!"
)
else:
vprint(f"INFO: Generating negative samples for '{split_mode}' evaluation!")
# retrieve the information from the batch
pos_src, pos_dst, pos_timestamp, edge_type = (
data.src.cpu().numpy(),
data.dst.cpu().numpy(),
data.t.cpu().numpy(),
data.edge_type.cpu().numpy(),
)
# generate a list of negative destinations for each positive edge
pos_edge_tqdm = tqdm(
zip(pos_src, pos_dst, pos_timestamp, edge_type), total=len(pos_src)
)
edge_t_dict = {} # {(t, u, edge_type): {v_1, v_2, ..} }
#! iterate once to put all edges into a dictionary for reference
for (
pos_s,
pos_d,
pos_t,
edge_type,
) in pos_edge_tqdm:
if (pos_t, pos_s, edge_type) not in edge_t_dict:
edge_t_dict[(pos_t, pos_s, edge_type)] = {pos_d:1}
else:
edge_t_dict[(pos_t, pos_s, edge_type)][pos_d] = 1
out_dict = {}
for key in tqdm(edge_t_dict):
conflict_set = np.array(list(edge_t_dict[key].keys()))
pos_d = conflict_set[0]
#* retieve the node type of the destination node as well
#! assumption, same edge type = same destination node type
d_node_type = int(self.node_type[pos_d - self.first_node_id])
all_dst = self.node_type_dict[d_node_type]
if (self.num_neg_e == -1):
filtered_all_dst = np.setdiff1d(all_dst, conflict_set)
else:
#* lazy sampling
neg_d_arr = np.random.choice(
all_dst, self.num_neg_e, replace=False) #never replace negatives
if len(np.setdiff1d(neg_d_arr, conflict_set)) < self.num_neg_e:
neg_d_arr = np.random.choice(
np.setdiff1d(all_dst, conflict_set), self.num_neg_e, replace=False)
filtered_all_dst = neg_d_arr
out_dict[key] = filtered_all_dst
vprint ("ns samples for ", len(out_dict), " positive edges are generated")
# save the generated evaluation set to disk
save_pkl(out_dict, filename)
def generate_negative_samples_random(self,
data: TemporalData,
split_mode: str,
filename: str,
) -> None:
r"""
generate random negative edges for ablation study
Parameters:
data: an object containing positive edges information
split_mode: specifies whether to generate negative edges for 'validation' or 'test' splits
filename: name of the file containing the generated negative edges
"""
vprint(
f"INFO: Negative Sampling Strategy: {self.strategy}, Data Split: {split_mode}"
)
assert split_mode in [
"val",
"test",
], "Invalid split-mode! It should be `val` or `test`!"
if os.path.exists(filename):
vprint(
f"INFO: negative samples for '{split_mode}' evaluation are already generated!"
)
else:
vprint(f"INFO: Generating negative samples for '{split_mode}' evaluation!")
# retrieve the information from the batch
pos_src, pos_dst, pos_timestamp, edge_type = (
data.src.cpu().numpy(),
data.dst.cpu().numpy(),
data.t.cpu().numpy(),
data.edge_type.cpu().numpy(),
)
first_dst_id = self.edge_data.dst.min()
last_dst_id = self.edge_data.dst.max()
all_dst = np.arange(first_dst_id, last_dst_id + 1)
evaluation_set = {}
# generate a list of negative destinations for each positive edge
pos_edge_tqdm = tqdm(
zip(pos_src, pos_dst, pos_timestamp, edge_type), total=len(pos_src)
)
for (
pos_s,
pos_d,
pos_t,
edge_type,
) in pos_edge_tqdm:
t_mask = pos_timestamp == pos_t
src_mask = pos_src == pos_s
fn_mask = np.logical_and(t_mask, src_mask)
pos_e_dst_same_src = pos_dst[fn_mask]
filtered_all_dst = np.setdiff1d(all_dst, pos_e_dst_same_src)
if (self.num_neg_e > len(filtered_all_dst)):
neg_d_arr = filtered_all_dst
else:
neg_d_arr = np.random.choice(
filtered_all_dst, self.num_neg_e, replace=False) #never replace negatives
evaluation_set[(pos_t, pos_s, edge_type)] = neg_d_arr
save_pkl(evaluation_set, filename)
================================================
FILE: tgb/linkproppred/thg_negative_sampler.py
================================================
"""
Sample negative edges for evaluation of dynamic link prediction
Load already generated negative edges from file, batch them based on the positive edge, and return the evaluation set
"""
import torch
from torch import Tensor
import numpy as np
from tgb.utils.utils import load_pkl
from typing import Union
import os
class THGNegativeEdgeSampler(object):
def __init__(
self,
dataset_name: str,
first_node_id: int,
last_node_id: int,
node_type: np.ndarray,
strategy: str = "node-type-filtered",
) -> None:
r"""
Negative Edge Sampler
Loads and query the negative batches based on the positive batches provided.
constructor for the negative edge sampler class
Parameters:
dataset_name: name of the dataset
first_node_id: identity of the first node
last_node_id: indentity of the last destination node
node_type: the node type of each node
strategy: will always load the pre-generated negatives
Returns:
None
"""
self.dataset_name = dataset_name
self.eval_set = {}
self.first_node_id = first_node_id
self.last_node_id = last_node_id
self.node_type = node_type
assert isinstance(self.node_type, np.ndarray), "node_type should be a numpy array"
def load_eval_set(
self,
fname: str,
split_mode: str = "val",
) -> None:
r"""
Load the evaluation set from disk, can be either val or test set ns samples
Parameters:
fname: the file name of the evaluation ns on disk
split_mode: the split mode of the evaluation set, can be either `val` or `test`
Returns:
None
"""
assert split_mode in [
"val",
"test",
], "Invalid split-mode! It should be `val`, `test`"
if not os.path.exists(fname):
raise FileNotFoundError(f"File not found at {fname}")
self.eval_set[split_mode] = load_pkl(fname)
def query_batch(self,
pos_src: Union[Tensor, np.ndarray],
pos_dst: Union[Tensor, np.ndarray],
pos_timestamp: Union[Tensor, np.ndarray],
edge_type: Union[Tensor, np.ndarray],
split_mode: str = "test") -> list:
r"""
For each positive edge in the `pos_batch`, return a list of negative edges
`split_mode` specifies whether the valiation or test evaluation set should be retrieved.
modify now to include edge type argument
Parameters:
pos_src: list of positive source nodes
pos_dst: list of positive destination nodes
pos_timestamp: list of timestamps of the positive edges
split_mode: specifies whether to generate negative edges for 'validation' or 'test' splits
Returns:
neg_samples: list of numpy array; each array contains the set of negative edges that
should be evaluated against each positive edge.
"""
assert split_mode in [
"val",
"test",
], "Invalid split-mode! It should be `val`, `test`!"
if self.eval_set[split_mode] == None:
raise ValueError(
f"Evaluation set is None! You should load the {split_mode} evaluation set first!"
)
# check the argument types...
if torch is not None and isinstance(pos_src, torch.Tensor):
pos_src = pos_src.detach().cpu().numpy()
if torch is not None and isinstance(pos_dst, torch.Tensor):
pos_dst = pos_dst.detach().cpu().numpy()
if torch is not None and isinstance(pos_timestamp, torch.Tensor):
pos_timestamp = pos_timestamp.detach().cpu().numpy()
if torch is not None and isinstance(edge_type, torch.Tensor):
edge_type = edge_type.detach().cpu().numpy()
if not isinstance(pos_src, np.ndarray) or not isinstance(pos_dst, np.ndarray) or not(pos_timestamp, np.ndarray) or not(edge_type, np.ndarray):
raise RuntimeError(
"pos_src, pos_dst, and pos_timestamp need to be either numpy ndarray or torch tensor!"
)
neg_samples = []
for pos_s, pos_d, pos_t, e_type in zip(pos_src, pos_dst, pos_timestamp, edge_type):
if (pos_t, pos_s, e_type) not in self.eval_set[split_mode]:
raise ValueError(
f"The edge ({pos_s}, {pos_d}, {pos_t}, {e_type}) is not in the '{split_mode}' evaluation set! Please check the implementation."
)
else:
filtered_dst = self.eval_set[split_mode]
neg_d_arr = filtered_dst[(pos_t, pos_s, e_type)]
neg_samples.append(
neg_d_arr
)
#? can't convert to numpy array due to different lengths of negative samples
return neg_samples
================================================
FILE: tgb/linkproppred/tkg_negative_generator.py
================================================
"""
Sample and Generate negative edges that are going to be used for evaluation of a dynamic graph learning model
Negative samples are generated and saved to files ONLY once;
other times, they should be loaded from file with instances of the `negative_sampler.py`.
"""
import numpy as np
from torch_geometric.data import TemporalData
import matplotlib.pyplot as plt
from tgb.utils.utils import save_pkl
import os
from tqdm import tqdm
from tgb.utils.utils import vprint
"""
negative sample generator for tkg datasets
temporal filterted MRR
"""
class TKGNegativeEdgeGenerator(object):
def __init__(
self,
dataset_name: str,
first_dst_id: int,
last_dst_id: int,
strategy: str = "time-filtered",
num_neg_e: int = -1, # -1 means generate all possible negatives
rnd_seed: int = 1,
partial_path: str = None,
edge_data: TemporalData = None,
) -> None:
r"""
Negative Edge Generator class for Temporal Knowledge Graphs
constructor for the negative edge generator class
Parameters:
dataset_name: name of the dataset
first_dst_id: identity of the first destination node
last_dst_id: indentity of the last destination node
num_neg_e: number of negative edges being generated per each positive edge
strategy: specifies which strategy should be used for generating the negatives
rnd_seed: random seed for reproducibility
edge_data: the positive edges to generate the negatives for, assuming sorted temporally
Returns:
None
"""
self.rnd_seed = rnd_seed
np.random.seed(self.rnd_seed)
self.dataset_name = dataset_name
self.first_dst_id = first_dst_id
self.last_dst_id = last_dst_id
self.num_neg_e = num_neg_e #-1 means generate all
assert strategy in [
"time-filtered",
"dst-time-filtered",
"random"
], "The supported strategies are `time-filtered`, dst-time-filtered, random"
self.strategy = strategy
self.dst_dict = None
if self.strategy == "dst-time-filtered":
if partial_path is None:
raise ValueError(
"The partial path to the directory where the dst_dict is stored is required")
else:
self.dst_dict_name = (
partial_path
+ "/"
+ self.dataset_name
+ "_"
+ "dst_dict"
+ ".pkl"
)
self.dst_dict = self.generate_dst_dict(edge_data=edge_data, dst_name=self.dst_dict_name)
self.edge_data = edge_data
def generate_dst_dict(self, edge_data: TemporalData, dst_name: str) -> dict:
r"""
Generate a dictionary of destination nodes for each type of edge
Parameters:
edge_data: an object containing positive edges information
dst_name: name of the file to save the generated dictionary of destination nodes
Returns:
dst_dict: a dictionary of destination nodes for each type of edge
"""
min_dst_idx, max_dst_idx = int(edge_data.dst.min()), int(edge_data.dst.max())
pos_src, pos_dst, pos_timestamp, edge_type = (
edge_data.src.cpu().numpy(),
edge_data.dst.cpu().numpy(),
edge_data.t.cpu().numpy(),
edge_data.edge_type.cpu().numpy(),
)
dst_track_dict = {} # {edge_type: {dst_1, dst_2, ..} }
# generate a list of negative destinations for each positive edge
pos_edge_tqdm = tqdm(
zip(pos_src, pos_dst, pos_timestamp, edge_type), total=len(pos_src)
)
for (
pos_s,
pos_d,
pos_t,
edge_type,
) in pos_edge_tqdm:
if edge_type not in dst_track_dict:
dst_track_dict[edge_type] = {pos_d:1}
else:
dst_track_dict[edge_type][pos_d] = 1
dst_dict = {}
edge_type_size = []
for key in dst_track_dict:
dst = np.array(list(dst_track_dict[key].keys()))
edge_type_size.append(len(dst))
dst_dict[key] = dst
vprint ('destination candidates generated for all edge types ', len(dst_dict))
return dst_dict
def generate_negative_samples(self,
pos_edges: TemporalData,
split_mode: str,
partial_path: str,
) -> None:
r"""
Generate negative samples
Parameters:
pos_edges: positive edges to generate the negatives for
split_mode: specifies whether to generate negative edges for 'validation' or 'test' splits
partial_path: in which directory save the generated negatives
"""
# file name for saving or loading...
filename = (
partial_path
+ "/"
+ self.dataset_name
+ "_"
+ split_mode
+ "_"
+ "ns"
+ ".pkl"
)
if self.strategy == "time-filtered":
self.generate_negative_samples_ftr(pos_edges, split_mode, filename)
elif self.strategy == "dst-time-filtered":
self.generate_negative_samples_dst(pos_edges, split_mode, filename)
elif self.strategy == "random":
self.generate_negative_samples_random(pos_edges, split_mode, filename)
else:
raise ValueError("Unsupported negative sample generation strategy!")
def generate_negative_samples_ftr(self,
data: TemporalData,
split_mode: str,
filename: str,
) -> None:
r"""
now we consider (s, d, t, edge_type) as a unique edge
Generate negative samples based on the random strategy:
- for each positive edge, sample a batch of negative edges from all possible edges with the same source node
- filter actual positive edges at the same timestamp with the same edge type
Parameters:
data: an object containing positive edges information
split_mode: specifies whether to generate negative edges for 'validation' or 'test' splits
filename: name of the file containing the generated negative edges
"""
vprint(
f"INFO: Negative Sampling Strategy: {self.strategy}, Data Split: {split_mode}"
)
assert split_mode in [
"val",
"test",
], "Invalid split-mode! It should be `val` or `test`!"
if os.path.exists(filename):
vprint(
f"INFO: negative samples for '{split_mode}' evaluation are already generated!"
)
else:
vprint(f"INFO: Generating negative samples for '{split_mode}' evaluation!")
# retrieve the information from the batch
pos_src, pos_dst, pos_timestamp, edge_type = (
data.src.cpu().numpy(),
data.dst.cpu().numpy(),
data.t.cpu().numpy(),
data.edge_type.cpu().numpy(),
)
# generate a list of negative destinations for each positive edge
pos_edge_tqdm = tqdm(
zip(pos_src, pos_dst, pos_timestamp, edge_type), total=len(pos_src)
)
edge_t_dict = {} # {(t, u, edge_type): {v_1, v_2, ..} }
#! iterate once to put all edges into a dictionary for reference
for (
pos_s,
pos_d,
pos_t,
edge_type,
) in pos_edge_tqdm:
if (pos_t, pos_s, edge_type) not in edge_t_dict:
edge_t_dict[(pos_t, pos_s, edge_type)] = {pos_d:1}
else:
edge_t_dict[(pos_t, pos_s, edge_type)][pos_d] = 1
conflict_dict = {}
for key in edge_t_dict:
conflict_dict[key] = np.array(list(edge_t_dict[key].keys()))
vprint ("conflict sets for ns samples for ", len(conflict_dict), " positive edges are generated")
# save the generated evaluation set to disk
save_pkl(conflict_dict, filename)
def generate_negative_samples_dst(self,
data: TemporalData,
split_mode: str,
filename: str,
) -> None:
r"""
now we consider (s, d, t, edge_type) as a unique edge
Generate negative samples based on the random strategy:
- for each positive edge, sample a batch of negative edges from all possible edges with the same source node
- filter actual positive edges at the same timestamp with the same edge type
Parameters:
data: an object containing positive edges information
split_mode: specifies whether to generate negative edges for 'validation' or 'test' splits
filename: name of the file containing the generated negative edges
"""
vprint(
f"INFO: Negative Sampling Strategy: {self.strategy}, Data Split: {split_mode}"
)
assert split_mode in [
"val",
"test",
], "Invalid split-mode! It should be `val` or `test`!"
if os.path.exists(filename):
vprint(
f"INFO: negative samples for '{split_mode}' evaluation are already generated!"
)
else:
if self.dst_dict is None:
raise ValueError("The dst_dict is not generated!")
vprint(f"INFO: Generating negative samples for '{split_mode}' evaluation!")
# retrieve the information from the batch
pos_src, pos_dst, pos_timestamp, edge_type = (
data.src.cpu().numpy(),
data.dst.cpu().numpy(),
data.t.cpu().numpy(),
data.edge_type.cpu().numpy(),
)
# generate a list of negative destinations for each positive edge
pos_edge_tqdm = tqdm(
zip(pos_src, pos_dst, pos_timestamp, edge_type), total=len(pos_src)
)
edge_t_dict = {} # {(t, u, edge_type): {v_1, v_2, ..} }
out_dict = {}
#! iterate once to put all edges into a dictionary for reference
for (
pos_s,
pos_d,
pos_t,
edge_type,
) in pos_edge_tqdm:
if (pos_t, pos_s, edge_type) not in edge_t_dict:
edge_t_dict[(pos_t, pos_s, edge_type)] = {pos_d:1}
else:
edge_t_dict[(pos_t, pos_s, edge_type)][pos_d] = 1
pos_src, pos_dst, pos_timestamp, edge_type = (
data.src.cpu().numpy(),
data.dst.cpu().numpy(),
data.t.cpu().numpy(),
data.edge_type.cpu().numpy(),
)
new_pos_edge_tqdm = tqdm(
zip(pos_src, pos_dst, pos_timestamp, edge_type), total=len(pos_src)
)
min_dst_idx, max_dst_idx = int(self.edge_data.dst.min()), int(self.edge_data.dst.max())
for (
pos_s,
pos_d,
pos_t,
edge_type,
) in new_pos_edge_tqdm:
#* generate based on # of ns samples
conflict_set = np.array(list(edge_t_dict[(pos_t, pos_s, edge_type)].keys()))
dst_set = self.dst_dict[edge_type] #dst_set contains conflict set
sample_num = self.num_neg_e
filtered_dst_set = np.setdiff1d(dst_set, conflict_set) #more efficient
dst_sampled = None
all_dst = np.arange(min_dst_idx, max_dst_idx+1)
if len(filtered_dst_set) < (sample_num):
#* with collision check
filtered_sample_set = np.setdiff1d(all_dst, filtered_dst_set)
dst_sampled = np.random.choice(filtered_sample_set, sample_num, replace=False)
# #* remove the conflict set from dst set
dst_sampled[0:len(filtered_dst_set)] = filtered_dst_set[:]
else:
# dst_sampled = rng.choice(max_dst_idx+1, sample_num, replace=False)
dst_sampled = np.random.choice(filtered_dst_set, sample_num, replace=False)
out_dict[(pos_t, pos_s, edge_type)] = dst_sampled
vprint ("negative samples for ", len(out_dict), " positive edges are generated")
# save the generated evaluation set to disk
save_pkl(out_dict, filename)
def generate_negative_samples_random(self,
data: TemporalData,
split_mode: str,
filename: str,
) -> None:
r"""
generate random negative edges for ablation study
Parameters:
data: an object containing positive edges information
split_mode: specifies whether to generate negative edges for 'validation' or 'test' splits
filename: name of the file containing the generated negative edges
"""
vprint(
f"INFO: Negative Sampling Strategy: {self.strategy}, Data Split: {split_mode}"
)
assert split_mode in [
"val",
"test",
], "Invalid split-mode! It should be `val` or `test`!"
if os.path.exists(filename):
vprint(
f"INFO: negative samples for '{split_mode}' evaluation are already generated!"
)
else:
vprint(f"INFO: Generating negative samples for '{split_mode}' evaluation!")
# retrieve the information from the batch
pos_src, pos_dst, pos_timestamp, edge_type = (
data.src.cpu().numpy(),
data.dst.cpu().numpy(),
data.t.cpu().numpy(),
data.edge_type.cpu().numpy(),
)
first_dst_id = self.edge_data.dst.min()
last_dst_id = self.edge_data.dst.max()
all_dst = np.arange(first_dst_id, last_dst_id + 1)
evaluation_set = {}
# generate a list of negative destinations for each positive edge
pos_edge_tqdm = tqdm(
zip(pos_src, pos_dst, pos_timestamp, edge_type), total=len(pos_src)
)
for (
pos_s,
pos_d,
pos_t,
edge_type,
) in pos_edge_tqdm:
t_mask = pos_timestamp == pos_t
src_mask = pos_src == pos_s
fn_mask = np.logical_and(t_mask, src_mask)
pos_e_dst_same_src = pos_dst[fn_mask]
filtered_all_dst = np.setdiff1d(all_dst, pos_e_dst_same_src)
if (self.num_neg_e > len(filtered_all_dst)):
neg_d_arr = filtered_all_dst
else:
neg_d_arr = np.random.choice(
filtered_all_dst, self.num_neg_e, replace=False) #never replace negatives
evaluation_set[(pos_t, pos_s, edge_type)] = neg_d_arr
save_pkl(evaluation_set, filename)
================================================
FILE: tgb/linkproppred/tkg_negative_sampler.py
================================================
"""
Sample negative edges for evaluation of dynamic link prediction
Load already generated negative edges from file, batch them based on the positive edge, and return the evaluation set
"""
import torch
from torch import Tensor
import numpy as np
from torch_geometric.data import TemporalData
from tgb.utils.utils import save_pkl, load_pkl
from tgb.utils.info import PROJ_DIR
from typing import Union
import os
import time
class TKGNegativeEdgeSampler(object):
def __init__(
self,
dataset_name: str,
first_dst_id: int,
last_dst_id: int,
strategy: str = "time-filtered",
partial_path: str = PROJ_DIR + "/data/processed",
) -> None:
r"""
Negative Edge Sampler
Loads and query the negative batches based on the positive batches provided.
constructor for the negative edge sampler class
Parameters:
dataset_name: name of the dataset
first_dst_id: identity of the first destination node
last_dst_id: indentity of the last destination node
strategy: will always load the pre-generated negatives
partial_path: the path to the directory where the negative edges are stored
Returns:
None
"""
self.dataset_name = dataset_name
self.eval_set = {}
self.first_dst_id = first_dst_id
self.last_dst_id = last_dst_id
self.strategy = strategy
self.dst_dict = None
def load_eval_set(
self,
fname: str,
split_mode: str = "val",
) -> None:
r"""
Load the evaluation set from disk, can be either val or test set ns samples
Parameters:
fname: the file name of the evaluation ns on disk
split_mode: the split mode of the evaluation set, can be either `val` or `test`
Returns:
None
"""
assert split_mode in [
"val",
"test",
], "Invalid split-mode! It should be `val`, `test`"
if not os.path.exists(fname):
raise FileNotFoundError(f"File not found at {fname}")
self.eval_set[split_mode] = load_pkl(fname)
def query_batch(self,
pos_src: Union[Tensor, np.ndarray],
pos_dst: Union[Tensor, np.ndarray],
pos_timestamp: Union[Tensor, np.ndarray],
edge_type: Union[Tensor, np.ndarray],
split_mode: str = "test") -> list:
r"""
For each positive edge in the `pos_batch`, return a list of negative edges
`split_mode` specifies whether the valiation or test evaluation set should be retrieved.
modify now to include edge type argument
Parameters:
pos_src: list of positive source nodes
pos_dst: list of positive destination nodes
pos_timestamp: list of timestamps of the positive edges
split_mode: specifies whether to generate negative edges for 'validation' or 'test' splits
Returns:
neg_samples: list of numpy array; each array contains the set of negative edges that
should be evaluated against each positive edge.
"""
assert split_mode in [
"val",
"test",
], "Invalid split-mode! It should be `val`, `test`!"
if self.eval_set[split_mode] == None:
raise ValueError(
f"Evaluation set is None! You should load the {split_mode} evaluation set first!"
)
# check the argument types...
if torch is not None and isinstance(pos_src, torch.Tensor):
pos_src = pos_src.detach().cpu().numpy()
if torch is not None and isinstance(pos_dst, torch.Tensor):
pos_dst = pos_dst.detach().cpu().numpy()
if torch is not None and isinstance(pos_timestamp, torch.Tensor):
pos_timestamp = pos_timestamp.detach().cpu().numpy()
if torch is not None and isinstance(edge_type, torch.Tensor):
edge_type = edge_type.detach().cpu().numpy()
if not isinstance(pos_src, np.ndarray) or not isinstance(pos_dst, np.ndarray) or not(pos_timestamp, np.ndarray) or not(edge_type, np.ndarray):
raise RuntimeError(
"pos_src, pos_dst, and pos_timestamp need to be either numpy ndarray or torch tensor!"
)
if self.strategy == "time-filtered":
neg_samples = []
for pos_s, pos_d, pos_t, e_type in zip(pos_src, pos_dst, pos_timestamp, edge_type):
if (pos_t, pos_s, e_type) not in self.eval_set[split_mode]:
raise ValueError(
f"The edge ({pos_s}, {pos_d}, {pos_t}, {e_type}) is not in the '{split_mode}' evaluation set! Please check the implementation."
)
else:
conflict_dict = self.eval_set[split_mode]
conflict_set = conflict_dict[(pos_t, pos_s, e_type)]
all_dst = np.arange(self.first_dst_id, self.last_dst_id + 1)
filtered_all_dst = np.delete(all_dst, conflict_set, axis=0)
#! always using all possible destinations for evaluation
neg_d_arr = filtered_all_dst
#! this is very slow
neg_samples.append(
neg_d_arr
)
elif self.strategy == "dst-time-filtered":
neg_samples = []
for pos_s, pos_d, pos_t, e_type in zip(pos_src, pos_dst, pos_timestamp, edge_type):
if (pos_t, pos_s, e_type) not in self.eval_set[split_mode]:
raise ValueError(
f"The edge ({pos_s}, {pos_d}, {pos_t}, {e_type}) is not in the '{split_mode}' evaluation set! Please check the implementation."
)
else:
filtered_dst = self.eval_set[split_mode]
neg_d_arr = filtered_dst[(pos_t, pos_s, e_type)]
neg_samples.append(
neg_d_arr
)
#? can't convert to numpy array due to different lengths of negative samples
return neg_samples
================================================
FILE: tgb/nodeproppred/dataset.py
================================================
from typing import Optional, Dict, Any, Tuple
import os
import os.path as osp
import numpy as np
import pandas as pd
import zipfile
import requests
from clint.textui import progress
from tgb.utils.info import (
PROJ_DIR,
DATA_URL_DICT,
DATA_NUM_CLASSES,
DATA_VERSION_DICT,
DATA_EVAL_METRIC_DICT,
BColors,
)
from tgb.utils.utils import save_pkl, load_pkl, vprint
from tgb.utils.pre_process import (
load_label_dict,
load_edgelist_sr,
load_edgelist_token,
load_edgelist_datetime,
load_trade_label_dict,
load_edgelist_trade,
)
class NodePropPredDataset(object):
def __init__(
self,
name: str,
root: str = "datasets",
meta_dict: Optional[dict] = None,
preprocess: Optional[bool] = True,
download: Optional[bool] = True,
) -> None:
r"""Dataset class for the node property prediction task. Stores meta information about each dataset such as evaluation metrics etc.
also automatically pre-processes the dataset.
[!] node property prediction datasets requires the following:
self.meta_dict["fname"]: path to the edge list file
self.meta_dict["nodefile"]: path to the node label file
Parameters:
name: name of the dataset
root: root directory to store the dataset folder
meta_dict: dictionary containing meta information about the dataset, should contain key 'dir_name' which is the name of the dataset folder
preprocess: whether to pre-process the dataset
download: whether to download the dataset or not (default: True)
Returns:
None
"""
self.name = name ## original name
# check if dataset url exist
if self.name in DATA_URL_DICT:
self.url = DATA_URL_DICT[self.name]
else:
self.url = None
# check if the evaluatioin metric are specified
if self.name in DATA_EVAL_METRIC_DICT:
self.metric = DATA_EVAL_METRIC_DICT[self.name]
else:
self.metric = None
raise ValueError(f"Dataset {self.name} default evaluation metric not found, it is not supported yet.")
root = PROJ_DIR + root
if meta_dict is None:
self.dir_name = "_".join(name.split("-")) ## replace hyphen with underline
meta_dict = {"dir_name": self.dir_name}
else:
self.dir_name = meta_dict["dir_name"]
self.root = osp.join(root, self.dir_name)
self.meta_dict = meta_dict
if "fname" not in self.meta_dict:
self.meta_dict["fname"] = self.root + "/" + self.name + "_edgelist.csv"
self.meta_dict["nodefile"] = self.root + "/" + self.name + "_node_labels.csv"
#! version check
self.version_passed = True
self._version_check()
self._num_classes = DATA_NUM_CLASSES[self.name]
# initialize
self._node_feat = None
self._edge_feat = None
self._full_data = None
if download:
self.download()
else:
if osp.exists(self.meta_dict["fname"]):
dir_name = self.meta_dict["fname"]
vprint(f"files found in {dir_name}")
else:
dir_name = self.meta_dict["fname"]
raise FileNotFoundError(f"Directory not found at {dir_name}, please download the dataset")
# check if the root directory exists, if not create it
if osp.isdir(self.root):
vprint("Dataset directory is ", self.root)
else:
raise FileNotFoundError(f"Directory not found at {self.root}")
if preprocess:
self.pre_process()
self.label_ts_idx = 0 # index for which node lables to return now
def _version_check(self) -> None:
r"""Implement Version checks for dataset files
updates the file names based on the current version number
prompt the user to download the new version via self.version_passed variable
"""
if (self.name in DATA_VERSION_DICT):
version = DATA_VERSION_DICT[self.name]
else:
raise ValueError(f"Dataset {self.name} version number not found.")
if (version > 1):
#* check if current version is outdated
self.meta_dict["fname"] = self.root + "/" + self.name + "_edgelist_v" + str(int(version)) + ".csv"
self.meta_dict["nodefile"] = self.root + "/" + self.name + "_node_labels_v" + str(int(version)) + ".csv"
if (not osp.exists(self.meta_dict["fname"])):
vprint(f"Dataset {self.name} version {int(version)} not found, Please download the latest version of the dataset.")
self.version_passed = False
return None
def download(self) -> None:
r"""
downloads this dataset from url
check if files are already downloaded
Returns:
None
"""
# check if the file already exists
if osp.exists(self.meta_dict["fname"]) and osp.exists(
self.meta_dict["nodefile"]
):
dir_name = self.meta_dict["fname"]
vprint(f"files found in {dir_name}")
return
else:
vprint(
f"{BColors.WARNING}Download started, this might take a while . . . {BColors.ENDC}"
)
vprint(f"Dataset title: {self.name}")
if self.url is None:
raise ValueError(f"Dataset {self.name} url not found, download not supported yet.")
else:
r = requests.get(self.url, stream=True)
if osp.isdir(self.root):
vprint("Dataset directory is ", self.root)
else:
os.makedirs(self.root)
path_download = self.root + "/" + self.name + ".zip"
print(f"downloading Dataset: {self.name} to {path_download}")
with open(path_download, "wb") as f:
total_length = int(r.headers.get("content-length"))
for chunk in progress.bar(
r.iter_content(chunk_size=1024),
expected_size=(total_length / 1024) + 1,
):
if chunk:
f.write(chunk)
f.flush()
# for unzipping the file
with zipfile.ZipFile(path_download, "r") as zip_ref:
zip_ref.extractall(self.root)
vprint(f"{BColors.OKGREEN}Download completed {BColors.ENDC}")
def generate_processed_files(
self,
) -> Tuple[pd.DataFrame, Dict[int, Dict[str, Any]]]:
r"""
returns an edge list of pandas data frame
Returns:
df: pandas data frame storing the temporal edge list
node_label_dict: dictionary with key as timestamp and item as dictionary of node labels
"""
OUT_DF = self.root + "/" + "ml_{}.pkl".format(self.name)
OUT_NODE_DF = self.root + "/" + "ml_{}_node.pkl".format(self.name)
OUT_LABEL_DF = self.root + "/" + "ml_{}_label.pkl".format(self.name)
OUT_EDGE_FEAT = self.root + "/" + "ml_{}.pkl".format(self.name + "_edge")
# * logic for large datasets, as node label file is too big to store on disc
if self.name == "tgbn-reddit" or self.name == "tgbn-token":
if osp.exists(OUT_DF) and osp.exists(OUT_NODE_DF) and osp.exists(OUT_EDGE_FEAT):
df = pd.read_pickle(OUT_DF)
edge_feat = load_pkl(OUT_EDGE_FEAT)
if (self.name == "tgbn-token"):
#! taking log normalization for numerical stability
vprint ("applying log normalization for weights in tgbn-token")
edge_feat[:,0] = np.log(edge_feat[:,0])
node_ids = load_pkl(OUT_NODE_DF)
labels_dict = load_pkl(OUT_LABEL_DF)
node_label_dict = load_label_dict(
self.meta_dict["nodefile"], node_ids, labels_dict
)
return df, node_label_dict, edge_feat
# * load the preprocessed file if possible
if osp.exists(OUT_DF) and osp.exists(OUT_NODE_DF) and osp.exists(OUT_EDGE_FEAT):
vprint(f"loading processed file from {OUT_DF}, edge features from {OUT_EDGE_FEAT}, node info from {OUT_NODE_DF}.")
df = pd.read_pickle(OUT_DF)
node_label_dict = load_pkl(OUT_NODE_DF)
edge_feat = load_pkl(OUT_EDGE_FEAT)
else: # * process the file
vprint("file not processed, generating processed file")
if self.name == "tgbn-reddit":
df, edge_feat, node_ids, labels_dict = load_edgelist_sr(
self.meta_dict["fname"], label_size=self._num_classes
)
elif self.name == "tgbn-token":
df, edge_feat, node_ids, labels_dict = load_edgelist_token(
self.meta_dict["fname"], label_size=self._num_classes
)
elif self.name == "tgbn-genre":
df, edge_feat, node_ids, labels_dict = load_edgelist_datetime(
self.meta_dict["fname"], label_size=self._num_classes
)
elif self.name == "tgbn-trade":
df, edge_feat, node_ids = load_edgelist_trade(
self.meta_dict["fname"], label_size=self._num_classes
)
df.to_pickle(OUT_DF)
save_pkl(edge_feat, OUT_EDGE_FEAT)
if self.name == "tgbn-trade":
node_label_dict = load_trade_label_dict(
self.meta_dict["nodefile"], node_ids
)
else:
node_label_dict = load_label_dict(
self.meta_dict["nodefile"], node_ids, labels_dict
)
if (
self.name != "tgbn-reddit" and self.name != "tgbn-token"
): # don't save subreddits on disc, the node label file is too big
save_pkl(node_label_dict, OUT_NODE_DF)
else:
save_pkl(node_ids, OUT_NODE_DF)
save_pkl(labels_dict, OUT_LABEL_DF)
vprint("file processed and saved")
return df, node_label_dict, edge_feat
def pre_process(self) -> None:
"""
Pre-process the dataset and generates the splits, must be run before dataset properties can be accessed
Returns:
None
"""
# first check if all files exist
if ("fname" not in self.meta_dict) or ("nodefile" not in self.meta_dict):
raise Exception("meta_dict does not contain all required filenames")
df, node_label_dict, edge_feat = self.generate_processed_files()
sources = np.array(df["u"])
destinations = np.array(df["i"])
timestamps = np.array(df["ts"])
edge_idxs = np.array(df["idx"])
edge_label = np.ones(sources.shape[0])
#self._edge_feat = np.array(df["w"])
self._edge_feat = edge_feat
full_data = {
"sources": sources,
"destinations": destinations,
"timestamps": timestamps,
"edge_idxs": edge_idxs,
"edge_feat": self._edge_feat,
"edge_label": edge_label,
"node_label_dict": node_label_dict,
}
self._full_data = full_data
# storing the split masks
_train_mask, _val_mask, _test_mask = self.generate_splits(full_data)
self._train_mask = _train_mask
self._val_mask = _val_mask
self._test_mask = _test_mask
self.label_dict = node_label_dict
self.label_ts = np.array(list(node_label_dict.keys()))
self.label_ts = np.sort(self.label_ts)
def generate_splits(
self,
full_data: Dict[str, Any],
val_ratio: float = 0.15,
test_ratio: float = 0.15,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
r"""
Generates train, validation, and test splits from the full dataset
Parameters:
full_data: dictionary containing the full dataset
val_ratio: ratio of validation data
test_ratio: ratio of test data
Returns:
train_mask: boolean mask for training data
val_mask: boolean mask for validation data
test_mask: boolean mask for test data
"""
val_time, test_time = list(
np.quantile(
full_data["timestamps"],
[(1 - val_ratio - test_ratio), (1 - test_ratio)],
)
)
timestamps = full_data["timestamps"]
train_mask = timestamps <= val_time
val_mask = np.logical_and(timestamps <= test_time, timestamps > val_time)
test_mask = timestamps > test_time
return train_mask, val_mask, test_mask
def find_next_labels_batch(
self,
cur_t: int,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
r"""
this returns the node labels closest to cur_t (for that given day)
Parameters:
cur_t: current timestamp of the batch of edges
Returns:
ts: timestamp of the node labels
source_idx: node ids
labels: the stacked label vectors
"""
if self.label_ts_idx >= (self.label_ts.shape[0]):
# for query that are after the last batch of labels
return None
else:
ts = self.label_ts[self.label_ts_idx]
if cur_t >= ts:
self.label_ts_idx += 1 # move to the next ts
# {ts: {node_id: label_vec}}
node_ids = np.array(list(self.label_dict[ts].keys()))
node_labels = []
for key in self.label_dict[ts]:
node_labels.append(np.array(self.label_dict[ts][key]))
node_labels = np.stack(node_labels, axis=0)
label_ts = np.full(node_ids.shape[0], ts, dtype="int")
return (label_ts, node_ids, node_labels)
else:
return None
def reset_label_time(self) -> None:
r"""
reset the pointer for node label once the entire dataset has been iterated once
Returns:
None
"""
self.label_ts_idx = 0
def return_label_ts(self) -> int:
"""
return the current label timestamp that the pointer is at
Returns:
ts: int, the timestamp of the node labels
"""
if (self.label_ts_idx >= self.label_ts.shape[0]):
return self.label_ts[-1]
else:
return self.label_ts[self.label_ts_idx]
@property
def num_classes(self) -> int:
"""
number of classes in the node label
Returns:
num_classes: int, number of classes
"""
return self._num_classes
@property
def eval_metric(self) -> str:
"""
the official evaluation metric for the dataset, loaded from info.py
Returns:
eval_metric: str, the evaluation metric
"""
return self.metric
# TODO not sure needed, to be removed
@property
def node_feat(self) -> Optional[np.ndarray]:
r"""
Returns the node features of the dataset with dim [N, feat_dim]
Returns:
node_feat: np.ndarray, [N, feat_dim] or None if there is no node feature
"""
return self._node_feat
# TODO not sure needed, to be removed
@property
def edge_feat(self) -> Optional[np.ndarray]:
r"""
Returns the edge features of the dataset with dim [E, feat_dim]
Returns:
edge_feat: np.ndarray, [E, feat_dim] or None if there is no edge feature
"""
return self._edge_feat
@property
def node_label_dict(self) -> Dict[int, Dict[int, Any]]:
r"""
Returns the node label dictionary of the dataset with {timestamp: {node_id: label_vec}}
Returns:
label_dict: Dict[int, Dict[int, Any]], the node label dictionary
"""
return self.label_dict
@property
def full_data(self) -> Dict[str, Any]:
r"""
the full data of the dataset as a dictionary with keys: 'sources', 'destinations', 'timestamps', 'edge_idxs', 'edge_feat', 'w', 'edge_label',
Returns:
full_data: Dict[str, Any]
"""
if self._full_data is None:
raise ValueError(
"dataset has not been processed yet, please call pre_process() first"
)
return self._full_data
@property
def train_mask(self) -> np.ndarray:
r"""
Returns the train mask of the dataset
Returns:
train_mask
"""
if self._train_mask is None:
raise ValueError("training split hasn't been loaded")
return self._train_mask
@property
def val_mask(self) -> np.ndarray:
r"""
Returns the validation mask of the dataset
Returns:
val_mask: Dict[str, Any]
"""
if self._val_mask is None:
raise ValueError("validation split hasn't been loaded")
return self._val_mask
@property
def test_mask(self) -> np.ndarray:
r"""
Returns the test mask of the dataset:
Returns:
test_mask: Dict[str, Any]
"""
if self._test_mask is None:
raise ValueError("test split hasn't been loaded")
return self._test_mask
def main():
# download files
name = "tgbn-trade"
dataset = NodePropPredDataset(name=name, root="datasets", preprocess=True)
dataset.node_feat
dataset.edge_feat # not the edge weights
dataset.full_data
dataset.full_data["edge_idxs"]
dataset.full_data["sources"]
dataset.full_data["destinations"]
dataset.full_data["timestamps"]
dataset.full_data["y"]
train_data = dataset.full_data[dataset.train_mask]
val_data = dataset.full_data[dataset.val_mask]
test_data = dataset.full_data[dataset.test_mask]
if __name__ == "__main__":
main()
================================================
FILE: tgb/nodeproppred/dataset_pyg.py
================================================
import os.path as osp
from typing import Optional, Dict, Any, Optional, Callable
import torch
from torch_geometric.data import InMemoryDataset, TemporalData, download_url
from tgb.nodeproppred.dataset import NodePropPredDataset
import warnings
# TODO check https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/data/in_memory_dataset.html
# avoid any overlapping properties
class PyGNodePropPredDataset(InMemoryDataset):
r"""
PyG wrapper for the NodePropPredDataset
can return pytorch tensors for src,dst,t,msg,label
can return Temporal Data object
also query the node labels corresponding to a timestamp from edge batch
Parameters:
name: name of the dataset, passed to `NodePropPredDataset`
root (string): Root directory where the dataset should be saved.
transform (callable, optional): A function/transform that takes in an
pre_transform (callable, optional): A function/transform that takes in
download (optional, bool): download dataset or not (default True)
"""
def __init__(
self,
name: str,
root: str = "datasets",
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
download: Optional[bool] = True,
):
self.name = name
self.root = root
self.dataset = NodePropPredDataset(name=name, root=root, download=download)
self._train_mask = torch.from_numpy(self.dataset.train_mask)
self._val_mask = torch.from_numpy(self.dataset.val_mask)
self._test_mask = torch.from_numpy(self.dataset.test_mask)
self.__num_classes = self.dataset.num_classes
super().__init__(root, transform, pre_transform)
self.process_data()
@property
def num_classes(self) -> int:
"""
how many classes are in the node label
Returns:
num_classes: int
"""
return self.__num_classes
@property
def eval_metric(self) -> str:
"""
the official evaluation metric for the dataset, loaded from info.py
Returns:
eval_metric: str, the evaluation metric
"""
return self.dataset.eval_metric
@property
def train_mask(self) -> torch.Tensor:
r"""
Returns the train mask of the dataset
Returns:
train_mask: the mask for edges in the training set
"""
if self._train_mask is None:
raise ValueError("training split hasn't been loaded")
return self._train_mask
@property
def val_mask(self) -> torch.Tensor:
r"""
Returns the validation mask of the dataset
Returns:
val_mask: the mask for edges in the validation set
"""
if self._val_mask is None:
raise ValueError("validation split hasn't been loaded")
return self._val_mask
@property
def test_mask(self) -> torch.Tensor:
r"""
Returns the test mask of the dataset:
Returns:
test_mask: the mask for edges in the test set
"""
if self._test_mask is None:
raise ValueError("test split hasn't been loaded")
return self._test_mask
@property
def src(self) -> torch.Tensor:
r"""
Returns the source nodes of the dataset
Returns:
src: the idx of the source nodes
"""
return self._src
@property
def dst(self) -> torch.Tensor:
r"""
Returns the destination nodes of the dataset
Returns:
dst: the idx of the destination nodes
"""
return self._dst
@property
def ts(self) -> torch.Tensor:
r"""
Returns the timestamps of the dataset
Returns:
ts: the timestamps of the edges
"""
return self._ts
@property
def edge_feat(self) -> torch.Tensor:
r"""
Returns the edge features of the dataset
Returns:
edge_feat: the edge features
"""
return self._edge_feat
@property
def edge_label(self) -> torch.Tensor:
r"""
Returns the edge labels of the dataset
Returns:
edge_label: the labels of the edges (all one tensor)
"""
return self._edge_label
def process_data(self):
"""
convert data to pytorch tensors
"""
src = torch.from_numpy(self.dataset.full_data["sources"])
dst = torch.from_numpy(self.dataset.full_data["destinations"])
t = torch.from_numpy(self.dataset.full_data["timestamps"])
edge_label = torch.from_numpy(self.dataset.full_data["edge_label"])
msg = torch.from_numpy(self.dataset.full_data["edge_feat"])
# msg = torch.from_numpy(self.dataset.full_data["edge_feat"]).reshape(
# [-1, 1]
# )
# * check typing
if src.dtype != torch.int64:
src = src.long()
if dst.dtype != torch.int64:
dst = dst.long()
if t.dtype != torch.int64:
t = t.long()
if msg.dtype != torch.float32:
msg = msg.float()
self._src = src
self._dst = dst
self._ts = t
self._edge_label = edge_label
self._edge_feat = msg
def get_TemporalData(
self,
) -> TemporalData:
"""
return the TemporalData object for the entire dataset
Returns:
data: TemporalData object storing the edgelist
"""
data = TemporalData(
src=self._src,
dst=self._dst,
t=self._ts,
msg=self._edge_feat,
y=self._edge_label,
)
return data
def reset_label_time(self) -> None:
"""
reset the pointer for the node labels, should be done per epoch
"""
self.dataset.reset_label_time()
def get_node_label(self, cur_t):
"""
return the node labels for the current timestamp
"""
label_tuple = self.dataset.find_next_labels_batch(cur_t)
if label_tuple is None:
return None
label_ts, label_srcs, labels = label_tuple[0], label_tuple[1], label_tuple[2]
label_ts = torch.from_numpy(label_ts).long()
label_srcs = torch.from_numpy(label_srcs).long()
labels = torch.from_numpy(labels).to(torch.float32)
return label_ts, label_srcs, labels
def get_label_time(self) -> int:
"""
return the timestamps of the current node labels
Returns:
t: time of the current node labels
"""
return self.dataset.return_label_ts()
def len(self) -> int:
"""
size of the dataset
Returns:
size: int
"""
return self._src.shape[0]
def get(self, idx: int) -> TemporalData:
"""
construct temporal data object for a single edge
Parameters:
idx: index of the edge
Returns:
data: TemporalData object
"""
data = TemporalData(
src=self._src[idx],
dst=self._dst[idx],
t=self._ts[idx],
msg=self._edge_feat[idx],
y=self._edge_label[idx],
)
return data
def __repr__(self) -> str:
return f"{self.name.capitalize()}()"
================================================
FILE: tgb/nodeproppred/evaluate.py
================================================
import numpy as np
from sklearn.metrics import mean_squared_error
from sklearn.metrics import ndcg_score
import math
from tgb.utils.info import DATA_EVAL_METRIC_DICT
try:
import torch
except ImportError:
torch = None
from tgb.utils.utils import vprint
class Evaluator(object):
"""Evaluator for Node Property Prediction"""
def __init__(self, name: str):
r"""
Parameters:
name: name of the dataset
"""
self.name = name
self.valid_metric_list = ["mse", "rmse", "ndcg"]
if self.name not in DATA_EVAL_METRIC_DICT:
raise NotImplementedError("Dataset not supported")
def _parse_and_check_input(self, input_dict):
"""
check whether the input has the required format
Parametrers:
-input_dict: a dictionary containing "y_true", "y_pred", and "eval_metric"
note: "eval_metric" should be a list including one or more of the followin metrics:
["mse"]
"""
# valid_metric_list = ['ap', 'au_roc_score', 'au_pr_score', 'acc', 'prec', 'rec', 'f1']
if "eval_metric" not in input_dict:
raise RuntimeError("Missing key of eval_metric")
for eval_metric in input_dict["eval_metric"]:
if eval_metric in self.valid_metric_list:
if "y_true" not in input_dict:
raise RuntimeError("Missing key of y_true")
if "y_pred" not in input_dict:
raise RuntimeError("Missing key of y_pred")
y_true, y_pred = input_dict["y_true"], input_dict["y_pred"]
# converting to numpy on cpu
if torch is not None and isinstance(y_true, torch.Tensor):
y_true = y_true.detach().cpu().numpy()
if torch is not None and isinstance(y_pred, torch.Tensor):
y_pred = y_pred.detach().cpu().numpy()
# check type and shape
if not isinstance(y_true, np.ndarray) or not isinstance(
y_pred, np.ndarray
):
raise RuntimeError(
"Arguments to Evaluator need to be either numpy ndarray or torch tensor!"
)
if not y_true.shape == y_pred.shape:
raise RuntimeError("Shape of y_true and y_pred must be the same!")
else:
raise ValueError(f"Unsupported eval metric: {eval_metric}, not found in {self.valid_metric_list}")
self.eval_metric = input_dict["eval_metric"]
return y_true, y_pred
def _compute_metrics(self, y_true, y_pred):
"""
compute the performance metrics for the given true labels and prediction probabilities
Parameters:
-y_true: actual true labels
-y_pred: predicted probabilities
"""
perf_dict = {}
for eval_metric in self.eval_metric:
if eval_metric == "mse":
perf_dict = {
"mse": mean_squared_error(y_true, y_pred),
"rmse": math.sqrt(mean_squared_error(y_true, y_pred)),
}
elif eval_metric == "ndcg":
k = 10
perf_dict = {"ndcg": ndcg_score(y_true, y_pred, k=k)}
return perf_dict
def eval(self, input_dict, verbose=False):
"""
evaluation for edge regression task
"""
y_true, y_pred = self._parse_and_check_input(input_dict)
perf_dict = self._compute_metrics(y_true, y_pred)
if verbose:
print("INFO: Evaluation Results:")
for eval_metric in input_dict["eval_metric"]:
print(f"\t>>> {eval_metric}: {perf_dict[eval_metric]:.4f}")
return perf_dict
@property
def expected_input_format(self):
desc = "==== Expected input format of Evaluator for {}\n".format(self.name)
if "mse" in self.valid_metric_list:
desc += "{'y_pred': y_pred}\n"
desc += "- y_pred: numpy ndarray or torch tensor of shape (num_edges, ). Torch tensor on GPU is recommended for efficiency.\n"
desc += "y_pred is the predicted weight for edges.\n"
else:
raise ValueError("Undefined eval metric %s" % (self.eval_metric))
return desc
@property
def expected_output_format(self):
desc = "==== Expected output format of Evaluator for {}\n".format(self.name)
if "mse" in self.valid_metric_list:
desc += "{'mse': mse\n"
desc += "- mse (float): mse score\n"
else:
raise ValueError("Undefined eval metric %s" % (self.eval_metric))
return desc
def main():
"""
simple test for evaluator
"""
name = "tgbn-trade"
evaluator = Evaluator(name=name)
print(evaluator.expected_input_format)
print(evaluator.expected_output_format)
input_dict = {"y_true": y_true, "y_pred": y_pred, "eval_metric": ["mse"]}
result_dict = evaluator.eval(input_dict)
print(result_dict)
if __name__ == "__main__":
main()
================================================
FILE: tgb/utils/dataset_stats.py
================================================
"""
Dataset statistics
"""
import numpy as np
import pandas as pd
import networkx as nx
import argparse
from tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset
from tgb.linkproppred.dataset import LinkPropPredDataset
from tgb.utils.utils import vprint
def get_unique_edges(sources, destination):
r"""
return unique edges
"""
unique_e = {}
for src, dst in zip(sources, destination):
if (src, dst) not in unique_e:
unique_e[(src, dst)] = True
return unique_e
def get_avg_e_per_ts(edgelist_df):
r"""
get the average number of edges per each timestamp
"""
sum_num_e_per_ts = 0
unique_ts = np.unique(np.array(edgelist_df['ts'].tolist()))
for ts in unique_ts:
num_e_at_this_ts = len(edgelist_df.loc[edgelist_df['ts'] == ts])
sum_num_e_per_ts += num_e_at_this_ts
avg_num_e_per_ts = (sum_num_e_per_ts * 1.0) / len(unique_ts)
return avg_num_e_per_ts
def get_avg_degree(edgelist_df):
r"""
get average degree over the timestamps
"""
degree_avg_at_ts_list = []
unique_ts = np.unique(np.array(edgelist_df['ts'].tolist()))
for ts in unique_ts:
e_at_this_ts = edgelist_df.loc[edgelist_df['ts'] == ts]
G = nx.MultiGraph()
for idx, e_row in e_at_this_ts.iterrows():
G.add_edge(e_row['src'], e_row['dst'], weight=e_row['ts'])
nodes = G.nodes()
degrees = [G.degree[n] for n in nodes]
degree_avg_at_ts_list.append(np.mean(degrees))
return np.mean(degree_avg_at_ts_list)
def get_index_metrics(train_val_data, test_data):
r"""
compute `surprise` and `recurrence` indices
"""
train_val_e_set = {}
for src, dst in zip(train_val_data['sources'], train_val_data['destinations']):
if (src, dst) not in train_val_e_set:
train_val_e_set[(src, dst)] = True
test_e_set = {}
for src, dst in zip(test_data['sources'], test_data['destinations']):
if (src, dst) not in test_e_set:
test_e_set[(src, dst)] = True
train_val_size = len(train_val_data['sources'])
test_size = len(test_data['sources'])
intersect = difference = 0
for e in test_e_set:
if e in train_val_e_set:
intersect += 1
else:
difference += 1
surprise = float(difference * 1.0 / test_size)
reoccurrence = float(intersect * 1.0 / train_val_size)
return surprise, reoccurrence
def get_node_ratio(history_data, eval_data):
r"""
compute the ratio of new nodes
"""
eval_uniq_nodes = set(eval_data['sources']).union(set(eval_data['destinations']))
hist_uniq_nodes = set(history_data['sources']).union(set(history_data['destinations']))
new_nodes = []
for node in eval_uniq_nodes:
if node not in hist_uniq_nodes:
new_nodes.append(node)
new_nodes = set(new_nodes)
new_node_ratio = float(len(new_nodes) * 1.0 / len(eval_uniq_nodes))
return new_node_ratio
def get_dataset_stats(data, temporal_stats=False):
r"""
returns simple stats based on counts
"""
# simple stats
sources, destinations, timestamps = data['full']['sources'], data['full']['destinations'], data['full']['timestamps']
edgelist_df = pd.DataFrame(zip(sources, destinations, timestamps), columns=['src', 'dst', 'ts'])
num_nodes = len(np.unique(np.concatenate((sources, destinations), axis=0)))
num_edges = len(sources) # = len(destinations) = len(timestamps)
num_unique_ts = len(np.unique(timestamps))
unique_e = get_unique_edges(sources, destinations)
num_unique_e = len(unique_e)
# compute temporal stats
if temporal_stats: # because it takes so long for large datasets...
avg_e_per_ts = get_avg_e_per_ts(edgelist_df)
avg_degree_per_ts = get_avg_degree(edgelist_df)
else:
avg_e_per_ts = -1
avg_degree_per_ts = -1
# compute reoccurrence & surprise
surprise, reoccurrence = get_index_metrics(data['train_val'], data['test'])
# compute new node ratio
val_nn_ratio = get_node_ratio(data['train'], data['val'])
#test_nn_ratio = get_node_ratio(data['train_val'], data['test'])
test_nn_ratio = get_node_ratio(data['train'], data['test'])
stats_dict = {
'num_nodes': num_nodes,
'num_edges': num_edges,
'num_unique_ts': num_unique_ts,
'num_unique_e': num_unique_e,
'avg_e_per_ts': avg_e_per_ts,
'avg_degree_per_ts': avg_degree_per_ts,
'surprise': surprise,
'reocurrence': reoccurrence,
'val_nn_ratio': val_nn_ratio,
'test_nn_ratio': test_nn_ratio,
}
return stats_dict
def main():
r"""
Generate dateset statistics
"""
parser = argparse.ArgumentParser(description='Dataset statistics')
parser.add_argument('-d', '--data', type=str, default='tgbl-wiki', help='random seed to use')
parser.add_argument('--tempstats', action='store_true', default=False, help='whether compute temporal statistics')
parser.parse_args()
args = parser.parse_args()
DATA = args.data
temporal_stats = args.tempstats
# data loading ...
if DATA in ['tgbl-wiki', 'tgbl-review', 'tgbl-flight', 'tgbl-comment', 'tgbl-coin']:
# load data: link prop. pred. with `numpy`
dataset = LinkPropPredDataset(name=DATA, root="datasets", preprocess=True)
data = dataset.full_data
# get masks
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
train_data = {'sources': data['sources'][train_mask],
'destinations': data['destinations'][train_mask],
}
val_data = {'sources': data['sources'][val_mask],
'destinations': data['destinations'][val_mask],
}
train_val_data = {'sources': np.concatenate([data['sources'][train_mask], data['sources'][val_mask]]),
'destinations': np.concatenate([data['destinations'][train_mask], data['destinations'][val_mask]]),
}
test_data = {'sources': data['sources'][test_mask],
'destinations': data['destinations'][test_mask],
}
full_data = {'sources': data['sources'],
'destinations': data['destinations'],
'timestamps': data['timestamps'],
}
elif DATA in ['tgbn-trade', 'tgbn-genre', 'tgbn-reddit', 'tgbn-token']:
# load data: node prop. pred.
dataset = PyGNodePropPredDataset(name=DATA, root="datasets")
data = dataset.get_TemporalData()
# split data
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
train_val_mask = np.logical_or(np.array(train_mask), np.array(val_mask))
train_data = {'sources': np.array(data[train_mask].src),
'destinations': np.array(data[train_mask].dst),
}
val_data = {'sources': np.array(data[val_mask].src),
'destinations': np.array(data[val_mask].dst),
}
train_val_data = {'sources': np.concatenate([np.array(data[train_mask].src), np.array(data[val_mask].src)]),
'destinations': np.concatenate([np.array(data[train_mask].dst), np.array(data[val_mask].dst)]),
}
test_data = {'sources': np.array(data[test_mask].src),
'destinations': np.array(data[test_mask].dst),
}
full_data = {'sources': np.array(data.src),
'destinations': np.array(data.dst),
'timestamps': np.array(data.t),
}
else:
raise ValueError("Unsupported data!")
split_data = {'train': train_data,
'val': val_data,
'train_val': train_val_data,
'test': test_data,
'full': full_data,
}
vprint("=============================")
vprint(f">>> DATA: {DATA}")
dataset_stats = get_dataset_stats(split_data, temporal_stats)
for k, v in dataset_stats.items():
vprint(f"{k}: {v}")
vprint("=============================")
if __name__ == "__main__":
main()
================================================
FILE: tgb/utils/info.py
================================================
import os.path as osp
import os
r"""
General space to store global information used elsewhere such as url links, evaluation metrics etc.
"""
PROJ_DIR = osp.dirname(osp.abspath(os.path.join(__file__, os.pardir))) + "/"
class BColors:
"""
A class to change the colors of the strings.
"""
HEADER = "\033[95m"
OKBLUE = "\033[94m"
OKCYAN = "\033[96m"
OKGREEN = "\033[92m"
WARNING = "\033[93m"
FAIL = "\033[91m"
ENDC = "\033[0m"
BOLD = "\033[1m"
UNDERLINE = "\033[4m"
DATA_URL_DICT = {
"tgbl-enron": "https://object-arbutus.cloud.computecanada.ca/tgb/tgbl-enron.zip",
"tgbl-uci": "https://object-arbutus.cloud.computecanada.ca/tgb/tgbl-uci.zip",
"tgbl-wiki":"https://object-arbutus.cloud.computecanada.ca/tgb/tgbl-wiki-v2.zip", #"https://object-arbutus.cloud.computecanada.ca/tgb/tgbl-wiki.zip", #v1
"tgbl-subreddit":"https://object-arbutus.cloud.computecanada.ca/tgb/tgbl-subreddit.zip",
"tgbl-lastfm":"https://object-arbutus.cloud.computecanada.ca/tgb/tgbl-lastfm.zip",
"tgbl-review": "https://object-arbutus.cloud.computecanada.ca/tgb/tgbl-review-v2.zip", # "https://object-arbutus.cloud.computecanada.ca/tgb/tgbl-review-v3.zip" #"https://object-arbutus.cloud.computecanada.ca/tgb/tgbl-review-v2.zip" #"https://object-arbutus.cloud.computecanada.ca/tgb/tgbl-review.zip", #v1
"tgbl-coin": "https://object-arbutus.cloud.computecanada.ca/tgb/tgbl-coin-v2.zip", #"https://object-arbutus.cloud.computecanada.ca/tgb/tgbl-coin.zip",
"tgbl-flight": "https://object-arbutus.cloud.computecanada.ca/tgb/tgbl-flight-v2.zip", #"tgbl-flight": "https://object-arbutus.cloud.computecanada.ca/tgb/tgbl-flight_edgelist_v2_ts.zip",
"tgbl-comment": "https://object-arbutus.cloud.computecanada.ca/tgb/tgbl-comment.zip",
"tgbn-trade": "https://object-arbutus.cloud.computecanada.ca/tgb/tgbn-trade.zip",
"tgbn-genre": "https://object-arbutus.cloud.computecanada.ca/tgb/tgbn-genre.zip",
"tgbn-reddit": "https://object-arbutus.cloud.computecanada.ca/tgb/tgbn-reddit.zip",
"tgbn-token": "https://object-arbutus.cloud.computecanada.ca/tgb/tgbn-token.zip",
"tkgl-polecat": "https://object-arbutus.cloud.computecanada.ca/tgb/tkgl-polecat.zip",
"tkgl-icews": "https://object-arbutus.cloud.computecanada.ca/tgb/tkgl-icews.zip",
"tkgl-yago":"https://object-arbutus.cloud.computecanada.ca/tgb/tkgl-yago.zip",
"tkgl-wikidata": "https://object-arbutus.cloud.computecanada.ca/tgb/tkgl-wikidata.zip",
"tkgl-smallpedia": "https://object-arbutus.cloud.computecanada.ca/tgb/tkgl-smallpedia.zip",
"thgl-myket": "https://object-arbutus.cloud.computecanada.ca/tgb/thgl-myket.zip",
"thgl-github": "https://object-arbutus.cloud.computecanada.ca/tgb/thgl-github.zip",
"thgl-forum": "https://object-arbutus.cloud.computecanada.ca/tgb/thgl-forum.zip",
"thgl-software": "https://object-arbutus.cloud.computecanada.ca/tgb/thgl-software.zip", #"https://object-arbutus.cloud.computecanada.ca/tgb/thgl-software_ns_random.zip"
}
DATA_VERSION_DICT = {
"tgbl-enron": 1,
"tgbl-uci": 1,
"tgbl-wiki": 2,
"tgbl-subreddit": 1,
"tgbl-lastfm": 1,
"tgbl-review": 2, #3
"tgbl-coin": 2,
"tgbl-comment": 1,
"tgbl-flight": 2,
"tgbn-trade": 1,
"tgbn-genre": 1,
"tgbn-reddit": 1,
"tgbn-token": 1,
"tkgl-polecat": 1,
"tkgl-icews": 1,
"tkgl-yago": 1,
"tkgl-wikidata": 1,
"tkgl-smallpedia": 1,
"thgl-myket": 1,
"thgl-github": 1,
"thgl-forum": 1,
"thgl-software": 1,
}
DATA_EVAL_METRIC_DICT = {
"tgbl-enron": "mrr",
"tgbl-uci": "mrr",
"tgbl-wiki": "mrr",
"tgbl-subreddit": "mrr",
"tgbl-lastfm": "mrr",
"tgbl-review": "mrr",
"tgbl-coin": "mrr",
"tgbl-comment": "mrr",
"tgbl-flight": "mrr",
"tkgl-polecat": "mrr",
"tkgl-yago": "mrr",
"tkgl-wikidata": "mrr",
"tkgl-smallpedia": "mrr",
"tkgl-icews": "mrr",
"thgl-myket": "mrr",
"thgl-github": "mrr",
"thgl-forum": "mrr",
"thgl-software": "mrr",
"tgbn-trade": "ndcg",
"tgbn-genre": "ndcg",
"tgbn-reddit": "ndcg",
"tgbn-token": "ndcg",
}
DATA_NS_STRATEGY_DICT = {
"tgbl-enron": "hist_rnd",
"tgbl-uci": "hist_rnd",
"tgbl-wiki": "hist_rnd",
"tgbl-subreddit": "hist_rnd",
"tgbl-lastfm": "hist_rnd",
"tgbl-review": "hist_rnd",
"tgbl-coin": "hist_rnd",
"tgbl-comment": "hist_rnd",
"tgbl-flight": "hist_rnd",
"tkgl-polecat": "time-filtered",
"tkgl-yago": "time-filtered",
"tkgl-wikidata": "dst-time-filtered",
"tkgl-smallpedia": "time-filtered",
"tkgl-icews": "time-filtered",
"thgl-myket": "node-type-filtered",
"thgl-github": "node-type-filtered",
"thgl-forum": "node-type-filtered",
"thgl-software": "node-type-filtered",
}
DATA_NUM_CLASSES = {
"tgbn-trade": 255,
"tgbn-genre": 513,
"tgbn-reddit": 698,
"tgbn-token": 1001,
}
================================================
FILE: tgb/utils/pre_process.py
================================================
from typing import Optional, cast, Union, List, overload, Literal
from tqdm import tqdm
import numpy as np
import pandas as pd
import os.path as osp
import time
import csv
import datetime
from datetime import date
from tgb.utils.utils import vprint
"""
function to process node type for thg datasets
"""
def process_node_type(
fname: str,
node_ids,
):
"""
1. process the node type into integer
3. return a numpy array of node types with index corresponding to node id
"""
node_feat = np.zeros(len(node_ids))
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
idx = 0
# node_id,type
for row in tqdm(csv_reader):
if idx == 0:
idx += 1
continue
else:
nid = int(row[0])
try:
node_type = int(row[1])
except:
raise ValueError(row[1], " is not an integer thus can't be a node type for thg dataset")
try:
node_id = node_ids[nid]
except:
raise ValueError(nid, " is not a valid node id")
node_feat[node_id] = node_type
return node_feat
"""
functions for thgl-forum dataset
"""
def csv_to_forum_data(
fname: str,
) -> pd.DataFrame:
r"""
used by thgl-forum dataset
convert the raw .csv data to pandas dataframe and numpy array
input .csv file format should be: timestamp, head, tail, relation type
Args:
fname: the path to the raw data
"""
feat_size = 2
num_lines = sum(1 for line in open(fname)) - 1
vprint(f"number of lines counted: {num_lines} in {fname}")
u_list = np.zeros(num_lines)
i_list = np.zeros(num_lines)
ts_list = np.zeros(num_lines)
label_list = np.zeros(num_lines)
edge_type = np.zeros(num_lines)
feat_l = np.zeros((num_lines, feat_size))
idx_list = np.zeros(num_lines)
w_list = np.zeros(num_lines)
node_ids = {}
unique_id = 0
word_max = 10000
score_max = 10000
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
idx = 0
#timestamp, head, tail, relation type
for row in tqdm(csv_reader):
if idx == 0:
idx += 1
continue
else:
#! ts,src,dst,relation_type,num_words,score
ts = int(row[0]) #converted to UNIX timestamp already
src = int(row[1])
dst = int(row[2])
relation = int(row[3])
num_words = int(row[4])
score = int(row[5])
if src not in node_ids:
node_ids[src] = unique_id
unique_id += 1
if dst not in node_ids:
node_ids[dst] = unique_id
unique_id += 1
u = node_ids[src]
i = node_ids[dst]
u_list[idx - 1] = u
i_list[idx - 1] = i
ts_list[idx - 1] = ts
idx_list[idx - 1] = idx
w_list[idx - 1] = float(1)
edge_type[idx - 1] = relation
feat_l[idx - 1] = np.array([num_words/word_max, score/score_max])
idx += 1
return (
pd.DataFrame(
{
"u": u_list,
"i": i_list,
"ts": ts_list,
"label": label_list,
"idx": idx_list,
"w": w_list,
"edge_type": edge_type,
}
),
feat_l,
node_ids,
)
"""
functions for thg dataset
"""
def csv_to_thg_data(
fname: str,
) -> pd.DataFrame:
r"""
used by thgl-myket dataset
convert the raw .csv data to pandas dataframe and numpy array
input .csv file format should be: timestamp, head, tail, relation type
Args:
fname: the path to the raw data
"""
feat_size = 1
num_lines = sum(1 for line in open(fname)) - 1
vprint(f"number of lines counted: {num_lines} in {fname}")
u_list = np.zeros(num_lines)
i_list = np.zeros(num_lines)
ts_list = np.zeros(num_lines)
label_list = np.zeros(num_lines)
edge_type = np.zeros(num_lines)
feat_l = np.zeros((num_lines, feat_size))
idx_list = np.zeros(num_lines)
w_list = np.zeros(num_lines)
node_ids = {}
unique_id = 0
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
idx = 0
#timestamp, head, tail, relation type
for row in tqdm(csv_reader):
if idx == 0:
idx += 1
continue
else:
ts = int(row[0]) #converted to UNIX timestamp already
src = int(row[1])
dst = int(row[2])
relation = int(row[3])
if src not in node_ids:
node_ids[src] = unique_id
unique_id += 1
if dst not in node_ids:
node_ids[dst] = unique_id
unique_id += 1
u = node_ids[src]
i = node_ids[dst]
u_list[idx - 1] = u
i_list[idx - 1] = i
ts_list[idx - 1] = ts
idx_list[idx - 1] = idx
w_list[idx - 1] = float(1)
edge_type[idx - 1] = relation
idx += 1
return (
pd.DataFrame(
{
"u": u_list,
"i": i_list,
"ts": ts_list,
"label": label_list,
"idx": idx_list,
"w": w_list,
"edge_type": edge_type,
}
),
feat_l,
node_ids,
)
"""
functions for tkgl-wikidata dataset
"""
def csv_to_wikidata(
fname: str,
) -> pd.DataFrame:
r"""
used by tkgl-wikidata and tkgl-smallpedia
convert the raw .csv data to pandas dataframe and numpy array
input .csv file format should be: timestamp, head, tail, relation type
Args:
fname: the path to the raw data
"""
feat_size = 1
num_lines = sum(1 for line in open(fname)) - 1
vprint(f"number of lines counted: {num_lines} in {fname}")
u_list = np.zeros(num_lines)
i_list = np.zeros(num_lines)
ts_list = np.zeros(num_lines)
label_list = np.zeros(num_lines)
edge_type = np.zeros(num_lines)
feat_l = np.zeros((num_lines, feat_size))
idx_list = np.zeros(num_lines)
w_list = np.zeros(num_lines)
node_ids = {}
edge_type_ids = {}
unique_id = 0
et_id = 0
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
idx = 0
#timestamp, head, tail, relation type
for row in tqdm(csv_reader):
if idx == 0:
idx += 1
continue
else:
ts = int(row[0]) #converted to year already
src = row[1]
dst = row[2]
relation = row[3]
if src not in node_ids:
node_ids[src] = unique_id
unique_id += 1
if dst not in node_ids:
node_ids[dst] = unique_id
unique_id += 1
if relation not in edge_type_ids:
edge_type_ids[relation] = et_id
et_id += 1
u = node_ids[src]
i = node_ids[dst]
u_list[idx - 1] = u
i_list[idx - 1] = i
ts_list[idx - 1] = ts
idx_list[idx - 1] = idx
w_list[idx - 1] = float(1)
edge_type[idx - 1] = edge_type_ids[relation]
idx += 1
return (
pd.DataFrame(
{
"u": u_list,
"i": i_list,
"ts": ts_list,
"label": label_list,
"idx": idx_list,
"w": w_list,
"edge_type": edge_type,
}
),
feat_l,
node_ids,
)
def csv_to_staticdata(
fname: str,
node_ids: dict,
) -> pd.DataFrame:
r"""
used by tkgl-wikidata and tkgl-smallpedia
convert the raw .csv data to pandas dataframe and numpy array for static knowledge edges
input .csv file format should be: head, tail, relation type
Args:
fname: the path to the raw data
node_ids: dictionary of node names mapped to integer node ids
"""
num_lines = sum(1 for line in open(fname)) - 1
vprint(f"number of lines counted: {num_lines} in {fname}")
u_list = np.zeros(num_lines)
i_list = np.zeros(num_lines)
edge_type = np.zeros(num_lines)
edge_type_ids = {}
out_dict = {}
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
idx = 0
#timestamp, head, tail, relation type
for row in tqdm(csv_reader):
if idx == 0:
idx += 1
continue
else:
src = row[0]
dst = row[1]
relation = row[2]
if src not in node_ids:
node_ids[src] = len(node_ids)
if dst not in node_ids:
node_ids[dst] = len(node_ids)
if relation not in edge_type_ids:
edge_type_ids[relation] = len(edge_type_ids)
u = node_ids[src]
i = node_ids[dst]
u_list[idx - 1] = u
i_list[idx - 1] = i
edge_type[idx - 1] = edge_type_ids[relation]
idx += 1
out_dict["head"] = u_list
out_dict["tail"] = i_list
out_dict["edge_type"] = edge_type
return out_dict, node_ids
"""
functions for tkgl-polecat, tkgl-icews dataset
"""
def csv_to_tkg_data(
fname: str,
) -> pd.DataFrame:
r"""
used by tkgl-polecat
convert the raw .csv data to pandas dataframe and numpy array
input .csv file format should be: timestamp, head, tail, relation type
Args:
fname: the path to the raw data
"""
feat_size = 1
num_lines = sum(1 for line in open(fname)) - 1
vprint(f"number of lines counted: {num_lines} in {fname}")
u_list = np.zeros(num_lines)
i_list = np.zeros(num_lines)
ts_list = np.zeros(num_lines)
label_list = np.zeros(num_lines)
edge_type = np.zeros(num_lines)
feat_l = np.zeros((num_lines, feat_size))
idx_list = np.zeros(num_lines)
w_list = np.zeros(num_lines)
node_ids = {}
unique_id = 0
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
idx = 0
#timestamp, head, tail, relation type
for row in tqdm(csv_reader):
if idx == 0:
idx += 1
continue
else:
ts = int(row[0]) #converted to UNIX timestamp already
src = int(row[1])
dst = int(row[2])
relation = int(row[3])
if src not in node_ids:
node_ids[src] = unique_id
unique_id += 1
if dst not in node_ids:
node_ids[dst] = unique_id
unique_id += 1
u = node_ids[src]
i = node_ids[dst]
u_list[idx - 1] = u
i_list[idx - 1] = i
ts_list[idx - 1] = ts
idx_list[idx - 1] = idx
w_list[idx - 1] = float(1)
edge_type[idx - 1] = relation
idx += 1
return (
pd.DataFrame(
{
"u": u_list,
"i": i_list,
"ts": ts_list,
"label": label_list,
"idx": idx_list,
"w": w_list,
"edge_type": edge_type,
}
),
feat_l,
node_ids,
)
"""
functions for wikipedia dataset
---------------------------------------
"""
def load_edgelist_wiki(fname: str) -> pd.DataFrame:
"""
loading wikipedia dataset into pandas dataframe
similar processing to
https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/datasets/jodie.html
Parameters:
fname: str, name of the input file
Returns:
df: a pandas dataframe containing the edgelist data
"""
df = pd.read_csv(fname, skiprows=1, header=None)
src = df.iloc[:, 0].values
dst = df.iloc[:, 1].values
dst += int(src.max()) + 1
t = df.iloc[:, 2].values
msg = df.iloc[:, 4:].values
idx = np.arange(t.shape[0])
w = np.ones(t.shape[0])
return pd.DataFrame({"u": src, "i": dst, "ts": t, "idx": idx, "w": w}), msg, None
"""
functions for un_trade dataset
---------------------------------------
"""
def load_edgelist_trade(fname: str, label_size=255):
"""
load the edgelist into pandas dataframe
"""
feat_size = 1
num_lines = sum(1 for line in open(fname)) - 1
u_list = np.zeros(num_lines)
i_list = np.zeros(num_lines)
ts_list = np.zeros(num_lines)
feat_l = np.zeros((num_lines, feat_size))
idx_list = np.zeros(num_lines)
w_list = np.zeros(num_lines)
node_ids = {} # dictionary for node ids
node_uid = 0
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
idx = 0
for row in tqdm(csv_reader):
if idx == 0:
idx += 1
else:
ts = int(row[0])
u = row[1]
v = row[2]
w = float(row[3])
if u not in node_ids:
node_ids[u] = node_uid
node_uid += 1
if v not in node_ids:
node_ids[v] = node_uid
node_uid += 1
u = node_ids[u]
i = node_ids[v]
u_list[idx - 1] = u
i_list[idx - 1] = i
ts_list[idx - 1] = ts
idx_list[idx - 1] = idx
w_list[idx - 1] = w
feat_l[idx - 1] = np.array([w])
idx += 1
return (
pd.DataFrame(
{"u": u_list, "i": i_list, "ts": ts_list, "idx": idx_list, "w": w_list}
),
feat_l,
node_ids,
)
def load_trade_label_dict(
fname: str,
node_ids: dict,
) -> dict:
"""
load node labels into a nested dictionary instead of pandas dataobject
{ts: {node_id: label_vec}}
Parameters:
fname: str, name of the input file
node_ids: dictionary of user names mapped to integer node ids
Returns:
node_label_dict: a nested dictionary of node labels
"""
if not osp.exists(fname):
raise FileNotFoundError(f"File not found at {fname}")
label_size = len(node_ids)
#label_vec = np.zeros(label_size)
node_label_dict = {} # {ts: {node_id: label_vec}}
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
idx = 0
for row in tqdm(csv_reader):
if idx == 0:
idx += 1
else:
ts = int(row[0])
u = node_ids[row[1]]
v = node_ids[row[2]]
weight = float(row[3])
if (ts not in node_label_dict):
node_label_dict[ts] = {u:np.zeros(label_size)}
if (u not in node_label_dict[ts]):
node_label_dict[ts][u] = np.zeros(label_size)
node_label_dict[ts][u][v] = weight
idx += 1
return node_label_dict
"""
functions for tgbn-token
---------------------------------------
"""
def load_edgelist_token(
fname: str,
label_size: int = 1001,
) -> pd.DataFrame:
"""
load the edgelist into pandas dataframe
also outputs index for the user nodes and genre nodes
Parameters:
fname: str, name of the input file
label_size: int, number of genres
Returns:
df: a pandas dataframe containing the edgelist data
"""
feat_size = 2
num_lines = sum(1 for line in open(fname)) - 1
vprint("there are ", num_lines, " lines in the raw data")
u_list = np.zeros(num_lines)
i_list = np.zeros(num_lines)
ts_list = np.zeros(num_lines)
label_list = np.zeros(num_lines)
feat_l = np.zeros((num_lines, feat_size))
idx_list = np.zeros(num_lines)
w_list = np.zeros(num_lines)
node_ids = {}
rd_dict = {}
node_uid = label_size # node ids start after all the genres
sr_uid = 0
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
idx = 0
# [timestamp,user_address,token_address,value,IsSender]
for row in tqdm(csv_reader):
if idx == 0:
idx += 1
else:
ts = row[0]
src = row[1]
token = row[2]
w = float(row[3])
attr = float(row[4])
if src not in node_ids:
node_ids[src] = node_uid
node_uid += 1
if token not in rd_dict:
rd_dict[token] = sr_uid
sr_uid += 1
u = node_ids[src]
i = rd_dict[token]
u_list[idx - 1] = u
i_list[idx - 1] = i
ts_list[idx - 1] = ts
idx_list[idx - 1] = idx
w_list[idx - 1] = w
feat_l[idx - 1] = np.array([w,attr])
idx += 1
return (
pd.DataFrame(
{
"u": u_list,
"i": i_list,
"ts": ts_list,
"label": label_list,
"idx": idx_list,
"w": w_list,
}
),
feat_l,
node_ids,
rd_dict,
)
"""
functions for subreddits dataset
---------------------------------------
"""
def load_edgelist_sr(
fname: str,
label_size: int = 2221,
) -> pd.DataFrame:
"""
load the edgelist into pandas dataframe
also outputs index for the user nodes and genre nodes
Parameters:
fname: str, name of the input file
label_size: int, number of genres
Returns:
df: a pandas dataframe containing the edgelist data
"""
feat_size = 1 #2
num_lines = sum(1 for line in open(fname)) - 1
vprint("there are ", num_lines, " lines in the raw data")
u_list = np.zeros(num_lines)
i_list = np.zeros(num_lines)
ts_list = np.zeros(num_lines)
label_list = np.zeros(num_lines)
feat_l = np.zeros((num_lines, feat_size))
idx_list = np.zeros(num_lines)
w_list = np.zeros(num_lines)
node_ids = {}
rd_dict = {}
node_uid = label_size # node ids start after all the genres
sr_uid = 0
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
idx = 0
# ['ts', 'src', 'subreddit', 'num_words', 'score']
for row in tqdm(csv_reader):
if idx == 0:
idx += 1
else:
ts = row[0]
src = row[1]
subreddit = row[2]
#num_words = int(row[3])
score = int(row[4])
if src not in node_ids:
node_ids[src] = node_uid
node_uid += 1
if subreddit not in rd_dict:
rd_dict[subreddit] = sr_uid
sr_uid += 1
w = float(score)
u = node_ids[src]
i = rd_dict[subreddit]
u_list[idx - 1] = u
i_list[idx - 1] = i
ts_list[idx - 1] = ts
idx_list[idx - 1] = idx
w_list[idx - 1] = w
feat_l[idx - 1] = np.array([w])
idx += 1
return (
pd.DataFrame(
{
"u": u_list,
"i": i_list,
"ts": ts_list,
"label": label_list,
"idx": idx_list,
"w": w_list,
}
),
feat_l,
node_ids,
rd_dict,
)
def load_labels_sr(
fname,
node_ids,
rd_dict,
):
"""
load the node labels for subreddit dataset
"""
if not osp.exists(fname):
raise FileNotFoundError(f"File not found at {fname}")
# day, user_idx, label_vec
label_size = len(rd_dict)
label_vec = np.zeros(label_size)
ts_prev = 0
prev_user = 0
ts_list = []
node_id_list = []
y_list = []
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
idx = 0
# ['ts', 'src', 'subreddit', 'num_words', 'score']
for row in tqdm(csv_reader):
if idx == 0:
idx += 1
else:
user_id = node_ids[int(row[1])]
ts = int(row[0])
sr_id = int(rd_dict[row[2]])
weight = float(row[3])
if idx == 1:
ts_prev = ts
prev_user = user_id
# the next day
if ts != ts_prev:
ts_list.append(ts_prev)
node_id_list.append(prev_user)
y_list.append(label_vec)
label_vec = np.zeros(label_size)
ts_prev = ts
prev_user = user_id
else:
label_vec[sr_id] = weight
if user_id != prev_user:
ts_list.append(ts_prev)
node_id_list.append(prev_user)
y_list.append(label_vec)
prev_user = user_id
label_vec = np.zeros(label_size)
idx += 1
return pd.DataFrame({"ts": ts_list, "node_id": node_id_list, "y": y_list})
def load_label_dict(fname: str, node_ids: dict, rd_dict: dict) -> dict:
"""
load node labels into a nested dictionary instead of pandas dataobject
{ts: {node_id: label_vec}}
Parameters:
fname: str, name of the input file
node_ids: dictionary of user names mapped to integer node ids
rd_dict: dictionary of subreddit names mapped to integer node ids
"""
if not osp.exists(fname):
raise FileNotFoundError(f"File not found at {fname}")
# day, user_idx, label_vec
label_size = len(rd_dict)
node_label_dict = {} # {ts: {node_id: label_vec}}
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
idx = 0
# ['ts', 'src', 'dst', 'w']
for row in tqdm(csv_reader):
if idx == 0:
idx += 1
else:
u = node_ids[row[1]]
ts = int(row[0])
v = int(rd_dict[row[2]])
weight = float(row[3])
if (ts not in node_label_dict):
node_label_dict[ts] = {u:np.zeros(label_size)}
if (u not in node_label_dict[ts]):
node_label_dict[ts][u] = np.zeros(label_size)
node_label_dict[ts][u][v] = weight
idx += 1
return node_label_dict
"""
functions for redditcomments
-------------------------------------------
"""
def csv_to_pd_data_rc(
fname: str,
) -> pd.DataFrame:
r"""
currently used by redditcomments dataset
convert the raw .csv data to pandas dataframe and numpy array
input .csv file format should be: timestamp, node u, node v, attributes
Args:
fname: the path to the raw data
"""
feat_size = 2 # 1 for subreddit, 1 for num words
num_lines = sum(1 for line in open(fname)) - 1
vprint("there are ", num_lines, " lines in the raw data")
u_list = np.zeros(num_lines)
i_list = np.zeros(num_lines)
ts_list = np.zeros(num_lines)
label_list = np.zeros(num_lines)
feat_l = np.zeros((num_lines, feat_size))
idx_list = np.zeros(num_lines)
w_list = np.zeros(num_lines)
node_ids = {}
unique_id = 0
max_words = 5000 # counted form statistics
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
idx = 0
# ['ts', 'src', 'dst', 'subreddit', 'num_words', 'score']
for row in tqdm(csv_reader):
if idx == 0:
idx += 1
continue
else:
ts = int(row[0])
src = row[1]
dst = row[2]
num_words = int(row[3]) / max_words # int number, normalize to [0,1]
score = int(row[4]) # int number
# reindexing node and subreddits
if src not in node_ids:
node_ids[src] = unique_id
unique_id += 1
if dst not in node_ids:
node_ids[dst] = unique_id
unique_id += 1
w = float(score)
u = node_ids[src]
i = node_ids[dst]
u_list[idx - 1] = u
i_list[idx - 1] = i
ts_list[idx - 1] = ts
idx_list[idx - 1] = idx
w_list[idx - 1] = w
feat_l[idx - 1] = np.array([num_words])
idx += 1
vprint("there are ", len(node_ids), " unique nodes")
return (
pd.DataFrame(
{
"u": u_list,
"i": i_list,
"ts": ts_list,
"label": label_list,
"idx": idx_list,
"w": w_list,
}
),
feat_l,
node_ids,
)
"""
functions for stablecoin
-------------------------------------------
"""
def csv_to_pd_data_sc(
fname: str,
) -> pd.DataFrame:
r"""
currently used by stablecoin dataset
convert the raw .csv data to pandas dataframe and numpy array
input .csv file format should be: timestamp, node u, node v, attributes
Parameters:
fname: the path to the raw data
Returns:
df: a pandas dataframe containing the edgelist data
feat_l: a numpy array containing the node features
node_ids: a dictionary mapping node id to integer
"""
feat_size = 1
num_lines = sum(1 for line in open(fname)) - 1
vprint(f"number of lines counted: {num_lines} in {fname}")
u_list = np.zeros(num_lines)
i_list = np.zeros(num_lines)
ts_list = np.zeros(num_lines)
label_list = np.zeros(num_lines)
feat_l = np.zeros((num_lines, feat_size))
idx_list = np.zeros(num_lines)
w_list = np.zeros(num_lines)
node_ids = {}
unique_id = 0
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
idx = 0
# time,src,dst,weight
# 1648811421,0x27cbb0e6885ccb1db2dab7c2314131c94795fbef,0x8426a27add8dca73548f012d92c7f8f4bbd42a3e,800.0
for row in tqdm(csv_reader):
if idx == 0:
idx += 1
continue
else:
ts = int(row[0])
src = row[1]
dst = row[2]
if src not in node_ids:
node_ids[src] = unique_id
unique_id += 1
if dst not in node_ids:
node_ids[dst] = unique_id
unique_id += 1
w = float(row[3])
if w == 0:
w = 1
u = node_ids[src]
i = node_ids[dst]
u_list[idx - 1] = u
i_list[idx - 1] = i
ts_list[idx - 1] = ts
idx_list[idx - 1] = idx
w_list[idx - 1] = w
feat_l[idx - 1] = np.zeros(feat_size)
idx += 1
#! normalize by log 2 for stablecoin
w_list = np.log2(w_list)
return (
pd.DataFrame(
{
"u": u_list,
"i": i_list,
"ts": ts_list,
"label": label_list,
"idx": idx_list,
"w": w_list,
}
),
feat_l,
node_ids,
)
"""
functions for opensky
-------------------------------------------
"""
def convert_str2int(
in_str: str,
) -> np.ndarray:
"""
convert strings to vectors of integers based on individual character
each letter is converted as follows, a=10, b=11
numbers are still int
Parameters:
in_str: an input string to parse
Returns:
out: a numpy integer array
"""
out = []
for element in in_str:
if element.isnumeric():
out.append(element)
elif element == "!":
out.append(-1)
else:
out.append(ord(element.upper()) - 44 + 9)
out = np.array(out, dtype=np.float32)
return out
def csv_to_pd_data(
fname: str,
) -> pd.DataFrame:
r"""
currently used by tgbl-flight dataset
convert the raw .csv data to pandas dataframe and numpy array
input .csv file format should be: timestamp, node u, node v, attributes
Args:
fname: the path to the raw data
"""
feat_size = 16
num_lines = sum(1 for line in open(fname)) - 1
vprint(f"number of lines counted: {num_lines} in {fname}")
u_list = np.zeros(num_lines)
i_list = np.zeros(num_lines)
ts_list = np.zeros(num_lines)
label_list = np.zeros(num_lines)
feat_l = np.zeros((num_lines, feat_size))
idx_list = np.zeros(num_lines)
w_list = np.zeros(num_lines)
node_ids = {}
unique_id = 0
ts_format = None
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
idx = 0
#'day','src','dst','callsign','typecode'
for row in tqdm(csv_reader):
if idx == 0:
idx += 1
continue
else:
ts = row[0]
if ts_format is None:
if (ts.isdigit()):
ts_format = True
else:
ts_format = False
if ts_format:
ts = float(int(ts)) #unix timestamp already
else:
#convert to unix timestamp
TIME_FORMAT = "%Y-%m-%d"
date_cur = datetime.datetime.strptime(ts, TIME_FORMAT)
ts = float(date_cur.timestamp())
# TIME_FORMAT = "%Y-%m-%d" # 2019-01-01
# date_cur = date.fromisoformat(ts)
# dt = datetime.datetime.combine(date_cur, datetime.datetime.min.time())
# dt = dt.replace(tzinfo=datetime.timezone.edt)
# ts = float(dt.timestamp())
src = row[1]
dst = row[2]
# 'callsign' has max size 8, can be 4, 5, 6, or 7
# 'typecode' has max size 8
# use ! as padding
# pad row[3] to size 7
if len(row[3]) == 0:
row[3] = "!!!!!!!!"
while len(row[3]) < 8:
row[3] += "!"
# pad row[4] to size 4
if len(row[4]) == 0:
row[4] = "!!!!!!!!"
while len(row[4]) < 8:
row[4] += "!"
if len(row[4]) > 8:
row[4] = "!!!!!!!!"
feat_str = row[3] + row[4]
if src not in node_ids:
node_ids[src] = unique_id
unique_id += 1
if dst not in node_ids:
node_ids[dst] = unique_id
unique_id += 1
u = node_ids[src]
i = node_ids[dst]
u_list[idx - 1] = u
i_list[idx - 1] = i
ts_list[idx - 1] = ts
idx_list[idx - 1] = idx
w_list[idx - 1] = float(1)
feat_l[idx - 1] = convert_str2int(feat_str)
idx += 1
return (
pd.DataFrame(
{
"u": u_list,
"i": i_list,
"ts": ts_list,
"label": label_list,
"idx": idx_list,
"w": w_list,
}
),
feat_l,
node_ids,
)
def process_node_feat(
fname: str,
node_ids,
):
"""
1. need to have the same node id as csv_to_pd_data
2. process the various node features into a vector
3. return a numpy array of node features with index corresponding to node id
airport_code,type,continent,iso_region,longitude,latitude
type: onehot encoding
continent: onehot encoding
iso_region: alphabet encoding same as edge feat
longitude: float divide by 180
latitude: float divide by 90
"""
feat_size = 20
node_feat = np.zeros((len(node_ids), feat_size))
type_dict = {}
type_idx = 0
continent_dict = {}
cont_idx = 0
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
idx = 0
# airport_code,type,continent,iso_region,longitude,latitude
for row in tqdm(csv_reader):
if idx == 0:
idx += 1
continue
else:
code = row[0]
if code not in node_ids:
continue
else:
node_id = node_ids[code]
airport_type = row[1]
if airport_type not in type_dict:
type_dict[airport_type] = type_idx
type_idx += 1
continent = row[2]
if continent not in continent_dict:
continent_dict[continent] = cont_idx
cont_idx += 1
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
idx = 0
# airport_code,type,continent,iso_region,longitude,latitude
for row in tqdm(csv_reader):
if idx == 0:
idx += 1
continue
else:
code = row[0]
if code not in node_ids:
continue
else:
node_id = node_ids[code]
airport_type = type_dict[row[1]]
type_vec = np.zeros(type_idx)
type_vec[airport_type] = 1
continent = continent_dict[row[2]]
cont_vec = np.zeros(cont_idx)
cont_vec[continent] = 1
while len(row[3]) < 7:
row[3] += "!"
iso_region = convert_str2int(row[3]) # numpy float array
lng = float(row[4])
lat = float(row[5])
coor_vec = np.array([lng, lat])
final = np.concatenate(
(type_vec, cont_vec, iso_region, coor_vec), axis=0
)
node_feat[node_id] = final
return node_feat
"""
functions for un trade
-------------------------------------------
"""
#! these are helper functions
# TODO cleaning the un trade csv with countries with comma in the name, to remove this function
def clean_rows(
fname: str,
outname: str,
):
r"""
clean the rows with comma in the name
args:
fname: the path to the raw data
outname: the path to the cleaned data
"""
outf = open(outname, "w")
with open(fname) as f:
s = next(f)
outf.write(s)
for idx, line in enumerate(f):
strs = ["China, Taiwan Province of", "China, mainland"]
for str in strs:
line = line.replace(
"China, Taiwan Province of", "Taiwan Province of China"
)
line = line.replace("China, mainland", "China mainland")
line = line.replace("China, Hong Kong SAR", "China Hong Kong SAR")
line = line.replace("China, Macao SAR", "China Macao SAR")
line = line.replace(
"Saint Helena, Ascension and Tristan da Cunha",
"Saint Helena Ascension and Tristan da Cunha",
)
e = line.strip().split(",")
if len(e) > 4:
raise ValueError(f"line has more than 4 elements: {e}")
outf.write(line)
outf.close()
"""
functions for last fm genre
-------------------------------------------
"""
def load_edgelist_datetime(fname, label_size=514):
"""
load the edgelist into a pandas dataframe
use numpy array instead of list for faster processing
assume all edges are already sorted by time
convert all time unit to unix time
time, user_id, genre, weight
"""
feat_size = 1
num_lines = sum(1 for line in open(fname)) - 1
vprint(f"number of lines counted: {num_lines} in {fname}")
u_list = np.zeros(num_lines)
i_list = np.zeros(num_lines)
ts_list = np.zeros(num_lines)
feat_l = np.zeros((num_lines, feat_size))
idx_list = np.zeros(num_lines)
w_list = np.zeros(num_lines)
node_ids = {} # dictionary for node ids
label_ids = {} # dictionary for label ids
node_uid = label_size # node ids start after the genre nodes
label_uid = 0
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
idx = 0
for row in tqdm(csv_reader):
if idx == 0:
idx += 1
else:
ts = int(row[0])
user_id = row[1]
genre = row[2]
w = float(row[3])
if user_id not in node_ids:
node_ids[user_id] = node_uid
node_uid += 1
if genre not in label_ids:
label_ids[genre] = label_uid
if label_uid >= label_size:
vprint("id overlap, terminate")
label_uid += 1
u = node_ids[user_id]
i = label_ids[genre]
u_list[idx - 1] = u
i_list[idx - 1] = i
ts_list[idx - 1] = ts
idx_list[idx - 1] = idx
w_list[idx - 1] = w
feat_l[idx - 1] = np.asarray([w])
idx += 1
return (
pd.DataFrame(
{"u": u_list, "i": i_list, "ts": ts_list, "idx": idx_list, "w": w_list}
),
feat_l,
node_ids,
label_ids,
)
def load_genre_list(fname):
"""
load the list of genres
"""
if not osp.exists(fname):
raise FileNotFoundError(f"File not found at {fname}")
edgelist = open(fname, "r")
lines = list(edgelist.readlines())
edgelist.close()
genre_index = {}
ctr = 0
for i in range(1, len(lines)):
vals = lines[i].split(",")
genre = vals[0]
if genre not in genre_index:
genre_index[genre] = ctr
ctr += 1
else:
raise ValueError("duplicate in genre_index")
return genre_index
"""
functions for wikipedia and un_trade
-------------------------------------------
"""
def reindex(
df: pd.DataFrame,
bipartite: Optional[bool] = False,
):
r"""
reindex the nodes especially if the node ids are not integers
Args:
df: the pandas dataframe containing the graph
bipartite: whether the graph is bipartite
"""
new_df = df.copy()
if bipartite:
assert df.u.max() - df.u.min() + 1 == len(df.u.unique())
assert df.i.max() - df.i.min() + 1 == len(df.i.unique())
upper_u = df.u.max() + 1
new_i = df.i + upper_u
new_df.i = new_i
new_df.u += 1
new_df.i += 1
new_df.idx += 1
else:
new_df.u += 1
new_df.i += 1
new_df.idx += 1
return new_df
================================================
FILE: tgb/utils/stats.py
================================================
"""
script for generating statistics from the dataset
"""
import csv
import numpy as np
from tgb.utils.utils import vprint
"""
#! analyze statistics from the dataset
#* 1). # of unique nodes, 2). # of edges. 3). # of unique edges, 4). # of timestamps 5). recurrence of nodes
"""
def analyze_csv(fname):
node_dict = {}
edge_dict = {}
num_edges = 0
num_time = 0
time_dict = {}
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
line_count = 0
for row in csv_reader:
if line_count == 0:
line_count += 1
else:
# t,u,v,w
t = row[0]
u = row[1]
v = row[2]
# count unique time
if t not in time_dict:
time_dict[t] = 1
num_time += 1
# unique nodes
if u not in node_dict:
node_dict[u] = 1
else:
node_dict[u] += 1
if v not in node_dict:
node_dict[v] = 1
else:
node_dict[v] += 1
# unique edges
num_edges += 1
if (u, v) not in edge_dict:
edge_dict[(u, v)] = 1
else:
edge_dict[(u, v)] += 1
vprint("----------------------high level statistics-------------------------")
vprint("number of total edges are ", num_edges)
vprint("number of nodes are ", len(node_dict))
vprint("number of unique edges are ", len(edge_dict))
vprint("number of unique timestamps are ", num_time)
num_10 = 0
num_100 = 0
num_1000 = 0
for node in node_dict:
if node_dict[node] >= 10:
num_10 += 1
if node_dict[node] >= 100:
num_100 += 1
if node_dict[node] >= 1000:
num_1000 += 1
vprint("number of nodes with # edges >= 10 is ", num_10)
vprint("number of nodes with # edges >= 100 is ", num_100)
vprint("number of nodes with # edges >= 1000 is ", num_1000)
vprint("----------------------high level statistics-------------------------")
def plot_curve(y: np.ndarray, outname: str) -> None:
"""
plot the training curve given y
Parameters:
y: np.ndarray, the training curve
outname: str, the output name
"""
plt.plot(y, color="#fc4e2a")
plt.savefig(outname + ".pdf")
plt.close()
def main():
fname = "tgb/datasets/tgbl-wiki/tgbl-wiki_edgelist.csv"
analyze_csv(fname)
if __name__ == "__main__":
main()
================================================
FILE: tgb/utils/utils.py
================================================
import random
import os
import pickle
import sys
import argparse
import json
import torch
from typing import Any
import numpy as np
from torch_geometric.data import TemporalData
import pandas as pd
import torch
_VERBOSE = os.getenv("TGB_VERBOSE", 'False').lower() in ['true', '1']
def set_verbose(flag: bool) -> None:
global _VERBOSE
_VERBOSE = flag
def vprint(*args, **kwargs):
global _VERBOSE
if _VERBOSE: print(*args, **kwargs)
def add_inverse_quadruples(df: pd.DataFrame) -> pd.DataFrame:
r"""
adds the inverse relations required for the model to the dataframe
"""
if ("edge_type" not in df):
raise ValueError("edge_type is required to invert relation in TKG")
sources = np.array(df["u"])
destinations = np.array(df["i"])
timestamps = np.array(df["ts"])
edge_idxs = np.array(df["idx"])
weights = np.array(df["w"])
edge_type = np.array(df["edge_type"])
num_rels = np.unique(edge_type).shape[0]
inv_edge_type = edge_type + num_rels
all_sources = np.concatenate([sources, destinations])
all_destinations = np.concatenate([destinations, sources])
all_timestamps = np.concatenate([timestamps, timestamps])
all_edge_idxs = np.concatenate([edge_idxs, edge_idxs+edge_idxs.max()+1])
all_weights = np.concatenate([weights, weights])
all_edge_types = np.concatenate([edge_type, inv_edge_type])
return pd.DataFrame(
{
"u": all_sources,
"i": all_destinations,
"ts": all_timestamps,
"label": np.ones(all_timestamps.shape[0]),
"idx": all_edge_idxs,
"w": all_weights,
"edge_type": all_edge_types,
}
)
def add_inverse_quadruples_np(quadruples: np.array,
num_rels:int) -> np.array:
"""
creates an inverse quadruple for each quadruple in quadruples. inverse quadruple swaps subject and objsect, and increases
relation id by num_rels
:param quadruples: [np.array] dataset quadruples, [src, relation_id, dst, timestamp ]
:param num_rels: [int] number of relations that we have originally
returns all_quadruples: [np.array] quadruples including inverse quadruples
"""
inverse_quadruples = quadruples[:, [2, 1, 0, 3]]
inverse_quadruples[:, 1] = inverse_quadruples[:, 1] + num_rels # we also need inverse quadruples
all_quadruples = np.concatenate((quadruples[:,0:4], inverse_quadruples))
return all_quadruples
def add_inverse_quadruples_pyg(data: TemporalData, num_rels:int=-1) -> list:
r"""
creates an inverse quadruple from PyG TemporalData object, returns both the original and inverse quadruples
"""
timestamp = data.t
head = data.src
tail = data.dst
msg = data.msg
edge_type = data.edge_type #relation
num_rels = torch.max(edge_type).item() + 1
inv_type = edge_type + num_rels
all_data = TemporalData(src=torch.cat([head, tail]),
dst=torch.cat([tail, head]),
t=torch.cat([timestamp, timestamp.clone()]),
edge_type=torch.cat([edge_type, inv_type]),
msg=torch.cat([msg, msg.clone()]),
y = torch.cat([data.y, data.y.clone()]),)
return all_data
# import torch
def save_pkl(obj: Any, fname: str) -> None:
r"""
save a python object as a pickle file
"""
with open(fname, "wb") as handle:
pickle.dump(obj, handle, protocol=pickle.HIGHEST_PROTOCOL)
def load_pkl(fname: str) -> Any:
r"""
load a python object from a pickle file
"""
with open(fname, "rb") as handle:
return pickle.load(handle)
def set_random_seed(random_seed: int):
r"""
set random seed for reproducibility
Args:
random_seed (int): random seed
"""
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
vprint(f'INFO: fixed random seed: {random_seed}')
def find_nearest(array, value):
array = np.asarray(array)
idx = (np.abs(array - value)).argmin()
return array[idx]
def get_args():
parser = argparse.ArgumentParser('*** TGB ***')
parser.add_argument('-d', '--data', type=str, help='Dataset name', default='tgbl-wiki')
parser.add_argument('--lr', type=float, help='Learning rate', default=1e-4)
parser.add_argument('--bs', type=int, help='Batch size', default=200)
parser.add_argument('--k_value', type=int, help='k_value for computing ranking metrics', default=10)
parser.add_argument('--num_epoch', type=int, help='Number of epochs', default=30)
parser.add_argument('--seed', type=int, help='Random seed', default=1)
parser.add_argument('--mem_dim', type=int, help='Memory dimension', default=100)
parser.add_argument('--time_dim', type=int, help='Time dimension', default=100)
parser.add_argument('--emb_dim', type=int, help='Embedding dimension', default=100)
parser.add_argument('--tolerance', type=float, help='Early stopper tolerance', default=1e-6)
parser.add_argument('--patience', type=float, help='Early stopper patience', default=5)
parser.add_argument('--num_run', type=int, help='Number of iteration runs', default=5)
try:
args = parser.parse_args()
except:
parser.print_help()
sys.exit(0)
return args, sys.argv
def save_results(new_results: dict, filename: str):
r"""
save (new) results into a json file
:param: new_results (dictionary): a dictionary of new results to be saved
:filename: the name of the file to save the (new) results
"""
if os.path.isfile(filename):
# append to the file
with open(filename, 'r+') as json_file:
file_data = json.load(json_file)
# convert file_data to list if not
if type(file_data) is dict:
file_data = [file_data]
file_data.append(new_results)
json_file.seek(0)
json.dump(file_data, json_file, indent=4)
else:
# dump the results
with open(filename, 'w') as json_file:
json.dump(new_results, json_file, indent=4)
def split_by_time(data):
"""
https://github.com/Lee-zix/CEN/blob/main/rgcn/utils.py
create list where each entry has an entry with all triples for this timestep
"""
timesteps = list(set(data[:,3]))
timesteps.sort()
snapshot_list = [None] * len(timesteps)
for index, ts in enumerate(timesteps):
mask = np.where(data[:, 3] == ts)[0]
snapshot_list[index] = data[mask,:3]
return snapshot_list