[
  {
    "path": ".gitignore",
    "content": ".ipynb_checkpoints/\nkg_data/processed/\nkg_data/baike_triples.txt\nkg_data/baiketriples.zip\nkg_data/.ipynb_checkpoints/"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2018 Tianyu Gao\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "﻿# Distant-Supervised-Chinese-Relation-Extraction\n## 基于远监督的中文关系抽取\n\n### 数据集构建\n\n* 中文通用知识库CN-DBpedia\n* 远监督假设\n\n处理流程可在 kg_data/README.md 中查看。点击[此处(谷歌云盘)](https://drive.google.com/open?id=1XmWW3-wveKiJFauqZTPcSbsgh4lECRBK)下载处理后的数据子集。\n\n### 模型选择\n\n使用 thunlp/OpenNRE 的模型, 具体信息参考其说明。\n\n**源链接:** https://github.com/thunlp/OpenNRE\n\n### 运行代码\n\n数据集文件目录代码默认为 data/chinese，在命令中运行：\n```\npython train_demo.py chinese pcnn att\n```\n### 模型结果\n\n部分关系的结果如下:\n\n类别|精准度|召回率|F1分数\n:-:|:-:|:-:|:-:\n**全部**|**0.95428**|**0.95036**|**0.95232**\n/人物/其它/民族|0.98374|0.979|0.98137\nNA|0.96853|0.97824|0.97336\n/人物/地点/国籍|0.84075|0.92673|0.88164\n/组织/地点/位于|0.85157|0.83652|0.84398\n/人物/其它/职业|0.86121|0.8037|0.83147\n/人物/组织/毕业于|0.84137|0.78092|0.81002\n/组织/人物/校长|0.94118|0.59259|0.72727\n/人物/地点/出生地|0.81049|0.49028|0.61097\n/人物/人物/家庭成员|0.65385|0.37778|0.47887\n/人物/组织/属于|0.99999|0.11364|0.20408\n/地点/地点/包含|0.99999|0.0625|0.11765\n/组织/人物/创始人|0.99999|0.05882|0.11111\n\n某些关系的召回率很低，分析发现原因可能是数据集中该关系的样本非常少。\n\n"
  },
  {
    "path": "draw_plot.py",
    "content": "import sklearn.metrics\nimport matplotlib\n# Use 'Agg' so this program could run on a remote server\nmatplotlib.use('Agg')\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport sys\nimport os\n\nresult_dir = './test_result'\n\ndef main():\n    models = sys.argv[1:]\n    for model in models:\n        x = np.load(os.path.join(result_dir, model +'_x' + '.npy')) \n        y = np.load(os.path.join(result_dir, model + '_y' + '.npy'))\n        f1 = (2 * x * y / (x + y + 1e-20)).max()\n        auc = sklearn.metrics.auc(x=x, y=y)\n        #plt.plot(x, y, lw=2, label=model + '-auc='+str(auc))\n        plt.plot(x, y, lw=2, label=model)\n        print(model + ' : ' + 'auc = ' + str(auc) + ' | ' + 'max F1 = ' + str(f1))\n        print('    P@100: {} | P@200: {} | P@300: {} | Mean: {}'.format(y[100], y[200], y[300], (y[100] + y[200] + y[300]) / 3))\n       \n    plt.xlabel('Recall')\n    plt.ylabel('Precision')\n    plt.ylim([0.3, 1.0])\n    plt.xlim([0.0, 0.4])\n    plt.title('Precision-Recall')\n    plt.legend(loc=\"upper right\")\n    plt.grid(True)\n    plt.savefig(os.path.join(result_dir, 'pr_curve'))\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "kg_data/EntityMatcher.py",
    "content": "import pickle\nimport multiprocessing\nimport time\nimport os\n\nclass EntityMatcher:\n    def __init__(self, entity_file, sentences_folder, process_num):\n        self.process_num = process_num\n\n        with open(entity_file, 'rb') as f:\n            entity_dict = pickle.load(f)\n        self.entities = list(set(list(entity_dict.keys())))\n\n        # sentence files\n        self.sentence_files = []\n        for root, dirs, files in os.walk(sentences_folder):\n            for file in files:\n                self.sentence_files.append(os.path.join(root, file))\n\n    def match(self, file_name):\n        print('Start %s!'%(file_name))\n        with open(file_name, 'rb') as f:\n            data = pickle.load(f)\n        new_data = []\n        for sen in data:\n            eset = []\n            for entity in self.entities:\n                if entity in sen:\n                    eset.append(entity)\n            if len(eset)>1:\n                new_data.append([sen, eset])\n        print('Done %s!'%(file_name))\n        return new_data, file_name\n\n    def write_file(self, data):\n        with open(data[1], 'wb') as f:\n            pickle.dump(data[0], f)\n\n    def run(self):\n        pool = multiprocessing.Pool(processes=self.process_num)\n        for one_file in self.sentence_files:\n            pool.apply_async(self.match, args=(str(one_file), ), callback=self.write_file)\n        pool.close()\n        pool.join()\n\nif __name__ == \"__main__\":\n    entity_file = 'processed/entities.pkl'\n    sentences_folder = 'processed/sentences'\n    process_num = 8\n    st = time.localtime()\n    print('\\n开始时间: ')\n    print(time.strftime(\"%Y-%m-%d %H:%M:%S\", st))\n    em = EntityMatcher(entity_file, sentences_folder, process_num)\n    em.run()\n\n    ed = time.localtime()\n    print('结束时间: ')\n    print (time.strftime(\"%Y-%m-%d %H:%M:%S\", ed))"
  },
  {
    "path": "kg_data/README.md",
    "content": "# 远监督数据集构造流程\n\n\n![](http://www.bbvdd.com/d/20190314161746ddc.png)\n\n## 运行顺序\n0. 下载原始数据, 解压后放在该目录下, 即  `kg_data/baike_triples.txt`\n1. 在kg_data目录下运行jupyter notebook\n```\njupyter notebook .\n```\n2. 按顺序执行 data_process.ipynb\n3. 匹配实体\n``` \npython EntityMatcher.py\n```\n4. 分词\n```\npython SentenceSegment.py\n```\n5. 按顺序执行 add_relation.ipynb\n\n\n## 原始数据\n\n* 数据来源\n\n原始数据采用了中文通用百科知识图谱（CN-DBpedia）公开的部分数据, 包含900万+的百科实体以及6600万+的三元组关系。其中摘要信息400万+, 标签信息1980万+, infobox信息4100万+。\n\n**下载地址：** http://www.openkg.cn/dataset/cndbpedia\n\n**源链接：** http://kw.fudan.edu.cn/cndbpedia\n\n* 数据格式\n\n下载压缩包后解压为baike_triples.txt文件, 文件的每一行为一个三元组。第一个元素为实体名称, 第二个元素为关系或属性指代词, 第三个元素为其对应的值。\n\n```\n\"1+8\"时代广场   中文名  \"1+8\"时代广场\n\"1+8\"时代广场   地点    咸宁大道与银泉大道交叉口\n\"1+8\"时代广场   实质    城市综合体项目\n\"1+8\"时代广场   总建面  约11.28万方\n北京   中文名称   北京\n北京   BaiduTAG  北京, 简称“京”, 是中华人民共和国省级行政区、首都、直辖市, 是全国的政治、文化中心。北京地处中国华北地区, 中心位于东经116°20′、北纬39°56′, 东与天津毗连, 其余均与河北相邻, 北京市总面积16410.54平方千米。\n北京   所属地区   中国华北地区\n\"1.4\"河南兰考火灾事故   地点    河南<a>兰考县</a>城关镇\n\"1.4\"河南兰考火灾事故   时间    2013年1月4日\n\"1.4\"河南兰考火灾事故   结果    一人重伤\n```\n\n## 构建实体字典\n\n原始数据的每行第一个元素为实体, **根据正则表达式筛选全部为中文字符的实体**, 转换为字典格式。\n\n处理后数据格式如下：\n\n```python\n{\n    \"北京\":{\n        \"中文名称\": \"北京\",\n        \"BaiduCard\": \"北京, 简称“京”, 是中华人民共和国省级行政区、首都、直辖市, 是全国的政治、文化中心。北京地处中国华北地区, 中心位于东经116°20′、北纬39°56′, 东与天津毗连, 其余均与河北相邻, 北京市总面积16410.54平方千米。\",\n        \"所属地区\": \"中国华北地区\"\n    },\n    ...\n}\n```\n\n## 获取句子集合、句子预处理\n\n实体的`BaiduCard`属性为实体的百度百科简介, 通常为多个句子。根据实体字典获取句子集合, 存为列表格式。\n\n对所有句子进行预处理, **去除所有中文字符、中文常用标点之外的所有字符,  并对多个句子进行拆分**, 存为列表格式。\n\n处理后数据格式如下：\n```python\n[\n    \"北京, 简称京, 是中华人民共和国省级行政区、首都、直辖市, 是全国的政治、文化中心。\",\n    \"北京地处中国华北地区, 中心位于东经、北纬, 东与天津毗连, 其余均与河北相邻, 北京市总面积平方千米。\",\n    ...\n]\n```\n\n## 句子匹配实体\n\n对每一个句子, 遍历实体集合, 根据**字符串匹配**保存所有出现在句子中的实体。**过滤掉没有实体或仅有一个实体出现的句子**, 数据处理为`[[sentence, [entity1,...]], ...]`的格式。\n\n处理后数据格式如下：\n```python\n[\n    [\n        \"北京, 简称京, 是中华人民共和国省级行政区、首都、直辖市, 是全国的政治、文化中心。\",\n        [\n            \"北京\",\n            \"中华人民共和国\",\n            \"政治\"\n        ]\n    ], \n    ...\n]\n```\n\n## 句子分词\n\n使用`Python`的`jieba`库进行中文分词, 对中文句子进行分词。将数据处理为`[[sentence, [entity1,...], [sentence_seg]], ...]`的格式。\n\n收集所有分词后的句子, 作为语料库使用`Python`的`word2vec`库训练词向量。\n\n**jieba 使用教程：** https://github.com/fxsjy/jieba \n\n**word2vec 使用教程：** https://radimrehurek.com/gensim/models/word2vec.html\n\n### 定义用户字典\n为防止实体被错误分词, 将所有实体(实体字典的键集合)写入到文件`dict.txt`作为用户字典。\n### 定义停用词\n定义文件`stop_word.txt`, 在分词过程中对句子去除中文停用词。(网上资源较多)\n\n### 训练词向量\n\n处理后数据格式如下：\n```python\n[\n    [\n        \"北京, 简称京, 是中华人民共和国省级行政区、首都、直辖市, 是全国的政治、文化中心。\",\n        [\n            \"北京\",\n            \"中华人民共和国\",\n            \"政治\"\n        ],\n        [\n            \"北京\",\n            \"简称\",\n            \"京\",\n            ...\n        ]\n    ], \n    ...\n]\n```\n\n## 句子、实体对筛选\n\n对分词后的句子重新对实体进行筛选, 对每一个句子的实体列表中的实体, 若其没有在分词后的句子中出现, 则去除该实体。对筛选后的实体集合两两组合, 数据处理为`[[sentence, entity_head, entity_tail, [sentence_seg]]]`的格式。（一个句子可能被用于多个样本。）\n\n此处先对句子匹配实体, 去除不符合条件的句子后然后分词, 再用分词后的句子匹配实体的主要原因是：\n1. 某些实体名称可能是另一实体的子集, 如“北京”和“北京大学”。在句子“北京大学是中国的著名大学。”中, 出现的实体应仅为“北京大学”。\n2. 分词时间较长, 不对句子进行初步筛选, 直接对所有句子先分词再匹配实体, 这样效率较低。\n\n处理后数据格式如下：\n```python\n[\n    [\n        \"北京, 简称京, 是中华人民共和国省级行政区、首都、直辖市, 是全国的政治、文化中心。\",\n        \"北京\",\n        \"中华人民共和国\",\n        [\n            \"北京\",\n            \"简称\",\n            ...\n        ]\n    ], \n    [\n        \"北京, 简称京, 是中华人民共和国省级行政区、首都、直辖市, 是全国的政治、文化中心。\",\n        \"北京\",\n        \"政治\",\n        [\n            \"北京\",\n            \"简称\",\n            ...\n        ]\n    ], \n    ...\n]\n```\n\n## 添加关系标签\n\n根据对原始数据集的分析, 人工预定义了23种出现频率较高关系, 见附录1, 其中'NA'表示两实体没有关系或存在其他关系。同时, 原始数据中的关系/属性并没有对齐（如妻子、夫人对应同一种关系）, 人工编写规则对关系对齐、聚合。\n\n遍历上一步的每一条数据, 根据实体字典和人工定义的关系对齐表进行关系标注。数据处理为`[[sentence, entity_head, entity_tail, relation, [sentence_seg]]]`的格式。\n\n处理后数据格式如下：\n```python\n[\n    [\n        \"北京, 简称京, 是中华人民共和国省级行政区、首都、直辖市, 是全国的政治、文化中心。\",\n        \"北京\",\n        \"中华人民共和国\",\n        \"/地点/地点/首都\"\n        [\n            \"北京\",\n            \"简称\",\n            ...\n        ]\n    ], \n    ...\n]\n```\n## 数据集格式转化\n\n对数据进行格式转化, 并添加'id'、'type'或其它等等属性。数据处理为:\n```python\n[\n    {\n        \"head\": {\n            \"word\": \"北京\",\n            \"id\": \"666\",\n            ...\n        },\n        \"tail\": {\n            \"word\": \"中华人名共和国\",\n            \"id\": \"6\",\n            ...\n        },\n        \"relation\": \"/地点/地点/首都\",\n        \"sentence\": \"北京 简称 京 是 中华人民共和国 省级 行政区 首都 直辖市 是 全国 的 政治 文化 中心\",\n        ...\n    }\n]\n```\n\n## 划分训练集、测试集\n\n每种标签按照3:1的比例划分训练集、测试集。\n\n# 一些问题\n1. 句子较多, 匹配实体、分词、训练词向量时间较长（400W句子匹配8W实体, 使用8个线程约需1~2小时？）, 建议先使用较少数据预测下运行时间, 使用多线程或者数据子集进行操作。\n2. 部分数据清洗工作较为简单粗暴, 存在改进空间。\n3. 关系种类较少, 关系对齐规则较为简单, 且原始数据中存在部分噪声（如BaiduTAG被错误分类）, 数据集中存在噪声。\n\n\n# 附录\n\n\n**关系对齐/聚合：** 对于三元组(head, relation, tail), 其属于关系 `/人物/地点/出生地` 的条件是head属于人物类别,  tail属于地点类别,  relation为 `/人物/地点/出生地` 对应关系指代词集合中的某一个。\n\n \n\n实体类别|BaiduTAG中至少含有以下类别中的一个\n:-:|:-:\n人物|人物、歌手、演员、作家\n机构|机构、企业、公司、学校、部门、大学\n地点|地点、地理、城市、国家、地区\n其它| 不限制\n\n\n\n<table>\n    <tr>\n        <th width=20px>序号</th>\n        <th>关系类别</th>\n        <th>关系指代词集合</th>\n    </tr>\n    <tr>\n        <td align=\"center\" width=10px>0</td>\n        <td width=200px align=\"center\">NA</td>\n        <td>原始数据不存在关系</td>\n    </tr>\n    <tr>\n        <td width=50px align=\"center\">1</td>\n        <td width=200px align=\"center\">/人物/人物/家庭成员</td>\n        <td>父亲、母亲、丈夫、妻子、儿子、女儿、哥哥、妹妹、姐姐、弟弟、孙子、孙女、爷爷、奶奶、外婆、外公、家人、家庭成员 ,夫人、对象、夫君</td>\n    </tr>\n    <tr>\n        <td width=50px align=\"center\">2</td>\n        <td width=200px align=\"center\">/人物/人物/社交关系</td>\n        <td> 朋友、好友、同学、合作、搭档、经纪人、师从</td>\n    </tr>\n    <tr>\n        <td width=50px align=\"center\">3</td>\n        <td width=200px align=\"center\">/人物/地点/出生地</td>\n        <td>出生地、出生于、来自、歌手出生地、作者出生地、出生在、作者出生地、出生</td>\n    </tr>\n    <tr>\n        <td width=50px align=\"center\">4</td>\n        <td width=200px align=\"center\">/人物/地点/居住地</td>\n        <td>居住地、主要居住地、居住、现居住、目前居住地、现居住于、居住地点、居住于</td>\n    </tr>\n    <tr>\n        <td width=50px align=\"center\">5</td>\n        <td width=200px align=\"center\">/人物/地点/国籍</td>\n        <td>国籍、国家</td>\n    </tr>\n    <tr>\n        <td width=50px align=\"center\">6</td>\n        <td width=200px align=\"center\">/人物/组织/毕业于</td>\n        <td>毕业院校、毕业于、毕业学院、本科毕业院校、最后毕业院校、毕业高中、毕业地点、本科毕业学校、知名校友</td>\n    </tr>\n    <tr>\n        <td width=50px align=\"center\">7</td>\n        <td width=200px align=\"center\">/人物/组织/属于</td>\n        <td>隶属单位、经纪公司、隶属关系、行政隶属、隶属学校、隶属大学、隶属地区、所属公司、签约公司、任职公司、工作单位、所属</td>\n    </tr>\n    <tr>\n        <td width=50px align=\"center\">8</td>\n        <td width=200px align=\"center\">/人物/其它/职业</td>\n        <td>职业</td>\n    </tr>\n    <tr>\n        <td width=50px align=\"center\">9</td>\n        <td width=200px align=\"center\">/人物/其它/民族</td>\n        <td>民族</td>\n    </tr>\n    <tr>\n        <td width=50px align=\"center\">10</td>\n        <td width=200px align=\"center\">/组织/人物/拥有者</td>\n        <td>拥有、拥有者</td>\n    </tr>\n    <tr>\n        <td width=50px align=\"center\">11</td>\n        <td width=200px align=\"center\">/组织/人物/创始人</td>\n        <td>创始人、创始、主要创始人、集团创始人</td>\n    </tr>\n    <tr>\n        <td width=50px align=\"center\">12</td>\n        <td width=200px align=\"center\">/组织/人物/校长</td>\n        <td>校长、现任校长、学校校长、总校长</td>\n    </tr>\n    <tr>\n        <td width=50px align=\"center\">13</td>\n        <td width=200px align=\"center\">/组织/人物/领导人</td>\n        <td>领导、现任领导、领导单位、主要领导、领导人、主要领导人</td>\n    </tr>\n    <tr>\n        <td width=50px align=\"center\">14</td>\n        <td width=200px align=\"center\">/组织/组织/周边</td>\n        <td>周围景观、周边景点</td>\n    </tr>\n    <tr>\n        <td width=50px align=\"center\">15</td>\n        <td width=200px align=\"center\">/组织/地点/位于</td>\n        <td>所属地区、国家、地区、地理位置、位于、区域、地点、总部地点、所在地、所在区域、位于城市、总部位于、酒店位于、学校位于、最早位于、地址、所在城市、城市、主要城市、坐落于</td>\n    </tr>\n    <tr>\n        <td width=50px align=\"center\">16</td>\n        <td width=200px align=\"center\">/地点/人物/相关人物</td>\n        <td>相关人物、知名人物、历史人物</td>\n    </tr>\n    <tr>\n        <td width=50px align=\"center\">17</td>\n        <td width=200px align=\"center\">/地点/地点/位于</td>\n        <td>所属地区、所属国、所属洲、所属州、所属国家、最大城市、地区、地理位置、位于、区域、地点、总部地点、所在地、所在区域、位于城市、总部位于、酒店位于、学校位于、最早位于、地址、所在城市、城市、主要城市、坐落于</td>\n    </tr>\n    <tr>\n        <td width=50px align=\"center\">18</td>\n        <td width=200px align=\"center\">/地点/地点/毗邻</td>\n        <td>毗邻、东邻、邻近行政区、相邻、紧邻、邻近、北邻、南邻、邻国</td>\n    </tr>\n    <tr>\n        <td width=50px align=\"center\">19</td>\n        <td width=200px align=\"center\">/地点/地点/包含</td>\n        <td>包含、包含国家、包含人物、下辖地区、下属、</td>\n    </tr>\n    <tr>\n        <td width=50px align=\"center\">20</td>\n        <td width=200px align=\"center\">/地点/地点/首都</td>\n        <td>首都</td>\n    </tr>\n    <tr>\n        <td width=50px align=\"center\">21</td>\n        <td width=200px align=\"center\">/地点/组织/景点</td>\n        <td>著名景点、主要景点、旅游景点、特色景点</td>\n    </tr>\n    <tr>\n        <td width=50px align=\"center\">22</td>\n        <td width=200px align=\"center\">/地点/其它/气候</td>\n        <td> 气候类型、气候条件、气候、气候带</td>\n    </tr>\n</table>"
  },
  {
    "path": "kg_data/SentenceSegment.py",
    "content": "import pickle\nimport jieba\nimport os\nimport multiprocessing\nimport time\n\ndef read_txt(file_name):\n    txt_data = []\n    with open(file_name, 'r', encoding='utf8') as f:\n        d = f.readline()\n        while d:\n            txt_data.append(d.strip())\n            d = f.readline()\n    return txt_data\n\nclass SentenceSegment:\n    def __init__(self, dict_file, stop_word_file, sentences_folder, process_num):\n        self.process_num = process_num\n        self.stop_word = read_txt(stop_word_file)\n        jieba.load_userdict(dict_file)\n        # sentence files\n        self.sentence_files = []\n        for root, dirs, files in os.walk(sentences_folder):\n            for file in files:\n                self.sentence_files.append(os.path.join(root, file))\n\n\n    def segment(self, file_name):\n        with open(file_name, 'rb') as f:\n            data = pickle.load(f)\n        new_data = []\n        for d in data:\n            # sentence segment\n            sen_seg = []\n            for word in jieba.cut(d[0]):\n                if word not in self.stop_word:\n                    sen_seg.append(word)\n            d.append(sen_seg)\n            # filter entities again\n            new_eset = []\n            for entity in d[1]:\n                if entity in sen_seg:\n                    new_eset.append(entity)\n            # remove data the number of whose entity less than 2\n            # and rebuilt data\n            if len(new_eset)>1:\n                for i in new_eset:\n                    for j in new_eset:\n                        if j!=i:\n                            new_data.append([d[0], i, j, sen_seg])\n        print('%s done!'%(file_name))\n        return new_data, file_name\n\n    def write_file(self, data):\n        with open(data[1], 'wb') as f:\n            pickle.dump(data[0], f)\n\n    def run(self):\n        pool = multiprocessing.Pool(processes=self.process_num)\n        for one_file in self.sentence_files:\n            pool.apply_async(self.segment, args=(str(one_file), ), callback=self.write_file)\n        pool.close()\n        pool.join()\n\nif __name__ == \"__main__\":\n    dict_file = 'processed/entity.txt'\n    stop_word = 'stop_word.txt'\n    sentences_folder = 'processed/sentences'\n    process_num = 8\n    st = time.localtime()\n\n    ss = SentenceSegment(dict_file, stop_word, sentences_folder, process_num)\n    ss.run()\n\n    ed = time.localtime()\n    print('\\n开始时间: ')\n    print(time.strftime(\"%Y-%m-%d %H:%M:%S\", st))\n    print('结束时间: ')\n    print (time.strftime(\"%Y-%m-%d %H:%M:%S\", ed))"
  },
  {
    "path": "kg_data/add_relation.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import pickle\\n\",\n    \"import json\\n\",\n    \"import os\\n\",\n    \"import math\\n\",\n    \"from gensim.models import Word2Vec\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"0\"\n      ]\n     },\n     \"execution_count\": 3,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"data = []\\n\",\n    \"for i in range(100):\\n\",\n    \"    with open('processed/sentences/sen'+str(i), 'rb') as f:\\n\",\n    \"        data += pickle.load(f)\\n\",\n    \"len(data)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"with open('processed/entities.pkl', 'rb') as f:\\n\",\n    \"    entities = pickle.load(f)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## 关系标注\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"collapsed\": true\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# 根据规则 替换所有关系\\n\",\n    \"\\n\",\n    \"tgs = {\\n\",\n    \"    \\\"per\\\": [\\\"人物\\\", \\\"歌手\\\", \\\"演员\\\", \\\"作家\\\"],\\n\",\n    \"    \\\"org\\\": [\\\"机构\\\", \\\"企业\\\", \\\"公司\\\", \\\"学校\\\", \\\"部门\\\", \\\"大学\\\"],\\n\",\n    \"    \\\"pl\\\": [\\\"地点\\\", \\\"地理\\\", \\\"城市\\\", \\\"国家\\\", \\\"地区\\\"]\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"config = {\\n\",\n    \"        # person  9\\n\",\n    \"    \\\"per2per_family_members\\\" : [\\\"父亲\\\",\\\"母亲\\\",\\\"丈夫\\\",\\\"妻子\\\",\\\"儿子\\\",\\\"女儿\\\",\\\"哥哥\\\",\\\"妹妹\\\",\\\"姐姐\\\",\\\"弟弟\\\",\\\"孙子\\\",\\n\",\n    \"                        \\\"孙女\\\",\\\"爷爷\\\",\\\"奶奶\\\",\\\"外婆\\\", \\\"外公\\\",\\\"家人\\\",\\\"家庭成员\\\" ,\\\"夫人\\\",\\\"对象\\\",\\\"夫君\\\"],\\n\",\n    \"    \\\"per2per_social_members\\\" : [\\\"朋友\\\", \\\"好友\\\", \\\"同学\\\", \\\"合作\\\", \\\"搭档\\\", \\\"经纪人\\\", \\\"师从\\\"],\\n\",\n    \"\\n\",\n    \"    \\\"per2pl_birth_place\\\" : [\\\"出生地\\\", \\\"出生于\\\", \\\"来自\\\", \\\"歌手出生地\\\", \\\"作者出生地\\\", \\\"出生在\\\", \\\"作者出生地\\\", \\\"出生\\\"],\\n\",\n    \"    \\\"per2pl_live_place\\\" : [\\\"居住地\\\", \\\"主要居住地\\\", \\\"居住\\\", \\\"现居住\\\", \\\"目前居住地\\\", \\\"现居住于\\\", \\\"居住地点\\\", \\\"居住于\\\"],\\n\",\n    \"    \\n\",\n    \"    \\\"per2pl_country\\\": [\\\"国籍\\\", \\\"国家\\\"],\\n\",\n    \"    \\\"per2org_graduate_from\\\" : [\\\"毕业院校\\\", \\\"毕业于\\\", \\\"毕业学院\\\", \\\"本科毕业院校\\\", \\\"最后毕业院校\\\", \\\"毕业高中\\\", \\\"毕业地点\\\", \\\"本科毕业学校\\\", \\\"知名校友\\\"],\\n\",\n    \"    \\\"per2org_belong_to\\\" : [\\\"隶属单位\\\", \\\"经纪公司\\\", \\\"隶属关系\\\", \\\"行政隶属\\\", \\\"隶属学校\\\", \\\"隶属大学\\\", \\\"隶属地区\\\", \\\"所属公司\\\", \\\"签约公司\\\", \\\"任职公司\\\", \\\"工作单位\\\", \\\"所属\\\"],\\n\",\n    \"\\n\",\n    \"    \\\"per2oth_profession\\\" : ['职业'],\\n\",\n    \"    \\\"per2oth_nation\\\" : ['民族'],\\n\",\n    \"\\n\",\n    \"    # orgnazition  9\\n\",\n    \"    \\\"org2per_owner\\\" : [\\\"拥有\\\", \\\"拥有者\\\"],\\n\",\n    \"    \\\"org2per_founder\\\" : [\\\"创始人\\\", \\\"创始\\\", \\\"主要创始人\\\", \\\"集团创始人\\\"],\\n\",\n    \"    \\\"org2per_school_leader\\\" : [\\\"校长\\\", \\\"现任校长\\\", \\\"学校校长\\\", \\\"总校长\\\"],\\n\",\n    \"    \\\"org2per_leader\\\" : [\\\"领导\\\", \\\"现任领导\\\", \\\"领导单位\\\", \\\"主要领导\\\", \\\"领导人\\\", \\\"主要领导人\\\"],\\n\",\n    \"\\n\",\n    \"    \\\"org2org_surroundings\\\" : [\\\"周围景观\\\", \\\"周边景点\\\"],\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"    \\\"org2pl_location\\\" : [\\\"所属地区\\\",\\\"国家\\\", \\\"地区\\\", \\\"地理位置\\\", \\\"位于\\\", \\\"区域\\\", \\\"地点\\\", \\\"总部地点\\\", \\\"所在地\\\", \\\"所在区域\\\", \\\"位于城市\\\", \\\"总部位于\\\", \\\"酒店位于\\\", \\\"学校位于\\\", \\\"最早位于\\\", \\\"地址\\\", \\\"所在城市\\\", \\\"城市\\\", \\\"主要城市\\\", \\\"坐落于\\\"],\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"    # place  7\\n\",\n    \"    \\\"pl2per_main_character\\\" : [\\\"相关人物\\\", \\\"知名人物\\\", \\\"历史人物\\\"],\\n\",\n    \"\\n\",\n    \"    \\\"pl2pl_location\\\" : [\\\"所属地区\\\",\\\"所属国\\\", \\\"所属洲\\\", \\\"所属州\\\", \\\"所属国家\\\", \\\"最大城市\\\", \\\"地区\\\", \\\"地理位置\\\", \\\"位于\\\", \\\"区域\\\", \\\"地点\\\", \\\"总部地点\\\", \\\"所在地\\\", \\\"所在区域\\\", \\\"位于城市\\\", \\\"总部位于\\\", \\\"酒店位于\\\", \\\"学校位于\\\", \\\"最早位于\\\", \\\"地址\\\", \\\"所在城市\\\", \\\"城市\\\", \\\"主要城市\\\", \\\"坐落于\\\"],\\n\",\n    \"    \\\"pl2pl_adjacement\\\" : [\\\"毗邻\\\", \\\"东邻\\\", \\\"邻近行政区\\\", \\\"相邻\\\", \\\"紧邻\\\", \\\"邻近\\\", \\\"北邻\\\", \\\"南邻\\\", \\\"邻国\\\"],\\n\",\n    \"    \\\"pl2pl_contains\\\" : [\\\"包含\\\", \\\"包含国家\\\", \\\"包含人物\\\", \\\"下辖地区\\\", \\\"下属\\\"],\\n\",\n    \"    \\\"pl2pl_captial\\\" : [\\\"首都\\\"],\\n\",\n    \"\\n\",\n    \"    \\\"pl2org_sights\\\" : [\\\"著名景点\\\", \\\"主要景点\\\", \\\"旅游景点\\\", \\\"特色景点\\\"],\\n\",\n    \"    \\\"pl2oth_climate\\\" : [\\\"气候类型\\\", \\\"气候条件\\\", \\\"气候\\\", \\\"气候带\\\"],\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"name = {\\n\",\n    \"    \\\"per2per_family_members\\\": \\\"/人物/人物/家庭成员\\\",\\n\",\n    \"    \\\"per2per_social_members\\\": \\\"/人物/人物/社交关系\\\",\\n\",\n    \"\\n\",\n    \"    \\\"per2pl_birth_place\\\": \\\"/人物/地点/出生地\\\",\\n\",\n    \"    \\\"per2pl_live_place\\\": \\\"/人物/地点/居住地\\\",\\n\",\n    \"    \\\"per2pl_country\\\" : \\\"/人物/地点/国籍\\\",\\n\",\n    \"    \\\"per2org_graduate_from\\\": \\\"/人物/组织/毕业于\\\",\\n\",\n    \"    \\\"per2org_belong_to\\\": \\\"/人物/组织/属于\\\",\\n\",\n    \"\\n\",\n    \"    \\\"per2oth_profession\\\": \\\"/人物/其它/职业\\\",\\n\",\n    \"    \\\"per2oth_nation\\\": \\\"/人物/其它/民族\\\",\\n\",\n    \"\\n\",\n    \"    \\\"org2per_owner\\\": \\\"/组织/人物/拥有者\\\",\\n\",\n    \"    \\\"org2per_founder\\\": \\\"/组织/人物/创始人\\\",\\n\",\n    \"    \\\"org2per_school_leader\\\": \\\"/组织/人物/校长\\\",\\n\",\n    \"    \\\"org2per_leader\\\": \\\"/组织/人物/领导人\\\",\\n\",\n    \"\\n\",\n    \"    \\\"org2org_surroundings\\\": \\\"/组织/组织/周边\\\",\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"    \\\"org2pl_location\\\": \\\"/组织/地点/位于\\\",\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"    \\\"pl2per_main_character\\\": \\\"/地点/人物/相关人物\\\",\\n\",\n    \"\\n\",\n    \"    \\\"pl2pl_location\\\": \\\"/地点/地点/位于\\\",\\n\",\n    \"    \\\"pl2pl_adjacement\\\": \\\"/地点/地点/毗邻\\\",\\n\",\n    \"    \\\"pl2pl_contains\\\": \\\"/地点/地点/包含\\\",\\n\",\n    \"    \\\"pl2pl_captial\\\": \\\"/地点/地点/首都\\\",\\n\",\n    \"\\n\",\n    \"    \\\"pl2org_sights\\\": \\\"/地点/组织/景点\\\",\\n\",\n    \"    \\\"pl2oth_climate\\\": \\\"/地点/其它/气候\\\"\\n\",\n    \"}\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def check(string, tgs):\\n\",\n    \"    for t in tgs:\\n\",\n    \"        if t in string:\\n\",\n    \"            return True\\n\",\n    \"    return False\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"collapsed\": true\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# 标注\\n\",\n    \"processed_data = []\\n\",\n    \"for can in data:\\n\",\n    \"    # can  [sentence, head, tail, segment]\\n\",\n    \"    sentence, head, tail, segment = can\\n\",\n    \"    if tail not in entities[can[1]].values():\\n\",\n    \"        can.append('NA')\\n\",\n    \"        processed_data.append(can)\\n\",\n    \"        continue\\n\",\n    \"    rel = ''\\n\",\n    \"    for key, value in entities[can[1]].items():\\n\",\n    \"        if value==tail:\\n\",\n    \"            rel = key\\n\",\n    \"    \\n\",\n    \"    for key, value in config.items():\\n\",\n    \"        if rel in value:\\n\",\n    \"            tp = key.split('_')[0].split('2')\\n\",\n    \"            if check(entities[can[1]].get('BaiduTAG', \\\"\\\"), tgs[tp[0]]):\\n\",\n    \"                if tp[1]=='oth':\\n\",\n    \"                    can.append(name[key])\\n\",\n    \"                    processed_data.append(can)\\n\",\n    \"                elif check(entities[can[2]].get('BaiduTAG', \\\"\\\"), tgs[tp[1]]):\\n\",\n    \"                    can.append(name[key])\\n\",\n    \"                    processed_data.append(can)\\n\",\n    \"            break\\n\",\n    \"len(processed_data)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"collapsed\": true\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"e2id = {}\\n\",\n    \"count = 0\\n\",\n    \"e_set = set()\\n\",\n    \"for i in processed_data:\\n\",\n    \"    e_set.add(i[1])\\n\",\n    \"    e_set.add(i[2])\\n\",\n    \"for e in e_set:\\n\",\n    \"    e2id[e] = count\\n\",\n    \"    count += 1\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"collapsed\": true\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"total_data = []\\n\",\n    \"for d in processed_data:\\n\",\n    \"    total_data.append({\\n\",\n    \"        'head':{\\n\",\n    \"            'word': d[1],\\n\",\n    \"            'id': str(e2id[d[1]])\\n\",\n    \"        },\\n\",\n    \"        'relation': d[-1],\\n\",\n    \"        'tail': {\\n\",\n    \"            'word': d[2],\\n\",\n    \"            'id': str(e2id[d[2]])\\n\",\n    \"        },\\n\",\n    \"        'sentence': ' '.join(d[-2]),\\n\",\n    \"        'ori_sen': d[0],\\n\",\n    \"        'sen_seg': d[-2]\\n\",\n    \"    })\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"collapsed\": true\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"rl = {}\\n\",\n    \"for i in total_data:\\n\",\n    \"    rl[i['relation']] = rl.get(i['relation'], 0) + 1\\n\",\n    \"\\n\",\n    \"trl = {}\\n\",\n    \"record = {}\\n\",\n    \"for k,v in rl.items():\\n\",\n    \"    trl[k] = int(max(math.floor(v*0.25), 1))\\n\",\n    \"    record[k] = 0\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"collapsed\": true\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"train = []\\n\",\n    \"test = []\\n\",\n    \"for i in total_data:\\n\",\n    \"    if record[i['relation']]<=trl[i['relation']]:\\n\",\n    \"        test.append(i)\\n\",\n    \"        record[i['relation']] += 1\\n\",\n    \"    else:\\n\",\n    \"        train.append(i)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"if not os.path.isdir('../data'):\\n\",\n    \"    os.mkdir('../data')\\n\",\n    \"if not os.path.isdir('../data/chinese'):\\n\",\n    \"    os.mkdir('../data/chinese')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"collapsed\": true\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"with open('../data/chinese/train.json', 'w', encoding='utf8') as f:\\n\",\n    \"    json.dump(train, f)\\n\",\n    \"with open('../data/chinese/test.json', 'w', encoding='utf8') as f:\\n\",\n    \"    json.dump(test, f)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"collapsed\": true\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"count = 0\\n\",\n    \"r2id = {}\\n\",\n    \"for k in list(rl.keys()):\\n\",\n    \"    r2id[k] = count\\n\",\n    \"    count +=1\\n\",\n    \"with open('../data/chinese/rel2id.json', 'w', encoding='utf8') as f:\\n\",\n    \"    json.dump(r2id, f)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## 训练词向量\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"collapsed\": true\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"senlist = []\\n\",\n    \"for d in data:\\n\",\n    \"    senlist.append(d[-1])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"model = Word2Vec(senlist, sg=5, min_count=1, size=50, workers=4)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"collapsed\": true\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"w2v = {}\\n\",\n    \"for i in model.wv.index2word:\\n\",\n    \"    w2v[i] = model[i]\\n\",\n    \"len(w2v)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"collapsed\": true\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"new_w2v = []\\n\",\n    \"for word, vec in w2v.items():\\n\",\n    \"    new_w2v.append({\\n\",\n    \"        'word': word,\\n\",\n    \"        'vec': [float(i) for i in vec]\\n\",\n    \"    })\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"collapsed\": true\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"with open('../data/chinese/word_vec.json', 'w', encoding='utf8') as f:\\n\",\n    \"    json.dump(new_w2v, f)\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python [default]\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.6.8\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "kg_data/data_process.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import pickle\\n\",\n    \"import json\\n\",\n    \"import re\\n\",\n    \"import os\\n\",\n    \"import math\\n\",\n    \"from tqdm import tqdm\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"1.读取初始数据\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"data = []\\n\",\n    \"with open('baike_triples.txt', 'r', encoding='utf8') as f:\\n\",\n    \"    for line in tqdm(f):\\n\",\n    \"        data.append(line.strip().split('\\\\t'))\\n\",\n    \"print(len(data), data[0])\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"2.保留元素全为中文的三元组\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"all_chinese = re.compile('^[\\\\u4e00-\\\\u9fa5]*$')\\n\",\n    \"new_data = []\\n\",\n    \"for triple in tqdm(data):\\n\",\n    \"    if bool(re.search(all_chinese, triple[0])) and bool(re.search(all_chinese, triple[1])) and bool(re.search(all_chinese, triple[2])):\\n\",\n    \"        new_data.append(triple)\\n\",\n    \"len(new_data)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"3.实体字典\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"scrolled\": true\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"entities = {}\\n\",\n    \"for triple in tqdm(new_data):\\n\",\n    \"    entities[triple[0]] = entities.get(triple[0], {})\\n\",\n    \"    entities[triple[0]][triple[1]] = triple[2]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"if not os.path.isdir('processed'):\\n\",\n    \"    os.mkdir('processed')\\n\",\n    \"with open('processed/entities.pkl', 'wb') as f:\\n\",\n    \"    pickle.dump(entities, f)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"e_set = set()\\n\",\n    \"with open('processed/entity.txt', 'w', encoding='utf8') as f:\\n\",\n    \"    for e in tqdm(entities.keys()):\\n\",\n    \"        if e not in e_set:\\n\",\n    \"            e_set.add(e)\\n\",\n    \"            f.write(e+'\\\\n')\\n\",\n    \"len(e_set)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"4.获取句子集合\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def process_sentence(sent):\\n\",\n    \"    sent = re.sub(r' +', ' ', sent)\\n\",\n    \"    sent = re.sub(r'[^\\\\u4e00-\\\\u9fa5,\\\\?\\\\!，。？：:！、；\\\\(\\\\)（） ]', '', sent)\\n\",\n    \"    sent = re.split(r'[\\\\?\\\\!。？！]', sent)\\n\",\n    \"    return sent\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"sents = []\\n\",\n    \"for triple in tqdm(data):\\n\",\n    \"    if triple[1]=='BaiduCARD':\\n\",\n    \"        sent = process_sentence(triple[2])\\n\",\n    \"        if sent != \\\"\\\":\\n\",\n    \"            sents.append(sent)\\n\",\n    \"print(len(sents), sents[0])\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"5.句子存为100个文件\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"if not os.path.isdir('processed/sentences'):\\n\",\n    \"    os.mkdir('processed/sentences')\\n\",\n    \"\\n\",\n    \"    begin = 0\\n\",\n    \"count = math.ceil(len(sents)*1.0/100)\\n\",\n    \"end = min(begin+count, len(sents))\\n\",\n    \"for i in tqdm(range(100)):\\n\",\n    \"    with open('processed/sentences/sen'+str(i), 'wb') as f:\\n\",\n    \"        pickle.dump(sents[begin:end], f)\\n\",\n    \"    begin = end\\n\",\n    \"    end = min(begin+count, len(sents))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python [default]\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.6.8\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "kg_data/stop_word.txt",
    "content": "\n可能\n起\n便于\n有些\n上述\n纯粹\n尽然\n乃至于\n极为\n也\n并不\n其次\n矣哉\n为何\n到头来\n此外\n［①ｅ］\n8\n就是了\n故此\n移动\n旁人\n立刻\n路经\n使\n元／吨\n刚\n归齐\n归根结底\n!\n〕\n迅速\n获得\n４\n已\n冒\n哎\n已矣\n哪天\n这么点儿\n哪样\n一些\n甚且\n起头\n大\n默默地\n⑩\n跟\n逐步\n它\n拦腰\n不仅仅\n里面\n同\n设或\n要么\n以後\n—\n）÷（１－\n扩大\n它的\n精光\n归\n战斗\n成为\n谁知\n这般\n此后\n尽管如此\n或多或少\n｛－\n替\n照\n千万\n人\n连日\n这么些\n广泛\n保险\n到头\n把\n↑\n前后\n最近\n他\n瑟瑟\n一片\n正值\n没有\n而且\n云云\n正常\n快要\n从新\n了\n窃\n光是\n长此下去\n挨家挨户\n届时\n而论\n大批\n“\n一面\n之前\n大多数\n较比\n并肩\n［⑧］\n昂然\n各位\n［①④］\n正是\n在\n自\n活\n最后\n那些\n不免\n［①⑨］\n设若\n立即\n蛮\n尽心竭力\n联系\n叫做\n具有\n宣布\n既\n毕竟\n那般\n先生\n啊呀\n表明\n．\n并没有\n切莫\n一番\n倒是\n［②⑤］\n以外\n一来\n不够\n赶快\n二\n另行\n尽可能\n嗡嗡\n另\n格外\n［②ａ］\n毫无保留地\n彻底\n不独\n经过\n大家\n即令\n愤然\n于\n［③ｈ］\n［①ｇ］\n个别\n接连不断\n逐渐\n$\n简而言之\n呜呼\n适应\n譬如\n某某\n庶几\n背地里\n具体地说\n自打\n仍旧\n突然\n趁热\n要求\n＞\n怪\n常言说得好\n哎呀\n临到\n强烈\n”\n－－\n此次\n不可抗拒\n另外\n能\n\"\n凑巧\n不仅\n唉\n傥然\n穷年累月\n.一\n为\n充其极\n敢于\n＇\n单纯\n起见\n除了\n如何\n＝｛\n来说\n企图\n嘿\n必将\n毋宁\n有所\n那样\n均\n次第\n挨个\n失去\n最好\n上面\n能够\n举行\n它是\n出现\n叮当\n挨次\n…………………………………………………③\n看来\n动辄\n趁机\n亲自\n才能\n亲手\n与\n亲眼\n有的是\n不定\n但愿\n由\n内\n下面\n当前\n必须\n全都\n共\n这么\n率然\n不日\n最\n--\n不足\n几度\n不如\n凭借\n猛然\n看起来\n俺们\n暗地里\n绝非\n［②ｊ］\n而又\n依靠\n如同\n紧接着\n日见\n［②ｅ］\n古来\n依\n即若\n［④ｅ］\n多\n看看\n那么些\n怎么\n而是\n各级\n\\\n以来\n如常\n非徒\n对比\n以及\n继后\n所有\n＄\n不但...而且\n何处\n才\n清楚\n居然\n难说\n或许\n除此\n按时\n}\n:\n成年\n来着\n别处\n何时\n喽\n继之\n除去\n成年累月\n［③ｄ］\n变成\n一何\n每\n除开\n除\n本身\n偏偏\n而\n饱\n别是\n只怕\n哪些\n一时\n倒不如说\n到处\n本地\n反过来\n@\n自己\n──\n成心\n策略地\n平素\n①\nφ\n全面\n粗\n等\n然后\n为什么\n什麽\n必\n般的\n哩\n任凭\n难怪\n权时\nexp\n何苦\n一转眼\n出于\n只限\n假如\n至若\n藉以\n只\n突出\n）\n常言道\n6\n适用\n无宁\n任\n陈年\n理应\n共同\n以故\n简言之\n当真\n欤\n要\n各\n并不是\n结合\n安全\n莫若\n常\n矣\n按说\n砰\n因\n不外\n允许\n不问\n当庭\n望\n［②ｉ］\n的确\n多数\n咱们\n若\n尚且\n高低\n继而\n这点\nⅢ\n不管怎样\n只消\n皆可\n何以\n与其\n＋ξ\n各个\n暗自\n［①ｈ］\n以至于\n惯常\n有\n这\n相当\n从速\n以为\n岂但\n不外乎\n焉\n直接\n［⑦］\n仅仅\n通过\n也是\n差一点\n恰似\n尔后\n进入\n除却\n彼时\n果真\n赖以\n但凡\n所在\n是以\n尽早\n真是\n每当\n出去\n万\nsup\n一旦\n一切\n相反\n即便\n也好\n截至\n无论\n朝着\n_\n；\n八\n真正\n趁便\n保管\n要不\n多么\n啷当\n咦\n反映\n｜\n两者\n综上所述\n于是乎\n充分\n呀\n属于\n诸\n截然\n倘若\n［⑥］\n从此\n略\n明显\n互\n齐\n之类\n难得\n具体\n恍然\n大体\n时候\n趁着\n遇到\n按\n老\n再有\n引起\n7\n若非\n决不\n［①Ｂ］\n距\n高兴\n不可\n即刻\n绝对\n运用\n难道\n极端\n非常\n［①⑥］\n认真\n末##末\n哟\n而况\n其他\n针对\n历\n大事\n［③ｂ］\n类如\n何妨\n无法\n％\n总的说来\n同一\n／\n：\n上来\n今\n达到\n或者\n接下来\n呼啦\n比方\n说明\n［③ｅ］\n各地\n谨\n据此\n各式\n伟大\n传\n既往\n２\nｂ］\n不得\n不敢\n是\n这会儿\n一则通过\n呐\n近年来\n之一\n下列\n比及\n防止\n么\n第二\n６\n何\n［①③］\n最大\n从未\n何须\n颇\n整个\n不惟\n等等\n殆\n如此\n此间\n争取\n觉得\n＋\n直到\n得起\n一个\n越是\n自个儿\n即使\n［②⑧］\n独\n或是\n往往\n率尔\n零\n｛\n［③①］\n像\n相信\n较之\n纵令\n尤其\n乘\n好象\n反倒是\n随著\n一样\n是不是\n绝不\n更加\n哪边\n、\n从古至今\n日渐\n将近\n｝＞\n下\n偶尔\n［①⑧］\n除此以外\n呕\n一方面\n当口儿\n良好\n亲口\n待到\n———\n加入\n犹且\n且不说\n即如\n经\n——\n即是说\n与其说\n－\n不可开交\n方\n（\n×××\n亲身\n据\n要不然\n到底\n复杂\n比如说\n不时\n莫不\n抑或\n～±\n哪年\n本人\n啦\n奇\n固\n这样\n来\n~\n③\n不久\n不限\n不对\n［③ｇ］\n＋＋\n您们\n显然\n总结\n风雨无阻\n如上所述\n｝\n以免\n虽则\n部分\n简直\n打开天窗说亮话\n极\n哈哈\n遵循\n介于\n’‘\n现代\n-\n甚或\n由是\n亦\n吱\n嘛\n果然\n２．３％\n这里\n假使\n根本\n据称\n达旦\n［①⑦］\n交口\n特点\n&\n咱\n正在\n＝☆\n个人\n马上\n毫无\n嘎嘎\n巨大\n开始\n几经\n因而\n②\n不拘\n`\n不是\n单\n除此之外\n冲\n更为\n［②Ｂ］\n这一来\n全部\n总的来说\n联袂\n或\n其二\n偶而\n其\n必然\n［②⑥］\n…\n人家\n大体上\n保持\n六\n总是\n不光\n换言之\n这个\n尔尔\n且\n广大\n四\n近几年来\n哉\n多多益善\n加之\n断然\n每年\n多年前\n再者\n所\n三\n碰巧\n［⑤ｆ］\n即或\n罢了\n起首\n八成\n多多少少\n暗中\n日益\n喂\n沿着\n绝顶\n主张\n所谓\n哈\n容易\n既…又\n将要\n不论\n各自\n甚么\n倘然\n某\n迄\n那麽\n有及\n－β\n再\n比\n从小\n她的\n局外\n该当\n一定\n不止\n倘\n敢\n可是\n省得\n丰富\n１．\n你是\n吧\n［③］\n随后\n注意\n－［＊］－\n何必\n猛然间\n……\n怎\n一起\n三天两头\n知道\n宁可\n从事\n相应\n如若\n少数\n这时\n为了\n啪达\n⑥\n据悉\n忽地\nφ．\n顷刻\n〔\n㈧\n进而\n使得\n［①ｃ］\n原来\n赶\n老老实实\n不妨\n可以\n0\n默然\n可好\n心里\n自各儿\n乘势\n彻夜\n密切\n５：０\n哦\n）、\n︿\n一一\n应用\n特别是\n相似\n不若\n怎么办\n双方\n就\n从头\n故\n二话没说\n是的\n普通\n比如\n［］\n【\n连袂\n取道\n按理\n起初\n帮助\n与否\n不同\n七\n哼唷\n大张旗鼓\n孰料\n啊\n需要\n实现\n还要\n〈\n挨门挨户\n存心\n宁肯\n一边\n凝神\n顷刻之间\n先后\n［③ａ］\n孰知\n(\n还有\n老是\n就是\n完成\n抽冷子\n可\n大凡\n如期\n做到\n因此\n最後\n若是\n急匆匆\n范围\n大都\n理该\n［③ｃ］\n|\n够瞧的\n这边\n重新\n［①Ｃ］\n→\n沿\n不巧\n极度\n不满\n而已\n有点\nａ］\n好\n哪儿\n犹自\nμ\n宁愿\n［①ａ］\n不止一次\n形成\n您是\n在于\n?\n下去\n川流不息\n主要\n后\n从宽\n为着\n吓\n满足\nＡ\n勃然\n尔等\n它们的\n敞开儿\n连声\n牢牢\n到目前为止\n根据\n不能\n从无到有\n过\n确定\n怎样\n出\n已经\n５\n总的来看\n？\n行为\n［①ｄ］\n固然\n到\n归根到底\n从今以后\n并且\n此时\n不然\n⑨\n除非\n#\n2\n不下\n然\n较\n℃\n纯\n凭\n对应\n扑通\n遭到\n每逢\n是否\n考虑\n大约\n为什麽\n不怎么\n比照\n去\n彼此\n怎奈\n梆\n［②⑩］\n]\n有著\n有的\n何乐而不为\n相等\n儿\n［②④\n恰如\n打从\n不怕\n］［\n强调\n不亦乐乎\n至\n『\n不力\n不了\n不过\n轰然\n一般\n上\n恐怕\n相对而言\n今後\n二话不说\n恰恰\n取得\n接著\n一直\n￥\n差不多\n［①⑤］\n后来\n＜φ\n9\n③］\n既然\n啊哈\n哪个\n诸位\n因为\n所幸\n之所以\n哗\n嘿嘿\n贼死\n』\n自家\n大抵\n＜Δ\n屡屡\n采取\n别人\n转动\n不得不\n拿\n那末\n不再\n顿时\n全身心\n顺\n全年\n4\n转贴\n其后\n嗯\n别管\n出来\n处处\n背靠背\n呸\n行动\n［③Ｆ］\n另方面\nＲ．Ｌ．\n余外\n照着\n不得了\n严重\n他是\n［②ｇ］\n况且\n白\n种\n待\n得到\n呵呵\n那\n限制\n及其\n［①②］\n3\nsub\n从中\n决定\nA\n一\n＠\n》\n被\n似乎\n对方\n得出\n相对\n打\n大量\n彼\n另悉\n不一\n迟早\n［①Ｄ］\n［⑩］\n小\n某些\n用\n将才\n［④ｄ］\n一则\n岂止\n［⑤ｂ］\n长话短说\n⑤\nLex\n此中\n向使\n［\n放量\n尽快\n切\n不成\n任何\n该\n乒\n构成\n离\n当儿\n掌握\n［＊］\n若夫\n现在\n并无\n基于\n以致\n假若\n甚而\n咋\n矣乎\n中间\n从而\n严格\n请勿\n//\n不比\n维持\n［②ｂ］\n即\n不胜\n哇\n累次\n莫\n三番两次\n有关\n恰恰相反\n不起\n不经意\n人民\n＜＜\n除外\n兮\n上下\n://\n但\n切不可\n然後\n难道说\n不至于\n别说\n哪\n尽量\n完全\n;\n方面\n借\n光\n不但\n［②ｆ］\n社会主义\n到了儿\n至今\n第\n上去\n漫说\n从重\n可见\n我\n乘隙\n准备\n反之\n无\n後面\n一.\n处理\n较为\n很少\n［④ｃ］\n喏\n尽管\n进来\n另一个\n而言\n普遍\n有着\n［⑨］\n...\n附近\n恰逢\n千万千万\n极了\n能否\n绝\n表示\n巴巴\n因了\n［②③］\n)\n嗡\n向\n［②⑦］\n以下\n挨门逐户\n任务\n此\n竟\n[\n这么样\n另一方面\n之后\n进行\n［①ｆ］\n从严\n说来\n之\n单单\n当中\n正如\n·\n故意\n专门\n恰巧\n大大\n敢情\n.日\nｃ］\n隔夜\n大致\n］∧′＝［\n来看\n的\n好的\n没奈何\n则甚\n或则\n换句话说\n那么\n按照\n了解\n呢\n万一\n最高\n竟而\n再则\nΨ\n过去\n大多\n存在\n略为\n并没\n岂非\n臭\n同样\n非独\n由于\n［②ｈ］\n合理\n反过来说\n甭\n不已\n对于\n不曾\n比起\n促进\n怎么样\n经常\n然则\n为止\n仅\n８\n乘胜\n独自\n极力\n鉴于\n看样子\n】\n分期分批\n当场\n加以\n深入\n产生\n日复一日\n总之\n二来\n则\n欢迎\n让\n呼哧\n首先\n接着\n诚然\n屡次\n［②］\n巩固\n却\n而后\n０：２\n'\n乌乎\n≈\n快\n前此\n也罢\n=\n多年来\n前面\n三番五次\n呵\n从古到今\n>\n不仅...而且\n却不\n者\n竟然\n白白\n＃\n的话\n实际\n分头\n哎哟\n同时\n就要\n必要\n每时每刻\n吧哒\n奈\n+\n见\n不变\n要是\n并非\n间或\n长线\n人们\n恰好\n赶早不赶晚\n立马\n趁\n如此等等\n［⑤ｄ］\n如下\n起先\n尽如人意\n看见\n看\n不仅仅是\n或曰\n比较\n慢说\n常常\n庶乎\n*\n看上去\n随时\n兼之\n［⑤］\n基本上\n概\n左右\n加上\n积极\n近\n＝″\n，也\n但是\n乃\n还\n来得及\n共总\n代替\n每个\n再说\n极其\n得\n当然\n动不动\n倒不如\n以\n奋勇\n［④ａ］\n曾\n你的\n何尝\n从此以后\n几乎\n为主\n有利\n具体说来\n只要\n嘎\n就此\n得天独厚\n当\n如其\n鄙人\n老大\n没\n立时\n长期以来\n乘机\n汝\n不会\n顷刻间\n显著\n坚决\n全然\n＝（\n谁\n分期\n喀\n切切\n自从\n否则\n啥\n〉\n随\n正巧\n看到\n前进\n而外\n嗳\n尔\n顺着\n虽\n即将\n之後\n方才\n别的\n传说\n一天\n半\n怎麽\n互相\n每天\n那里\n弹指之间\n及至\n叫\n不尽\n此地\n此处\n今年\n反而\n［③⑩］\n上升\n不只\n适当\n顷\n他的\n纵然\n以前\n人人\n。\n∈［\n也就是说\n充其量\n来自\n俺\n如上\n据我所知\n论说\n基本\n方能\n我们\n......\n几时\n［①ｉ］\n［⑤ｅ］\n不由得\n当地\n趁势\n自身\n［－\n阿\n本着\n练习\n＜±\n据实\n这些\n通常\n哪怕\n《\n这儿\n［②Ｇ］\n大面儿上\n这种\n你们\n以后\n些\n切勿\n来讲\n［①Ａ］\n中小\n隔日\n不消\n那个\n不常\n.数\n▲\n只是\n她是\n１２％\nｅ］\n全体\n不然的话\n［⑤ａ］\n转变\n》），\n［④ｂ］\n进步\n从来\n着呢\n从轻\n［①ｏ］\n如是\n由此可见\n得了\n..\n吗\n如今\n～＋\n从优\n对待\n并排\n诸如\n公然\n五\n宁\n和\n这麽\n各人\n例如\n趁早\n便\n就算\n认识\n借此\n当时\n<\n然而\n哪里\n非特\n∪φ∈\n＿\n对\n纵使\n多多\n不能不\n咧\n只当\n今天\n■\n从不\n及时\n［②②］\n不择手段\n乘虚\n咚\n我是\n往\n多次\n豁然\n有效\n以至\n很\n还是\n＜λ\n那边\n等到\n理当\n随着\n［②ｃ］\n今后\n靠\n并\n大举\n给\n故而\n决非\n凡是\n说说\n她\n又及\n意思\n>>\n明确\n开展\n＜\n３\n数/\n仍然\n他们\n千\nｆ］\n弗\n必定\n如果\n不得已\n将\n至于\nγ\n既是\n为此\n其实\n重大\n］\n非但\nＺＸＦＩＴＬ\n［①］\n倍感\n几番\n设使\n啊哟\n究竟\n论\n致\n刚巧\n反手\n９\n初\n近来\n过于\n总而言之\n谁人\n作为\n莫不然\n”，\n，\n～\n毫不\n先不先\n都\n再次\n其它\n就是说\n什么样\n立地\n④\n话说\n呃\n不尽然\n一下\n②ｃ\n呆呆地\n特殊\n据说\n有力\n×\n除此而外\n管\n...................\n进去\n逢\n屡次三番\n毫无例外\n尽心尽力\n因着\n如前所述\n〕〔\n嘘\n带\n关于\n她们\n不管\n当着\n一次\n’\n朝\n加强\n嗬\n凡\n大力\n方便\n避免\n使用\n尽\n分别\n起来\n莫非\n‘\n十分\n大略\n多亏\n::\n会\n遵照\n不要\n们\n向着\n相同\n不特\n不大\n某个\n如\n累年\n［②\n又\n眨眼\nｎｇ昉\n其一\n坚持\n满\n%\n不少\n具体来说\n纵\n略微\n造成\n嘻\n其余\n譬喻\n多少\n以便\n依据\n后面\n地\n後来\n虽说\n过来\n~~~~\n日臻\n集中\n不断\n于是\n由此\n大概\n唯有\n反之则\n忽然\n什么\n喔唷\n刚才\n连\n迫于\n顶多\n与此同时\n己\n挨着\n有时\n./\n叮咚\n岂\n只有\n1\n应该\n5\n乃至\n前者\n临\n借以\n召开\n陡然\n＝\n其中\n＝［\n哼\n０\n及\n非得\n受到\n您\n［②①］\n何止\n那儿\n姑且\n以期\n看出\n哗啦\n嘎登\n反应\n＊\n曾经\n诚如\n/\n每每\n九\n腾\n从早到晚\n不知不觉\n反之亦然\n若果\nВ\n一致\n再者说\n开外\n本\n怕\n连同\n愿意\n始而\n应当\n目前\n倍加\n甚至\n年复一年\n虽然\n既...又\n」\n更\nＬＩ\n莫如\n来不及\n沙沙\n这就是说\n⑧\n先後\n［⑤］］\n你\n乎\n那会儿\n呗\n甚至于\n甫\n重要\n反倒\n仍\n组成\n怪不得\n更进一步\n^\n如次\n倘使\n［④］\n似的\n云尔\n后者\n用来\n认为\n规定\n好在\n别\n从\n倘或\n几\n.\n当下\n要不是\n个\n⑦\n结果\n′｜\n它们\n＆\n着\n再其次\n谁料\n许多\n１\n啐\n我的\n当头\n在下\n串行\n大不了\n常言说\n缕缕\n依照\n刚好\n不料\n不必\n＝－\n屡\n不单\n问题\n立\n按期\n呜\n不\n巴\nΔ\n自后\n替代\n＞λ\n举凡\n所以\n咳\n他人\n且说\n惟其\n处在\n传闻\n！\n不迭\n继续\n,\n定\n就地\n连连\n伙同\n这次\n何况\n［②ｄ］\n全力\n７\n′∈\n下来\n极大\n［①Ｅ］\n那时\n各种\n当即\n边\n略加\n周围\n那么样\n［①①］\n连日来\n匆匆\n以上\n很多\n"
  },
  {
    "path": "nrekit/data_loader.py",
    "content": "from six import iteritems\n\nimport json\nimport os\nimport multiprocessing\nimport numpy as np\nimport random\n\nclass file_data_loader:\n    def __next__(self):\n        raise NotImplementedError\n    \n    def next(self):\n        return self.__next__()\n\n    def next_batch(self, batch_size):\n        raise NotImplementedError\n\nclass npy_data_loader(file_data_loader):\n    MODE_INSTANCE = 0      # One batch contains batch_size instances.\n    MODE_ENTPAIR_BAG = 1   # One batch contains batch_size bags, instances in which have the same entity pair (usually for testing).\n    MODE_RELFACT_BAG = 2   # One batch contains batch size bags, instances in which have the same relation fact. (usually for training).\n\n    def __iter__(self):\n        return self\n\n    def __init__(self, data_dir, prefix, mode, word_vec_npy='vec.npy', shuffle=True, max_length=120, batch_size=160):\n        if not os.path.isdir(data_dir):\n            raise Exception(\"[ERROR] Data dir doesn't exist!\")\n        self.mode = mode\n        self.shuffle = shuffle\n        self.max_length = max_length\n        self.batch_size = batch_size\n        self.word_vec_mat = np.load(os.path.join(data_dir, word_vec_npy))\n        self.data_word = np.load(os.path.join(data_dir, prefix + \"_word.npy\")) \n        self.data_pos1 = np.load(os.path.join(data_dir, prefix + \"_pos1.npy\")) \n        self.data_pos2 = np.load(os.path.join(data_dir, prefix + \"_pos2.npy\")) \n        self.data_mask = np.load(os.path.join(data_dir, prefix + \"_mask.npy\")) \n        self.data_rel = np.load(os.path.join(data_dir, prefix + \"_label.npy\")) \n        self.data_length = np.load(os.path.join(data_dir, prefix + \"_len.npy\")) \n        self.scope = np.load(os.path.join(data_dir, prefix + \"_instance_scope.npy\"))\n        self.triple = np.load(os.path.join(data_dir, prefix + \"_instance_triple.npy\"))\n        self.relfact_tot = len(self.triple)\n        for i in range(self.scope.shape[0]):\n            self.scope[i][1] += 1\n\n        self.instance_tot = self.data_word.shape[0]\n        self.rel_tot = 53\n\n        if self.mode == self.MODE_INSTANCE:\n            self.order = list(range(self.instance_tot))\n        else:\n            self.order = list(range(len(self.scope)))\n        self.idx = 0\n\n        if self.shuffle:\n            random.shuffle(self.order) \n\n        print(\"Total relation fact: %d\" % (self.relfact_tot))\n\n    def __next__(self):\n        return self.next_batch(self.batch_size)\n\n    def next_batch(self, batch_size):\n        if self.idx >= len(self.order):\n            self.idx = 0\n            if self.shuffle:\n                random.shuffle(self.order) \n            raise StopIteration\n\n        batch_data = {}\n\n        if self.mode == self.MODE_INSTANCE:\n            idx0 = self.idx\n            idx1 = self.idx + batch_size\n            if idx1 > len(self.order):\n                self.idx = 0\n                if self.shuffle:\n                    random.shuffle(self.order) \n                raise StopIteration\n            self.idx = idx1\n            batch_data['word'] = self.data_word[idx0:idx1]\n            batch_data['pos1'] = self.data_pos1[idx0:idx1]\n            batch_data['pos2'] = self.data_pos2[idx0:idx1]\n            batch_data['rel'] = self.data_rel[idx0:idx1]\n            batch_data['length'] = self.data_length[idx0:idx1]\n            batch_data['scope'] = np.stack([list(range(idx1 - idx0)), list(range(1, idx1 - idx0 + 1))], axis=1)\n        elif self.mode == self.MODE_ENTPAIR_BAG or self.mode == self.MODE_RELFACT_BAG:\n            idx0 = self.idx\n            idx1 = self.idx + batch_size\n            if idx1 > len(self.order):\n                self.idx = 0\n                if self.shuffle:\n                    random.shuffle(self.order) \n                raise StopIteration\n            self.idx = idx1\n            _word = []\n            _pos1 = []\n            _pos2 = []\n            _rel = []\n            _ins_rel = []\n            _multi_rel = []\n            _length = []\n            _scope = []\n            _mask = []\n            cur_pos = 0\n            for i in range(idx0, idx1):\n                _word.append(self.data_word[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])\n                _pos1.append(self.data_pos1[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])\n                _pos2.append(self.data_pos2[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])\n                _rel.append(self.data_rel[self.scope[self.order[i]][0]])\n                _ins_rel.append(self.data_rel[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])\n                _length.append(self.data_length[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])\n                _mask.append(self.data_mask[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])\n                bag_size = self.scope[self.order[i]][1] - self.scope[self.order[i]][0]\n                _scope.append([cur_pos, cur_pos + bag_size])\n                cur_pos = cur_pos + bag_size\n                if self.mode == self.MODE_ENTPAIR_BAG:\n                    _one_multi_rel = np.zeros((self.rel_tot), dtype=np.int32)\n                    for j in range(self.scope[self.order[i]][0], self.scope[self.order[i]][1]):\n                        _one_multi_rel[self.data_rel[j]] = 1\n                    _multi_rel.append(_one_multi_rel)\n            batch_data['word'] = np.concatenate(_word)\n            batch_data['pos1'] = np.concatenate(_pos1)\n            batch_data['pos2'] = np.concatenate(_pos2)\n            batch_data['rel'] = np.stack(_rel)\n            batch_data['ins_rel'] = np.concatenate(_ins_rel)\n            if self.mode == self.MODE_ENTPAIR_BAG:\n                batch_data['multi_rel'] = np.stack(_multi_rel)\n            batch_data['length'] = np.concatenate(_length)\n            batch_data['scope'] = np.stack(_scope)\n            batch_data['mask'] = np.concatenate(_mask)\n\n        return batch_data\n\nclass json_file_data_loader(file_data_loader):\n    MODE_INSTANCE = 0      # One batch contains batch_size instances.\n    MODE_ENTPAIR_BAG = 1   # One batch contains batch_size bags, instances in which have the same entity pair (usually for testing).\n    MODE_RELFACT_BAG = 2   # One batch contains batch size bags, instances in which have the same relation fact. (usually for training).\n\n    def _load_preprocessed_file(self):\n        name_prefix = '.'.join(self.file_name.split('/')[-1].split('.')[:-1])\n        word_vec_name_prefix = '.'.join(self.word_vec_file_name.split('/')[-1].split('.')[:-1])\n        processed_data_dir = '_processed_data'\n        if not os.path.isdir(processed_data_dir):\n            return False\n        word_npy_file_name = os.path.join(processed_data_dir, name_prefix + '_word.npy')\n        pos1_npy_file_name = os.path.join(processed_data_dir, name_prefix + '_pos1.npy')\n        pos2_npy_file_name = os.path.join(processed_data_dir, name_prefix + '_pos2.npy')\n        rel_npy_file_name = os.path.join(processed_data_dir, name_prefix + '_rel.npy')\n        mask_npy_file_name = os.path.join(processed_data_dir, name_prefix + '_mask.npy')\n        length_npy_file_name = os.path.join(processed_data_dir, name_prefix + '_length.npy')\n        entpair2scope_file_name = os.path.join(processed_data_dir, name_prefix + '_entpair2scope.json')\n        relfact2scope_file_name = os.path.join(processed_data_dir, name_prefix + '_relfact2scope.json')\n        word_vec_mat_file_name = os.path.join(processed_data_dir, word_vec_name_prefix + '_mat.npy')\n        word2id_file_name = os.path.join(processed_data_dir, word_vec_name_prefix + '_word2id.json')\n        if not os.path.exists(word_npy_file_name) or \\\n           not os.path.exists(pos1_npy_file_name) or \\\n           not os.path.exists(pos2_npy_file_name) or \\\n           not os.path.exists(rel_npy_file_name) or \\\n           not os.path.exists(mask_npy_file_name) or \\\n           not os.path.exists(length_npy_file_name) or \\\n           not os.path.exists(entpair2scope_file_name) or \\\n           not os.path.exists(relfact2scope_file_name) or \\\n           not os.path.exists(word_vec_mat_file_name) or \\\n           not os.path.exists(word2id_file_name):\n            return False\n        print(\"Pre-processed files exist. Loading them...\")\n        self.data_word = np.load(word_npy_file_name)\n        self.data_pos1 = np.load(pos1_npy_file_name)\n        self.data_pos2 = np.load(pos2_npy_file_name)\n        self.data_rel = np.load(rel_npy_file_name)\n        self.data_mask = np.load(mask_npy_file_name)\n        self.data_length = np.load(length_npy_file_name)\n        self.entpair2scope = json.load(open(entpair2scope_file_name))\n        self.relfact2scope = json.load(open(relfact2scope_file_name))\n        self.word_vec_mat = np.load(word_vec_mat_file_name)\n        self.word2id = json.load(open(word2id_file_name))\n        if self.data_word.shape[1] != self.max_length:\n            print(\"Pre-processed files don't match current settings. Reprocessing...\")\n            return False\n        print(\"Finish loading\")\n        return True\n\n    def __init__(self, file_name, word_vec_file_name, rel2id_file_name, mode, shuffle=True, max_length=120, case_sensitive=False, reprocess=False, batch_size=160):\n        '''\n        file_name: Json file storing the data in the following format\n            [\n                {\n                    'sentence': 'Bill Gates is the founder of Microsoft .',\n                    'head': {'word': 'Bill Gates', ...(other information)},\n                    'tail': {'word': 'Microsoft', ...(other information)},\n                    'relation': 'founder'\n                },\n                ...\n            ]\n        word_vec_file_name: Json file storing word vectors in the following format\n            [\n                {'word': 'the', 'vec': [0.418, 0.24968, ...]},\n                {'word': ',', 'vec': [0.013441, 0.23682, ...]},\n                ...\n            ]\n        rel2id_file_name: Json file storing relation-to-id diction in the following format\n            {\n                'NA': 0\n                'founder': 1\n                ...\n            }\n            **IMPORTANT**: make sure the id of NA is 0!\n        mode: Specify how to get a batch of data. See MODE_* constants for details.\n        shuffle: Whether to shuffle the data, default as True. You should use shuffle when training.\n        max_length: The length that all the sentences need to be extend to, default as 120.\n        case_sensitive: Whether the data processing is case-sensitive, default as False.\n        reprocess: Do the pre-processing whether there exist pre-processed files, default as False.\n        batch_size: The size of each batch, default as 160.\n        '''\n\n        self.file_name = file_name\n        self.word_vec_file_name = word_vec_file_name\n        self.case_sensitive = case_sensitive\n        self.max_length = max_length\n        self.mode = mode\n        self.shuffle = shuffle\n        self.batch_size = batch_size\n        self.rel2id = json.load(open(rel2id_file_name))\n\n        if reprocess or not self._load_preprocessed_file(): # Try to load pre-processed files:\n            # Check files\n            if file_name is None or not os.path.isfile(file_name):\n                raise Exception(\"[ERROR] Data file doesn't exist\")\n            if word_vec_file_name is None or not os.path.isfile(word_vec_file_name):\n                raise Exception(\"[ERROR] Word vector file doesn't exist\")\n\n            # Load files\n            print(\"Loading data file...\")\n            self.ori_data = json.load(open(self.file_name, \"r\"))\n            print(\"Finish loading\")\n            print(\"Loading word vector file...\")\n            self.ori_word_vec = json.load(open(self.word_vec_file_name, \"r\"))\n            print(\"Finish loading\")\n            \n            # Eliminate case sensitive\n            if not case_sensitive:\n                print(\"Elimiating case sensitive problem...\")\n                for i in range(len(self.ori_data)):\n                    self.ori_data[i]['sentence'] = self.ori_data[i]['sentence'].lower()\n                    self.ori_data[i]['head']['word'] = self.ori_data[i]['head']['word'].lower()\n                    self.ori_data[i]['tail']['word'] = self.ori_data[i]['tail']['word'].lower()\n                print(\"Finish eliminating\")\n\n            # Sort data by entities and relations\n            print(\"Sort data...\")\n            self.ori_data.sort(key=lambda a: a['head']['id'] + '#' + a['tail']['id'] + '#' + a['relation'])\n            print(\"Finish sorting\")\n       \n            # Pre-process word vec\n            self.word2id = {}\n            self.word_vec_tot = len(self.ori_word_vec)\n            UNK = self.word_vec_tot\n            BLANK = self.word_vec_tot + 1\n            self.word_vec_dim = len(self.ori_word_vec[0]['vec'])\n            print(\"Got {} words of {} dims\".format(self.word_vec_tot, self.word_vec_dim))\n            print(\"Building word vector matrix and mapping...\")\n            self.word_vec_mat = np.zeros((self.word_vec_tot, self.word_vec_dim), dtype=np.float32)\n            for cur_id, word in enumerate(self.ori_word_vec):\n                w = word['word']\n                if not case_sensitive:\n                    w = w.lower()\n                self.word2id[w] = cur_id\n                self.word_vec_mat[cur_id, :] = word['vec']\n            self.word2id['UNK'] = UNK\n            self.word2id['BLANK'] = BLANK\n            print(\"Finish building\")\n\n            # Pre-process data\n            print(\"Pre-processing data...\")\n            self.instance_tot = len(self.ori_data)\n            self.entpair2scope = {} # (head, tail) -> scope\n            self.relfact2scope = {} # (head, tail, relation) -> scope\n            self.data_word = np.zeros((self.instance_tot, self.max_length), dtype=np.int32)\n            self.data_pos1 = np.zeros((self.instance_tot, self.max_length), dtype=np.int32) \n            self.data_pos2 = np.zeros((self.instance_tot, self.max_length), dtype=np.int32)\n            self.data_rel = np.zeros((self.instance_tot), dtype=np.int32)\n            self.data_mask = np.zeros((self.instance_tot, self.max_length), dtype=np.int32)\n            self.data_length = np.zeros((self.instance_tot), dtype=np.int32)\n            last_entpair = ''\n            last_entpair_pos = -1\n            last_relfact = ''\n            last_relfact_pos = -1\n            for i in range(self.instance_tot):\n                ins = self.ori_data[i]\n                if ins['relation'] in self.rel2id:\n                    self.data_rel[i] = self.rel2id[ins['relation']]\n                else:\n                    self.data_rel[i] = self.rel2id['NA']\n                sentence = ' '.join(ins['sentence'].split()) # delete extra spaces\n                head = ins['head']['word']\n                tail = ins['tail']['word']\n                cur_entpair = ins['head']['id'] + '#' + ins['tail']['id']\n                cur_relfact = ins['head']['id'] + '#' + ins['tail']['id'] + '#' + ins['relation']\n                if cur_entpair != last_entpair:\n                    if last_entpair != '':\n                        self.entpair2scope[last_entpair] = [last_entpair_pos, i] # left closed right open\n                    last_entpair = cur_entpair\n                    last_entpair_pos = i\n                if cur_relfact != last_relfact:\n                    if last_relfact != '':\n                        self.relfact2scope[last_relfact] = [last_relfact_pos, i]\n                    last_relfact = cur_relfact\n                    last_relfact_pos = i\n                p1 = sentence.find(' ' + head + ' ')\n                p2 = sentence.find(' ' + tail + ' ')\n                if p1 == -1:\n                    if sentence[:len(head) + 1] == head + \" \":\n                        p1 = 0\n                    elif sentence[-len(head) - 1:] == \" \" + head:\n                        p1 = len(sentence) - len(head)\n                    else:\n                        p1 = 0 # shouldn't happen\n                else:\n                    p1 += 1\n                if p2 == -1:\n                    if sentence[:len(tail) + 1] == tail + \" \":\n                        p2 = 0\n                    elif sentence[-len(tail) - 1:] == \" \" + tail:\n                        p2 = len(sentence) - len(tail)\n                    else:\n                        p2 = 0 # shouldn't happen\n                else:\n                    p2 += 1\n                # if p1 == -1 or p2 == -1:\n                #     raise Exception(\"[ERROR] Sentence doesn't contain the entity, index = {}, sentence = {}, head = {}, tail = {}\".format(i, sentence, head, tail))\n\n                words = sentence.split()\n                cur_ref_data_word = self.data_word[i]         \n                cur_pos = 0\n                pos1 = -1\n                pos2 = -1\n                for j, word in enumerate(words):\n                    if j < max_length:\n                        if word in self.word2id:\n                            cur_ref_data_word[j] = self.word2id[word]\n                        else:\n                            cur_ref_data_word[j] = UNK\n                    if cur_pos == p1:\n                        pos1 = j\n                        p1 = -1\n                    if cur_pos == p2:\n                        pos2 = j\n                        p2 = -1\n                    cur_pos += len(word) + 1\n                for j in range(j + 1, max_length):\n                    cur_ref_data_word[j] = BLANK\n                self.data_length[i] = len(words)\n                if len(words) > max_length:\n                    self.data_length[i] = max_length\n                if pos1 == -1 or pos2 == -1:\n                    raise Exception(\"[ERROR] Position error, index = {}, sentence = {}, head = {}, tail = {}\".format(i, sentence, head, tail))\n                if pos1 >= max_length:\n                    pos1 = max_length - 1\n                if pos2 >= max_length:\n                    pos2 = max_length - 1\n                pos_min = min(pos1, pos2)\n                pos_max = max(pos1, pos2)\n                for j in range(max_length):\n                    self.data_pos1[i][j] = j - pos1 + max_length\n                    self.data_pos2[i][j] = j - pos2 + max_length\n                    if j >= self.data_length[i]:\n                        self.data_mask[i][j] = 0\n                    elif j <= pos_min:\n                        self.data_mask[i][j] = 1\n                    elif j <= pos_max:\n                        self.data_mask[i][j] = 2\n                    else:\n                        self.data_mask[i][j] = 3\n                    \n            if last_entpair != '':\n                self.entpair2scope[last_entpair] = [last_entpair_pos, self.instance_tot] # left closed right open\n            if last_relfact != '':\n                self.relfact2scope[last_relfact] = [last_relfact_pos, self.instance_tot]\n\n            print(\"Finish pre-processing\")     \n\n            print(\"Storing processed files...\")\n            name_prefix = '.'.join(file_name.split('/')[-1].split('.')[:-1])\n            word_vec_name_prefix = '.'.join(word_vec_file_name.split('/')[-1].split('.')[:-1])\n            processed_data_dir = '_processed_data'\n            if not os.path.isdir(processed_data_dir):\n                os.mkdir(processed_data_dir)\n            np.save(os.path.join(processed_data_dir, name_prefix + '_word.npy'), self.data_word)\n            np.save(os.path.join(processed_data_dir, name_prefix + '_pos1.npy'), self.data_pos1)\n            np.save(os.path.join(processed_data_dir, name_prefix + '_pos2.npy'), self.data_pos2)\n            np.save(os.path.join(processed_data_dir, name_prefix + '_rel.npy'), self.data_rel)\n            np.save(os.path.join(processed_data_dir, name_prefix + '_mask.npy'), self.data_mask)\n            np.save(os.path.join(processed_data_dir, name_prefix + '_length.npy'), self.data_length)\n            json.dump(self.entpair2scope, open(os.path.join(processed_data_dir, name_prefix + '_entpair2scope.json'), 'w'))\n            json.dump(self.relfact2scope, open(os.path.join(processed_data_dir, name_prefix + '_relfact2scope.json'), 'w'))\n            np.save(os.path.join(processed_data_dir, word_vec_name_prefix + '_mat.npy'), self.word_vec_mat)\n            json.dump(self.word2id, open(os.path.join(processed_data_dir, word_vec_name_prefix + '_word2id.json'), 'w'))\n            print(\"Finish storing\")\n\n        # Prepare for idx\n        self.instance_tot = self.data_word.shape[0]\n        self.entpair_tot = len(self.entpair2scope)\n        self.relfact_tot = 0 # The number of relation facts, without NA.\n        for key in self.relfact2scope:\n            if key[-2:] != 'NA':\n                self.relfact_tot += 1\n        self.rel_tot = len(self.rel2id)\n\n        if self.mode == self.MODE_INSTANCE:\n            self.order = list(range(self.instance_tot))\n        elif self.mode == self.MODE_ENTPAIR_BAG:\n            self.order = list(range(len(self.entpair2scope)))\n            self.scope_name = []\n            self.scope = []\n            for key, value in iteritems(self.entpair2scope):\n                self.scope_name.append(key)\n                self.scope.append(value)\n        elif self.mode == self.MODE_RELFACT_BAG:\n            self.order = list(range(len(self.relfact2scope)))\n            self.scope_name = []\n            self.scope = []\n            for key, value in iteritems(self.relfact2scope):\n                self.scope_name.append(key)\n                self.scope.append(value)\n        else:\n            raise Exception(\"[ERROR] Invalid mode\")\n        self.idx = 0\n\n        if self.shuffle:\n            random.shuffle(self.order) \n\n        print(\"Total relation fact: %d\" % (self.relfact_tot))\n\n    def __iter__(self):\n        return self\n\n    def __next__(self):\n        return self.next_batch(self.batch_size)\n\n    def next_batch(self, batch_size):\n        if self.idx >= len(self.order):\n            self.idx = 0\n            if self.shuffle:\n                random.shuffle(self.order) \n            raise StopIteration\n\n        batch_data = {}\n\n        if self.mode == self.MODE_INSTANCE:\n            idx0 = self.idx\n            idx1 = self.idx + batch_size\n            if idx1 > len(self.order):\n                idx1 = len(self.order)\n            self.idx = idx1\n            batch_data['word'] = self.data_word[idx0:idx1]\n            batch_data['pos1'] = self.data_pos1[idx0:idx1]\n            batch_data['pos2'] = self.data_pos2[idx0:idx1]\n            batch_data['rel'] = self.data_rel[idx0:idx1]\n            batch_data['mask'] = self.data_mask[idx0:idx1]\n            batch_data['length'] = self.data_length[idx0:idx1]\n            batch_data['scope'] = np.stack([list(range(batch_size)), list(range(1, batch_size + 1))], axis=1)\n            if idx1 - idx0 < batch_size:\n                padding = batch_size - (idx1 - idx0)\n                batch_data['word'] = np.concatenate([batch_data['word'], np.zeros((padding, self.data_word.shape[-1]), dtype=np.int32)])\n                batch_data['pos1'] = np.concatenate([batch_data['pos1'], np.zeros((padding, self.data_pos1.shape[-1]), dtype=np.int32)])\n                batch_data['pos2'] = np.concatenate([batch_data['pos2'], np.zeros((padding, self.data_pos2.shape[-1]), dtype=np.int32)])\n                batch_data['mask'] = np.concatenate([batch_data['mask'], np.zeros((padding, self.data_mask.shape[-1]), dtype=np.int32)])\n                batch_data['rel'] = np.concatenate([batch_data['rel'], np.zeros((padding), dtype=np.int32)])\n                batch_data['length'] = np.concatenate([batch_data['length'], np.zeros((padding), dtype=np.int32)])\n        elif self.mode == self.MODE_ENTPAIR_BAG or self.mode == self.MODE_RELFACT_BAG:\n            idx0 = self.idx\n            idx1 = self.idx + batch_size\n            if idx1 > len(self.order):\n                idx1 = len(self.order)\n            self.idx = idx1\n            _word = []\n            _pos1 = []\n            _pos2 = []\n            _mask = []\n            _rel = []\n            _ins_rel = []\n            _multi_rel = []\n            _entpair = []\n            _length = []\n            _scope = []\n            cur_pos = 0\n            for i in range(idx0, idx1):\n                _word.append(self.data_word[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])\n                _pos1.append(self.data_pos1[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])\n                _pos2.append(self.data_pos2[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])\n                _mask.append(self.data_mask[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])\n                _rel.append(self.data_rel[self.scope[self.order[i]][0]])\n                _ins_rel.append(self.data_rel[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])\n                _length.append(self.data_length[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])\n                bag_size = self.scope[self.order[i]][1] - self.scope[self.order[i]][0]\n                _scope.append([cur_pos, cur_pos + bag_size])\n                cur_pos = cur_pos + bag_size\n                if self.mode == self.MODE_ENTPAIR_BAG:\n                    _one_multi_rel = np.zeros((self.rel_tot), dtype=np.int32)\n                    for j in range(self.scope[self.order[i]][0], self.scope[self.order[i]][1]):\n                        _one_multi_rel[self.data_rel[j]] = 1\n                    _multi_rel.append(_one_multi_rel)\n                    _entpair.append(self.scope_name[self.order[i]])\n            for i in range(batch_size - (idx1 - idx0)):\n                _word.append(np.zeros((1, self.data_word.shape[-1]), dtype=np.int32))\n                _pos1.append(np.zeros((1, self.data_pos1.shape[-1]), dtype=np.int32))\n                _pos2.append(np.zeros((1, self.data_pos2.shape[-1]), dtype=np.int32))\n                _mask.append(np.zeros((1, self.data_mask.shape[-1]), dtype=np.int32))\n                _rel.append(0)\n                _ins_rel.append(np.zeros((1), dtype=np.int32))\n                _length.append(np.zeros((1), dtype=np.int32))\n                _scope.append([cur_pos, cur_pos + 1])\n                cur_pos += 1\n                if self.mode == self.MODE_ENTPAIR_BAG:\n                    _multi_rel.append(np.zeros((self.rel_tot), dtype=np.int32))\n                    _entpair.append('None#None')\n            batch_data['word'] = np.concatenate(_word)\n            batch_data['pos1'] = np.concatenate(_pos1)\n            batch_data['pos2'] = np.concatenate(_pos2)\n            batch_data['mask'] = np.concatenate(_mask)\n            batch_data['rel'] = np.stack(_rel)\n            batch_data['ins_rel'] = np.concatenate(_ins_rel)\n            if self.mode == self.MODE_ENTPAIR_BAG:\n                batch_data['multi_rel'] = np.stack(_multi_rel)\n                batch_data['entpair'] = _entpair\n            batch_data['length'] = np.concatenate(_length)\n            batch_data['scope'] = np.stack(_scope)\n\n        return batch_data\n"
  },
  {
    "path": "nrekit/framework.py",
    "content": "import tensorflow as tf\nimport os\nimport sklearn.metrics\nimport numpy as np\nimport sys\nimport time\n\ndef average_gradients(tower_grads):\n    \"\"\"Calculate the average gradient for each shared variable across all towers.\n\n    Note that this function provides a synchronization point across all towers.\n\n    Args:\n        tower_grads: List of lists of (gradient, variable) tuples. The outer list\n            is over individual gradients. The inner list is over the gradient\n            calculation for each tower.\n    Returns:\n         List of pairs of (gradient, variable) where the gradient has been averaged\n         across all towers.\n    \"\"\"\n    average_grads = []\n    for grad_and_vars in zip(*tower_grads):\n        # Note that each grad_and_vars looks like the following:\n        #     ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))\n        grads = []\n        for g, _ in grad_and_vars:\n            # Add 0 dimension to the gradients to represent the tower.\n            expanded_g = tf.expand_dims(g, 0)\n\n            # Append on a 'tower' dimension which we will average over below.\n            grads.append(expanded_g)\n\n        # Average over the 'tower' dimension.\n        grad = tf.concat(axis=0, values=grads)\n        grad = tf.reduce_mean(grad, 0)\n\n        # Keep in mind that the Variables are redundant because they are shared\n        # across towers. So .. we will just return the first tower's pointer to\n        # the Variable.\n        v = grad_and_vars[0][1]\n        grad_and_var = (grad, v)\n        average_grads.append(grad_and_var)\n    return average_grads\n\nclass re_model:\n    def __init__(self, train_data_loader, batch_size, max_length=120):\n        self.word = tf.placeholder(dtype=tf.int32, shape=[None, max_length], name='word')\n        self.pos1 = tf.placeholder(dtype=tf.int32, shape=[None, max_length], name='pos1')\n        self.pos2 = tf.placeholder(dtype=tf.int32, shape=[None, max_length], name='pos2')\n        self.label = tf.placeholder(dtype=tf.int32, shape=[batch_size], name='label')\n        self.ins_label = tf.placeholder(dtype=tf.int32, shape=[None], name='ins_label')\n        self.length = tf.placeholder(dtype=tf.int32, shape=[None], name='length')\n        self.scope = tf.placeholder(dtype=tf.int32, shape=[batch_size, 2], name='scope')\n        self.train_data_loader = train_data_loader\n        self.rel_tot = train_data_loader.rel_tot\n        self.word_vec_mat = train_data_loader.word_vec_mat\n\n    def loss(self):\n        raise NotImplementedError\n    \n    def train_logit(self):\n        raise NotImplementedError\n    \n    def test_logit(self):\n        raise NotImplementedError\n\nclass re_framework:\n    MODE_BAG = 0 # Train and test the model at bag level.\n    MODE_INS = 1 # Train and test the model at instance level\n\n    def __init__(self, train_data_loader, test_data_loader, max_length=120, batch_size=160):\n        self.train_data_loader = train_data_loader\n        self.test_data_loader = test_data_loader\n        self.sess = None\n\n    def one_step_multi_models(self, sess, models, batch_data_gen, run_array, return_label=True):\n        feed_dict = {}\n        batch_label = []\n        for model in models:\n            batch_data = batch_data_gen.next_batch(batch_data_gen.batch_size // len(models))\n            feed_dict.update({\n                model.word: batch_data['word'],\n                model.pos1: batch_data['pos1'],\n                model.pos2: batch_data['pos2'],\n                model.label: batch_data['rel'],\n                model.ins_label: batch_data['ins_rel'],\n                model.scope: batch_data['scope'],\n                model.length: batch_data['length'],\n            })\n            if 'mask' in batch_data and hasattr(model, \"mask\"):\n                feed_dict.update({model.mask: batch_data['mask']})\n            batch_label.append(batch_data['rel'])\n        result = sess.run(run_array, feed_dict)\n        batch_label = np.concatenate(batch_label)\n        if return_label:\n            result += [batch_label]\n        return result\n\n    def one_step(self, sess, model, batch_data, run_array):\n        feed_dict = {\n            model.word: batch_data['word'],\n            model.pos1: batch_data['pos1'],\n            model.pos2: batch_data['pos2'],\n            model.label: batch_data['rel'],\n            model.ins_label: batch_data['ins_rel'],\n            model.scope: batch_data['scope'],\n            model.length: batch_data['length'],\n        }\n        if 'mask' in batch_data and hasattr(model, \"mask\"):\n            feed_dict.update({model.mask: batch_data['mask']})\n        result = sess.run(run_array, feed_dict)\n        return result\n\n    def train(self,\n              model,\n              model_name,\n              ckpt_dir='./checkpoint',\n              summary_dir='./summary',\n              test_result_dir='./test_result',\n              learning_rate=0.5,\n              max_epoch=60,\n              pretrain_model=None,\n              test_epoch=1,\n              optimizer=tf.train.GradientDescentOptimizer,\n              gpu_nums=1):\n        \n        assert(self.train_data_loader.batch_size % gpu_nums == 0)\n        print(\"Start training...\")\n        \n        # Init\n        config = tf.ConfigProto(allow_soft_placement=True)\n        self.sess = tf.Session(config=config)\n        optimizer = optimizer(learning_rate)\n        \n        # Multi GPUs\n        tower_grads = []\n        tower_models = []\n        for gpu_id in range(gpu_nums):\n            with tf.device(\"/gpu:%d\" % gpu_id):\n                with tf.name_scope(\"gpu_%d\" % gpu_id):\n                    cur_model = model(self.train_data_loader, self.train_data_loader.batch_size // gpu_nums, self.train_data_loader.max_length)\n                    tower_grads.append(optimizer.compute_gradients(cur_model.loss()))\n                    tower_models.append(cur_model)\n                    tf.add_to_collection(\"loss\", cur_model.loss())\n                    tf.add_to_collection(\"train_logit\", cur_model.train_logit())\n\n        loss_collection = tf.get_collection(\"loss\")\n        loss = tf.add_n(loss_collection) / len(loss_collection)\n        train_logit_collection = tf.get_collection(\"train_logit\")\n        train_logit = tf.concat(train_logit_collection, 0)\n\n        grads = average_gradients(tower_grads)\n        train_op = optimizer.apply_gradients(grads)\n        summary_writer = tf.summary.FileWriter(summary_dir, self.sess.graph)\n\n        # Saver\n        saver = tf.train.Saver(max_to_keep=None)\n        if pretrain_model is None:\n            self.sess.run(tf.global_variables_initializer())\n        else:\n            saver.restore(self.sess, pretrain_model)\n\n        # Training\n        best_metric = 0\n        best_prec = None\n        best_recall = None\n        not_best_count = 0 # Stop training after several epochs without improvement.\n        for epoch in range(max_epoch):\n            print('###### Epoch ' + str(epoch) + ' ######')\n            tot_correct = 0\n            tot_not_na_correct = 0\n            tot = 0\n            tot_not_na = 0\n            i = 0\n            time_sum = 0\n            while True:\n                time_start = time.time()\n                try:\n                    iter_loss, iter_logit, _train_op, iter_label = self.one_step_multi_models(self.sess, tower_models, self.train_data_loader, [loss, train_logit, train_op])\n                except StopIteration:\n                    break\n                time_end = time.time()\n                t = time_end - time_start\n                time_sum += t\n                iter_output = iter_logit.argmax(-1)\n                iter_correct = (iter_output == iter_label).sum()\n                iter_not_na_correct = np.logical_and(iter_output == iter_label, iter_label != 0).sum()\n                tot_correct += iter_correct\n                tot_not_na_correct += iter_not_na_correct\n                tot += iter_label.shape[0]\n                tot_not_na += (iter_label != 0).sum()\n                if tot_not_na > 0:\n                    sys.stdout.write(\"epoch %d step %d time %.2f | loss: %f, not NA accuracy: %f, accuracy: %f\\r\" % (epoch, i, t, iter_loss, float(tot_not_na_correct) / tot_not_na, float(tot_correct) / tot))\n                    sys.stdout.flush()\n                i += 1\n            print(\"\\nAverage iteration time: %f\" % (time_sum / i))\n\n            if (epoch + 1) % test_epoch == 0:\n                metric = self.test(model)\n                if metric > best_metric:\n                    best_metric = metric\n                    best_prec = self.cur_prec\n                    best_recall = self.cur_recall\n                    print(\"Best model, storing...\")\n                    if not os.path.isdir(ckpt_dir):\n                        os.mkdir(ckpt_dir)\n                    path = saver.save(self.sess, os.path.join(ckpt_dir, model_name))\n                    print(\"Finish storing\")\n                    not_best_count = 0\n                else:\n                    not_best_count += 1\n\n            if not_best_count >= 20:\n                break\n        \n        print(\"######\")\n        print(\"Finish training \" + model_name)\n        print(\"Best epoch auc = %f\" % (best_metric))\n        if (not best_prec is None) and (not best_recall is None):\n            if not os.path.isdir(test_result_dir):\n                os.mkdir(test_result_dir)\n            np.save(os.path.join(test_result_dir, model_name + \"_x.npy\"), best_recall)\n            np.save(os.path.join(test_result_dir, model_name + \"_y.npy\"), best_prec)\n\n    def test(self,\n             model,\n             ckpt=None,\n             return_result=False,\n             mode=MODE_BAG):\n        if mode == re_framework.MODE_BAG:\n            return self.__test_bag__(model, ckpt=ckpt, return_result=return_result)\n        elif mode == re_framework.MODE_INS:\n            raise NotImplementedError\n        else:\n            raise NotImplementedError\n        \n    def __test_bag__(self, model, ckpt=None, return_result=False):\n        print(\"Testing...\")\n        if self.sess == None:\n            self.sess = tf.Session()\n        model = model(self.test_data_loader, self.test_data_loader.batch_size, self.test_data_loader.max_length)\n        if not ckpt is None:\n            saver = tf.train.Saver()\n            saver.restore(self.sess, ckpt)\n        tot_correct = 0\n        tot_not_na_correct = 0\n        tot = 0\n        tot_not_na = 0\n        entpair_tot = 0\n        test_result = []\n        pred_result = []\n         \n        for i, batch_data in enumerate(self.test_data_loader):\n            iter_logit = self.one_step(self.sess, model, batch_data, [model.test_logit()])[0]\n            iter_output = iter_logit.argmax(-1)\n            iter_correct = (iter_output == batch_data['rel']).sum()\n            iter_not_na_correct = np.logical_and(iter_output == batch_data['rel'], batch_data['rel'] != 0).sum()\n            tot_correct += iter_correct\n            tot_not_na_correct += iter_not_na_correct\n            tot += batch_data['rel'].shape[0]\n            tot_not_na += (batch_data['rel'] != 0).sum()\n            if tot_not_na > 0:\n                sys.stdout.write(\"[TEST] step %d | not NA accuracy: %f, accuracy: %f\\r\" % (i, float(tot_not_na_correct) / tot_not_na, float(tot_correct) / tot))\n                sys.stdout.flush()\n            for idx in range(len(iter_logit)):\n                for rel in range(1, self.test_data_loader.rel_tot):\n                    test_result.append({'score': iter_logit[idx][rel], 'flag': batch_data['multi_rel'][idx][rel]})\n                    if batch_data['entpair'][idx] != \"None#None\":\n                        pred_result.append({'score': float(iter_logit[idx][rel]), 'entpair': batch_data['entpair'][idx].encode('utf-8'), 'relation': rel})\n                entpair_tot += 1 \n        sorted_test_result = sorted(test_result, key=lambda x: x['score'])\n        prec = []\n        recall = [] \n        correct = 0\n        for i, item in enumerate(sorted_test_result[::-1]):\n            correct += item['flag']\n            prec.append(float(correct) / (i + 1))\n            recall.append(float(correct) / self.test_data_loader.relfact_tot)\n        auc = sklearn.metrics.auc(x=recall, y=prec)\n        print(\"\\n[TEST] auc: {}\".format(auc))\n        print(\"Finish testing\")\n        self.cur_prec = prec\n        self.cur_recall = recall\n\n        if not return_result:\n            return auc\n        else:\n            return (auc, pred_result)\n"
  },
  {
    "path": "nrekit/network/classifier.py",
    "content": "import tensorflow as tf\nimport numpy as np\n\ndef softmax_cross_entropy(x, label, rel_tot, weights_table=None, weights=1.0, var_scope=None):\n    with tf.variable_scope(var_scope or \"loss\", reuse=tf.AUTO_REUSE):\n        if weights_table is not None:\n            weights = tf.nn.embedding_lookup(weights_table, label)\n        label_onehot = tf.one_hot(indices=label, depth=rel_tot, dtype=tf.int32)\n        loss = tf.losses.softmax_cross_entropy(onehot_labels=label_onehot, logits=x, weights=weights)\n        tf.summary.scalar('loss', loss)\n        return loss\n\ndef sigmoid_cross_entropy(x, label, rel_tot, weights_table=None, var_scope=None):\n    with tf.variable_scope(var_scope or \"loss\", reuse=tf.AUTO_REUSE):\n        if weights_table is None:\n            weights = 1.0\n        else:\n            weights = tf.nn.embedding_lookup(weights_table, label)\n        label_onehot = tf.one_hot(indices=label, depth=rel_tot, dtype=tf.int32)\n        loss = tf.losses.sigmoid_cross_entropy(label_onehot, logits=x, weights=weights)\n        tf.summary.scalar('loss', loss)\n        return loss\n\n# Soft-label\n# I just implemented it, but I haven't got the result in paper.\ndef soft_label_softmax_cross_entropy(x):\n    with tf.name_scope(\"soft-label-loss\"):\n        label_onehot = tf.one_hot(indices=self.label, depth=FLAGS.num_classes, dtype=tf.int32)\n        nscore = x + 0.9 * tf.reshape(tf.reduce_max(x, 1), [-1, 1]) * tf.cast(label_onehot, tf.float32)\n        nlabel = tf.one_hot(indices=tf.reshape(tf.argmax(nscore, axis=1), [-1]), depth=FLAGS.num_classes, dtype=tf.int32)\n        loss = tf.losses.softmax_cross_entropy(onehot_labels=nlabel, logits=nscore, weights=self.weights)\n        tf.summary.scalar('loss', loss)\n        return loss\n\ndef output(x):\n    return tf.argmax(x, axis=-1)\n"
  },
  {
    "path": "nrekit/network/embedding.py",
    "content": "import tensorflow as tf\nimport numpy as np\n\ndef word_embedding(word, word_vec_mat, var_scope=None, word_embedding_dim=50, add_unk_and_blank=True):\n    with tf.variable_scope(var_scope or 'word_embedding', reuse=tf.AUTO_REUSE):\n        word_embedding = tf.get_variable('word_embedding', initializer=word_vec_mat, dtype=tf.float32)\n        if add_unk_and_blank:\n            word_embedding = tf.concat([word_embedding,\n                                        tf.get_variable(\"unk_word_embedding\", [1, word_embedding_dim], dtype=tf.float32,\n                                            initializer=tf.contrib.layers.xavier_initializer()),\n                                        tf.constant(np.zeros((1, word_embedding_dim), dtype=np.float32))], 0)\n        x = tf.nn.embedding_lookup(word_embedding, word)\n        return x\n\ndef pos_embedding(pos1, pos2, var_scope=None, pos_embedding_dim=5, max_length=120):\n    with tf.variable_scope(var_scope or 'pos_embedding', reuse=tf.AUTO_REUSE):\n        pos_tot = max_length * 2\n\n        pos1_embedding = tf.get_variable('real_pos1_embedding', [pos_tot, pos_embedding_dim], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer()) \n        # pos1_embedding = tf.concat([tf.zeros((1, pos_embedding_dim), dtype=tf.float32), real_pos1_embedding], 0)\n        pos2_embedding = tf.get_variable('real_pos2_embedding', [pos_tot, pos_embedding_dim], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer()) \n        # pos2_embedding = tf.concat([tf.zeros((1, pos_embedding_dim), dtype=tf.float32), real_pos2_embedding], 0)\n\n        input_pos1 = tf.nn.embedding_lookup(pos1_embedding, pos1)\n        input_pos2 = tf.nn.embedding_lookup(pos2_embedding, pos2)\n        x = tf.concat([input_pos1, input_pos2], -1)\n        return x\n\ndef word_position_embedding(word, word_vec_mat, pos1, pos2, var_scope=None, word_embedding_dim=50, pos_embedding_dim=5, max_length=120, add_unk_and_blank=True):\n    w_embedding = word_embedding(word, word_vec_mat, var_scope=var_scope, word_embedding_dim=word_embedding_dim, add_unk_and_blank=add_unk_and_blank)\n    p_embedding = pos_embedding(pos1, pos2, var_scope=var_scope, pos_embedding_dim=pos_embedding_dim, max_length=max_length)\n    return tf.concat([w_embedding, p_embedding], -1)\n"
  },
  {
    "path": "nrekit/network/encoder.py",
    "content": "import tensorflow as tf\nimport numpy as np\nimport math\n\ndef __dropout__(x, keep_prob=1.0):\n    return tf.contrib.layers.dropout(x, keep_prob=keep_prob)\n\ndef __pooling__(x):\n    return tf.reduce_max(x, axis=-2)\n\ndef __piecewise_pooling__(x, mask):\n    mask_embedding = tf.constant([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float32)\n    mask = tf.nn.embedding_lookup(mask_embedding, mask)\n    hidden_size = x.shape[-1]\n    x = tf.reduce_max(tf.expand_dims(mask * 100, 2) + tf.expand_dims(x, 3), axis=1) - 100\n    return tf.reshape(x, [-1, hidden_size * 3])\n\ndef __cnn_cell__(x, hidden_size=230, kernel_size=3, stride_size=1):\n    x = tf.layers.conv1d(inputs=x, \n                         filters=hidden_size, \n                         kernel_size=kernel_size, \n                         strides=stride_size, \n                         padding='same', \n                         kernel_initializer=tf.contrib.layers.xavier_initializer())\n    return x\n\ndef cnn(x, hidden_size=230, kernel_size=3, stride_size=1, activation=tf.nn.relu, var_scope=None, keep_prob=1.0):\n    with tf.variable_scope(var_scope or \"cnn\", reuse=tf.AUTO_REUSE):\n        max_length = x.shape[1]\n        x = __cnn_cell__(x, hidden_size, kernel_size, stride_size)\n        x = __pooling__(x)\n        x = activation(x)\n        x = __dropout__(x, keep_prob)\n        return x\n\ndef pcnn(x, mask, hidden_size=230, kernel_size=3, stride_size=1, activation=tf.nn.relu, var_scope=None, keep_prob=1.0):\n    with tf.variable_scope(var_scope or \"pcnn\", reuse=tf.AUTO_REUSE):\n        max_length = x.shape[1]\n        x = __cnn_cell__(x, hidden_size, kernel_size, stride_size)\n        x = __piecewise_pooling__(x, mask)\n        x = activation(x)\n        x = __dropout__(x, keep_prob)\n        return x\n\ndef __rnn_cell__(hidden_size, cell_name='lstm'):\n    if isinstance(cell_name, list) or isinstance(cell_name, tuple):\n        if len(cell_name) == 1:\n            return __rnn_cell__(hidden_size, cell_name[0])\n        cells = [self.__rnn_cell__(hidden_size, c) for c in cell_name]\n        return tf.contrib.rnn.MultiRNNCell(cells, state_is_tuple=True)\n    if cell_name.lower() == 'lstm':\n        return tf.contrib.rnn.BasicLSTMCell(hidden_size, state_is_tuple=True)\n    elif cell_name.lower() == 'gru':\n        return tf.contrib.rnn.GRUCell(hidden_size)\n    raise NotImplementedError\n\ndef rnn(x, length, hidden_size=230, cell_name='lstm', var_scope=None, keep_prob=1.0):\n    with tf.variable_scope(var_scope or \"rnn\", reuse=tf.AUTO_REUSE):\n        x = __dropout__(x, keep_prob)\n        cell = __rnn_cell__(hidden_size, cell_name)\n        _, states = tf.nn.dynamic_rnn(cell, x, sequence_length=length, dtype=tf.float32, scope='dynamic-rnn')\n        if isinstance(states, tuple):\n            states = states[0]\n        return states\n\ndef birnn(x, length, hidden_size=230, cell_name='lstm', var_scope=None, keep_prob=1.0):\n    with tf.variable_scope(var_scope or \"birnn\", reuse=tf.AUTO_REUSE):\n        x = __dropout__(x, keep_prob)\n        fw_cell = __rnn_cell__(hidden_size, cell_name)\n        bw_cell = __rnn_cell__(hidden_size, cell_name)\n        _, states = tf.nn.bidirectional_dynamic_rnn(fw_cell, bw_cell, x, sequence_length=length, dtype=tf.float32, scope='dynamic-bi-rnn')\n        fw_states, bw_states = states\n        if isinstance(fw_states, tuple):\n            fw_states = fw_states[0]\n            bw_states = bw_states[0]\n        return tf.concat([fw_states, bw_states], axis=1)\n\n"
  },
  {
    "path": "nrekit/network/selector.py",
    "content": "import tensorflow as tf\nimport numpy as np\n\ndef __dropout__(x, keep_prob=1.0):\n    return tf.contrib.layers.dropout(x, keep_prob=keep_prob)\n\ndef __logit__(x, rel_tot, var_scope=None):\n    with tf.variable_scope(var_scope or 'logit', reuse=tf.AUTO_REUSE):\n        relation_matrix = tf.get_variable('relation_matrix', shape=[rel_tot, x.shape[1]], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer())\n        bias = tf.get_variable('bias', shape=[rel_tot], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer())\n        logit = tf.matmul(x, tf.transpose(relation_matrix)) + bias\n    return logit\n\ndef __attention_train_logit__(x, query, rel_tot, var_scope=None):\n    with tf.variable_scope(var_scope or 'logit', reuse=tf.AUTO_REUSE):\n        relation_matrix = tf.get_variable('relation_matrix', shape=[rel_tot, x.shape[1]], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer())\n        bias = tf.get_variable('bias', shape=[rel_tot], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer())\n    current_relation = tf.nn.embedding_lookup(relation_matrix, query)\n    attention_logit = tf.reduce_sum(current_relation * x, -1) # sum[(n', hidden_size) \\dot (n', hidden_size)] = (n)\n    return attention_logit\n\ndef __attention_test_logit__(x, rel_tot, var_scope=None):\n    with tf.variable_scope(var_scope or 'logit', reuse=tf.AUTO_REUSE):\n        relation_matrix = tf.get_variable('relation_matrix', shape=[rel_tot, x.shape[1]], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer())\n        bias = tf.get_variable('bias', shape=[rel_tot], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer())\n    attention_logit = tf.matmul(x, tf.transpose(relation_matrix)) # (n', hidden_size) x (hidden_size, rel_tot) = (n', rel_tot)\n    return attention_logit\n\ndef instance(x, rel_tot, var_scope=None, keep_prob=1.0):\n    x = __dropout__(x, keep_prob)\n    x = __logit__(x, rel_tot)\n    return x\n\ndef bag_attention(x, scope, query, rel_tot, is_training, var_scope=None, dropout_before=False, keep_prob=1.0):\n    with tf.variable_scope(var_scope or \"attention\", reuse=tf.AUTO_REUSE):\n        if is_training: # training\n            if dropout_before:\n                x = __dropout__(x, keep_prob)\n            bag_repre = []\n            attention_logit = __attention_train_logit__(x, query, rel_tot)\n            for i in range(scope.shape[0]):\n                bag_hidden_mat = x[scope[i][0]:scope[i][1]]\n                attention_score = tf.nn.softmax(attention_logit[scope[i][0]:scope[i][1]], -1)\n                bag_repre.append(tf.squeeze(tf.matmul(tf.expand_dims(attention_score, 0), bag_hidden_mat))) # (1, n') x (n', hidden_size) = (1, hidden_size) -> (hidden_size)\n            bag_repre = tf.stack(bag_repre)\n            if not dropout_before:\n                bag_repre = __dropout__(bag_repre, keep_prob)\n            return __logit__(bag_repre, rel_tot), bag_repre\n        else: # testing\n            attention_logit = __attention_test_logit__(x, rel_tot) # (n, rel_tot)\n            bag_repre = [] \n            bag_logit = []\n            for i in range(scope.shape[0]):\n                bag_hidden_mat = x[scope[i][0]:scope[i][1]]\n                attention_score = tf.nn.softmax(tf.transpose(attention_logit[scope[i][0]:scope[i][1], :]), -1) # softmax of (rel_tot, n')\n                bag_repre_for_each_rel = tf.matmul(attention_score, bag_hidden_mat) # (rel_tot, n') \\dot (n', hidden_size) = (rel_tot, hidden_size)\n                bag_logit_for_each_rel = __logit__(bag_repre_for_each_rel, rel_tot) # -> (rel_tot, rel_tot)\n                bag_repre.append(bag_repre_for_each_rel)\n                bag_logit.append(tf.diag_part(tf.nn.softmax(bag_logit_for_each_rel, -1))) # could be improved by sigmoid?\n            bag_repre = tf.stack(bag_repre)\n            bag_logit = tf.stack(bag_logit)\n            return bag_logit, bag_repre\n\ndef bag_average(x, scope, rel_tot, var_scope=None, dropout_before=False, keep_prob=1.0):\n    with tf.variable_scope(var_scope or \"average\", reuse=tf.AUTO_REUSE):\n        if dropout_before:\n            x = __dropout__(x, keep_prob)\n        bag_repre = []\n        for i in range(scope.shape[0]):\n            bag_hidden_mat = x[scope[i][0]:scope[i][1]]\n            bag_repre.append(tf.reduce_mean(bag_hidden_mat, 0)) # (n', hidden_size) -> (hidden_size)\n        bag_repre = tf.stack(bag_repre)\n        if not dropout_before:\n            bag_repre = __dropout__(bag_repre, keep_prob)\n    return __logit__(bag_repre, rel_tot), bag_repre\n\ndef bag_one(x, scope, query, rel_tot, is_training, var_scope=None, dropout_before=False, keep_prob=1.0): # could be improved?\n    with tf.variable_scope(var_scope or \"one\", reuse=tf.AUTO_REUSE):\n        if is_training: # training\n            if dropout_before:\n                x = __dropout__(x, keep_prob)\n            bag_repre = []\n            for i in range(scope.shape[0]):\n                bag_hidden_mat = x[scope[i][0]:scope[i][1]]\n                instance_logit = tf.nn.softmax(__logit__(bag_hidden_mat, rel_tot), -1) # (n', hidden_size) -> (n', rel_tot)\n                j = tf.argmax(instance_logit[:, query[i]], output_type=tf.int32)\n                bag_repre.append(bag_hidden_mat[j])\n            bag_repre = tf.stack(bag_repre)\n            if not dropout_before:\n                bag_repre = __dropout__(bag_repre, keep_prob)\n            return __logit__(bag_repre, rel_tot), bag_repre\n        else: # testing\n            if dropout_before:\n                x = __dropout__(x, keep_prob)\n            bag_repre = []\n            bag_logit = []\n            for i in range(scope.shape[0]):\n                bag_hidden_mat = x[scope[i][0]:scope[i][1]]\n                instance_logit = tf.nn.softmax(__logit__(bag_hidden_mat, rel_tot), -1) # (n', hidden_size) -> (n', rel_tot)\n                bag_logit.append(tf.reduce_max(instance_logit, 0))\n                bag_repre.append(bag_hidden_mat[0]) # fake max repre\n            bag_logit = tf.stack(bag_logit)\n            bag_repre = tf.stack(bag_repre)\n            return bag_logit, bag_repre\n\ndef bag_cross_max(x, scope, rel_tot, var_scope=None, dropout_before=False, keep_prob=1.0):\n    '''\n    Cross-sentence Max-pooling proposed by (Jiang et al. 2016.)\n    \"Relation Extraction with Multi-instance Multi-label Convolutional Neural Networks\"\n    https://pdfs.semanticscholar.org/8731/369a707046f3f8dd463d1fd107de31d40a24.pdf\n    '''\n    with tf.variable_scope(var_scope or \"cross_max\", reuse=tf.AUTO_REUSE):\n        if dropout_before:\n            x = __dropout__(x, keep_prob)\n        bag_repre = []\n        for i in range(scope.shape[0]):\n            bag_hidden_mat = x[scope[i][0]:scope[i][1]]\n            bag_repre.append(tf.reduce_max(bag_hidden_mat, 0)) # (n', hidden_size) -> (hidden_size)\n        bag_repre = tf.stack(bag_repre)\n        if not dropout_before:\n            bag_repre = __dropout__(bag_repre, keep_prob)\n    return __logit__(bag_repre, rel_tot), bag_repre\n"
  },
  {
    "path": "nrekit/rl.py",
    "content": "import tensorflow as tf\nimport os\nimport sklearn.metrics\nimport numpy as np\nimport sys\nimport math\nimport time\nimport framework\nimport network\n\nclass policy_agent(framework.re_model):\n    def __init__(self, train_data_loader, batch_size, max_length=120):\n        framework.re_model.__init__(self, train_data_loader, batch_size, max_length)\n        self.weights = tf.placeholder(tf.float32, shape=(), name=\"weights_scalar\")\n\n        x = network.embedding.word_position_embedding(self.word, self.word_vec_mat, self.pos1, self.pos2)\n        x_train = network.encoder.cnn(x, keep_prob=0.5)\n        x_test = network.encoder.cnn(x, keep_prob=1.0)\n        self._train_logit = network.selector.instance(x_train, 2, keep_prob=0.5)\n        self._test_logit = network.selector.instance(x_test, 2, keep_prob=1.0)\n        self._loss = network.classifier.softmax_cross_entropy(self._train_logit, self.ins_label, 2, weights=self.weights)\n\n    def loss(self):\n        return self._loss\n\n    def train_logit(self):\n        return self._train_logit\n\n    def test_logit(self):\n        return self._test_logit\n\nclass rl_re_framework(framework.re_framework):\n    def __init__(self, train_data_loader, test_data_loader, max_length=120, batch_size=160):\n        framework.re_framework.__init__(self, train_data_loader, test_data_loader, max_length, batch_size)\n\n    def agent_one_step(self, sess, agent_model, batch_data, run_array, weights=1):\n        feed_dict = {\n            agent_model.word: batch_data['word'],\n            agent_model.pos1: batch_data['pos1'],\n            agent_model.pos2: batch_data['pos2'],\n            agent_model.ins_label: batch_data['agent_label'],\n            agent_model.length: batch_data['length'],\n            agent_model.weights: weights\n        }\n        if 'mask' in batch_data and hasattr(agent_model, \"mask\"):\n            feed_dict.update({agent_model.mask: batch_data['mask']})\n        result = sess.run(run_array, feed_dict)\n        return result\n\n    def pretrain_main_model(self, max_epoch):\n        for epoch in range(max_epoch):\n            print('###### Epoch ' + str(epoch) + ' ######')\n            tot_correct = 0\n            tot_not_na_correct = 0\n            tot = 0\n            tot_not_na = 0\n            i = 0\n            time_sum = 0\n            \n            for i, batch_data in enumerate(self.train_data_loader):\n                time_start = time.time()\n                iter_loss, iter_logit, _train_op = self.one_step(self.sess, self.model, batch_data, [self.model.loss(), self.model.train_logit(), self.train_op])\n                time_end = time.time()\n                t = time_end - time_start\n                time_sum += t\n                iter_output = iter_logit.argmax(-1)\n                iter_label = batch_data['rel']\n                iter_correct = (iter_output == iter_label).sum()\n                iter_not_na_correct = np.logical_and(iter_output == iter_label, iter_label != 0).sum()\n                tot_correct += iter_correct\n                tot_not_na_correct += iter_not_na_correct\n                tot += iter_label.shape[0]\n                tot_not_na += (iter_label != 0).sum()\n                if tot_not_na > 0:\n                    sys.stdout.write(\"[pretrain main model] epoch %d step %d time %.2f | loss: %f, not NA accuracy: %f, accuracy: %f\\r\" % (epoch, i, t, iter_loss, float(tot_not_na_correct) / tot_not_na, float(tot_correct) / tot))\n                    sys.stdout.flush()\n                i += 1\n            print(\"\\nAverage iteration time: %f\" % (time_sum / i))\n\n    def pretrain_agent_model(self, max_epoch):\n        # Pre-train policy agent\n        for epoch in range(max_epoch):\n            print('###### [Pre-train Policy Agent] Epoch ' + str(epoch) + ' ######')\n            tot_correct = 0\n            tot_not_na_correct = 0\n            tot = 0\n            tot_not_na = 0\n            time_sum = 0\n            \n            for i, batch_data in enumerate(self.train_data_loader):\n                time_start = time.time()\n                batch_data['agent_label'] = batch_data['ins_rel'] + 0\n                batch_data['agent_label'][batch_data['agent_label'] > 0] = 1\n                iter_loss, iter_logit, _train_op = self.agent_one_step(self.sess, self.agent_model, batch_data, [self.agent_model.loss(), self.agent_model.train_logit(), self.agent_train_op])\n                time_end = time.time()\n                t = time_end - time_start\n                time_sum += t\n                iter_output = iter_logit.argmax(-1)\n                iter_label = batch_data['ins_rel']\n                iter_correct = (iter_output == iter_label).sum()\n                iter_not_na_correct = np.logical_and(iter_output == iter_label, iter_label != 0).sum()\n                tot_correct += iter_correct\n                tot_not_na_correct += iter_not_na_correct\n                tot += iter_label.shape[0]\n                tot_not_na += (iter_label != 0).sum()\n                if tot_not_na > 0:\n                    sys.stdout.write(\"[pretrain policy agent] epoch %d step %d time %.2f | loss: %f, not NA accuracy: %f, accuracy: %f\\r\" % (epoch, i, t, iter_loss, float(tot_not_na_correct) / tot_not_na, float(tot_correct) / tot))\n                    sys.stdout.flush()\n                i += 1\n\n    def train(self,\n              model, # The main model\n              agent_model, # The model of policy agent\n              model_name,\n              ckpt_dir='./checkpoint',\n              summary_dir='./summary',\n              test_result_dir='./test_result',\n              learning_rate=0.5,\n              max_epoch=60,\n              pretrain_agent_epoch=1,\n              pretrain_model=None,\n              test_epoch=1,\n              optimizer=tf.train.GradientDescentOptimizer):\n        \n        print(\"Start training...\")\n        \n        # Init\n        self.model = model(self.train_data_loader, self.train_data_loader.batch_size, self.train_data_loader.max_length)\n        model_optimizer = optimizer(learning_rate)\n        grads = model_optimizer.compute_gradients(self.model.loss())\n        self.train_op = model_optimizer.apply_gradients(grads)\n\n        # Init policy agent\n        self.agent_model = agent_model(self.train_data_loader, self.train_data_loader.batch_size, self.train_data_loader.max_length)\n        agent_optimizer = optimizer(learning_rate)\n        agent_grads = agent_optimizer.compute_gradients(self.agent_model.loss())\n        self.agent_train_op = agent_optimizer.apply_gradients(agent_grads)\n\n        # Session, writer and saver\n        self.sess = tf.Session()\n        summary_writer = tf.summary.FileWriter(summary_dir, self.sess.graph)\n        saver = tf.train.Saver(max_to_keep=None)\n        if pretrain_model is None:\n            self.sess.run(tf.global_variables_initializer())\n        else:\n            saver.restore(self.sess, pretrain_model)\n\n        self.pretrain_main_model(max_epoch=5) # Pre-train main model\n        self.pretrain_agent_model(max_epoch=1) # Pre-train policy agent \n\n        # Train\n        tot_delete = 0\n        batch_count = 0\n        instance_count = 0\n        reward = 0.0\n        best_metric = 0\n        best_prec = None\n        best_recall = None\n        not_best_count = 0 # Stop training after several epochs without improvement.\n        for epoch in range(max_epoch):\n            print('###### Epoch ' + str(epoch) + ' ######')\n            tot_correct = 0\n            tot_not_na_correct = 0\n            tot = 0\n            tot_not_na = 0\n            i = 0\n            time_sum = 0\n            batch_stack = []\n           \n            # Update policy agent\n            for i, batch_data in enumerate(self.train_data_loader):\n                # Make action\n                batch_data['agent_label'] = batch_data['ins_rel'] + 0\n                batch_data['agent_label'][batch_data['agent_label'] > 0] = 1\n                batch_stack.append(batch_data)\n                iter_logit = self.agent_one_step(self.sess, self.agent_model, batch_data, [self.agent_model.train_logit()])[0]\n                action_result = iter_logit.argmax(-1)\n                \n                # Calculate reward\n                batch_delete = np.sum(np.logical_and(batch_data['ins_rel'] != 0, action_result == 0))\n                batch_data['ins_rel'][action_result == 0] = 0\n                iter_loss = self.one_step(self.sess, self.model, batch_data, [self.model.loss()])[0]\n                reward += iter_loss\n                tot_delete += batch_delete\n                batch_count += 1\n\n                # Update parameters of policy agent\n                alpha = 0.1\n                if batch_count == 100:\n                    reward = reward / float(batch_count)\n                    average_loss = reward\n                    reward = - math.log(1 - math.e ** (-reward))\n                    sys.stdout.write('tot delete : %f | reward : %f | average loss : %f\\r' % (tot_delete, reward, average_loss))\n                    sys.stdout.flush()\n                    for batch_data in batch_stack:\n                        self.agent_one_step(self.sess, self.agent_model, batch_data, [self.agent_train_op], weights=reward * alpha)\n                    batch_count = 0\n                    reward = 0\n                    tot_delete = 0\n                    batch_stack = []\n                i += 1\n\n            # Train the main model\n            for i, batch_data in enumerate(self.train_data_loader):\n                batch_data['agent_label'] = batch_data['ins_rel'] + 0\n                batch_data['agent_label'][batch_data['agent_label'] > 0] = 1\n                time_start = time.time()\n\n                # Make actions\n                iter_logit = self.agent_one_step(self.sess, self.agent_model, batch_data, [self.agent_model.train_logit()])[0]\n                action_result = iter_logit.argmax(-1)\n                batch_data['ins_rel'][action_result == 0] = 0\n                \n                # Real training\n                iter_loss, iter_logit, _train_op = self.agent_one_step(self.sess, self.agent_model, batch_data, [self.agent_model.loss(), self.agent_model.train_logit(), self.agent_train_op])\n                time_end = time.time()\n                t = time_end - time_start\n                time_sum += t\n                iter_output = iter_logit.argmax(-1)\n                if tot_not_na > 0:\n                    sys.stdout.write(\"epoch %d step %d time %.2f | loss: %f, not NA accuracy: %f, accuracy: %f\\r\" % (epoch, i, t, iter_loss, float(tot_not_na_correct) / tot_not_na, float(tot_correct) / tot))\n                    sys.stdout.flush()\n                i += 1\n            print(\"\\nAverage iteration time: %f\" % (time_sum / i))\n\n            if (epoch + 1) % test_epoch == 0:\n                metric = self.test(model)\n                if metric > best_metric:\n                    best_metric = metric\n                    best_prec = self.cur_prec\n                    best_recall = self.cur_recall\n                    print(\"Best model, storing...\")\n                    if not os.path.isdir(ckpt_dir):\n                        os.mkdir(ckpt_dir)\n                    path = saver.save(self.sess, os.path.join(ckpt_dir, model_name))\n                    print(\"Finish storing\")\n                    not_best_count = 0\n                else:\n                    not_best_count += 1\n\n            if not_best_count >= 20:\n                break\n\n        print(\"######\")\n        print(\"Finish training \" + model_name)\n        print(\"Best epoch auc = %f\" % (best_metric))\n        if (not best_prec is None) and (not best_recall is None):\n            if not os.path.isdir(test_result_dir):\n                os.mkdir(test_result_dir)\n            np.save(os.path.join(test_result_dir, model_name + \"_x.npy\"), best_recall)\n            np.save(os.path.join(test_result_dir, model_name + \"_y.npy\"), best_prec)\n\n"
  },
  {
    "path": "test_demo.py",
    "content": "import nrekit\nimport numpy as np\nimport tensorflow as tf\nimport sys\nimport os\nimport json\n\ndataset_name = 'nyt'\nif len(sys.argv) > 1:\n    dataset_name = sys.argv[1]\ndataset_dir = os.path.join('./data', dataset_name)\nif not os.path.isdir(dataset_dir):\n    raise Exception(\"[ERROR] Dataset dir %s doesn't exist!\" % (dataset_dir))\n\n# The first 3 parameters are train / test data file name, word embedding file name and relation-id mapping file name respectively.\ntrain_loader = nrekit.data_loader.json_file_data_loader(os.path.join(dataset_dir, 'train.json'), \n                                                        os.path.join(dataset_dir, 'word_vec.json'),\n                                                        os.path.join(dataset_dir, 'rel2id.json'), \n                                                        mode=nrekit.data_loader.json_file_data_loader.MODE_RELFACT_BAG,\n                                                        shuffle=True)\ntest_loader = nrekit.data_loader.json_file_data_loader(os.path.join(dataset_dir, 'test.json'), \n                                                       os.path.join(dataset_dir, 'word_vec.json'),\n                                                       os.path.join(dataset_dir, 'rel2id.json'), \n                                                       mode=nrekit.data_loader.json_file_data_loader.MODE_ENTPAIR_BAG,\n                                                       shuffle=False)\n\nframework = nrekit.framework.re_framework(train_loader, test_loader)\n\nclass model(nrekit.framework.re_model):\n    encoder = \"pcnn\"\n    selector = \"att\"\n\n    def __init__(self, train_data_loader, batch_size, max_length=120):\n        nrekit.framework.re_model.__init__(self, train_data_loader, batch_size, max_length=max_length)\n        self.mask = tf.placeholder(dtype=tf.int32, shape=[None, max_length], name=\"mask\")\n        \n        # Embedding\n        x = nrekit.network.embedding.word_position_embedding(self.word, self.word_vec_mat, self.pos1, self.pos2)\n\n        # Encoder\n        if model.encoder == \"pcnn\":\n            x_train = nrekit.network.encoder.pcnn(x, self.mask, keep_prob=0.5)\n            x_test = nrekit.network.encoder.pcnn(x, self.mask, keep_prob=1.0)\n        elif model.encoder == \"cnn\":\n            x_train = nrekit.network.encoder.cnn(x, keep_prob=0.5)\n            x_test = nrekit.network.encoder.cnn(x, keep_prob=1.0)\n        elif model.encoder == \"rnn\":\n            x_train = nrekit.network.encoder.rnn(x, self.length, keep_prob=0.5)\n            x_test = nrekit.network.encoder.rnn(x, self.length, keep_prob=1.0)\n        elif model.encoder == \"birnn\":\n            x_train = nrekit.network.encoder.birnn(x, self.length, keep_prob=0.5)\n            x_test = nrekit.network.encoder.birnn(x, self.length, keep_prob=1.0)\n        else:\n            raise NotImplementedError\n\n        # Selector\n        if model.selector == \"att\":\n            self._train_logit, train_repre = nrekit.network.selector.bag_attention(x_train, self.scope, self.ins_label, self.rel_tot, True, keep_prob=0.5)\n            self._test_logit, test_repre = nrekit.network.selector.bag_attention(x_test, self.scope, self.ins_label, self.rel_tot, False, keep_prob=1.0)\n        elif model.selector == \"ave\":\n            self._train_logit, train_repre = nrekit.network.selector.bag_average(x_train, self.scope, self.rel_tot, keep_prob=0.5)\n            self._test_logit, test_repre = nrekit.network.selector.bag_average(x_test, self.scope, self.rel_tot, keep_prob=1.0)\n            self._test_logit = tf.nn.softmax(self._test_logit)\n        elif model.selector == \"max\":\n            self._train_logit, train_repre = nrekit.network.selector.bag_maximum(x_train, self.scope, self.ins_label, self.rel_tot, True, keep_prob=0.5)\n            self._test_logit, test_repre = nrekit.network.selector.bag_maximum(x_test, self.scope, self.ins_label, self.rel_tot, False, keep_prob=1.0)\n            self._test_logit = tf.nn.softmax(self._test_logit)\n        else:\n            raise NotImplementedError\n        \n        # Classifier\n        self._loss = nrekit.network.classifier.softmax_cross_entropy(self._train_logit, self.label, self.rel_tot, weights_table=self.get_weights())\n \n    def loss(self):\n        return self._loss\n\n    def train_logit(self):\n        return self._train_logit\n\n    def test_logit(self):\n        return self._test_logit\n\n    def get_weights(self):\n        with tf.variable_scope(\"weights_table\", reuse=tf.AUTO_REUSE):\n            print(\"Calculating weights_table...\")\n            _weights_table = np.zeros((self.rel_tot), dtype=np.float32)\n            for i in range(len(self.train_data_loader.data_rel)):\n                _weights_table[self.train_data_loader.data_rel[i]] += 1.0 \n            _weights_table = 1 / (_weights_table ** 0.05)\n            weights_table = tf.get_variable(name='weights_table', dtype=tf.float32, trainable=False, initializer=_weights_table)\n            print(\"Finish calculating\")\n        return weights_table\n\nif len(sys.argv) > 2:\n    model.encoder = sys.argv[2]\nif len(sys.argv) > 3:\n    model.selector = sys.argv[3]\n\nauc, pred_result = framework.test(model, ckpt=\"./checkpoint/\" + dataset_name + \"_\" + model.encoder + \"_\" + model.selector, return_result=True)\n\nwith open('./test_result/' + dataset_name + \"_\" + model.encoder + \"_\" + model.selector + \"_pred.json\", 'w') as outfile:\n    json.dump(pred_result, outfile)\n\n"
  },
  {
    "path": "train_demo.py",
    "content": "import nrekit\nimport numpy as np\nimport tensorflow as tf\nimport sys\nimport os\n\ndataset_name = 'nyt'\nif len(sys.argv) > 1:\n    dataset_name = sys.argv[1]\ndataset_dir = os.path.join('./data', dataset_name)\nif not os.path.isdir(dataset_dir):\n    raise Exception(\"[ERROR] Dataset dir %s doesn't exist!\" % (dataset_dir))\n\n# The first 3 parameters are train / test data file name, word embedding file name and relation-id mapping file name respectively.\ntrain_loader = nrekit.data_loader.json_file_data_loader(os.path.join(dataset_dir, 'train.json'), \n                                                        os.path.join(dataset_dir, 'word_vec.json'),\n                                                        os.path.join(dataset_dir, 'rel2id.json'), \n                                                        mode=nrekit.data_loader.json_file_data_loader.MODE_RELFACT_BAG,\n                                                        shuffle=True)\ntest_loader = nrekit.data_loader.json_file_data_loader(os.path.join(dataset_dir, 'test.json'), \n                                                       os.path.join(dataset_dir, 'word_vec.json'),\n                                                       os.path.join(dataset_dir, 'rel2id.json'), \n                                                       mode=nrekit.data_loader.json_file_data_loader.MODE_ENTPAIR_BAG,\n                                                       shuffle=False)\n\nframework = nrekit.framework.re_framework(train_loader, test_loader)\n\nclass model(nrekit.framework.re_model):\n    encoder = \"pcnn\"\n    selector = \"att\"\n\n    def __init__(self, train_data_loader, batch_size, max_length=120):\n        nrekit.framework.re_model.__init__(self, train_data_loader, batch_size, max_length=max_length)\n        self.mask = tf.placeholder(dtype=tf.int32, shape=[None, max_length], name=\"mask\")\n        \n        # Embedding\n        x = nrekit.network.embedding.word_position_embedding(self.word, self.word_vec_mat, self.pos1, self.pos2)\n\n        # Encoder\n        if model.encoder == \"pcnn\":\n            x_train = nrekit.network.encoder.pcnn(x, self.mask, keep_prob=0.5)\n            x_test = nrekit.network.encoder.pcnn(x, self.mask, keep_prob=1.0)\n        elif model.encoder == \"cnn\":\n            x_train = nrekit.network.encoder.cnn(x, keep_prob=0.5)\n            x_test = nrekit.network.encoder.cnn(x, keep_prob=1.0)\n        elif model.encoder == \"rnn\":\n            x_train = nrekit.network.encoder.rnn(x, self.length, keep_prob=0.5)\n            x_test = nrekit.network.encoder.rnn(x, self.length, keep_prob=1.0)\n        elif model.encoder == \"birnn\":\n            x_train = nrekit.network.encoder.birnn(x, self.length, keep_prob=0.5)\n            x_test = nrekit.network.encoder.birnn(x, self.length, keep_prob=1.0)\n        else:\n            raise NotImplementedError\n\n        # Selector\n        if model.selector == \"att\":\n            self._train_logit, train_repre = nrekit.network.selector.bag_attention(x_train, self.scope, self.ins_label, self.rel_tot, True, keep_prob=0.5)\n            self._test_logit, test_repre = nrekit.network.selector.bag_attention(x_test, self.scope, self.ins_label, self.rel_tot, False, keep_prob=1.0)\n        elif model.selector == \"ave\":\n            self._train_logit, train_repre = nrekit.network.selector.bag_average(x_train, self.scope, self.rel_tot, keep_prob=0.5)\n            self._test_logit, test_repre = nrekit.network.selector.bag_average(x_test, self.scope, self.rel_tot, keep_prob=1.0)\n            self._test_logit = tf.nn.softmax(self._test_logit)\n        elif model.selector == \"one\":\n            self._train_logit, train_repre = nrekit.network.selector.bag_one(x_train, self.scope, self.label, self.rel_tot, True, keep_prob=0.5)\n            self._test_logit, test_repre = nrekit.network.selector.bag_one(x_test, self.scope, self.label, self.rel_tot, False, keep_prob=1.0)\n            self._test_logit = tf.nn.softmax(self._test_logit)\n        elif model.selector == \"cross_max\":\n            self._train_logit, train_repre = nrekit.network.selector.bag_cross_max(x_train, self.scope, self.rel_tot, keep_prob=0.5)\n            self._test_logit, test_repre = nrekit.network.selector.bag_cross_max(x_test, self.scope, self.rel_tot, keep_prob=1.0)\n            self._test_logit = tf.nn.softmax(self._test_logit)\n        else:\n            raise NotImplementedError\n        \n        # Classifier\n        self._loss = nrekit.network.classifier.softmax_cross_entropy(self._train_logit, self.label, self.rel_tot, weights_table=self.get_weights())\n \n    def loss(self):\n        return self._loss\n\n    def train_logit(self):\n        return self._train_logit\n\n    def test_logit(self):\n        return self._test_logit\n\n    def get_weights(self):\n        with tf.variable_scope(\"weights_table\", reuse=tf.AUTO_REUSE):\n            print(\"Calculating weights_table...\")\n            _weights_table = np.zeros((self.rel_tot), dtype=np.float32)\n            for i in range(len(self.train_data_loader.data_rel)):\n                _weights_table[self.train_data_loader.data_rel[i]] += 1.0 \n            _weights_table = 1 / (_weights_table ** 0.05)\n            weights_table = tf.get_variable(name='weights_table', dtype=tf.float32, trainable=False, initializer=_weights_table)\n            print(\"Finish calculating\")\n        return weights_table\n\nuse_rl = False\nif len(sys.argv) > 2:\n    model.encoder = sys.argv[2]\nif len(sys.argv) > 3:\n    model.selector = sys.argv[3]\nif len(sys.argv) > 4:\n    if sys.argv[4] == 'rl':\n        use_rl = True\n\nif use_rl:\n    rl_framework = nrekit.rl.rl_re_framework(train_loader, test_loader)\n    rl_framework.train(model, nrekit.rl.policy_agent, model_name=dataset_name + \"_\" + model.encoder + \"_\" + model.selector + \"_rl\", max_epoch=60, ckpt_dir=\"checkpoint\")\nelse:\n    framework.train(model, model_name=dataset_name + \"_\" + model.encoder + \"_\" + model.selector, max_epoch=60, ckpt_dir=\"checkpoint\", gpu_nums=1)\n"
  }
]