[
  {
    "path": ".gitignore",
    "content": "data/\nruns/\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# pyenv\n.python-version\n\n# celery beat schedule file\ncelerybeat-schedule\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "# BAMnet\n\n\nCode & data accompanying the NAACL2019 paper [\"Bidirectional Attentive Memory Networks for Question Answering over Knowledge Bases\"](https://arxiv.org/abs/1903.02188)\n\n\n## Get started\n\n\n### Prerequisites\nThis code is written in python 3. You will need to install a few python packages in order to run the code.\nWe recommend you to use `virtualenv` to manage your python packages and environments.\nPlease take the following steps to create a python virtual environment.\n\n* If you have not installed `virtualenv`, install it with ```pip install virtualenv```.\n* Create a virtual environment with ```virtualenv venv```.\n* Activate the virtual environment with `source venv/bin/activate`.\n* Install the package requirements with `pip install -r requirements.txt`.\n\n\n\n\n### Run the KBQA system\n\n* Download the preprocessed data from [here](https://1drv.ms/u/s!AjiSpuwVTt09gSE2niFGjdIVsqA7?e=PEf6sT) and put the data folder under the root directory.\n\n\n* Create a folder (e.g., `runs/WebQ/`) to save model checkpoint. You can download the pretrained models from [here](https://1drv.ms/u/s!AjiSpuwVTt09gSLcnrp0GyKtpWBg?e=DtqYt8). (Note: if you cannot access the above data and pretrained models, please download from [here](http://academic.hugochan.net/download/BAMnet-WebQ.zip).)\n\n\n* Please modify the config files in the `src/config/` folder to suit your needs. Note that you can start with modifying only the data folder (e.g., `data_dir`, `model_file`, `pre_word2vec`) and vocab size (e.g., `vocab_size`, `num_ent_types`, `num_relations`), and leave other hyperparameters as they are.\n\n\n* Go to the `BAMnet/src` folder, train the BAMnet model\n\n\t```\n\tpython train.py -config config/bamnet_webq.yml\n\t```\n\t\n\n*  Test the BAMnet model (with ground-truth topic entity)\n\t\n\t```\n\tpython test.py -config config/bamnet_webq.yml\n\t```\n\n*  Train the topic entity predictor\n\n\t```\n\tpython train_entnet.py -config config/entnet_webq.yml\n\t```\n\n*  Test the topic entity predictor\n\n\t```\n\tpython test_entnet.py -config config/entnet_webq.yml\n\t```\n\n*  Test the whole system (BAMnet + topic entity predictor)\n\n\t```\n\tpython joint_test.py -bamnet_config config/bamnet_webq.yml -entnet_config config/entnet_webq.yml -raw_data ../data/WebQ\n\t```\n\n\n\n### Preprocess the dataset on your own\n\n* Go to the `BAMnet/src` folder, to prepare data for the BAMnet model, run the following cmd:\n\n\t```\n\tpython build_all_data.py -data_dir ../data/WebQ -fb_dir ../data/WebQ -out_dir ../data/WebQ\n\t```\n\t\n* To prepare data for the topic entity predictor model, run the following cmd:\n\n\t```\n\tpython build_all_data.py -dtype ent -data_dir ../data/WebQ -fb_dir ../data/WebQ -out_dir ../data/WebQ\n\t```\n\n\n Note that in the message printed out, your will see some data statistics such as `vocab_size`, `num_ent_types `, `num_relations`. These numbers will be used later when modifying the config files.\n\n\n* Download the pretrained Glove word ebeddings [glove.840B.300d.zip](http://nlp.stanford.edu/data/wordvecs/glove.840B.300d.zip).\n\n* Unzip the file and convert glove format to word2vec format using the following cmd:\n\n\t```\n\tpython -m gensim.scripts.glove2word2vec --input glove.840B.300d.txt --output glove.840B.300d.w2v\n\t```\n\n* Fetch the pretrained Glove vectors for our vocabulary.\n\n\t```\n\tpython build_pretrained_w2v.py -emb glove.840B.300d.w2v -data_dir ../data/WebQ -out ../data/WebQ/glove_pretrained_300d_w2v.npy -emb_size 300\n\t```\n\n\n\n\n## Architecture\n\n<center><img src=\"images/overall_arch.png\"/></center>\n\n\n\n## Experiment results on WebQuestions\n\n\n### Results on WebQuestions test set. Bold: best in-category performance. \r\n\r\n\n<center><img src=\"images/results.png\" width=\"300\" height=\"500\"/></center>\n\n\n\n\n\n\n### Predicted answers of BAMnet w/ and w/o bidirectional attention on the WebQuestions test set\r\n\n![pred_examples](images/pred_examples.png \"pred_examples\")\n\n\n\n### Attention heatmap generated by the reasoning module\r\n\n![attn_heatmap](images/attn_heatmap.png \"attn_heatmap\")\n\n\n\n\n\n## Reference\n\nIf you found this code useful, please consider citing the following paper:\n\nYu Chen, Lingfei Wu, Mohammed J. Zaki. **\"Bidirectional Attentive Memory Networks for Question Answering over Knowledge Bases.\"** *In Proc. 2019 Annual Conference of the North American Chapter of the Association for Computational Linguistics (NAACL-HLT2019). June 2019.*\n\n\n\t@article{chen2019bidirectional,\n\t  title={Bidirectional Attentive Memory Networks for Question Answering over Knowledge Bases},\n\t  author={Chen, Yu and Wu, Lingfei and Zaki, Mohammed J},\n\t  journal={arXiv preprint arXiv:1903.02188},\n\t  year={2019}\n\t}\n"
  },
  {
    "path": "requirements.txt",
    "content": "rapidfuzz==0.3.0\ngensim==3.5.0\nnltk==3.4.5\nnumpy==1.14.5\nPyYAML==5.1\ntorch==0.4.1\n\n"
  },
  {
    "path": "src/build_all_data.py",
    "content": "'''\nCreated on Oct, 2017\n\n@author: hugo\n\n'''\nimport argparse\n\nfrom core.build_data.build_data import build_vocab, build_data, build_seed_ent_data\nfrom core.utils.utils import *\nfrom core.build_data import utils as build_utils\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('-data_dir', '--data_dir', required=True, type=str, help='path to the data dir')\n    parser.add_argument('-fb_dir', '--fb_dir', required=True, type=str, help='path to the freebase dir')\n    parser.add_argument('-out_dir', '--out_dir', required=True, type=str, help='path to the output dir')\n    parser.add_argument('-dtype', '--data_type', default='qa', type=str, help='data type')\n    parser.add_argument('-min_freq', '--min_freq', default=1, type=int, help='min word vocab freq')\n    parser.add_argument('-topn', '--topn', default=15, type=int, help='top n candidates')\n    args = parser.parse_args()\n\n    train_data = load_ndjson(os.path.join(args.data_dir, 'raw_train.json'))\n    valid_data = load_ndjson(os.path.join(args.data_dir, 'raw_valid.json'))\n    test_data = load_ndjson(os.path.join(args.data_dir, 'raw_test.json'))\n    freebase = load_ndjson(os.path.join(args.fb_dir, 'freebase_full.json'), return_type='dict')\n\n    if not (os.path.exists(os.path.join(args.out_dir, 'entity2id.json')) and \\\n        os.path.exists(os.path.join(args.out_dir, 'entityType2id.json')) and \\\n        os.path.exists(os.path.join(args.out_dir, 'relation2id.json')) and \\\n        os.path.exists(os.path.join(args.out_dir, 'vocab2id.json'))):\n\n        used_fbkeys = set()\n        for each in train_data + valid_data:\n            used_fbkeys.update(each['freebaseKeyCands'][:args.topn])\n        print('# of used_fbkeys: {}'.format(len(used_fbkeys)))\n\n        entity2id, entityType2id, relation2id, vocab2id = build_vocab(train_data + valid_data, freebase, used_fbkeys, min_freq=args.min_freq)\n        dump_json(entity2id, os.path.join(args.out_dir, 'entity2id.json'))\n        dump_json(entityType2id, os.path.join(args.out_dir, 'entityType2id.json'))\n        dump_json(relation2id, os.path.join(args.out_dir, 'relation2id.json'))\n        dump_json(vocab2id, os.path.join(args.out_dir, 'vocab2id.json'))\n    else:\n        entity2id = load_json(os.path.join(args.out_dir, 'entity2id.json'))\n        entityType2id = load_json(os.path.join(args.out_dir, 'entityType2id.json'))\n        relation2id = load_json(os.path.join(args.out_dir, 'relation2id.json'))\n        vocab2id = load_json(os.path.join(args.out_dir, 'vocab2id.json'))\n        print('Using pre-built vocabs stored in %s' % args.out_dir)\n\n    if args.data_type == 'qa':\n        train_vec = build_data(train_data, freebase, entity2id, entityType2id, relation2id, vocab2id)\n        valid_vec = build_data(valid_data, freebase, entity2id, entityType2id, relation2id, vocab2id)\n        test_vec = build_data(test_data, freebase, entity2id, entityType2id, relation2id, vocab2id)\n        dump_json(train_vec, os.path.join(args.out_dir, 'train_vec.json'))\n        dump_json(valid_vec, os.path.join(args.out_dir, 'valid_vec.json'))\n        dump_json(test_vec, os.path.join(args.out_dir, 'test_vec.json'))\n        print('Saved data to {}'.format(os.path.join(args.out_dir, 'train(valid, or test)_vec.json')))\n    else:\n        train_vec = build_seed_ent_data(train_data, freebase, entity2id, entityType2id, relation2id, vocab2id, args.topn, dtype='train')\n        valid_vec = build_seed_ent_data(valid_data, freebase, entity2id, entityType2id, relation2id, vocab2id, args.topn, dtype='valid')\n        test_vec = build_seed_ent_data(test_data, freebase, entity2id, entityType2id, relation2id, vocab2id, args.topn, dtype='test')\n        dump_json(train_vec, os.path.join(args.out_dir, 'train_ent_vec.json'))\n        dump_json(valid_vec, os.path.join(args.out_dir, 'valid_ent_vec.json'))\n        dump_json(test_vec, os.path.join(args.out_dir, 'test_ent_vec.json'))\n        print('Saved data to {}'.format(os.path.join(args.out_dir, 'train(valid, or test)_ent_vec.json')))\n\n    # Mark the data as built.\n    build_utils.mark_done(args.out_dir)\n"
  },
  {
    "path": "src/build_pretrained_w2v.py",
    "content": "'''\nCreated on Oct, 2017\n\n@author: hugo\n\n'''\nimport argparse\nimport os\n\nfrom core.utils.utils import load_json\nfrom core.utils.generic_utils import dump_embeddings\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('-emb', '--embed_path', required=True, type=str, help='path to the pretrained word embeddings')\n    parser.add_argument('-data_dir', '--data_dir', required=True, type=str, help='path to the data dir')\n    parser.add_argument('-out', '--out_path', required=True, type=str, help='path to the output path')\n    parser.add_argument('-emb_size', '--emb_size', required=True, type=int, help='embedding size')\n    parser.add_argument('--binary', action='store_true', help='flag: binary file')\n    args = parser.parse_args()\n\n    vocab_dict = load_json(os.path.join(args.data_dir, 'vocab2id.json'))\n    dump_embeddings(vocab_dict, args.embed_path, args.out_path, emb_size=args.emb_size, binary=True if args.binary else False)\n"
  },
  {
    "path": "src/config/bamnet_webq.yml",
    "content": "# Seed 15 Data\nname: 'WebQuestions'\ndata_dir: '../data/WebQ/'\ntrain_data: 'train_vec.json'\nvalid_data: 'valid_vec.json'\ntest_data: 'test_vec.json'\npre_word2vec: '../data/WebQ/glove_pretrained_300d_w2v.npy'\n\n# Full vocab\nvocab_size: 100797\nnum_ent_types: 1712\nnum_relations: 4996\n\nnum_query_words: 10\n\n# Output\nmodel_file: '../runs/WebQ/bamnet.md'\n\n# Model\nquery_size: 32\nquery_markup_size: 1 # Not used\nans_bow_size: 1 # Not used\nans_path_bow_size: null\nans_ctx_entity_bow_size: 6\n\nvocab_embed_size: 300\nhidden_size: 128\no_embed_size: 128\nmem_size: 96\nword_emb_dropout: 0.3\nque_enc_dropout: 0.3\nans_enc_dropout: 0.2\nattention: 'add'\nnum_hops: 1\n\n# Training\nlearning_rate: 0.001\nbatch_size: 32\nnum_epochs: 100\nvalid_patience: 10\nmargin: 1\n\n# Testing\ntest_batch_size: 1\ntest_margin:\n        - 0.7\n\n# Device\nno_cuda: False\ngpu: 0\n"
  },
  {
    "path": "src/config/entnet_webq.yml",
    "content": "# WebQuestions Data\nname: 'WebQuestions'\ndata_dir: '../data/WebQ/'\ntrain_data: 'train_ent_vec.json'\nvalid_data: 'valid_ent_vec.json'\ntest_data: 'test_ent_vec.json'\n\n# Full vocab\nvocab_size: 100797\nnum_ent_types: 1712\nnum_relations: 4996\npre_word2vec: '../data/WebQ/glove_pretrained_300d_w2v.npy'\n\n\n# Output\nmodel_file: '../runs/WebQ/entnet.md'\n\n\n# Model\nquery_size: 32\nmax_seed_ent_name_size: null\nmax_seed_type_name_size: null\nmax_seed_rel_name_size: null\nmax_seed_rel_size: null\n\nvocab_embed_size: 300\nhidden_size: 128\no_embed_size: 128\nword_emb_dropout: 0.3\nque_enc_dropout: 0.3\nent_enc_dropout: 0.2\nattention: 'simple'\nseq_enc_type: 'cnn'\nnum_ent_hops: 1\n\n# Training\nlearning_rate: 0.001\nbatch_size: 32\nnum_epochs: 100\nvalid_patience: 10\n\n# Testing\ntest_batch_size: 1\n\n# Device\nno_cuda: False\ngpu: 0\n"
  },
  {
    "path": "src/core/__init__.py",
    "content": "'''\nCreated on Oct, 2017\n\n@author: hugo\n\n'''\n"
  },
  {
    "path": "src/core/bamnet/__init__.py",
    "content": "'''\nCreated on Oct, 2017\n\n@author: hugo\n\n'''\n"
  },
  {
    "path": "src/core/bamnet/bamnet.py",
    "content": "'''\nCreated on Sep, 2017\n\n@author: hugo\n\n'''\nimport os\nimport timeit\nimport numpy as np\n\nimport torch\nfrom torch import optim\nfrom torch.optim.lr_scheduler import ReduceLROnPlateau\nfrom torch.nn import MultiLabelMarginLoss\nimport torch.backends.cudnn as cudnn\n\nfrom .modules import BAMnet\nfrom .utils import to_cuda, next_batch\nfrom ..utils.utils import load_ndarray\nfrom ..utils.generic_utils import unique\nfrom ..utils.metrics import *\nfrom .. import config\n\n\nCTX_BOW_INDEX = -5\ndef get_text_overlap(raw_query, query_mentions, ctx_ent_names, vocab2id, ctx_stops, query):\n    def longest_common_substring(s1, s2):\n       m = [[0] * (1 + len(s2)) for i in range(1 + len(s1))]\n       longest, x_longest = 0, 0\n       for x in range(1, 1 + len(s1)):\n           for y in range(1, 1 + len(s2)):\n               if s1[x - 1] == s2[y - 1]:\n                   m[x][y] = m[x - 1][y - 1] + 1\n                   if m[x][y] > longest:\n                       longest = m[x][y]\n                       x_longest = x\n               else:\n                   m[x][y] = 0\n       return s1[x_longest - longest: x_longest]\n\n    sub_seq = longest_common_substring(raw_query, ctx_ent_names)\n    if len(set(sub_seq) - ctx_stops) == 0:\n        return []\n\n    men_type = None\n    for men, type_ in query_mentions:\n        if type_.lower() in config.constraint_mention_types:\n            if '_'.join(sub_seq) in '_'.join(men):\n                men_type = '__{}__'.format(type_.lower())\n                break\n\n    if men_type:\n        return [vocab2id[men_type] if men_type in vocab2id else config.RESERVED_TOKENS['UNK']]\n    else:\n        return [vocab2id[x] if x in vocab2id else config.RESERVED_TOKENS['UNK'] for x in sub_seq]\n\nclass BAMnetAgent(object):\n    \"\"\" Bidirectional attentive memory network agent.\n    \"\"\"\n    def __init__(self, opt, ctx_stops, vocab2id):\n        self.ctx_stops = ctx_stops\n        self.vocab2id = vocab2id\n        opt['cuda'] = not opt['no_cuda'] and torch.cuda.is_available()\n        if opt['cuda']:\n            print('[ Using CUDA ]')\n            torch.cuda.set_device(opt['gpu'])\n            # It enables benchmark mode in cudnn, which\n            # leads to faster runtime when the input sizes do not vary.\n            cudnn.benchmark = True\n\n        self.opt = opt\n        if self.opt['pre_word2vec']:\n            pre_w2v = load_ndarray(self.opt['pre_word2vec'])\n        else:\n            pre_w2v = None\n\n        self.model = BAMnet(opt['vocab_size'], opt['vocab_embed_size'], \\\n                opt['o_embed_size'], opt['hidden_size'], \\\n                opt['num_ent_types'], opt['num_relations'], \\\n                opt['num_query_words'], \\\n                word_emb_dropout=opt['word_emb_dropout'], \\\n                que_enc_dropout=opt['que_enc_dropout'], \\\n                ans_enc_dropout=opt['ans_enc_dropout'], \\\n                pre_w2v=pre_w2v, \\\n                num_hops=opt['num_hops'], \\\n                att=opt['attention'], \\\n                use_cuda=opt['cuda'])\n        if opt['cuda']:\n            self.model.cuda()\n\n        # MultiLabelMarginLoss\n        # For each sample in the mini-batch:\n        # loss(x, y) = sum_ij(max(0, 1 - (x[y[j]] - x[i]))) / x.size(0)\n        self.loss_fn = MultiLabelMarginLoss()\n\n        optim_params = [p for p in self.model.parameters() if p.requires_grad]\n        self.optimizers = {'bamnet': optim.Adam(optim_params, lr=opt['learning_rate'])}\n        self.scheduler = ReduceLROnPlateau(self.optimizers['bamnet'], mode='min', \\\n                    patience=self.opt['valid_patience'] // 3, verbose=True)\n\n        if opt.get('model_file') and os.path.isfile(opt['model_file']):\n            print('Loading existing model parameters from ' + opt['model_file'])\n            self.load(opt['model_file'])\n        super(BAMnetAgent, self).__init__()\n\n    def train(self, train_X, train_y, valid_X, valid_y, valid_cand_labels, valid_gold_ans_labels, seed=1234):\n        print('Training size: {}, Validation size: {}'.format(len(train_y), len(valid_y)))\n        random1 = np.random.RandomState(seed)\n        random2 = np.random.RandomState(seed)\n        random3 = np.random.RandomState(seed)\n        random4 = np.random.RandomState(seed)\n        random5 = np.random.RandomState(seed)\n        random6 = np.random.RandomState(seed)\n        random7 = np.random.RandomState(seed)\n        memories, queries, query_words, raw_queries, query_mentions, query_lengths = train_X\n        gold_ans_inds = train_y\n\n        valid_memories, valid_queries, valid_query_words, valid_raw_queries, valid_query_mentions, valid_query_lengths = valid_X\n        valid_gold_ans_inds = valid_y\n\n        n_incr_error = 0  # nb. of consecutive increase in error\n        best_loss = float(\"inf\")\n        num_batches = len(queries) // self.opt['batch_size'] + (len(queries) % self.opt['batch_size'] != 0)\n        num_valid_batches = len(valid_queries) // self.opt['batch_size'] + (len(valid_queries) % self.opt['batch_size'] != 0)\n        for epoch in range(1, self.opt['num_epochs'] + 1):\n            start = timeit.default_timer()\n            n_incr_error += 1\n            random1.shuffle(memories)\n            random2.shuffle(queries)\n            random3.shuffle(query_words)\n            random4.shuffle(raw_queries)\n            random5.shuffle(query_mentions)\n            random6.shuffle(query_lengths)\n            random7.shuffle(gold_ans_inds)\n            train_gen = next_batch(memories, queries, query_words, raw_queries, query_mentions, query_lengths, gold_ans_inds, self.opt['batch_size'])\n            train_loss = 0\n            for batch_xs, batch_ys in train_gen:\n                train_loss += self.train_step(batch_xs, batch_ys) / num_batches\n\n            valid_gen = next_batch(valid_memories, valid_queries, valid_query_words, valid_raw_queries, valid_query_mentions, valid_query_lengths, valid_gold_ans_inds, self.opt['batch_size'])\n            valid_loss = 0\n            for batch_valid_xs, batch_valid_ys in valid_gen:\n                valid_loss += self.train_step(batch_valid_xs, batch_valid_ys, is_training=False) / num_valid_batches\n            self.scheduler.step(valid_loss)\n\n            # if False:\n            if epoch > 0:\n                pred = self.predict(valid_X, valid_cand_labels, batch_size=1, margin=self.opt['margin'], silence=True)\n                predictions = [unique([x[0] for x in each]) for each in pred]\n                valid_f1 = calc_avg_f1(valid_gold_ans_labels, predictions, verbose=False)[-1]\n            else:\n                valid_f1 = 0.\n            print('Epoch {}/{}: Runtime: {}s, Train loss: {:.4}, valid loss: {:.4}, valid F1: {:.4}'.format(epoch, self.opt['num_epochs'], \\\n                                                    int(timeit.default_timer() - start), train_loss, valid_loss, valid_f1))\n\n            if valid_loss < best_loss:\n                best_loss = valid_loss\n                n_incr_error = 0\n                self.save()\n\n            if n_incr_error >= self.opt['valid_patience']:\n                print('Early stopping occured. Optimization Finished!')\n                self.save(self.opt['model_file'] + '.final')\n                break\n\n    def predict(self, xs, cand_labels, batch_size=32, margin=1, ys=None, verbose=False, silence=False):\n        '''Prediction scores are returned in the verbose mode.\n        '''\n        if not silence:\n            print('Testing size: {}'.format(len(cand_labels)))\n        memories, queries, query_words, raw_queries, query_mentions, query_lengths = xs\n        gen = next_batch(memories, queries, query_words, raw_queries, query_mentions, query_lengths, cand_labels, batch_size)\n        predictions = []\n        for batch_xs, batch_cands in gen:\n            batch_pred = self.predict_step(batch_xs, batch_cands, margin, verbose=verbose)\n            predictions.extend(batch_pred)\n        return predictions\n\n    def train_step(self, xs, ys, is_training=True):\n        # Sets the module in training mode.\n        # This has any effect only on modules such as Dropout or BatchNorm.\n        self.model.train(mode=is_training)\n        with torch.set_grad_enabled(is_training):\n            # Organize inputs for network\n            selected_memories, new_ys, ctx_mask = self.dynamic_ctx_negative_sampling(xs[0], ys, self.opt['mem_size'], \\\n                                    self.opt['ans_ctx_entity_bow_size'], xs[3], xs[4], xs[1])\n            selected_memories = [to_cuda(torch.LongTensor(np.array(x)), self.opt['cuda']) for x in zip(*selected_memories)]\n            ctx_mask = to_cuda(ctx_mask, self.opt['cuda'])\n            queries = to_cuda(torch.LongTensor(xs[1]), self.opt['cuda'])\n            query_words = to_cuda(torch.LongTensor(xs[2]), self.opt['cuda'])\n            query_lengths = to_cuda(torch.LongTensor(xs[5]), self.opt['cuda'])\n            mem_hop_scores = self.model(selected_memories, queries, query_lengths, query_words, ctx_mask=None)\n            # Set margin\n            new_ys, mask_ys = self.pack_gold_ans(new_ys, mem_hop_scores[-1].size(1), placeholder=-1)\n\n            loss = 0\n            for _, s in enumerate(mem_hop_scores):\n                s = self.set_loss_margin(s, mask_ys, self.opt['margin'])\n                loss += self.loss_fn(s, new_ys)\n            loss /= len(mem_hop_scores)\n\n            if is_training:\n                for o in self.optimizers.values():\n                    o.zero_grad()\n                loss.backward()\n                for o in self.optimizers.values():\n                    o.step()\n            return loss.item()\n\n    def predict_step(self, xs, cand_labels, margin, verbose=False):\n        self.model.train(mode=False)\n        with torch.set_grad_enabled(False):\n            # Organize inputs for network\n            memories, ctx_mask = self.pad_ctx_memory(xs[0], self.opt['ans_ctx_entity_bow_size'], xs[3], xs[4], xs[1])\n            memories = [to_cuda(torch.LongTensor(np.array(x)), self.opt['cuda']) for x in zip(*memories)]\n            ctx_mask = to_cuda(ctx_mask, self.opt['cuda'])\n            queries = to_cuda(torch.LongTensor(xs[1]), self.opt['cuda'])\n            query_words = to_cuda(torch.LongTensor(xs[2]), self.opt['cuda'])\n            query_lengths = to_cuda(torch.LongTensor(xs[5]), self.opt['cuda'])\n            mem_hop_scores = self.model(memories, queries, query_lengths, query_words, ctx_mask=None)\n\n            predictions = self.ranked_predictions(cand_labels, mem_hop_scores[-1].data, margin)\n            return predictions\n\n    def dynamic_ctx_negative_sampling(self, memories, ys, mem_size, ctx_bow_size, raw_queries, query_mentions, queries):\n        # Randomly select negative samples from the candidiate answer set\n        ctx_bow_size = max(min(max(map(len, (a for x in list(zip(*memories))[CTX_BOW_INDEX] for y in x for a in y)), default=0), ctx_bow_size), 1)\n\n        selected_memories = []\n        new_ys = []\n        ctx_mask = []\n        for i in range(len(ys)):\n            n = len(memories[i][0]) - 1 # The last element is a dummy candidate\n            num_gold = len(ys[i]) if mem_size > len(ys[i]) else \\\n                    (mem_size - min(mem_size // 2, n - len(ys[i]))) # Max possible (pos, neg) pairs\n            selected_gold_inds = np.random.choice(ys[i], num_gold, replace=False).tolist() if len(ys[i]) > 0 else []\n            if n > len(ys[i]):\n                p = np.ones(n)\n                p[ys[i]] = 0\n                p = p / np.sum(p)\n                selected_inds = np.random.choice(n, min(mem_size, n) - num_gold, replace=False, p=p).tolist()\n            else:\n                selected_inds = []\n            augmented_selected_inds = selected_gold_inds + selected_inds + [-1] * max(mem_size - n, 0)\n            xx = [min(mem_size, n)] + [np.array(x)[augmented_selected_inds] for x in memories[i][:CTX_BOW_INDEX]]\n\n            ctx_bow = []\n            ctx_bow_len = []\n            ctx_num = []\n            tmp_ctx_mask = np.zeros(mem_size)\n            for _, idx in enumerate(augmented_selected_inds):\n                tmp_ctx = []\n                tmp_ctx_len = []\n                for ctx_ent_names in memories[i][CTX_BOW_INDEX][idx]:\n                    sub_seq = get_text_overlap(raw_queries[i], query_mentions[i], ctx_ent_names, self.vocab2id, self.ctx_stops, queries[i])\n                    if len(sub_seq) > 0:\n                        tmp_ctx_mask[_] = 1\n                        tmp_ctx.append(sub_seq[:ctx_bow_size] + [config.RESERVED_TOKENS['PAD']] * max(0, ctx_bow_size - len(sub_seq)))\n                        tmp_ctx_len.append(max(min(ctx_bow_size, len(sub_seq)), 1))\n                ctx_bow.append(tmp_ctx)\n                ctx_bow_len.append(tmp_ctx_len)\n                ctx_num.append(len(tmp_ctx))\n\n            xx += [ctx_bow, ctx_bow_len, ctx_num]\n            xx += [np.array(x)[augmented_selected_inds] for x in memories[i][CTX_BOW_INDEX+1:]]\n            selected_memories.append(xx)\n            new_ys.append(list(range(num_gold)))\n            ctx_mask.append(tmp_ctx_mask)\n\n        max_ctx_num = max(max([y for x in selected_memories for y in x[CTX_BOW_INDEX]]), 1)\n        for i in range(len(selected_memories)): # Example\n            for j in range(len(selected_memories[i][-1])): # Cand\n                count = selected_memories[i][CTX_BOW_INDEX][j]\n                if count < max_ctx_num:\n                    selected_memories[i][CTX_BOW_INDEX - 2][j] += [[config.RESERVED_TOKENS['PAD']] * ctx_bow_size] * (max_ctx_num - count)\n                    selected_memories[i][CTX_BOW_INDEX - 1][j] += [1] * (max_ctx_num - count)\n        return selected_memories, new_ys, torch.Tensor(np.array(ctx_mask))\n\n    def pad_ctx_memory(self, memories, ctx_bow_size, raw_queries, query_mentions, queries):\n        cand_ans_size = max(max(map(len, list(zip(*memories))[0]), default=0) - 1, 1) # The last element is a dummy candidate\n        ctx_bow_size = max(min(max(map(len, (a for x in list(zip(*memories))[CTX_BOW_INDEX] for y in x for a in y)), default=0), ctx_bow_size), 1)\n\n        pad_memories = []\n        ctx_mask = []\n        for i in range(len(memories)):\n            n = len(memories[i][0]) - 1 # The last element is a dummy candidate\n            augmented_inds = list(range(n)) + [-1] * (cand_ans_size - n)\n            xx = [n] + [np.array(x)[augmented_inds] for x in memories[i][:CTX_BOW_INDEX]]\n\n            ctx_bow = []\n            ctx_bow_len = []\n            ctx_num = []\n            tmp_ctx_mask = np.zeros(cand_ans_size)\n            for _, idx in enumerate(augmented_inds):\n                tmp_ctx = []\n                tmp_ctx_len = []\n                for ctx_ent_names in memories[i][CTX_BOW_INDEX][idx]:\n                    sub_seq = get_text_overlap(raw_queries[i], query_mentions[i], ctx_ent_names, self.vocab2id, self.ctx_stops, queries[i])\n                    if len(sub_seq) > 0:\n                        tmp_ctx_mask[_] = 1\n                        tmp_ctx.append(sub_seq[:ctx_bow_size] + [config.RESERVED_TOKENS['PAD']] * max(0, ctx_bow_size - len(sub_seq)))\n                        tmp_ctx_len.append(max(min(ctx_bow_size, len(sub_seq)), 1))\n                ctx_bow.append(tmp_ctx)\n                ctx_bow_len.append(tmp_ctx_len)\n                ctx_num.append(len(tmp_ctx))\n\n            xx += [ctx_bow, ctx_bow_len, ctx_num]\n            xx += [np.array(x)[augmented_inds] for x in memories[i][CTX_BOW_INDEX+1:]]\n            pad_memories.append(xx)\n            ctx_mask.append(tmp_ctx_mask)\n\n        max_ctx_num = max(max([y for x in pad_memories for y in x[CTX_BOW_INDEX]]), 1)\n        for i in range(len(pad_memories)): # Example\n            for j in range(len(pad_memories[i][-1])): # Cand\n                count = pad_memories[i][CTX_BOW_INDEX][j]\n                if count < max_ctx_num:\n                    pad_memories[i][CTX_BOW_INDEX - 2][j] += [[config.RESERVED_TOKENS['PAD']] * ctx_bow_size] * (max_ctx_num - count)\n                    pad_memories[i][CTX_BOW_INDEX - 1][j] += [1] * (max_ctx_num - count)\n        return pad_memories, torch.Tensor(np.array(ctx_mask))\n\n    def pack_gold_ans(self, x, N, placeholder=-1):\n        y = np.ones((len(x), N), dtype='int64') * placeholder\n        mask = np.zeros((len(x), N))\n        for i in range(len(x)):\n            y[i, :len(x[i])] = x[i]\n            mask[i, :len(x[i])] = 1\n        return to_cuda(torch.LongTensor(y), self.opt['cuda']), to_cuda(torch.Tensor(mask), self.opt['cuda'])\n\n    def set_loss_margin(self, scores, gold_mask, margin):\n        \"\"\"Since the pytorch built-in MultiLabelMarginLoss fixes the margin as 1.\n        We simply work around this annoying feature by *modifying* the golden scores.\n        E.g., if we want margin as 3, we decrease each golden score by 3 - 1 before\n        feeding it to the built-in loss.\n        \"\"\"\n        new_scores = scores - (margin - 1) * gold_mask\n        return new_scores\n\n    def ranked_predictions(self, cand_labels, scores, margin):\n        _, sorted_inds = scores.sort(descending=True, dim=1)\n        return [[(cand_labels[i][j], scores[i][j]) for j in r if scores[i][j] + margin >= scores[i][r[0]] \\\n                and cand_labels[i][j] != 'UNK'] \\\n                if len(cand_labels[i]) > 0 and scores[i][r[0]] > -1e4 else [] \\\n                for i, r in enumerate(sorted_inds)] # Very large negative ones are dummy candidates\n\n    def save(self, path=None):\n        path = self.opt.get('model_file', None) if path is None else path\n\n        if path:\n            checkpoint = {}\n            checkpoint['bamnet'] = self.model.state_dict()\n            checkpoint['bamnet_optim'] = self.optimizers['bamnet'].state_dict()\n            with open(path, 'wb') as write:\n                torch.save(checkpoint, write)\n                print('Saved model to {}'.format(path))\n\n    def load(self, path):\n        with open(path, 'rb') as read:\n            checkpoint = torch.load(read, map_location=lambda storage, loc: storage)\n        self.model.load_state_dict(checkpoint['bamnet'])\n        self.optimizers['bamnet'].load_state_dict(checkpoint['bamnet_optim'])\n"
  },
  {
    "path": "src/core/bamnet/ent_modules.py",
    "content": "'''\nCreated on Sep, 2018\n\n@author: hugo\n\n'''\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nfrom torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence\nimport torch.nn.functional as F\n\nfrom .modules import SeqEncoder, SelfAttention_CoAtt, Attention\nfrom .utils import to_cuda\n\n\nINF = 1e20\nVERY_SMALL_NUMBER = 1e-10\nclass Entnet(nn.Module):\n    def __init__(self, vocab_size, vocab_embed_size, o_embed_size, \\\n        hidden_size, num_ent_types, num_relations, \\\n        seq_enc_type='cnn', \\\n        word_emb_dropout=None, \\\n        que_enc_dropout=None,\\\n        ent_enc_dropout=None, \\\n        pre_w2v=None, \\\n        num_hops=1, \\\n        att='add', \\\n        use_cuda=True):\n        super(Entnet, self).__init__()\n        self.use_cuda = use_cuda\n        self.seq_enc_type = seq_enc_type\n        self.que_enc_dropout = que_enc_dropout\n        self.ent_enc_dropout = ent_enc_dropout\n        self.num_hops = num_hops\n        self.hidden_size = hidden_size\n        self.que_enc = SeqEncoder(vocab_size, vocab_embed_size, hidden_size, \\\n                        seq_enc_type=seq_enc_type, \\\n                        word_emb_dropout=word_emb_dropout, \\\n                        bidirectional=True, \\\n                        cnn_kernel_size=[2, 3], \\\n                        init_word_embed=pre_w2v, \\\n                        use_cuda=use_cuda).que_enc\n\n        self.ent_enc = EntEncoder(o_embed_size, hidden_size, \\\n                        num_ent_types, num_relations, \\\n                        vocab_size=vocab_size, \\\n                        vocab_embed_size=vocab_embed_size, \\\n                        shared_embed=self.que_enc.embed, \\\n                        seq_enc_type=seq_enc_type, \\\n                        word_emb_dropout=word_emb_dropout, \\\n                        ent_enc_dropout=ent_enc_dropout, \\\n                        use_cuda=use_cuda)\n        self.batchnorm = nn.BatchNorm1d(hidden_size)\n\n        if seq_enc_type in ('lstm', 'gru'):\n            self.self_atten = SelfAttention_CoAtt(hidden_size)\n            print('[ Using self-attention on question encoder ]')\n\n        self.ent_memory_hop = EntRomHop(hidden_size, hidden_size, hidden_size, atten_type=att)\n        print('[ Using {}-hop entity memory update ]'.format(num_hops))\n\n    def forward(self, memories, queries, query_lengths):\n        x_ent_names, x_ent_name_len, x_type_names, x_types, x_type_name_len, x_rel_names, x_rels, x_rel_name_len, x_rel_mask = memories\n        x_rel_mask = self.create_mask_3D(x_rel_mask, x_rels.size(-1), use_cuda=self.use_cuda)\n\n        # Question encoder\n        if self.seq_enc_type in ('lstm', 'gru'):\n            Q_r = self.que_enc(queries, query_lengths)[0]\n            if self.que_enc_dropout:\n                Q_r = F.dropout(Q_r, p=self.que_enc_dropout, training=self.training)\n\n            query_mask = self.create_mask(query_lengths, Q_r.size(1), self.use_cuda)\n            q_r = self.self_atten(Q_r, query_lengths, query_mask)\n        else:\n            q_r = self.que_enc(queries, query_lengths)[1]\n            if self.que_enc_dropout:\n                q_r = F.dropout(q_r, p=self.que_enc_dropout, training=self.training)\n\n        # Entity encoder\n        ent_val, ent_key = self.ent_enc(x_ent_names, x_ent_name_len, x_type_names, x_types, x_type_name_len, x_rel_names, x_rels, x_rel_name_len, x_rel_mask)\n\n        ent_val = torch.cat([each.unsqueeze(2) for each in ent_val], 2)\n        ent_key = torch.cat([each.unsqueeze(2) for each in ent_key], 2)\n        ent_val = torch.sum(ent_val, 2)\n        ent_key = torch.sum(ent_key, 2)\n\n        mem_hop_scores = []\n        mid_score = self.clf_score(q_r, ent_key)\n        mem_hop_scores.append(mid_score)\n\n        for _ in range(self.num_hops):\n            q_r = q_r + self.ent_memory_hop(q_r, ent_key, ent_val)\n            q_r = self.batchnorm(q_r)\n            mid_score = self.clf_score(q_r, ent_key)\n            mem_hop_scores.append(mid_score)\n        return mem_hop_scores\n\n    def clf_score(self, q_r, ent_key):\n        return torch.matmul(ent_key, q_r.unsqueeze(-1)).squeeze(-1)\n\n    def create_mask(self, x, N, use_cuda=True):\n        x = x.data\n        mask = np.zeros((x.size(0), N))\n        for i in range(x.size(0)):\n            mask[i, :x[i]] = 1\n        return to_cuda(torch.Tensor(mask), use_cuda)\n\n    def create_mask_3D(self, x, N, use_cuda=True):\n        x = x.data\n        mask = np.zeros((x.size(0), x.size(1), N))\n        for i in range(x.size(0)):\n            for j in range(x.size(1)):\n                mask[i, j, :x[i, j]] = 1\n        return to_cuda(torch.Tensor(mask), use_cuda)\n\nclass EntEncoder(nn.Module):\n    \"\"\"Entity Encoder\"\"\"\n    def __init__(self, o_embed_size, hidden_size, num_ent_types, num_relations, vocab_size=None, \\\n                    vocab_embed_size=None, shared_embed=None, seq_enc_type='lstm', word_emb_dropout=None, \\\n                    ent_enc_dropout=None, use_cuda=True):\n        super(EntEncoder, self).__init__()\n        # Cannot have embed and vocab_size set as None at the same time.\n        self.ent_enc_dropout = ent_enc_dropout\n        self.hidden_size = hidden_size\n        self.relation_embed = nn.Embedding(num_relations, o_embed_size, padding_idx=0)\n        self.embed = shared_embed if shared_embed is not None else nn.Embedding(vocab_size, vocab_embed_size, padding_idx=0)\n        self.vocab_embed_size = self.embed.weight.data.size(1)\n\n        self.linear_node_name_key = nn.Linear(hidden_size, hidden_size, bias=False)\n        self.linear_node_type_key = nn.Linear(hidden_size, hidden_size, bias=False)\n        self.linear_rels_key = nn.Linear(hidden_size + o_embed_size, hidden_size, bias=False)\n        self.linear_node_name_val = nn.Linear(hidden_size, hidden_size, bias=False)\n        self.linear_node_type_val = nn.Linear(hidden_size, hidden_size, bias=False)\n        self.linear_rels_val = nn.Linear(hidden_size + o_embed_size, hidden_size, bias=False)\n\n        self.kg_enc_ent = SeqEncoder(vocab_size, \\\n                        self.vocab_embed_size, \\\n                        hidden_size, \\\n                        seq_enc_type=seq_enc_type, \\\n                        word_emb_dropout=word_emb_dropout, \\\n                        bidirectional=True, \\\n                        cnn_kernel_size=[3], \\\n                        shared_embed=shared_embed, \\\n                        use_cuda=use_cuda).que_enc # entity name\n\n        self.kg_enc_type = SeqEncoder(vocab_size, \\\n                        self.vocab_embed_size, \\\n                        hidden_size, \\\n                        seq_enc_type=seq_enc_type, \\\n                        word_emb_dropout=word_emb_dropout, \\\n                        bidirectional=True, \\\n                        cnn_kernel_size=[3], \\\n                        shared_embed=shared_embed, \\\n                        use_cuda=use_cuda).que_enc # entity type name\n\n        self.kg_enc_rel = SeqEncoder(vocab_size, \\\n                        self.vocab_embed_size, \\\n                        hidden_size, \\\n                        seq_enc_type=seq_enc_type, \\\n                        word_emb_dropout=word_emb_dropout, \\\n                        bidirectional=True, \\\n                        cnn_kernel_size=[3], \\\n                        shared_embed=shared_embed, \\\n                        use_cuda=use_cuda).que_enc # relation name\n\n    def forward(self, x_ent_names, x_ent_name_len, x_type_names, x_types, x_type_name_len, x_rel_names, x_rels, x_rel_name_len, x_rel_mask):\n        node_ent_names, node_type_names, node_types, edge_rel_names, edge_rels = self.enc_kg_features(x_ent_names, x_ent_name_len, x_type_names, x_types, x_type_name_len, x_rel_names, x_rels, x_rel_name_len, x_rel_mask)\n        node_name_key = self.linear_node_name_key(node_ent_names)\n        node_type_key = self.linear_node_type_key(node_type_names)\n        rel_key = self.linear_rels_key(torch.cat([edge_rel_names, edge_rels], -1))\n\n        node_name_val = self.linear_node_name_val(node_ent_names)\n        node_type_val = self.linear_node_type_val(node_type_names)\n        rel_val = self.linear_rels_val(torch.cat([edge_rel_names, edge_rels], -1))\n\n        ent_comp_val = [node_name_val, node_type_val, rel_val]\n        ent_comp_key = [node_name_key, node_type_key, rel_key]\n        return ent_comp_val, ent_comp_key\n\n    def enc_kg_features(self, x_ent_names, x_ent_name_len, x_type_names, x_types, x_type_name_len, x_rel_names, x_rels, x_rel_name_len, x_rel_mask):\n        node_ent_names = (self.kg_enc_ent(x_ent_names.view(-1, x_ent_names.size(-1)), x_ent_name_len.view(-1))[1]).view(x_ent_names.size(0), x_ent_names.size(1), -1)\n        node_type_names = (self.kg_enc_type(x_type_names.view(-1, x_type_names.size(-1)), x_type_name_len.view(-1))[1]).view(x_type_names.size(0), x_type_names.size(1), -1)\n        node_types = None\n        edge_rel_names = torch.mean((self.kg_enc_rel(x_rel_names.view(-1, x_rel_names.size(-1)), x_rel_name_len.view(-1))[1]).view(x_rel_names.size(0), x_rel_names.size(1), x_rel_names.size(2), -1), 2)\n        edge_rels = torch.mean(self.relation_embed(x_rels.view(-1, x_rels.size(-1))), 1).view(x_rels.size(0), x_rels.size(1), -1)\n\n        if self.ent_enc_dropout:\n            node_ent_names = F.dropout(node_ent_names, p=self.ent_enc_dropout, training=self.training)\n            node_type_names = F.dropout(node_type_names, p=self.ent_enc_dropout, training=self.training)\n            # node_types = F.dropout(node_types, p=self.ent_enc_dropout, training=self.training)\n            edge_rel_names = F.dropout(edge_rel_names, p=self.ent_enc_dropout, training=self.training)\n            edge_rels = F.dropout(edge_rels, p=self.ent_enc_dropout, training=self.training)\n        return node_ent_names, node_type_names, node_types, edge_rel_names, edge_rels\n\n\nclass EntRomHop(nn.Module):\n    def __init__(self, query_embed_size, in_memory_embed_size, hidden_size, atten_type='add'):\n        super(EntRomHop, self).__init__()\n        self.atten = Attention(hidden_size, query_embed_size, in_memory_embed_size, atten_type=atten_type)\n        self.gru_step = GRUStep(hidden_size, in_memory_embed_size)\n\n    def forward(self, h_state, key_memory_embed, val_memory_embed, atten_mask=None):\n        attention = self.atten(h_state, key_memory_embed, atten_mask=atten_mask)\n        probs = torch.softmax(attention, dim=-1)\n        memory_output = torch.bmm(probs.unsqueeze(1), val_memory_embed).squeeze(1)\n        h_state = self.gru_step(h_state, memory_output)\n        return h_state\n\nclass GRUStep(nn.Module):\n    def __init__(self, hidden_size, input_size):\n        super(GRUStep, self).__init__()\n        '''GRU module'''\n        self.linear_z = nn.Linear(hidden_size + input_size, hidden_size, bias=False)\n        self.linear_r = nn.Linear(hidden_size + input_size, hidden_size, bias=False)\n        self.linear_t = nn.Linear(hidden_size + input_size, hidden_size, bias=False)\n\n    def forward(self, h_state, input_):\n        z = torch.sigmoid(self.linear_z(torch.cat([h_state, input_], -1)))\n        r = torch.sigmoid(self.linear_r(torch.cat([h_state, input_], -1)))\n        t = torch.tanh(self.linear_t(torch.cat([r * h_state, input_], -1)))\n        h_state = (1 - z) * h_state + z * t\n        return h_state\n"
  },
  {
    "path": "src/core/bamnet/entnet.py",
    "content": "'''\nCreated on Sep, 2018\n\n@author: hugo\n\n'''\nimport os\nimport timeit\nimport numpy as np\n\nimport torch\nfrom torch import optim\nfrom torch.optim.lr_scheduler import ReduceLROnPlateau\nfrom torch.nn import CrossEntropyLoss, MultiLabelMarginLoss\nimport torch.backends.cudnn as cudnn\n\nfrom .ent_modules import Entnet\nfrom .utils import to_cuda, next_ent_batch\nfrom ..utils.utils import load_ndarray\nfrom ..utils.generic_utils import unique\nfrom ..utils.metrics import *\n\n\nclass EntnetAgent(object):\n    def __init__(self, opt):\n        opt['cuda'] = not opt['no_cuda'] and torch.cuda.is_available()\n        if opt['cuda']:\n            print('[ Using CUDA ]')\n            torch.cuda.set_device(opt['gpu'])\n            # It enables benchmark mode in cudnn, which\n            # leads to faster runtime when the input sizes do not vary.\n            cudnn.benchmark = True\n\n        self.opt = opt\n        if self.opt['pre_word2vec']:\n            pre_w2v = load_ndarray(self.opt['pre_word2vec'])\n        else:\n            pre_w2v = None\n\n        self.ent_model = Entnet(opt['vocab_size'], opt['vocab_embed_size'], \\\n                opt['o_embed_size'], opt['hidden_size'], \\\n                opt['num_ent_types'], opt['num_relations'], \\\n                seq_enc_type=opt['seq_enc_type'], \\\n                word_emb_dropout=opt['word_emb_dropout'], \\\n                que_enc_dropout=opt['que_enc_dropout'], \\\n                ent_enc_dropout=opt['ent_enc_dropout'], \\\n                pre_w2v=pre_w2v, \\\n                num_hops=opt['num_ent_hops'], \\\n                att=opt['attention'], \\\n                use_cuda=opt['cuda'])\n        if opt['cuda']:\n            self.ent_model.cuda()\n\n        self.loss_fn = MultiLabelMarginLoss()\n\n        optim_params = [p for p in self.ent_model.parameters() if p.requires_grad]\n        self.optimizers = {'entnet': optim.Adam(optim_params, lr=opt['learning_rate'])}\n        self.scheduler = ReduceLROnPlateau(self.optimizers['entnet'], mode='min', \\\n                    patience=self.opt['valid_patience'] // 3, verbose=True)\n\n        if opt.get('model_file') and os.path.isfile(opt['model_file']):\n            print('Loading existing ent_model parameters from ' + opt['model_file'])\n            self.load(opt['model_file'])\n        else:\n            self.save()\n            self.load(opt['model_file'])\n        super(EntnetAgent, self).__init__()\n\n    def train(self, train_X, train_y, valid_X, valid_y, seed=1234):\n        print('Training size: {}, Validation size: {}'.format(len(train_y), len(valid_y)))\n        random1 = np.random.RandomState(seed)\n        random2 = np.random.RandomState(seed)\n        random3 = np.random.RandomState(seed)\n        random4 = np.random.RandomState(seed)\n        memories, queries, query_lengths = train_X\n        ent_inds = train_y\n\n        valid_memories, valid_queries, valid_query_lengths = valid_X\n        valid_ent_inds = valid_y\n\n        n_incr_error = 0  # nb. of consecutive increase in error\n        best_loss = float(\"inf\")\n        best_acc = 0\n        num_batches = len(queries) // self.opt['batch_size'] + (len(queries) % self.opt['batch_size'] != 0)\n        num_valid_batches = len(valid_queries) // self.opt['batch_size'] + (len(valid_queries) % self.opt['batch_size'] != 0)\n        for epoch in range(1, self.opt['num_epochs'] + 1):\n            start = timeit.default_timer()\n            n_incr_error += 1\n            random1.shuffle(memories)\n            random2.shuffle(queries)\n            random3.shuffle(query_lengths)\n            random4.shuffle(ent_inds)\n            train_gen = next_ent_batch(memories, queries, query_lengths, ent_inds, self.opt['batch_size'])\n            train_loss = 0\n            for batch_xs, batch_ys in train_gen:\n                train_loss += self.train_step(batch_xs, batch_ys) / num_batches\n\n            valid_gen = next_ent_batch(valid_memories, valid_queries, valid_query_lengths, valid_ent_inds, self.opt['batch_size'])\n            valid_loss = 0\n            for batch_valid_xs, batch_valid_ys in valid_gen:\n                valid_loss += self.train_step(batch_valid_xs, batch_valid_ys, is_training=False) / num_valid_batches\n            self.scheduler.step(valid_loss)\n\n            if epoch > 0:\n                valid_acc = self.evaluate(valid_X, valid_ent_inds, batch_size=1, silence=True)\n                # valid_acc = 0.\n                print('Epoch {}/{}: Runtime: {}s, Training loss: {:.4}, validation loss: {:.4}, validation ACC: {:.4}'.format(epoch, self.opt['num_epochs'], \\\n                                                    int(timeit.default_timer() - start), train_loss, valid_loss, valid_acc))\n\n                # self.scheduler.step(valid_acc)\n                # if valid_acc > best_acc:\n                #     best_acc = valid_acc\n                #     n_incr_error = 0\n                #     self.save()\n\n                if valid_loss < best_loss:\n                    best_loss = valid_loss\n                    n_incr_error = 0\n                    self.save()\n\n                if n_incr_error >= self.opt['valid_patience']:\n                    print('Early stopping occured. Optimization Finished!')\n                    self.save(self.opt['model_file'] + '.final')\n                    break\n\n    def evaluate(self, xs, ys, batch_size=1, silence=False):\n        '''Prediction scores are returned in the verbose mode.\n        '''\n        if not silence:\n            print('Data size: {}'.format(len(xs[0])))\n        memories, queries, query_lengths = xs\n        gen = next_ent_batch(memories, queries, query_lengths, ys, batch_size)\n        correct = 0\n        num_samples = 0\n        for batch_xs, batch_ys in gen:\n            correct += self.evaluate_step(batch_xs, batch_ys)\n            num_samples += len(batch_ys)\n        acc = 100 * correct / num_samples\n        return acc\n\n    def predict(self, xs, cand_labels, batch_size=1, silence=False):\n        if not silence:\n            print('Data size: {}'.format(len(xs[0])))\n        memories, queries, query_lengths = xs\n        gen = next_ent_batch(memories, queries, query_lengths, cand_labels, batch_size)\n        predictions = []\n        for batch_xs, batch_cands in gen:\n            batch_pred = self.predict_step(batch_xs, batch_cands)\n            predictions.extend(batch_pred)\n        return predictions\n\n    def train_step(self, xs, ys, is_training=True):\n        # Sets the module in training mode.\n        # This has any effect only on modules such as Dropout or BatchNorm.\n        self.ent_model.train(mode=is_training)\n        with torch.set_grad_enabled(is_training):\n            # Organize inputs for network\n            memories = [to_cuda(torch.LongTensor(np.array(x)), self.opt['cuda']) for x in zip(*xs[0])]\n            queries = to_cuda(torch.LongTensor(xs[1]), self.opt['cuda'])\n            query_lengths = to_cuda(torch.LongTensor(xs[2]), self.opt['cuda'])\n            mem_hop_scores = self.ent_model(memories, queries, query_lengths)\n            # ys = to_cuda(torch.LongTensor(ys), self.opt['cuda']).squeeze(-1)\n            # Set margin\n            ys, mask_ys = self.pack_gold_ans(ys, mem_hop_scores[-1].size(1), placeholder=-1)\n\n            loss = 0\n            for _, s in enumerate(mem_hop_scores):\n                loss += self.loss_fn(s, ys)\n            loss /= len(mem_hop_scores)\n\n            if is_training:\n                for o in self.optimizers.values():\n                    o.zero_grad()\n                loss.backward()\n                for o in self.optimizers.values():\n                    o.step()\n            return loss.item()\n\n    def evaluate_step(self, xs, ys):\n        self.ent_model.train(mode=False)\n        with torch.set_grad_enabled(False):\n            # Organize inputs for network\n            memories = [to_cuda(torch.LongTensor(np.array(x)), self.opt['cuda']) for x in zip(*xs[0])]\n            queries = to_cuda(torch.LongTensor(xs[1]), self.opt['cuda'])\n            query_lengths = to_cuda(torch.LongTensor(xs[2]), self.opt['cuda'])\n            scores = self.ent_model(memories, queries, query_lengths)[-1]\n            ys = to_cuda(torch.LongTensor(ys), self.opt['cuda']).squeeze(1)\n\n            predictions = scores.max(1)[1].type_as(ys)\n            correct = predictions.eq(ys).sum()\n            return correct.item()\n\n    def predict_step(self, xs, cand_labels):\n        self.ent_model.train(mode=False)\n        with torch.set_grad_enabled(False):\n            # Organize inputs for network\n            memories = [to_cuda(torch.LongTensor(np.array(x)), self.opt['cuda']) for x in zip(*xs[0])]\n            queries = to_cuda(torch.LongTensor(xs[1]), self.opt['cuda'])\n            query_lengths = to_cuda(torch.LongTensor(xs[2]), self.opt['cuda'])\n            scores = self.ent_model(memories, queries, query_lengths)[-1]\n\n            predictions = self.ranked_predictions(cand_labels, scores)\n            return predictions\n\n    def pack_gold_ans(self, x, N, placeholder=-1):\n        y = np.ones((len(x), N), dtype='int64') * placeholder\n        mask = np.zeros((len(x), N))\n        for i in range(len(x)):\n            y[i, :len(x[i])] = x[i]\n            mask[i, :len(x[i])] = 1\n        return to_cuda(torch.LongTensor(y), self.opt['cuda']), to_cuda(torch.Tensor(mask), self.opt['cuda'])\n\n    def ranked_predictions(self, cand_labels, scores):\n        _, sorted_inds = scores.sort(descending=True, dim=1)\n        return [cand_labels[i][r[0]] if len(cand_labels[i]) > 0 else '' \\\n                for i, r in enumerate(sorted_inds)]\n\n    def save(self, path=None):\n        path = self.opt.get('model_file', None) if path is None else path\n\n        if path:\n            checkpoint = {}\n            checkpoint['entnet'] = self.ent_model.state_dict()\n            checkpoint['entnet_optim'] = self.optimizers['entnet'].state_dict()\n            with open(path, 'wb') as write:\n                torch.save(checkpoint, write)\n                print('Saved ent_model to {}'.format(path))\n\n    def load(self, path):\n        with open(path, 'rb') as read:\n            checkpoint = torch.load(read, map_location=lambda storage, loc: storage)\n        self.ent_model.load_state_dict(checkpoint['entnet'])\n        self.optimizers['entnet'].load_state_dict(checkpoint['entnet_optim'])\n"
  },
  {
    "path": "src/core/bamnet/modules.py",
    "content": "'''\nCreated on Sep, 2017\n\n@author: hugo\n\n'''\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nfrom torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence\nimport torch.nn.functional as F\n\nfrom .utils import to_cuda\n\n\nINF = 1e20\nVERY_SMALL_NUMBER = 1e-10\nclass BAMnet(nn.Module):\n    def __init__(self, vocab_size, vocab_embed_size, o_embed_size, \\\n        hidden_size, num_ent_types, num_relations, num_query_words, \\\n        word_emb_dropout=None,\\\n        que_enc_dropout=None,\\\n        ans_enc_dropout=None, \\\n        pre_w2v=None, \\\n        num_hops=1, \\\n        att='add', \\\n        use_cuda=True):\n        super(BAMnet, self).__init__()\n        self.use_cuda = use_cuda\n        self.word_emb_dropout = word_emb_dropout\n        self.que_enc_dropout = que_enc_dropout\n        self.ans_enc_dropout = ans_enc_dropout\n        self.num_hops = num_hops\n        self.hidden_size = hidden_size\n        self.que_enc = SeqEncoder(vocab_size, vocab_embed_size, hidden_size, \\\n                        seq_enc_type='lstm', \\\n                        word_emb_dropout=word_emb_dropout, bidirectional=True, \\\n                        init_word_embed=pre_w2v, use_cuda=use_cuda).que_enc\n\n        self.ans_enc = AnsEncoder(o_embed_size, hidden_size, \\\n                        num_ent_types, num_relations, \\\n                        vocab_size=vocab_size, \\\n                        vocab_embed_size=vocab_embed_size, \\\n                        shared_embed=self.que_enc.embed, \\\n                        word_emb_dropout=word_emb_dropout, \\\n                        ans_enc_dropout=ans_enc_dropout, \\\n                        use_cuda=use_cuda)\n\n        self.qw_embed = nn.Embedding(num_query_words, o_embed_size // 8, padding_idx=0)\n        self.batchnorm = nn.BatchNorm1d(hidden_size)\n\n        self.init_atten = Attention(hidden_size, hidden_size, hidden_size, atten_type=att)\n        self.self_atten = SelfAttention_CoAtt(hidden_size)\n        print('[ Using self-attention on question encoder ]')\n\n        self.memory_hop = RomHop(hidden_size, hidden_size, hidden_size, atten_type=att)\n        print('[ Using {}-hop memory update ]'.format(self.num_hops))\n\n    def kb_aware_query_enc(self, memories, queries, query_lengths, ans_mask, ctx_mask=None):\n        # Question encoder\n        Q_r = self.que_enc(queries, query_lengths)[0]\n        if self.que_enc_dropout:\n            Q_r = F.dropout(Q_r, p=self.que_enc_dropout, training=self.training)\n\n        query_mask = create_mask(query_lengths, Q_r.size(1), self.use_cuda)\n        q_r_init = self.self_atten(Q_r, query_lengths, query_mask)\n\n        # Answer encoder\n        _, _, _, x_type_bow, x_types, x_type_bow_len, x_path_bow, x_paths, x_path_bow_len, x_ctx_ent, x_ctx_ent_len, x_ctx_ent_num, _, _, _, _ = memories\n        ans_comp_val, ans_comp_key = self.ans_enc(x_type_bow, x_types, x_type_bow_len, x_path_bow, x_paths, x_path_bow_len, x_ctx_ent, x_ctx_ent_len, x_ctx_ent_num)\n        if self.ans_enc_dropout:\n            for _ in range(len(ans_comp_key)):\n                ans_comp_key[_] = F.dropout(ans_comp_key[_], p=self.ans_enc_dropout, training=self.training)\n        # KB memory summary\n        ans_comp_atts = [self.init_atten(q_r_init, each, atten_mask=ans_mask) for each in ans_comp_key]\n        if ctx_mask is not None:\n            ans_comp_atts[-1] = ctx_mask * ans_comp_atts[-1] - (1 - ctx_mask) * INF\n        ans_comp_probs = [torch.softmax(each, dim=-1) for each in ans_comp_atts]\n        memory_summary = []\n        for i, probs in enumerate(ans_comp_probs):\n            memory_summary.append(torch.bmm(probs.unsqueeze(1), ans_comp_val[i]))\n        memory_summary = torch.cat(memory_summary, 1)\n\n        # Co-attention\n        CoAtt = torch.bmm(Q_r, memory_summary.transpose(1, 2)) # co-attention matrix\n        CoAtt = query_mask.unsqueeze(-1) * CoAtt - (1 - query_mask).unsqueeze(-1) * INF\n        if ctx_mask is not None:\n            # mask over empty ctx elements\n            ctx_mask_global = (ctx_mask.sum(-1, keepdim=True) > 0).float()\n            CoAtt[:, :, -1] = ctx_mask_global * CoAtt[:, :, -1].clone() - (1 - ctx_mask_global) * INF\n\n        q_att = F.max_pool1d(CoAtt, kernel_size=CoAtt.size(-1)).squeeze(-1)\n        q_att = torch.softmax(q_att, dim=-1)\n        return (ans_comp_val, ans_comp_key), (q_att, Q_r), query_mask\n\n    def forward(self, memories, queries, query_lengths, query_words, ctx_mask=None):\n        ctx_mask = None\n        mem_hop_scores = []\n        ans_mask = create_mask(memories[0], memories[2].size(1), self.use_cuda)\n\n        # Multi-task learning on answer type matching\n        # question word vec\n        self.qw_vec = torch.mean(self.qw_embed(query_words), 1)\n        # answer type vec\n        x_types = memories[4]\n        ans_types = torch.mean(self.ans_enc.ent_type_embed(x_types.view(-1, x_types.size(-1))), 1).view(x_types.size(0), x_types.size(1), -1)\n        qw_anstype_loss = torch.bmm(ans_types, self.qw_vec.unsqueeze(2)).squeeze(2)\n        if ans_mask is not None:\n            qw_anstype_loss = ans_mask * qw_anstype_loss - (1 - ans_mask) * INF # Make dummy candidates have large negative scores\n        mem_hop_scores.append(qw_anstype_loss)\n\n\n        # Kb-aware question attention module\n        (ans_val, ans_key), (q_att, Q_r), query_mask = self.kb_aware_query_enc(memories, queries, query_lengths, ans_mask, ctx_mask=ctx_mask)\n        ans_val = torch.cat([each.unsqueeze(2) for each in ans_val], 2)\n        ans_key = torch.cat([each.unsqueeze(2) for each in ans_key], 2)\n\n        q_r = torch.bmm(q_att.unsqueeze(1), Q_r).squeeze(1)\n        mid_score = self.scoring(ans_key.sum(2), q_r, mask=ans_mask)\n        mem_hop_scores.append(mid_score)\n\n        Q_r, ans_key, ans_val = self.memory_hop(Q_r, ans_key, ans_val, q_att, atten_mask=ans_mask, ctx_mask=ctx_mask, query_mask=query_mask)\n        q_r = torch.bmm(q_att.unsqueeze(1), Q_r).squeeze(1)\n        mid_score = self.scoring(ans_key, q_r, mask=ans_mask)\n        mem_hop_scores.append(mid_score)\n\n        # Generalization module\n        for _ in range(self.num_hops):\n            q_r_tmp = self.memory_hop.gru_step(q_r, ans_key, ans_val, atten_mask=ans_mask)\n            q_r = self.batchnorm(q_r + q_r_tmp)\n            mid_score = self.scoring(ans_key, q_r, mask=ans_mask)\n            mem_hop_scores.append(mid_score)\n        return mem_hop_scores\n\n    def premature_score(self, memories, queries, query_lengths, ctx_mask=None):\n        ctx_mask = None\n        ans_mask = create_mask(memories[0], memories[2].size(1), self.use_cuda)\n\n        # Kb-aware question attention module\n        (ans_val, ans_key), (q_att, Q_r), query_mask = self.kb_aware_query_enc(memories, queries, query_lengths, ans_mask, ctx_mask=ctx_mask)\n        ans_key = torch.cat([each.unsqueeze(2) for each in ans_key], 2)\n\n        mem_hop_scores = []\n        q_r = torch.bmm(q_att.unsqueeze(1), Q_r).squeeze(1)\n        score = self.scoring(ans_key.sum(2), q_r, mask=ans_mask)\n        return score\n\n    def scoring(self, ans_r, q_r, mask=None):\n        score = torch.bmm(ans_r, q_r.unsqueeze(2)).squeeze(2)\n        if mask is not None:\n            score = mask * score - (1 - mask) * INF # Make dummy candidates have large negative scores\n        return score\n\nclass RomHop(nn.Module):\n    def __init__(self, query_embed_size, in_memory_embed_size, hidden_size, atten_type='add'):\n        super(RomHop, self).__init__()\n        self.hidden_size = hidden_size\n        self.gru_linear_z = nn.Linear(2 * hidden_size, hidden_size, bias=False)\n        self.gru_linear_r = nn.Linear(2 * hidden_size, hidden_size, bias=False)\n        self.gru_linear_t = nn.Linear(2 * hidden_size, hidden_size, bias=False)\n        self.gru_atten = Attention(hidden_size, query_embed_size, in_memory_embed_size, atten_type=atten_type)\n\n    def forward(self, query_embed, in_memory_embed, out_memory_embed, query_att, \\\n                atten_mask=None, ctx_mask=None, query_mask=None):\n        output = self.update_coatt_cat_maxpool(query_embed, in_memory_embed, out_memory_embed, query_att, \\\n                    atten_mask=atten_mask, ctx_mask=ctx_mask, query_mask=query_mask)\n        return output\n\n    def gru_step(self, h_state, in_memory_embed, out_memory_embed, atten_mask=None):\n        attention = self.gru_atten(h_state, in_memory_embed, atten_mask=atten_mask)\n        probs = torch.softmax(attention, dim=-1)\n\n        memory_output = torch.bmm(probs.unsqueeze(1), out_memory_embed).squeeze(1)\n        # GRU-like memory update\n        z = torch.sigmoid(self.gru_linear_z(torch.cat([h_state, memory_output], -1)))\n        r = torch.sigmoid(self.gru_linear_r(torch.cat([h_state, memory_output], -1)))\n        t = torch.tanh(self.gru_linear_t(torch.cat([r * h_state, memory_output], -1)))\n        output = (1 - z) * h_state + z * t\n        return output\n\n    def update_coatt_cat_maxpool(self, query_embed, in_memory_embed, out_memory_embed, query_att, atten_mask=None, ctx_mask=None, query_mask=None):\n        attention = torch.bmm(query_embed, in_memory_embed.view(in_memory_embed.size(0), -1, in_memory_embed.size(-1))\\\n            .transpose(1, 2)).view(query_embed.size(0), query_embed.size(1), in_memory_embed.size(1), -1) # bs * N * M * k\n        if ctx_mask is not None:\n            attention[:, :, :, -1] = ctx_mask.unsqueeze(1) * attention[:, :, :, -1].clone() - (1 - ctx_mask).unsqueeze(1) * INF\n        if atten_mask is not None:\n            attention = atten_mask.unsqueeze(1).unsqueeze(-1) * attention - (1 - atten_mask).unsqueeze(1).unsqueeze(-1) * INF\n        if query_mask is not None:\n            attention = query_mask.unsqueeze(2).unsqueeze(-1) * attention - (1 - query_mask).unsqueeze(2).unsqueeze(-1) * INF\n\n        # Importance module\n        kb_feature_att = F.max_pool1d(attention.view(attention.size(0), attention.size(1), -1).transpose(1, 2), kernel_size=attention.size(1)).squeeze(-1).view(attention.size(0), -1, attention.size(-1))\n        kb_feature_att = torch.softmax(kb_feature_att, dim=-1).view(-1, kb_feature_att.size(-1)).unsqueeze(1)\n        in_memory_embed = torch.bmm(kb_feature_att, in_memory_embed.view(-1, in_memory_embed.size(2), in_memory_embed.size(-1))).squeeze(1).view(in_memory_embed.size(0), in_memory_embed.size(1), -1)\n        out_memory_embed = out_memory_embed.sum(2)\n\n        # Enhanced module\n        attention = F.max_pool1d(attention.view(attention.size(0), -1, attention.size(-1)), kernel_size=attention.size(-1)).squeeze(-1).view(attention.size(0), attention.size(1), attention.size(2))\n        probs = torch.softmax(attention, dim=-1)\n        new_query_embed = query_embed + query_att.unsqueeze(2) * torch.bmm(probs, out_memory_embed)\n\n        probs2 = torch.softmax(attention, dim=1)\n        kb_att = torch.bmm(query_att.unsqueeze(1), probs).squeeze(1)\n        in_memory_embed = in_memory_embed + kb_att.unsqueeze(2) * torch.bmm(probs2.transpose(1, 2), new_query_embed)\n        return new_query_embed, in_memory_embed, out_memory_embed\n\nclass AnsEncoder(nn.Module):\n    \"\"\"Answer Encoder\"\"\"\n    def __init__(self, o_embed_size, hidden_size, num_ent_types, num_relations, vocab_size=None, \\\n                    vocab_embed_size=None, shared_embed=None, word_emb_dropout=None, \\\n                    ans_enc_dropout=None, use_cuda=True):\n        super(AnsEncoder, self).__init__()\n        # Cannot have embed and vocab_size set as None at the same time.\n        self.use_cuda = use_cuda\n        self.ans_enc_dropout = ans_enc_dropout\n        self.hidden_size = hidden_size\n        self.ent_type_embed = nn.Embedding(num_ent_types, o_embed_size // 8, padding_idx=0)\n        self.relation_embed = nn.Embedding(num_relations, o_embed_size, padding_idx=0)\n        self.embed = shared_embed if shared_embed is not None else nn.Embedding(vocab_size, vocab_embed_size, padding_idx=0)\n        self.vocab_embed_size = self.embed.weight.data.size(1)\n\n        self.linear_type_bow_key = nn.Linear(hidden_size, hidden_size, bias=False)\n        self.linear_paths_key = nn.Linear(hidden_size + o_embed_size, hidden_size, bias=False)\n        self.linear_ctx_key = nn.Linear(hidden_size, hidden_size, bias=False)\n        self.linear_type_bow_val = nn.Linear(hidden_size, hidden_size, bias=False)\n        self.linear_paths_val = nn.Linear(hidden_size + o_embed_size, hidden_size, bias=False)\n        self.linear_ctx_val = nn.Linear(hidden_size, hidden_size, bias=False)\n\n        # lstm for ans encoder\n        self.lstm_enc_type = EncoderRNN(vocab_size, self.vocab_embed_size, hidden_size, \\\n                        dropout=word_emb_dropout, \\\n                        bidirectional=True, \\\n                        shared_embed=shared_embed, \\\n                        rnn_type='lstm', \\\n                        use_cuda=use_cuda)\n        self.lstm_enc_path = EncoderRNN(vocab_size, self.vocab_embed_size, hidden_size, \\\n                        dropout=word_emb_dropout, \\\n                        bidirectional=True, \\\n                        shared_embed=shared_embed, \\\n                        rnn_type='lstm', \\\n                        use_cuda=use_cuda)\n        self.lstm_enc_ctx = EncoderRNN(vocab_size, self.vocab_embed_size, hidden_size, \\\n                        dropout=word_emb_dropout, \\\n                        bidirectional=True, \\\n                        shared_embed=shared_embed, \\\n                        rnn_type='lstm', \\\n                        use_cuda=use_cuda)\n\n    def forward(self, x_type_bow, x_types, x_type_bow_len, x_path_bow, x_paths, x_path_bow_len, x_ctx_ents, x_ctx_ent_len, x_ctx_ent_num):\n        ans_type_bow, ans_types, ans_path_bow, ans_paths, ans_ctx_ent = self.enc_ans_features(x_type_bow, x_types, x_type_bow_len, x_path_bow, x_paths, x_path_bow_len, x_ctx_ents, x_ctx_ent_len, x_ctx_ent_num)\n        ans_val, ans_key = self.enc_comp_kv(ans_type_bow, ans_types, ans_path_bow, ans_paths, ans_ctx_ent)\n        return ans_val, ans_key\n\n    def enc_comp_kv(self, ans_type_bow, ans_types, ans_path_bow, ans_paths, ans_ctx_ent):\n        ans_type_bow_val = self.linear_type_bow_val(ans_type_bow)\n        ans_paths_val = self.linear_paths_val(torch.cat([ans_path_bow, ans_paths], -1))\n        ans_ctx_val = self.linear_ctx_val(ans_ctx_ent)\n\n        ans_type_bow_key = self.linear_type_bow_key(ans_type_bow)\n        ans_paths_key = self.linear_paths_key(torch.cat([ans_path_bow, ans_paths], -1))\n        ans_ctx_key = self.linear_ctx_key(ans_ctx_ent)\n\n        ans_comp_val = [ans_type_bow_val, ans_paths_val, ans_ctx_val]\n        ans_comp_key = [ans_type_bow_key, ans_paths_key, ans_ctx_key]\n        return ans_comp_val, ans_comp_key\n\n    def enc_ans_features(self, x_type_bow, x_types, x_type_bow_len, x_path_bow, x_paths, x_path_bow_len, x_ctx_ents, x_ctx_ent_len, x_ctx_ent_num):\n        '''\n        x_types: answer type\n        x_paths: answer path, i.e., bow of relation\n        x_ctx_ents: answer context, i.e., bow of entity words, (batch_size, num_cands, num_ctx, L)\n        '''\n        # ans_types = torch.mean(self.ent_type_embed(x_types.view(-1, x_types.size(-1))), 1).view(x_types.size(0), x_types.size(1), -1)\n        ans_type_bow = (self.lstm_enc_type(x_type_bow.view(-1, x_type_bow.size(-1)), x_type_bow_len.view(-1))[1]).view(x_type_bow.size(0), x_type_bow.size(1), -1)\n        ans_path_bow = (self.lstm_enc_path(x_path_bow.view(-1, x_path_bow.size(-1)), x_path_bow_len.view(-1))[1]).view(x_path_bow.size(0), x_path_bow.size(1), -1)\n        ans_paths = torch.mean(self.relation_embed(x_paths.view(-1, x_paths.size(-1))), 1).view(x_paths.size(0), x_paths.size(1), -1)\n\n        # Avg over ctx\n        ctx_num_mask = create_mask(x_ctx_ent_num.view(-1), x_ctx_ents.size(2), self.use_cuda).view(x_ctx_ent_num.shape + (-1,))\n        ans_ctx_ent = (self.lstm_enc_ctx(x_ctx_ents.view(-1, x_ctx_ents.size(-1)), x_ctx_ent_len.view(-1))[1]).view(x_ctx_ents.size(0), x_ctx_ents.size(1), x_ctx_ents.size(2), -1)\n        ans_ctx_ent = ctx_num_mask.unsqueeze(-1) * ans_ctx_ent\n        ans_ctx_ent = torch.sum(ans_ctx_ent, dim=2) / torch.clamp(x_ctx_ent_num.float().unsqueeze(-1), min=VERY_SMALL_NUMBER)\n\n        if self.ans_enc_dropout:\n            # ans_types = F.dropout(ans_types, p=self.ans_enc_dropout, training=self.training)\n            ans_type_bow = F.dropout(ans_type_bow, p=self.ans_enc_dropout, training=self.training)\n            ans_path_bow = F.dropout(ans_path_bow, p=self.ans_enc_dropout, training=self.training)\n            ans_paths = F.dropout(ans_paths, p=self.ans_enc_dropout, training=self.training)\n            ans_ctx_ent = F.dropout(ans_ctx_ent, p=self.ans_enc_dropout, training=self.training)\n        return ans_type_bow, None, ans_path_bow, ans_paths, ans_ctx_ent\n\nclass SeqEncoder(object):\n    \"\"\"Question Encoder\"\"\"\n    def __init__(self, vocab_size, embed_size, hidden_size, \\\n                seq_enc_type='lstm', word_emb_dropout=None,\n                cnn_kernel_size=[3], bidirectional=False, \\\n                shared_embed=None, init_word_embed=None, use_cuda=True):\n        if seq_enc_type in ('lstm', 'gru'):\n            self.que_enc = EncoderRNN(vocab_size, embed_size, hidden_size, \\\n                        dropout=word_emb_dropout, \\\n                        bidirectional=bidirectional, \\\n                        shared_embed=shared_embed, \\\n                        init_word_embed=init_word_embed, \\\n                        rnn_type=seq_enc_type, \\\n                        use_cuda=use_cuda)\n\n        elif seq_enc_type == 'cnn':\n            self.que_enc = EncoderCNN(vocab_size, embed_size, hidden_size, \\\n                        kernel_size=cnn_kernel_size, dropout=word_emb_dropout, \\\n                        shared_embed=shared_embed, \\\n                        init_word_embed=init_word_embed, \\\n                        use_cuda=use_cuda)\n        else:\n            raise RuntimeError('Unknown SeqEncoder type: {}'.format(seq_enc_type))\n\nclass EncoderRNN(nn.Module):\n    def __init__(self, vocab_size, embed_size, hidden_size, dropout=None, \\\n        bidirectional=False, shared_embed=None, init_word_embed=None, rnn_type='lstm', use_cuda=True):\n        super(EncoderRNN, self).__init__()\n        if not rnn_type in ('lstm', 'gru'):\n            raise RuntimeError('rnn_type is expected to be lstm or gru, got {}'.format(rnn_type))\n        if bidirectional:\n            print('[ Using bidirectional {} encoder ]'.format(rnn_type))\n        else:\n            print('[ Using {} encoder ]'.format(rnn_type))\n        if bidirectional and hidden_size % 2 != 0:\n            raise RuntimeError('hidden_size is expected to be even in the bidirectional mode!')\n        self.dropout = dropout\n        self.rnn_type = rnn_type\n        self.use_cuda = use_cuda\n        self.hidden_size = hidden_size // 2 if bidirectional else hidden_size\n        self.num_directions = 2 if bidirectional else 1\n        self.embed = shared_embed if shared_embed is not None else nn.Embedding(vocab_size, embed_size, padding_idx=0)\n        model = nn.LSTM if rnn_type == 'lstm' else nn.GRU\n        self.model = model(embed_size, self.hidden_size, 1, batch_first=True, bidirectional=bidirectional)\n        if shared_embed is None:\n            self.init_weights(init_word_embed)\n\n    def init_weights(self, init_word_embed):\n        if init_word_embed is not None:\n            print('[ Using pretrained word embeddings ]')\n            self.embed.weight.data.copy_(torch.from_numpy(init_word_embed))\n        else:\n            self.embed.weight.data.uniform_(-0.08, 0.08)\n\n    def forward(self, x, x_len):\n        \"\"\"x: [batch_size * max_length]\n           x_len: [batch_size]\n        \"\"\"\n        x = self.embed(x)\n        if self.dropout:\n            x = F.dropout(x, p=self.dropout, training=self.training)\n\n        sorted_x_len, indx = torch.sort(x_len, 0, descending=True)\n        x = pack_padded_sequence(x[indx], sorted_x_len.data.tolist(), batch_first=True)\n\n        h0 = to_cuda(torch.zeros(self.num_directions, x_len.size(0), self.hidden_size), self.use_cuda)\n        if self.rnn_type == 'lstm':\n            c0 = to_cuda(torch.zeros(self.num_directions, x_len.size(0), self.hidden_size), self.use_cuda)\n            packed_h, (packed_h_t, _) = self.model(x, (h0, c0))\n            if self.num_directions == 2:\n                packed_h_t = torch.cat([packed_h_t[i] for i in range(packed_h_t.size(0))], -1)\n        else:\n            packed_h, packed_h_t = self.model(x, h0)\n            if self.num_directions == 2:\n                packed_h_t = packed_h_t.transpose(0, 1).contiguous().view(query_lengths.size(0), -1)\n\n        hh, _ = pad_packed_sequence(packed_h, batch_first=True)\n\n        # restore the sorting\n        _, inverse_indx = torch.sort(indx, 0)\n        restore_hh = hh[inverse_indx]\n        restore_packed_h_t = packed_h_t[inverse_indx]\n        return restore_hh, restore_packed_h_t\n\n\nclass EncoderCNN(nn.Module):\n    def __init__(self, vocab_size, embed_size, hidden_size, kernel_size=[2, 3], \\\n            dropout=None, shared_embed=None, init_word_embed=None, use_cuda=True):\n        super(EncoderCNN, self).__init__()\n        print('[ Using CNN encoder with kernel size: {} ]'.format(kernel_size))\n        self.use_cuda = use_cuda\n        self.dropout = dropout\n        self.embed = shared_embed if shared_embed is not None else nn.Embedding(vocab_size, embed_size, padding_idx=0)\n        self.cnns = nn.ModuleList([nn.Conv1d(embed_size, hidden_size, kernel_size=k, padding=k-1) for k in kernel_size])\n\n        if len(kernel_size) > 1:\n            self.fc = nn.Linear(len(kernel_size) * hidden_size, hidden_size)\n        if shared_embed is None:\n            self.init_weights(init_word_embed)\n\n    def init_weights(self, init_word_embed):\n        if init_word_embed is not None:\n            print('[ Using pretrained word embeddings ]')\n            self.embed.weight.data.copy_(torch.from_numpy(init_word_embed))\n        else:\n            self.embed.weight.data.uniform_(-0.08, 0.08)\n\n    def forward(self, x, x_len=None):\n        \"\"\"x: [batch_size * max_length]\n           x_len: reserved\n        \"\"\"\n        x = self.embed(x)\n        if self.dropout:\n            x = F.dropout(x, p=self.dropout, training=self.training)\n        # Trun(batch_size, seq_len, embed_size) to (batch_size, embed_size, seq_len) for cnn1d\n        x = x.transpose(1, 2)\n        z = [conv(x) for conv in self.cnns]\n        output = [F.max_pool1d(i, kernel_size=i.size(-1)).squeeze(-1) for i in z]\n\n        if len(output) > 1:\n            output = self.fc(torch.cat(output, -1))\n        else:\n            output = output[0]\n        return None, output\n\n\nclass Attention(nn.Module):\n    def __init__(self, hidden_size, h_state_embed_size=None, in_memory_embed_size=None, atten_type='simple'):\n        super(Attention, self).__init__()\n        self.atten_type = atten_type\n        if not h_state_embed_size:\n            h_state_embed_size = hidden_size\n        if not in_memory_embed_size:\n            in_memory_embed_size = hidden_size\n        if atten_type in ('mul', 'add'):\n            self.W = torch.Tensor(h_state_embed_size, hidden_size)\n            self.W = nn.Parameter(nn.init.xavier_uniform_(self.W))\n            if atten_type == 'add':\n                self.W2 = torch.Tensor(in_memory_embed_size, hidden_size)\n                self.W2 = nn.Parameter(nn.init.xavier_uniform_(self.W2))\n                self.W3 = torch.Tensor(hidden_size, 1)\n                self.W3 = nn.Parameter(nn.init.xavier_uniform_(self.W3))\n        elif atten_type == 'simple':\n            pass\n        else:\n            raise RuntimeError('Unknown atten_type: {}'.format(self.atten_type))\n\n    def forward(self, query_embed, in_memory_embed, atten_mask=None):\n        if self.atten_type == 'simple': # simple attention\n            attention = torch.bmm(in_memory_embed, query_embed.unsqueeze(2)).squeeze(2)\n        elif self.atten_type == 'mul': # multiplicative attention\n            attention = torch.bmm(in_memory_embed, torch.mm(query_embed, self.W).unsqueeze(2)).squeeze(2)\n        elif self.atten_type == 'add': # additive attention\n            attention = torch.tanh(torch.mm(in_memory_embed.view(-1, in_memory_embed.size(-1)), self.W2)\\\n                .view(in_memory_embed.size(0), -1, self.W2.size(-1)) \\\n                + torch.mm(query_embed, self.W).unsqueeze(1))\n            attention = torch.mm(attention.view(-1, attention.size(-1)), self.W3).view(attention.size(0), -1)\n        else:\n            raise RuntimeError('Unknown atten_type: {}'.format(self.atten_type))\n\n        if atten_mask is not None:\n            # Exclude masked elements from the softmax\n            attention = atten_mask * attention - (1 - atten_mask) * INF\n        return attention\n\nclass SelfAttention_CoAtt(nn.Module):\n    def __init__(self, hidden_size, use_cuda=True):\n        super(SelfAttention_CoAtt, self).__init__()\n        self.use_cuda = use_cuda\n        self.hidden_size = hidden_size\n        self.model = nn.LSTM(2 * hidden_size, hidden_size // 2, batch_first=True, bidirectional=True)\n\n    def forward(self, x, x_len, atten_mask):\n        CoAtt = torch.bmm(x, x.transpose(1, 2))\n        CoAtt = atten_mask.unsqueeze(1) * CoAtt - (1 - atten_mask).unsqueeze(1) * INF\n        CoAtt = torch.softmax(CoAtt, dim=-1)\n        new_x = torch.cat([torch.bmm(CoAtt, x), x], -1)\n\n        sorted_x_len, indx = torch.sort(x_len, 0, descending=True)\n        new_x = pack_padded_sequence(new_x[indx], sorted_x_len.data.tolist(), batch_first=True)\n\n        h0 = to_cuda(torch.zeros(2, x_len.size(0), self.hidden_size // 2), self.use_cuda)\n        c0 = to_cuda(torch.zeros(2, x_len.size(0), self.hidden_size // 2), self.use_cuda)\n        packed_h, (packed_h_t, _) = self.model(new_x, (h0, c0))\n\n        # restore the sorting\n        _, inverse_indx = torch.sort(indx, 0)\n        packed_h_t = torch.cat([packed_h_t[i] for i in range(packed_h_t.size(0))], -1)\n        restore_packed_h_t = packed_h_t[inverse_indx]\n        output = restore_packed_h_t\n        return output\n\ndef create_mask(x, N, use_cuda=True):\n    x = x.data\n    mask = np.zeros((x.size(0), N))\n    for i in range(x.size(0)):\n        mask[i, :x[i]] = 1\n    return to_cuda(torch.Tensor(mask), use_cuda)\n"
  },
  {
    "path": "src/core/bamnet/utils.py",
    "content": "'''\nCreated on Oct, 2017\n\n@author: hugo\n\n'''\nimport torch\nfrom torch.autograd import Variable\nimport numpy as np\n\n\ndef to_cuda(x, use_cuda=True):\n    if use_cuda and torch.cuda.is_available():\n        x = x.cuda()\n    return x\n\n# One pass over the dataset\ndef next_batch(memories, queries, query_words, raw_queries, query_mentions, query_lengths, gold_ans_inds, batch_size):\n    for i in range(0, len(memories), batch_size):\n        yield (memories[i: i + batch_size], queries[i: i + batch_size], query_words[i: i + batch_size], raw_queries[i: i + batch_size], query_mentions[i: i + batch_size], query_lengths[i: i + batch_size]), gold_ans_inds[i: i + batch_size]\n\n# One pass over the dataset\ndef next_ent_batch(memories, queries, query_lengths, gold_inds, batch_size):\n    for i in range(0, len(memories), batch_size):\n        yield (memories[i: i + batch_size], queries[i: i + batch_size], query_lengths[i: i + batch_size]), gold_inds[i: i + batch_size]\n"
  },
  {
    "path": "src/core/build_data/__init__.py",
    "content": "'''\nCreated on Oct, 2017\n\n@author: hugo\n\n'''\n"
  },
  {
    "path": "src/core/build_data/build_all.py",
    "content": "'''\nCreated on Oct, 2017\n\n@author: hugo\n\n'''\nimport os\n\nfrom . import utils as build_utils\nfrom ..utils.utils import *\nfrom .build_data import build_vocab, build_data\n\n\ndef build(dpath, version=None, out_dir=None):\n    if not build_utils.built(dpath, version_string=version):\n        raise RuntimeError(\"Please build/preprocess the data by running the build_all_data.py script!\")\n"
  },
  {
    "path": "src/core/build_data/build_data.py",
    "content": "'''\nCreated on Sep, 2017\n\n@author: hugo\n\n'''\nimport os\nimport math\nimport argparse\nfrom itertools import count\nfrom rapidfuzz import fuzz, process\nfrom collections import defaultdict\n\nfrom ..utils.utils import *\nfrom ..utils.generic_utils import normalize_answer, unique\nfrom ..utils.freebase_utils import if_filterout\nfrom .. import config\n\n\nIGNORE_DUMMY = True\nENT_TYPE_HOP = 1\n# Entity mention types: 'NP', 'ORGANIZATION', 'DATE', 'NUMBER', 'MISC', 'ORDINAL', 'DURATION', 'PERSON', 'TIME', 'LOCATION'\n\ndef build_kb_data(kb, used_fbkeys=None):\n    entities = defaultdict(int)\n    entity_types = defaultdict(int)\n    relations = defaultdict(int)\n    vocabs = defaultdict(int)\n    if not used_fbkeys:\n        used_fbkeys = kb.keys()\n    for k in used_fbkeys:\n        if not k in kb:\n            continue\n        v = kb[k]\n        entities[v['id']] += 1\n        # We prefer notable_types than type since they are more representative.\n        # If notable_types are not available, we use only the first available type.\n        # We found the type field contains much noise.\n        selected_types = (v['notable_types'] + v['type'])[:ENT_TYPE_HOP]\n        for ent_type in selected_types:\n            entity_types[ent_type] += 1\n        for token in [y for x in selected_types for y in x.lower().split('/')[-1].split('_')]:\n            vocabs[token] += 1\n        # Add entity vocabs\n        selected_names = v['name'][:1] + v['alias'] # We need all topic entity alias\n        for token in [y for x in selected_names for y in tokenize(x.lower())]:\n            vocabs[token] += 1\n        if not 'neighbors' in v:\n            continue\n        for kk, vv in v['neighbors'].items(): # 1st hop\n            if if_filterout(kk):\n                continue\n            relations[kk] += 1\n            # Add relation vocabs\n            for token in [x for x in kk.lower().split('/')[-1].split('_')]:\n                vocabs[token] += 1\n            for nbr in vv:\n                if isinstance(nbr, str):\n                    for token in [y for y in tokenize(nbr.lower())]:\n                        vocabs[token] += 1\n                    continue\n                elif isinstance(nbr, bool):\n                    continue\n                elif isinstance(nbr, float):\n                    continue\n                    # vocabs.update([y for y in tokenize(str(nbr).lower())])\n                elif isinstance(nbr, dict):\n                    nbr_k = list(nbr.keys())[0]\n                    nbr_v = nbr[nbr_k]\n                    entities[nbr_k] += 1\n                    selected_types = (nbr_v['notable_types'] + nbr_v['type'])[:ENT_TYPE_HOP]\n                    for ent_type in selected_types:\n                        entity_types[ent_type] += 1\n                    selected_names = (nbr_v['name'] + nbr_v['alias'])[:1]\n                    for token in [y for x in selected_names for y in tokenize(x.lower())] + \\\n                        [y for x in selected_types for y in x.lower().split('/')[-1].split('_')]:\n                        vocabs[token] += 1\n                    if not 'neighbors' in nbr_v:\n                        continue\n                    for kkk, vvv in nbr_v['neighbors'].items(): # 2nd hop\n                        if if_filterout(kkk):\n                            continue\n                        relations[kkk] += 1\n                        # Add relation vocabs\n                        for token in [x for x in kkk.lower().split('/')[-1].split('_')]:\n                            vocabs[token] += 1\n                        for nbr_nbr in vvv:\n                            if isinstance(nbr_nbr, str):\n                                for token in [y for y in tokenize(nbr_nbr.lower())]:\n                                    vocabs[token] += 1\n                                continue\n                            elif isinstance(nbr_nbr, bool):\n                                continue\n                            elif isinstance(nbr_nbr, float):\n                                # vocabs.update([y for y in tokenize(str(nbr_nbr).lower())])\n                                continue\n                            elif isinstance(nbr_nbr, dict):\n                                nbr_nbr_k = list(nbr_nbr.keys())[0]\n                                nbr_nbr_v = nbr_nbr[nbr_nbr_k]\n                                entities[nbr_nbr_k] += 1\n                                selected_types = (nbr_nbr_v['notable_types'] + nbr_nbr_v['type'])[:ENT_TYPE_HOP]\n                                for ent_type in selected_types:\n                                    entity_types[ent_type] += 1\n                                selected_names = (nbr_nbr_v['name'] + nbr_nbr_v['alias'])[:1]\n                                for token in [y for x in selected_names for y in tokenize(x.lower())] + \\\n                                    [y for x in selected_types for y in x.lower().split('/')[-1].split('_')]:\n                                    vocabs[token] += 1\n                            else:\n                                raise RuntimeError('Unknown type: %s' % type(nbr_nbr))\n                else:\n                    raise RuntimeError('Unknown type: %s' % type(nbr))\n    return (entities, entity_types, relations, vocabs)\n\ndef build_qa_vocab(qa):\n    vocabs = defaultdict(int)\n    for each in qa:\n        for token in tokenize(each['qText'].lower()):\n            vocabs[token] += 1\n    return vocabs\n\ndef delex_query_topic_ent(query, topic_ent, ent_types):\n    query = tokenize(query.lower())\n    if topic_ent == '':\n        return query, None\n\n    ent_type_dict = {}\n    for ent, type_ in ent_types:\n        if ent not in ent_type_dict:\n            ent_type_dict[ent] = type_\n        else:\n            if ent_type_dict[ent] == 'NP':\n                ent_type_dict[ent] = type_\n\n    ret = process.extract(topic_ent.replace('_', ' '), set(list(zip(*ent_types))[0]), scorer=fuzz.token_sort_ratio)\n    if len(ret) == 0:\n        return query, None\n\n    # We prefer Non-NP entity mentions\n    # e.g., we prefer `uk` than `people in the uk` when matching `united_kingdom`\n    topic_men = None\n    topic_score = None\n    for token, score in ret:\n        if ent_type_dict[token].lower() in config.topic_mention_types:\n            topic_men = token\n            topic_score = score\n            break\n\n    if topic_men is None:\n        return query, None\n\n    topic_ent_type = ent_type_dict[topic_men].lower()\n    topic_tokens = tokenize(topic_men.lower())\n    indices = [i for i, x in enumerate(query) if x == topic_tokens[0]]\n    for i in indices:\n        if query[i: i + len(topic_tokens)] == topic_tokens:\n            start_idx = i\n            end_idx = i + len(topic_tokens)\n            break\n    query_template = query[:start_idx] + [topic_ent_type] + query[end_idx:]\n    return query_template, topic_men\n\ndef delex_query(query, ent_mens, mention_types):\n    for men, type_ in ent_mens:\n        type_ = type_.lower()\n        if type_ in mention_types:\n            men = tokenize(men.lower())\n            indices = [i for i, x in enumerate(query) if x == men[0]]\n            start_idx = None\n            for i in indices:\n                if query[i: i + len(men)] == men:\n                    start_idx = i\n                    end_idx = i + len(men)\n                    break\n            if start_idx is not None:\n                query = query[:start_idx] + ['__{}__'.format(type_)] + query[end_idx:]\n    return query\n\ndef build_data(qa, kb, entity2id, entityType2id, relation2id, vocab2id, pred_seed_ents=None):\n    queries = []\n    raw_queries = []\n    query_mentions = []\n    memories = []\n    cand_labels = [] # Candidate answer labels (i.e., names)\n    gold_ans_labels = [] # True gold answer labels\n    gold_ans_inds = [] # The \"gold\" answer indices corresponding to the cand list\n    for qid, each in enumerate(qa):\n        freebase_key = each['freebaseKey'] if not pred_seed_ents else pred_seed_ents[qid]\n        if isinstance(freebase_key, list):\n            freebase_key = freebase_key[0] if len(freebase_key) > 0 else ''\n        # Convert query to query template\n        query, topic_men = delex_query_topic_ent(each['qText'], freebase_key, each['entities'])\n        query2 = delex_query(query, each['entities'], config.delex_mention_types)\n        q = [vocab2id[x] if x in vocab2id else config.RESERVED_TOKENS['UNK'] for x in query2]\n        queries.append(q)\n        raw_queries.append(query)\n\n        query_mentions.append([(tokenize(x[0].lower()), x[1].lower()) for x in each['entities'] if topic_men != x[0]])\n        gold_ans_labels.append(each['answers'])\n\n        if not freebase_key in kb:\n            gold_ans_inds.append([])\n            memories.append([[]] * 8)\n            cand_labels.append([])\n            continue\n\n        ans_cands = build_ans_cands(kb[freebase_key], entity2id, entityType2id, relation2id, vocab2id)\n        memories.append(ans_cands[:-1])\n        cand_labels.append(ans_cands[-1])\n        if len(ans_cands[0]) == 0:\n            gold_ans_inds.append([])\n            continue\n\n        norm_cand_labels = [normalize_answer(x) for x in ans_cands[-1]]\n        tmp_cand_inds = []\n        for a in each['answers']:\n            a = normalize_answer(a)\n            # Find all the candidiate answers which match the gold answer.\n            inds = [i for i, j in zip(count(), norm_cand_labels) if j == a]\n            tmp_cand_inds.extend(inds)\n        # Note that tmp_cand_inds can be empty in which case\n        # the question can *NOT* be answered by this KB entity.\n        gold_ans_inds.append(tmp_cand_inds)\n    return (queries, raw_queries, query_mentions, memories, cand_labels, gold_ans_inds, gold_ans_labels)\n\ndef build_vocab(data, freebase, used_fbkeys=None, min_freq=1):\n    entities, entity_types, relations, kb_vocabs = build_kb_data(freebase, used_fbkeys)\n\n    # Entity\n    all_entities = set({ent for ent in entities if entities[ent] >= min_freq})\n    entity2id = dict(zip(all_entities, range(len(config.RESERVED_ENTS), len(all_entities) + len(config.RESERVED_ENTS))))\n    for ent, idx in config.RESERVED_ENTS.items():\n        entity2id.update({ent: idx})\n\n    # Entity type\n    all_ent_types = set({ent_type for ent_type in entity_types if entity_types[ent_type] >= min_freq})\n    all_ent_types.update(config.extra_ent_types)\n    entityType2id = dict(zip(all_ent_types, range(len(config.RESERVED_ENT_TYPES), len(all_ent_types) + len(config.RESERVED_ENT_TYPES))))\n    for ent_type, idx in config.RESERVED_ENT_TYPES.items():\n        entityType2id.update({ent_type: idx})\n\n    # Relation\n    all_relations = set({rel for rel in relations if relations[rel] >= min_freq})\n    all_relations.update(config.extra_rels)\n    relation2id = dict(zip(all_relations, range(len(config.RESERVED_RELS), len(all_relations) + len(config.RESERVED_RELS))))\n    for rel, idx in config.RESERVED_RELS.items():\n        relation2id.update({rel: idx})\n\n    # Vocab\n    vocabs = build_qa_vocab(data)\n    for token, count in kb_vocabs.items():\n        vocabs[token] += count\n    # sorted_vocabs = sorted(vocabs.items(), key=lambda d:d[1], reverse=True)\n    all_tokens = set({token for token in vocabs if vocabs[token] >= min_freq})\n    all_tokens.update(config.extra_vocab_tokens)\n    vocab2id = dict(zip(all_tokens, range(len(config.RESERVED_TOKENS), len(all_tokens) + len(config.RESERVED_TOKENS))))\n    for token, idx in config.RESERVED_TOKENS.items():\n        vocab2id.update({token: idx})\n\n    print('Num of entities: %s' % len(entity2id))\n    print('Num of entity_types: %s' % len(entityType2id))\n    print('Num of relations: %s' % len(relation2id))\n    print('Num of vocabs: %s' % len(vocab2id))\n    return entity2id, entityType2id, relation2id, vocab2id\n\ndef build_ans_cands(graph, entity2id, entityType2id, relation2id, vocab2id):\n    cand_ans_bows = [] # bow of answer entity\n    cand_ans_entities = [] # answer entity\n    cand_ans_types = [] # type of answer entity\n    cand_ans_type_bows = [] # bow of answer entity type\n    cand_ans_paths = [] # relation path from topic entity to answer entity\n    cand_ans_path_bows = []\n    cand_ans_ctx = [] # context (i.e., 1-hop entity bows and relation bows) connects to the answer path\n    cand_ans_topic_key_type = [] # topic key entity type\n    cand_labels = [] # candidiate answers\n\n    selected_types = (graph['notable_types'] + graph['type'])[:ENT_TYPE_HOP]\n    topic_key_ent_type_bows = [vocab2id[x] if x in vocab2id else config.RESERVED_TOKENS['UNK'] for y in selected_types for x in y.lower().split('/')[-1].split('_')]\n    topic_key_ent_type = [entityType2id[x] if x in entityType2id else config.RESERVED_ENT_TYPES['UNK'] for x in selected_types]\n\n    # We only consider the alias relations of topic entityies\n    for each in graph['alias']:\n        cand_ans_topic_key_type.append([topic_key_ent_type_bows, topic_key_ent_type])\n        ent_bow = [vocab2id[y] if y in vocab2id else config.RESERVED_TOKENS['UNK'] for y in tokenize(each.lower())]\n        cand_ans_bows.append(ent_bow)\n        cand_ans_entities.append(config.RESERVED_ENTS['PAD'])\n        cand_ans_types.append([])\n        cand_ans_type_bows.append([])\n        cand_ans_paths.append([relation2id['alias'] if 'alias' in relation2id else config.RESERVED_RELS['UNK']])\n        cand_ans_path_bows.append([vocab2id['alias']])\n        # We do not count the topic_entity as context since it is trivial\n        cand_ans_ctx.append([[], []])\n        cand_labels.append(each)\n\n    if len(cand_labels) == 0 and (not 'neighbors' in graph or len(graph['neighbors']) == 0):\n        return ([], [], [], [], [], [], [], [], [])\n\n    for k, v in graph['neighbors'].items():\n        if if_filterout(k):\n            continue\n        k_bow = [vocab2id[x] if x in vocab2id else config.RESERVED_TOKENS['UNK'] for x in k.lower().split('/')[-1].split('_')]\n        for nbr in v:\n            if isinstance(nbr, str):\n                cand_ans_topic_key_type.append([topic_key_ent_type_bows, topic_key_ent_type])\n                ent_bow = [vocab2id[y] if y in vocab2id else config.RESERVED_TOKENS['UNK'] for y in tokenize(nbr.lower())]\n                cand_ans_bows.append(ent_bow)\n                cand_ans_entities.append(config.RESERVED_ENTS['PAD'])\n                cand_ans_types.append([])\n                cand_ans_type_bows.append([])\n                cand_ans_paths.append([relation2id[k] if k in relation2id else config.RESERVED_RELS['UNK']])\n                cand_ans_path_bows.append(k_bow)\n                cand_ans_ctx.append([[], []])\n                cand_labels.append(nbr)\n                continue\n            elif isinstance(nbr, bool):\n                cand_ans_topic_key_type.append([topic_key_ent_type_bows, topic_key_ent_type])\n                cand_ans_bows.append([vocab2id['true' if nbr else 'false']])\n                cand_ans_entities.append(config.RESERVED_ENTS['PAD'])\n                cand_ans_types.append([entityType2id['bool']])\n                cand_ans_type_bows.append([vocab2id['bool']])\n                cand_ans_paths.append([relation2id[k] if k in relation2id else config.RESERVED_RELS['UNK']])\n                cand_ans_path_bows.append(k_bow)\n                cand_ans_ctx.append([[], []])\n                cand_labels.append('true' if nbr else 'false')\n                continue\n            elif isinstance(nbr, float):\n                cand_ans_topic_key_type.append([topic_key_ent_type_bows, topic_key_ent_type])\n                cand_ans_bows.append([vocab2id[str(nbr)] if str(nbr) in vocab2id else config.RESERVED_TOKENS['UNK']])\n                cand_ans_entities.append(config.RESERVED_ENTS['PAD'])\n                cand_ans_types.append([entityType2id['num']])\n                cand_ans_type_bows.append([vocab2id['num']])\n                cand_ans_paths.append([relation2id[k] if k in relation2id else config.RESERVED_RELS['UNK']])\n                cand_ans_path_bows.append(k_bow)\n                cand_ans_ctx.append([[], []])\n                cand_labels.append(str(nbr))\n                continue\n            elif isinstance(nbr, dict):\n                nbr_k = list(nbr.keys())[0]\n                nbr_v = nbr[nbr_k]\n                selected_names = (nbr_v['name'] + nbr_v['alias'])[:1]\n                is_dummy = True\n                if not IGNORE_DUMMY or len(selected_names) > 0: # Otherwise, it is an intermediate (dummpy) node\n                    cand_ans_topic_key_type.append([topic_key_ent_type_bows, topic_key_ent_type])\n                    nbr_k_bow = [vocab2id[y] if y in vocab2id else config.RESERVED_TOKENS['UNK'] for x in selected_names for y in tokenize(x.lower())]\n                    cand_ans_bows.append(nbr_k_bow)\n                    cand_ans_entities.append(entity2id[nbr_k] if nbr_k in entity2id else config.RESERVED_ENTS['UNK'])\n                    selected_types = (nbr_v['notable_types'] + nbr_v['type'])[:ENT_TYPE_HOP]\n                    cand_ans_types.append([entityType2id[x] if x in entityType2id else config.RESERVED_ENT_TYPES['UNK'] for x in selected_types])\n                    cand_ans_type_bows.append([vocab2id[x] if x in vocab2id else config.RESERVED_TOKENS['UNK'] for y in selected_types for x in y.lower().split('/')[-1].split('_')])\n                    cand_ans_paths.append([relation2id[k] if k in relation2id else config.RESERVED_RELS['UNK']])\n                    cand_ans_path_bows.append(k_bow)\n                    cand_labels.append(selected_names[0] if len(selected_names) > 0 else 'UNK')\n                    is_dummy = False\n\n                if not 'neighbors' in nbr_v:\n                    if not is_dummy:\n                        cand_ans_ctx.append([[], []])\n                    continue\n\n                rels = []\n                labels = []\n                all_ctx = [set(), set()]\n                for kk, vv in nbr_v['neighbors'].items(): # 2nd hop\n                    if if_filterout(kk):\n                        continue\n                    kk_bow = [vocab2id[x] if x in vocab2id else config.RESERVED_TOKENS['UNK'] for x in kk.lower().split('/')[-1].split('_')]\n                    all_ctx[1].add(kk)\n                    for nbr_nbr in vv:\n                        if isinstance(nbr_nbr, str):\n                            cand_ans_topic_key_type.append([topic_key_ent_type_bows, topic_key_ent_type])\n                            ent_bow = [vocab2id[y] if y in vocab2id else config.RESERVED_TOKENS['UNK'] for y in tokenize(nbr_nbr.lower())]\n                            cand_ans_bows.append(ent_bow)\n                            cand_ans_entities.append(config.RESERVED_ENTS['PAD'])\n                            cand_ans_types.append([])\n                            cand_ans_type_bows.append([])\n                            cand_ans_paths.append([relation2id[k] if k in relation2id else config.RESERVED_RELS['UNK'], relation2id[kk] if kk in relation2id else config.RESERVED_RELS['UNK']])\n                            cand_ans_path_bows.append(kk_bow + k_bow)\n                            labels.append(nbr_nbr)\n                            all_ctx[0].add(nbr_nbr)\n                            rels.append(kk)\n                            continue\n                        elif isinstance(nbr_nbr, bool):\n                            cand_ans_topic_key_type.append([topic_key_ent_type_bows, topic_key_ent_type])\n                            cand_ans_bows.append([vocab2id['true' if nbr_nbr else 'false']])\n                            cand_ans_entities.append(config.RESERVED_ENTS['PAD'])\n                            cand_ans_types.append([entityType2id['bool']])\n                            cand_ans_type_bows.append([vocab2id['bool']])\n                            cand_ans_paths.append([relation2id[k] if k in relation2id else config.RESERVED_RELS['UNK'], relation2id[kk] if kk in relation2id else config.RESERVED_RELS['UNK']])\n                            cand_ans_path_bows.append(kk_bow + k_bow)\n                            labels.append('true' if nbr_nbr else 'false')\n                            all_ctx[0].add('true' if nbr_nbr else 'false')\n                            rels.append(kk)\n                            continue\n                        elif isinstance(nbr_nbr, float):\n                            cand_ans_topic_key_type.append([topic_key_ent_type_bows, topic_key_ent_type])\n                            cand_ans_bows.append([vocab2id[str(nbr_nbr)] if str(nbr_nbr) in vocab2id else config.RESERVED_TOKENS['UNK']])\n                            cand_ans_entities.append(config.RESERVED_ENTS['PAD'])\n                            cand_ans_types.append([entityType2id['num']])\n                            cand_ans_type_bows.append([vocab2id['num']])\n                            cand_ans_paths.append([relation2id[k] if k in relation2id else config.RESERVED_RELS['UNK'], relation2id[kk] if kk in relation2id else config.RESERVED_RELS['UNK']])\n                            cand_ans_path_bows.append(kk_bow + k_bow)\n                            labels.append(str(nbr_nbr))\n                            all_ctx[0].add(str(nbr_nbr))\n                            rels.append(kk)\n                            continue\n                        elif isinstance(nbr_nbr, dict):\n                            nbr_nbr_k = list(nbr_nbr.keys())[0]\n                            nbr_nbr_v = nbr_nbr[nbr_nbr_k]\n                            selected_names = (nbr_nbr_v['name'] + nbr_nbr_v['alias'])[:1]\n                            if not IGNORE_DUMMY or len(selected_names) > 0:\n                                cand_ans_topic_key_type.append([topic_key_ent_type_bows, topic_key_ent_type])\n                                ent_bow = [vocab2id[y] if y in vocab2id else config.RESERVED_TOKENS['UNK'] for x in selected_names for y in tokenize(x.lower())]\n                                cand_ans_bows.append(ent_bow)\n                                cand_ans_entities.append(entity2id[nbr_nbr_k] if nbr_nbr_k in entity2id else config.RESERVED_ENTS['UNK'])\n                                selected_types = (nbr_nbr_v['notable_types'] + nbr_nbr_v['type'])[:ENT_TYPE_HOP]\n                                cand_ans_types.append([entityType2id[x] if x in entityType2id else config.RESERVED_ENT_TYPES['UNK'] for x in selected_types])\n                                cand_ans_type_bows.append([vocab2id[x] if x in vocab2id else config.RESERVED_TOKENS['UNK'] for y in selected_types for x in y.lower().split('/')[-1].split('_')])\n                                cand_ans_paths.append([relation2id[k] if k in relation2id else config.RESERVED_RELS['UNK'], relation2id[kk] if kk in relation2id else config.RESERVED_RELS['UNK']])\n                                cand_ans_path_bows.append(kk_bow + k_bow)\n                                labels.append(selected_names[0] if len(selected_names) > 0 else 'UNK')\n                                if len(selected_names) > 0:\n                                    all_ctx[0].add(selected_names[0])\n                                rels.append(kk)\n                        else:\n                            raise RuntimeError('Unknown type: %s' % type(nbr_nbr))\n\n                assert len(labels) == len(rels)\n                if not is_dummy:\n                    ctx_ent_bow = [tokenize(x.lower()) for x in all_ctx[0]]\n                    # ctx_rel_bow = list(set([vocab2id[y] for x in all_ctx[1] for y in x.lower().split('/')[-1].split('_') if y in vocab2id]))\n                    ctx_rel_bow = []\n                    cand_ans_ctx.append([ctx_ent_bow, ctx_rel_bow])\n                for i in range(len(labels)):\n                    tmp_ent_names = all_ctx[0] - set([labels[i]])\n                    # tmp_rel_names = all_ctx[1] - set([rels[i]])\n                    ctx_ent_bow = [tokenize(x.lower()) for x in tmp_ent_names]\n                    # ctx_rel_bow = list(set([vocab2id[y] for x in tmp_rel_names for y in x.lower().split('/')[-1].split('_') if y in vocab2id]))\n                    ctx_rel_bow = []\n                    cand_ans_ctx.append([ctx_ent_bow, ctx_rel_bow])\n                cand_labels.extend(labels)\n            else:\n                raise RuntimeError('Unknown type: %s' % type(nbr))\n\n    assert len(cand_ans_bows) == len(cand_ans_entities) == len(cand_ans_types) == len(cand_ans_type_bows) == len(cand_ans_paths) \\\n            == len(cand_ans_ctx) == len(cand_labels) == len(cand_ans_topic_key_type) == len(cand_ans_path_bows)\n    return (cand_ans_bows, cand_ans_entities, cand_ans_type_bows, cand_ans_types, cand_ans_path_bows, cand_ans_paths, cand_ans_ctx, cand_ans_topic_key_type, cand_labels)\n\n\n# Build seed entity candidates for topic entity classification\ndef build_seed_ent_data(qa, kb, entity2id, entityType2id, relation2id, vocab2id, topn, dtype):\n    queries = []\n    seed_ent_features = []\n    seed_ent_labels = []\n    seed_ent_inds = []\n    for each in qa:\n        query = tokenize(each['qText'].lower())\n        q = [vocab2id[x] if x in vocab2id else config.RESERVED_TOKENS['UNK'] for x in query]\n        queries.append(q)\n        tmp_features = []\n        tmp_labels = []\n        tmp_inds = []\n        for i, freebase_key in enumerate(each['freebaseKeyCands'][:topn]):\n            tmp_labels.append(freebase_key)\n            if freebase_key == each['freebaseKey']:\n                tmp_inds.append(i)\n\n            if freebase_key in kb:\n                features = build_seed_entity_feature(freebase_key, kb[freebase_key], entity2id, entityType2id, relation2id, vocab2id)\n                tmp_features.append(features)\n            else:\n                tmp_features.append([[]] * 5)\n\n        if dtype == 'test':\n            if len(tmp_inds) == 0: # No answer\n                tmp_inds.append(-1)\n        else:\n            assert len(tmp_labels) == topn\n\n        assert len(tmp_inds) == 1\n        seed_ent_features.append(list(zip(*tmp_features)))\n        seed_ent_labels.append(tmp_labels)\n        seed_ent_inds.append(tmp_inds)\n    return (queries, seed_ent_features, seed_ent_labels, seed_ent_inds)\n\ndef build_seed_entity_feature(seed_ent, graph, entity2id, entityType2id, relation2id, vocab2id):\n    # candidate seed entity features:\n    # entity name\n    # entity type\n    # entity neighboring relations\n    selected_names = (graph['name'] + graph['alias'])[:1]\n    seed_ent_name = [vocab2id[y] if y in vocab2id else config.RESERVED_TOKENS['UNK'] for x in selected_names for y in tokenize(x.lower())]\n    selected_types = (graph['notable_types'] + graph['type'])[:ENT_TYPE_HOP]\n    seed_ent_type_name = [vocab2id[x] if x in vocab2id else config.RESERVED_TOKENS['UNK'] for y in selected_types for x in y.lower().split('/')[-1].split('_')]\n    seed_ent_type = [entityType2id[x] if x in entityType2id else config.RESERVED_ENT_TYPES['UNK'] for x in selected_types]\n    seed_rel_names = []\n    seed_rels = []\n\n    for k in graph['neighbors']:\n        if if_filterout(k):\n            continue\n        k_bow = [vocab2id[x] if x in vocab2id else config.RESERVED_TOKENS['UNK'] for x in k.lower().split('/')[-1].split('_')]\n        seed_rel_names.append(k_bow)\n        seed_rels.append(relation2id[k] if k in relation2id else config.RESERVED_RELS['UNK'])\n    return (seed_ent_name, seed_ent_type_name, seed_ent_type, seed_rel_names, seed_rels)\n"
  },
  {
    "path": "src/core/build_data/freebase.py",
    "content": "'''\nCreated on Sep, 2017\n\n@author: hugo\n\n'''\nimport os\n\nfrom ..utils.utils import *\n\n\ndef fetch_meta(path):\n    try:\n        data = load_gzip_json(path)\n    except:\n        return {}\n    content = {}\n    properties = data['property']\n    if '/type/object/name' in properties:\n        content['name'] = [x['value'] for x in properties['/type/object/name']['values']]\n    else:\n        content['name'] = []\n    if '/common/topic/alias' in properties:\n        content['alias'] = [x['value'] for x in properties['/common/topic/alias']['values']]\n    else:\n        content['alias'] = []\n    if '/common/topic/notable_types' in properties:\n        content['notable_types'] = [x['id'] for x in properties['/common/topic/notable_types']['values']]\n    else:\n        content['notable_types'] = []\n    if '/type/object/type' in properties:\n        content['type'] = [x['id'] for x in properties['/type/object/type']['values']]\n    else:\n        content['type'] = []\n    return content\n\ndef fetch(data, data_dir):\n    if not 'id' in data:\n        return data['value']\n    mid = data['id']\n    # meta data might not be in the subgraph, get it from target files\n    meta = fetch_meta(os.path.join(data_dir, '{}.json.gz'.format(mid.strip('/').replace('/', '.'))))\n    if meta == {}:\n        if not 'property' in data:\n            if 'text' in data:\n                return data['text']\n            else:\n                import pdb;pdb.set_trace()\n        properties = data['property']\n        if '/type/object/name' in properties:\n            meta['name'] = [x['value'] for x in properties['/type/object/name']['values']]\n        else:\n            meta['name'] = []\n        if '/common/topic/alias' in properties:\n            meta['alias'] = [x['value'] for x in properties['/common/topic/alias']['values']]\n        else:\n            meta['alias'] = []\n        if '/common/topic/notable_types' in properties:\n            meta['notable_types'] = [x['id'] for x in properties['/common/topic/notable_types']['values']]\n        else:\n            meta['notable_types'] = []\n        if '/type/object/type' in properties:\n            meta['type'] = [x['id'] for x in properties['/type/object/type']['values']]\n        else:\n            meta['type'] = []\n    graph = {mid: meta}\n    if not 'property' in data: # we stop at the 2nd hop\n        return graph\n    properties = data['property']\n    neighbors = {}\n    for k, v in properties.items():\n        if k.startswith('/common') or k.startswith('/type') \\\n            or k.startswith('/freebase') or k.startswith('/user') \\\n            or k.startswith('/imdb'):\n            continue\n        if len(v['values']) > 0:\n            neighbors[k] = []\n            for nbr in v['values']:\n                nbr_graph = fetch(nbr, data_dir)\n                neighbors[k].append(nbr_graph)\n    graph[mid]['neighbors'] = neighbors\n    return graph\n"
  },
  {
    "path": "src/core/build_data/utils.py",
    "content": "'''\nCreated on Sep, 2017\n\n@author: hugo\n\n'''\nimport os\nimport datetime\nimport shutil\nfrom collections import defaultdict\nimport numpy as np\nfrom scipy.sparse import *\n\nRESERVED_TOKENS = {'PAD': 0, 'UNK': 1}\n\n\ndef built(path, version_string=None):\n    \"\"\"Checks if 'built.log' flag has been set for that task.\n    If a version_string is provided, this has to match, or the version\n    is regarded as not built.\n    \"\"\"\n    if version_string:\n        fname = os.path.join(path, 'built.log')\n        if not os.path.isfile(fname):\n            return False\n        else:\n            with open(fname, 'r') as read:\n                text = read.read().split('\\n')\n            return (len(text) > 1 and text[1] == version_string)\n    else:\n        return os.path.isfile(os.path.join(path, 'built.log'))\n\ndef mark_done(path, version_string=None):\n    \"\"\"Marks the path as done by adding a 'built.log' file with the current\n    timestamp plus a version description string if specified.\n    \"\"\"\n    with open(os.path.join(path, 'built.log'), 'w') as write:\n        write.write(str(datetime.datetime.today()))\n        if version_string:\n            write.write('\\n' + version_string)\n\ndef make_dir(path):\n    \"\"\"Makes the directory and any nonexistent parent directories.\"\"\"\n    os.makedirs(path, exist_ok=True)\n\ndef remove_dir(path):\n    \"\"\"Removes the given directory, if it exists.\"\"\"\n    shutil.rmtree(path, ignore_errors=True)\n\ndef vectorize_data(queries, query_mentions, memories, max_query_size=None, max_query_markup_size=None, max_mem_size=None, \\\n                max_ans_bow_size=None, max_ans_type_bow_size=None, max_ans_path_bow_size=None, max_ans_path_size=None, \\\n                max_ans_ctx_entity_bows_size=None, max_ans_ctx_relation_bows_size=1, \\\n                verbose=True, fixed_size=False, vocab2id=None):\n    cand_ans_bows, cand_ans_entities, cand_ans_type_bows, cand_ans_types, cand_ans_path_bows, cand_ans_paths, cand_ans_ctx, cand_ans_topic_key = zip(*memories)\n    cand_ans_size = min(max(map(len, (x for x in cand_ans_entities)), default=0), max_mem_size if max_mem_size else float('inf'))\n    if fixed_size:\n        query_size = max_query_size\n        # query_markup_size = max_query_markup_size\n        cand_ans_bows_size = max_ans_bow_size\n        cand_ans_type_bows_size = max_ans_type_bow_size\n        cand_ans_path_bows_size = max_ans_path_bow_size\n        cand_ans_paths_size = max_ans_path_size\n    else:\n        query_size = max(min(max(map(len, queries), default=0), max_query_size if max_query_size else float('inf')), 1)\n        # query_markup_size = max(min(max(map(len, query_mentions), default=0), max_query_markup_size if max_query_markup_size else float('inf')), 1)\n        cand_ans_bows_size = max(min(max(map(len, (y for x in cand_ans_bows for y in x)), default=0), max_ans_bow_size if max_ans_bow_size else float('inf')), 1)\n        cand_ans_type_bows_size = max(min(max(map(len, (y for x in cand_ans_type_bows for y in x)), default=0), max_ans_type_bow_size if max_ans_type_bow_size else float('inf')), 1)\n        cand_ans_path_bows_size = max(min(max(map(len, (y for x in cand_ans_path_bows for y in x)), default=0), max_ans_path_bow_size if max_ans_path_bow_size else float('inf')), 1)\n        cand_ans_paths_size = max(min(max(map(len, (y for x in cand_ans_paths for y in x)), default=0), max_ans_path_size if max_ans_path_size else float('inf')), 1)\n    cand_ans_types_size = max(max(map(len, (y for x in cand_ans_types for y in x)), default=0), 1)\n    cand_ans_ctx_entity_bows_size = max(min(max(map(len, (z for x in cand_ans_ctx for y in x for z in y[0])), default=0), max_ans_ctx_entity_bows_size if max_ans_ctx_entity_bows_size else float('inf')), 1)\n    cand_ans_ctx_relation_bows_size = max(min(max(map(len, (y[1] for x in cand_ans_ctx for y in x)), default=0), max_ans_ctx_relation_bows_size if max_ans_ctx_relation_bows_size else float('inf')), 1)\n    cand_ans_topic_key_ent_type_bows_size = max(max(map(len, (y[0] for x in cand_ans_topic_key for y in x)), default=0), 1)\n    cand_ans_topic_key_ent_types_size = max(max(map(len, (y[1] for x in cand_ans_topic_key for y in x)), default=0), 1)\n\n    if verbose:\n        print('\\nquery_size: {}, cand_ans_size: {}, cand_ans_bows_size: {}, '\n            'cand_ans_type_bows_size: {}, cand_ans_types_size: {}, cand_ans_path_bows_size: {}, cand_ans_paths_size: {}, '\n            'cand_ans_ctx_entity_bows_size: {}, cand_ans_topic_key_ent_types_size: {}'\\\n            .format(query_size, cand_ans_size, cand_ans_bows_size, cand_ans_type_bows_size, \\\n            cand_ans_types_size, cand_ans_path_bows_size, cand_ans_paths_size, cand_ans_ctx_entity_bows_size, \\\n            cand_ans_topic_key_ent_types_size))\n\n    # Question word\n    qw_tokens = [\"which\", \"what\", \"who\", \"whose\", \"whom\", \"where\", \"when\", \"how\", \"why\", \"whether\"]\n    qw_vids = [vocab2id[each] for each in qw_tokens if each in vocab2id]\n    qw_vid2id = dict(zip(qw_vids, range(len(qw_vids))))\n\n    Q = []\n    QW = []\n    Q_len = []\n    for i, q in enumerate(queries):\n        Q_len.append(min(query_size, len(q)))\n        lq = max(0, query_size - len(q))\n        q_vec = q[-query_size:] + [0] * lq\n        Q.append(q_vec)\n        tmp = [qw_vid2id[each] for each in q if each in qw_vid2id]\n        tmp = tmp[-query_size:] + [0] * max(0, query_size - len(tmp))\n        QW.append(tmp)\n\n    cand_ans_bows_vec = []\n    for x in cand_ans_bows:\n        tmp = []\n        for y in x:\n            l = max(0, cand_ans_bows_size - len(y))\n            tmp1 = y[:cand_ans_bows_size] + [0] * l\n            tmp.append(tmp1)\n        tmp += [[0] * cand_ans_bows_size] # Add a dummy candidate after the true sequence\n        cand_ans_bows_vec.append(tmp)\n\n    cand_ans_entities_vec = []\n    for x in cand_ans_entities:\n        cand_ans_entities_vec.append(x + [0]) # Add a dummy candidate after the true sequence\n\n    cand_ans_types_vec = []\n    for x in cand_ans_types:\n        tmp = []\n        for y in x:\n            l = max(0, cand_ans_types_size - len(y))\n            tmp1 = y[:cand_ans_types_size] + [0] * l\n            tmp.append(tmp1)\n        tmp += [[0] * cand_ans_types_size] # Add a dummy candidate after the true sequence\n        cand_ans_types_vec.append(tmp)\n\n    cand_ans_type_bows_vec = []\n    cand_ans_type_bows_len = []\n    for x in cand_ans_type_bows:\n        tmp = []\n        tmp_len = []\n        for y in x:\n            l = max(0, cand_ans_type_bows_size - len(y))\n            tmp1 = y[:cand_ans_type_bows_size] + [0] * l\n            tmp.append(tmp1)\n            tmp_len.append(max(min(cand_ans_type_bows_size, len(y)), 1))\n        tmp += [[0] * cand_ans_type_bows_size] # Add a dummy candidate after the true sequence\n        tmp_len += [1]\n        cand_ans_type_bows_vec.append(tmp)\n        cand_ans_type_bows_len.append(tmp_len)\n\n    cand_ans_paths_vec = []\n    for x in cand_ans_paths:\n        tmp = []\n        for y in x:\n            l = max(0, cand_ans_paths_size - len(y))\n            tmp1 = y[:cand_ans_paths_size] + [0] * l\n            tmp.append(tmp1)\n        tmp += [[0] * cand_ans_paths_size] # Add a dummy candidate after the true sequence\n        cand_ans_paths_vec.append(tmp)\n\n    cand_ans_path_bows_vec = []\n    cand_ans_path_bows_len = []\n    for x in cand_ans_path_bows:\n        tmp = []\n        tmp_len = []\n        for y in x:\n            l = max(0, cand_ans_path_bows_size - len(y))\n            tmp1 = y[:cand_ans_path_bows_size] + [0] * l\n            tmp.append(tmp1)\n            tmp_len.append(max(min(cand_ans_path_bows_size, len(y)), 1))\n        tmp += [[0] * cand_ans_path_bows_size] # Add a dummy candidate after the true sequence\n        tmp_len += [1]\n        cand_ans_path_bows_vec.append(tmp)\n        cand_ans_path_bows_len.append(tmp_len)\n\n    cand_ans_ctx_entity_vec = []\n    cand_ans_ctx_relation_vec = []\n    for x in cand_ans_ctx:\n        tmp_ent = []\n        tmp_rel = []\n        for y in x:\n            tmp_ent.append(y[0]) # y[0] is a list of lists\n            l_rel = max(0, cand_ans_ctx_relation_bows_size - len(y[1]))\n            tmp_rel.append(y[1][:cand_ans_ctx_relation_bows_size] + [0] * l_rel)\n        tmp_ent += [[]] # Add a dummy candidate after the true sequence\n        tmp_rel += [[0] * cand_ans_ctx_relation_bows_size]\n        cand_ans_ctx_entity_vec.append(tmp_ent)\n        cand_ans_ctx_relation_vec.append(tmp_rel)\n\n    cand_ans_topic_key_ent_type_bows_vec = []\n    cand_ans_topic_key_ent_type_vec = []\n    cand_ans_topic_key_ent_type_bows_len = []\n    for x in cand_ans_topic_key:\n        tmp_ent_type_bows = []\n        tmp_ent_type = []\n        tmp_ent_type_bow_len = []\n        for y in x:\n            tmp_ent_type_bows.append(y[0][:cand_ans_topic_key_ent_type_bows_size] + [0] * max(0, cand_ans_topic_key_ent_type_bows_size - len(y[0])))\n            tmp_ent_type.append(y[1][:cand_ans_topic_key_ent_types_size] + [0] * max(0, cand_ans_topic_key_ent_types_size - len(y[1])))\n            tmp_ent_type_bow_len.append(max(min(cand_ans_topic_key_ent_type_bows_size, len(y[0])), 1))\n        tmp_ent_type_bows += [[0] * cand_ans_topic_key_ent_type_bows_size] # Add a dummy candidate after the true sequence\n        tmp_ent_type += [[0] * cand_ans_topic_key_ent_types_size]\n        tmp_ent_type_bow_len += [1]\n        cand_ans_topic_key_ent_type_bows_vec.append(tmp_ent_type_bows)\n        cand_ans_topic_key_ent_type_vec.append(tmp_ent_type)\n        cand_ans_topic_key_ent_type_bows_len.append(tmp_ent_type_bow_len)\n    return Q, QW, Q_len, list(zip(cand_ans_bows_vec, cand_ans_entities_vec, cand_ans_type_bows_vec, cand_ans_types_vec, cand_ans_type_bows_len, cand_ans_path_bows_vec, cand_ans_paths_vec, cand_ans_path_bows_len, cand_ans_ctx_entity_vec, cand_ans_ctx_relation_vec, cand_ans_topic_key_ent_type_bows_vec, cand_ans_topic_key_ent_type_vec, cand_ans_topic_key_ent_type_bows_len))\n\n\ndef vectorize_ent_data(queries, ent_memories, max_query_size=None, \\\n                max_seed_ent_name_size=None, max_seed_type_name_size=None, \\\n                max_seed_rel_name_size=None, max_seed_rel_size=None, verbose=True):\n    seed_ent_name, seed_ent_type_name, seed_ent_type, seed_rel_names, seed_rels = zip(*ent_memories)\n\n    max_query_size = max(min(max(map(len, queries), default=0), max_query_size if max_query_size else float('inf')), 1)\n    cand_seed_ent_name_size = max(min(max(map(len, (y for x in seed_ent_name for y in x)), default=0), max_seed_ent_name_size if max_seed_ent_name_size else float('inf')), 1)\n    cand_seed_type_name_size = max(min(max(map(len, (y for x in seed_ent_type_name for y in x)), default=0), max_seed_type_name_size if max_seed_type_name_size else float('inf')), 1)\n    cand_seed_types_size = max(max(map(len, (y for x in seed_ent_type for y in x)), default=0), 1)\n    cand_seed_rel_name_size = max(min(max(map(len, (z for x in seed_rel_names for y in x for z in y)), default=0), max_seed_rel_name_size if max_seed_rel_name_size else float('inf')), 1)\n    cand_seed_rel_size = max(min(max(map(len, (y for x in seed_rels for y in x)), default=0), max_seed_rel_size if max_seed_rel_size else float('inf')), 1)\n\n\n    if verbose:\n        print('\\nmax_query_size: {}, cand_seed_ent_name_size: {}, cand_seed_type_name_size: {}, '\n            'cand_seed_types_size: {}, cand_seed_rel_name_size: {}, cand_seed_rel_size: {}'.format(max_query_size, \\\n                cand_seed_ent_name_size, cand_seed_type_name_size, cand_seed_types_size, \\\n                cand_seed_rel_name_size, cand_seed_rel_size))\n\n\n    # Query vectorization\n    Q = []\n    Q_len = []\n    for q in queries:\n        Q_len.append(min(max_query_size, len(q)))\n        lq = max(0, max_query_size - len(q))\n        q_vec = q[-max_query_size:] + [0] * lq\n        Q.append(q_vec)\n\n\n    # Entity vectorization\n    cand_seed_ent_name_vec = []\n    cand_seed_ent_name_len = []\n    for x in seed_ent_name:\n        tmp = []\n        tmp_len = []\n        for y in x:\n            l = max(0, cand_seed_ent_name_size - len(y))\n            tmp1 = y[:cand_seed_ent_name_size] + [0] * l\n            tmp.append(tmp1)\n            tmp_len.append(max(min(cand_seed_ent_name_size, len(y)), 1))\n        cand_seed_ent_name_vec.append(tmp)\n        cand_seed_ent_name_len.append(tmp_len)\n\n    cand_seed_type_vec = []\n    for x in seed_ent_type:\n        tmp = []\n        for y in x:\n            l = max(0, cand_seed_types_size - len(y))\n            tmp1 = y[:cand_seed_types_size] + [0] * l\n            tmp.append(tmp1)\n        cand_seed_type_vec.append(tmp)\n\n    cand_seed_type_name_vec = []\n    cand_seed_type_name_len = []\n    for x in seed_ent_type_name:\n        tmp = []\n        tmp_len = []\n        for y in x:\n            l = max(0, cand_seed_type_name_size - len(y))\n            tmp1 = y[:cand_seed_type_name_size] + [0] * l\n            tmp.append(tmp1)\n            tmp_len.append(max(min(cand_seed_type_name_size, len(y)), 1))\n        cand_seed_type_name_vec.append(tmp)\n        cand_seed_type_name_len.append(tmp_len)\n\n\n    cand_seed_rel_vec = []\n    cand_seed_rel_mask = []\n    for x in seed_rels: # example\n        x_tmp = []\n        x_mask = []\n        for y in x: # seed entity\n            l = max(0, cand_seed_rel_size - len(y))\n            y_tmp = y[:cand_seed_rel_size] + [0] * l\n            x_tmp.append(y_tmp)\n            x_mask.append(min(len(y), cand_seed_rel_size))\n        cand_seed_rel_vec.append(x_tmp)\n        cand_seed_rel_mask.append(x_mask)\n\n\n    cand_seed_rel_name_vec = []\n    cand_seed_rel_name_len = []\n    for x in seed_rel_names: # example\n        x_tmp = []\n        x_tmp_len = []\n        for y in x: # seed entity\n            y_tmp = []\n            y_tmp_len = []\n            for z in y: # relation\n                z_l = max(0, cand_seed_rel_name_size - len(z))\n                z_tmp = z[:cand_seed_rel_name_size] + [0] * z_l\n                y_tmp.append(z_tmp)\n                y_tmp_len.append(max(min(cand_seed_rel_name_size, len(z)), 1))\n            y_l = max(0, cand_seed_rel_size - len(y))\n            y_tmp += [[0] * cand_seed_rel_name_size] * y_l\n            y_tmp_len += [1] * y_l\n            x_tmp.append(y_tmp)\n            x_tmp_len.append(y_tmp_len)\n        cand_seed_rel_name_vec.append(x_tmp)\n        cand_seed_rel_name_len.append(x_tmp_len)\n    return Q, Q_len, list(zip(cand_seed_ent_name_vec, cand_seed_ent_name_len, cand_seed_type_name_vec, cand_seed_type_vec, cand_seed_type_name_len, cand_seed_rel_name_vec, cand_seed_rel_vec, cand_seed_rel_name_len, cand_seed_rel_mask))\n"
  },
  {
    "path": "src/core/build_data/webquestions.py",
    "content": "'''\nCreated on Sep, 2017\n\n@author: hugo\n\n'''\nimport os\n# import re\nimport argparse\nfrom nltk.parse.stanford import StanfordDependencyParser\n\nfrom ..utils.utils import *\nfrom ..utils.freebase_utils import if_filterout\nfrom ..utils.generic_utils import *\n\n\ndef get_used_fbkeys(data_dir, out_dir):\n    # Fetch freebase keys used in training and validation sets.\n    fbkeys = set()\n    split = ['factoid_webqa/train.json', 'factoid_webqa/valid.json']\n    files = [os.path.join(data_dir, x) for x in split]\n    for f in files:\n        data = load_json(f)\n        for qa in data:\n            fbkeys.add(qa['freebaseKey'])\n    dump_json(list(fbkeys), os.path.join(out_dir, 'fbkeys_train_valid.json'), indent=1)\n\ndef get_all_fbkeys(data_dir, out_dir):\n    # Fetch all freebase keys possibily useful to answer questions.\n    fbkeys = set()\n    split = ['factoid_webqa/train.json', 'factoid_webqa/valid.json', 'factoid_webqa/test.json']\n    files = [os.path.join(data_dir, x) for x in split]\n    for f in files:\n        data = load_json(f)\n        for qa in data:\n            fbkeys.add(qa['freebaseKey'])\n\n    retrieved_test_path = os.path.join(data_dir, 'factoid_webqa/webquestions.examples.test.retrieved.json')\n    if os.path.exists(retrieved_test_path):\n        data = load_json(retrieved_test_path)\n        for qa in data:\n            if not 'retrievedList' in qa:\n                continue\n            for x in qa['retrievedList'].split():\n                fbkeys.add(x.split(':')[0])\n    dump_json(list(fbkeys), os.path.join(out_dir, 'fbkeys_train_valid_test_retrieved.json'), indent=1)\n\ndef main(fb_path, mid2key_path, data_dir, out_dir):\n    HAS_DEP = False\n    if HAS_DEP:\n        dep_parser = StanfordDependencyParser(model_path=\"edu/stanford/nlp/models/lexparser/englishPCFG.ser.gz\") # Set CLASSPATH and STANFORD_MODELS environment variables beforehand\n    kb = load_ndjson(fb_path, return_type='dict')\n    mid2key = load_json(mid2key_path)\n    all_split_questions = []\n    split = ['factoid_webqa/train.json', 'factoid_webqa/valid.json', 'factoid_webqa/test.json']\n    files = [os.path.join(data_dir, x) for x in split]\n    missing_mid2key = []\n\n    for f in files:\n        data_type = os.path.basename(f).split('.')[0]\n        num_unanswerable = 0\n        all_questions = []\n        data = load_json(f)\n        for q in data:\n            questions = {}\n            questions['answers'] = q['answers']\n            questions['entities'] = q['entities']\n            questions['qText'] = q['qText']\n            questions['qId'] = q['qId']\n            questions['freebaseKey'] = q['freebaseKey']\n            questions['freebaseKeyCands'] = [q['freebaseKey']]\n            for x in q['freebaseMids']:\n                if x['mid'] in mid2key:\n                    fbkey = mid2key[x['mid']]\n                    if fbkey != q['freebaseKey']:\n                        questions['freebaseKeyCands'].append(fbkey)\n                else:\n                    missing_mid2key.append(x['mid'])\n\n            qtext = tokenize(q['qText'])\n            if HAS_DEP:\n                qw = list(set(qtext).intersection(question_word_list))\n                question_word = qw[0] if len(qw) > 0 else ''\n                topic_ent = q['freebaseKey']\n                dep_path = extract_dep_feature(dep_parser, ' '.join(qtext), topic_ent, question_word)\n            else:\n                dep_path = []\n            questions['dep_path'] = dep_path\n            all_questions.append(questions)\n\n            if not q['freebaseKey'] in kb:\n                num_unanswerable += 1\n                continue\n            cand_ans = fetch_ans_cands(kb[q['freebaseKey']])\n            norm_cand_ans = set([normalize_answer(x) for x in cand_ans])\n            norm_gold_ans = [normalize_answer(x) for x in q['answers']]\n            # Check if we can find the gold answer from the candidiate answers.\n            if len(norm_cand_ans.intersection(norm_gold_ans)) == 0:\n                num_unanswerable += 1\n                continue\n        all_split_questions.append(all_questions)\n        print('{} set: Num of unanswerable questions: {}'.format(data_type, num_unanswerable))\n\n    for i, each in enumerate(all_split_questions):\n        dump_ndjson(each, os.path.join(out_dir, split[i].split('/')[-1]))\n\ndef fetch_ans_cands(graph):\n    cand_ans = set() # candidiate answers\n    # We only consider the alias relations of topic entityies\n    cand_ans.update(graph['alias'])\n    for k, v in graph['neighbors'].items():\n        if if_filterout(k):\n            continue\n        for nbr in v:\n            if isinstance(nbr, str):\n                cand_ans.add(nbr)\n                continue\n            elif isinstance(nbr, bool):\n                cand_ans.add('true' if nbr else 'false')\n                continue\n            elif isinstance(nbr, float):\n                cand_ans.add(str(nbr))\n                continue\n            elif isinstance(nbr, dict):\n                nbr_k = list(nbr.keys())[0]\n                nbr_v = nbr[nbr_k]\n                selected_names = nbr_v['name'] if 'name' in nbr_v and len(nbr_v['name']) > 0 else (nbr_v['alias'][:1] if 'alias' in nbr_v else [])\n                cand_ans.add(selected_names[0] if len(selected_names) > 0 else 'UNK')\n                if not 'neighbors' in nbr_v:\n                    continue\n                for kk, vv in nbr_v['neighbors'].items(): # 2nd hop\n                    if if_filterout(kk):\n                        continue\n                    for nbr_nbr in vv:\n                        if isinstance(nbr_nbr, str):\n                            cand_ans.add(nbr_nbr)\n                            continue\n                        elif isinstance(nbr_nbr, bool):\n                            cand_ans.add('true' if nbr_nbr else 'false')\n                            continue\n                        elif isinstance(nbr_nbr, float):\n                            cand_ans.add(str(nbr_nbr))\n                            continue\n                        elif isinstance(nbr_nbr, dict):\n                            nbr_nbr_k = list(nbr_nbr.keys())[0]\n                            nbr_nbr_v = nbr_nbr[nbr_nbr_k]\n                            selected_names = nbr_nbr_v['name'] if 'name' in nbr_nbr_v and len(nbr_nbr_v['name']) > 0 else (nbr_nbr_v['alias'][:1] if 'alias' in nbr_nbr_v else [])\n                            cand_ans.add(selected_names[0] if len(selected_names) > 0 else 'UNK')\n                        else:\n                            raise RuntimeError('Unknown type: %s' % type(nbr_nbr))\n            else:\n                raise RuntimeError('Unknown type: %s' % type(nbr))\n    return list(cand_ans)\n"
  },
  {
    "path": "src/core/config.py",
    "content": "\n# Vocabulary\nRESERVED_TOKENS = {'PAD': 0, 'UNK': 1}\nRESERVED_ENTS = {'PAD': 0, 'UNK': 1}\nRESERVED_ENT_TYPES = {'PAD': 0, 'UNK': 1}\nRESERVED_RELS = {'PAD': 0, 'UNK': 1}\n\nextra_vocab_tokens = ['alias', 'true', 'false', 'num', 'bool'] + \\\n    ['np', 'organization', 'date', 'number', 'misc', 'ordinal', 'duration', 'person', 'time', 'location'] + \\\n    ['__np__', '__organization__', '__date__', '__number__', '__misc__', '__ordinal__', '__duration__', '__person__', '__time__', '__location__']\n\nextra_rels = ['alias']\nextra_ent_types = ['num', 'bool']\n\n\n# BAMnet entity mention types\ntopic_mention_types = {'person', 'organization', 'location', 'misc'}\n# delex_mention_types = {'date', 'time', 'ordinal', 'number'}\ndelex_mention_types = {'date', 'ordinal', 'number'}\nconstraint_mention_types = delex_mention_types\n"
  },
  {
    "path": "src/core/utils/__init__.py",
    "content": "'''\nCreated on Oct, 2017\n\n@author: hugo\n\n'''\n"
  },
  {
    "path": "src/core/utils/freebase_utils.py",
    "content": "'''\nCreated on Oct, 2017\n\n@author: hugo\n\n'''\nfrom rapidfuzz import fuzz, process\n\n\ndef if_filterout(s):\n    if s.endswith('has_sentences') or \\\n        s.endswith('exceptions') or s.endswith('sww_base/source') or \\\n        s.endswith('kwtopic/assessment'):\n        return True\n    else:\n        return False\n\ndef query_kb(kb, ent_name, fuzz_threshold=90):\n    results = []\n    for k, v in kb.items():\n        ret = process.extractOne(ent_name, v['name'] + v['alias'], scorer=fuzz.token_sort_ratio)\n        if ret[1] > fuzz_threshold:\n            results.append((k, ret[0], ret[1]))\n    results = sorted(results, key=lambda d:d[-1], reverse=True)\n    return list(zip(*results))[0] if len(results) > 0 else []\n"
  },
  {
    "path": "src/core/utils/generic_utils.py",
    "content": "'''\nCreated on Oct, 2017\n\n@author: hugo\n\n'''\nimport re, string\nimport numpy as np\nfrom rapidfuzz import fuzz, process\nfrom nltk.corpus import stopwords\n\nfrom .utils import dump_ndarray, tokenize\n\n\nquestion_word_list = 'who, when, what, where, how, which, why, whom, whose'.split(', ')\nstop_words = set(stopwords.words(\"english\"))\n\ndef find_parent(x, tree, conn='<-'):\n    root = tree[0][0]\n    path = []\n    for parent, indicator, child in tree:\n        if x == child[0]:\n            path.extend([conn, '__{}__'.format(indicator), '-', parent[0]])\n            if not parent == root:\n                p = find_parent(parent[0], tree, conn)\n                path.extend(p)\n            return path\n    return path\n\ndef extract_dep_feature(dep_parser, text, topic_ent, question_word):\n    dep = dep_parser.raw_parse(text).__next__()\n    tree = list(dep.triples())\n    topic_ent = list(set(tokenize(topic_ent)) - stop_words)\n    text = text.split()\n\n    path_len = 1e5\n    topic_ent_to_root = []\n    for each in topic_ent:\n        ret = process.extractOne(each, text, scorer=fuzz.token_sort_ratio)\n        if ret[1] < 85:\n            continue\n        tmp = find_parent(ret[0], tree, '->')\n        if len(tmp) > 0 and len(tmp) < path_len:\n            topic_ent_to_root = tmp\n            path_len = len(tmp)\n    question_word_to_root = find_parent(question_word, tree)\n    # if len(question_word_to_root) == 0 or len(topic_ent_to_root) == 0:\n        # import pdb;pdb.set_trace()\n    return question_word_to_root + list(reversed(topic_ent_to_root[:-1]))\n\ndef unique(seq):\n    seen = set()\n    seen_add = seen.add\n    return [x for x in seq if not (x in seen or seen_add(x))]\n\nre_art = re.compile(r'\\b(a|an|the)\\b')\nre_punc = re.compile(r'[%s]' % re.escape(string.punctuation))\n\ndef normalize_answer(s):\n    \"\"\"Lower text and remove extra whitespace.\"\"\"\n    def remove_articles(text):\n        return re_art.sub(' ', text)\n\n    def remove_punc(text):\n        return re_punc.sub(' ', text)  # convert punctuation to spaces\n\n    def white_space_fix(text):\n        return ' '.join(text.split())\n\n    def lower(text):\n        return text.lower()\n\n    return white_space_fix(remove_articles(remove_punc(lower(s))))\n\ndef dump_embeddings(vocab_dict, emb_file, out_path, emb_size=300, binary=False, seed=123):\n    vocab_emb = get_embeddings(emb_file, vocab_dict, binary)\n\n    vocab_size = len(vocab_dict)\n    np.random.seed(seed)\n    embeddings = np.random.uniform(-0.08, 0.08, (vocab_size, emb_size))\n    for w, idx in vocab_dict.items():\n        if w in vocab_emb:\n            embeddings[int(idx)] = vocab_emb[w]\n    embeddings[0] = 0\n    dump_ndarray(embeddings, out_path)\n    return embeddings\n\ndef get_embeddings(emb_file, vocab, binary=False):\n    pt = PreTrainEmbedding(emb_file, binary)\n    vocab_embs = {}\n\n    i = 0.\n    for each in vocab:\n        emb = pt.get_embeddings(each)\n        if not emb is None:\n            vocab_embs[each] = emb\n            i += 1\n    print('get_wordemb hit ratio: %s' % (i / len(vocab)))\n    return vocab_embs\n\nclass PreTrainEmbedding():\n    def __init__(self, file, binary=False):\n        import gensim\n        self.model = gensim.models.KeyedVectors.load_word2vec_format(file, binary=binary)\n\n    def get_embeddings(self, word):\n        word_list = [word, word.upper(), word.lower(), word.title(), string.capwords(word, '_')]\n\n        for w in word_list:\n            try:\n                return self.model[w]\n            except KeyError:\n                # print('Can not get embedding for ', w)\n                continue\n        return None\n"
  },
  {
    "path": "src/core/utils/metrics.py",
    "content": "'''\nCreated on Oct, 2017\n\n@author: hugo\n\nNote: Modified the official evaluation script provided by Berant et al.\n(https://github.com/percyliang/sempre/blob/master/scripts/evaluation.py)\n'''\nfrom .generic_utils import normalize_answer\n\n\ndef calc_f1(gold_list, pred_list):\n    \"\"\"Return a tuple with recall, precision, and f1 for one example\"\"\"\n\n    # Assume all questions have at least one answer\n    if len(gold_list) == 0:\n        raise RuntimeError('Gold list may not be empty')\n    # If we return an empty list recall is zero and precision is one\n    if len(pred_list) == 0:\n        return (0, 1, 0)\n    # It is guaranteed now that both lists are not empty\n\n    # Normalize answers\n    gold_list = [normalize_answer(s) for s in gold_list]\n    pred_list = [normalize_answer(s) for s in pred_list]\n\n    precision = 0\n    for entity in pred_list:\n        if entity in gold_list:\n            precision += 1\n    precision = float(precision) / len(pred_list)\n\n    recall = 0\n    for entity in gold_list:\n        if entity in pred_list:\n              recall += 1\n    recall = float(recall) / len(gold_list)\n\n    f1 = 0\n    if precision + recall > 0:\n        f1 = 2 * recall * precision / (precision + recall)\n    return (recall, precision, f1)\n\ndef calc_avg_f1(gold_list, pred_list, verbose=True):\n    \"\"\"Go over all examples and compute recall, precision and F1\"\"\"\n    avg_recall = 0\n    avg_precision = 0\n    avg_f1 = 0\n    count = 0\n\n    out_f = open('error_analysis.txt', 'w')\n    assert len(gold_list) == len(pred_list)\n    for i, gold in enumerate(gold_list):\n        recall, precision, f1 = calc_f1(gold, pred_list[i])\n        avg_recall += recall\n        avg_precision += precision\n        avg_f1 += f1\n        count += 1\n        if True:\n        # if f1 < 0.6:\n            out_f.write('{}\\t{}\\t{}\\t{}\\n'.format(i, gold, pred_list[i], f1))\n    out_f.close()\n\n    avg_recall = float(avg_recall) / count\n    avg_precision = float(avg_precision) / count\n    avg_f1 = float(avg_f1) / count\n    avg_new_f1 = 0\n    if avg_precision + avg_recall > 0:\n        avg_new_f1 = 2 * avg_recall * avg_precision / (avg_precision + avg_recall)\n\n    if verbose:\n        print(\"Number of questions: \" + str(count))\n        print(\"Average recall over questions: \" + str(avg_recall))\n        print(\"Average precision over questions: \" + str(avg_precision))\n        print(\"Average f1 over questions: \" + str(avg_f1))\n        # print(\"F1 of average recall and average precision: \" + str(avg_new_f1))\n    return count, avg_recall, avg_precision, avg_f1\n"
  },
  {
    "path": "src/core/utils/utils.py",
    "content": "'''\nCreated on Sep, 2017\n\n@author: hugo\n\n'''\nimport os\nimport re\nimport yaml\nimport gzip\nimport json\nimport string\nimport numpy as np\nfrom nltk.tokenize import wordpunct_tokenize#, word_tokenize\n\n\n# tokenize = lambda s: word_tokenize(re.sub(r'[%s]' % punc_wo_dot, ' ', re.sub(r'(?<!\\d)[%s](?!\\d)' % string.punctuation, ' ', s)))\ntokenize = lambda s: wordpunct_tokenize(re.sub('[%s]' % re.escape(string.punctuation), ' ', s))\n\ndef get_config(config_path=\"config.yml\"):\n    with open(config_path, \"r\") as setting:\n        config = yaml.load(setting)\n    return config\n\ndef print_config(config):\n    print(\"**************** MODEL CONFIGURATION ****************\")\n    for key in sorted(config.keys()):\n        val = config[key]\n        keystr = \"{}\".format(key) + (\" \" * (24 - len(key)))\n        print(\"{} -->   {}\".format(keystr, val))\n    print(\"**************** MODEL CONFIGURATION ****************\")\n\ndef read_lines(path_to_file):\n    data = []\n    try:\n        with open(path_to_file, 'r') as f:\n            for line in f:\n                tmp = [float(x) for x in line.strip().split()]\n                data.append(tmp)\n    except Exception as e:\n        raise e\n\n    return data\n\ndef dump_ndarray(data, path_to_file):\n    try:\n        with open(path_to_file, 'wb') as f:\n            np.save(f, data)\n    except Exception as e:\n        raise e\n\ndef load_ndarray(path_to_file):\n    try:\n        with open(path_to_file, 'rb') as f:\n            data = np.load(f)\n    except Exception as e:\n        raise e\n\n    return data\n\ndef dump_ndjson(data, file):\n    try:\n        with open(file, 'w') as f:\n            for each in data:\n                f.write(json.dumps(each) + '\\n')\n    except Exception as e:\n        raise e\n\ndef load_ndjson(file, return_type='array'):\n    if return_type == 'array':\n        return load_ndjson_to_array(file)\n    elif return_type == 'dict':\n        return load_ndjson_to_dict(file)\n    else:\n        raise RuntimeError('Unknown return_type: %s' % return_type)\n\ndef load_ndjson_to_array(file):\n    data = []\n    try:\n        with open(file, 'r') as f:\n            for line in f:\n                data.append(json.loads(line.strip()))\n    except Exception as e:\n        raise e\n    return data\n\ndef load_ndjson_to_dict(file):\n    data = {}\n    try:\n        with open(file, 'r') as f:\n            for line in f:\n                data.update(json.loads(line.strip()))\n    except Exception as e:\n        raise e\n    return data\n\ndef dump_json(data, file, indent=None):\n    try:\n        with open(file, 'w') as f:\n            json.dump(data, f, indent=indent)\n    except Exception as e:\n        raise e\n\ndef load_json(file):\n    try:\n        with open(file, 'r') as f:\n            data = json.load(f)\n    except Exception as e:\n        raise e\n    return data\n\ndef dump_dict_ndjson(data, file):\n    try:\n        with open(file, 'w') as f:\n            for k, v in data.items():\n                line = json.dumps([k, v]) + '\\n'\n                f.write(line)\n    except Exception as e:\n        raise e\n\ndef load_gzip_json(file):\n    try:\n        with gzip.open(file, 'r') as f:\n            data = json.load(f)\n    except Exception as e:\n        raise e\n    return data\n\ndef get_all_files(dir, recursive=False):\n    if recursive:\n        return [os.path.join(root, file) for root, dirnames, filenames in os.walk(dir) for file in filenames if os.path.isfile(os.path.join(root, file)) and not file.startswith('.')]\n    else:\n        return [os.path.join(dir, filename) for filename in os.listdir(dir) if os.path.isfile(os.path.join(dir, filename)) and not filename.startswith('.')]\n\n# Print iterations progress\ndef printProgressBar(iteration, total, prefix = '', suffix = '', decimals = 1, length = 100, fill = '█'):\n    \"\"\"\n    Call in a loop to create terminal progress bar\n    @params:\n        iteration   - Required  : current iteration (Int)\n        total       - Required  : total iterations (Int)\n        prefix      - Optional  : prefix string (Str)\n        suffix      - Optional  : suffix string (Str)\n        decimals    - Optional  : positive number of decimals in percent complete (Int)\n        length      - Optional  : character length of bar (Int)\n        fill        - Optional  : bar fill character (Str)\n    \"\"\"\n    percent = (\"{0:.\" + str(decimals) + \"f}\").format(100 * (iteration / float(total)))\n    filledLength = int(length * iteration // total)\n    bar = fill * filledLength + '-' * (length - filledLength)\n    print('\\r%s |%s| %s%% %s' % (prefix, bar, percent, suffix), end = '\\r')\n"
  },
  {
    "path": "src/joint_test.py",
    "content": "import timeit\nimport argparse\nimport numpy as np\n\nfrom core.bamnet.entnet import EntnetAgent\nfrom core.bamnet.bamnet import BAMnetAgent\nfrom core.build_data.build_all import build\nfrom core.build_data.utils import vectorize_ent_data, vectorize_data\nfrom core.build_data.build_data import build_data\nfrom core.utils.generic_utils import unique\nfrom core.utils.utils import *\nfrom core.utils.metrics import *\n\n\ndef dynamic_pred(pred, margin):\n    predictions = []\n    for i in range(len(pred)):\n        predictions.append(unique([x[0] for x in pred[i] if x[1] + margin >= pred[i][0][1]]))\n    return predictions\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('-bamnet_config', '--bamnet_config', required=True, type=str, help='path to the config file')\n    parser.add_argument('-entnet_config', '--entnet_config', required=True, type=str, help='path to the config file')\n    parser.add_argument('-raw_data', '--raw_data_dir', required=True, type=str, help='raw data dir')\n    cfg = vars(parser.parse_args())\n    bamnet_opt = get_config(cfg['bamnet_config'])\n    entnet_opt = get_config(cfg['entnet_config'])\n\n    start = timeit.default_timer()\n    # Entnet\n    # Ensure data is built\n    build(entnet_opt['data_dir'])\n    data_vec = load_json(os.path.join(entnet_opt['data_dir'], entnet_opt['test_data']))\n\n    queries, memories, ent_labels, ent_inds = data_vec\n    queries, query_lengths, memories = vectorize_ent_data(queries, \\\n                                        memories, max_query_size=entnet_opt['query_size'], \\\n                                        max_seed_ent_name_size=entnet_opt['max_seed_ent_name_size'], \\\n                                        max_seed_type_name_size=entnet_opt['max_seed_type_name_size'], \\\n                                        max_seed_rel_name_size=entnet_opt['max_seed_rel_name_size'], \\\n                                        max_seed_rel_size=entnet_opt['max_seed_rel_size'])\n\n    ent_model = EntnetAgent(entnet_opt)\n    acc = ent_model.evaluate([memories, queries, query_lengths], ent_inds, batch_size=entnet_opt['test_batch_size'])\n    print('acc: {}'.format(acc))\n    pred_seed_ents = ent_model.predict([memories, queries, query_lengths], ent_labels, batch_size=entnet_opt['test_batch_size'])\n\n\n    # BAMnet\n    # Ensure data is built\n    build(bamnet_opt['data_dir'])\n    entity2id = load_json(os.path.join(bamnet_opt['data_dir'], 'entity2id.json'))\n    entityType2id = load_json(os.path.join(bamnet_opt['data_dir'], 'entityType2id.json'))\n    relation2id = load_json(os.path.join(bamnet_opt['data_dir'], 'relation2id.json'))\n    vocab2id = load_json(os.path.join(bamnet_opt['data_dir'], 'vocab2id.json'))\n    ctx_stopwords = {'i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you', \"you're\", \"you've\", \"you'll\", \"you'd\", 'your', 'yours', 'yourself', 'yourselves', 'he', 'him', 'his', 'himself', 'she', \"she's\", 'her', 'hers', 'herself', 'it', \"it's\", 'its', 'itself', 'they', 'them', 'their', 'theirs', 'themselves', 'what', 'which', 'who', 'whom', 'this', 'that', \"that'll\", 'these', 'those', 'am', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has', 'had', 'having', 'do', 'does', 'did', 'doing', 'a', 'an', 'the', 'and', 'but', 'if', 'or', 'because', 'as', 'until', 'while', 'of', 'at', 'by', 'for', 'with', 'about', 'against', 'between', 'into', 'through', 'during', 'before', 'after', 'above', 'below', 'to', 'from', 'up', 'down', 'in', 'out', 'on', 'off', 'over', 'under', 'again', 'further', 'then', 'once', 'here', 'there', 'when', 'where', 'why', 'how', 'all', 'any', 'both', 'each', 'few', 'more', 'most', 'other', 'some', 'such', 'no', 'nor', 'not', 'only', 'own', 'same', 'so', 'than', 'too', 'very', 's', 't', 'can', 'will', 'just', 'don', \"don't\", 'should', \"should've\", 'now', 'd', 'll', 'm', 'o', 're', 've', 'y', 'ain', 'aren', \"aren't\", 'couldn', \"couldn't\", 'didn', \"didn't\", 'doesn', \"doesn't\", 'hadn', \"hadn't\", 'hasn', \"hasn't\", 'haven', \"haven't\", 'isn', \"isn't\", 'ma', 'mightn', \"mightn't\", 'mustn', \"mustn't\", 'needn', \"needn't\", 'shan', \"shan't\", 'shouldn', \"shouldn't\", 'wasn', \"wasn't\", 'weren', \"weren't\", 'won', \"won't\", 'wouldn', \"wouldn't\"}\n\n    # Build data in real time\n    freebase = load_ndjson(os.path.join(cfg['raw_data_dir'], 'freebase_full.json'), return_type='dict')\n    test_data = load_ndjson(os.path.join(cfg['raw_data_dir'], 'raw_test.json'))\n    data_vec = build_data(test_data, freebase, entity2id, entityType2id, relation2id, vocab2id, pred_seed_ents=pred_seed_ents)\n\n    queries, raw_queries, query_mentions, memories, cand_labels, _, gold_ans_labels = data_vec\n    queries, query_words, query_lengths, memories_vec = vectorize_data(queries, query_mentions, memories, \\\n                                        max_query_size=bamnet_opt['query_size'], \\\n                                        max_query_markup_size=bamnet_opt['query_markup_size'], \\\n                                        max_ans_bow_size=bamnet_opt['ans_bow_size'], \\\n                                        vocab2id=vocab2id)\n\n    model = BAMnetAgent(bamnet_opt, ctx_stopwords, vocab2id)\n    pred = model.predict([memories_vec, queries, query_words, raw_queries, query_mentions, query_lengths], cand_labels, batch_size=bamnet_opt['test_batch_size'], margin=2)\n\n    print('\\nPredictions')\n    for margin in bamnet_opt['test_margin']:\n        print('\\nMargin: {}'.format(margin))\n        predictions = dynamic_pred(pred, margin)\n        calc_avg_f1(gold_ans_labels, predictions)\n    print('Runtime: %ss' % (timeit.default_timer() - start))\n"
  },
  {
    "path": "src/run_freebase.py",
    "content": "'''\nCreated on Oct, 2017\n\n@author: hugo\n\n'''\nimport argparse\nimport os\nimport json\n\nfrom core.build_data.freebase import *\nfrom core.utils.utils import *\n\n\nparser = argparse.ArgumentParser()\nparser.add_argument('-data_dir', '--data_dir', required=True, type=str, help='path to the data dir')\nparser.add_argument('-fbkeys', '--freebase_keys', required=True, type=str, help='path to the freebase key file')\nparser.add_argument('-out_dir', '--out_dir', type=str, required=True, help='path to the output dir')\nargs = parser.parse_args()\n\nids = load_json(args.freebase_keys)\ntotal = len(ids)\nprint('Fetching {} entities and their 2-hop neighbors.'.format(total))\nprint_bar_len = 50\ncnt = 0\nmissing_ids = set()\nwith open(os.path.join(args.out_dir, 'freebase.json'), 'a') as out_f:\n    for id_ in ids:\n        try:\n            data = load_gzip_json(os.path.join(args.data_dir, '{}.json.gz'.format(id_)))\n        except:\n            missing_ids.add(id_)\n            continue\n        graph = fetch(data, args.data_dir)\n        graph2 = {id_: list(graph.values())[0]}\n        graph2[id_]['id'] = list(graph.keys())[0]\n        line = json.dumps(graph2) + '\\n'\n        out_f.write(line)\n        cnt += 1\n        if cnt % int(total / print_bar_len) == 0:\n            printProgressBar(cnt, total, prefix='Progress:', suffix='Complete', length=print_bar_len)\n    printProgressBar(cnt, total, prefix='Progress:', suffix='Complete', length=print_bar_len)\n\nprint('Missed %s mids' % len(missing_ids))\ndump_json(list(missing_ids), os.path.join(args.out_dir, 'missing_fbids.json'))\n"
  },
  {
    "path": "src/run_webquestions.py",
    "content": "'''\nCreated on Oct, 2017\n\n@author: hugo\n\n'''\nimport argparse\nfrom core.build_data.webquestions import *\n\nparser = argparse.ArgumentParser()\nparser.add_argument('-fb', '--freebase_path', required=True, type=str, help='path to the freebase data')\nparser.add_argument('-mid2key', '--mid2key_path', required=True, type=str, help='path to the freebase data')\nparser.add_argument('-data_dir', '--data_dir', required=True, type=str, help='path to the data dir')\nparser.add_argument('-out_dir', '--out_dir', type=str, required=True, help='path to the output dir')\nargs = parser.parse_args()\n\nmain(args.freebase_path, args.mid2key_path, args.data_dir, args.out_dir)\n# get_used_fbkeys(args.data_dir, args.out_dir)\n# get_all_fbkeys(args.data_dir, args.out_dir)\n"
  },
  {
    "path": "src/test.py",
    "content": "import timeit\nimport argparse\n\nfrom core.bamnet.bamnet import BAMnetAgent\nfrom core.build_data.build_all import build\nfrom core.build_data.utils import vectorize_data\nfrom core.utils.utils import *\nfrom core.utils.generic_utils import unique\nfrom core.utils.metrics import *\n\n\ndef dynamic_pred(pred, margin):\n    predictions = []\n    for i in range(len(pred)):\n        predictions.append(unique([x[0] for x in pred[i] if x[1] + margin >= pred[i][0][1]]))\n    return predictions\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('-config', '--config', required=True, type=str, help='path to the config file')\n    cfg = vars(parser.parse_args())\n    opt = get_config(cfg['config'])\n\n    # Ensure data is built\n    build(opt['data_dir'])\n    data_vec = load_json(os.path.join(opt['data_dir'], opt['test_data']))\n    vocab2id = load_json(os.path.join(opt['data_dir'], 'vocab2id.json'))\n    ctx_stopwords = {'i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you', \"you're\", \"you've\", \"you'll\", \"you'd\", 'your', 'yours', 'yourself', 'yourselves', 'he', 'him', 'his', 'himself', 'she', \"she's\", 'her', 'hers', 'herself', 'it', \"it's\", 'its', 'itself', 'they', 'them', 'their', 'theirs', 'themselves', 'what', 'which', 'who', 'whom', 'this', 'that', \"that'll\", 'these', 'those', 'am', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has', 'had', 'having', 'do', 'does', 'did', 'doing', 'a', 'an', 'the', 'and', 'but', 'if', 'or', 'because', 'as', 'until', 'while', 'of', 'at', 'by', 'for', 'with', 'about', 'against', 'between', 'into', 'through', 'during', 'before', 'after', 'above', 'below', 'to', 'from', 'up', 'down', 'in', 'out', 'on', 'off', 'over', 'under', 'again', 'further', 'then', 'once', 'here', 'there', 'when', 'where', 'why', 'how', 'all', 'any', 'both', 'each', 'few', 'more', 'most', 'other', 'some', 'such', 'no', 'nor', 'not', 'only', 'own', 'same', 'so', 'than', 'too', 'very', 's', 't', 'can', 'will', 'just', 'don', \"don't\", 'should', \"should've\", 'now', 'd', 'll', 'm', 'o', 're', 've', 'y', 'ain', 'aren', \"aren't\", 'couldn', \"couldn't\", 'didn', \"didn't\", 'doesn', \"doesn't\", 'hadn', \"hadn't\", 'hasn', \"hasn't\", 'haven', \"haven't\", 'isn', \"isn't\", 'ma', 'mightn', \"mightn't\", 'mustn', \"mustn't\", 'needn', \"needn't\", 'shan', \"shan't\", 'shouldn', \"shouldn't\", 'wasn', \"wasn't\", 'weren', \"weren't\", 'won', \"won't\", 'wouldn', \"wouldn't\"}\n\n    queries, raw_queries, query_mentions, memories, cand_labels, _, gold_ans_labels = data_vec\n    queries, query_words, query_lengths, memories_vec = vectorize_data(queries, query_mentions, memories, \\\n                                        max_query_size=opt['query_size'], \\\n                                        max_query_markup_size=opt['query_markup_size'], \\\n                                        max_ans_bow_size=opt['ans_bow_size'], \\\n                                        vocab2id=vocab2id)\n\n    start = timeit.default_timer()\n\n    model = BAMnetAgent(opt, ctx_stopwords, vocab2id)\n    pred = model.predict([memories_vec, queries, query_words, raw_queries, query_mentions, query_lengths], cand_labels, batch_size=opt['test_batch_size'], margin=2)\n\n    print('\\nPredictions')\n    for margin in opt['test_margin']:\n        print('\\nMargin: {}'.format(margin))\n        predictions = dynamic_pred(pred, margin)\n        calc_avg_f1(gold_ans_labels, predictions)\n    print('Runtime: %ss' % (timeit.default_timer() - start))\n    import pdb;pdb.set_trace()\n"
  },
  {
    "path": "src/test_entnet.py",
    "content": "import timeit\nimport argparse\nimport numpy as np\n\nfrom core.bamnet.entnet import EntnetAgent\nfrom core.build_data.build_all import build\nfrom core.build_data.utils import vectorize_ent_data\nfrom core.utils.utils import *\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('-dt', '--datatype', default='test', type=str, help='data type: {train, valid, test}')\n    parser.add_argument('-config', '--config', required=True, type=str, help='path to the config file')\n    cfg = vars(parser.parse_args())\n    opt = get_config(cfg['config'])\n\n    # Ensure data is built\n    build(opt['data_dir'])\n    data_vec = load_json(os.path.join(opt['data_dir'], opt['test_data']))\n\n    queries, memories, ent_labels, ent_inds = data_vec\n    queries, query_lengths, memories = vectorize_ent_data(queries, \\\n                                        memories, max_query_size=opt['query_size'], \\\n                                        max_seed_ent_name_size=opt['max_seed_ent_name_size'], \\\n                                        max_seed_type_name_size=opt['max_seed_type_name_size'], \\\n                                        max_seed_rel_name_size=opt['max_seed_rel_name_size'], \\\n                                        max_seed_rel_size=opt['max_seed_rel_size'])\n\n    start = timeit.default_timer()\n\n    ent_model = EntnetAgent(opt)\n    acc = ent_model.evaluate([memories, queries, query_lengths], ent_inds, batch_size=opt['test_batch_size'])\n    print('acc: {}'.format(acc))\n    print('Runtime: %ss' % (timeit.default_timer() - start))\n"
  },
  {
    "path": "src/train.py",
    "content": "import timeit\nimport argparse\nimport numpy as np\n\nfrom core.bamnet.bamnet import BAMnetAgent\nfrom core.build_data.build_all import build\nfrom core.build_data.utils import vectorize_data\nfrom core.utils.utils import *\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('-config', '--config', required=True, type=str, help='path to the config file')\n    cfg = vars(parser.parse_args())\n    opt = get_config(cfg['config'])\n    print_config(opt)\n\n    # Ensure data is built\n    build(opt['data_dir'])\n    train_vec = load_json(os.path.join(opt['data_dir'], opt['train_data']))\n    valid_vec = load_json(os.path.join(opt['data_dir'], opt['valid_data']))\n\n    vocab2id = load_json(os.path.join(opt['data_dir'], 'vocab2id.json'))\n    ctx_stopwords = {'i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you', \"you're\", \"you've\", \"you'll\", \"you'd\", 'your', 'yours', 'yourself', 'yourselves', 'he', 'him', 'his', 'himself', 'she', \"she's\", 'her', 'hers', 'herself', 'it', \"it's\", 'its', 'itself', 'they', 'them', 'their', 'theirs', 'themselves', 'what', 'which', 'who', 'whom', 'this', 'that', \"that'll\", 'these', 'those', 'am', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has', 'had', 'having', 'do', 'does', 'did', 'doing', 'a', 'an', 'the', 'and', 'but', 'if', 'or', 'because', 'as', 'until', 'while', 'of', 'at', 'by', 'for', 'with', 'about', 'against', 'between', 'into', 'through', 'during', 'before', 'after', 'above', 'below', 'to', 'from', 'up', 'down', 'in', 'out', 'on', 'off', 'over', 'under', 'again', 'further', 'then', 'once', 'here', 'there', 'when', 'where', 'why', 'how', 'all', 'any', 'both', 'each', 'few', 'more', 'most', 'other', 'some', 'such', 'no', 'nor', 'not', 'only', 'own', 'same', 'so', 'than', 'too', 'very', 's', 't', 'can', 'will', 'just', 'don', \"don't\", 'should', \"should've\", 'now', 'd', 'll', 'm', 'o', 're', 've', 'y', 'ain', 'aren', \"aren't\", 'couldn', \"couldn't\", 'didn', \"didn't\", 'doesn', \"doesn't\", 'hadn', \"hadn't\", 'hasn', \"hasn't\", 'haven', \"haven't\", 'isn', \"isn't\", 'ma', 'mightn', \"mightn't\", 'mustn', \"mustn't\", 'needn', \"needn't\", 'shan', \"shan't\", 'shouldn', \"shouldn't\", 'wasn', \"wasn't\", 'weren', \"weren't\", 'won', \"won't\", 'wouldn', \"wouldn't\"}\n\n    train_queries, train_raw_queries, train_query_mentions, train_memories, _, train_gold_ans_inds, _ = train_vec\n    train_queries, train_query_words, train_query_lengths, train_memories = vectorize_data(train_queries, train_query_mentions, \\\n                                        train_memories, max_query_size=opt['query_size'], \\\n                                        max_query_markup_size=opt['query_markup_size'], \\\n                                        max_mem_size=opt['mem_size'], \\\n                                        max_ans_bow_size=opt['ans_bow_size'], \\\n                                        max_ans_path_bow_size=opt['ans_path_bow_size'], \\\n                                        vocab2id=vocab2id)\n\n    valid_queries, valid_raw_queries, valid_query_mentions, valid_memories, valid_cand_labels, valid_gold_ans_inds, valid_gold_ans_labels = valid_vec\n    valid_queries, valid_query_words, valid_query_lengths, valid_memories = vectorize_data(valid_queries, valid_query_mentions, \\\n                                        valid_memories, max_query_size=opt['query_size'], \\\n                                        max_query_markup_size=opt['query_markup_size'], \\\n                                        max_mem_size=opt['mem_size'], \\\n                                        max_ans_bow_size=opt['ans_bow_size'], \\\n                                        max_ans_path_bow_size=opt['ans_path_bow_size'], \\\n                                        vocab2id=vocab2id)\n\n    start = timeit.default_timer()\n\n    model = BAMnetAgent(opt, ctx_stopwords, vocab2id)\n    model.train([train_memories, train_queries, train_query_words, train_raw_queries, train_query_mentions, train_query_lengths], train_gold_ans_inds, \\\n        [valid_memories, valid_queries, valid_query_words, valid_raw_queries, valid_query_mentions, valid_query_lengths], \\\n        valid_gold_ans_inds, valid_cand_labels, valid_gold_ans_labels)\n\n    print('Runtime: %ss' % (timeit.default_timer() - start))\n"
  },
  {
    "path": "src/train_entnet.py",
    "content": "import timeit\nimport argparse\nimport numpy as np\n\nfrom core.bamnet.entnet import EntnetAgent\nfrom core.build_data.build_all import build\nfrom core.build_data.utils import vectorize_ent_data\nfrom core.utils.utils import *\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('-config', '--config', required=True, type=str, help='path to the config file')\n    cfg = vars(parser.parse_args())\n    opt = get_config(cfg['config'])\n    print_config(opt)\n\n    # Ensure data is built\n    build(opt['data_dir'])\n    train_vec = load_json(os.path.join(opt['data_dir'], opt['train_data']))\n    valid_vec = load_json(os.path.join(opt['data_dir'], opt['valid_data']))\n\n    train_queries, train_memories, _, train_ent_inds = train_vec\n    train_queries, train_query_lengths, train_memories = vectorize_ent_data(train_queries, \\\n                                        train_memories, max_query_size=opt['query_size'], \\\n                                        max_seed_ent_name_size=opt['max_seed_ent_name_size'], \\\n                                        max_seed_type_name_size=opt['max_seed_type_name_size'], \\\n                                        max_seed_rel_name_size=opt['max_seed_rel_name_size'], \\\n                                        max_seed_rel_size=opt['max_seed_rel_size'])\n\n    valid_queries, valid_memories, _, valid_ent_inds = valid_vec\n    valid_queries, valid_query_lengths, valid_memories = vectorize_ent_data(valid_queries, \\\n                                        valid_memories, max_query_size=opt['query_size'], \\\n                                        max_seed_ent_name_size=opt['max_seed_ent_name_size'], \\\n                                        max_seed_type_name_size=opt['max_seed_type_name_size'], \\\n                                        max_seed_rel_name_size=opt['max_seed_rel_name_size'], \\\n                                        max_seed_rel_size=opt['max_seed_rel_size'])\n\n    start = timeit.default_timer()\n\n    ent_model = EntnetAgent(opt)\n    ent_model.train([train_memories, train_queries, train_query_lengths], train_ent_inds, \\\n        [valid_memories, valid_queries, valid_query_lengths], valid_ent_inds)\n\n    print('Runtime: %ss' % (timeit.default_timer() - start))\n"
  }
]