Copy disabled (too large)
Download .txt
Showing preview only (74,471K chars total). Download the full file to get everything.
Repository: cmavro/GNN-RAG
Branch: main
Commit: 28d31bc0db3d
Files: 118
Total size: 106.0 MB
Directory structure:
gitextract_o4kycd_o/
├── .gitignore
├── README.md
├── gnn/
│ ├── .gitignore
│ ├── README.md
│ ├── dataset_load.py
│ ├── dataset_load_graft.py
│ ├── evaluate.py
│ ├── main.py
│ ├── models/
│ │ ├── GraftNet/
│ │ │ └── graftnet.py
│ │ ├── NSM/
│ │ │ └── nsm.py
│ │ ├── ReaRev/
│ │ │ └── rearev.py
│ │ └── base_model.py
│ ├── modules/
│ │ ├── kg_reasoning/
│ │ │ ├── base_gnn.py
│ │ │ ├── graft_gnn.py
│ │ │ ├── nsm_gnn.py
│ │ │ └── reasongnn.py
│ │ ├── layer_init.py
│ │ ├── query_update.py
│ │ └── question_encoding/
│ │ ├── base_encoder.py
│ │ ├── bert_encoder.py
│ │ ├── lstm_encoder.py
│ │ └── tokenizers.py
│ ├── parsing.py
│ ├── requirements.txt
│ ├── scripts/
│ │ └── rearev_cwq.sh
│ ├── train_model.py
│ └── utils.py
└── llm/
├── .gitignore
├── README.md
├── prompts/
│ ├── alpaca.txt
│ ├── general_prompt.txt
│ ├── llama2.txt
│ └── llama2_predict.txt
├── requirements.txt
├── results/
│ ├── KGQA-GNN-RAG/
│ │ ├── rearev-lmsr/
│ │ │ ├── RoG-cwq/
│ │ │ │ └── RoG/
│ │ │ │ └── test/
│ │ │ │ └── results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/
│ │ │ │ └── False/
│ │ │ │ ├── args.txt
│ │ │ │ ├── detailed_eval_result.jsonl
│ │ │ │ ├── eval_result.txt
│ │ │ │ └── predictions.jsonl
│ │ │ └── RoG-webqsp/
│ │ │ ├── RoG/
│ │ │ │ └── test/
│ │ │ │ └── results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/
│ │ │ │ └── False/
│ │ │ │ ├── args.txt
│ │ │ │ ├── detailed_eval_result.jsonl
│ │ │ │ ├── eval_result.txt
│ │ │ │ └── predictions.jsonl
│ │ │ └── llama2-chat-hf/
│ │ │ └── test/
│ │ │ └── no_rule/
│ │ │ └── False/
│ │ │ ├── args.txt
│ │ │ ├── detailed_eval_result.jsonl
│ │ │ ├── eval_result.txt
│ │ │ └── predictions.jsonl
│ │ └── rearev-sbert/
│ │ ├── RoG-cwq/
│ │ │ └── RoG/
│ │ │ └── test/
│ │ │ └── results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/
│ │ │ └── False/
│ │ │ ├── args.txt
│ │ │ ├── detailed_eval_result.jsonl
│ │ │ ├── eval_result.txt
│ │ │ └── predictions.jsonl
│ │ └── RoG-webqsp/
│ │ └── RoG/
│ │ └── test/
│ │ └── results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/
│ │ └── False/
│ │ ├── args.txt
│ │ ├── detailed_eval_result.jsonl
│ │ ├── eval_result.txt
│ │ └── predictions.jsonl
│ ├── KGQA-GNN-RAG-RA/
│ │ ├── rearev-lmsr/
│ │ │ ├── RoG-cwq/
│ │ │ │ └── RoG/
│ │ │ │ └── test/
│ │ │ │ └── results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/
│ │ │ │ └── False/
│ │ │ │ ├── args.txt
│ │ │ │ ├── detailed_eval_result.jsonl
│ │ │ │ ├── eval_result.txt
│ │ │ │ └── predictions.jsonl
│ │ │ └── RoG-webqsp/
│ │ │ └── RoG/
│ │ │ └── test/
│ │ │ └── results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/
│ │ │ └── False/
│ │ │ ├── args.txt
│ │ │ ├── detailed_eval_result.jsonl
│ │ │ ├── eval_result.txt
│ │ │ └── predictions.jsonl
│ │ └── rearev-sbert/
│ │ ├── RoG-cwq/
│ │ │ └── RoG/
│ │ │ └── test/
│ │ │ └── results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/
│ │ │ └── False/
│ │ │ ├── args.txt
│ │ │ ├── detailed_eval_result.jsonl
│ │ │ ├── eval_result.txt
│ │ │ └── predictions.jsonl
│ │ └── RoG-webqsp/
│ │ └── RoG/
│ │ └── test/
│ │ └── results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/
│ │ └── False/
│ │ ├── args.txt
│ │ ├── detailed_eval_result.jsonl
│ │ ├── eval_result.txt
│ │ └── predictions.jsonl
│ ├── gen_rule_path/
│ │ ├── RoG-cwq/
│ │ │ └── RoG/
│ │ │ └── test/
│ │ │ ├── predictions_2_False.jsonl
│ │ │ └── predictions_3_False.jsonl
│ │ └── RoG-webqsp/
│ │ └── RoG/
│ │ ├── test/
│ │ │ ├── predictions_1_False.jsonl
│ │ │ ├── predictions_2_False.jsonl
│ │ │ └── predictions_3_False.jsonl
│ │ ├── train/
│ │ │ ├── predictions_1_False.jsonl
│ │ │ └── predictions_3_False.jsonl
│ │ └── validation/
│ │ └── predictions_3_False.jsonl
│ └── gnn/
│ ├── RoG-cwq/
│ │ ├── rearev-lmsr/
│ │ │ └── test.info
│ │ └── rearev-sbert/
│ │ └── test.info
│ └── RoG-webqsp/
│ ├── rearev-lmsr/
│ │ └── test.info
│ └── rearev-sbert/
│ └── test.info
├── scripts/
│ ├── evaluate_multi_hop.sh
│ ├── interpretable_example.py
│ ├── planning.sh
│ ├── plug-and-play.sh
│ ├── rag-reasoning.sh
│ └── train.sh
└── src/
├── __init__.py
├── align_kg/
│ ├── __init__.py
│ ├── build_align_qa_dataset.py
│ └── data_loader.py
├── joint_training/
│ ├── generate_explanation_results.py
│ ├── joint_finetuning.py
│ ├── preprocess_align.py
│ └── preprocess_qa.py
├── llms/
│ ├── __init__.py
│ ├── language_models/
│ │ ├── __init__.py
│ │ ├── alpaca.py
│ │ ├── base_language_model.py
│ │ ├── chatgpt.py
│ │ ├── flan_t5.py
│ │ ├── llama.py
│ │ └── longchat/
│ │ ├── llama_condense_monkey_patch.py
│ │ ├── llama_flash_attn_monkey_patch.py
│ │ └── longchat.py
│ ├── llm_proxy.py
│ └── start_fastchat_api.py
├── qa_prediction/
│ ├── build_qa_input.py
│ ├── evaluate_multi_hop.py
│ ├── evaluate_results.py
│ ├── gen_rule_path.py
│ └── predict_answer.py
└── utils/
├── __init__.py
├── graph_utils.py
├── merge_peft.py
├── training_utils.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
================================================
FILE: README.md
================================================
This is the code for **GNN-RAG: Graph Neural Retrieval for Large Language Modeling Reasoning**.

The directory is the following:
|----`gnn` folder has the implementation of different KGQA GNNs.
You can train your own GNNs or you can skip this folder and use directly the GNN output (retrieved answer nodes) that we computed (`llm/results/gnn`).
|----`llm` folder has the implementation for RAG-based KGQA with LLMs.
Please see details on how to reproduce results there.
**Results**: We append all the results for Table 2: See `results/KGQA-GNN-RAG-RA` or `results/KGQA-GNN-RAG`. You can look at the actual LLM generations, as well as the KG information retrieved ("input" key) in predictions.jsonl.
================================================
FILE: gnn/.gitignore
================================================
#LM_KGQA specific
checkpoint/
checkpoint/pretrain/
data/
pretrained_lms/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
================================================
FILE: gnn/README.md
================================================
## Get Started
We have simple requirements in `requirements.txt`. You can always check if you can run the code immediately.
The datasets as well as the pretrained LM (LMsr) are uploaded here: hhttps://drive.google.com/drive/folders/1ifgVHQDnvFEunP9hmVYT07Y3rvcpIfQp?usp=sharing
Please download them and extract them to the corresponding folders.
## Training
Please follow the guidelines and hyperparamters of the corresponding GNNs for training. See `scripts` on a training example.
Otherwise, you can download released GNN models from here: https://drive.google.com/file/d/1p7eLSsSKkZQxB32mT5lMsthVP6R_3x1j/view
## Evaluation
To evaluate them, copy the command from the above scripts, add the `--is_eval` argument, and `--load experiment` followed by the name of the corresponding `ckpt` model.
For example, for Webqsp run:
```
python main.py ReaRev --entity_dim 50 --num_epoch 200 --batch_size 8 --eval_every 2 --data_folder data/webqsp/ --lm sbert --num_iter 3 --num_ins 2 --num_gnn 3 --relation_word_emb True --load_experiment ReaRev_webqsp.ckpt --is_eval --name webqsp
```
The result is saved as a `.info` file. In order to use GNN-RAG, please move this file to the corresponding folder in `GNN-RAG/llm/results/gnn/` by renaming it to `test.info`.
================================================
FILE: gnn/dataset_load.py
================================================
import json
import numpy as np
import re
from tqdm import tqdm
import torch
from collections import Counter
import random
import warnings
import pickle
warnings.filterwarnings("ignore")
from modules.question_encoding.tokenizers import LSTMTokenizer#, BERTTokenizer
from transformers import AutoTokenizer
import time
import os
class BasicDataLoader(object):
"""
Basic Dataloader contains all the functions to read questions and KGs from json files and
create mappings between global entity ids and local ids that are used during GNN updates.
"""
def __init__(self, config, word2id, relation2id, entity2id, tokenize, data_type="train"):
self.tokenize = tokenize
self._parse_args(config, word2id, relation2id, entity2id)
self._load_file(config, data_type)
self._load_data()
def _load_file(self, config, data_type="train"):
"""
Loads lines (questions + KG subgraphs) from json files.
"""
data_file = config['data_folder'] + data_type + ".json"
self.data_file = data_file
print('loading data from', data_file)
self.data_type = data_type
self.data = []
skip_index = set()
index = 0
with open(data_file) as f_in:
for line in tqdm(f_in):
if index == config['max_train'] and data_type == "train": break #break if we reach max_question_size
line = json.loads(line)
if len(line['entities']) == 0:
skip_index.add(index)
continue
self.data.append(line)
self.max_facts = max(self.max_facts, 2 * len(line['subgraph']['tuples']))
index += 1
print("skip", skip_index)
print('max_facts: ', self.max_facts)
self.num_data = len(self.data)
self.batches = np.arange(self.num_data)
def _load_data(self):
"""
Creates mappings between global entity ids and local entity ids that are used during GNN updates.
"""
print('converting global to local entity index ...')
self.global2local_entity_maps = self._build_global2local_entity_maps()
if self.use_self_loop:
self.max_facts = self.max_facts + self.max_local_entity
self.question_id = []
self.candidate_entities = np.full((self.num_data, self.max_local_entity), len(self.entity2id), dtype=int)
self.kb_adj_mats = np.empty(self.num_data, dtype=object)
self.q_adj_mats = np.empty(self.num_data, dtype=object)
self.kb_fact_rels = np.full((self.num_data, self.max_facts), self.num_kb_relation, dtype=int)
self.query_entities = np.zeros((self.num_data, self.max_local_entity), dtype=float)
self.seed_list = np.empty(self.num_data, dtype=object)
self.seed_distribution = np.zeros((self.num_data, self.max_local_entity), dtype=float)
# self.query_texts = np.full((self.num_data, self.max_query_word), len(self.word2id), dtype=int)
self.answer_dists = np.zeros((self.num_data, self.max_local_entity), dtype=float)
self.answer_lists = np.empty(self.num_data, dtype=object)
self._prepare_data()
def _parse_args(self, config, word2id, relation2id, entity2id):
"""
Builds necessary dictionaries and stores arguments.
"""
self.data_eff = config['data_eff']
self.data_name = config['name']
if 'use_inverse_relation' in config:
self.use_inverse_relation = config['use_inverse_relation']
else:
self.use_inverse_relation = False
if 'use_self_loop' in config:
self.use_self_loop = config['use_self_loop']
else:
self.use_self_loop = False
self.rel_word_emb = config['relation_word_emb']
#self.num_step = config['num_step']
self.max_local_entity = 0
self.max_relevant_doc = 0
self.max_facts = 0
print('building word index ...')
self.word2id = word2id
self.id2word = {i: word for word, i in word2id.items()}
self.relation2id = relation2id
self.entity2id = entity2id
self.id2entity = {i: entity for entity, i in entity2id.items()}
self.q_type = config['q_type']
if self.use_inverse_relation:
self.num_kb_relation = 2 * len(relation2id)
else:
self.num_kb_relation = len(relation2id)
if self.use_self_loop:
self.num_kb_relation = self.num_kb_relation + 1
print("Entity: {}, Relation in KB: {}, Relation in use: {} ".format(len(entity2id),
len(self.relation2id),
self.num_kb_relation))
def get_quest(self, training=False):
q_list = []
sample_ids = self.sample_ids
for sample_id in sample_ids:
tp_str = self.decode_text(self.query_texts[sample_id, :])
# id2word = self.id2word
# for i in range(self.max_query_word):
# if self.query_texts[sample_id, i] in id2word:
# tp_str += id2word[self.query_texts[sample_id, i]] + " "
q_list.append(tp_str)
return q_list
def decode_text(self, np_array_x):
if self.tokenize == 'lstm':
id2word = self.id2word
tp_str = ""
for i in range(self.max_query_word):
if np_array_x[i] in id2word:
tp_str += id2word[np_array_x[i]] + " "
else:
tp_str = ""
words = self.tokenizer.convert_ids_to_tokens(np_array_x)
for w in words:
if w not in ['[CLS]', '[SEP]', '[PAD]']:
tp_str += w + " "
return tp_str
def _prepare_data(self):
"""
global2local_entity_maps: a map from global entity id to local entity id
adj_mats: a local adjacency matrix for each relation. relation 0 is reserved for self-connection.
"""
max_count = 0
for line in self.data:
word_list = line["question"].split(' ')
max_count = max(max_count, len(word_list))
if self.rel_word_emb:
self.build_rel_words(self.tokenize)
else:
self.rel_texts = None
self.rel_texts_inv = None
self.ent_texts = None
self.max_query_word = max_count
#self.query_texts = np.full((self.num_data, self.max_query_word), len(self.word2id), dtype=int)
#self.query_texts2 = np.full((self.num_data, self.max_query_word), len(self.word2id), dtype=int)
#build tokenizers
if self.tokenize == 'lstm':
self.num_word = len(self.word2id)
self.tokenizer = LSTMTokenizer(self.word2id, self.max_query_word)
self.query_texts = np.full((self.num_data, self.max_query_word), self.num_word, dtype=int)
else:
if self.tokenize == 'bert':
tokenizer_name = 'bert-base-uncased'
elif self.tokenize == 'roberta':
tokenizer_name = 'roberta-base'
elif self.tokenize == 'sbert':
tokenizer_name = 'sentence-transformers/all-MiniLM-L6-v2'
elif self.tokenize == 'sbert2':
tokenizer_name = 'sentence-transformers/all-mpnet-base-v2'
elif self.tokenize == 't5':
tokenizer_name = 't5-small'
elif self.tokenize == 'simcse':
tokenizer_name = 'princeton-nlp/sup-simcse-bert-base-uncased'
elif self.tokenize == 't5':
tokenizer_name = 't5-small'
elif self.tokenize == 'relbert':
tokenizer_name = 'pretrained_lms/sr-simbert/'
self.max_query_word = max_count + 2 #for cls token and sep
#self.tokenizer = AutoTokenizer(self.max_query_word)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
self.num_word = self.tokenizer.convert_tokens_to_ids(self.tokenizer.pad_token) #self.tokenizer.q_tokenizer.encode("[UNK]")[0]
self.query_texts = np.full((self.num_data, self.max_query_word), self.num_word, dtype=int)
next_id = 0
num_query_entity = {}
for sample in tqdm(self.data):
self.question_id.append(sample["id"])
# get a list of local entities
g2l = self.global2local_entity_maps[next_id]
#print(g2l)
if len(g2l) == 0:
#print(next_id)
continue
# build connection between question and entities in it
tp_set = set()
seed_list = []
key_ent = 'entities_cid' if 'entities_cid' in sample else 'entities'
for j, entity in enumerate(sample[key_ent]):
# if entity['text'] not in self.entity2id:
# continue
try:
if isinstance(entity, dict) and 'text' in entity:
global_entity = self.entity2id[entity['text']]
else:
global_entity = self.entity2id[entity]
global_entity = self.entity2id[entity['text']]
except:
global_entity = entity #self.entity2id[entity['text']]
if global_entity not in g2l:
continue
local_ent = g2l[global_entity]
self.query_entities[next_id, local_ent] = 1.0
seed_list.append(local_ent)
tp_set.add(local_ent)
self.seed_list[next_id] = seed_list
num_query_entity[next_id] = len(tp_set)
for global_entity, local_entity in g2l.items():
if self.data_name != 'cwq':
if local_entity not in tp_set: # skip entities in question
#print(global_entity)
#print(local_entity)
self.candidate_entities[next_id, local_entity] = global_entity
elif self.data_name == 'cwq':
self.candidate_entities[next_id, local_entity] = global_entity
# if local_entity != 0: # skip question node
# self.candidate_entities[next_id, local_entity] = global_entity
# relations in local KB
head_list = []
rel_list = []
tail_list = []
for i, tpl in enumerate(sample['subgraph']['tuples']):
sbj, rel, obj = tpl
try:
if isinstance(sbj, dict) and 'text' in sbj:
head = g2l[self.entity2id[sbj['text']]]
rel = self.relation2id[rel['text']]
tail = g2l[self.entity2id[obj['text']]]
else:
head = g2l[self.entity2id[sbj]]
rel = self.relation2id[rel]
tail = g2l[self.entity2id[obj]]
except:
head = g2l[sbj]
try:
rel = int(rel)
except:
rel = self.relation2id[rel]
tail = g2l[obj]
head_list.append(head)
rel_list.append(rel)
tail_list.append(tail)
self.kb_fact_rels[next_id, i] = rel
if self.use_inverse_relation:
head_list.append(tail)
rel_list.append(rel + len(self.relation2id))
tail_list.append(head)
self.kb_fact_rels[next_id, i] = rel + len(self.relation2id)
if len(tp_set) > 0:
for local_ent in tp_set:
self.seed_distribution[next_id, local_ent] = 1.0 / len(tp_set)
else:
for index in range(len(g2l)):
self.seed_distribution[next_id, index] = 1.0 / len(g2l)
try:
assert np.sum(self.seed_distribution[next_id]) > 0.0
except:
print(next_id, len(tp_set))
exit(-1)
#tokenize question
if self.tokenize == 'lstm':
self.query_texts[next_id] = self.tokenizer.tokenize(sample['question'])
else:
tokens = self.tokenizer.encode_plus(text=sample['question'], max_length=self.max_query_word, \
pad_to_max_length=True, return_attention_mask = False, truncation=True)
self.query_texts[next_id] = np.array(tokens['input_ids'])
# construct distribution for answers
answer_list = []
if 'answers_cid' in sample:
for answer in sample['answers_cid']:
#keyword = 'text' if type(answer['kb_id']) == int else 'kb_id'
answer_ent = answer
answer_list.append(answer_ent)
if answer_ent in g2l:
self.answer_dists[next_id, g2l[answer_ent]] = 1.0
else:
for answer in sample['answers']:
keyword = 'text' if type(answer['kb_id']) == int else 'kb_id'
answer_ent = self.entity2id[answer[keyword]]
answer_list.append(answer_ent)
if answer_ent in g2l:
self.answer_dists[next_id, g2l[answer_ent]] = 1.0
self.answer_lists[next_id] = answer_list
if not self.data_eff:
self.kb_adj_mats[next_id] = (np.array(head_list, dtype=int),
np.array(rel_list, dtype=int),
np.array(tail_list, dtype=int))
next_id += 1
num_no_query_ent = 0
num_one_query_ent = 0
num_multiple_ent = 0
for i in range(next_id):
ct = num_query_entity[i]
if ct == 1:
num_one_query_ent += 1
elif ct == 0:
num_no_query_ent += 1
else:
num_multiple_ent += 1
print("{} cases in total, {} cases without query entity, {} cases with single query entity,"
" {} cases with multiple query entities".format(next_id, num_no_query_ent,
num_one_query_ent, num_multiple_ent))
def build_rel_words(self, tokenize):
"""
Tokenizes relation surface forms.
"""
max_rel_words = 0
rel_words = []
if 'metaqa' in self.data_file:
for rel in self.relation2id:
words = rel.split('_')
max_rel_words = max(len(words), max_rel_words)
rel_words.append(words)
#print(rel_words)
else:
for rel in self.relation2id:
rel = rel.strip()
fields = rel.split('.')
try:
words = fields[-2].split('_') + fields[-1].split('_')
max_rel_words = max(len(words), max_rel_words)
rel_words.append(words)
#print(rel, words)
except:
words = ['UNK']
rel_words.append(words)
pass
#words = fields[-2].split('_') + fields[-1].split('_')
self.max_rel_words = max_rel_words
if tokenize == 'lstm':
self.rel_texts = np.full((self.num_kb_relation + 1, self.max_rel_words), len(self.word2id), dtype=int)
self.rel_texts_inv = np.full((self.num_kb_relation + 1, self.max_rel_words), len(self.word2id), dtype=int)
for rel_id,tokens in enumerate(rel_words):
for j, word in enumerate(tokens):
if j < self.max_rel_words:
if word in self.word2id:
self.rel_texts[rel_id, j] = self.word2id[word]
self.rel_texts_inv[rel_id, j] = self.word2id[word]
else:
self.rel_texts[rel_id, j] = len(self.word2id)
self.rel_texts_inv[rel_id, j] = len(self.word2id)
else:
if tokenize == 'bert':
tokenizer_name = 'bert-base-uncased'
elif tokenize == 'roberta':
tokenizer_name = 'roberta-base'
elif tokenize == 'sbert':
tokenizer_name = 'sentence-transformers/all-MiniLM-L6-v2'
elif tokenize == 'sbert2':
tokenizer_name = 'sentence-transformers/all-mpnet-base-v2'
elif tokenize == 'simcse':
tokenizer_name = 'princeton-nlp/sup-simcse-bert-base-uncased'
elif tokenize == 't5':
tokenizer_name = 't5-small'
elif tokenize == 'relbert':
tokenizer_name = 'pretrained_lms/sr-simbert/'
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
pad_val = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
self.rel_texts = np.full((self.num_kb_relation + 1, self.max_rel_words), pad_val, dtype=int)
self.rel_texts_inv = np.full((self.num_kb_relation + 1, self.max_rel_words), pad_val, dtype=int)
for rel_id,words in enumerate(rel_words):
tokens = tokenizer.encode_plus(text=' '.join(words), max_length=self.max_rel_words, \
pad_to_max_length=True, return_attention_mask = False, truncation=True)
tokens_inv = tokenizer.encode_plus(text=' '.join(words[::-1]), max_length=self.max_rel_words, \
pad_to_max_length=True, return_attention_mask = False, truncation=True)
self.rel_texts[rel_id] = np.array(tokens['input_ids'])
self.rel_texts_inv[rel_id] = np.array(tokens_inv['input_ids'])
#print(rel_words)
#print(len(rel_words), len(self.relation2id))
assert len(rel_words) == len(self.relation2id)
#print(self.rel_texts, self.max_rel_words)
def create_kb_adj_mats(self, sample_id):
"""
Re-build local adj mats if we have data_eff == True (they are not pre-stored).
"""
sample = self.data[sample_id]
g2l = self.global2local_entity_maps[sample_id]
# build connection between question and entities in it
head_list = []
rel_list = []
tail_list = []
for i, tpl in enumerate(sample['subgraph']['tuples']):
sbj, rel, obj = tpl
try:
if isinstance(sbj, dict) and 'text' in sbj:
head = g2l[self.entity2id[sbj['text']]]
rel = self.relation2id[rel['text']]
tail = g2l[self.entity2id[obj['text']]]
else:
head = g2l[self.entity2id[sbj]]
rel = self.relation2id[rel]
tail = g2l[self.entity2id[obj]]
except:
head = g2l[sbj]
try:
rel = int(rel)
except:
rel = self.relation2id[rel]
tail = g2l[obj]
head_list.append(head)
rel_list.append(rel)
tail_list.append(tail)
if self.use_inverse_relation:
head_list.append(tail)
rel_list.append(rel + len(self.relation2id))
tail_list.append(head)
return np.array(head_list, dtype=int), np.array(rel_list, dtype=int), np.array(tail_list, dtype=int)
def _build_fact_mat(self, sample_ids, fact_dropout):
"""
Creates local adj mats that contain entities, relations, and structure.
"""
batch_heads = np.array([], dtype=int)
batch_rels = np.array([], dtype=int)
batch_tails = np.array([], dtype=int)
batch_ids = np.array([], dtype=int)
#print(sample_ids)
for i, sample_id in enumerate(sample_ids):
index_bias = i * self.max_local_entity
if self.data_eff:
head_list, rel_list, tail_list = self.create_kb_adj_mats(sample_id) #kb_adj_mats[sample_id]
else:
(head_list, rel_list, tail_list) = self.kb_adj_mats[sample_id]
num_fact = len(head_list)
num_keep_fact = int(np.floor(num_fact * (1 - fact_dropout)))
mask_index = np.random.permutation(num_fact)[: num_keep_fact]
real_head_list = head_list[mask_index] + index_bias
real_tail_list = tail_list[mask_index] + index_bias
real_rel_list = rel_list[mask_index]
batch_heads = np.append(batch_heads, real_head_list)
batch_rels = np.append(batch_rels, real_rel_list)
batch_tails = np.append(batch_tails, real_tail_list)
batch_ids = np.append(batch_ids, np.full(len(mask_index), i, dtype=int))
if self.use_self_loop:
num_ent_now = len(self.global2local_entity_maps[sample_id])
ent_array = np.array(range(num_ent_now), dtype=int) + index_bias
rel_array = np.array([self.num_kb_relation - 1] * num_ent_now, dtype=int)
batch_heads = np.append(batch_heads, ent_array)
batch_tails = np.append(batch_tails, ent_array)
batch_rels = np.append(batch_rels, rel_array)
batch_ids = np.append(batch_ids, np.full(num_ent_now, i, dtype=int))
fact_ids = np.array(range(len(batch_heads)), dtype=int)
head_rels_ids = zip(batch_heads, batch_rels)
head_count = Counter(batch_heads)
# tail_count = Counter(batch_tails)
weight_list = [1.0 / head_count[head] for head in batch_heads]
head_rels_batch = list(zip(batch_heads, batch_rels))
#print(head_rels_batch)
head_rels_count = Counter(head_rels_batch)
weight_rel_list = [1.0 / head_rels_count[(h,r)] for (h,r) in head_rels_batch]
#print(head_rels_count)
# tail_count = Counter(batch_tails)
# entity2fact_index = torch.LongTensor([batch_heads, fact_ids])
# entity2fact_val = torch.FloatTensor(weight_list)
# entity2fact_mat = torch.sparse.FloatTensor(entity2fact_index, entity2fact_val, torch.Size(
# [len(sample_ids) * self.max_local_entity, len(batch_heads)]))
return batch_heads, batch_rels, batch_tails, batch_ids, fact_ids, weight_list, weight_rel_list
def reset_batches(self, is_sequential=True):
if is_sequential:
self.batches = np.arange(self.num_data)
else:
self.batches = np.random.permutation(self.num_data)
def _build_global2local_entity_maps(self):
"""Create a map from global entity id to local entity of each sample"""
global2local_entity_maps = [None] * self.num_data
total_local_entity = 0.0
next_id = 0
for sample in tqdm(self.data):
g2l = dict()
if 'entities_cid' in sample:
self._add_entity_to_map(self.entity2id, sample['entities_cid'], g2l)
else:
self._add_entity_to_map(self.entity2id, sample['entities'], g2l)
#self._add_entity_to_map(self.entity2id, sample['entities'], g2l)
# construct a map from global entity id to local entity id
self._add_entity_to_map(self.entity2id, sample['subgraph']['entities'], g2l)
global2local_entity_maps[next_id] = g2l
total_local_entity += len(g2l)
self.max_local_entity = max(self.max_local_entity, len(g2l))
next_id += 1
print('avg local entity: ', total_local_entity / next_id)
print('max local entity: ', self.max_local_entity)
return global2local_entity_maps
@staticmethod
def _add_entity_to_map(entity2id, entities, g2l):
#print(entities)
#print(entity2id)
for entity_global_id in entities:
try:
if isinstance(entity_global_id, dict) and 'text' in entity_global_id:
ent = entity2id[entity_global_id['text']]
else:
ent = entity2id[entity_global_id]
if ent not in g2l:
g2l[ent] = len(g2l)
except:
if entity_global_id not in g2l:
g2l[entity_global_id] = len(g2l)
def deal_q_type(self, q_type=None):
sample_ids = self.sample_ids
if q_type is None:
q_type = self.q_type
if q_type == "seq":
q_input = self.query_texts[sample_ids]
else:
raise NotImplementedError
return q_input
class SingleDataLoader(BasicDataLoader):
"""
Single Dataloader creates training/eval batches during KGQA.
"""
def __init__(self, config, word2id, relation2id, entity2id, tokenize, data_type="train"):
super(SingleDataLoader, self).__init__(config, word2id, relation2id, entity2id, tokenize, data_type)
def get_batch(self, iteration, batch_size, fact_dropout, q_type=None, test=False):
start = batch_size * iteration
end = min(batch_size * (iteration + 1), self.num_data)
sample_ids = self.batches[start: end]
self.sample_ids = sample_ids
# true_batch_id, sample_ids, seed_dist = self.deal_multi_seed(ori_sample_ids)
# self.sample_ids = sample_ids
# self.true_sample_ids = ori_sample_ids
# self.batch_ids = true_batch_id
true_batch_id = None
seed_dist = self.seed_distribution[sample_ids]
q_input = self.deal_q_type(q_type)
kb_adj_mats = self._build_fact_mat(sample_ids, fact_dropout=fact_dropout)
if test:
return self.candidate_entities[sample_ids], \
self.query_entities[sample_ids], \
kb_adj_mats, \
q_input, \
seed_dist, \
true_batch_id, \
self.answer_dists[sample_ids], \
self.answer_lists[sample_ids],\
return self.candidate_entities[sample_ids], \
self.query_entities[sample_ids], \
kb_adj_mats, \
q_input, \
seed_dist, \
true_batch_id, \
self.answer_dists[sample_ids]
def load_dict(filename):
word2id = dict()
with open(filename, encoding='utf-8') as f_in:
for line in f_in:
word = line.strip()
word2id[word] = len(word2id)
return word2id
def load_dict_int(filename):
word2id = dict()
with open(filename, encoding='utf-8') as f_in:
for line in f_in:
word = line.strip()
word2id[int(word)] = int(word)
return word2id
def load_data(config, tokenize):
"""
Creates train/val/test dataloaders (seperately).
"""
if 'sr-cwq' in config['data_folder']:
entity2id = load_dict_int(config['data_folder'] + config['entity2id'])
else:
entity2id = load_dict(config['data_folder'] + config['entity2id'])
word2id = load_dict(config['data_folder'] + config['word2id'])
relation2id = load_dict(config['data_folder'] + config['relation2id'])
if config["is_eval"]:
train_data = None
valid_data = SingleDataLoader(config, word2id, relation2id, entity2id, tokenize, data_type="dev")
test_data = SingleDataLoader(config, word2id, relation2id, entity2id, tokenize, data_type="test")
num_word = test_data.num_word
else:
train_data = SingleDataLoader(config, word2id, relation2id, entity2id, tokenize, data_type="train")
valid_data = SingleDataLoader(config, word2id, relation2id, entity2id, tokenize, data_type="dev")
test_data = SingleDataLoader(config, word2id, relation2id, entity2id, tokenize, data_type="test")
num_word = train_data.num_word
relation_texts = test_data.rel_texts
relation_texts_inv = test_data.rel_texts_inv
entities_texts = None
dataset = {
"train": train_data,
"valid": valid_data,
"test": test_data, #test_data,
"entity2id": entity2id,
"relation2id": relation2id,
"word2id": word2id,
"num_word": num_word,
"rel_texts": relation_texts,
"rel_texts_inv": relation_texts_inv,
"ent_texts": entities_texts
}
return dataset
if __name__ == "__main__":
st = time.time()
#args = get_config()
load_data(args)
================================================
FILE: gnn/dataset_load_graft.py
================================================
import json
import numpy as np
import re
from tqdm import tqdm
import torch
from collections import Counter
import random
import warnings
import pickle
warnings.filterwarnings("ignore")
from modules.question_encoding.tokenizers import LSTMTokenizer#, BERTTokenizer
from transformers import AutoTokenizer
import time
import os
from dataset_load import BasicDataLoader
class GraftBasicDataLoader(BasicDataLoader):
"""
Basic Dataloader contains all the functions to read questions and KGs from json files and
create mappings between global entity ids and local ids that are used during GNN updates.
"""
def __init__(self, config, word2id, relation2id, entity2id, tokenize, data_type):
super(GraftBasicDataLoader, self).__init__(config, word2id, relation2id, entity2id, tokenize, data_type)
def create_kb_adj_mats_facts(self, sample_id):
sample = self.data[sample_id]
g2l = self.global2local_entity_maps[sample_id]
entity2fact_e, entity2fact_f = [], []
fact2entity_f, fact2entity_e = [], []
kb_fact_rels = np.full(self.max_facts, self.num_kb_relation, dtype=int)
for i, tpl in enumerate(sample['subgraph']['tuples']):
sbj, rel, obj = tpl
try:
if isinstance(sbj, dict) and 'text' in sbj:
head = g2l[self.entity2id[sbj['text']]]
rel = self.relation2id[rel['text']]
tail = g2l[self.entity2id[obj['text']]]
else:
head = g2l[self.entity2id[sbj]]
rel = self.relation2id[rel]
tail = g2l[self.entity2id[obj]]
except:
head = g2l[sbj]
try:
rel = int(rel)
except:
rel = self.relation2id[rel]
tail = g2l[obj]
if not self.use_inverse_relation:
entity2fact_e += [head]
entity2fact_f += [i]
fact2entity_f += [i]
fact2entity_e += [tail]
kb_fact_rels[i] = rel
else:
entity2fact_e += [head, tail]
entity2fact_f += [2 * i, 2 * i + 1]
fact2entity_f += [2 * i, 2 * i + 1]
fact2entity_e += [tail, head]
kb_fact_rels[2 * i] = rel
kb_fact_rels[2 * i + 1] = rel + len(self.relation2id)
kb_adj_mats = (np.array(entity2fact_f, dtype=int), \
np.array(entity2fact_e, dtype=int), np.array([1.0] * len(entity2fact_f))),\
(np.array(fact2entity_e, dtype=int), np.array(fact2entity_f, dtype=int), np.array([1.0] * len(fact2entity_e)))
return kb_adj_mats, kb_fact_rels
def _build_fact_mat_maxfacts(self, sample_ids, fact_dropout):
"""Create sparse matrix representation for batched data"""
kb_fact_rels = np.full((len(sample_ids), self.max_facts), self.num_kb_relation, dtype=int)
mats0_batch = np.array([], dtype=int)
mats0_0 = np.array([], dtype=int)
mats0_1 = np.array([], dtype=int)
vals0 = np.array([], dtype=float)
mats1_batch = np.array([], dtype=int)
mats1_0 = np.array([], dtype=int)
mats1_1 = np.array([], dtype=int)
vals1 = np.array([], dtype=float)
for i, sample_id in enumerate(sample_ids):
((mat0_0, mat0_1, val0), (mat1_0, mat1_1, val1)), kb_fact_rel = self.create_kb_adj_mats_facts(sample_id)
kb_fact_rels[i] = kb_fact_rel
assert len(val0) == len(val1)
num_fact = len(val0)
num_keep_fact = int(np.floor(num_fact * (1 - fact_dropout)))
mask_index = np.random.permutation(num_fact)[ : num_keep_fact]
# mat0
mats0_batch = np.append(mats0_batch, np.full(len(mask_index), i, dtype=int))
mats0_0 = np.append(mats0_0, mat0_0[mask_index])
mats0_1 = np.append(mats0_1, mat0_1[mask_index])
vals0 = np.append(vals0, val0[mask_index])
# mat1
mats1_batch = np.append(mats1_batch, np.full(len(mask_index), i, dtype=int))
mats1_0 = np.append(mats1_0, mat1_0[mask_index])
mats1_1 = np.append(mats1_1, mat1_1[mask_index])
vals1 = np.append(vals1, val1[mask_index])
return ((mats0_batch, mats0_0, mats0_1, vals0), (mats1_batch, mats1_0, mats1_1, vals1)), kb_fact_rels
class GraftSingleDataLoader(GraftBasicDataLoader):
"""
Single Dataloader creates training/eval batches during KGQA.
"""
def __init__(self, config, word2id, relation2id, entity2id, tokenize, data_type="train"):
super(GraftSingleDataLoader, self).__init__(config, word2id, relation2id, entity2id, tokenize, data_type)
def get_batch(self, iteration, batch_size, fact_dropout, q_type=None, test=False):
start = batch_size * iteration
end = min(batch_size * (iteration + 1), self.num_data)
sample_ids = self.batches[start: end]
self.sample_ids = sample_ids
# true_batch_id, sample_ids, seed_dist = self.deal_multi_seed(ori_sample_ids)
# self.sample_ids = sample_ids
# self.true_sample_ids = ori_sample_ids
# self.batch_ids = true_batch_id
true_batch_id = None
seed_dist = self.seed_distribution[sample_ids]
q_input = self.deal_q_type(q_type)
kb_adj_mats = self._build_fact_mat(sample_ids, fact_dropout=fact_dropout)
kb_fact_rels = self.kb_fact_rels[sample_ids]
kb_adj_mats_graft, _ = self._build_fact_mat_maxfacts(sample_ids, fact_dropout=fact_dropout)
if test:
return self.candidate_entities[sample_ids], \
self.query_entities[sample_ids], \
kb_adj_mats, \
kb_adj_mats_graft, \
q_input, \
kb_fact_rels, \
seed_dist, \
true_batch_id, \
self.answer_dists[sample_ids], \
self.answer_lists[sample_ids],\
return self.candidate_entities[sample_ids], \
self.query_entities[sample_ids], \
kb_adj_mats, \
kb_adj_mats_graft, \
q_input, \
kb_fact_rels, \
seed_dist, \
true_batch_id, \
self.answer_dists[sample_ids]
def load_dict(filename):
word2id = dict()
with open(filename, encoding='utf-8') as f_in:
for line in f_in:
word = line.strip()
word2id[word] = len(word2id)
return word2id
def load_dict_int(filename):
word2id = dict()
with open(filename, encoding='utf-8') as f_in:
for line in f_in:
word = line.strip()
word2id[int(word)] = int(word)
return word2id
def load_data_graft(config, tokenize):
"""
Creates train/val/test dataloaders (seperately).
"""
if 'sr-cwq' in config['data_folder']:
entity2id = load_dict_int(config['data_folder'] + config['entity2id'])
else:
entity2id = load_dict(config['data_folder'] + config['entity2id'])
word2id = load_dict(config['data_folder'] + config['word2id'])
relation2id = load_dict(config['data_folder'] + config['relation2id'])
if config["is_eval"]:
train_data = None
valid_data = GraftSingleDataLoader(config, word2id, relation2id, entity2id, tokenize, data_type="dev")
test_data = GraftSingleDataLoader(config, word2id, relation2id, entity2id, tokenize, data_type="test")
num_word = test_data.num_word
else:
train_data = GraftSingleDataLoader(config, word2id, relation2id, entity2id, tokenize, data_type="train")
valid_data = GraftSingleDataLoader(config, word2id, relation2id, entity2id, tokenize, data_type="dev")
test_data = GraftSingleDataLoader(config, word2id, relation2id, entity2id, tokenize, data_type="test")
num_word = train_data.num_word
relation_texts = test_data.rel_texts
relation_texts_inv = test_data.rel_texts_inv
entities_texts = None
dataset = {
"train": train_data,
"valid": valid_data,
"test": test_data, #test_data,
"entity2id": entity2id,
"relation2id": relation2id,
"word2id": word2id,
"num_word": num_word,
"rel_texts": relation_texts,
"rel_texts_inv": relation_texts_inv,
"ent_texts": entities_texts
}
return dataset
if __name__ == "__main__":
st = time.time()
#args = get_config()
load_data_graft(args)
================================================
FILE: gnn/evaluate.py
================================================
from tqdm import tqdm
tqdm.monitor_iterval = 0
import torch
import numpy as np
import math, os
import json
import pickle
def cal_accuracy(pred, answer_dist):
"""
pred: batch_size
answer_dist: batch_size, max_local_entity
"""
num_correct = 0.0
num_answerable = 0.0
for i, l in enumerate(pred):
num_correct += (answer_dist[i, l] != 0)
for dist in answer_dist:
if np.sum(dist) != 0:
num_answerable += 1
return num_correct / len(pred), num_answerable / len(pred)
def f1_and_hits(answers, candidate2prob, id2entity, entity2name, eps=0.5):
ans = []
retrieved = []
for a in answers:
if entity2name is None:
ans.append(id2entity[a])
else:
ans.append(entity2name[id2entity[a]])
correct = 0
cand_list = sorted(candidate2prob, key=lambda x:x[1], reverse=True)
if len(cand_list) == 0:
best_ans = -1
else:
best_ans = cand_list[0][0]
# max_prob = cand_list[0][1]
tp_prob = 0.0
for c, prob in cand_list:
if entity2name is None:
retrieved.append((id2entity[c], prob))
else:
retrieved.append((entity2name[id2entity[c]], prob))
tp_prob += prob
if c in answers:
correct += 1
if tp_prob > eps:
break
if correct > 0:
em = 1
else:
em = 0
if len(answers) == 0:
if len(retrieved) == 0:
return 1.0, 1.0, 1.0, 1.0, 1.0, 0, retrieved, ans # precision, recall, f1, hits, em
else:
return 0.0, 1.0, 0.0, 1.0, 1.0, 1, retrieved , ans # precision, recall, f1, hits, em
else:
hits = float(best_ans in answers)
if len(retrieved) == 0:
return 1.0, 0.0, 0.0, hits, hits, 2, retrieved , ans # precision, recall, f1, hits, em
else:
p, r = correct / len(retrieved), correct / len(answers)
f1 = 2.0 / (1.0 / p + 1.0 / r) if p != 0 and r != 0 else 0.0
return p, r, f1, hits, em, 3, retrieved, ans
class Evaluator:
def __init__(self, args, model, entity2id, relation2id, device):
self.model = model
self.args = args
self.eps = args['eps']
self.model_name = args["model_name"]
id2entity = {idx: entity for entity, idx in entity2id.items()}
self.id2entity = id2entity
self.entity2name = None
if 'sr-' in args["data_folder"]:
file = open('ent2id.pickle', 'rb')
self.entity2name = list((pickle.load(file)).keys())
file.close()
id2relation = {idx: relation for relation, idx in relation2id.items()}
num_rel_ori = len(relation2id)
if 'use_inverse_relation' in args:
self.use_inverse_relation = args['use_inverse_relation']
if self.use_inverse_relation:
for i in range(len(id2relation)):
id2relation[i + num_rel_ori] = id2relation[i] + "_rev"
if 'use_self_loop' in args:
self.use_self_loop = args['use_self_loop']
if self.use_self_loop:
id2relation[len(id2relation)] = "self_loop"
self.id2relation = id2relation
self.file_write = None
self.device = device
def write_info(self, valid_data, tp_list, num_step):
question_list = valid_data.get_quest()
#num_step = steps
obj_list = []
if tp_list is not None:
# attn_list = [tp[1] for tp in tp_list]
action_list = [tp[0] for tp in tp_list]
for i in range(len(question_list)):
obj_list.append({})
for j in range(num_step):
if tp_list is None:
actions = None
else:
actions = action_list[j]
actions = actions.cpu().numpy()
# if attn_list is not None:
# attention = attn_list[j].cpu().numpy()
for i in range(len(question_list)):
tp_obj = obj_list[i]
q = question_list[i]
# real_index = self.true_batch_id[i][0]
tp_obj['question'] = q
tp_obj[j] = {}
# print(actions)
if tp_list is not None:
action = actions[i]
rel_action = self.id2relation[action]
tp_obj[j]['rel_action'] = rel_action
tp_obj[j]['action'] = str(action)
# if attn_list is not None:
# attention_tp = attention[i]
# tp_obj[j]['attention'] = attention_tp.tolist()
return obj_list
def evaluate(self, valid_data, test_batch_size=20, write_info=False):
write_info = True
self.model.eval()
self.count = 0
eps = self.eps
id2entity = self.id2entity
eval_loss, eval_acc, eval_max_acc = [], [], []
f1s, hits, ems, precisions, recalls = [], [], [], [], []
valid_data.reset_batches(is_sequential=True)
num_epoch = math.ceil(valid_data.num_data / test_batch_size)
if write_info and self.file_write is None:
filename = os.path.join(self.args['checkpoint_dir'],
"{}_test.info".format(self.args['experiment_name']))
self.file_write = open(filename, "w")
case_ct = {}
max_local_entity = valid_data.max_local_entity
ignore_prob = (1 - eps) / max_local_entity
for iteration in tqdm(range(num_epoch)):
batch = valid_data.get_batch(iteration, test_batch_size, fact_dropout=0.0, test=True)
with torch.no_grad():
loss, extras, pred_dist, tp_list = self.model(batch[:-1])
pred = torch.max(pred_dist, dim=1)[1]
if self.model_name == 'GraftNet':
local_entity, query_entities, _, _, query_text, _, \
seed_dist, true_batch_id, answer_dist, answer_list = batch
else:
local_entity, query_entities, _, query_text, \
seed_dist, true_batch_id, answer_dist, answer_list = batch
# self.true_batch_id = true_batch_id
if write_info:
obj_list = self.write_info(valid_data, tp_list, self.model.num_iter)
# pred_sum = torch.sum(pred_dist, dim=1)
# print(pred_sum)
candidate_entities = torch.from_numpy(local_entity).type('torch.LongTensor')
true_answers = torch.from_numpy(answer_dist).type('torch.FloatTensor')
query_entities = torch.from_numpy(query_entities).type('torch.LongTensor')
# acc, max_acc = cal_accuracy(pred, true_answers.cpu().numpy())
eval_loss.append(loss.item())
# eval_acc.append(acc)
# eval_max_acc.append(max_acc)
#pr_dist2 = pred_dist#.copy()
#pred_dist = pr_dist2[-1]
batch_size = pred_dist.size(0)
batch_answers = answer_list
batch_candidates = candidate_entities
pad_ent_id = len(id2entity)
#pr_dist2 = pred_dist.copy()
#for pred_dist in pr_dist2:
for batch_id in range(batch_size):
answers = batch_answers[batch_id]
candidates = batch_candidates[batch_id, :].tolist()
probs = pred_dist[batch_id, :].tolist()
seed_entities = query_entities[batch_id, :].tolist()
#print(seed_entities)
#print(candidates)
candidate2prob = []
for c, p, s in zip(candidates, probs, seed_entities):
if s == 1.0:
# ignore seed entities
#print(c, self.id2entity)
# print(c, p, s)
# if c < pad_ent_id:
# tp_obj['seed'] = self.id2entity[c]
continue
if c == pad_ent_id:
continue
if p < ignore_prob:
continue
candidate2prob.append((c, p))
precision, recall, f1, hit, em, case, retrived , ans = f1_and_hits(answers, candidate2prob, self.id2entity, self.entity2name ,eps)
if write_info:
tp_obj = obj_list[batch_id]
tp_obj['answers'] = ans
tp_obj['precison'] = precision
tp_obj['recall'] = recall
tp_obj['f1'] = f1
tp_obj['hit'] = hit
tp_obj['em'] = em
tp_obj['cand'] = retrived
self.file_write.write(json.dumps(tp_obj) + "\n")
case_ct.setdefault(case, 0)
case_ct[case] += 1
f1s.append(f1)
hits.append(hit)
ems.append(em)
precisions.append(precision)
recalls.append(recall)
print('evaluation.......')
print('how many eval samples......', len(f1s))
# print('avg_f1', np.mean(f1s))
print('avg_em', np.mean(ems))
print('avg_hits', np.mean(hits))
print('avg_f1', np.mean(f1s))
print('avg_precision', np.mean(precisions))
print('avg_recall', np.mean(recalls))
print(case_ct)
if write_info:
self.file_write.close()
self.file_write = None
return np.mean(f1s), np.mean(hits), np.mean(ems)
================================================
FILE: gnn/main.py
================================================
import argparse
from utils import create_logger
import torch
import numpy as np
import os
import time
#from Models.ReaRev.rearev import
from train_model import Trainer_KBQA
from parsing import add_parse_args
parser = argparse.ArgumentParser()
add_parse_args(parser)
args = parser.parse_args()
args.use_cuda = torch.cuda.is_available()
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.experiment_name == None:
timestamp = str(int(time.time()))
args.experiment_name = "{}-{}-{}".format(
args.dataset,
args.model_name,
timestamp,
)
def main():
if not os.path.exists(args.checkpoint_dir):
os.mkdir(args.checkpoint_dir)
logger = create_logger(args)
trainer = Trainer_KBQA(args=vars(args), model_name=args.model_name, logger=logger)
if not args.is_eval:
trainer.train(0, args.num_epoch - 1)
else:
assert args.load_experiment is not None
if args.load_experiment is not None:
ckpt_path = os.path.join(args.checkpoint_dir, args.load_experiment)
print("Loading pre trained model from {}".format(ckpt_path))
else:
ckpt_path = None
trainer.evaluate_single(ckpt_path)
if __name__ == '__main__':
main()
================================================
FILE: gnn/models/GraftNet/graftnet.py
================================================
import torch
import numpy as np
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
from models.base_model import BaseModel
from modules.kg_reasoning.graft_gnn import GraftLayer
from modules.question_encoding.lstm_encoder import LSTMInstruction#, BERTInstruction
from modules.question_encoding.bert_encoder import BERTInstruction#, BERTInstruction
from modules.layer_init import TypeLayer
from modules.query_update import AttnEncoder
VERY_SMALL_NUMBER = 1e-10
VERY_NEG_NUMBER = -100000000000
class GraftNet(BaseModel):
def __init__(self, args, num_entity, num_relation, num_word):
"""
num_relation: number of relation including self-connection
"""
super(GraftNet, self).__init__(args, num_entity, num_relation, num_word)
self.num_layer = args['num_layer']
self.loss_type = args['loss_type']
self.model_name = args['model_name'].lower()
self.lm = args['lm']
self.norm_rel = args['norm_rel']
self.num_iter = self.num_layer
self.layers(args)
self.private_module_def(args, num_entity, num_relation)
self.to(self.device)
def layers(self, args):
# initialize entity embedding
word_dim = self.word_dim
kg_dim = self.kg_dim
entity_dim = self.entity_dim
self.linear_dropout = args['linear_dropout']
self.entity_linear = nn.Linear(in_features=self.ent_dim, out_features=entity_dim)
self.relation_linear1 = nn.Linear(in_features=self.rel_dim, out_features=entity_dim)
# dropout
self.linear_drop = nn.Dropout(p=self.linear_dropout)
if self.encode_type:
self.type_layer = TypeLayer(in_features=entity_dim, out_features=entity_dim,
linear_drop=self.linear_drop, device=self.device, norm_rel=self.norm_rel)
self.self_att_r = AttnEncoder(self.entity_dim)
self.kld_loss = nn.KLDivLoss(reduction='none')
self.bce_loss_logits = nn.BCEWithLogitsLoss(reduction='none')
self.mse_loss = torch.nn.MSELoss()
def private_module_def(self, args, num_entity, num_relation):
# initialize entity embedding
word_dim = self.word_dim
kg_dim = self.kg_dim
entity_dim = self.entity_dim
self.reasoning = GraftLayer(args, num_entity, num_relation, entity_dim)
if args['lm'] == 'lstm':
self.instruction = LSTMInstruction(args, self.word_embedding, self.num_word)
else:
self.instruction = BERTInstruction(args, self.word_embedding, self.num_word, args['lm'])
self.relation_linear = nn.Linear(in_features=self.word_dim, out_features=entity_dim)
def get_ent_init(self, local_entity, kb_adj_mat, rel_features):
if self.encode_type:
local_entity_emb = self.type_layer(local_entity=local_entity,
edge_list=kb_adj_mat,
rel_features=rel_features)
else:
local_entity_emb = self.entity_embedding(local_entity) # batch_size, max_local_entity, word_dim
local_entity_emb = self.entity_linear(local_entity_emb)
return local_entity_emb
def get_rel_feature(self):
if self.rel_texts is None:
rel_features = self.relation_embedding.weight
rel_features = self.relation_linear1(rel_features)
else:
#rel_features = self.instruction.encode_question(self.rel_texts, store=False)
rel_features = self.rel_features # self.relation_linear(self.rel_features)
#print(rel_features.size())
#print(self.instruction.question_emb)
rel_features = self.instruction.question_emb(rel_features)
#rel_features = self.relation_linear(rel_features)
rel_features = self.self_att_r(rel_features, (self.rel_texts != self.instruction.pad_val).float())
if self.lm == 'lstm':
rel_features = self.self_att_r(rel_features, (self.rel_texts != self.num_relation+1).float())
# else:
# rel_features = self.self_att_r(rel_features, (self.rel_texts != self.instruction.pad_val).float())
return rel_features
def init_reason(self, curr_dist, local_entity, kb_adj_mat, kb_adj_mat_graft, kb_fact_rel, q_input):
# batch_size = local_entity.size(0)
self.local_entity = local_entity
self.instruction_list, self.attn_list = self.instruction(q_input)
self.query_hidden_emb = self.instruction.query_hidden_emb
self.query_node_emb = self.instruction.query_node_emb
self.query_mask = self.instruction.query_mask
rel_features = self.get_rel_feature()
self.local_entity_emb = self.get_ent_init(local_entity, kb_adj_mat, rel_features)
self.curr_dist = curr_dist
self.dist_history = []
self.action_probs = []
self.seed_entities = curr_dist
self.reasoning.init_reason(local_entity=local_entity,
kb_adj_mat=kb_adj_mat,
kb_adj_mat_graft=kb_adj_mat_graft,
kb_fact_rel = kb_fact_rel,
local_entity_emb=self.local_entity_emb,
rel_features=rel_features,
query_node_emb=self.query_node_emb)
def calc_loss_label(self, curr_dist, teacher_dist, label_valid):
tp_loss = self.get_loss(pred_dist=curr_dist, answer_dist=teacher_dist, reduction='none')
tp_loss = tp_loss * label_valid
cur_loss = torch.sum(tp_loss) / curr_dist.size(0)
return cur_loss
def forward(self, batch, training=False):
local_entity, query_entities, kb_adj_mat ,kb_adj_mat_graft, query_text, kb_fact_rel, seed_dist, true_batch_id, answer_dist = batch
local_entity = torch.from_numpy(local_entity).type('torch.LongTensor').to(self.device)
# local_entity_mask = (local_entity != self.num_entity).float()
query_entities = torch.from_numpy(query_entities).type('torch.FloatTensor').to(self.device)
answer_dist = torch.from_numpy(answer_dist).type('torch.FloatTensor').to(self.device)
seed_dist = torch.from_numpy(seed_dist).type('torch.FloatTensor').to(self.device)
current_dist = Variable(seed_dist, requires_grad=True)
q_input= torch.from_numpy(query_text).type('torch.LongTensor').to(self.device)
if self.lm == 'bert':
query_mask = (q_input != 0).float()
else:
query_mask = (q_input != self.num_word).float()
#query_mask = (q_input != self.num_word).float()
#instruction generation
self.init_reason(curr_dist=current_dist, local_entity=local_entity,
kb_adj_mat=kb_adj_mat, kb_adj_mat_graft=kb_adj_mat_graft, kb_fact_rel=kb_fact_rel, q_input=q_input)
self.instruction.init_reason(q_input)
#reasoning
self.curr_dist = current_dist
self.ent_dist = current_dist
self.dist_history.append(self.curr_dist)
for i in range(self.num_layer):
score_tp, score, self.curr_dist= self.reasoning(self.curr_dist, self.query_hidden_emb, self.query_mask, step=i, return_score=True)
self.dist_history.append(score)
pred_dist = self.dist_history[-1]
answer_number = torch.sum(answer_dist, dim=1, keepdim=True)
case_valid = (answer_number > 0).float()
loss = self.calc_loss_label(curr_dist=score_tp, teacher_dist=answer_dist, label_valid=case_valid)
pred_dist = self.dist_history[-1]
pred = torch.max(pred_dist, dim=1)[1]
# answer_mask = self.local_entity_mask
# self.possible_cand.append(answer_mask)
# score_tp = score_tp + (1 - answer_mask) * VERY_NEG_NUMBER
if training:
h1, f1 = self.get_eval_metric(pred_dist, answer_dist)
tp_list = [h1.tolist(), f1.tolist()]
else:
tp_list = None
return loss, pred, pred_dist, tp_list
================================================
FILE: gnn/models/NSM/nsm.py
================================================
import torch
import numpy as np
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
from models.base_model import BaseModel
from modules.kg_reasoning.nsm_gnn import NSMLayer, NSMLayer_back
from modules.question_encoding.lstm_encoder import LSTMInstruction#, BERTInstruction
from modules.question_encoding.bert_encoder import BERTInstruction
from modules.layer_init import TypeLayer
from modules.query_update import AttnEncoder
VERY_SMALL_NUMBER = 1e-10
VERY_NEG_NUMBER = -100000000000
class NSM(BaseModel):
def __init__(self, args, num_entity, num_relation, num_word):
"""
num_relation: number of relation including self-connection
"""
super(NSM, self).__init__(args, num_entity, num_relation, num_word)
self.num_step = args['num_step']
self.num_iter = self.num_step
self.loss_type = args['loss_type']
self.model_name = args['model_name'].lower()
self.lambda_constrain = args['lambda_constrain']
self.lambda_back = args['lambda_back']
self.lm = args['lm']
self.norm_rel = args['norm_rel']
self.layers(args)
self.private_module_def(args, num_entity, num_relation)
self.to(self.device)
def layers(self, args):
# initialize entity embedding
word_dim = self.word_dim
kg_dim = self.kg_dim
entity_dim = self.entity_dim
self.linear_dropout = args['linear_dropout']
self.entity_linear = nn.Linear(in_features=self.ent_dim, out_features=entity_dim)
self.relation_linear1 = nn.Linear(in_features=self.rel_dim, out_features=entity_dim)
self.relation_linear2 = nn.Linear(in_features=self.rel_dim, out_features=entity_dim)
self.kg_lin = nn.Linear(in_features=entity_dim, out_features=entity_dim)
self.softmax_d1 = nn.Softmax(dim=1)
self.score_func = nn.Linear(in_features=2*entity_dim, out_features=1)
# dropout
self.linear_drop = nn.Dropout(p=self.linear_dropout)
if self.encode_type:
self.type_layer = TypeLayer(in_features=entity_dim, out_features=entity_dim,
linear_drop=self.linear_drop, device=self.device, norm_rel=self.norm_rel)
self.self_att_r = AttnEncoder(self.entity_dim)
self.self_att_r2 = AttnEncoder(self.entity_dim)
self.kld_loss = nn.KLDivLoss(reduction='none')
self.kld_loss_1 = nn.KLDivLoss(reduction='none')
self.bce_loss_logits = nn.BCEWithLogitsLoss(reduction='none')
self.mse_loss = torch.nn.MSELoss()
def private_module_def(self, args, num_entity, num_relation):
# initialize entity embedding
word_dim = self.word_dim
kg_dim = self.kg_dim
entity_dim = self.entity_dim
self.reasoning = NSMLayer(args, num_entity, num_relation, entity_dim)
self.reasoning2 = NSMLayer(args, num_entity, num_relation, entity_dim)
if self.lambda_back != 0.0 or self.lambda_constrain != 0.0:
self.reasoning_back = NSMLayer_back(args, num_entity, num_relation, entity_dim)
if args['lm'] == 'lstm':
self.instruction = LSTMInstruction(args, self.word_embedding, self.num_word)
else:
self.instruction = BERTInstruction(args, self.word_embedding, self.num_word, args['lm'])
self.relation_linear = nn.Linear(in_features=self.word_dim, out_features=entity_dim)
def get_ent_init(self, local_entity, kb_adj_mat, rel_features):
if self.encode_type:
local_entity_emb = self.type_layer(local_entity=local_entity,
edge_list=kb_adj_mat,
rel_features=rel_features)
else:
local_entity_emb = self.entity_embedding(local_entity) # batch_size, max_local_entity, word_dim
local_entity_emb = self.entity_linear(local_entity_emb)
return local_entity_emb
def get_rel_feature(self):
if self.rel_texts is None:
rel_features = self.relation_embedding.weight
rel_features = self.relation_linear1(rel_features)
else:
#rel_features = self.instruction.encode_question(self.rel_texts, store=False)
rel_features = self.instruction.question_emb(self.rel_features)
#rel_features = self.relation_linear(rel_features)
rel_features = self.self_att_r(rel_features, (self.rel_texts != self.instruction.pad_val).float())
if self.lm == 'lstm':
rel_features = self.self_att_r(rel_features, (self.rel_texts != self.num_relation+1).float())
# else:
# rel_features = self.self_att_r(rel_features, (self.rel_texts != self.instruction.pad_val).float())
return rel_features
def init_reason(self, curr_dist, local_entity, kb_adj_mat, q_input):
# batch_size = local_entity.size(0)
self.local_entity = local_entity
self.instruction_list, self.attn_list = self.instruction(q_input)
rel_features = self.get_rel_feature()
#print(self.rel_features1)
#self.rel_features2 = self.get_rel_feature2()
self.local_entity_emb = self.get_ent_init(local_entity, kb_adj_mat, rel_features)
#self.kge_entity_emb = self.get_ent_init2(local_entity, kb_adj_mat, self.rel_features)
self.curr_dist = curr_dist
self.dist_history = []
self.dist_history2 = []
self.backward_history = []
self.action_probs = []
self.seed_entities = curr_dist
self.reasoning.init_reason(local_entity=local_entity,
kb_adj_mat=kb_adj_mat,
local_entity_emb=self.local_entity_emb,
rel_features=rel_features)
if self.lambda_back != 0.0 or self.lambda_constrain != 0.0:
self.reasoning_back.init_reason(local_entity=local_entity,
kb_adj_mat=kb_adj_mat,
local_entity_emb=self.local_entity_emb,
rel_features=rel_features)
def get_js_div(self, dist_1, dist_2):
mean_dist = (dist_1 + dist_2) / 2
log_mean_dist = torch.log(mean_dist + 1e-8)
# loss_kl_1 = self.kld_loss_1(log_mean_dist, dist_1)
# loss_kl_2 = self.kld_loss_1(log_mean_dist, dist_2)
# print(loss_kl_1.item(), loss_kl_2.item())
loss = 0.5 * (self.kld_loss_1(log_mean_dist, dist_1) + self.kld_loss_1(log_mean_dist, dist_2))
return loss
def calc_loss_backward(self, case_valid):
back_loss = None
constrain_loss = None
for i in range(self.num_step):
forward_dist = self.dist_history[i]
backward_dist = self.backward_history[i]
if i == 0:
# back_loss = self.get_loss_new(backward_dist, forward_dist)
back_loss = self.calc_loss_label(curr_dist=backward_dist,
teacher_dist=forward_dist,
label_valid=case_valid)
# backward last step should be similar with seed distribution
else:
tp_loss = self.get_js_div(forward_dist, backward_dist)
tp_loss = torch.sum(tp_loss * case_valid) / forward_dist.size(0)
if constrain_loss is None:
constrain_loss = tp_loss
else:
constrain_loss += tp_loss
return back_loss, constrain_loss
def calc_loss_label(self, curr_dist, teacher_dist, label_valid):
tp_loss = self.get_loss(pred_dist=curr_dist, answer_dist=teacher_dist, reduction='none')
tp_loss = tp_loss * label_valid
cur_loss = torch.sum(tp_loss) / curr_dist.size(0)
return cur_loss
def forward(self, batch, training=False):
local_entity, query_entities, kb_adj_mat, query_text, seed_dist, true_batch_id, answer_dist = batch
local_entity = torch.from_numpy(local_entity).type('torch.LongTensor').to(self.device)
# local_entity_mask = (local_entity != self.num_entity).float()
query_entities = torch.from_numpy(query_entities).type('torch.FloatTensor').to(self.device)
answer_dist = torch.from_numpy(answer_dist).type('torch.FloatTensor').to(self.device)
seed_dist = torch.from_numpy(seed_dist).type('torch.FloatTensor').to(self.device)
current_dist = Variable(seed_dist, requires_grad=True)
q_input= torch.from_numpy(query_text).type('torch.LongTensor').to(self.device)
#ent_texts= torch.from_numpy(ent_texts).type('torch.LongTensor').to(self.device)
#ent_texts= torch.from_numpy(ent_texts).type('torch.FloatTensor').to(self.device)
if self.lm != 'lstm':
pad_val = self.instruction.pad_val #tokenizer.convert_tokens_to_ids(self.instruction.tokenizer.pad_token)
query_mask = (q_input != pad_val).float()
else:
query_mask = (q_input != self.num_word).float()
"""
Instruction generations
"""
self.init_reason(curr_dist=current_dist, local_entity=local_entity,
kb_adj_mat=kb_adj_mat, q_input=q_input)
self.instruction.init_reason(q_input)
for i in range(self.num_step):
relational_ins, attn_weight = self.instruction.get_instruction(self.instruction.relational_ins, step=i)
self.instruction.instructions.append(relational_ins.unsqueeze(1))
self.instruction.relational_ins = relational_ins
"""
GNN reasoning
"""
self.curr_dist = current_dist
self.dist_history.append(self.curr_dist)
self.dist_history2.append(self.curr_dist)
for i in range(self.num_step):
self.curr_dist = self.reasoning(self.curr_dist, self.instruction_list[i], step=i)
self.dist_history.append(self.curr_dist)
"""
NSM backward learning (if used)
"""
if self.lambda_back != 0.0 or self.lambda_constrain != 0.0:
answer_len = torch.sum(answer_dist, dim=1, keepdim=True)
answer_len[answer_len == 0] = 1.0
answer_prob = answer_dist.div(answer_len)
self.curr_dist_back = answer_prob
self.backward_history.append(self.curr_dist_back)
for i in range(self.num_step):
self.curr_dist_back = self.reasoning_back(self.curr_dist_back, self.instruction_list[self.num_step-i-1], step=i)
self.backward_history.append(self.curr_dist_back)
pred_dist = self.dist_history[-1]
answer_number = torch.sum(answer_dist, dim=1, keepdim=True)
case_valid = (answer_number > 0).float()
# filter no answer training case
# loss = torch.sum(tp_loss * case_valid) / pred_dist.size(0)
loss =self.calc_loss_label(curr_dist=pred_dist, teacher_dist=answer_dist, label_valid=case_valid)
if self.lambda_back > 0.0 or self.lambda_constrain > 0.0:
back_loss, constrain_loss = self.calc_loss_backward(case_valid)
loss = loss + self.lambda_back * back_loss + self.lambda_constrain * constrain_loss
pred = torch.max(pred_dist, dim=1)[1]
if training:
h1, f1 = self.get_eval_metric(pred_dist, answer_dist)
tp_list = [h1.tolist(), f1.tolist()]
else:
tp_list = None
#return loss, pred, 0.5*(pred_dist+pred_dist2), tp_list
return loss, pred, pred_dist, tp_list
================================================
FILE: gnn/models/ReaRev/rearev.py
================================================
import torch
import numpy as np
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
from models.base_model import BaseModel
from modules.kg_reasoning.reasongnn import ReasonGNNLayer
from modules.question_encoding.lstm_encoder import LSTMInstruction
from modules.question_encoding.bert_encoder import BERTInstruction
from modules.layer_init import TypeLayer
from modules.query_update import AttnEncoder, Fusion, QueryReform
VERY_SMALL_NUMBER = 1e-10
VERY_NEG_NUMBER = -100000000000
class ReaRev(BaseModel):
def __init__(self, args, num_entity, num_relation, num_word):
"""
Init ReaRev model.
"""
super(ReaRev, self).__init__(args, num_entity, num_relation, num_word)
#self.embedding_def()
#self.share_module_def()
self.norm_rel = args['norm_rel']
self.layers(args)
self.loss_type = args['loss_type']
self.num_iter = args['num_iter']
self.num_ins = args['num_ins']
self.num_gnn = args['num_gnn']
self.alg = args['alg']
assert self.alg == 'bfs'
self.lm = args['lm']
self.private_module_def(args, num_entity, num_relation)
self.to(self.device)
self.lin = nn.Linear(3*self.entity_dim, self.entity_dim)
self.fusion = Fusion(self.entity_dim)
self.reforms = []
for i in range(self.num_ins):
self.add_module('reform' + str(i), QueryReform(self.entity_dim))
# self.reform_rel = QueryReform(self.entity_dim)
# self.add_module('reform', QueryReform(self.entity_dim))
def layers(self, args):
# initialize entity embedding
word_dim = self.word_dim
kg_dim = self.kg_dim
entity_dim = self.entity_dim
#self.lstm_dropout = args['lstm_dropout']
self.linear_dropout = args['linear_dropout']
self.entity_linear = nn.Linear(in_features=self.ent_dim, out_features=entity_dim)
self.relation_linear = nn.Linear(in_features=self.rel_dim, out_features=entity_dim)
# self.relation_linear_inv = nn.Linear(in_features=self.rel_dim, out_features=entity_dim)
#self.relation_linear = nn.Linear(in_features=self.rel_dim, out_features=entity_dim)
# dropout
#self.lstm_drop = nn.Dropout(p=self.lstm_dropout)
self.linear_drop = nn.Dropout(p=self.linear_dropout)
if self.encode_type:
self.type_layer = TypeLayer(in_features=entity_dim, out_features=entity_dim,
linear_drop=self.linear_drop, device=self.device, norm_rel=self.norm_rel)
self.self_att_r = AttnEncoder(self.entity_dim)
#self.self_att_r_inv = AttnEncoder(self.entity_dim)
self.kld_loss = nn.KLDivLoss(reduction='none')
self.bce_loss_logits = nn.BCEWithLogitsLoss(reduction='none')
self.mse_loss = torch.nn.MSELoss()
def get_ent_init(self, local_entity, kb_adj_mat, rel_features):
if self.encode_type:
local_entity_emb = self.type_layer(local_entity=local_entity,
edge_list=kb_adj_mat,
rel_features=rel_features)
else:
local_entity_emb = self.entity_embedding(local_entity) # batch_size, max_local_entity, word_dim
local_entity_emb = self.entity_linear(local_entity_emb)
return local_entity_emb
def get_rel_feature(self):
"""
Encode relation tokens to vectors.
"""
if self.rel_texts is None:
rel_features = self.relation_embedding.weight
rel_features_inv = self.relation_embedding_inv.weight
rel_features = self.relation_linear(rel_features)
rel_features_inv = self.relation_linear(rel_features_inv)
else:
rel_features = self.instruction.question_emb(self.rel_features)
rel_features_inv = self.instruction.question_emb(self.rel_features_inv)
rel_features = self.self_att_r(rel_features, (self.rel_texts != self.instruction.pad_val).float())
rel_features_inv = self.self_att_r(rel_features_inv, (self.rel_texts != self.instruction.pad_val).float())
if self.lm == 'lstm':
rel_features = self.self_att_r(rel_features, (self.rel_texts != self.num_relation+1).float())
rel_features_inv = self.self_att_r(rel_features_inv, (self.rel_texts_inv != self.num_relation+1).float())
return rel_features, rel_features_inv
def private_module_def(self, args, num_entity, num_relation):
"""
Building modules: LM encoder, GNN, etc.
"""
# initialize entity embedding
word_dim = self.word_dim
kg_dim = self.kg_dim
entity_dim = self.entity_dim
self.reasoning = ReasonGNNLayer(args, num_entity, num_relation, entity_dim, self.alg)
if args['lm'] == 'lstm':
self.instruction = LSTMInstruction(args, self.word_embedding, self.num_word)
self.relation_linear = nn.Linear(in_features=entity_dim, out_features=entity_dim)
else:
self.instruction = BERTInstruction(args, self.word_embedding, self.num_word, args['lm'])
#self.relation_linear = nn.Linear(in_features=self.instruction.word_dim, out_features=entity_dim)
# self.relation_linear = nn.Linear(in_features=entity_dim, out_features=entity_dim)
# self.relation_linear_inv = nn.Linear(in_features=entity_dim, out_features=entity_dim)
def init_reason(self, curr_dist, local_entity, kb_adj_mat, q_input, query_entities):
"""
Initializing Reasoning
"""
# batch_size = local_entity.size(0)
self.local_entity = local_entity
self.instruction_list, self.attn_list = self.instruction(q_input)
rel_features, rel_features_inv = self.get_rel_feature()
self.local_entity_emb = self.get_ent_init(local_entity, kb_adj_mat, rel_features)
self.init_entity_emb = self.local_entity_emb
self.curr_dist = curr_dist
self.dist_history = []
self.action_probs = []
self.seed_entities = curr_dist
self.reasoning.init_reason(
local_entity=local_entity,
kb_adj_mat=kb_adj_mat,
local_entity_emb=self.local_entity_emb,
rel_features=rel_features,
rel_features_inv=rel_features_inv,
query_entities=query_entities)
def calc_loss_label(self, curr_dist, teacher_dist, label_valid):
tp_loss = self.get_loss(pred_dist=curr_dist, answer_dist=teacher_dist, reduction='none')
tp_loss = tp_loss * label_valid
cur_loss = torch.sum(tp_loss) / curr_dist.size(0)
return cur_loss
def forward(self, batch, training=False):
"""
Forward function: creates instructions and performs GNN reasoning.
"""
# local_entity, query_entities, kb_adj_mat, query_text, seed_dist, answer_dist = batch
local_entity, query_entities, kb_adj_mat, query_text, seed_dist, true_batch_id, answer_dist = batch
local_entity = torch.from_numpy(local_entity).type('torch.LongTensor').to(self.device)
# local_entity_mask = (local_entity != self.num_entity).float()
query_entities = torch.from_numpy(query_entities).type('torch.FloatTensor').to(self.device)
answer_dist = torch.from_numpy(answer_dist).type('torch.FloatTensor').to(self.device)
seed_dist = torch.from_numpy(seed_dist).type('torch.FloatTensor').to(self.device)
current_dist = Variable(seed_dist, requires_grad=True)
q_input= torch.from_numpy(query_text).type('torch.LongTensor').to(self.device)
#query_text2 = torch.from_numpy(query_text2).type('torch.LongTensor').to(self.device)
if self.lm != 'lstm':
pad_val = self.instruction.pad_val #tokenizer.convert_tokens_to_ids(self.instruction.tokenizer.pad_token)
query_mask = (q_input != pad_val).float()
else:
query_mask = (q_input != self.num_word).float()
"""
Instruction generations
"""
self.init_reason(curr_dist=current_dist, local_entity=local_entity,
kb_adj_mat=kb_adj_mat, q_input=q_input, query_entities=query_entities)
self.instruction.init_reason(q_input)
for i in range(self.num_ins):
relational_ins, attn_weight = self.instruction.get_instruction(self.instruction.relational_ins, step=i)
self.instruction.instructions.append(relational_ins.unsqueeze(1))
self.instruction.relational_ins = relational_ins
#relation_ins = torch.cat(self.instruction.instructions, dim=1)
#query_emb = None
self.dist_history.append(self.curr_dist)
"""
BFS + GNN reasoning
"""
for t in range(self.num_iter):
relation_ins = torch.cat(self.instruction.instructions, dim=1)
self.curr_dist = current_dist
for j in range(self.num_gnn):
self.curr_dist, global_rep = self.reasoning(self.curr_dist, relation_ins, step=j)
self.dist_history.append(self.curr_dist)
qs = []
"""
Instruction Updates
"""
for j in range(self.num_ins):
reform = getattr(self, 'reform' + str(j))
q = reform(self.instruction.instructions[j].squeeze(1), global_rep, query_entities, local_entity)
qs.append(q.unsqueeze(1))
self.instruction.instructions[j] = q.unsqueeze(1)
"""
Answer Predictions
"""
pred_dist = self.dist_history[-1]
answer_number = torch.sum(answer_dist, dim=1, keepdim=True)
case_valid = (answer_number > 0).float()
# filter no answer training case
# loss = 0
# for pred_dist in self.dist_history:
loss = self.calc_loss_label(curr_dist=pred_dist, teacher_dist=answer_dist, label_valid=case_valid)
pred_dist = self.dist_history[-1]
pred = torch.max(pred_dist, dim=1)[1]
if training:
h1, f1 = self.get_eval_metric(pred_dist, answer_dist)
tp_list = [h1.tolist(), f1.tolist()]
else:
tp_list = None
return loss, pred, pred_dist, tp_list
================================================
FILE: gnn/models/base_model.py
================================================
import torch
import numpy as np
import torch.nn as nn
import numpy as np
VERY_SMALL_NUMBER = 1e-10
class BaseModel(torch.nn.Module):
"""
Base model functions: create embeddings, store relations, compute f1/h1 scores, etc.
"""
def __init__(self, args, num_entity, num_relation, num_word):
super(BaseModel, self).__init__()
self.num_relation = num_relation
self.num_entity = num_entity
self.num_word = num_word
print('Num Word', self.num_word)
self.kge_frozen = args['kge_frozen']
self.kg_dim = args['kg_dim']
#self._parse_args(args)
self.entity_emb_file = args['entity_emb_file']
self.relation_emb_file = args['relation_emb_file']
self.relation_word_emb = args['relation_word_emb']
self.word_emb_file = args['word_emb_file']
self.entity_dim = args['entity_dim']
self.lm = args['lm']
if self.lm in ['bert']:
#self.word_dim = 768
args['word_dim'] = 768
self.word_dim = args['word_dim']
self.rel_texts = None
#self.share_module_def()
#self.model_name = args['model_name'].lower()
self.device = torch.device('cuda' if args['use_cuda'] else 'cpu')
print("Entity: {}, Relation: {}, Word: {}".format(num_entity, num_relation, num_word))
self.kld_loss = nn.KLDivLoss(reduction='none')
self.bce_loss = nn.BCEWithLogitsLoss(reduction='none')
self.mse_loss = torch.nn.MSELoss()
for k, v in args.items():
if k.endswith('dim'):
setattr(self, k, v)
if k.endswith('emb_file') or k.endswith('kge_file'):
if v is None:
setattr(self, k, None)
else:
setattr(self, k, args['data_folder'] + v)
self.reset_time = 0
if 'use_inverse_relation' in args:
self.use_inverse_relation = args['use_inverse_relation']
if 'use_self_loop' in args:
self.use_self_loop = args['use_self_loop']
self.eps = args['eps']
self.embedding_def()
args['word_dim'] = self.word_dim
def embedding_def(self):
num_entity = self.num_entity
num_relation = self.num_relation
num_word = self.num_word
if self.lm != 'lstm':
self.word_dim = 768
self.word_embedding = nn.Embedding(num_embeddings=num_word + 1, embedding_dim=self.word_dim,
padding_idx=num_word)
elif self.word_emb_file is not None:
word_emb = np.load(self.word_emb_file)
_ , self.word_dim = word_emb.shape
print('Word emb dim', self.word_dim)
self.word_embedding = nn.Embedding(num_embeddings=num_word + 1, embedding_dim=self.word_dim,
padding_idx=num_word)
self.word_embedding.weight = nn.Parameter(
torch.from_numpy(
np.pad(np.load(self.word_emb_file), ((0, 1), (0, 0)), 'constant')).type(
'torch.FloatTensor'))
self.word_embedding.weight.requires_grad = False
else:
#self.word_dim = 768
self.word_embedding = nn.Embedding(num_embeddings=num_word + 1, embedding_dim=self.word_dim,
padding_idx=num_word)
if self.entity_emb_file is not None:
self.encode_type = False
emb = np.load(self.entity_emb_file)
ent_num , self.ent_dim = emb.shape
# if ent_num != num_entity:
# print('Number of entities in KG embeddings do not match: Random Init.')
self.entity_embedding = nn.Embedding(num_embeddings=num_entity + 1, embedding_dim=self.ent_dim,
padding_idx=num_entity)
if ent_num != num_entity:
print('Number of entities in KG embeddings do not match: Random Init.')
else:
self.entity_embedding.weight = nn.Parameter(
torch.from_numpy(np.pad(emb, ((0, 1), (0, 0)), 'constant')).type(
'torch.FloatTensor'))
if self.kge_frozen:
self.entity_embedding.weight.requires_grad = False
else:
self.entity_embedding.weight.requires_grad = True
else:
self.ent_dim = self.kg_dim
self.encode_type = True
#self.entity_embedding = nn.Embedding(num_embeddings=num_entity + 1, embedding_dim=self.ent_dim,
#padding_idx=num_entity)
# initialize relation embedding
if self.relation_emb_file is not None:
np_tensor = self.load_relation_file(self.relation_emb_file)
#print('check?', np_tensor.shape)
rel_num, self.rel_dim = np_tensor.shape
self.relation_embedding = nn.Embedding(num_embeddings=num_relation+1, embedding_dim=self.rel_dim)
if rel_num != num_relation:
print('Number of relations in KG embeddings do not match: Random Init.')
else:
self.relation_embedding.weight = nn.Parameter(torch.from_numpy(np_tensor).type('torch.FloatTensor'))
if self.kge_frozen:
self.relation_embedding.weight.requires_grad = False
else:
self.relation_embedding.weight.requires_grad = True
elif self.relation_word_emb:
self.rel_dim = self.entity_dim
self.relation_embedding = nn.Embedding(num_embeddings=num_relation+1, embedding_dim=self.rel_dim)
self.relation_embedding.weight.requires_grad = True
self.relation_embedding_inv = nn.Embedding(num_embeddings=num_relation+1, embedding_dim=self.rel_dim)
self.relation_embedding_inv.weight.requires_grad = True
pass
else:
self.rel_dim = 2*self.kg_dim
self.relation_embedding = nn.Embedding(num_embeddings=num_relation+1, embedding_dim=self.rel_dim)
self.relation_embedding_inv = nn.Embedding(num_embeddings=num_relation+1, embedding_dim=self.rel_dim)
# initialize text embeddings
def load_relation_file(self, filename):
half_tensor = np.load(filename)
num_pad = 0
if self.use_self_loop:
num_pad = 2
if self.use_inverse_relation:
load_tensor = np.concatenate([half_tensor, half_tensor])
else:
load_tensor = half_tensor
return np.pad(load_tensor, ((0, num_pad), (0, 0)), 'constant')
def use_rel_texts(self, rel_texts, rel_texts_inv):
self.rel_texts = torch.from_numpy(rel_texts).type('torch.LongTensor').to(self.device)
self.rel_texts_inv = torch.from_numpy(rel_texts_inv).type('torch.LongTensor').to(self.device)
def encode_rel_texts(self, rel_texts, rel_texts_inv):
self.rel_texts = torch.from_numpy(rel_texts).type('torch.LongTensor').to(self.device)
self.rel_texts_inv = torch.from_numpy(rel_texts_inv).type('torch.LongTensor').to(self.device)
self.instruction.eval()
with torch.no_grad():
self.rel_features = self.instruction.encode_question(self.rel_texts, store=False)
self.rel_features_inv = self.instruction.encode_question(self.rel_texts_inv, store=False)
self.rel_features.requires_grad = False
self.rel_features_inv.requires_grad = False
def init_hidden(self, num_layer, batch_size, hidden_size):
return self.instruction.init_hidden(num_layer, batch_size, hidden_size)
def encode_question(self, q_input):
return self.instruction.encode_question(q_input)
def get_instruction(self, query_hidden_emb, query_mask, states):
return self.instruction.get_instruction(query_hidden_emb, query_mask, states)
def get_loss_bce(self, pred_dist_score, answer_dist):
answer_dist = (answer_dist > 0).float() * 0.9 # label smooth
# answer_dist = answer_dist * 0.9 # label smooth
loss = self.bce_loss_logits(pred_dist_score, answer_dist)
return loss
def get_loss_kl(self, pred_dist, answer_dist):
answer_len = torch.sum(answer_dist, dim=1, keepdim=True)
answer_len[answer_len == 0] = 1.0
answer_prob = answer_dist.div(answer_len)
log_prob = torch.log(pred_dist + 1e-8)
loss = self.kld_loss(log_prob, answer_prob)
return loss
def get_loss(self, pred_dist, answer_dist, reduction='mean'):
if self.loss_type == "bce":
tp_loss = self.get_loss_bce(pred_dist, answer_dist)
if reduction == 'none':
return tp_loss
else:
# mean
return torch.mean(tp_loss)
else:
tp_loss = self.get_loss_kl(pred_dist, answer_dist)
if reduction == 'none':
return tp_loss
else:
# batchmean
return torch.sum(tp_loss) / pred_dist.size(0)
def f1_and_hits(self, answers, candidate2prob, eps=0.5):
retrieved = []
correct = 0
cand_list = sorted(candidate2prob, key=lambda x:x[1], reverse=True)
if len(cand_list) == 0:
best_ans = -1
else:
best_ans = cand_list[0][0]
# max_prob = cand_list[0][1]
tp_prob = 0.0
for c, prob in cand_list:
retrieved.append((c, prob))
tp_prob += prob
if c in answers:
correct += 1
if tp_prob > eps:
break
if len(answers) == 0:
if len(retrieved) == 0:
return 1.0, 1.0, 1.0, 1.0 # precision, recall, f1, hits
else:
return 0.0, 1.0, 0.0, 1.0 # precision, recall, f1, hits
else:
hits = float(best_ans in answers)
if len(retrieved) == 0:
return 1.0, 0.0, 0.0, hits # precision, recall, f1, hits
else:
p, r = correct / len(retrieved), correct / len(answers)
f1 = 2.0 / (1.0 / p + 1.0 / r) if p != 0 and r != 0 else 0.0
return p, r, f1, hits
def calc_f1_new(self, curr_dist, dist_ans, h1_vec):
batch_size = curr_dist.size(0)
max_local_entity = curr_dist.size(1)
seed_dist = self.seed_entities #self.dist_history[0]
local_entity = self.local_entity
ignore_prob = (1 - self.eps) / max_local_entity
pad_ent_id = self.num_entity
# hits_list = []
f1_list = []
for batch_id in range(batch_size):
if h1_vec[batch_id].item() == 0.0:
f1_list.append(0.0)
# we consider cases which own hit@1 as prior to reduce computation time
continue
candidates = local_entity[batch_id, :].tolist()
probs = curr_dist[batch_id, :].tolist()
answer_prob = dist_ans[batch_id, :].tolist()
seed_entities = seed_dist[batch_id, :].tolist()
answer_list = []
candidate2prob = []
for c, p, p_a, s in zip(candidates, probs, answer_prob, seed_entities):
if s > 0:
# ignore seed entities
continue
if c == pad_ent_id:
continue
if p_a > 0:
answer_list.append(c)
if p < ignore_prob:
continue
candidate2prob.append((c, p))
precision, recall, f1, hits = self.f1_and_hits(answer_list, candidate2prob, self.eps)
# hits_list.append(hits)
f1_list.append(f1)
# hits_vec = torch.FloatTensor(hits_list).to(self.device)
f1_vec = torch.FloatTensor(f1_list).to(self.device)
return f1_vec
def calc_h1(self, curr_dist, dist_ans, eps=0.01):
greedy_option = curr_dist.argmax(dim=-1, keepdim=True)
dist_top1 = torch.zeros_like(curr_dist).scatter_(1, greedy_option, 1.0)
dist_ans = (dist_ans > eps).float()
h1 = torch.sum(dist_top1 * dist_ans, dim=-1)
return (h1 > 0).float()
def get_eval_metric(self, pred_dist, answer_dist):
with torch.no_grad():
h1 = self.calc_h1(curr_dist=pred_dist, dist_ans=answer_dist, eps=VERY_SMALL_NUMBER)
f1 = self.calc_f1_new(pred_dist, answer_dist, h1)
return h1, f1
================================================
FILE: gnn/modules/kg_reasoning/base_gnn.py
================================================
import torch
import numpy as np
from collections import defaultdict
VERY_NEG_NUMBER = -100000000000
class BaseGNNLayer(torch.nn.Module):
"""
Builds sparse tensors that represent structure.
"""
def __init__(self, args, num_entity, num_relation):
super(BaseGNNLayer, self).__init__()
self.num_relation = num_relation
self.num_entity = num_entity
self.device = torch.device('cuda' if args['use_cuda'] else 'cpu')
self.normalized_gnn = args['normalized_gnn']
def build_matrix(self):
batch_heads, batch_rels, batch_tails, batch_ids, fact_ids, weight_list, _ = self.edge_list
num_fact = len(fact_ids)
num_relation = self.num_relation
batch_size = self.batch_size
max_local_entity = self.max_local_entity
self.num_fact = num_fact
fact2head = torch.LongTensor([batch_heads, fact_ids]).to(self.device)
fact2tail = torch.LongTensor([batch_tails, fact_ids]).to(self.device)
head2fact = torch.LongTensor([fact_ids, batch_heads]).to(self.device)
tail2fact = torch.LongTensor([fact_ids, batch_tails]).to(self.device)
head2tail = torch.LongTensor([batch_heads, batch_tails]).to(self.device)
rel2fact = torch.LongTensor([fact_ids, batch_rels + batch_ids * num_relation]).to(self.device)
fact2rel = torch.LongTensor([batch_rels + batch_ids * num_relation, fact_ids]).to(self.device)
self.batch_rels = torch.LongTensor(batch_rels).to(self.device)
self.batch_ids = torch.LongTensor(batch_ids).to(self.device)
self.batch_heads = torch.LongTensor(batch_heads).to(self.device)
self.batch_tails = torch.LongTensor(batch_tails).to(self.device)
# self.batch_ids = batch_ids
if self.normalized_gnn:
vals = torch.FloatTensor(weight_list).to(self.device)
else:
vals = torch.ones_like(self.batch_ids).float().to(self.device)
#vals = torch.ones_like(self.batch_ids).float().to(self.device)
# Sparse Matrix for reason on graph
self.fact2head_mat = self._build_sparse_tensor(fact2head, vals, (batch_size * max_local_entity, num_fact))
self.head2fact_mat = self._build_sparse_tensor(head2fact, vals, (num_fact, batch_size * max_local_entity))
self.fact2tail_mat = self._build_sparse_tensor(fact2tail, vals, (batch_size * max_local_entity, num_fact))
self.tail2fact_mat = self._build_sparse_tensor(tail2fact, vals, (num_fact, batch_size * max_local_entity))
self.head2tail_mat = self._build_sparse_tensor(head2tail, vals, (batch_size * max_local_entity, batch_size * max_local_entity))
self.fact2rel_mat = self._build_sparse_tensor(fact2rel, vals, (batch_size * num_relation, num_fact))
self.rel2fact_mat = self._build_sparse_tensor(rel2fact, vals, (num_fact, batch_size * num_relation))
def _build_sparse_tensor(self, indices, values, size):
return torch.sparse.FloatTensor(indices, values, size).to(self.device)
def build_adj_facts(self):
batch_size = self.batch_size
max_local_entity = self.max_local_entity
max_fact = self.max_fact
(e2f_batch, e2f_f, e2f_e, e2f_val), (f2e_batch, f2e_e, f2e_f, f2e_val) = self.edge_list2
entity2fact_index = torch.LongTensor([e2f_batch, e2f_f, e2f_e]).to(self.device)
entity2fact_val = torch.FloatTensor(e2f_val).to(self.device)
self.entity2fact_mat =torch.sparse.FloatTensor(entity2fact_index, entity2fact_val, \
torch.Size([batch_size, max_fact, max_local_entity])).to(self.device) # batch_size, max_fact, max_local_entity
fact2entity_index = torch.LongTensor([f2e_batch, f2e_e, f2e_f]).to(self.device)
fact2entity_val = torch.FloatTensor(f2e_val).to(self.device)
self.fact2entity_mat = torch.sparse.FloatTensor(fact2entity_index, fact2entity_val, \
torch.Size([batch_size, max_local_entity, max_fact])).to(self.device) # batch_size, max_local_entity, max_fact
self.kb_fact_rel = torch.LongTensor(self.kb_fact_rel).to(self.device)
================================================
FILE: gnn/modules/kg_reasoning/graft_gnn.py
================================================
import torch
import numpy as np
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import time
from .base_gnn import BaseGNNLayer
VERY_SMALL_NUMBER = 1e-10
VERY_NEG_NUMBER = -100000000000
class GraftLayer(BaseGNNLayer):
def __init__(self, args, num_entity, num_relation, entity_dim):
super(GraftLayer, self).__init__(args, num_entity, num_relation)
self.num_entity = num_entity
self.num_relation = num_relation
self.entity_dim = entity_dim
self.entity_dim = entity_dim
self.num_layer = args['num_layer']
self.pagerank_lambda = args['pagerank_lambda']
self.fact_scale = args['fact_scale']
self.k = 3
self.init_layers(args)
def init_layers(self, args):
entity_dim = self.entity_dim
self.softmax_d1 = nn.Softmax(dim=1)
self.score_func = nn.Linear(in_features=entity_dim, out_features=1)
self.linear_dropout = args['linear_dropout']
self.linear_drop = nn.Dropout(p=self.linear_dropout)
for i in range(self.num_layer):
self.add_module('q2e_linear' + str(i), nn.Linear(in_features=entity_dim, out_features=entity_dim))
self.add_module('e2q_linear' + str(i), nn.Linear(in_features=self.k * entity_dim, out_features=entity_dim))
self.add_module('e2e_linear' + str(i), nn.Linear(in_features=self.k * entity_dim, out_features=entity_dim))
self.add_module('kb_head_linear' + str(i), nn.Linear(in_features=entity_dim, out_features=entity_dim))
self.add_module('kb_tail_linear' + str(i), nn.Linear(in_features=entity_dim, out_features=entity_dim))
self.add_module('kb_self_linear' + str(i), nn.Linear(in_features=entity_dim, out_features=entity_dim))
def init_reason(self, local_entity, kb_adj_mat, kb_adj_mat_graft, kb_fact_rel, local_entity_emb, rel_features, query_node_emb=None):
batch_size, max_local_entity = local_entity.size()
self.local_entity_mask = (local_entity != self.num_entity).float()
self.batch_size = batch_size
self.max_local_entity = max_local_entity
self.edge_list = kb_adj_mat
self.edge_list2 = kb_adj_mat_graft
self.rel_features = rel_features
self.local_entity_emb = local_entity_emb
self.num_relation = self.rel_features.size(0)
self.possible_cand = []
self.kb_fact_rel = kb_fact_rel #torch.LongTensor(kb_fact_rel).to(self.device)
_, self.max_fact = kb_fact_rel.shape
self.query_node_emb = query_node_emb
self.build_matrix()
self.build_adj_facts()
def compute_attention(self, query_hidden_emb, query_mask):
batch_size = self.batch_size
max_local_entity = self.max_local_entity
rel_features = self.rel_features
num_rels = rel_features.size(0)
#print(self.kb_fact_rel, self.kb_fact_rel.size())
fact_rel = torch.index_select(rel_features, dim=0, index=self.batch_rels)
#batch_rels = rel_features[self.batch_ids, self.batch_rels]
#print(fact_rel.size())
local_fact_emb = rel_features[self.kb_fact_rel]#torch.sparse.mm(self.fact2rel_mat, fact_rel).view(batch_size, -1, self.entity_dim)
# attention fact2question
div = float(np.sqrt(self.entity_dim))
fact2query_sim = torch.bmm(query_hidden_emb, local_fact_emb.transpose(1, 2)) / div # batch_size, max_query_word, max_fact
fact2query_sim = self.softmax_d1(fact2query_sim + (1 - query_mask.unsqueeze(dim=2)) * VERY_NEG_NUMBER) # batch_size, max_query_word, max_fact
fact2query_att = torch.sum(fact2query_sim.unsqueeze(dim=3) * query_hidden_emb.unsqueeze(dim=2), dim=1) # batch_size, max_fact, entity_dim
W = torch.sum(fact2query_att * local_fact_emb, dim=2) / div # batch_size, max_fact
W_max = torch.max(W, dim=1, keepdim=True)[0] # batch_size, 1
self.W_tilde = torch.exp(W - W_max) # batch_size, max_fact
e2f_softmax = torch.bmm(self.entity2fact_mat.transpose(1, 2), self.W_tilde.unsqueeze(dim=2)).squeeze(dim=2) # batch_size, max_local_entity
self.e2f_softmax = torch.clamp(e2f_softmax, min=VERY_SMALL_NUMBER)
#e2f_out_dim = use_cuda(Variable(torch.sum(entity2fact_mat.to_dense(), dim=1), requires_grad=False)) # batch_size, max_local_entity
assert not torch.isnan(self.e2f_softmax).any()
def reason_layer(self, curr_dist, kb_self_linear, kb_head_linear, kb_tail_linear):
batch_size = self.batch_size
max_local_entity = self.max_local_entity
# num_relation = self.num_relation
rel_features = self.rel_features
local_fact_emb = rel_features[self.kb_fact_rel]
e2f_emb = F.relu(kb_self_linear(local_fact_emb) + torch.bmm(self.entity2fact_mat, kb_head_linear(self.linear_drop(self.local_entity_emb)))) # batch_size, max_fact, entity_dim
e2f_softmax_normalized = self.W_tilde.unsqueeze(dim=2) * torch.bmm(self.entity2fact_mat, (curr_dist / self.e2f_softmax).unsqueeze(dim=2)) # batch_size, max_fact, 1
e2f_emb = e2f_emb * e2f_softmax_normalized # batch_size, max_fact, entity_dim
f2e_emb = F.relu(kb_self_linear(self.local_entity_emb) + torch.bmm(self.fact2entity_mat, kb_tail_linear(self.linear_drop(e2f_emb))))
next_curr_dist = torch.bmm(self.fact2entity_mat, e2f_softmax_normalized).squeeze(dim=2)
next_curr_dist = self.pagerank_lambda * next_curr_dist + (1 - self.pagerank_lambda) * curr_dist # batch_size, max_local_entity
assert not torch.isnan(f2e_emb).any()
neighbor_rep = f2e_emb
return neighbor_rep, next_curr_dist
def forward(self, current_dist, query_hidden_emb, query_mask, step=0, return_score=True):
# get linear transformation functions for each layer
q2e_linear = getattr(self, 'q2e_linear' + str(step))
e2e_linear = getattr(self, 'e2e_linear' + str(step))
e2q_linear = getattr(self, 'e2q_linear' + str(step))
kb_self_linear = getattr(self, 'kb_self_linear' + str(step))
kb_head_linear = getattr(self, 'kb_head_linear' + str(step))
kb_tail_linear = getattr(self, 'kb_tail_linear' + str(step))
batch_size = self.batch_size
max_local_entity = self.max_local_entity
if step == 0:
query_node_emb = self.query_node_emb#.unsqueeze(1)
self.compute_attention(query_hidden_emb, query_mask)
#q2e_emb = q2e_linear(self.linear_drop(self.query_node_emb)).expand(batch_size, max_local_entity, self.entity_dim)
else:
query_node_emb = self.query_emb#.unsqueeze(1)
q2e_emb = q2e_linear(self.linear_drop(query_node_emb)).expand(batch_size, max_local_entity, self.entity_dim) # batch_size, max_local_entity, entity_dim
next_local_entity_emb = torch.cat((self.local_entity_emb, q2e_emb), dim=2)
# score_func = getattr(self, 'score_func' + str(step))
score_func = self.score_func
#relational_ins = relational_ins.squeeze(1)
neighbor_rep, next_curr_dist = self.reason_layer(current_dist, kb_self_linear, kb_head_linear, kb_tail_linear)
next_local_entity_emb = torch.cat((next_local_entity_emb, self.fact_scale*neighbor_rep), dim=2)
#self.query_emb = torch.bmm(init_dist.unsqueeze(dim=1), e2q_linear(self.linear_drop(next_local_entity_emb)))
self.query_emb = torch.bmm(next_curr_dist.unsqueeze(dim=1), e2q_linear(self.linear_drop(next_local_entity_emb)))
self.local_entity_emb = F.relu(e2e_linear(self.linear_drop(next_local_entity_emb)))
score_tp = score_func(self.linear_drop(self.local_entity_emb)).squeeze(dim=2)
answer_mask = self.local_entity_mask
self.possible_cand.append(answer_mask)
score = score_tp + (1 - answer_mask) * VERY_NEG_NUMBER
score = self.softmax_d1(score) #F.sigmoid(score) #* self.local_entity_mask #* answer_mask #+ (1 - answer_mask) * VERY_NEG_NUMBER
#current_dist = self.softmax_d1(score_tp)
current_dist = next_curr_dist
if return_score:
return score_tp, score, current_dist
return score_tp, current_dist
================================================
FILE: gnn/modules/kg_reasoning/nsm_gnn.py
================================================
import torch
import numpy as np
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import time
from .base_gnn import BaseGNNLayer
VERY_SMALL_NUMBER = 1e-10
VERY_NEG_NUMBER = -100000000000
class NSMBaseLayer(BaseGNNLayer):
def __init__(self, args, num_entity, num_relation, entity_dim):
super(NSMBaseLayer, self).__init__(args, num_entity, num_relation)
self.num_entity = num_entity
self.num_relation = num_relation
self.entity_dim = entity_dim
self.num_steps = args['num_step']
self.reason_kb = args['reason_kb']
self.init_layers(args)
def init_layers(self, args):
entity_dim = self.entity_dim
self.softmax_d1 = nn.Softmax(dim=1)
self.score_func = nn.Linear(in_features=entity_dim, out_features=1)
self.lin = nn.Linear(in_features=2*entity_dim, out_features=entity_dim)
self.linear_dropout = args['linear_dropout']
self.linear_drop = nn.Dropout(p=self.linear_dropout)
for i in range(self.num_steps):
self.add_module('rel_linear' + str(i), nn.Linear(in_features=entity_dim, out_features=entity_dim))
self.add_module('e2e_linear' + str(i), nn.Linear(in_features=entity_dim + entity_dim, out_features=entity_dim))
def init_reason(self, local_entity, kb_adj_mat, local_entity_emb, rel_features, query_node_emb=None):
batch_size, max_local_entity = local_entity.size()
self.local_entity_mask = (local_entity != self.num_entity).float()
self.batch_size = batch_size
self.max_local_entity = max_local_entity
self.edge_list = kb_adj_mat
self.rel_features = rel_features
self.local_entity_emb = local_entity_emb
self.num_relation = self.rel_features.size(0)
self.possible_cand = []
self.build_matrix()
def reason_layer(self, curr_dist, instruction, rel_linear):
pass
def forward(self, current_dist, relational_ins, step=0, return_score=False):
rel_linear = getattr(self, 'rel_linear' + str(step))
e2e_linear = getattr(self, 'e2e_linear' + str(step))
# score_func = getattr(self, 'score_func' + str(step))
score_func = self.score_func
relational_ins = relational_ins.squeeze(1)
neighbor_rep, possible_tail = self.reason_layer(current_dist, relational_ins, rel_linear)
next_local_entity_emb = torch.cat((self.local_entity_emb, neighbor_rep), dim=2)
self.local_entity_emb = e2e_linear(self.linear_drop(next_local_entity_emb))
self.local_entity_emb = F.relu(self.local_entity_emb)
score_tp = score_func(self.linear_drop(self.local_entity_emb)).squeeze(dim=2)
if self.reason_kb:
answer_mask = self.local_entity_mask * possible_tail
else:
answer_mask = self.local_entity_mask
self.possible_cand.append(answer_mask)
score_tp = score_tp + (1 - answer_mask) * VERY_NEG_NUMBER
current_dist = self.softmax_d1(score_tp)
if return_score:
return score_tp, current_dist
return current_dist
class NSMLayer(NSMBaseLayer):
def __init__(self, args, num_entity, num_relation, entity_dim):
super(NSMLayer, self).__init__(args, num_entity, num_relation, entity_dim)
def reason_layer(self, curr_dist, instruction, rel_linear):
batch_size = self.batch_size
max_local_entity = self.max_local_entity
# num_relation = self.num_relation
rel_features = self.rel_features
fact_rel = torch.index_select(rel_features, dim=0, index=self.batch_rels) #rels (facts), entity_dim
fact_query = torch.index_select(instruction, dim=0, index=self.batch_ids) #one query per batch entry: rels (facts), entity_dim
fact_val = F.relu(rel_linear(fact_rel) * fact_query)
fact_prior = torch.sparse.mm(self.head2fact_mat, curr_dist.view(-1, 1)) #rels (facts), 1 (scaling)
possible_tail = torch.sparse.mm(self.fact2tail_mat, fact_prior) # batch_size * max_local_entity, 1
# (batch_size *max_local_entity, num_fact) (num_fact, 1)
possible_tail = (possible_tail > VERY_SMALL_NUMBER).float().view(batch_size, max_local_entity)
fact_val = fact_val * fact_prior
f2e_emb = torch.sparse.mm(self.fact2tail_mat, fact_val) # batch_size * max_local_entity, entity_dim
assert not torch.isnan(f2e_emb).any()
neighbor_rep = f2e_emb.view(batch_size, max_local_entity, self.entity_dim)
return neighbor_rep, possible_tail
class NSMLayer_back(NSMBaseLayer):
def __init__(self, args, num_entity, num_relation, entity_dim):
super(NSMLayer_back, self).__init__(args, num_entity, num_relation, entity_dim)
def reason_layer(self, curr_dist, instruction, rel_linear):
batch_size = self.batch_size
max_local_entity = self.max_local_entity
# num_relation = self.num_relation
rel_features = self.rel_features_inv
fact_rel = torch.index_select(rel_features, dim=0, index=self.batch_rels)
fact_query = torch.index_select(instruction, dim=0, index=self.batch_ids)
fact_val = F.relu(rel_linear(fact_rel) * fact_query)
fact_prior = torch.sparse.mm(self.tail2fact_mat, curr_dist.view(-1, 1))
possible_head = torch.sparse.mm(self.fact2head_mat, fact_prior)
# (batch_size *max_local_entity, num_fact) (num_fact, 1)
possible_head = (possible_head > VERY_SMALL_NUMBER).float().view(batch_size, max_local_entity)
fact_val = fact_val * fact_prior
f2e_emb = torch.sparse.mm(self.fact2head_mat, fact_val)
assert not torch.isnan(f2e_emb).any()
neighbor_rep = f2e_emb.view(batch_size, max_local_entity, self.entity_dim)
return neighbor_rep, possible_head
================================================
FILE: gnn/modules/kg_reasoning/reasongnn.py
================================================
import torch
import torch.nn.functional as F
import torch.nn as nn
from .base_gnn import BaseGNNLayer
VERY_NEG_NUMBER = -100000000000
class ReasonGNNLayer(BaseGNNLayer):
"""
GNN Reasoning
"""
def __init__(self, args, num_entity, num_relation, entity_dim, alg):
super(ReasonGNNLayer, self).__init__(args, num_entity, num_relation)
self.num_entity = num_entity
self.num_relation = num_relation
self.entity_dim = entity_dim
self.alg = alg
self.num_ins = args['num_ins']
self.num_gnn = args['num_gnn']
self.use_posemb = args['pos_emb']
self.init_layers(args)
def init_layers(self, args):
entity_dim = self.entity_dim
self.softmax_d1 = nn.Softmax(dim=1)
self.score_func = nn.Linear(in_features=entity_dim, out_features=1)
self.glob_lin = nn.Linear(in_features=entity_dim, out_features=entity_dim)
self.lin = nn.Linear(in_features=2*entity_dim, out_features=entity_dim)
assert self.alg == 'bfs'
self.linear_dropout = args['linear_dropout']
self.linear_drop = nn.Dropout(p=self.linear_dropout)
for i in range(self.num_gnn):
self.add_module('rel_linear' + str(i), nn.Linear(in_features=entity_dim, out_features=entity_dim))
if self.alg == 'bfs':
self.add_module('e2e_linear' + str(i), nn.Linear(in_features=2*(self.num_ins)*entity_dim + entity_dim, out_features=entity_dim))
if self.use_posemb:
self.add_module('pos_emb' + str(i), nn.Embedding(self.num_relation, entity_dim))
self.add_module('pos_emb_inv' + str(i), nn.Embedding(self.num_relation, entity_dim))
self.lin_m = nn.Linear(in_features=(self.num_ins)*entity_dim, out_features=entity_dim)
def init_reason(self, local_entity, kb_adj_mat, local_entity_emb, rel_features, rel_features_inv, query_entities, query_node_emb=None):
batch_size, max_local_entity = local_entity.size()
self.local_entity_mask = (local_entity != self.num_entity).float()
self.batch_size = batch_size
self.max_local_entity = max_local_entity
self.edge_list = kb_adj_mat
self.rel_features = rel_features
self.rel_features_inv = rel_features_inv
self.local_entity_emb = local_entity_emb
self.num_relation = self.rel_features.size(0)
self.possible_cand = []
self.build_matrix()
self.query_entities = query_entities
def reason_layer(self, curr_dist, instruction, rel_linear, pos_emb):
"""
Aggregates neighbor representations
"""
batch_size = self.batch_size
max_local_entity = self.max_local_entity
# num_relation = self.num_relation
rel_features = self.rel_features
fact_rel = torch.index_select(rel_features, dim=0, index=self.batch_rels)
fact_query = torch.index_select(instruction, dim=0, index=self.batch_ids)
if pos_emb is not None:
pe = pos_emb(self.batch_rels)
# fact_rel = torch.cat([fact_rel, pe], 1)
fact_val = F.relu((rel_linear(fact_rel)+pe) * fact_query)
else :
fact_val = F.relu(rel_linear(fact_rel) * fact_query)
fact_prior = torch.sparse.mm(self.head2fact_mat, curr_dist.view(-1, 1))
fact_val = fact_val * fact_prior
f2e_emb = torch.sparse.mm(self.fact2tail_mat, fact_val)
assert not torch.isnan(f2e_emb).any()
neighbor_rep = f2e_emb.view(batch_size, max_local_entity, self.entity_dim)
return neighbor_rep
def reason_layer_inv(self, curr_dist, instruction, rel_linear, pos_emb_inv):
batch_size = self.batch_size
max_local_entity = self.max_local_entity
# num_relation = self.num_relation
rel_features = self.rel_features_inv
fact_rel = torch.index_select(rel_features, dim=0, index=self.batch_rels)
fact_query = torch.index_select(instruction, dim=0, index=self.batch_ids)
if pos_emb_inv is not None:
pe = pos_emb_inv(self.batch_rels)
# fact_rel = torch.cat([fact_rel, pe], 1)
fact_val = F.relu((rel_linear(fact_rel)+pe) * fact_query)
else :
fact_val = F.relu(rel_linear(fact_rel) * fact_query)
fact_prior = torch.sparse.mm(self.tail2fact_mat, curr_dist.view(-1, 1))
fact_val = fact_val * fact_prior
f2e_emb = torch.sparse.mm(self.fact2head_mat, fact_val)
assert not torch.isnan(f2e_emb).any()
neighbor_rep = f2e_emb.view(batch_size, max_local_entity, self.entity_dim)
return neighbor_rep
def combine(self,emb):
"""
Combines instruction-specific representations.
"""
local_emb = torch.cat(emb, dim=-1)
local_emb = F.relu(self.lin_m(local_emb))
score_func = self.score_func
score_tp = score_func(self.linear_drop(local_emb)).squeeze(dim=2)
answer_mask = self.local_entity_mask
self.possible_cand.append(answer_mask)
score_tp = score_tp + (1 - answer_mask) * VERY_NEG_NUMBER
current_dist = self.softmax_d1(score_tp)
return current_dist, local_emb
def forward(self, current_dist, relational_ins, step=0, return_score=False):
"""
Compute next probabilistic vectors and current node representations.
"""
rel_linear = getattr(self, 'rel_linear' + str(step))
e2e_linear = getattr(self, 'e2e_linear' + str(step))
# score_func = getattr(self, 'score_func' + str(step))
score_func = self.score_func
neighbor_reps = []
if self.use_posemb :
pos_emb = getattr(self, 'pos_emb' + str(step))
pos_emb_inv = getattr(self, 'pos_emb_inv' + str(step))
else :
pos_emb, pos_emb_inv = None, None
for j in range(relational_ins.size(1)):
# we do the same procedure for existing and inverse relations
neighbor_rep = self.reason_layer(current_dist, relational_ins[:,j,:], rel_linear, pos_emb)
neighbor_reps.append(neighbor_rep)
neighbor_rep = self.reason_layer_inv(current_dist, relational_ins[:,j,:], rel_linear, pos_emb_inv)
neighbor_reps.append(neighbor_rep)
neighbor_reps = torch.cat(neighbor_reps, dim=2)
next_local_entity_emb = torch.cat((self.local_entity_emb, neighbor_reps), dim=2)
#print(next_local_entity_emb.size())
self.local_entity_emb = F.relu(e2e_linear(self.linear_drop(next_local_entity_emb)))
score_tp = score_func(self.linear_drop(self.local_entity_emb)).squeeze(dim=2)
answer_mask = self.local_entity_mask
self.possible_cand.append(answer_mask)
score_tp = score_tp + (1 - answer_mask) * VERY_NEG_NUMBER
current_dist = self.softmax_d1(score_tp)
if return_score:
return score_tp, current_dist
return current_dist, self.local_entity_emb
================================================
FILE: gnn/modules/layer_init.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
VERY_NEG_NUMBER = -100000000000
VERY_SMALL_NUMBER = 1e-10
class TypeLayer(nn.Module):
"""
Sparse version GAT layer, similar to https://arxiv.org/abs/1710.10903
"""
def __init__(self, in_features, out_features, linear_drop, device, norm_rel):
super(TypeLayer, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.linear_drop = linear_drop
# self.kb_head_linear = nn.Linear(in_features, out_features)
self.kb_self_linear = nn.Linear(in_features, out_features)
# self.kb_tail_linear = nn.Linear(out_features, out_features)
self.device = device
self.norm_rel = norm_rel
def forward(self, local_entity, edge_list, rel_features):
'''
input_vector: (batch_size, max_local_entity)
curr_dist: (batch_size, max_local_entity)
instruction: (batch_size, hidden_size)
'''
batch_heads, batch_rels, batch_tails, batch_ids, fact_ids, weight_list, weight_rel_list = edge_list
num_fact = len(fact_ids)
batch_size, max_local_entity = local_entity.size()
hidden_size = self.in_features
fact2head = torch.LongTensor([batch_heads, fact_ids]).to(self.device)
fact2tail = torch.LongTensor([batch_tails, fact_ids]).to(self.device)
batch_rels = torch.LongTensor(batch_rels).to(self.device)
batch_ids = torch.LongTensor(batch_ids).to(self.device)
if self.norm_rel:
val_one = torch.FloatTensor(weight_rel_list).to(self.device) #* torch.FloatTensor(weight_list).to(self.device)
else:
val_one = torch.ones_like(batch_ids).float().to(self.device)
# print("Prepare data:{:.4f}".format(time.time() - st))
# Step 1: Calculate value for every fact with rel and head
fact_rel = torch.index_select(rel_features, dim=0, index=batch_rels)
# fact_val = F.relu(self.kb_self_linear(fact_rel) + self.kb_head_linear(self.linear_drop(fact_ent)))
fact_val = self.kb_self_linear(fact_rel)
# fact_val = self.kb_self_linear(fact_rel)#self.kb_head_linear(self.linear_drop(fact_ent))
# Step 3: Edge Aggregation with Sparse MM
fact2tail_mat = self._build_sparse_tensor(fact2tail, val_one, (batch_size * max_local_entity, num_fact))
fact2head_mat = self._build_sparse_tensor(fact2head, val_one, (batch_size * max_local_entity, num_fact))
# neighbor_rep = torch.sparse.mm(fact2tail_mat, self.kb_tail_linear(self.linear_drop(fact_val)))
f2e_emb = F.relu(torch.sparse.mm(fact2tail_mat, fact_val) + torch.sparse.mm(fact2head_mat, fact_val))
assert not torch.isnan(f2e_emb).any()
f2e_emb = f2e_emb.view(batch_size, max_local_entity, hidden_size)
return f2e_emb
def _build_sparse_tensor(self, indices, values, size):
return torch.sparse.FloatTensor(indices, values, size).to(self.device)
================================================
FILE: gnn/modules/query_update.py
================================================
import torch
import numpy as np
import torch.nn.functional as F
import torch.nn as nn
class Fusion(nn.Module):
"""docstring for Fusion"""
def __init__(self, d_hid):
super(Fusion, self).__init__()
self.r = nn.Linear(d_hid*3, d_hid, bias=False)
self.g = nn.Linear(d_hid*3, d_hid, bias=False)
def forward(self, x, y):
r_ = self.r(torch.cat([x,y,x-y], dim=-1))#.tanh()
g_ = torch.sigmoid(self.g(torch.cat([x,y,x-y], dim=-1)))
return g_ * r_ + (1 - g_) * x
class QueryReform(nn.Module):
"""docstring for QueryReform"""
def __init__(self, h_dim):
super(QueryReform, self).__init__()
# self.q_encoder = AttnEncoder(h_dim)
self.fusion = Fusion(h_dim)
self.q_ent_attn = nn.Linear(h_dim, h_dim)
def forward(self, q_node, ent_emb, seed_info, ent_mask):
'''
q: (B,q_len,h_dim)
q_mask: (B,q_len)
q_ent_span: (B,q_len)
ent_emb: (B,C,h_dim)
seed_info: (B, C)
ent_mask: (B, C)
'''
# q_node = self.q_encoder(q, q_mask)
q_ent_attn = (self.q_ent_attn(q_node).unsqueeze(1) * ent_emb).sum(2, keepdim=True)
q_ent_attn = F.softmax(q_ent_attn - (1 - ent_mask.unsqueeze(2)) * 1e8, dim=1)
attn_retrieve = (q_ent_attn * ent_emb).sum(1)
seed_retrieve = torch.bmm(seed_info.unsqueeze(1), ent_emb).squeeze(1) # (B, 1, h_dim)
# how to calculate the gate
#return self.fusion(q_node, attn_retrieve)
return self.fusion(q_node, seed_retrieve)
class AttnEncoder(nn.Module):
"""docstring for ClassName"""
def __init__(self, d_hid):
super(AttnEncoder, self).__init__()
self.attn_linear = nn.Linear(d_hid, 1, bias=False)
def forward(self, x, x_mask):
"""
x: (B, len, d_hid)
x_mask: (B, len)
return: (B, d_hid)
"""
x_attn = self.attn_linear(x)
x_attn = x_attn - (1 - x_mask.unsqueeze(2))*1e8
x_attn = F.softmax(x_attn, dim=1)
return (x*x_attn).sum(1)
class Attention(nn.Module):
""" Applies attention mechanism on the `context` using the `query`.
**Thank you** to IBM for their initial implementation of :class:`Attention`. Here is
their `License
<https://github.com/IBM/pytorch-seq2seq/blob/master/LICENSE>`__.
Args:
dimensions (int): Dimensionality of the query and context.
attention_type (str, optional): How to compute the attention score:
* dot: :math:`score(H_j,q) = H_j^T q`
* general: :math:`score(H_j, q) = H_j^T W_a q`
Example:
>>> attention = Attention(256)
>>> query = torch.randn(5, 1, 256)
>>> context = torch.randn(5, 5, 256)
>>> output, weights = attention(query, context)
>>> output.size()
torch.Size([5, 1, 256])
>>> weights.size()
torch.Size([5, 1, 5])
"""
def __init__(self, dimensions, attention_type='general'):
super(Attention, self).__init__()
if attention_type not in ['dot', 'general']:
raise ValueError('Invalid attention type selected.')
self.attention_type = attention_type
if self.attention_type == 'general':
self.linear_in = nn.Linear(dimensions, dimensions, bias=False)
self.linear_out = nn.Linear(dimensions * 2, dimensions, bias=False)
self.softmax = nn.Softmax(dim=-1)
self.tanh = nn.Tanh()
def forward(self, query, context):
"""
Args:
query (:class:`torch.FloatTensor` [batch size, output length, dimensions]): Sequence of
queries to query the context.
context (:class:`torch.FloatTensor` [batch size, query length, dimensions]): Data
overwhich to apply the attention mechanism.
Returns:
:class:`tuple` with `output` and `weights`:
* **output** (:class:`torch.LongTensor` [batch size, output length, dimensions]):
Tensor containing the attended features.
* **weights** (:class:`torch.FloatTensor` [batch size, output length, query length]):
Tensor containing attention weights.
"""
batch_size, output_len, dimensions = query.size()
query_len = context.size(1)
if self.attention_type == "general":
query = query.reshape(batch_size * output_len, dimensions)
query = self.linear_in(query)
query = query.reshape(batch_size, output_len, dimensions)
# TODO: Include mask on PADDING_INDEX?
# (batch_size, output_len, dimensions) * (batch_size, query_len, dimensions) ->
# (batch_size, output_len, query_len)
attention_scores = torch.bmm(query, context.transpose(1, 2).contiguous())
# Compute weights across every context sequence
attention_scores = attention_scores.view(batch_size * output_len, query_len)
attention_weights = self.softmax(attention_scores)
attention_weights = attention_weights.view(batch_size, output_len, query_len)
# (batch_size, output_len, query_len) * (batch_size, query_len, dimensions) ->
# (batch_size, output_len, dimensions)
mix = torch.bmm(attention_weights, context)
# concat -> (batch_size * output_len, 2*dimensions)
combined = torch.cat((mix, query), dim=2)
combined = combined.view(batch_size * output_len, 2 * dimensions)
# Apply linear_out on every 2nd dimension of concat
# output -> (batch_size, output_len, dimensions)
output = self.linear_out(combined).view(batch_size, output_len, dimensions)
output = self.tanh(output)
return output, attention_weights
================================================
FILE: gnn/modules/question_encoding/base_encoder.py
================================================
import torch
import torch.nn.functional as F
import torch.nn as nn
VERY_SMALL_NUMBER = 1e-10
VERY_NEG_NUMBER = -100000000000
class BaseInstruction(torch.nn.Module):
def __init__(self, args, constraint):
super(BaseInstruction, self).__init__()
self.constraint = constraint
self._parse_args(args)
self.share_module_def()
def _parse_args(self, args):
self.device = torch.device('cuda' if args['use_cuda'] else 'cpu')
# self.share_encoder = args['share_encoder']
self.q_type = args['q_type']
if 'num_step' in args:
self.num_ins = args['num_step']
elif 'num_ins' in args:
self.num_ins = args['num_ins']
elif 'num_layer' in args:
self.num_ins = args['num_layer']
elif 'num_expansion_ins' in args and 'num_backup_ins' in args:
self.num_ins = args['num_backup_ins'] if self.constraint else args['num_expansion_ins']
else:
self.num_ins = 1
self.lm_dropout = args['lm_dropout']
self.linear_dropout = args['linear_dropout']
self.lm_frozen = args['lm_frozen']
for k, v in args.items():
if k.endswith('dim'):
setattr(self, k, v)
if k.endswith('emb_file') or k.endswith('kge_file'):
if v is None:
setattr(self, k, None)
else:
setattr(self, k, args['data_folder'] + v)
self.reset_time = 0
def share_module_def(self):
# dropout
self.lstm_drop = nn.Dropout(p=self.lm_dropout)
self.linear_drop = nn.Dropout(p=self.linear_dropout)
def init_hidden(self, num_layer, batch_size, hidden_size):
return (torch.zeros(num_layer, batch_size, hidden_size).to(self.device),
torch.zeros(num_layer, batch_size, hidden_size).to(self.device))
def encode_question(self, *args):
# constituency tree or query_text
pass
@staticmethod
def get_node_emb(query_hidden_emb, action):
'''
:param query_hidden_emb: (batch_size, max_hyper, emb)
:param action: (batch_size)
:return: (batch_size, 1, emb)
'''
batch_size, max_hyper, _ = query_hidden_emb.size()
row_idx = torch.arange(0, batch_size).type(torch.LongTensor)
q_rep = query_hidden_emb[row_idx, action, :]
return q_rep.unsqueeze(1)
def init_reason(self, query_text):
self.batch_size = query_text.size(0)
self.max_query_word = query_text.size(1)
self.encode_question(query_text)
self.relational_ins = torch.zeros(self.batch_size, self.entity_dim).to(self.device)
self.instructions = []
self.attn_list = []
def get_instruction(self, relational_ins, step=0, query_node_emb=None):
query_hidden_emb = self.query_hidden_emb
query_mask = self.query_mask
if query_node_emb is None:
query_node_emb = self.query_node_emb
relational_ins = relational_ins.unsqueeze(1)
question_linear = getattr(self, 'question_linear' + str(step))
q_i = question_linear(self.linear_drop(query_node_emb))
cq = self.cq_linear(self.linear_drop(torch.cat((relational_ins, q_i, q_i-relational_ins,q_i*relational_ins), dim=-1)))
# batch_size, 1, entity_dim
ca = self.ca_linear(self.linear_drop(cq * query_hidden_emb))
# batch_size, max_local_entity, 1
# cv = self.softmax_d1(ca + (1 - query_mask.unsqueeze(2)) * VERY_NEG_NUMBER)
attn_weight = F.softmax(ca + (1 - query_mask.unsqueeze(2)) * VERY_NEG_NUMBER, dim=1)
# batch_size, max_local_entity, 1
relational_ins = torch.sum(attn_weight * query_hidden_emb, dim=1)
return relational_ins, attn_weight
def forward(self, query_text, lm=None):
if lm is not None:
self.node_encoder = lm
self.init_reason(query_text)
for i in range(self.num_ins):
relational_ins, attn_weight = self.get_instruction(self.relational_ins, step=i)
self.instructions.append(relational_ins)
self.attn_list.append(attn_weight)
self.relational_ins = relational_ins
return self.instructions, self.attn_list
================================================
FILE: gnn/modules/question_encoding/bert_encoder.py
================================================
import torch.nn.functional as F
import torch.nn as nn
VERY_SMALL_NUMBER = 1e-10
VERY_NEG_NUMBER = -100000000000
from transformers import AutoModel, AutoTokenizer #DistilBertModel, BertModel, BertTokenizer, RobertaModel, RobertaTokenizer
from torch.nn import LayerNorm
import warnings
warnings.filterwarnings("ignore")
import os
os.environ['TRANSFORMERS_CACHE'] = '/export/scratch/costas/home/mavro016/.cache'
from .base_encoder import BaseInstruction
class BERTInstruction(BaseInstruction):
def __init__(self, args, word_embedding, num_word, model, constraint=False):
super(BERTInstruction, self).__init__(args, constraint)
self.word_embedding = word_embedding
self.num_word = num_word
self.constraint = constraint
entity_dim = self.entity_dim
self.model = model
if model == 'bert':
self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
self.pretrained_weights = 'bert-base-uncased'
word_dim = 768#self.word_dim
elif model == 'roberta':
self.tokenizer = AutoTokenizer.from_pretrained('roberta-base')
self.pretrained_weights = 'roberta-base'
word_dim = 768#self.word_dim
elif model == 'sbert':
self.tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
self.pretrained_weights = 'sentence-transformers/all-MiniLM-L6-v2'
word_dim = 384#self.word_dim
elif model == 'simcse':
#print('ok')
self.tokenizer = AutoTokenizer.from_pretrained('princeton-nlp/sup-simcse-bert-base-uncased')
self.pretrained_weights = 'princeton-nlp/sup-simcse-bert-base-uncased'
word_dim = 768#self.word_dim
elif model == 'sbert2':
#tokenizer_name = 'sentence-transformers/all-mpnet-base-v2'
self.tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
self.pretrained_weights = 'sentence-transformers/all-mpnet-base-v2'
word_dim = 768#self.word_dim
elif model == 't5':
self.tokenizer = AutoTokenizer.from_pretrained('t5-small')
self.pretrained_weights = 't5-small'
word_dim = 768#self.word_dim
elif model == 'relbert':
self.tokenizer = AutoTokenizer.from_pretrained('pretrained_lms/sr-simbert/')
self.pretrained_weights = 'pretrained_lms/sr-simbert/'
word_dim = 768
#self.mask = mask
self.pad_val = self.tokenizer.convert_tokens_to_ids(self.tokenizer.pad_token)
self.word_dim = word_dim
print('word_dim', self.word_dim)
self.cq_linear = nn.Linear(in_features=4 * entity_dim, out_features=entity_dim)
self.ca_linear = nn.Linear(in_features=entity_dim, out_features=1)
for i in range(self.num_ins):
self.add_module('question_linear' + str(i), nn.Linear(in_features=entity_dim, out_features=entity_dim))
self.question_emb = nn.Linear(in_features=word_dim, out_features=entity_dim)
if not self.constraint:
self.encoder_def()
def encoder_def(self):
# initialize entity embedding
word_dim = self.word_dim
entity_dim = self.entity_dim
self.node_encoder = AutoModel.from_pretrained(self.pretrained_weights)
print('Total Params', sum(p.numel() for p in self.node_encoder.parameters()))
if self.lm_frozen == 1:
print('Freezing LM params')
for param in self.node_encoder.parameters():
param.requires_grad = False
else:
for param in self.node_encoder.parameters():
param.requires_grad = True
print('Unfrozen LM params')
def encode_question(self, query_text, store=True):
batch_size = query_text.size(0)
if self.model != 't5':
query_hidden_emb = self.node_encoder(query_text)[0] # 1, batch_size, entity_dim
else:
query_hidden_emb = self.node_encoder.encoder(query_text)[0]
#print(query_hidden_emb.size())
if store:
self.query_hidden_emb = self.question_emb(query_hidden_emb)
self.query_node_emb = query_hidden_emb.transpose(1,0)[0].unsqueeze(1)
#print(self.query_node_emb.size())
self.query_node_emb = self.question_emb(self.query_node_emb)
self.query_mask = (query_text != self.pad_val).float()
return query_hidden_emb, self.query_node_emb
else:
return query_hidden_emb
================================================
FILE: gnn/modules/question_encoding/lstm_encoder.py
================================================
import torch.nn as nn
from utils import get_dict
from .base_encoder import BaseInstruction
VERY_SMALL_NUMBER = 1e-10
VERY_NEG_NUMBER = -100000000000
class LSTMInstruction(BaseInstruction):
def __init__(self, args, word_embedding, num_word):
super(LSTMInstruction, self).__init__(args)
self.word2id = get_dict(args['data_folder'],args['word2id'])
self.word_embedding = word_embedding
self.num_word = num_word
self.encoder_def()
entity_dim = self.entity_dim
self.cq_linear = nn.Linear(in_features=4 * entity_dim, out_features=entity_dim)
self.ca_linear = nn.Linear(in_features=entity_dim, out_features=1)
for i in range(self.num_ins):
self.add_module('question_linear' + str(i), nn.Linear(in_features=entity_dim, out_features=entity_dim))
def encoder_def(self):
# initialize entity embedding
word_dim = self.word_dim
entity_dim = self.entity_dim
self.node_encoder = nn.LSTM(input_size=word_dim, hidden_size=entity_dim,
batch_first=True, bidirectional=False)
def encode_question(self, query_text, store=True):
batch_size = query_text.size(0)
query_word_emb = self.word_embedding(query_text) # batch_size, max_query_word, word_dim
query_hidden_emb, (h_n, c_n) = self.node_encoder(self.lstm_drop(query_word_emb),
self.init_hidden(1, batch_size,
self.entity_dim)) # 1, batch_size, entity_dim
if store:
self.instruction_hidden = h_n
self.instruction_mem = c_n
self.query_node_emb = h_n.squeeze(dim=0).unsqueeze(dim=1) # batch_size, 1, entity_dim
self.query_hidden_emb = query_hidden_emb
self.query_mask = (query_text != self.num_word).float()
return query_hidden_emb, self.query_node_emb
else:
return query_hidden_emb
================================================
FILE: gnn/modules/question_encoding/tokenizers.py
================================================
import re
import numpy as np
from transformers import BertTokenizer
class LSTMTokenizer():
def __init__(self, word2id, max_query_word):
super(LSTMTokenizer, self).__init__()
self.word2id = word2id
self.max_query_word = max_query_word
def tokenize(self, question):
tokens = self.tokenize_sent(question)
query_text = np.full(self.max_query_word, len(self.word2id), dtype=int)
#tokens = question.split()
#if self.data_type == "train":
# random.shuffle(tokens)
for j, word in enumerate(tokens):
if j < self.max_query_word:
if word in self.word2id:
query_text[j] = self.word2id[word]
else:
query_text[j] = len(self.word2id)
return query_text
@staticmethod
def tokenize_sent(question_text):
question_text = question_text.strip().lower()
question_text = re.sub('\'s', ' s', question_text)
words = []
toks = enumerate(question_text.split(' '))
for w_idx, w in toks:
w = re.sub('^[^a-z0-9]|[^a-z0-9]$', '', w)
if w == '':
continue
words += [w]
return words
class BERTTokenizer():
def __init__(self, max_query_word):
super(BERTTokenizer, self).__init__()
self.q_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
self.max_query_word = max_query_word
self.num_word = self.q_tokenizer.encode("[UNK]")[0] #len(self.q_tokenizer.vocab.keys())
def tokenize(self, question):
query_text = np.full(self.max_query_word, 0, dtype=int)
tokens = self.q_tokenizer.encode_plus(text=question, max_length=self.max_query_word, \
pad_to_max_length=True, return_attention_mask = False, truncation=True)
return np.array(tokens['input_ids'])
================================================
FILE: gnn/parsing.py
================================================
import argparse
import sys
def bool_flag(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def add_shared_args(parser):
parser.add_argument('--name', default='webqsp', type=str)
parser.add_argument('--data_folder', default='data/webqsp/', type=str)
parser.add_argument('--max_train', default=200000, type=int)
# embeddings
parser.add_argument('--word2id', default='vocab.txt', type=str)
parser.add_argument('--relation2id', default='relations.txt', type=str)
parser.add_argument('--entity2id', default='entities.txt', type=str)
parser.add_argument('--char2id', default='chars.txt', type=str)
parser.add_argument('--entity_emb_file', default=None, type=str)
parser.add_argument('--relation_emb_file', default=None, type=str)
parser.add_argument('--relation_word_emb', default=True, type=bool_flag)
parser.add_argument('--word_emb_file', default='word_emb.npy', type=str)
parser.add_argument('--rel_word_ids', default='rel_word_idx.npy', type=str)
parser.add_argument('--kge_frozen', default=0, type=int)
parser.add_argument('--lm', default='lstm', type=str, choices=['lstm', 'bert', 'roberta', 'sbert', 't5','sbert2', 'dbert', 'simcse', 'relbert'])
parser.add_argument('--lm_frozen', default=1, type=int)
# dimensions, layers, dropout
parser.add_argument('--entity_dim', default=50, type=int)
parser.add_argument('--kg_dim', default=100, type=int)
parser.add_argument('--word_dim', default=300, type=int)
parser.add_argument('--lm_dropout', default=0.3, type=float)
parser.add_argument('--linear_dropout', default=0.2, type=float)
# optimization
parser.add_argument('--num_epoch', default=100, type=int)
parser.add_argument('--warmup_epoch', default=0, type=int)
parser.add_argument('--fact_scale', default=3, type=int)
parser.add_argument('--eval_every', default=2, type=int)
parser.add_argument('--batch_size', default=20, type=int)
parser.add_argument('--gradient_clip', default=1.0, type=float)
parser.add_argument('--lr', default=0.0005, type=float)
parser.add_argument('--decay_rate', default=0.0, type=float)
parser.add_argument('--seed', default=19960626, type=int)
parser.add_argument('--lr_schedule', action='store_true')
parser.add_argument('--label_smooth', default=0.1, type=float)
parser.add_argument('--fact_drop', default=0, type=float)
#parser.add_argument('--encode_type', action='store_true')
# model options
parser.add_argument('--is_eval', action='store_true')
parser.add_argument('--checkpoint_dir', default='checkpoint/pretrain/', type=str)
parser.add_argument('--log_level', type=str, default='info')
parser.add_argument('--experiment_name', default='', type=str)
parser.add_argument('--load_experiment', default=None, type=str)
parser.add_argument('--load_ckpt_file', default=None, type=str)
parser.add_argument('--eps', default=0.95, type=float) # threshold for f1
parser.add_argument('--test_batch_size', default=20, type=int)
parser.add_argument('--q_type', default='seq', type=str)
def add_parse_args(parser):
subparsers = parser.add_subparsers(help='Reason KGQA model')
parser_rearev = subparsers.add_parser("ReaRev")
create_parser_rearev(parser_rearev)
parser_nsm = subparsers.add_parser("NSM")
create_parser_nsm(parser_nsm)
parser_graftnet = subparsers.add_parser("GraftNet")
create_parser_graftnet(parser_graftnet)
parser_nutrea = subparsers.add_parser("NuTrea")
create_parser_nutrea(parser_nutrea)
def create_parser_rearev(parser):
parser.add_argument('--model_name', default='ReaRev', type=str, choices=['ReaRev'])
parser.add_argument('--alg', default='bfs', type=str)
parser.add_argument('--num_iter', default=2, type=int)
parser.add_argument('--num_ins', default=3, type=int)
parser.add_argument('--num_gnn', default=3, type=int)
parser.add_argument('--loss_type', default='kl', type=str)
parser.add_argument('--use_self_loop', default=True, type=bool_flag)
parser.add_argument('--normalized_gnn', default=False, type=bool_flag)
parser.add_argument('--norm_rel', action='store_true')
parser.add_argument('--data_eff', action='store_true')
parser.add_argument('--pos_emb', action='store_true')
add_shared_args(parser)
def create_parser_nsm(parser):
parser.add_argument('--model_name', default='NSM', type=str, choices=['NSM'])
parser.add_argument('--num_step', default=3, type=int)
parser.add_argument('--reason_kb', default=False, type=bool_flag)
parser.add_argument('--loss_type', default='kl', type=str)
parser.add_argument('--lambda_constrain', default=0.0, type=float)
parser.add_argument('--lambda_back', default=0.0, type=float)
parser.add_argument('--use_self_loop', default=True, type=bool_flag)
parser.add_argument('--use_inverse_relation', action='store_true')
parser.add_argument('--norm_rel', action='store_true')
parser.add_argument('--normalized_gnn', default=False, type=bool_flag)
parser.add_argument('--data_eff', action='store_true')
add_shared_args(parser)
def create_parser_graftnet(parser):
parser.add_argument('--model_name', default='GraftNet', type=str, choices=['GraftNet'])
parser.add_argument('--pagerank_lambda', default=0.8, type=float)
parser.add_argument('--loss_type', default='bce', type=str)
parser.add_argument('--num_layer', default=3, type=int)
parser.add_argument('--use_inverse_relation', action='store_true')
parser.add_argument('--norm_rel', action='store_true')
parser.add_argument('--normalized_gnn', default=False, type=bool_flag)
parser.add_argument('--data_eff', action='store_true')
#parser.add_argument('--use_self_loop', default=True, type=bool_flag)
add_shared_args(parser)
================================================
FILE: gnn/requirements.txt
================================================
Base==1.0.4
numpy==1.19.5
torch==1.7.1+cu110
tqdm==4.59.0
transformers==4.6.1
================================================
FILE: gnn/scripts/rearev_cwq.sh
================================================
###ReaRev+SBERT training
# python main.py ReaRev --is_eval --load_experiment relbert-full_cwq-rearev-final.ckpt --entity_dim 50 --num_epoch 200 --batch_size 8 --eval_every 2 \
# --lm relbert --num_iter 2 --num_ins 3 --num_gnn 3 --name cwq \
# --experiment_name prn_cwq-rearev-sbert --data_folder data/CWQ/ --num_epoch 100 --warmup_epoch 80
###ReaRev+LMSR training
# python main.py ReaRev --entity_dim 50 --num_epoch 200 --batch_size 8 --eval_every 2 \
# --lm relbert --num_iter 2 --num_ins 3 --num_gnn 3 --name cwq \
# --experiment_name prn_cwq-rearev-lmsr --data_folder data/CWQ/ --num_epoch 100 #--warmup_epoch 80
###Evaluate CWQ
python main.py ReaRev --entity_dim 50 --num_epoch 100 --batch_size 8 --eval_every 2 --data_folder data/CWQ/ --lm sbert --num_iter 2 --num_ins 3 --num_gnn 3 --relation_word_emb True --load_experiment ReaRev_CWQ.ckpt --is_eval --name cwq
================================================
FILE: gnn/train_model.py
================================================
from utils import create_logger
import time
import numpy as np
import os, math
import torch
from torch.optim.lr_scheduler import ExponentialLR
import torch.optim as optim
from tqdm import tqdm
tqdm.monitor_iterval = 0
#from dataset_load_paths import load_data
from dataset_load import load_data
from dataset_load_graft import load_data_graft
from models.ReaRev.rearev import ReaRev
from models.NSM.nsm import NSM
from models.GraftNet.graftnet import GraftNet
from evaluate import Evaluator
class Trainer_KBQA(object):
def __init__(self, args, model_name, logger=None):
#print('Trainer here')
self.args = args
self.logger = logger
self.best_dev_performance = 0.0
self.best_h1 = 0.0
self.best_f1 = 0.0
self.best_h1b = 0.0
self.best_f1b = 0.0
self.eps = args['eps']
self.warmup_epoch = args['warmup_epoch']
self.learning_rate = self.args['lr']
self.test_batch_size = args['test_batch_size']
self.device = torch.device('cuda' if args['use_cuda'] else 'cpu')
self.reset_time = 0
self.load_data(args, args['lm'])
if 'decay_rate' in args:
self.decay_rate = args['decay_rate']
else:
self.decay_rate = 0.98
if model_name == 'ReaRev':
self.model = ReaRev(self.args, len(self.entity2id), self.num_kb_relation,
self.num_word)
elif model_name == 'NSM':
self.model = NSM(self.args, len(self.entity2id), self.num_kb_relation,
self.num_word)
elif model_name == 'GraftNet':
self.model = GraftNet(self.args, len(self.entity2id), self.num_kb_relation,
self.num_word)
elif model_name == 'NuTrea':
self.model = NuTrea(self.args, len(self.entity2id), self.num_kb_relation,
self.num_word)
if args['relation_word_emb']:
#self.model.use_rel_texts(self.rel_texts, self.rel_texts_inv)
self.model.encode_rel_texts(self.rel_texts, self.rel_texts_inv)
self.model.to(self.device)
self.evaluator = Evaluator(args=args, model=self.model, entity2id=self.entity2id,
relation2id=self.relation2id, device=self.device)
self.load_pretrain()
self.optim_def()
self.num_relation = self.num_kb_relation
self.num_entity = len(self.entity2id)
self.num_word = len(self.word2id)
print("Entity: {}, Relation: {}, Word: {}".format(self.num_entity, self.num_relation, self.num_word))
for k, v in args.items():
if k.endswith('dim'):
setattr(self, k, v)
if k.endswith('emb_file') or k.endswith('kge_file'):
if v is None:
setattr(self, k, None)
else:
setattr(self, k, args['data_folder'] + v)
def optim_def(self):
trainable = filter(lambda p: p.requires_grad, self.model.parameters())
self.optim_model = optim.Adam(trainable, lr=self.learning_rate)
if self.decay_rate > 0:
self.scheduler = ExponentialLR(self.optim_model, self.decay_rate)
def load_data(self, args, tokenize):
if args["model_name"] == "GraftNet":
dataset = load_data_graft(args, tokenize)
else:
dataset = load_data(args, tokenize)
self.train_data = dataset["train"]
self.valid_data = dataset["valid"]
self.test_data = dataset["test"]
self.entity2id = dataset["entity2id"]
self.relation2id = dataset["relation2id"]
self.word2id = dataset["word2id"]
self.num_word = dataset["num_word"]
self.num_kb_relation = self.test_data.num_kb_relation
self.num_entity = len(self.entity2id)
self.rel_texts = dataset["rel_texts"]
self.rel_texts_inv = dataset["rel_texts_inv"]
def load_pretrain(self):
args = self.args
if args['load_experiment'] is not None:
ckpt_path = os.path.join(args['checkpoint_dir'], args['load_experiment'])
print("Load ckpt from", ckpt_path)
self.load_ckpt(ckpt_path)
def evaluate(self, data, test_batch_size=20, write_info=False):
return self.evaluator.evaluate(data, test_batch_size, write_info)
def train(self, start_epoch, end_epoch):
# self.load_pretrain()
eval_every = self.args['eval_every']
# eval_acc = inference(self.model, self.valid_data, self.entity2id, self.args)
# self.evaluate(self.test_data, self.test_batch_size)
print("Start Training------------------")
for epoch in range(start_epoch, end_epoch + 1):
st = time.time()
loss, extras, h1_list_all, f1_list_all = self.train_epoch()
if self.decay_rate > 0:
self.scheduler.step()
self.logger.info("Epoch: {}, loss : {:.4f}, time: {}".format(epoch + 1, loss, time.time() - st))
self.logger.info("Training h1 : {:.4f}, f1 : {:.4f}".format(np.mean(h1_list_all), np.mean(f1_list_all)))
if (epoch + 1) % eval_every == 0:
eval_f1, eval_h1, eval_em = self.evaluate(self.valid_data, self.test_batch_size)
self.logger.info("EVAL F1: {:.4f}, H1: {:.4f}, EM {:.4f}".format(eval_f1, eval_h1, eval_em))
# eval_f1, eval_h1 = self.evaluate(self.test_data, self.test_batch_size)
# self.logger.info("TEST F1: {:.4f}, H1: {:.4f}".format(eval_f1, eval_h1))
do_test = False
if epoch > self.warmup_epoch:
if eval_h1 > self.best_h1:
self.best_h1 = eval_h1
self.save_ckpt("h1")
self.logger.info("BEST EVAL H1: {:.4f}".format(eval_h1))
do_test = True
if eval_f1 > self.best_f1:
self.best_f1 = eval_f1
self.save_ckpt("f1")
self.logger.info("BEST EVAL F1: {:.4f}".format(eval_f1))
do_test = True
eval_f1, eval_h1, eval_em = self.evaluate(self.test_data, self.test_batch_size)
self.logger.info("TEST F1: {:.4f}, H1: {:.4f}, EM {:.4f}".format(eval_f1, eval_h1, eval_em))
# if do_test:
# eval_f1, eval_h1 = self.evaluate(self.test_data, self.test_batch_size)
# self.logger.info("TEST F1: {:.4f}, H1: {:.4f}".format(eval_f1, eval_h1))
# if eval_h1 > self.best_h1:
# self.best_h1 = eval_h1
# self.save_ckpt("h1")
# if eval_f1 > self.best_f1:
# self.best_f1 = eval_f1
# self.save_ckpt("f1")
# self.reset_time = 0
# else:
# self.logger.info('No improvement after one evaluation iter.')
# self.reset_time += 1
# if self.reset_time >= 5:
# self.logger.info('No improvement after 5 evaluation. Early Stopping.')
# break
self.save_ckpt("final")
self.logger.info('Train Done! Evaluate on testset with saved model')
print("End Training------------------")
self.evaluate_best()
def evaluate_best(self):
filename = os.path.join(self.args['checkpoint_dir'], "{}-h1.ckpt".format(self.args['experiment_name']))
self.load_ckpt(filename)
eval_f1, eval_h1, eval_em = self.evaluate(self.test_data, self.test_batch_size, write_info=False)
self.logger.info("Best h1 evaluation")
self.logger.info("TEST F1: {:.4f}, H1: {:.4f}, EM {:.4f}".format(eval_f1, eval_h1, eval_em))
filename = os.path.join(self.args['checkpoint_dir'], "{}-f1.ckpt".format(self.args['experiment_name']))
self.load_ckpt(filename)
eval_f1, eval_h1, eval_em = self.evaluate(self.test_data, self.test_batch_size, write_info=False)
self.logger.info("Best f1 evaluation")
self.logger.info("TEST F1: {:.4f}, H1: {:.4f}, EM {:.4f}".format(eval_f1, eval_h1, eval_em))
filename = os.path.join(self.args['checkpoint_dir'], "{}-final.ckpt".format(self.args['experiment_name']))
self.load_ckpt(filename)
eval_f1, eval_h1, eval_em = self.evaluate(self.test_data, self.test_batch_size, write_info=False)
self.logger.info("Final evaluation")
self.logger.info("TEST F1: {:.4f}, H1: {:.4f}, EM {:.4f}".format(eval_f1, eval_h1, eval_em))
def evaluate_single(self, filename):
if filename is not None:
self.load_ckpt(filename)
eval_f1, eval_hits, eval_ems = self.evaluate(self.valid_data, self.test_batch_size, write_info=False)
self.logger.info("EVAL F1: {:.4f}, H1: {:.4f}, EM {:.4f}".format(eval_f1, eval_hits, eval_ems))
test_f1, test_hits, test_ems = self.evaluate(self.test_data, self.test_batch_size, write_info=True)
self.logger.info("TEST F1: {:.4f}, H1: {:.4f}, EM {:.4f}".format(test_f1, test_hits, test_ems))
def train_epoch(self):
self.model.train()
self.train_data.reset_batches(is_sequential=False)
losses = []
actor_losses = []
ent_losses = []
num_epoch = math.ceil(self.train_data.num_data / self.args['batch_size'])
h1_list_all = []
f1_list_all = []
for iteration in tqdm(range(num_epoch)):
batch = self.train_data.get_batch(iteration, self.args['batch_size'], self.args['fact_drop'])
self.optim_model.zero_grad()
loss, _, _, tp_list = self.model(batch, training=True)
# if tp_list is not None:
h1_list, f1_list = tp_list
h1_list_all.extend(h1_list)
f1_list_all.extend(f1_list)
loss.backward()
torch.nn.utils.clip_grad_norm_([param for name, param in self.model.named_parameters()],
self.args['gradient_clip'])
self.optim_model.step()
losses.append(loss.item())
extras = [0, 0]
return np.mean(losses), extras, h1_list_all, f1_list_all
def save_ckpt(self, reason="h1"):
model = self.model
checkpoint = {
'model_state_dict': model.state_dict()
}
model_name = os.path.join(self.args['checkpoint_dir'], "{}-{}.ckpt".format(self.args['experiment_name'],
reason))
torch.save(checkpoint, model_name)
print("Best %s, save model as %s" %(reason, model_name))
def load_ckpt(self, filename):
checkpoint = torch.load(filename)
model_state_dict = checkpoint["model_state_dict"]
model = self.model
#self.logger.info("Load param of {} from {}.".format(", ".join(list(model_state_dict.keys())), filename))
model.load_state_dict(model_state_dict, strict=False)
================================================
FILE: gnn/utils.py
================================================
import logging
import os
def create_logger(args):
log_file = os.path.join(args.checkpoint_dir, args.experiment_name + ".log")
logger = logging.getLogger()
log_level = logging.DEBUG if args.log_level == "debug" else logging.INFO
logger.setLevel(level=log_level)
# Formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# FileHandler
file_handler = logging.FileHandler(log_file)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
# StreamHandler
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
logger.info("PARAMETER" + "-" * 10)
for attr, value in sorted(args.__dict__.items()):
logger.info("{}={}".format(attr.upper(), value))
logger.info("---------" + "-" * 10)
return logger
def get_dict(data_folder, filename):
filename_true = os.path.join(data_folder, filename)
word2id = dict()
with open(filename_true, encoding='utf-8') as f_in:
for line in f_in:
word = line.strip()
word2id[word] = len(word2id)
return word2id
================================================
FILE: llm/.gitignore
================================================
datasets/
*json
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
================================================
FILE: llm/README.md
================================================
## Get Started
We have simple requirements in `requirements.txt`. You can always check if you can run the code immediately.
Please also download `entities_names.json` file from https://drive.google.com/drive/folders/1ifgVHQDnvFEunP9hmVYT07Y3rvcpIfQp?usp=sharing, as GNNs use the dense graphs.
## Evaluation
We provide the results of GNN retrieval in `results/gnn`. To evaluate GNN-RAG performance, run `scripts/rag-reasoning.sh`.
You can also compute perfromance on multi-hop question by `scripts/evaluate_multi_hop.sh`.
To test different LLMs for KGQA (ChatGPT, LLaMA2), see `scripts/plug-and-play.sh`.
## Resutls
We append all the results for Table 2: See `results/KGQA-GNN-RAG-RA`. You can look at the actual LLM generations, as well as the KG information retrieved ("input" key) in `predictions.jsonl`.
================================================
FILE: llm/prompts/alpaca.txt
================================================
Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Input:
{input}
### Response:
================================================
FILE: llm/prompts/general_prompt.txt
================================================
{instruction}
{input}
================================================
FILE: llm/prompts/llama2.txt
================================================
[INST] <<SYS>>
<</SYS>>
{instruction}{input} [/INST]
================================================
FILE: llm/prompts/llama2_predict.txt
================================================
[INST] <<SYS>>
<</SYS>>
{instruction}
{input} [/INST]
================================================
FILE: llm/requirements.txt
================================================
openai==0.27.9
transformers==4.32.0
trl==0.7.1
peft==0.5.0
datasets==2.14.4
accelerate==0.22.0
pybind11==2.11.1
networkx==3.1
graph-walker==1.0.6
tqdm
================================================
FILE: llm/results/KGQA-GNN-RAG/rearev-lmsr/RoG-cwq/RoG/test/results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/False/args.txt
================================================
{
"data_path": "rmanluo",
"d": "RoG-cwq",
"split": "test",
"predict_path": "results/KGQA-GO",
"model_name": "RoG",
"prompt_path": "prompts/llama2_predict.txt",
"add_rule": true,
"use_true": false,
"cot": false,
"explain": false,
"use_random": false,
"each_line": false,
"rule_path": "results/gen_rule_path/RoG-cwq/RoG/test/predictions_3_False.jsonl",
"force": false,
"n": 1,
"filter_empty": false,
"debug": false,
"encrypt": false,
"model_path": "rmanluo/RoG",
"max_new_tokens": 512,
"dtype": "fp16"
}
================================================
FILE: llm/results/KGQA-GNN-RAG/rearev-lmsr/RoG-cwq/RoG/test/results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/False/detailed_eval_result.jsonl
================================================
{"id": "WebQTest-832_c334509bb5e02cacae1ba2e80c176499", "prediction": ["2014 World Series", "2010 World Series", "2012 World Series"], "ground_truth": ["2014 World Series"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 0.5, "precission": 0.3333333333333333, "recall": 1.0}
{"id": "WebQTrn-1259_1997cb4922db71983be26e6a509950f4", "prediction": ["University of Tennessee", "Cornell University", "Shortridge High School", "Butler University", "University of Oxford", "University Yale", "Georgetown University", "University College, Oxford"], "ground_truth": ["Belmont University"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTest-1384_744a496b907e407b16bc5d7c197dc3f0", "prediction": ["Judaism"], "ground_truth": ["Judaism"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-241_dfb6c97ac9bf2f0ac07f27dd80f9edc2", "prediction": ["Germany"], "ground_truth": ["Germany"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-1077_f4a9e5f1e0dcfb82cbadf4771eda7bb5", "prediction": ["Islam", "Pashtunism", "Afghanism"], "ground_truth": ["Shia Islam", "Sunni Islam"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTest-576_01e2da60a2779c4ae4b5d1547499a4f8", "prediction": ["Guatemala"], "ground_truth": ["Guatemala"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTest-590_6aad73acb74f304bc9acae44314164be", "prediction": ["Muammar Gaddafi"], "ground_truth": ["Abdullah al-Thani"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTrn-3251_d8cddfe5e947e414b7735780ef1efff8", "prediction": ["University of Northern Colorado"], "ground_truth": ["University of Northern Colorado"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-2069_0fa727f3b282196eb1097410b4be6818", "prediction": ["Spanish Language"], "ground_truth": ["Aymara language", "Mapudungun Language", "Rapa Nui Language", "Spanish Language", "Puquina Language"], "acc": 0.2, "hit": 1, "hit1": 1, "f1": 0.33333333333333337, "precission": 1.0, "recall": 0.2}
{"id": "WebQTrn-1938_7322a2a4d46bf36b95bfab4418c9a32b", "prediction": ["Parliamentary system"], "ground_truth": ["Parliamentary system"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-60_68f0d0ad309d64a4af858a5ef4fb5713", "prediction": ["Angola"], "ground_truth": ["Mozambique"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTest-100_bf66fd89b6e9fc5fcb96c8b3f7a0e616", "prediction": ["Basque Language", "Catalan language", "Proven\u00e7al Language", "Corsican Language", "Tahitian Language", "Franco-Proven\u00e7al Language", "Breton", "Yeniche Language", "Guianese Creole French Language", "Esperanto Language", "West Flemish", "Gallo language", "Antillean Creole French", "Occitan language", "French", "R\u00e9union Creole French Language", "Alsatian dialect"], "ground_truth": ["Haitian Creole", "French"], "acc": 0.5, "hit": 1, "hit1": 0, "f1": 0.10526315789473684, "precission": 0.058823529411764705, "recall": 0.5}
{"id": "WebQTrn-962_032f61bfcfed69da8b215bb8f058c24e", "prediction": ["Michael Connor Humphreys"], "ground_truth": ["Michael Connor Humphreys"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-849_586aae7703d62aa44eb79759e1563309", "prediction": ["Denmark"], "ground_truth": ["Denmark"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-710_c264a6d11d7956741926d417b94327e2", "prediction": ["2014 World Series"], "ground_truth": ["2014 World Series"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-2209_c1374f388d9cc7a78365860c91218362", "prediction": ["2008 NBA Finals", "1969 NBA Finals", "1986 NBA Finals", "1984 NBA Finals", "1981 NBA Finals", "1976 NBA Finals", "1974 NBA Finals", "1968 NBA Finals", "1966 NBA Finals", "1965 NBA Finals", "1964 NBA Finals", "1963 NBA Finals", "1962 NBA Finals", "1961 NBA Finals", "1960 NBA Finals", "1957 NBA Finals"], "ground_truth": ["2008 NBA Finals", "1969 NBA Finals", "1986 NBA Finals", "1984 NBA Finals", "1981 NBA Finals", "1976 NBA Finals", "1974 NBA Finals", "1968 NBA Finals", "1966 NBA Finals", "1965 NBA Finals", "1964 NBA Finals", "1963 NBA Finals", "1962 NBA Finals", "1961 NBA Finals", "1960 NBA Finals", "1959 NBA Finals", "1957 NBA Finals"], "acc": 0.9411764705882353, "hit": 1, "hit1": 1, "f1": 0.9696969696969697, "precission": 1.0, "recall": 0.9411764705882353}
{"id": "WebQTest-1251_cbf2f20f6caf754bc49d672ca7b150b7", "prediction": ["East Germany"], "ground_truth": ["East Germany"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-3100_143c89d70679c3e5257c93d8e2bc4c67", "prediction": ["Greater Antilles", "Pico Duarte"], "ground_truth": ["Greater Antilles"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 0.6666666666666666, "precission": 0.5, "recall": 1.0}
{"id": "WebQTrn-662_7a992044f94b39edfc37ac5dcfcb3c26", "prediction": ["Manchester United F.C. First team", "Manchester United F.C."], "ground_truth": ["Newton Heath L&YR F.C."], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTest-743_0a8cdba29cf260283b7c890b3609c0b9", "prediction": ["Robert F. Kennedy"], "ground_truth": ["Robert F. Kennedy"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTest-1797_68a33792b0a1e18937dcd4b3f50d941e", "prediction": ["Confederate States of America"], "ground_truth": ["Confederate States of America"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-1864_9dc4e22121d3a46d45b8f9bd9e8c7013", "prediction": ["China"], "ground_truth": ["India"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTrn-1841_b8df00139e3fa59b8633ef551ed8ca9f", "prediction": ["Football"], "ground_truth": ["Spain national football team"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTrn-124_f3990dc9aa470fa81ec4cf2912a9924f", "prediction": ["Unbroken"], "ground_truth": ["In the Land of Blood and Honey"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTest-538_49b4e9304f18a0a1cbe37bb162f61131", "prediction": ["National Lampoon's TV: The Movie", "The Tournament", "Celebrity", "Life as a House", "Anatomy of a Hate Crime", "Changing Hearts", "The Rules of Attraction", "Marco Polo", "How to Make Love to a Woman", "In Enemy Hands", "Pulse", "Lost City Raiders", "Wake", "The Lost Samaritan", "Fireball", "Fearless", "The Sensation of Sight", "The Anomaly", "Time Framed", "Caught on Tape", "Recess", "The Old Man and the Studio"], "ground_truth": ["Hannah Montana: The Movie", "Valentine's Day", "The Lorax", "Jonas Brothers: The Concert Experience", "The Giver"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTrn-2722_8babdaa9ecd05a72e3227b43b1f98771", "prediction": ["2007 World Championships in Athletics \u2013 Men's 100 metres", "2007 World Championships in Athletics \u2013 Men's 200 metres"], "ground_truth": ["Gymnastics at the 2012 Summer Olympics \u2013 Women's artistic team all-around"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTest-983_ccda690fb745939d0a62c3fbcf3e3769", "prediction": ["University of Missouri"], "ground_truth": ["University of Missouri"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-1770_8db36acba886620a06031d39165d78de", "prediction": ["Paris-Michael Katherine Jackson", "Michael Joseph Jackson, Jr.", "Prince Michael Jackson II"], "ground_truth": ["Blue Ivy"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTrn-1817_89670933168c3f4e5195a241f9d46e76", "prediction": ["Communist state", "Socialist state", "Single-party state"], "ground_truth": ["Communist state", "Socialist state", "Single-party state"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-849_1eb0f1ddd5074471fbe3a7f6b575f202", "prediction": ["Denmark"], "ground_truth": ["Denmark"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-2152_92fba37c9723caee68665ad9a5e4a468", "prediction": ["Oakland Athletics", "Texas Rangers"], "ground_truth": ["Texas Rangers"], "acc": 1.0, "hit": 1, "hit1": 0, "f1": 0.6666666666666666, "precission": 0.5, "recall": 1.0}
{"id": "WebQTrn-3376_0619d288bbed0ca782e60c6f841a6051", "prediction": ["Spelman College", "Sarah Lawrence College"], "ground_truth": ["Temple University", "Castlemont High School"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTrn-567_693feb48c0515cd069014e7ca2846b37", "prediction": ["A Beautiful Mind"], "ground_truth": ["The Journey"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTrn-1077_0c34ca057060e35aa5c74fbbca682dee", "prediction": ["Islam"], "ground_truth": ["Shia Islam", "Sunni Islam"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTest-1686_29e74083744b3631541b29b4094fb273", "prediction": ["Janet Napolitano"], "ground_truth": ["Janet Napolitano"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTest-626_8c6b952c6bd963f0ece4e401c9eb731a", "prediction": ["Iowa", "Kansas", "Missouri", "Nebraska", "North Dakota", "South Dakota"], "ground_truth": ["Iowa", "Kansas", "Missouri", "Nebraska", "North Dakota", "South Dakota"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTest-12_68d745a0657c86906382873e57294d6a", "prediction": ["Ted Strickland", "Mike DeWine", "Sherrod Brown"], "ground_truth": ["Return J. Meigs, Jr.", "John Kasich"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTest-576_906ad6be7bec9d208f4dde4f7721c261", "prediction": ["El Salvador", "El Salvador", "Guatemala", "Honduras", "Panama", "Costa Rica"], "ground_truth": ["Belize", "Costa Rica", "El Salvador", "Guatemala", "Honduras", "Panama"], "acc": 0.8333333333333334, "hit": 1, "hit1": 1, "f1": 0.8333333333333334, "precission": 0.8333333333333334, "recall": 0.8333333333333334}
{"id": "WebQTrn-2152_3cdf60c15a8355981dd92e3c57ac2eed", "prediction": ["Seattle Mariners"], "ground_truth": ["Seattle Mariners"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTest-61_09020cbb000c86fda5910ec084690246", "prediction": ["Islam"], "ground_truth": ["Islam"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-1864_67ecd1c247c3b2c9545fbcf1ad8d9d00", "prediction": ["China"], "ground_truth": ["India"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTrn-3136_a2debf685e0c50491e35a9cf7a1e9ade", "prediction": ["Northern Ireland"], "ground_truth": ["Northern Ireland"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTest-998_f693ae2cb9aa6b8ffdf0f103a6777b62", "prediction": ["1973 NBA Finals"], "ground_truth": ["1973 NBA Finals"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-1023_1e7110e48c30a2cef3caf291e3b8d394", "prediction": ["Argentine peso"], "ground_truth": ["Argentine peso"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-810_c334509bb5e02cacae1ba2e80c176499", "prediction": ["2014 World Series", "2010 World Series", "2012 World Series"], "ground_truth": ["2014 World Series"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 0.5, "precission": 0.3333333333333333, "recall": 1.0}
{"id": "WebQTrn-557_960c16ffdb29e173df0577fc76c7455d", "prediction": ["Judy Garland"], "ground_truth": ["Judy Garland"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-21_660138373d19bbffdd3d3f7a30234e4a", "prediction": ["Hailemariam Desalegn"], "ground_truth": ["Hailemariam Desalegn"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-2286_e0906c845ea8b2e22e08e1b0e6eb9b43", "prediction": ["United States of America", "United States, with Territories"], "ground_truth": ["Contiguous United States", "United States of America", "United States, with Territories"], "acc": 0.6666666666666666, "hit": 1, "hit1": 1, "f1": 0.8, "precission": 1.0, "recall": 0.6666666666666666}
{"id": "WebQTest-1320_c5498ca807d2e1ec30d4c8fdd41f0bf7", "prediction": ["Miller Park", "Milwaukee"], "ground_truth": ["Miller Park"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 0.6666666666666666, "precission": 0.5, "recall": 1.0}
{"id": "WebQTrn-2444_8f2cd432b509e5b8fe681bb55bca2767", "prediction": ["Sportsman's Park"], "ground_truth": ["Busch Stadium"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTest-100_524908899a8aa334a18a0ac00f8f2fe6", "prediction": ["Basque Language", "Catalan language", "Proven\u00e7al Language", "Corsican Language", "Tahitian Language", "Franco-Proven\u00e7al Language", "Breton", "Yeniche Language", "Guianese Creole French Language", "Esperanto Language", "West Flemish", "Gallo language", "Antillean Creole French", "Occitan language", "French", "R\u00e9union Creole French Language", "Alsatian dialect"], "ground_truth": ["Haitian Creole", "French"], "acc": 0.5, "hit": 1, "hit1": 0, "f1": 0.10526315789473684, "precission": 0.058823529411764705, "recall": 0.5}
{"id": "WebQTrn-568_d54918e8e89ad97237bce821087a9818", "prediction": ["Mecklenburg County"], "ground_truth": ["Mecklenburg County"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTest-361_e24533e28da40db99eb4b25773f9d38f", "prediction": ["Albertina", "Museum of Military History, Vienna"], "ground_truth": ["Kunsthistorisches Museum"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTrn-2215_11c4cd5a25fd84f3980d7013c0329bad", "prediction": ["Glen Dale"], "ground_truth": ["Glen Dale"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-2834_6a27420dcf0528ae017dd74e40cfd38a", "prediction": ["Islam"], "ground_truth": ["Islam"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-634_000f72b60ccaedff5f056f522ee06e98", "prediction": ["Film Score Composer", "Composer"], "ground_truth": ["Professor", "Monk", "Theologian", "Physician", "Priest", "Writer"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTrn-857_9392f3f06e288ee4e3437a74f6bf5a37", "prediction": ["Al Gore"], "ground_truth": ["Al Gore"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-810_a188aff4a054e1ec66fafba1b8021f67", "prediction": ["2014 World Series", "2010 World Series", "2012 World Series"], "ground_truth": ["2014 World Series"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 0.5, "precission": 0.3333333333333333, "recall": 1.0}
{"id": "WebQTrn-2026_d059b24adec4064377b957ca598769be", "prediction": ["Dominican peso"], "ground_truth": ["Dominican peso"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-60_990f6babd500d25e3746174e6da58c84", "prediction": ["Angola", "Guyana"], "ground_truth": ["Angola"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 0.6666666666666666, "precission": 0.5, "recall": 1.0}
{"id": "WebQTrn-2570_374d1789f1735b6f08e1a829c0d075a2", "prediction": ["Franklin D. Roosevelt"], "ground_truth": ["Harry S. Truman"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTest-484_ac67410188d0f2258139a3c84773885e", "prediction": ["English Language", "Esperanto Language", "Lojban"], "ground_truth": ["English Language", "Esperanto Language", "Lojban"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTest-1785_1c178196ae53ffd4b09e5787a35c3950", "prediction": ["James at 15"], "ground_truth": ["Nanny and the Professor"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTrn-3744_26401b4c33bcb760afd734acaa0a1869", "prediction": ["University of Wisconsin-Madison"], "ground_truth": ["University of Wisconsin-Madison"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-810_e3d40457273785e46c5b71732713a5f4", "prediction": ["2014 World Series", "2010 World Series", "2012 World Series"], "ground_truth": ["2014 World Series"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 0.5, "precission": 0.3333333333333333, "recall": 1.0}
{"id": "WebQTest-538_92e606ef9c0429ad6820797ad2950730", "prediction": ["The Lorax"], "ground_truth": ["The Lorax"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-2189_f4440609f5cecb091bf8e86adb47be25", "prediction": ["Modern Standard Arabic"], "ground_truth": ["Modern Standard Arabic"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-372_a806813629e86776ae3bfe26a3e000f8", "prediction": ["Holy Spirit", "Jesus Christ", "The Father", "The Son", "God"], "ground_truth": ["Holy Spirit", "Jesus Christ", "The Father", "God"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 0.888888888888889, "precission": 0.8, "recall": 1.0}
{"id": "WebQTrn-3744_b9bd90569bb7912ec3ea180bf164663c", "prediction": ["University of Wisconsin-Madison"], "ground_truth": ["University of Wisconsin-Madison"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTest-712_6099ea03f4fd2476605c4a513318d29c", "prediction": ["New Zealand", "New Zealand", "New Zealand", "New Zealand", "New Zealand", "New Zealand"], "ground_truth": ["New Zealand", "New Zealand", "New Zealand", "New Zealand", "New Zealand"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 0.9090909090909091, "precission": 0.8333333333333334, "recall": 1.0}
{"id": "WebQTest-100_de15ac1f762e3ec1e1261f6d9c81ebf9", "prediction": ["Haitian Creole", "French"], "ground_truth": ["Haitian Creole", "French"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-1557_edfd3507d32929ce0e9d844f7a2674de", "prediction": ["McGill University"], "ground_truth": ["McGill University"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-60_39ba4faa87698cb0767d1a5ee7ce1827", "prediction": ["Portugal"], "ground_truth": ["South Africa"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTrn-846_a29552911617e890ca2e1d6564e0990e", "prediction": ["Belleville"], "ground_truth": ["Belleville"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTest-121_333b95bd45af3e6e31e328bc8c24d84f", "prediction": ["Manchester City F.C."], "ground_truth": ["LA Galaxy"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTest-626_47b990334f91b6d3bd042f82b81740f6", "prediction": ["South Dakota"], "ground_truth": ["South Dakota"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-2444_e5059ff268415917df330817b9c8ef8c", "prediction": ["Roger Dean Stadium"], "ground_truth": ["Busch Stadium"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTrn-2664_471b83eade9707a4dba68e201bf29d73", "prediction": ["Madagascar"], "ground_truth": ["Sierra Leone"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTest-576_8dfecb6548586cf236340abadabeb86c", "prediction": ["Honduras", "Honduras"], "ground_truth": ["El Salvador", "Panama"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTrn-124_028bb5f442b37a4af9f9fd9fa0bc5e9a", "prediction": ["By the Sea", "In the Land of Blood and Honey", "A Place in Time", "Unbroken"], "ground_truth": ["By the Sea", "In the Land of Blood and Honey", "A Place in Time", "Unbroken"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-2871_d7d303efc1f901f14e6aae2bb469743c", "prediction": ["Arizona", "Phoenix", "Lake Powell"], "ground_truth": ["Lake Powell", "Phoenix"], "acc": 1.0, "hit": 1, "hit1": 0, "f1": 0.8, "precission": 0.6666666666666666, "recall": 1.0}
{"id": "WebQTrn-750_bb35e9b8023fbf9c05df55b4245a4775", "prediction": ["1988 World Series", "1981 World Series", "1965 World Series", "1959 World Series", "1963 World Series", "1964 World Series", "1956 World Series", "1955 World Series", "1953 World Series", "1952 World Series", "1951 World Series", "1950 World Series", "1949 World Series", "1947 World Series", "1946 World Series", "1945 World Series", "1943 World Series", "1942 World Series", "1941 World Series", "1939 World Series", "1938 World Series", "1937 World Series", "1936 World Series", "1935 World Series", "1934 World Series", "1933 World Series", "1932 World Series", "1931 World Series", "1930 World Series", "1929 World Series", "1928 World Series", "1927 World Series", "1925 World Series", "1923 World Series", "1922 World Series", "1921 World Series", "1916 World Series", "1915 World Series", "1913 World Series", "1912 World Series", "1911 World Series", "1910 World Series", "1909 World Series", "1908 World Series", "1907 World Series", "1905 World Series", "1903 World Series", "1902 World Series", "1901 World Series", "1883 World Series", "1882 World Series", "1881 World Series", "1879 World Series", "1877 World Series", "1876 World Series", "1875 World Series", "1874 World Series", "1873 World Series", "1872 World Series", "1871 World Series", "1870 World Series", "1868 World Series", "1867 World Series", "1866 World Series", "1865 World Series", "1864 World Series", "1863 World Series", "1862 World Series", "1861 World Series", "1860 World Series", "1859 World Series", "1857 World Series", "1856 World Series"], "ground_truth": ["1981 World Series", "1988 World Series", "1965 World Series", "1963 World Series", "1959 World Series"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 0.1282051282051282, "precission": 0.0684931506849315, "recall": 1.0}
{"id": "WebQTest-96_11da03aa9cec8b011619c8ea0dbfdcf9", "prediction": ["Dallas"], "ground_truth": ["New York City"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTest-1528_25853c768670cd164d7793f094ba7cbb", "prediction": ["Percy Jackson & the Olympians: The Lightning Thief"], "ground_truth": ["Percy Jackson & the Olympians: The Lightning Thief", "Percy Jackson: Sea of Monsters"], "acc": 0.5, "hit": 1, "hit1": 1, "f1": 0.6666666666666666, "precission": 1.0, "recall": 0.5}
{"id": "WebQTest-1260_5dd0eeca79ae03b7711252c032849eb2", "prediction": ["2001 AFC Championship Game", "2013 AFC Championship Game", "Super Bowl XLVII", "Super Bowl XXXV"], "ground_truth": ["Super Bowl XXXV"], "acc": 1.0, "hit"
gitextract_o4kycd_o/
├── .gitignore
├── README.md
├── gnn/
│ ├── .gitignore
│ ├── README.md
│ ├── dataset_load.py
│ ├── dataset_load_graft.py
│ ├── evaluate.py
│ ├── main.py
│ ├── models/
│ │ ├── GraftNet/
│ │ │ └── graftnet.py
│ │ ├── NSM/
│ │ │ └── nsm.py
│ │ ├── ReaRev/
│ │ │ └── rearev.py
│ │ └── base_model.py
│ ├── modules/
│ │ ├── kg_reasoning/
│ │ │ ├── base_gnn.py
│ │ │ ├── graft_gnn.py
│ │ │ ├── nsm_gnn.py
│ │ │ └── reasongnn.py
│ │ ├── layer_init.py
│ │ ├── query_update.py
│ │ └── question_encoding/
│ │ ├── base_encoder.py
│ │ ├── bert_encoder.py
│ │ ├── lstm_encoder.py
│ │ └── tokenizers.py
│ ├── parsing.py
│ ├── requirements.txt
│ ├── scripts/
│ │ └── rearev_cwq.sh
│ ├── train_model.py
│ └── utils.py
└── llm/
├── .gitignore
├── README.md
├── prompts/
│ ├── alpaca.txt
│ ├── general_prompt.txt
│ ├── llama2.txt
│ └── llama2_predict.txt
├── requirements.txt
├── results/
│ ├── KGQA-GNN-RAG/
│ │ ├── rearev-lmsr/
│ │ │ ├── RoG-cwq/
│ │ │ │ └── RoG/
│ │ │ │ └── test/
│ │ │ │ └── results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/
│ │ │ │ └── False/
│ │ │ │ ├── args.txt
│ │ │ │ ├── detailed_eval_result.jsonl
│ │ │ │ ├── eval_result.txt
│ │ │ │ └── predictions.jsonl
│ │ │ └── RoG-webqsp/
│ │ │ ├── RoG/
│ │ │ │ └── test/
│ │ │ │ └── results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/
│ │ │ │ └── False/
│ │ │ │ ├── args.txt
│ │ │ │ ├── detailed_eval_result.jsonl
│ │ │ │ ├── eval_result.txt
│ │ │ │ └── predictions.jsonl
│ │ │ └── llama2-chat-hf/
│ │ │ └── test/
│ │ │ └── no_rule/
│ │ │ └── False/
│ │ │ ├── args.txt
│ │ │ ├── detailed_eval_result.jsonl
│ │ │ ├── eval_result.txt
│ │ │ └── predictions.jsonl
│ │ └── rearev-sbert/
│ │ ├── RoG-cwq/
│ │ │ └── RoG/
│ │ │ └── test/
│ │ │ └── results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/
│ │ │ └── False/
│ │ │ ├── args.txt
│ │ │ ├── detailed_eval_result.jsonl
│ │ │ ├── eval_result.txt
│ │ │ └── predictions.jsonl
│ │ └── RoG-webqsp/
│ │ └── RoG/
│ │ └── test/
│ │ └── results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/
│ │ └── False/
│ │ ├── args.txt
│ │ ├── detailed_eval_result.jsonl
│ │ ├── eval_result.txt
│ │ └── predictions.jsonl
│ ├── KGQA-GNN-RAG-RA/
│ │ ├── rearev-lmsr/
│ │ │ ├── RoG-cwq/
│ │ │ │ └── RoG/
│ │ │ │ └── test/
│ │ │ │ └── results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/
│ │ │ │ └── False/
│ │ │ │ ├── args.txt
│ │ │ │ ├── detailed_eval_result.jsonl
│ │ │ │ ├── eval_result.txt
│ │ │ │ └── predictions.jsonl
│ │ │ └── RoG-webqsp/
│ │ │ └── RoG/
│ │ │ └── test/
│ │ │ └── results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/
│ │ │ └── False/
│ │ │ ├── args.txt
│ │ │ ├── detailed_eval_result.jsonl
│ │ │ ├── eval_result.txt
│ │ │ └── predictions.jsonl
│ │ └── rearev-sbert/
│ │ ├── RoG-cwq/
│ │ │ └── RoG/
│ │ │ └── test/
│ │ │ └── results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/
│ │ │ └── False/
│ │ │ ├── args.txt
│ │ │ ├── detailed_eval_result.jsonl
│ │ │ ├── eval_result.txt
│ │ │ └── predictions.jsonl
│ │ └── RoG-webqsp/
│ │ └── RoG/
│ │ └── test/
│ │ └── results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/
│ │ └── False/
│ │ ├── args.txt
│ │ ├── detailed_eval_result.jsonl
│ │ ├── eval_result.txt
│ │ └── predictions.jsonl
│ ├── gen_rule_path/
│ │ ├── RoG-cwq/
│ │ │ └── RoG/
│ │ │ └── test/
│ │ │ ├── predictions_2_False.jsonl
│ │ │ └── predictions_3_False.jsonl
│ │ └── RoG-webqsp/
│ │ └── RoG/
│ │ ├── test/
│ │ │ ├── predictions_1_False.jsonl
│ │ │ ├── predictions_2_False.jsonl
│ │ │ └── predictions_3_False.jsonl
│ │ ├── train/
│ │ │ ├── predictions_1_False.jsonl
│ │ │ └── predictions_3_False.jsonl
│ │ └── validation/
│ │ └── predictions_3_False.jsonl
│ └── gnn/
│ ├── RoG-cwq/
│ │ ├── rearev-lmsr/
│ │ │ └── test.info
│ │ └── rearev-sbert/
│ │ └── test.info
│ └── RoG-webqsp/
│ ├── rearev-lmsr/
│ │ └── test.info
│ └── rearev-sbert/
│ └── test.info
├── scripts/
│ ├── evaluate_multi_hop.sh
│ ├── interpretable_example.py
│ ├── planning.sh
│ ├── plug-and-play.sh
│ ├── rag-reasoning.sh
│ └── train.sh
└── src/
├── __init__.py
├── align_kg/
│ ├── __init__.py
│ ├── build_align_qa_dataset.py
│ └── data_loader.py
├── joint_training/
│ ├── generate_explanation_results.py
│ ├── joint_finetuning.py
│ ├── preprocess_align.py
│ └── preprocess_qa.py
├── llms/
│ ├── __init__.py
│ ├── language_models/
│ │ ├── __init__.py
│ │ ├── alpaca.py
│ │ ├── base_language_model.py
│ │ ├── chatgpt.py
│ │ ├── flan_t5.py
│ │ ├── llama.py
│ │ └── longchat/
│ │ ├── llama_condense_monkey_patch.py
│ │ ├── llama_flash_attn_monkey_patch.py
│ │ └── longchat.py
│ ├── llm_proxy.py
│ └── start_fastchat_api.py
├── qa_prediction/
│ ├── build_qa_input.py
│ ├── evaluate_multi_hop.py
│ ├── evaluate_results.py
│ ├── gen_rule_path.py
│ └── predict_answer.py
└── utils/
├── __init__.py
├── graph_utils.py
├── merge_peft.py
├── training_utils.py
└── utils.py
SYMBOL INDEX (298 symbols across 47 files)
FILE: gnn/dataset_load.py
class BasicDataLoader (line 18) | class BasicDataLoader(object):
method __init__ (line 24) | def __init__(self, config, word2id, relation2id, entity2id, tokenize, ...
method _load_file (line 31) | def _load_file(self, config, data_type="train"):
method _load_data (line 62) | def _load_data(self):
method _parse_args (line 88) | def _parse_args(self, config, word2id, relation2id, entity2id):
method get_quest (line 130) | def get_quest(self, training=False):
method decode_text (line 143) | def decode_text(self, np_array_x):
method _prepare_data (line 159) | def _prepare_data(self):
method build_rel_words (line 354) | def build_rel_words(self, tokenize):
method create_kb_adj_mats (line 432) | def create_kb_adj_mats(self, sample_id):
method _build_fact_mat (line 473) | def _build_fact_mat(self, sample_ids, fact_dropout):
method reset_batches (line 530) | def reset_batches(self, is_sequential=True):
method _build_global2local_entity_maps (line 536) | def _build_global2local_entity_maps(self):
method _add_entity_to_map (line 562) | def _add_entity_to_map(entity2id, entities, g2l):
method deal_q_type (line 577) | def deal_q_type(self, q_type=None):
class SingleDataLoader (line 592) | class SingleDataLoader(BasicDataLoader):
method __init__ (line 596) | def __init__(self, config, word2id, relation2id, entity2id, tokenize, ...
method get_batch (line 599) | def get_batch(self, iteration, batch_size, fact_dropout, q_type=None, ...
function load_dict (line 632) | def load_dict(filename):
function load_dict_int (line 640) | def load_dict_int(filename):
function load_data (line 648) | def load_data(config, tokenize):
FILE: gnn/dataset_load_graft.py
class GraftBasicDataLoader (line 19) | class GraftBasicDataLoader(BasicDataLoader):
method __init__ (line 24) | def __init__(self, config, word2id, relation2id, entity2id, tokenize, ...
method create_kb_adj_mats_facts (line 27) | def create_kb_adj_mats_facts(self, sample_id):
method _build_fact_mat_maxfacts (line 70) | def _build_fact_mat_maxfacts(self, sample_ids, fact_dropout):
class GraftSingleDataLoader (line 106) | class GraftSingleDataLoader(GraftBasicDataLoader):
method __init__ (line 110) | def __init__(self, config, word2id, relation2id, entity2id, tokenize, ...
method get_batch (line 113) | def get_batch(self, iteration, batch_size, fact_dropout, q_type=None, ...
function load_dict (line 152) | def load_dict(filename):
function load_dict_int (line 160) | def load_dict_int(filename):
function load_data_graft (line 168) | def load_data_graft(config, tokenize):
FILE: gnn/evaluate.py
function cal_accuracy (line 10) | def cal_accuracy(pred, answer_dist):
function f1_and_hits (line 25) | def f1_and_hits(answers, candidate2prob, id2entity, entity2name, eps=0.5):
class Evaluator (line 70) | class Evaluator:
method __init__ (line 72) | def __init__(self, args, model, entity2id, relation2id, device):
method write_info (line 106) | def write_info(self, valid_data, tp_list, num_step):
method evaluate (line 140) | def evaluate(self, valid_data, test_batch_size=20, write_info=False):
FILE: gnn/main.py
function main (line 29) | def main():
FILE: gnn/models/GraftNet/graftnet.py
class GraftNet (line 21) | class GraftNet(BaseModel):
method __init__ (line 22) | def __init__(self, args, num_entity, num_relation, num_word):
method layers (line 37) | def layers(self, args):
method private_module_def (line 61) | def private_module_def(self, args, num_entity, num_relation):
method get_ent_init (line 74) | def get_ent_init(self, local_entity, kb_adj_mat, rel_features):
method get_rel_feature (line 85) | def get_rel_feature(self):
method init_reason (line 105) | def init_reason(self, curr_dist, local_entity, kb_adj_mat, kb_adj_mat_...
method calc_loss_label (line 128) | def calc_loss_label(self, curr_dist, teacher_dist, label_valid):
method forward (line 135) | def forward(self, batch, training=False):
FILE: gnn/models/NSM/nsm.py
class NSM (line 19) | class NSM(BaseModel):
method __init__ (line 20) | def __init__(self, args, num_entity, num_relation, num_word):
method layers (line 37) | def layers(self, args):
method private_module_def (line 69) | def private_module_def(self, args, num_entity, num_relation):
method get_ent_init (line 85) | def get_ent_init(self, local_entity, kb_adj_mat, rel_features):
method get_rel_feature (line 97) | def get_rel_feature(self):
method init_reason (line 114) | def init_reason(self, curr_dist, local_entity, kb_adj_mat, q_input):
method get_js_div (line 142) | def get_js_div(self, dist_1, dist_2):
method calc_loss_backward (line 151) | def calc_loss_backward(self, case_valid):
method calc_loss_label (line 172) | def calc_loss_label(self, curr_dist, teacher_dist, label_valid):
method forward (line 179) | def forward(self, batch, training=False):
FILE: gnn/models/ReaRev/rearev.py
class ReaRev (line 19) | class ReaRev(BaseModel):
method __init__ (line 20) | def __init__(self, args, num_entity, num_relation, num_word):
method layers (line 51) | def layers(self, args):
method get_ent_init (line 79) | def get_ent_init(self, local_entity, kb_adj_mat, rel_features):
method get_rel_feature (line 91) | def get_rel_feature(self):
method private_module_def (line 114) | def private_module_def(self, args, num_entity, num_relation):
method init_reason (line 132) | def init_reason(self, curr_dist, local_entity, kb_adj_mat, q_input, qu...
method calc_loss_label (line 156) | def calc_loss_label(self, curr_dist, teacher_dist, label_valid):
method forward (line 163) | def forward(self, batch, training=False):
FILE: gnn/models/base_model.py
class BaseModel (line 9) | class BaseModel(torch.nn.Module):
method __init__ (line 14) | def __init__(self, args, num_entity, num_relation, num_word):
method embedding_def (line 70) | def embedding_def(self):
method load_relation_file (line 153) | def load_relation_file(self, filename):
method use_rel_texts (line 164) | def use_rel_texts(self, rel_texts, rel_texts_inv):
method encode_rel_texts (line 168) | def encode_rel_texts(self, rel_texts, rel_texts_inv):
method init_hidden (line 178) | def init_hidden(self, num_layer, batch_size, hidden_size):
method encode_question (line 181) | def encode_question(self, q_input):
method get_instruction (line 184) | def get_instruction(self, query_hidden_emb, query_mask, states):
method get_loss_bce (line 187) | def get_loss_bce(self, pred_dist_score, answer_dist):
method get_loss_kl (line 193) | def get_loss_kl(self, pred_dist, answer_dist):
method get_loss (line 201) | def get_loss(self, pred_dist, answer_dist, reduction='mean'):
method f1_and_hits (line 217) | def f1_and_hits(self, answers, candidate2prob, eps=0.5):
method calc_f1_new (line 249) | def calc_f1_new(self, curr_dist, dist_ans, h1_vec):
method calc_h1 (line 287) | def calc_h1(self, curr_dist, dist_ans, eps=0.01):
method get_eval_metric (line 294) | def get_eval_metric(self, pred_dist, answer_dist):
FILE: gnn/modules/kg_reasoning/base_gnn.py
class BaseGNNLayer (line 7) | class BaseGNNLayer(torch.nn.Module):
method __init__ (line 11) | def __init__(self, args, num_entity, num_relation):
method build_matrix (line 19) | def build_matrix(self):
method _build_sparse_tensor (line 53) | def _build_sparse_tensor(self, indices, values, size):
method build_adj_facts (line 56) | def build_adj_facts(self):
FILE: gnn/modules/kg_reasoning/graft_gnn.py
class GraftLayer (line 14) | class GraftLayer(BaseGNNLayer):
method __init__ (line 15) | def __init__(self, args, num_entity, num_relation, entity_dim):
method init_layers (line 27) | def init_layers(self, args):
method init_reason (line 45) | def init_reason(self, local_entity, kb_adj_mat, kb_adj_mat_graft, kb_f...
method compute_attention (line 64) | def compute_attention(self, query_hidden_emb, query_mask):
method reason_layer (line 89) | def reason_layer(self, curr_dist, kb_self_linear, kb_head_linear, kb_t...
method forward (line 111) | def forward(self, current_dist, query_hidden_emb, query_mask, step=0, ...
FILE: gnn/modules/kg_reasoning/nsm_gnn.py
class NSMBaseLayer (line 14) | class NSMBaseLayer(BaseGNNLayer):
method __init__ (line 15) | def __init__(self, args, num_entity, num_relation, entity_dim):
method init_layers (line 24) | def init_layers(self, args):
method init_reason (line 38) | def init_reason(self, local_entity, kb_adj_mat, local_entity_emb, rel_...
method reason_layer (line 51) | def reason_layer(self, curr_dist, instruction, rel_linear):
method forward (line 54) | def forward(self, current_dist, relational_ins, step=0, return_score=F...
class NSMLayer (line 83) | class NSMLayer(NSMBaseLayer):
method __init__ (line 84) | def __init__(self, args, num_entity, num_relation, entity_dim):
method reason_layer (line 87) | def reason_layer(self, curr_dist, instruction, rel_linear):
class NSMLayer_back (line 114) | class NSMLayer_back(NSMBaseLayer):
method __init__ (line 115) | def __init__(self, args, num_entity, num_relation, entity_dim):
method reason_layer (line 118) | def reason_layer(self, curr_dist, instruction, rel_linear):
FILE: gnn/modules/kg_reasoning/reasongnn.py
class ReasonGNNLayer (line 11) | class ReasonGNNLayer(BaseGNNLayer):
method __init__ (line 15) | def __init__(self, args, num_entity, num_relation, entity_dim, alg):
method init_layers (line 27) | def init_layers(self, args):
method init_reason (line 46) | def init_reason(self, local_entity, kb_adj_mat, local_entity_emb, rel_...
method reason_layer (line 61) | def reason_layer(self, curr_dist, instruction, rel_linear, pos_emb):
method reason_layer_inv (line 91) | def reason_layer_inv(self, curr_dist, instruction, rel_linear, pos_emb...
method combine (line 118) | def combine(self,emb):
method forward (line 134) | def forward(self, current_dist, relational_ins, step=0, return_score=F...
FILE: gnn/modules/layer_init.py
class TypeLayer (line 9) | class TypeLayer(nn.Module):
method __init__ (line 14) | def __init__(self, in_features, out_features, linear_drop, device, nor...
method forward (line 25) | def forward(self, local_entity, edge_list, rel_features):
method _build_sparse_tensor (line 64) | def _build_sparse_tensor(self, indices, values, size):
FILE: gnn/modules/query_update.py
class Fusion (line 6) | class Fusion(nn.Module):
method __init__ (line 8) | def __init__(self, d_hid):
method forward (line 13) | def forward(self, x, y):
class QueryReform (line 18) | class QueryReform(nn.Module):
method __init__ (line 20) | def __init__(self, h_dim):
method forward (line 26) | def forward(self, q_node, ent_emb, seed_info, ent_mask):
class AttnEncoder (line 46) | class AttnEncoder(nn.Module):
method __init__ (line 48) | def __init__(self, d_hid):
method forward (line 52) | def forward(self, x, x_mask):
class Attention (line 63) | class Attention(nn.Module):
method __init__ (line 89) | def __init__(self, dimensions, attention_type='general'):
method forward (line 103) | def forward(self, query, context):
FILE: gnn/modules/question_encoding/base_encoder.py
class BaseInstruction (line 8) | class BaseInstruction(torch.nn.Module):
method __init__ (line 10) | def __init__(self, args, constraint):
method _parse_args (line 16) | def _parse_args(self, args):
method share_module_def (line 48) | def share_module_def(self):
method init_hidden (line 53) | def init_hidden(self, num_layer, batch_size, hidden_size):
method encode_question (line 57) | def encode_question(self, *args):
method get_node_emb (line 62) | def get_node_emb(query_hidden_emb, action):
method init_reason (line 74) | def init_reason(self, query_text):
method get_instruction (line 82) | def get_instruction(self, relational_ins, step=0, query_node_emb=None):
method forward (line 105) | def forward(self, query_text, lm=None):
FILE: gnn/modules/question_encoding/bert_encoder.py
class BERTInstruction (line 18) | class BERTInstruction(BaseInstruction):
method __init__ (line 20) | def __init__(self, args, word_embedding, num_word, model, constraint=F...
method encoder_def (line 74) | def encoder_def(self):
method encode_question (line 89) | def encode_question(self, query_text, store=True):
FILE: gnn/modules/question_encoding/lstm_encoder.py
class LSTMInstruction (line 10) | class LSTMInstruction(BaseInstruction):
method __init__ (line 12) | def __init__(self, args, word_embedding, num_word):
method encoder_def (line 25) | def encoder_def(self):
method encode_question (line 32) | def encode_question(self, query_text, store=True):
FILE: gnn/modules/question_encoding/tokenizers.py
class LSTMTokenizer (line 5) | class LSTMTokenizer():
method __init__ (line 6) | def __init__(self, word2id, max_query_word):
method tokenize (line 11) | def tokenize(self, question):
method tokenize_sent (line 28) | def tokenize_sent(question_text):
class BERTTokenizer (line 41) | class BERTTokenizer():
method __init__ (line 42) | def __init__(self, max_query_word):
method tokenize (line 50) | def tokenize(self, question):
FILE: gnn/parsing.py
function bool_flag (line 5) | def bool_flag(v):
function add_shared_args (line 13) | def add_shared_args(parser):
function add_parse_args (line 68) | def add_parse_args(parser):
function create_parser_rearev (line 85) | def create_parser_rearev(parser):
function create_parser_nsm (line 101) | def create_parser_nsm(parser):
function create_parser_graftnet (line 115) | def create_parser_graftnet(parser):
FILE: gnn/train_model.py
class Trainer_KBQA (line 24) | class Trainer_KBQA(object):
method __init__ (line 25) | def __init__(self, args, model_name, logger=None):
method optim_def (line 89) | def optim_def(self):
method load_data (line 96) | def load_data(self, args, tokenize):
method load_pretrain (line 113) | def load_pretrain(self):
method evaluate (line 120) | def evaluate(self, data, test_batch_size=20, write_info=False):
method train (line 123) | def train(self, start_epoch, end_epoch):
method evaluate_best (line 182) | def evaluate_best(self):
method evaluate_single (line 201) | def evaluate_single(self, filename):
method train_epoch (line 209) | def train_epoch(self):
method save_ckpt (line 236) | def save_ckpt(self, reason="h1"):
method load_ckpt (line 246) | def load_ckpt(self, filename):
FILE: gnn/utils.py
function create_logger (line 5) | def create_logger(args):
function get_dict (line 29) | def get_dict(data_folder, filename):
FILE: llm/src/align_kg/build_align_qa_dataset.py
function build_data (line 13) | def build_data(args):
function process_data (line 35) | def process_data(data, remove_duplicate=False):
FILE: llm/src/align_kg/data_loader.py
function load_new_tokens (line 10) | def load_new_tokens(default_new_tokens, rel_dict_path):
function load_multiple_datasets (line 21) | def load_multiple_datasets(data_path_list, shuffle=False):
function get_test_dataset (line 41) | def get_test_dataset(dataset):
FILE: llm/src/joint_training/generate_explanation_results.py
function formatting_prompts_func (line 106) | def formatting_prompts_func(example):
FILE: llm/src/joint_training/joint_finetuning.py
class ScriptArguments (line 37) | class ScriptArguments:
class ScriptTrainingArguments (line 72) | class ScriptTrainingArguments(TrainingArguments):
function train (line 84) | def train():
FILE: llm/src/joint_training/preprocess_align.py
function formatting_prompts_func (line 29) | def formatting_prompts_func(example):
FILE: llm/src/joint_training/preprocess_qa.py
function formatting_prompts_func (line 36) | def formatting_prompts_func(example):
FILE: llm/src/llms/language_models/__init__.py
function get_registed_model (line 18) | def get_registed_model(model_name) -> BaseLanguageModel:
FILE: llm/src/llms/language_models/alpaca.py
class Alpaca (line 5) | class Alpaca(BaseLanguageModel):
method add_args (line 8) | def add_args(parser):
method __init__ (line 13) | def __init__(self, args):
method load_model (line 17) | def load_model(self, **kwargs):
method tokenize (line 20) | def tokenize(self, text):
method prepare_for_inference (line 23) | def prepare_for_inference(self, **model_kwargs):
method generate_sentence (line 28) | def generate_sentence(self, llm_input):
FILE: llm/src/llms/language_models/base_language_model.py
class BaseLanguageModel (line 4) | class BaseLanguageModel(object):
method add_args (line 12) | def add_args(parser):
method __init__ (line 15) | def __init__(self, args):
method load_model (line 18) | def load_model(self, **kwargs):
method prepare_for_inference (line 21) | def prepare_for_inference(self, **model_kwargs):
method tokenize (line 24) | def tokenize(self, text):
method generate_sentence (line 33) | def generate_sentence(self, lm_input):
FILE: llm/src/llms/language_models/chatgpt.py
function get_token_limit (line 13) | def get_token_limit(model='gpt-4'):
class ChatGPT (line 25) | class ChatGPT(BaseLanguageModel):
method add_args (line 28) | def add_args(parser):
method __init__ (line 31) | def __init__(self, args):
method tokenize (line 37) | def tokenize(self, text):
method prepare_for_inference (line 46) | def prepare_for_inference(self, model_kwargs={}):
method generate_sentence (line 52) | def generate_sentence(self, llm_input):
FILE: llm/src/llms/language_models/flan_t5.py
class FlanT5 (line 5) | class FlanT5(BaseLanguageModel):
method add_args (line 8) | def add_args(parser):
method __init__ (line 13) | def __init__(self, args):
method load_model (line 17) | def load_model(self, **kwargs):
method tokenize (line 21) | def tokenize(self, text):
method prepare_for_inference (line 24) | def prepare_for_inference(self, **model_kwargs):
method generate_sentence (line 30) | def generate_sentence(self, llm_input):
FILE: llm/src/llms/language_models/llama.py
class Llama (line 6) | class Llama(BaseLanguageModel):
method add_args (line 9) | def add_args(parser):
method __init__ (line 15) | def __init__(self, args):
method load_model (line 19) | def load_model(self, **kwargs):
method tokenize (line 23) | def tokenize(self, text):
method prepare_for_inference (line 26) | def prepare_for_inference(self, **model_kwargs):
method generate_sentence (line 34) | def generate_sentence(self, llm_input):
FILE: llm/src/llms/language_models/longchat/llama_condense_monkey_patch.py
function rank0_print (line 10) | def rank0_print(*args):
class CondenseRotaryEmbedding (line 18) | class CondenseRotaryEmbedding(torch.nn.Module):
method __init__ (line 19) | def __init__(self, dim, ratio, max_position_embeddings=2048, base=1000...
method forward (line 37) | def forward(self, x, seq_len=None):
function replace_llama_with_condense (line 53) | def replace_llama_with_condense(ratio):
FILE: llm/src/llms/language_models/longchat/llama_flash_attn_monkey_patch.py
function forward (line 14) | def forward(
function _prepare_decoder_attention_mask (line 84) | def _prepare_decoder_attention_mask(self, attention_mask, input_shape,
function replace_llama_attn_with_flash_attn (line 90) | def replace_llama_attn_with_flash_attn():
FILE: llm/src/llms/language_models/longchat/longchat.py
function maybe_monkey_patch (line 5) | def maybe_monkey_patch(args):
class Longchat (line 14) | class Longchat(BaseLanguageModel):
method add_args (line 18) | def add_args(parser):
method __init__ (line 25) | def __init__(self, args):
method load_model (line 30) | def load_model(self, **kwargs):
method tokenize (line 35) | def tokenize(self, text):
method prepare_for_inference (line 38) | def prepare_for_inference(self, **model_kwargs):
method generate_sentence (line 44) | def generate_sentence(self, llm_input):
FILE: llm/src/llms/llm_proxy.py
class LLMProxy (line 7) | class LLMProxy(object):
method regist_args (line 10) | def regist_args(parser):
method __init__ (line 18) | def __init__(self, args) -> None:
method query (line 33) | def query(message, model_name, timeout=60, max_retry=3):
FILE: llm/src/llms/start_fastchat_api.py
function terminate_process (line 8) | def terminate_process():
function start_fastchat_api (line 18) | def start_fastchat_api(model_names, model_path, conv_template, host, port):
function exit_handler (line 49) | def exit_handler(signal, frame):
FILE: llm/src/qa_prediction/build_qa_input.py
function normalize (line 15) | def normalize(s: str) -> str:
class PromptBuilder (line 26) | class PromptBuilder(object):
method __init__ (line 40) | def __init__(self, prompt_path, encrypt=False, add_rule = False, use_t...
method _read_prompt_template (line 53) | def _read_prompt_template(self, template_file):
method apply_rules (line 58) | def apply_rules(self, graph, rules, srouce_entities):
method direct_answer (line 66) | def direct_answer(self, question_dict):
method process_input (line 83) | def process_input(self, question_dict):
method check_prompt_length (line 164) | def check_prompt_length(self, prompt, list_of_paths, maximun_token):
FILE: llm/src/qa_prediction/evaluate_multi_hop.py
function normalize (line 21) | def normalize(s: str) -> str:
function match (line 33) | def match(s1: str, s2: str) -> bool:
function eval_acc (line 38) | def eval_acc(prediction, answer):
function eval_hit (line 45) | def eval_hit(prediction, answer):
function eval_hit1 (line 51) | def eval_hit1(prediction, answer):
function eval_f1 (line 57) | def eval_f1(prediction, answer):
function extract_topk_prediction (line 72) | def extract_topk_prediction(prediction, k=-1):
function eval_result (line 84) | def eval_result(predict_file1, encrypt=False, cal_f1=True, topk = -1):
FILE: llm/src/qa_prediction/evaluate_results.py
function normalize (line 15) | def normalize(s: str) -> str:
function match (line 27) | def match(s1: str, s2: str) -> bool:
function eval_acc (line 32) | def eval_acc(prediction, answer):
function eval_hit (line 39) | def eval_hit(prediction, answer):
function eval_hit1 (line 45) | def eval_hit1(prediction, answer):
function eval_f1 (line 51) | def eval_f1(prediction, answer):
function extract_topk_prediction (line 66) | def extract_topk_prediction(prediction, k=-1):
function eval_result (line 78) | def eval_result(predict_file, encrypt=False, cal_f1=True, topk = -1):
FILE: llm/src/qa_prediction/gen_rule_path.py
function get_output_file (line 25) | def get_output_file(path, force=False):
function parse_prediction (line 42) | def parse_prediction(prediction):
function generate_seq (line 71) | def generate_seq(
function gen_prediction (line 102) | def gen_prediction(args):
FILE: llm/src/qa_prediction/predict_answer.py
function normalize (line 25) | def normalize(s: str) -> str:
function match (line 37) | def match(s1: str, s2: str) -> bool:
function load_gnn_rag (line 43) | def load_gnn_rag(g_data_file, g_data_file2=None):
function get_output_file (line 83) | def get_output_file(path, force=False):
function merge_rule_result (line 100) | def merge_rule_result(qa_dataset, rule_dataset, n_proc=1, filter_empty=F...
function prediction (line 127) | def prediction(data, processed_list, input_builder, model, encrypt=False...
function main (line 174) | def main(args, LLM):
FILE: llm/src/utils/graph_utils.py
function build_graph (line 10) | def build_graph(graph: list, entities=None, encrypt=False) -> nx.Graph:
function bfs_with_rule (line 24) | def bfs_with_rule(graph, start_node, target_rule, max_p = 10):
function get_truth_paths (line 49) | def get_truth_paths(q_entity: list, a_entity: list, graph: nx.Graph) -> ...
function get_simple_paths (line 77) | def get_simple_paths(q_entity: list, a_entity: list, graph: nx.Graph, ho...
function get_negative_paths (line 100) | def get_negative_paths(q_entity: list, a_entity: list, graph: nx.Graph, ...
function get_random_paths (line 129) | def get_random_paths(q_entity: list, graph: nx.Graph, n=3, hop=2):# -> t...
FILE: llm/src/utils/merge_peft.py
class ScriptArguments (line 8) | class ScriptArguments:
FILE: llm/src/utils/training_utils.py
function smart_tokenizer_and_embedding_resize (line 4) | def smart_tokenizer_and_embedding_resize(
FILE: llm/src/utils/utils.py
function read_prompt (line 5) | def read_prompt(prompt_path):
function load_jsonl (line 10) | def load_jsonl(file_path):
function load_multiple_jsonl (line 17) | def load_multiple_jsonl(file_path_list):
function list_to_string (line 23) | def list_to_string(l: list) -> str:
function rule_to_string (line 27) | def rule_to_string(rule: list, sep_token = "<SEP>", bop = "<PATH>", eop ...
function path_to_string (line 34) | def path_to_string(path: list) -> str:
class InstructFormater (line 46) | class InstructFormater(object):
method __init__ (line 47) | def __init__(self, prompt_path):
method format (line 57) | def format(self, instruction, message):
Copy disabled (too large)
Download .json
Condensed preview — 118 files, each showing path, character count, and a content snippet. Download the .json file for the full structured content (77,694K chars).
[
{
"path": ".gitignore",
"chars": 3078,
"preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
},
{
"path": "README.md",
"chars": 1044,
"preview": "This is the code for **GNN-RAG: Graph Neural Retrieval for Large Language Modeling Reasoning**.\n\n\n![alt GNN-RAG: The GNN"
},
{
"path": "gnn/.gitignore",
"chars": 3152,
"preview": "#LM_KGQA specific\ncheckpoint/\ncheckpoint/pretrain/\ndata/\npretrained_lms/\n\n# Byte-compiled / optimized / DLL files\n__pyca"
},
{
"path": "gnn/README.md",
"chars": 1262,
"preview": "## Get Started\nWe have simple requirements in `requirements.txt`. You can always check if you can run the code immediate"
},
{
"path": "gnn/dataset_load.py",
"chars": 28915,
"preview": "import json\nimport numpy as np\nimport re\nfrom tqdm import tqdm\nimport torch\nfrom collections import Counter\nimport rando"
},
{
"path": "gnn/dataset_load_graft.py",
"chars": 8621,
"preview": "import json\nimport numpy as np\nimport re\nfrom tqdm import tqdm\nimport torch\nfrom collections import Counter\nimport rando"
},
{
"path": "gnn/evaluate.py",
"chars": 9583,
"preview": "\nfrom tqdm import tqdm\ntqdm.monitor_iterval = 0\nimport torch\nimport numpy as np\nimport math, os\nimport json\nimport pickl"
},
{
"path": "gnn/main.py",
"chars": 1256,
"preview": "import argparse\n\nfrom utils import create_logger\nimport torch\nimport numpy as np\nimport os\nimport time\n#from Models.ReaR"
},
{
"path": "gnn/models/GraftNet/graftnet.py",
"chars": 8232,
"preview": "import torch\nimport numpy as np\nfrom torch.autograd import Variable\nimport torch.nn.functional as F\nimport torch.nn as n"
},
{
"path": "gnn/models/NSM/nsm.py",
"chars": 11802,
"preview": "import torch\nimport numpy as np\nfrom torch.autograd import Variable\nimport torch.nn.functional as F\nimport torch.nn as n"
},
{
"path": "gnn/models/ReaRev/rearev.py",
"chars": 10649,
"preview": "import torch\nimport numpy as np\nfrom torch.autograd import Variable\nimport torch.nn.functional as F\nimport torch.nn as n"
},
{
"path": "gnn/models/base_model.py",
"chars": 12609,
"preview": "import torch\nimport numpy as np\nimport torch.nn as nn\n\nimport numpy as np\n\nVERY_SMALL_NUMBER = 1e-10\n\nclass BaseModel(to"
},
{
"path": "gnn/modules/kg_reasoning/base_gnn.py",
"chars": 4119,
"preview": "import torch\nimport numpy as np\nfrom collections import defaultdict\n\nVERY_NEG_NUMBER = -100000000000\n\nclass BaseGNNLayer"
},
{
"path": "gnn/modules/kg_reasoning/graft_gnn.py",
"chars": 8234,
"preview": "import torch\nimport numpy as np\nfrom torch.autograd import Variable\nimport torch.nn.functional as F\nimport torch.nn as n"
},
{
"path": "gnn/modules/kg_reasoning/nsm_gnn.py",
"chars": 5962,
"preview": "import torch\nimport numpy as np\nfrom torch.autograd import Variable\nimport torch.nn.functional as F\nimport torch.nn as n"
},
{
"path": "gnn/modules/kg_reasoning/reasongnn.py",
"chars": 7136,
"preview": "\nimport torch\nimport torch.nn.functional as F\nimport torch.nn as nn\n\n\nfrom .base_gnn import BaseGNNLayer\n\nVERY_NEG_NUMBE"
},
{
"path": "gnn/modules/layer_init.py",
"chars": 3008,
"preview": "\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nVERY_NEG_NUMBER = -100000000000\nVERY_SMALL_NUMBER = "
},
{
"path": "gnn/modules/query_update.py",
"chars": 5739,
"preview": "import torch\nimport numpy as np\nimport torch.nn.functional as F\nimport torch.nn as nn\n\nclass Fusion(nn.Module):\n \"\"\"d"
},
{
"path": "gnn/modules/question_encoding/base_encoder.py",
"chars": 4327,
"preview": "import torch\nimport torch.nn.functional as F\nimport torch.nn as nn\n\nVERY_SMALL_NUMBER = 1e-10\nVERY_NEG_NUMBER = -1000000"
},
{
"path": "gnn/modules/question_encoding/bert_encoder.py",
"chars": 4674,
"preview": "\nimport torch.nn.functional as F\nimport torch.nn as nn\nVERY_SMALL_NUMBER = 1e-10\nVERY_NEG_NUMBER = -100000000000\n\n\nfrom "
},
{
"path": "gnn/modules/question_encoding/lstm_encoder.py",
"chars": 2053,
"preview": "\n\nimport torch.nn as nn\nfrom utils import get_dict\nfrom .base_encoder import BaseInstruction\n\nVERY_SMALL_NUMBER = 1e-10\n"
},
{
"path": "gnn/modules/question_encoding/tokenizers.py",
"chars": 1943,
"preview": "import re\nimport numpy as np\nfrom transformers import BertTokenizer\n\nclass LSTMTokenizer():\n def __init__(self, word2"
},
{
"path": "gnn/parsing.py",
"chars": 6021,
"preview": "import argparse\nimport sys\n\n\ndef bool_flag(v):\n if v.lower() in ('yes', 'true', 't', 'y', '1'):\n return True\n "
},
{
"path": "gnn/requirements.txt",
"chars": 77,
"preview": "Base==1.0.4\nnumpy==1.19.5\ntorch==1.7.1+cu110\ntqdm==4.59.0\ntransformers==4.6.1"
},
{
"path": "gnn/scripts/rearev_cwq.sh",
"chars": 877,
"preview": "\n###ReaRev+SBERT training\n# python main.py ReaRev --is_eval --load_experiment relbert-full_cwq-rearev-final.ckpt --entit"
},
{
"path": "gnn/train_model.py",
"chars": 11260,
"preview": "\nfrom utils import create_logger\nimport time\nimport numpy as np\nimport os, math\n\nimport torch\nfrom torch.optim.lr_schedu"
},
{
"path": "gnn/utils.py",
"chars": 1177,
"preview": "import logging\nimport os\n\n\ndef create_logger(args):\n log_file = os.path.join(args.checkpoint_dir, args.experiment_nam"
},
{
"path": "llm/.gitignore",
"chars": 3095,
"preview": "datasets/\n*json\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distr"
},
{
"path": "llm/README.md",
"chars": 816,
"preview": "## Get Started\nWe have simple requirements in `requirements.txt`. You can always check if you can run the code immediate"
},
{
"path": "llm/prompts/alpaca.txt",
"chars": 224,
"preview": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that"
},
{
"path": "llm/prompts/general_prompt.txt",
"chars": 22,
"preview": "{instruction}\n\n{input}"
},
{
"path": "llm/prompts/llama2.txt",
"chars": 52,
"preview": "[INST] <<SYS>>\n<</SYS>>\n{instruction}{input} [/INST]"
},
{
"path": "llm/prompts/llama2_predict.txt",
"chars": 54,
"preview": "[INST] <<SYS>>\n<</SYS>>\n{instruction}\n\n{input} [/INST]"
},
{
"path": "llm/requirements.txt",
"chars": 151,
"preview": "openai==0.27.9\ntransformers==4.32.0\ntrl==0.7.1\npeft==0.5.0\ndatasets==2.14.4\naccelerate==0.22.0\npybind11==2.11.1\nnetworkx"
},
{
"path": "llm/results/KGQA-GNN-RAG/rearev-lmsr/RoG-cwq/RoG/test/results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/False/args.txt",
"chars": 543,
"preview": "{\n \"data_path\": \"rmanluo\",\n \"d\": \"RoG-cwq\",\n \"split\": \"test\",\n \"predict_path\": \"results/KGQA-GO\",\n \"model_name\": \"R"
},
{
"path": "llm/results/KGQA-GNN-RAG/rearev-lmsr/RoG-cwq/RoG/test/results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/False/detailed_eval_result.jsonl",
"chars": 925418,
"preview": "{\"id\": \"WebQTest-832_c334509bb5e02cacae1ba2e80c176499\", \"prediction\": [\"2014 World Series\", \"2010 World Series\", \"2012 W"
},
{
"path": "llm/results/KGQA-GNN-RAG/rearev-lmsr/RoG-cwq/RoG/test/results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/False/eval_result.txt",
"chars": 152,
"preview": "Accuracy: 62.30464023153203 Hit: 66.24185783064287 Hit1: 61.31407533276692 F1: 58.963107192119594 Precision: 59.90911544"
},
{
"path": "llm/results/KGQA-GNN-RAG/rearev-lmsr/RoG-cwq/RoG/test/results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/False/predictions.jsonl",
"chars": 10081521,
"preview": "{\"id\": \"WebQTest-832_c334509bb5e02cacae1ba2e80c176499\", \"question\": \"Lou Seal is the mascot for the team that last won t"
},
{
"path": "llm/results/KGQA-GNN-RAG/rearev-lmsr/RoG-webqsp/RoG/test/results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/False/args.txt",
"chars": 549,
"preview": "{\n \"data_path\": \"rmanluo\",\n \"d\": \"RoG-webqsp\",\n \"split\": \"test\",\n \"predict_path\": \"results/KGQA-GO\",\n \"model_name\":"
},
{
"path": "llm/results/KGQA-GNN-RAG/rearev-lmsr/RoG-webqsp/RoG/test/results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/False/detailed_eval_result.jsonl",
"chars": 870257,
"preview": "{\"id\": \"WebQTest-0\", \"prediction\": [\"Jamaican English\", \"Jamaican Creole English Language\"], \"ground_truth\": [\"Jamaican "
},
{
"path": "llm/results/KGQA-GNN-RAG/rearev-lmsr/RoG-webqsp/RoG/test/results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/False/eval_result.txt",
"chars": 149,
"preview": "Accuracy: 72.63363902772625 Hit: 85.012285012285 Hit1: 80.28255528255528 F1: 71.53705252039563 Precision: 78.34699702921"
},
{
"path": "llm/results/KGQA-GNN-RAG/rearev-lmsr/RoG-webqsp/RoG/test/results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/False/predictions.jsonl",
"chars": 3224996,
"preview": "{\"id\": \"WebQTest-0\", \"question\": \"what does jamaican people speak\", \"prediction\": \"Jamaican English\\nJamaican Creole Eng"
},
{
"path": "llm/results/KGQA-GNN-RAG/rearev-lmsr/RoG-webqsp/llama2-chat-hf/test/no_rule/False/args.txt",
"chars": 716,
"preview": "{\n \"data_path\": \"rmanluo\",\n \"d\": \"RoG-webqsp\",\n \"split\": \"test\",\n \"predict_path\": \"results/KGQA-G-lmsr\",\n \"model_na"
},
{
"path": "llm/results/KGQA-GNN-RAG/rearev-lmsr/RoG-webqsp/llama2-chat-hf/test/no_rule/False/detailed_eval_result.jsonl",
"chars": 1431238,
"preview": "{\"id\": \"WebQTest-0\", \"prediction\": [\"Based on the reasoning paths provided, the possible answers to the question \\\"what "
},
{
"path": "llm/results/KGQA-GNN-RAG/rearev-lmsr/RoG-webqsp/llama2-chat-hf/test/no_rule/False/eval_result.txt",
"chars": 152,
"preview": "Accuracy: 74.90245578760596 Hit: 86.79361179361179 Hit1: 6.142506142506143 F1: 36.256619819061264 Precision: 31.38534839"
},
{
"path": "llm/results/KGQA-GNN-RAG/rearev-lmsr/RoG-webqsp/llama2-chat-hf/test/no_rule/False/predictions.jsonl",
"chars": 4595076,
"preview": "{\"id\": \"WebQTest-0\", \"question\": \"what does jamaican people speak\", \"prediction\": \"Based on the reasoning paths provided"
},
{
"path": "llm/results/KGQA-GNN-RAG/rearev-sbert/RoG-cwq/RoG/test/results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/False/args.txt",
"chars": 548,
"preview": "{\n \"data_path\": \"rmanluo\",\n \"d\": \"RoG-cwq\",\n \"split\": \"test\",\n \"predict_path\": \"results/KGQA-G-sbert\",\n \"model_name"
},
{
"path": "llm/results/KGQA-GNN-RAG/rearev-sbert/RoG-cwq/RoG/test/results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/False/detailed_eval_result.jsonl",
"chars": 923622,
"preview": "{\"id\": \"WebQTest-832_c334509bb5e02cacae1ba2e80c176499\", \"prediction\": [\"2014 World Series\", \"2010 World Series\", \"2012 W"
},
{
"path": "llm/results/KGQA-GNN-RAG/rearev-sbert/RoG-cwq/RoG/test/results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/False/eval_result.txt",
"chars": 151,
"preview": "Accuracy: 62.76966324662397 Hit: 66.80826961200793 Hit1: 61.73888416879071 F1: 59.42654256015807 Precision: 60.543456243"
},
{
"path": "llm/results/KGQA-GNN-RAG/rearev-sbert/RoG-cwq/RoG/test/results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/False/predictions.jsonl",
"chars": 10066656,
"preview": "{\"id\": \"WebQTest-832_c334509bb5e02cacae1ba2e80c176499\", \"question\": \"Lou Seal is the mascot for the team that last won t"
},
{
"path": "llm/results/KGQA-GNN-RAG/rearev-sbert/RoG-webqsp/RoG/test/results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/False/args.txt",
"chars": 554,
"preview": "{\n \"data_path\": \"rmanluo\",\n \"d\": \"RoG-webqsp\",\n \"split\": \"test\",\n \"predict_path\": \"results/KGQA-G-sbert\",\n \"model_n"
},
{
"path": "llm/results/KGQA-GNN-RAG/rearev-sbert/RoG-webqsp/RoG/test/results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/False/detailed_eval_result.jsonl",
"chars": 888293,
"preview": "{\"id\": \"WebQTest-0\", \"prediction\": [\"Jamaican English\", \"Jamaican Creole English Language\"], \"ground_truth\": [\"Jamaican "
},
{
"path": "llm/results/KGQA-GNN-RAG/rearev-sbert/RoG-webqsp/RoG/test/results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/False/eval_result.txt",
"chars": 151,
"preview": "Accuracy: 73.76651533728501 Hit: 85.68796068796068 Hit1: 80.58968058968058 F1: 71.28257365406769 Precision: 77.185064493"
},
{
"path": "llm/results/KGQA-GNN-RAG/rearev-sbert/RoG-webqsp/RoG/test/results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/False/predictions.jsonl",
"chars": 3767245,
"preview": "{\"id\": \"WebQTest-0\", \"question\": \"what does jamaican people speak\", \"prediction\": \"Jamaican English\\nJamaican Creole Eng"
},
{
"path": "llm/results/KGQA-GNN-RAG-RA/rearev-lmsr/RoG-cwq/RoG/test/results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/False/args.txt",
"chars": 542,
"preview": "{\n \"data_path\": \"rmanluo\",\n \"d\": \"RoG-cwq\",\n \"split\": \"test\",\n \"predict_path\": \"results/KGQA-G\",\n \"model_name\": \"Ro"
},
{
"path": "llm/results/KGQA-GNN-RAG-RA/rearev-lmsr/RoG-cwq/RoG/test/results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/False/detailed_eval_result.jsonl",
"chars": 918794,
"preview": "{\"id\": \"WebQTest-832_c334509bb5e02cacae1ba2e80c176499\", \"prediction\": [\"2014 World Series\", \"2012 World Series\", \"2010 W"
},
{
"path": "llm/results/KGQA-GNN-RAG-RA/rearev-lmsr/RoG-cwq/RoG/test/results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/False/eval_result.txt",
"chars": 153,
"preview": "Accuracy: 64.56916340016058 Hit: 67.96941376380629 Hit1: 63.041631265930334 F1: 60.94993975622944 Precision: 61.13609531"
},
{
"path": "llm/results/KGQA-GNN-RAG-RA/rearev-lmsr/RoG-webqsp/RoG/test/results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/False/args.txt",
"chars": 548,
"preview": "{\n \"data_path\": \"rmanluo\",\n \"d\": \"RoG-webqsp\",\n \"split\": \"test\",\n \"predict_path\": \"results/KGQA-G\",\n \"model_name\": "
},
{
"path": "llm/results/KGQA-GNN-RAG-RA/rearev-lmsr/RoG-webqsp/RoG/test/results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/False/detailed_eval_result.jsonl",
"chars": 930571,
"preview": "{\"id\": \"WebQTest-0\", \"prediction\": [\"Jamaican English\", \"Jamaican Creole English Language\"], \"ground_truth\": [\"Jamaican "
},
{
"path": "llm/results/KGQA-GNN-RAG-RA/rearev-lmsr/RoG-webqsp/RoG/test/results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/False/eval_result.txt",
"chars": 148,
"preview": "Accuracy: 82.1694156927992 Hit: 89.8034398034398 Hit1: 82.37100737100737 F1: 73.36257924133025 Precision: 74.26870476680"
},
{
"path": "llm/results/KGQA-GNN-RAG-RA/rearev-lmsr/RoG-webqsp/RoG/test/results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/False/predictions.jsonl",
"chars": 5193412,
"preview": "{\"id\": \"WebQTest-0\", \"question\": \"what does jamaican people speak\", \"prediction\": \"Jamaican English\\nJamaican Creole Eng"
},
{
"path": "llm/results/KGQA-GNN-RAG-RA/rearev-sbert/RoG-cwq/RoG/test/results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/False/args.txt",
"chars": 544,
"preview": "{\n \"data_path\": \"rmanluo\",\n \"d\": \"RoG-cwq\",\n \"split\": \"test\",\n \"predict_path\": \"results/KGQA-GOS\",\n \"model_name\": \""
},
{
"path": "llm/results/KGQA-GNN-RAG-RA/rearev-sbert/RoG-cwq/RoG/test/results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/False/detailed_eval_result.jsonl",
"chars": 943356,
"preview": "{\"id\": \"WebQTest-832_c334509bb5e02cacae1ba2e80c176499\", \"prediction\": [\"2014 World Series\", \"2010 World Series\", \"2012 W"
},
{
"path": "llm/results/KGQA-GNN-RAG-RA/rearev-sbert/RoG-cwq/RoG/test/results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/False/eval_result.txt",
"chars": 153,
"preview": "Accuracy: 65.00780922839806 Hit: 68.64910790144435 Hit1: 62.786745964316054 F1: 60.448091645187766 Precision: 60.4990589"
},
{
"path": "llm/results/KGQA-GNN-RAG-RA/rearev-sbert/RoG-webqsp/RoG/test/results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/False/args.txt",
"chars": 549,
"preview": "{\n \"data_path\": \"rmanluo\",\n \"d\": \"RoG-webqsp\",\n \"split\": \"test\",\n \"predict_path\": \"results/KGQA-G2\",\n \"model_name\":"
},
{
"path": "llm/results/KGQA-GNN-RAG-RA/rearev-sbert/RoG-webqsp/RoG/test/results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/False/detailed_eval_result.jsonl",
"chars": 956024,
"preview": "{\"id\": \"WebQTest-0\", \"prediction\": [\"Jamaican English\", \"Jamaican Creole English Language\"], \"ground_truth\": [\"Jamaican "
},
{
"path": "llm/results/KGQA-GNN-RAG-RA/rearev-sbert/RoG-webqsp/RoG/test/results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/False/eval_result.txt",
"chars": 147,
"preview": "Accuracy: 82.9532337788689 Hit: 90.72481572481573 Hit1: 82.8009828009828 F1: 73.4900197824918 Precision: 73.233417222905"
},
{
"path": "llm/results/KGQA-GNN-RAG-RA/rearev-sbert/RoG-webqsp/RoG/test/results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/False/predictions.jsonl",
"chars": 5618513,
"preview": "{\"id\": \"WebQTest-0\", \"question\": \"what does jamaican people speak\", \"prediction\": \"Jamaican English\\nJamaican Creole Eng"
},
{
"path": "llm/results/gen_rule_path/RoG-cwq/RoG/test/predictions_2_False.jsonl",
"chars": 3792573,
"preview": "{\"id\": \"WebQTest-832_c334509bb5e02cacae1ba2e80c176499\", \"question\": \"Lou Seal is the mascot for the team that last won t"
},
{
"path": "llm/results/gen_rule_path/RoG-cwq/RoG/test/predictions_3_False.jsonl",
"chars": 4442068,
"preview": "{\"id\": \"WebQTest-832_c334509bb5e02cacae1ba2e80c176499\", \"question\": \"Lou Seal is the mascot for the team that last won t"
},
{
"path": "llm/results/gen_rule_path/RoG-webqsp/RoG/test/predictions_1_False.jsonl",
"chars": 1073009,
"preview": "{\"id\": \"WebQTest-0\", \"question\": \"what does jamaican people speak\", \"prediction\": [[\"location.country.languages_spoken\"]"
},
{
"path": "llm/results/gen_rule_path/RoG-webqsp/RoG/test/predictions_2_False.jsonl",
"chars": 1402094,
"preview": "{\"id\": \"WebQTest-0\", \"question\": \"what does jamaican people speak\", \"prediction\": [[\"location.country.languages_spoken\"]"
},
{
"path": "llm/results/gen_rule_path/RoG-webqsp/RoG/test/predictions_3_False.jsonl",
"chars": 1671687,
"preview": "{\"id\": \"WebQTest-0\", \"question\": \"what does jamaican people speak\", \"prediction\": [[\"location.country.languages_spoken\"]"
},
{
"path": "llm/results/gen_rule_path/RoG-webqsp/RoG/train/predictions_1_False.jsonl",
"chars": 1954524,
"preview": "{\"id\": \"WebQTrn-0\", \"question\": \"what is the name of justin bieber brother\", \"prediction\": [[\"people.person.nationality\""
},
{
"path": "llm/results/gen_rule_path/RoG-webqsp/RoG/train/predictions_3_False.jsonl",
"chars": 3007418,
"preview": "{\"id\": \"WebQTrn-0\", \"question\": \"what is the name of justin bieber brother\", \"prediction\": [[\"people.person.nationality\""
},
{
"path": "llm/results/gen_rule_path/RoG-webqsp/RoG/validation/predictions_3_False.jsonl",
"chars": 260362,
"preview": "{\"id\": \"WebQTrn-9\", \"question\": \"how old is sacha baron cohen\", \"prediction\": [[\"people.person.nationality\"], [\"people.p"
},
{
"path": "llm/results/gnn/RoG-cwq/rearev-lmsr/test.info",
"chars": 1552571,
"preview": "{\"question\": \"lou seal is the mascot for the team that last won the world series when ? \", \"0\": {}, \"1\": {}, \"answers\": "
},
{
"path": "llm/results/gnn/RoG-cwq/rearev-sbert/test.info",
"chars": 1843012,
"preview": "{\"question\": \"lou seal is the mascot for the team that last won the world series when ? \", \"0\": {}, \"1\": {}, \"answers\": "
},
{
"path": "llm/results/gnn/RoG-webqsp/rearev-lmsr/test.info",
"chars": 861379,
"preview": "{\"question\": \"what does jamaican people speak \", \"0\": {}, \"1\": {}, \"answers\": [\"m.01428y\", \"m.04ygk0\"], \"precison\": 1.0,"
},
{
"path": "llm/results/gnn/RoG-webqsp/rearev-sbert/test.info",
"chars": 970807,
"preview": "{\"question\": \"what does jamaican people speak \", \"0\": {}, \"1\": {}, \"2\": {}, \"answers\": [\"m.01428y\", \"m.04ygk0\"], \"precis"
},
{
"path": "llm/scripts/evaluate_multi_hop.sh",
"chars": 180,
"preview": "d='results/KGQA-GNN-RAG/rearev-sbert/RoG-webqsp/RoG/test/results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_j"
},
{
"path": "llm/scripts/interpretable_example.py",
"chars": 1383,
"preview": "from transformers import pipeline, AutoTokenizer\nimport torch\n\nMODEL_PATH_OR_NAME=\"rmanluo/RoG\"\n\ntokenizer = AutoTokeniz"
},
{
"path": "llm/scripts/planning.sh",
"chars": 389,
"preview": "SPLIT=\"test\"\nDATASET_LIST=\"RoG-webqsp\"\nMODEL_NAME=RoG\nMODEL_PATH=rmanluo/RoG\n\nBEAM_LIST=\"3\" # \"1 2 3 4 5\"\nfor DATASET in"
},
{
"path": "llm/scripts/plug-and-play.sh",
"chars": 1188,
"preview": "SPLIT=\"test\"\nDATASET_LIST=\"RoG-cwq\"\nBEAM_LIST=\"3\" # \"1 2 3 4 5\"\nMODEL_LIST=\"llama2-chat-hf\"\n#PROMPT_LIST=\"prompts/genera"
},
{
"path": "llm/scripts/rag-reasoning.sh",
"chars": 1693,
"preview": "\nSPLIT=\"test\"\nDATASET_LIST=\"RoG-webqsp\"\nMODEL_NAME=RoG\nPROMPT_PATH=prompts/llama2_predict.txt\nBEAM_LIST=\"3\" # \"1 2 3 4 5"
},
{
"path": "llm/scripts/train.sh",
"chars": 1217,
"preview": "MODEL_PATH=meta-llama/Llama-2-7b-chat-hf\nDATASET_LIST=\"datasets/joint_training/align/cwq/cwq_train.jsonl datasets/joint_"
},
{
"path": "llm/src/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "llm/src/align_kg/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "llm/src/align_kg/build_align_qa_dataset.py",
"chars": 2367,
"preview": "import sys\nimport os\nsys.path.append(os.path.dirname(os.path.realpath(__file__)) + \"/..\")\nimport argparse\nimport os\nimpo"
},
{
"path": "llm/src/align_kg/data_loader.py",
"chars": 1570,
"preview": "import string\nimport sys\nimport os\nsys.path.append(os.path.dirname(os.path.realpath(__file__)) + \"/..\")\n\nfrom datasets i"
},
{
"path": "llm/src/joint_training/generate_explanation_results.py",
"chars": 6159,
"preview": "import sys\nimport os\nsys.path.append(os.path.dirname(os.path.realpath(__file__)) + \"/..\")\nfrom utils import *\nimport dat"
},
{
"path": "llm/src/joint_training/joint_finetuning.py",
"chars": 6396,
"preview": "import sys\nimport os\n\nimport torch\n\nsys.path.append(os.path.dirname(os.path.realpath(__file__)) + \"/..\")\nimport os\nfrom "
},
{
"path": "llm/src/joint_training/preprocess_align.py",
"chars": 2121,
"preview": "import sys\nimport os\nsys.path.append(os.path.dirname(os.path.realpath(__file__)) + \"/..\")\nfrom utils import *\nfrom trans"
},
{
"path": "llm/src/joint_training/preprocess_qa.py",
"chars": 2484,
"preview": "import sys\nimport os\nsys.path.append(os.path.dirname(os.path.realpath(__file__)) + \"/..\")\nfrom utils import *\nfrom trans"
},
{
"path": "llm/src/llms/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "llm/src/llms/language_models/__init__.py",
"chars": 643,
"preview": "from .chatgpt import ChatGPT\nfrom .alpaca import Alpaca\nfrom .longchat.longchat import Longchat\nfrom .base_language_mode"
},
{
"path": "llm/src/llms/language_models/alpaca.py",
"chars": 1504,
"preview": "from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer\nimport torch\nfrom .base_language_model import Bas"
},
{
"path": "llm/src/llms/language_models/base_language_model.py",
"chars": 851,
"preview": "import collections\n\n\nclass BaseLanguageModel(object):\n \"\"\"\n Base lanuage model. Define how to generate sentence by"
},
{
"path": "llm/src/llms/language_models/chatgpt.py",
"chars": 2919,
"preview": "import time\nimport os\nimport openai\nfrom .base_language_model import BaseLanguageModel\nimport dotenv\nimport tiktoken\ndot"
},
{
"path": "llm/src/llms/language_models/flan_t5.py",
"chars": 1431,
"preview": "from transformers import pipeline, AutoModel, AutoTokenizer\nimport torch\nfrom .base_language_model import BaseLanguageMo"
},
{
"path": "llm/src/llms/language_models/llama.py",
"chars": 1797,
"preview": "from transformers import pipeline, AutoTokenizer\nimport torch\nfrom .base_language_model import BaseLanguageModel\nfrom tr"
},
{
"path": "llm/src/llms/language_models/longchat/llama_condense_monkey_patch.py",
"chars": 2838,
"preview": "# code adapted from https://huggingface.co/kaiokendev/superhot-13b-8k-no-rlhf-test/blob/main/llama_rope_scaled_monkey_pa"
},
{
"path": "llm/src/llms/language_models/longchat/llama_flash_attn_monkey_patch.py",
"chars": 3982,
"preview": "from typing import List, Optional, Tuple\n\nimport torch\nfrom torch import nn\n\nimport transformers\nfrom transformers.model"
},
{
"path": "llm/src/llms/language_models/longchat/longchat.py",
"chars": 2375,
"preview": "from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer\nimport torch\nfrom ..base_language_model import Ba"
},
{
"path": "llm/src/llms/llm_proxy.py",
"chars": 2080,
"preview": "\nimport openai\nfrom dotenv import load_dotenv\nimport os\nimport time\nfrom .start_fastchat_api import start_fastchat_api\nc"
},
{
"path": "llm/src/llms/start_fastchat_api.py",
"chars": 2274,
"preview": "import subprocess\nimport argparse\nimport sys\nimport atexit\nimport signal\nprocesses = []\n\ndef terminate_process():\n fo"
},
{
"path": "llm/src/qa_prediction/build_qa_input.py",
"chars": 8094,
"preview": "import sys\nimport os\nsys.path.append(os.path.dirname(os.path.realpath(__file__)) + \"/..\")\nimport utils\nimport random\nfro"
},
{
"path": "llm/src/qa_prediction/evaluate_multi_hop.py",
"chars": 5426,
"preview": "import argparse\nimport glob\nimport json\nimport os\nimport re\nimport string\nfrom sklearn.metrics import precision_score\n\ni"
},
{
"path": "llm/src/qa_prediction/evaluate_results.py",
"chars": 5769,
"preview": "import argparse\nimport glob\nimport json\nimport os\nimport re\nimport string\nfrom sklearn.metrics import precision_score\n\ni"
},
{
"path": "llm/src/qa_prediction/gen_rule_path.py",
"chars": 7237,
"preview": "import json\nimport sys\nimport os\n\nsys.path.append(os.path.dirname(os.path.realpath(__file__)) + \"/..\")\nimport argparse\ni"
},
{
"path": "llm/src/qa_prediction/predict_answer.py",
"chars": 10921,
"preview": "import sys\nimport os\n\nsys.path.append(os.path.dirname(os.path.realpath(__file__)) + \"/..\")\nimport utils\nimport argparse\n"
},
{
"path": "llm/src/utils/__init__.py",
"chars": 77,
"preview": "from .graph_utils import *\nfrom .utils import *\nfrom .training_utils import *"
},
{
"path": "llm/src/utils/graph_utils.py",
"chars": 5004,
"preview": "import networkx as nx\nfrom collections import deque\n#import walker\n\nimport json\nwith open('entities_names.json') as f:\n "
},
{
"path": "llm/src/utils/merge_peft.py",
"chars": 688,
"preview": "from dataclasses import dataclass, field\nfrom typing import Optional\nfrom transformers import HfArgumentParser\nimport to"
},
{
"path": "llm/src/utils/training_utils.py",
"chars": 1126,
"preview": "import transformers\nfrom typing import Dict, List\n\ndef smart_tokenizer_and_embedding_resize(\n new_tokens: List[str],\n"
},
{
"path": "llm/src/utils/utils.py",
"chars": 1514,
"preview": "\nimport json\nimport string\n\ndef read_prompt(prompt_path):\n with open(prompt_path, 'r') as f:\n prompt_template "
}
]
// ... and 2 more files (download for full content)
About this extraction
This page contains the full source code of the cmavro/GNN-RAG GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 118 files (106.0 MB), approximately 18.6M tokens, and a symbol index with 298 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.