Full Code of cmavro/GNN-RAG for AI

main 28d31bc0db3d cached
118 files
106.0 MB
18.6M tokens
298 symbols
1 requests
Copy disabled (too large) Download .txt
Showing preview only (74,471K chars total). Download the full file to get everything.
Repository: cmavro/GNN-RAG
Branch: main
Commit: 28d31bc0db3d
Files: 118
Total size: 106.0 MB

Directory structure:
gitextract_o4kycd_o/

├── .gitignore
├── README.md
├── gnn/
│   ├── .gitignore
│   ├── README.md
│   ├── dataset_load.py
│   ├── dataset_load_graft.py
│   ├── evaluate.py
│   ├── main.py
│   ├── models/
│   │   ├── GraftNet/
│   │   │   └── graftnet.py
│   │   ├── NSM/
│   │   │   └── nsm.py
│   │   ├── ReaRev/
│   │   │   └── rearev.py
│   │   └── base_model.py
│   ├── modules/
│   │   ├── kg_reasoning/
│   │   │   ├── base_gnn.py
│   │   │   ├── graft_gnn.py
│   │   │   ├── nsm_gnn.py
│   │   │   └── reasongnn.py
│   │   ├── layer_init.py
│   │   ├── query_update.py
│   │   └── question_encoding/
│   │       ├── base_encoder.py
│   │       ├── bert_encoder.py
│   │       ├── lstm_encoder.py
│   │       └── tokenizers.py
│   ├── parsing.py
│   ├── requirements.txt
│   ├── scripts/
│   │   └── rearev_cwq.sh
│   ├── train_model.py
│   └── utils.py
└── llm/
    ├── .gitignore
    ├── README.md
    ├── prompts/
    │   ├── alpaca.txt
    │   ├── general_prompt.txt
    │   ├── llama2.txt
    │   └── llama2_predict.txt
    ├── requirements.txt
    ├── results/
    │   ├── KGQA-GNN-RAG/
    │   │   ├── rearev-lmsr/
    │   │   │   ├── RoG-cwq/
    │   │   │   │   └── RoG/
    │   │   │   │       └── test/
    │   │   │   │           └── results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/
    │   │   │   │               └── False/
    │   │   │   │                   ├── args.txt
    │   │   │   │                   ├── detailed_eval_result.jsonl
    │   │   │   │                   ├── eval_result.txt
    │   │   │   │                   └── predictions.jsonl
    │   │   │   └── RoG-webqsp/
    │   │   │       ├── RoG/
    │   │   │       │   └── test/
    │   │   │       │       └── results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/
    │   │   │       │           └── False/
    │   │   │       │               ├── args.txt
    │   │   │       │               ├── detailed_eval_result.jsonl
    │   │   │       │               ├── eval_result.txt
    │   │   │       │               └── predictions.jsonl
    │   │   │       └── llama2-chat-hf/
    │   │   │           └── test/
    │   │   │               └── no_rule/
    │   │   │                   └── False/
    │   │   │                       ├── args.txt
    │   │   │                       ├── detailed_eval_result.jsonl
    │   │   │                       ├── eval_result.txt
    │   │   │                       └── predictions.jsonl
    │   │   └── rearev-sbert/
    │   │       ├── RoG-cwq/
    │   │       │   └── RoG/
    │   │       │       └── test/
    │   │       │           └── results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/
    │   │       │               └── False/
    │   │       │                   ├── args.txt
    │   │       │                   ├── detailed_eval_result.jsonl
    │   │       │                   ├── eval_result.txt
    │   │       │                   └── predictions.jsonl
    │   │       └── RoG-webqsp/
    │   │           └── RoG/
    │   │               └── test/
    │   │                   └── results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/
    │   │                       └── False/
    │   │                           ├── args.txt
    │   │                           ├── detailed_eval_result.jsonl
    │   │                           ├── eval_result.txt
    │   │                           └── predictions.jsonl
    │   ├── KGQA-GNN-RAG-RA/
    │   │   ├── rearev-lmsr/
    │   │   │   ├── RoG-cwq/
    │   │   │   │   └── RoG/
    │   │   │   │       └── test/
    │   │   │   │           └── results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/
    │   │   │   │               └── False/
    │   │   │   │                   ├── args.txt
    │   │   │   │                   ├── detailed_eval_result.jsonl
    │   │   │   │                   ├── eval_result.txt
    │   │   │   │                   └── predictions.jsonl
    │   │   │   └── RoG-webqsp/
    │   │   │       └── RoG/
    │   │   │           └── test/
    │   │   │               └── results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/
    │   │   │                   └── False/
    │   │   │                       ├── args.txt
    │   │   │                       ├── detailed_eval_result.jsonl
    │   │   │                       ├── eval_result.txt
    │   │   │                       └── predictions.jsonl
    │   │   └── rearev-sbert/
    │   │       ├── RoG-cwq/
    │   │       │   └── RoG/
    │   │       │       └── test/
    │   │       │           └── results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/
    │   │       │               └── False/
    │   │       │                   ├── args.txt
    │   │       │                   ├── detailed_eval_result.jsonl
    │   │       │                   ├── eval_result.txt
    │   │       │                   └── predictions.jsonl
    │   │       └── RoG-webqsp/
    │   │           └── RoG/
    │   │               └── test/
    │   │                   └── results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/
    │   │                       └── False/
    │   │                           ├── args.txt
    │   │                           ├── detailed_eval_result.jsonl
    │   │                           ├── eval_result.txt
    │   │                           └── predictions.jsonl
    │   ├── gen_rule_path/
    │   │   ├── RoG-cwq/
    │   │   │   └── RoG/
    │   │   │       └── test/
    │   │   │           ├── predictions_2_False.jsonl
    │   │   │           └── predictions_3_False.jsonl
    │   │   └── RoG-webqsp/
    │   │       └── RoG/
    │   │           ├── test/
    │   │           │   ├── predictions_1_False.jsonl
    │   │           │   ├── predictions_2_False.jsonl
    │   │           │   └── predictions_3_False.jsonl
    │   │           ├── train/
    │   │           │   ├── predictions_1_False.jsonl
    │   │           │   └── predictions_3_False.jsonl
    │   │           └── validation/
    │   │               └── predictions_3_False.jsonl
    │   └── gnn/
    │       ├── RoG-cwq/
    │       │   ├── rearev-lmsr/
    │       │   │   └── test.info
    │       │   └── rearev-sbert/
    │       │       └── test.info
    │       └── RoG-webqsp/
    │           ├── rearev-lmsr/
    │           │   └── test.info
    │           └── rearev-sbert/
    │               └── test.info
    ├── scripts/
    │   ├── evaluate_multi_hop.sh
    │   ├── interpretable_example.py
    │   ├── planning.sh
    │   ├── plug-and-play.sh
    │   ├── rag-reasoning.sh
    │   └── train.sh
    └── src/
        ├── __init__.py
        ├── align_kg/
        │   ├── __init__.py
        │   ├── build_align_qa_dataset.py
        │   └── data_loader.py
        ├── joint_training/
        │   ├── generate_explanation_results.py
        │   ├── joint_finetuning.py
        │   ├── preprocess_align.py
        │   └── preprocess_qa.py
        ├── llms/
        │   ├── __init__.py
        │   ├── language_models/
        │   │   ├── __init__.py
        │   │   ├── alpaca.py
        │   │   ├── base_language_model.py
        │   │   ├── chatgpt.py
        │   │   ├── flan_t5.py
        │   │   ├── llama.py
        │   │   └── longchat/
        │   │       ├── llama_condense_monkey_patch.py
        │   │       ├── llama_flash_attn_monkey_patch.py
        │   │       └── longchat.py
        │   ├── llm_proxy.py
        │   └── start_fastchat_api.py
        ├── qa_prediction/
        │   ├── build_qa_input.py
        │   ├── evaluate_multi_hop.py
        │   ├── evaluate_results.py
        │   ├── gen_rule_path.py
        │   └── predict_answer.py
        └── utils/
            ├── __init__.py
            ├── graph_utils.py
            ├── merge_peft.py
            ├── training_utils.py
            └── utils.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
#   For a library or package, you might want to ignore these files since the code is
#   intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# poetry
#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
#   This is especially recommended for binary packages to ensure reproducibility, and is more
#   commonly ignored for libraries.
#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
#   in version control.
#   https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
#  and can be added to the global gitignore or merged into this file.  For a more nuclear
#  option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/


================================================
FILE: README.md
================================================
This is the code for **GNN-RAG: Graph Neural Retrieval for Large Language Modeling Reasoning**.


![alt GNN-RAG: The GNN reasons over a dense subgraph to retrieve candidate answers, along
with the corresponding reasoning paths (shortest paths from question entities to answers). The
retrieved reasoning paths -optionally combined with retrieval augmentation (RA)- are verbalized
and given to the LLM for RAG](GNN-RAG.png "GNN-RAG")

The directory is the following:

|----`gnn` folder has the implementation of different KGQA GNNs. 

You can train your own GNNs or you can skip this folder and  use directly the GNN output (retrieved answer nodes) that we computed (`llm/results/gnn`).

|----`llm` folder has the implementation for RAG-based KGQA with LLMs. 

Please see details on how to reproduce results there. 

**Results**: We append all the results for Table 2: See `results/KGQA-GNN-RAG-RA` or `results/KGQA-GNN-RAG`. You can look at the actual LLM generations, as well as the KG information retrieved ("input" key) in predictions.jsonl.


================================================
FILE: gnn/.gitignore
================================================
#LM_KGQA specific
checkpoint/
checkpoint/pretrain/
data/
pretrained_lms/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
#   For a library or package, you might want to ignore these files since the code is
#   intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# poetry
#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
#   This is especially recommended for binary packages to ensure reproducibility, and is more
#   commonly ignored for libraries.
#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
#   in version control.
#   https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
#  and can be added to the global gitignore or merged into this file.  For a more nuclear
#  option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/


================================================
FILE: gnn/README.md
================================================
## Get Started
We have simple requirements in `requirements.txt`. You can always check if you can run the code immediately.

The datasets as well as the pretrained LM (LMsr) are uploaded here: hhttps://drive.google.com/drive/folders/1ifgVHQDnvFEunP9hmVYT07Y3rvcpIfQp?usp=sharing

Please download them and extract them to the corresponding folders.

## Training
Please follow the guidelines and hyperparamters of the corresponding GNNs for training. See `scripts` on a training example.  

Otherwise, you can download released GNN models from here: https://drive.google.com/file/d/1p7eLSsSKkZQxB32mT5lMsthVP6R_3x1j/view

## Evaluation

To evaluate them, copy the command from the above scripts, add the `--is_eval` argument, and `--load experiment` followed by the name of the corresponding `ckpt` model.

For example, for Webqsp run:
```
python main.py ReaRev --entity_dim 50 --num_epoch 200 --batch_size 8 --eval_every 2 --data_folder data/webqsp/ --lm sbert --num_iter 3 --num_ins 2 --num_gnn 3 --relation_word_emb True --load_experiment ReaRev_webqsp.ckpt --is_eval --name webqsp
```

The result is saved as a `.info` file. In order to use GNN-RAG, please move this file to the corresponding folder in `GNN-RAG/llm/results/gnn/` by renaming it to `test.info`.

================================================
FILE: gnn/dataset_load.py
================================================
import json
import numpy as np
import re
from tqdm import tqdm
import torch
from collections import Counter
import random
import warnings
import pickle
warnings.filterwarnings("ignore")
from modules.question_encoding.tokenizers import LSTMTokenizer#, BERTTokenizer
from transformers import AutoTokenizer
import time

import os


class BasicDataLoader(object):
    """ 
    Basic Dataloader contains all the functions to read questions and KGs from json files and
    create mappings between global entity ids and local ids that are used during GNN updates.
    """

    def __init__(self, config, word2id, relation2id, entity2id, tokenize, data_type="train"):
        self.tokenize = tokenize
        self._parse_args(config, word2id, relation2id, entity2id)
        self._load_file(config, data_type)
        self._load_data()
        

    def _load_file(self, config, data_type="train"):

        """
        Loads lines (questions + KG subgraphs) from json files.
        """
        
        data_file = config['data_folder'] + data_type + ".json"
        self.data_file = data_file
        print('loading data from', data_file)
        self.data_type = data_type
        self.data = []
        skip_index = set()
        index = 0

        with open(data_file) as f_in:
            for line in tqdm(f_in):
                if index == config['max_train'] and data_type == "train": break  #break if we reach max_question_size
                line = json.loads(line)
                
                if len(line['entities']) == 0:
                    skip_index.add(index)
                    continue
                self.data.append(line)
                self.max_facts = max(self.max_facts, 2 * len(line['subgraph']['tuples']))
                index += 1

        print("skip", skip_index)
        print('max_facts: ', self.max_facts)
        self.num_data = len(self.data)
        self.batches = np.arange(self.num_data)

    def _load_data(self):

        """
        Creates mappings between global entity ids and local entity ids that are used during GNN updates.
        """

        print('converting global to local entity index ...')
        self.global2local_entity_maps = self._build_global2local_entity_maps()

        if self.use_self_loop:
            self.max_facts = self.max_facts + self.max_local_entity

        self.question_id = []
        self.candidate_entities = np.full((self.num_data, self.max_local_entity), len(self.entity2id), dtype=int)
        self.kb_adj_mats = np.empty(self.num_data, dtype=object)
        self.q_adj_mats = np.empty(self.num_data, dtype=object)
        self.kb_fact_rels = np.full((self.num_data, self.max_facts), self.num_kb_relation, dtype=int)
        self.query_entities = np.zeros((self.num_data, self.max_local_entity), dtype=float)
        self.seed_list = np.empty(self.num_data, dtype=object)
        self.seed_distribution = np.zeros((self.num_data, self.max_local_entity), dtype=float)
        # self.query_texts = np.full((self.num_data, self.max_query_word), len(self.word2id), dtype=int)
        self.answer_dists = np.zeros((self.num_data, self.max_local_entity), dtype=float)
        self.answer_lists = np.empty(self.num_data, dtype=object)

        self._prepare_data()

    def _parse_args(self, config, word2id, relation2id, entity2id):

        """
        Builds necessary dictionaries and stores arguments.
        """
        self.data_eff = config['data_eff']
        self.data_name = config['name']

        if 'use_inverse_relation' in config:
            self.use_inverse_relation = config['use_inverse_relation']
        else:
            self.use_inverse_relation = False
        if 'use_self_loop' in config:
            self.use_self_loop = config['use_self_loop']
        else:
            self.use_self_loop = False

        self.rel_word_emb = config['relation_word_emb']
        #self.num_step = config['num_step']
        self.max_local_entity = 0
        self.max_relevant_doc = 0
        self.max_facts = 0

        print('building word index ...')
        self.word2id = word2id
        self.id2word = {i: word for word, i in word2id.items()}
        self.relation2id = relation2id
        self.entity2id = entity2id
        self.id2entity = {i: entity for entity, i in entity2id.items()}
        self.q_type = config['q_type']

        if self.use_inverse_relation:
            self.num_kb_relation = 2 * len(relation2id)
        else:
            self.num_kb_relation = len(relation2id)
        if self.use_self_loop:
            self.num_kb_relation = self.num_kb_relation + 1
        print("Entity: {}, Relation in KB: {}, Relation in use: {} ".format(len(entity2id),
                                                                            len(self.relation2id),
                                                                            self.num_kb_relation))

    
    def get_quest(self, training=False):
        q_list = []
        
        sample_ids = self.sample_ids
        for sample_id in sample_ids:
            tp_str = self.decode_text(self.query_texts[sample_id, :])
            # id2word = self.id2word
            # for i in range(self.max_query_word):
            #     if self.query_texts[sample_id, i] in id2word:
            #         tp_str += id2word[self.query_texts[sample_id, i]] + " "
            q_list.append(tp_str)
        return q_list

    def decode_text(self, np_array_x):
        if self.tokenize == 'lstm':
            id2word = self.id2word
            tp_str = ""
            for i in range(self.max_query_word):
                if np_array_x[i] in id2word:
                    tp_str += id2word[np_array_x[i]] + " "
        else:
            tp_str = ""
            words = self.tokenizer.convert_ids_to_tokens(np_array_x)
            for w in words:
                if w not in ['[CLS]', '[SEP]', '[PAD]']:
                    tp_str += w + " "
        return tp_str
    

    def _prepare_data(self):
        """
        global2local_entity_maps: a map from global entity id to local entity id
        adj_mats: a local adjacency matrix for each relation. relation 0 is reserved for self-connection.
        """
        max_count = 0
        for line in self.data:
            word_list = line["question"].split(' ')
            max_count = max(max_count, len(word_list))

        
        if self.rel_word_emb:
            self.build_rel_words(self.tokenize)
        else:
            self.rel_texts = None
            self.rel_texts_inv = None
            self.ent_texts = None



        self.max_query_word = max_count
        #self.query_texts = np.full((self.num_data, self.max_query_word), len(self.word2id), dtype=int)
        #self.query_texts2 = np.full((self.num_data, self.max_query_word), len(self.word2id), dtype=int)

        #build tokenizers
        if self.tokenize == 'lstm':
            self.num_word = len(self.word2id)
            self.tokenizer = LSTMTokenizer(self.word2id, self.max_query_word)
            self.query_texts = np.full((self.num_data, self.max_query_word), self.num_word, dtype=int)
        else:
            if self.tokenize == 'bert':
                tokenizer_name = 'bert-base-uncased'    
            elif self.tokenize  == 'roberta':
                tokenizer_name = 'roberta-base'
            elif self.tokenize  == 'sbert':
                tokenizer_name = 'sentence-transformers/all-MiniLM-L6-v2'
            elif self.tokenize == 'sbert2':
                tokenizer_name = 'sentence-transformers/all-mpnet-base-v2'
            elif self.tokenize  == 't5':
                tokenizer_name = 't5-small'
            elif self.tokenize == 'simcse':
                tokenizer_name = 'princeton-nlp/sup-simcse-bert-base-uncased'
            elif self.tokenize  == 't5':
                tokenizer_name = 't5-small'
            elif self.tokenize  == 'relbert':
                tokenizer_name = 'pretrained_lms/sr-simbert/'

            self.max_query_word = max_count + 2 #for cls token and sep
            #self.tokenizer = AutoTokenizer(self.max_query_word)
            self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
            self.num_word = self.tokenizer.convert_tokens_to_ids(self.tokenizer.pad_token) #self.tokenizer.q_tokenizer.encode("[UNK]")[0]
            
            self.query_texts = np.full((self.num_data, self.max_query_word), self.num_word, dtype=int)


        next_id = 0
        num_query_entity = {}
        for sample in tqdm(self.data):
            self.question_id.append(sample["id"])
            # get a list of local entities
            g2l = self.global2local_entity_maps[next_id]
            #print(g2l)
            if len(g2l) == 0:
                #print(next_id)
                continue
            # build connection between question and entities in it
            tp_set = set()
            seed_list = []
            key_ent = 'entities_cid' if 'entities_cid' in sample else 'entities'
            for j, entity in enumerate(sample[key_ent]):
                # if entity['text'] not in self.entity2id:
                #     continue
                try:
                    if isinstance(entity, dict) and  'text' in entity:
                        global_entity = self.entity2id[entity['text']]
                    else:
                        global_entity = self.entity2id[entity]
                    global_entity = self.entity2id[entity['text']]
                except:
                    global_entity = entity #self.entity2id[entity['text']]

                if global_entity not in g2l:
                    continue
                local_ent = g2l[global_entity]
                self.query_entities[next_id, local_ent] = 1.0
                seed_list.append(local_ent)
                tp_set.add(local_ent)
            
            self.seed_list[next_id] = seed_list
            num_query_entity[next_id] = len(tp_set)
            for global_entity, local_entity in g2l.items():
                if self.data_name != 'cwq':

                    if local_entity not in tp_set:  # skip entities in question
                    #print(global_entity)
                    #print(local_entity)
                        self.candidate_entities[next_id, local_entity] = global_entity
                elif self.data_name == 'cwq':
                    self.candidate_entities[next_id, local_entity] = global_entity
                # if local_entity != 0:  # skip question node
                #     self.candidate_entities[next_id, local_entity] = global_entity

            # relations in local KB
            head_list = []
            rel_list = []
            tail_list = []
            for i, tpl in enumerate(sample['subgraph']['tuples']):
                sbj, rel, obj = tpl
                try:
                    if isinstance(sbj, dict) and  'text' in sbj:
                        head = g2l[self.entity2id[sbj['text']]]
                        rel = self.relation2id[rel['text']]
                        tail = g2l[self.entity2id[obj['text']]]
                    else:
                        head = g2l[self.entity2id[sbj]]
                        rel = self.relation2id[rel]
                        tail = g2l[self.entity2id[obj]]
                except:
                    head = g2l[sbj]
                    try:
                        rel = int(rel)
                    except:
                        rel = self.relation2id[rel]
                    tail = g2l[obj]
                head_list.append(head)
                rel_list.append(rel)
                tail_list.append(tail)
                self.kb_fact_rels[next_id, i] = rel
                if self.use_inverse_relation:
                    head_list.append(tail)
                    rel_list.append(rel + len(self.relation2id))
                    tail_list.append(head)
                    self.kb_fact_rels[next_id, i] = rel + len(self.relation2id)
                
            if len(tp_set) > 0:
                for local_ent in tp_set:
                    self.seed_distribution[next_id, local_ent] = 1.0 / len(tp_set)
            else:
                for index in range(len(g2l)):
                    self.seed_distribution[next_id, index] = 1.0 / len(g2l)
            try:
                assert np.sum(self.seed_distribution[next_id]) > 0.0
            except:
                print(next_id, len(tp_set))
                exit(-1)

            #tokenize question
            if self.tokenize == 'lstm':
                self.query_texts[next_id] = self.tokenizer.tokenize(sample['question'])
            else:
                tokens =  self.tokenizer.encode_plus(text=sample['question'], max_length=self.max_query_word, \
                    pad_to_max_length=True, return_attention_mask = False, truncation=True)
                self.query_texts[next_id] = np.array(tokens['input_ids'])


            # construct distribution for answers
            answer_list = []
            if 'answers_cid' in sample:
                for answer in sample['answers_cid']:
                    #keyword = 'text' if type(answer['kb_id']) == int else 'kb_id'
                    answer_ent = answer
                    answer_list.append(answer_ent)
                    if answer_ent in g2l:
                        self.answer_dists[next_id, g2l[answer_ent]] = 1.0
            else:
                for answer in sample['answers']:
                    keyword = 'text' if type(answer['kb_id']) == int else 'kb_id'
                    answer_ent = self.entity2id[answer[keyword]]
                    answer_list.append(answer_ent)
                    if answer_ent in g2l:
                        self.answer_dists[next_id, g2l[answer_ent]] = 1.0
            self.answer_lists[next_id] = answer_list

            if not self.data_eff:
                self.kb_adj_mats[next_id] = (np.array(head_list, dtype=int),
                                         np.array(rel_list, dtype=int),
                                         np.array(tail_list, dtype=int))

            next_id += 1
        num_no_query_ent = 0
        num_one_query_ent = 0
        num_multiple_ent = 0
        for i in range(next_id):
            ct = num_query_entity[i]
            if ct == 1:
                num_one_query_ent += 1
            elif ct == 0:
                num_no_query_ent += 1
            else:
                num_multiple_ent += 1
        print("{} cases in total, {} cases without query entity, {} cases with single query entity,"
              " {} cases with multiple query entities".format(next_id, num_no_query_ent,
                                                              num_one_query_ent, num_multiple_ent))

        
    def build_rel_words(self, tokenize):
        """ 
        Tokenizes relation surface forms.
        """

        max_rel_words = 0
        rel_words = []
        if 'metaqa' in self.data_file:
            for rel in self.relation2id:
                words = rel.split('_')
                max_rel_words = max(len(words), max_rel_words)
                rel_words.append(words)
            #print(rel_words)
        else:
            for rel in self.relation2id:
                rel = rel.strip()
                fields = rel.split('.')
                try:
                    words = fields[-2].split('_') + fields[-1].split('_')
                    max_rel_words = max(len(words), max_rel_words)
                    rel_words.append(words)
                    #print(rel, words)
                except:
                    words = ['UNK']
                    rel_words.append(words)
                    pass
                #words = fields[-2].split('_') + fields[-1].split('_')
            
        self.max_rel_words = max_rel_words
        if tokenize == 'lstm':
            self.rel_texts = np.full((self.num_kb_relation + 1, self.max_rel_words), len(self.word2id), dtype=int)
            self.rel_texts_inv = np.full((self.num_kb_relation + 1, self.max_rel_words), len(self.word2id), dtype=int)
            for rel_id,tokens in enumerate(rel_words):
                for j, word in enumerate(tokens):
                    if j < self.max_rel_words:
                            if word in self.word2id:
                                self.rel_texts[rel_id, j] = self.word2id[word]
                                self.rel_texts_inv[rel_id, j] = self.word2id[word]
                            else:
                                self.rel_texts[rel_id, j] = len(self.word2id)
                                self.rel_texts_inv[rel_id, j] = len(self.word2id)
        else:
            if tokenize == 'bert':
                tokenizer_name = 'bert-base-uncased'
            elif tokenize == 'roberta':
                tokenizer_name = 'roberta-base'
            elif tokenize == 'sbert':
                tokenizer_name = 'sentence-transformers/all-MiniLM-L6-v2'
            elif tokenize == 'sbert2':
                tokenizer_name = 'sentence-transformers/all-mpnet-base-v2'
            elif tokenize == 'simcse':
                tokenizer_name = 'princeton-nlp/sup-simcse-bert-base-uncased'
            elif tokenize == 't5':
                tokenizer_name = 't5-small'
            elif tokenize  == 'relbert':
                tokenizer_name = 'pretrained_lms/sr-simbert/'
            
            tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
            pad_val = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
            self.rel_texts = np.full((self.num_kb_relation + 1, self.max_rel_words), pad_val, dtype=int)
            self.rel_texts_inv = np.full((self.num_kb_relation + 1, self.max_rel_words), pad_val, dtype=int)
            
            for rel_id,words in enumerate(rel_words):

                tokens =  tokenizer.encode_plus(text=' '.join(words), max_length=self.max_rel_words, \
                    pad_to_max_length=True, return_attention_mask = False, truncation=True)
                tokens_inv =  tokenizer.encode_plus(text=' '.join(words[::-1]), max_length=self.max_rel_words, \
                    pad_to_max_length=True, return_attention_mask = False, truncation=True)
                self.rel_texts[rel_id] = np.array(tokens['input_ids'])
                self.rel_texts_inv[rel_id] = np.array(tokens_inv['input_ids'])


        
        #print(rel_words)
        #print(len(rel_words), len(self.relation2id))
        assert len(rel_words) == len(self.relation2id)
        #print(self.rel_texts, self.max_rel_words)

    def create_kb_adj_mats(self, sample_id):

        """
        Re-build local adj mats if we have data_eff == True (they are not pre-stored).
        """
        sample = self.data[sample_id]
        g2l = self.global2local_entity_maps[sample_id]
        
        # build connection between question and entities in it
        head_list = []
        rel_list = []
        tail_list = []
        for i, tpl in enumerate(sample['subgraph']['tuples']):
            sbj, rel, obj = tpl
            try:
                if isinstance(sbj, dict) and  'text' in sbj:
                    head = g2l[self.entity2id[sbj['text']]]
                    rel = self.relation2id[rel['text']]
                    tail = g2l[self.entity2id[obj['text']]]
                else:
                    head = g2l[self.entity2id[sbj]]
                    rel = self.relation2id[rel]
                    tail = g2l[self.entity2id[obj]]
            except:
                head = g2l[sbj]
                try:
                    rel = int(rel)
                except:
                    rel = self.relation2id[rel]
                tail = g2l[obj]
            head_list.append(head)
            rel_list.append(rel)
            tail_list.append(tail)
            if self.use_inverse_relation:
                head_list.append(tail)
                rel_list.append(rel + len(self.relation2id))
                tail_list.append(head)

        return np.array(head_list, dtype=int),  np.array(rel_list, dtype=int), np.array(tail_list, dtype=int)

    
    def _build_fact_mat(self, sample_ids, fact_dropout):
        """
        Creates local adj mats that contain entities, relations, and structure.
        """
        batch_heads = np.array([], dtype=int)
        batch_rels = np.array([], dtype=int)
        batch_tails = np.array([], dtype=int)
        batch_ids = np.array([], dtype=int)
        #print(sample_ids)
        for i, sample_id in enumerate(sample_ids):
            index_bias = i * self.max_local_entity
            if self.data_eff:
                head_list, rel_list, tail_list = self.create_kb_adj_mats(sample_id) #kb_adj_mats[sample_id]
            else:
                (head_list, rel_list, tail_list) = self.kb_adj_mats[sample_id]
            num_fact = len(head_list)
            num_keep_fact = int(np.floor(num_fact * (1 - fact_dropout)))
            mask_index = np.random.permutation(num_fact)[: num_keep_fact]

            real_head_list = head_list[mask_index] + index_bias
            real_tail_list = tail_list[mask_index] + index_bias
            real_rel_list = rel_list[mask_index]
            batch_heads = np.append(batch_heads, real_head_list)
            batch_rels = np.append(batch_rels, real_rel_list)
            batch_tails = np.append(batch_tails, real_tail_list)
            batch_ids = np.append(batch_ids, np.full(len(mask_index), i, dtype=int))
            if self.use_self_loop:
                num_ent_now = len(self.global2local_entity_maps[sample_id])
                ent_array = np.array(range(num_ent_now), dtype=int) + index_bias
                rel_array = np.array([self.num_kb_relation - 1] * num_ent_now, dtype=int)
                batch_heads = np.append(batch_heads, ent_array)
                batch_tails = np.append(batch_tails, ent_array)
                batch_rels = np.append(batch_rels, rel_array)
                batch_ids = np.append(batch_ids, np.full(num_ent_now, i, dtype=int))
        fact_ids = np.array(range(len(batch_heads)), dtype=int)
        head_rels_ids = zip(batch_heads, batch_rels)
        head_count = Counter(batch_heads)
        # tail_count = Counter(batch_tails)
        weight_list = [1.0 / head_count[head] for head in batch_heads]

        
        head_rels_batch = list(zip(batch_heads, batch_rels))
        #print(head_rels_batch)
        head_rels_count = Counter(head_rels_batch)
        weight_rel_list = [1.0 / head_rels_count[(h,r)] for (h,r) in head_rels_batch]

        #print(head_rels_count)

        # tail_count = Counter(batch_tails)

        # entity2fact_index = torch.LongTensor([batch_heads, fact_ids])
        # entity2fact_val = torch.FloatTensor(weight_list)
        # entity2fact_mat = torch.sparse.FloatTensor(entity2fact_index, entity2fact_val, torch.Size(
        #     [len(sample_ids) * self.max_local_entity, len(batch_heads)]))
        return batch_heads, batch_rels, batch_tails, batch_ids, fact_ids, weight_list, weight_rel_list


    def reset_batches(self, is_sequential=True):
        if is_sequential:
            self.batches = np.arange(self.num_data)
        else:
            self.batches = np.random.permutation(self.num_data)

    def _build_global2local_entity_maps(self):
        """Create a map from global entity id to local entity of each sample"""
        global2local_entity_maps = [None] * self.num_data
        total_local_entity = 0.0
        next_id = 0
        for sample in tqdm(self.data):
            g2l = dict()
            if 'entities_cid' in sample:
                self._add_entity_to_map(self.entity2id, sample['entities_cid'], g2l)
            else:
                self._add_entity_to_map(self.entity2id, sample['entities'], g2l)
            #self._add_entity_to_map(self.entity2id, sample['entities'], g2l)
            # construct a map from global entity id to local entity id
            self._add_entity_to_map(self.entity2id, sample['subgraph']['entities'], g2l)

            global2local_entity_maps[next_id] = g2l
            total_local_entity += len(g2l)
            self.max_local_entity = max(self.max_local_entity, len(g2l))
            next_id += 1
        print('avg local entity: ', total_local_entity / next_id)
        print('max local entity: ', self.max_local_entity)
        return global2local_entity_maps



    @staticmethod
    def _add_entity_to_map(entity2id, entities, g2l):
        #print(entities)
        #print(entity2id)
        for entity_global_id in entities:
            try:
                if isinstance(entity_global_id, dict) and 'text' in entity_global_id:
                    ent = entity2id[entity_global_id['text']]
                else:
                    ent = entity2id[entity_global_id]
                if ent not in g2l:
                    g2l[ent] = len(g2l)
            except:
                if entity_global_id not in g2l:
                    g2l[entity_global_id] = len(g2l)

    def deal_q_type(self, q_type=None):
        sample_ids = self.sample_ids
        if q_type is None:
            q_type = self.q_type
        if q_type == "seq":
            q_input = self.query_texts[sample_ids]
        else:
            raise NotImplementedError
        
        return q_input

    



class SingleDataLoader(BasicDataLoader):
    """
    Single Dataloader creates training/eval batches during KGQA.
    """
    def __init__(self, config, word2id, relation2id, entity2id, tokenize, data_type="train"):
        super(SingleDataLoader, self).__init__(config, word2id, relation2id, entity2id, tokenize, data_type)
        
    def get_batch(self, iteration, batch_size, fact_dropout, q_type=None, test=False):
        start = batch_size * iteration
        end = min(batch_size * (iteration + 1), self.num_data)
        sample_ids = self.batches[start: end]
        self.sample_ids = sample_ids
        # true_batch_id, sample_ids, seed_dist = self.deal_multi_seed(ori_sample_ids)
        # self.sample_ids = sample_ids
        # self.true_sample_ids = ori_sample_ids
        # self.batch_ids = true_batch_id
        true_batch_id = None
        seed_dist = self.seed_distribution[sample_ids]
        q_input = self.deal_q_type(q_type)
        kb_adj_mats = self._build_fact_mat(sample_ids, fact_dropout=fact_dropout)
        
        if test:
            return self.candidate_entities[sample_ids], \
                   self.query_entities[sample_ids], \
                   kb_adj_mats, \
                   q_input, \
                   seed_dist, \
                   true_batch_id, \
                   self.answer_dists[sample_ids], \
                   self.answer_lists[sample_ids],\

        return self.candidate_entities[sample_ids], \
               self.query_entities[sample_ids], \
               kb_adj_mats, \
               q_input, \
               seed_dist, \
               true_batch_id, \
               self.answer_dists[sample_ids]


def load_dict(filename):
    word2id = dict()
    with open(filename, encoding='utf-8') as f_in:
        for line in f_in:
            word = line.strip()
            word2id[word] = len(word2id)
    return word2id

def load_dict_int(filename):
    word2id = dict()
    with open(filename, encoding='utf-8') as f_in:
        for line in f_in:
            word = line.strip()
            word2id[int(word)] = int(word)
    return word2id

def load_data(config, tokenize):

    """
    Creates train/val/test dataloaders (seperately).
    """
    if 'sr-cwq' in config['data_folder']:
        entity2id = load_dict_int(config['data_folder'] + config['entity2id'])
    else:
        entity2id = load_dict(config['data_folder'] + config['entity2id'])
    word2id = load_dict(config['data_folder'] + config['word2id'])
    relation2id = load_dict(config['data_folder'] + config['relation2id'])
    
    if config["is_eval"]:
        train_data = None
        valid_data = SingleDataLoader(config, word2id, relation2id, entity2id, tokenize, data_type="dev")
        test_data = SingleDataLoader(config, word2id, relation2id, entity2id, tokenize, data_type="test")
        num_word = test_data.num_word
    else:
        train_data = SingleDataLoader(config, word2id, relation2id, entity2id, tokenize, data_type="train")
        valid_data = SingleDataLoader(config, word2id, relation2id, entity2id, tokenize, data_type="dev")
        test_data = SingleDataLoader(config, word2id, relation2id, entity2id, tokenize, data_type="test")
        num_word = train_data.num_word
    relation_texts = test_data.rel_texts
    relation_texts_inv = test_data.rel_texts_inv
    entities_texts = None
    dataset = {
        "train": train_data,
        "valid": valid_data,
        "test": test_data, #test_data,
        "entity2id": entity2id,
        "relation2id": relation2id,
        "word2id": word2id,
        "num_word": num_word,
        "rel_texts": relation_texts,
        "rel_texts_inv": relation_texts_inv,
        "ent_texts": entities_texts
    }
    return dataset


if __name__ == "__main__":
    st = time.time()
    #args = get_config()
    load_data(args)


================================================
FILE: gnn/dataset_load_graft.py
================================================
import json
import numpy as np
import re
from tqdm import tqdm
import torch
from collections import Counter
import random
import warnings
import pickle
warnings.filterwarnings("ignore")
from modules.question_encoding.tokenizers import LSTMTokenizer#, BERTTokenizer
from transformers import AutoTokenizer
import time

import os

from dataset_load import BasicDataLoader

class GraftBasicDataLoader(BasicDataLoader):
    """ 
    Basic Dataloader contains all the functions to read questions and KGs from json files and
    create mappings between global entity ids and local ids that are used during GNN updates.
    """
    def __init__(self, config, word2id, relation2id, entity2id, tokenize, data_type):
        super(GraftBasicDataLoader, self).__init__(config, word2id, relation2id, entity2id, tokenize, data_type)

    def create_kb_adj_mats_facts(self, sample_id):
        sample = self.data[sample_id]
        g2l = self.global2local_entity_maps[sample_id]
        entity2fact_e, entity2fact_f = [], []
        fact2entity_f, fact2entity_e = [], []
        kb_fact_rels = np.full(self.max_facts, self.num_kb_relation, dtype=int)
        for i, tpl in enumerate(sample['subgraph']['tuples']):
            sbj, rel, obj = tpl
            try:

                if isinstance(sbj, dict) and  'text' in sbj:
                    head = g2l[self.entity2id[sbj['text']]]
                    rel = self.relation2id[rel['text']]
                    tail = g2l[self.entity2id[obj['text']]]
                else:
                    head = g2l[self.entity2id[sbj]]
                    rel = self.relation2id[rel]
                    tail = g2l[self.entity2id[obj]]
            except:
                head = g2l[sbj]
                try:
                    rel = int(rel)
                except:
                    rel = self.relation2id[rel]
                tail = g2l[obj]
            if not self.use_inverse_relation:
                entity2fact_e += [head]
                entity2fact_f += [i]
                fact2entity_f += [i]
                fact2entity_e += [tail]
                kb_fact_rels[i] = rel
            else:
                entity2fact_e += [head, tail]
                entity2fact_f += [2 * i, 2 * i + 1]
                fact2entity_f += [2 * i, 2 * i + 1]
                fact2entity_e += [tail, head]
                kb_fact_rels[2 * i] = rel
                kb_fact_rels[2 * i + 1] = rel + len(self.relation2id)
        kb_adj_mats = (np.array(entity2fact_f, dtype=int), \
            np.array(entity2fact_e, dtype=int), np.array([1.0] * len(entity2fact_f))),\
            (np.array(fact2entity_e, dtype=int), np.array(fact2entity_f, dtype=int), np.array([1.0] * len(fact2entity_e)))
        return kb_adj_mats, kb_fact_rels

    def _build_fact_mat_maxfacts(self, sample_ids, fact_dropout):
        """Create sparse matrix representation for batched data"""
        kb_fact_rels = np.full((len(sample_ids), self.max_facts), self.num_kb_relation, dtype=int)

        mats0_batch = np.array([], dtype=int)
        mats0_0 = np.array([], dtype=int)
        mats0_1 = np.array([], dtype=int)
        vals0 = np.array([], dtype=float)

        mats1_batch = np.array([], dtype=int)
        mats1_0 = np.array([], dtype=int)
        mats1_1 = np.array([], dtype=int)
        vals1 = np.array([], dtype=float)

        for i, sample_id in enumerate(sample_ids):
            ((mat0_0, mat0_1, val0), (mat1_0, mat1_1, val1)), kb_fact_rel = self.create_kb_adj_mats_facts(sample_id)
            kb_fact_rels[i] = kb_fact_rel
            assert len(val0) == len(val1)
            num_fact = len(val0)
            num_keep_fact = int(np.floor(num_fact * (1 - fact_dropout)))
            mask_index = np.random.permutation(num_fact)[ : num_keep_fact]
            # mat0
            mats0_batch = np.append(mats0_batch, np.full(len(mask_index), i, dtype=int))
            mats0_0 = np.append(mats0_0, mat0_0[mask_index])
            mats0_1 = np.append(mats0_1, mat0_1[mask_index])
            vals0 = np.append(vals0, val0[mask_index])
            # mat1
            mats1_batch = np.append(mats1_batch, np.full(len(mask_index), i, dtype=int))
            mats1_0 = np.append(mats1_0, mat1_0[mask_index])
            mats1_1 = np.append(mats1_1, mat1_1[mask_index])
            vals1 = np.append(vals1, val1[mask_index])

        return ((mats0_batch, mats0_0, mats0_1, vals0), (mats1_batch, mats1_0, mats1_1, vals1)), kb_fact_rels



class GraftSingleDataLoader(GraftBasicDataLoader):
    """
    Single Dataloader creates training/eval batches during KGQA.
    """
    def __init__(self, config, word2id, relation2id, entity2id, tokenize, data_type="train"):
        super(GraftSingleDataLoader, self).__init__(config, word2id, relation2id, entity2id, tokenize, data_type)
        
    def get_batch(self, iteration, batch_size, fact_dropout, q_type=None, test=False):
        start = batch_size * iteration
        end = min(batch_size * (iteration + 1), self.num_data)
        sample_ids = self.batches[start: end]
        self.sample_ids = sample_ids
        # true_batch_id, sample_ids, seed_dist = self.deal_multi_seed(ori_sample_ids)
        # self.sample_ids = sample_ids
        # self.true_sample_ids = ori_sample_ids
        # self.batch_ids = true_batch_id
        true_batch_id = None
        seed_dist = self.seed_distribution[sample_ids]
        q_input = self.deal_q_type(q_type)
        kb_adj_mats = self._build_fact_mat(sample_ids, fact_dropout=fact_dropout)
        kb_fact_rels = self.kb_fact_rels[sample_ids]
        kb_adj_mats_graft, _ = self._build_fact_mat_maxfacts(sample_ids, fact_dropout=fact_dropout)
        
        if test:
            return self.candidate_entities[sample_ids], \
                   self.query_entities[sample_ids], \
                   kb_adj_mats, \
                   kb_adj_mats_graft, \
                   q_input, \
                   kb_fact_rels, \
                   seed_dist, \
                   true_batch_id, \
                   self.answer_dists[sample_ids], \
                   self.answer_lists[sample_ids],\

        return self.candidate_entities[sample_ids], \
               self.query_entities[sample_ids], \
               kb_adj_mats, \
               kb_adj_mats_graft, \
               q_input, \
               kb_fact_rels, \
               seed_dist, \
               true_batch_id, \
               self.answer_dists[sample_ids]


def load_dict(filename):
    word2id = dict()
    with open(filename, encoding='utf-8') as f_in:
        for line in f_in:
            word = line.strip()
            word2id[word] = len(word2id)
    return word2id

def load_dict_int(filename):
    word2id = dict()
    with open(filename, encoding='utf-8') as f_in:
        for line in f_in:
            word = line.strip()
            word2id[int(word)] = int(word)
    return word2id

def load_data_graft(config, tokenize):

    """
    Creates train/val/test dataloaders (seperately).
    """
    if 'sr-cwq' in config['data_folder']:
        entity2id = load_dict_int(config['data_folder'] + config['entity2id'])
    else:
        entity2id = load_dict(config['data_folder'] + config['entity2id'])
    word2id = load_dict(config['data_folder'] + config['word2id'])
    relation2id = load_dict(config['data_folder'] + config['relation2id'])
    
    if config["is_eval"]:
        train_data = None
        valid_data = GraftSingleDataLoader(config, word2id, relation2id, entity2id, tokenize, data_type="dev")
        test_data = GraftSingleDataLoader(config, word2id, relation2id, entity2id, tokenize, data_type="test")
        num_word = test_data.num_word
    else:
        train_data = GraftSingleDataLoader(config, word2id, relation2id, entity2id, tokenize, data_type="train")
        valid_data = GraftSingleDataLoader(config, word2id, relation2id, entity2id, tokenize, data_type="dev")
        test_data = GraftSingleDataLoader(config, word2id, relation2id, entity2id, tokenize, data_type="test")
        num_word = train_data.num_word
    relation_texts = test_data.rel_texts
    relation_texts_inv = test_data.rel_texts_inv
    entities_texts = None
    dataset = {
        "train": train_data,
        "valid": valid_data,
        "test": test_data, #test_data,
        "entity2id": entity2id,
        "relation2id": relation2id,
        "word2id": word2id,
        "num_word": num_word,
        "rel_texts": relation_texts,
        "rel_texts_inv": relation_texts_inv,
        "ent_texts": entities_texts
    }
    return dataset


if __name__ == "__main__":
    st = time.time()
    #args = get_config()
    load_data_graft(args)


================================================
FILE: gnn/evaluate.py
================================================

from tqdm import tqdm
tqdm.monitor_iterval = 0
import torch
import numpy as np
import math, os
import json
import pickle

def cal_accuracy(pred, answer_dist):
    """
    pred: batch_size
    answer_dist: batch_size, max_local_entity
    """
    num_correct = 0.0
    num_answerable = 0.0
    for i, l in enumerate(pred):
        num_correct += (answer_dist[i, l] != 0)
    for dist in answer_dist:
        if np.sum(dist) != 0:
            num_answerable += 1
    return num_correct / len(pred), num_answerable / len(pred)


def f1_and_hits(answers, candidate2prob, id2entity, entity2name, eps=0.5):
    ans = []
    retrieved = []
    for a in answers:
        if entity2name is None:
            ans.append(id2entity[a])
        else:
            ans.append(entity2name[id2entity[a]])
    correct = 0
    cand_list = sorted(candidate2prob, key=lambda x:x[1], reverse=True)
    if len(cand_list) == 0:
        best_ans = -1
    else:
        best_ans = cand_list[0][0]
    # max_prob = cand_list[0][1]
    tp_prob = 0.0
    for c, prob in cand_list:
        if entity2name is None:
            retrieved.append((id2entity[c], prob))
        else:
           retrieved.append((entity2name[id2entity[c]], prob))
        tp_prob += prob
        if c in answers:
            correct += 1
        if tp_prob > eps:
            break
    if correct > 0:
        em = 1
    else:
        em = 0
    if len(answers) == 0:
        if len(retrieved) == 0:
            return 1.0, 1.0, 1.0, 1.0, 1.0, 0, retrieved, ans  # precision, recall, f1, hits, em
        else:
            return 0.0, 1.0, 0.0, 1.0, 1.0, 1, retrieved , ans # precision, recall, f1, hits, em
    else:
        hits = float(best_ans in answers)
        if len(retrieved) == 0:
            return 1.0, 0.0, 0.0, hits, hits, 2, retrieved , ans # precision, recall, f1, hits, em
        else:
            p, r = correct / len(retrieved), correct / len(answers)
            f1 = 2.0 / (1.0 / p + 1.0 / r) if p != 0 and r != 0 else 0.0
            return p, r, f1, hits, em, 3, retrieved, ans


class Evaluator:

    def __init__(self, args, model, entity2id, relation2id, device):
        self.model = model
        self.args = args
        self.eps = args['eps']
        self.model_name = args["model_name"]
        
        id2entity = {idx: entity for entity, idx in entity2id.items()}
        self.id2entity = id2entity

        self.entity2name = None
        if 'sr-' in args["data_folder"]:
            file = open('ent2id.pickle', 'rb')
            self.entity2name = list((pickle.load(file)).keys())
            file.close()

            
        id2relation = {idx: relation for relation, idx in relation2id.items()}
        num_rel_ori = len(relation2id)

        if 'use_inverse_relation' in args:
            self.use_inverse_relation = args['use_inverse_relation']
            if self.use_inverse_relation:
                for i in range(len(id2relation)):
                    id2relation[i + num_rel_ori] = id2relation[i] + "_rev"

        if 'use_self_loop' in args:
            self.use_self_loop = args['use_self_loop']
            if self.use_self_loop:
                id2relation[len(id2relation)] = "self_loop"

        self.id2relation = id2relation
        self.file_write = None
        self.device = device

    def write_info(self, valid_data, tp_list, num_step):
        question_list = valid_data.get_quest()
        #num_step = steps
        obj_list = []
        if tp_list is not None:
            # attn_list = [tp[1] for tp in tp_list]
            action_list = [tp[0] for tp in tp_list]
        for i in range(len(question_list)):
            obj_list.append({})
        for j in range(num_step):
            if tp_list is None:
                actions = None
            else:
                actions = action_list[j]
                actions = actions.cpu().numpy()
            # if attn_list is not None:
            #     attention = attn_list[j].cpu().numpy()
            for i in range(len(question_list)):
                tp_obj = obj_list[i]
                q = question_list[i]
                # real_index = self.true_batch_id[i][0]
                tp_obj['question'] = q
                tp_obj[j] = {}
                # print(actions)
                if tp_list is not None:
                    action = actions[i]
                    rel_action = self.id2relation[action]
                    tp_obj[j]['rel_action'] = rel_action
                    tp_obj[j]['action'] = str(action)
                    # if attn_list is not None:
                    #     attention_tp = attention[i]
                    #     tp_obj[j]['attention'] = attention_tp.tolist()
        return obj_list

    def evaluate(self, valid_data, test_batch_size=20, write_info=False):
        write_info = True
        self.model.eval()
        self.count = 0
        eps = self.eps
        id2entity = self.id2entity
        eval_loss, eval_acc, eval_max_acc = [], [], []
        f1s, hits, ems,  precisions, recalls = [], [], [], [], []
        valid_data.reset_batches(is_sequential=True)
        num_epoch = math.ceil(valid_data.num_data / test_batch_size)
        if write_info and self.file_write is None:
            filename = os.path.join(self.args['checkpoint_dir'],
                                    "{}_test.info".format(self.args['experiment_name']))
            self.file_write = open(filename, "w")
        case_ct = {}
        max_local_entity = valid_data.max_local_entity
        ignore_prob = (1 - eps) / max_local_entity
        for iteration in tqdm(range(num_epoch)):
            batch = valid_data.get_batch(iteration, test_batch_size, fact_dropout=0.0, test=True)
            with torch.no_grad():
                loss, extras, pred_dist, tp_list = self.model(batch[:-1])
                pred = torch.max(pred_dist, dim=1)[1]
            if self.model_name == 'GraftNet':
                local_entity, query_entities, _, _, query_text, _, \
                seed_dist, true_batch_id, answer_dist, answer_list = batch
            else:
                local_entity, query_entities, _, query_text, \
                seed_dist, true_batch_id, answer_dist, answer_list = batch
            # self.true_batch_id = true_batch_id
            if write_info:
                obj_list = self.write_info(valid_data, tp_list, self.model.num_iter)
                # pred_sum = torch.sum(pred_dist, dim=1)
                # print(pred_sum)
            candidate_entities = torch.from_numpy(local_entity).type('torch.LongTensor')
            true_answers = torch.from_numpy(answer_dist).type('torch.FloatTensor')
            query_entities = torch.from_numpy(query_entities).type('torch.LongTensor')
            # acc, max_acc = cal_accuracy(pred, true_answers.cpu().numpy())
            eval_loss.append(loss.item())
            # eval_acc.append(acc)
            # eval_max_acc.append(max_acc)
            #pr_dist2 = pred_dist#.copy()
            #pred_dist = pr_dist2[-1]
            batch_size = pred_dist.size(0)
            batch_answers = answer_list
            batch_candidates = candidate_entities
            pad_ent_id = len(id2entity)
            #pr_dist2 = pred_dist.copy()
            #for pred_dist in pr_dist2:
            for batch_id in range(batch_size):
                answers = batch_answers[batch_id]
                candidates = batch_candidates[batch_id, :].tolist()
                probs = pred_dist[batch_id, :].tolist()
                seed_entities = query_entities[batch_id, :].tolist()
                #print(seed_entities)
                #print(candidates)
                candidate2prob = []
                for c, p, s in zip(candidates, probs, seed_entities):
                    if s == 1.0:
                        # ignore seed entities
                        #print(c, self.id2entity)
                        # print(c, p, s)
                        # if c < pad_ent_id:
                        #     tp_obj['seed'] = self.id2entity[c]
                        continue
                    if c == pad_ent_id:
                        continue
                    if p < ignore_prob:
                        continue
                    candidate2prob.append((c, p))
                precision, recall, f1, hit, em, case, retrived , ans = f1_and_hits(answers, candidate2prob, self.id2entity, self.entity2name ,eps)
                if write_info:
                    tp_obj = obj_list[batch_id]
                    tp_obj['answers'] = ans
                    tp_obj['precison'] = precision
                    tp_obj['recall'] = recall
                    tp_obj['f1'] = f1
                    tp_obj['hit'] = hit
                    tp_obj['em'] = em
                    tp_obj['cand'] = retrived
                    self.file_write.write(json.dumps(tp_obj) + "\n")
                case_ct.setdefault(case, 0)
                case_ct[case] += 1
                f1s.append(f1)
                hits.append(hit)
                ems.append(em)
                precisions.append(precision)
                recalls.append(recall)
        print('evaluation.......')
        print('how many eval samples......', len(f1s))
        # print('avg_f1', np.mean(f1s))
        print('avg_em', np.mean(ems))
        print('avg_hits', np.mean(hits))
        print('avg_f1', np.mean(f1s))
        print('avg_precision', np.mean(precisions))
        print('avg_recall', np.mean(recalls))
        
        print(case_ct)
        if write_info:
            self.file_write.close()
            self.file_write = None
        return np.mean(f1s), np.mean(hits), np.mean(ems)





================================================
FILE: gnn/main.py
================================================
import argparse

from utils import create_logger
import torch
import numpy as np
import os
import time
#from Models.ReaRev.rearev import 
from train_model import Trainer_KBQA
from parsing import add_parse_args

parser = argparse.ArgumentParser()
add_parse_args(parser)

args = parser.parse_args()
args.use_cuda = torch.cuda.is_available()

np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.experiment_name == None:
    timestamp = str(int(time.time()))
    args.experiment_name = "{}-{}-{}".format(
        args.dataset,
        args.model_name,
        timestamp,
    )


def main():
    if not os.path.exists(args.checkpoint_dir):
        os.mkdir(args.checkpoint_dir)
    logger = create_logger(args)
    trainer = Trainer_KBQA(args=vars(args), model_name=args.model_name, logger=logger)
    if not args.is_eval:
        trainer.train(0, args.num_epoch - 1)
    else:
        assert args.load_experiment is not None
        if args.load_experiment is not None:
            ckpt_path = os.path.join(args.checkpoint_dir, args.load_experiment)
            print("Loading pre trained model from {}".format(ckpt_path))
        else:
            ckpt_path = None
        trainer.evaluate_single(ckpt_path)


if __name__ == '__main__':
    main()


================================================
FILE: gnn/models/GraftNet/graftnet.py
================================================
import torch
import numpy as np
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn

from models.base_model import BaseModel
from modules.kg_reasoning.graft_gnn import GraftLayer
from modules.question_encoding.lstm_encoder import LSTMInstruction#, BERTInstruction
from modules.question_encoding.bert_encoder import BERTInstruction#, BERTInstruction

from modules.layer_init import TypeLayer
from modules.query_update import AttnEncoder



VERY_SMALL_NUMBER = 1e-10
VERY_NEG_NUMBER = -100000000000


class GraftNet(BaseModel):
    def __init__(self, args, num_entity, num_relation, num_word):
        """
        num_relation: number of relation including self-connection
        """
        super(GraftNet, self).__init__(args, num_entity, num_relation, num_word)
        self.num_layer = args['num_layer']
        self.loss_type =  args['loss_type']
        self.model_name = args['model_name'].lower()
        self.lm = args['lm']
        self.norm_rel = args['norm_rel']
        self.num_iter = self.num_layer
        self.layers(args)
        self.private_module_def(args, num_entity, num_relation)
        self.to(self.device)

    def layers(self, args):
        # initialize entity embedding
        word_dim = self.word_dim
        kg_dim = self.kg_dim
        entity_dim = self.entity_dim

        self.linear_dropout = args['linear_dropout']
        
        self.entity_linear = nn.Linear(in_features=self.ent_dim, out_features=entity_dim)
        self.relation_linear1 = nn.Linear(in_features=self.rel_dim, out_features=entity_dim)

        # dropout
        self.linear_drop = nn.Dropout(p=self.linear_dropout)

        if self.encode_type:
            self.type_layer = TypeLayer(in_features=entity_dim, out_features=entity_dim,
                                        linear_drop=self.linear_drop, device=self.device, norm_rel=self.norm_rel)

        self.self_att_r = AttnEncoder(self.entity_dim)
        self.kld_loss = nn.KLDivLoss(reduction='none')
        self.bce_loss_logits = nn.BCEWithLogitsLoss(reduction='none')
        self.mse_loss = torch.nn.MSELoss()
    
    
    def private_module_def(self, args, num_entity, num_relation):
        # initialize entity embedding
        word_dim = self.word_dim
        kg_dim = self.kg_dim
        entity_dim = self.entity_dim
        self.reasoning = GraftLayer(args, num_entity, num_relation, entity_dim)
        if args['lm'] == 'lstm':
            self.instruction = LSTMInstruction(args, self.word_embedding, self.num_word)
        else:
            
            self.instruction = BERTInstruction(args, self.word_embedding, self.num_word, args['lm'])
            self.relation_linear = nn.Linear(in_features=self.word_dim, out_features=entity_dim)

    def get_ent_init(self, local_entity, kb_adj_mat, rel_features):
        if self.encode_type:
            local_entity_emb = self.type_layer(local_entity=local_entity,
                                               edge_list=kb_adj_mat,
                                               rel_features=rel_features)
        else:
            local_entity_emb = self.entity_embedding(local_entity)  # batch_size, max_local_entity, word_dim
            local_entity_emb = self.entity_linear(local_entity_emb)
        
        return local_entity_emb
    
    def get_rel_feature(self):
        if self.rel_texts is None:
            rel_features = self.relation_embedding.weight
            rel_features = self.relation_linear1(rel_features)
        else:
            #rel_features = self.instruction.encode_question(self.rel_texts, store=False)
            rel_features = self.rel_features # self.relation_linear(self.rel_features)
            #print(rel_features.size())
            #print(self.instruction.question_emb)
            rel_features = self.instruction.question_emb(rel_features)
            #rel_features = self.relation_linear(rel_features)
            rel_features = self.self_att_r(rel_features,  (self.rel_texts != self.instruction.pad_val).float())
            if self.lm == 'lstm':
                rel_features = self.self_att_r(rel_features, (self.rel_texts != self.num_relation+1).float())
            # else:
            #     rel_features = self.self_att_r(rel_features,  (self.rel_texts != self.instruction.pad_val).float())

        return rel_features
    
    
    def init_reason(self, curr_dist, local_entity, kb_adj_mat, kb_adj_mat_graft, kb_fact_rel, q_input):
        # batch_size = local_entity.size(0)
        self.local_entity = local_entity
        self.instruction_list, self.attn_list = self.instruction(q_input)
        self.query_hidden_emb = self.instruction.query_hidden_emb
        self.query_node_emb = self.instruction.query_node_emb
        self.query_mask = self.instruction.query_mask
        rel_features = self.get_rel_feature()
        self.local_entity_emb = self.get_ent_init(local_entity, kb_adj_mat, rel_features)
        self.curr_dist = curr_dist
        self.dist_history = []
        self.action_probs = []
        self.seed_entities = curr_dist


        self.reasoning.init_reason(local_entity=local_entity,
                                   kb_adj_mat=kb_adj_mat,
                                   kb_adj_mat_graft=kb_adj_mat_graft,
                                   kb_fact_rel = kb_fact_rel,
                                   local_entity_emb=self.local_entity_emb,
                                   rel_features=rel_features,
                                   query_node_emb=self.query_node_emb)
    
    def calc_loss_label(self, curr_dist, teacher_dist, label_valid):
        tp_loss = self.get_loss(pred_dist=curr_dist, answer_dist=teacher_dist, reduction='none')
        tp_loss = tp_loss * label_valid
        cur_loss = torch.sum(tp_loss) / curr_dist.size(0)
        return cur_loss


    def forward(self, batch, training=False):
        local_entity, query_entities, kb_adj_mat ,kb_adj_mat_graft,  query_text, kb_fact_rel, seed_dist, true_batch_id, answer_dist = batch
        local_entity = torch.from_numpy(local_entity).type('torch.LongTensor').to(self.device)

        # local_entity_mask = (local_entity != self.num_entity).float()
        query_entities = torch.from_numpy(query_entities).type('torch.FloatTensor').to(self.device)
        answer_dist = torch.from_numpy(answer_dist).type('torch.FloatTensor').to(self.device)
        seed_dist = torch.from_numpy(seed_dist).type('torch.FloatTensor').to(self.device)
        current_dist = Variable(seed_dist, requires_grad=True)

        q_input= torch.from_numpy(query_text).type('torch.LongTensor').to(self.device)
        if self.lm == 'bert':
            query_mask = (q_input != 0).float()
        else:
            query_mask = (q_input != self.num_word).float()
        #query_mask = (q_input != self.num_word).float()

        
        #instruction generation
        self.init_reason(curr_dist=current_dist, local_entity=local_entity,
                         kb_adj_mat=kb_adj_mat, kb_adj_mat_graft=kb_adj_mat_graft, kb_fact_rel=kb_fact_rel, q_input=q_input)
        self.instruction.init_reason(q_input)

        
        #reasoning
        self.curr_dist = current_dist   
        self.ent_dist = current_dist
        self.dist_history.append(self.curr_dist)
        for i in range(self.num_layer):
            score_tp, score, self.curr_dist= self.reasoning(self.curr_dist, self.query_hidden_emb, self.query_mask, step=i, return_score=True)
            self.dist_history.append(score)

        pred_dist = self.dist_history[-1]
        answer_number = torch.sum(answer_dist, dim=1, keepdim=True)
        case_valid = (answer_number > 0).float()
        loss = self.calc_loss_label(curr_dist=score_tp, teacher_dist=answer_dist, label_valid=case_valid)
        pred_dist = self.dist_history[-1]
        pred = torch.max(pred_dist, dim=1)[1]

        # answer_mask = self.local_entity_mask
        # self.possible_cand.append(answer_mask)
        # score_tp = score_tp + (1 - answer_mask) * VERY_NEG_NUMBER

        if training:
            h1, f1 = self.get_eval_metric(pred_dist, answer_dist)
            tp_list = [h1.tolist(), f1.tolist()]
        else:
            tp_list = None
        return loss, pred, pred_dist, tp_list

================================================
FILE: gnn/models/NSM/nsm.py
================================================
import torch
import numpy as np
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn

from models.base_model import BaseModel
from modules.kg_reasoning.nsm_gnn import NSMLayer, NSMLayer_back
from modules.question_encoding.lstm_encoder import LSTMInstruction#, BERTInstruction
from modules.question_encoding.bert_encoder import BERTInstruction
from modules.layer_init import TypeLayer
from modules.query_update import AttnEncoder


VERY_SMALL_NUMBER = 1e-10
VERY_NEG_NUMBER = -100000000000


class NSM(BaseModel):
    def __init__(self, args, num_entity, num_relation, num_word):
        """
        num_relation: number of relation including self-connection
        """
        super(NSM, self).__init__(args, num_entity, num_relation, num_word)
        self.num_step = args['num_step']
        self.num_iter = self.num_step
        self.loss_type =  args['loss_type']
        self.model_name = args['model_name'].lower()
        self.lambda_constrain = args['lambda_constrain']
        self.lambda_back = args['lambda_back']
        self.lm = args['lm']
        self.norm_rel = args['norm_rel']
        self.layers(args)
        self.private_module_def(args, num_entity, num_relation)
        self.to(self.device)

    def layers(self, args):
        # initialize entity embedding
        word_dim = self.word_dim
        kg_dim = self.kg_dim
        entity_dim = self.entity_dim

        self.linear_dropout = args['linear_dropout']
        
        self.entity_linear = nn.Linear(in_features=self.ent_dim, out_features=entity_dim)
        
        self.relation_linear1 = nn.Linear(in_features=self.rel_dim, out_features=entity_dim)
        self.relation_linear2 = nn.Linear(in_features=self.rel_dim, out_features=entity_dim)

        self.kg_lin = nn.Linear(in_features=entity_dim, out_features=entity_dim)
        self.softmax_d1 = nn.Softmax(dim=1)
        self.score_func = nn.Linear(in_features=2*entity_dim, out_features=1)
        # dropout
        self.linear_drop = nn.Dropout(p=self.linear_dropout)

        if self.encode_type:
            self.type_layer = TypeLayer(in_features=entity_dim, out_features=entity_dim,
                                        linear_drop=self.linear_drop, device=self.device, norm_rel=self.norm_rel)

        
        self.self_att_r = AttnEncoder(self.entity_dim)
        self.self_att_r2 = AttnEncoder(self.entity_dim)
        self.kld_loss = nn.KLDivLoss(reduction='none')
        self.kld_loss_1 = nn.KLDivLoss(reduction='none')
        self.bce_loss_logits = nn.BCEWithLogitsLoss(reduction='none')
        self.mse_loss = torch.nn.MSELoss()
    
    
    def private_module_def(self, args, num_entity, num_relation):
        # initialize entity embedding
        word_dim = self.word_dim
        kg_dim = self.kg_dim
        entity_dim = self.entity_dim
        self.reasoning = NSMLayer(args, num_entity, num_relation, entity_dim)
        self.reasoning2 = NSMLayer(args, num_entity, num_relation, entity_dim)
        if self.lambda_back != 0.0 or self.lambda_constrain != 0.0:
            self.reasoning_back = NSMLayer_back(args, num_entity, num_relation, entity_dim)
        if args['lm'] == 'lstm':
            self.instruction = LSTMInstruction(args, self.word_embedding, self.num_word)
        else:
            
            self.instruction = BERTInstruction(args, self.word_embedding, self.num_word, args['lm'])
            self.relation_linear = nn.Linear(in_features=self.word_dim, out_features=entity_dim)
            
    def get_ent_init(self, local_entity, kb_adj_mat, rel_features):
        if self.encode_type:
            local_entity_emb = self.type_layer(local_entity=local_entity,
                                               edge_list=kb_adj_mat,
                                               rel_features=rel_features)
        else:
            local_entity_emb = self.entity_embedding(local_entity)  # batch_size, max_local_entity, word_dim
            local_entity_emb = self.entity_linear(local_entity_emb)
        
        return local_entity_emb


    def get_rel_feature(self):
        if self.rel_texts is None:
            rel_features = self.relation_embedding.weight
            rel_features = self.relation_linear1(rel_features)
        else:
            #rel_features = self.instruction.encode_question(self.rel_texts, store=False)
            rel_features = self.instruction.question_emb(self.rel_features)
            #rel_features = self.relation_linear(rel_features)
            rel_features = self.self_att_r(rel_features,  (self.rel_texts != self.instruction.pad_val).float())
            if self.lm == 'lstm':
                rel_features = self.self_att_r(rel_features, (self.rel_texts != self.num_relation+1).float())
            # else:
            #     rel_features = self.self_att_r(rel_features,  (self.rel_texts != self.instruction.pad_val).float())

        return rel_features

    
    def init_reason(self, curr_dist, local_entity, kb_adj_mat, q_input):
        # batch_size = local_entity.size(0)
        self.local_entity = local_entity
        self.instruction_list, self.attn_list = self.instruction(q_input)
        rel_features = self.get_rel_feature()
        #print(self.rel_features1)
        #self.rel_features2 = self.get_rel_feature2()
        self.local_entity_emb = self.get_ent_init(local_entity, kb_adj_mat, rel_features)
        #self.kge_entity_emb = self.get_ent_init2(local_entity, kb_adj_mat, self.rel_features)
        self.curr_dist = curr_dist
        self.dist_history = []
        self.dist_history2 = []
        self.backward_history = []
        self.action_probs = []
        self.seed_entities = curr_dist


        self.reasoning.init_reason(local_entity=local_entity,
                                   kb_adj_mat=kb_adj_mat,
                                   local_entity_emb=self.local_entity_emb,
                                   rel_features=rel_features)

        if self.lambda_back != 0.0 or self.lambda_constrain != 0.0:
            self.reasoning_back.init_reason(local_entity=local_entity,
                                   kb_adj_mat=kb_adj_mat,
                                   local_entity_emb=self.local_entity_emb,
                                   rel_features=rel_features)
    
    def get_js_div(self, dist_1, dist_2):
        mean_dist = (dist_1 + dist_2) / 2
        log_mean_dist = torch.log(mean_dist + 1e-8)
        # loss_kl_1 = self.kld_loss_1(log_mean_dist, dist_1)
        # loss_kl_2 = self.kld_loss_1(log_mean_dist, dist_2)
        # print(loss_kl_1.item(), loss_kl_2.item())
        loss = 0.5 * (self.kld_loss_1(log_mean_dist, dist_1) + self.kld_loss_1(log_mean_dist, dist_2))
        return loss
    
    def calc_loss_backward(self, case_valid):
        back_loss = None
        constrain_loss = None
        for i in range(self.num_step):
            forward_dist = self.dist_history[i]
            backward_dist = self.backward_history[i]
            if i == 0:
                # back_loss = self.get_loss_new(backward_dist, forward_dist)
                back_loss = self.calc_loss_label(curr_dist=backward_dist,
                                                 teacher_dist=forward_dist,
                                                 label_valid=case_valid)
                # backward last step should be similar with seed distribution
            else:
                tp_loss = self.get_js_div(forward_dist, backward_dist)
                tp_loss = torch.sum(tp_loss * case_valid) / forward_dist.size(0)
                if constrain_loss is None:
                    constrain_loss = tp_loss
                else:
                    constrain_loss += tp_loss
        return back_loss, constrain_loss

    def calc_loss_label(self, curr_dist, teacher_dist, label_valid):
        tp_loss = self.get_loss(pred_dist=curr_dist, answer_dist=teacher_dist, reduction='none')
        tp_loss = tp_loss * label_valid
        cur_loss = torch.sum(tp_loss) / curr_dist.size(0)
        return cur_loss


    def forward(self, batch, training=False):
        local_entity, query_entities, kb_adj_mat, query_text, seed_dist, true_batch_id,  answer_dist = batch
        local_entity = torch.from_numpy(local_entity).type('torch.LongTensor').to(self.device)

        # local_entity_mask = (local_entity != self.num_entity).float()
        query_entities = torch.from_numpy(query_entities).type('torch.FloatTensor').to(self.device)
        answer_dist = torch.from_numpy(answer_dist).type('torch.FloatTensor').to(self.device)
        seed_dist = torch.from_numpy(seed_dist).type('torch.FloatTensor').to(self.device)
        current_dist = Variable(seed_dist, requires_grad=True)

        q_input= torch.from_numpy(query_text).type('torch.LongTensor').to(self.device)
        #ent_texts= torch.from_numpy(ent_texts).type('torch.LongTensor').to(self.device)
        #ent_texts= torch.from_numpy(ent_texts).type('torch.FloatTensor').to(self.device)
        if self.lm != 'lstm':
            pad_val = self.instruction.pad_val #tokenizer.convert_tokens_to_ids(self.instruction.tokenizer.pad_token)
            query_mask = (q_input != pad_val).float()
            
        else:
            query_mask = (q_input != self.num_word).float()

        
        """
        Instruction generations
        """
        self.init_reason(curr_dist=current_dist, local_entity=local_entity,
                         kb_adj_mat=kb_adj_mat,  q_input=q_input)
        self.instruction.init_reason(q_input)
        
        for i in range(self.num_step):
            relational_ins, attn_weight = self.instruction.get_instruction(self.instruction.relational_ins, step=i)
            self.instruction.instructions.append(relational_ins.unsqueeze(1))
            self.instruction.relational_ins = relational_ins
        
        """
        GNN reasoning
        """
        self.curr_dist = current_dist    
        self.dist_history.append(self.curr_dist)
        self.dist_history2.append(self.curr_dist)

        for i in range(self.num_step):
            
            self.curr_dist = self.reasoning(self.curr_dist, self.instruction_list[i], step=i)
            self.dist_history.append(self.curr_dist)

        """
        NSM backward learning (if used)
        """
        if self.lambda_back != 0.0 or self.lambda_constrain != 0.0:
            answer_len = torch.sum(answer_dist, dim=1, keepdim=True)
            answer_len[answer_len == 0] = 1.0
            answer_prob = answer_dist.div(answer_len)
            self.curr_dist_back = answer_prob
            self.backward_history.append(self.curr_dist_back)
            for i in range(self.num_step):
                self.curr_dist_back = self.reasoning_back(self.curr_dist_back, self.instruction_list[self.num_step-i-1], step=i)
                self.backward_history.append(self.curr_dist_back)

        pred_dist = self.dist_history[-1]
        answer_number = torch.sum(answer_dist, dim=1, keepdim=True)
        case_valid = (answer_number > 0).float()
        # filter no answer training case
        # loss = torch.sum(tp_loss * case_valid) / pred_dist.size(0)
        loss =self.calc_loss_label(curr_dist=pred_dist, teacher_dist=answer_dist, label_valid=case_valid)
        
        if self.lambda_back > 0.0 or self.lambda_constrain > 0.0:
             back_loss, constrain_loss = self.calc_loss_backward(case_valid)
             loss = loss + self.lambda_back * back_loss + self.lambda_constrain * constrain_loss
        pred = torch.max(pred_dist, dim=1)[1]
        if training:
            h1, f1 = self.get_eval_metric(pred_dist, answer_dist)
            tp_list = [h1.tolist(), f1.tolist()]
        else:
            tp_list = None
        #return loss, pred, 0.5*(pred_dist+pred_dist2), tp_list
        return loss, pred, pred_dist, tp_list


================================================
FILE: gnn/models/ReaRev/rearev.py
================================================
import torch
import numpy as np
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn

from models.base_model import BaseModel
from modules.kg_reasoning.reasongnn import ReasonGNNLayer
from modules.question_encoding.lstm_encoder import LSTMInstruction
from modules.question_encoding.bert_encoder import BERTInstruction
from modules.layer_init import TypeLayer
from modules.query_update import AttnEncoder, Fusion, QueryReform

VERY_SMALL_NUMBER = 1e-10
VERY_NEG_NUMBER = -100000000000



class ReaRev(BaseModel):
    def __init__(self, args, num_entity, num_relation, num_word):
        """
        Init ReaRev model.
        """
        super(ReaRev, self).__init__(args, num_entity, num_relation, num_word)
        #self.embedding_def()
        #self.share_module_def()
        self.norm_rel = args['norm_rel']
        self.layers(args)
        

        self.loss_type =  args['loss_type']
        self.num_iter = args['num_iter']
        self.num_ins = args['num_ins']
        self.num_gnn = args['num_gnn']
        self.alg = args['alg']
        assert self.alg == 'bfs'
        self.lm = args['lm']
        
        self.private_module_def(args, num_entity, num_relation)

        self.to(self.device)
        self.lin = nn.Linear(3*self.entity_dim, self.entity_dim)

        self.fusion = Fusion(self.entity_dim)
        self.reforms = []
        for i in range(self.num_ins):
            self.add_module('reform' + str(i), QueryReform(self.entity_dim))
        # self.reform_rel = QueryReform(self.entity_dim)
        # self.add_module('reform', QueryReform(self.entity_dim))

    def layers(self, args):
        # initialize entity embedding
        word_dim = self.word_dim
        kg_dim = self.kg_dim
        entity_dim = self.entity_dim

        #self.lstm_dropout = args['lstm_dropout']
        self.linear_dropout = args['linear_dropout']
        
        self.entity_linear = nn.Linear(in_features=self.ent_dim, out_features=entity_dim)
        self.relation_linear = nn.Linear(in_features=self.rel_dim, out_features=entity_dim)
        # self.relation_linear_inv = nn.Linear(in_features=self.rel_dim, out_features=entity_dim)
        #self.relation_linear = nn.Linear(in_features=self.rel_dim, out_features=entity_dim)

        # dropout
        #self.lstm_drop = nn.Dropout(p=self.lstm_dropout)
        self.linear_drop = nn.Dropout(p=self.linear_dropout)

        if self.encode_type:
            self.type_layer = TypeLayer(in_features=entity_dim, out_features=entity_dim,
                                        linear_drop=self.linear_drop, device=self.device, norm_rel=self.norm_rel)

        self.self_att_r = AttnEncoder(self.entity_dim)
        #self.self_att_r_inv = AttnEncoder(self.entity_dim)
        self.kld_loss = nn.KLDivLoss(reduction='none')
        self.bce_loss_logits = nn.BCEWithLogitsLoss(reduction='none')
        self.mse_loss = torch.nn.MSELoss()

    def get_ent_init(self, local_entity, kb_adj_mat, rel_features):
        if self.encode_type:
            local_entity_emb = self.type_layer(local_entity=local_entity,
                                               edge_list=kb_adj_mat,
                                               rel_features=rel_features)
        else:
            local_entity_emb = self.entity_embedding(local_entity)  # batch_size, max_local_entity, word_dim
            local_entity_emb = self.entity_linear(local_entity_emb)
        
        return local_entity_emb
    
   
    def get_rel_feature(self):
        """
        Encode relation tokens to vectors.
        """
        if self.rel_texts is None:
            rel_features = self.relation_embedding.weight
            rel_features_inv = self.relation_embedding_inv.weight
            rel_features = self.relation_linear(rel_features)
            rel_features_inv = self.relation_linear(rel_features_inv)
        else:
            
            rel_features = self.instruction.question_emb(self.rel_features)
            rel_features_inv = self.instruction.question_emb(self.rel_features_inv)
            
            rel_features = self.self_att_r(rel_features,  (self.rel_texts != self.instruction.pad_val).float())
            rel_features_inv = self.self_att_r(rel_features_inv,  (self.rel_texts != self.instruction.pad_val).float())
            if self.lm == 'lstm':
                rel_features = self.self_att_r(rel_features, (self.rel_texts != self.num_relation+1).float())
                rel_features_inv = self.self_att_r(rel_features_inv, (self.rel_texts_inv != self.num_relation+1).float())

        return rel_features, rel_features_inv


    def private_module_def(self, args, num_entity, num_relation):
        """
        Building modules: LM encoder, GNN, etc.
        """
        # initialize entity embedding
        word_dim = self.word_dim
        kg_dim = self.kg_dim
        entity_dim = self.entity_dim
        self.reasoning = ReasonGNNLayer(args, num_entity, num_relation, entity_dim, self.alg)
        if args['lm'] == 'lstm':
            self.instruction = LSTMInstruction(args, self.word_embedding, self.num_word)
            self.relation_linear = nn.Linear(in_features=entity_dim, out_features=entity_dim)
        else:
            self.instruction = BERTInstruction(args, self.word_embedding, self.num_word, args['lm'])
            #self.relation_linear = nn.Linear(in_features=self.instruction.word_dim, out_features=entity_dim)
        # self.relation_linear = nn.Linear(in_features=entity_dim, out_features=entity_dim)
        # self.relation_linear_inv = nn.Linear(in_features=entity_dim, out_features=entity_dim)

    def init_reason(self, curr_dist, local_entity, kb_adj_mat, q_input, query_entities):
        """
        Initializing Reasoning
        """
        # batch_size = local_entity.size(0)
        self.local_entity = local_entity
        self.instruction_list, self.attn_list = self.instruction(q_input)
        rel_features, rel_features_inv  = self.get_rel_feature()
        self.local_entity_emb = self.get_ent_init(local_entity, kb_adj_mat, rel_features)
        self.init_entity_emb = self.local_entity_emb
        self.curr_dist = curr_dist
        self.dist_history = []
        self.action_probs = []
        self.seed_entities = curr_dist
        
        self.reasoning.init_reason( 
                                   local_entity=local_entity,
                                   kb_adj_mat=kb_adj_mat,
                                   local_entity_emb=self.local_entity_emb,
                                   rel_features=rel_features,
                                   rel_features_inv=rel_features_inv,
                                   query_entities=query_entities)


    def calc_loss_label(self, curr_dist, teacher_dist, label_valid):
        tp_loss = self.get_loss(pred_dist=curr_dist, answer_dist=teacher_dist, reduction='none')
        tp_loss = tp_loss * label_valid
        cur_loss = torch.sum(tp_loss) / curr_dist.size(0)
        return cur_loss

    
    def forward(self, batch, training=False):
        """
        Forward function: creates instructions and performs GNN reasoning.
        """

        # local_entity, query_entities, kb_adj_mat, query_text, seed_dist, answer_dist = batch
        local_entity, query_entities, kb_adj_mat, query_text, seed_dist, true_batch_id, answer_dist = batch
        local_entity = torch.from_numpy(local_entity).type('torch.LongTensor').to(self.device)
        # local_entity_mask = (local_entity != self.num_entity).float()
        query_entities = torch.from_numpy(query_entities).type('torch.FloatTensor').to(self.device)
        answer_dist = torch.from_numpy(answer_dist).type('torch.FloatTensor').to(self.device)
        seed_dist = torch.from_numpy(seed_dist).type('torch.FloatTensor').to(self.device)
        current_dist = Variable(seed_dist, requires_grad=True)

        q_input= torch.from_numpy(query_text).type('torch.LongTensor').to(self.device)
        #query_text2 = torch.from_numpy(query_text2).type('torch.LongTensor').to(self.device)
        if self.lm != 'lstm':
            pad_val = self.instruction.pad_val #tokenizer.convert_tokens_to_ids(self.instruction.tokenizer.pad_token)
            query_mask = (q_input != pad_val).float()
            
        else:
            query_mask = (q_input != self.num_word).float()

        
        """
        Instruction generations
        """
        self.init_reason(curr_dist=current_dist, local_entity=local_entity,
                         kb_adj_mat=kb_adj_mat, q_input=q_input, query_entities=query_entities)
        self.instruction.init_reason(q_input)
        for i in range(self.num_ins):
            relational_ins, attn_weight = self.instruction.get_instruction(self.instruction.relational_ins, step=i) 
            self.instruction.instructions.append(relational_ins.unsqueeze(1))
            self.instruction.relational_ins = relational_ins
        #relation_ins = torch.cat(self.instruction.instructions, dim=1)
        #query_emb = None
        self.dist_history.append(self.curr_dist)


        """
        BFS + GNN reasoning
        """

        for t in range(self.num_iter):
            relation_ins = torch.cat(self.instruction.instructions, dim=1)
            self.curr_dist = current_dist            
            for j in range(self.num_gnn):
                self.curr_dist, global_rep = self.reasoning(self.curr_dist, relation_ins, step=j)
            self.dist_history.append(self.curr_dist)
            qs = []

            """
            Instruction Updates
            """
            for j in range(self.num_ins):
                reform = getattr(self, 'reform' + str(j))
                q = reform(self.instruction.instructions[j].squeeze(1), global_rep, query_entities, local_entity)
                qs.append(q.unsqueeze(1))
                self.instruction.instructions[j] = q.unsqueeze(1)
        
        
        """
        Answer Predictions
        """
        pred_dist = self.dist_history[-1]
        answer_number = torch.sum(answer_dist, dim=1, keepdim=True)
        case_valid = (answer_number > 0).float()
        # filter no answer training case
        # loss = 0
        # for pred_dist in self.dist_history:
        loss = self.calc_loss_label(curr_dist=pred_dist, teacher_dist=answer_dist, label_valid=case_valid)

        
        pred_dist = self.dist_history[-1]
        pred = torch.max(pred_dist, dim=1)[1]
        if training:
            h1, f1 = self.get_eval_metric(pred_dist, answer_dist)
            tp_list = [h1.tolist(), f1.tolist()]
        else:
            tp_list = None
        return loss, pred, pred_dist, tp_list

    

================================================
FILE: gnn/models/base_model.py
================================================
import torch
import numpy as np
import torch.nn as nn

import numpy as np

VERY_SMALL_NUMBER = 1e-10

class BaseModel(torch.nn.Module):
    """
    Base model functions: create embeddings, store relations, compute f1/h1 scores, etc.
    """

    def __init__(self, args, num_entity, num_relation, num_word):
        super(BaseModel, self).__init__()
        self.num_relation = num_relation
        self.num_entity = num_entity
        self.num_word = num_word
        print('Num Word', self.num_word)
        self.kge_frozen = args['kge_frozen']
        self.kg_dim = args['kg_dim']
        #self._parse_args(args)
        self.entity_emb_file = args['entity_emb_file']
        self.relation_emb_file = args['relation_emb_file']
        self.relation_word_emb = args['relation_word_emb']
        self.word_emb_file = args['word_emb_file']
        self.entity_dim = args['entity_dim']
        
        self.lm = args['lm']
        if self.lm in ['bert']:
            #self.word_dim = 768
            args['word_dim'] = 768
        
        self.word_dim = args['word_dim']

        self.rel_texts = None

        
        #self.share_module_def()
        #self.model_name = args['model_name'].lower()
        self.device = torch.device('cuda' if args['use_cuda'] else 'cpu')
       
        print("Entity: {}, Relation: {}, Word: {}".format(num_entity, num_relation, num_word))

        
        self.kld_loss = nn.KLDivLoss(reduction='none')
        self.bce_loss = nn.BCEWithLogitsLoss(reduction='none')
        self.mse_loss = torch.nn.MSELoss()

        for k, v in args.items():
            if k.endswith('dim'):
                setattr(self, k, v)
            if k.endswith('emb_file') or k.endswith('kge_file'):
                if v is None:
                    setattr(self, k, None)
                else:
                    setattr(self, k, args['data_folder'] + v)

        self.reset_time = 0

        if 'use_inverse_relation' in args:
            self.use_inverse_relation = args['use_inverse_relation']
        if 'use_self_loop' in args:
            self.use_self_loop = args['use_self_loop']
        self.eps = args['eps']

        self.embedding_def()
        args['word_dim'] = self.word_dim
        
    def embedding_def(self):
        num_entity = self.num_entity
        num_relation = self.num_relation
        num_word = self.num_word

        if self.lm != 'lstm':
            self.word_dim = 768
            self.word_embedding = nn.Embedding(num_embeddings=num_word + 1, embedding_dim=self.word_dim,
                                           padding_idx=num_word)
        elif self.word_emb_file is not None:
            word_emb = np.load(self.word_emb_file)
            _ , self.word_dim = word_emb.shape
            print('Word emb dim', self.word_dim)
            self.word_embedding = nn.Embedding(num_embeddings=num_word + 1, embedding_dim=self.word_dim,
                                           padding_idx=num_word)
            self.word_embedding.weight = nn.Parameter(
                torch.from_numpy(
                    np.pad(np.load(self.word_emb_file), ((0, 1), (0, 0)), 'constant')).type(
                    'torch.FloatTensor'))
            self.word_embedding.weight.requires_grad = False
        else:
            #self.word_dim = 768
            self.word_embedding = nn.Embedding(num_embeddings=num_word + 1, embedding_dim=self.word_dim,
                                           padding_idx=num_word)


        if self.entity_emb_file is not None:
            self.encode_type = False
            emb = np.load(self.entity_emb_file)
            ent_num , self.ent_dim = emb.shape
            # if ent_num != num_entity:
            #     print('Number of entities in KG embeddings do not match: Random Init.')
            
            self.entity_embedding = nn.Embedding(num_embeddings=num_entity + 1, embedding_dim=self.ent_dim,
                                                padding_idx=num_entity)
            if ent_num != num_entity:
                print('Number of entities in KG embeddings do not match: Random Init.')
            else:
                self.entity_embedding.weight = nn.Parameter(
                    torch.from_numpy(np.pad(emb, ((0, 1), (0, 0)), 'constant')).type(
                        'torch.FloatTensor'))
            if self.kge_frozen:
                self.entity_embedding.weight.requires_grad = False
            else:
                self.entity_embedding.weight.requires_grad = True
        else:
            self.ent_dim = self.kg_dim 
            self.encode_type = True
            #self.entity_embedding = nn.Embedding(num_embeddings=num_entity + 1, embedding_dim=self.ent_dim,
                                                #padding_idx=num_entity)

        # initialize relation embedding
        if self.relation_emb_file is not None:
            np_tensor = self.load_relation_file(self.relation_emb_file)
            #print('check?', np_tensor.shape)
            rel_num, self.rel_dim = np_tensor.shape
            self.relation_embedding = nn.Embedding(num_embeddings=num_relation+1, embedding_dim=self.rel_dim)
            if rel_num != num_relation:
                 print('Number of relations in KG embeddings do not match: Random Init.')
            else:
                self.relation_embedding.weight = nn.Parameter(torch.from_numpy(np_tensor).type('torch.FloatTensor'))
            if self.kge_frozen:
                self.relation_embedding.weight.requires_grad = False
            else:
                self.relation_embedding.weight.requires_grad = True

        elif self.relation_word_emb:
            self.rel_dim = self.entity_dim
            self.relation_embedding = nn.Embedding(num_embeddings=num_relation+1, embedding_dim=self.rel_dim)
            self.relation_embedding.weight.requires_grad = True
            self.relation_embedding_inv = nn.Embedding(num_embeddings=num_relation+1, embedding_dim=self.rel_dim)
            self.relation_embedding_inv.weight.requires_grad = True
            pass
        else:
            self.rel_dim = 2*self.kg_dim 
            self.relation_embedding = nn.Embedding(num_embeddings=num_relation+1, embedding_dim=self.rel_dim)
            self.relation_embedding_inv = nn.Embedding(num_embeddings=num_relation+1, embedding_dim=self.rel_dim)

        # initialize text embeddings
        
        
    

    def load_relation_file(self, filename):
        half_tensor = np.load(filename)
        num_pad = 0
        if self.use_self_loop:
            num_pad = 2
        if self.use_inverse_relation:
            load_tensor = np.concatenate([half_tensor, half_tensor])
        else:
            load_tensor = half_tensor
        return np.pad(load_tensor, ((0, num_pad), (0, 0)), 'constant')

    def use_rel_texts(self, rel_texts, rel_texts_inv):
        self.rel_texts = torch.from_numpy(rel_texts).type('torch.LongTensor').to(self.device)
        self.rel_texts_inv = torch.from_numpy(rel_texts_inv).type('torch.LongTensor').to(self.device)

    def encode_rel_texts(self, rel_texts, rel_texts_inv):
        self.rel_texts = torch.from_numpy(rel_texts).type('torch.LongTensor').to(self.device)
        self.rel_texts_inv = torch.from_numpy(rel_texts_inv).type('torch.LongTensor').to(self.device)
        self.instruction.eval()
        with torch.no_grad():
            self.rel_features = self.instruction.encode_question(self.rel_texts, store=False)
            self.rel_features_inv = self.instruction.encode_question(self.rel_texts_inv, store=False)
        self.rel_features.requires_grad = False
        self.rel_features_inv.requires_grad = False

    def init_hidden(self, num_layer, batch_size, hidden_size):
        return self.instruction.init_hidden(num_layer, batch_size, hidden_size)

    def encode_question(self, q_input):
        return self.instruction.encode_question(q_input)

    def get_instruction(self, query_hidden_emb, query_mask, states):
        return self.instruction.get_instruction(query_hidden_emb, query_mask, states)

    def get_loss_bce(self, pred_dist_score, answer_dist):
        answer_dist = (answer_dist > 0).float() * 0.9   # label smooth
        # answer_dist = answer_dist * 0.9  # label smooth
        loss = self.bce_loss_logits(pred_dist_score, answer_dist)
        return loss

    def get_loss_kl(self, pred_dist, answer_dist):
        answer_len = torch.sum(answer_dist, dim=1, keepdim=True)
        answer_len[answer_len == 0] = 1.0
        answer_prob = answer_dist.div(answer_len)
        log_prob = torch.log(pred_dist + 1e-8)
        loss = self.kld_loss(log_prob, answer_prob)
        return loss

    def get_loss(self, pred_dist, answer_dist, reduction='mean'):
        if self.loss_type == "bce":
            tp_loss = self.get_loss_bce(pred_dist, answer_dist)
            if reduction == 'none':
                return tp_loss
            else:
                # mean
                return torch.mean(tp_loss)
        else:
            tp_loss = self.get_loss_kl(pred_dist, answer_dist)
            if reduction == 'none':
                return tp_loss
            else:
                # batchmean
                return torch.sum(tp_loss) / pred_dist.size(0)

    def f1_and_hits(self, answers, candidate2prob, eps=0.5):
        retrieved = []
        correct = 0
        cand_list = sorted(candidate2prob, key=lambda x:x[1], reverse=True)
        if len(cand_list) == 0:
            best_ans = -1
        else:
            best_ans = cand_list[0][0]
        # max_prob = cand_list[0][1]
        tp_prob = 0.0
        for c, prob in cand_list:
            retrieved.append((c, prob))
            tp_prob += prob
            if c in answers:
                correct += 1
            if tp_prob > eps:
                break
        if len(answers) == 0:
            if len(retrieved) == 0:
                return 1.0, 1.0, 1.0, 1.0  # precision, recall, f1, hits
            else:
                return 0.0, 1.0, 0.0, 1.0  # precision, recall, f1, hits
        else:
            hits = float(best_ans in answers)
            if len(retrieved) == 0:
                return 1.0, 0.0, 0.0, hits  # precision, recall, f1, hits
            else:
                p, r = correct / len(retrieved), correct / len(answers)
                f1 = 2.0 / (1.0 / p + 1.0 / r) if p != 0 and r != 0 else 0.0
                return p, r, f1, hits


    def calc_f1_new(self, curr_dist, dist_ans, h1_vec):
        batch_size = curr_dist.size(0)
        max_local_entity = curr_dist.size(1)
        seed_dist = self.seed_entities #self.dist_history[0]
        local_entity = self.local_entity
        ignore_prob = (1 - self.eps) / max_local_entity
        pad_ent_id = self.num_entity
        # hits_list = []
        f1_list = []
        for batch_id in range(batch_size):
            if h1_vec[batch_id].item() == 0.0:
                f1_list.append(0.0)
                # we consider cases which own hit@1 as prior to reduce computation time
                continue
            candidates = local_entity[batch_id, :].tolist()
            probs = curr_dist[batch_id, :].tolist()
            answer_prob = dist_ans[batch_id, :].tolist()
            seed_entities = seed_dist[batch_id, :].tolist()
            answer_list = []
            candidate2prob = []
            for c, p, p_a, s in zip(candidates, probs, answer_prob, seed_entities):
                if s > 0:
                    # ignore seed entities
                    continue
                if c == pad_ent_id:
                    continue
                if p_a > 0:
                    answer_list.append(c)
                if p < ignore_prob:
                    continue
                candidate2prob.append((c, p))
            precision, recall, f1, hits = self.f1_and_hits(answer_list, candidate2prob, self.eps)
            # hits_list.append(hits)
            f1_list.append(f1)
        # hits_vec = torch.FloatTensor(hits_list).to(self.device)
        f1_vec = torch.FloatTensor(f1_list).to(self.device)
        return f1_vec

    def calc_h1(self, curr_dist, dist_ans, eps=0.01):
        greedy_option = curr_dist.argmax(dim=-1, keepdim=True)
        dist_top1 = torch.zeros_like(curr_dist).scatter_(1, greedy_option, 1.0)
        dist_ans = (dist_ans > eps).float()
        h1 = torch.sum(dist_top1 * dist_ans, dim=-1)
        return (h1 > 0).float()
    
    def get_eval_metric(self, pred_dist, answer_dist):
        with torch.no_grad():
            h1 = self.calc_h1(curr_dist=pred_dist, dist_ans=answer_dist, eps=VERY_SMALL_NUMBER)
            f1 = self.calc_f1_new(pred_dist, answer_dist, h1)
        return h1, f1

================================================
FILE: gnn/modules/kg_reasoning/base_gnn.py
================================================
import torch
import numpy as np
from collections import defaultdict

VERY_NEG_NUMBER = -100000000000

class BaseGNNLayer(torch.nn.Module):
    """
    Builds sparse tensors that represent structure.
    """
    def __init__(self, args, num_entity, num_relation):
        super(BaseGNNLayer, self).__init__()
        self.num_relation = num_relation
        self.num_entity = num_entity
        self.device = torch.device('cuda' if args['use_cuda'] else 'cpu')
        self.normalized_gnn = args['normalized_gnn']


    def build_matrix(self):
        batch_heads, batch_rels, batch_tails, batch_ids, fact_ids, weight_list, _ = self.edge_list
        num_fact = len(fact_ids)
        num_relation = self.num_relation
        batch_size = self.batch_size
        max_local_entity = self.max_local_entity
        self.num_fact = num_fact
        fact2head = torch.LongTensor([batch_heads, fact_ids]).to(self.device)
        fact2tail = torch.LongTensor([batch_tails, fact_ids]).to(self.device)
        head2fact = torch.LongTensor([fact_ids, batch_heads]).to(self.device)
        tail2fact = torch.LongTensor([fact_ids, batch_tails]).to(self.device)
        head2tail = torch.LongTensor([batch_heads, batch_tails]).to(self.device)
        rel2fact = torch.LongTensor([fact_ids, batch_rels + batch_ids * num_relation]).to(self.device)
        fact2rel = torch.LongTensor([batch_rels + batch_ids * num_relation, fact_ids]).to(self.device)
        self.batch_rels = torch.LongTensor(batch_rels).to(self.device)
        self.batch_ids = torch.LongTensor(batch_ids).to(self.device)
        self.batch_heads = torch.LongTensor(batch_heads).to(self.device)
        self.batch_tails = torch.LongTensor(batch_tails).to(self.device)
        # self.batch_ids = batch_ids
        if self.normalized_gnn:
            vals = torch.FloatTensor(weight_list).to(self.device)
        else:
            vals = torch.ones_like(self.batch_ids).float().to(self.device)

        #vals = torch.ones_like(self.batch_ids).float().to(self.device)
        # Sparse Matrix for reason on graph
        self.fact2head_mat = self._build_sparse_tensor(fact2head, vals, (batch_size * max_local_entity, num_fact))
        self.head2fact_mat = self._build_sparse_tensor(head2fact, vals, (num_fact, batch_size * max_local_entity))
        self.fact2tail_mat = self._build_sparse_tensor(fact2tail, vals, (batch_size * max_local_entity, num_fact))
        self.tail2fact_mat = self._build_sparse_tensor(tail2fact, vals, (num_fact, batch_size * max_local_entity))
        self.head2tail_mat = self._build_sparse_tensor(head2tail, vals, (batch_size * max_local_entity, batch_size * max_local_entity))
        self.fact2rel_mat = self._build_sparse_tensor(fact2rel, vals, (batch_size * num_relation, num_fact))
        self.rel2fact_mat = self._build_sparse_tensor(rel2fact, vals, (num_fact, batch_size * num_relation))

    def _build_sparse_tensor(self, indices, values, size):
        return torch.sparse.FloatTensor(indices, values, size).to(self.device)

    def build_adj_facts(self):
        
        batch_size = self.batch_size
        max_local_entity = self.max_local_entity
        max_fact = self.max_fact
        
        (e2f_batch, e2f_f, e2f_e, e2f_val), (f2e_batch, f2e_e, f2e_f, f2e_val) = self.edge_list2
        
        entity2fact_index = torch.LongTensor([e2f_batch, e2f_f, e2f_e]).to(self.device)
        entity2fact_val = torch.FloatTensor(e2f_val).to(self.device)
        self.entity2fact_mat =torch.sparse.FloatTensor(entity2fact_index, entity2fact_val, \
            torch.Size([batch_size, max_fact, max_local_entity])).to(self.device) # batch_size, max_fact, max_local_entity

        fact2entity_index = torch.LongTensor([f2e_batch, f2e_e, f2e_f]).to(self.device)
        fact2entity_val = torch.FloatTensor(f2e_val).to(self.device)
        self.fact2entity_mat = torch.sparse.FloatTensor(fact2entity_index, fact2entity_val, \
            torch.Size([batch_size, max_local_entity, max_fact])).to(self.device) # batch_size,  max_local_entity, max_fact


        self.kb_fact_rel =  torch.LongTensor(self.kb_fact_rel).to(self.device)


================================================
FILE: gnn/modules/kg_reasoning/graft_gnn.py
================================================
import torch
import numpy as np
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import time

from .base_gnn import BaseGNNLayer

VERY_SMALL_NUMBER = 1e-10
VERY_NEG_NUMBER = -100000000000


class GraftLayer(BaseGNNLayer):
    def __init__(self, args, num_entity, num_relation, entity_dim):
        super(GraftLayer, self).__init__(args, num_entity, num_relation)
        self.num_entity = num_entity
        self.num_relation = num_relation
        self.entity_dim = entity_dim
        self.entity_dim = entity_dim
        self.num_layer = args['num_layer']
        self.pagerank_lambda = args['pagerank_lambda']
        self.fact_scale = args['fact_scale']
        self.k = 3
        self.init_layers(args)

    def init_layers(self, args):
        entity_dim = self.entity_dim
        self.softmax_d1 = nn.Softmax(dim=1)
        self.score_func = nn.Linear(in_features=entity_dim, out_features=1)
        
        self.linear_dropout = args['linear_dropout']
        self.linear_drop = nn.Dropout(p=self.linear_dropout)
        for i in range(self.num_layer):
            self.add_module('q2e_linear' + str(i), nn.Linear(in_features=entity_dim, out_features=entity_dim))
            self.add_module('e2q_linear' + str(i), nn.Linear(in_features=self.k * entity_dim, out_features=entity_dim))
            self.add_module('e2e_linear' + str(i), nn.Linear(in_features=self.k * entity_dim, out_features=entity_dim))

            self.add_module('kb_head_linear' + str(i), nn.Linear(in_features=entity_dim, out_features=entity_dim))
            self.add_module('kb_tail_linear' + str(i), nn.Linear(in_features=entity_dim, out_features=entity_dim))
            self.add_module('kb_self_linear' + str(i), nn.Linear(in_features=entity_dim, out_features=entity_dim))

           

    def init_reason(self, local_entity, kb_adj_mat, kb_adj_mat_graft, kb_fact_rel, local_entity_emb, rel_features, query_node_emb=None):
        batch_size, max_local_entity = local_entity.size()
        self.local_entity_mask = (local_entity != self.num_entity).float()
        self.batch_size = batch_size
        self.max_local_entity = max_local_entity
        self.edge_list = kb_adj_mat
        self.edge_list2 = kb_adj_mat_graft
        self.rel_features = rel_features
        self.local_entity_emb = local_entity_emb
        self.num_relation = self.rel_features.size(0)
        self.possible_cand = []
        self.kb_fact_rel = kb_fact_rel #torch.LongTensor(kb_fact_rel).to(self.device)
        _, self.max_fact = kb_fact_rel.shape
        self.query_node_emb = query_node_emb
        self.build_matrix()
        self.build_adj_facts()
        
    

    def compute_attention(self, query_hidden_emb, query_mask):
        batch_size = self.batch_size
        max_local_entity = self.max_local_entity
        rel_features = self.rel_features
        num_rels = rel_features.size(0)

        #print(self.kb_fact_rel, self.kb_fact_rel.size())
        fact_rel = torch.index_select(rel_features, dim=0, index=self.batch_rels)
        #batch_rels = rel_features[self.batch_ids, self.batch_rels]
        #print(fact_rel.size())
        local_fact_emb = rel_features[self.kb_fact_rel]#torch.sparse.mm(self.fact2rel_mat, fact_rel).view(batch_size, -1, self.entity_dim)
        
        # attention fact2question
        div = float(np.sqrt(self.entity_dim))
        fact2query_sim = torch.bmm(query_hidden_emb, local_fact_emb.transpose(1, 2)) / div # batch_size, max_query_word, max_fact
        fact2query_sim = self.softmax_d1(fact2query_sim + (1 - query_mask.unsqueeze(dim=2)) * VERY_NEG_NUMBER) # batch_size, max_query_word, max_fact
        fact2query_att = torch.sum(fact2query_sim.unsqueeze(dim=3) * query_hidden_emb.unsqueeze(dim=2), dim=1) # batch_size, max_fact, entity_dim
        W = torch.sum(fact2query_att * local_fact_emb, dim=2) / div # batch_size, max_fact
        W_max = torch.max(W, dim=1, keepdim=True)[0] # batch_size, 1
        self.W_tilde = torch.exp(W - W_max) # batch_size, max_fact
        e2f_softmax = torch.bmm(self.entity2fact_mat.transpose(1, 2), self.W_tilde.unsqueeze(dim=2)).squeeze(dim=2) # batch_size, max_local_entity
        self.e2f_softmax = torch.clamp(e2f_softmax, min=VERY_SMALL_NUMBER)
        #e2f_out_dim = use_cuda(Variable(torch.sum(entity2fact_mat.to_dense(), dim=1), requires_grad=False)) # batch_size, max_local_entity
        assert not torch.isnan(self.e2f_softmax).any()

    def reason_layer(self, curr_dist, kb_self_linear, kb_head_linear, kb_tail_linear):
        batch_size = self.batch_size
        max_local_entity = self.max_local_entity
        # num_relation = self.num_relation
        rel_features = self.rel_features

        local_fact_emb = rel_features[self.kb_fact_rel]
        e2f_emb = F.relu(kb_self_linear(local_fact_emb) + torch.bmm(self.entity2fact_mat, kb_head_linear(self.linear_drop(self.local_entity_emb)))) # batch_size, max_fact, entity_dim
        e2f_softmax_normalized = self.W_tilde.unsqueeze(dim=2) * torch.bmm(self.entity2fact_mat, (curr_dist / self.e2f_softmax).unsqueeze(dim=2)) # batch_size, max_fact, 1
        e2f_emb = e2f_emb * e2f_softmax_normalized # batch_size, max_fact, entity_dim
        f2e_emb = F.relu(kb_self_linear(self.local_entity_emb) + torch.bmm(self.fact2entity_mat, kb_tail_linear(self.linear_drop(e2f_emb))))
                
        next_curr_dist = torch.bmm(self.fact2entity_mat, e2f_softmax_normalized).squeeze(dim=2)
        next_curr_dist = self.pagerank_lambda * next_curr_dist + (1 - self.pagerank_lambda) * curr_dist # batch_size, max_local_entity

        assert not torch.isnan(f2e_emb).any()
        neighbor_rep = f2e_emb

        return neighbor_rep, next_curr_dist

    

    def forward(self, current_dist, query_hidden_emb, query_mask, step=0, return_score=True):
        # get linear transformation functions for each layer
        q2e_linear = getattr(self, 'q2e_linear' + str(step))
        e2e_linear = getattr(self, 'e2e_linear' + str(step))
        e2q_linear = getattr(self, 'e2q_linear' + str(step))
        kb_self_linear = getattr(self, 'kb_self_linear' + str(step))
        kb_head_linear = getattr(self, 'kb_head_linear' + str(step))
        kb_tail_linear = getattr(self, 'kb_tail_linear' + str(step))

        batch_size = self.batch_size
        max_local_entity = self.max_local_entity

        if step == 0:
            query_node_emb = self.query_node_emb#.unsqueeze(1)
            self.compute_attention(query_hidden_emb, query_mask)
            #q2e_emb = q2e_linear(self.linear_drop(self.query_node_emb)).expand(batch_size, max_local_entity, self.entity_dim)
        else:
            query_node_emb = self.query_emb#.unsqueeze(1)

        q2e_emb = q2e_linear(self.linear_drop(query_node_emb)).expand(batch_size, max_local_entity, self.entity_dim) # batch_size, max_local_entity, entity_dim

        next_local_entity_emb = torch.cat((self.local_entity_emb, q2e_emb), dim=2)
        # score_func = getattr(self, 'score_func' + str(step))
        score_func = self.score_func
        #relational_ins = relational_ins.squeeze(1)
        neighbor_rep, next_curr_dist = self.reason_layer(current_dist, kb_self_linear, kb_head_linear, kb_tail_linear)

        next_local_entity_emb = torch.cat((next_local_entity_emb, self.fact_scale*neighbor_rep), dim=2)
        #self.query_emb = torch.bmm(init_dist.unsqueeze(dim=1), e2q_linear(self.linear_drop(next_local_entity_emb)))
        self.query_emb = torch.bmm(next_curr_dist.unsqueeze(dim=1), e2q_linear(self.linear_drop(next_local_entity_emb)))

        self.local_entity_emb = F.relu(e2e_linear(self.linear_drop(next_local_entity_emb)))

        score_tp = score_func(self.linear_drop(self.local_entity_emb)).squeeze(dim=2)
        answer_mask = self.local_entity_mask
        self.possible_cand.append(answer_mask)
        score = score_tp + (1 - answer_mask) * VERY_NEG_NUMBER
        score = self.softmax_d1(score) #F.sigmoid(score) #* self.local_entity_mask #* answer_mask #+ (1 - answer_mask) * VERY_NEG_NUMBER
        #current_dist = self.softmax_d1(score_tp)
        current_dist = next_curr_dist
        if return_score:
            return score_tp, score, current_dist
        return score_tp, current_dist



================================================
FILE: gnn/modules/kg_reasoning/nsm_gnn.py
================================================
import torch
import numpy as np
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import time

from .base_gnn import BaseGNNLayer

VERY_SMALL_NUMBER = 1e-10
VERY_NEG_NUMBER = -100000000000


class NSMBaseLayer(BaseGNNLayer):
    def __init__(self, args, num_entity, num_relation, entity_dim):
        super(NSMBaseLayer, self).__init__(args, num_entity, num_relation)
        self.num_entity = num_entity
        self.num_relation = num_relation
        self.entity_dim = entity_dim
        self.num_steps = args['num_step']
        self.reason_kb = args['reason_kb']
        self.init_layers(args)

    def init_layers(self, args):
        entity_dim = self.entity_dim
        self.softmax_d1 = nn.Softmax(dim=1)
        self.score_func = nn.Linear(in_features=entity_dim, out_features=1)
        
        self.lin = nn.Linear(in_features=2*entity_dim, out_features=entity_dim)

        
        self.linear_dropout = args['linear_dropout']
        self.linear_drop = nn.Dropout(p=self.linear_dropout)
        for i in range(self.num_steps):
            self.add_module('rel_linear' + str(i), nn.Linear(in_features=entity_dim, out_features=entity_dim))
            self.add_module('e2e_linear' + str(i), nn.Linear(in_features=entity_dim + entity_dim, out_features=entity_dim))
            
    def init_reason(self, local_entity, kb_adj_mat, local_entity_emb, rel_features, query_node_emb=None):
        batch_size, max_local_entity = local_entity.size()
        self.local_entity_mask = (local_entity != self.num_entity).float()
        self.batch_size = batch_size
        self.max_local_entity = max_local_entity
        self.edge_list = kb_adj_mat
        self.rel_features = rel_features
        self.local_entity_emb = local_entity_emb
        self.num_relation = self.rel_features.size(0)
        self.possible_cand = []
        self.build_matrix()
        
       
    def reason_layer(self, curr_dist, instruction, rel_linear):
        pass

    def forward(self, current_dist, relational_ins, step=0, return_score=False):
        rel_linear = getattr(self, 'rel_linear' + str(step))
        e2e_linear = getattr(self, 'e2e_linear' + str(step))
        
        # score_func = getattr(self, 'score_func' + str(step))
        score_func = self.score_func
        relational_ins = relational_ins.squeeze(1)
        neighbor_rep, possible_tail = self.reason_layer(current_dist, relational_ins, rel_linear)
        next_local_entity_emb = torch.cat((self.local_entity_emb, neighbor_rep), dim=2)
        self.local_entity_emb = e2e_linear(self.linear_drop(next_local_entity_emb))
        
        self.local_entity_emb = F.relu(self.local_entity_emb)

        score_tp = score_func(self.linear_drop(self.local_entity_emb)).squeeze(dim=2)
        if self.reason_kb:
            answer_mask = self.local_entity_mask * possible_tail
        else:
            answer_mask = self.local_entity_mask
        self.possible_cand.append(answer_mask)
        score_tp = score_tp + (1 - answer_mask) * VERY_NEG_NUMBER
        current_dist = self.softmax_d1(score_tp)
        if return_score:
            return score_tp, current_dist
        return current_dist


    


class NSMLayer(NSMBaseLayer):
    def __init__(self, args, num_entity, num_relation, entity_dim):
        super(NSMLayer, self).__init__(args, num_entity, num_relation, entity_dim)

    def reason_layer(self, curr_dist, instruction, rel_linear):
        batch_size = self.batch_size
        max_local_entity = self.max_local_entity
        # num_relation = self.num_relation
        rel_features = self.rel_features
        
        fact_rel = torch.index_select(rel_features, dim=0, index=self.batch_rels) #rels (facts), entity_dim


        fact_query = torch.index_select(instruction, dim=0, index=self.batch_ids) #one query per batch entry: rels (facts), entity_dim
        fact_val = F.relu(rel_linear(fact_rel) * fact_query)
        fact_prior = torch.sparse.mm(self.head2fact_mat, curr_dist.view(-1, 1)) #rels (facts), 1 (scaling)


        possible_tail = torch.sparse.mm(self.fact2tail_mat, fact_prior) # batch_size * max_local_entity, 1
        # (batch_size *max_local_entity, num_fact) (num_fact, 1)
        possible_tail = (possible_tail > VERY_SMALL_NUMBER).float().view(batch_size, max_local_entity)

        fact_val = fact_val * fact_prior
        
        f2e_emb = torch.sparse.mm(self.fact2tail_mat, fact_val)  # batch_size * max_local_entity, entity_dim 
        assert not torch.isnan(f2e_emb).any()

        neighbor_rep = f2e_emb.view(batch_size, max_local_entity, self.entity_dim)
        
        return neighbor_rep, possible_tail

class NSMLayer_back(NSMBaseLayer):
    def __init__(self, args, num_entity, num_relation, entity_dim):
        super(NSMLayer_back, self).__init__(args, num_entity, num_relation, entity_dim)

    def reason_layer(self, curr_dist, instruction, rel_linear):
        batch_size = self.batch_size
        max_local_entity = self.max_local_entity
        # num_relation = self.num_relation
        rel_features = self.rel_features_inv
        
        fact_rel = torch.index_select(rel_features, dim=0, index=self.batch_rels)
        
        fact_query = torch.index_select(instruction, dim=0, index=self.batch_ids)
        fact_val = F.relu(rel_linear(fact_rel) * fact_query)
        fact_prior = torch.sparse.mm(self.tail2fact_mat, curr_dist.view(-1, 1))
        
        possible_head = torch.sparse.mm(self.fact2head_mat, fact_prior)
        # (batch_size *max_local_entity, num_fact) (num_fact, 1)
        possible_head = (possible_head > VERY_SMALL_NUMBER).float().view(batch_size, max_local_entity)

        fact_val = fact_val * fact_prior

        f2e_emb = torch.sparse.mm(self.fact2head_mat, fact_val)
        assert not torch.isnan(f2e_emb).any()

        neighbor_rep = f2e_emb.view(batch_size, max_local_entity, self.entity_dim)
        
        
        return neighbor_rep, possible_head

================================================
FILE: gnn/modules/kg_reasoning/reasongnn.py
================================================

import torch
import torch.nn.functional as F
import torch.nn as nn


from .base_gnn import BaseGNNLayer

VERY_NEG_NUMBER = -100000000000

class ReasonGNNLayer(BaseGNNLayer):
    """
    GNN Reasoning
    """
    def __init__(self, args, num_entity, num_relation, entity_dim, alg):
        super(ReasonGNNLayer, self).__init__(args, num_entity, num_relation)
        self.num_entity = num_entity
        self.num_relation = num_relation
        self.entity_dim = entity_dim
        self.alg = alg
        self.num_ins = args['num_ins']
        self.num_gnn = args['num_gnn']
        
        self.use_posemb = args['pos_emb']
        self.init_layers(args)

    def init_layers(self, args):
        entity_dim = self.entity_dim
        self.softmax_d1 = nn.Softmax(dim=1)
        self.score_func = nn.Linear(in_features=entity_dim, out_features=1)
        self.glob_lin = nn.Linear(in_features=entity_dim, out_features=entity_dim)
        self.lin = nn.Linear(in_features=2*entity_dim, out_features=entity_dim)
        assert self.alg == 'bfs'
        self.linear_dropout = args['linear_dropout']
        self.linear_drop = nn.Dropout(p=self.linear_dropout)
        for i in range(self.num_gnn):
            self.add_module('rel_linear' + str(i), nn.Linear(in_features=entity_dim, out_features=entity_dim))
            if self.alg == 'bfs':
                self.add_module('e2e_linear' + str(i), nn.Linear(in_features=2*(self.num_ins)*entity_dim + entity_dim, out_features=entity_dim))

            if self.use_posemb:
                self.add_module('pos_emb' + str(i), nn.Embedding(self.num_relation, entity_dim))
                self.add_module('pos_emb_inv' + str(i), nn.Embedding(self.num_relation, entity_dim))
        self.lin_m =  nn.Linear(in_features=(self.num_ins)*entity_dim, out_features=entity_dim)

    def init_reason(self, local_entity, kb_adj_mat, local_entity_emb, rel_features, rel_features_inv, query_entities, query_node_emb=None):
        batch_size, max_local_entity = local_entity.size()
        self.local_entity_mask = (local_entity != self.num_entity).float()
        self.batch_size = batch_size
        self.max_local_entity = max_local_entity
        self.edge_list = kb_adj_mat
        self.rel_features = rel_features
        self.rel_features_inv = rel_features_inv
        self.local_entity_emb = local_entity_emb
        self.num_relation = self.rel_features.size(0)
        self.possible_cand = []
        self.build_matrix()
        self.query_entities = query_entities
       

    def reason_layer(self, curr_dist, instruction, rel_linear, pos_emb):
        """
        Aggregates neighbor representations
        """
        batch_size = self.batch_size
        max_local_entity = self.max_local_entity
        # num_relation = self.num_relation
        rel_features = self.rel_features
        
        
        fact_rel = torch.index_select(rel_features, dim=0, index=self.batch_rels)
        
        fact_query = torch.index_select(instruction, dim=0, index=self.batch_ids)
        if pos_emb is not None:
            pe = pos_emb(self.batch_rels)
            # fact_rel = torch.cat([fact_rel, pe], 1)
            fact_val = F.relu((rel_linear(fact_rel)+pe) * fact_query)
        else :
            fact_val = F.relu(rel_linear(fact_rel) * fact_query)
        fact_prior = torch.sparse.mm(self.head2fact_mat, curr_dist.view(-1, 1))

        fact_val = fact_val * fact_prior
        
        f2e_emb = torch.sparse.mm(self.fact2tail_mat, fact_val)
        assert not torch.isnan(f2e_emb).any()

        neighbor_rep = f2e_emb.view(batch_size, max_local_entity, self.entity_dim)
        
        return neighbor_rep

    def reason_layer_inv(self, curr_dist, instruction, rel_linear, pos_emb_inv):
        batch_size = self.batch_size
        max_local_entity = self.max_local_entity
        # num_relation = self.num_relation
        rel_features = self.rel_features_inv
        
        fact_rel = torch.index_select(rel_features, dim=0, index=self.batch_rels)
        
        fact_query = torch.index_select(instruction, dim=0, index=self.batch_ids)
        if pos_emb_inv is not None:
            pe = pos_emb_inv(self.batch_rels)
            # fact_rel = torch.cat([fact_rel, pe], 1)
            fact_val = F.relu((rel_linear(fact_rel)+pe) * fact_query)
        else :
            fact_val = F.relu(rel_linear(fact_rel) * fact_query)
        fact_prior = torch.sparse.mm(self.tail2fact_mat, curr_dist.view(-1, 1))
        

        fact_val = fact_val * fact_prior

        f2e_emb = torch.sparse.mm(self.fact2head_mat, fact_val)
        assert not torch.isnan(f2e_emb).any()

        neighbor_rep = f2e_emb.view(batch_size, max_local_entity, self.entity_dim)
        
        return neighbor_rep

    def combine(self,emb):
        """
        Combines instruction-specific representations.
        """
        local_emb = torch.cat(emb, dim=-1)
        local_emb = F.relu(self.lin_m(local_emb))

        score_func = self.score_func
        
        score_tp = score_func(self.linear_drop(local_emb)).squeeze(dim=2)
        answer_mask = self.local_entity_mask
        self.possible_cand.append(answer_mask)
        score_tp = score_tp + (1 - answer_mask) * VERY_NEG_NUMBER
        current_dist = self.softmax_d1(score_tp)
        return current_dist, local_emb

    def forward(self, current_dist, relational_ins, step=0, return_score=False):
        """
        Compute next probabilistic vectors and current node representations.
        """
        rel_linear = getattr(self, 'rel_linear' + str(step))
        e2e_linear = getattr(self, 'e2e_linear' + str(step))
        # score_func = getattr(self, 'score_func' + str(step))
        score_func = self.score_func
        neighbor_reps = []
        
        if self.use_posemb :
            pos_emb = getattr(self, 'pos_emb' + str(step))
            pos_emb_inv = getattr(self, 'pos_emb_inv' + str(step))
        else :
            pos_emb, pos_emb_inv = None, None

        for j in range(relational_ins.size(1)):
            # we do the same procedure for existing and inverse relations
            neighbor_rep = self.reason_layer(current_dist, relational_ins[:,j,:], rel_linear, pos_emb)
            neighbor_reps.append(neighbor_rep)

            neighbor_rep = self.reason_layer_inv(current_dist, relational_ins[:,j,:], rel_linear, pos_emb_inv)
            neighbor_reps.append(neighbor_rep)

        neighbor_reps = torch.cat(neighbor_reps, dim=2)
        
        
        next_local_entity_emb = torch.cat((self.local_entity_emb, neighbor_reps), dim=2)
        #print(next_local_entity_emb.size())
        self.local_entity_emb = F.relu(e2e_linear(self.linear_drop(next_local_entity_emb)))

        score_tp = score_func(self.linear_drop(self.local_entity_emb)).squeeze(dim=2)
        answer_mask = self.local_entity_mask
        self.possible_cand.append(answer_mask)
        score_tp = score_tp + (1 - answer_mask) * VERY_NEG_NUMBER
        current_dist = self.softmax_d1(score_tp)
        if return_score:
            return score_tp, current_dist
        
        
        return current_dist, self.local_entity_emb 




================================================
FILE: gnn/modules/layer_init.py
================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
VERY_NEG_NUMBER = -100000000000
VERY_SMALL_NUMBER = 1e-10


class TypeLayer(nn.Module):
    """
    Sparse version GAT layer, similar to https://arxiv.org/abs/1710.10903
    """

    def __init__(self, in_features, out_features, linear_drop, device, norm_rel):
        super(TypeLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.linear_drop = linear_drop
        # self.kb_head_linear = nn.Linear(in_features, out_features)
        self.kb_self_linear = nn.Linear(in_features, out_features)
        # self.kb_tail_linear = nn.Linear(out_features, out_features)
        self.device = device
        self.norm_rel = norm_rel

    def forward(self, local_entity, edge_list, rel_features):
        '''
        input_vector: (batch_size, max_local_entity)
        curr_dist: (batch_size, max_local_entity)
        instruction: (batch_size, hidden_size)
        '''
        batch_heads, batch_rels, batch_tails, batch_ids, fact_ids, weight_list, weight_rel_list = edge_list
        num_fact = len(fact_ids)
        batch_size, max_local_entity = local_entity.size()
        hidden_size = self.in_features
        fact2head = torch.LongTensor([batch_heads, fact_ids]).to(self.device)
        fact2tail = torch.LongTensor([batch_tails, fact_ids]).to(self.device)
        batch_rels = torch.LongTensor(batch_rels).to(self.device)
        batch_ids = torch.LongTensor(batch_ids).to(self.device)
        if self.norm_rel:
            val_one = torch.FloatTensor(weight_rel_list).to(self.device) #* torch.FloatTensor(weight_list).to(self.device)
        else:
            val_one = torch.ones_like(batch_ids).float().to(self.device)

        
        # print("Prepare data:{:.4f}".format(time.time() - st))
        # Step 1: Calculate value for every fact with rel and head
        fact_rel = torch.index_select(rel_features, dim=0, index=batch_rels)
        # fact_val = F.relu(self.kb_self_linear(fact_rel) + self.kb_head_linear(self.linear_drop(fact_ent)))
        fact_val = self.kb_self_linear(fact_rel)
        # fact_val = self.kb_self_linear(fact_rel)#self.kb_head_linear(self.linear_drop(fact_ent))

        # Step 3: Edge Aggregation with Sparse MM
        fact2tail_mat = self._build_sparse_tensor(fact2tail, val_one, (batch_size * max_local_entity, num_fact))
        fact2head_mat = self._build_sparse_tensor(fact2head, val_one, (batch_size * max_local_entity, num_fact))

        # neighbor_rep = torch.sparse.mm(fact2tail_mat, self.kb_tail_linear(self.linear_drop(fact_val)))
        f2e_emb = F.relu(torch.sparse.mm(fact2tail_mat, fact_val) + torch.sparse.mm(fact2head_mat, fact_val))
        assert not torch.isnan(f2e_emb).any()

        f2e_emb = f2e_emb.view(batch_size, max_local_entity, hidden_size)

        return f2e_emb

    def _build_sparse_tensor(self, indices, values, size):
        return torch.sparse.FloatTensor(indices, values, size).to(self.device)


================================================
FILE: gnn/modules/query_update.py
================================================
import torch
import numpy as np
import torch.nn.functional as F
import torch.nn as nn

class Fusion(nn.Module):
    """docstring for Fusion"""
    def __init__(self, d_hid):
        super(Fusion, self).__init__()
        self.r = nn.Linear(d_hid*3, d_hid, bias=False)
        self.g = nn.Linear(d_hid*3, d_hid, bias=False)

    def forward(self, x, y):
        r_ = self.r(torch.cat([x,y,x-y], dim=-1))#.tanh()
        g_ = torch.sigmoid(self.g(torch.cat([x,y,x-y], dim=-1)))
        return g_ * r_ + (1 - g_) * x

class QueryReform(nn.Module):
    """docstring for QueryReform"""
    def __init__(self, h_dim):
        super(QueryReform, self).__init__()
        # self.q_encoder = AttnEncoder(h_dim)
        self.fusion = Fusion(h_dim)
        self.q_ent_attn = nn.Linear(h_dim, h_dim)

    def forward(self, q_node, ent_emb, seed_info, ent_mask):
        '''
        q: (B,q_len,h_dim)
        q_mask: (B,q_len)
        q_ent_span: (B,q_len)
        ent_emb: (B,C,h_dim)
        seed_info: (B, C)
        ent_mask: (B, C)
        '''
        # q_node = self.q_encoder(q, q_mask)
        q_ent_attn = (self.q_ent_attn(q_node).unsqueeze(1) * ent_emb).sum(2, keepdim=True)
        q_ent_attn = F.softmax(q_ent_attn - (1 - ent_mask.unsqueeze(2)) * 1e8, dim=1)
        attn_retrieve = (q_ent_attn * ent_emb).sum(1)

        seed_retrieve = torch.bmm(seed_info.unsqueeze(1), ent_emb).squeeze(1) # (B, 1, h_dim)
        # how to calculate the gate

        #return  self.fusion(q_node, attn_retrieve)
        return  self.fusion(q_node, seed_retrieve)

class AttnEncoder(nn.Module):
    """docstring for ClassName"""
    def __init__(self, d_hid):
        super(AttnEncoder, self).__init__()
        self.attn_linear = nn.Linear(d_hid, 1, bias=False)

    def forward(self, x, x_mask):
        """
        x: (B, len, d_hid)
        x_mask: (B, len)
        return: (B, d_hid)
        """
        x_attn = self.attn_linear(x)
        x_attn = x_attn - (1 - x_mask.unsqueeze(2))*1e8
        x_attn = F.softmax(x_attn, dim=1)
        return (x*x_attn).sum(1)

class Attention(nn.Module):
    """ Applies attention mechanism on the `context` using the `query`.

    **Thank you** to IBM for their initial implementation of :class:`Attention`. Here is
    their `License
    <https://github.com/IBM/pytorch-seq2seq/blob/master/LICENSE>`__.

    Args:
        dimensions (int): Dimensionality of the query and context.
        attention_type (str, optional): How to compute the attention score:

            * dot: :math:`score(H_j,q) = H_j^T q`
            * general: :math:`score(H_j, q) = H_j^T W_a q`

    Example:

         >>> attention = Attention(256)
         >>> query = torch.randn(5, 1, 256)
         >>> context = torch.randn(5, 5, 256)
         >>> output, weights = attention(query, context)
         >>> output.size()
         torch.Size([5, 1, 256])
         >>> weights.size()
         torch.Size([5, 1, 5])
    """

    def __init__(self, dimensions, attention_type='general'):
        super(Attention, self).__init__()

        if attention_type not in ['dot', 'general']:
            raise ValueError('Invalid attention type selected.')

        self.attention_type = attention_type
        if self.attention_type == 'general':
            self.linear_in = nn.Linear(dimensions, dimensions, bias=False)

        self.linear_out = nn.Linear(dimensions * 2, dimensions, bias=False)
        self.softmax = nn.Softmax(dim=-1)
        self.tanh = nn.Tanh()

    def forward(self, query, context):
        """
        Args:
            query (:class:`torch.FloatTensor` [batch size, output length, dimensions]): Sequence of
                queries to query the context.
            context (:class:`torch.FloatTensor` [batch size, query length, dimensions]): Data
                overwhich to apply the attention mechanism.

        Returns:
            :class:`tuple` with `output` and `weights`:
            * **output** (:class:`torch.LongTensor` [batch size, output length, dimensions]):
              Tensor containing the attended features.
            * **weights** (:class:`torch.FloatTensor` [batch size, output length, query length]):
              Tensor containing attention weights.
        """
        batch_size, output_len, dimensions = query.size()
        query_len = context.size(1)

        if self.attention_type == "general":
            query = query.reshape(batch_size * output_len, dimensions)
            query = self.linear_in(query)
            query = query.reshape(batch_size, output_len, dimensions)

        # TODO: Include mask on PADDING_INDEX?

        # (batch_size, output_len, dimensions) * (batch_size, query_len, dimensions) ->
        # (batch_size, output_len, query_len)
        attention_scores = torch.bmm(query, context.transpose(1, 2).contiguous())

        # Compute weights across every context sequence
        attention_scores = attention_scores.view(batch_size * output_len, query_len)
        attention_weights = self.softmax(attention_scores)
        attention_weights = attention_weights.view(batch_size, output_len, query_len)

        # (batch_size, output_len, query_len) * (batch_size, query_len, dimensions) ->
        # (batch_size, output_len, dimensions)
        mix = torch.bmm(attention_weights, context)

        # concat -> (batch_size * output_len, 2*dimensions)
        combined = torch.cat((mix, query), dim=2)
        combined = combined.view(batch_size * output_len, 2 * dimensions)

        # Apply linear_out on every 2nd dimension of concat
        # output -> (batch_size, output_len, dimensions)
        output = self.linear_out(combined).view(batch_size, output_len, dimensions)
        output = self.tanh(output)

        return output, attention_weights

================================================
FILE: gnn/modules/question_encoding/base_encoder.py
================================================
import torch
import torch.nn.functional as F
import torch.nn as nn

VERY_SMALL_NUMBER = 1e-10
VERY_NEG_NUMBER = -100000000000

class BaseInstruction(torch.nn.Module):

    def __init__(self, args, constraint):
        super(BaseInstruction, self).__init__()
        self.constraint = constraint
        self._parse_args(args)
        self.share_module_def()

    def _parse_args(self, args):
        self.device = torch.device('cuda' if args['use_cuda'] else 'cpu')
        

        # self.share_encoder = args['share_encoder']
        self.q_type = args['q_type']
        if 'num_step' in args:
            self.num_ins = args['num_step']
        elif 'num_ins' in args:
            self.num_ins = args['num_ins']
        elif 'num_layer' in args:
            self.num_ins = args['num_layer']
        elif 'num_expansion_ins' in args and 'num_backup_ins' in args:
            self.num_ins = args['num_backup_ins'] if self.constraint else args['num_expansion_ins']
        else:
            self.num_ins = 1
        
        self.lm_dropout = args['lm_dropout']
        self.linear_dropout = args['linear_dropout']
        self.lm_frozen = args['lm_frozen']

        for k, v in args.items():
            if k.endswith('dim'):
                setattr(self, k, v)
            if k.endswith('emb_file') or k.endswith('kge_file'):
                if v is None:
                    setattr(self, k, None)
                else:
                    setattr(self, k, args['data_folder'] + v)

        self.reset_time = 0

    def share_module_def(self):
        # dropout
        self.lstm_drop = nn.Dropout(p=self.lm_dropout)
        self.linear_drop = nn.Dropout(p=self.linear_dropout)

    def init_hidden(self, num_layer, batch_size, hidden_size):
        return (torch.zeros(num_layer, batch_size, hidden_size).to(self.device),
                torch.zeros(num_layer, batch_size, hidden_size).to(self.device))

    def encode_question(self, *args):
        # constituency tree or query_text
        pass

    @staticmethod
    def get_node_emb(query_hidden_emb, action):
        '''

        :param query_hidden_emb: (batch_size, max_hyper, emb)
        :param action: (batch_size)
        :return: (batch_size, 1, emb)
        '''
        batch_size, max_hyper, _ = query_hidden_emb.size()
        row_idx = torch.arange(0, batch_size).type(torch.LongTensor)
        q_rep = query_hidden_emb[row_idx, action, :]
        return q_rep.unsqueeze(1)

    def init_reason(self, query_text):
        self.batch_size = query_text.size(0)
        self.max_query_word = query_text.size(1)
        self.encode_question(query_text)
        self.relational_ins = torch.zeros(self.batch_size, self.entity_dim).to(self.device)
        self.instructions = []
        self.attn_list = []

    def get_instruction(self, relational_ins, step=0, query_node_emb=None):
        
        query_hidden_emb = self.query_hidden_emb
        
        query_mask = self.query_mask
        if query_node_emb is None:
            query_node_emb = self.query_node_emb
        
        relational_ins = relational_ins.unsqueeze(1)
        question_linear = getattr(self, 'question_linear' + str(step))
        q_i = question_linear(self.linear_drop(query_node_emb))
        cq = self.cq_linear(self.linear_drop(torch.cat((relational_ins, q_i, q_i-relational_ins,q_i*relational_ins), dim=-1)))
        # batch_size, 1, entity_dim
        ca = self.ca_linear(self.linear_drop(cq * query_hidden_emb))
        # batch_size, max_local_entity, 1
        # cv = self.softmax_d1(ca + (1 - query_mask.unsqueeze(2)) * VERY_NEG_NUMBER)
        attn_weight = F.softmax(ca + (1 - query_mask.unsqueeze(2)) * VERY_NEG_NUMBER, dim=1)
        # batch_size, max_local_entity, 1
        relational_ins = torch.sum(attn_weight * query_hidden_emb, dim=1)
        return relational_ins, attn_weight
        


    def forward(self, query_text, lm=None):
        if lm is not None:
            self.node_encoder = lm
        self.init_reason(query_text)
        for i in range(self.num_ins):
            relational_ins, attn_weight = self.get_instruction(self.relational_ins, step=i)
            self.instructions.append(relational_ins)
            self.attn_list.append(attn_weight)
            self.relational_ins = relational_ins
        return self.instructions, self.attn_list



================================================
FILE: gnn/modules/question_encoding/bert_encoder.py
================================================

import torch.nn.functional as F
import torch.nn as nn
VERY_SMALL_NUMBER = 1e-10
VERY_NEG_NUMBER = -100000000000


from transformers import AutoModel, AutoTokenizer #DistilBertModel, BertModel, BertTokenizer, RobertaModel, RobertaTokenizer
from torch.nn import LayerNorm
import warnings
warnings.filterwarnings("ignore")
import os
os.environ['TRANSFORMERS_CACHE'] = '/export/scratch/costas/home/mavro016/.cache'

from .base_encoder import BaseInstruction


class BERTInstruction(BaseInstruction):

    def __init__(self, args, word_embedding, num_word, model, constraint=False):
        super(BERTInstruction, self).__init__(args, constraint)
        self.word_embedding = word_embedding
        self.num_word = num_word
        self.constraint = constraint
        
        entity_dim = self.entity_dim
        self.model = model
        
        
        if model == 'bert':
            self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
            self.pretrained_weights = 'bert-base-uncased'
            word_dim = 768#self.word_dim
        elif model == 'roberta':
            self.tokenizer = AutoTokenizer.from_pretrained('roberta-base')
            self.pretrained_weights = 'roberta-base'
            word_dim = 768#self.word_dim
        elif model == 'sbert':
            self.tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
            self.pretrained_weights = 'sentence-transformers/all-MiniLM-L6-v2'
            word_dim = 384#self.word_dim
        elif model == 'simcse':
            #print('ok')
            self.tokenizer = AutoTokenizer.from_pretrained('princeton-nlp/sup-simcse-bert-base-uncased')
            self.pretrained_weights = 'princeton-nlp/sup-simcse-bert-base-uncased'
            word_dim = 768#self.word_dim
        elif model == 'sbert2':
            #tokenizer_name = 'sentence-transformers/all-mpnet-base-v2'
            self.tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
            self.pretrained_weights = 'sentence-transformers/all-mpnet-base-v2'
            word_dim = 768#self.word_dim
        elif model == 't5':
            self.tokenizer = AutoTokenizer.from_pretrained('t5-small')
            self.pretrained_weights = 't5-small'
            word_dim = 768#self.word_dim
        elif model  == 'relbert':
            self.tokenizer = AutoTokenizer.from_pretrained('pretrained_lms/sr-simbert/')
            self.pretrained_weights = 'pretrained_lms/sr-simbert/'
            word_dim = 768
        #self.mask = mask
        self.pad_val = self.tokenizer.convert_tokens_to_ids(self.tokenizer.pad_token)
        self.word_dim = word_dim

        print('word_dim', self.word_dim)
        self.cq_linear = nn.Linear(in_features=4 * entity_dim, out_features=entity_dim)
        self.ca_linear = nn.Linear(in_features=entity_dim, out_features=1)
        for i in range(self.num_ins):
            self.add_module('question_linear' + str(i), nn.Linear(in_features=entity_dim, out_features=entity_dim))
        self.question_emb = nn.Linear(in_features=word_dim, out_features=entity_dim)

        if not self.constraint:
            self.encoder_def()

    def encoder_def(self):
        # initialize entity embedding
        word_dim = self.word_dim
        entity_dim = self.entity_dim
        self.node_encoder = AutoModel.from_pretrained(self.pretrained_weights)
        print('Total Params', sum(p.numel() for p in self.node_encoder.parameters()))
        if self.lm_frozen == 1:
            print('Freezing LM params')
            for param in self.node_encoder.parameters():
                param.requires_grad = False
        else:
            for param in self.node_encoder.parameters():
                param.requires_grad = True
            print('Unfrozen LM params')

    def encode_question(self, query_text, store=True):
        batch_size = query_text.size(0)
        
        if self.model != 't5':
            
            query_hidden_emb = self.node_encoder(query_text)[0]  # 1, batch_size, entity_dim
        else:
            query_hidden_emb = self.node_encoder.encoder(query_text)[0]
            #print(query_hidden_emb.size())
        

        if store:
            self.query_hidden_emb = self.question_emb(query_hidden_emb)
            self.query_node_emb = query_hidden_emb.transpose(1,0)[0].unsqueeze(1)
            #print(self.query_node_emb.size())
            self.query_node_emb = self.question_emb(self.query_node_emb)
            
            self.query_mask = (query_text != self.pad_val).float()
            return query_hidden_emb, self.query_node_emb
        else:
            return  query_hidden_emb 



================================================
FILE: gnn/modules/question_encoding/lstm_encoder.py
================================================


import torch.nn as nn
from utils import get_dict
from .base_encoder import BaseInstruction

VERY_SMALL_NUMBER = 1e-10
VERY_NEG_NUMBER = -100000000000

class LSTMInstruction(BaseInstruction):

    def __init__(self, args, word_embedding, num_word):
        super(LSTMInstruction, self).__init__(args)
        self.word2id = get_dict(args['data_folder'],args['word2id'])

        self.word_embedding = word_embedding
        self.num_word = num_word
        self.encoder_def()
        entity_dim = self.entity_dim
        self.cq_linear = nn.Linear(in_features=4 * entity_dim, out_features=entity_dim)
        self.ca_linear = nn.Linear(in_features=entity_dim, out_features=1)
        for i in range(self.num_ins):
            self.add_module('question_linear' + str(i), nn.Linear(in_features=entity_dim, out_features=entity_dim))

    def encoder_def(self):
        # initialize entity embedding
        word_dim = self.word_dim
        entity_dim = self.entity_dim
        self.node_encoder = nn.LSTM(input_size=word_dim, hidden_size=entity_dim,
                                    batch_first=True, bidirectional=False)

    def encode_question(self, query_text, store=True):
        batch_size = query_text.size(0)
        query_word_emb = self.word_embedding(query_text)  # batch_size, max_query_word, word_dim
        query_hidden_emb, (h_n, c_n) = self.node_encoder(self.lstm_drop(query_word_emb),
                                                         self.init_hidden(1, batch_size,
                                                                          self.entity_dim))  # 1, batch_size, entity_dim
        if store:
            self.instruction_hidden = h_n
            self.instruction_mem = c_n
            self.query_node_emb = h_n.squeeze(dim=0).unsqueeze(dim=1)  # batch_size, 1, entity_dim
            self.query_hidden_emb = query_hidden_emb
            self.query_mask = (query_text != self.num_word).float()
            return query_hidden_emb, self.query_node_emb
        else:
            return query_hidden_emb
    

    



================================================
FILE: gnn/modules/question_encoding/tokenizers.py
================================================
import re
import numpy as np
from transformers import BertTokenizer

class LSTMTokenizer():
    def __init__(self, word2id, max_query_word):
        super(LSTMTokenizer, self).__init__()
        self.word2id = word2id
        self.max_query_word = max_query_word

    def tokenize(self, question):
        tokens = self.tokenize_sent(question)
        query_text = np.full(self.max_query_word, len(self.word2id), dtype=int)
        #tokens = question.split()
        #if self.data_type == "train":
        #    random.shuffle(tokens)
        for j, word in enumerate(tokens):
            if j < self.max_query_word:
                    if word in self.word2id:
                        query_text[j] = self.word2id[word]
                        
            else:
                query_text[j] = len(self.word2id)

        return query_text

    @staticmethod
    def tokenize_sent(question_text):
        question_text = question_text.strip().lower()
        question_text = re.sub('\'s', ' s', question_text)
        words = []
        toks = enumerate(question_text.split(' '))
        
        for w_idx, w in toks:
            w = re.sub('^[^a-z0-9]|[^a-z0-9]$', '', w)
            if w == '':
                continue
            words += [w]
        return words

class BERTTokenizer():
    def __init__(self, max_query_word):
        super(BERTTokenizer, self).__init__()
        self.q_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.max_query_word = max_query_word
        self.num_word = self.q_tokenizer.encode("[UNK]")[0] #len(self.q_tokenizer.vocab.keys())

        
    
    def tokenize(self, question):
        query_text = np.full(self.max_query_word, 0, dtype=int)
        tokens =  self.q_tokenizer.encode_plus(text=question, max_length=self.max_query_word, \
                    pad_to_max_length=True, return_attention_mask = False, truncation=True)
        return np.array(tokens['input_ids'])

================================================
FILE: gnn/parsing.py
================================================
import argparse
import sys


def bool_flag(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def add_shared_args(parser):
    parser.add_argument('--name', default='webqsp', type=str)
    parser.add_argument('--data_folder', default='data/webqsp/', type=str)
    parser.add_argument('--max_train', default=200000, type=int)

    # embeddings
    parser.add_argument('--word2id', default='vocab.txt', type=str)
    parser.add_argument('--relation2id', default='relations.txt', type=str)
    parser.add_argument('--entity2id', default='entities.txt', type=str)
    parser.add_argument('--char2id', default='chars.txt', type=str)
    parser.add_argument('--entity_emb_file', default=None, type=str)
    parser.add_argument('--relation_emb_file', default=None, type=str)
    parser.add_argument('--relation_word_emb', default=True, type=bool_flag)
    parser.add_argument('--word_emb_file', default='word_emb.npy', type=str)
    parser.add_argument('--rel_word_ids', default='rel_word_idx.npy', type=str)
    parser.add_argument('--kge_frozen', default=0, type=int)
    parser.add_argument('--lm', default='lstm', type=str, choices=['lstm', 'bert', 'roberta', 'sbert', 't5','sbert2', 'dbert', 'simcse', 'relbert'])
    parser.add_argument('--lm_frozen', default=1, type=int)

    # dimensions, layers, dropout
    parser.add_argument('--entity_dim', default=50, type=int)
    parser.add_argument('--kg_dim', default=100, type=int)
    parser.add_argument('--word_dim', default=300, type=int)
    parser.add_argument('--lm_dropout', default=0.3, type=float)
    parser.add_argument('--linear_dropout', default=0.2, type=float)

    # optimization
    parser.add_argument('--num_epoch', default=100, type=int)
    parser.add_argument('--warmup_epoch', default=0, type=int)
    parser.add_argument('--fact_scale', default=3, type=int)
    parser.add_argument('--eval_every', default=2, type=int)
    parser.add_argument('--batch_size', default=20, type=int)
    parser.add_argument('--gradient_clip', default=1.0, type=float)
    parser.add_argument('--lr', default=0.0005, type=float)
    parser.add_argument('--decay_rate', default=0.0, type=float)
    parser.add_argument('--seed', default=19960626, type=int)
    parser.add_argument('--lr_schedule', action='store_true')
    parser.add_argument('--label_smooth', default=0.1, type=float)
    parser.add_argument('--fact_drop', default=0, type=float)
    #parser.add_argument('--encode_type', action='store_true')

    # model options

    parser.add_argument('--is_eval', action='store_true')
    parser.add_argument('--checkpoint_dir', default='checkpoint/pretrain/', type=str)
    parser.add_argument('--log_level', type=str, default='info')
    parser.add_argument('--experiment_name', default='', type=str)
    parser.add_argument('--load_experiment', default=None, type=str)
    parser.add_argument('--load_ckpt_file', default=None, type=str)
    parser.add_argument('--eps', default=0.95, type=float) # threshold for f1
    parser.add_argument('--test_batch_size', default=20, type=int)
    parser.add_argument('--q_type', default='seq', type=str)



def add_parse_args(parser):
    
    subparsers = parser.add_subparsers(help='Reason KGQA model')

    parser_rearev = subparsers.add_parser("ReaRev")
    create_parser_rearev(parser_rearev)

    parser_nsm = subparsers.add_parser("NSM")
    create_parser_nsm(parser_nsm)

    parser_graftnet = subparsers.add_parser("GraftNet")
    create_parser_graftnet(parser_graftnet)

    parser_nutrea = subparsers.add_parser("NuTrea")
    create_parser_nutrea(parser_nutrea)


def create_parser_rearev(parser):

    parser.add_argument('--model_name', default='ReaRev', type=str, choices=['ReaRev'])
    parser.add_argument('--alg', default='bfs', type=str)
    parser.add_argument('--num_iter', default=2, type=int)
    parser.add_argument('--num_ins', default=3, type=int)
    parser.add_argument('--num_gnn', default=3, type=int)
    parser.add_argument('--loss_type', default='kl', type=str)
    parser.add_argument('--use_self_loop', default=True, type=bool_flag)
    parser.add_argument('--normalized_gnn', default=False, type=bool_flag)
    parser.add_argument('--norm_rel', action='store_true')
    parser.add_argument('--data_eff', action='store_true')
    parser.add_argument('--pos_emb', action='store_true')
    add_shared_args(parser)


def create_parser_nsm(parser):
    parser.add_argument('--model_name', default='NSM', type=str, choices=['NSM'])
    parser.add_argument('--num_step', default=3, type=int)
    parser.add_argument('--reason_kb', default=False, type=bool_flag)
    parser.add_argument('--loss_type', default='kl', type=str)
    parser.add_argument('--lambda_constrain', default=0.0, type=float)
    parser.add_argument('--lambda_back', default=0.0, type=float)
    parser.add_argument('--use_self_loop', default=True, type=bool_flag)
    parser.add_argument('--use_inverse_relation', action='store_true')
    parser.add_argument('--norm_rel', action='store_true')
    parser.add_argument('--normalized_gnn', default=False, type=bool_flag)
    parser.add_argument('--data_eff', action='store_true')
    add_shared_args(parser)

def create_parser_graftnet(parser):
    parser.add_argument('--model_name', default='GraftNet', type=str, choices=['GraftNet'])
    parser.add_argument('--pagerank_lambda', default=0.8, type=float)
    parser.add_argument('--loss_type', default='bce', type=str)
    parser.add_argument('--num_layer', default=3, type=int)
    parser.add_argument('--use_inverse_relation', action='store_true')
    parser.add_argument('--norm_rel', action='store_true')
    parser.add_argument('--normalized_gnn', default=False, type=bool_flag)
    parser.add_argument('--data_eff', action='store_true')
    #parser.add_argument('--use_self_loop', default=True, type=bool_flag)
    add_shared_args(parser)


================================================
FILE: gnn/requirements.txt
================================================
Base==1.0.4
numpy==1.19.5
torch==1.7.1+cu110
tqdm==4.59.0
transformers==4.6.1

================================================
FILE: gnn/scripts/rearev_cwq.sh
================================================

###ReaRev+SBERT training
# python main.py ReaRev --is_eval --load_experiment relbert-full_cwq-rearev-final.ckpt --entity_dim 50 --num_epoch 200 --batch_size 8 --eval_every 2  \
# --lm relbert --num_iter 2 --num_ins 3 --num_gnn 3  --name cwq \
# --experiment_name prn_cwq-rearev-sbert --data_folder data/CWQ/ --num_epoch 100 --warmup_epoch 80

###ReaRev+LMSR training
# python main.py ReaRev  --entity_dim 50 --num_epoch 200 --batch_size 8 --eval_every 2  \
# --lm relbert --num_iter 2 --num_ins 3 --num_gnn 3  --name cwq \
# --experiment_name prn_cwq-rearev-lmsr  --data_folder data/CWQ/ --num_epoch 100 #--warmup_epoch 80


###Evaluate CWQ
python main.py ReaRev --entity_dim 50 --num_epoch 100 --batch_size 8 --eval_every 2 --data_folder data/CWQ/ --lm sbert --num_iter 2 --num_ins 3 --num_gnn 3 --relation_word_emb True --load_experiment ReaRev_CWQ.ckpt --is_eval --name cwq

================================================
FILE: gnn/train_model.py
================================================

from utils import create_logger
import time
import numpy as np
import os, math

import torch
from torch.optim.lr_scheduler import ExponentialLR
import torch.optim as optim

from tqdm import tqdm
tqdm.monitor_iterval = 0



#from dataset_load_paths import load_data
from dataset_load import load_data
from dataset_load_graft import load_data_graft
from models.ReaRev.rearev import ReaRev
from models.NSM.nsm import NSM
from models.GraftNet.graftnet import GraftNet
from evaluate import Evaluator

class Trainer_KBQA(object):
    def __init__(self, args, model_name, logger=None):
        #print('Trainer here')
        self.args = args
        self.logger = logger
        self.best_dev_performance = 0.0
        self.best_h1 = 0.0
        self.best_f1 = 0.0
        self.best_h1b = 0.0
        self.best_f1b = 0.0
        self.eps = args['eps']
        self.warmup_epoch = args['warmup_epoch']
        self.learning_rate = self.args['lr']
        self.test_batch_size = args['test_batch_size']
        self.device = torch.device('cuda' if args['use_cuda'] else 'cpu')
        self.reset_time = 0
        self.load_data(args, args['lm'])
        


        if 'decay_rate' in args:
            self.decay_rate = args['decay_rate']
        else:
            self.decay_rate = 0.98

        if model_name == 'ReaRev':
            self.model = ReaRev(self.args,  len(self.entity2id), self.num_kb_relation,
                                  self.num_word)
        elif model_name == 'NSM':
            self.model = NSM(self.args,  len(self.entity2id), self.num_kb_relation,
                                  self.num_word)
        elif model_name == 'GraftNet':
            self.model = GraftNet(self.args,  len(self.entity2id), self.num_kb_relation,
                                  self.num_word)
        elif model_name == 'NuTrea':
            self.model = NuTrea(self.args,  len(self.entity2id), self.num_kb_relation,
                                  self.num_word)
        
        if args['relation_word_emb']:
            #self.model.use_rel_texts(self.rel_texts, self.rel_texts_inv)
            self.model.encode_rel_texts(self.rel_texts, self.rel_texts_inv)


        self.model.to(self.device)
        self.evaluator = Evaluator(args=args, model=self.model, entity2id=self.entity2id,
                                       relation2id=self.relation2id, device=self.device)
        self.load_pretrain()
        self.optim_def()
        
        self.num_relation =  self.num_kb_relation
        self.num_entity = len(self.entity2id)
        self.num_word = len(self.word2id)
                                  

        print("Entity: {}, Relation: {}, Word: {}".format(self.num_entity, self.num_relation, self.num_word))

        for k, v in args.items():
            if k.endswith('dim'):
                setattr(self, k, v)
            if k.endswith('emb_file') or k.endswith('kge_file'):
                if v is None:
                    setattr(self, k, None)
                else:
                    setattr(self, k, args['data_folder'] + v)

    def optim_def(self):
        
        trainable = filter(lambda p: p.requires_grad, self.model.parameters())
        self.optim_model = optim.Adam(trainable, lr=self.learning_rate)
        if self.decay_rate > 0:
            self.scheduler = ExponentialLR(self.optim_model, self.decay_rate)

    def load_data(self, args, tokenize):
        if args["model_name"] == "GraftNet":
            dataset = load_data_graft(args, tokenize)
        else:
            dataset = load_data(args, tokenize)
        self.train_data = dataset["train"]
        self.valid_data = dataset["valid"]
        self.test_data = dataset["test"]
        self.entity2id = dataset["entity2id"]
        self.relation2id = dataset["relation2id"]
        self.word2id = dataset["word2id"]
        self.num_word = dataset["num_word"]
        self.num_kb_relation = self.test_data.num_kb_relation
        self.num_entity = len(self.entity2id)
        self.rel_texts = dataset["rel_texts"]
        self.rel_texts_inv = dataset["rel_texts_inv"]

    def load_pretrain(self):
        args = self.args
        if args['load_experiment'] is not None:
            ckpt_path = os.path.join(args['checkpoint_dir'], args['load_experiment'])
            print("Load ckpt from", ckpt_path)
            self.load_ckpt(ckpt_path)

    def evaluate(self, data, test_batch_size=20, write_info=False):
        return self.evaluator.evaluate(data, test_batch_size, write_info)

    def train(self, start_epoch, end_epoch):
        # self.load_pretrain()
        eval_every = self.args['eval_every']
        # eval_acc = inference(self.model, self.valid_data, self.entity2id, self.args)
        # self.evaluate(self.test_data, self.test_batch_size)
        print("Start Training------------------")
        for epoch in range(start_epoch, end_epoch + 1):
            st = time.time()
            loss, extras, h1_list_all, f1_list_all = self.train_epoch()

            if self.decay_rate > 0:
                self.scheduler.step()
            
            self.logger.info("Epoch: {}, loss : {:.4f}, time: {}".format(epoch + 1, loss, time.time() - st))
            self.logger.info("Training h1 : {:.4f}, f1 : {:.4f}".format(np.mean(h1_list_all), np.mean(f1_list_all)))
            
            if (epoch + 1) % eval_every == 0:
                eval_f1, eval_h1, eval_em = self.evaluate(self.valid_data, self.test_batch_size)
                self.logger.info("EVAL F1: {:.4f}, H1: {:.4f}, EM {:.4f}".format(eval_f1, eval_h1, eval_em))
                # eval_f1, eval_h1 = self.evaluate(self.test_data, self.test_batch_size)
                # self.logger.info("TEST F1: {:.4f}, H1: {:.4f}".format(eval_f1, eval_h1))
                do_test = False

                if epoch > self.warmup_epoch:
                    if eval_h1 > self.best_h1:
                        self.best_h1 = eval_h1
                        self.save_ckpt("h1")
                        self.logger.info("BEST EVAL H1: {:.4f}".format(eval_h1))
                        do_test = True
                    if eval_f1 > self.best_f1:
                        self.best_f1 = eval_f1
                        self.save_ckpt("f1")
                        self.logger.info("BEST EVAL F1: {:.4f}".format(eval_f1))
                        do_test = True

                eval_f1, eval_h1, eval_em = self.evaluate(self.test_data, self.test_batch_size)
                self.logger.info("TEST F1: {:.4f}, H1: {:.4f}, EM {:.4f}".format(eval_f1, eval_h1, eval_em))
                # if do_test:
                #     eval_f1, eval_h1 = self.evaluate(self.test_data, self.test_batch_size)
                #     self.logger.info("TEST F1: {:.4f}, H1: {:.4f}".format(eval_f1, eval_h1))
                
                # if eval_h1 > self.best_h1:
                #     self.best_h1 = eval_h1
                #     self.save_ckpt("h1")
                # if eval_f1 > self.best_f1:
                #     self.best_f1 = eval_f1
                #     self.save_ckpt("f1")
                # self.reset_time = 0
                # else:
                #     self.logger.info('No improvement after one evaluation iter.')
                #     self.reset_time += 1
                # if self.reset_time >= 5:
                #     self.logger.info('No improvement after 5 evaluation. Early Stopping.')
                #     break
        self.save_ckpt("final")
        self.logger.info('Train Done! Evaluate on testset with saved model')
        print("End Training------------------")
        self.evaluate_best()

    def evaluate_best(self):
        filename = os.path.join(self.args['checkpoint_dir'], "{}-h1.ckpt".format(self.args['experiment_name']))
        self.load_ckpt(filename)
        eval_f1, eval_h1, eval_em = self.evaluate(self.test_data, self.test_batch_size, write_info=False)
        self.logger.info("Best h1 evaluation")
        self.logger.info("TEST F1: {:.4f}, H1: {:.4f}, EM {:.4f}".format(eval_f1, eval_h1, eval_em))

        filename = os.path.join(self.args['checkpoint_dir'], "{}-f1.ckpt".format(self.args['experiment_name']))
        self.load_ckpt(filename)
        eval_f1, eval_h1, eval_em = self.evaluate(self.test_data, self.test_batch_size,  write_info=False)
        self.logger.info("Best f1 evaluation")
        self.logger.info("TEST F1: {:.4f}, H1: {:.4f}, EM {:.4f}".format(eval_f1, eval_h1, eval_em))

        filename = os.path.join(self.args['checkpoint_dir'], "{}-final.ckpt".format(self.args['experiment_name']))
        self.load_ckpt(filename)
        eval_f1, eval_h1, eval_em = self.evaluate(self.test_data, self.test_batch_size, write_info=False)
        self.logger.info("Final evaluation")
        self.logger.info("TEST F1: {:.4f}, H1: {:.4f}, EM {:.4f}".format(eval_f1, eval_h1, eval_em))

    def evaluate_single(self, filename):
        if filename is not None:
            self.load_ckpt(filename)
        eval_f1, eval_hits, eval_ems = self.evaluate(self.valid_data, self.test_batch_size, write_info=False)
        self.logger.info("EVAL F1: {:.4f}, H1: {:.4f}, EM {:.4f}".format(eval_f1, eval_hits, eval_ems))
        test_f1, test_hits, test_ems = self.evaluate(self.test_data, self.test_batch_size, write_info=True)
        self.logger.info("TEST F1: {:.4f}, H1: {:.4f}, EM {:.4f}".format(test_f1, test_hits, test_ems))

    def train_epoch(self):
        self.model.train()
        self.train_data.reset_batches(is_sequential=False)
        losses = []
        actor_losses = []
        ent_losses = []
        num_epoch = math.ceil(self.train_data.num_data / self.args['batch_size'])
        h1_list_all = []
        f1_list_all = []
        for iteration in tqdm(range(num_epoch)):
            batch = self.train_data.get_batch(iteration, self.args['batch_size'], self.args['fact_drop'])
            
            self.optim_model.zero_grad()
            loss, _, _, tp_list = self.model(batch, training=True)
            # if tp_list is not None:
            h1_list, f1_list = tp_list
            h1_list_all.extend(h1_list)
            f1_list_all.extend(f1_list)
            loss.backward()
            torch.nn.utils.clip_grad_norm_([param for name, param in self.model.named_parameters()],
                                           self.args['gradient_clip'])
            self.optim_model.step()
            losses.append(loss.item())
        extras = [0, 0]
        return np.mean(losses), extras, h1_list_all, f1_list_all

    
    def save_ckpt(self, reason="h1"):
        model = self.model
        checkpoint = {
            'model_state_dict': model.state_dict()
        }
        model_name = os.path.join(self.args['checkpoint_dir'], "{}-{}.ckpt".format(self.args['experiment_name'],
                                                                                   reason))
        torch.save(checkpoint, model_name)
        print("Best %s, save model as %s" %(reason, model_name))

    def load_ckpt(self, filename):
        checkpoint = torch.load(filename)
        model_state_dict = checkpoint["model_state_dict"]

        model = self.model
        #self.logger.info("Load param of {} from {}.".format(", ".join(list(model_state_dict.keys())), filename))
        model.load_state_dict(model_state_dict, strict=False)



================================================
FILE: gnn/utils.py
================================================
import logging
import os


def create_logger(args):
    log_file = os.path.join(args.checkpoint_dir, args.experiment_name + ".log")
    logger = logging.getLogger()
    log_level = logging.DEBUG if args.log_level == "debug" else logging.INFO
    logger.setLevel(level=log_level)
    # Formatter
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    # FileHandler
    file_handler = logging.FileHandler(log_file)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    # StreamHandler
    stream_handler = logging.StreamHandler()
    stream_handler.setFormatter(formatter)
    logger.addHandler(stream_handler)

    logger.info("PARAMETER" + "-" * 10)
    for attr, value in sorted(args.__dict__.items()):
        logger.info("{}={}".format(attr.upper(), value))
    logger.info("---------" + "-" * 10)

    return logger


def get_dict(data_folder, filename):
    filename_true = os.path.join(data_folder, filename)
    word2id = dict()
    with open(filename_true, encoding='utf-8') as f_in:
        for line in f_in:
            word = line.strip()
            word2id[word] = len(word2id)
    return word2id



================================================
FILE: llm/.gitignore
================================================
datasets/
*json

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
#   For a library or package, you might want to ignore these files since the code is
#   intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# poetry
#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
#   This is especially recommended for binary packages to ensure reproducibility, and is more
#   commonly ignored for libraries.
#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
#   in version control.
#   https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
#  and can be added to the global gitignore or merged into this file.  For a more nuclear
#  option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/


================================================
FILE: llm/README.md
================================================
## Get Started
We have simple requirements in `requirements.txt`. You can always check if you can run the code immediately.

Please also download `entities_names.json` file from https://drive.google.com/drive/folders/1ifgVHQDnvFEunP9hmVYT07Y3rvcpIfQp?usp=sharing, as GNNs use the dense graphs. 

## Evaluation
We provide the results of GNN retrieval in `results/gnn`. To evaluate GNN-RAG performance, run `scripts/rag-reasoning.sh`. 

You can also compute perfromance on multi-hop question by `scripts/evaluate_multi_hop.sh`. 

To test different LLMs for KGQA (ChatGPT, LLaMA2), see `scripts/plug-and-play.sh`. 

## Resutls

We append all the results for Table 2: See `results/KGQA-GNN-RAG-RA`. You can look at the actual LLM generations, as well as the KG information retrieved ("input" key) in `predictions.jsonl`.

================================================
FILE: llm/prompts/alpaca.txt
================================================
Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{instruction}

### Input:
{input}

### Response:

================================================
FILE: llm/prompts/general_prompt.txt
================================================
{instruction}

{input}

================================================
FILE: llm/prompts/llama2.txt
================================================
[INST] <<SYS>>
<</SYS>>
{instruction}{input} [/INST]

================================================
FILE: llm/prompts/llama2_predict.txt
================================================
[INST] <<SYS>>
<</SYS>>
{instruction}

{input} [/INST]

================================================
FILE: llm/requirements.txt
================================================
openai==0.27.9
transformers==4.32.0
trl==0.7.1
peft==0.5.0
datasets==2.14.4
accelerate==0.22.0
pybind11==2.11.1
networkx==3.1
graph-walker==1.0.6
tqdm


================================================
FILE: llm/results/KGQA-GNN-RAG/rearev-lmsr/RoG-cwq/RoG/test/results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/False/args.txt
================================================
{
  "data_path": "rmanluo",
  "d": "RoG-cwq",
  "split": "test",
  "predict_path": "results/KGQA-GO",
  "model_name": "RoG",
  "prompt_path": "prompts/llama2_predict.txt",
  "add_rule": true,
  "use_true": false,
  "cot": false,
  "explain": false,
  "use_random": false,
  "each_line": false,
  "rule_path": "results/gen_rule_path/RoG-cwq/RoG/test/predictions_3_False.jsonl",
  "force": false,
  "n": 1,
  "filter_empty": false,
  "debug": false,
  "encrypt": false,
  "model_path": "rmanluo/RoG",
  "max_new_tokens": 512,
  "dtype": "fp16"
}

================================================
FILE: llm/results/KGQA-GNN-RAG/rearev-lmsr/RoG-cwq/RoG/test/results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/False/detailed_eval_result.jsonl
================================================
{"id": "WebQTest-832_c334509bb5e02cacae1ba2e80c176499", "prediction": ["2014 World Series", "2010 World Series", "2012 World Series"], "ground_truth": ["2014 World Series"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 0.5, "precission": 0.3333333333333333, "recall": 1.0}
{"id": "WebQTrn-1259_1997cb4922db71983be26e6a509950f4", "prediction": ["University of Tennessee", "Cornell University", "Shortridge High School", "Butler University", "University of Oxford", "University Yale", "Georgetown University", "University College, Oxford"], "ground_truth": ["Belmont University"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTest-1384_744a496b907e407b16bc5d7c197dc3f0", "prediction": ["Judaism"], "ground_truth": ["Judaism"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-241_dfb6c97ac9bf2f0ac07f27dd80f9edc2", "prediction": ["Germany"], "ground_truth": ["Germany"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-1077_f4a9e5f1e0dcfb82cbadf4771eda7bb5", "prediction": ["Islam", "Pashtunism", "Afghanism"], "ground_truth": ["Shia Islam", "Sunni Islam"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTest-576_01e2da60a2779c4ae4b5d1547499a4f8", "prediction": ["Guatemala"], "ground_truth": ["Guatemala"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTest-590_6aad73acb74f304bc9acae44314164be", "prediction": ["Muammar Gaddafi"], "ground_truth": ["Abdullah al-Thani"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTrn-3251_d8cddfe5e947e414b7735780ef1efff8", "prediction": ["University of Northern Colorado"], "ground_truth": ["University of Northern Colorado"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-2069_0fa727f3b282196eb1097410b4be6818", "prediction": ["Spanish Language"], "ground_truth": ["Aymara language", "Mapudungun Language", "Rapa Nui Language", "Spanish Language", "Puquina Language"], "acc": 0.2, "hit": 1, "hit1": 1, "f1": 0.33333333333333337, "precission": 1.0, "recall": 0.2}
{"id": "WebQTrn-1938_7322a2a4d46bf36b95bfab4418c9a32b", "prediction": ["Parliamentary system"], "ground_truth": ["Parliamentary system"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-60_68f0d0ad309d64a4af858a5ef4fb5713", "prediction": ["Angola"], "ground_truth": ["Mozambique"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTest-100_bf66fd89b6e9fc5fcb96c8b3f7a0e616", "prediction": ["Basque Language", "Catalan language", "Proven\u00e7al Language", "Corsican Language", "Tahitian Language", "Franco-Proven\u00e7al Language", "Breton", "Yeniche Language", "Guianese Creole French Language", "Esperanto Language", "West Flemish", "Gallo language", "Antillean Creole French", "Occitan language", "French", "R\u00e9union Creole French Language", "Alsatian dialect"], "ground_truth": ["Haitian Creole", "French"], "acc": 0.5, "hit": 1, "hit1": 0, "f1": 0.10526315789473684, "precission": 0.058823529411764705, "recall": 0.5}
{"id": "WebQTrn-962_032f61bfcfed69da8b215bb8f058c24e", "prediction": ["Michael Connor Humphreys"], "ground_truth": ["Michael Connor Humphreys"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-849_586aae7703d62aa44eb79759e1563309", "prediction": ["Denmark"], "ground_truth": ["Denmark"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-710_c264a6d11d7956741926d417b94327e2", "prediction": ["2014 World Series"], "ground_truth": ["2014 World Series"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-2209_c1374f388d9cc7a78365860c91218362", "prediction": ["2008 NBA Finals", "1969 NBA Finals", "1986 NBA Finals", "1984 NBA Finals", "1981 NBA Finals", "1976 NBA Finals", "1974 NBA Finals", "1968 NBA Finals", "1966 NBA Finals", "1965 NBA Finals", "1964 NBA Finals", "1963 NBA Finals", "1962 NBA Finals", "1961 NBA Finals", "1960 NBA Finals", "1957 NBA Finals"], "ground_truth": ["2008 NBA Finals", "1969 NBA Finals", "1986 NBA Finals", "1984 NBA Finals", "1981 NBA Finals", "1976 NBA Finals", "1974 NBA Finals", "1968 NBA Finals", "1966 NBA Finals", "1965 NBA Finals", "1964 NBA Finals", "1963 NBA Finals", "1962 NBA Finals", "1961 NBA Finals", "1960 NBA Finals", "1959 NBA Finals", "1957 NBA Finals"], "acc": 0.9411764705882353, "hit": 1, "hit1": 1, "f1": 0.9696969696969697, "precission": 1.0, "recall": 0.9411764705882353}
{"id": "WebQTest-1251_cbf2f20f6caf754bc49d672ca7b150b7", "prediction": ["East Germany"], "ground_truth": ["East Germany"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-3100_143c89d70679c3e5257c93d8e2bc4c67", "prediction": ["Greater Antilles", "Pico Duarte"], "ground_truth": ["Greater Antilles"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 0.6666666666666666, "precission": 0.5, "recall": 1.0}
{"id": "WebQTrn-662_7a992044f94b39edfc37ac5dcfcb3c26", "prediction": ["Manchester United F.C. First team", "Manchester United F.C."], "ground_truth": ["Newton Heath L&YR F.C."], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTest-743_0a8cdba29cf260283b7c890b3609c0b9", "prediction": ["Robert F. Kennedy"], "ground_truth": ["Robert F. Kennedy"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTest-1797_68a33792b0a1e18937dcd4b3f50d941e", "prediction": ["Confederate States of America"], "ground_truth": ["Confederate States of America"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-1864_9dc4e22121d3a46d45b8f9bd9e8c7013", "prediction": ["China"], "ground_truth": ["India"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTrn-1841_b8df00139e3fa59b8633ef551ed8ca9f", "prediction": ["Football"], "ground_truth": ["Spain national football team"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTrn-124_f3990dc9aa470fa81ec4cf2912a9924f", "prediction": ["Unbroken"], "ground_truth": ["In the Land of Blood and Honey"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTest-538_49b4e9304f18a0a1cbe37bb162f61131", "prediction": ["National Lampoon's TV: The Movie", "The Tournament", "Celebrity", "Life as a House", "Anatomy of a Hate Crime", "Changing Hearts", "The Rules of Attraction", "Marco Polo", "How to Make Love to a Woman", "In Enemy Hands", "Pulse", "Lost City Raiders", "Wake", "The Lost Samaritan", "Fireball", "Fearless", "The Sensation of Sight", "The Anomaly", "Time Framed", "Caught on Tape", "Recess", "The Old Man and the Studio"], "ground_truth": ["Hannah Montana: The Movie", "Valentine's Day", "The Lorax", "Jonas Brothers: The Concert Experience", "The Giver"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTrn-2722_8babdaa9ecd05a72e3227b43b1f98771", "prediction": ["2007 World Championships in Athletics \u2013 Men's 100 metres", "2007 World Championships in Athletics \u2013 Men's 200 metres"], "ground_truth": ["Gymnastics at the 2012 Summer Olympics \u2013 Women's artistic team all-around"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTest-983_ccda690fb745939d0a62c3fbcf3e3769", "prediction": ["University of Missouri"], "ground_truth": ["University of Missouri"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-1770_8db36acba886620a06031d39165d78de", "prediction": ["Paris-Michael Katherine Jackson", "Michael Joseph Jackson, Jr.", "Prince Michael Jackson II"], "ground_truth": ["Blue Ivy"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTrn-1817_89670933168c3f4e5195a241f9d46e76", "prediction": ["Communist state", "Socialist state", "Single-party state"], "ground_truth": ["Communist state", "Socialist state", "Single-party state"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-849_1eb0f1ddd5074471fbe3a7f6b575f202", "prediction": ["Denmark"], "ground_truth": ["Denmark"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-2152_92fba37c9723caee68665ad9a5e4a468", "prediction": ["Oakland Athletics", "Texas Rangers"], "ground_truth": ["Texas Rangers"], "acc": 1.0, "hit": 1, "hit1": 0, "f1": 0.6666666666666666, "precission": 0.5, "recall": 1.0}
{"id": "WebQTrn-3376_0619d288bbed0ca782e60c6f841a6051", "prediction": ["Spelman College", "Sarah Lawrence College"], "ground_truth": ["Temple University", "Castlemont High School"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTrn-567_693feb48c0515cd069014e7ca2846b37", "prediction": ["A Beautiful Mind"], "ground_truth": ["The Journey"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTrn-1077_0c34ca057060e35aa5c74fbbca682dee", "prediction": ["Islam"], "ground_truth": ["Shia Islam", "Sunni Islam"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTest-1686_29e74083744b3631541b29b4094fb273", "prediction": ["Janet Napolitano"], "ground_truth": ["Janet Napolitano"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTest-626_8c6b952c6bd963f0ece4e401c9eb731a", "prediction": ["Iowa", "Kansas", "Missouri", "Nebraska", "North Dakota", "South Dakota"], "ground_truth": ["Iowa", "Kansas", "Missouri", "Nebraska", "North Dakota", "South Dakota"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTest-12_68d745a0657c86906382873e57294d6a", "prediction": ["Ted Strickland", "Mike DeWine", "Sherrod Brown"], "ground_truth": ["Return J. Meigs, Jr.", "John Kasich"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTest-576_906ad6be7bec9d208f4dde4f7721c261", "prediction": ["El Salvador", "El Salvador", "Guatemala", "Honduras", "Panama", "Costa Rica"], "ground_truth": ["Belize", "Costa Rica", "El Salvador", "Guatemala", "Honduras", "Panama"], "acc": 0.8333333333333334, "hit": 1, "hit1": 1, "f1": 0.8333333333333334, "precission": 0.8333333333333334, "recall": 0.8333333333333334}
{"id": "WebQTrn-2152_3cdf60c15a8355981dd92e3c57ac2eed", "prediction": ["Seattle Mariners"], "ground_truth": ["Seattle Mariners"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTest-61_09020cbb000c86fda5910ec084690246", "prediction": ["Islam"], "ground_truth": ["Islam"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-1864_67ecd1c247c3b2c9545fbcf1ad8d9d00", "prediction": ["China"], "ground_truth": ["India"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTrn-3136_a2debf685e0c50491e35a9cf7a1e9ade", "prediction": ["Northern Ireland"], "ground_truth": ["Northern Ireland"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTest-998_f693ae2cb9aa6b8ffdf0f103a6777b62", "prediction": ["1973 NBA Finals"], "ground_truth": ["1973 NBA Finals"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-1023_1e7110e48c30a2cef3caf291e3b8d394", "prediction": ["Argentine peso"], "ground_truth": ["Argentine peso"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-810_c334509bb5e02cacae1ba2e80c176499", "prediction": ["2014 World Series", "2010 World Series", "2012 World Series"], "ground_truth": ["2014 World Series"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 0.5, "precission": 0.3333333333333333, "recall": 1.0}
{"id": "WebQTrn-557_960c16ffdb29e173df0577fc76c7455d", "prediction": ["Judy Garland"], "ground_truth": ["Judy Garland"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-21_660138373d19bbffdd3d3f7a30234e4a", "prediction": ["Hailemariam Desalegn"], "ground_truth": ["Hailemariam Desalegn"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-2286_e0906c845ea8b2e22e08e1b0e6eb9b43", "prediction": ["United States of America", "United States, with Territories"], "ground_truth": ["Contiguous United States", "United States of America", "United States, with Territories"], "acc": 0.6666666666666666, "hit": 1, "hit1": 1, "f1": 0.8, "precission": 1.0, "recall": 0.6666666666666666}
{"id": "WebQTest-1320_c5498ca807d2e1ec30d4c8fdd41f0bf7", "prediction": ["Miller Park", "Milwaukee"], "ground_truth": ["Miller Park"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 0.6666666666666666, "precission": 0.5, "recall": 1.0}
{"id": "WebQTrn-2444_8f2cd432b509e5b8fe681bb55bca2767", "prediction": ["Sportsman's Park"], "ground_truth": ["Busch Stadium"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTest-100_524908899a8aa334a18a0ac00f8f2fe6", "prediction": ["Basque Language", "Catalan language", "Proven\u00e7al Language", "Corsican Language", "Tahitian Language", "Franco-Proven\u00e7al Language", "Breton", "Yeniche Language", "Guianese Creole French Language", "Esperanto Language", "West Flemish", "Gallo language", "Antillean Creole French", "Occitan language", "French", "R\u00e9union Creole French Language", "Alsatian dialect"], "ground_truth": ["Haitian Creole", "French"], "acc": 0.5, "hit": 1, "hit1": 0, "f1": 0.10526315789473684, "precission": 0.058823529411764705, "recall": 0.5}
{"id": "WebQTrn-568_d54918e8e89ad97237bce821087a9818", "prediction": ["Mecklenburg County"], "ground_truth": ["Mecklenburg County"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTest-361_e24533e28da40db99eb4b25773f9d38f", "prediction": ["Albertina", "Museum of Military History, Vienna"], "ground_truth": ["Kunsthistorisches Museum"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTrn-2215_11c4cd5a25fd84f3980d7013c0329bad", "prediction": ["Glen Dale"], "ground_truth": ["Glen Dale"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-2834_6a27420dcf0528ae017dd74e40cfd38a", "prediction": ["Islam"], "ground_truth": ["Islam"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-634_000f72b60ccaedff5f056f522ee06e98", "prediction": ["Film Score Composer", "Composer"], "ground_truth": ["Professor", "Monk", "Theologian", "Physician", "Priest", "Writer"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTrn-857_9392f3f06e288ee4e3437a74f6bf5a37", "prediction": ["Al Gore"], "ground_truth": ["Al Gore"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-810_a188aff4a054e1ec66fafba1b8021f67", "prediction": ["2014 World Series", "2010 World Series", "2012 World Series"], "ground_truth": ["2014 World Series"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 0.5, "precission": 0.3333333333333333, "recall": 1.0}
{"id": "WebQTrn-2026_d059b24adec4064377b957ca598769be", "prediction": ["Dominican peso"], "ground_truth": ["Dominican peso"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-60_990f6babd500d25e3746174e6da58c84", "prediction": ["Angola", "Guyana"], "ground_truth": ["Angola"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 0.6666666666666666, "precission": 0.5, "recall": 1.0}
{"id": "WebQTrn-2570_374d1789f1735b6f08e1a829c0d075a2", "prediction": ["Franklin D. Roosevelt"], "ground_truth": ["Harry S. Truman"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTest-484_ac67410188d0f2258139a3c84773885e", "prediction": ["English Language", "Esperanto Language", "Lojban"], "ground_truth": ["English Language", "Esperanto Language", "Lojban"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTest-1785_1c178196ae53ffd4b09e5787a35c3950", "prediction": ["James at 15"], "ground_truth": ["Nanny and the Professor"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTrn-3744_26401b4c33bcb760afd734acaa0a1869", "prediction": ["University of Wisconsin-Madison"], "ground_truth": ["University of Wisconsin-Madison"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-810_e3d40457273785e46c5b71732713a5f4", "prediction": ["2014 World Series", "2010 World Series", "2012 World Series"], "ground_truth": ["2014 World Series"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 0.5, "precission": 0.3333333333333333, "recall": 1.0}
{"id": "WebQTest-538_92e606ef9c0429ad6820797ad2950730", "prediction": ["The Lorax"], "ground_truth": ["The Lorax"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-2189_f4440609f5cecb091bf8e86adb47be25", "prediction": ["Modern Standard Arabic"], "ground_truth": ["Modern Standard Arabic"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-372_a806813629e86776ae3bfe26a3e000f8", "prediction": ["Holy Spirit", "Jesus Christ", "The Father", "The Son", "God"], "ground_truth": ["Holy Spirit", "Jesus Christ", "The Father", "God"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 0.888888888888889, "precission": 0.8, "recall": 1.0}
{"id": "WebQTrn-3744_b9bd90569bb7912ec3ea180bf164663c", "prediction": ["University of Wisconsin-Madison"], "ground_truth": ["University of Wisconsin-Madison"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTest-712_6099ea03f4fd2476605c4a513318d29c", "prediction": ["New Zealand", "New Zealand", "New Zealand", "New Zealand", "New Zealand", "New Zealand"], "ground_truth": ["New Zealand", "New Zealand", "New Zealand", "New Zealand", "New Zealand"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 0.9090909090909091, "precission": 0.8333333333333334, "recall": 1.0}
{"id": "WebQTest-100_de15ac1f762e3ec1e1261f6d9c81ebf9", "prediction": ["Haitian Creole", "French"], "ground_truth": ["Haitian Creole", "French"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-1557_edfd3507d32929ce0e9d844f7a2674de", "prediction": ["McGill University"], "ground_truth": ["McGill University"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-60_39ba4faa87698cb0767d1a5ee7ce1827", "prediction": ["Portugal"], "ground_truth": ["South Africa"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTrn-846_a29552911617e890ca2e1d6564e0990e", "prediction": ["Belleville"], "ground_truth": ["Belleville"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTest-121_333b95bd45af3e6e31e328bc8c24d84f", "prediction": ["Manchester City F.C."], "ground_truth": ["LA Galaxy"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTest-626_47b990334f91b6d3bd042f82b81740f6", "prediction": ["South Dakota"], "ground_truth": ["South Dakota"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-2444_e5059ff268415917df330817b9c8ef8c", "prediction": ["Roger Dean Stadium"], "ground_truth": ["Busch Stadium"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTrn-2664_471b83eade9707a4dba68e201bf29d73", "prediction": ["Madagascar"], "ground_truth": ["Sierra Leone"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTest-576_8dfecb6548586cf236340abadabeb86c", "prediction": ["Honduras", "Honduras"], "ground_truth": ["El Salvador", "Panama"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTrn-124_028bb5f442b37a4af9f9fd9fa0bc5e9a", "prediction": ["By the Sea", "In the Land of Blood and Honey", "A Place in Time", "Unbroken"], "ground_truth": ["By the Sea", "In the Land of Blood and Honey", "A Place in Time", "Unbroken"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 1.0, "precission": 1.0, "recall": 1.0}
{"id": "WebQTrn-2871_d7d303efc1f901f14e6aae2bb469743c", "prediction": ["Arizona", "Phoenix", "Lake Powell"], "ground_truth": ["Lake Powell", "Phoenix"], "acc": 1.0, "hit": 1, "hit1": 0, "f1": 0.8, "precission": 0.6666666666666666, "recall": 1.0}
{"id": "WebQTrn-750_bb35e9b8023fbf9c05df55b4245a4775", "prediction": ["1988 World Series", "1981 World Series", "1965 World Series", "1959 World Series", "1963 World Series", "1964 World Series", "1956 World Series", "1955 World Series", "1953 World Series", "1952 World Series", "1951 World Series", "1950 World Series", "1949 World Series", "1947 World Series", "1946 World Series", "1945 World Series", "1943 World Series", "1942 World Series", "1941 World Series", "1939 World Series", "1938 World Series", "1937 World Series", "1936 World Series", "1935 World Series", "1934 World Series", "1933 World Series", "1932 World Series", "1931 World Series", "1930 World Series", "1929 World Series", "1928 World Series", "1927 World Series", "1925 World Series", "1923 World Series", "1922 World Series", "1921 World Series", "1916 World Series", "1915 World Series", "1913 World Series", "1912 World Series", "1911 World Series", "1910 World Series", "1909 World Series", "1908 World Series", "1907 World Series", "1905 World Series", "1903 World Series", "1902 World Series", "1901 World Series", "1883 World Series", "1882 World Series", "1881 World Series", "1879 World Series", "1877 World Series", "1876 World Series", "1875 World Series", "1874 World Series", "1873 World Series", "1872 World Series", "1871 World Series", "1870 World Series", "1868 World Series", "1867 World Series", "1866 World Series", "1865 World Series", "1864 World Series", "1863 World Series", "1862 World Series", "1861 World Series", "1860 World Series", "1859 World Series", "1857 World Series", "1856 World Series"], "ground_truth": ["1981 World Series", "1988 World Series", "1965 World Series", "1963 World Series", "1959 World Series"], "acc": 1.0, "hit": 1, "hit1": 1, "f1": 0.1282051282051282, "precission": 0.0684931506849315, "recall": 1.0}
{"id": "WebQTest-96_11da03aa9cec8b011619c8ea0dbfdcf9", "prediction": ["Dallas"], "ground_truth": ["New York City"], "acc": 0.0, "hit": 0, "hit1": 0, "f1": 0, "precission": 0.0, "recall": 0.0}
{"id": "WebQTest-1528_25853c768670cd164d7793f094ba7cbb", "prediction": ["Percy Jackson & the Olympians: The Lightning Thief"], "ground_truth": ["Percy Jackson & the Olympians: The Lightning Thief", "Percy Jackson: Sea of Monsters"], "acc": 0.5, "hit": 1, "hit1": 1, "f1": 0.6666666666666666, "precission": 1.0, "recall": 0.5}
{"id": "WebQTest-1260_5dd0eeca79ae03b7711252c032849eb2", "prediction": ["2001 AFC Championship Game", "2013 AFC Championship Game", "Super Bowl XLVII", "Super Bowl XXXV"], "ground_truth": ["Super Bowl XXXV"], "acc": 1.0, "hit"
Download .txt
gitextract_o4kycd_o/

├── .gitignore
├── README.md
├── gnn/
│   ├── .gitignore
│   ├── README.md
│   ├── dataset_load.py
│   ├── dataset_load_graft.py
│   ├── evaluate.py
│   ├── main.py
│   ├── models/
│   │   ├── GraftNet/
│   │   │   └── graftnet.py
│   │   ├── NSM/
│   │   │   └── nsm.py
│   │   ├── ReaRev/
│   │   │   └── rearev.py
│   │   └── base_model.py
│   ├── modules/
│   │   ├── kg_reasoning/
│   │   │   ├── base_gnn.py
│   │   │   ├── graft_gnn.py
│   │   │   ├── nsm_gnn.py
│   │   │   └── reasongnn.py
│   │   ├── layer_init.py
│   │   ├── query_update.py
│   │   └── question_encoding/
│   │       ├── base_encoder.py
│   │       ├── bert_encoder.py
│   │       ├── lstm_encoder.py
│   │       └── tokenizers.py
│   ├── parsing.py
│   ├── requirements.txt
│   ├── scripts/
│   │   └── rearev_cwq.sh
│   ├── train_model.py
│   └── utils.py
└── llm/
    ├── .gitignore
    ├── README.md
    ├── prompts/
    │   ├── alpaca.txt
    │   ├── general_prompt.txt
    │   ├── llama2.txt
    │   └── llama2_predict.txt
    ├── requirements.txt
    ├── results/
    │   ├── KGQA-GNN-RAG/
    │   │   ├── rearev-lmsr/
    │   │   │   ├── RoG-cwq/
    │   │   │   │   └── RoG/
    │   │   │   │       └── test/
    │   │   │   │           └── results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/
    │   │   │   │               └── False/
    │   │   │   │                   ├── args.txt
    │   │   │   │                   ├── detailed_eval_result.jsonl
    │   │   │   │                   ├── eval_result.txt
    │   │   │   │                   └── predictions.jsonl
    │   │   │   └── RoG-webqsp/
    │   │   │       ├── RoG/
    │   │   │       │   └── test/
    │   │   │       │       └── results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/
    │   │   │       │           └── False/
    │   │   │       │               ├── args.txt
    │   │   │       │               ├── detailed_eval_result.jsonl
    │   │   │       │               ├── eval_result.txt
    │   │   │       │               └── predictions.jsonl
    │   │   │       └── llama2-chat-hf/
    │   │   │           └── test/
    │   │   │               └── no_rule/
    │   │   │                   └── False/
    │   │   │                       ├── args.txt
    │   │   │                       ├── detailed_eval_result.jsonl
    │   │   │                       ├── eval_result.txt
    │   │   │                       └── predictions.jsonl
    │   │   └── rearev-sbert/
    │   │       ├── RoG-cwq/
    │   │       │   └── RoG/
    │   │       │       └── test/
    │   │       │           └── results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/
    │   │       │               └── False/
    │   │       │                   ├── args.txt
    │   │       │                   ├── detailed_eval_result.jsonl
    │   │       │                   ├── eval_result.txt
    │   │       │                   └── predictions.jsonl
    │   │       └── RoG-webqsp/
    │   │           └── RoG/
    │   │               └── test/
    │   │                   └── results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/
    │   │                       └── False/
    │   │                           ├── args.txt
    │   │                           ├── detailed_eval_result.jsonl
    │   │                           ├── eval_result.txt
    │   │                           └── predictions.jsonl
    │   ├── KGQA-GNN-RAG-RA/
    │   │   ├── rearev-lmsr/
    │   │   │   ├── RoG-cwq/
    │   │   │   │   └── RoG/
    │   │   │   │       └── test/
    │   │   │   │           └── results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/
    │   │   │   │               └── False/
    │   │   │   │                   ├── args.txt
    │   │   │   │                   ├── detailed_eval_result.jsonl
    │   │   │   │                   ├── eval_result.txt
    │   │   │   │                   └── predictions.jsonl
    │   │   │   └── RoG-webqsp/
    │   │   │       └── RoG/
    │   │   │           └── test/
    │   │   │               └── results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/
    │   │   │                   └── False/
    │   │   │                       ├── args.txt
    │   │   │                       ├── detailed_eval_result.jsonl
    │   │   │                       ├── eval_result.txt
    │   │   │                       └── predictions.jsonl
    │   │   └── rearev-sbert/
    │   │       ├── RoG-cwq/
    │   │       │   └── RoG/
    │   │       │       └── test/
    │   │       │           └── results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/
    │   │       │               └── False/
    │   │       │                   ├── args.txt
    │   │       │                   ├── detailed_eval_result.jsonl
    │   │       │                   ├── eval_result.txt
    │   │       │                   └── predictions.jsonl
    │   │       └── RoG-webqsp/
    │   │           └── RoG/
    │   │               └── test/
    │   │                   └── results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/
    │   │                       └── False/
    │   │                           ├── args.txt
    │   │                           ├── detailed_eval_result.jsonl
    │   │                           ├── eval_result.txt
    │   │                           └── predictions.jsonl
    │   ├── gen_rule_path/
    │   │   ├── RoG-cwq/
    │   │   │   └── RoG/
    │   │   │       └── test/
    │   │   │           ├── predictions_2_False.jsonl
    │   │   │           └── predictions_3_False.jsonl
    │   │   └── RoG-webqsp/
    │   │       └── RoG/
    │   │           ├── test/
    │   │           │   ├── predictions_1_False.jsonl
    │   │           │   ├── predictions_2_False.jsonl
    │   │           │   └── predictions_3_False.jsonl
    │   │           ├── train/
    │   │           │   ├── predictions_1_False.jsonl
    │   │           │   └── predictions_3_False.jsonl
    │   │           └── validation/
    │   │               └── predictions_3_False.jsonl
    │   └── gnn/
    │       ├── RoG-cwq/
    │       │   ├── rearev-lmsr/
    │       │   │   └── test.info
    │       │   └── rearev-sbert/
    │       │       └── test.info
    │       └── RoG-webqsp/
    │           ├── rearev-lmsr/
    │           │   └── test.info
    │           └── rearev-sbert/
    │               └── test.info
    ├── scripts/
    │   ├── evaluate_multi_hop.sh
    │   ├── interpretable_example.py
    │   ├── planning.sh
    │   ├── plug-and-play.sh
    │   ├── rag-reasoning.sh
    │   └── train.sh
    └── src/
        ├── __init__.py
        ├── align_kg/
        │   ├── __init__.py
        │   ├── build_align_qa_dataset.py
        │   └── data_loader.py
        ├── joint_training/
        │   ├── generate_explanation_results.py
        │   ├── joint_finetuning.py
        │   ├── preprocess_align.py
        │   └── preprocess_qa.py
        ├── llms/
        │   ├── __init__.py
        │   ├── language_models/
        │   │   ├── __init__.py
        │   │   ├── alpaca.py
        │   │   ├── base_language_model.py
        │   │   ├── chatgpt.py
        │   │   ├── flan_t5.py
        │   │   ├── llama.py
        │   │   └── longchat/
        │   │       ├── llama_condense_monkey_patch.py
        │   │       ├── llama_flash_attn_monkey_patch.py
        │   │       └── longchat.py
        │   ├── llm_proxy.py
        │   └── start_fastchat_api.py
        ├── qa_prediction/
        │   ├── build_qa_input.py
        │   ├── evaluate_multi_hop.py
        │   ├── evaluate_results.py
        │   ├── gen_rule_path.py
        │   └── predict_answer.py
        └── utils/
            ├── __init__.py
            ├── graph_utils.py
            ├── merge_peft.py
            ├── training_utils.py
            └── utils.py
Download .txt
SYMBOL INDEX (298 symbols across 47 files)

FILE: gnn/dataset_load.py
  class BasicDataLoader (line 18) | class BasicDataLoader(object):
    method __init__ (line 24) | def __init__(self, config, word2id, relation2id, entity2id, tokenize, ...
    method _load_file (line 31) | def _load_file(self, config, data_type="train"):
    method _load_data (line 62) | def _load_data(self):
    method _parse_args (line 88) | def _parse_args(self, config, word2id, relation2id, entity2id):
    method get_quest (line 130) | def get_quest(self, training=False):
    method decode_text (line 143) | def decode_text(self, np_array_x):
    method _prepare_data (line 159) | def _prepare_data(self):
    method build_rel_words (line 354) | def build_rel_words(self, tokenize):
    method create_kb_adj_mats (line 432) | def create_kb_adj_mats(self, sample_id):
    method _build_fact_mat (line 473) | def _build_fact_mat(self, sample_ids, fact_dropout):
    method reset_batches (line 530) | def reset_batches(self, is_sequential=True):
    method _build_global2local_entity_maps (line 536) | def _build_global2local_entity_maps(self):
    method _add_entity_to_map (line 562) | def _add_entity_to_map(entity2id, entities, g2l):
    method deal_q_type (line 577) | def deal_q_type(self, q_type=None):
  class SingleDataLoader (line 592) | class SingleDataLoader(BasicDataLoader):
    method __init__ (line 596) | def __init__(self, config, word2id, relation2id, entity2id, tokenize, ...
    method get_batch (line 599) | def get_batch(self, iteration, batch_size, fact_dropout, q_type=None, ...
  function load_dict (line 632) | def load_dict(filename):
  function load_dict_int (line 640) | def load_dict_int(filename):
  function load_data (line 648) | def load_data(config, tokenize):

FILE: gnn/dataset_load_graft.py
  class GraftBasicDataLoader (line 19) | class GraftBasicDataLoader(BasicDataLoader):
    method __init__ (line 24) | def __init__(self, config, word2id, relation2id, entity2id, tokenize, ...
    method create_kb_adj_mats_facts (line 27) | def create_kb_adj_mats_facts(self, sample_id):
    method _build_fact_mat_maxfacts (line 70) | def _build_fact_mat_maxfacts(self, sample_ids, fact_dropout):
  class GraftSingleDataLoader (line 106) | class GraftSingleDataLoader(GraftBasicDataLoader):
    method __init__ (line 110) | def __init__(self, config, word2id, relation2id, entity2id, tokenize, ...
    method get_batch (line 113) | def get_batch(self, iteration, batch_size, fact_dropout, q_type=None, ...
  function load_dict (line 152) | def load_dict(filename):
  function load_dict_int (line 160) | def load_dict_int(filename):
  function load_data_graft (line 168) | def load_data_graft(config, tokenize):

FILE: gnn/evaluate.py
  function cal_accuracy (line 10) | def cal_accuracy(pred, answer_dist):
  function f1_and_hits (line 25) | def f1_and_hits(answers, candidate2prob, id2entity, entity2name, eps=0.5):
  class Evaluator (line 70) | class Evaluator:
    method __init__ (line 72) | def __init__(self, args, model, entity2id, relation2id, device):
    method write_info (line 106) | def write_info(self, valid_data, tp_list, num_step):
    method evaluate (line 140) | def evaluate(self, valid_data, test_batch_size=20, write_info=False):

FILE: gnn/main.py
  function main (line 29) | def main():

FILE: gnn/models/GraftNet/graftnet.py
  class GraftNet (line 21) | class GraftNet(BaseModel):
    method __init__ (line 22) | def __init__(self, args, num_entity, num_relation, num_word):
    method layers (line 37) | def layers(self, args):
    method private_module_def (line 61) | def private_module_def(self, args, num_entity, num_relation):
    method get_ent_init (line 74) | def get_ent_init(self, local_entity, kb_adj_mat, rel_features):
    method get_rel_feature (line 85) | def get_rel_feature(self):
    method init_reason (line 105) | def init_reason(self, curr_dist, local_entity, kb_adj_mat, kb_adj_mat_...
    method calc_loss_label (line 128) | def calc_loss_label(self, curr_dist, teacher_dist, label_valid):
    method forward (line 135) | def forward(self, batch, training=False):

FILE: gnn/models/NSM/nsm.py
  class NSM (line 19) | class NSM(BaseModel):
    method __init__ (line 20) | def __init__(self, args, num_entity, num_relation, num_word):
    method layers (line 37) | def layers(self, args):
    method private_module_def (line 69) | def private_module_def(self, args, num_entity, num_relation):
    method get_ent_init (line 85) | def get_ent_init(self, local_entity, kb_adj_mat, rel_features):
    method get_rel_feature (line 97) | def get_rel_feature(self):
    method init_reason (line 114) | def init_reason(self, curr_dist, local_entity, kb_adj_mat, q_input):
    method get_js_div (line 142) | def get_js_div(self, dist_1, dist_2):
    method calc_loss_backward (line 151) | def calc_loss_backward(self, case_valid):
    method calc_loss_label (line 172) | def calc_loss_label(self, curr_dist, teacher_dist, label_valid):
    method forward (line 179) | def forward(self, batch, training=False):

FILE: gnn/models/ReaRev/rearev.py
  class ReaRev (line 19) | class ReaRev(BaseModel):
    method __init__ (line 20) | def __init__(self, args, num_entity, num_relation, num_word):
    method layers (line 51) | def layers(self, args):
    method get_ent_init (line 79) | def get_ent_init(self, local_entity, kb_adj_mat, rel_features):
    method get_rel_feature (line 91) | def get_rel_feature(self):
    method private_module_def (line 114) | def private_module_def(self, args, num_entity, num_relation):
    method init_reason (line 132) | def init_reason(self, curr_dist, local_entity, kb_adj_mat, q_input, qu...
    method calc_loss_label (line 156) | def calc_loss_label(self, curr_dist, teacher_dist, label_valid):
    method forward (line 163) | def forward(self, batch, training=False):

FILE: gnn/models/base_model.py
  class BaseModel (line 9) | class BaseModel(torch.nn.Module):
    method __init__ (line 14) | def __init__(self, args, num_entity, num_relation, num_word):
    method embedding_def (line 70) | def embedding_def(self):
    method load_relation_file (line 153) | def load_relation_file(self, filename):
    method use_rel_texts (line 164) | def use_rel_texts(self, rel_texts, rel_texts_inv):
    method encode_rel_texts (line 168) | def encode_rel_texts(self, rel_texts, rel_texts_inv):
    method init_hidden (line 178) | def init_hidden(self, num_layer, batch_size, hidden_size):
    method encode_question (line 181) | def encode_question(self, q_input):
    method get_instruction (line 184) | def get_instruction(self, query_hidden_emb, query_mask, states):
    method get_loss_bce (line 187) | def get_loss_bce(self, pred_dist_score, answer_dist):
    method get_loss_kl (line 193) | def get_loss_kl(self, pred_dist, answer_dist):
    method get_loss (line 201) | def get_loss(self, pred_dist, answer_dist, reduction='mean'):
    method f1_and_hits (line 217) | def f1_and_hits(self, answers, candidate2prob, eps=0.5):
    method calc_f1_new (line 249) | def calc_f1_new(self, curr_dist, dist_ans, h1_vec):
    method calc_h1 (line 287) | def calc_h1(self, curr_dist, dist_ans, eps=0.01):
    method get_eval_metric (line 294) | def get_eval_metric(self, pred_dist, answer_dist):

FILE: gnn/modules/kg_reasoning/base_gnn.py
  class BaseGNNLayer (line 7) | class BaseGNNLayer(torch.nn.Module):
    method __init__ (line 11) | def __init__(self, args, num_entity, num_relation):
    method build_matrix (line 19) | def build_matrix(self):
    method _build_sparse_tensor (line 53) | def _build_sparse_tensor(self, indices, values, size):
    method build_adj_facts (line 56) | def build_adj_facts(self):

FILE: gnn/modules/kg_reasoning/graft_gnn.py
  class GraftLayer (line 14) | class GraftLayer(BaseGNNLayer):
    method __init__ (line 15) | def __init__(self, args, num_entity, num_relation, entity_dim):
    method init_layers (line 27) | def init_layers(self, args):
    method init_reason (line 45) | def init_reason(self, local_entity, kb_adj_mat, kb_adj_mat_graft, kb_f...
    method compute_attention (line 64) | def compute_attention(self, query_hidden_emb, query_mask):
    method reason_layer (line 89) | def reason_layer(self, curr_dist, kb_self_linear, kb_head_linear, kb_t...
    method forward (line 111) | def forward(self, current_dist, query_hidden_emb, query_mask, step=0, ...

FILE: gnn/modules/kg_reasoning/nsm_gnn.py
  class NSMBaseLayer (line 14) | class NSMBaseLayer(BaseGNNLayer):
    method __init__ (line 15) | def __init__(self, args, num_entity, num_relation, entity_dim):
    method init_layers (line 24) | def init_layers(self, args):
    method init_reason (line 38) | def init_reason(self, local_entity, kb_adj_mat, local_entity_emb, rel_...
    method reason_layer (line 51) | def reason_layer(self, curr_dist, instruction, rel_linear):
    method forward (line 54) | def forward(self, current_dist, relational_ins, step=0, return_score=F...
  class NSMLayer (line 83) | class NSMLayer(NSMBaseLayer):
    method __init__ (line 84) | def __init__(self, args, num_entity, num_relation, entity_dim):
    method reason_layer (line 87) | def reason_layer(self, curr_dist, instruction, rel_linear):
  class NSMLayer_back (line 114) | class NSMLayer_back(NSMBaseLayer):
    method __init__ (line 115) | def __init__(self, args, num_entity, num_relation, entity_dim):
    method reason_layer (line 118) | def reason_layer(self, curr_dist, instruction, rel_linear):

FILE: gnn/modules/kg_reasoning/reasongnn.py
  class ReasonGNNLayer (line 11) | class ReasonGNNLayer(BaseGNNLayer):
    method __init__ (line 15) | def __init__(self, args, num_entity, num_relation, entity_dim, alg):
    method init_layers (line 27) | def init_layers(self, args):
    method init_reason (line 46) | def init_reason(self, local_entity, kb_adj_mat, local_entity_emb, rel_...
    method reason_layer (line 61) | def reason_layer(self, curr_dist, instruction, rel_linear, pos_emb):
    method reason_layer_inv (line 91) | def reason_layer_inv(self, curr_dist, instruction, rel_linear, pos_emb...
    method combine (line 118) | def combine(self,emb):
    method forward (line 134) | def forward(self, current_dist, relational_ins, step=0, return_score=F...

FILE: gnn/modules/layer_init.py
  class TypeLayer (line 9) | class TypeLayer(nn.Module):
    method __init__ (line 14) | def __init__(self, in_features, out_features, linear_drop, device, nor...
    method forward (line 25) | def forward(self, local_entity, edge_list, rel_features):
    method _build_sparse_tensor (line 64) | def _build_sparse_tensor(self, indices, values, size):

FILE: gnn/modules/query_update.py
  class Fusion (line 6) | class Fusion(nn.Module):
    method __init__ (line 8) | def __init__(self, d_hid):
    method forward (line 13) | def forward(self, x, y):
  class QueryReform (line 18) | class QueryReform(nn.Module):
    method __init__ (line 20) | def __init__(self, h_dim):
    method forward (line 26) | def forward(self, q_node, ent_emb, seed_info, ent_mask):
  class AttnEncoder (line 46) | class AttnEncoder(nn.Module):
    method __init__ (line 48) | def __init__(self, d_hid):
    method forward (line 52) | def forward(self, x, x_mask):
  class Attention (line 63) | class Attention(nn.Module):
    method __init__ (line 89) | def __init__(self, dimensions, attention_type='general'):
    method forward (line 103) | def forward(self, query, context):

FILE: gnn/modules/question_encoding/base_encoder.py
  class BaseInstruction (line 8) | class BaseInstruction(torch.nn.Module):
    method __init__ (line 10) | def __init__(self, args, constraint):
    method _parse_args (line 16) | def _parse_args(self, args):
    method share_module_def (line 48) | def share_module_def(self):
    method init_hidden (line 53) | def init_hidden(self, num_layer, batch_size, hidden_size):
    method encode_question (line 57) | def encode_question(self, *args):
    method get_node_emb (line 62) | def get_node_emb(query_hidden_emb, action):
    method init_reason (line 74) | def init_reason(self, query_text):
    method get_instruction (line 82) | def get_instruction(self, relational_ins, step=0, query_node_emb=None):
    method forward (line 105) | def forward(self, query_text, lm=None):

FILE: gnn/modules/question_encoding/bert_encoder.py
  class BERTInstruction (line 18) | class BERTInstruction(BaseInstruction):
    method __init__ (line 20) | def __init__(self, args, word_embedding, num_word, model, constraint=F...
    method encoder_def (line 74) | def encoder_def(self):
    method encode_question (line 89) | def encode_question(self, query_text, store=True):

FILE: gnn/modules/question_encoding/lstm_encoder.py
  class LSTMInstruction (line 10) | class LSTMInstruction(BaseInstruction):
    method __init__ (line 12) | def __init__(self, args, word_embedding, num_word):
    method encoder_def (line 25) | def encoder_def(self):
    method encode_question (line 32) | def encode_question(self, query_text, store=True):

FILE: gnn/modules/question_encoding/tokenizers.py
  class LSTMTokenizer (line 5) | class LSTMTokenizer():
    method __init__ (line 6) | def __init__(self, word2id, max_query_word):
    method tokenize (line 11) | def tokenize(self, question):
    method tokenize_sent (line 28) | def tokenize_sent(question_text):
  class BERTTokenizer (line 41) | class BERTTokenizer():
    method __init__ (line 42) | def __init__(self, max_query_word):
    method tokenize (line 50) | def tokenize(self, question):

FILE: gnn/parsing.py
  function bool_flag (line 5) | def bool_flag(v):
  function add_shared_args (line 13) | def add_shared_args(parser):
  function add_parse_args (line 68) | def add_parse_args(parser):
  function create_parser_rearev (line 85) | def create_parser_rearev(parser):
  function create_parser_nsm (line 101) | def create_parser_nsm(parser):
  function create_parser_graftnet (line 115) | def create_parser_graftnet(parser):

FILE: gnn/train_model.py
  class Trainer_KBQA (line 24) | class Trainer_KBQA(object):
    method __init__ (line 25) | def __init__(self, args, model_name, logger=None):
    method optim_def (line 89) | def optim_def(self):
    method load_data (line 96) | def load_data(self, args, tokenize):
    method load_pretrain (line 113) | def load_pretrain(self):
    method evaluate (line 120) | def evaluate(self, data, test_batch_size=20, write_info=False):
    method train (line 123) | def train(self, start_epoch, end_epoch):
    method evaluate_best (line 182) | def evaluate_best(self):
    method evaluate_single (line 201) | def evaluate_single(self, filename):
    method train_epoch (line 209) | def train_epoch(self):
    method save_ckpt (line 236) | def save_ckpt(self, reason="h1"):
    method load_ckpt (line 246) | def load_ckpt(self, filename):

FILE: gnn/utils.py
  function create_logger (line 5) | def create_logger(args):
  function get_dict (line 29) | def get_dict(data_folder, filename):

FILE: llm/src/align_kg/build_align_qa_dataset.py
  function build_data (line 13) | def build_data(args):
  function process_data (line 35) | def process_data(data, remove_duplicate=False):

FILE: llm/src/align_kg/data_loader.py
  function load_new_tokens (line 10) | def load_new_tokens(default_new_tokens, rel_dict_path):
  function load_multiple_datasets (line 21) | def load_multiple_datasets(data_path_list, shuffle=False):
  function get_test_dataset (line 41) | def get_test_dataset(dataset):

FILE: llm/src/joint_training/generate_explanation_results.py
  function formatting_prompts_func (line 106) | def formatting_prompts_func(example):

FILE: llm/src/joint_training/joint_finetuning.py
  class ScriptArguments (line 37) | class ScriptArguments:
  class ScriptTrainingArguments (line 72) | class ScriptTrainingArguments(TrainingArguments):
  function train (line 84) | def train():

FILE: llm/src/joint_training/preprocess_align.py
  function formatting_prompts_func (line 29) | def formatting_prompts_func(example):

FILE: llm/src/joint_training/preprocess_qa.py
  function formatting_prompts_func (line 36) | def formatting_prompts_func(example):

FILE: llm/src/llms/language_models/__init__.py
  function get_registed_model (line 18) | def get_registed_model(model_name) -> BaseLanguageModel:

FILE: llm/src/llms/language_models/alpaca.py
  class Alpaca (line 5) | class Alpaca(BaseLanguageModel):
    method add_args (line 8) | def add_args(parser):
    method __init__ (line 13) | def __init__(self, args):
    method load_model (line 17) | def load_model(self, **kwargs):
    method tokenize (line 20) | def tokenize(self, text):
    method prepare_for_inference (line 23) | def prepare_for_inference(self, **model_kwargs):
    method generate_sentence (line 28) | def generate_sentence(self, llm_input):

FILE: llm/src/llms/language_models/base_language_model.py
  class BaseLanguageModel (line 4) | class BaseLanguageModel(object):
    method add_args (line 12) | def add_args(parser):
    method __init__ (line 15) | def __init__(self, args):
    method load_model (line 18) | def load_model(self, **kwargs):
    method prepare_for_inference (line 21) | def prepare_for_inference(self, **model_kwargs):
    method tokenize (line 24) | def tokenize(self, text):
    method generate_sentence (line 33) | def generate_sentence(self, lm_input):

FILE: llm/src/llms/language_models/chatgpt.py
  function get_token_limit (line 13) | def get_token_limit(model='gpt-4'):
  class ChatGPT (line 25) | class ChatGPT(BaseLanguageModel):
    method add_args (line 28) | def add_args(parser):
    method __init__ (line 31) | def __init__(self, args):
    method tokenize (line 37) | def tokenize(self, text):
    method prepare_for_inference (line 46) | def prepare_for_inference(self, model_kwargs={}):
    method generate_sentence (line 52) | def generate_sentence(self, llm_input):

FILE: llm/src/llms/language_models/flan_t5.py
  class FlanT5 (line 5) | class FlanT5(BaseLanguageModel):
    method add_args (line 8) | def add_args(parser):
    method __init__ (line 13) | def __init__(self, args):
    method load_model (line 17) | def load_model(self, **kwargs):
    method tokenize (line 21) | def tokenize(self, text):
    method prepare_for_inference (line 24) | def prepare_for_inference(self, **model_kwargs):
    method generate_sentence (line 30) | def generate_sentence(self, llm_input):

FILE: llm/src/llms/language_models/llama.py
  class Llama (line 6) | class Llama(BaseLanguageModel):
    method add_args (line 9) | def add_args(parser):
    method __init__ (line 15) | def __init__(self, args):
    method load_model (line 19) | def load_model(self, **kwargs):
    method tokenize (line 23) | def tokenize(self, text):
    method prepare_for_inference (line 26) | def prepare_for_inference(self, **model_kwargs):
    method generate_sentence (line 34) | def generate_sentence(self, llm_input):

FILE: llm/src/llms/language_models/longchat/llama_condense_monkey_patch.py
  function rank0_print (line 10) | def rank0_print(*args):
  class CondenseRotaryEmbedding (line 18) | class CondenseRotaryEmbedding(torch.nn.Module):
    method __init__ (line 19) | def __init__(self, dim, ratio, max_position_embeddings=2048, base=1000...
    method forward (line 37) | def forward(self, x, seq_len=None):
  function replace_llama_with_condense (line 53) | def replace_llama_with_condense(ratio):

FILE: llm/src/llms/language_models/longchat/llama_flash_attn_monkey_patch.py
  function forward (line 14) | def forward(
  function _prepare_decoder_attention_mask (line 84) | def _prepare_decoder_attention_mask(self, attention_mask, input_shape,
  function replace_llama_attn_with_flash_attn (line 90) | def replace_llama_attn_with_flash_attn():

FILE: llm/src/llms/language_models/longchat/longchat.py
  function maybe_monkey_patch (line 5) | def maybe_monkey_patch(args):
  class Longchat (line 14) | class Longchat(BaseLanguageModel):
    method add_args (line 18) | def add_args(parser):
    method __init__ (line 25) | def __init__(self, args):
    method load_model (line 30) | def load_model(self, **kwargs):
    method tokenize (line 35) | def tokenize(self, text):
    method prepare_for_inference (line 38) | def prepare_for_inference(self, **model_kwargs):
    method generate_sentence (line 44) | def generate_sentence(self, llm_input):

FILE: llm/src/llms/llm_proxy.py
  class LLMProxy (line 7) | class LLMProxy(object):
    method regist_args (line 10) | def regist_args(parser):
    method __init__ (line 18) | def __init__(self, args) -> None:
    method query (line 33) | def query(message, model_name, timeout=60, max_retry=3):

FILE: llm/src/llms/start_fastchat_api.py
  function terminate_process (line 8) | def terminate_process():
  function start_fastchat_api (line 18) | def start_fastchat_api(model_names, model_path, conv_template, host, port):
  function exit_handler (line 49) | def exit_handler(signal, frame):

FILE: llm/src/qa_prediction/build_qa_input.py
  function normalize (line 15) | def normalize(s: str) -> str:
  class PromptBuilder (line 26) | class PromptBuilder(object):
    method __init__ (line 40) | def __init__(self, prompt_path, encrypt=False, add_rule = False, use_t...
    method _read_prompt_template (line 53) | def _read_prompt_template(self, template_file):
    method apply_rules (line 58) | def apply_rules(self, graph, rules, srouce_entities):
    method direct_answer (line 66) | def direct_answer(self, question_dict):
    method process_input (line 83) | def process_input(self, question_dict):
    method check_prompt_length (line 164) | def check_prompt_length(self, prompt, list_of_paths, maximun_token):

FILE: llm/src/qa_prediction/evaluate_multi_hop.py
  function normalize (line 21) | def normalize(s: str) -> str:
  function match (line 33) | def match(s1: str, s2: str) -> bool:
  function eval_acc (line 38) | def eval_acc(prediction, answer):
  function eval_hit (line 45) | def eval_hit(prediction, answer):
  function eval_hit1 (line 51) | def eval_hit1(prediction, answer):
  function eval_f1 (line 57) | def eval_f1(prediction, answer):
  function extract_topk_prediction (line 72) | def extract_topk_prediction(prediction, k=-1):
  function eval_result (line 84) | def eval_result(predict_file1, encrypt=False, cal_f1=True, topk = -1):

FILE: llm/src/qa_prediction/evaluate_results.py
  function normalize (line 15) | def normalize(s: str) -> str:
  function match (line 27) | def match(s1: str, s2: str) -> bool:
  function eval_acc (line 32) | def eval_acc(prediction, answer):
  function eval_hit (line 39) | def eval_hit(prediction, answer):
  function eval_hit1 (line 45) | def eval_hit1(prediction, answer):
  function eval_f1 (line 51) | def eval_f1(prediction, answer):
  function extract_topk_prediction (line 66) | def extract_topk_prediction(prediction, k=-1):
  function eval_result (line 78) | def eval_result(predict_file, encrypt=False, cal_f1=True, topk = -1):

FILE: llm/src/qa_prediction/gen_rule_path.py
  function get_output_file (line 25) | def get_output_file(path, force=False):
  function parse_prediction (line 42) | def parse_prediction(prediction):
  function generate_seq (line 71) | def generate_seq(
  function gen_prediction (line 102) | def gen_prediction(args):

FILE: llm/src/qa_prediction/predict_answer.py
  function normalize (line 25) | def normalize(s: str) -> str:
  function match (line 37) | def match(s1: str, s2: str) -> bool:
  function load_gnn_rag (line 43) | def load_gnn_rag(g_data_file, g_data_file2=None):
  function get_output_file (line 83) | def get_output_file(path, force=False):
  function merge_rule_result (line 100) | def merge_rule_result(qa_dataset, rule_dataset, n_proc=1, filter_empty=F...
  function prediction (line 127) | def prediction(data, processed_list, input_builder, model, encrypt=False...
  function main (line 174) | def main(args, LLM):

FILE: llm/src/utils/graph_utils.py
  function build_graph (line 10) | def build_graph(graph: list, entities=None, encrypt=False) -> nx.Graph:
  function bfs_with_rule (line 24) | def bfs_with_rule(graph, start_node, target_rule, max_p = 10):
  function get_truth_paths (line 49) | def get_truth_paths(q_entity: list, a_entity: list, graph: nx.Graph) -> ...
  function get_simple_paths (line 77) | def get_simple_paths(q_entity: list, a_entity: list, graph: nx.Graph, ho...
  function get_negative_paths (line 100) | def get_negative_paths(q_entity: list, a_entity: list, graph: nx.Graph, ...
  function get_random_paths (line 129) | def get_random_paths(q_entity: list, graph: nx.Graph, n=3, hop=2):# -> t...

FILE: llm/src/utils/merge_peft.py
  class ScriptArguments (line 8) | class ScriptArguments:

FILE: llm/src/utils/training_utils.py
  function smart_tokenizer_and_embedding_resize (line 4) | def smart_tokenizer_and_embedding_resize(

FILE: llm/src/utils/utils.py
  function read_prompt (line 5) | def read_prompt(prompt_path):
  function load_jsonl (line 10) | def load_jsonl(file_path):
  function load_multiple_jsonl (line 17) | def load_multiple_jsonl(file_path_list):
  function list_to_string (line 23) | def list_to_string(l: list) -> str:
  function rule_to_string (line 27) | def rule_to_string(rule: list, sep_token = "<SEP>", bop = "<PATH>", eop ...
  function path_to_string (line 34) | def path_to_string(path: list) -> str:
  class InstructFormater (line 46) | class InstructFormater(object):
    method __init__ (line 47) | def __init__(self, prompt_path):
    method format (line 57) | def format(self, instruction, message):
Copy disabled (too large) Download .json
Condensed preview — 118 files, each showing path, character count, and a content snippet. Download the .json file for the full structured content (77,694K chars).
[
  {
    "path": ".gitignore",
    "chars": 3078,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": "README.md",
    "chars": 1044,
    "preview": "This is the code for **GNN-RAG: Graph Neural Retrieval for Large Language Modeling Reasoning**.\n\n\n![alt GNN-RAG: The GNN"
  },
  {
    "path": "gnn/.gitignore",
    "chars": 3152,
    "preview": "#LM_KGQA specific\ncheckpoint/\ncheckpoint/pretrain/\ndata/\npretrained_lms/\n\n# Byte-compiled / optimized / DLL files\n__pyca"
  },
  {
    "path": "gnn/README.md",
    "chars": 1262,
    "preview": "## Get Started\nWe have simple requirements in `requirements.txt`. You can always check if you can run the code immediate"
  },
  {
    "path": "gnn/dataset_load.py",
    "chars": 28915,
    "preview": "import json\nimport numpy as np\nimport re\nfrom tqdm import tqdm\nimport torch\nfrom collections import Counter\nimport rando"
  },
  {
    "path": "gnn/dataset_load_graft.py",
    "chars": 8621,
    "preview": "import json\nimport numpy as np\nimport re\nfrom tqdm import tqdm\nimport torch\nfrom collections import Counter\nimport rando"
  },
  {
    "path": "gnn/evaluate.py",
    "chars": 9583,
    "preview": "\nfrom tqdm import tqdm\ntqdm.monitor_iterval = 0\nimport torch\nimport numpy as np\nimport math, os\nimport json\nimport pickl"
  },
  {
    "path": "gnn/main.py",
    "chars": 1256,
    "preview": "import argparse\n\nfrom utils import create_logger\nimport torch\nimport numpy as np\nimport os\nimport time\n#from Models.ReaR"
  },
  {
    "path": "gnn/models/GraftNet/graftnet.py",
    "chars": 8232,
    "preview": "import torch\nimport numpy as np\nfrom torch.autograd import Variable\nimport torch.nn.functional as F\nimport torch.nn as n"
  },
  {
    "path": "gnn/models/NSM/nsm.py",
    "chars": 11802,
    "preview": "import torch\nimport numpy as np\nfrom torch.autograd import Variable\nimport torch.nn.functional as F\nimport torch.nn as n"
  },
  {
    "path": "gnn/models/ReaRev/rearev.py",
    "chars": 10649,
    "preview": "import torch\nimport numpy as np\nfrom torch.autograd import Variable\nimport torch.nn.functional as F\nimport torch.nn as n"
  },
  {
    "path": "gnn/models/base_model.py",
    "chars": 12609,
    "preview": "import torch\nimport numpy as np\nimport torch.nn as nn\n\nimport numpy as np\n\nVERY_SMALL_NUMBER = 1e-10\n\nclass BaseModel(to"
  },
  {
    "path": "gnn/modules/kg_reasoning/base_gnn.py",
    "chars": 4119,
    "preview": "import torch\nimport numpy as np\nfrom collections import defaultdict\n\nVERY_NEG_NUMBER = -100000000000\n\nclass BaseGNNLayer"
  },
  {
    "path": "gnn/modules/kg_reasoning/graft_gnn.py",
    "chars": 8234,
    "preview": "import torch\nimport numpy as np\nfrom torch.autograd import Variable\nimport torch.nn.functional as F\nimport torch.nn as n"
  },
  {
    "path": "gnn/modules/kg_reasoning/nsm_gnn.py",
    "chars": 5962,
    "preview": "import torch\nimport numpy as np\nfrom torch.autograd import Variable\nimport torch.nn.functional as F\nimport torch.nn as n"
  },
  {
    "path": "gnn/modules/kg_reasoning/reasongnn.py",
    "chars": 7136,
    "preview": "\nimport torch\nimport torch.nn.functional as F\nimport torch.nn as nn\n\n\nfrom .base_gnn import BaseGNNLayer\n\nVERY_NEG_NUMBE"
  },
  {
    "path": "gnn/modules/layer_init.py",
    "chars": 3008,
    "preview": "\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nVERY_NEG_NUMBER = -100000000000\nVERY_SMALL_NUMBER = "
  },
  {
    "path": "gnn/modules/query_update.py",
    "chars": 5739,
    "preview": "import torch\nimport numpy as np\nimport torch.nn.functional as F\nimport torch.nn as nn\n\nclass Fusion(nn.Module):\n    \"\"\"d"
  },
  {
    "path": "gnn/modules/question_encoding/base_encoder.py",
    "chars": 4327,
    "preview": "import torch\nimport torch.nn.functional as F\nimport torch.nn as nn\n\nVERY_SMALL_NUMBER = 1e-10\nVERY_NEG_NUMBER = -1000000"
  },
  {
    "path": "gnn/modules/question_encoding/bert_encoder.py",
    "chars": 4674,
    "preview": "\nimport torch.nn.functional as F\nimport torch.nn as nn\nVERY_SMALL_NUMBER = 1e-10\nVERY_NEG_NUMBER = -100000000000\n\n\nfrom "
  },
  {
    "path": "gnn/modules/question_encoding/lstm_encoder.py",
    "chars": 2053,
    "preview": "\n\nimport torch.nn as nn\nfrom utils import get_dict\nfrom .base_encoder import BaseInstruction\n\nVERY_SMALL_NUMBER = 1e-10\n"
  },
  {
    "path": "gnn/modules/question_encoding/tokenizers.py",
    "chars": 1943,
    "preview": "import re\nimport numpy as np\nfrom transformers import BertTokenizer\n\nclass LSTMTokenizer():\n    def __init__(self, word2"
  },
  {
    "path": "gnn/parsing.py",
    "chars": 6021,
    "preview": "import argparse\nimport sys\n\n\ndef bool_flag(v):\n    if v.lower() in ('yes', 'true', 't', 'y', '1'):\n        return True\n "
  },
  {
    "path": "gnn/requirements.txt",
    "chars": 77,
    "preview": "Base==1.0.4\nnumpy==1.19.5\ntorch==1.7.1+cu110\ntqdm==4.59.0\ntransformers==4.6.1"
  },
  {
    "path": "gnn/scripts/rearev_cwq.sh",
    "chars": 877,
    "preview": "\n###ReaRev+SBERT training\n# python main.py ReaRev --is_eval --load_experiment relbert-full_cwq-rearev-final.ckpt --entit"
  },
  {
    "path": "gnn/train_model.py",
    "chars": 11260,
    "preview": "\nfrom utils import create_logger\nimport time\nimport numpy as np\nimport os, math\n\nimport torch\nfrom torch.optim.lr_schedu"
  },
  {
    "path": "gnn/utils.py",
    "chars": 1177,
    "preview": "import logging\nimport os\n\n\ndef create_logger(args):\n    log_file = os.path.join(args.checkpoint_dir, args.experiment_nam"
  },
  {
    "path": "llm/.gitignore",
    "chars": 3095,
    "preview": "datasets/\n*json\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distr"
  },
  {
    "path": "llm/README.md",
    "chars": 816,
    "preview": "## Get Started\nWe have simple requirements in `requirements.txt`. You can always check if you can run the code immediate"
  },
  {
    "path": "llm/prompts/alpaca.txt",
    "chars": 224,
    "preview": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that"
  },
  {
    "path": "llm/prompts/general_prompt.txt",
    "chars": 22,
    "preview": "{instruction}\n\n{input}"
  },
  {
    "path": "llm/prompts/llama2.txt",
    "chars": 52,
    "preview": "[INST] <<SYS>>\n<</SYS>>\n{instruction}{input} [/INST]"
  },
  {
    "path": "llm/prompts/llama2_predict.txt",
    "chars": 54,
    "preview": "[INST] <<SYS>>\n<</SYS>>\n{instruction}\n\n{input} [/INST]"
  },
  {
    "path": "llm/requirements.txt",
    "chars": 151,
    "preview": "openai==0.27.9\ntransformers==4.32.0\ntrl==0.7.1\npeft==0.5.0\ndatasets==2.14.4\naccelerate==0.22.0\npybind11==2.11.1\nnetworkx"
  },
  {
    "path": "llm/results/KGQA-GNN-RAG/rearev-lmsr/RoG-cwq/RoG/test/results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/False/args.txt",
    "chars": 543,
    "preview": "{\n  \"data_path\": \"rmanluo\",\n  \"d\": \"RoG-cwq\",\n  \"split\": \"test\",\n  \"predict_path\": \"results/KGQA-GO\",\n  \"model_name\": \"R"
  },
  {
    "path": "llm/results/KGQA-GNN-RAG/rearev-lmsr/RoG-cwq/RoG/test/results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/False/detailed_eval_result.jsonl",
    "chars": 925418,
    "preview": "{\"id\": \"WebQTest-832_c334509bb5e02cacae1ba2e80c176499\", \"prediction\": [\"2014 World Series\", \"2010 World Series\", \"2012 W"
  },
  {
    "path": "llm/results/KGQA-GNN-RAG/rearev-lmsr/RoG-cwq/RoG/test/results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/False/eval_result.txt",
    "chars": 152,
    "preview": "Accuracy: 62.30464023153203 Hit: 66.24185783064287 Hit1: 61.31407533276692 F1: 58.963107192119594 Precision: 59.90911544"
  },
  {
    "path": "llm/results/KGQA-GNN-RAG/rearev-lmsr/RoG-cwq/RoG/test/results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/False/predictions.jsonl",
    "chars": 10081521,
    "preview": "{\"id\": \"WebQTest-832_c334509bb5e02cacae1ba2e80c176499\", \"question\": \"Lou Seal is the mascot for the team that last won t"
  },
  {
    "path": "llm/results/KGQA-GNN-RAG/rearev-lmsr/RoG-webqsp/RoG/test/results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/False/args.txt",
    "chars": 549,
    "preview": "{\n  \"data_path\": \"rmanluo\",\n  \"d\": \"RoG-webqsp\",\n  \"split\": \"test\",\n  \"predict_path\": \"results/KGQA-GO\",\n  \"model_name\":"
  },
  {
    "path": "llm/results/KGQA-GNN-RAG/rearev-lmsr/RoG-webqsp/RoG/test/results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/False/detailed_eval_result.jsonl",
    "chars": 870257,
    "preview": "{\"id\": \"WebQTest-0\", \"prediction\": [\"Jamaican English\", \"Jamaican Creole English Language\"], \"ground_truth\": [\"Jamaican "
  },
  {
    "path": "llm/results/KGQA-GNN-RAG/rearev-lmsr/RoG-webqsp/RoG/test/results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/False/eval_result.txt",
    "chars": 149,
    "preview": "Accuracy: 72.63363902772625 Hit: 85.012285012285 Hit1: 80.28255528255528 F1: 71.53705252039563 Precision: 78.34699702921"
  },
  {
    "path": "llm/results/KGQA-GNN-RAG/rearev-lmsr/RoG-webqsp/RoG/test/results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/False/predictions.jsonl",
    "chars": 3224996,
    "preview": "{\"id\": \"WebQTest-0\", \"question\": \"what does jamaican people speak\", \"prediction\": \"Jamaican English\\nJamaican Creole Eng"
  },
  {
    "path": "llm/results/KGQA-GNN-RAG/rearev-lmsr/RoG-webqsp/llama2-chat-hf/test/no_rule/False/args.txt",
    "chars": 716,
    "preview": "{\n  \"data_path\": \"rmanluo\",\n  \"d\": \"RoG-webqsp\",\n  \"split\": \"test\",\n  \"predict_path\": \"results/KGQA-G-lmsr\",\n  \"model_na"
  },
  {
    "path": "llm/results/KGQA-GNN-RAG/rearev-lmsr/RoG-webqsp/llama2-chat-hf/test/no_rule/False/detailed_eval_result.jsonl",
    "chars": 1431238,
    "preview": "{\"id\": \"WebQTest-0\", \"prediction\": [\"Based on the reasoning paths provided, the possible answers to the question \\\"what "
  },
  {
    "path": "llm/results/KGQA-GNN-RAG/rearev-lmsr/RoG-webqsp/llama2-chat-hf/test/no_rule/False/eval_result.txt",
    "chars": 152,
    "preview": "Accuracy: 74.90245578760596 Hit: 86.79361179361179 Hit1: 6.142506142506143 F1: 36.256619819061264 Precision: 31.38534839"
  },
  {
    "path": "llm/results/KGQA-GNN-RAG/rearev-lmsr/RoG-webqsp/llama2-chat-hf/test/no_rule/False/predictions.jsonl",
    "chars": 4595076,
    "preview": "{\"id\": \"WebQTest-0\", \"question\": \"what does jamaican people speak\", \"prediction\": \"Based on the reasoning paths provided"
  },
  {
    "path": "llm/results/KGQA-GNN-RAG/rearev-sbert/RoG-cwq/RoG/test/results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/False/args.txt",
    "chars": 548,
    "preview": "{\n  \"data_path\": \"rmanluo\",\n  \"d\": \"RoG-cwq\",\n  \"split\": \"test\",\n  \"predict_path\": \"results/KGQA-G-sbert\",\n  \"model_name"
  },
  {
    "path": "llm/results/KGQA-GNN-RAG/rearev-sbert/RoG-cwq/RoG/test/results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/False/detailed_eval_result.jsonl",
    "chars": 923622,
    "preview": "{\"id\": \"WebQTest-832_c334509bb5e02cacae1ba2e80c176499\", \"prediction\": [\"2014 World Series\", \"2010 World Series\", \"2012 W"
  },
  {
    "path": "llm/results/KGQA-GNN-RAG/rearev-sbert/RoG-cwq/RoG/test/results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/False/eval_result.txt",
    "chars": 151,
    "preview": "Accuracy: 62.76966324662397 Hit: 66.80826961200793 Hit1: 61.73888416879071 F1: 59.42654256015807 Precision: 60.543456243"
  },
  {
    "path": "llm/results/KGQA-GNN-RAG/rearev-sbert/RoG-cwq/RoG/test/results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/False/predictions.jsonl",
    "chars": 10066656,
    "preview": "{\"id\": \"WebQTest-832_c334509bb5e02cacae1ba2e80c176499\", \"question\": \"Lou Seal is the mascot for the team that last won t"
  },
  {
    "path": "llm/results/KGQA-GNN-RAG/rearev-sbert/RoG-webqsp/RoG/test/results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/False/args.txt",
    "chars": 554,
    "preview": "{\n  \"data_path\": \"rmanluo\",\n  \"d\": \"RoG-webqsp\",\n  \"split\": \"test\",\n  \"predict_path\": \"results/KGQA-G-sbert\",\n  \"model_n"
  },
  {
    "path": "llm/results/KGQA-GNN-RAG/rearev-sbert/RoG-webqsp/RoG/test/results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/False/detailed_eval_result.jsonl",
    "chars": 888293,
    "preview": "{\"id\": \"WebQTest-0\", \"prediction\": [\"Jamaican English\", \"Jamaican Creole English Language\"], \"ground_truth\": [\"Jamaican "
  },
  {
    "path": "llm/results/KGQA-GNN-RAG/rearev-sbert/RoG-webqsp/RoG/test/results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/False/eval_result.txt",
    "chars": 151,
    "preview": "Accuracy: 73.76651533728501 Hit: 85.68796068796068 Hit1: 80.58968058968058 F1: 71.28257365406769 Precision: 77.185064493"
  },
  {
    "path": "llm/results/KGQA-GNN-RAG/rearev-sbert/RoG-webqsp/RoG/test/results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/False/predictions.jsonl",
    "chars": 3767245,
    "preview": "{\"id\": \"WebQTest-0\", \"question\": \"what does jamaican people speak\", \"prediction\": \"Jamaican English\\nJamaican Creole Eng"
  },
  {
    "path": "llm/results/KGQA-GNN-RAG-RA/rearev-lmsr/RoG-cwq/RoG/test/results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/False/args.txt",
    "chars": 542,
    "preview": "{\n  \"data_path\": \"rmanluo\",\n  \"d\": \"RoG-cwq\",\n  \"split\": \"test\",\n  \"predict_path\": \"results/KGQA-G\",\n  \"model_name\": \"Ro"
  },
  {
    "path": "llm/results/KGQA-GNN-RAG-RA/rearev-lmsr/RoG-cwq/RoG/test/results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/False/detailed_eval_result.jsonl",
    "chars": 918794,
    "preview": "{\"id\": \"WebQTest-832_c334509bb5e02cacae1ba2e80c176499\", \"prediction\": [\"2014 World Series\", \"2012 World Series\", \"2010 W"
  },
  {
    "path": "llm/results/KGQA-GNN-RAG-RA/rearev-lmsr/RoG-cwq/RoG/test/results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/False/eval_result.txt",
    "chars": 153,
    "preview": "Accuracy: 64.56916340016058 Hit: 67.96941376380629 Hit1: 63.041631265930334 F1: 60.94993975622944 Precision: 61.13609531"
  },
  {
    "path": "llm/results/KGQA-GNN-RAG-RA/rearev-lmsr/RoG-webqsp/RoG/test/results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/False/args.txt",
    "chars": 548,
    "preview": "{\n  \"data_path\": \"rmanluo\",\n  \"d\": \"RoG-webqsp\",\n  \"split\": \"test\",\n  \"predict_path\": \"results/KGQA-G\",\n  \"model_name\": "
  },
  {
    "path": "llm/results/KGQA-GNN-RAG-RA/rearev-lmsr/RoG-webqsp/RoG/test/results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/False/detailed_eval_result.jsonl",
    "chars": 930571,
    "preview": "{\"id\": \"WebQTest-0\", \"prediction\": [\"Jamaican English\", \"Jamaican Creole English Language\"], \"ground_truth\": [\"Jamaican "
  },
  {
    "path": "llm/results/KGQA-GNN-RAG-RA/rearev-lmsr/RoG-webqsp/RoG/test/results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/False/eval_result.txt",
    "chars": 148,
    "preview": "Accuracy: 82.1694156927992 Hit: 89.8034398034398 Hit1: 82.37100737100737 F1: 73.36257924133025 Precision: 74.26870476680"
  },
  {
    "path": "llm/results/KGQA-GNN-RAG-RA/rearev-lmsr/RoG-webqsp/RoG/test/results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/False/predictions.jsonl",
    "chars": 5193412,
    "preview": "{\"id\": \"WebQTest-0\", \"question\": \"what does jamaican people speak\", \"prediction\": \"Jamaican English\\nJamaican Creole Eng"
  },
  {
    "path": "llm/results/KGQA-GNN-RAG-RA/rearev-sbert/RoG-cwq/RoG/test/results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/False/args.txt",
    "chars": 544,
    "preview": "{\n  \"data_path\": \"rmanluo\",\n  \"d\": \"RoG-cwq\",\n  \"split\": \"test\",\n  \"predict_path\": \"results/KGQA-GOS\",\n  \"model_name\": \""
  },
  {
    "path": "llm/results/KGQA-GNN-RAG-RA/rearev-sbert/RoG-cwq/RoG/test/results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/False/detailed_eval_result.jsonl",
    "chars": 943356,
    "preview": "{\"id\": \"WebQTest-832_c334509bb5e02cacae1ba2e80c176499\", \"prediction\": [\"2014 World Series\", \"2010 World Series\", \"2012 W"
  },
  {
    "path": "llm/results/KGQA-GNN-RAG-RA/rearev-sbert/RoG-cwq/RoG/test/results_gen_rule_path_RoG-cwq_RoG_test_predictions_3_False_jsonl/False/eval_result.txt",
    "chars": 153,
    "preview": "Accuracy: 65.00780922839806 Hit: 68.64910790144435 Hit1: 62.786745964316054 F1: 60.448091645187766 Precision: 60.4990589"
  },
  {
    "path": "llm/results/KGQA-GNN-RAG-RA/rearev-sbert/RoG-webqsp/RoG/test/results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/False/args.txt",
    "chars": 549,
    "preview": "{\n  \"data_path\": \"rmanluo\",\n  \"d\": \"RoG-webqsp\",\n  \"split\": \"test\",\n  \"predict_path\": \"results/KGQA-G2\",\n  \"model_name\":"
  },
  {
    "path": "llm/results/KGQA-GNN-RAG-RA/rearev-sbert/RoG-webqsp/RoG/test/results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/False/detailed_eval_result.jsonl",
    "chars": 956024,
    "preview": "{\"id\": \"WebQTest-0\", \"prediction\": [\"Jamaican English\", \"Jamaican Creole English Language\"], \"ground_truth\": [\"Jamaican "
  },
  {
    "path": "llm/results/KGQA-GNN-RAG-RA/rearev-sbert/RoG-webqsp/RoG/test/results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/False/eval_result.txt",
    "chars": 147,
    "preview": "Accuracy: 82.9532337788689 Hit: 90.72481572481573 Hit1: 82.8009828009828 F1: 73.4900197824918 Precision: 73.233417222905"
  },
  {
    "path": "llm/results/KGQA-GNN-RAG-RA/rearev-sbert/RoG-webqsp/RoG/test/results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_jsonl/False/predictions.jsonl",
    "chars": 5618513,
    "preview": "{\"id\": \"WebQTest-0\", \"question\": \"what does jamaican people speak\", \"prediction\": \"Jamaican English\\nJamaican Creole Eng"
  },
  {
    "path": "llm/results/gen_rule_path/RoG-cwq/RoG/test/predictions_2_False.jsonl",
    "chars": 3792573,
    "preview": "{\"id\": \"WebQTest-832_c334509bb5e02cacae1ba2e80c176499\", \"question\": \"Lou Seal is the mascot for the team that last won t"
  },
  {
    "path": "llm/results/gen_rule_path/RoG-cwq/RoG/test/predictions_3_False.jsonl",
    "chars": 4442068,
    "preview": "{\"id\": \"WebQTest-832_c334509bb5e02cacae1ba2e80c176499\", \"question\": \"Lou Seal is the mascot for the team that last won t"
  },
  {
    "path": "llm/results/gen_rule_path/RoG-webqsp/RoG/test/predictions_1_False.jsonl",
    "chars": 1073009,
    "preview": "{\"id\": \"WebQTest-0\", \"question\": \"what does jamaican people speak\", \"prediction\": [[\"location.country.languages_spoken\"]"
  },
  {
    "path": "llm/results/gen_rule_path/RoG-webqsp/RoG/test/predictions_2_False.jsonl",
    "chars": 1402094,
    "preview": "{\"id\": \"WebQTest-0\", \"question\": \"what does jamaican people speak\", \"prediction\": [[\"location.country.languages_spoken\"]"
  },
  {
    "path": "llm/results/gen_rule_path/RoG-webqsp/RoG/test/predictions_3_False.jsonl",
    "chars": 1671687,
    "preview": "{\"id\": \"WebQTest-0\", \"question\": \"what does jamaican people speak\", \"prediction\": [[\"location.country.languages_spoken\"]"
  },
  {
    "path": "llm/results/gen_rule_path/RoG-webqsp/RoG/train/predictions_1_False.jsonl",
    "chars": 1954524,
    "preview": "{\"id\": \"WebQTrn-0\", \"question\": \"what is the name of justin bieber brother\", \"prediction\": [[\"people.person.nationality\""
  },
  {
    "path": "llm/results/gen_rule_path/RoG-webqsp/RoG/train/predictions_3_False.jsonl",
    "chars": 3007418,
    "preview": "{\"id\": \"WebQTrn-0\", \"question\": \"what is the name of justin bieber brother\", \"prediction\": [[\"people.person.nationality\""
  },
  {
    "path": "llm/results/gen_rule_path/RoG-webqsp/RoG/validation/predictions_3_False.jsonl",
    "chars": 260362,
    "preview": "{\"id\": \"WebQTrn-9\", \"question\": \"how old is sacha baron cohen\", \"prediction\": [[\"people.person.nationality\"], [\"people.p"
  },
  {
    "path": "llm/results/gnn/RoG-cwq/rearev-lmsr/test.info",
    "chars": 1552571,
    "preview": "{\"question\": \"lou seal is the mascot for the team that last won the world series when ? \", \"0\": {}, \"1\": {}, \"answers\": "
  },
  {
    "path": "llm/results/gnn/RoG-cwq/rearev-sbert/test.info",
    "chars": 1843012,
    "preview": "{\"question\": \"lou seal is the mascot for the team that last won the world series when ? \", \"0\": {}, \"1\": {}, \"answers\": "
  },
  {
    "path": "llm/results/gnn/RoG-webqsp/rearev-lmsr/test.info",
    "chars": 861379,
    "preview": "{\"question\": \"what does jamaican people speak \", \"0\": {}, \"1\": {}, \"answers\": [\"m.01428y\", \"m.04ygk0\"], \"precison\": 1.0,"
  },
  {
    "path": "llm/results/gnn/RoG-webqsp/rearev-sbert/test.info",
    "chars": 970807,
    "preview": "{\"question\": \"what does jamaican people speak \", \"0\": {}, \"1\": {}, \"2\": {}, \"answers\": [\"m.01428y\", \"m.04ygk0\"], \"precis"
  },
  {
    "path": "llm/scripts/evaluate_multi_hop.sh",
    "chars": 180,
    "preview": "d='results/KGQA-GNN-RAG/rearev-sbert/RoG-webqsp/RoG/test/results_gen_rule_path_RoG-webqsp_RoG_test_predictions_3_False_j"
  },
  {
    "path": "llm/scripts/interpretable_example.py",
    "chars": 1383,
    "preview": "from transformers import pipeline, AutoTokenizer\nimport torch\n\nMODEL_PATH_OR_NAME=\"rmanluo/RoG\"\n\ntokenizer = AutoTokeniz"
  },
  {
    "path": "llm/scripts/planning.sh",
    "chars": 389,
    "preview": "SPLIT=\"test\"\nDATASET_LIST=\"RoG-webqsp\"\nMODEL_NAME=RoG\nMODEL_PATH=rmanluo/RoG\n\nBEAM_LIST=\"3\" # \"1 2 3 4 5\"\nfor DATASET in"
  },
  {
    "path": "llm/scripts/plug-and-play.sh",
    "chars": 1188,
    "preview": "SPLIT=\"test\"\nDATASET_LIST=\"RoG-cwq\"\nBEAM_LIST=\"3\" # \"1 2 3 4 5\"\nMODEL_LIST=\"llama2-chat-hf\"\n#PROMPT_LIST=\"prompts/genera"
  },
  {
    "path": "llm/scripts/rag-reasoning.sh",
    "chars": 1693,
    "preview": "\nSPLIT=\"test\"\nDATASET_LIST=\"RoG-webqsp\"\nMODEL_NAME=RoG\nPROMPT_PATH=prompts/llama2_predict.txt\nBEAM_LIST=\"3\" # \"1 2 3 4 5"
  },
  {
    "path": "llm/scripts/train.sh",
    "chars": 1217,
    "preview": "MODEL_PATH=meta-llama/Llama-2-7b-chat-hf\nDATASET_LIST=\"datasets/joint_training/align/cwq/cwq_train.jsonl datasets/joint_"
  },
  {
    "path": "llm/src/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "llm/src/align_kg/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "llm/src/align_kg/build_align_qa_dataset.py",
    "chars": 2367,
    "preview": "import sys\nimport os\nsys.path.append(os.path.dirname(os.path.realpath(__file__)) + \"/..\")\nimport argparse\nimport os\nimpo"
  },
  {
    "path": "llm/src/align_kg/data_loader.py",
    "chars": 1570,
    "preview": "import string\nimport sys\nimport os\nsys.path.append(os.path.dirname(os.path.realpath(__file__)) + \"/..\")\n\nfrom datasets i"
  },
  {
    "path": "llm/src/joint_training/generate_explanation_results.py",
    "chars": 6159,
    "preview": "import sys\nimport os\nsys.path.append(os.path.dirname(os.path.realpath(__file__)) + \"/..\")\nfrom utils import *\nimport dat"
  },
  {
    "path": "llm/src/joint_training/joint_finetuning.py",
    "chars": 6396,
    "preview": "import sys\nimport os\n\nimport torch\n\nsys.path.append(os.path.dirname(os.path.realpath(__file__)) + \"/..\")\nimport os\nfrom "
  },
  {
    "path": "llm/src/joint_training/preprocess_align.py",
    "chars": 2121,
    "preview": "import sys\nimport os\nsys.path.append(os.path.dirname(os.path.realpath(__file__)) + \"/..\")\nfrom utils import *\nfrom trans"
  },
  {
    "path": "llm/src/joint_training/preprocess_qa.py",
    "chars": 2484,
    "preview": "import sys\nimport os\nsys.path.append(os.path.dirname(os.path.realpath(__file__)) + \"/..\")\nfrom utils import *\nfrom trans"
  },
  {
    "path": "llm/src/llms/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "llm/src/llms/language_models/__init__.py",
    "chars": 643,
    "preview": "from .chatgpt import ChatGPT\nfrom .alpaca import Alpaca\nfrom .longchat.longchat import Longchat\nfrom .base_language_mode"
  },
  {
    "path": "llm/src/llms/language_models/alpaca.py",
    "chars": 1504,
    "preview": "from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer\nimport torch\nfrom .base_language_model import Bas"
  },
  {
    "path": "llm/src/llms/language_models/base_language_model.py",
    "chars": 851,
    "preview": "import collections\n\n\nclass BaseLanguageModel(object):\n    \"\"\"\n    Base lanuage model. Define how to generate sentence by"
  },
  {
    "path": "llm/src/llms/language_models/chatgpt.py",
    "chars": 2919,
    "preview": "import time\nimport os\nimport openai\nfrom .base_language_model import BaseLanguageModel\nimport dotenv\nimport tiktoken\ndot"
  },
  {
    "path": "llm/src/llms/language_models/flan_t5.py",
    "chars": 1431,
    "preview": "from transformers import pipeline, AutoModel, AutoTokenizer\nimport torch\nfrom .base_language_model import BaseLanguageMo"
  },
  {
    "path": "llm/src/llms/language_models/llama.py",
    "chars": 1797,
    "preview": "from transformers import pipeline, AutoTokenizer\nimport torch\nfrom .base_language_model import BaseLanguageModel\nfrom tr"
  },
  {
    "path": "llm/src/llms/language_models/longchat/llama_condense_monkey_patch.py",
    "chars": 2838,
    "preview": "# code adapted from https://huggingface.co/kaiokendev/superhot-13b-8k-no-rlhf-test/blob/main/llama_rope_scaled_monkey_pa"
  },
  {
    "path": "llm/src/llms/language_models/longchat/llama_flash_attn_monkey_patch.py",
    "chars": 3982,
    "preview": "from typing import List, Optional, Tuple\n\nimport torch\nfrom torch import nn\n\nimport transformers\nfrom transformers.model"
  },
  {
    "path": "llm/src/llms/language_models/longchat/longchat.py",
    "chars": 2375,
    "preview": "from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer\nimport torch\nfrom ..base_language_model import Ba"
  },
  {
    "path": "llm/src/llms/llm_proxy.py",
    "chars": 2080,
    "preview": "\nimport openai\nfrom dotenv import load_dotenv\nimport os\nimport time\nfrom .start_fastchat_api import start_fastchat_api\nc"
  },
  {
    "path": "llm/src/llms/start_fastchat_api.py",
    "chars": 2274,
    "preview": "import subprocess\nimport argparse\nimport sys\nimport atexit\nimport signal\nprocesses = []\n\ndef terminate_process():\n    fo"
  },
  {
    "path": "llm/src/qa_prediction/build_qa_input.py",
    "chars": 8094,
    "preview": "import sys\nimport os\nsys.path.append(os.path.dirname(os.path.realpath(__file__)) + \"/..\")\nimport utils\nimport random\nfro"
  },
  {
    "path": "llm/src/qa_prediction/evaluate_multi_hop.py",
    "chars": 5426,
    "preview": "import argparse\nimport glob\nimport json\nimport os\nimport re\nimport string\nfrom sklearn.metrics import precision_score\n\ni"
  },
  {
    "path": "llm/src/qa_prediction/evaluate_results.py",
    "chars": 5769,
    "preview": "import argparse\nimport glob\nimport json\nimport os\nimport re\nimport string\nfrom sklearn.metrics import precision_score\n\ni"
  },
  {
    "path": "llm/src/qa_prediction/gen_rule_path.py",
    "chars": 7237,
    "preview": "import json\nimport sys\nimport os\n\nsys.path.append(os.path.dirname(os.path.realpath(__file__)) + \"/..\")\nimport argparse\ni"
  },
  {
    "path": "llm/src/qa_prediction/predict_answer.py",
    "chars": 10921,
    "preview": "import sys\nimport os\n\nsys.path.append(os.path.dirname(os.path.realpath(__file__)) + \"/..\")\nimport utils\nimport argparse\n"
  },
  {
    "path": "llm/src/utils/__init__.py",
    "chars": 77,
    "preview": "from .graph_utils import *\nfrom .utils import *\nfrom .training_utils import *"
  },
  {
    "path": "llm/src/utils/graph_utils.py",
    "chars": 5004,
    "preview": "import networkx as nx\nfrom collections import deque\n#import walker\n\nimport json\nwith open('entities_names.json') as f:\n "
  },
  {
    "path": "llm/src/utils/merge_peft.py",
    "chars": 688,
    "preview": "from dataclasses import dataclass, field\nfrom typing import Optional\nfrom transformers import HfArgumentParser\nimport to"
  },
  {
    "path": "llm/src/utils/training_utils.py",
    "chars": 1126,
    "preview": "import transformers\nfrom typing import Dict, List\n\ndef smart_tokenizer_and_embedding_resize(\n    new_tokens: List[str],\n"
  },
  {
    "path": "llm/src/utils/utils.py",
    "chars": 1514,
    "preview": "\nimport json\nimport string\n\ndef read_prompt(prompt_path):\n    with open(prompt_path, 'r') as f:\n        prompt_template "
  }
]

// ... and 2 more files (download for full content)

About this extraction

This page contains the full source code of the cmavro/GNN-RAG GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 118 files (106.0 MB), approximately 18.6M tokens, and a symbol index with 298 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!