Full Code of hugochan/BAMnet for AI

master b693616a9241 cached
34 files
169.2 KB
43.9k tokens
130 symbols
1 requests
Download .txt
Repository: hugochan/BAMnet
Branch: master
Commit: b693616a9241
Files: 34
Total size: 169.2 KB

Directory structure:
gitextract_pbgz21hk/

├── .gitignore
├── LICENSE
├── README.md
├── requirements.txt
└── src/
    ├── build_all_data.py
    ├── build_pretrained_w2v.py
    ├── config/
    │   ├── bamnet_webq.yml
    │   └── entnet_webq.yml
    ├── core/
    │   ├── __init__.py
    │   ├── bamnet/
    │   │   ├── __init__.py
    │   │   ├── bamnet.py
    │   │   ├── ent_modules.py
    │   │   ├── entnet.py
    │   │   ├── modules.py
    │   │   └── utils.py
    │   ├── build_data/
    │   │   ├── __init__.py
    │   │   ├── build_all.py
    │   │   ├── build_data.py
    │   │   ├── freebase.py
    │   │   ├── utils.py
    │   │   └── webquestions.py
    │   ├── config.py
    │   └── utils/
    │       ├── __init__.py
    │       ├── freebase_utils.py
    │       ├── generic_utils.py
    │       ├── metrics.py
    │       └── utils.py
    ├── joint_test.py
    ├── run_freebase.py
    ├── run_webquestions.py
    ├── test.py
    ├── test_entnet.py
    ├── train.py
    └── train_entnet.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
data/
runs/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/


================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
# BAMnet


Code & data accompanying the NAACL2019 paper ["Bidirectional Attentive Memory Networks for Question Answering over Knowledge Bases"](https://arxiv.org/abs/1903.02188)


## Get started


### Prerequisites
This code is written in python 3. You will need to install a few python packages in order to run the code.
We recommend you to use `virtualenv` to manage your python packages and environments.
Please take the following steps to create a python virtual environment.

* If you have not installed `virtualenv`, install it with ```pip install virtualenv```.
* Create a virtual environment with ```virtualenv venv```.
* Activate the virtual environment with `source venv/bin/activate`.
* Install the package requirements with `pip install -r requirements.txt`.




### Run the KBQA system

* Download the preprocessed data from [here](https://1drv.ms/u/s!AjiSpuwVTt09gSE2niFGjdIVsqA7?e=PEf6sT) and put the data folder under the root directory.


* Create a folder (e.g., `runs/WebQ/`) to save model checkpoint. You can download the pretrained models from [here](https://1drv.ms/u/s!AjiSpuwVTt09gSLcnrp0GyKtpWBg?e=DtqYt8). (Note: if you cannot access the above data and pretrained models, please download from [here](http://academic.hugochan.net/download/BAMnet-WebQ.zip).)


* Please modify the config files in the `src/config/` folder to suit your needs. Note that you can start with modifying only the data folder (e.g., `data_dir`, `model_file`, `pre_word2vec`) and vocab size (e.g., `vocab_size`, `num_ent_types`, `num_relations`), and leave other hyperparameters as they are.


* Go to the `BAMnet/src` folder, train the BAMnet model

	```
	python train.py -config config/bamnet_webq.yml
	```
	

*  Test the BAMnet model (with ground-truth topic entity)
	
	```
	python test.py -config config/bamnet_webq.yml
	```

*  Train the topic entity predictor

	```
	python train_entnet.py -config config/entnet_webq.yml
	```

*  Test the topic entity predictor

	```
	python test_entnet.py -config config/entnet_webq.yml
	```

*  Test the whole system (BAMnet + topic entity predictor)

	```
	python joint_test.py -bamnet_config config/bamnet_webq.yml -entnet_config config/entnet_webq.yml -raw_data ../data/WebQ
	```



### Preprocess the dataset on your own

* Go to the `BAMnet/src` folder, to prepare data for the BAMnet model, run the following cmd:

	```
	python build_all_data.py -data_dir ../data/WebQ -fb_dir ../data/WebQ -out_dir ../data/WebQ
	```
	
* To prepare data for the topic entity predictor model, run the following cmd:

	```
	python build_all_data.py -dtype ent -data_dir ../data/WebQ -fb_dir ../data/WebQ -out_dir ../data/WebQ
	```


 Note that in the message printed out, your will see some data statistics such as `vocab_size`, `num_ent_types `, `num_relations`. These numbers will be used later when modifying the config files.


* Download the pretrained Glove word ebeddings [glove.840B.300d.zip](http://nlp.stanford.edu/data/wordvecs/glove.840B.300d.zip).

* Unzip the file and convert glove format to word2vec format using the following cmd:

	```
	python -m gensim.scripts.glove2word2vec --input glove.840B.300d.txt --output glove.840B.300d.w2v
	```

* Fetch the pretrained Glove vectors for our vocabulary.

	```
	python build_pretrained_w2v.py -emb glove.840B.300d.w2v -data_dir ../data/WebQ -out ../data/WebQ/glove_pretrained_300d_w2v.npy -emb_size 300
	```




## Architecture

<center><img src="images/overall_arch.png"/></center>



## Experiment results on WebQuestions


### Results on WebQuestions test set. Bold: best in-category performance. 


<center><img src="images/results.png" width="300" height="500"/></center>






### Predicted answers of BAMnet w/ and w/o bidirectional attention on the WebQuestions test set

![pred_examples](images/pred_examples.png "pred_examples")



### Attention heatmap generated by the reasoning module

![attn_heatmap](images/attn_heatmap.png "attn_heatmap")





## Reference

If you found this code useful, please consider citing the following paper:

Yu Chen, Lingfei Wu, Mohammed J. Zaki. **"Bidirectional Attentive Memory Networks for Question Answering over Knowledge Bases."** *In Proc. 2019 Annual Conference of the North American Chapter of the Association for Computational Linguistics (NAACL-HLT2019). June 2019.*


	@article{chen2019bidirectional,
	  title={Bidirectional Attentive Memory Networks for Question Answering over Knowledge Bases},
	  author={Chen, Yu and Wu, Lingfei and Zaki, Mohammed J},
	  journal={arXiv preprint arXiv:1903.02188},
	  year={2019}
	}


================================================
FILE: requirements.txt
================================================
rapidfuzz==0.3.0
gensim==3.5.0
nltk==3.4.5
numpy==1.14.5
PyYAML==5.1
torch==0.4.1



================================================
FILE: src/build_all_data.py
================================================
'''
Created on Oct, 2017

@author: hugo

'''
import argparse

from core.build_data.build_data import build_vocab, build_data, build_seed_ent_data
from core.utils.utils import *
from core.build_data import utils as build_utils


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-data_dir', '--data_dir', required=True, type=str, help='path to the data dir')
    parser.add_argument('-fb_dir', '--fb_dir', required=True, type=str, help='path to the freebase dir')
    parser.add_argument('-out_dir', '--out_dir', required=True, type=str, help='path to the output dir')
    parser.add_argument('-dtype', '--data_type', default='qa', type=str, help='data type')
    parser.add_argument('-min_freq', '--min_freq', default=1, type=int, help='min word vocab freq')
    parser.add_argument('-topn', '--topn', default=15, type=int, help='top n candidates')
    args = parser.parse_args()

    train_data = load_ndjson(os.path.join(args.data_dir, 'raw_train.json'))
    valid_data = load_ndjson(os.path.join(args.data_dir, 'raw_valid.json'))
    test_data = load_ndjson(os.path.join(args.data_dir, 'raw_test.json'))
    freebase = load_ndjson(os.path.join(args.fb_dir, 'freebase_full.json'), return_type='dict')

    if not (os.path.exists(os.path.join(args.out_dir, 'entity2id.json')) and \
        os.path.exists(os.path.join(args.out_dir, 'entityType2id.json')) and \
        os.path.exists(os.path.join(args.out_dir, 'relation2id.json')) and \
        os.path.exists(os.path.join(args.out_dir, 'vocab2id.json'))):

        used_fbkeys = set()
        for each in train_data + valid_data:
            used_fbkeys.update(each['freebaseKeyCands'][:args.topn])
        print('# of used_fbkeys: {}'.format(len(used_fbkeys)))

        entity2id, entityType2id, relation2id, vocab2id = build_vocab(train_data + valid_data, freebase, used_fbkeys, min_freq=args.min_freq)
        dump_json(entity2id, os.path.join(args.out_dir, 'entity2id.json'))
        dump_json(entityType2id, os.path.join(args.out_dir, 'entityType2id.json'))
        dump_json(relation2id, os.path.join(args.out_dir, 'relation2id.json'))
        dump_json(vocab2id, os.path.join(args.out_dir, 'vocab2id.json'))
    else:
        entity2id = load_json(os.path.join(args.out_dir, 'entity2id.json'))
        entityType2id = load_json(os.path.join(args.out_dir, 'entityType2id.json'))
        relation2id = load_json(os.path.join(args.out_dir, 'relation2id.json'))
        vocab2id = load_json(os.path.join(args.out_dir, 'vocab2id.json'))
        print('Using pre-built vocabs stored in %s' % args.out_dir)

    if args.data_type == 'qa':
        train_vec = build_data(train_data, freebase, entity2id, entityType2id, relation2id, vocab2id)
        valid_vec = build_data(valid_data, freebase, entity2id, entityType2id, relation2id, vocab2id)
        test_vec = build_data(test_data, freebase, entity2id, entityType2id, relation2id, vocab2id)
        dump_json(train_vec, os.path.join(args.out_dir, 'train_vec.json'))
        dump_json(valid_vec, os.path.join(args.out_dir, 'valid_vec.json'))
        dump_json(test_vec, os.path.join(args.out_dir, 'test_vec.json'))
        print('Saved data to {}'.format(os.path.join(args.out_dir, 'train(valid, or test)_vec.json')))
    else:
        train_vec = build_seed_ent_data(train_data, freebase, entity2id, entityType2id, relation2id, vocab2id, args.topn, dtype='train')
        valid_vec = build_seed_ent_data(valid_data, freebase, entity2id, entityType2id, relation2id, vocab2id, args.topn, dtype='valid')
        test_vec = build_seed_ent_data(test_data, freebase, entity2id, entityType2id, relation2id, vocab2id, args.topn, dtype='test')
        dump_json(train_vec, os.path.join(args.out_dir, 'train_ent_vec.json'))
        dump_json(valid_vec, os.path.join(args.out_dir, 'valid_ent_vec.json'))
        dump_json(test_vec, os.path.join(args.out_dir, 'test_ent_vec.json'))
        print('Saved data to {}'.format(os.path.join(args.out_dir, 'train(valid, or test)_ent_vec.json')))

    # Mark the data as built.
    build_utils.mark_done(args.out_dir)


================================================
FILE: src/build_pretrained_w2v.py
================================================
'''
Created on Oct, 2017

@author: hugo

'''
import argparse
import os

from core.utils.utils import load_json
from core.utils.generic_utils import dump_embeddings


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-emb', '--embed_path', required=True, type=str, help='path to the pretrained word embeddings')
    parser.add_argument('-data_dir', '--data_dir', required=True, type=str, help='path to the data dir')
    parser.add_argument('-out', '--out_path', required=True, type=str, help='path to the output path')
    parser.add_argument('-emb_size', '--emb_size', required=True, type=int, help='embedding size')
    parser.add_argument('--binary', action='store_true', help='flag: binary file')
    args = parser.parse_args()

    vocab_dict = load_json(os.path.join(args.data_dir, 'vocab2id.json'))
    dump_embeddings(vocab_dict, args.embed_path, args.out_path, emb_size=args.emb_size, binary=True if args.binary else False)


================================================
FILE: src/config/bamnet_webq.yml
================================================
# Seed 15 Data
name: 'WebQuestions'
data_dir: '../data/WebQ/'
train_data: 'train_vec.json'
valid_data: 'valid_vec.json'
test_data: 'test_vec.json'
pre_word2vec: '../data/WebQ/glove_pretrained_300d_w2v.npy'

# Full vocab
vocab_size: 100797
num_ent_types: 1712
num_relations: 4996

num_query_words: 10

# Output
model_file: '../runs/WebQ/bamnet.md'

# Model
query_size: 32
query_markup_size: 1 # Not used
ans_bow_size: 1 # Not used
ans_path_bow_size: null
ans_ctx_entity_bow_size: 6

vocab_embed_size: 300
hidden_size: 128
o_embed_size: 128
mem_size: 96
word_emb_dropout: 0.3
que_enc_dropout: 0.3
ans_enc_dropout: 0.2
attention: 'add'
num_hops: 1

# Training
learning_rate: 0.001
batch_size: 32
num_epochs: 100
valid_patience: 10
margin: 1

# Testing
test_batch_size: 1
test_margin:
        - 0.7

# Device
no_cuda: False
gpu: 0


================================================
FILE: src/config/entnet_webq.yml
================================================
# WebQuestions Data
name: 'WebQuestions'
data_dir: '../data/WebQ/'
train_data: 'train_ent_vec.json'
valid_data: 'valid_ent_vec.json'
test_data: 'test_ent_vec.json'

# Full vocab
vocab_size: 100797
num_ent_types: 1712
num_relations: 4996
pre_word2vec: '../data/WebQ/glove_pretrained_300d_w2v.npy'


# Output
model_file: '../runs/WebQ/entnet.md'


# Model
query_size: 32
max_seed_ent_name_size: null
max_seed_type_name_size: null
max_seed_rel_name_size: null
max_seed_rel_size: null

vocab_embed_size: 300
hidden_size: 128
o_embed_size: 128
word_emb_dropout: 0.3
que_enc_dropout: 0.3
ent_enc_dropout: 0.2
attention: 'simple'
seq_enc_type: 'cnn'
num_ent_hops: 1

# Training
learning_rate: 0.001
batch_size: 32
num_epochs: 100
valid_patience: 10

# Testing
test_batch_size: 1

# Device
no_cuda: False
gpu: 0


================================================
FILE: src/core/__init__.py
================================================
'''
Created on Oct, 2017

@author: hugo

'''


================================================
FILE: src/core/bamnet/__init__.py
================================================
'''
Created on Oct, 2017

@author: hugo

'''


================================================
FILE: src/core/bamnet/bamnet.py
================================================
'''
Created on Sep, 2017

@author: hugo

'''
import os
import timeit
import numpy as np

import torch
from torch import optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.nn import MultiLabelMarginLoss
import torch.backends.cudnn as cudnn

from .modules import BAMnet
from .utils import to_cuda, next_batch
from ..utils.utils import load_ndarray
from ..utils.generic_utils import unique
from ..utils.metrics import *
from .. import config


CTX_BOW_INDEX = -5
def get_text_overlap(raw_query, query_mentions, ctx_ent_names, vocab2id, ctx_stops, query):
    def longest_common_substring(s1, s2):
       m = [[0] * (1 + len(s2)) for i in range(1 + len(s1))]
       longest, x_longest = 0, 0
       for x in range(1, 1 + len(s1)):
           for y in range(1, 1 + len(s2)):
               if s1[x - 1] == s2[y - 1]:
                   m[x][y] = m[x - 1][y - 1] + 1
                   if m[x][y] > longest:
                       longest = m[x][y]
                       x_longest = x
               else:
                   m[x][y] = 0
       return s1[x_longest - longest: x_longest]

    sub_seq = longest_common_substring(raw_query, ctx_ent_names)
    if len(set(sub_seq) - ctx_stops) == 0:
        return []

    men_type = None
    for men, type_ in query_mentions:
        if type_.lower() in config.constraint_mention_types:
            if '_'.join(sub_seq) in '_'.join(men):
                men_type = '__{}__'.format(type_.lower())
                break

    if men_type:
        return [vocab2id[men_type] if men_type in vocab2id else config.RESERVED_TOKENS['UNK']]
    else:
        return [vocab2id[x] if x in vocab2id else config.RESERVED_TOKENS['UNK'] for x in sub_seq]

class BAMnetAgent(object):
    """ Bidirectional attentive memory network agent.
    """
    def __init__(self, opt, ctx_stops, vocab2id):
        self.ctx_stops = ctx_stops
        self.vocab2id = vocab2id
        opt['cuda'] = not opt['no_cuda'] and torch.cuda.is_available()
        if opt['cuda']:
            print('[ Using CUDA ]')
            torch.cuda.set_device(opt['gpu'])
            # It enables benchmark mode in cudnn, which
            # leads to faster runtime when the input sizes do not vary.
            cudnn.benchmark = True

        self.opt = opt
        if self.opt['pre_word2vec']:
            pre_w2v = load_ndarray(self.opt['pre_word2vec'])
        else:
            pre_w2v = None

        self.model = BAMnet(opt['vocab_size'], opt['vocab_embed_size'], \
                opt['o_embed_size'], opt['hidden_size'], \
                opt['num_ent_types'], opt['num_relations'], \
                opt['num_query_words'], \
                word_emb_dropout=opt['word_emb_dropout'], \
                que_enc_dropout=opt['que_enc_dropout'], \
                ans_enc_dropout=opt['ans_enc_dropout'], \
                pre_w2v=pre_w2v, \
                num_hops=opt['num_hops'], \
                att=opt['attention'], \
                use_cuda=opt['cuda'])
        if opt['cuda']:
            self.model.cuda()

        # MultiLabelMarginLoss
        # For each sample in the mini-batch:
        # loss(x, y) = sum_ij(max(0, 1 - (x[y[j]] - x[i]))) / x.size(0)
        self.loss_fn = MultiLabelMarginLoss()

        optim_params = [p for p in self.model.parameters() if p.requires_grad]
        self.optimizers = {'bamnet': optim.Adam(optim_params, lr=opt['learning_rate'])}
        self.scheduler = ReduceLROnPlateau(self.optimizers['bamnet'], mode='min', \
                    patience=self.opt['valid_patience'] // 3, verbose=True)

        if opt.get('model_file') and os.path.isfile(opt['model_file']):
            print('Loading existing model parameters from ' + opt['model_file'])
            self.load(opt['model_file'])
        super(BAMnetAgent, self).__init__()

    def train(self, train_X, train_y, valid_X, valid_y, valid_cand_labels, valid_gold_ans_labels, seed=1234):
        print('Training size: {}, Validation size: {}'.format(len(train_y), len(valid_y)))
        random1 = np.random.RandomState(seed)
        random2 = np.random.RandomState(seed)
        random3 = np.random.RandomState(seed)
        random4 = np.random.RandomState(seed)
        random5 = np.random.RandomState(seed)
        random6 = np.random.RandomState(seed)
        random7 = np.random.RandomState(seed)
        memories, queries, query_words, raw_queries, query_mentions, query_lengths = train_X
        gold_ans_inds = train_y

        valid_memories, valid_queries, valid_query_words, valid_raw_queries, valid_query_mentions, valid_query_lengths = valid_X
        valid_gold_ans_inds = valid_y

        n_incr_error = 0  # nb. of consecutive increase in error
        best_loss = float("inf")
        num_batches = len(queries) // self.opt['batch_size'] + (len(queries) % self.opt['batch_size'] != 0)
        num_valid_batches = len(valid_queries) // self.opt['batch_size'] + (len(valid_queries) % self.opt['batch_size'] != 0)
        for epoch in range(1, self.opt['num_epochs'] + 1):
            start = timeit.default_timer()
            n_incr_error += 1
            random1.shuffle(memories)
            random2.shuffle(queries)
            random3.shuffle(query_words)
            random4.shuffle(raw_queries)
            random5.shuffle(query_mentions)
            random6.shuffle(query_lengths)
            random7.shuffle(gold_ans_inds)
            train_gen = next_batch(memories, queries, query_words, raw_queries, query_mentions, query_lengths, gold_ans_inds, self.opt['batch_size'])
            train_loss = 0
            for batch_xs, batch_ys in train_gen:
                train_loss += self.train_step(batch_xs, batch_ys) / num_batches

            valid_gen = next_batch(valid_memories, valid_queries, valid_query_words, valid_raw_queries, valid_query_mentions, valid_query_lengths, valid_gold_ans_inds, self.opt['batch_size'])
            valid_loss = 0
            for batch_valid_xs, batch_valid_ys in valid_gen:
                valid_loss += self.train_step(batch_valid_xs, batch_valid_ys, is_training=False) / num_valid_batches
            self.scheduler.step(valid_loss)

            # if False:
            if epoch > 0:
                pred = self.predict(valid_X, valid_cand_labels, batch_size=1, margin=self.opt['margin'], silence=True)
                predictions = [unique([x[0] for x in each]) for each in pred]
                valid_f1 = calc_avg_f1(valid_gold_ans_labels, predictions, verbose=False)[-1]
            else:
                valid_f1 = 0.
            print('Epoch {}/{}: Runtime: {}s, Train loss: {:.4}, valid loss: {:.4}, valid F1: {:.4}'.format(epoch, self.opt['num_epochs'], \
                                                    int(timeit.default_timer() - start), train_loss, valid_loss, valid_f1))

            if valid_loss < best_loss:
                best_loss = valid_loss
                n_incr_error = 0
                self.save()

            if n_incr_error >= self.opt['valid_patience']:
                print('Early stopping occured. Optimization Finished!')
                self.save(self.opt['model_file'] + '.final')
                break

    def predict(self, xs, cand_labels, batch_size=32, margin=1, ys=None, verbose=False, silence=False):
        '''Prediction scores are returned in the verbose mode.
        '''
        if not silence:
            print('Testing size: {}'.format(len(cand_labels)))
        memories, queries, query_words, raw_queries, query_mentions, query_lengths = xs
        gen = next_batch(memories, queries, query_words, raw_queries, query_mentions, query_lengths, cand_labels, batch_size)
        predictions = []
        for batch_xs, batch_cands in gen:
            batch_pred = self.predict_step(batch_xs, batch_cands, margin, verbose=verbose)
            predictions.extend(batch_pred)
        return predictions

    def train_step(self, xs, ys, is_training=True):
        # Sets the module in training mode.
        # This has any effect only on modules such as Dropout or BatchNorm.
        self.model.train(mode=is_training)
        with torch.set_grad_enabled(is_training):
            # Organize inputs for network
            selected_memories, new_ys, ctx_mask = self.dynamic_ctx_negative_sampling(xs[0], ys, self.opt['mem_size'], \
                                    self.opt['ans_ctx_entity_bow_size'], xs[3], xs[4], xs[1])
            selected_memories = [to_cuda(torch.LongTensor(np.array(x)), self.opt['cuda']) for x in zip(*selected_memories)]
            ctx_mask = to_cuda(ctx_mask, self.opt['cuda'])
            queries = to_cuda(torch.LongTensor(xs[1]), self.opt['cuda'])
            query_words = to_cuda(torch.LongTensor(xs[2]), self.opt['cuda'])
            query_lengths = to_cuda(torch.LongTensor(xs[5]), self.opt['cuda'])
            mem_hop_scores = self.model(selected_memories, queries, query_lengths, query_words, ctx_mask=None)
            # Set margin
            new_ys, mask_ys = self.pack_gold_ans(new_ys, mem_hop_scores[-1].size(1), placeholder=-1)

            loss = 0
            for _, s in enumerate(mem_hop_scores):
                s = self.set_loss_margin(s, mask_ys, self.opt['margin'])
                loss += self.loss_fn(s, new_ys)
            loss /= len(mem_hop_scores)

            if is_training:
                for o in self.optimizers.values():
                    o.zero_grad()
                loss.backward()
                for o in self.optimizers.values():
                    o.step()
            return loss.item()

    def predict_step(self, xs, cand_labels, margin, verbose=False):
        self.model.train(mode=False)
        with torch.set_grad_enabled(False):
            # Organize inputs for network
            memories, ctx_mask = self.pad_ctx_memory(xs[0], self.opt['ans_ctx_entity_bow_size'], xs[3], xs[4], xs[1])
            memories = [to_cuda(torch.LongTensor(np.array(x)), self.opt['cuda']) for x in zip(*memories)]
            ctx_mask = to_cuda(ctx_mask, self.opt['cuda'])
            queries = to_cuda(torch.LongTensor(xs[1]), self.opt['cuda'])
            query_words = to_cuda(torch.LongTensor(xs[2]), self.opt['cuda'])
            query_lengths = to_cuda(torch.LongTensor(xs[5]), self.opt['cuda'])
            mem_hop_scores = self.model(memories, queries, query_lengths, query_words, ctx_mask=None)

            predictions = self.ranked_predictions(cand_labels, mem_hop_scores[-1].data, margin)
            return predictions

    def dynamic_ctx_negative_sampling(self, memories, ys, mem_size, ctx_bow_size, raw_queries, query_mentions, queries):
        # Randomly select negative samples from the candidiate answer set
        ctx_bow_size = max(min(max(map(len, (a for x in list(zip(*memories))[CTX_BOW_INDEX] for y in x for a in y)), default=0), ctx_bow_size), 1)

        selected_memories = []
        new_ys = []
        ctx_mask = []
        for i in range(len(ys)):
            n = len(memories[i][0]) - 1 # The last element is a dummy candidate
            num_gold = len(ys[i]) if mem_size > len(ys[i]) else \
                    (mem_size - min(mem_size // 2, n - len(ys[i]))) # Max possible (pos, neg) pairs
            selected_gold_inds = np.random.choice(ys[i], num_gold, replace=False).tolist() if len(ys[i]) > 0 else []
            if n > len(ys[i]):
                p = np.ones(n)
                p[ys[i]] = 0
                p = p / np.sum(p)
                selected_inds = np.random.choice(n, min(mem_size, n) - num_gold, replace=False, p=p).tolist()
            else:
                selected_inds = []
            augmented_selected_inds = selected_gold_inds + selected_inds + [-1] * max(mem_size - n, 0)
            xx = [min(mem_size, n)] + [np.array(x)[augmented_selected_inds] for x in memories[i][:CTX_BOW_INDEX]]

            ctx_bow = []
            ctx_bow_len = []
            ctx_num = []
            tmp_ctx_mask = np.zeros(mem_size)
            for _, idx in enumerate(augmented_selected_inds):
                tmp_ctx = []
                tmp_ctx_len = []
                for ctx_ent_names in memories[i][CTX_BOW_INDEX][idx]:
                    sub_seq = get_text_overlap(raw_queries[i], query_mentions[i], ctx_ent_names, self.vocab2id, self.ctx_stops, queries[i])
                    if len(sub_seq) > 0:
                        tmp_ctx_mask[_] = 1
                        tmp_ctx.append(sub_seq[:ctx_bow_size] + [config.RESERVED_TOKENS['PAD']] * max(0, ctx_bow_size - len(sub_seq)))
                        tmp_ctx_len.append(max(min(ctx_bow_size, len(sub_seq)), 1))
                ctx_bow.append(tmp_ctx)
                ctx_bow_len.append(tmp_ctx_len)
                ctx_num.append(len(tmp_ctx))

            xx += [ctx_bow, ctx_bow_len, ctx_num]
            xx += [np.array(x)[augmented_selected_inds] for x in memories[i][CTX_BOW_INDEX+1:]]
            selected_memories.append(xx)
            new_ys.append(list(range(num_gold)))
            ctx_mask.append(tmp_ctx_mask)

        max_ctx_num = max(max([y for x in selected_memories for y in x[CTX_BOW_INDEX]]), 1)
        for i in range(len(selected_memories)): # Example
            for j in range(len(selected_memories[i][-1])): # Cand
                count = selected_memories[i][CTX_BOW_INDEX][j]
                if count < max_ctx_num:
                    selected_memories[i][CTX_BOW_INDEX - 2][j] += [[config.RESERVED_TOKENS['PAD']] * ctx_bow_size] * (max_ctx_num - count)
                    selected_memories[i][CTX_BOW_INDEX - 1][j] += [1] * (max_ctx_num - count)
        return selected_memories, new_ys, torch.Tensor(np.array(ctx_mask))

    def pad_ctx_memory(self, memories, ctx_bow_size, raw_queries, query_mentions, queries):
        cand_ans_size = max(max(map(len, list(zip(*memories))[0]), default=0) - 1, 1) # The last element is a dummy candidate
        ctx_bow_size = max(min(max(map(len, (a for x in list(zip(*memories))[CTX_BOW_INDEX] for y in x for a in y)), default=0), ctx_bow_size), 1)

        pad_memories = []
        ctx_mask = []
        for i in range(len(memories)):
            n = len(memories[i][0]) - 1 # The last element is a dummy candidate
            augmented_inds = list(range(n)) + [-1] * (cand_ans_size - n)
            xx = [n] + [np.array(x)[augmented_inds] for x in memories[i][:CTX_BOW_INDEX]]

            ctx_bow = []
            ctx_bow_len = []
            ctx_num = []
            tmp_ctx_mask = np.zeros(cand_ans_size)
            for _, idx in enumerate(augmented_inds):
                tmp_ctx = []
                tmp_ctx_len = []
                for ctx_ent_names in memories[i][CTX_BOW_INDEX][idx]:
                    sub_seq = get_text_overlap(raw_queries[i], query_mentions[i], ctx_ent_names, self.vocab2id, self.ctx_stops, queries[i])
                    if len(sub_seq) > 0:
                        tmp_ctx_mask[_] = 1
                        tmp_ctx.append(sub_seq[:ctx_bow_size] + [config.RESERVED_TOKENS['PAD']] * max(0, ctx_bow_size - len(sub_seq)))
                        tmp_ctx_len.append(max(min(ctx_bow_size, len(sub_seq)), 1))
                ctx_bow.append(tmp_ctx)
                ctx_bow_len.append(tmp_ctx_len)
                ctx_num.append(len(tmp_ctx))

            xx += [ctx_bow, ctx_bow_len, ctx_num]
            xx += [np.array(x)[augmented_inds] for x in memories[i][CTX_BOW_INDEX+1:]]
            pad_memories.append(xx)
            ctx_mask.append(tmp_ctx_mask)

        max_ctx_num = max(max([y for x in pad_memories for y in x[CTX_BOW_INDEX]]), 1)
        for i in range(len(pad_memories)): # Example
            for j in range(len(pad_memories[i][-1])): # Cand
                count = pad_memories[i][CTX_BOW_INDEX][j]
                if count < max_ctx_num:
                    pad_memories[i][CTX_BOW_INDEX - 2][j] += [[config.RESERVED_TOKENS['PAD']] * ctx_bow_size] * (max_ctx_num - count)
                    pad_memories[i][CTX_BOW_INDEX - 1][j] += [1] * (max_ctx_num - count)
        return pad_memories, torch.Tensor(np.array(ctx_mask))

    def pack_gold_ans(self, x, N, placeholder=-1):
        y = np.ones((len(x), N), dtype='int64') * placeholder
        mask = np.zeros((len(x), N))
        for i in range(len(x)):
            y[i, :len(x[i])] = x[i]
            mask[i, :len(x[i])] = 1
        return to_cuda(torch.LongTensor(y), self.opt['cuda']), to_cuda(torch.Tensor(mask), self.opt['cuda'])

    def set_loss_margin(self, scores, gold_mask, margin):
        """Since the pytorch built-in MultiLabelMarginLoss fixes the margin as 1.
        We simply work around this annoying feature by *modifying* the golden scores.
        E.g., if we want margin as 3, we decrease each golden score by 3 - 1 before
        feeding it to the built-in loss.
        """
        new_scores = scores - (margin - 1) * gold_mask
        return new_scores

    def ranked_predictions(self, cand_labels, scores, margin):
        _, sorted_inds = scores.sort(descending=True, dim=1)
        return [[(cand_labels[i][j], scores[i][j]) for j in r if scores[i][j] + margin >= scores[i][r[0]] \
                and cand_labels[i][j] != 'UNK'] \
                if len(cand_labels[i]) > 0 and scores[i][r[0]] > -1e4 else [] \
                for i, r in enumerate(sorted_inds)] # Very large negative ones are dummy candidates

    def save(self, path=None):
        path = self.opt.get('model_file', None) if path is None else path

        if path:
            checkpoint = {}
            checkpoint['bamnet'] = self.model.state_dict()
            checkpoint['bamnet_optim'] = self.optimizers['bamnet'].state_dict()
            with open(path, 'wb') as write:
                torch.save(checkpoint, write)
                print('Saved model to {}'.format(path))

    def load(self, path):
        with open(path, 'rb') as read:
            checkpoint = torch.load(read, map_location=lambda storage, loc: storage)
        self.model.load_state_dict(checkpoint['bamnet'])
        self.optimizers['bamnet'].load_state_dict(checkpoint['bamnet_optim'])


================================================
FILE: src/core/bamnet/ent_modules.py
================================================
'''
Created on Sep, 2018

@author: hugo

'''
import numpy as np

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
import torch.nn.functional as F

from .modules import SeqEncoder, SelfAttention_CoAtt, Attention
from .utils import to_cuda


INF = 1e20
VERY_SMALL_NUMBER = 1e-10
class Entnet(nn.Module):
    def __init__(self, vocab_size, vocab_embed_size, o_embed_size, \
        hidden_size, num_ent_types, num_relations, \
        seq_enc_type='cnn', \
        word_emb_dropout=None, \
        que_enc_dropout=None,\
        ent_enc_dropout=None, \
        pre_w2v=None, \
        num_hops=1, \
        att='add', \
        use_cuda=True):
        super(Entnet, self).__init__()
        self.use_cuda = use_cuda
        self.seq_enc_type = seq_enc_type
        self.que_enc_dropout = que_enc_dropout
        self.ent_enc_dropout = ent_enc_dropout
        self.num_hops = num_hops
        self.hidden_size = hidden_size
        self.que_enc = SeqEncoder(vocab_size, vocab_embed_size, hidden_size, \
                        seq_enc_type=seq_enc_type, \
                        word_emb_dropout=word_emb_dropout, \
                        bidirectional=True, \
                        cnn_kernel_size=[2, 3], \
                        init_word_embed=pre_w2v, \
                        use_cuda=use_cuda).que_enc

        self.ent_enc = EntEncoder(o_embed_size, hidden_size, \
                        num_ent_types, num_relations, \
                        vocab_size=vocab_size, \
                        vocab_embed_size=vocab_embed_size, \
                        shared_embed=self.que_enc.embed, \
                        seq_enc_type=seq_enc_type, \
                        word_emb_dropout=word_emb_dropout, \
                        ent_enc_dropout=ent_enc_dropout, \
                        use_cuda=use_cuda)
        self.batchnorm = nn.BatchNorm1d(hidden_size)

        if seq_enc_type in ('lstm', 'gru'):
            self.self_atten = SelfAttention_CoAtt(hidden_size)
            print('[ Using self-attention on question encoder ]')

        self.ent_memory_hop = EntRomHop(hidden_size, hidden_size, hidden_size, atten_type=att)
        print('[ Using {}-hop entity memory update ]'.format(num_hops))

    def forward(self, memories, queries, query_lengths):
        x_ent_names, x_ent_name_len, x_type_names, x_types, x_type_name_len, x_rel_names, x_rels, x_rel_name_len, x_rel_mask = memories
        x_rel_mask = self.create_mask_3D(x_rel_mask, x_rels.size(-1), use_cuda=self.use_cuda)

        # Question encoder
        if self.seq_enc_type in ('lstm', 'gru'):
            Q_r = self.que_enc(queries, query_lengths)[0]
            if self.que_enc_dropout:
                Q_r = F.dropout(Q_r, p=self.que_enc_dropout, training=self.training)

            query_mask = self.create_mask(query_lengths, Q_r.size(1), self.use_cuda)
            q_r = self.self_atten(Q_r, query_lengths, query_mask)
        else:
            q_r = self.que_enc(queries, query_lengths)[1]
            if self.que_enc_dropout:
                q_r = F.dropout(q_r, p=self.que_enc_dropout, training=self.training)

        # Entity encoder
        ent_val, ent_key = self.ent_enc(x_ent_names, x_ent_name_len, x_type_names, x_types, x_type_name_len, x_rel_names, x_rels, x_rel_name_len, x_rel_mask)

        ent_val = torch.cat([each.unsqueeze(2) for each in ent_val], 2)
        ent_key = torch.cat([each.unsqueeze(2) for each in ent_key], 2)
        ent_val = torch.sum(ent_val, 2)
        ent_key = torch.sum(ent_key, 2)

        mem_hop_scores = []
        mid_score = self.clf_score(q_r, ent_key)
        mem_hop_scores.append(mid_score)

        for _ in range(self.num_hops):
            q_r = q_r + self.ent_memory_hop(q_r, ent_key, ent_val)
            q_r = self.batchnorm(q_r)
            mid_score = self.clf_score(q_r, ent_key)
            mem_hop_scores.append(mid_score)
        return mem_hop_scores

    def clf_score(self, q_r, ent_key):
        return torch.matmul(ent_key, q_r.unsqueeze(-1)).squeeze(-1)

    def create_mask(self, x, N, use_cuda=True):
        x = x.data
        mask = np.zeros((x.size(0), N))
        for i in range(x.size(0)):
            mask[i, :x[i]] = 1
        return to_cuda(torch.Tensor(mask), use_cuda)

    def create_mask_3D(self, x, N, use_cuda=True):
        x = x.data
        mask = np.zeros((x.size(0), x.size(1), N))
        for i in range(x.size(0)):
            for j in range(x.size(1)):
                mask[i, j, :x[i, j]] = 1
        return to_cuda(torch.Tensor(mask), use_cuda)

class EntEncoder(nn.Module):
    """Entity Encoder"""
    def __init__(self, o_embed_size, hidden_size, num_ent_types, num_relations, vocab_size=None, \
                    vocab_embed_size=None, shared_embed=None, seq_enc_type='lstm', word_emb_dropout=None, \
                    ent_enc_dropout=None, use_cuda=True):
        super(EntEncoder, self).__init__()
        # Cannot have embed and vocab_size set as None at the same time.
        self.ent_enc_dropout = ent_enc_dropout
        self.hidden_size = hidden_size
        self.relation_embed = nn.Embedding(num_relations, o_embed_size, padding_idx=0)
        self.embed = shared_embed if shared_embed is not None else nn.Embedding(vocab_size, vocab_embed_size, padding_idx=0)
        self.vocab_embed_size = self.embed.weight.data.size(1)

        self.linear_node_name_key = nn.Linear(hidden_size, hidden_size, bias=False)
        self.linear_node_type_key = nn.Linear(hidden_size, hidden_size, bias=False)
        self.linear_rels_key = nn.Linear(hidden_size + o_embed_size, hidden_size, bias=False)
        self.linear_node_name_val = nn.Linear(hidden_size, hidden_size, bias=False)
        self.linear_node_type_val = nn.Linear(hidden_size, hidden_size, bias=False)
        self.linear_rels_val = nn.Linear(hidden_size + o_embed_size, hidden_size, bias=False)

        self.kg_enc_ent = SeqEncoder(vocab_size, \
                        self.vocab_embed_size, \
                        hidden_size, \
                        seq_enc_type=seq_enc_type, \
                        word_emb_dropout=word_emb_dropout, \
                        bidirectional=True, \
                        cnn_kernel_size=[3], \
                        shared_embed=shared_embed, \
                        use_cuda=use_cuda).que_enc # entity name

        self.kg_enc_type = SeqEncoder(vocab_size, \
                        self.vocab_embed_size, \
                        hidden_size, \
                        seq_enc_type=seq_enc_type, \
                        word_emb_dropout=word_emb_dropout, \
                        bidirectional=True, \
                        cnn_kernel_size=[3], \
                        shared_embed=shared_embed, \
                        use_cuda=use_cuda).que_enc # entity type name

        self.kg_enc_rel = SeqEncoder(vocab_size, \
                        self.vocab_embed_size, \
                        hidden_size, \
                        seq_enc_type=seq_enc_type, \
                        word_emb_dropout=word_emb_dropout, \
                        bidirectional=True, \
                        cnn_kernel_size=[3], \
                        shared_embed=shared_embed, \
                        use_cuda=use_cuda).que_enc # relation name

    def forward(self, x_ent_names, x_ent_name_len, x_type_names, x_types, x_type_name_len, x_rel_names, x_rels, x_rel_name_len, x_rel_mask):
        node_ent_names, node_type_names, node_types, edge_rel_names, edge_rels = self.enc_kg_features(x_ent_names, x_ent_name_len, x_type_names, x_types, x_type_name_len, x_rel_names, x_rels, x_rel_name_len, x_rel_mask)
        node_name_key = self.linear_node_name_key(node_ent_names)
        node_type_key = self.linear_node_type_key(node_type_names)
        rel_key = self.linear_rels_key(torch.cat([edge_rel_names, edge_rels], -1))

        node_name_val = self.linear_node_name_val(node_ent_names)
        node_type_val = self.linear_node_type_val(node_type_names)
        rel_val = self.linear_rels_val(torch.cat([edge_rel_names, edge_rels], -1))

        ent_comp_val = [node_name_val, node_type_val, rel_val]
        ent_comp_key = [node_name_key, node_type_key, rel_key]
        return ent_comp_val, ent_comp_key

    def enc_kg_features(self, x_ent_names, x_ent_name_len, x_type_names, x_types, x_type_name_len, x_rel_names, x_rels, x_rel_name_len, x_rel_mask):
        node_ent_names = (self.kg_enc_ent(x_ent_names.view(-1, x_ent_names.size(-1)), x_ent_name_len.view(-1))[1]).view(x_ent_names.size(0), x_ent_names.size(1), -1)
        node_type_names = (self.kg_enc_type(x_type_names.view(-1, x_type_names.size(-1)), x_type_name_len.view(-1))[1]).view(x_type_names.size(0), x_type_names.size(1), -1)
        node_types = None
        edge_rel_names = torch.mean((self.kg_enc_rel(x_rel_names.view(-1, x_rel_names.size(-1)), x_rel_name_len.view(-1))[1]).view(x_rel_names.size(0), x_rel_names.size(1), x_rel_names.size(2), -1), 2)
        edge_rels = torch.mean(self.relation_embed(x_rels.view(-1, x_rels.size(-1))), 1).view(x_rels.size(0), x_rels.size(1), -1)

        if self.ent_enc_dropout:
            node_ent_names = F.dropout(node_ent_names, p=self.ent_enc_dropout, training=self.training)
            node_type_names = F.dropout(node_type_names, p=self.ent_enc_dropout, training=self.training)
            # node_types = F.dropout(node_types, p=self.ent_enc_dropout, training=self.training)
            edge_rel_names = F.dropout(edge_rel_names, p=self.ent_enc_dropout, training=self.training)
            edge_rels = F.dropout(edge_rels, p=self.ent_enc_dropout, training=self.training)
        return node_ent_names, node_type_names, node_types, edge_rel_names, edge_rels


class EntRomHop(nn.Module):
    def __init__(self, query_embed_size, in_memory_embed_size, hidden_size, atten_type='add'):
        super(EntRomHop, self).__init__()
        self.atten = Attention(hidden_size, query_embed_size, in_memory_embed_size, atten_type=atten_type)
        self.gru_step = GRUStep(hidden_size, in_memory_embed_size)

    def forward(self, h_state, key_memory_embed, val_memory_embed, atten_mask=None):
        attention = self.atten(h_state, key_memory_embed, atten_mask=atten_mask)
        probs = torch.softmax(attention, dim=-1)
        memory_output = torch.bmm(probs.unsqueeze(1), val_memory_embed).squeeze(1)
        h_state = self.gru_step(h_state, memory_output)
        return h_state

class GRUStep(nn.Module):
    def __init__(self, hidden_size, input_size):
        super(GRUStep, self).__init__()
        '''GRU module'''
        self.linear_z = nn.Linear(hidden_size + input_size, hidden_size, bias=False)
        self.linear_r = nn.Linear(hidden_size + input_size, hidden_size, bias=False)
        self.linear_t = nn.Linear(hidden_size + input_size, hidden_size, bias=False)

    def forward(self, h_state, input_):
        z = torch.sigmoid(self.linear_z(torch.cat([h_state, input_], -1)))
        r = torch.sigmoid(self.linear_r(torch.cat([h_state, input_], -1)))
        t = torch.tanh(self.linear_t(torch.cat([r * h_state, input_], -1)))
        h_state = (1 - z) * h_state + z * t
        return h_state


================================================
FILE: src/core/bamnet/entnet.py
================================================
'''
Created on Sep, 2018

@author: hugo

'''
import os
import timeit
import numpy as np

import torch
from torch import optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.nn import CrossEntropyLoss, MultiLabelMarginLoss
import torch.backends.cudnn as cudnn

from .ent_modules import Entnet
from .utils import to_cuda, next_ent_batch
from ..utils.utils import load_ndarray
from ..utils.generic_utils import unique
from ..utils.metrics import *


class EntnetAgent(object):
    def __init__(self, opt):
        opt['cuda'] = not opt['no_cuda'] and torch.cuda.is_available()
        if opt['cuda']:
            print('[ Using CUDA ]')
            torch.cuda.set_device(opt['gpu'])
            # It enables benchmark mode in cudnn, which
            # leads to faster runtime when the input sizes do not vary.
            cudnn.benchmark = True

        self.opt = opt
        if self.opt['pre_word2vec']:
            pre_w2v = load_ndarray(self.opt['pre_word2vec'])
        else:
            pre_w2v = None

        self.ent_model = Entnet(opt['vocab_size'], opt['vocab_embed_size'], \
                opt['o_embed_size'], opt['hidden_size'], \
                opt['num_ent_types'], opt['num_relations'], \
                seq_enc_type=opt['seq_enc_type'], \
                word_emb_dropout=opt['word_emb_dropout'], \
                que_enc_dropout=opt['que_enc_dropout'], \
                ent_enc_dropout=opt['ent_enc_dropout'], \
                pre_w2v=pre_w2v, \
                num_hops=opt['num_ent_hops'], \
                att=opt['attention'], \
                use_cuda=opt['cuda'])
        if opt['cuda']:
            self.ent_model.cuda()

        self.loss_fn = MultiLabelMarginLoss()

        optim_params = [p for p in self.ent_model.parameters() if p.requires_grad]
        self.optimizers = {'entnet': optim.Adam(optim_params, lr=opt['learning_rate'])}
        self.scheduler = ReduceLROnPlateau(self.optimizers['entnet'], mode='min', \
                    patience=self.opt['valid_patience'] // 3, verbose=True)

        if opt.get('model_file') and os.path.isfile(opt['model_file']):
            print('Loading existing ent_model parameters from ' + opt['model_file'])
            self.load(opt['model_file'])
        else:
            self.save()
            self.load(opt['model_file'])
        super(EntnetAgent, self).__init__()

    def train(self, train_X, train_y, valid_X, valid_y, seed=1234):
        print('Training size: {}, Validation size: {}'.format(len(train_y), len(valid_y)))
        random1 = np.random.RandomState(seed)
        random2 = np.random.RandomState(seed)
        random3 = np.random.RandomState(seed)
        random4 = np.random.RandomState(seed)
        memories, queries, query_lengths = train_X
        ent_inds = train_y

        valid_memories, valid_queries, valid_query_lengths = valid_X
        valid_ent_inds = valid_y

        n_incr_error = 0  # nb. of consecutive increase in error
        best_loss = float("inf")
        best_acc = 0
        num_batches = len(queries) // self.opt['batch_size'] + (len(queries) % self.opt['batch_size'] != 0)
        num_valid_batches = len(valid_queries) // self.opt['batch_size'] + (len(valid_queries) % self.opt['batch_size'] != 0)
        for epoch in range(1, self.opt['num_epochs'] + 1):
            start = timeit.default_timer()
            n_incr_error += 1
            random1.shuffle(memories)
            random2.shuffle(queries)
            random3.shuffle(query_lengths)
            random4.shuffle(ent_inds)
            train_gen = next_ent_batch(memories, queries, query_lengths, ent_inds, self.opt['batch_size'])
            train_loss = 0
            for batch_xs, batch_ys in train_gen:
                train_loss += self.train_step(batch_xs, batch_ys) / num_batches

            valid_gen = next_ent_batch(valid_memories, valid_queries, valid_query_lengths, valid_ent_inds, self.opt['batch_size'])
            valid_loss = 0
            for batch_valid_xs, batch_valid_ys in valid_gen:
                valid_loss += self.train_step(batch_valid_xs, batch_valid_ys, is_training=False) / num_valid_batches
            self.scheduler.step(valid_loss)

            if epoch > 0:
                valid_acc = self.evaluate(valid_X, valid_ent_inds, batch_size=1, silence=True)
                # valid_acc = 0.
                print('Epoch {}/{}: Runtime: {}s, Training loss: {:.4}, validation loss: {:.4}, validation ACC: {:.4}'.format(epoch, self.opt['num_epochs'], \
                                                    int(timeit.default_timer() - start), train_loss, valid_loss, valid_acc))

                # self.scheduler.step(valid_acc)
                # if valid_acc > best_acc:
                #     best_acc = valid_acc
                #     n_incr_error = 0
                #     self.save()

                if valid_loss < best_loss:
                    best_loss = valid_loss
                    n_incr_error = 0
                    self.save()

                if n_incr_error >= self.opt['valid_patience']:
                    print('Early stopping occured. Optimization Finished!')
                    self.save(self.opt['model_file'] + '.final')
                    break

    def evaluate(self, xs, ys, batch_size=1, silence=False):
        '''Prediction scores are returned in the verbose mode.
        '''
        if not silence:
            print('Data size: {}'.format(len(xs[0])))
        memories, queries, query_lengths = xs
        gen = next_ent_batch(memories, queries, query_lengths, ys, batch_size)
        correct = 0
        num_samples = 0
        for batch_xs, batch_ys in gen:
            correct += self.evaluate_step(batch_xs, batch_ys)
            num_samples += len(batch_ys)
        acc = 100 * correct / num_samples
        return acc

    def predict(self, xs, cand_labels, batch_size=1, silence=False):
        if not silence:
            print('Data size: {}'.format(len(xs[0])))
        memories, queries, query_lengths = xs
        gen = next_ent_batch(memories, queries, query_lengths, cand_labels, batch_size)
        predictions = []
        for batch_xs, batch_cands in gen:
            batch_pred = self.predict_step(batch_xs, batch_cands)
            predictions.extend(batch_pred)
        return predictions

    def train_step(self, xs, ys, is_training=True):
        # Sets the module in training mode.
        # This has any effect only on modules such as Dropout or BatchNorm.
        self.ent_model.train(mode=is_training)
        with torch.set_grad_enabled(is_training):
            # Organize inputs for network
            memories = [to_cuda(torch.LongTensor(np.array(x)), self.opt['cuda']) for x in zip(*xs[0])]
            queries = to_cuda(torch.LongTensor(xs[1]), self.opt['cuda'])
            query_lengths = to_cuda(torch.LongTensor(xs[2]), self.opt['cuda'])
            mem_hop_scores = self.ent_model(memories, queries, query_lengths)
            # ys = to_cuda(torch.LongTensor(ys), self.opt['cuda']).squeeze(-1)
            # Set margin
            ys, mask_ys = self.pack_gold_ans(ys, mem_hop_scores[-1].size(1), placeholder=-1)

            loss = 0
            for _, s in enumerate(mem_hop_scores):
                loss += self.loss_fn(s, ys)
            loss /= len(mem_hop_scores)

            if is_training:
                for o in self.optimizers.values():
                    o.zero_grad()
                loss.backward()
                for o in self.optimizers.values():
                    o.step()
            return loss.item()

    def evaluate_step(self, xs, ys):
        self.ent_model.train(mode=False)
        with torch.set_grad_enabled(False):
            # Organize inputs for network
            memories = [to_cuda(torch.LongTensor(np.array(x)), self.opt['cuda']) for x in zip(*xs[0])]
            queries = to_cuda(torch.LongTensor(xs[1]), self.opt['cuda'])
            query_lengths = to_cuda(torch.LongTensor(xs[2]), self.opt['cuda'])
            scores = self.ent_model(memories, queries, query_lengths)[-1]
            ys = to_cuda(torch.LongTensor(ys), self.opt['cuda']).squeeze(1)

            predictions = scores.max(1)[1].type_as(ys)
            correct = predictions.eq(ys).sum()
            return correct.item()

    def predict_step(self, xs, cand_labels):
        self.ent_model.train(mode=False)
        with torch.set_grad_enabled(False):
            # Organize inputs for network
            memories = [to_cuda(torch.LongTensor(np.array(x)), self.opt['cuda']) for x in zip(*xs[0])]
            queries = to_cuda(torch.LongTensor(xs[1]), self.opt['cuda'])
            query_lengths = to_cuda(torch.LongTensor(xs[2]), self.opt['cuda'])
            scores = self.ent_model(memories, queries, query_lengths)[-1]

            predictions = self.ranked_predictions(cand_labels, scores)
            return predictions

    def pack_gold_ans(self, x, N, placeholder=-1):
        y = np.ones((len(x), N), dtype='int64') * placeholder
        mask = np.zeros((len(x), N))
        for i in range(len(x)):
            y[i, :len(x[i])] = x[i]
            mask[i, :len(x[i])] = 1
        return to_cuda(torch.LongTensor(y), self.opt['cuda']), to_cuda(torch.Tensor(mask), self.opt['cuda'])

    def ranked_predictions(self, cand_labels, scores):
        _, sorted_inds = scores.sort(descending=True, dim=1)
        return [cand_labels[i][r[0]] if len(cand_labels[i]) > 0 else '' \
                for i, r in enumerate(sorted_inds)]

    def save(self, path=None):
        path = self.opt.get('model_file', None) if path is None else path

        if path:
            checkpoint = {}
            checkpoint['entnet'] = self.ent_model.state_dict()
            checkpoint['entnet_optim'] = self.optimizers['entnet'].state_dict()
            with open(path, 'wb') as write:
                torch.save(checkpoint, write)
                print('Saved ent_model to {}'.format(path))

    def load(self, path):
        with open(path, 'rb') as read:
            checkpoint = torch.load(read, map_location=lambda storage, loc: storage)
        self.ent_model.load_state_dict(checkpoint['entnet'])
        self.optimizers['entnet'].load_state_dict(checkpoint['entnet_optim'])


================================================
FILE: src/core/bamnet/modules.py
================================================
'''
Created on Sep, 2017

@author: hugo

'''
import numpy as np

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
import torch.nn.functional as F

from .utils import to_cuda


INF = 1e20
VERY_SMALL_NUMBER = 1e-10
class BAMnet(nn.Module):
    def __init__(self, vocab_size, vocab_embed_size, o_embed_size, \
        hidden_size, num_ent_types, num_relations, num_query_words, \
        word_emb_dropout=None,\
        que_enc_dropout=None,\
        ans_enc_dropout=None, \
        pre_w2v=None, \
        num_hops=1, \
        att='add', \
        use_cuda=True):
        super(BAMnet, self).__init__()
        self.use_cuda = use_cuda
        self.word_emb_dropout = word_emb_dropout
        self.que_enc_dropout = que_enc_dropout
        self.ans_enc_dropout = ans_enc_dropout
        self.num_hops = num_hops
        self.hidden_size = hidden_size
        self.que_enc = SeqEncoder(vocab_size, vocab_embed_size, hidden_size, \
                        seq_enc_type='lstm', \
                        word_emb_dropout=word_emb_dropout, bidirectional=True, \
                        init_word_embed=pre_w2v, use_cuda=use_cuda).que_enc

        self.ans_enc = AnsEncoder(o_embed_size, hidden_size, \
                        num_ent_types, num_relations, \
                        vocab_size=vocab_size, \
                        vocab_embed_size=vocab_embed_size, \
                        shared_embed=self.que_enc.embed, \
                        word_emb_dropout=word_emb_dropout, \
                        ans_enc_dropout=ans_enc_dropout, \
                        use_cuda=use_cuda)

        self.qw_embed = nn.Embedding(num_query_words, o_embed_size // 8, padding_idx=0)
        self.batchnorm = nn.BatchNorm1d(hidden_size)

        self.init_atten = Attention(hidden_size, hidden_size, hidden_size, atten_type=att)
        self.self_atten = SelfAttention_CoAtt(hidden_size)
        print('[ Using self-attention on question encoder ]')

        self.memory_hop = RomHop(hidden_size, hidden_size, hidden_size, atten_type=att)
        print('[ Using {}-hop memory update ]'.format(self.num_hops))

    def kb_aware_query_enc(self, memories, queries, query_lengths, ans_mask, ctx_mask=None):
        # Question encoder
        Q_r = self.que_enc(queries, query_lengths)[0]
        if self.que_enc_dropout:
            Q_r = F.dropout(Q_r, p=self.que_enc_dropout, training=self.training)

        query_mask = create_mask(query_lengths, Q_r.size(1), self.use_cuda)
        q_r_init = self.self_atten(Q_r, query_lengths, query_mask)

        # Answer encoder
        _, _, _, x_type_bow, x_types, x_type_bow_len, x_path_bow, x_paths, x_path_bow_len, x_ctx_ent, x_ctx_ent_len, x_ctx_ent_num, _, _, _, _ = memories
        ans_comp_val, ans_comp_key = self.ans_enc(x_type_bow, x_types, x_type_bow_len, x_path_bow, x_paths, x_path_bow_len, x_ctx_ent, x_ctx_ent_len, x_ctx_ent_num)
        if self.ans_enc_dropout:
            for _ in range(len(ans_comp_key)):
                ans_comp_key[_] = F.dropout(ans_comp_key[_], p=self.ans_enc_dropout, training=self.training)
        # KB memory summary
        ans_comp_atts = [self.init_atten(q_r_init, each, atten_mask=ans_mask) for each in ans_comp_key]
        if ctx_mask is not None:
            ans_comp_atts[-1] = ctx_mask * ans_comp_atts[-1] - (1 - ctx_mask) * INF
        ans_comp_probs = [torch.softmax(each, dim=-1) for each in ans_comp_atts]
        memory_summary = []
        for i, probs in enumerate(ans_comp_probs):
            memory_summary.append(torch.bmm(probs.unsqueeze(1), ans_comp_val[i]))
        memory_summary = torch.cat(memory_summary, 1)

        # Co-attention
        CoAtt = torch.bmm(Q_r, memory_summary.transpose(1, 2)) # co-attention matrix
        CoAtt = query_mask.unsqueeze(-1) * CoAtt - (1 - query_mask).unsqueeze(-1) * INF
        if ctx_mask is not None:
            # mask over empty ctx elements
            ctx_mask_global = (ctx_mask.sum(-1, keepdim=True) > 0).float()
            CoAtt[:, :, -1] = ctx_mask_global * CoAtt[:, :, -1].clone() - (1 - ctx_mask_global) * INF

        q_att = F.max_pool1d(CoAtt, kernel_size=CoAtt.size(-1)).squeeze(-1)
        q_att = torch.softmax(q_att, dim=-1)
        return (ans_comp_val, ans_comp_key), (q_att, Q_r), query_mask

    def forward(self, memories, queries, query_lengths, query_words, ctx_mask=None):
        ctx_mask = None
        mem_hop_scores = []
        ans_mask = create_mask(memories[0], memories[2].size(1), self.use_cuda)

        # Multi-task learning on answer type matching
        # question word vec
        self.qw_vec = torch.mean(self.qw_embed(query_words), 1)
        # answer type vec
        x_types = memories[4]
        ans_types = torch.mean(self.ans_enc.ent_type_embed(x_types.view(-1, x_types.size(-1))), 1).view(x_types.size(0), x_types.size(1), -1)
        qw_anstype_loss = torch.bmm(ans_types, self.qw_vec.unsqueeze(2)).squeeze(2)
        if ans_mask is not None:
            qw_anstype_loss = ans_mask * qw_anstype_loss - (1 - ans_mask) * INF # Make dummy candidates have large negative scores
        mem_hop_scores.append(qw_anstype_loss)


        # Kb-aware question attention module
        (ans_val, ans_key), (q_att, Q_r), query_mask = self.kb_aware_query_enc(memories, queries, query_lengths, ans_mask, ctx_mask=ctx_mask)
        ans_val = torch.cat([each.unsqueeze(2) for each in ans_val], 2)
        ans_key = torch.cat([each.unsqueeze(2) for each in ans_key], 2)

        q_r = torch.bmm(q_att.unsqueeze(1), Q_r).squeeze(1)
        mid_score = self.scoring(ans_key.sum(2), q_r, mask=ans_mask)
        mem_hop_scores.append(mid_score)

        Q_r, ans_key, ans_val = self.memory_hop(Q_r, ans_key, ans_val, q_att, atten_mask=ans_mask, ctx_mask=ctx_mask, query_mask=query_mask)
        q_r = torch.bmm(q_att.unsqueeze(1), Q_r).squeeze(1)
        mid_score = self.scoring(ans_key, q_r, mask=ans_mask)
        mem_hop_scores.append(mid_score)

        # Generalization module
        for _ in range(self.num_hops):
            q_r_tmp = self.memory_hop.gru_step(q_r, ans_key, ans_val, atten_mask=ans_mask)
            q_r = self.batchnorm(q_r + q_r_tmp)
            mid_score = self.scoring(ans_key, q_r, mask=ans_mask)
            mem_hop_scores.append(mid_score)
        return mem_hop_scores

    def premature_score(self, memories, queries, query_lengths, ctx_mask=None):
        ctx_mask = None
        ans_mask = create_mask(memories[0], memories[2].size(1), self.use_cuda)

        # Kb-aware question attention module
        (ans_val, ans_key), (q_att, Q_r), query_mask = self.kb_aware_query_enc(memories, queries, query_lengths, ans_mask, ctx_mask=ctx_mask)
        ans_key = torch.cat([each.unsqueeze(2) for each in ans_key], 2)

        mem_hop_scores = []
        q_r = torch.bmm(q_att.unsqueeze(1), Q_r).squeeze(1)
        score = self.scoring(ans_key.sum(2), q_r, mask=ans_mask)
        return score

    def scoring(self, ans_r, q_r, mask=None):
        score = torch.bmm(ans_r, q_r.unsqueeze(2)).squeeze(2)
        if mask is not None:
            score = mask * score - (1 - mask) * INF # Make dummy candidates have large negative scores
        return score

class RomHop(nn.Module):
    def __init__(self, query_embed_size, in_memory_embed_size, hidden_size, atten_type='add'):
        super(RomHop, self).__init__()
        self.hidden_size = hidden_size
        self.gru_linear_z = nn.Linear(2 * hidden_size, hidden_size, bias=False)
        self.gru_linear_r = nn.Linear(2 * hidden_size, hidden_size, bias=False)
        self.gru_linear_t = nn.Linear(2 * hidden_size, hidden_size, bias=False)
        self.gru_atten = Attention(hidden_size, query_embed_size, in_memory_embed_size, atten_type=atten_type)

    def forward(self, query_embed, in_memory_embed, out_memory_embed, query_att, \
                atten_mask=None, ctx_mask=None, query_mask=None):
        output = self.update_coatt_cat_maxpool(query_embed, in_memory_embed, out_memory_embed, query_att, \
                    atten_mask=atten_mask, ctx_mask=ctx_mask, query_mask=query_mask)
        return output

    def gru_step(self, h_state, in_memory_embed, out_memory_embed, atten_mask=None):
        attention = self.gru_atten(h_state, in_memory_embed, atten_mask=atten_mask)
        probs = torch.softmax(attention, dim=-1)

        memory_output = torch.bmm(probs.unsqueeze(1), out_memory_embed).squeeze(1)
        # GRU-like memory update
        z = torch.sigmoid(self.gru_linear_z(torch.cat([h_state, memory_output], -1)))
        r = torch.sigmoid(self.gru_linear_r(torch.cat([h_state, memory_output], -1)))
        t = torch.tanh(self.gru_linear_t(torch.cat([r * h_state, memory_output], -1)))
        output = (1 - z) * h_state + z * t
        return output

    def update_coatt_cat_maxpool(self, query_embed, in_memory_embed, out_memory_embed, query_att, atten_mask=None, ctx_mask=None, query_mask=None):
        attention = torch.bmm(query_embed, in_memory_embed.view(in_memory_embed.size(0), -1, in_memory_embed.size(-1))\
            .transpose(1, 2)).view(query_embed.size(0), query_embed.size(1), in_memory_embed.size(1), -1) # bs * N * M * k
        if ctx_mask is not None:
            attention[:, :, :, -1] = ctx_mask.unsqueeze(1) * attention[:, :, :, -1].clone() - (1 - ctx_mask).unsqueeze(1) * INF
        if atten_mask is not None:
            attention = atten_mask.unsqueeze(1).unsqueeze(-1) * attention - (1 - atten_mask).unsqueeze(1).unsqueeze(-1) * INF
        if query_mask is not None:
            attention = query_mask.unsqueeze(2).unsqueeze(-1) * attention - (1 - query_mask).unsqueeze(2).unsqueeze(-1) * INF

        # Importance module
        kb_feature_att = F.max_pool1d(attention.view(attention.size(0), attention.size(1), -1).transpose(1, 2), kernel_size=attention.size(1)).squeeze(-1).view(attention.size(0), -1, attention.size(-1))
        kb_feature_att = torch.softmax(kb_feature_att, dim=-1).view(-1, kb_feature_att.size(-1)).unsqueeze(1)
        in_memory_embed = torch.bmm(kb_feature_att, in_memory_embed.view(-1, in_memory_embed.size(2), in_memory_embed.size(-1))).squeeze(1).view(in_memory_embed.size(0), in_memory_embed.size(1), -1)
        out_memory_embed = out_memory_embed.sum(2)

        # Enhanced module
        attention = F.max_pool1d(attention.view(attention.size(0), -1, attention.size(-1)), kernel_size=attention.size(-1)).squeeze(-1).view(attention.size(0), attention.size(1), attention.size(2))
        probs = torch.softmax(attention, dim=-1)
        new_query_embed = query_embed + query_att.unsqueeze(2) * torch.bmm(probs, out_memory_embed)

        probs2 = torch.softmax(attention, dim=1)
        kb_att = torch.bmm(query_att.unsqueeze(1), probs).squeeze(1)
        in_memory_embed = in_memory_embed + kb_att.unsqueeze(2) * torch.bmm(probs2.transpose(1, 2), new_query_embed)
        return new_query_embed, in_memory_embed, out_memory_embed

class AnsEncoder(nn.Module):
    """Answer Encoder"""
    def __init__(self, o_embed_size, hidden_size, num_ent_types, num_relations, vocab_size=None, \
                    vocab_embed_size=None, shared_embed=None, word_emb_dropout=None, \
                    ans_enc_dropout=None, use_cuda=True):
        super(AnsEncoder, self).__init__()
        # Cannot have embed and vocab_size set as None at the same time.
        self.use_cuda = use_cuda
        self.ans_enc_dropout = ans_enc_dropout
        self.hidden_size = hidden_size
        self.ent_type_embed = nn.Embedding(num_ent_types, o_embed_size // 8, padding_idx=0)
        self.relation_embed = nn.Embedding(num_relations, o_embed_size, padding_idx=0)
        self.embed = shared_embed if shared_embed is not None else nn.Embedding(vocab_size, vocab_embed_size, padding_idx=0)
        self.vocab_embed_size = self.embed.weight.data.size(1)

        self.linear_type_bow_key = nn.Linear(hidden_size, hidden_size, bias=False)
        self.linear_paths_key = nn.Linear(hidden_size + o_embed_size, hidden_size, bias=False)
        self.linear_ctx_key = nn.Linear(hidden_size, hidden_size, bias=False)
        self.linear_type_bow_val = nn.Linear(hidden_size, hidden_size, bias=False)
        self.linear_paths_val = nn.Linear(hidden_size + o_embed_size, hidden_size, bias=False)
        self.linear_ctx_val = nn.Linear(hidden_size, hidden_size, bias=False)

        # lstm for ans encoder
        self.lstm_enc_type = EncoderRNN(vocab_size, self.vocab_embed_size, hidden_size, \
                        dropout=word_emb_dropout, \
                        bidirectional=True, \
                        shared_embed=shared_embed, \
                        rnn_type='lstm', \
                        use_cuda=use_cuda)
        self.lstm_enc_path = EncoderRNN(vocab_size, self.vocab_embed_size, hidden_size, \
                        dropout=word_emb_dropout, \
                        bidirectional=True, \
                        shared_embed=shared_embed, \
                        rnn_type='lstm', \
                        use_cuda=use_cuda)
        self.lstm_enc_ctx = EncoderRNN(vocab_size, self.vocab_embed_size, hidden_size, \
                        dropout=word_emb_dropout, \
                        bidirectional=True, \
                        shared_embed=shared_embed, \
                        rnn_type='lstm', \
                        use_cuda=use_cuda)

    def forward(self, x_type_bow, x_types, x_type_bow_len, x_path_bow, x_paths, x_path_bow_len, x_ctx_ents, x_ctx_ent_len, x_ctx_ent_num):
        ans_type_bow, ans_types, ans_path_bow, ans_paths, ans_ctx_ent = self.enc_ans_features(x_type_bow, x_types, x_type_bow_len, x_path_bow, x_paths, x_path_bow_len, x_ctx_ents, x_ctx_ent_len, x_ctx_ent_num)
        ans_val, ans_key = self.enc_comp_kv(ans_type_bow, ans_types, ans_path_bow, ans_paths, ans_ctx_ent)
        return ans_val, ans_key

    def enc_comp_kv(self, ans_type_bow, ans_types, ans_path_bow, ans_paths, ans_ctx_ent):
        ans_type_bow_val = self.linear_type_bow_val(ans_type_bow)
        ans_paths_val = self.linear_paths_val(torch.cat([ans_path_bow, ans_paths], -1))
        ans_ctx_val = self.linear_ctx_val(ans_ctx_ent)

        ans_type_bow_key = self.linear_type_bow_key(ans_type_bow)
        ans_paths_key = self.linear_paths_key(torch.cat([ans_path_bow, ans_paths], -1))
        ans_ctx_key = self.linear_ctx_key(ans_ctx_ent)

        ans_comp_val = [ans_type_bow_val, ans_paths_val, ans_ctx_val]
        ans_comp_key = [ans_type_bow_key, ans_paths_key, ans_ctx_key]
        return ans_comp_val, ans_comp_key

    def enc_ans_features(self, x_type_bow, x_types, x_type_bow_len, x_path_bow, x_paths, x_path_bow_len, x_ctx_ents, x_ctx_ent_len, x_ctx_ent_num):
        '''
        x_types: answer type
        x_paths: answer path, i.e., bow of relation
        x_ctx_ents: answer context, i.e., bow of entity words, (batch_size, num_cands, num_ctx, L)
        '''
        # ans_types = torch.mean(self.ent_type_embed(x_types.view(-1, x_types.size(-1))), 1).view(x_types.size(0), x_types.size(1), -1)
        ans_type_bow = (self.lstm_enc_type(x_type_bow.view(-1, x_type_bow.size(-1)), x_type_bow_len.view(-1))[1]).view(x_type_bow.size(0), x_type_bow.size(1), -1)
        ans_path_bow = (self.lstm_enc_path(x_path_bow.view(-1, x_path_bow.size(-1)), x_path_bow_len.view(-1))[1]).view(x_path_bow.size(0), x_path_bow.size(1), -1)
        ans_paths = torch.mean(self.relation_embed(x_paths.view(-1, x_paths.size(-1))), 1).view(x_paths.size(0), x_paths.size(1), -1)

        # Avg over ctx
        ctx_num_mask = create_mask(x_ctx_ent_num.view(-1), x_ctx_ents.size(2), self.use_cuda).view(x_ctx_ent_num.shape + (-1,))
        ans_ctx_ent = (self.lstm_enc_ctx(x_ctx_ents.view(-1, x_ctx_ents.size(-1)), x_ctx_ent_len.view(-1))[1]).view(x_ctx_ents.size(0), x_ctx_ents.size(1), x_ctx_ents.size(2), -1)
        ans_ctx_ent = ctx_num_mask.unsqueeze(-1) * ans_ctx_ent
        ans_ctx_ent = torch.sum(ans_ctx_ent, dim=2) / torch.clamp(x_ctx_ent_num.float().unsqueeze(-1), min=VERY_SMALL_NUMBER)

        if self.ans_enc_dropout:
            # ans_types = F.dropout(ans_types, p=self.ans_enc_dropout, training=self.training)
            ans_type_bow = F.dropout(ans_type_bow, p=self.ans_enc_dropout, training=self.training)
            ans_path_bow = F.dropout(ans_path_bow, p=self.ans_enc_dropout, training=self.training)
            ans_paths = F.dropout(ans_paths, p=self.ans_enc_dropout, training=self.training)
            ans_ctx_ent = F.dropout(ans_ctx_ent, p=self.ans_enc_dropout, training=self.training)
        return ans_type_bow, None, ans_path_bow, ans_paths, ans_ctx_ent

class SeqEncoder(object):
    """Question Encoder"""
    def __init__(self, vocab_size, embed_size, hidden_size, \
                seq_enc_type='lstm', word_emb_dropout=None,
                cnn_kernel_size=[3], bidirectional=False, \
                shared_embed=None, init_word_embed=None, use_cuda=True):
        if seq_enc_type in ('lstm', 'gru'):
            self.que_enc = EncoderRNN(vocab_size, embed_size, hidden_size, \
                        dropout=word_emb_dropout, \
                        bidirectional=bidirectional, \
                        shared_embed=shared_embed, \
                        init_word_embed=init_word_embed, \
                        rnn_type=seq_enc_type, \
                        use_cuda=use_cuda)

        elif seq_enc_type == 'cnn':
            self.que_enc = EncoderCNN(vocab_size, embed_size, hidden_size, \
                        kernel_size=cnn_kernel_size, dropout=word_emb_dropout, \
                        shared_embed=shared_embed, \
                        init_word_embed=init_word_embed, \
                        use_cuda=use_cuda)
        else:
            raise RuntimeError('Unknown SeqEncoder type: {}'.format(seq_enc_type))

class EncoderRNN(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, dropout=None, \
        bidirectional=False, shared_embed=None, init_word_embed=None, rnn_type='lstm', use_cuda=True):
        super(EncoderRNN, self).__init__()
        if not rnn_type in ('lstm', 'gru'):
            raise RuntimeError('rnn_type is expected to be lstm or gru, got {}'.format(rnn_type))
        if bidirectional:
            print('[ Using bidirectional {} encoder ]'.format(rnn_type))
        else:
            print('[ Using {} encoder ]'.format(rnn_type))
        if bidirectional and hidden_size % 2 != 0:
            raise RuntimeError('hidden_size is expected to be even in the bidirectional mode!')
        self.dropout = dropout
        self.rnn_type = rnn_type
        self.use_cuda = use_cuda
        self.hidden_size = hidden_size // 2 if bidirectional else hidden_size
        self.num_directions = 2 if bidirectional else 1
        self.embed = shared_embed if shared_embed is not None else nn.Embedding(vocab_size, embed_size, padding_idx=0)
        model = nn.LSTM if rnn_type == 'lstm' else nn.GRU
        self.model = model(embed_size, self.hidden_size, 1, batch_first=True, bidirectional=bidirectional)
        if shared_embed is None:
            self.init_weights(init_word_embed)

    def init_weights(self, init_word_embed):
        if init_word_embed is not None:
            print('[ Using pretrained word embeddings ]')
            self.embed.weight.data.copy_(torch.from_numpy(init_word_embed))
        else:
            self.embed.weight.data.uniform_(-0.08, 0.08)

    def forward(self, x, x_len):
        """x: [batch_size * max_length]
           x_len: [batch_size]
        """
        x = self.embed(x)
        if self.dropout:
            x = F.dropout(x, p=self.dropout, training=self.training)

        sorted_x_len, indx = torch.sort(x_len, 0, descending=True)
        x = pack_padded_sequence(x[indx], sorted_x_len.data.tolist(), batch_first=True)

        h0 = to_cuda(torch.zeros(self.num_directions, x_len.size(0), self.hidden_size), self.use_cuda)
        if self.rnn_type == 'lstm':
            c0 = to_cuda(torch.zeros(self.num_directions, x_len.size(0), self.hidden_size), self.use_cuda)
            packed_h, (packed_h_t, _) = self.model(x, (h0, c0))
            if self.num_directions == 2:
                packed_h_t = torch.cat([packed_h_t[i] for i in range(packed_h_t.size(0))], -1)
        else:
            packed_h, packed_h_t = self.model(x, h0)
            if self.num_directions == 2:
                packed_h_t = packed_h_t.transpose(0, 1).contiguous().view(query_lengths.size(0), -1)

        hh, _ = pad_packed_sequence(packed_h, batch_first=True)

        # restore the sorting
        _, inverse_indx = torch.sort(indx, 0)
        restore_hh = hh[inverse_indx]
        restore_packed_h_t = packed_h_t[inverse_indx]
        return restore_hh, restore_packed_h_t


class EncoderCNN(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, kernel_size=[2, 3], \
            dropout=None, shared_embed=None, init_word_embed=None, use_cuda=True):
        super(EncoderCNN, self).__init__()
        print('[ Using CNN encoder with kernel size: {} ]'.format(kernel_size))
        self.use_cuda = use_cuda
        self.dropout = dropout
        self.embed = shared_embed if shared_embed is not None else nn.Embedding(vocab_size, embed_size, padding_idx=0)
        self.cnns = nn.ModuleList([nn.Conv1d(embed_size, hidden_size, kernel_size=k, padding=k-1) for k in kernel_size])

        if len(kernel_size) > 1:
            self.fc = nn.Linear(len(kernel_size) * hidden_size, hidden_size)
        if shared_embed is None:
            self.init_weights(init_word_embed)

    def init_weights(self, init_word_embed):
        if init_word_embed is not None:
            print('[ Using pretrained word embeddings ]')
            self.embed.weight.data.copy_(torch.from_numpy(init_word_embed))
        else:
            self.embed.weight.data.uniform_(-0.08, 0.08)

    def forward(self, x, x_len=None):
        """x: [batch_size * max_length]
           x_len: reserved
        """
        x = self.embed(x)
        if self.dropout:
            x = F.dropout(x, p=self.dropout, training=self.training)
        # Trun(batch_size, seq_len, embed_size) to (batch_size, embed_size, seq_len) for cnn1d
        x = x.transpose(1, 2)
        z = [conv(x) for conv in self.cnns]
        output = [F.max_pool1d(i, kernel_size=i.size(-1)).squeeze(-1) for i in z]

        if len(output) > 1:
            output = self.fc(torch.cat(output, -1))
        else:
            output = output[0]
        return None, output


class Attention(nn.Module):
    def __init__(self, hidden_size, h_state_embed_size=None, in_memory_embed_size=None, atten_type='simple'):
        super(Attention, self).__init__()
        self.atten_type = atten_type
        if not h_state_embed_size:
            h_state_embed_size = hidden_size
        if not in_memory_embed_size:
            in_memory_embed_size = hidden_size
        if atten_type in ('mul', 'add'):
            self.W = torch.Tensor(h_state_embed_size, hidden_size)
            self.W = nn.Parameter(nn.init.xavier_uniform_(self.W))
            if atten_type == 'add':
                self.W2 = torch.Tensor(in_memory_embed_size, hidden_size)
                self.W2 = nn.Parameter(nn.init.xavier_uniform_(self.W2))
                self.W3 = torch.Tensor(hidden_size, 1)
                self.W3 = nn.Parameter(nn.init.xavier_uniform_(self.W3))
        elif atten_type == 'simple':
            pass
        else:
            raise RuntimeError('Unknown atten_type: {}'.format(self.atten_type))

    def forward(self, query_embed, in_memory_embed, atten_mask=None):
        if self.atten_type == 'simple': # simple attention
            attention = torch.bmm(in_memory_embed, query_embed.unsqueeze(2)).squeeze(2)
        elif self.atten_type == 'mul': # multiplicative attention
            attention = torch.bmm(in_memory_embed, torch.mm(query_embed, self.W).unsqueeze(2)).squeeze(2)
        elif self.atten_type == 'add': # additive attention
            attention = torch.tanh(torch.mm(in_memory_embed.view(-1, in_memory_embed.size(-1)), self.W2)\
                .view(in_memory_embed.size(0), -1, self.W2.size(-1)) \
                + torch.mm(query_embed, self.W).unsqueeze(1))
            attention = torch.mm(attention.view(-1, attention.size(-1)), self.W3).view(attention.size(0), -1)
        else:
            raise RuntimeError('Unknown atten_type: {}'.format(self.atten_type))

        if atten_mask is not None:
            # Exclude masked elements from the softmax
            attention = atten_mask * attention - (1 - atten_mask) * INF
        return attention

class SelfAttention_CoAtt(nn.Module):
    def __init__(self, hidden_size, use_cuda=True):
        super(SelfAttention_CoAtt, self).__init__()
        self.use_cuda = use_cuda
        self.hidden_size = hidden_size
        self.model = nn.LSTM(2 * hidden_size, hidden_size // 2, batch_first=True, bidirectional=True)

    def forward(self, x, x_len, atten_mask):
        CoAtt = torch.bmm(x, x.transpose(1, 2))
        CoAtt = atten_mask.unsqueeze(1) * CoAtt - (1 - atten_mask).unsqueeze(1) * INF
        CoAtt = torch.softmax(CoAtt, dim=-1)
        new_x = torch.cat([torch.bmm(CoAtt, x), x], -1)

        sorted_x_len, indx = torch.sort(x_len, 0, descending=True)
        new_x = pack_padded_sequence(new_x[indx], sorted_x_len.data.tolist(), batch_first=True)

        h0 = to_cuda(torch.zeros(2, x_len.size(0), self.hidden_size // 2), self.use_cuda)
        c0 = to_cuda(torch.zeros(2, x_len.size(0), self.hidden_size // 2), self.use_cuda)
        packed_h, (packed_h_t, _) = self.model(new_x, (h0, c0))

        # restore the sorting
        _, inverse_indx = torch.sort(indx, 0)
        packed_h_t = torch.cat([packed_h_t[i] for i in range(packed_h_t.size(0))], -1)
        restore_packed_h_t = packed_h_t[inverse_indx]
        output = restore_packed_h_t
        return output

def create_mask(x, N, use_cuda=True):
    x = x.data
    mask = np.zeros((x.size(0), N))
    for i in range(x.size(0)):
        mask[i, :x[i]] = 1
    return to_cuda(torch.Tensor(mask), use_cuda)


================================================
FILE: src/core/bamnet/utils.py
================================================
'''
Created on Oct, 2017

@author: hugo

'''
import torch
from torch.autograd import Variable
import numpy as np


def to_cuda(x, use_cuda=True):
    if use_cuda and torch.cuda.is_available():
        x = x.cuda()
    return x

# One pass over the dataset
def next_batch(memories, queries, query_words, raw_queries, query_mentions, query_lengths, gold_ans_inds, batch_size):
    for i in range(0, len(memories), batch_size):
        yield (memories[i: i + batch_size], queries[i: i + batch_size], query_words[i: i + batch_size], raw_queries[i: i + batch_size], query_mentions[i: i + batch_size], query_lengths[i: i + batch_size]), gold_ans_inds[i: i + batch_size]

# One pass over the dataset
def next_ent_batch(memories, queries, query_lengths, gold_inds, batch_size):
    for i in range(0, len(memories), batch_size):
        yield (memories[i: i + batch_size], queries[i: i + batch_size], query_lengths[i: i + batch_size]), gold_inds[i: i + batch_size]


================================================
FILE: src/core/build_data/__init__.py
================================================
'''
Created on Oct, 2017

@author: hugo

'''


================================================
FILE: src/core/build_data/build_all.py
================================================
'''
Created on Oct, 2017

@author: hugo

'''
import os

from . import utils as build_utils
from ..utils.utils import *
from .build_data import build_vocab, build_data


def build(dpath, version=None, out_dir=None):
    if not build_utils.built(dpath, version_string=version):
        raise RuntimeError("Please build/preprocess the data by running the build_all_data.py script!")


================================================
FILE: src/core/build_data/build_data.py
================================================
'''
Created on Sep, 2017

@author: hugo

'''
import os
import math
import argparse
from itertools import count
from rapidfuzz import fuzz, process
from collections import defaultdict

from ..utils.utils import *
from ..utils.generic_utils import normalize_answer, unique
from ..utils.freebase_utils import if_filterout
from .. import config


IGNORE_DUMMY = True
ENT_TYPE_HOP = 1
# Entity mention types: 'NP', 'ORGANIZATION', 'DATE', 'NUMBER', 'MISC', 'ORDINAL', 'DURATION', 'PERSON', 'TIME', 'LOCATION'

def build_kb_data(kb, used_fbkeys=None):
    entities = defaultdict(int)
    entity_types = defaultdict(int)
    relations = defaultdict(int)
    vocabs = defaultdict(int)
    if not used_fbkeys:
        used_fbkeys = kb.keys()
    for k in used_fbkeys:
        if not k in kb:
            continue
        v = kb[k]
        entities[v['id']] += 1
        # We prefer notable_types than type since they are more representative.
        # If notable_types are not available, we use only the first available type.
        # We found the type field contains much noise.
        selected_types = (v['notable_types'] + v['type'])[:ENT_TYPE_HOP]
        for ent_type in selected_types:
            entity_types[ent_type] += 1
        for token in [y for x in selected_types for y in x.lower().split('/')[-1].split('_')]:
            vocabs[token] += 1
        # Add entity vocabs
        selected_names = v['name'][:1] + v['alias'] # We need all topic entity alias
        for token in [y for x in selected_names for y in tokenize(x.lower())]:
            vocabs[token] += 1
        if not 'neighbors' in v:
            continue
        for kk, vv in v['neighbors'].items(): # 1st hop
            if if_filterout(kk):
                continue
            relations[kk] += 1
            # Add relation vocabs
            for token in [x for x in kk.lower().split('/')[-1].split('_')]:
                vocabs[token] += 1
            for nbr in vv:
                if isinstance(nbr, str):
                    for token in [y for y in tokenize(nbr.lower())]:
                        vocabs[token] += 1
                    continue
                elif isinstance(nbr, bool):
                    continue
                elif isinstance(nbr, float):
                    continue
                    # vocabs.update([y for y in tokenize(str(nbr).lower())])
                elif isinstance(nbr, dict):
                    nbr_k = list(nbr.keys())[0]
                    nbr_v = nbr[nbr_k]
                    entities[nbr_k] += 1
                    selected_types = (nbr_v['notable_types'] + nbr_v['type'])[:ENT_TYPE_HOP]
                    for ent_type in selected_types:
                        entity_types[ent_type] += 1
                    selected_names = (nbr_v['name'] + nbr_v['alias'])[:1]
                    for token in [y for x in selected_names for y in tokenize(x.lower())] + \
                        [y for x in selected_types for y in x.lower().split('/')[-1].split('_')]:
                        vocabs[token] += 1
                    if not 'neighbors' in nbr_v:
                        continue
                    for kkk, vvv in nbr_v['neighbors'].items(): # 2nd hop
                        if if_filterout(kkk):
                            continue
                        relations[kkk] += 1
                        # Add relation vocabs
                        for token in [x for x in kkk.lower().split('/')[-1].split('_')]:
                            vocabs[token] += 1
                        for nbr_nbr in vvv:
                            if isinstance(nbr_nbr, str):
                                for token in [y for y in tokenize(nbr_nbr.lower())]:
                                    vocabs[token] += 1
                                continue
                            elif isinstance(nbr_nbr, bool):
                                continue
                            elif isinstance(nbr_nbr, float):
                                # vocabs.update([y for y in tokenize(str(nbr_nbr).lower())])
                                continue
                            elif isinstance(nbr_nbr, dict):
                                nbr_nbr_k = list(nbr_nbr.keys())[0]
                                nbr_nbr_v = nbr_nbr[nbr_nbr_k]
                                entities[nbr_nbr_k] += 1
                                selected_types = (nbr_nbr_v['notable_types'] + nbr_nbr_v['type'])[:ENT_TYPE_HOP]
                                for ent_type in selected_types:
                                    entity_types[ent_type] += 1
                                selected_names = (nbr_nbr_v['name'] + nbr_nbr_v['alias'])[:1]
                                for token in [y for x in selected_names for y in tokenize(x.lower())] + \
                                    [y for x in selected_types for y in x.lower().split('/')[-1].split('_')]:
                                    vocabs[token] += 1
                            else:
                                raise RuntimeError('Unknown type: %s' % type(nbr_nbr))
                else:
                    raise RuntimeError('Unknown type: %s' % type(nbr))
    return (entities, entity_types, relations, vocabs)

def build_qa_vocab(qa):
    vocabs = defaultdict(int)
    for each in qa:
        for token in tokenize(each['qText'].lower()):
            vocabs[token] += 1
    return vocabs

def delex_query_topic_ent(query, topic_ent, ent_types):
    query = tokenize(query.lower())
    if topic_ent == '':
        return query, None

    ent_type_dict = {}
    for ent, type_ in ent_types:
        if ent not in ent_type_dict:
            ent_type_dict[ent] = type_
        else:
            if ent_type_dict[ent] == 'NP':
                ent_type_dict[ent] = type_

    ret = process.extract(topic_ent.replace('_', ' '), set(list(zip(*ent_types))[0]), scorer=fuzz.token_sort_ratio)
    if len(ret) == 0:
        return query, None

    # We prefer Non-NP entity mentions
    # e.g., we prefer `uk` than `people in the uk` when matching `united_kingdom`
    topic_men = None
    topic_score = None
    for token, score in ret:
        if ent_type_dict[token].lower() in config.topic_mention_types:
            topic_men = token
            topic_score = score
            break

    if topic_men is None:
        return query, None

    topic_ent_type = ent_type_dict[topic_men].lower()
    topic_tokens = tokenize(topic_men.lower())
    indices = [i for i, x in enumerate(query) if x == topic_tokens[0]]
    for i in indices:
        if query[i: i + len(topic_tokens)] == topic_tokens:
            start_idx = i
            end_idx = i + len(topic_tokens)
            break
    query_template = query[:start_idx] + [topic_ent_type] + query[end_idx:]
    return query_template, topic_men

def delex_query(query, ent_mens, mention_types):
    for men, type_ in ent_mens:
        type_ = type_.lower()
        if type_ in mention_types:
            men = tokenize(men.lower())
            indices = [i for i, x in enumerate(query) if x == men[0]]
            start_idx = None
            for i in indices:
                if query[i: i + len(men)] == men:
                    start_idx = i
                    end_idx = i + len(men)
                    break
            if start_idx is not None:
                query = query[:start_idx] + ['__{}__'.format(type_)] + query[end_idx:]
    return query

def build_data(qa, kb, entity2id, entityType2id, relation2id, vocab2id, pred_seed_ents=None):
    queries = []
    raw_queries = []
    query_mentions = []
    memories = []
    cand_labels = [] # Candidate answer labels (i.e., names)
    gold_ans_labels = [] # True gold answer labels
    gold_ans_inds = [] # The "gold" answer indices corresponding to the cand list
    for qid, each in enumerate(qa):
        freebase_key = each['freebaseKey'] if not pred_seed_ents else pred_seed_ents[qid]
        if isinstance(freebase_key, list):
            freebase_key = freebase_key[0] if len(freebase_key) > 0 else ''
        # Convert query to query template
        query, topic_men = delex_query_topic_ent(each['qText'], freebase_key, each['entities'])
        query2 = delex_query(query, each['entities'], config.delex_mention_types)
        q = [vocab2id[x] if x in vocab2id else config.RESERVED_TOKENS['UNK'] for x in query2]
        queries.append(q)
        raw_queries.append(query)

        query_mentions.append([(tokenize(x[0].lower()), x[1].lower()) for x in each['entities'] if topic_men != x[0]])
        gold_ans_labels.append(each['answers'])

        if not freebase_key in kb:
            gold_ans_inds.append([])
            memories.append([[]] * 8)
            cand_labels.append([])
            continue

        ans_cands = build_ans_cands(kb[freebase_key], entity2id, entityType2id, relation2id, vocab2id)
        memories.append(ans_cands[:-1])
        cand_labels.append(ans_cands[-1])
        if len(ans_cands[0]) == 0:
            gold_ans_inds.append([])
            continue

        norm_cand_labels = [normalize_answer(x) for x in ans_cands[-1]]
        tmp_cand_inds = []
        for a in each['answers']:
            a = normalize_answer(a)
            # Find all the candidiate answers which match the gold answer.
            inds = [i for i, j in zip(count(), norm_cand_labels) if j == a]
            tmp_cand_inds.extend(inds)
        # Note that tmp_cand_inds can be empty in which case
        # the question can *NOT* be answered by this KB entity.
        gold_ans_inds.append(tmp_cand_inds)
    return (queries, raw_queries, query_mentions, memories, cand_labels, gold_ans_inds, gold_ans_labels)

def build_vocab(data, freebase, used_fbkeys=None, min_freq=1):
    entities, entity_types, relations, kb_vocabs = build_kb_data(freebase, used_fbkeys)

    # Entity
    all_entities = set({ent for ent in entities if entities[ent] >= min_freq})
    entity2id = dict(zip(all_entities, range(len(config.RESERVED_ENTS), len(all_entities) + len(config.RESERVED_ENTS))))
    for ent, idx in config.RESERVED_ENTS.items():
        entity2id.update({ent: idx})

    # Entity type
    all_ent_types = set({ent_type for ent_type in entity_types if entity_types[ent_type] >= min_freq})
    all_ent_types.update(config.extra_ent_types)
    entityType2id = dict(zip(all_ent_types, range(len(config.RESERVED_ENT_TYPES), len(all_ent_types) + len(config.RESERVED_ENT_TYPES))))
    for ent_type, idx in config.RESERVED_ENT_TYPES.items():
        entityType2id.update({ent_type: idx})

    # Relation
    all_relations = set({rel for rel in relations if relations[rel] >= min_freq})
    all_relations.update(config.extra_rels)
    relation2id = dict(zip(all_relations, range(len(config.RESERVED_RELS), len(all_relations) + len(config.RESERVED_RELS))))
    for rel, idx in config.RESERVED_RELS.items():
        relation2id.update({rel: idx})

    # Vocab
    vocabs = build_qa_vocab(data)
    for token, count in kb_vocabs.items():
        vocabs[token] += count
    # sorted_vocabs = sorted(vocabs.items(), key=lambda d:d[1], reverse=True)
    all_tokens = set({token for token in vocabs if vocabs[token] >= min_freq})
    all_tokens.update(config.extra_vocab_tokens)
    vocab2id = dict(zip(all_tokens, range(len(config.RESERVED_TOKENS), len(all_tokens) + len(config.RESERVED_TOKENS))))
    for token, idx in config.RESERVED_TOKENS.items():
        vocab2id.update({token: idx})

    print('Num of entities: %s' % len(entity2id))
    print('Num of entity_types: %s' % len(entityType2id))
    print('Num of relations: %s' % len(relation2id))
    print('Num of vocabs: %s' % len(vocab2id))
    return entity2id, entityType2id, relation2id, vocab2id

def build_ans_cands(graph, entity2id, entityType2id, relation2id, vocab2id):
    cand_ans_bows = [] # bow of answer entity
    cand_ans_entities = [] # answer entity
    cand_ans_types = [] # type of answer entity
    cand_ans_type_bows = [] # bow of answer entity type
    cand_ans_paths = [] # relation path from topic entity to answer entity
    cand_ans_path_bows = []
    cand_ans_ctx = [] # context (i.e., 1-hop entity bows and relation bows) connects to the answer path
    cand_ans_topic_key_type = [] # topic key entity type
    cand_labels = [] # candidiate answers

    selected_types = (graph['notable_types'] + graph['type'])[:ENT_TYPE_HOP]
    topic_key_ent_type_bows = [vocab2id[x] if x in vocab2id else config.RESERVED_TOKENS['UNK'] for y in selected_types for x in y.lower().split('/')[-1].split('_')]
    topic_key_ent_type = [entityType2id[x] if x in entityType2id else config.RESERVED_ENT_TYPES['UNK'] for x in selected_types]

    # We only consider the alias relations of topic entityies
    for each in graph['alias']:
        cand_ans_topic_key_type.append([topic_key_ent_type_bows, topic_key_ent_type])
        ent_bow = [vocab2id[y] if y in vocab2id else config.RESERVED_TOKENS['UNK'] for y in tokenize(each.lower())]
        cand_ans_bows.append(ent_bow)
        cand_ans_entities.append(config.RESERVED_ENTS['PAD'])
        cand_ans_types.append([])
        cand_ans_type_bows.append([])
        cand_ans_paths.append([relation2id['alias'] if 'alias' in relation2id else config.RESERVED_RELS['UNK']])
        cand_ans_path_bows.append([vocab2id['alias']])
        # We do not count the topic_entity as context since it is trivial
        cand_ans_ctx.append([[], []])
        cand_labels.append(each)

    if len(cand_labels) == 0 and (not 'neighbors' in graph or len(graph['neighbors']) == 0):
        return ([], [], [], [], [], [], [], [], [])

    for k, v in graph['neighbors'].items():
        if if_filterout(k):
            continue
        k_bow = [vocab2id[x] if x in vocab2id else config.RESERVED_TOKENS['UNK'] for x in k.lower().split('/')[-1].split('_')]
        for nbr in v:
            if isinstance(nbr, str):
                cand_ans_topic_key_type.append([topic_key_ent_type_bows, topic_key_ent_type])
                ent_bow = [vocab2id[y] if y in vocab2id else config.RESERVED_TOKENS['UNK'] for y in tokenize(nbr.lower())]
                cand_ans_bows.append(ent_bow)
                cand_ans_entities.append(config.RESERVED_ENTS['PAD'])
                cand_ans_types.append([])
                cand_ans_type_bows.append([])
                cand_ans_paths.append([relation2id[k] if k in relation2id else config.RESERVED_RELS['UNK']])
                cand_ans_path_bows.append(k_bow)
                cand_ans_ctx.append([[], []])
                cand_labels.append(nbr)
                continue
            elif isinstance(nbr, bool):
                cand_ans_topic_key_type.append([topic_key_ent_type_bows, topic_key_ent_type])
                cand_ans_bows.append([vocab2id['true' if nbr else 'false']])
                cand_ans_entities.append(config.RESERVED_ENTS['PAD'])
                cand_ans_types.append([entityType2id['bool']])
                cand_ans_type_bows.append([vocab2id['bool']])
                cand_ans_paths.append([relation2id[k] if k in relation2id else config.RESERVED_RELS['UNK']])
                cand_ans_path_bows.append(k_bow)
                cand_ans_ctx.append([[], []])
                cand_labels.append('true' if nbr else 'false')
                continue
            elif isinstance(nbr, float):
                cand_ans_topic_key_type.append([topic_key_ent_type_bows, topic_key_ent_type])
                cand_ans_bows.append([vocab2id[str(nbr)] if str(nbr) in vocab2id else config.RESERVED_TOKENS['UNK']])
                cand_ans_entities.append(config.RESERVED_ENTS['PAD'])
                cand_ans_types.append([entityType2id['num']])
                cand_ans_type_bows.append([vocab2id['num']])
                cand_ans_paths.append([relation2id[k] if k in relation2id else config.RESERVED_RELS['UNK']])
                cand_ans_path_bows.append(k_bow)
                cand_ans_ctx.append([[], []])
                cand_labels.append(str(nbr))
                continue
            elif isinstance(nbr, dict):
                nbr_k = list(nbr.keys())[0]
                nbr_v = nbr[nbr_k]
                selected_names = (nbr_v['name'] + nbr_v['alias'])[:1]
                is_dummy = True
                if not IGNORE_DUMMY or len(selected_names) > 0: # Otherwise, it is an intermediate (dummpy) node
                    cand_ans_topic_key_type.append([topic_key_ent_type_bows, topic_key_ent_type])
                    nbr_k_bow = [vocab2id[y] if y in vocab2id else config.RESERVED_TOKENS['UNK'] for x in selected_names for y in tokenize(x.lower())]
                    cand_ans_bows.append(nbr_k_bow)
                    cand_ans_entities.append(entity2id[nbr_k] if nbr_k in entity2id else config.RESERVED_ENTS['UNK'])
                    selected_types = (nbr_v['notable_types'] + nbr_v['type'])[:ENT_TYPE_HOP]
                    cand_ans_types.append([entityType2id[x] if x in entityType2id else config.RESERVED_ENT_TYPES['UNK'] for x in selected_types])
                    cand_ans_type_bows.append([vocab2id[x] if x in vocab2id else config.RESERVED_TOKENS['UNK'] for y in selected_types for x in y.lower().split('/')[-1].split('_')])
                    cand_ans_paths.append([relation2id[k] if k in relation2id else config.RESERVED_RELS['UNK']])
                    cand_ans_path_bows.append(k_bow)
                    cand_labels.append(selected_names[0] if len(selected_names) > 0 else 'UNK')
                    is_dummy = False

                if not 'neighbors' in nbr_v:
                    if not is_dummy:
                        cand_ans_ctx.append([[], []])
                    continue

                rels = []
                labels = []
                all_ctx = [set(), set()]
                for kk, vv in nbr_v['neighbors'].items(): # 2nd hop
                    if if_filterout(kk):
                        continue
                    kk_bow = [vocab2id[x] if x in vocab2id else config.RESERVED_TOKENS['UNK'] for x in kk.lower().split('/')[-1].split('_')]
                    all_ctx[1].add(kk)
                    for nbr_nbr in vv:
                        if isinstance(nbr_nbr, str):
                            cand_ans_topic_key_type.append([topic_key_ent_type_bows, topic_key_ent_type])
                            ent_bow = [vocab2id[y] if y in vocab2id else config.RESERVED_TOKENS['UNK'] for y in tokenize(nbr_nbr.lower())]
                            cand_ans_bows.append(ent_bow)
                            cand_ans_entities.append(config.RESERVED_ENTS['PAD'])
                            cand_ans_types.append([])
                            cand_ans_type_bows.append([])
                            cand_ans_paths.append([relation2id[k] if k in relation2id else config.RESERVED_RELS['UNK'], relation2id[kk] if kk in relation2id else config.RESERVED_RELS['UNK']])
                            cand_ans_path_bows.append(kk_bow + k_bow)
                            labels.append(nbr_nbr)
                            all_ctx[0].add(nbr_nbr)
                            rels.append(kk)
                            continue
                        elif isinstance(nbr_nbr, bool):
                            cand_ans_topic_key_type.append([topic_key_ent_type_bows, topic_key_ent_type])
                            cand_ans_bows.append([vocab2id['true' if nbr_nbr else 'false']])
                            cand_ans_entities.append(config.RESERVED_ENTS['PAD'])
                            cand_ans_types.append([entityType2id['bool']])
                            cand_ans_type_bows.append([vocab2id['bool']])
                            cand_ans_paths.append([relation2id[k] if k in relation2id else config.RESERVED_RELS['UNK'], relation2id[kk] if kk in relation2id else config.RESERVED_RELS['UNK']])
                            cand_ans_path_bows.append(kk_bow + k_bow)
                            labels.append('true' if nbr_nbr else 'false')
                            all_ctx[0].add('true' if nbr_nbr else 'false')
                            rels.append(kk)
                            continue
                        elif isinstance(nbr_nbr, float):
                            cand_ans_topic_key_type.append([topic_key_ent_type_bows, topic_key_ent_type])
                            cand_ans_bows.append([vocab2id[str(nbr_nbr)] if str(nbr_nbr) in vocab2id else config.RESERVED_TOKENS['UNK']])
                            cand_ans_entities.append(config.RESERVED_ENTS['PAD'])
                            cand_ans_types.append([entityType2id['num']])
                            cand_ans_type_bows.append([vocab2id['num']])
                            cand_ans_paths.append([relation2id[k] if k in relation2id else config.RESERVED_RELS['UNK'], relation2id[kk] if kk in relation2id else config.RESERVED_RELS['UNK']])
                            cand_ans_path_bows.append(kk_bow + k_bow)
                            labels.append(str(nbr_nbr))
                            all_ctx[0].add(str(nbr_nbr))
                            rels.append(kk)
                            continue
                        elif isinstance(nbr_nbr, dict):
                            nbr_nbr_k = list(nbr_nbr.keys())[0]
                            nbr_nbr_v = nbr_nbr[nbr_nbr_k]
                            selected_names = (nbr_nbr_v['name'] + nbr_nbr_v['alias'])[:1]
                            if not IGNORE_DUMMY or len(selected_names) > 0:
                                cand_ans_topic_key_type.append([topic_key_ent_type_bows, topic_key_ent_type])
                                ent_bow = [vocab2id[y] if y in vocab2id else config.RESERVED_TOKENS['UNK'] for x in selected_names for y in tokenize(x.lower())]
                                cand_ans_bows.append(ent_bow)
                                cand_ans_entities.append(entity2id[nbr_nbr_k] if nbr_nbr_k in entity2id else config.RESERVED_ENTS['UNK'])
                                selected_types = (nbr_nbr_v['notable_types'] + nbr_nbr_v['type'])[:ENT_TYPE_HOP]
                                cand_ans_types.append([entityType2id[x] if x in entityType2id else config.RESERVED_ENT_TYPES['UNK'] for x in selected_types])
                                cand_ans_type_bows.append([vocab2id[x] if x in vocab2id else config.RESERVED_TOKENS['UNK'] for y in selected_types for x in y.lower().split('/')[-1].split('_')])
                                cand_ans_paths.append([relation2id[k] if k in relation2id else config.RESERVED_RELS['UNK'], relation2id[kk] if kk in relation2id else config.RESERVED_RELS['UNK']])
                                cand_ans_path_bows.append(kk_bow + k_bow)
                                labels.append(selected_names[0] if len(selected_names) > 0 else 'UNK')
                                if len(selected_names) > 0:
                                    all_ctx[0].add(selected_names[0])
                                rels.append(kk)
                        else:
                            raise RuntimeError('Unknown type: %s' % type(nbr_nbr))

                assert len(labels) == len(rels)
                if not is_dummy:
                    ctx_ent_bow = [tokenize(x.lower()) for x in all_ctx[0]]
                    # ctx_rel_bow = list(set([vocab2id[y] for x in all_ctx[1] for y in x.lower().split('/')[-1].split('_') if y in vocab2id]))
                    ctx_rel_bow = []
                    cand_ans_ctx.append([ctx_ent_bow, ctx_rel_bow])
                for i in range(len(labels)):
                    tmp_ent_names = all_ctx[0] - set([labels[i]])
                    # tmp_rel_names = all_ctx[1] - set([rels[i]])
                    ctx_ent_bow = [tokenize(x.lower()) for x in tmp_ent_names]
                    # ctx_rel_bow = list(set([vocab2id[y] for x in tmp_rel_names for y in x.lower().split('/')[-1].split('_') if y in vocab2id]))
                    ctx_rel_bow = []
                    cand_ans_ctx.append([ctx_ent_bow, ctx_rel_bow])
                cand_labels.extend(labels)
            else:
                raise RuntimeError('Unknown type: %s' % type(nbr))

    assert len(cand_ans_bows) == len(cand_ans_entities) == len(cand_ans_types) == len(cand_ans_type_bows) == len(cand_ans_paths) \
            == len(cand_ans_ctx) == len(cand_labels) == len(cand_ans_topic_key_type) == len(cand_ans_path_bows)
    return (cand_ans_bows, cand_ans_entities, cand_ans_type_bows, cand_ans_types, cand_ans_path_bows, cand_ans_paths, cand_ans_ctx, cand_ans_topic_key_type, cand_labels)


# Build seed entity candidates for topic entity classification
def build_seed_ent_data(qa, kb, entity2id, entityType2id, relation2id, vocab2id, topn, dtype):
    queries = []
    seed_ent_features = []
    seed_ent_labels = []
    seed_ent_inds = []
    for each in qa:
        query = tokenize(each['qText'].lower())
        q = [vocab2id[x] if x in vocab2id else config.RESERVED_TOKENS['UNK'] for x in query]
        queries.append(q)
        tmp_features = []
        tmp_labels = []
        tmp_inds = []
        for i, freebase_key in enumerate(each['freebaseKeyCands'][:topn]):
            tmp_labels.append(freebase_key)
            if freebase_key == each['freebaseKey']:
                tmp_inds.append(i)

            if freebase_key in kb:
                features = build_seed_entity_feature(freebase_key, kb[freebase_key], entity2id, entityType2id, relation2id, vocab2id)
                tmp_features.append(features)
            else:
                tmp_features.append([[]] * 5)

        if dtype == 'test':
            if len(tmp_inds) == 0: # No answer
                tmp_inds.append(-1)
        else:
            assert len(tmp_labels) == topn

        assert len(tmp_inds) == 1
        seed_ent_features.append(list(zip(*tmp_features)))
        seed_ent_labels.append(tmp_labels)
        seed_ent_inds.append(tmp_inds)
    return (queries, seed_ent_features, seed_ent_labels, seed_ent_inds)

def build_seed_entity_feature(seed_ent, graph, entity2id, entityType2id, relation2id, vocab2id):
    # candidate seed entity features:
    # entity name
    # entity type
    # entity neighboring relations
    selected_names = (graph['name'] + graph['alias'])[:1]
    seed_ent_name = [vocab2id[y] if y in vocab2id else config.RESERVED_TOKENS['UNK'] for x in selected_names for y in tokenize(x.lower())]
    selected_types = (graph['notable_types'] + graph['type'])[:ENT_TYPE_HOP]
    seed_ent_type_name = [vocab2id[x] if x in vocab2id else config.RESERVED_TOKENS['UNK'] for y in selected_types for x in y.lower().split('/')[-1].split('_')]
    seed_ent_type = [entityType2id[x] if x in entityType2id else config.RESERVED_ENT_TYPES['UNK'] for x in selected_types]
    seed_rel_names = []
    seed_rels = []

    for k in graph['neighbors']:
        if if_filterout(k):
            continue
        k_bow = [vocab2id[x] if x in vocab2id else config.RESERVED_TOKENS['UNK'] for x in k.lower().split('/')[-1].split('_')]
        seed_rel_names.append(k_bow)
        seed_rels.append(relation2id[k] if k in relation2id else config.RESERVED_RELS['UNK'])
    return (seed_ent_name, seed_ent_type_name, seed_ent_type, seed_rel_names, seed_rels)


================================================
FILE: src/core/build_data/freebase.py
================================================
'''
Created on Sep, 2017

@author: hugo

'''
import os

from ..utils.utils import *


def fetch_meta(path):
    try:
        data = load_gzip_json(path)
    except:
        return {}
    content = {}
    properties = data['property']
    if '/type/object/name' in properties:
        content['name'] = [x['value'] for x in properties['/type/object/name']['values']]
    else:
        content['name'] = []
    if '/common/topic/alias' in properties:
        content['alias'] = [x['value'] for x in properties['/common/topic/alias']['values']]
    else:
        content['alias'] = []
    if '/common/topic/notable_types' in properties:
        content['notable_types'] = [x['id'] for x in properties['/common/topic/notable_types']['values']]
    else:
        content['notable_types'] = []
    if '/type/object/type' in properties:
        content['type'] = [x['id'] for x in properties['/type/object/type']['values']]
    else:
        content['type'] = []
    return content

def fetch(data, data_dir):
    if not 'id' in data:
        return data['value']
    mid = data['id']
    # meta data might not be in the subgraph, get it from target files
    meta = fetch_meta(os.path.join(data_dir, '{}.json.gz'.format(mid.strip('/').replace('/', '.'))))
    if meta == {}:
        if not 'property' in data:
            if 'text' in data:
                return data['text']
            else:
                import pdb;pdb.set_trace()
        properties = data['property']
        if '/type/object/name' in properties:
            meta['name'] = [x['value'] for x in properties['/type/object/name']['values']]
        else:
            meta['name'] = []
        if '/common/topic/alias' in properties:
            meta['alias'] = [x['value'] for x in properties['/common/topic/alias']['values']]
        else:
            meta['alias'] = []
        if '/common/topic/notable_types' in properties:
            meta['notable_types'] = [x['id'] for x in properties['/common/topic/notable_types']['values']]
        else:
            meta['notable_types'] = []
        if '/type/object/type' in properties:
            meta['type'] = [x['id'] for x in properties['/type/object/type']['values']]
        else:
            meta['type'] = []
    graph = {mid: meta}
    if not 'property' in data: # we stop at the 2nd hop
        return graph
    properties = data['property']
    neighbors = {}
    for k, v in properties.items():
        if k.startswith('/common') or k.startswith('/type') \
            or k.startswith('/freebase') or k.startswith('/user') \
            or k.startswith('/imdb'):
            continue
        if len(v['values']) > 0:
            neighbors[k] = []
            for nbr in v['values']:
                nbr_graph = fetch(nbr, data_dir)
                neighbors[k].append(nbr_graph)
    graph[mid]['neighbors'] = neighbors
    return graph


================================================
FILE: src/core/build_data/utils.py
================================================
'''
Created on Sep, 2017

@author: hugo

'''
import os
import datetime
import shutil
from collections import defaultdict
import numpy as np
from scipy.sparse import *

RESERVED_TOKENS = {'PAD': 0, 'UNK': 1}


def built(path, version_string=None):
    """Checks if 'built.log' flag has been set for that task.
    If a version_string is provided, this has to match, or the version
    is regarded as not built.
    """
    if version_string:
        fname = os.path.join(path, 'built.log')
        if not os.path.isfile(fname):
            return False
        else:
            with open(fname, 'r') as read:
                text = read.read().split('\n')
            return (len(text) > 1 and text[1] == version_string)
    else:
        return os.path.isfile(os.path.join(path, 'built.log'))

def mark_done(path, version_string=None):
    """Marks the path as done by adding a 'built.log' file with the current
    timestamp plus a version description string if specified.
    """
    with open(os.path.join(path, 'built.log'), 'w') as write:
        write.write(str(datetime.datetime.today()))
        if version_string:
            write.write('\n' + version_string)

def make_dir(path):
    """Makes the directory and any nonexistent parent directories."""
    os.makedirs(path, exist_ok=True)

def remove_dir(path):
    """Removes the given directory, if it exists."""
    shutil.rmtree(path, ignore_errors=True)

def vectorize_data(queries, query_mentions, memories, max_query_size=None, max_query_markup_size=None, max_mem_size=None, \
                max_ans_bow_size=None, max_ans_type_bow_size=None, max_ans_path_bow_size=None, max_ans_path_size=None, \
                max_ans_ctx_entity_bows_size=None, max_ans_ctx_relation_bows_size=1, \
                verbose=True, fixed_size=False, vocab2id=None):
    cand_ans_bows, cand_ans_entities, cand_ans_type_bows, cand_ans_types, cand_ans_path_bows, cand_ans_paths, cand_ans_ctx, cand_ans_topic_key = zip(*memories)
    cand_ans_size = min(max(map(len, (x for x in cand_ans_entities)), default=0), max_mem_size if max_mem_size else float('inf'))
    if fixed_size:
        query_size = max_query_size
        # query_markup_size = max_query_markup_size
        cand_ans_bows_size = max_ans_bow_size
        cand_ans_type_bows_size = max_ans_type_bow_size
        cand_ans_path_bows_size = max_ans_path_bow_size
        cand_ans_paths_size = max_ans_path_size
    else:
        query_size = max(min(max(map(len, queries), default=0), max_query_size if max_query_size else float('inf')), 1)
        # query_markup_size = max(min(max(map(len, query_mentions), default=0), max_query_markup_size if max_query_markup_size else float('inf')), 1)
        cand_ans_bows_size = max(min(max(map(len, (y for x in cand_ans_bows for y in x)), default=0), max_ans_bow_size if max_ans_bow_size else float('inf')), 1)
        cand_ans_type_bows_size = max(min(max(map(len, (y for x in cand_ans_type_bows for y in x)), default=0), max_ans_type_bow_size if max_ans_type_bow_size else float('inf')), 1)
        cand_ans_path_bows_size = max(min(max(map(len, (y for x in cand_ans_path_bows for y in x)), default=0), max_ans_path_bow_size if max_ans_path_bow_size else float('inf')), 1)
        cand_ans_paths_size = max(min(max(map(len, (y for x in cand_ans_paths for y in x)), default=0), max_ans_path_size if max_ans_path_size else float('inf')), 1)
    cand_ans_types_size = max(max(map(len, (y for x in cand_ans_types for y in x)), default=0), 1)
    cand_ans_ctx_entity_bows_size = max(min(max(map(len, (z for x in cand_ans_ctx for y in x for z in y[0])), default=0), max_ans_ctx_entity_bows_size if max_ans_ctx_entity_bows_size else float('inf')), 1)
    cand_ans_ctx_relation_bows_size = max(min(max(map(len, (y[1] for x in cand_ans_ctx for y in x)), default=0), max_ans_ctx_relation_bows_size if max_ans_ctx_relation_bows_size else float('inf')), 1)
    cand_ans_topic_key_ent_type_bows_size = max(max(map(len, (y[0] for x in cand_ans_topic_key for y in x)), default=0), 1)
    cand_ans_topic_key_ent_types_size = max(max(map(len, (y[1] for x in cand_ans_topic_key for y in x)), default=0), 1)

    if verbose:
        print('\nquery_size: {}, cand_ans_size: {}, cand_ans_bows_size: {}, '
            'cand_ans_type_bows_size: {}, cand_ans_types_size: {}, cand_ans_path_bows_size: {}, cand_ans_paths_size: {}, '
            'cand_ans_ctx_entity_bows_size: {}, cand_ans_topic_key_ent_types_size: {}'\
            .format(query_size, cand_ans_size, cand_ans_bows_size, cand_ans_type_bows_size, \
            cand_ans_types_size, cand_ans_path_bows_size, cand_ans_paths_size, cand_ans_ctx_entity_bows_size, \
            cand_ans_topic_key_ent_types_size))

    # Question word
    qw_tokens = ["which", "what", "who", "whose", "whom", "where", "when", "how", "why", "whether"]
    qw_vids = [vocab2id[each] for each in qw_tokens if each in vocab2id]
    qw_vid2id = dict(zip(qw_vids, range(len(qw_vids))))

    Q = []
    QW = []
    Q_len = []
    for i, q in enumerate(queries):
        Q_len.append(min(query_size, len(q)))
        lq = max(0, query_size - len(q))
        q_vec = q[-query_size:] + [0] * lq
        Q.append(q_vec)
        tmp = [qw_vid2id[each] for each in q if each in qw_vid2id]
        tmp = tmp[-query_size:] + [0] * max(0, query_size - len(tmp))
        QW.append(tmp)

    cand_ans_bows_vec = []
    for x in cand_ans_bows:
        tmp = []
        for y in x:
            l = max(0, cand_ans_bows_size - len(y))
            tmp1 = y[:cand_ans_bows_size] + [0] * l
            tmp.append(tmp1)
        tmp += [[0] * cand_ans_bows_size] # Add a dummy candidate after the true sequence
        cand_ans_bows_vec.append(tmp)

    cand_ans_entities_vec = []
    for x in cand_ans_entities:
        cand_ans_entities_vec.append(x + [0]) # Add a dummy candidate after the true sequence

    cand_ans_types_vec = []
    for x in cand_ans_types:
        tmp = []
        for y in x:
            l = max(0, cand_ans_types_size - len(y))
            tmp1 = y[:cand_ans_types_size] + [0] * l
            tmp.append(tmp1)
        tmp += [[0] * cand_ans_types_size] # Add a dummy candidate after the true sequence
        cand_ans_types_vec.append(tmp)

    cand_ans_type_bows_vec = []
    cand_ans_type_bows_len = []
    for x in cand_ans_type_bows:
        tmp = []
        tmp_len = []
        for y in x:
            l = max(0, cand_ans_type_bows_size - len(y))
            tmp1 = y[:cand_ans_type_bows_size] + [0] * l
            tmp.append(tmp1)
            tmp_len.append(max(min(cand_ans_type_bows_size, len(y)), 1))
        tmp += [[0] * cand_ans_type_bows_size] # Add a dummy candidate after the true sequence
        tmp_len += [1]
        cand_ans_type_bows_vec.append(tmp)
        cand_ans_type_bows_len.append(tmp_len)

    cand_ans_paths_vec = []
    for x in cand_ans_paths:
        tmp = []
        for y in x:
            l = max(0, cand_ans_paths_size - len(y))
            tmp1 = y[:cand_ans_paths_size] + [0] * l
            tmp.append(tmp1)
        tmp += [[0] * cand_ans_paths_size] # Add a dummy candidate after the true sequence
        cand_ans_paths_vec.append(tmp)

    cand_ans_path_bows_vec = []
    cand_ans_path_bows_len = []
    for x in cand_ans_path_bows:
        tmp = []
        tmp_len = []
        for y in x:
            l = max(0, cand_ans_path_bows_size - len(y))
            tmp1 = y[:cand_ans_path_bows_size] + [0] * l
            tmp.append(tmp1)
            tmp_len.append(max(min(cand_ans_path_bows_size, len(y)), 1))
        tmp += [[0] * cand_ans_path_bows_size] # Add a dummy candidate after the true sequence
        tmp_len += [1]
        cand_ans_path_bows_vec.append(tmp)
        cand_ans_path_bows_len.append(tmp_len)

    cand_ans_ctx_entity_vec = []
    cand_ans_ctx_relation_vec = []
    for x in cand_ans_ctx:
        tmp_ent = []
        tmp_rel = []
        for y in x:
            tmp_ent.append(y[0]) # y[0] is a list of lists
            l_rel = max(0, cand_ans_ctx_relation_bows_size - len(y[1]))
            tmp_rel.append(y[1][:cand_ans_ctx_relation_bows_size] + [0] * l_rel)
        tmp_ent += [[]] # Add a dummy candidate after the true sequence
        tmp_rel += [[0] * cand_ans_ctx_relation_bows_size]
        cand_ans_ctx_entity_vec.append(tmp_ent)
        cand_ans_ctx_relation_vec.append(tmp_rel)

    cand_ans_topic_key_ent_type_bows_vec = []
    cand_ans_topic_key_ent_type_vec = []
    cand_ans_topic_key_ent_type_bows_len = []
    for x in cand_ans_topic_key:
        tmp_ent_type_bows = []
        tmp_ent_type = []
        tmp_ent_type_bow_len = []
        for y in x:
            tmp_ent_type_bows.append(y[0][:cand_ans_topic_key_ent_type_bows_size] + [0] * max(0, cand_ans_topic_key_ent_type_bows_size - len(y[0])))
            tmp_ent_type.append(y[1][:cand_ans_topic_key_ent_types_size] + [0] * max(0, cand_ans_topic_key_ent_types_size - len(y[1])))
            tmp_ent_type_bow_len.append(max(min(cand_ans_topic_key_ent_type_bows_size, len(y[0])), 1))
        tmp_ent_type_bows += [[0] * cand_ans_topic_key_ent_type_bows_size] # Add a dummy candidate after the true sequence
        tmp_ent_type += [[0] * cand_ans_topic_key_ent_types_size]
        tmp_ent_type_bow_len += [1]
        cand_ans_topic_key_ent_type_bows_vec.append(tmp_ent_type_bows)
        cand_ans_topic_key_ent_type_vec.append(tmp_ent_type)
        cand_ans_topic_key_ent_type_bows_len.append(tmp_ent_type_bow_len)
    return Q, QW, Q_len, list(zip(cand_ans_bows_vec, cand_ans_entities_vec, cand_ans_type_bows_vec, cand_ans_types_vec, cand_ans_type_bows_len, cand_ans_path_bows_vec, cand_ans_paths_vec, cand_ans_path_bows_len, cand_ans_ctx_entity_vec, cand_ans_ctx_relation_vec, cand_ans_topic_key_ent_type_bows_vec, cand_ans_topic_key_ent_type_vec, cand_ans_topic_key_ent_type_bows_len))


def vectorize_ent_data(queries, ent_memories, max_query_size=None, \
                max_seed_ent_name_size=None, max_seed_type_name_size=None, \
                max_seed_rel_name_size=None, max_seed_rel_size=None, verbose=True):
    seed_ent_name, seed_ent_type_name, seed_ent_type, seed_rel_names, seed_rels = zip(*ent_memories)

    max_query_size = max(min(max(map(len, queries), default=0), max_query_size if max_query_size else float('inf')), 1)
    cand_seed_ent_name_size = max(min(max(map(len, (y for x in seed_ent_name for y in x)), default=0), max_seed_ent_name_size if max_seed_ent_name_size else float('inf')), 1)
    cand_seed_type_name_size = max(min(max(map(len, (y for x in seed_ent_type_name for y in x)), default=0), max_seed_type_name_size if max_seed_type_name_size else float('inf')), 1)
    cand_seed_types_size = max(max(map(len, (y for x in seed_ent_type for y in x)), default=0), 1)
    cand_seed_rel_name_size = max(min(max(map(len, (z for x in seed_rel_names for y in x for z in y)), default=0), max_seed_rel_name_size if max_seed_rel_name_size else float('inf')), 1)
    cand_seed_rel_size = max(min(max(map(len, (y for x in seed_rels for y in x)), default=0), max_seed_rel_size if max_seed_rel_size else float('inf')), 1)


    if verbose:
        print('\nmax_query_size: {}, cand_seed_ent_name_size: {}, cand_seed_type_name_size: {}, '
            'cand_seed_types_size: {}, cand_seed_rel_name_size: {}, cand_seed_rel_size: {}'.format(max_query_size, \
                cand_seed_ent_name_size, cand_seed_type_name_size, cand_seed_types_size, \
                cand_seed_rel_name_size, cand_seed_rel_size))


    # Query vectorization
    Q = []
    Q_len = []
    for q in queries:
        Q_len.append(min(max_query_size, len(q)))
        lq = max(0, max_query_size - len(q))
        q_vec = q[-max_query_size:] + [0] * lq
        Q.append(q_vec)


    # Entity vectorization
    cand_seed_ent_name_vec = []
    cand_seed_ent_name_len = []
    for x in seed_ent_name:
        tmp = []
        tmp_len = []
        for y in x:
            l = max(0, cand_seed_ent_name_size - len(y))
            tmp1 = y[:cand_seed_ent_name_size] + [0] * l
            tmp.append(tmp1)
            tmp_len.append(max(min(cand_seed_ent_name_size, len(y)), 1))
        cand_seed_ent_name_vec.append(tmp)
        cand_seed_ent_name_len.append(tmp_len)

    cand_seed_type_vec = []
    for x in seed_ent_type:
        tmp = []
        for y in x:
            l = max(0, cand_seed_types_size - len(y))
            tmp1 = y[:cand_seed_types_size] + [0] * l
            tmp.append(tmp1)
        cand_seed_type_vec.append(tmp)

    cand_seed_type_name_vec = []
    cand_seed_type_name_len = []
    for x in seed_ent_type_name:
        tmp = []
        tmp_len = []
        for y in x:
            l = max(0, cand_seed_type_name_size - len(y))
            tmp1 = y[:cand_seed_type_name_size] + [0] * l
            tmp.append(tmp1)
            tmp_len.append(max(min(cand_seed_type_name_size, len(y)), 1))
        cand_seed_type_name_vec.append(tmp)
        cand_seed_type_name_len.append(tmp_len)


    cand_seed_rel_vec = []
    cand_seed_rel_mask = []
    for x in seed_rels: # example
        x_tmp = []
        x_mask = []
        for y in x: # seed entity
            l = max(0, cand_seed_rel_size - len(y))
            y_tmp = y[:cand_seed_rel_size] + [0] * l
            x_tmp.append(y_tmp)
            x_mask.append(min(len(y), cand_seed_rel_size))
        cand_seed_rel_vec.append(x_tmp)
        cand_seed_rel_mask.append(x_mask)


    cand_seed_rel_name_vec = []
    cand_seed_rel_name_len = []
    for x in seed_rel_names: # example
        x_tmp = []
        x_tmp_len = []
        for y in x: # seed entity
            y_tmp = []
            y_tmp_len = []
            for z in y: # relation
                z_l = max(0, cand_seed_rel_name_size - len(z))
                z_tmp = z[:cand_seed_rel_name_size] + [0] * z_l
                y_tmp.append(z_tmp)
                y_tmp_len.append(max(min(cand_seed_rel_name_size, len(z)), 1))
            y_l = max(0, cand_seed_rel_size - len(y))
            y_tmp += [[0] * cand_seed_rel_name_size] * y_l
            y_tmp_len += [1] * y_l
            x_tmp.append(y_tmp)
            x_tmp_len.append(y_tmp_len)
        cand_seed_rel_name_vec.append(x_tmp)
        cand_seed_rel_name_len.append(x_tmp_len)
    return Q, Q_len, list(zip(cand_seed_ent_name_vec, cand_seed_ent_name_len, cand_seed_type_name_vec, cand_seed_type_vec, cand_seed_type_name_len, cand_seed_rel_name_vec, cand_seed_rel_vec, cand_seed_rel_name_len, cand_seed_rel_mask))


================================================
FILE: src/core/build_data/webquestions.py
================================================
'''
Created on Sep, 2017

@author: hugo

'''
import os
# import re
import argparse
from nltk.parse.stanford import StanfordDependencyParser

from ..utils.utils import *
from ..utils.freebase_utils import if_filterout
from ..utils.generic_utils import *


def get_used_fbkeys(data_dir, out_dir):
    # Fetch freebase keys used in training and validation sets.
    fbkeys = set()
    split = ['factoid_webqa/train.json', 'factoid_webqa/valid.json']
    files = [os.path.join(data_dir, x) for x in split]
    for f in files:
        data = load_json(f)
        for qa in data:
            fbkeys.add(qa['freebaseKey'])
    dump_json(list(fbkeys), os.path.join(out_dir, 'fbkeys_train_valid.json'), indent=1)

def get_all_fbkeys(data_dir, out_dir):
    # Fetch all freebase keys possibily useful to answer questions.
    fbkeys = set()
    split = ['factoid_webqa/train.json', 'factoid_webqa/valid.json', 'factoid_webqa/test.json']
    files = [os.path.join(data_dir, x) for x in split]
    for f in files:
        data = load_json(f)
        for qa in data:
            fbkeys.add(qa['freebaseKey'])

    retrieved_test_path = os.path.join(data_dir, 'factoid_webqa/webquestions.examples.test.retrieved.json')
    if os.path.exists(retrieved_test_path):
        data = load_json(retrieved_test_path)
        for qa in data:
            if not 'retrievedList' in qa:
                continue
            for x in qa['retrievedList'].split():
                fbkeys.add(x.split(':')[0])
    dump_json(list(fbkeys), os.path.join(out_dir, 'fbkeys_train_valid_test_retrieved.json'), indent=1)

def main(fb_path, mid2key_path, data_dir, out_dir):
    HAS_DEP = False
    if HAS_DEP:
        dep_parser = StanfordDependencyParser(model_path="edu/stanford/nlp/models/lexparser/englishPCFG.ser.gz") # Set CLASSPATH and STANFORD_MODELS environment variables beforehand
    kb = load_ndjson(fb_path, return_type='dict')
    mid2key = load_json(mid2key_path)
    all_split_questions = []
    split = ['factoid_webqa/train.json', 'factoid_webqa/valid.json', 'factoid_webqa/test.json']
    files = [os.path.join(data_dir, x) for x in split]
    missing_mid2key = []

    for f in files:
        data_type = os.path.basename(f).split('.')[0]
        num_unanswerable = 0
        all_questions = []
        data = load_json(f)
        for q in data:
            questions = {}
            questions['answers'] = q['answers']
            questions['entities'] = q['entities']
            questions['qText'] = q['qText']
            questions['qId'] = q['qId']
            questions['freebaseKey'] = q['freebaseKey']
            questions['freebaseKeyCands'] = [q['freebaseKey']]
            for x in q['freebaseMids']:
                if x['mid'] in mid2key:
                    fbkey = mid2key[x['mid']]
                    if fbkey != q['freebaseKey']:
                        questions['freebaseKeyCands'].append(fbkey)
                else:
                    missing_mid2key.append(x['mid'])

            qtext = tokenize(q['qText'])
            if HAS_DEP:
                qw = list(set(qtext).intersection(question_word_list))
                question_word = qw[0] if len(qw) > 0 else ''
                topic_ent = q['freebaseKey']
                dep_path = extract_dep_feature(dep_parser, ' '.join(qtext), topic_ent, question_word)
            else:
                dep_path = []
            questions['dep_path'] = dep_path
            all_questions.append(questions)

            if not q['freebaseKey'] in kb:
                num_unanswerable += 1
                continue
            cand_ans = fetch_ans_cands(kb[q['freebaseKey']])
            norm_cand_ans = set([normalize_answer(x) for x in cand_ans])
            norm_gold_ans = [normalize_answer(x) for x in q['answers']]
            # Check if we can find the gold answer from the candidiate answers.
            if len(norm_cand_ans.intersection(norm_gold_ans)) == 0:
                num_unanswerable += 1
                continue
        all_split_questions.append(all_questions)
        print('{} set: Num of unanswerable questions: {}'.format(data_type, num_unanswerable))

    for i, each in enumerate(all_split_questions):
        dump_ndjson(each, os.path.join(out_dir, split[i].split('/')[-1]))

def fetch_ans_cands(graph):
    cand_ans = set() # candidiate answers
    # We only consider the alias relations of topic entityies
    cand_ans.update(graph['alias'])
    for k, v in graph['neighbors'].items():
        if if_filterout(k):
            continue
        for nbr in v:
            if isinstance(nbr, str):
                cand_ans.add(nbr)
                continue
            elif isinstance(nbr, bool):
                cand_ans.add('true' if nbr else 'false')
                continue
            elif isinstance(nbr, float):
                cand_ans.add(str(nbr))
                continue
            elif isinstance(nbr, dict):
                nbr_k = list(nbr.keys())[0]
                nbr_v = nbr[nbr_k]
                selected_names = nbr_v['name'] if 'name' in nbr_v and len(nbr_v['name']) > 0 else (nbr_v['alias'][:1] if 'alias' in nbr_v else [])
                cand_ans.add(selected_names[0] if len(selected_names) > 0 else 'UNK')
                if not 'neighbors' in nbr_v:
                    continue
                for kk, vv in nbr_v['neighbors'].items(): # 2nd hop
                    if if_filterout(kk):
                        continue
                    for nbr_nbr in vv:
                        if isinstance(nbr_nbr, str):
                            cand_ans.add(nbr_nbr)
                            continue
                        elif isinstance(nbr_nbr, bool):
                            cand_ans.add('true' if nbr_nbr else 'false')
                            continue
                        elif isinstance(nbr_nbr, float):
                            cand_ans.add(str(nbr_nbr))
                            continue
                        elif isinstance(nbr_nbr, dict):
                            nbr_nbr_k = list(nbr_nbr.keys())[0]
                            nbr_nbr_v = nbr_nbr[nbr_nbr_k]
                            selected_names = nbr_nbr_v['name'] if 'name' in nbr_nbr_v and len(nbr_nbr_v['name']) > 0 else (nbr_nbr_v['alias'][:1] if 'alias' in nbr_nbr_v else [])
                            cand_ans.add(selected_names[0] if len(selected_names) > 0 else 'UNK')
                        else:
                            raise RuntimeError('Unknown type: %s' % type(nbr_nbr))
            else:
                raise RuntimeError('Unknown type: %s' % type(nbr))
    return list(cand_ans)


================================================
FILE: src/core/config.py
================================================

# Vocabulary
RESERVED_TOKENS = {'PAD': 0, 'UNK': 1}
RESERVED_ENTS = {'PAD': 0, 'UNK': 1}
RESERVED_ENT_TYPES = {'PAD': 0, 'UNK': 1}
RESERVED_RELS = {'PAD': 0, 'UNK': 1}

extra_vocab_tokens = ['alias', 'true', 'false', 'num', 'bool'] + \
    ['np', 'organization', 'date', 'number', 'misc', 'ordinal', 'duration', 'person', 'time', 'location'] + \
    ['__np__', '__organization__', '__date__', '__number__', '__misc__', '__ordinal__', '__duration__', '__person__', '__time__', '__location__']

extra_rels = ['alias']
extra_ent_types = ['num', 'bool']


# BAMnet entity mention types
topic_mention_types = {'person', 'organization', 'location', 'misc'}
# delex_mention_types = {'date', 'time', 'ordinal', 'number'}
delex_mention_types = {'date', 'ordinal', 'number'}
constraint_mention_types = delex_mention_types


================================================
FILE: src/core/utils/__init__.py
================================================
'''
Created on Oct, 2017

@author: hugo

'''


================================================
FILE: src/core/utils/freebase_utils.py
================================================
'''
Created on Oct, 2017

@author: hugo

'''
from rapidfuzz import fuzz, process


def if_filterout(s):
    if s.endswith('has_sentences') or \
        s.endswith('exceptions') or s.endswith('sww_base/source') or \
        s.endswith('kwtopic/assessment'):
        return True
    else:
        return False

def query_kb(kb, ent_name, fuzz_threshold=90):
    results = []
    for k, v in kb.items():
        ret = process.extractOne(ent_name, v['name'] + v['alias'], scorer=fuzz.token_sort_ratio)
        if ret[1] > fuzz_threshold:
            results.append((k, ret[0], ret[1]))
    results = sorted(results, key=lambda d:d[-1], reverse=True)
    return list(zip(*results))[0] if len(results) > 0 else []


================================================
FILE: src/core/utils/generic_utils.py
================================================
'''
Created on Oct, 2017

@author: hugo

'''
import re, string
import numpy as np
from rapidfuzz import fuzz, process
from nltk.corpus import stopwords

from .utils import dump_ndarray, tokenize


question_word_list = 'who, when, what, where, how, which, why, whom, whose'.split(', ')
stop_words = set(stopwords.words("english"))

def find_parent(x, tree, conn='<-'):
    root = tree[0][0]
    path = []
    for parent, indicator, child in tree:
        if x == child[0]:
            path.extend([conn, '__{}__'.format(indicator), '-', parent[0]])
            if not parent == root:
                p = find_parent(parent[0], tree, conn)
                path.extend(p)
            return path
    return path

def extract_dep_feature(dep_parser, text, topic_ent, question_word):
    dep = dep_parser.raw_parse(text).__next__()
    tree = list(dep.triples())
    topic_ent = list(set(tokenize(topic_ent)) - stop_words)
    text = text.split()

    path_len = 1e5
    topic_ent_to_root = []
    for each in topic_ent:
        ret = process.extractOne(each, text, scorer=fuzz.token_sort_ratio)
        if ret[1] < 85:
            continue
        tmp = find_parent(ret[0], tree, '->')
        if len(tmp) > 0 and len(tmp) < path_len:
            topic_ent_to_root = tmp
            path_len = len(tmp)
    question_word_to_root = find_parent(question_word, tree)
    # if len(question_word_to_root) == 0 or len(topic_ent_to_root) == 0:
        # import pdb;pdb.set_trace()
    return question_word_to_root + list(reversed(topic_ent_to_root[:-1]))

def unique(seq):
    seen = set()
    seen_add = seen.add
    return [x for x in seq if not (x in seen or seen_add(x))]

re_art = re.compile(r'\b(a|an|the)\b')
re_punc = re.compile(r'[%s]' % re.escape(string.punctuation))

def normalize_answer(s):
    """Lower text and remove extra whitespace."""
    def remove_articles(text):
        return re_art.sub(' ', text)

    def remove_punc(text):
        return re_punc.sub(' ', text)  # convert punctuation to spaces

    def white_space_fix(text):
        return ' '.join(text.split())

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

def dump_embeddings(vocab_dict, emb_file, out_path, emb_size=300, binary=False, seed=123):
    vocab_emb = get_embeddings(emb_file, vocab_dict, binary)

    vocab_size = len(vocab_dict)
    np.random.seed(seed)
    embeddings = np.random.uniform(-0.08, 0.08, (vocab_size, emb_size))
    for w, idx in vocab_dict.items():
        if w in vocab_emb:
            embeddings[int(idx)] = vocab_emb[w]
    embeddings[0] = 0
    dump_ndarray(embeddings, out_path)
    return embeddings

def get_embeddings(emb_file, vocab, binary=False):
    pt = PreTrainEmbedding(emb_file, binary)
    vocab_embs = {}

    i = 0.
    for each in vocab:
        emb = pt.get_embeddings(each)
        if not emb is None:
            vocab_embs[each] = emb
            i += 1
    print('get_wordemb hit ratio: %s' % (i / len(vocab)))
    return vocab_embs

class PreTrainEmbedding():
    def __init__(self, file, binary=False):
        import gensim
        self.model = gensim.models.KeyedVectors.load_word2vec_format(file, binary=binary)

    def get_embeddings(self, word):
        word_list = [word, word.upper(), word.lower(), word.title(), string.capwords(word, '_')]

        for w in word_list:
            try:
                return self.model[w]
            except KeyError:
                # print('Can not get embedding for ', w)
                continue
        return None


================================================
FILE: src/core/utils/metrics.py
================================================
'''
Created on Oct, 2017

@author: hugo

Note: Modified the official evaluation script provided by Berant et al.
(https://github.com/percyliang/sempre/blob/master/scripts/evaluation.py)
'''
from .generic_utils import normalize_answer


def calc_f1(gold_list, pred_list):
    """Return a tuple with recall, precision, and f1 for one example"""

    # Assume all questions have at least one answer
    if len(gold_list) == 0:
        raise RuntimeError('Gold list may not be empty')
    # If we return an empty list recall is zero and precision is one
    if len(pred_list) == 0:
        return (0, 1, 0)
    # It is guaranteed now that both lists are not empty

    # Normalize answers
    gold_list = [normalize_answer(s) for s in gold_list]
    pred_list = [normalize_answer(s) for s in pred_list]

    precision = 0
    for entity in pred_list:
        if entity in gold_list:
            precision += 1
    precision = float(precision) / len(pred_list)

    recall = 0
    for entity in gold_list:
        if entity in pred_list:
              recall += 1
    recall = float(recall) / len(gold_list)

    f1 = 0
    if precision + recall > 0:
        f1 = 2 * recall * precision / (precision + recall)
    return (recall, precision, f1)

def calc_avg_f1(gold_list, pred_list, verbose=True):
    """Go over all examples and compute recall, precision and F1"""
    avg_recall = 0
    avg_precision = 0
    avg_f1 = 0
    count = 0

    out_f = open('error_analysis.txt', 'w')
    assert len(gold_list) == len(pred_list)
    for i, gold in enumerate(gold_list):
        recall, precision, f1 = calc_f1(gold, pred_list[i])
        avg_recall += recall
        avg_precision += precision
        avg_f1 += f1
        count += 1
        if True:
        # if f1 < 0.6:
            out_f.write('{}\t{}\t{}\t{}\n'.format(i, gold, pred_list[i], f1))
    out_f.close()

    avg_recall = float(avg_recall) / count
    avg_precision = float(avg_precision) / count
    avg_f1 = float(avg_f1) / count
    avg_new_f1 = 0
    if avg_precision + avg_recall > 0:
        avg_new_f1 = 2 * avg_recall * avg_precision / (avg_precision + avg_recall)

    if verbose:
        print("Number of questions: " + str(count))
        print("Average recall over questions: " + str(avg_recall))
        print("Average precision over questions: " + str(avg_precision))
        print("Average f1 over questions: " + str(avg_f1))
        # print("F1 of average recall and average precision: " + str(avg_new_f1))
    return count, avg_recall, avg_precision, avg_f1


================================================
FILE: src/core/utils/utils.py
================================================
'''
Created on Sep, 2017

@author: hugo

'''
import os
import re
import yaml
import gzip
import json
import string
import numpy as np
from nltk.tokenize import wordpunct_tokenize#, word_tokenize


# tokenize = lambda s: word_tokenize(re.sub(r'[%s]' % punc_wo_dot, ' ', re.sub(r'(?<!\d)[%s](?!\d)' % string.punctuation, ' ', s)))
tokenize = lambda s: wordpunct_tokenize(re.sub('[%s]' % re.escape(string.punctuation), ' ', s))

def get_config(config_path="config.yml"):
    with open(config_path, "r") as setting:
        config = yaml.load(setting)
    return config

def print_config(config):
    print("**************** MODEL CONFIGURATION ****************")
    for key in sorted(config.keys()):
        val = config[key]
        keystr = "{}".format(key) + (" " * (24 - len(key)))
        print("{} -->   {}".format(keystr, val))
    print("**************** MODEL CONFIGURATION ****************")

def read_lines(path_to_file):
    data = []
    try:
        with open(path_to_file, 'r') as f:
            for line in f:
                tmp = [float(x) for x in line.strip().split()]
                data.append(tmp)
    except Exception as e:
        raise e

    return data

def dump_ndarray(data, path_to_file):
    try:
        with open(path_to_file, 'wb') as f:
            np.save(f, data)
    except Exception as e:
        raise e

def load_ndarray(path_to_file):
    try:
        with open(path_to_file, 'rb') as f:
            data = np.load(f)
    except Exception as e:
        raise e

    return data

def dump_ndjson(data, file):
    try:
        with open(file, 'w') as f:
            for each in data:
                f.write(json.dumps(each) + '\n')
    except Exception as e:
        raise e

def load_ndjson(file, return_type='array'):
    if return_type == 'array':
        return load_ndjson_to_array(file)
    elif return_type == 'dict':
        return load_ndjson_to_dict(file)
    else:
        raise RuntimeError('Unknown return_type: %s' % return_type)

def load_ndjson_to_array(file):
    data = []
    try:
        with open(file, 'r') as f:
            for line in f:
                data.append(json.loads(line.strip()))
    except Exception as e:
        raise e
    return data

def load_ndjson_to_dict(file):
    data = {}
    try:
        with open(file, 'r') as f:
            for line in f:
                data.update(json.loads(line.strip()))
    except Exception as e:
        raise e
    return data

def dump_json(data, file, indent=None):
    try:
        with open(file, 'w') as f:
            json.dump(data, f, indent=indent)
    except Exception as e:
        raise e

def load_json(file):
    try:
        with open(file, 'r') as f:
            data = json.load(f)
    except Exception as e:
        raise e
    return data

def dump_dict_ndjson(data, file):
    try:
        with open(file, 'w') as f:
            for k, v in data.items():
                line = json.dumps([k, v]) + '\n'
                f.write(line)
    except Exception as e:
        raise e

def load_gzip_json(file):
    try:
        with gzip.open(file, 'r') as f:
            data = json.load(f)
    except Exception as e:
        raise e
    return data

def get_all_files(dir, recursive=False):
    if recursive:
        return [os.path.join(root, file) for root, dirnames, filenames in os.walk(dir) for file in filenames if os.path.isfile(os.path.join(root, file)) and not file.startswith('.')]
    else:
        return [os.path.join(dir, filename) for filename in os.listdir(dir) if os.path.isfile(os.path.join(dir, filename)) and not filename.startswith('.')]

# Print iterations progress
def printProgressBar(iteration, total, prefix = '', suffix = '', decimals = 1, length = 100, fill = '█'):
    """
    Call in a loop to create terminal progress bar
    @params:
        iteration   - Required  : current iteration (Int)
        total       - Required  : total iterations (Int)
        prefix      - Optional  : prefix string (Str)
        suffix      - Optional  : suffix string (Str)
        decimals    - Optional  : positive number of decimals in percent complete (Int)
        length      - Optional  : character length of bar (Int)
        fill        - Optional  : bar fill character (Str)
    """
    percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total)))
    filledLength = int(length * iteration // total)
    bar = fill * filledLength + '-' * (length - filledLength)
    print('\r%s |%s| %s%% %s' % (prefix, bar, percent, suffix), end = '\r')


================================================
FILE: src/joint_test.py
================================================
import timeit
import argparse
import numpy as np

from core.bamnet.entnet import EntnetAgent
from core.bamnet.bamnet import BAMnetAgent
from core.build_data.build_all import build
from core.build_data.utils import vectorize_ent_data, vectorize_data
from core.build_data.build_data import build_data
from core.utils.generic_utils import unique
from core.utils.utils import *
from core.utils.metrics import *


def dynamic_pred(pred, margin):
    predictions = []
    for i in range(len(pred)):
        predictions.append(unique([x[0] for x in pred[i] if x[1] + margin >= pred[i][0][1]]))
    return predictions

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-bamnet_config', '--bamnet_config', required=True, type=str, help='path to the config file')
    parser.add_argument('-entnet_config', '--entnet_config', required=True, type=str, help='path to the config file')
    parser.add_argument('-raw_data', '--raw_data_dir', required=True, type=str, help='raw data dir')
    cfg = vars(parser.parse_args())
    bamnet_opt = get_config(cfg['bamnet_config'])
    entnet_opt = get_config(cfg['entnet_config'])

    start = timeit.default_timer()
    # Entnet
    # Ensure data is built
    build(entnet_opt['data_dir'])
    data_vec = load_json(os.path.join(entnet_opt['data_dir'], entnet_opt['test_data']))

    queries, memories, ent_labels, ent_inds = data_vec
    queries, query_lengths, memories = vectorize_ent_data(queries, \
                                        memories, max_query_size=entnet_opt['query_size'], \
                                        max_seed_ent_name_size=entnet_opt['max_seed_ent_name_size'], \
                                        max_seed_type_name_size=entnet_opt['max_seed_type_name_size'], \
                                        max_seed_rel_name_size=entnet_opt['max_seed_rel_name_size'], \
                                        max_seed_rel_size=entnet_opt['max_seed_rel_size'])

    ent_model = EntnetAgent(entnet_opt)
    acc = ent_model.evaluate([memories, queries, query_lengths], ent_inds, batch_size=entnet_opt['test_batch_size'])
    print('acc: {}'.format(acc))
    pred_seed_ents = ent_model.predict([memories, queries, query_lengths], ent_labels, batch_size=entnet_opt['test_batch_size'])


    # BAMnet
    # Ensure data is built
    build(bamnet_opt['data_dir'])
    entity2id = load_json(os.path.join(bamnet_opt['data_dir'], 'entity2id.json'))
    entityType2id = load_json(os.path.join(bamnet_opt['data_dir'], 'entityType2id.json'))
    relation2id = load_json(os.path.join(bamnet_opt['data_dir'], 'relation2id.json'))
    vocab2id = load_json(os.path.join(bamnet_opt['data_dir'], 'vocab2id.json'))
    ctx_stopwords = {'i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you', "you're", "you've", "you'll", "you'd", 'your', 'yours', 'yourself', 'yourselves', 'he', 'him', 'his', 'himself', 'she', "she's", 'her', 'hers', 'herself', 'it', "it's", 'its', 'itself', 'they', 'them', 'their', 'theirs', 'themselves', 'what', 'which', 'who', 'whom', 'this', 'that', "that'll", 'these', 'those', 'am', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has', 'had', 'having', 'do', 'does', 'did', 'doing', 'a', 'an', 'the', 'and', 'but', 'if', 'or', 'because', 'as', 'until', 'while', 'of', 'at', 'by', 'for', 'with', 'about', 'against', 'between', 'into', 'through', 'during', 'before', 'after', 'above', 'below', 'to', 'from', 'up', 'down', 'in', 'out', 'on', 'off', 'over', 'under', 'again', 'further', 'then', 'once', 'here', 'there', 'when', 'where', 'why', 'how', 'all', 'any', 'both', 'each', 'few', 'more', 'most', 'other', 'some', 'such', 'no', 'nor', 'not', 'only', 'own', 'same', 'so', 'than', 'too', 'very', 's', 't', 'can', 'will', 'just', 'don', "don't", 'should', "should've", 'now', 'd', 'll', 'm', 'o', 're', 've', 'y', 'ain', 'aren', "aren't", 'couldn', "couldn't", 'didn', "didn't", 'doesn', "doesn't", 'hadn', "hadn't", 'hasn', "hasn't", 'haven', "haven't", 'isn', "isn't", 'ma', 'mightn', "mightn't", 'mustn', "mustn't", 'needn', "needn't", 'shan', "shan't", 'shouldn', "shouldn't", 'wasn', "wasn't", 'weren', "weren't", 'won', "won't", 'wouldn', "wouldn't"}

    # Build data in real time
    freebase = load_ndjson(os.path.join(cfg['raw_data_dir'], 'freebase_full.json'), return_type='dict')
    test_data = load_ndjson(os.path.join(cfg['raw_data_dir'], 'raw_test.json'))
    data_vec = build_data(test_data, freebase, entity2id, entityType2id, relation2id, vocab2id, pred_seed_ents=pred_seed_ents)

    queries, raw_queries, query_mentions, memories, cand_labels, _, gold_ans_labels = data_vec
    queries, query_words, query_lengths, memories_vec = vectorize_data(queries, query_mentions, memories, \
                                        max_query_size=bamnet_opt['query_size'], \
                                        max_query_markup_size=bamnet_opt['query_markup_size'], \
                                        max_ans_bow_size=bamnet_opt['ans_bow_size'], \
                                        vocab2id=vocab2id)

    model = BAMnetAgent(bamnet_opt, ctx_stopwords, vocab2id)
    pred = model.predict([memories_vec, queries, query_words, raw_queries, query_mentions, query_lengths], cand_labels, batch_size=bamnet_opt['test_batch_size'], margin=2)

    print('\nPredictions')
    for margin in bamnet_opt['test_margin']:
        print('\nMargin: {}'.format(margin))
        predictions = dynamic_pred(pred, margin)
        calc_avg_f1(gold_ans_labels, predictions)
    print('Runtime: %ss' % (timeit.default_timer() - start))


================================================
FILE: src/run_freebase.py
================================================
'''
Created on Oct, 2017

@author: hugo

'''
import argparse
import os
import json

from core.build_data.freebase import *
from core.utils.utils import *


parser = argparse.ArgumentParser()
parser.add_argument('-data_dir', '--data_dir', required=True, type=str, help='path to the data dir')
parser.add_argument('-fbkeys', '--freebase_keys', required=True, type=str, help='path to the freebase key file')
parser.add_argument('-out_dir', '--out_dir', type=str, required=True, help='path to the output dir')
args = parser.parse_args()

ids = load_json(args.freebase_keys)
total = len(ids)
print('Fetching {} entities and their 2-hop neighbors.'.format(total))
print_bar_len = 50
cnt = 0
missing_ids = set()
with open(os.path.join(args.out_dir, 'freebase.json'), 'a') as out_f:
    for id_ in ids:
        try:
            data = load_gzip_json(os.path.join(args.data_dir, '{}.json.gz'.format(id_)))
        except:
            missing_ids.add(id_)
            continue
        graph = fetch(data, args.data_dir)
        graph2 = {id_: list(graph.values())[0]}
        graph2[id_]['id'] = list(graph.keys())[0]
        line = json.dumps(graph2) + '\n'
        out_f.write(line)
        cnt += 1
        if cnt % int(total / print_bar_len) == 0:
            printProgressBar(cnt, total, prefix='Progress:', suffix='Complete', length=print_bar_len)
    printProgressBar(cnt, total, prefix='Progress:', suffix='Complete', length=print_bar_len)

print('Missed %s mids' % len(missing_ids))
dump_json(list(missing_ids), os.path.join(args.out_dir, 'missing_fbids.json'))


================================================
FILE: src/run_webquestions.py
================================================
'''
Created on Oct, 2017

@author: hugo

'''
import argparse
from core.build_data.webquestions import *

parser = argparse.ArgumentParser()
parser.add_argument('-fb', '--freebase_path', required=True, type=str, help='path to the freebase data')
parser.add_argument('-mid2key', '--mid2key_path', required=True, type=str, help='path to the freebase data')
parser.add_argument('-data_dir', '--data_dir', required=True, type=str, help='path to the data dir')
parser.add_argument('-out_dir', '--out_dir', type=str, required=True, help='path to the output dir')
args = parser.parse_args()

main(args.freebase_path, args.mid2key_path, args.data_dir, args.out_dir)
# get_used_fbkeys(args.data_dir, args.out_dir)
# get_all_fbkeys(args.data_dir, args.out_dir)


================================================
FILE: src/test.py
================================================
import timeit
import argparse

from core.bamnet.bamnet import BAMnetAgent
from core.build_data.build_all import build
from core.build_data.utils import vectorize_data
from core.utils.utils import *
from core.utils.generic_utils import unique
from core.utils.metrics import *


def dynamic_pred(pred, margin):
    predictions = []
    for i in range(len(pred)):
        predictions.append(unique([x[0] for x in pred[i] if x[1] + margin >= pred[i][0][1]]))
    return predictions

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-config', '--config', required=True, type=str, help='path to the config file')
    cfg = vars(parser.parse_args())
    opt = get_config(cfg['config'])

    # Ensure data is built
    build(opt['data_dir'])
    data_vec = load_json(os.path.join(opt['data_dir'], opt['test_data']))
    vocab2id = load_json(os.path.join(opt['data_dir'], 'vocab2id.json'))
    ctx_stopwords = {'i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you', "you're", "you've", "you'll", "you'd", 'your', 'yours', 'yourself', 'yourselves', 'he', 'him', 'his', 'himself', 'she', "she's", 'her', 'hers', 'herself', 'it', "it's", 'its', 'itself', 'they', 'them', 'their', 'theirs', 'themselves', 'what', 'which', 'who', 'whom', 'this', 'that', "that'll", 'these', 'those', 'am', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has', 'had', 'having', 'do', 'does', 'did', 'doing', 'a', 'an', 'the', 'and', 'but', 'if', 'or', 'because', 'as', 'until', 'while', 'of', 'at', 'by', 'for', 'with', 'about', 'against', 'between', 'into', 'through', 'during', 'before', 'after', 'above', 'below', 'to', 'from', 'up', 'down', 'in', 'out', 'on', 'off', 'over', 'under', 'again', 'further', 'then', 'once', 'here', 'there', 'when', 'where', 'why', 'how', 'all', 'any', 'both', 'each', 'few', 'more', 'most', 'other', 'some', 'such', 'no', 'nor', 'not', 'only', 'own', 'same', 'so', 'than', 'too', 'very', 's', 't', 'can', 'will', 'just', 'don', "don't", 'should', "should've", 'now', 'd', 'll', 'm', 'o', 're', 've', 'y', 'ain', 'aren', "aren't", 'couldn', "couldn't", 'didn', "didn't", 'doesn', "doesn't", 'hadn', "hadn't", 'hasn', "hasn't", 'haven', "haven't", 'isn', "isn't", 'ma', 'mightn', "mightn't", 'mustn', "mustn't", 'needn', "needn't", 'shan', "shan't", 'shouldn', "shouldn't", 'wasn', "wasn't", 'weren', "weren't", 'won', "won't", 'wouldn', "wouldn't"}

    queries, raw_queries, query_mentions, memories, cand_labels, _, gold_ans_labels = data_vec
    queries, query_words, query_lengths, memories_vec = vectorize_data(queries, query_mentions, memories, \
                                        max_query_size=opt['query_size'], \
                                        max_query_markup_size=opt['query_markup_size'], \
                                        max_ans_bow_size=opt['ans_bow_size'], \
                                        vocab2id=vocab2id)

    start = timeit.default_timer()

    model = BAMnetAgent(opt, ctx_stopwords, vocab2id)
    pred = model.predict([memories_vec, queries, query_words, raw_queries, query_mentions, query_lengths], cand_labels, batch_size=opt['test_batch_size'], margin=2)

    print('\nPredictions')
    for margin in opt['test_margin']:
        print('\nMargin: {}'.format(margin))
        predictions = dynamic_pred(pred, margin)
        calc_avg_f1(gold_ans_labels, predictions)
    print('Runtime: %ss' % (timeit.default_timer() - start))
    import pdb;pdb.set_trace()


================================================
FILE: src/test_entnet.py
================================================
import timeit
import argparse
import numpy as np

from core.bamnet.entnet import EntnetAgent
from core.build_data.build_all import build
from core.build_data.utils import vectorize_ent_data
from core.utils.utils import *


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-dt', '--datatype', default='test', type=str, help='data type: {train, valid, test}')
    parser.add_argument('-config', '--config', required=True, type=str, help='path to the config file')
    cfg = vars(parser.parse_args())
    opt = get_config(cfg['config'])

    # Ensure data is built
    build(opt['data_dir'])
    data_vec = load_json(os.path.join(opt['data_dir'], opt['test_data']))

    queries, memories, ent_labels, ent_inds = data_vec
    queries, query_lengths, memories = vectorize_ent_data(queries, \
                                        memories, max_query_size=opt['query_size'], \
                                        max_seed_ent_name_size=opt['max_seed_ent_name_size'], \
                                        max_seed_type_name_size=opt['max_seed_type_name_size'], \
                                        max_seed_rel_name_size=opt['max_seed_rel_name_size'], \
                                        max_seed_rel_size=opt['max_seed_rel_size'])

    start = timeit.default_timer()

    ent_model = EntnetAgent(opt)
    acc = ent_model.evaluate([memories, queries, query_lengths], ent_inds, batch_size=opt['test_batch_size'])
    print('acc: {}'.format(acc))
    print('Runtime: %ss' % (timeit.default_timer() - start))


================================================
FILE: src/train.py
================================================
import timeit
import argparse
import numpy as np

from core.bamnet.bamnet import BAMnetAgent
from core.build_data.build_all import build
from core.build_data.utils import vectorize_data
from core.utils.utils import *


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-config', '--config', required=True, type=str, help='path to the config file')
    cfg = vars(parser.parse_args())
    opt = get_config(cfg['config'])
    print_config(opt)

    # Ensure data is built
    build(opt['data_dir'])
    train_vec = load_json(os.path.join(opt['data_dir'], opt['train_data']))
    valid_vec = load_json(os.path.join(opt['data_dir'], opt['valid_data']))

    vocab2id = load_json(os.path.join(opt['data_dir'], 'vocab2id.json'))
    ctx_stopwords = {'i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you', "you're", "you've", "you'll", "you'd", 'your', 'yours', 'yourself', 'yourselves', 'he', 'him', 'his', 'himself', 'she', "she's", 'her', 'hers', 'herself', 'it', "it's", 'its', 'itself', 'they', 'them', 'their', 'theirs', 'themselves', 'what', 'which', 'who', 'whom', 'this', 'that', "that'll", 'these', 'those', 'am', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has', 'had', 'having', 'do', 'does', 'did', 'doing', 'a', 'an', 'the', 'and', 'but', 'if', 'or', 'because', 'as', 'until', 'while', 'of', 'at', 'by', 'for', 'with', 'about', 'against', 'between', 'into', 'through', 'during', 'before', 'after', 'above', 'below', 'to', 'from', 'up', 'down', 'in', 'out', 'on', 'off', 'over', 'under', 'again', 'further', 'then', 'once', 'here', 'there', 'when', 'where', 'why', 'how', 'all', 'any', 'both', 'each', 'few', 'more', 'most', 'other', 'some', 'such', 'no', 'nor', 'not', 'only', 'own', 'same', 'so', 'than', 'too', 'very', 's', 't', 'can', 'will', 'just', 'don', "don't", 'should', "should've", 'now', 'd', 'll', 'm', 'o', 're', 've', 'y', 'ain', 'aren', "aren't", 'couldn', "couldn't", 'didn', "didn't", 'doesn', "doesn't", 'hadn', "hadn't", 'hasn', "hasn't", 'haven', "haven't", 'isn', "isn't", 'ma', 'mightn', "mightn't", 'mustn', "mustn't", 'needn', "needn't", 'shan', "shan't", 'shouldn', "shouldn't", 'wasn', "wasn't", 'weren', "weren't", 'won', "won't", 'wouldn', "wouldn't"}

    train_queries, train_raw_queries, train_query_mentions, train_memories, _, train_gold_ans_inds, _ = train_vec
    train_queries, train_query_words, train_query_lengths, train_memories = vectorize_data(train_queries, train_query_mentions, \
                                        train_memories, max_query_size=opt['query_size'], \
                                        max_query_markup_size=opt['query_markup_size'], \
                                        max_mem_size=opt['mem_size'], \
                                        max_ans_bow_size=opt['ans_bow_size'], \
                                        max_ans_path_bow_size=opt['ans_path_bow_size'], \
                                        vocab2id=vocab2id)

    valid_queries, valid_raw_queries, valid_query_mentions, valid_memories, valid_cand_labels, valid_gold_ans_inds, valid_gold_ans_labels = valid_vec
    valid_queries, valid_query_words, valid_query_lengths, valid_memories = vectorize_data(valid_queries, valid_query_mentions, \
                                        valid_memories, max_query_size=opt['query_size'], \
                                        max_query_markup_size=opt['query_markup_size'], \
                                        max_mem_size=opt['mem_size'], \
                                        max_ans_bow_size=opt['ans_bow_size'], \
                                        max_ans_path_bow_size=opt['ans_path_bow_size'], \
                                        vocab2id=vocab2id)

    start = timeit.default_timer()

    model = BAMnetAgent(opt, ctx_stopwords, vocab2id)
    model.train([train_memories, train_queries, train_query_words, train_raw_queries, train_query_mentions, train_query_lengths], train_gold_ans_inds, \
        [valid_memories, valid_queries, valid_query_words, valid_raw_queries, valid_query_mentions, valid_query_lengths], \
        valid_gold_ans_inds, valid_cand_labels, valid_gold_ans_labels)

    print('Runtime: %ss' % (timeit.default_timer() - start))


================================================
FILE: src/train_entnet.py
================================================
import timeit
import argparse
import numpy as np

from core.bamnet.entnet import EntnetAgent
from core.build_data.build_all import build
from core.build_data.utils import vectorize_ent_data
from core.utils.utils import *


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-config', '--config', required=True, type=str, help='path to the config file')
    cfg = vars(parser.parse_args())
    opt = get_config(cfg['config'])
    print_config(opt)

    # Ensure data is built
    build(opt['data_dir'])
    train_vec = load_json(os.path.join(opt['data_dir'], opt['train_data']))
    valid_vec = load_json(os.path.join(opt['data_dir'], opt['valid_data']))

    train_queries, train_memories, _, train_ent_inds = train_vec
    train_queries, train_query_lengths, train_memories = vectorize_ent_data(train_queries, \
                                        train_memories, max_query_size=opt['query_size'], \
                                        max_seed_ent_name_size=opt['max_seed_ent_name_size'], \
                                        max_seed_type_name_size=opt['max_seed_type_name_size'], \
                                        max_seed_rel_name_size=opt['max_seed_rel_name_size'], \
                                        max_seed_rel_size=opt['max_seed_rel_size'])

    valid_queries, valid_memories, _, valid_ent_inds = valid_vec
    valid_queries, valid_query_lengths, valid_memories = vectorize_ent_data(valid_queries, \
                                        valid_memories, max_query_size=opt['query_size'], \
                                        max_seed_ent_name_size=opt['max_seed_ent_name_size'], \
                                        max_seed_type_name_size=opt['max_seed_type_name_size'], \
                                        max_seed_rel_name_size=opt['max_seed_rel_name_size'], \
                                        max_seed_rel_size=opt['max_seed_rel_size'])

    start = timeit.default_timer()

    ent_model = EntnetAgent(opt)
    ent_model.train([train_memories, train_queries, train_query_lengths], train_ent_inds, \
        [valid_memories, valid_queries, valid_query_lengths], valid_ent_inds)

    print('Runtime: %ss' % (timeit.default_timer() - start))
Download .txt
gitextract_pbgz21hk/

├── .gitignore
├── LICENSE
├── README.md
├── requirements.txt
└── src/
    ├── build_all_data.py
    ├── build_pretrained_w2v.py
    ├── config/
    │   ├── bamnet_webq.yml
    │   └── entnet_webq.yml
    ├── core/
    │   ├── __init__.py
    │   ├── bamnet/
    │   │   ├── __init__.py
    │   │   ├── bamnet.py
    │   │   ├── ent_modules.py
    │   │   ├── entnet.py
    │   │   ├── modules.py
    │   │   └── utils.py
    │   ├── build_data/
    │   │   ├── __init__.py
    │   │   ├── build_all.py
    │   │   ├── build_data.py
    │   │   ├── freebase.py
    │   │   ├── utils.py
    │   │   └── webquestions.py
    │   ├── config.py
    │   └── utils/
    │       ├── __init__.py
    │       ├── freebase_utils.py
    │       ├── generic_utils.py
    │       ├── metrics.py
    │       └── utils.py
    ├── joint_test.py
    ├── run_freebase.py
    ├── run_webquestions.py
    ├── test.py
    ├── test_entnet.py
    ├── train.py
    └── train_entnet.py
Download .txt
SYMBOL INDEX (130 symbols across 16 files)

FILE: src/core/bamnet/bamnet.py
  function get_text_overlap (line 26) | def get_text_overlap(raw_query, query_mentions, ctx_ent_names, vocab2id,...
  class BAMnetAgent (line 57) | class BAMnetAgent(object):
    method __init__ (line 60) | def __init__(self, opt, ctx_stops, vocab2id):
    method train (line 106) | def train(self, train_X, train_y, valid_X, valid_y, valid_cand_labels,...
    method predict (line 166) | def predict(self, xs, cand_labels, batch_size=32, margin=1, ys=None, v...
    method train_step (line 179) | def train_step(self, xs, ys, is_training=True):
    method predict_step (line 210) | def predict_step(self, xs, cand_labels, margin, verbose=False):
    method dynamic_ctx_negative_sampling (line 225) | def dynamic_ctx_negative_sampling(self, memories, ys, mem_size, ctx_bo...
    method pad_ctx_memory (line 279) | def pad_ctx_memory(self, memories, ctx_bow_size, raw_queries, query_me...
    method pack_gold_ans (line 321) | def pack_gold_ans(self, x, N, placeholder=-1):
    method set_loss_margin (line 329) | def set_loss_margin(self, scores, gold_mask, margin):
    method ranked_predictions (line 338) | def ranked_predictions(self, cand_labels, scores, margin):
    method save (line 345) | def save(self, path=None):
    method load (line 356) | def load(self, path):

FILE: src/core/bamnet/ent_modules.py
  class Entnet (line 20) | class Entnet(nn.Module):
    method __init__ (line 21) | def __init__(self, vocab_size, vocab_embed_size, o_embed_size, \
    method forward (line 64) | def forward(self, memories, queries, query_lengths):
    method clf_score (line 100) | def clf_score(self, q_r, ent_key):
    method create_mask (line 103) | def create_mask(self, x, N, use_cuda=True):
    method create_mask_3D (line 110) | def create_mask_3D(self, x, N, use_cuda=True):
  class EntEncoder (line 118) | class EntEncoder(nn.Module):
    method __init__ (line 120) | def __init__(self, o_embed_size, hidden_size, num_ent_types, num_relat...
    method forward (line 168) | def forward(self, x_ent_names, x_ent_name_len, x_type_names, x_types, ...
    method enc_kg_features (line 182) | def enc_kg_features(self, x_ent_names, x_ent_name_len, x_type_names, x...
  class EntRomHop (line 198) | class EntRomHop(nn.Module):
    method __init__ (line 199) | def __init__(self, query_embed_size, in_memory_embed_size, hidden_size...
    method forward (line 204) | def forward(self, h_state, key_memory_embed, val_memory_embed, atten_m...
  class GRUStep (line 211) | class GRUStep(nn.Module):
    method __init__ (line 212) | def __init__(self, hidden_size, input_size):
    method forward (line 219) | def forward(self, h_state, input_):

FILE: src/core/bamnet/entnet.py
  class EntnetAgent (line 24) | class EntnetAgent(object):
    method __init__ (line 25) | def __init__(self, opt):
    method train (line 69) | def train(self, train_X, train_y, valid_X, valid_y, seed=1234):
    method evaluate (line 126) | def evaluate(self, xs, ys, batch_size=1, silence=False):
    method predict (line 141) | def predict(self, xs, cand_labels, batch_size=1, silence=False):
    method train_step (line 152) | def train_step(self, xs, ys, is_training=True):
    method evaluate_step (line 179) | def evaluate_step(self, xs, ys):
    method predict_step (line 193) | def predict_step(self, xs, cand_labels):
    method pack_gold_ans (line 205) | def pack_gold_ans(self, x, N, placeholder=-1):
    method ranked_predictions (line 213) | def ranked_predictions(self, cand_labels, scores):
    method save (line 218) | def save(self, path=None):
    method load (line 229) | def load(self, path):

FILE: src/core/bamnet/modules.py
  class BAMnet (line 19) | class BAMnet(nn.Module):
    method __init__ (line 20) | def __init__(self, vocab_size, vocab_embed_size, o_embed_size, \
    method kb_aware_query_enc (line 60) | def kb_aware_query_enc(self, memories, queries, query_lengths, ans_mas...
    method forward (line 97) | def forward(self, memories, queries, query_lengths, query_words, ctx_m...
    method premature_score (line 136) | def premature_score(self, memories, queries, query_lengths, ctx_mask=N...
    method scoring (line 149) | def scoring(self, ans_r, q_r, mask=None):
  class RomHop (line 155) | class RomHop(nn.Module):
    method __init__ (line 156) | def __init__(self, query_embed_size, in_memory_embed_size, hidden_size...
    method forward (line 164) | def forward(self, query_embed, in_memory_embed, out_memory_embed, quer...
    method gru_step (line 170) | def gru_step(self, h_state, in_memory_embed, out_memory_embed, atten_m...
    method update_coatt_cat_maxpool (line 182) | def update_coatt_cat_maxpool(self, query_embed, in_memory_embed, out_m...
  class AnsEncoder (line 208) | class AnsEncoder(nn.Module):
    method __init__ (line 210) | def __init__(self, o_embed_size, hidden_size, num_ent_types, num_relat...
    method forward (line 250) | def forward(self, x_type_bow, x_types, x_type_bow_len, x_path_bow, x_p...
    method enc_comp_kv (line 255) | def enc_comp_kv(self, ans_type_bow, ans_types, ans_path_bow, ans_paths...
    method enc_ans_features (line 268) | def enc_ans_features(self, x_type_bow, x_types, x_type_bow_len, x_path...
  class SeqEncoder (line 293) | class SeqEncoder(object):
    method __init__ (line 295) | def __init__(self, vocab_size, embed_size, hidden_size, \
  class EncoderRNN (line 317) | class EncoderRNN(nn.Module):
    method __init__ (line 318) | def __init__(self, vocab_size, embed_size, hidden_size, dropout=None, \
    method init_weights (line 340) | def init_weights(self, init_word_embed):
    method forward (line 347) | def forward(self, x, x_len):
  class EncoderCNN (line 378) | class EncoderCNN(nn.Module):
    method __init__ (line 379) | def __init__(self, vocab_size, embed_size, hidden_size, kernel_size=[2...
    method init_weights (line 393) | def init_weights(self, init_word_embed):
    method forward (line 400) | def forward(self, x, x_len=None):
  class Attention (line 419) | class Attention(nn.Module):
    method __init__ (line 420) | def __init__(self, hidden_size, h_state_embed_size=None, in_memory_emb...
    method forward (line 440) | def forward(self, query_embed, in_memory_embed, atten_mask=None):
  class SelfAttention_CoAtt (line 458) | class SelfAttention_CoAtt(nn.Module):
    method __init__ (line 459) | def __init__(self, hidden_size, use_cuda=True):
    method forward (line 465) | def forward(self, x, x_len, atten_mask):
  function create_mask (line 485) | def create_mask(x, N, use_cuda=True):

FILE: src/core/bamnet/utils.py
  function to_cuda (line 12) | def to_cuda(x, use_cuda=True):
  function next_batch (line 18) | def next_batch(memories, queries, query_words, raw_queries, query_mentio...
  function next_ent_batch (line 23) | def next_ent_batch(memories, queries, query_lengths, gold_inds, batch_si...

FILE: src/core/build_data/build_all.py
  function build (line 14) | def build(dpath, version=None, out_dir=None):

FILE: src/core/build_data/build_data.py
  function build_kb_data (line 24) | def build_kb_data(kb, used_fbkeys=None):
  function build_qa_vocab (line 114) | def build_qa_vocab(qa):
  function delex_query_topic_ent (line 121) | def delex_query_topic_ent(query, topic_ent, ent_types):
  function delex_query (line 162) | def delex_query(query, ent_mens, mention_types):
  function build_data (line 178) | def build_data(qa, kb, entity2id, entityType2id, relation2id, vocab2id, ...
  function build_vocab (line 225) | def build_vocab(data, freebase, used_fbkeys=None, min_freq=1):
  function build_ans_cands (line 265) | def build_ans_cands(graph, entity2id, entityType2id, relation2id, vocab2...
  function build_seed_ent_data (line 449) | def build_seed_ent_data(qa, kb, entity2id, entityType2id, relation2id, v...
  function build_seed_entity_feature (line 484) | def build_seed_entity_feature(seed_ent, graph, entity2id, entityType2id,...

FILE: src/core/build_data/freebase.py
  function fetch_meta (line 12) | def fetch_meta(path):
  function fetch (line 37) | def fetch(data, data_dir):

FILE: src/core/build_data/utils.py
  function built (line 17) | def built(path, version_string=None):
  function mark_done (line 33) | def mark_done(path, version_string=None):
  function make_dir (line 42) | def make_dir(path):
  function remove_dir (line 46) | def remove_dir(path):
  function vectorize_data (line 50) | def vectorize_data(queries, query_mentions, memories, max_query_size=Non...
  function vectorize_ent_data (line 199) | def vectorize_ent_data(queries, ent_memories, max_query_size=None, \

FILE: src/core/build_data/webquestions.py
  function get_used_fbkeys (line 17) | def get_used_fbkeys(data_dir, out_dir):
  function get_all_fbkeys (line 28) | def get_all_fbkeys(data_dir, out_dir):
  function main (line 48) | def main(fb_path, mid2key_path, data_dir, out_dir):
  function fetch_ans_cands (line 107) | def fetch_ans_cands(graph):

FILE: src/core/utils/freebase_utils.py
  function if_filterout (line 10) | def if_filterout(s):
  function query_kb (line 18) | def query_kb(kb, ent_name, fuzz_threshold=90):

FILE: src/core/utils/generic_utils.py
  function find_parent (line 18) | def find_parent(x, tree, conn='<-'):
  function extract_dep_feature (line 30) | def extract_dep_feature(dep_parser, text, topic_ent, question_word):
  function unique (line 51) | def unique(seq):
  function normalize_answer (line 59) | def normalize_answer(s):
  function dump_embeddings (line 75) | def dump_embeddings(vocab_dict, emb_file, out_path, emb_size=300, binary...
  function get_embeddings (line 88) | def get_embeddings(emb_file, vocab, binary=False):
  class PreTrainEmbedding (line 101) | class PreTrainEmbedding():
    method __init__ (line 102) | def __init__(self, file, binary=False):
    method get_embeddings (line 106) | def get_embeddings(self, word):

FILE: src/core/utils/metrics.py
  function calc_f1 (line 12) | def calc_f1(gold_list, pred_list):
  function calc_avg_f1 (line 44) | def calc_avg_f1(gold_list, pred_list, verbose=True):

FILE: src/core/utils/utils.py
  function get_config (line 20) | def get_config(config_path="config.yml"):
  function print_config (line 25) | def print_config(config):
  function read_lines (line 33) | def read_lines(path_to_file):
  function dump_ndarray (line 45) | def dump_ndarray(data, path_to_file):
  function load_ndarray (line 52) | def load_ndarray(path_to_file):
  function dump_ndjson (line 61) | def dump_ndjson(data, file):
  function load_ndjson (line 69) | def load_ndjson(file, return_type='array'):
  function load_ndjson_to_array (line 77) | def load_ndjson_to_array(file):
  function load_ndjson_to_dict (line 87) | def load_ndjson_to_dict(file):
  function dump_json (line 97) | def dump_json(data, file, indent=None):
  function load_json (line 104) | def load_json(file):
  function dump_dict_ndjson (line 112) | def dump_dict_ndjson(data, file):
  function load_gzip_json (line 121) | def load_gzip_json(file):
  function get_all_files (line 129) | def get_all_files(dir, recursive=False):
  function printProgressBar (line 136) | def printProgressBar(iteration, total, prefix = '', suffix = '', decimal...

FILE: src/joint_test.py
  function dynamic_pred (line 15) | def dynamic_pred(pred, margin):

FILE: src/test.py
  function dynamic_pred (line 12) | def dynamic_pred(pred, margin):
Condensed preview — 34 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (179K chars).
[
  {
    "path": ".gitignore",
    "chars": 1216,
    "preview": "data/\nruns/\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribut"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 4567,
    "preview": "# BAMnet\n\n\nCode & data accompanying the NAACL2019 paper [\"Bidirectional Attentive Memory Networks for Question Answering"
  },
  {
    "path": "requirements.txt",
    "chars": 83,
    "preview": "rapidfuzz==0.3.0\ngensim==3.5.0\nnltk==3.4.5\nnumpy==1.14.5\nPyYAML==5.1\ntorch==0.4.1\n\n"
  },
  {
    "path": "src/build_all_data.py",
    "chars": 4094,
    "preview": "'''\nCreated on Oct, 2017\n\n@author: hugo\n\n'''\nimport argparse\n\nfrom core.build_data.build_data import build_vocab, build_"
  },
  {
    "path": "src/build_pretrained_w2v.py",
    "chars": 974,
    "preview": "'''\nCreated on Oct, 2017\n\n@author: hugo\n\n'''\nimport argparse\nimport os\n\nfrom core.utils.utils import load_json\nfrom core"
  },
  {
    "path": "src/config/bamnet_webq.yml",
    "chars": 827,
    "preview": "# Seed 15 Data\nname: 'WebQuestions'\ndata_dir: '../data/WebQ/'\ntrain_data: 'train_vec.json'\nvalid_data: 'valid_vec.json'\n"
  },
  {
    "path": "src/config/entnet_webq.yml",
    "chars": 804,
    "preview": "# WebQuestions Data\nname: 'WebQuestions'\ndata_dir: '../data/WebQ/'\ntrain_data: 'train_ent_vec.json'\nvalid_data: 'valid_e"
  },
  {
    "path": "src/core/__init__.py",
    "chars": 45,
    "preview": "'''\nCreated on Oct, 2017\n\n@author: hugo\n\n'''\n"
  },
  {
    "path": "src/core/bamnet/__init__.py",
    "chars": 45,
    "preview": "'''\nCreated on Oct, 2017\n\n@author: hugo\n\n'''\n"
  },
  {
    "path": "src/core/bamnet/bamnet.py",
    "chars": 17939,
    "preview": "'''\nCreated on Sep, 2017\n\n@author: hugo\n\n'''\nimport os\nimport timeit\nimport numpy as np\n\nimport torch\nfrom torch import "
  },
  {
    "path": "src/core/bamnet/ent_modules.py",
    "chars": 11193,
    "preview": "'''\nCreated on Sep, 2018\n\n@author: hugo\n\n'''\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nfrom torch.nn.utils."
  },
  {
    "path": "src/core/bamnet/entnet.py",
    "chars": 10201,
    "preview": "'''\nCreated on Sep, 2018\n\n@author: hugo\n\n'''\nimport os\nimport timeit\nimport numpy as np\n\nimport torch\nfrom torch import "
  },
  {
    "path": "src/core/bamnet/modules.py",
    "chars": 26042,
    "preview": "'''\nCreated on Sep, 2017\n\n@author: hugo\n\n'''\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nfrom torch.nn.utils."
  },
  {
    "path": "src/core/bamnet/utils.py",
    "chars": 956,
    "preview": "'''\nCreated on Oct, 2017\n\n@author: hugo\n\n'''\nimport torch\nfrom torch.autograd import Variable\nimport numpy as np\n\n\ndef t"
  },
  {
    "path": "src/core/build_data/__init__.py",
    "chars": 45,
    "preview": "'''\nCreated on Oct, 2017\n\n@author: hugo\n\n'''\n"
  },
  {
    "path": "src/core/build_data/build_all.py",
    "chars": 380,
    "preview": "'''\nCreated on Oct, 2017\n\n@author: hugo\n\n'''\nimport os\n\nfrom . import utils as build_utils\nfrom ..utils.utils import *\nf"
  },
  {
    "path": "src/core/build_data/build_data.py",
    "chars": 27020,
    "preview": "'''\nCreated on Sep, 2017\n\n@author: hugo\n\n'''\nimport os\nimport math\nimport argparse\nfrom itertools import count\nfrom rapi"
  },
  {
    "path": "src/core/build_data/freebase.py",
    "chars": 2863,
    "preview": "'''\nCreated on Sep, 2017\n\n@author: hugo\n\n'''\nimport os\n\nfrom ..utils.utils import *\n\n\ndef fetch_meta(path):\n    try:\n   "
  },
  {
    "path": "src/core/build_data/utils.py",
    "chars": 14408,
    "preview": "'''\nCreated on Sep, 2017\n\n@author: hugo\n\n'''\nimport os\nimport datetime\nimport shutil\nfrom collections import defaultdict"
  },
  {
    "path": "src/core/build_data/webquestions.py",
    "chars": 6602,
    "preview": "'''\nCreated on Sep, 2017\n\n@author: hugo\n\n'''\nimport os\n# import re\nimport argparse\nfrom nltk.parse.stanford import Stanf"
  },
  {
    "path": "src/core/config.py",
    "chars": 813,
    "preview": "\n# Vocabulary\nRESERVED_TOKENS = {'PAD': 0, 'UNK': 1}\nRESERVED_ENTS = {'PAD': 0, 'UNK': 1}\nRESERVED_ENT_TYPES = {'PAD': 0"
  },
  {
    "path": "src/core/utils/__init__.py",
    "chars": 45,
    "preview": "'''\nCreated on Oct, 2017\n\n@author: hugo\n\n'''\n"
  },
  {
    "path": "src/core/utils/freebase_utils.py",
    "chars": 708,
    "preview": "'''\nCreated on Oct, 2017\n\n@author: hugo\n\n'''\nfrom rapidfuzz import fuzz, process\n\n\ndef if_filterout(s):\n    if s.endswit"
  },
  {
    "path": "src/core/utils/generic_utils.py",
    "chars": 3562,
    "preview": "'''\nCreated on Oct, 2017\n\n@author: hugo\n\n'''\nimport re, string\nimport numpy as np\nfrom rapidfuzz import fuzz, process\nfr"
  },
  {
    "path": "src/core/utils/metrics.py",
    "chars": 2533,
    "preview": "'''\nCreated on Oct, 2017\n\n@author: hugo\n\nNote: Modified the official evaluation script provided by Berant et al.\n(https:"
  },
  {
    "path": "src/core/utils/utils.py",
    "chars": 4519,
    "preview": "'''\nCreated on Sep, 2017\n\n@author: hugo\n\n'''\nimport os\nimport re\nimport yaml\nimport gzip\nimport json\nimport string\nimpor"
  },
  {
    "path": "src/joint_test.py",
    "chars": 5580,
    "preview": "import timeit\nimport argparse\nimport numpy as np\n\nfrom core.bamnet.entnet import EntnetAgent\nfrom core.bamnet.bamnet imp"
  },
  {
    "path": "src/run_freebase.py",
    "chars": 1561,
    "preview": "'''\nCreated on Oct, 2017\n\n@author: hugo\n\n'''\nimport argparse\nimport os\nimport json\n\nfrom core.build_data.freebase import"
  },
  {
    "path": "src/run_webquestions.py",
    "chars": 750,
    "preview": "'''\nCreated on Oct, 2017\n\n@author: hugo\n\n'''\nimport argparse\nfrom core.build_data.webquestions import *\n\nparser = argpar"
  },
  {
    "path": "src/test.py",
    "chars": 3484,
    "preview": "import timeit\nimport argparse\n\nfrom core.bamnet.bamnet import BAMnetAgent\nfrom core.build_data.build_all import build\nfr"
  },
  {
    "path": "src/test_entnet.py",
    "chars": 1564,
    "preview": "import timeit\nimport argparse\nimport numpy as np\n\nfrom core.bamnet.entnet import EntnetAgent\nfrom core.build_data.build_"
  },
  {
    "path": "src/train.py",
    "chars": 4251,
    "preview": "import timeit\nimport argparse\nimport numpy as np\n\nfrom core.bamnet.bamnet import BAMnetAgent\nfrom core.build_data.build_"
  },
  {
    "path": "src/train_entnet.py",
    "chars": 2246,
    "preview": "import timeit\nimport argparse\nimport numpy as np\n\nfrom core.bamnet.entnet import EntnetAgent\nfrom core.build_data.build_"
  }
]

About this extraction

This page contains the full source code of the hugochan/BAMnet GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 34 files (169.2 KB), approximately 43.9k tokens, and a symbol index with 130 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!