Repository: xiaolalala/Distant-Supervised-Chinese-Relation-Extraction
Branch: master
Commit: d0eb40a049b0
Files: 19
Total size: 111.5 KB
Directory structure:
gitextract_idaztadn/
├── .gitignore
├── LICENSE
├── README.md
├── draw_plot.py
├── kg_data/
│ ├── EntityMatcher.py
│ ├── README.md
│ ├── SentenceSegment.py
│ ├── add_relation.ipynb
│ ├── data_process.ipynb
│ └── stop_word.txt
├── nrekit/
│ ├── data_loader.py
│ ├── framework.py
│ ├── network/
│ │ ├── classifier.py
│ │ ├── embedding.py
│ │ ├── encoder.py
│ │ └── selector.py
│ └── rl.py
├── test_demo.py
└── train_demo.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
.ipynb_checkpoints/
kg_data/processed/
kg_data/baike_triples.txt
kg_data/baiketriples.zip
kg_data/.ipynb_checkpoints/
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2018 Tianyu Gao
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# Distant-Supervised-Chinese-Relation-Extraction
## 基于远监督的中文关系抽取
### 数据集构建
* 中文通用知识库CN-DBpedia
* 远监督假设
处理流程可在 kg_data/README.md 中查看。点击[此处(谷歌云盘)](https://drive.google.com/open?id=1XmWW3-wveKiJFauqZTPcSbsgh4lECRBK)下载处理后的数据子集。
### 模型选择
使用 thunlp/OpenNRE 的模型, 具体信息参考其说明。
**源链接:** https://github.com/thunlp/OpenNRE
### 运行代码
数据集文件目录代码默认为 data/chinese,在命令中运行:
```
python train_demo.py chinese pcnn att
```
### 模型结果
部分关系的结果如下:
类别|精准度|召回率|F1分数
:-:|:-:|:-:|:-:
**全部**|**0.95428**|**0.95036**|**0.95232**
/人物/其它/民族|0.98374|0.979|0.98137
NA|0.96853|0.97824|0.97336
/人物/地点/国籍|0.84075|0.92673|0.88164
/组织/地点/位于|0.85157|0.83652|0.84398
/人物/其它/职业|0.86121|0.8037|0.83147
/人物/组织/毕业于|0.84137|0.78092|0.81002
/组织/人物/校长|0.94118|0.59259|0.72727
/人物/地点/出生地|0.81049|0.49028|0.61097
/人物/人物/家庭成员|0.65385|0.37778|0.47887
/人物/组织/属于|0.99999|0.11364|0.20408
/地点/地点/包含|0.99999|0.0625|0.11765
/组织/人物/创始人|0.99999|0.05882|0.11111
某些关系的召回率很低,分析发现原因可能是数据集中该关系的样本非常少。
================================================
FILE: draw_plot.py
================================================
import sklearn.metrics
import matplotlib
# Use 'Agg' so this program could run on a remote server
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import sys
import os
result_dir = './test_result'
def main():
models = sys.argv[1:]
for model in models:
x = np.load(os.path.join(result_dir, model +'_x' + '.npy'))
y = np.load(os.path.join(result_dir, model + '_y' + '.npy'))
f1 = (2 * x * y / (x + y + 1e-20)).max()
auc = sklearn.metrics.auc(x=x, y=y)
#plt.plot(x, y, lw=2, label=model + '-auc='+str(auc))
plt.plot(x, y, lw=2, label=model)
print(model + ' : ' + 'auc = ' + str(auc) + ' | ' + 'max F1 = ' + str(f1))
print(' P@100: {} | P@200: {} | P@300: {} | Mean: {}'.format(y[100], y[200], y[300], (y[100] + y[200] + y[300]) / 3))
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.ylim([0.3, 1.0])
plt.xlim([0.0, 0.4])
plt.title('Precision-Recall')
plt.legend(loc="upper right")
plt.grid(True)
plt.savefig(os.path.join(result_dir, 'pr_curve'))
if __name__ == "__main__":
main()
================================================
FILE: kg_data/EntityMatcher.py
================================================
import pickle
import multiprocessing
import time
import os
class EntityMatcher:
def __init__(self, entity_file, sentences_folder, process_num):
self.process_num = process_num
with open(entity_file, 'rb') as f:
entity_dict = pickle.load(f)
self.entities = list(set(list(entity_dict.keys())))
# sentence files
self.sentence_files = []
for root, dirs, files in os.walk(sentences_folder):
for file in files:
self.sentence_files.append(os.path.join(root, file))
def match(self, file_name):
print('Start %s!'%(file_name))
with open(file_name, 'rb') as f:
data = pickle.load(f)
new_data = []
for sen in data:
eset = []
for entity in self.entities:
if entity in sen:
eset.append(entity)
if len(eset)>1:
new_data.append([sen, eset])
print('Done %s!'%(file_name))
return new_data, file_name
def write_file(self, data):
with open(data[1], 'wb') as f:
pickle.dump(data[0], f)
def run(self):
pool = multiprocessing.Pool(processes=self.process_num)
for one_file in self.sentence_files:
pool.apply_async(self.match, args=(str(one_file), ), callback=self.write_file)
pool.close()
pool.join()
if __name__ == "__main__":
entity_file = 'processed/entities.pkl'
sentences_folder = 'processed/sentences'
process_num = 8
st = time.localtime()
print('\n开始时间: ')
print(time.strftime("%Y-%m-%d %H:%M:%S", st))
em = EntityMatcher(entity_file, sentences_folder, process_num)
em.run()
ed = time.localtime()
print('结束时间: ')
print (time.strftime("%Y-%m-%d %H:%M:%S", ed))
================================================
FILE: kg_data/README.md
================================================
# 远监督数据集构造流程

## 运行顺序
0. 下载原始数据, 解压后放在该目录下, 即 `kg_data/baike_triples.txt`
1. 在kg_data目录下运行jupyter notebook
```
jupyter notebook .
```
2. 按顺序执行 data_process.ipynb
3. 匹配实体
```
python EntityMatcher.py
```
4. 分词
```
python SentenceSegment.py
```
5. 按顺序执行 add_relation.ipynb
## 原始数据
* 数据来源
原始数据采用了中文通用百科知识图谱(CN-DBpedia)公开的部分数据, 包含900万+的百科实体以及6600万+的三元组关系。其中摘要信息400万+, 标签信息1980万+, infobox信息4100万+。
**下载地址:** http://www.openkg.cn/dataset/cndbpedia
**源链接:** http://kw.fudan.edu.cn/cndbpedia
* 数据格式
下载压缩包后解压为baike_triples.txt文件, 文件的每一行为一个三元组。第一个元素为实体名称, 第二个元素为关系或属性指代词, 第三个元素为其对应的值。
```
"1+8"时代广场 中文名 "1+8"时代广场
"1+8"时代广场 地点 咸宁大道与银泉大道交叉口
"1+8"时代广场 实质 城市综合体项目
"1+8"时代广场 总建面 约11.28万方
北京 中文名称 北京
北京 BaiduTAG 北京, 简称“京”, 是中华人民共和国省级行政区、首都、直辖市, 是全国的政治、文化中心。北京地处中国华北地区, 中心位于东经116°20′、北纬39°56′, 东与天津毗连, 其余均与河北相邻, 北京市总面积16410.54平方千米。
北京 所属地区 中国华北地区
"1.4"河南兰考火灾事故 地点 河南<a>兰考县</a>城关镇
"1.4"河南兰考火灾事故 时间 2013年1月4日
"1.4"河南兰考火灾事故 结果 一人重伤
```
## 构建实体字典
原始数据的每行第一个元素为实体, **根据正则表达式筛选全部为中文字符的实体**, 转换为字典格式。
处理后数据格式如下:
```python
{
"北京":{
"中文名称": "北京",
"BaiduCard": "北京, 简称“京”, 是中华人民共和国省级行政区、首都、直辖市, 是全国的政治、文化中心。北京地处中国华北地区, 中心位于东经116°20′、北纬39°56′, 东与天津毗连, 其余均与河北相邻, 北京市总面积16410.54平方千米。",
"所属地区": "中国华北地区"
},
...
}
```
## 获取句子集合、句子预处理
实体的`BaiduCard`属性为实体的百度百科简介, 通常为多个句子。根据实体字典获取句子集合, 存为列表格式。
对所有句子进行预处理, **去除所有中文字符、中文常用标点之外的所有字符, 并对多个句子进行拆分**, 存为列表格式。
处理后数据格式如下:
```python
[
"北京, 简称京, 是中华人民共和国省级行政区、首都、直辖市, 是全国的政治、文化中心。",
"北京地处中国华北地区, 中心位于东经、北纬, 东与天津毗连, 其余均与河北相邻, 北京市总面积平方千米。",
...
]
```
## 句子匹配实体
对每一个句子, 遍历实体集合, 根据**字符串匹配**保存所有出现在句子中的实体。**过滤掉没有实体或仅有一个实体出现的句子**, 数据处理为`[[sentence, [entity1,...]], ...]`的格式。
处理后数据格式如下:
```python
[
[
"北京, 简称京, 是中华人民共和国省级行政区、首都、直辖市, 是全国的政治、文化中心。",
[
"北京",
"中华人民共和国",
"政治"
]
],
...
]
```
## 句子分词
使用`Python`的`jieba`库进行中文分词, 对中文句子进行分词。将数据处理为`[[sentence, [entity1,...], [sentence_seg]], ...]`的格式。
收集所有分词后的句子, 作为语料库使用`Python`的`word2vec`库训练词向量。
**jieba 使用教程:** https://github.com/fxsjy/jieba
**word2vec 使用教程:** https://radimrehurek.com/gensim/models/word2vec.html
### 定义用户字典
为防止实体被错误分词, 将所有实体(实体字典的键集合)写入到文件`dict.txt`作为用户字典。
### 定义停用词
定义文件`stop_word.txt`, 在分词过程中对句子去除中文停用词。(网上资源较多)
### 训练词向量
处理后数据格式如下:
```python
[
[
"北京, 简称京, 是中华人民共和国省级行政区、首都、直辖市, 是全国的政治、文化中心。",
[
"北京",
"中华人民共和国",
"政治"
],
[
"北京",
"简称",
"京",
...
]
],
...
]
```
## 句子、实体对筛选
对分词后的句子重新对实体进行筛选, 对每一个句子的实体列表中的实体, 若其没有在分词后的句子中出现, 则去除该实体。对筛选后的实体集合两两组合, 数据处理为`[[sentence, entity_head, entity_tail, [sentence_seg]]]`的格式。(一个句子可能被用于多个样本。)
此处先对句子匹配实体, 去除不符合条件的句子后然后分词, 再用分词后的句子匹配实体的主要原因是:
1. 某些实体名称可能是另一实体的子集, 如“北京”和“北京大学”。在句子“北京大学是中国的著名大学。”中, 出现的实体应仅为“北京大学”。
2. 分词时间较长, 不对句子进行初步筛选, 直接对所有句子先分词再匹配实体, 这样效率较低。
处理后数据格式如下:
```python
[
[
"北京, 简称京, 是中华人民共和国省级行政区、首都、直辖市, 是全国的政治、文化中心。",
"北京",
"中华人民共和国",
[
"北京",
"简称",
...
]
],
[
"北京, 简称京, 是中华人民共和国省级行政区、首都、直辖市, 是全国的政治、文化中心。",
"北京",
"政治",
[
"北京",
"简称",
...
]
],
...
]
```
## 添加关系标签
根据对原始数据集的分析, 人工预定义了23种出现频率较高关系, 见附录1, 其中'NA'表示两实体没有关系或存在其他关系。同时, 原始数据中的关系/属性并没有对齐(如妻子、夫人对应同一种关系), 人工编写规则对关系对齐、聚合。
遍历上一步的每一条数据, 根据实体字典和人工定义的关系对齐表进行关系标注。数据处理为`[[sentence, entity_head, entity_tail, relation, [sentence_seg]]]`的格式。
处理后数据格式如下:
```python
[
[
"北京, 简称京, 是中华人民共和国省级行政区、首都、直辖市, 是全国的政治、文化中心。",
"北京",
"中华人民共和国",
"/地点/地点/首都"
[
"北京",
"简称",
...
]
],
...
]
```
## 数据集格式转化
对数据进行格式转化, 并添加'id'、'type'或其它等等属性。数据处理为:
```python
[
{
"head": {
"word": "北京",
"id": "666",
...
},
"tail": {
"word": "中华人名共和国",
"id": "6",
...
},
"relation": "/地点/地点/首都",
"sentence": "北京 简称 京 是 中华人民共和国 省级 行政区 首都 直辖市 是 全国 的 政治 文化 中心",
...
}
]
```
## 划分训练集、测试集
每种标签按照3:1的比例划分训练集、测试集。
# 一些问题
1. 句子较多, 匹配实体、分词、训练词向量时间较长(400W句子匹配8W实体, 使用8个线程约需1~2小时?), 建议先使用较少数据预测下运行时间, 使用多线程或者数据子集进行操作。
2. 部分数据清洗工作较为简单粗暴, 存在改进空间。
3. 关系种类较少, 关系对齐规则较为简单, 且原始数据中存在部分噪声(如BaiduTAG被错误分类), 数据集中存在噪声。
# 附录
**关系对齐/聚合:** 对于三元组(head, relation, tail), 其属于关系 `/人物/地点/出生地` 的条件是head属于人物类别, tail属于地点类别, relation为 `/人物/地点/出生地` 对应关系指代词集合中的某一个。
实体类别|BaiduTAG中至少含有以下类别中的一个
:-:|:-:
人物|人物、歌手、演员、作家
机构|机构、企业、公司、学校、部门、大学
地点|地点、地理、城市、国家、地区
其它| 不限制
<table>
<tr>
<th width=20px>序号</th>
<th>关系类别</th>
<th>关系指代词集合</th>
</tr>
<tr>
<td align="center" width=10px>0</td>
<td width=200px align="center">NA</td>
<td>原始数据不存在关系</td>
</tr>
<tr>
<td width=50px align="center">1</td>
<td width=200px align="center">/人物/人物/家庭成员</td>
<td>父亲、母亲、丈夫、妻子、儿子、女儿、哥哥、妹妹、姐姐、弟弟、孙子、孙女、爷爷、奶奶、外婆、外公、家人、家庭成员 ,夫人、对象、夫君</td>
</tr>
<tr>
<td width=50px align="center">2</td>
<td width=200px align="center">/人物/人物/社交关系</td>
<td> 朋友、好友、同学、合作、搭档、经纪人、师从</td>
</tr>
<tr>
<td width=50px align="center">3</td>
<td width=200px align="center">/人物/地点/出生地</td>
<td>出生地、出生于、来自、歌手出生地、作者出生地、出生在、作者出生地、出生</td>
</tr>
<tr>
<td width=50px align="center">4</td>
<td width=200px align="center">/人物/地点/居住地</td>
<td>居住地、主要居住地、居住、现居住、目前居住地、现居住于、居住地点、居住于</td>
</tr>
<tr>
<td width=50px align="center">5</td>
<td width=200px align="center">/人物/地点/国籍</td>
<td>国籍、国家</td>
</tr>
<tr>
<td width=50px align="center">6</td>
<td width=200px align="center">/人物/组织/毕业于</td>
<td>毕业院校、毕业于、毕业学院、本科毕业院校、最后毕业院校、毕业高中、毕业地点、本科毕业学校、知名校友</td>
</tr>
<tr>
<td width=50px align="center">7</td>
<td width=200px align="center">/人物/组织/属于</td>
<td>隶属单位、经纪公司、隶属关系、行政隶属、隶属学校、隶属大学、隶属地区、所属公司、签约公司、任职公司、工作单位、所属</td>
</tr>
<tr>
<td width=50px align="center">8</td>
<td width=200px align="center">/人物/其它/职业</td>
<td>职业</td>
</tr>
<tr>
<td width=50px align="center">9</td>
<td width=200px align="center">/人物/其它/民族</td>
<td>民族</td>
</tr>
<tr>
<td width=50px align="center">10</td>
<td width=200px align="center">/组织/人物/拥有者</td>
<td>拥有、拥有者</td>
</tr>
<tr>
<td width=50px align="center">11</td>
<td width=200px align="center">/组织/人物/创始人</td>
<td>创始人、创始、主要创始人、集团创始人</td>
</tr>
<tr>
<td width=50px align="center">12</td>
<td width=200px align="center">/组织/人物/校长</td>
<td>校长、现任校长、学校校长、总校长</td>
</tr>
<tr>
<td width=50px align="center">13</td>
<td width=200px align="center">/组织/人物/领导人</td>
<td>领导、现任领导、领导单位、主要领导、领导人、主要领导人</td>
</tr>
<tr>
<td width=50px align="center">14</td>
<td width=200px align="center">/组织/组织/周边</td>
<td>周围景观、周边景点</td>
</tr>
<tr>
<td width=50px align="center">15</td>
<td width=200px align="center">/组织/地点/位于</td>
<td>所属地区、国家、地区、地理位置、位于、区域、地点、总部地点、所在地、所在区域、位于城市、总部位于、酒店位于、学校位于、最早位于、地址、所在城市、城市、主要城市、坐落于</td>
</tr>
<tr>
<td width=50px align="center">16</td>
<td width=200px align="center">/地点/人物/相关人物</td>
<td>相关人物、知名人物、历史人物</td>
</tr>
<tr>
<td width=50px align="center">17</td>
<td width=200px align="center">/地点/地点/位于</td>
<td>所属地区、所属国、所属洲、所属州、所属国家、最大城市、地区、地理位置、位于、区域、地点、总部地点、所在地、所在区域、位于城市、总部位于、酒店位于、学校位于、最早位于、地址、所在城市、城市、主要城市、坐落于</td>
</tr>
<tr>
<td width=50px align="center">18</td>
<td width=200px align="center">/地点/地点/毗邻</td>
<td>毗邻、东邻、邻近行政区、相邻、紧邻、邻近、北邻、南邻、邻国</td>
</tr>
<tr>
<td width=50px align="center">19</td>
<td width=200px align="center">/地点/地点/包含</td>
<td>包含、包含国家、包含人物、下辖地区、下属、</td>
</tr>
<tr>
<td width=50px align="center">20</td>
<td width=200px align="center">/地点/地点/首都</td>
<td>首都</td>
</tr>
<tr>
<td width=50px align="center">21</td>
<td width=200px align="center">/地点/组织/景点</td>
<td>著名景点、主要景点、旅游景点、特色景点</td>
</tr>
<tr>
<td width=50px align="center">22</td>
<td width=200px align="center">/地点/其它/气候</td>
<td> 气候类型、气候条件、气候、气候带</td>
</tr>
</table>
================================================
FILE: kg_data/SentenceSegment.py
================================================
import pickle
import jieba
import os
import multiprocessing
import time
def read_txt(file_name):
txt_data = []
with open(file_name, 'r', encoding='utf8') as f:
d = f.readline()
while d:
txt_data.append(d.strip())
d = f.readline()
return txt_data
class SentenceSegment:
def __init__(self, dict_file, stop_word_file, sentences_folder, process_num):
self.process_num = process_num
self.stop_word = read_txt(stop_word_file)
jieba.load_userdict(dict_file)
# sentence files
self.sentence_files = []
for root, dirs, files in os.walk(sentences_folder):
for file in files:
self.sentence_files.append(os.path.join(root, file))
def segment(self, file_name):
with open(file_name, 'rb') as f:
data = pickle.load(f)
new_data = []
for d in data:
# sentence segment
sen_seg = []
for word in jieba.cut(d[0]):
if word not in self.stop_word:
sen_seg.append(word)
d.append(sen_seg)
# filter entities again
new_eset = []
for entity in d[1]:
if entity in sen_seg:
new_eset.append(entity)
# remove data the number of whose entity less than 2
# and rebuilt data
if len(new_eset)>1:
for i in new_eset:
for j in new_eset:
if j!=i:
new_data.append([d[0], i, j, sen_seg])
print('%s done!'%(file_name))
return new_data, file_name
def write_file(self, data):
with open(data[1], 'wb') as f:
pickle.dump(data[0], f)
def run(self):
pool = multiprocessing.Pool(processes=self.process_num)
for one_file in self.sentence_files:
pool.apply_async(self.segment, args=(str(one_file), ), callback=self.write_file)
pool.close()
pool.join()
if __name__ == "__main__":
dict_file = 'processed/entity.txt'
stop_word = 'stop_word.txt'
sentences_folder = 'processed/sentences'
process_num = 8
st = time.localtime()
ss = SentenceSegment(dict_file, stop_word, sentences_folder, process_num)
ss.run()
ed = time.localtime()
print('\n开始时间: ')
print(time.strftime("%Y-%m-%d %H:%M:%S", st))
print('结束时间: ')
print (time.strftime("%Y-%m-%d %H:%M:%S", ed))
================================================
FILE: kg_data/add_relation.ipynb
================================================
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"import json\n",
"import os\n",
"import math\n",
"from gensim.models import Word2Vec"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = []\n",
"for i in range(100):\n",
" with open('processed/sentences/sen'+str(i), 'rb') as f:\n",
" data += pickle.load(f)\n",
"len(data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open('processed/entities.pkl', 'rb') as f:\n",
" entities = pickle.load(f)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 关系标注"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# 根据规则 替换所有关系\n",
"\n",
"tgs = {\n",
" \"per\": [\"人物\", \"歌手\", \"演员\", \"作家\"],\n",
" \"org\": [\"机构\", \"企业\", \"公司\", \"学校\", \"部门\", \"大学\"],\n",
" \"pl\": [\"地点\", \"地理\", \"城市\", \"国家\", \"地区\"]\n",
"}\n",
"\n",
"config = {\n",
" # person 9\n",
" \"per2per_family_members\" : [\"父亲\",\"母亲\",\"丈夫\",\"妻子\",\"儿子\",\"女儿\",\"哥哥\",\"妹妹\",\"姐姐\",\"弟弟\",\"孙子\",\n",
" \"孙女\",\"爷爷\",\"奶奶\",\"外婆\", \"外公\",\"家人\",\"家庭成员\" ,\"夫人\",\"对象\",\"夫君\"],\n",
" \"per2per_social_members\" : [\"朋友\", \"好友\", \"同学\", \"合作\", \"搭档\", \"经纪人\", \"师从\"],\n",
"\n",
" \"per2pl_birth_place\" : [\"出生地\", \"出生于\", \"来自\", \"歌手出生地\", \"作者出生地\", \"出生在\", \"作者出生地\", \"出生\"],\n",
" \"per2pl_live_place\" : [\"居住地\", \"主要居住地\", \"居住\", \"现居住\", \"目前居住地\", \"现居住于\", \"居住地点\", \"居住于\"],\n",
" \n",
" \"per2pl_country\": [\"国籍\", \"国家\"],\n",
" \"per2org_graduate_from\" : [\"毕业院校\", \"毕业于\", \"毕业学院\", \"本科毕业院校\", \"最后毕业院校\", \"毕业高中\", \"毕业地点\", \"本科毕业学校\", \"知名校友\"],\n",
" \"per2org_belong_to\" : [\"隶属单位\", \"经纪公司\", \"隶属关系\", \"行政隶属\", \"隶属学校\", \"隶属大学\", \"隶属地区\", \"所属公司\", \"签约公司\", \"任职公司\", \"工作单位\", \"所属\"],\n",
"\n",
" \"per2oth_profession\" : ['职业'],\n",
" \"per2oth_nation\" : ['民族'],\n",
"\n",
" # orgnazition 9\n",
" \"org2per_owner\" : [\"拥有\", \"拥有者\"],\n",
" \"org2per_founder\" : [\"创始人\", \"创始\", \"主要创始人\", \"集团创始人\"],\n",
" \"org2per_school_leader\" : [\"校长\", \"现任校长\", \"学校校长\", \"总校长\"],\n",
" \"org2per_leader\" : [\"领导\", \"现任领导\", \"领导单位\", \"主要领导\", \"领导人\", \"主要领导人\"],\n",
"\n",
" \"org2org_surroundings\" : [\"周围景观\", \"周边景点\"],\n",
"\n",
"\n",
" \"org2pl_location\" : [\"所属地区\",\"国家\", \"地区\", \"地理位置\", \"位于\", \"区域\", \"地点\", \"总部地点\", \"所在地\", \"所在区域\", \"位于城市\", \"总部位于\", \"酒店位于\", \"学校位于\", \"最早位于\", \"地址\", \"所在城市\", \"城市\", \"主要城市\", \"坐落于\"],\n",
"\n",
"\n",
" # place 7\n",
" \"pl2per_main_character\" : [\"相关人物\", \"知名人物\", \"历史人物\"],\n",
"\n",
" \"pl2pl_location\" : [\"所属地区\",\"所属国\", \"所属洲\", \"所属州\", \"所属国家\", \"最大城市\", \"地区\", \"地理位置\", \"位于\", \"区域\", \"地点\", \"总部地点\", \"所在地\", \"所在区域\", \"位于城市\", \"总部位于\", \"酒店位于\", \"学校位于\", \"最早位于\", \"地址\", \"所在城市\", \"城市\", \"主要城市\", \"坐落于\"],\n",
" \"pl2pl_adjacement\" : [\"毗邻\", \"东邻\", \"邻近行政区\", \"相邻\", \"紧邻\", \"邻近\", \"北邻\", \"南邻\", \"邻国\"],\n",
" \"pl2pl_contains\" : [\"包含\", \"包含国家\", \"包含人物\", \"下辖地区\", \"下属\"],\n",
" \"pl2pl_captial\" : [\"首都\"],\n",
"\n",
" \"pl2org_sights\" : [\"著名景点\", \"主要景点\", \"旅游景点\", \"特色景点\"],\n",
" \"pl2oth_climate\" : [\"气候类型\", \"气候条件\", \"气候\", \"气候带\"],\n",
"}\n",
"\n",
"\n",
"name = {\n",
" \"per2per_family_members\": \"/人物/人物/家庭成员\",\n",
" \"per2per_social_members\": \"/人物/人物/社交关系\",\n",
"\n",
" \"per2pl_birth_place\": \"/人物/地点/出生地\",\n",
" \"per2pl_live_place\": \"/人物/地点/居住地\",\n",
" \"per2pl_country\" : \"/人物/地点/国籍\",\n",
" \"per2org_graduate_from\": \"/人物/组织/毕业于\",\n",
" \"per2org_belong_to\": \"/人物/组织/属于\",\n",
"\n",
" \"per2oth_profession\": \"/人物/其它/职业\",\n",
" \"per2oth_nation\": \"/人物/其它/民族\",\n",
"\n",
" \"org2per_owner\": \"/组织/人物/拥有者\",\n",
" \"org2per_founder\": \"/组织/人物/创始人\",\n",
" \"org2per_school_leader\": \"/组织/人物/校长\",\n",
" \"org2per_leader\": \"/组织/人物/领导人\",\n",
"\n",
" \"org2org_surroundings\": \"/组织/组织/周边\",\n",
"\n",
"\n",
" \"org2pl_location\": \"/组织/地点/位于\",\n",
"\n",
"\n",
" \"pl2per_main_character\": \"/地点/人物/相关人物\",\n",
"\n",
" \"pl2pl_location\": \"/地点/地点/位于\",\n",
" \"pl2pl_adjacement\": \"/地点/地点/毗邻\",\n",
" \"pl2pl_contains\": \"/地点/地点/包含\",\n",
" \"pl2pl_captial\": \"/地点/地点/首都\",\n",
"\n",
" \"pl2org_sights\": \"/地点/组织/景点\",\n",
" \"pl2oth_climate\": \"/地点/其它/气候\"\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def check(string, tgs):\n",
" for t in tgs:\n",
" if t in string:\n",
" return True\n",
" return False"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# 标注\n",
"processed_data = []\n",
"for can in data:\n",
" # can [sentence, head, tail, segment]\n",
" sentence, head, tail, segment = can\n",
" if tail not in entities[can[1]].values():\n",
" can.append('NA')\n",
" processed_data.append(can)\n",
" continue\n",
" rel = ''\n",
" for key, value in entities[can[1]].items():\n",
" if value==tail:\n",
" rel = key\n",
" \n",
" for key, value in config.items():\n",
" if rel in value:\n",
" tp = key.split('_')[0].split('2')\n",
" if check(entities[can[1]].get('BaiduTAG', \"\"), tgs[tp[0]]):\n",
" if tp[1]=='oth':\n",
" can.append(name[key])\n",
" processed_data.append(can)\n",
" elif check(entities[can[2]].get('BaiduTAG', \"\"), tgs[tp[1]]):\n",
" can.append(name[key])\n",
" processed_data.append(can)\n",
" break\n",
"len(processed_data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"e2id = {}\n",
"count = 0\n",
"e_set = set()\n",
"for i in processed_data:\n",
" e_set.add(i[1])\n",
" e_set.add(i[2])\n",
"for e in e_set:\n",
" e2id[e] = count\n",
" count += 1"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"total_data = []\n",
"for d in processed_data:\n",
" total_data.append({\n",
" 'head':{\n",
" 'word': d[1],\n",
" 'id': str(e2id[d[1]])\n",
" },\n",
" 'relation': d[-1],\n",
" 'tail': {\n",
" 'word': d[2],\n",
" 'id': str(e2id[d[2]])\n",
" },\n",
" 'sentence': ' '.join(d[-2]),\n",
" 'ori_sen': d[0],\n",
" 'sen_seg': d[-2]\n",
" })"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"rl = {}\n",
"for i in total_data:\n",
" rl[i['relation']] = rl.get(i['relation'], 0) + 1\n",
"\n",
"trl = {}\n",
"record = {}\n",
"for k,v in rl.items():\n",
" trl[k] = int(max(math.floor(v*0.25), 1))\n",
" record[k] = 0"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"train = []\n",
"test = []\n",
"for i in total_data:\n",
" if record[i['relation']]<=trl[i['relation']]:\n",
" test.append(i)\n",
" record[i['relation']] += 1\n",
" else:\n",
" train.append(i)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if not os.path.isdir('../data'):\n",
" os.mkdir('../data')\n",
"if not os.path.isdir('../data/chinese'):\n",
" os.mkdir('../data/chinese')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"with open('../data/chinese/train.json', 'w', encoding='utf8') as f:\n",
" json.dump(train, f)\n",
"with open('../data/chinese/test.json', 'w', encoding='utf8') as f:\n",
" json.dump(test, f)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"count = 0\n",
"r2id = {}\n",
"for k in list(rl.keys()):\n",
" r2id[k] = count\n",
" count +=1\n",
"with open('../data/chinese/rel2id.json', 'w', encoding='utf8') as f:\n",
" json.dump(r2id, f)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 训练词向量"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"senlist = []\n",
"for d in data:\n",
" senlist.append(d[-1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = Word2Vec(senlist, sg=5, min_count=1, size=50, workers=4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"w2v = {}\n",
"for i in model.wv.index2word:\n",
" w2v[i] = model[i]\n",
"len(w2v)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"new_w2v = []\n",
"for word, vec in w2v.items():\n",
" new_w2v.append({\n",
" 'word': word,\n",
" 'vec': [float(i) for i in vec]\n",
" })"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"with open('../data/chinese/word_vec.json', 'w', encoding='utf8') as f:\n",
" json.dump(new_w2v, f)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [default]",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
================================================
FILE: kg_data/data_process.ipynb
================================================
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"import json\n",
"import re\n",
"import os\n",
"import math\n",
"from tqdm import tqdm"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1.读取初始数据"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data = []\n",
"with open('baike_triples.txt', 'r', encoding='utf8') as f:\n",
" for line in tqdm(f):\n",
" data.append(line.strip().split('\\t'))\n",
"print(len(data), data[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"2.保留元素全为中文的三元组"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"all_chinese = re.compile('^[\\u4e00-\\u9fa5]*$')\n",
"new_data = []\n",
"for triple in tqdm(data):\n",
" if bool(re.search(all_chinese, triple[0])) and bool(re.search(all_chinese, triple[1])) and bool(re.search(all_chinese, triple[2])):\n",
" new_data.append(triple)\n",
"len(new_data)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"3.实体字典"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"entities = {}\n",
"for triple in tqdm(new_data):\n",
" entities[triple[0]] = entities.get(triple[0], {})\n",
" entities[triple[0]][triple[1]] = triple[2]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if not os.path.isdir('processed'):\n",
" os.mkdir('processed')\n",
"with open('processed/entities.pkl', 'wb') as f:\n",
" pickle.dump(entities, f)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"e_set = set()\n",
"with open('processed/entity.txt', 'w', encoding='utf8') as f:\n",
" for e in tqdm(entities.keys()):\n",
" if e not in e_set:\n",
" e_set.add(e)\n",
" f.write(e+'\\n')\n",
"len(e_set)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"4.获取句子集合"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def process_sentence(sent):\n",
" sent = re.sub(r' +', ' ', sent)\n",
" sent = re.sub(r'[^\\u4e00-\\u9fa5,\\?\\!,。?::!、;\\(\\)() ]', '', sent)\n",
" sent = re.split(r'[\\?\\!。?!]', sent)\n",
" return sent"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sents = []\n",
"for triple in tqdm(data):\n",
" if triple[1]=='BaiduCARD':\n",
" sent = process_sentence(triple[2])\n",
" if sent != \"\":\n",
" sents.append(sent)\n",
"print(len(sents), sents[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"5.句子存为100个文件"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if not os.path.isdir('processed/sentences'):\n",
" os.mkdir('processed/sentences')\n",
"\n",
" begin = 0\n",
"count = math.ceil(len(sents)*1.0/100)\n",
"end = min(begin+count, len(sents))\n",
"for i in tqdm(range(100)):\n",
" with open('processed/sentences/sen'+str(i), 'wb') as f:\n",
" pickle.dump(sents[begin:end], f)\n",
" begin = end\n",
" end = min(begin+count, len(sents))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [default]",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
================================================
FILE: kg_data/stop_word.txt
================================================
可能
起
便于
有些
上述
纯粹
尽然
乃至于
极为
也
并不
其次
矣哉
为何
到头来
此外
[①e]
8
就是了
故此
移动
旁人
立刻
路经
使
元/吨
刚
归齐
归根结底
!
〕
迅速
获得
4
已
冒
哎
已矣
哪天
这么点儿
哪样
一些
甚且
起头
大
默默地
⑩
跟
逐步
它
拦腰
不仅仅
里面
同
设或
要么
以後
—
)÷(1-
扩大
它的
精光
归
战斗
成为
谁知
这般
此后
尽管如此
或多或少
{-
替
照
千万
人
连日
这么些
广泛
保险
到头
把
↑
前后
最近
他
瑟瑟
一片
正值
没有
而且
云云
正常
快要
从新
了
窃
光是
长此下去
挨家挨户
届时
而论
大批
“
一面
之前
大多数
较比
并肩
[⑧]
昂然
各位
[①④]
正是
在
自
活
最后
那些
不免
[①⑨]
设若
立即
蛮
尽心竭力
联系
叫做
具有
宣布
既
毕竟
那般
先生
啊呀
表明
.
并没有
切莫
一番
倒是
[②⑤]
以外
一来
不够
赶快
二
另行
尽可能
嗡嗡
另
格外
[②a]
毫无保留地
彻底
不独
经过
大家
即令
愤然
于
[③h]
[①g]
个别
接连不断
逐渐
$
简而言之
呜呼
适应
譬如
某某
庶几
背地里
具体地说
自打
仍旧
突然
趁热
要求
>
怪
常言说得好
哎呀
临到
强烈
”
--
此次
不可抗拒
另外
能
"
凑巧
不仅
唉
傥然
穷年累月
.一
为
充其极
敢于
'
单纯
起见
除了
如何
={
来说
企图
嘿
必将
毋宁
有所
那样
均
次第
挨个
失去
最好
上面
能够
举行
它是
出现
叮当
挨次
…………………………………………………③
看来
动辄
趁机
亲自
才能
亲手
与
亲眼
有的是
不定
但愿
由
内
下面
当前
必须
全都
共
这么
率然
不日
最
--
不足
几度
不如
凭借
猛然
看起来
俺们
暗地里
绝非
[②j]
而又
依靠
如同
紧接着
日见
[②e]
古来
依
即若
[④e]
多
看看
那么些
怎么
而是
各级
\
以来
如常
非徒
对比
以及
继后
所有
$
不但...而且
何处
才
清楚
居然
难说
或许
除此
按时
}
:
成年
来着
别处
何时
喽
继之
除去
成年累月
[③d]
变成
一何
每
除开
除
本身
偏偏
而
饱
别是
只怕
哪些
一时
倒不如说
到处
本地
反过来
@
自己
──
成心
策略地
平素
①
φ
全面
粗
等
然后
为什么
什麽
必
般的
哩
任凭
难怪
权时
exp
何苦
一转眼
出于
只限
假如
至若
藉以
只
突出
)
常言道
6
适用
无宁
任
陈年
理应
共同
以故
简言之
当真
欤
要
各
并不是
结合
安全
莫若
常
矣
按说
砰
因
不外
允许
不问
当庭
望
[②i]
的确
多数
咱们
若
尚且
高低
继而
这点
Ⅲ
不管怎样
只消
皆可
何以
与其
+ξ
各个
暗自
[①h]
以至于
惯常
有
这
相当
从速
以为
岂但
不外乎
焉
直接
[⑦]
仅仅
通过
也是
差一点
恰似
尔后
进入
除却
彼时
果真
赖以
但凡
所在
是以
尽早
真是
每当
出去
万
sup
一旦
一切
相反
即便
也好
截至
无论
朝着
_
;
八
真正
趁便
保管
要不
多么
啷当
咦
反映
|
两者
综上所述
于是乎
充分
呀
属于
诸
截然
倘若
[⑥]
从此
略
明显
互
齐
之类
难得
具体
恍然
大体
时候
趁着
遇到
按
老
再有
引起
7
若非
决不
[①B]
距
高兴
不可
即刻
绝对
运用
难道
极端
非常
[①⑥]
认真
末##末
哟
而况
其他
针对
历
大事
[③b]
类如
何妨
无法
%
总的说来
同一
/
:
上来
今
达到
或者
接下来
呼啦
比方
说明
[③e]
各地
谨
据此
各式
伟大
传
既往
2
b]
不得
不敢
是
这会儿
一则通过
呐
近年来
之一
下列
比及
防止
么
第二
6
何
[①③]
最大
从未
何须
颇
整个
不惟
等等
殆
如此
此间
争取
觉得
+
直到
得起
一个
越是
自个儿
即使
[②⑧]
独
或是
往往
率尔
零
{
[③①]
像
相信
较之
纵令
尤其
乘
好象
反倒是
随著
一样
是不是
绝不
更加
哪边
、
从古至今
日渐
将近
}>
下
偶尔
[①⑧]
除此以外
呕
一方面
当口儿
良好
亲口
待到
———
加入
犹且
且不说
即如
经
——
即是说
与其说
-
不可开交
方
(
×××
亲身
据
要不然
到底
复杂
比如说
不时
莫不
抑或
~±
哪年
本人
啦
奇
固
这样
来
~
③
不久
不限
不对
[③g]
++
您们
显然
总结
风雨无阻
如上所述
}
以免
虽则
部分
简直
打开天窗说亮话
极
哈哈
遵循
介于
’‘
现代
-
甚或
由是
亦
吱
嘛
果然
2.3%
这里
假使
根本
据称
达旦
[①⑦]
交口
特点
&
咱
正在
=☆
个人
马上
毫无
嘎嘎
巨大
开始
几经
因而
②
不拘
`
不是
单
除此之外
冲
更为
[②B]
这一来
全部
总的来说
联袂
或
其二
偶而
其
必然
[②⑥]
…
人家
大体上
保持
六
总是
不光
换言之
这个
尔尔
且
广大
四
近几年来
哉
多多益善
加之
断然
每年
多年前
再者
所
三
碰巧
[⑤f]
即或
罢了
起首
八成
多多少少
暗中
日益
喂
沿着
绝顶
主张
所谓
哈
容易
既…又
将要
不论
各自
甚么
倘然
某
迄
那麽
有及
-β
再
比
从小
她的
局外
该当
一定
不止
倘
敢
可是
省得
丰富
1.
你是
吧
[③]
随后
注意
-[*]-
何必
猛然间
……
怎
一起
三天两头
知道
宁可
从事
相应
如若
少数
这时
为了
啪达
⑥
据悉
忽地
φ.
顷刻
〔
㈧
进而
使得
[①c]
原来
赶
老老实实
不妨
可以
0
默然
可好
心里
自各儿
乘势
彻夜
密切
5:0
哦
)、
︿
一一
应用
特别是
相似
不若
怎么办
双方
就
从头
故
二话没说
是的
普通
比如
[]
【
连袂
取道
按理
起初
帮助
与否
不同
七
哼唷
大张旗鼓
孰料
啊
需要
实现
还要
〈
挨门挨户
存心
宁肯
一边
凝神
顷刻之间
先后
[③a]
孰知
(
还有
老是
就是
完成
抽冷子
可
大凡
如期
做到
因此
最後
若是
急匆匆
范围
大都
理该
[③c]
|
够瞧的
这边
重新
[①C]
→
沿
不巧
极度
不满
而已
有点
a]
好
哪儿
犹自
μ
宁愿
[①a]
不止一次
形成
您是
在于
?
下去
川流不息
主要
后
从宽
为着
吓
满足
A
勃然
尔等
它们的
敞开儿
连声
牢牢
到目前为止
根据
不能
从无到有
过
确定
怎样
出
已经
5
总的来看
?
行为
[①d]
固然
到
归根到底
从今以后
并且
此时
不然
⑨
除非
#
2
不下
然
较
℃
纯
凭
对应
扑通
遭到
每逢
是否
考虑
大约
为什麽
不怎么
比照
去
彼此
怎奈
梆
[②⑩]
]
有著
有的
何乐而不为
相等
儿
[②④
恰如
打从
不怕
][
强调
不亦乐乎
至
『
不力
不了
不过
轰然
一般
上
恐怕
相对而言
今後
二话不说
恰恰
取得
接著
一直
¥
差不多
[①⑤]
后来
<φ
9
③]
既然
啊哈
哪个
诸位
因为
所幸
之所以
哗
嘿嘿
贼死
』
自家
大抵
<Δ
屡屡
采取
别人
转动
不得不
拿
那末
不再
顿时
全身心
顺
全年
4
转贴
其后
嗯
别管
出来
处处
背靠背
呸
行动
[③F]
另方面
R.L.
余外
照着
不得了
严重
他是
[②g]
况且
白
种
待
得到
呵呵
那
限制
及其
[①②]
3
sub
从中
决定
A
一
@
》
被
似乎
对方
得出
相对
打
大量
彼
另悉
不一
迟早
[①D]
[⑩]
小
某些
用
将才
[④d]
一则
岂止
[⑤b]
长话短说
⑤
Lex
此中
向使
[
放量
尽快
切
不成
任何
该
乒
构成
离
当儿
掌握
[*]
若夫
现在
并无
基于
以致
假若
甚而
咋
矣乎
中间
从而
严格
请勿
//
不比
维持
[②b]
即
不胜
哇
累次
莫
三番两次
有关
恰恰相反
不起
不经意
人民
<<
除外
兮
上下
://
但
切不可
然後
难道说
不至于
别说
哪
尽量
完全
;
方面
借
光
不但
[②f]
社会主义
到了儿
至今
第
上去
漫说
从重
可见
我
乘隙
准备
反之
无
後面
一.
处理
较为
很少
[④c]
喏
尽管
进来
另一个
而言
普遍
有着
[⑨]
...
附近
恰逢
千万千万
极了
能否
绝
表示
巴巴
因了
[②③]
)
嗡
向
[②⑦]
以下
挨门逐户
任务
此
竟
[
这么样
另一方面
之后
进行
[①f]
从严
说来
之
单单
当中
正如
·
故意
专门
恰巧
大大
敢情
.日
c]
隔夜
大致
]∧′=[
来看
的
好的
没奈何
则甚
或则
换句话说
那么
按照
了解
呢
万一
最高
竟而
再则
Ψ
过去
大多
存在
略为
并没
岂非
臭
同样
非独
由于
[②h]
合理
反过来说
甭
不已
对于
不曾
比起
促进
怎么样
经常
然则
为止
仅
8
乘胜
独自
极力
鉴于
看样子
】
分期分批
当场
加以
深入
产生
日复一日
总之
二来
则
欢迎
让
呼哧
首先
接着
诚然
屡次
[②]
巩固
却
而后
0:2
'
乌乎
≈
快
前此
也罢
=
多年来
前面
三番五次
呵
从古到今
>
不仅...而且
却不
者
竟然
白白
#
的话
实际
分头
哎哟
同时
就要
必要
每时每刻
吧哒
奈
+
见
不变
要是
并非
间或
长线
人们
恰好
赶早不赶晚
立马
趁
如此等等
[⑤d]
如下
起先
尽如人意
看见
看
不仅仅是
或曰
比较
慢说
常常
庶乎
*
看上去
随时
兼之
[⑤]
基本上
概
左右
加上
积极
近
=″
,也
但是
乃
还
来得及
共总
代替
每个
再说
极其
得
当然
动不动
倒不如
以
奋勇
[④a]
曾
你的
何尝
从此以后
几乎
为主
有利
具体说来
只要
嘎
就此
得天独厚
当
如其
鄙人
老大
没
立时
长期以来
乘机
汝
不会
顷刻间
显著
坚决
全然
=(
谁
分期
喀
切切
自从
否则
啥
〉
随
正巧
看到
前进
而外
嗳
尔
顺着
虽
即将
之後
方才
别的
传说
一天
半
怎麽
互相
每天
那里
弹指之间
及至
叫
不尽
此地
此处
今年
反而
[③⑩]
上升
不只
适当
顷
他的
纵然
以前
人人
。
∈[
也就是说
充其量
来自
俺
如上
据我所知
论说
基本
方能
我们
......
几时
[①i]
[⑤e]
不由得
当地
趁势
自身
[-
阿
本着
练习
<±
据实
这些
通常
哪怕
《
这儿
[②G]
大面儿上
这种
你们
以后
些
切勿
来讲
[①A]
中小
隔日
不消
那个
不常
.数
▲
只是
她是
12%
e]
全体
不然的话
[⑤a]
转变
》),
[④b]
进步
从来
着呢
从轻
[①o]
如是
由此可见
得了
..
吗
如今
~+
从优
对待
并排
诸如
公然
五
宁
和
这麽
各人
例如
趁早
便
就算
认识
借此
当时
<
然而
哪里
非特
∪φ∈
_
对
纵使
多多
不能不
咧
只当
今天
■
从不
及时
[②②]
不择手段
乘虚
咚
我是
往
多次
豁然
有效
以至
很
还是
<λ
那边
等到
理当
随着
[②c]
今后
靠
并
大举
给
故而
决非
凡是
说说
她
又及
意思
>>
明确
开展
<
3
数/
仍然
他们
千
f]
弗
必定
如果
不得已
将
至于
γ
既是
为此
其实
重大
]
非但
ZXFITL
[①]
倍感
几番
设使
啊哟
究竟
论
致
刚巧
反手
9
初
近来
过于
总而言之
谁人
作为
莫不然
”,
,
~
毫不
先不先
都
再次
其它
就是说
什么样
立地
④
话说
呃
不尽然
一下
②c
呆呆地
特殊
据说
有力
×
除此而外
管
...................
进去
逢
屡次三番
毫无例外
尽心尽力
因着
如前所述
〕〔
嘘
带
关于
她们
不管
当着
一次
’
朝
加强
嗬
凡
大力
方便
避免
使用
尽
分别
起来
莫非
‘
十分
大略
多亏
::
会
遵照
不要
们
向着
相同
不特
不大
某个
如
累年
[②
又
眨眼
ng昉
其一
坚持
满
%
不少
具体来说
纵
略微
造成
嘻
其余
譬喻
多少
以便
依据
后面
地
後来
虽说
过来
~~~~
日臻
集中
不断
于是
由此
大概
唯有
反之则
忽然
什么
喔唷
刚才
连
迫于
顶多
与此同时
己
挨着
有时
./
叮咚
岂
只有
1
应该
5
乃至
前者
临
借以
召开
陡然
=
其中
=[
哼
0
及
非得
受到
您
[②①]
何止
那儿
姑且
以期
看出
哗啦
嘎登
反应
*
曾经
诚如
/
每每
九
腾
从早到晚
不知不觉
反之亦然
若果
В
一致
再者说
开外
本
怕
连同
愿意
始而
应当
目前
倍加
甚至
年复一年
虽然
既...又
」
更
LI
莫如
来不及
沙沙
这就是说
⑧
先後
[⑤]]
你
乎
那会儿
呗
甚至于
甫
重要
反倒
仍
组成
怪不得
更进一步
^
如次
倘使
[④]
似的
云尔
后者
用来
认为
规定
好在
别
从
倘或
几
.
当下
要不是
个
⑦
结果
′|
它们
&
着
再其次
谁料
许多
1
啐
我的
当头
在下
串行
大不了
常言说
缕缕
依照
刚好
不料
不必
=-
屡
不单
问题
立
按期
呜
不
巴
Δ
自后
替代
>λ
举凡
所以
咳
他人
且说
惟其
处在
传闻
!
不迭
继续
,
定
就地
连连
伙同
这次
何况
[②d]
全力
7
′∈
下来
极大
[①E]
那时
各种
当即
边
略加
周围
那么样
[①①]
连日来
匆匆
以上
很多
================================================
FILE: nrekit/data_loader.py
================================================
from six import iteritems
import json
import os
import multiprocessing
import numpy as np
import random
class file_data_loader:
def __next__(self):
raise NotImplementedError
def next(self):
return self.__next__()
def next_batch(self, batch_size):
raise NotImplementedError
class npy_data_loader(file_data_loader):
MODE_INSTANCE = 0 # One batch contains batch_size instances.
MODE_ENTPAIR_BAG = 1 # One batch contains batch_size bags, instances in which have the same entity pair (usually for testing).
MODE_RELFACT_BAG = 2 # One batch contains batch size bags, instances in which have the same relation fact. (usually for training).
def __iter__(self):
return self
def __init__(self, data_dir, prefix, mode, word_vec_npy='vec.npy', shuffle=True, max_length=120, batch_size=160):
if not os.path.isdir(data_dir):
raise Exception("[ERROR] Data dir doesn't exist!")
self.mode = mode
self.shuffle = shuffle
self.max_length = max_length
self.batch_size = batch_size
self.word_vec_mat = np.load(os.path.join(data_dir, word_vec_npy))
self.data_word = np.load(os.path.join(data_dir, prefix + "_word.npy"))
self.data_pos1 = np.load(os.path.join(data_dir, prefix + "_pos1.npy"))
self.data_pos2 = np.load(os.path.join(data_dir, prefix + "_pos2.npy"))
self.data_mask = np.load(os.path.join(data_dir, prefix + "_mask.npy"))
self.data_rel = np.load(os.path.join(data_dir, prefix + "_label.npy"))
self.data_length = np.load(os.path.join(data_dir, prefix + "_len.npy"))
self.scope = np.load(os.path.join(data_dir, prefix + "_instance_scope.npy"))
self.triple = np.load(os.path.join(data_dir, prefix + "_instance_triple.npy"))
self.relfact_tot = len(self.triple)
for i in range(self.scope.shape[0]):
self.scope[i][1] += 1
self.instance_tot = self.data_word.shape[0]
self.rel_tot = 53
if self.mode == self.MODE_INSTANCE:
self.order = list(range(self.instance_tot))
else:
self.order = list(range(len(self.scope)))
self.idx = 0
if self.shuffle:
random.shuffle(self.order)
print("Total relation fact: %d" % (self.relfact_tot))
def __next__(self):
return self.next_batch(self.batch_size)
def next_batch(self, batch_size):
if self.idx >= len(self.order):
self.idx = 0
if self.shuffle:
random.shuffle(self.order)
raise StopIteration
batch_data = {}
if self.mode == self.MODE_INSTANCE:
idx0 = self.idx
idx1 = self.idx + batch_size
if idx1 > len(self.order):
self.idx = 0
if self.shuffle:
random.shuffle(self.order)
raise StopIteration
self.idx = idx1
batch_data['word'] = self.data_word[idx0:idx1]
batch_data['pos1'] = self.data_pos1[idx0:idx1]
batch_data['pos2'] = self.data_pos2[idx0:idx1]
batch_data['rel'] = self.data_rel[idx0:idx1]
batch_data['length'] = self.data_length[idx0:idx1]
batch_data['scope'] = np.stack([list(range(idx1 - idx0)), list(range(1, idx1 - idx0 + 1))], axis=1)
elif self.mode == self.MODE_ENTPAIR_BAG or self.mode == self.MODE_RELFACT_BAG:
idx0 = self.idx
idx1 = self.idx + batch_size
if idx1 > len(self.order):
self.idx = 0
if self.shuffle:
random.shuffle(self.order)
raise StopIteration
self.idx = idx1
_word = []
_pos1 = []
_pos2 = []
_rel = []
_ins_rel = []
_multi_rel = []
_length = []
_scope = []
_mask = []
cur_pos = 0
for i in range(idx0, idx1):
_word.append(self.data_word[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])
_pos1.append(self.data_pos1[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])
_pos2.append(self.data_pos2[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])
_rel.append(self.data_rel[self.scope[self.order[i]][0]])
_ins_rel.append(self.data_rel[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])
_length.append(self.data_length[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])
_mask.append(self.data_mask[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])
bag_size = self.scope[self.order[i]][1] - self.scope[self.order[i]][0]
_scope.append([cur_pos, cur_pos + bag_size])
cur_pos = cur_pos + bag_size
if self.mode == self.MODE_ENTPAIR_BAG:
_one_multi_rel = np.zeros((self.rel_tot), dtype=np.int32)
for j in range(self.scope[self.order[i]][0], self.scope[self.order[i]][1]):
_one_multi_rel[self.data_rel[j]] = 1
_multi_rel.append(_one_multi_rel)
batch_data['word'] = np.concatenate(_word)
batch_data['pos1'] = np.concatenate(_pos1)
batch_data['pos2'] = np.concatenate(_pos2)
batch_data['rel'] = np.stack(_rel)
batch_data['ins_rel'] = np.concatenate(_ins_rel)
if self.mode == self.MODE_ENTPAIR_BAG:
batch_data['multi_rel'] = np.stack(_multi_rel)
batch_data['length'] = np.concatenate(_length)
batch_data['scope'] = np.stack(_scope)
batch_data['mask'] = np.concatenate(_mask)
return batch_data
class json_file_data_loader(file_data_loader):
MODE_INSTANCE = 0 # One batch contains batch_size instances.
MODE_ENTPAIR_BAG = 1 # One batch contains batch_size bags, instances in which have the same entity pair (usually for testing).
MODE_RELFACT_BAG = 2 # One batch contains batch size bags, instances in which have the same relation fact. (usually for training).
def _load_preprocessed_file(self):
name_prefix = '.'.join(self.file_name.split('/')[-1].split('.')[:-1])
word_vec_name_prefix = '.'.join(self.word_vec_file_name.split('/')[-1].split('.')[:-1])
processed_data_dir = '_processed_data'
if not os.path.isdir(processed_data_dir):
return False
word_npy_file_name = os.path.join(processed_data_dir, name_prefix + '_word.npy')
pos1_npy_file_name = os.path.join(processed_data_dir, name_prefix + '_pos1.npy')
pos2_npy_file_name = os.path.join(processed_data_dir, name_prefix + '_pos2.npy')
rel_npy_file_name = os.path.join(processed_data_dir, name_prefix + '_rel.npy')
mask_npy_file_name = os.path.join(processed_data_dir, name_prefix + '_mask.npy')
length_npy_file_name = os.path.join(processed_data_dir, name_prefix + '_length.npy')
entpair2scope_file_name = os.path.join(processed_data_dir, name_prefix + '_entpair2scope.json')
relfact2scope_file_name = os.path.join(processed_data_dir, name_prefix + '_relfact2scope.json')
word_vec_mat_file_name = os.path.join(processed_data_dir, word_vec_name_prefix + '_mat.npy')
word2id_file_name = os.path.join(processed_data_dir, word_vec_name_prefix + '_word2id.json')
if not os.path.exists(word_npy_file_name) or \
not os.path.exists(pos1_npy_file_name) or \
not os.path.exists(pos2_npy_file_name) or \
not os.path.exists(rel_npy_file_name) or \
not os.path.exists(mask_npy_file_name) or \
not os.path.exists(length_npy_file_name) or \
not os.path.exists(entpair2scope_file_name) or \
not os.path.exists(relfact2scope_file_name) or \
not os.path.exists(word_vec_mat_file_name) or \
not os.path.exists(word2id_file_name):
return False
print("Pre-processed files exist. Loading them...")
self.data_word = np.load(word_npy_file_name)
self.data_pos1 = np.load(pos1_npy_file_name)
self.data_pos2 = np.load(pos2_npy_file_name)
self.data_rel = np.load(rel_npy_file_name)
self.data_mask = np.load(mask_npy_file_name)
self.data_length = np.load(length_npy_file_name)
self.entpair2scope = json.load(open(entpair2scope_file_name))
self.relfact2scope = json.load(open(relfact2scope_file_name))
self.word_vec_mat = np.load(word_vec_mat_file_name)
self.word2id = json.load(open(word2id_file_name))
if self.data_word.shape[1] != self.max_length:
print("Pre-processed files don't match current settings. Reprocessing...")
return False
print("Finish loading")
return True
def __init__(self, file_name, word_vec_file_name, rel2id_file_name, mode, shuffle=True, max_length=120, case_sensitive=False, reprocess=False, batch_size=160):
'''
file_name: Json file storing the data in the following format
[
{
'sentence': 'Bill Gates is the founder of Microsoft .',
'head': {'word': 'Bill Gates', ...(other information)},
'tail': {'word': 'Microsoft', ...(other information)},
'relation': 'founder'
},
...
]
word_vec_file_name: Json file storing word vectors in the following format
[
{'word': 'the', 'vec': [0.418, 0.24968, ...]},
{'word': ',', 'vec': [0.013441, 0.23682, ...]},
...
]
rel2id_file_name: Json file storing relation-to-id diction in the following format
{
'NA': 0
'founder': 1
...
}
**IMPORTANT**: make sure the id of NA is 0!
mode: Specify how to get a batch of data. See MODE_* constants for details.
shuffle: Whether to shuffle the data, default as True. You should use shuffle when training.
max_length: The length that all the sentences need to be extend to, default as 120.
case_sensitive: Whether the data processing is case-sensitive, default as False.
reprocess: Do the pre-processing whether there exist pre-processed files, default as False.
batch_size: The size of each batch, default as 160.
'''
self.file_name = file_name
self.word_vec_file_name = word_vec_file_name
self.case_sensitive = case_sensitive
self.max_length = max_length
self.mode = mode
self.shuffle = shuffle
self.batch_size = batch_size
self.rel2id = json.load(open(rel2id_file_name))
if reprocess or not self._load_preprocessed_file(): # Try to load pre-processed files:
# Check files
if file_name is None or not os.path.isfile(file_name):
raise Exception("[ERROR] Data file doesn't exist")
if word_vec_file_name is None or not os.path.isfile(word_vec_file_name):
raise Exception("[ERROR] Word vector file doesn't exist")
# Load files
print("Loading data file...")
self.ori_data = json.load(open(self.file_name, "r"))
print("Finish loading")
print("Loading word vector file...")
self.ori_word_vec = json.load(open(self.word_vec_file_name, "r"))
print("Finish loading")
# Eliminate case sensitive
if not case_sensitive:
print("Elimiating case sensitive problem...")
for i in range(len(self.ori_data)):
self.ori_data[i]['sentence'] = self.ori_data[i]['sentence'].lower()
self.ori_data[i]['head']['word'] = self.ori_data[i]['head']['word'].lower()
self.ori_data[i]['tail']['word'] = self.ori_data[i]['tail']['word'].lower()
print("Finish eliminating")
# Sort data by entities and relations
print("Sort data...")
self.ori_data.sort(key=lambda a: a['head']['id'] + '#' + a['tail']['id'] + '#' + a['relation'])
print("Finish sorting")
# Pre-process word vec
self.word2id = {}
self.word_vec_tot = len(self.ori_word_vec)
UNK = self.word_vec_tot
BLANK = self.word_vec_tot + 1
self.word_vec_dim = len(self.ori_word_vec[0]['vec'])
print("Got {} words of {} dims".format(self.word_vec_tot, self.word_vec_dim))
print("Building word vector matrix and mapping...")
self.word_vec_mat = np.zeros((self.word_vec_tot, self.word_vec_dim), dtype=np.float32)
for cur_id, word in enumerate(self.ori_word_vec):
w = word['word']
if not case_sensitive:
w = w.lower()
self.word2id[w] = cur_id
self.word_vec_mat[cur_id, :] = word['vec']
self.word2id['UNK'] = UNK
self.word2id['BLANK'] = BLANK
print("Finish building")
# Pre-process data
print("Pre-processing data...")
self.instance_tot = len(self.ori_data)
self.entpair2scope = {} # (head, tail) -> scope
self.relfact2scope = {} # (head, tail, relation) -> scope
self.data_word = np.zeros((self.instance_tot, self.max_length), dtype=np.int32)
self.data_pos1 = np.zeros((self.instance_tot, self.max_length), dtype=np.int32)
self.data_pos2 = np.zeros((self.instance_tot, self.max_length), dtype=np.int32)
self.data_rel = np.zeros((self.instance_tot), dtype=np.int32)
self.data_mask = np.zeros((self.instance_tot, self.max_length), dtype=np.int32)
self.data_length = np.zeros((self.instance_tot), dtype=np.int32)
last_entpair = ''
last_entpair_pos = -1
last_relfact = ''
last_relfact_pos = -1
for i in range(self.instance_tot):
ins = self.ori_data[i]
if ins['relation'] in self.rel2id:
self.data_rel[i] = self.rel2id[ins['relation']]
else:
self.data_rel[i] = self.rel2id['NA']
sentence = ' '.join(ins['sentence'].split()) # delete extra spaces
head = ins['head']['word']
tail = ins['tail']['word']
cur_entpair = ins['head']['id'] + '#' + ins['tail']['id']
cur_relfact = ins['head']['id'] + '#' + ins['tail']['id'] + '#' + ins['relation']
if cur_entpair != last_entpair:
if last_entpair != '':
self.entpair2scope[last_entpair] = [last_entpair_pos, i] # left closed right open
last_entpair = cur_entpair
last_entpair_pos = i
if cur_relfact != last_relfact:
if last_relfact != '':
self.relfact2scope[last_relfact] = [last_relfact_pos, i]
last_relfact = cur_relfact
last_relfact_pos = i
p1 = sentence.find(' ' + head + ' ')
p2 = sentence.find(' ' + tail + ' ')
if p1 == -1:
if sentence[:len(head) + 1] == head + " ":
p1 = 0
elif sentence[-len(head) - 1:] == " " + head:
p1 = len(sentence) - len(head)
else:
p1 = 0 # shouldn't happen
else:
p1 += 1
if p2 == -1:
if sentence[:len(tail) + 1] == tail + " ":
p2 = 0
elif sentence[-len(tail) - 1:] == " " + tail:
p2 = len(sentence) - len(tail)
else:
p2 = 0 # shouldn't happen
else:
p2 += 1
# if p1 == -1 or p2 == -1:
# raise Exception("[ERROR] Sentence doesn't contain the entity, index = {}, sentence = {}, head = {}, tail = {}".format(i, sentence, head, tail))
words = sentence.split()
cur_ref_data_word = self.data_word[i]
cur_pos = 0
pos1 = -1
pos2 = -1
for j, word in enumerate(words):
if j < max_length:
if word in self.word2id:
cur_ref_data_word[j] = self.word2id[word]
else:
cur_ref_data_word[j] = UNK
if cur_pos == p1:
pos1 = j
p1 = -1
if cur_pos == p2:
pos2 = j
p2 = -1
cur_pos += len(word) + 1
for j in range(j + 1, max_length):
cur_ref_data_word[j] = BLANK
self.data_length[i] = len(words)
if len(words) > max_length:
self.data_length[i] = max_length
if pos1 == -1 or pos2 == -1:
raise Exception("[ERROR] Position error, index = {}, sentence = {}, head = {}, tail = {}".format(i, sentence, head, tail))
if pos1 >= max_length:
pos1 = max_length - 1
if pos2 >= max_length:
pos2 = max_length - 1
pos_min = min(pos1, pos2)
pos_max = max(pos1, pos2)
for j in range(max_length):
self.data_pos1[i][j] = j - pos1 + max_length
self.data_pos2[i][j] = j - pos2 + max_length
if j >= self.data_length[i]:
self.data_mask[i][j] = 0
elif j <= pos_min:
self.data_mask[i][j] = 1
elif j <= pos_max:
self.data_mask[i][j] = 2
else:
self.data_mask[i][j] = 3
if last_entpair != '':
self.entpair2scope[last_entpair] = [last_entpair_pos, self.instance_tot] # left closed right open
if last_relfact != '':
self.relfact2scope[last_relfact] = [last_relfact_pos, self.instance_tot]
print("Finish pre-processing")
print("Storing processed files...")
name_prefix = '.'.join(file_name.split('/')[-1].split('.')[:-1])
word_vec_name_prefix = '.'.join(word_vec_file_name.split('/')[-1].split('.')[:-1])
processed_data_dir = '_processed_data'
if not os.path.isdir(processed_data_dir):
os.mkdir(processed_data_dir)
np.save(os.path.join(processed_data_dir, name_prefix + '_word.npy'), self.data_word)
np.save(os.path.join(processed_data_dir, name_prefix + '_pos1.npy'), self.data_pos1)
np.save(os.path.join(processed_data_dir, name_prefix + '_pos2.npy'), self.data_pos2)
np.save(os.path.join(processed_data_dir, name_prefix + '_rel.npy'), self.data_rel)
np.save(os.path.join(processed_data_dir, name_prefix + '_mask.npy'), self.data_mask)
np.save(os.path.join(processed_data_dir, name_prefix + '_length.npy'), self.data_length)
json.dump(self.entpair2scope, open(os.path.join(processed_data_dir, name_prefix + '_entpair2scope.json'), 'w'))
json.dump(self.relfact2scope, open(os.path.join(processed_data_dir, name_prefix + '_relfact2scope.json'), 'w'))
np.save(os.path.join(processed_data_dir, word_vec_name_prefix + '_mat.npy'), self.word_vec_mat)
json.dump(self.word2id, open(os.path.join(processed_data_dir, word_vec_name_prefix + '_word2id.json'), 'w'))
print("Finish storing")
# Prepare for idx
self.instance_tot = self.data_word.shape[0]
self.entpair_tot = len(self.entpair2scope)
self.relfact_tot = 0 # The number of relation facts, without NA.
for key in self.relfact2scope:
if key[-2:] != 'NA':
self.relfact_tot += 1
self.rel_tot = len(self.rel2id)
if self.mode == self.MODE_INSTANCE:
self.order = list(range(self.instance_tot))
elif self.mode == self.MODE_ENTPAIR_BAG:
self.order = list(range(len(self.entpair2scope)))
self.scope_name = []
self.scope = []
for key, value in iteritems(self.entpair2scope):
self.scope_name.append(key)
self.scope.append(value)
elif self.mode == self.MODE_RELFACT_BAG:
self.order = list(range(len(self.relfact2scope)))
self.scope_name = []
self.scope = []
for key, value in iteritems(self.relfact2scope):
self.scope_name.append(key)
self.scope.append(value)
else:
raise Exception("[ERROR] Invalid mode")
self.idx = 0
if self.shuffle:
random.shuffle(self.order)
print("Total relation fact: %d" % (self.relfact_tot))
def __iter__(self):
return self
def __next__(self):
return self.next_batch(self.batch_size)
def next_batch(self, batch_size):
if self.idx >= len(self.order):
self.idx = 0
if self.shuffle:
random.shuffle(self.order)
raise StopIteration
batch_data = {}
if self.mode == self.MODE_INSTANCE:
idx0 = self.idx
idx1 = self.idx + batch_size
if idx1 > len(self.order):
idx1 = len(self.order)
self.idx = idx1
batch_data['word'] = self.data_word[idx0:idx1]
batch_data['pos1'] = self.data_pos1[idx0:idx1]
batch_data['pos2'] = self.data_pos2[idx0:idx1]
batch_data['rel'] = self.data_rel[idx0:idx1]
batch_data['mask'] = self.data_mask[idx0:idx1]
batch_data['length'] = self.data_length[idx0:idx1]
batch_data['scope'] = np.stack([list(range(batch_size)), list(range(1, batch_size + 1))], axis=1)
if idx1 - idx0 < batch_size:
padding = batch_size - (idx1 - idx0)
batch_data['word'] = np.concatenate([batch_data['word'], np.zeros((padding, self.data_word.shape[-1]), dtype=np.int32)])
batch_data['pos1'] = np.concatenate([batch_data['pos1'], np.zeros((padding, self.data_pos1.shape[-1]), dtype=np.int32)])
batch_data['pos2'] = np.concatenate([batch_data['pos2'], np.zeros((padding, self.data_pos2.shape[-1]), dtype=np.int32)])
batch_data['mask'] = np.concatenate([batch_data['mask'], np.zeros((padding, self.data_mask.shape[-1]), dtype=np.int32)])
batch_data['rel'] = np.concatenate([batch_data['rel'], np.zeros((padding), dtype=np.int32)])
batch_data['length'] = np.concatenate([batch_data['length'], np.zeros((padding), dtype=np.int32)])
elif self.mode == self.MODE_ENTPAIR_BAG or self.mode == self.MODE_RELFACT_BAG:
idx0 = self.idx
idx1 = self.idx + batch_size
if idx1 > len(self.order):
idx1 = len(self.order)
self.idx = idx1
_word = []
_pos1 = []
_pos2 = []
_mask = []
_rel = []
_ins_rel = []
_multi_rel = []
_entpair = []
_length = []
_scope = []
cur_pos = 0
for i in range(idx0, idx1):
_word.append(self.data_word[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])
_pos1.append(self.data_pos1[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])
_pos2.append(self.data_pos2[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])
_mask.append(self.data_mask[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])
_rel.append(self.data_rel[self.scope[self.order[i]][0]])
_ins_rel.append(self.data_rel[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])
_length.append(self.data_length[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])
bag_size = self.scope[self.order[i]][1] - self.scope[self.order[i]][0]
_scope.append([cur_pos, cur_pos + bag_size])
cur_pos = cur_pos + bag_size
if self.mode == self.MODE_ENTPAIR_BAG:
_one_multi_rel = np.zeros((self.rel_tot), dtype=np.int32)
for j in range(self.scope[self.order[i]][0], self.scope[self.order[i]][1]):
_one_multi_rel[self.data_rel[j]] = 1
_multi_rel.append(_one_multi_rel)
_entpair.append(self.scope_name[self.order[i]])
for i in range(batch_size - (idx1 - idx0)):
_word.append(np.zeros((1, self.data_word.shape[-1]), dtype=np.int32))
_pos1.append(np.zeros((1, self.data_pos1.shape[-1]), dtype=np.int32))
_pos2.append(np.zeros((1, self.data_pos2.shape[-1]), dtype=np.int32))
_mask.append(np.zeros((1, self.data_mask.shape[-1]), dtype=np.int32))
_rel.append(0)
_ins_rel.append(np.zeros((1), dtype=np.int32))
_length.append(np.zeros((1), dtype=np.int32))
_scope.append([cur_pos, cur_pos + 1])
cur_pos += 1
if self.mode == self.MODE_ENTPAIR_BAG:
_multi_rel.append(np.zeros((self.rel_tot), dtype=np.int32))
_entpair.append('None#None')
batch_data['word'] = np.concatenate(_word)
batch_data['pos1'] = np.concatenate(_pos1)
batch_data['pos2'] = np.concatenate(_pos2)
batch_data['mask'] = np.concatenate(_mask)
batch_data['rel'] = np.stack(_rel)
batch_data['ins_rel'] = np.concatenate(_ins_rel)
if self.mode == self.MODE_ENTPAIR_BAG:
batch_data['multi_rel'] = np.stack(_multi_rel)
batch_data['entpair'] = _entpair
batch_data['length'] = np.concatenate(_length)
batch_data['scope'] = np.stack(_scope)
return batch_data
================================================
FILE: nrekit/framework.py
================================================
import tensorflow as tf
import os
import sklearn.metrics
import numpy as np
import sys
import time
def average_gradients(tower_grads):
"""Calculate the average gradient for each shared variable across all towers.
Note that this function provides a synchronization point across all towers.
Args:
tower_grads: List of lists of (gradient, variable) tuples. The outer list
is over individual gradients. The inner list is over the gradient
calculation for each tower.
Returns:
List of pairs of (gradient, variable) where the gradient has been averaged
across all towers.
"""
average_grads = []
for grad_and_vars in zip(*tower_grads):
# Note that each grad_and_vars looks like the following:
# ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
grads = []
for g, _ in grad_and_vars:
# Add 0 dimension to the gradients to represent the tower.
expanded_g = tf.expand_dims(g, 0)
# Append on a 'tower' dimension which we will average over below.
grads.append(expanded_g)
# Average over the 'tower' dimension.
grad = tf.concat(axis=0, values=grads)
grad = tf.reduce_mean(grad, 0)
# Keep in mind that the Variables are redundant because they are shared
# across towers. So .. we will just return the first tower's pointer to
# the Variable.
v = grad_and_vars[0][1]
grad_and_var = (grad, v)
average_grads.append(grad_and_var)
return average_grads
class re_model:
def __init__(self, train_data_loader, batch_size, max_length=120):
self.word = tf.placeholder(dtype=tf.int32, shape=[None, max_length], name='word')
self.pos1 = tf.placeholder(dtype=tf.int32, shape=[None, max_length], name='pos1')
self.pos2 = tf.placeholder(dtype=tf.int32, shape=[None, max_length], name='pos2')
self.label = tf.placeholder(dtype=tf.int32, shape=[batch_size], name='label')
self.ins_label = tf.placeholder(dtype=tf.int32, shape=[None], name='ins_label')
self.length = tf.placeholder(dtype=tf.int32, shape=[None], name='length')
self.scope = tf.placeholder(dtype=tf.int32, shape=[batch_size, 2], name='scope')
self.train_data_loader = train_data_loader
self.rel_tot = train_data_loader.rel_tot
self.word_vec_mat = train_data_loader.word_vec_mat
def loss(self):
raise NotImplementedError
def train_logit(self):
raise NotImplementedError
def test_logit(self):
raise NotImplementedError
class re_framework:
MODE_BAG = 0 # Train and test the model at bag level.
MODE_INS = 1 # Train and test the model at instance level
def __init__(self, train_data_loader, test_data_loader, max_length=120, batch_size=160):
self.train_data_loader = train_data_loader
self.test_data_loader = test_data_loader
self.sess = None
def one_step_multi_models(self, sess, models, batch_data_gen, run_array, return_label=True):
feed_dict = {}
batch_label = []
for model in models:
batch_data = batch_data_gen.next_batch(batch_data_gen.batch_size // len(models))
feed_dict.update({
model.word: batch_data['word'],
model.pos1: batch_data['pos1'],
model.pos2: batch_data['pos2'],
model.label: batch_data['rel'],
model.ins_label: batch_data['ins_rel'],
model.scope: batch_data['scope'],
model.length: batch_data['length'],
})
if 'mask' in batch_data and hasattr(model, "mask"):
feed_dict.update({model.mask: batch_data['mask']})
batch_label.append(batch_data['rel'])
result = sess.run(run_array, feed_dict)
batch_label = np.concatenate(batch_label)
if return_label:
result += [batch_label]
return result
def one_step(self, sess, model, batch_data, run_array):
feed_dict = {
model.word: batch_data['word'],
model.pos1: batch_data['pos1'],
model.pos2: batch_data['pos2'],
model.label: batch_data['rel'],
model.ins_label: batch_data['ins_rel'],
model.scope: batch_data['scope'],
model.length: batch_data['length'],
}
if 'mask' in batch_data and hasattr(model, "mask"):
feed_dict.update({model.mask: batch_data['mask']})
result = sess.run(run_array, feed_dict)
return result
def train(self,
model,
model_name,
ckpt_dir='./checkpoint',
summary_dir='./summary',
test_result_dir='./test_result',
learning_rate=0.5,
max_epoch=60,
pretrain_model=None,
test_epoch=1,
optimizer=tf.train.GradientDescentOptimizer,
gpu_nums=1):
assert(self.train_data_loader.batch_size % gpu_nums == 0)
print("Start training...")
# Init
config = tf.ConfigProto(allow_soft_placement=True)
self.sess = tf.Session(config=config)
optimizer = optimizer(learning_rate)
# Multi GPUs
tower_grads = []
tower_models = []
for gpu_id in range(gpu_nums):
with tf.device("/gpu:%d" % gpu_id):
with tf.name_scope("gpu_%d" % gpu_id):
cur_model = model(self.train_data_loader, self.train_data_loader.batch_size // gpu_nums, self.train_data_loader.max_length)
tower_grads.append(optimizer.compute_gradients(cur_model.loss()))
tower_models.append(cur_model)
tf.add_to_collection("loss", cur_model.loss())
tf.add_to_collection("train_logit", cur_model.train_logit())
loss_collection = tf.get_collection("loss")
loss = tf.add_n(loss_collection) / len(loss_collection)
train_logit_collection = tf.get_collection("train_logit")
train_logit = tf.concat(train_logit_collection, 0)
grads = average_gradients(tower_grads)
train_op = optimizer.apply_gradients(grads)
summary_writer = tf.summary.FileWriter(summary_dir, self.sess.graph)
# Saver
saver = tf.train.Saver(max_to_keep=None)
if pretrain_model is None:
self.sess.run(tf.global_variables_initializer())
else:
saver.restore(self.sess, pretrain_model)
# Training
best_metric = 0
best_prec = None
best_recall = None
not_best_count = 0 # Stop training after several epochs without improvement.
for epoch in range(max_epoch):
print('###### Epoch ' + str(epoch) + ' ######')
tot_correct = 0
tot_not_na_correct = 0
tot = 0
tot_not_na = 0
i = 0
time_sum = 0
while True:
time_start = time.time()
try:
iter_loss, iter_logit, _train_op, iter_label = self.one_step_multi_models(self.sess, tower_models, self.train_data_loader, [loss, train_logit, train_op])
except StopIteration:
break
time_end = time.time()
t = time_end - time_start
time_sum += t
iter_output = iter_logit.argmax(-1)
iter_correct = (iter_output == iter_label).sum()
iter_not_na_correct = np.logical_and(iter_output == iter_label, iter_label != 0).sum()
tot_correct += iter_correct
tot_not_na_correct += iter_not_na_correct
tot += iter_label.shape[0]
tot_not_na += (iter_label != 0).sum()
if tot_not_na > 0:
sys.stdout.write("epoch %d step %d time %.2f | loss: %f, not NA accuracy: %f, accuracy: %f\r" % (epoch, i, t, iter_loss, float(tot_not_na_correct) / tot_not_na, float(tot_correct) / tot))
sys.stdout.flush()
i += 1
print("\nAverage iteration time: %f" % (time_sum / i))
if (epoch + 1) % test_epoch == 0:
metric = self.test(model)
if metric > best_metric:
best_metric = metric
best_prec = self.cur_prec
best_recall = self.cur_recall
print("Best model, storing...")
if not os.path.isdir(ckpt_dir):
os.mkdir(ckpt_dir)
path = saver.save(self.sess, os.path.join(ckpt_dir, model_name))
print("Finish storing")
not_best_count = 0
else:
not_best_count += 1
if not_best_count >= 20:
break
print("######")
print("Finish training " + model_name)
print("Best epoch auc = %f" % (best_metric))
if (not best_prec is None) and (not best_recall is None):
if not os.path.isdir(test_result_dir):
os.mkdir(test_result_dir)
np.save(os.path.join(test_result_dir, model_name + "_x.npy"), best_recall)
np.save(os.path.join(test_result_dir, model_name + "_y.npy"), best_prec)
def test(self,
model,
ckpt=None,
return_result=False,
mode=MODE_BAG):
if mode == re_framework.MODE_BAG:
return self.__test_bag__(model, ckpt=ckpt, return_result=return_result)
elif mode == re_framework.MODE_INS:
raise NotImplementedError
else:
raise NotImplementedError
def __test_bag__(self, model, ckpt=None, return_result=False):
print("Testing...")
if self.sess == None:
self.sess = tf.Session()
model = model(self.test_data_loader, self.test_data_loader.batch_size, self.test_data_loader.max_length)
if not ckpt is None:
saver = tf.train.Saver()
saver.restore(self.sess, ckpt)
tot_correct = 0
tot_not_na_correct = 0
tot = 0
tot_not_na = 0
entpair_tot = 0
test_result = []
pred_result = []
for i, batch_data in enumerate(self.test_data_loader):
iter_logit = self.one_step(self.sess, model, batch_data, [model.test_logit()])[0]
iter_output = iter_logit.argmax(-1)
iter_correct = (iter_output == batch_data['rel']).sum()
iter_not_na_correct = np.logical_and(iter_output == batch_data['rel'], batch_data['rel'] != 0).sum()
tot_correct += iter_correct
tot_not_na_correct += iter_not_na_correct
tot += batch_data['rel'].shape[0]
tot_not_na += (batch_data['rel'] != 0).sum()
if tot_not_na > 0:
sys.stdout.write("[TEST] step %d | not NA accuracy: %f, accuracy: %f\r" % (i, float(tot_not_na_correct) / tot_not_na, float(tot_correct) / tot))
sys.stdout.flush()
for idx in range(len(iter_logit)):
for rel in range(1, self.test_data_loader.rel_tot):
test_result.append({'score': iter_logit[idx][rel], 'flag': batch_data['multi_rel'][idx][rel]})
if batch_data['entpair'][idx] != "None#None":
pred_result.append({'score': float(iter_logit[idx][rel]), 'entpair': batch_data['entpair'][idx].encode('utf-8'), 'relation': rel})
entpair_tot += 1
sorted_test_result = sorted(test_result, key=lambda x: x['score'])
prec = []
recall = []
correct = 0
for i, item in enumerate(sorted_test_result[::-1]):
correct += item['flag']
prec.append(float(correct) / (i + 1))
recall.append(float(correct) / self.test_data_loader.relfact_tot)
auc = sklearn.metrics.auc(x=recall, y=prec)
print("\n[TEST] auc: {}".format(auc))
print("Finish testing")
self.cur_prec = prec
self.cur_recall = recall
if not return_result:
return auc
else:
return (auc, pred_result)
================================================
FILE: nrekit/network/classifier.py
================================================
import tensorflow as tf
import numpy as np
def softmax_cross_entropy(x, label, rel_tot, weights_table=None, weights=1.0, var_scope=None):
with tf.variable_scope(var_scope or "loss", reuse=tf.AUTO_REUSE):
if weights_table is not None:
weights = tf.nn.embedding_lookup(weights_table, label)
label_onehot = tf.one_hot(indices=label, depth=rel_tot, dtype=tf.int32)
loss = tf.losses.softmax_cross_entropy(onehot_labels=label_onehot, logits=x, weights=weights)
tf.summary.scalar('loss', loss)
return loss
def sigmoid_cross_entropy(x, label, rel_tot, weights_table=None, var_scope=None):
with tf.variable_scope(var_scope or "loss", reuse=tf.AUTO_REUSE):
if weights_table is None:
weights = 1.0
else:
weights = tf.nn.embedding_lookup(weights_table, label)
label_onehot = tf.one_hot(indices=label, depth=rel_tot, dtype=tf.int32)
loss = tf.losses.sigmoid_cross_entropy(label_onehot, logits=x, weights=weights)
tf.summary.scalar('loss', loss)
return loss
# Soft-label
# I just implemented it, but I haven't got the result in paper.
def soft_label_softmax_cross_entropy(x):
with tf.name_scope("soft-label-loss"):
label_onehot = tf.one_hot(indices=self.label, depth=FLAGS.num_classes, dtype=tf.int32)
nscore = x + 0.9 * tf.reshape(tf.reduce_max(x, 1), [-1, 1]) * tf.cast(label_onehot, tf.float32)
nlabel = tf.one_hot(indices=tf.reshape(tf.argmax(nscore, axis=1), [-1]), depth=FLAGS.num_classes, dtype=tf.int32)
loss = tf.losses.softmax_cross_entropy(onehot_labels=nlabel, logits=nscore, weights=self.weights)
tf.summary.scalar('loss', loss)
return loss
def output(x):
return tf.argmax(x, axis=-1)
================================================
FILE: nrekit/network/embedding.py
================================================
import tensorflow as tf
import numpy as np
def word_embedding(word, word_vec_mat, var_scope=None, word_embedding_dim=50, add_unk_and_blank=True):
with tf.variable_scope(var_scope or 'word_embedding', reuse=tf.AUTO_REUSE):
word_embedding = tf.get_variable('word_embedding', initializer=word_vec_mat, dtype=tf.float32)
if add_unk_and_blank:
word_embedding = tf.concat([word_embedding,
tf.get_variable("unk_word_embedding", [1, word_embedding_dim], dtype=tf.float32,
initializer=tf.contrib.layers.xavier_initializer()),
tf.constant(np.zeros((1, word_embedding_dim), dtype=np.float32))], 0)
x = tf.nn.embedding_lookup(word_embedding, word)
return x
def pos_embedding(pos1, pos2, var_scope=None, pos_embedding_dim=5, max_length=120):
with tf.variable_scope(var_scope or 'pos_embedding', reuse=tf.AUTO_REUSE):
pos_tot = max_length * 2
pos1_embedding = tf.get_variable('real_pos1_embedding', [pos_tot, pos_embedding_dim], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer())
# pos1_embedding = tf.concat([tf.zeros((1, pos_embedding_dim), dtype=tf.float32), real_pos1_embedding], 0)
pos2_embedding = tf.get_variable('real_pos2_embedding', [pos_tot, pos_embedding_dim], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer())
# pos2_embedding = tf.concat([tf.zeros((1, pos_embedding_dim), dtype=tf.float32), real_pos2_embedding], 0)
input_pos1 = tf.nn.embedding_lookup(pos1_embedding, pos1)
input_pos2 = tf.nn.embedding_lookup(pos2_embedding, pos2)
x = tf.concat([input_pos1, input_pos2], -1)
return x
def word_position_embedding(word, word_vec_mat, pos1, pos2, var_scope=None, word_embedding_dim=50, pos_embedding_dim=5, max_length=120, add_unk_and_blank=True):
w_embedding = word_embedding(word, word_vec_mat, var_scope=var_scope, word_embedding_dim=word_embedding_dim, add_unk_and_blank=add_unk_and_blank)
p_embedding = pos_embedding(pos1, pos2, var_scope=var_scope, pos_embedding_dim=pos_embedding_dim, max_length=max_length)
return tf.concat([w_embedding, p_embedding], -1)
================================================
FILE: nrekit/network/encoder.py
================================================
import tensorflow as tf
import numpy as np
import math
def __dropout__(x, keep_prob=1.0):
return tf.contrib.layers.dropout(x, keep_prob=keep_prob)
def __pooling__(x):
return tf.reduce_max(x, axis=-2)
def __piecewise_pooling__(x, mask):
mask_embedding = tf.constant([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float32)
mask = tf.nn.embedding_lookup(mask_embedding, mask)
hidden_size = x.shape[-1]
x = tf.reduce_max(tf.expand_dims(mask * 100, 2) + tf.expand_dims(x, 3), axis=1) - 100
return tf.reshape(x, [-1, hidden_size * 3])
def __cnn_cell__(x, hidden_size=230, kernel_size=3, stride_size=1):
x = tf.layers.conv1d(inputs=x,
filters=hidden_size,
kernel_size=kernel_size,
strides=stride_size,
padding='same',
kernel_initializer=tf.contrib.layers.xavier_initializer())
return x
def cnn(x, hidden_size=230, kernel_size=3, stride_size=1, activation=tf.nn.relu, var_scope=None, keep_prob=1.0):
with tf.variable_scope(var_scope or "cnn", reuse=tf.AUTO_REUSE):
max_length = x.shape[1]
x = __cnn_cell__(x, hidden_size, kernel_size, stride_size)
x = __pooling__(x)
x = activation(x)
x = __dropout__(x, keep_prob)
return x
def pcnn(x, mask, hidden_size=230, kernel_size=3, stride_size=1, activation=tf.nn.relu, var_scope=None, keep_prob=1.0):
with tf.variable_scope(var_scope or "pcnn", reuse=tf.AUTO_REUSE):
max_length = x.shape[1]
x = __cnn_cell__(x, hidden_size, kernel_size, stride_size)
x = __piecewise_pooling__(x, mask)
x = activation(x)
x = __dropout__(x, keep_prob)
return x
def __rnn_cell__(hidden_size, cell_name='lstm'):
if isinstance(cell_name, list) or isinstance(cell_name, tuple):
if len(cell_name) == 1:
return __rnn_cell__(hidden_size, cell_name[0])
cells = [self.__rnn_cell__(hidden_size, c) for c in cell_name]
return tf.contrib.rnn.MultiRNNCell(cells, state_is_tuple=True)
if cell_name.lower() == 'lstm':
return tf.contrib.rnn.BasicLSTMCell(hidden_size, state_is_tuple=True)
elif cell_name.lower() == 'gru':
return tf.contrib.rnn.GRUCell(hidden_size)
raise NotImplementedError
def rnn(x, length, hidden_size=230, cell_name='lstm', var_scope=None, keep_prob=1.0):
with tf.variable_scope(var_scope or "rnn", reuse=tf.AUTO_REUSE):
x = __dropout__(x, keep_prob)
cell = __rnn_cell__(hidden_size, cell_name)
_, states = tf.nn.dynamic_rnn(cell, x, sequence_length=length, dtype=tf.float32, scope='dynamic-rnn')
if isinstance(states, tuple):
states = states[0]
return states
def birnn(x, length, hidden_size=230, cell_name='lstm', var_scope=None, keep_prob=1.0):
with tf.variable_scope(var_scope or "birnn", reuse=tf.AUTO_REUSE):
x = __dropout__(x, keep_prob)
fw_cell = __rnn_cell__(hidden_size, cell_name)
bw_cell = __rnn_cell__(hidden_size, cell_name)
_, states = tf.nn.bidirectional_dynamic_rnn(fw_cell, bw_cell, x, sequence_length=length, dtype=tf.float32, scope='dynamic-bi-rnn')
fw_states, bw_states = states
if isinstance(fw_states, tuple):
fw_states = fw_states[0]
bw_states = bw_states[0]
return tf.concat([fw_states, bw_states], axis=1)
================================================
FILE: nrekit/network/selector.py
================================================
import tensorflow as tf
import numpy as np
def __dropout__(x, keep_prob=1.0):
return tf.contrib.layers.dropout(x, keep_prob=keep_prob)
def __logit__(x, rel_tot, var_scope=None):
with tf.variable_scope(var_scope or 'logit', reuse=tf.AUTO_REUSE):
relation_matrix = tf.get_variable('relation_matrix', shape=[rel_tot, x.shape[1]], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer())
bias = tf.get_variable('bias', shape=[rel_tot], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer())
logit = tf.matmul(x, tf.transpose(relation_matrix)) + bias
return logit
def __attention_train_logit__(x, query, rel_tot, var_scope=None):
with tf.variable_scope(var_scope or 'logit', reuse=tf.AUTO_REUSE):
relation_matrix = tf.get_variable('relation_matrix', shape=[rel_tot, x.shape[1]], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer())
bias = tf.get_variable('bias', shape=[rel_tot], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer())
current_relation = tf.nn.embedding_lookup(relation_matrix, query)
attention_logit = tf.reduce_sum(current_relation * x, -1) # sum[(n', hidden_size) \dot (n', hidden_size)] = (n)
return attention_logit
def __attention_test_logit__(x, rel_tot, var_scope=None):
with tf.variable_scope(var_scope or 'logit', reuse=tf.AUTO_REUSE):
relation_matrix = tf.get_variable('relation_matrix', shape=[rel_tot, x.shape[1]], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer())
bias = tf.get_variable('bias', shape=[rel_tot], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer())
attention_logit = tf.matmul(x, tf.transpose(relation_matrix)) # (n', hidden_size) x (hidden_size, rel_tot) = (n', rel_tot)
return attention_logit
def instance(x, rel_tot, var_scope=None, keep_prob=1.0):
x = __dropout__(x, keep_prob)
x = __logit__(x, rel_tot)
return x
def bag_attention(x, scope, query, rel_tot, is_training, var_scope=None, dropout_before=False, keep_prob=1.0):
with tf.variable_scope(var_scope or "attention", reuse=tf.AUTO_REUSE):
if is_training: # training
if dropout_before:
x = __dropout__(x, keep_prob)
bag_repre = []
attention_logit = __attention_train_logit__(x, query, rel_tot)
for i in range(scope.shape[0]):
bag_hidden_mat = x[scope[i][0]:scope[i][1]]
attention_score = tf.nn.softmax(attention_logit[scope[i][0]:scope[i][1]], -1)
bag_repre.append(tf.squeeze(tf.matmul(tf.expand_dims(attention_score, 0), bag_hidden_mat))) # (1, n') x (n', hidden_size) = (1, hidden_size) -> (hidden_size)
bag_repre = tf.stack(bag_repre)
if not dropout_before:
bag_repre = __dropout__(bag_repre, keep_prob)
return __logit__(bag_repre, rel_tot), bag_repre
else: # testing
attention_logit = __attention_test_logit__(x, rel_tot) # (n, rel_tot)
bag_repre = []
bag_logit = []
for i in range(scope.shape[0]):
bag_hidden_mat = x[scope[i][0]:scope[i][1]]
attention_score = tf.nn.softmax(tf.transpose(attention_logit[scope[i][0]:scope[i][1], :]), -1) # softmax of (rel_tot, n')
bag_repre_for_each_rel = tf.matmul(attention_score, bag_hidden_mat) # (rel_tot, n') \dot (n', hidden_size) = (rel_tot, hidden_size)
bag_logit_for_each_rel = __logit__(bag_repre_for_each_rel, rel_tot) # -> (rel_tot, rel_tot)
bag_repre.append(bag_repre_for_each_rel)
bag_logit.append(tf.diag_part(tf.nn.softmax(bag_logit_for_each_rel, -1))) # could be improved by sigmoid?
bag_repre = tf.stack(bag_repre)
bag_logit = tf.stack(bag_logit)
return bag_logit, bag_repre
def bag_average(x, scope, rel_tot, var_scope=None, dropout_before=False, keep_prob=1.0):
with tf.variable_scope(var_scope or "average", reuse=tf.AUTO_REUSE):
if dropout_before:
x = __dropout__(x, keep_prob)
bag_repre = []
for i in range(scope.shape[0]):
bag_hidden_mat = x[scope[i][0]:scope[i][1]]
bag_repre.append(tf.reduce_mean(bag_hidden_mat, 0)) # (n', hidden_size) -> (hidden_size)
bag_repre = tf.stack(bag_repre)
if not dropout_before:
bag_repre = __dropout__(bag_repre, keep_prob)
return __logit__(bag_repre, rel_tot), bag_repre
def bag_one(x, scope, query, rel_tot, is_training, var_scope=None, dropout_before=False, keep_prob=1.0): # could be improved?
with tf.variable_scope(var_scope or "one", reuse=tf.AUTO_REUSE):
if is_training: # training
if dropout_before:
x = __dropout__(x, keep_prob)
bag_repre = []
for i in range(scope.shape[0]):
bag_hidden_mat = x[scope[i][0]:scope[i][1]]
instance_logit = tf.nn.softmax(__logit__(bag_hidden_mat, rel_tot), -1) # (n', hidden_size) -> (n', rel_tot)
j = tf.argmax(instance_logit[:, query[i]], output_type=tf.int32)
bag_repre.append(bag_hidden_mat[j])
bag_repre = tf.stack(bag_repre)
if not dropout_before:
bag_repre = __dropout__(bag_repre, keep_prob)
return __logit__(bag_repre, rel_tot), bag_repre
else: # testing
if dropout_before:
x = __dropout__(x, keep_prob)
bag_repre = []
bag_logit = []
for i in range(scope.shape[0]):
bag_hidden_mat = x[scope[i][0]:scope[i][1]]
instance_logit = tf.nn.softmax(__logit__(bag_hidden_mat, rel_tot), -1) # (n', hidden_size) -> (n', rel_tot)
bag_logit.append(tf.reduce_max(instance_logit, 0))
bag_repre.append(bag_hidden_mat[0]) # fake max repre
bag_logit = tf.stack(bag_logit)
bag_repre = tf.stack(bag_repre)
return bag_logit, bag_repre
def bag_cross_max(x, scope, rel_tot, var_scope=None, dropout_before=False, keep_prob=1.0):
'''
Cross-sentence Max-pooling proposed by (Jiang et al. 2016.)
"Relation Extraction with Multi-instance Multi-label Convolutional Neural Networks"
https://pdfs.semanticscholar.org/8731/369a707046f3f8dd463d1fd107de31d40a24.pdf
'''
with tf.variable_scope(var_scope or "cross_max", reuse=tf.AUTO_REUSE):
if dropout_before:
x = __dropout__(x, keep_prob)
bag_repre = []
for i in range(scope.shape[0]):
bag_hidden_mat = x[scope[i][0]:scope[i][1]]
bag_repre.append(tf.reduce_max(bag_hidden_mat, 0)) # (n', hidden_size) -> (hidden_size)
bag_repre = tf.stack(bag_repre)
if not dropout_before:
bag_repre = __dropout__(bag_repre, keep_prob)
return __logit__(bag_repre, rel_tot), bag_repre
================================================
FILE: nrekit/rl.py
================================================
import tensorflow as tf
import os
import sklearn.metrics
import numpy as np
import sys
import math
import time
import framework
import network
class policy_agent(framework.re_model):
def __init__(self, train_data_loader, batch_size, max_length=120):
framework.re_model.__init__(self, train_data_loader, batch_size, max_length)
self.weights = tf.placeholder(tf.float32, shape=(), name="weights_scalar")
x = network.embedding.word_position_embedding(self.word, self.word_vec_mat, self.pos1, self.pos2)
x_train = network.encoder.cnn(x, keep_prob=0.5)
x_test = network.encoder.cnn(x, keep_prob=1.0)
self._train_logit = network.selector.instance(x_train, 2, keep_prob=0.5)
self._test_logit = network.selector.instance(x_test, 2, keep_prob=1.0)
self._loss = network.classifier.softmax_cross_entropy(self._train_logit, self.ins_label, 2, weights=self.weights)
def loss(self):
return self._loss
def train_logit(self):
return self._train_logit
def test_logit(self):
return self._test_logit
class rl_re_framework(framework.re_framework):
def __init__(self, train_data_loader, test_data_loader, max_length=120, batch_size=160):
framework.re_framework.__init__(self, train_data_loader, test_data_loader, max_length, batch_size)
def agent_one_step(self, sess, agent_model, batch_data, run_array, weights=1):
feed_dict = {
agent_model.word: batch_data['word'],
agent_model.pos1: batch_data['pos1'],
agent_model.pos2: batch_data['pos2'],
agent_model.ins_label: batch_data['agent_label'],
agent_model.length: batch_data['length'],
agent_model.weights: weights
}
if 'mask' in batch_data and hasattr(agent_model, "mask"):
feed_dict.update({agent_model.mask: batch_data['mask']})
result = sess.run(run_array, feed_dict)
return result
def pretrain_main_model(self, max_epoch):
for epoch in range(max_epoch):
print('###### Epoch ' + str(epoch) + ' ######')
tot_correct = 0
tot_not_na_correct = 0
tot = 0
tot_not_na = 0
i = 0
time_sum = 0
for i, batch_data in enumerate(self.train_data_loader):
time_start = time.time()
iter_loss, iter_logit, _train_op = self.one_step(self.sess, self.model, batch_data, [self.model.loss(), self.model.train_logit(), self.train_op])
time_end = time.time()
t = time_end - time_start
time_sum += t
iter_output = iter_logit.argmax(-1)
iter_label = batch_data['rel']
iter_correct = (iter_output == iter_label).sum()
iter_not_na_correct = np.logical_and(iter_output == iter_label, iter_label != 0).sum()
tot_correct += iter_correct
tot_not_na_correct += iter_not_na_correct
tot += iter_label.shape[0]
tot_not_na += (iter_label != 0).sum()
if tot_not_na > 0:
sys.stdout.write("[pretrain main model] epoch %d step %d time %.2f | loss: %f, not NA accuracy: %f, accuracy: %f\r" % (epoch, i, t, iter_loss, float(tot_not_na_correct) / tot_not_na, float(tot_correct) / tot))
sys.stdout.flush()
i += 1
print("\nAverage iteration time: %f" % (time_sum / i))
def pretrain_agent_model(self, max_epoch):
# Pre-train policy agent
for epoch in range(max_epoch):
print('###### [Pre-train Policy Agent] Epoch ' + str(epoch) + ' ######')
tot_correct = 0
tot_not_na_correct = 0
tot = 0
tot_not_na = 0
time_sum = 0
for i, batch_data in enumerate(self.train_data_loader):
time_start = time.time()
batch_data['agent_label'] = batch_data['ins_rel'] + 0
batch_data['agent_label'][batch_data['agent_label'] > 0] = 1
iter_loss, iter_logit, _train_op = self.agent_one_step(self.sess, self.agent_model, batch_data, [self.agent_model.loss(), self.agent_model.train_logit(), self.agent_train_op])
time_end = time.time()
t = time_end - time_start
time_sum += t
iter_output = iter_logit.argmax(-1)
iter_label = batch_data['ins_rel']
iter_correct = (iter_output == iter_label).sum()
iter_not_na_correct = np.logical_and(iter_output == iter_label, iter_label != 0).sum()
tot_correct += iter_correct
tot_not_na_correct += iter_not_na_correct
tot += iter_label.shape[0]
tot_not_na += (iter_label != 0).sum()
if tot_not_na > 0:
sys.stdout.write("[pretrain policy agent] epoch %d step %d time %.2f | loss: %f, not NA accuracy: %f, accuracy: %f\r" % (epoch, i, t, iter_loss, float(tot_not_na_correct) / tot_not_na, float(tot_correct) / tot))
sys.stdout.flush()
i += 1
def train(self,
model, # The main model
agent_model, # The model of policy agent
model_name,
ckpt_dir='./checkpoint',
summary_dir='./summary',
test_result_dir='./test_result',
learning_rate=0.5,
max_epoch=60,
pretrain_agent_epoch=1,
pretrain_model=None,
test_epoch=1,
optimizer=tf.train.GradientDescentOptimizer):
print("Start training...")
# Init
self.model = model(self.train_data_loader, self.train_data_loader.batch_size, self.train_data_loader.max_length)
model_optimizer = optimizer(learning_rate)
grads = model_optimizer.compute_gradients(self.model.loss())
self.train_op = model_optimizer.apply_gradients(grads)
# Init policy agent
self.agent_model = agent_model(self.train_data_loader, self.train_data_loader.batch_size, self.train_data_loader.max_length)
agent_optimizer = optimizer(learning_rate)
agent_grads = agent_optimizer.compute_gradients(self.agent_model.loss())
self.agent_train_op = agent_optimizer.apply_gradients(agent_grads)
# Session, writer and saver
self.sess = tf.Session()
summary_writer = tf.summary.FileWriter(summary_dir, self.sess.graph)
saver = tf.train.Saver(max_to_keep=None)
if pretrain_model is None:
self.sess.run(tf.global_variables_initializer())
else:
saver.restore(self.sess, pretrain_model)
self.pretrain_main_model(max_epoch=5) # Pre-train main model
self.pretrain_agent_model(max_epoch=1) # Pre-train policy agent
# Train
tot_delete = 0
batch_count = 0
instance_count = 0
reward = 0.0
best_metric = 0
best_prec = None
best_recall = None
not_best_count = 0 # Stop training after several epochs without improvement.
for epoch in range(max_epoch):
print('###### Epoch ' + str(epoch) + ' ######')
tot_correct = 0
tot_not_na_correct = 0
tot = 0
tot_not_na = 0
i = 0
time_sum = 0
batch_stack = []
# Update policy agent
for i, batch_data in enumerate(self.train_data_loader):
# Make action
batch_data['agent_label'] = batch_data['ins_rel'] + 0
batch_data['agent_label'][batch_data['agent_label'] > 0] = 1
batch_stack.append(batch_data)
iter_logit = self.agent_one_step(self.sess, self.agent_model, batch_data, [self.agent_model.train_logit()])[0]
action_result = iter_logit.argmax(-1)
# Calculate reward
batch_delete = np.sum(np.logical_and(batch_data['ins_rel'] != 0, action_result == 0))
batch_data['ins_rel'][action_result == 0] = 0
iter_loss = self.one_step(self.sess, self.model, batch_data, [self.model.loss()])[0]
reward += iter_loss
tot_delete += batch_delete
batch_count += 1
# Update parameters of policy agent
alpha = 0.1
if batch_count == 100:
reward = reward / float(batch_count)
average_loss = reward
reward = - math.log(1 - math.e ** (-reward))
sys.stdout.write('tot delete : %f | reward : %f | average loss : %f\r' % (tot_delete, reward, average_loss))
sys.stdout.flush()
for batch_data in batch_stack:
self.agent_one_step(self.sess, self.agent_model, batch_data, [self.agent_train_op], weights=reward * alpha)
batch_count = 0
reward = 0
tot_delete = 0
batch_stack = []
i += 1
# Train the main model
for i, batch_data in enumerate(self.train_data_loader):
batch_data['agent_label'] = batch_data['ins_rel'] + 0
batch_data['agent_label'][batch_data['agent_label'] > 0] = 1
time_start = time.time()
# Make actions
iter_logit = self.agent_one_step(self.sess, self.agent_model, batch_data, [self.agent_model.train_logit()])[0]
action_result = iter_logit.argmax(-1)
batch_data['ins_rel'][action_result == 0] = 0
# Real training
iter_loss, iter_logit, _train_op = self.agent_one_step(self.sess, self.agent_model, batch_data, [self.agent_model.loss(), self.agent_model.train_logit(), self.agent_train_op])
time_end = time.time()
t = time_end - time_start
time_sum += t
iter_output = iter_logit.argmax(-1)
if tot_not_na > 0:
sys.stdout.write("epoch %d step %d time %.2f | loss: %f, not NA accuracy: %f, accuracy: %f\r" % (epoch, i, t, iter_loss, float(tot_not_na_correct) / tot_not_na, float(tot_correct) / tot))
sys.stdout.flush()
i += 1
print("\nAverage iteration time: %f" % (time_sum / i))
if (epoch + 1) % test_epoch == 0:
metric = self.test(model)
if metric > best_metric:
best_metric = metric
best_prec = self.cur_prec
best_recall = self.cur_recall
print("Best model, storing...")
if not os.path.isdir(ckpt_dir):
os.mkdir(ckpt_dir)
path = saver.save(self.sess, os.path.join(ckpt_dir, model_name))
print("Finish storing")
not_best_count = 0
else:
not_best_count += 1
if not_best_count >= 20:
break
print("######")
print("Finish training " + model_name)
print("Best epoch auc = %f" % (best_metric))
if (not best_prec is None) and (not best_recall is None):
if not os.path.isdir(test_result_dir):
os.mkdir(test_result_dir)
np.save(os.path.join(test_result_dir, model_name + "_x.npy"), best_recall)
np.save(os.path.join(test_result_dir, model_name + "_y.npy"), best_prec)
================================================
FILE: test_demo.py
================================================
import nrekit
import numpy as np
import tensorflow as tf
import sys
import os
import json
dataset_name = 'nyt'
if len(sys.argv) > 1:
dataset_name = sys.argv[1]
dataset_dir = os.path.join('./data', dataset_name)
if not os.path.isdir(dataset_dir):
raise Exception("[ERROR] Dataset dir %s doesn't exist!" % (dataset_dir))
# The first 3 parameters are train / test data file name, word embedding file name and relation-id mapping file name respectively.
train_loader = nrekit.data_loader.json_file_data_loader(os.path.join(dataset_dir, 'train.json'),
os.path.join(dataset_dir, 'word_vec.json'),
os.path.join(dataset_dir, 'rel2id.json'),
mode=nrekit.data_loader.json_file_data_loader.MODE_RELFACT_BAG,
shuffle=True)
test_loader = nrekit.data_loader.json_file_data_loader(os.path.join(dataset_dir, 'test.json'),
os.path.join(dataset_dir, 'word_vec.json'),
os.path.join(dataset_dir, 'rel2id.json'),
mode=nrekit.data_loader.json_file_data_loader.MODE_ENTPAIR_BAG,
shuffle=False)
framework = nrekit.framework.re_framework(train_loader, test_loader)
class model(nrekit.framework.re_model):
encoder = "pcnn"
selector = "att"
def __init__(self, train_data_loader, batch_size, max_length=120):
nrekit.framework.re_model.__init__(self, train_data_loader, batch_size, max_length=max_length)
self.mask = tf.placeholder(dtype=tf.int32, shape=[None, max_length], name="mask")
# Embedding
x = nrekit.network.embedding.word_position_embedding(self.word, self.word_vec_mat, self.pos1, self.pos2)
# Encoder
if model.encoder == "pcnn":
x_train = nrekit.network.encoder.pcnn(x, self.mask, keep_prob=0.5)
x_test = nrekit.network.encoder.pcnn(x, self.mask, keep_prob=1.0)
elif model.encoder == "cnn":
x_train = nrekit.network.encoder.cnn(x, keep_prob=0.5)
x_test = nrekit.network.encoder.cnn(x, keep_prob=1.0)
elif model.encoder == "rnn":
x_train = nrekit.network.encoder.rnn(x, self.length, keep_prob=0.5)
x_test = nrekit.network.encoder.rnn(x, self.length, keep_prob=1.0)
elif model.encoder == "birnn":
x_train = nrekit.network.encoder.birnn(x, self.length, keep_prob=0.5)
x_test = nrekit.network.encoder.birnn(x, self.length, keep_prob=1.0)
else:
raise NotImplementedError
# Selector
if model.selector == "att":
self._train_logit, train_repre = nrekit.network.selector.bag_attention(x_train, self.scope, self.ins_label, self.rel_tot, True, keep_prob=0.5)
self._test_logit, test_repre = nrekit.network.selector.bag_attention(x_test, self.scope, self.ins_label, self.rel_tot, False, keep_prob=1.0)
elif model.selector == "ave":
self._train_logit, train_repre = nrekit.network.selector.bag_average(x_train, self.scope, self.rel_tot, keep_prob=0.5)
self._test_logit, test_repre = nrekit.network.selector.bag_average(x_test, self.scope, self.rel_tot, keep_prob=1.0)
self._test_logit = tf.nn.softmax(self._test_logit)
elif model.selector == "max":
self._train_logit, train_repre = nrekit.network.selector.bag_maximum(x_train, self.scope, self.ins_label, self.rel_tot, True, keep_prob=0.5)
self._test_logit, test_repre = nrekit.network.selector.bag_maximum(x_test, self.scope, self.ins_label, self.rel_tot, False, keep_prob=1.0)
self._test_logit = tf.nn.softmax(self._test_logit)
else:
raise NotImplementedError
# Classifier
self._loss = nrekit.network.classifier.softmax_cross_entropy(self._train_logit, self.label, self.rel_tot, weights_table=self.get_weights())
def loss(self):
return self._loss
def train_logit(self):
return self._train_logit
def test_logit(self):
return self._test_logit
def get_weights(self):
with tf.variable_scope("weights_table", reuse=tf.AUTO_REUSE):
print("Calculating weights_table...")
_weights_table = np.zeros((self.rel_tot), dtype=np.float32)
for i in range(len(self.train_data_loader.data_rel)):
_weights_table[self.train_data_loader.data_rel[i]] += 1.0
_weights_table = 1 / (_weights_table ** 0.05)
weights_table = tf.get_variable(name='weights_table', dtype=tf.float32, trainable=False, initializer=_weights_table)
print("Finish calculating")
return weights_table
if len(sys.argv) > 2:
model.encoder = sys.argv[2]
if len(sys.argv) > 3:
model.selector = sys.argv[3]
auc, pred_result = framework.test(model, ckpt="./checkpoint/" + dataset_name + "_" + model.encoder + "_" + model.selector, return_result=True)
with open('./test_result/' + dataset_name + "_" + model.encoder + "_" + model.selector + "_pred.json", 'w') as outfile:
json.dump(pred_result, outfile)
================================================
FILE: train_demo.py
================================================
import nrekit
import numpy as np
import tensorflow as tf
import sys
import os
dataset_name = 'nyt'
if len(sys.argv) > 1:
dataset_name = sys.argv[1]
dataset_dir = os.path.join('./data', dataset_name)
if not os.path.isdir(dataset_dir):
raise Exception("[ERROR] Dataset dir %s doesn't exist!" % (dataset_dir))
# The first 3 parameters are train / test data file name, word embedding file name and relation-id mapping file name respectively.
train_loader = nrekit.data_loader.json_file_data_loader(os.path.join(dataset_dir, 'train.json'),
os.path.join(dataset_dir, 'word_vec.json'),
os.path.join(dataset_dir, 'rel2id.json'),
mode=nrekit.data_loader.json_file_data_loader.MODE_RELFACT_BAG,
shuffle=True)
test_loader = nrekit.data_loader.json_file_data_loader(os.path.join(dataset_dir, 'test.json'),
os.path.join(dataset_dir, 'word_vec.json'),
os.path.join(dataset_dir, 'rel2id.json'),
mode=nrekit.data_loader.json_file_data_loader.MODE_ENTPAIR_BAG,
shuffle=False)
framework = nrekit.framework.re_framework(train_loader, test_loader)
class model(nrekit.framework.re_model):
encoder = "pcnn"
selector = "att"
def __init__(self, train_data_loader, batch_size, max_length=120):
nrekit.framework.re_model.__init__(self, train_data_loader, batch_size, max_length=max_length)
self.mask = tf.placeholder(dtype=tf.int32, shape=[None, max_length], name="mask")
# Embedding
x = nrekit.network.embedding.word_position_embedding(self.word, self.word_vec_mat, self.pos1, self.pos2)
# Encoder
if model.encoder == "pcnn":
x_train = nrekit.network.encoder.pcnn(x, self.mask, keep_prob=0.5)
x_test = nrekit.network.encoder.pcnn(x, self.mask, keep_prob=1.0)
elif model.encoder == "cnn":
x_train = nrekit.network.encoder.cnn(x, keep_prob=0.5)
x_test = nrekit.network.encoder.cnn(x, keep_prob=1.0)
elif model.encoder == "rnn":
x_train = nrekit.network.encoder.rnn(x, self.length, keep_prob=0.5)
x_test = nrekit.network.encoder.rnn(x, self.length, keep_prob=1.0)
elif model.encoder == "birnn":
x_train = nrekit.network.encoder.birnn(x, self.length, keep_prob=0.5)
x_test = nrekit.network.encoder.birnn(x, self.length, keep_prob=1.0)
else:
raise NotImplementedError
# Selector
if model.selector == "att":
self._train_logit, train_repre = nrekit.network.selector.bag_attention(x_train, self.scope, self.ins_label, self.rel_tot, True, keep_prob=0.5)
self._test_logit, test_repre = nrekit.network.selector.bag_attention(x_test, self.scope, self.ins_label, self.rel_tot, False, keep_prob=1.0)
elif model.selector == "ave":
self._train_logit, train_repre = nrekit.network.selector.bag_average(x_train, self.scope, self.rel_tot, keep_prob=0.5)
self._test_logit, test_repre = nrekit.network.selector.bag_average(x_test, self.scope, self.rel_tot, keep_prob=1.0)
self._test_logit = tf.nn.softmax(self._test_logit)
elif model.selector == "one":
self._train_logit, train_repre = nrekit.network.selector.bag_one(x_train, self.scope, self.label, self.rel_tot, True, keep_prob=0.5)
self._test_logit, test_repre = nrekit.network.selector.bag_one(x_test, self.scope, self.label, self.rel_tot, False, keep_prob=1.0)
self._test_logit = tf.nn.softmax(self._test_logit)
elif model.selector == "cross_max":
self._train_logit, train_repre = nrekit.network.selector.bag_cross_max(x_train, self.scope, self.rel_tot, keep_prob=0.5)
self._test_logit, test_repre = nrekit.network.selector.bag_cross_max(x_test, self.scope, self.rel_tot, keep_prob=1.0)
self._test_logit = tf.nn.softmax(self._test_logit)
else:
raise NotImplementedError
# Classifier
self._loss = nrekit.network.classifier.softmax_cross_entropy(self._train_logit, self.label, self.rel_tot, weights_table=self.get_weights())
def loss(self):
return self._loss
def train_logit(self):
return self._train_logit
def test_logit(self):
return self._test_logit
def get_weights(self):
with tf.variable_scope("weights_table", reuse=tf.AUTO_REUSE):
print("Calculating weights_table...")
_weights_table = np.zeros((self.rel_tot), dtype=np.float32)
for i in range(len(self.train_data_loader.data_rel)):
_weights_table[self.train_data_loader.data_rel[i]] += 1.0
_weights_table = 1 / (_weights_table ** 0.05)
weights_table = tf.get_variable(name='weights_table', dtype=tf.float32, trainable=False, initializer=_weights_table)
print("Finish calculating")
return weights_table
use_rl = False
if len(sys.argv) > 2:
model.encoder = sys.argv[2]
if len(sys.argv) > 3:
model.selector = sys.argv[3]
if len(sys.argv) > 4:
if sys.argv[4] == 'rl':
use_rl = True
if use_rl:
rl_framework = nrekit.rl.rl_re_framework(train_loader, test_loader)
rl_framework.train(model, nrekit.rl.policy_agent, model_name=dataset_name + "_" + model.encoder + "_" + model.selector + "_rl", max_epoch=60, ckpt_dir="checkpoint")
else:
framework.train(model, model_name=dataset_name + "_" + model.encoder + "_" + model.selector, max_epoch=60, ckpt_dir="checkpoint", gpu_nums=1)
gitextract_idaztadn/ ├── .gitignore ├── LICENSE ├── README.md ├── draw_plot.py ├── kg_data/ │ ├── EntityMatcher.py │ ├── README.md │ ├── SentenceSegment.py │ ├── add_relation.ipynb │ ├── data_process.ipynb │ └── stop_word.txt ├── nrekit/ │ ├── data_loader.py │ ├── framework.py │ ├── network/ │ │ ├── classifier.py │ │ ├── embedding.py │ │ ├── encoder.py │ │ └── selector.py │ └── rl.py ├── test_demo.py └── train_demo.py
SYMBOL INDEX (88 symbols across 12 files)
FILE: draw_plot.py
function main (line 12) | def main():
FILE: kg_data/EntityMatcher.py
class EntityMatcher (line 6) | class EntityMatcher:
method __init__ (line 7) | def __init__(self, entity_file, sentences_folder, process_num):
method match (line 20) | def match(self, file_name):
method write_file (line 35) | def write_file(self, data):
method run (line 39) | def run(self):
FILE: kg_data/SentenceSegment.py
function read_txt (line 7) | def read_txt(file_name):
class SentenceSegment (line 16) | class SentenceSegment:
method __init__ (line 17) | def __init__(self, dict_file, stop_word_file, sentences_folder, proces...
method segment (line 28) | def segment(self, file_name):
method write_file (line 54) | def write_file(self, data):
method run (line 58) | def run(self):
FILE: nrekit/data_loader.py
class file_data_loader (line 9) | class file_data_loader:
method __next__ (line 10) | def __next__(self):
method next (line 13) | def next(self):
method next_batch (line 16) | def next_batch(self, batch_size):
class npy_data_loader (line 19) | class npy_data_loader(file_data_loader):
method __iter__ (line 24) | def __iter__(self):
method __init__ (line 27) | def __init__(self, data_dir, prefix, mode, word_vec_npy='vec.npy', shu...
method __next__ (line 61) | def __next__(self):
method next_batch (line 64) | def next_batch(self, batch_size):
class json_file_data_loader (line 136) | class json_file_data_loader(file_data_loader):
method _load_preprocessed_file (line 141) | def _load_preprocessed_file(self):
method __init__ (line 185) | def __init__(self, file_name, word_vec_file_name, rel2id_file_name, mo...
method __iter__ (line 436) | def __iter__(self):
method __next__ (line 439) | def __next__(self):
method next_batch (line 442) | def next_batch(self, batch_size):
FILE: nrekit/framework.py
function average_gradients (line 8) | def average_gradients(tower_grads):
class re_model (line 45) | class re_model:
method __init__ (line 46) | def __init__(self, train_data_loader, batch_size, max_length=120):
method loss (line 58) | def loss(self):
method train_logit (line 61) | def train_logit(self):
method test_logit (line 64) | def test_logit(self):
class re_framework (line 67) | class re_framework:
method __init__ (line 71) | def __init__(self, train_data_loader, test_data_loader, max_length=120...
method one_step_multi_models (line 76) | def one_step_multi_models(self, sess, models, batch_data_gen, run_arra...
method one_step (line 99) | def one_step(self, sess, model, batch_data, run_array):
method train (line 114) | def train(self,
method test (line 225) | def test(self,
method __test_bag__ (line 237) | def __test_bag__(self, model, ckpt=None, return_result=False):
FILE: nrekit/network/classifier.py
function softmax_cross_entropy (line 4) | def softmax_cross_entropy(x, label, rel_tot, weights_table=None, weights...
function sigmoid_cross_entropy (line 13) | def sigmoid_cross_entropy(x, label, rel_tot, weights_table=None, var_sco...
function soft_label_softmax_cross_entropy (line 26) | def soft_label_softmax_cross_entropy(x):
function output (line 35) | def output(x):
FILE: nrekit/network/embedding.py
function word_embedding (line 4) | def word_embedding(word, word_vec_mat, var_scope=None, word_embedding_di...
function pos_embedding (line 15) | def pos_embedding(pos1, pos2, var_scope=None, pos_embedding_dim=5, max_l...
function word_position_embedding (line 29) | def word_position_embedding(word, word_vec_mat, pos1, pos2, var_scope=No...
FILE: nrekit/network/encoder.py
function __dropout__ (line 5) | def __dropout__(x, keep_prob=1.0):
function __pooling__ (line 8) | def __pooling__(x):
function __piecewise_pooling__ (line 11) | def __piecewise_pooling__(x, mask):
function __cnn_cell__ (line 18) | def __cnn_cell__(x, hidden_size=230, kernel_size=3, stride_size=1):
function cnn (line 27) | def cnn(x, hidden_size=230, kernel_size=3, stride_size=1, activation=tf....
function pcnn (line 36) | def pcnn(x, mask, hidden_size=230, kernel_size=3, stride_size=1, activat...
function __rnn_cell__ (line 45) | def __rnn_cell__(hidden_size, cell_name='lstm'):
function rnn (line 57) | def rnn(x, length, hidden_size=230, cell_name='lstm', var_scope=None, ke...
function birnn (line 66) | def birnn(x, length, hidden_size=230, cell_name='lstm', var_scope=None, ...
FILE: nrekit/network/selector.py
function __dropout__ (line 4) | def __dropout__(x, keep_prob=1.0):
function __logit__ (line 7) | def __logit__(x, rel_tot, var_scope=None):
function __attention_train_logit__ (line 14) | def __attention_train_logit__(x, query, rel_tot, var_scope=None):
function __attention_test_logit__ (line 22) | def __attention_test_logit__(x, rel_tot, var_scope=None):
function instance (line 29) | def instance(x, rel_tot, var_scope=None, keep_prob=1.0):
function bag_attention (line 34) | def bag_attention(x, scope, query, rel_tot, is_training, var_scope=None,...
function bag_average (line 64) | def bag_average(x, scope, rel_tot, var_scope=None, dropout_before=False,...
function bag_one (line 77) | def bag_one(x, scope, query, rel_tot, is_training, var_scope=None, dropo...
function bag_cross_max (line 106) | def bag_cross_max(x, scope, rel_tot, var_scope=None, dropout_before=Fals...
FILE: nrekit/rl.py
class policy_agent (line 11) | class policy_agent(framework.re_model):
method __init__ (line 12) | def __init__(self, train_data_loader, batch_size, max_length=120):
method loss (line 23) | def loss(self):
method train_logit (line 26) | def train_logit(self):
method test_logit (line 29) | def test_logit(self):
class rl_re_framework (line 32) | class rl_re_framework(framework.re_framework):
method __init__ (line 33) | def __init__(self, train_data_loader, test_data_loader, max_length=120...
method agent_one_step (line 36) | def agent_one_step(self, sess, agent_model, batch_data, run_array, wei...
method pretrain_main_model (line 50) | def pretrain_main_model(self, max_epoch):
method pretrain_agent_model (line 80) | def pretrain_agent_model(self, max_epoch):
method train (line 111) | def train(self,
FILE: test_demo.py
class model (line 29) | class model(nrekit.framework.re_model):
method __init__ (line 33) | def __init__(self, train_data_loader, batch_size, max_length=120):
method loss (line 74) | def loss(self):
method train_logit (line 77) | def train_logit(self):
method test_logit (line 80) | def test_logit(self):
method get_weights (line 83) | def get_weights(self):
FILE: train_demo.py
class model (line 28) | class model(nrekit.framework.re_model):
method __init__ (line 32) | def __init__(self, train_data_loader, batch_size, max_length=120):
method loss (line 77) | def loss(self):
method train_logit (line 80) | def train_logit(self):
method test_logit (line 83) | def test_logit(self):
method get_weights (line 86) | def get_weights(self):
Condensed preview — 19 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (138K chars).
[
{
"path": ".gitignore",
"chars": 117,
"preview": ".ipynb_checkpoints/\nkg_data/processed/\nkg_data/baike_triples.txt\nkg_data/baiketriples.zip\nkg_data/.ipynb_checkpoints/"
},
{
"path": "LICENSE",
"chars": 1067,
"preview": "MIT License\n\nCopyright (c) 2018 Tianyu Gao\n\nPermission is hereby granted, free of charge, to any person obtaining a copy"
},
{
"path": "README.md",
"chars": 944,
"preview": "# Distant-Supervised-Chinese-Relation-Extraction\n## 基于远监督的中文关系抽取\n\n### 数据集构建\n\n* 中文通用知识库CN-DBpedia\n* 远监督假设\n\n处理流程可在 kg_dat"
},
{
"path": "draw_plot.py",
"chars": 1124,
"preview": "import sklearn.metrics\nimport matplotlib\n# Use 'Agg' so this program could run on a remote server\nmatplotlib.use('Agg')\n"
},
{
"path": "kg_data/EntityMatcher.py",
"chars": 1809,
"preview": "import pickle\nimport multiprocessing\nimport time\nimport os\n\nclass EntityMatcher:\n def __init__(self, entity_file, sen"
},
{
"path": "kg_data/README.md",
"chars": 8566,
"preview": "# 远监督数据集构造流程\n\n\n\n\n## 运行顺序\n0. 下载原始数据, 解压后放在该目录下, 即 `kg_data/baike_triple"
},
{
"path": "kg_data/SentenceSegment.py",
"chars": 2495,
"preview": "import pickle\nimport jieba\nimport os\nimport multiprocessing\nimport time\n\ndef read_txt(file_name):\n txt_data = []\n "
},
{
"path": "kg_data/add_relation.ipynb",
"chars": 11338,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"code\",\n \"execution_count\": 1,\n \"metadata\": {},\n \"outputs\": [],\n \"source\": [\n "
},
{
"path": "kg_data/data_process.ipynb",
"chars": 4330,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"code\",\n \"execution_count\": null,\n \"metadata\": {},\n \"outputs\": [],\n \"source\": "
},
{
"path": "kg_data/stop_word.txt",
"chars": 5782,
"preview": "\n可能\n起\n便于\n有些\n上述\n纯粹\n尽然\n乃至于\n极为\n也\n并不\n其次\n矣哉\n为何\n到头来\n此外\n[①e]\n8\n就是了\n故此\n移动\n旁人\n立刻\n路经\n使\n元/吨\n刚\n归齐\n归根结底\n!\n〕\n迅速\n获得\n4\n已\n冒\n哎\n已矣\n哪天\n这么点儿\n"
},
{
"path": "nrekit/data_loader.py",
"chars": 26753,
"preview": "from six import iteritems\n\nimport json\nimport os\nimport multiprocessing\nimport numpy as np\nimport random\n\nclass file_dat"
},
{
"path": "nrekit/framework.py",
"chars": 12349,
"preview": "import tensorflow as tf\nimport os\nimport sklearn.metrics\nimport numpy as np\nimport sys\nimport time\n\ndef average_gradient"
},
{
"path": "nrekit/network/classifier.py",
"chars": 1776,
"preview": "import tensorflow as tf\nimport numpy as np\n\ndef softmax_cross_entropy(x, label, rel_tot, weights_table=None, weights=1.0"
},
{
"path": "nrekit/network/embedding.py",
"chars": 2268,
"preview": "import tensorflow as tf\nimport numpy as np\n\ndef word_embedding(word, word_vec_mat, var_scope=None, word_embedding_dim=50"
},
{
"path": "nrekit/network/encoder.py",
"chars": 3449,
"preview": "import tensorflow as tf\nimport numpy as np\nimport math\n\ndef __dropout__(x, keep_prob=1.0):\n return tf.contrib.layers."
},
{
"path": "nrekit/network/selector.py",
"chars": 6971,
"preview": "import tensorflow as tf\nimport numpy as np\n\ndef __dropout__(x, keep_prob=1.0):\n return tf.contrib.layers.dropout(x, k"
},
{
"path": "nrekit/rl.py",
"chars": 11764,
"preview": "import tensorflow as tf\nimport os\nimport sklearn.metrics\nimport numpy as np\nimport sys\nimport math\nimport time\nimport fr"
},
{
"path": "test_demo.py",
"chars": 5377,
"preview": "import nrekit\nimport numpy as np\nimport tensorflow as tf\nimport sys\nimport os\nimport json\n\ndataset_name = 'nyt'\nif len(s"
},
{
"path": "train_demo.py",
"chars": 5909,
"preview": "import nrekit\nimport numpy as np\nimport tensorflow as tf\nimport sys\nimport os\n\ndataset_name = 'nyt'\nif len(sys.argv) > 1"
}
]
About this extraction
This page contains the full source code of the xiaolalala/Distant-Supervised-Chinese-Relation-Extraction GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 19 files (111.5 KB), approximately 35.6k tokens, and a symbol index with 88 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.