master d0eb40a049b0 cached
19 files
111.5 KB
35.6k tokens
88 symbols
1 requests
Download .txt
Repository: xiaolalala/Distant-Supervised-Chinese-Relation-Extraction
Branch: master
Commit: d0eb40a049b0
Files: 19
Total size: 111.5 KB

Directory structure:
gitextract_idaztadn/

├── .gitignore
├── LICENSE
├── README.md
├── draw_plot.py
├── kg_data/
│   ├── EntityMatcher.py
│   ├── README.md
│   ├── SentenceSegment.py
│   ├── add_relation.ipynb
│   ├── data_process.ipynb
│   └── stop_word.txt
├── nrekit/
│   ├── data_loader.py
│   ├── framework.py
│   ├── network/
│   │   ├── classifier.py
│   │   ├── embedding.py
│   │   ├── encoder.py
│   │   └── selector.py
│   └── rl.py
├── test_demo.py
└── train_demo.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
.ipynb_checkpoints/
kg_data/processed/
kg_data/baike_triples.txt
kg_data/baiketriples.zip
kg_data/.ipynb_checkpoints/

================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2018 Tianyu Gao

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
# Distant-Supervised-Chinese-Relation-Extraction
## 基于远监督的中文关系抽取

### 数据集构建

* 中文通用知识库CN-DBpedia
* 远监督假设

处理流程可在 kg_data/README.md 中查看。点击[此处(谷歌云盘)](https://drive.google.com/open?id=1XmWW3-wveKiJFauqZTPcSbsgh4lECRBK)下载处理后的数据子集。

### 模型选择

使用 thunlp/OpenNRE 的模型, 具体信息参考其说明。

**源链接:** https://github.com/thunlp/OpenNRE

### 运行代码

数据集文件目录代码默认为 data/chinese,在命令中运行:
```
python train_demo.py chinese pcnn att
```
### 模型结果

部分关系的结果如下:

类别|精准度|召回率|F1分数
:-:|:-:|:-:|:-:
**全部**|**0.95428**|**0.95036**|**0.95232**
/人物/其它/民族|0.98374|0.979|0.98137
NA|0.96853|0.97824|0.97336
/人物/地点/国籍|0.84075|0.92673|0.88164
/组织/地点/位于|0.85157|0.83652|0.84398
/人物/其它/职业|0.86121|0.8037|0.83147
/人物/组织/毕业于|0.84137|0.78092|0.81002
/组织/人物/校长|0.94118|0.59259|0.72727
/人物/地点/出生地|0.81049|0.49028|0.61097
/人物/人物/家庭成员|0.65385|0.37778|0.47887
/人物/组织/属于|0.99999|0.11364|0.20408
/地点/地点/包含|0.99999|0.0625|0.11765
/组织/人物/创始人|0.99999|0.05882|0.11111

某些关系的召回率很低,分析发现原因可能是数据集中该关系的样本非常少。



================================================
FILE: draw_plot.py
================================================
import sklearn.metrics
import matplotlib
# Use 'Agg' so this program could run on a remote server
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import sys
import os

result_dir = './test_result'

def main():
    models = sys.argv[1:]
    for model in models:
        x = np.load(os.path.join(result_dir, model +'_x' + '.npy')) 
        y = np.load(os.path.join(result_dir, model + '_y' + '.npy'))
        f1 = (2 * x * y / (x + y + 1e-20)).max()
        auc = sklearn.metrics.auc(x=x, y=y)
        #plt.plot(x, y, lw=2, label=model + '-auc='+str(auc))
        plt.plot(x, y, lw=2, label=model)
        print(model + ' : ' + 'auc = ' + str(auc) + ' | ' + 'max F1 = ' + str(f1))
        print('    P@100: {} | P@200: {} | P@300: {} | Mean: {}'.format(y[100], y[200], y[300], (y[100] + y[200] + y[300]) / 3))
       
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.ylim([0.3, 1.0])
    plt.xlim([0.0, 0.4])
    plt.title('Precision-Recall')
    plt.legend(loc="upper right")
    plt.grid(True)
    plt.savefig(os.path.join(result_dir, 'pr_curve'))

if __name__ == "__main__":
    main()


================================================
FILE: kg_data/EntityMatcher.py
================================================
import pickle
import multiprocessing
import time
import os

class EntityMatcher:
    def __init__(self, entity_file, sentences_folder, process_num):
        self.process_num = process_num

        with open(entity_file, 'rb') as f:
            entity_dict = pickle.load(f)
        self.entities = list(set(list(entity_dict.keys())))

        # sentence files
        self.sentence_files = []
        for root, dirs, files in os.walk(sentences_folder):
            for file in files:
                self.sentence_files.append(os.path.join(root, file))

    def match(self, file_name):
        print('Start %s!'%(file_name))
        with open(file_name, 'rb') as f:
            data = pickle.load(f)
        new_data = []
        for sen in data:
            eset = []
            for entity in self.entities:
                if entity in sen:
                    eset.append(entity)
            if len(eset)>1:
                new_data.append([sen, eset])
        print('Done %s!'%(file_name))
        return new_data, file_name

    def write_file(self, data):
        with open(data[1], 'wb') as f:
            pickle.dump(data[0], f)

    def run(self):
        pool = multiprocessing.Pool(processes=self.process_num)
        for one_file in self.sentence_files:
            pool.apply_async(self.match, args=(str(one_file), ), callback=self.write_file)
        pool.close()
        pool.join()

if __name__ == "__main__":
    entity_file = 'processed/entities.pkl'
    sentences_folder = 'processed/sentences'
    process_num = 8
    st = time.localtime()
    print('\n开始时间: ')
    print(time.strftime("%Y-%m-%d %H:%M:%S", st))
    em = EntityMatcher(entity_file, sentences_folder, process_num)
    em.run()

    ed = time.localtime()
    print('结束时间: ')
    print (time.strftime("%Y-%m-%d %H:%M:%S", ed))

================================================
FILE: kg_data/README.md
================================================
# 远监督数据集构造流程


![](http://www.bbvdd.com/d/20190314161746ddc.png)

## 运行顺序
0. 下载原始数据, 解压后放在该目录下, 即  `kg_data/baike_triples.txt`
1. 在kg_data目录下运行jupyter notebook
```
jupyter notebook .
```
2. 按顺序执行 data_process.ipynb
3. 匹配实体
``` 
python EntityMatcher.py
```
4. 分词
```
python SentenceSegment.py
```
5. 按顺序执行 add_relation.ipynb


## 原始数据

* 数据来源

原始数据采用了中文通用百科知识图谱(CN-DBpedia)公开的部分数据, 包含900万+的百科实体以及6600万+的三元组关系。其中摘要信息400万+, 标签信息1980万+, infobox信息4100万+。

**下载地址:** http://www.openkg.cn/dataset/cndbpedia

**源链接:** http://kw.fudan.edu.cn/cndbpedia

* 数据格式

下载压缩包后解压为baike_triples.txt文件, 文件的每一行为一个三元组。第一个元素为实体名称, 第二个元素为关系或属性指代词, 第三个元素为其对应的值。

```
"1+8"时代广场   中文名  "1+8"时代广场
"1+8"时代广场   地点    咸宁大道与银泉大道交叉口
"1+8"时代广场   实质    城市综合体项目
"1+8"时代广场   总建面  约11.28万方
北京   中文名称   北京
北京   BaiduTAG  北京, 简称“京”, 是中华人民共和国省级行政区、首都、直辖市, 是全国的政治、文化中心。北京地处中国华北地区, 中心位于东经116°20′、北纬39°56′, 东与天津毗连, 其余均与河北相邻, 北京市总面积16410.54平方千米。
北京   所属地区   中国华北地区
"1.4"河南兰考火灾事故   地点    河南<a>兰考县</a>城关镇
"1.4"河南兰考火灾事故   时间    2013年1月4日
"1.4"河南兰考火灾事故   结果    一人重伤
```

## 构建实体字典

原始数据的每行第一个元素为实体, **根据正则表达式筛选全部为中文字符的实体**, 转换为字典格式。

处理后数据格式如下:

```python
{
    "北京":{
        "中文名称": "北京",
        "BaiduCard": "北京, 简称“京”, 是中华人民共和国省级行政区、首都、直辖市, 是全国的政治、文化中心。北京地处中国华北地区, 中心位于东经116°20′、北纬39°56′, 东与天津毗连, 其余均与河北相邻, 北京市总面积16410.54平方千米。",
        "所属地区": "中国华北地区"
    },
    ...
}
```

## 获取句子集合、句子预处理

实体的`BaiduCard`属性为实体的百度百科简介, 通常为多个句子。根据实体字典获取句子集合, 存为列表格式。

对所有句子进行预处理, **去除所有中文字符、中文常用标点之外的所有字符,  并对多个句子进行拆分**, 存为列表格式。

处理后数据格式如下:
```python
[
    "北京, 简称京, 是中华人民共和国省级行政区、首都、直辖市, 是全国的政治、文化中心。",
    "北京地处中国华北地区, 中心位于东经、北纬, 东与天津毗连, 其余均与河北相邻, 北京市总面积平方千米。",
    ...
]
```

## 句子匹配实体

对每一个句子, 遍历实体集合, 根据**字符串匹配**保存所有出现在句子中的实体。**过滤掉没有实体或仅有一个实体出现的句子**, 数据处理为`[[sentence, [entity1,...]], ...]`的格式。

处理后数据格式如下:
```python
[
    [
        "北京, 简称京, 是中华人民共和国省级行政区、首都、直辖市, 是全国的政治、文化中心。",
        [
            "北京",
            "中华人民共和国",
            "政治"
        ]
    ], 
    ...
]
```

## 句子分词

使用`Python`的`jieba`库进行中文分词, 对中文句子进行分词。将数据处理为`[[sentence, [entity1,...], [sentence_seg]], ...]`的格式。

收集所有分词后的句子, 作为语料库使用`Python`的`word2vec`库训练词向量。

**jieba 使用教程:** https://github.com/fxsjy/jieba 

**word2vec 使用教程:** https://radimrehurek.com/gensim/models/word2vec.html

### 定义用户字典
为防止实体被错误分词, 将所有实体(实体字典的键集合)写入到文件`dict.txt`作为用户字典。
### 定义停用词
定义文件`stop_word.txt`, 在分词过程中对句子去除中文停用词。(网上资源较多)

### 训练词向量

处理后数据格式如下:
```python
[
    [
        "北京, 简称京, 是中华人民共和国省级行政区、首都、直辖市, 是全国的政治、文化中心。",
        [
            "北京",
            "中华人民共和国",
            "政治"
        ],
        [
            "北京",
            "简称",
            "京",
            ...
        ]
    ], 
    ...
]
```

## 句子、实体对筛选

对分词后的句子重新对实体进行筛选, 对每一个句子的实体列表中的实体, 若其没有在分词后的句子中出现, 则去除该实体。对筛选后的实体集合两两组合, 数据处理为`[[sentence, entity_head, entity_tail, [sentence_seg]]]`的格式。(一个句子可能被用于多个样本。)

此处先对句子匹配实体, 去除不符合条件的句子后然后分词, 再用分词后的句子匹配实体的主要原因是:
1. 某些实体名称可能是另一实体的子集, 如“北京”和“北京大学”。在句子“北京大学是中国的著名大学。”中, 出现的实体应仅为“北京大学”。
2. 分词时间较长, 不对句子进行初步筛选, 直接对所有句子先分词再匹配实体, 这样效率较低。

处理后数据格式如下:
```python
[
    [
        "北京, 简称京, 是中华人民共和国省级行政区、首都、直辖市, 是全国的政治、文化中心。",
        "北京",
        "中华人民共和国",
        [
            "北京",
            "简称",
            ...
        ]
    ], 
    [
        "北京, 简称京, 是中华人民共和国省级行政区、首都、直辖市, 是全国的政治、文化中心。",
        "北京",
        "政治",
        [
            "北京",
            "简称",
            ...
        ]
    ], 
    ...
]
```

## 添加关系标签

根据对原始数据集的分析, 人工预定义了23种出现频率较高关系, 见附录1, 其中'NA'表示两实体没有关系或存在其他关系。同时, 原始数据中的关系/属性并没有对齐(如妻子、夫人对应同一种关系), 人工编写规则对关系对齐、聚合。

遍历上一步的每一条数据, 根据实体字典和人工定义的关系对齐表进行关系标注。数据处理为`[[sentence, entity_head, entity_tail, relation, [sentence_seg]]]`的格式。

处理后数据格式如下:
```python
[
    [
        "北京, 简称京, 是中华人民共和国省级行政区、首都、直辖市, 是全国的政治、文化中心。",
        "北京",
        "中华人民共和国",
        "/地点/地点/首都"
        [
            "北京",
            "简称",
            ...
        ]
    ], 
    ...
]
```
## 数据集格式转化

对数据进行格式转化, 并添加'id'、'type'或其它等等属性。数据处理为:
```python
[
    {
        "head": {
            "word": "北京",
            "id": "666",
            ...
        },
        "tail": {
            "word": "中华人名共和国",
            "id": "6",
            ...
        },
        "relation": "/地点/地点/首都",
        "sentence": "北京 简称 京 是 中华人民共和国 省级 行政区 首都 直辖市 是 全国 的 政治 文化 中心",
        ...
    }
]
```

## 划分训练集、测试集

每种标签按照3:1的比例划分训练集、测试集。

# 一些问题
1. 句子较多, 匹配实体、分词、训练词向量时间较长(400W句子匹配8W实体, 使用8个线程约需1~2小时?), 建议先使用较少数据预测下运行时间, 使用多线程或者数据子集进行操作。
2. 部分数据清洗工作较为简单粗暴, 存在改进空间。
3. 关系种类较少, 关系对齐规则较为简单, 且原始数据中存在部分噪声(如BaiduTAG被错误分类), 数据集中存在噪声。


# 附录


**关系对齐/聚合:** 对于三元组(head, relation, tail), 其属于关系 `/人物/地点/出生地` 的条件是head属于人物类别,  tail属于地点类别,  relation为 `/人物/地点/出生地` 对应关系指代词集合中的某一个。

 

实体类别|BaiduTAG中至少含有以下类别中的一个
:-:|:-:
人物|人物、歌手、演员、作家
机构|机构、企业、公司、学校、部门、大学
地点|地点、地理、城市、国家、地区
其它| 不限制



<table>
    <tr>
        <th width=20px>序号</th>
        <th>关系类别</th>
        <th>关系指代词集合</th>
    </tr>
    <tr>
        <td align="center" width=10px>0</td>
        <td width=200px align="center">NA</td>
        <td>原始数据不存在关系</td>
    </tr>
    <tr>
        <td width=50px align="center">1</td>
        <td width=200px align="center">/人物/人物/家庭成员</td>
        <td>父亲、母亲、丈夫、妻子、儿子、女儿、哥哥、妹妹、姐姐、弟弟、孙子、孙女、爷爷、奶奶、外婆、外公、家人、家庭成员 ,夫人、对象、夫君</td>
    </tr>
    <tr>
        <td width=50px align="center">2</td>
        <td width=200px align="center">/人物/人物/社交关系</td>
        <td> 朋友、好友、同学、合作、搭档、经纪人、师从</td>
    </tr>
    <tr>
        <td width=50px align="center">3</td>
        <td width=200px align="center">/人物/地点/出生地</td>
        <td>出生地、出生于、来自、歌手出生地、作者出生地、出生在、作者出生地、出生</td>
    </tr>
    <tr>
        <td width=50px align="center">4</td>
        <td width=200px align="center">/人物/地点/居住地</td>
        <td>居住地、主要居住地、居住、现居住、目前居住地、现居住于、居住地点、居住于</td>
    </tr>
    <tr>
        <td width=50px align="center">5</td>
        <td width=200px align="center">/人物/地点/国籍</td>
        <td>国籍、国家</td>
    </tr>
    <tr>
        <td width=50px align="center">6</td>
        <td width=200px align="center">/人物/组织/毕业于</td>
        <td>毕业院校、毕业于、毕业学院、本科毕业院校、最后毕业院校、毕业高中、毕业地点、本科毕业学校、知名校友</td>
    </tr>
    <tr>
        <td width=50px align="center">7</td>
        <td width=200px align="center">/人物/组织/属于</td>
        <td>隶属单位、经纪公司、隶属关系、行政隶属、隶属学校、隶属大学、隶属地区、所属公司、签约公司、任职公司、工作单位、所属</td>
    </tr>
    <tr>
        <td width=50px align="center">8</td>
        <td width=200px align="center">/人物/其它/职业</td>
        <td>职业</td>
    </tr>
    <tr>
        <td width=50px align="center">9</td>
        <td width=200px align="center">/人物/其它/民族</td>
        <td>民族</td>
    </tr>
    <tr>
        <td width=50px align="center">10</td>
        <td width=200px align="center">/组织/人物/拥有者</td>
        <td>拥有、拥有者</td>
    </tr>
    <tr>
        <td width=50px align="center">11</td>
        <td width=200px align="center">/组织/人物/创始人</td>
        <td>创始人、创始、主要创始人、集团创始人</td>
    </tr>
    <tr>
        <td width=50px align="center">12</td>
        <td width=200px align="center">/组织/人物/校长</td>
        <td>校长、现任校长、学校校长、总校长</td>
    </tr>
    <tr>
        <td width=50px align="center">13</td>
        <td width=200px align="center">/组织/人物/领导人</td>
        <td>领导、现任领导、领导单位、主要领导、领导人、主要领导人</td>
    </tr>
    <tr>
        <td width=50px align="center">14</td>
        <td width=200px align="center">/组织/组织/周边</td>
        <td>周围景观、周边景点</td>
    </tr>
    <tr>
        <td width=50px align="center">15</td>
        <td width=200px align="center">/组织/地点/位于</td>
        <td>所属地区、国家、地区、地理位置、位于、区域、地点、总部地点、所在地、所在区域、位于城市、总部位于、酒店位于、学校位于、最早位于、地址、所在城市、城市、主要城市、坐落于</td>
    </tr>
    <tr>
        <td width=50px align="center">16</td>
        <td width=200px align="center">/地点/人物/相关人物</td>
        <td>相关人物、知名人物、历史人物</td>
    </tr>
    <tr>
        <td width=50px align="center">17</td>
        <td width=200px align="center">/地点/地点/位于</td>
        <td>所属地区、所属国、所属洲、所属州、所属国家、最大城市、地区、地理位置、位于、区域、地点、总部地点、所在地、所在区域、位于城市、总部位于、酒店位于、学校位于、最早位于、地址、所在城市、城市、主要城市、坐落于</td>
    </tr>
    <tr>
        <td width=50px align="center">18</td>
        <td width=200px align="center">/地点/地点/毗邻</td>
        <td>毗邻、东邻、邻近行政区、相邻、紧邻、邻近、北邻、南邻、邻国</td>
    </tr>
    <tr>
        <td width=50px align="center">19</td>
        <td width=200px align="center">/地点/地点/包含</td>
        <td>包含、包含国家、包含人物、下辖地区、下属、</td>
    </tr>
    <tr>
        <td width=50px align="center">20</td>
        <td width=200px align="center">/地点/地点/首都</td>
        <td>首都</td>
    </tr>
    <tr>
        <td width=50px align="center">21</td>
        <td width=200px align="center">/地点/组织/景点</td>
        <td>著名景点、主要景点、旅游景点、特色景点</td>
    </tr>
    <tr>
        <td width=50px align="center">22</td>
        <td width=200px align="center">/地点/其它/气候</td>
        <td> 气候类型、气候条件、气候、气候带</td>
    </tr>
</table>

================================================
FILE: kg_data/SentenceSegment.py
================================================
import pickle
import jieba
import os
import multiprocessing
import time

def read_txt(file_name):
    txt_data = []
    with open(file_name, 'r', encoding='utf8') as f:
        d = f.readline()
        while d:
            txt_data.append(d.strip())
            d = f.readline()
    return txt_data

class SentenceSegment:
    def __init__(self, dict_file, stop_word_file, sentences_folder, process_num):
        self.process_num = process_num
        self.stop_word = read_txt(stop_word_file)
        jieba.load_userdict(dict_file)
        # sentence files
        self.sentence_files = []
        for root, dirs, files in os.walk(sentences_folder):
            for file in files:
                self.sentence_files.append(os.path.join(root, file))


    def segment(self, file_name):
        with open(file_name, 'rb') as f:
            data = pickle.load(f)
        new_data = []
        for d in data:
            # sentence segment
            sen_seg = []
            for word in jieba.cut(d[0]):
                if word not in self.stop_word:
                    sen_seg.append(word)
            d.append(sen_seg)
            # filter entities again
            new_eset = []
            for entity in d[1]:
                if entity in sen_seg:
                    new_eset.append(entity)
            # remove data the number of whose entity less than 2
            # and rebuilt data
            if len(new_eset)>1:
                for i in new_eset:
                    for j in new_eset:
                        if j!=i:
                            new_data.append([d[0], i, j, sen_seg])
        print('%s done!'%(file_name))
        return new_data, file_name

    def write_file(self, data):
        with open(data[1], 'wb') as f:
            pickle.dump(data[0], f)

    def run(self):
        pool = multiprocessing.Pool(processes=self.process_num)
        for one_file in self.sentence_files:
            pool.apply_async(self.segment, args=(str(one_file), ), callback=self.write_file)
        pool.close()
        pool.join()

if __name__ == "__main__":
    dict_file = 'processed/entity.txt'
    stop_word = 'stop_word.txt'
    sentences_folder = 'processed/sentences'
    process_num = 8
    st = time.localtime()

    ss = SentenceSegment(dict_file, stop_word, sentences_folder, process_num)
    ss.run()

    ed = time.localtime()
    print('\n开始时间: ')
    print(time.strftime("%Y-%m-%d %H:%M:%S", st))
    print('结束时间: ')
    print (time.strftime("%Y-%m-%d %H:%M:%S", ed))

================================================
FILE: kg_data/add_relation.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import json\n",
    "import os\n",
    "import math\n",
    "from gensim.models import Word2Vec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = []\n",
    "for i in range(100):\n",
    "    with open('processed/sentences/sen'+str(i), 'rb') as f:\n",
    "        data += pickle.load(f)\n",
    "len(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('processed/entities.pkl', 'rb') as f:\n",
    "    entities = pickle.load(f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 关系标注"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 根据规则 替换所有关系\n",
    "\n",
    "tgs = {\n",
    "    \"per\": [\"人物\", \"歌手\", \"演员\", \"作家\"],\n",
    "    \"org\": [\"机构\", \"企业\", \"公司\", \"学校\", \"部门\", \"大学\"],\n",
    "    \"pl\": [\"地点\", \"地理\", \"城市\", \"国家\", \"地区\"]\n",
    "}\n",
    "\n",
    "config = {\n",
    "        # person  9\n",
    "    \"per2per_family_members\" : [\"父亲\",\"母亲\",\"丈夫\",\"妻子\",\"儿子\",\"女儿\",\"哥哥\",\"妹妹\",\"姐姐\",\"弟弟\",\"孙子\",\n",
    "                        \"孙女\",\"爷爷\",\"奶奶\",\"外婆\", \"外公\",\"家人\",\"家庭成员\" ,\"夫人\",\"对象\",\"夫君\"],\n",
    "    \"per2per_social_members\" : [\"朋友\", \"好友\", \"同学\", \"合作\", \"搭档\", \"经纪人\", \"师从\"],\n",
    "\n",
    "    \"per2pl_birth_place\" : [\"出生地\", \"出生于\", \"来自\", \"歌手出生地\", \"作者出生地\", \"出生在\", \"作者出生地\", \"出生\"],\n",
    "    \"per2pl_live_place\" : [\"居住地\", \"主要居住地\", \"居住\", \"现居住\", \"目前居住地\", \"现居住于\", \"居住地点\", \"居住于\"],\n",
    "    \n",
    "    \"per2pl_country\": [\"国籍\", \"国家\"],\n",
    "    \"per2org_graduate_from\" : [\"毕业院校\", \"毕业于\", \"毕业学院\", \"本科毕业院校\", \"最后毕业院校\", \"毕业高中\", \"毕业地点\", \"本科毕业学校\", \"知名校友\"],\n",
    "    \"per2org_belong_to\" : [\"隶属单位\", \"经纪公司\", \"隶属关系\", \"行政隶属\", \"隶属学校\", \"隶属大学\", \"隶属地区\", \"所属公司\", \"签约公司\", \"任职公司\", \"工作单位\", \"所属\"],\n",
    "\n",
    "    \"per2oth_profession\" : ['职业'],\n",
    "    \"per2oth_nation\" : ['民族'],\n",
    "\n",
    "    # orgnazition  9\n",
    "    \"org2per_owner\" : [\"拥有\", \"拥有者\"],\n",
    "    \"org2per_founder\" : [\"创始人\", \"创始\", \"主要创始人\", \"集团创始人\"],\n",
    "    \"org2per_school_leader\" : [\"校长\", \"现任校长\", \"学校校长\", \"总校长\"],\n",
    "    \"org2per_leader\" : [\"领导\", \"现任领导\", \"领导单位\", \"主要领导\", \"领导人\", \"主要领导人\"],\n",
    "\n",
    "    \"org2org_surroundings\" : [\"周围景观\", \"周边景点\"],\n",
    "\n",
    "\n",
    "    \"org2pl_location\" : [\"所属地区\",\"国家\", \"地区\", \"地理位置\", \"位于\", \"区域\", \"地点\", \"总部地点\", \"所在地\", \"所在区域\", \"位于城市\", \"总部位于\", \"酒店位于\", \"学校位于\", \"最早位于\", \"地址\", \"所在城市\", \"城市\", \"主要城市\", \"坐落于\"],\n",
    "\n",
    "\n",
    "    # place  7\n",
    "    \"pl2per_main_character\" : [\"相关人物\", \"知名人物\", \"历史人物\"],\n",
    "\n",
    "    \"pl2pl_location\" : [\"所属地区\",\"所属国\", \"所属洲\", \"所属州\", \"所属国家\", \"最大城市\", \"地区\", \"地理位置\", \"位于\", \"区域\", \"地点\", \"总部地点\", \"所在地\", \"所在区域\", \"位于城市\", \"总部位于\", \"酒店位于\", \"学校位于\", \"最早位于\", \"地址\", \"所在城市\", \"城市\", \"主要城市\", \"坐落于\"],\n",
    "    \"pl2pl_adjacement\" : [\"毗邻\", \"东邻\", \"邻近行政区\", \"相邻\", \"紧邻\", \"邻近\", \"北邻\", \"南邻\", \"邻国\"],\n",
    "    \"pl2pl_contains\" : [\"包含\", \"包含国家\", \"包含人物\", \"下辖地区\", \"下属\"],\n",
    "    \"pl2pl_captial\" : [\"首都\"],\n",
    "\n",
    "    \"pl2org_sights\" : [\"著名景点\", \"主要景点\", \"旅游景点\", \"特色景点\"],\n",
    "    \"pl2oth_climate\" : [\"气候类型\", \"气候条件\", \"气候\", \"气候带\"],\n",
    "}\n",
    "\n",
    "\n",
    "name = {\n",
    "    \"per2per_family_members\": \"/人物/人物/家庭成员\",\n",
    "    \"per2per_social_members\": \"/人物/人物/社交关系\",\n",
    "\n",
    "    \"per2pl_birth_place\": \"/人物/地点/出生地\",\n",
    "    \"per2pl_live_place\": \"/人物/地点/居住地\",\n",
    "    \"per2pl_country\" : \"/人物/地点/国籍\",\n",
    "    \"per2org_graduate_from\": \"/人物/组织/毕业于\",\n",
    "    \"per2org_belong_to\": \"/人物/组织/属于\",\n",
    "\n",
    "    \"per2oth_profession\": \"/人物/其它/职业\",\n",
    "    \"per2oth_nation\": \"/人物/其它/民族\",\n",
    "\n",
    "    \"org2per_owner\": \"/组织/人物/拥有者\",\n",
    "    \"org2per_founder\": \"/组织/人物/创始人\",\n",
    "    \"org2per_school_leader\": \"/组织/人物/校长\",\n",
    "    \"org2per_leader\": \"/组织/人物/领导人\",\n",
    "\n",
    "    \"org2org_surroundings\": \"/组织/组织/周边\",\n",
    "\n",
    "\n",
    "    \"org2pl_location\": \"/组织/地点/位于\",\n",
    "\n",
    "\n",
    "    \"pl2per_main_character\": \"/地点/人物/相关人物\",\n",
    "\n",
    "    \"pl2pl_location\": \"/地点/地点/位于\",\n",
    "    \"pl2pl_adjacement\": \"/地点/地点/毗邻\",\n",
    "    \"pl2pl_contains\": \"/地点/地点/包含\",\n",
    "    \"pl2pl_captial\": \"/地点/地点/首都\",\n",
    "\n",
    "    \"pl2org_sights\": \"/地点/组织/景点\",\n",
    "    \"pl2oth_climate\": \"/地点/其它/气候\"\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def check(string, tgs):\n",
    "    for t in tgs:\n",
    "        if t in string:\n",
    "            return True\n",
    "    return False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 标注\n",
    "processed_data = []\n",
    "for can in data:\n",
    "    # can  [sentence, head, tail, segment]\n",
    "    sentence, head, tail, segment = can\n",
    "    if tail not in entities[can[1]].values():\n",
    "        can.append('NA')\n",
    "        processed_data.append(can)\n",
    "        continue\n",
    "    rel = ''\n",
    "    for key, value in entities[can[1]].items():\n",
    "        if value==tail:\n",
    "            rel = key\n",
    "    \n",
    "    for key, value in config.items():\n",
    "        if rel in value:\n",
    "            tp = key.split('_')[0].split('2')\n",
    "            if check(entities[can[1]].get('BaiduTAG', \"\"), tgs[tp[0]]):\n",
    "                if tp[1]=='oth':\n",
    "                    can.append(name[key])\n",
    "                    processed_data.append(can)\n",
    "                elif check(entities[can[2]].get('BaiduTAG', \"\"), tgs[tp[1]]):\n",
    "                    can.append(name[key])\n",
    "                    processed_data.append(can)\n",
    "            break\n",
    "len(processed_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "e2id = {}\n",
    "count = 0\n",
    "e_set = set()\n",
    "for i in processed_data:\n",
    "    e_set.add(i[1])\n",
    "    e_set.add(i[2])\n",
    "for e in e_set:\n",
    "    e2id[e] = count\n",
    "    count += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "total_data = []\n",
    "for d in processed_data:\n",
    "    total_data.append({\n",
    "        'head':{\n",
    "            'word': d[1],\n",
    "            'id': str(e2id[d[1]])\n",
    "        },\n",
    "        'relation': d[-1],\n",
    "        'tail': {\n",
    "            'word': d[2],\n",
    "            'id': str(e2id[d[2]])\n",
    "        },\n",
    "        'sentence': ' '.join(d[-2]),\n",
    "        'ori_sen': d[0],\n",
    "        'sen_seg': d[-2]\n",
    "    })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "rl = {}\n",
    "for i in total_data:\n",
    "    rl[i['relation']] = rl.get(i['relation'], 0) + 1\n",
    "\n",
    "trl = {}\n",
    "record = {}\n",
    "for k,v in rl.items():\n",
    "    trl[k] = int(max(math.floor(v*0.25), 1))\n",
    "    record[k] = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train = []\n",
    "test = []\n",
    "for i in total_data:\n",
    "    if record[i['relation']]<=trl[i['relation']]:\n",
    "        test.append(i)\n",
    "        record[i['relation']] += 1\n",
    "    else:\n",
    "        train.append(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not os.path.isdir('../data'):\n",
    "    os.mkdir('../data')\n",
    "if not os.path.isdir('../data/chinese'):\n",
    "    os.mkdir('../data/chinese')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "with open('../data/chinese/train.json', 'w', encoding='utf8') as f:\n",
    "    json.dump(train, f)\n",
    "with open('../data/chinese/test.json', 'w', encoding='utf8') as f:\n",
    "    json.dump(test, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "count = 0\n",
    "r2id = {}\n",
    "for k in list(rl.keys()):\n",
    "    r2id[k] = count\n",
    "    count +=1\n",
    "with open('../data/chinese/rel2id.json', 'w', encoding='utf8') as f:\n",
    "    json.dump(r2id, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练词向量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "senlist = []\n",
    "for d in data:\n",
    "    senlist.append(d[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Word2Vec(senlist, sg=5, min_count=1, size=50, workers=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "w2v = {}\n",
    "for i in model.wv.index2word:\n",
    "    w2v[i] = model[i]\n",
    "len(w2v)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "new_w2v = []\n",
    "for word, vec in w2v.items():\n",
    "    new_w2v.append({\n",
    "        'word': word,\n",
    "        'vec': [float(i) for i in vec]\n",
    "    })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "with open('../data/chinese/word_vec.json', 'w', encoding='utf8') as f:\n",
    "    json.dump(new_w2v, f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}


================================================
FILE: kg_data/data_process.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import json\n",
    "import re\n",
    "import os\n",
    "import math\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1.读取初始数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = []\n",
    "with open('baike_triples.txt', 'r', encoding='utf8') as f:\n",
    "    for line in tqdm(f):\n",
    "        data.append(line.strip().split('\\t'))\n",
    "print(len(data), data[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "2.保留元素全为中文的三元组"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_chinese = re.compile('^[\\u4e00-\\u9fa5]*$')\n",
    "new_data = []\n",
    "for triple in tqdm(data):\n",
    "    if bool(re.search(all_chinese, triple[0])) and bool(re.search(all_chinese, triple[1])) and bool(re.search(all_chinese, triple[2])):\n",
    "        new_data.append(triple)\n",
    "len(new_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "3.实体字典"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "entities = {}\n",
    "for triple in tqdm(new_data):\n",
    "    entities[triple[0]] = entities.get(triple[0], {})\n",
    "    entities[triple[0]][triple[1]] = triple[2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not os.path.isdir('processed'):\n",
    "    os.mkdir('processed')\n",
    "with open('processed/entities.pkl', 'wb') as f:\n",
    "    pickle.dump(entities, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "e_set = set()\n",
    "with open('processed/entity.txt', 'w', encoding='utf8') as f:\n",
    "    for e in tqdm(entities.keys()):\n",
    "        if e not in e_set:\n",
    "            e_set.add(e)\n",
    "            f.write(e+'\\n')\n",
    "len(e_set)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "4.获取句子集合"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_sentence(sent):\n",
    "    sent = re.sub(r' +', ' ', sent)\n",
    "    sent = re.sub(r'[^\\u4e00-\\u9fa5,\\?\\!,。?::!、;\\(\\)() ]', '', sent)\n",
    "    sent = re.split(r'[\\?\\!。?!]', sent)\n",
    "    return sent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sents = []\n",
    "for triple in tqdm(data):\n",
    "    if triple[1]=='BaiduCARD':\n",
    "        sent = process_sentence(triple[2])\n",
    "        if sent != \"\":\n",
    "            sents.append(sent)\n",
    "print(len(sents), sents[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "5.句子存为100个文件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not os.path.isdir('processed/sentences'):\n",
    "    os.mkdir('processed/sentences')\n",
    "\n",
    "    begin = 0\n",
    "count = math.ceil(len(sents)*1.0/100)\n",
    "end = min(begin+count, len(sents))\n",
    "for i in tqdm(range(100)):\n",
    "    with open('processed/sentences/sen'+str(i), 'wb') as f:\n",
    "        pickle.dump(sents[begin:end], f)\n",
    "    begin = end\n",
    "    end = min(begin+count, len(sents))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}


================================================
FILE: kg_data/stop_word.txt
================================================

可能
起
便于
有些
上述
纯粹
尽然
乃至于
极为
也
并不
其次
矣哉
为何
到头来
此外
[①e]
8
就是了
故此
移动
旁人
立刻
路经
使
元/吨
刚
归齐
归根结底
!
〕
迅速
获得
4
已
冒
哎
已矣
哪天
这么点儿
哪样
一些
甚且
起头
大
默默地
⑩
跟
逐步
它
拦腰
不仅仅
里面
同
设或
要么
以後
—
)÷(1-
扩大
它的
精光
归
战斗
成为
谁知
这般
此后
尽管如此
或多或少
{-
替
照
千万
人
连日
这么些
广泛
保险
到头
把
↑
前后
最近
他
瑟瑟
一片
正值
没有
而且
云云
正常
快要
从新
了
窃
光是
长此下去
挨家挨户
届时
而论
大批
“
一面
之前
大多数
较比
并肩
[⑧]
昂然
各位
[①④]
正是
在
自
活
最后
那些
不免
[①⑨]
设若
立即
蛮
尽心竭力
联系
叫做
具有
宣布
既
毕竟
那般
先生
啊呀
表明
.
并没有
切莫
一番
倒是
[②⑤]
以外
一来
不够
赶快
二
另行
尽可能
嗡嗡
另
格外
[②a]
毫无保留地
彻底
不独
经过
大家
即令
愤然
于
[③h]
[①g]
个别
接连不断
逐渐
$
简而言之
呜呼
适应
譬如
某某
庶几
背地里
具体地说
自打
仍旧
突然
趁热
要求
>
怪
常言说得好
哎呀
临到
强烈
”
--
此次
不可抗拒
另外
能
"
凑巧
不仅
唉
傥然
穷年累月
.一
为
充其极
敢于
'
单纯
起见
除了
如何
={
来说
企图
嘿
必将
毋宁
有所
那样
均
次第
挨个
失去
最好
上面
能够
举行
它是
出现
叮当
挨次
…………………………………………………③
看来
动辄
趁机
亲自
才能
亲手
与
亲眼
有的是
不定
但愿
由
内
下面
当前
必须
全都
共
这么
率然
不日
最
--
不足
几度
不如
凭借
猛然
看起来
俺们
暗地里
绝非
[②j]
而又
依靠
如同
紧接着
日见
[②e]
古来
依
即若
[④e]
多
看看
那么些
怎么
而是
各级
\
以来
如常
非徒
对比
以及
继后
所有
$
不但...而且
何处
才
清楚
居然
难说
或许
除此
按时
}
:
成年
来着
别处
何时
喽
继之
除去
成年累月
[③d]
变成
一何
每
除开
除
本身
偏偏
而
饱
别是
只怕
哪些
一时
倒不如说
到处
本地
反过来
@
自己
──
成心
策略地
平素
①
φ
全面
粗
等
然后
为什么
什麽
必
般的
哩
任凭
难怪
权时
exp
何苦
一转眼
出于
只限
假如
至若
藉以
只
突出
)
常言道
6
适用
无宁
任
陈年
理应
共同
以故
简言之
当真
欤
要
各
并不是
结合
安全
莫若
常
矣
按说
砰
因
不外
允许
不问
当庭
望
[②i]
的确
多数
咱们
若
尚且
高低
继而
这点
Ⅲ
不管怎样
只消
皆可
何以
与其
+ξ
各个
暗自
[①h]
以至于
惯常
有
这
相当
从速
以为
岂但
不外乎
焉
直接
[⑦]
仅仅
通过
也是
差一点
恰似
尔后
进入
除却
彼时
果真
赖以
但凡
所在
是以
尽早
真是
每当
出去
万
sup
一旦
一切
相反
即便
也好
截至
无论
朝着
_
;
八
真正
趁便
保管
要不
多么
啷当
咦
反映
|
两者
综上所述
于是乎
充分
呀
属于
诸
截然
倘若
[⑥]
从此
略
明显
互
齐
之类
难得
具体
恍然
大体
时候
趁着
遇到
按
老
再有
引起
7
若非
决不
[①B]
距
高兴
不可
即刻
绝对
运用
难道
极端
非常
[①⑥]
认真
末##末
哟
而况
其他
针对
历
大事
[③b]
类如
何妨
无法
%
总的说来
同一
/
:
上来
今
达到
或者
接下来
呼啦
比方
说明
[③e]
各地
谨
据此
各式
伟大
传
既往
2
b]
不得
不敢
是
这会儿
一则通过
呐
近年来
之一
下列
比及
防止
么
第二
6
何
[①③]
最大
从未
何须
颇
整个
不惟
等等
殆
如此
此间
争取
觉得
+
直到
得起
一个
越是
自个儿
即使
[②⑧]
独
或是
往往
率尔
零
{
[③①]
像
相信
较之
纵令
尤其
乘
好象
反倒是
随著
一样
是不是
绝不
更加
哪边
、
从古至今
日渐
将近
}>
下
偶尔
[①⑧]
除此以外
呕
一方面
当口儿
良好
亲口
待到
———
加入
犹且
且不说
即如
经
——
即是说
与其说
-
不可开交
方
(
×××
亲身
据
要不然
到底
复杂
比如说
不时
莫不
抑或
~±
哪年
本人
啦
奇
固
这样
来
~
③
不久
不限
不对
[③g]
++
您们
显然
总结
风雨无阻
如上所述
}
以免
虽则
部分
简直
打开天窗说亮话
极
哈哈
遵循
介于
’‘
现代
-
甚或
由是
亦
吱
嘛
果然
2.3%
这里
假使
根本
据称
达旦
[①⑦]
交口
特点
&
咱
正在
=☆
个人
马上
毫无
嘎嘎
巨大
开始
几经
因而
②
不拘
`
不是
单
除此之外
冲
更为
[②B]
这一来
全部
总的来说
联袂
或
其二
偶而
其
必然
[②⑥]
…
人家
大体上
保持
六
总是
不光
换言之
这个
尔尔
且
广大
四
近几年来
哉
多多益善
加之
断然
每年
多年前
再者
所
三
碰巧
[⑤f]
即或
罢了
起首
八成
多多少少
暗中
日益
喂
沿着
绝顶
主张
所谓
哈
容易
既…又
将要
不论
各自
甚么
倘然
某
迄
那麽
有及
-β
再
比
从小
她的
局外
该当
一定
不止
倘
敢
可是
省得
丰富
1.
你是
吧
[③]
随后
注意
-[*]-
何必
猛然间
……
怎
一起
三天两头
知道
宁可
从事
相应
如若
少数
这时
为了
啪达
⑥
据悉
忽地
φ.
顷刻
〔
㈧
进而
使得
[①c]
原来
赶
老老实实
不妨
可以
0
默然
可好
心里
自各儿
乘势
彻夜
密切
5:0
哦
)、
︿
一一
应用
特别是
相似
不若
怎么办
双方
就
从头
故
二话没说
是的
普通
比如
[]
【
连袂
取道
按理
起初
帮助
与否
不同
七
哼唷
大张旗鼓
孰料
啊
需要
实现
还要
〈
挨门挨户
存心
宁肯
一边
凝神
顷刻之间
先后
[③a]
孰知
(
还有
老是
就是
完成
抽冷子
可
大凡
如期
做到
因此
最後
若是
急匆匆
范围
大都
理该
[③c]
|
够瞧的
这边
重新
[①C]
→
沿
不巧
极度
不满
而已
有点
a]
好
哪儿
犹自
μ
宁愿
[①a]
不止一次
形成
您是
在于
?
下去
川流不息
主要
后
从宽
为着
吓
满足
A
勃然
尔等
它们的
敞开儿
连声
牢牢
到目前为止
根据
不能
从无到有
过
确定
怎样
出
已经
5
总的来看
?
行为
[①d]
固然
到
归根到底
从今以后
并且
此时
不然
⑨
除非
#
2
不下
然
较
℃
纯
凭
对应
扑通
遭到
每逢
是否
考虑
大约
为什麽
不怎么
比照
去
彼此
怎奈
梆
[②⑩]
]
有著
有的
何乐而不为
相等
儿
[②④
恰如
打从
不怕
][
强调
不亦乐乎
至
『
不力
不了
不过
轰然
一般
上
恐怕
相对而言
今後
二话不说
恰恰
取得
接著
一直
¥
差不多
[①⑤]
后来
<φ
9
③]
既然
啊哈
哪个
诸位
因为
所幸
之所以
哗
嘿嘿
贼死
』
自家
大抵
<Δ
屡屡
采取
别人
转动
不得不
拿
那末
不再
顿时
全身心
顺
全年
4
转贴
其后
嗯
别管
出来
处处
背靠背
呸
行动
[③F]
另方面
R.L.
余外
照着
不得了
严重
他是
[②g]
况且
白
种
待
得到
呵呵
那
限制
及其
[①②]
3
sub
从中
决定
A
一
@
》
被
似乎
对方
得出
相对
打
大量
彼
另悉
不一
迟早
[①D]
[⑩]
小
某些
用
将才
[④d]
一则
岂止
[⑤b]
长话短说
⑤
Lex
此中
向使
[
放量
尽快
切
不成
任何
该
乒
构成
离
当儿
掌握
[*]
若夫
现在
并无
基于
以致
假若
甚而
咋
矣乎
中间
从而
严格
请勿
//
不比
维持
[②b]
即
不胜
哇
累次
莫
三番两次
有关
恰恰相反
不起
不经意
人民
<<
除外
兮
上下
://
但
切不可
然後
难道说
不至于
别说
哪
尽量
完全
;
方面
借
光
不但
[②f]
社会主义
到了儿
至今
第
上去
漫说
从重
可见
我
乘隙
准备
反之
无
後面
一.
处理
较为
很少
[④c]
喏
尽管
进来
另一个
而言
普遍
有着
[⑨]
...
附近
恰逢
千万千万
极了
能否
绝
表示
巴巴
因了
[②③]
)
嗡
向
[②⑦]
以下
挨门逐户
任务
此
竟
[
这么样
另一方面
之后
进行
[①f]
从严
说来
之
单单
当中
正如
·
故意
专门
恰巧
大大
敢情
.日
c]
隔夜
大致
]∧′=[
来看
的
好的
没奈何
则甚
或则
换句话说
那么
按照
了解
呢
万一
最高
竟而
再则
Ψ
过去
大多
存在
略为
并没
岂非
臭
同样
非独
由于
[②h]
合理
反过来说
甭
不已
对于
不曾
比起
促进
怎么样
经常
然则
为止
仅
8
乘胜
独自
极力
鉴于
看样子
】
分期分批
当场
加以
深入
产生
日复一日
总之
二来
则
欢迎
让
呼哧
首先
接着
诚然
屡次
[②]
巩固
却
而后
0:2
'
乌乎
≈
快
前此
也罢
=
多年来
前面
三番五次
呵
从古到今
>
不仅...而且
却不
者
竟然
白白
#
的话
实际
分头
哎哟
同时
就要
必要
每时每刻
吧哒
奈
+
见
不变
要是
并非
间或
长线
人们
恰好
赶早不赶晚
立马
趁
如此等等
[⑤d]
如下
起先
尽如人意
看见
看
不仅仅是
或曰
比较
慢说
常常
庶乎
*
看上去
随时
兼之
[⑤]
基本上
概
左右
加上
积极
近
=″
,也
但是
乃
还
来得及
共总
代替
每个
再说
极其
得
当然
动不动
倒不如
以
奋勇
[④a]
曾
你的
何尝
从此以后
几乎
为主
有利
具体说来
只要
嘎
就此
得天独厚
当
如其
鄙人
老大
没
立时
长期以来
乘机
汝
不会
顷刻间
显著
坚决
全然
=(
谁
分期
喀
切切
自从
否则
啥
〉
随
正巧
看到
前进
而外
嗳
尔
顺着
虽
即将
之後
方才
别的
传说
一天
半
怎麽
互相
每天
那里
弹指之间
及至
叫
不尽
此地
此处
今年
反而
[③⑩]
上升
不只
适当
顷
他的
纵然
以前
人人
。
∈[
也就是说
充其量
来自
俺
如上
据我所知
论说
基本
方能
我们
......
几时
[①i]
[⑤e]
不由得
当地
趁势
自身
[-
阿
本着
练习
<±
据实
这些
通常
哪怕
《
这儿
[②G]
大面儿上
这种
你们
以后
些
切勿
来讲
[①A]
中小
隔日
不消
那个
不常
.数
▲
只是
她是
12%
e]
全体
不然的话
[⑤a]
转变
》),
[④b]
进步
从来
着呢
从轻
[①o]
如是
由此可见
得了
..
吗
如今
~+
从优
对待
并排
诸如
公然
五
宁
和
这麽
各人
例如
趁早
便
就算
认识
借此
当时
<
然而
哪里
非特
∪φ∈
_
对
纵使
多多
不能不
咧
只当
今天
■
从不
及时
[②②]
不择手段
乘虚
咚
我是
往
多次
豁然
有效
以至
很
还是
<λ
那边
等到
理当
随着
[②c]
今后
靠
并
大举
给
故而
决非
凡是
说说
她
又及
意思
>>
明确
开展
<
3
数/
仍然
他们
千
f]
弗
必定
如果
不得已
将
至于
γ
既是
为此
其实
重大
]
非但
ZXFITL
[①]
倍感
几番
设使
啊哟
究竟
论
致
刚巧
反手
9
初
近来
过于
总而言之
谁人
作为
莫不然
”,
,
~
毫不
先不先
都
再次
其它
就是说
什么样
立地
④
话说
呃
不尽然
一下
②c
呆呆地
特殊
据说
有力
×
除此而外
管
...................
进去
逢
屡次三番
毫无例外
尽心尽力
因着
如前所述
〕〔
嘘
带
关于
她们
不管
当着
一次
’
朝
加强
嗬
凡
大力
方便
避免
使用
尽
分别
起来
莫非
‘
十分
大略
多亏
::
会
遵照
不要
们
向着
相同
不特
不大
某个
如
累年
[②
又
眨眼
ng昉
其一
坚持
满
%
不少
具体来说
纵
略微
造成
嘻
其余
譬喻
多少
以便
依据
后面
地
後来
虽说
过来
~~~~
日臻
集中
不断
于是
由此
大概
唯有
反之则
忽然
什么
喔唷
刚才
连
迫于
顶多
与此同时
己
挨着
有时
./
叮咚
岂
只有
1
应该
5
乃至
前者
临
借以
召开
陡然
=
其中
=[
哼
0
及
非得
受到
您
[②①]
何止
那儿
姑且
以期
看出
哗啦
嘎登
反应
*
曾经
诚如
/
每每
九
腾
从早到晚
不知不觉
反之亦然
若果
В
一致
再者说
开外
本
怕
连同
愿意
始而
应当
目前
倍加
甚至
年复一年
虽然
既...又
」
更
LI
莫如
来不及
沙沙
这就是说
⑧
先後
[⑤]]
你
乎
那会儿
呗
甚至于
甫
重要
反倒
仍
组成
怪不得
更进一步
^
如次
倘使
[④]
似的
云尔
后者
用来
认为
规定
好在
别
从
倘或
几
.
当下
要不是
个
⑦
结果
′|
它们
&
着
再其次
谁料
许多
1
啐
我的
当头
在下
串行
大不了
常言说
缕缕
依照
刚好
不料
不必
=-
屡
不单
问题
立
按期
呜
不
巴
Δ
自后
替代
>λ
举凡
所以
咳
他人
且说
惟其
处在
传闻
!
不迭
继续
,
定
就地
连连
伙同
这次
何况
[②d]
全力
7
′∈
下来
极大
[①E]
那时
各种
当即
边
略加
周围
那么样
[①①]
连日来
匆匆
以上
很多


================================================
FILE: nrekit/data_loader.py
================================================
from six import iteritems

import json
import os
import multiprocessing
import numpy as np
import random

class file_data_loader:
    def __next__(self):
        raise NotImplementedError
    
    def next(self):
        return self.__next__()

    def next_batch(self, batch_size):
        raise NotImplementedError

class npy_data_loader(file_data_loader):
    MODE_INSTANCE = 0      # One batch contains batch_size instances.
    MODE_ENTPAIR_BAG = 1   # One batch contains batch_size bags, instances in which have the same entity pair (usually for testing).
    MODE_RELFACT_BAG = 2   # One batch contains batch size bags, instances in which have the same relation fact. (usually for training).

    def __iter__(self):
        return self

    def __init__(self, data_dir, prefix, mode, word_vec_npy='vec.npy', shuffle=True, max_length=120, batch_size=160):
        if not os.path.isdir(data_dir):
            raise Exception("[ERROR] Data dir doesn't exist!")
        self.mode = mode
        self.shuffle = shuffle
        self.max_length = max_length
        self.batch_size = batch_size
        self.word_vec_mat = np.load(os.path.join(data_dir, word_vec_npy))
        self.data_word = np.load(os.path.join(data_dir, prefix + "_word.npy")) 
        self.data_pos1 = np.load(os.path.join(data_dir, prefix + "_pos1.npy")) 
        self.data_pos2 = np.load(os.path.join(data_dir, prefix + "_pos2.npy")) 
        self.data_mask = np.load(os.path.join(data_dir, prefix + "_mask.npy")) 
        self.data_rel = np.load(os.path.join(data_dir, prefix + "_label.npy")) 
        self.data_length = np.load(os.path.join(data_dir, prefix + "_len.npy")) 
        self.scope = np.load(os.path.join(data_dir, prefix + "_instance_scope.npy"))
        self.triple = np.load(os.path.join(data_dir, prefix + "_instance_triple.npy"))
        self.relfact_tot = len(self.triple)
        for i in range(self.scope.shape[0]):
            self.scope[i][1] += 1

        self.instance_tot = self.data_word.shape[0]
        self.rel_tot = 53

        if self.mode == self.MODE_INSTANCE:
            self.order = list(range(self.instance_tot))
        else:
            self.order = list(range(len(self.scope)))
        self.idx = 0

        if self.shuffle:
            random.shuffle(self.order) 

        print("Total relation fact: %d" % (self.relfact_tot))

    def __next__(self):
        return self.next_batch(self.batch_size)

    def next_batch(self, batch_size):
        if self.idx >= len(self.order):
            self.idx = 0
            if self.shuffle:
                random.shuffle(self.order) 
            raise StopIteration

        batch_data = {}

        if self.mode == self.MODE_INSTANCE:
            idx0 = self.idx
            idx1 = self.idx + batch_size
            if idx1 > len(self.order):
                self.idx = 0
                if self.shuffle:
                    random.shuffle(self.order) 
                raise StopIteration
            self.idx = idx1
            batch_data['word'] = self.data_word[idx0:idx1]
            batch_data['pos1'] = self.data_pos1[idx0:idx1]
            batch_data['pos2'] = self.data_pos2[idx0:idx1]
            batch_data['rel'] = self.data_rel[idx0:idx1]
            batch_data['length'] = self.data_length[idx0:idx1]
            batch_data['scope'] = np.stack([list(range(idx1 - idx0)), list(range(1, idx1 - idx0 + 1))], axis=1)
        elif self.mode == self.MODE_ENTPAIR_BAG or self.mode == self.MODE_RELFACT_BAG:
            idx0 = self.idx
            idx1 = self.idx + batch_size
            if idx1 > len(self.order):
                self.idx = 0
                if self.shuffle:
                    random.shuffle(self.order) 
                raise StopIteration
            self.idx = idx1
            _word = []
            _pos1 = []
            _pos2 = []
            _rel = []
            _ins_rel = []
            _multi_rel = []
            _length = []
            _scope = []
            _mask = []
            cur_pos = 0
            for i in range(idx0, idx1):
                _word.append(self.data_word[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])
                _pos1.append(self.data_pos1[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])
                _pos2.append(self.data_pos2[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])
                _rel.append(self.data_rel[self.scope[self.order[i]][0]])
                _ins_rel.append(self.data_rel[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])
                _length.append(self.data_length[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])
                _mask.append(self.data_mask[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])
                bag_size = self.scope[self.order[i]][1] - self.scope[self.order[i]][0]
                _scope.append([cur_pos, cur_pos + bag_size])
                cur_pos = cur_pos + bag_size
                if self.mode == self.MODE_ENTPAIR_BAG:
                    _one_multi_rel = np.zeros((self.rel_tot), dtype=np.int32)
                    for j in range(self.scope[self.order[i]][0], self.scope[self.order[i]][1]):
                        _one_multi_rel[self.data_rel[j]] = 1
                    _multi_rel.append(_one_multi_rel)
            batch_data['word'] = np.concatenate(_word)
            batch_data['pos1'] = np.concatenate(_pos1)
            batch_data['pos2'] = np.concatenate(_pos2)
            batch_data['rel'] = np.stack(_rel)
            batch_data['ins_rel'] = np.concatenate(_ins_rel)
            if self.mode == self.MODE_ENTPAIR_BAG:
                batch_data['multi_rel'] = np.stack(_multi_rel)
            batch_data['length'] = np.concatenate(_length)
            batch_data['scope'] = np.stack(_scope)
            batch_data['mask'] = np.concatenate(_mask)

        return batch_data

class json_file_data_loader(file_data_loader):
    MODE_INSTANCE = 0      # One batch contains batch_size instances.
    MODE_ENTPAIR_BAG = 1   # One batch contains batch_size bags, instances in which have the same entity pair (usually for testing).
    MODE_RELFACT_BAG = 2   # One batch contains batch size bags, instances in which have the same relation fact. (usually for training).

    def _load_preprocessed_file(self):
        name_prefix = '.'.join(self.file_name.split('/')[-1].split('.')[:-1])
        word_vec_name_prefix = '.'.join(self.word_vec_file_name.split('/')[-1].split('.')[:-1])
        processed_data_dir = '_processed_data'
        if not os.path.isdir(processed_data_dir):
            return False
        word_npy_file_name = os.path.join(processed_data_dir, name_prefix + '_word.npy')
        pos1_npy_file_name = os.path.join(processed_data_dir, name_prefix + '_pos1.npy')
        pos2_npy_file_name = os.path.join(processed_data_dir, name_prefix + '_pos2.npy')
        rel_npy_file_name = os.path.join(processed_data_dir, name_prefix + '_rel.npy')
        mask_npy_file_name = os.path.join(processed_data_dir, name_prefix + '_mask.npy')
        length_npy_file_name = os.path.join(processed_data_dir, name_prefix + '_length.npy')
        entpair2scope_file_name = os.path.join(processed_data_dir, name_prefix + '_entpair2scope.json')
        relfact2scope_file_name = os.path.join(processed_data_dir, name_prefix + '_relfact2scope.json')
        word_vec_mat_file_name = os.path.join(processed_data_dir, word_vec_name_prefix + '_mat.npy')
        word2id_file_name = os.path.join(processed_data_dir, word_vec_name_prefix + '_word2id.json')
        if not os.path.exists(word_npy_file_name) or \
           not os.path.exists(pos1_npy_file_name) or \
           not os.path.exists(pos2_npy_file_name) or \
           not os.path.exists(rel_npy_file_name) or \
           not os.path.exists(mask_npy_file_name) or \
           not os.path.exists(length_npy_file_name) or \
           not os.path.exists(entpair2scope_file_name) or \
           not os.path.exists(relfact2scope_file_name) or \
           not os.path.exists(word_vec_mat_file_name) or \
           not os.path.exists(word2id_file_name):
            return False
        print("Pre-processed files exist. Loading them...")
        self.data_word = np.load(word_npy_file_name)
        self.data_pos1 = np.load(pos1_npy_file_name)
        self.data_pos2 = np.load(pos2_npy_file_name)
        self.data_rel = np.load(rel_npy_file_name)
        self.data_mask = np.load(mask_npy_file_name)
        self.data_length = np.load(length_npy_file_name)
        self.entpair2scope = json.load(open(entpair2scope_file_name))
        self.relfact2scope = json.load(open(relfact2scope_file_name))
        self.word_vec_mat = np.load(word_vec_mat_file_name)
        self.word2id = json.load(open(word2id_file_name))
        if self.data_word.shape[1] != self.max_length:
            print("Pre-processed files don't match current settings. Reprocessing...")
            return False
        print("Finish loading")
        return True

    def __init__(self, file_name, word_vec_file_name, rel2id_file_name, mode, shuffle=True, max_length=120, case_sensitive=False, reprocess=False, batch_size=160):
        '''
        file_name: Json file storing the data in the following format
            [
                {
                    'sentence': 'Bill Gates is the founder of Microsoft .',
                    'head': {'word': 'Bill Gates', ...(other information)},
                    'tail': {'word': 'Microsoft', ...(other information)},
                    'relation': 'founder'
                },
                ...
            ]
        word_vec_file_name: Json file storing word vectors in the following format
            [
                {'word': 'the', 'vec': [0.418, 0.24968, ...]},
                {'word': ',', 'vec': [0.013441, 0.23682, ...]},
                ...
            ]
        rel2id_file_name: Json file storing relation-to-id diction in the following format
            {
                'NA': 0
                'founder': 1
                ...
            }
            **IMPORTANT**: make sure the id of NA is 0!
        mode: Specify how to get a batch of data. See MODE_* constants for details.
        shuffle: Whether to shuffle the data, default as True. You should use shuffle when training.
        max_length: The length that all the sentences need to be extend to, default as 120.
        case_sensitive: Whether the data processing is case-sensitive, default as False.
        reprocess: Do the pre-processing whether there exist pre-processed files, default as False.
        batch_size: The size of each batch, default as 160.
        '''

        self.file_name = file_name
        self.word_vec_file_name = word_vec_file_name
        self.case_sensitive = case_sensitive
        self.max_length = max_length
        self.mode = mode
        self.shuffle = shuffle
        self.batch_size = batch_size
        self.rel2id = json.load(open(rel2id_file_name))

        if reprocess or not self._load_preprocessed_file(): # Try to load pre-processed files:
            # Check files
            if file_name is None or not os.path.isfile(file_name):
                raise Exception("[ERROR] Data file doesn't exist")
            if word_vec_file_name is None or not os.path.isfile(word_vec_file_name):
                raise Exception("[ERROR] Word vector file doesn't exist")

            # Load files
            print("Loading data file...")
            self.ori_data = json.load(open(self.file_name, "r"))
            print("Finish loading")
            print("Loading word vector file...")
            self.ori_word_vec = json.load(open(self.word_vec_file_name, "r"))
            print("Finish loading")
            
            # Eliminate case sensitive
            if not case_sensitive:
                print("Elimiating case sensitive problem...")
                for i in range(len(self.ori_data)):
                    self.ori_data[i]['sentence'] = self.ori_data[i]['sentence'].lower()
                    self.ori_data[i]['head']['word'] = self.ori_data[i]['head']['word'].lower()
                    self.ori_data[i]['tail']['word'] = self.ori_data[i]['tail']['word'].lower()
                print("Finish eliminating")

            # Sort data by entities and relations
            print("Sort data...")
            self.ori_data.sort(key=lambda a: a['head']['id'] + '#' + a['tail']['id'] + '#' + a['relation'])
            print("Finish sorting")
       
            # Pre-process word vec
            self.word2id = {}
            self.word_vec_tot = len(self.ori_word_vec)
            UNK = self.word_vec_tot
            BLANK = self.word_vec_tot + 1
            self.word_vec_dim = len(self.ori_word_vec[0]['vec'])
            print("Got {} words of {} dims".format(self.word_vec_tot, self.word_vec_dim))
            print("Building word vector matrix and mapping...")
            self.word_vec_mat = np.zeros((self.word_vec_tot, self.word_vec_dim), dtype=np.float32)
            for cur_id, word in enumerate(self.ori_word_vec):
                w = word['word']
                if not case_sensitive:
                    w = w.lower()
                self.word2id[w] = cur_id
                self.word_vec_mat[cur_id, :] = word['vec']
            self.word2id['UNK'] = UNK
            self.word2id['BLANK'] = BLANK
            print("Finish building")

            # Pre-process data
            print("Pre-processing data...")
            self.instance_tot = len(self.ori_data)
            self.entpair2scope = {} # (head, tail) -> scope
            self.relfact2scope = {} # (head, tail, relation) -> scope
            self.data_word = np.zeros((self.instance_tot, self.max_length), dtype=np.int32)
            self.data_pos1 = np.zeros((self.instance_tot, self.max_length), dtype=np.int32) 
            self.data_pos2 = np.zeros((self.instance_tot, self.max_length), dtype=np.int32)
            self.data_rel = np.zeros((self.instance_tot), dtype=np.int32)
            self.data_mask = np.zeros((self.instance_tot, self.max_length), dtype=np.int32)
            self.data_length = np.zeros((self.instance_tot), dtype=np.int32)
            last_entpair = ''
            last_entpair_pos = -1
            last_relfact = ''
            last_relfact_pos = -1
            for i in range(self.instance_tot):
                ins = self.ori_data[i]
                if ins['relation'] in self.rel2id:
                    self.data_rel[i] = self.rel2id[ins['relation']]
                else:
                    self.data_rel[i] = self.rel2id['NA']
                sentence = ' '.join(ins['sentence'].split()) # delete extra spaces
                head = ins['head']['word']
                tail = ins['tail']['word']
                cur_entpair = ins['head']['id'] + '#' + ins['tail']['id']
                cur_relfact = ins['head']['id'] + '#' + ins['tail']['id'] + '#' + ins['relation']
                if cur_entpair != last_entpair:
                    if last_entpair != '':
                        self.entpair2scope[last_entpair] = [last_entpair_pos, i] # left closed right open
                    last_entpair = cur_entpair
                    last_entpair_pos = i
                if cur_relfact != last_relfact:
                    if last_relfact != '':
                        self.relfact2scope[last_relfact] = [last_relfact_pos, i]
                    last_relfact = cur_relfact
                    last_relfact_pos = i
                p1 = sentence.find(' ' + head + ' ')
                p2 = sentence.find(' ' + tail + ' ')
                if p1 == -1:
                    if sentence[:len(head) + 1] == head + " ":
                        p1 = 0
                    elif sentence[-len(head) - 1:] == " " + head:
                        p1 = len(sentence) - len(head)
                    else:
                        p1 = 0 # shouldn't happen
                else:
                    p1 += 1
                if p2 == -1:
                    if sentence[:len(tail) + 1] == tail + " ":
                        p2 = 0
                    elif sentence[-len(tail) - 1:] == " " + tail:
                        p2 = len(sentence) - len(tail)
                    else:
                        p2 = 0 # shouldn't happen
                else:
                    p2 += 1
                # if p1 == -1 or p2 == -1:
                #     raise Exception("[ERROR] Sentence doesn't contain the entity, index = {}, sentence = {}, head = {}, tail = {}".format(i, sentence, head, tail))

                words = sentence.split()
                cur_ref_data_word = self.data_word[i]         
                cur_pos = 0
                pos1 = -1
                pos2 = -1
                for j, word in enumerate(words):
                    if j < max_length:
                        if word in self.word2id:
                            cur_ref_data_word[j] = self.word2id[word]
                        else:
                            cur_ref_data_word[j] = UNK
                    if cur_pos == p1:
                        pos1 = j
                        p1 = -1
                    if cur_pos == p2:
                        pos2 = j
                        p2 = -1
                    cur_pos += len(word) + 1
                for j in range(j + 1, max_length):
                    cur_ref_data_word[j] = BLANK
                self.data_length[i] = len(words)
                if len(words) > max_length:
                    self.data_length[i] = max_length
                if pos1 == -1 or pos2 == -1:
                    raise Exception("[ERROR] Position error, index = {}, sentence = {}, head = {}, tail = {}".format(i, sentence, head, tail))
                if pos1 >= max_length:
                    pos1 = max_length - 1
                if pos2 >= max_length:
                    pos2 = max_length - 1
                pos_min = min(pos1, pos2)
                pos_max = max(pos1, pos2)
                for j in range(max_length):
                    self.data_pos1[i][j] = j - pos1 + max_length
                    self.data_pos2[i][j] = j - pos2 + max_length
                    if j >= self.data_length[i]:
                        self.data_mask[i][j] = 0
                    elif j <= pos_min:
                        self.data_mask[i][j] = 1
                    elif j <= pos_max:
                        self.data_mask[i][j] = 2
                    else:
                        self.data_mask[i][j] = 3
                    
            if last_entpair != '':
                self.entpair2scope[last_entpair] = [last_entpair_pos, self.instance_tot] # left closed right open
            if last_relfact != '':
                self.relfact2scope[last_relfact] = [last_relfact_pos, self.instance_tot]

            print("Finish pre-processing")     

            print("Storing processed files...")
            name_prefix = '.'.join(file_name.split('/')[-1].split('.')[:-1])
            word_vec_name_prefix = '.'.join(word_vec_file_name.split('/')[-1].split('.')[:-1])
            processed_data_dir = '_processed_data'
            if not os.path.isdir(processed_data_dir):
                os.mkdir(processed_data_dir)
            np.save(os.path.join(processed_data_dir, name_prefix + '_word.npy'), self.data_word)
            np.save(os.path.join(processed_data_dir, name_prefix + '_pos1.npy'), self.data_pos1)
            np.save(os.path.join(processed_data_dir, name_prefix + '_pos2.npy'), self.data_pos2)
            np.save(os.path.join(processed_data_dir, name_prefix + '_rel.npy'), self.data_rel)
            np.save(os.path.join(processed_data_dir, name_prefix + '_mask.npy'), self.data_mask)
            np.save(os.path.join(processed_data_dir, name_prefix + '_length.npy'), self.data_length)
            json.dump(self.entpair2scope, open(os.path.join(processed_data_dir, name_prefix + '_entpair2scope.json'), 'w'))
            json.dump(self.relfact2scope, open(os.path.join(processed_data_dir, name_prefix + '_relfact2scope.json'), 'w'))
            np.save(os.path.join(processed_data_dir, word_vec_name_prefix + '_mat.npy'), self.word_vec_mat)
            json.dump(self.word2id, open(os.path.join(processed_data_dir, word_vec_name_prefix + '_word2id.json'), 'w'))
            print("Finish storing")

        # Prepare for idx
        self.instance_tot = self.data_word.shape[0]
        self.entpair_tot = len(self.entpair2scope)
        self.relfact_tot = 0 # The number of relation facts, without NA.
        for key in self.relfact2scope:
            if key[-2:] != 'NA':
                self.relfact_tot += 1
        self.rel_tot = len(self.rel2id)

        if self.mode == self.MODE_INSTANCE:
            self.order = list(range(self.instance_tot))
        elif self.mode == self.MODE_ENTPAIR_BAG:
            self.order = list(range(len(self.entpair2scope)))
            self.scope_name = []
            self.scope = []
            for key, value in iteritems(self.entpair2scope):
                self.scope_name.append(key)
                self.scope.append(value)
        elif self.mode == self.MODE_RELFACT_BAG:
            self.order = list(range(len(self.relfact2scope)))
            self.scope_name = []
            self.scope = []
            for key, value in iteritems(self.relfact2scope):
                self.scope_name.append(key)
                self.scope.append(value)
        else:
            raise Exception("[ERROR] Invalid mode")
        self.idx = 0

        if self.shuffle:
            random.shuffle(self.order) 

        print("Total relation fact: %d" % (self.relfact_tot))

    def __iter__(self):
        return self

    def __next__(self):
        return self.next_batch(self.batch_size)

    def next_batch(self, batch_size):
        if self.idx >= len(self.order):
            self.idx = 0
            if self.shuffle:
                random.shuffle(self.order) 
            raise StopIteration

        batch_data = {}

        if self.mode == self.MODE_INSTANCE:
            idx0 = self.idx
            idx1 = self.idx + batch_size
            if idx1 > len(self.order):
                idx1 = len(self.order)
            self.idx = idx1
            batch_data['word'] = self.data_word[idx0:idx1]
            batch_data['pos1'] = self.data_pos1[idx0:idx1]
            batch_data['pos2'] = self.data_pos2[idx0:idx1]
            batch_data['rel'] = self.data_rel[idx0:idx1]
            batch_data['mask'] = self.data_mask[idx0:idx1]
            batch_data['length'] = self.data_length[idx0:idx1]
            batch_data['scope'] = np.stack([list(range(batch_size)), list(range(1, batch_size + 1))], axis=1)
            if idx1 - idx0 < batch_size:
                padding = batch_size - (idx1 - idx0)
                batch_data['word'] = np.concatenate([batch_data['word'], np.zeros((padding, self.data_word.shape[-1]), dtype=np.int32)])
                batch_data['pos1'] = np.concatenate([batch_data['pos1'], np.zeros((padding, self.data_pos1.shape[-1]), dtype=np.int32)])
                batch_data['pos2'] = np.concatenate([batch_data['pos2'], np.zeros((padding, self.data_pos2.shape[-1]), dtype=np.int32)])
                batch_data['mask'] = np.concatenate([batch_data['mask'], np.zeros((padding, self.data_mask.shape[-1]), dtype=np.int32)])
                batch_data['rel'] = np.concatenate([batch_data['rel'], np.zeros((padding), dtype=np.int32)])
                batch_data['length'] = np.concatenate([batch_data['length'], np.zeros((padding), dtype=np.int32)])
        elif self.mode == self.MODE_ENTPAIR_BAG or self.mode == self.MODE_RELFACT_BAG:
            idx0 = self.idx
            idx1 = self.idx + batch_size
            if idx1 > len(self.order):
                idx1 = len(self.order)
            self.idx = idx1
            _word = []
            _pos1 = []
            _pos2 = []
            _mask = []
            _rel = []
            _ins_rel = []
            _multi_rel = []
            _entpair = []
            _length = []
            _scope = []
            cur_pos = 0
            for i in range(idx0, idx1):
                _word.append(self.data_word[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])
                _pos1.append(self.data_pos1[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])
                _pos2.append(self.data_pos2[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])
                _mask.append(self.data_mask[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])
                _rel.append(self.data_rel[self.scope[self.order[i]][0]])
                _ins_rel.append(self.data_rel[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])
                _length.append(self.data_length[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])
                bag_size = self.scope[self.order[i]][1] - self.scope[self.order[i]][0]
                _scope.append([cur_pos, cur_pos + bag_size])
                cur_pos = cur_pos + bag_size
                if self.mode == self.MODE_ENTPAIR_BAG:
                    _one_multi_rel = np.zeros((self.rel_tot), dtype=np.int32)
                    for j in range(self.scope[self.order[i]][0], self.scope[self.order[i]][1]):
                        _one_multi_rel[self.data_rel[j]] = 1
                    _multi_rel.append(_one_multi_rel)
                    _entpair.append(self.scope_name[self.order[i]])
            for i in range(batch_size - (idx1 - idx0)):
                _word.append(np.zeros((1, self.data_word.shape[-1]), dtype=np.int32))
                _pos1.append(np.zeros((1, self.data_pos1.shape[-1]), dtype=np.int32))
                _pos2.append(np.zeros((1, self.data_pos2.shape[-1]), dtype=np.int32))
                _mask.append(np.zeros((1, self.data_mask.shape[-1]), dtype=np.int32))
                _rel.append(0)
                _ins_rel.append(np.zeros((1), dtype=np.int32))
                _length.append(np.zeros((1), dtype=np.int32))
                _scope.append([cur_pos, cur_pos + 1])
                cur_pos += 1
                if self.mode == self.MODE_ENTPAIR_BAG:
                    _multi_rel.append(np.zeros((self.rel_tot), dtype=np.int32))
                    _entpair.append('None#None')
            batch_data['word'] = np.concatenate(_word)
            batch_data['pos1'] = np.concatenate(_pos1)
            batch_data['pos2'] = np.concatenate(_pos2)
            batch_data['mask'] = np.concatenate(_mask)
            batch_data['rel'] = np.stack(_rel)
            batch_data['ins_rel'] = np.concatenate(_ins_rel)
            if self.mode == self.MODE_ENTPAIR_BAG:
                batch_data['multi_rel'] = np.stack(_multi_rel)
                batch_data['entpair'] = _entpair
            batch_data['length'] = np.concatenate(_length)
            batch_data['scope'] = np.stack(_scope)

        return batch_data


================================================
FILE: nrekit/framework.py
================================================
import tensorflow as tf
import os
import sklearn.metrics
import numpy as np
import sys
import time

def average_gradients(tower_grads):
    """Calculate the average gradient for each shared variable across all towers.

    Note that this function provides a synchronization point across all towers.

    Args:
        tower_grads: List of lists of (gradient, variable) tuples. The outer list
            is over individual gradients. The inner list is over the gradient
            calculation for each tower.
    Returns:
         List of pairs of (gradient, variable) where the gradient has been averaged
         across all towers.
    """
    average_grads = []
    for grad_and_vars in zip(*tower_grads):
        # Note that each grad_and_vars looks like the following:
        #     ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
        grads = []
        for g, _ in grad_and_vars:
            # Add 0 dimension to the gradients to represent the tower.
            expanded_g = tf.expand_dims(g, 0)

            # Append on a 'tower' dimension which we will average over below.
            grads.append(expanded_g)

        # Average over the 'tower' dimension.
        grad = tf.concat(axis=0, values=grads)
        grad = tf.reduce_mean(grad, 0)

        # Keep in mind that the Variables are redundant because they are shared
        # across towers. So .. we will just return the first tower's pointer to
        # the Variable.
        v = grad_and_vars[0][1]
        grad_and_var = (grad, v)
        average_grads.append(grad_and_var)
    return average_grads

class re_model:
    def __init__(self, train_data_loader, batch_size, max_length=120):
        self.word = tf.placeholder(dtype=tf.int32, shape=[None, max_length], name='word')
        self.pos1 = tf.placeholder(dtype=tf.int32, shape=[None, max_length], name='pos1')
        self.pos2 = tf.placeholder(dtype=tf.int32, shape=[None, max_length], name='pos2')
        self.label = tf.placeholder(dtype=tf.int32, shape=[batch_size], name='label')
        self.ins_label = tf.placeholder(dtype=tf.int32, shape=[None], name='ins_label')
        self.length = tf.placeholder(dtype=tf.int32, shape=[None], name='length')
        self.scope = tf.placeholder(dtype=tf.int32, shape=[batch_size, 2], name='scope')
        self.train_data_loader = train_data_loader
        self.rel_tot = train_data_loader.rel_tot
        self.word_vec_mat = train_data_loader.word_vec_mat

    def loss(self):
        raise NotImplementedError
    
    def train_logit(self):
        raise NotImplementedError
    
    def test_logit(self):
        raise NotImplementedError

class re_framework:
    MODE_BAG = 0 # Train and test the model at bag level.
    MODE_INS = 1 # Train and test the model at instance level

    def __init__(self, train_data_loader, test_data_loader, max_length=120, batch_size=160):
        self.train_data_loader = train_data_loader
        self.test_data_loader = test_data_loader
        self.sess = None

    def one_step_multi_models(self, sess, models, batch_data_gen, run_array, return_label=True):
        feed_dict = {}
        batch_label = []
        for model in models:
            batch_data = batch_data_gen.next_batch(batch_data_gen.batch_size // len(models))
            feed_dict.update({
                model.word: batch_data['word'],
                model.pos1: batch_data['pos1'],
                model.pos2: batch_data['pos2'],
                model.label: batch_data['rel'],
                model.ins_label: batch_data['ins_rel'],
                model.scope: batch_data['scope'],
                model.length: batch_data['length'],
            })
            if 'mask' in batch_data and hasattr(model, "mask"):
                feed_dict.update({model.mask: batch_data['mask']})
            batch_label.append(batch_data['rel'])
        result = sess.run(run_array, feed_dict)
        batch_label = np.concatenate(batch_label)
        if return_label:
            result += [batch_label]
        return result

    def one_step(self, sess, model, batch_data, run_array):
        feed_dict = {
            model.word: batch_data['word'],
            model.pos1: batch_data['pos1'],
            model.pos2: batch_data['pos2'],
            model.label: batch_data['rel'],
            model.ins_label: batch_data['ins_rel'],
            model.scope: batch_data['scope'],
            model.length: batch_data['length'],
        }
        if 'mask' in batch_data and hasattr(model, "mask"):
            feed_dict.update({model.mask: batch_data['mask']})
        result = sess.run(run_array, feed_dict)
        return result

    def train(self,
              model,
              model_name,
              ckpt_dir='./checkpoint',
              summary_dir='./summary',
              test_result_dir='./test_result',
              learning_rate=0.5,
              max_epoch=60,
              pretrain_model=None,
              test_epoch=1,
              optimizer=tf.train.GradientDescentOptimizer,
              gpu_nums=1):
        
        assert(self.train_data_loader.batch_size % gpu_nums == 0)
        print("Start training...")
        
        # Init
        config = tf.ConfigProto(allow_soft_placement=True)
        self.sess = tf.Session(config=config)
        optimizer = optimizer(learning_rate)
        
        # Multi GPUs
        tower_grads = []
        tower_models = []
        for gpu_id in range(gpu_nums):
            with tf.device("/gpu:%d" % gpu_id):
                with tf.name_scope("gpu_%d" % gpu_id):
                    cur_model = model(self.train_data_loader, self.train_data_loader.batch_size // gpu_nums, self.train_data_loader.max_length)
                    tower_grads.append(optimizer.compute_gradients(cur_model.loss()))
                    tower_models.append(cur_model)
                    tf.add_to_collection("loss", cur_model.loss())
                    tf.add_to_collection("train_logit", cur_model.train_logit())

        loss_collection = tf.get_collection("loss")
        loss = tf.add_n(loss_collection) / len(loss_collection)
        train_logit_collection = tf.get_collection("train_logit")
        train_logit = tf.concat(train_logit_collection, 0)

        grads = average_gradients(tower_grads)
        train_op = optimizer.apply_gradients(grads)
        summary_writer = tf.summary.FileWriter(summary_dir, self.sess.graph)

        # Saver
        saver = tf.train.Saver(max_to_keep=None)
        if pretrain_model is None:
            self.sess.run(tf.global_variables_initializer())
        else:
            saver.restore(self.sess, pretrain_model)

        # Training
        best_metric = 0
        best_prec = None
        best_recall = None
        not_best_count = 0 # Stop training after several epochs without improvement.
        for epoch in range(max_epoch):
            print('###### Epoch ' + str(epoch) + ' ######')
            tot_correct = 0
            tot_not_na_correct = 0
            tot = 0
            tot_not_na = 0
            i = 0
            time_sum = 0
            while True:
                time_start = time.time()
                try:
                    iter_loss, iter_logit, _train_op, iter_label = self.one_step_multi_models(self.sess, tower_models, self.train_data_loader, [loss, train_logit, train_op])
                except StopIteration:
                    break
                time_end = time.time()
                t = time_end - time_start
                time_sum += t
                iter_output = iter_logit.argmax(-1)
                iter_correct = (iter_output == iter_label).sum()
                iter_not_na_correct = np.logical_and(iter_output == iter_label, iter_label != 0).sum()
                tot_correct += iter_correct
                tot_not_na_correct += iter_not_na_correct
                tot += iter_label.shape[0]
                tot_not_na += (iter_label != 0).sum()
                if tot_not_na > 0:
                    sys.stdout.write("epoch %d step %d time %.2f | loss: %f, not NA accuracy: %f, accuracy: %f\r" % (epoch, i, t, iter_loss, float(tot_not_na_correct) / tot_not_na, float(tot_correct) / tot))
                    sys.stdout.flush()
                i += 1
            print("\nAverage iteration time: %f" % (time_sum / i))

            if (epoch + 1) % test_epoch == 0:
                metric = self.test(model)
                if metric > best_metric:
                    best_metric = metric
                    best_prec = self.cur_prec
                    best_recall = self.cur_recall
                    print("Best model, storing...")
                    if not os.path.isdir(ckpt_dir):
                        os.mkdir(ckpt_dir)
                    path = saver.save(self.sess, os.path.join(ckpt_dir, model_name))
                    print("Finish storing")
                    not_best_count = 0
                else:
                    not_best_count += 1

            if not_best_count >= 20:
                break
        
        print("######")
        print("Finish training " + model_name)
        print("Best epoch auc = %f" % (best_metric))
        if (not best_prec is None) and (not best_recall is None):
            if not os.path.isdir(test_result_dir):
                os.mkdir(test_result_dir)
            np.save(os.path.join(test_result_dir, model_name + "_x.npy"), best_recall)
            np.save(os.path.join(test_result_dir, model_name + "_y.npy"), best_prec)

    def test(self,
             model,
             ckpt=None,
             return_result=False,
             mode=MODE_BAG):
        if mode == re_framework.MODE_BAG:
            return self.__test_bag__(model, ckpt=ckpt, return_result=return_result)
        elif mode == re_framework.MODE_INS:
            raise NotImplementedError
        else:
            raise NotImplementedError
        
    def __test_bag__(self, model, ckpt=None, return_result=False):
        print("Testing...")
        if self.sess == None:
            self.sess = tf.Session()
        model = model(self.test_data_loader, self.test_data_loader.batch_size, self.test_data_loader.max_length)
        if not ckpt is None:
            saver = tf.train.Saver()
            saver.restore(self.sess, ckpt)
        tot_correct = 0
        tot_not_na_correct = 0
        tot = 0
        tot_not_na = 0
        entpair_tot = 0
        test_result = []
        pred_result = []
         
        for i, batch_data in enumerate(self.test_data_loader):
            iter_logit = self.one_step(self.sess, model, batch_data, [model.test_logit()])[0]
            iter_output = iter_logit.argmax(-1)
            iter_correct = (iter_output == batch_data['rel']).sum()
            iter_not_na_correct = np.logical_and(iter_output == batch_data['rel'], batch_data['rel'] != 0).sum()
            tot_correct += iter_correct
            tot_not_na_correct += iter_not_na_correct
            tot += batch_data['rel'].shape[0]
            tot_not_na += (batch_data['rel'] != 0).sum()
            if tot_not_na > 0:
                sys.stdout.write("[TEST] step %d | not NA accuracy: %f, accuracy: %f\r" % (i, float(tot_not_na_correct) / tot_not_na, float(tot_correct) / tot))
                sys.stdout.flush()
            for idx in range(len(iter_logit)):
                for rel in range(1, self.test_data_loader.rel_tot):
                    test_result.append({'score': iter_logit[idx][rel], 'flag': batch_data['multi_rel'][idx][rel]})
                    if batch_data['entpair'][idx] != "None#None":
                        pred_result.append({'score': float(iter_logit[idx][rel]), 'entpair': batch_data['entpair'][idx].encode('utf-8'), 'relation': rel})
                entpair_tot += 1 
        sorted_test_result = sorted(test_result, key=lambda x: x['score'])
        prec = []
        recall = [] 
        correct = 0
        for i, item in enumerate(sorted_test_result[::-1]):
            correct += item['flag']
            prec.append(float(correct) / (i + 1))
            recall.append(float(correct) / self.test_data_loader.relfact_tot)
        auc = sklearn.metrics.auc(x=recall, y=prec)
        print("\n[TEST] auc: {}".format(auc))
        print("Finish testing")
        self.cur_prec = prec
        self.cur_recall = recall

        if not return_result:
            return auc
        else:
            return (auc, pred_result)


================================================
FILE: nrekit/network/classifier.py
================================================
import tensorflow as tf
import numpy as np

def softmax_cross_entropy(x, label, rel_tot, weights_table=None, weights=1.0, var_scope=None):
    with tf.variable_scope(var_scope or "loss", reuse=tf.AUTO_REUSE):
        if weights_table is not None:
            weights = tf.nn.embedding_lookup(weights_table, label)
        label_onehot = tf.one_hot(indices=label, depth=rel_tot, dtype=tf.int32)
        loss = tf.losses.softmax_cross_entropy(onehot_labels=label_onehot, logits=x, weights=weights)
        tf.summary.scalar('loss', loss)
        return loss

def sigmoid_cross_entropy(x, label, rel_tot, weights_table=None, var_scope=None):
    with tf.variable_scope(var_scope or "loss", reuse=tf.AUTO_REUSE):
        if weights_table is None:
            weights = 1.0
        else:
            weights = tf.nn.embedding_lookup(weights_table, label)
        label_onehot = tf.one_hot(indices=label, depth=rel_tot, dtype=tf.int32)
        loss = tf.losses.sigmoid_cross_entropy(label_onehot, logits=x, weights=weights)
        tf.summary.scalar('loss', loss)
        return loss

# Soft-label
# I just implemented it, but I haven't got the result in paper.
def soft_label_softmax_cross_entropy(x):
    with tf.name_scope("soft-label-loss"):
        label_onehot = tf.one_hot(indices=self.label, depth=FLAGS.num_classes, dtype=tf.int32)
        nscore = x + 0.9 * tf.reshape(tf.reduce_max(x, 1), [-1, 1]) * tf.cast(label_onehot, tf.float32)
        nlabel = tf.one_hot(indices=tf.reshape(tf.argmax(nscore, axis=1), [-1]), depth=FLAGS.num_classes, dtype=tf.int32)
        loss = tf.losses.softmax_cross_entropy(onehot_labels=nlabel, logits=nscore, weights=self.weights)
        tf.summary.scalar('loss', loss)
        return loss

def output(x):
    return tf.argmax(x, axis=-1)


================================================
FILE: nrekit/network/embedding.py
================================================
import tensorflow as tf
import numpy as np

def word_embedding(word, word_vec_mat, var_scope=None, word_embedding_dim=50, add_unk_and_blank=True):
    with tf.variable_scope(var_scope or 'word_embedding', reuse=tf.AUTO_REUSE):
        word_embedding = tf.get_variable('word_embedding', initializer=word_vec_mat, dtype=tf.float32)
        if add_unk_and_blank:
            word_embedding = tf.concat([word_embedding,
                                        tf.get_variable("unk_word_embedding", [1, word_embedding_dim], dtype=tf.float32,
                                            initializer=tf.contrib.layers.xavier_initializer()),
                                        tf.constant(np.zeros((1, word_embedding_dim), dtype=np.float32))], 0)
        x = tf.nn.embedding_lookup(word_embedding, word)
        return x

def pos_embedding(pos1, pos2, var_scope=None, pos_embedding_dim=5, max_length=120):
    with tf.variable_scope(var_scope or 'pos_embedding', reuse=tf.AUTO_REUSE):
        pos_tot = max_length * 2

        pos1_embedding = tf.get_variable('real_pos1_embedding', [pos_tot, pos_embedding_dim], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer()) 
        # pos1_embedding = tf.concat([tf.zeros((1, pos_embedding_dim), dtype=tf.float32), real_pos1_embedding], 0)
        pos2_embedding = tf.get_variable('real_pos2_embedding', [pos_tot, pos_embedding_dim], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer()) 
        # pos2_embedding = tf.concat([tf.zeros((1, pos_embedding_dim), dtype=tf.float32), real_pos2_embedding], 0)

        input_pos1 = tf.nn.embedding_lookup(pos1_embedding, pos1)
        input_pos2 = tf.nn.embedding_lookup(pos2_embedding, pos2)
        x = tf.concat([input_pos1, input_pos2], -1)
        return x

def word_position_embedding(word, word_vec_mat, pos1, pos2, var_scope=None, word_embedding_dim=50, pos_embedding_dim=5, max_length=120, add_unk_and_blank=True):
    w_embedding = word_embedding(word, word_vec_mat, var_scope=var_scope, word_embedding_dim=word_embedding_dim, add_unk_and_blank=add_unk_and_blank)
    p_embedding = pos_embedding(pos1, pos2, var_scope=var_scope, pos_embedding_dim=pos_embedding_dim, max_length=max_length)
    return tf.concat([w_embedding, p_embedding], -1)


================================================
FILE: nrekit/network/encoder.py
================================================
import tensorflow as tf
import numpy as np
import math

def __dropout__(x, keep_prob=1.0):
    return tf.contrib.layers.dropout(x, keep_prob=keep_prob)

def __pooling__(x):
    return tf.reduce_max(x, axis=-2)

def __piecewise_pooling__(x, mask):
    mask_embedding = tf.constant([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float32)
    mask = tf.nn.embedding_lookup(mask_embedding, mask)
    hidden_size = x.shape[-1]
    x = tf.reduce_max(tf.expand_dims(mask * 100, 2) + tf.expand_dims(x, 3), axis=1) - 100
    return tf.reshape(x, [-1, hidden_size * 3])

def __cnn_cell__(x, hidden_size=230, kernel_size=3, stride_size=1):
    x = tf.layers.conv1d(inputs=x, 
                         filters=hidden_size, 
                         kernel_size=kernel_size, 
                         strides=stride_size, 
                         padding='same', 
                         kernel_initializer=tf.contrib.layers.xavier_initializer())
    return x

def cnn(x, hidden_size=230, kernel_size=3, stride_size=1, activation=tf.nn.relu, var_scope=None, keep_prob=1.0):
    with tf.variable_scope(var_scope or "cnn", reuse=tf.AUTO_REUSE):
        max_length = x.shape[1]
        x = __cnn_cell__(x, hidden_size, kernel_size, stride_size)
        x = __pooling__(x)
        x = activation(x)
        x = __dropout__(x, keep_prob)
        return x

def pcnn(x, mask, hidden_size=230, kernel_size=3, stride_size=1, activation=tf.nn.relu, var_scope=None, keep_prob=1.0):
    with tf.variable_scope(var_scope or "pcnn", reuse=tf.AUTO_REUSE):
        max_length = x.shape[1]
        x = __cnn_cell__(x, hidden_size, kernel_size, stride_size)
        x = __piecewise_pooling__(x, mask)
        x = activation(x)
        x = __dropout__(x, keep_prob)
        return x

def __rnn_cell__(hidden_size, cell_name='lstm'):
    if isinstance(cell_name, list) or isinstance(cell_name, tuple):
        if len(cell_name) == 1:
            return __rnn_cell__(hidden_size, cell_name[0])
        cells = [self.__rnn_cell__(hidden_size, c) for c in cell_name]
        return tf.contrib.rnn.MultiRNNCell(cells, state_is_tuple=True)
    if cell_name.lower() == 'lstm':
        return tf.contrib.rnn.BasicLSTMCell(hidden_size, state_is_tuple=True)
    elif cell_name.lower() == 'gru':
        return tf.contrib.rnn.GRUCell(hidden_size)
    raise NotImplementedError

def rnn(x, length, hidden_size=230, cell_name='lstm', var_scope=None, keep_prob=1.0):
    with tf.variable_scope(var_scope or "rnn", reuse=tf.AUTO_REUSE):
        x = __dropout__(x, keep_prob)
        cell = __rnn_cell__(hidden_size, cell_name)
        _, states = tf.nn.dynamic_rnn(cell, x, sequence_length=length, dtype=tf.float32, scope='dynamic-rnn')
        if isinstance(states, tuple):
            states = states[0]
        return states

def birnn(x, length, hidden_size=230, cell_name='lstm', var_scope=None, keep_prob=1.0):
    with tf.variable_scope(var_scope or "birnn", reuse=tf.AUTO_REUSE):
        x = __dropout__(x, keep_prob)
        fw_cell = __rnn_cell__(hidden_size, cell_name)
        bw_cell = __rnn_cell__(hidden_size, cell_name)
        _, states = tf.nn.bidirectional_dynamic_rnn(fw_cell, bw_cell, x, sequence_length=length, dtype=tf.float32, scope='dynamic-bi-rnn')
        fw_states, bw_states = states
        if isinstance(fw_states, tuple):
            fw_states = fw_states[0]
            bw_states = bw_states[0]
        return tf.concat([fw_states, bw_states], axis=1)



================================================
FILE: nrekit/network/selector.py
================================================
import tensorflow as tf
import numpy as np

def __dropout__(x, keep_prob=1.0):
    return tf.contrib.layers.dropout(x, keep_prob=keep_prob)

def __logit__(x, rel_tot, var_scope=None):
    with tf.variable_scope(var_scope or 'logit', reuse=tf.AUTO_REUSE):
        relation_matrix = tf.get_variable('relation_matrix', shape=[rel_tot, x.shape[1]], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer())
        bias = tf.get_variable('bias', shape=[rel_tot], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer())
        logit = tf.matmul(x, tf.transpose(relation_matrix)) + bias
    return logit

def __attention_train_logit__(x, query, rel_tot, var_scope=None):
    with tf.variable_scope(var_scope or 'logit', reuse=tf.AUTO_REUSE):
        relation_matrix = tf.get_variable('relation_matrix', shape=[rel_tot, x.shape[1]], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer())
        bias = tf.get_variable('bias', shape=[rel_tot], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer())
    current_relation = tf.nn.embedding_lookup(relation_matrix, query)
    attention_logit = tf.reduce_sum(current_relation * x, -1) # sum[(n', hidden_size) \dot (n', hidden_size)] = (n)
    return attention_logit

def __attention_test_logit__(x, rel_tot, var_scope=None):
    with tf.variable_scope(var_scope or 'logit', reuse=tf.AUTO_REUSE):
        relation_matrix = tf.get_variable('relation_matrix', shape=[rel_tot, x.shape[1]], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer())
        bias = tf.get_variable('bias', shape=[rel_tot], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer())
    attention_logit = tf.matmul(x, tf.transpose(relation_matrix)) # (n', hidden_size) x (hidden_size, rel_tot) = (n', rel_tot)
    return attention_logit

def instance(x, rel_tot, var_scope=None, keep_prob=1.0):
    x = __dropout__(x, keep_prob)
    x = __logit__(x, rel_tot)
    return x

def bag_attention(x, scope, query, rel_tot, is_training, var_scope=None, dropout_before=False, keep_prob=1.0):
    with tf.variable_scope(var_scope or "attention", reuse=tf.AUTO_REUSE):
        if is_training: # training
            if dropout_before:
                x = __dropout__(x, keep_prob)
            bag_repre = []
            attention_logit = __attention_train_logit__(x, query, rel_tot)
            for i in range(scope.shape[0]):
                bag_hidden_mat = x[scope[i][0]:scope[i][1]]
                attention_score = tf.nn.softmax(attention_logit[scope[i][0]:scope[i][1]], -1)
                bag_repre.append(tf.squeeze(tf.matmul(tf.expand_dims(attention_score, 0), bag_hidden_mat))) # (1, n') x (n', hidden_size) = (1, hidden_size) -> (hidden_size)
            bag_repre = tf.stack(bag_repre)
            if not dropout_before:
                bag_repre = __dropout__(bag_repre, keep_prob)
            return __logit__(bag_repre, rel_tot), bag_repre
        else: # testing
            attention_logit = __attention_test_logit__(x, rel_tot) # (n, rel_tot)
            bag_repre = [] 
            bag_logit = []
            for i in range(scope.shape[0]):
                bag_hidden_mat = x[scope[i][0]:scope[i][1]]
                attention_score = tf.nn.softmax(tf.transpose(attention_logit[scope[i][0]:scope[i][1], :]), -1) # softmax of (rel_tot, n')
                bag_repre_for_each_rel = tf.matmul(attention_score, bag_hidden_mat) # (rel_tot, n') \dot (n', hidden_size) = (rel_tot, hidden_size)
                bag_logit_for_each_rel = __logit__(bag_repre_for_each_rel, rel_tot) # -> (rel_tot, rel_tot)
                bag_repre.append(bag_repre_for_each_rel)
                bag_logit.append(tf.diag_part(tf.nn.softmax(bag_logit_for_each_rel, -1))) # could be improved by sigmoid?
            bag_repre = tf.stack(bag_repre)
            bag_logit = tf.stack(bag_logit)
            return bag_logit, bag_repre

def bag_average(x, scope, rel_tot, var_scope=None, dropout_before=False, keep_prob=1.0):
    with tf.variable_scope(var_scope or "average", reuse=tf.AUTO_REUSE):
        if dropout_before:
            x = __dropout__(x, keep_prob)
        bag_repre = []
        for i in range(scope.shape[0]):
            bag_hidden_mat = x[scope[i][0]:scope[i][1]]
            bag_repre.append(tf.reduce_mean(bag_hidden_mat, 0)) # (n', hidden_size) -> (hidden_size)
        bag_repre = tf.stack(bag_repre)
        if not dropout_before:
            bag_repre = __dropout__(bag_repre, keep_prob)
    return __logit__(bag_repre, rel_tot), bag_repre

def bag_one(x, scope, query, rel_tot, is_training, var_scope=None, dropout_before=False, keep_prob=1.0): # could be improved?
    with tf.variable_scope(var_scope or "one", reuse=tf.AUTO_REUSE):
        if is_training: # training
            if dropout_before:
                x = __dropout__(x, keep_prob)
            bag_repre = []
            for i in range(scope.shape[0]):
                bag_hidden_mat = x[scope[i][0]:scope[i][1]]
                instance_logit = tf.nn.softmax(__logit__(bag_hidden_mat, rel_tot), -1) # (n', hidden_size) -> (n', rel_tot)
                j = tf.argmax(instance_logit[:, query[i]], output_type=tf.int32)
                bag_repre.append(bag_hidden_mat[j])
            bag_repre = tf.stack(bag_repre)
            if not dropout_before:
                bag_repre = __dropout__(bag_repre, keep_prob)
            return __logit__(bag_repre, rel_tot), bag_repre
        else: # testing
            if dropout_before:
                x = __dropout__(x, keep_prob)
            bag_repre = []
            bag_logit = []
            for i in range(scope.shape[0]):
                bag_hidden_mat = x[scope[i][0]:scope[i][1]]
                instance_logit = tf.nn.softmax(__logit__(bag_hidden_mat, rel_tot), -1) # (n', hidden_size) -> (n', rel_tot)
                bag_logit.append(tf.reduce_max(instance_logit, 0))
                bag_repre.append(bag_hidden_mat[0]) # fake max repre
            bag_logit = tf.stack(bag_logit)
            bag_repre = tf.stack(bag_repre)
            return bag_logit, bag_repre

def bag_cross_max(x, scope, rel_tot, var_scope=None, dropout_before=False, keep_prob=1.0):
    '''
    Cross-sentence Max-pooling proposed by (Jiang et al. 2016.)
    "Relation Extraction with Multi-instance Multi-label Convolutional Neural Networks"
    https://pdfs.semanticscholar.org/8731/369a707046f3f8dd463d1fd107de31d40a24.pdf
    '''
    with tf.variable_scope(var_scope or "cross_max", reuse=tf.AUTO_REUSE):
        if dropout_before:
            x = __dropout__(x, keep_prob)
        bag_repre = []
        for i in range(scope.shape[0]):
            bag_hidden_mat = x[scope[i][0]:scope[i][1]]
            bag_repre.append(tf.reduce_max(bag_hidden_mat, 0)) # (n', hidden_size) -> (hidden_size)
        bag_repre = tf.stack(bag_repre)
        if not dropout_before:
            bag_repre = __dropout__(bag_repre, keep_prob)
    return __logit__(bag_repre, rel_tot), bag_repre


================================================
FILE: nrekit/rl.py
================================================
import tensorflow as tf
import os
import sklearn.metrics
import numpy as np
import sys
import math
import time
import framework
import network

class policy_agent(framework.re_model):
    def __init__(self, train_data_loader, batch_size, max_length=120):
        framework.re_model.__init__(self, train_data_loader, batch_size, max_length)
        self.weights = tf.placeholder(tf.float32, shape=(), name="weights_scalar")

        x = network.embedding.word_position_embedding(self.word, self.word_vec_mat, self.pos1, self.pos2)
        x_train = network.encoder.cnn(x, keep_prob=0.5)
        x_test = network.encoder.cnn(x, keep_prob=1.0)
        self._train_logit = network.selector.instance(x_train, 2, keep_prob=0.5)
        self._test_logit = network.selector.instance(x_test, 2, keep_prob=1.0)
        self._loss = network.classifier.softmax_cross_entropy(self._train_logit, self.ins_label, 2, weights=self.weights)

    def loss(self):
        return self._loss

    def train_logit(self):
        return self._train_logit

    def test_logit(self):
        return self._test_logit

class rl_re_framework(framework.re_framework):
    def __init__(self, train_data_loader, test_data_loader, max_length=120, batch_size=160):
        framework.re_framework.__init__(self, train_data_loader, test_data_loader, max_length, batch_size)

    def agent_one_step(self, sess, agent_model, batch_data, run_array, weights=1):
        feed_dict = {
            agent_model.word: batch_data['word'],
            agent_model.pos1: batch_data['pos1'],
            agent_model.pos2: batch_data['pos2'],
            agent_model.ins_label: batch_data['agent_label'],
            agent_model.length: batch_data['length'],
            agent_model.weights: weights
        }
        if 'mask' in batch_data and hasattr(agent_model, "mask"):
            feed_dict.update({agent_model.mask: batch_data['mask']})
        result = sess.run(run_array, feed_dict)
        return result

    def pretrain_main_model(self, max_epoch):
        for epoch in range(max_epoch):
            print('###### Epoch ' + str(epoch) + ' ######')
            tot_correct = 0
            tot_not_na_correct = 0
            tot = 0
            tot_not_na = 0
            i = 0
            time_sum = 0
            
            for i, batch_data in enumerate(self.train_data_loader):
                time_start = time.time()
                iter_loss, iter_logit, _train_op = self.one_step(self.sess, self.model, batch_data, [self.model.loss(), self.model.train_logit(), self.train_op])
                time_end = time.time()
                t = time_end - time_start
                time_sum += t
                iter_output = iter_logit.argmax(-1)
                iter_label = batch_data['rel']
                iter_correct = (iter_output == iter_label).sum()
                iter_not_na_correct = np.logical_and(iter_output == iter_label, iter_label != 0).sum()
                tot_correct += iter_correct
                tot_not_na_correct += iter_not_na_correct
                tot += iter_label.shape[0]
                tot_not_na += (iter_label != 0).sum()
                if tot_not_na > 0:
                    sys.stdout.write("[pretrain main model] epoch %d step %d time %.2f | loss: %f, not NA accuracy: %f, accuracy: %f\r" % (epoch, i, t, iter_loss, float(tot_not_na_correct) / tot_not_na, float(tot_correct) / tot))
                    sys.stdout.flush()
                i += 1
            print("\nAverage iteration time: %f" % (time_sum / i))

    def pretrain_agent_model(self, max_epoch):
        # Pre-train policy agent
        for epoch in range(max_epoch):
            print('###### [Pre-train Policy Agent] Epoch ' + str(epoch) + ' ######')
            tot_correct = 0
            tot_not_na_correct = 0
            tot = 0
            tot_not_na = 0
            time_sum = 0
            
            for i, batch_data in enumerate(self.train_data_loader):
                time_start = time.time()
                batch_data['agent_label'] = batch_data['ins_rel'] + 0
                batch_data['agent_label'][batch_data['agent_label'] > 0] = 1
                iter_loss, iter_logit, _train_op = self.agent_one_step(self.sess, self.agent_model, batch_data, [self.agent_model.loss(), self.agent_model.train_logit(), self.agent_train_op])
                time_end = time.time()
                t = time_end - time_start
                time_sum += t
                iter_output = iter_logit.argmax(-1)
                iter_label = batch_data['ins_rel']
                iter_correct = (iter_output == iter_label).sum()
                iter_not_na_correct = np.logical_and(iter_output == iter_label, iter_label != 0).sum()
                tot_correct += iter_correct
                tot_not_na_correct += iter_not_na_correct
                tot += iter_label.shape[0]
                tot_not_na += (iter_label != 0).sum()
                if tot_not_na > 0:
                    sys.stdout.write("[pretrain policy agent] epoch %d step %d time %.2f | loss: %f, not NA accuracy: %f, accuracy: %f\r" % (epoch, i, t, iter_loss, float(tot_not_na_correct) / tot_not_na, float(tot_correct) / tot))
                    sys.stdout.flush()
                i += 1

    def train(self,
              model, # The main model
              agent_model, # The model of policy agent
              model_name,
              ckpt_dir='./checkpoint',
              summary_dir='./summary',
              test_result_dir='./test_result',
              learning_rate=0.5,
              max_epoch=60,
              pretrain_agent_epoch=1,
              pretrain_model=None,
              test_epoch=1,
              optimizer=tf.train.GradientDescentOptimizer):
        
        print("Start training...")
        
        # Init
        self.model = model(self.train_data_loader, self.train_data_loader.batch_size, self.train_data_loader.max_length)
        model_optimizer = optimizer(learning_rate)
        grads = model_optimizer.compute_gradients(self.model.loss())
        self.train_op = model_optimizer.apply_gradients(grads)

        # Init policy agent
        self.agent_model = agent_model(self.train_data_loader, self.train_data_loader.batch_size, self.train_data_loader.max_length)
        agent_optimizer = optimizer(learning_rate)
        agent_grads = agent_optimizer.compute_gradients(self.agent_model.loss())
        self.agent_train_op = agent_optimizer.apply_gradients(agent_grads)

        # Session, writer and saver
        self.sess = tf.Session()
        summary_writer = tf.summary.FileWriter(summary_dir, self.sess.graph)
        saver = tf.train.Saver(max_to_keep=None)
        if pretrain_model is None:
            self.sess.run(tf.global_variables_initializer())
        else:
            saver.restore(self.sess, pretrain_model)

        self.pretrain_main_model(max_epoch=5) # Pre-train main model
        self.pretrain_agent_model(max_epoch=1) # Pre-train policy agent 

        # Train
        tot_delete = 0
        batch_count = 0
        instance_count = 0
        reward = 0.0
        best_metric = 0
        best_prec = None
        best_recall = None
        not_best_count = 0 # Stop training after several epochs without improvement.
        for epoch in range(max_epoch):
            print('###### Epoch ' + str(epoch) + ' ######')
            tot_correct = 0
            tot_not_na_correct = 0
            tot = 0
            tot_not_na = 0
            i = 0
            time_sum = 0
            batch_stack = []
           
            # Update policy agent
            for i, batch_data in enumerate(self.train_data_loader):
                # Make action
                batch_data['agent_label'] = batch_data['ins_rel'] + 0
                batch_data['agent_label'][batch_data['agent_label'] > 0] = 1
                batch_stack.append(batch_data)
                iter_logit = self.agent_one_step(self.sess, self.agent_model, batch_data, [self.agent_model.train_logit()])[0]
                action_result = iter_logit.argmax(-1)
                
                # Calculate reward
                batch_delete = np.sum(np.logical_and(batch_data['ins_rel'] != 0, action_result == 0))
                batch_data['ins_rel'][action_result == 0] = 0
                iter_loss = self.one_step(self.sess, self.model, batch_data, [self.model.loss()])[0]
                reward += iter_loss
                tot_delete += batch_delete
                batch_count += 1

                # Update parameters of policy agent
                alpha = 0.1
                if batch_count == 100:
                    reward = reward / float(batch_count)
                    average_loss = reward
                    reward = - math.log(1 - math.e ** (-reward))
                    sys.stdout.write('tot delete : %f | reward : %f | average loss : %f\r' % (tot_delete, reward, average_loss))
                    sys.stdout.flush()
                    for batch_data in batch_stack:
                        self.agent_one_step(self.sess, self.agent_model, batch_data, [self.agent_train_op], weights=reward * alpha)
                    batch_count = 0
                    reward = 0
                    tot_delete = 0
                    batch_stack = []
                i += 1

            # Train the main model
            for i, batch_data in enumerate(self.train_data_loader):
                batch_data['agent_label'] = batch_data['ins_rel'] + 0
                batch_data['agent_label'][batch_data['agent_label'] > 0] = 1
                time_start = time.time()

                # Make actions
                iter_logit = self.agent_one_step(self.sess, self.agent_model, batch_data, [self.agent_model.train_logit()])[0]
                action_result = iter_logit.argmax(-1)
                batch_data['ins_rel'][action_result == 0] = 0
                
                # Real training
                iter_loss, iter_logit, _train_op = self.agent_one_step(self.sess, self.agent_model, batch_data, [self.agent_model.loss(), self.agent_model.train_logit(), self.agent_train_op])
                time_end = time.time()
                t = time_end - time_start
                time_sum += t
                iter_output = iter_logit.argmax(-1)
                if tot_not_na > 0:
                    sys.stdout.write("epoch %d step %d time %.2f | loss: %f, not NA accuracy: %f, accuracy: %f\r" % (epoch, i, t, iter_loss, float(tot_not_na_correct) / tot_not_na, float(tot_correct) / tot))
                    sys.stdout.flush()
                i += 1
            print("\nAverage iteration time: %f" % (time_sum / i))

            if (epoch + 1) % test_epoch == 0:
                metric = self.test(model)
                if metric > best_metric:
                    best_metric = metric
                    best_prec = self.cur_prec
                    best_recall = self.cur_recall
                    print("Best model, storing...")
                    if not os.path.isdir(ckpt_dir):
                        os.mkdir(ckpt_dir)
                    path = saver.save(self.sess, os.path.join(ckpt_dir, model_name))
                    print("Finish storing")
                    not_best_count = 0
                else:
                    not_best_count += 1

            if not_best_count >= 20:
                break

        print("######")
        print("Finish training " + model_name)
        print("Best epoch auc = %f" % (best_metric))
        if (not best_prec is None) and (not best_recall is None):
            if not os.path.isdir(test_result_dir):
                os.mkdir(test_result_dir)
            np.save(os.path.join(test_result_dir, model_name + "_x.npy"), best_recall)
            np.save(os.path.join(test_result_dir, model_name + "_y.npy"), best_prec)



================================================
FILE: test_demo.py
================================================
import nrekit
import numpy as np
import tensorflow as tf
import sys
import os
import json

dataset_name = 'nyt'
if len(sys.argv) > 1:
    dataset_name = sys.argv[1]
dataset_dir = os.path.join('./data', dataset_name)
if not os.path.isdir(dataset_dir):
    raise Exception("[ERROR] Dataset dir %s doesn't exist!" % (dataset_dir))

# The first 3 parameters are train / test data file name, word embedding file name and relation-id mapping file name respectively.
train_loader = nrekit.data_loader.json_file_data_loader(os.path.join(dataset_dir, 'train.json'), 
                                                        os.path.join(dataset_dir, 'word_vec.json'),
                                                        os.path.join(dataset_dir, 'rel2id.json'), 
                                                        mode=nrekit.data_loader.json_file_data_loader.MODE_RELFACT_BAG,
                                                        shuffle=True)
test_loader = nrekit.data_loader.json_file_data_loader(os.path.join(dataset_dir, 'test.json'), 
                                                       os.path.join(dataset_dir, 'word_vec.json'),
                                                       os.path.join(dataset_dir, 'rel2id.json'), 
                                                       mode=nrekit.data_loader.json_file_data_loader.MODE_ENTPAIR_BAG,
                                                       shuffle=False)

framework = nrekit.framework.re_framework(train_loader, test_loader)

class model(nrekit.framework.re_model):
    encoder = "pcnn"
    selector = "att"

    def __init__(self, train_data_loader, batch_size, max_length=120):
        nrekit.framework.re_model.__init__(self, train_data_loader, batch_size, max_length=max_length)
        self.mask = tf.placeholder(dtype=tf.int32, shape=[None, max_length], name="mask")
        
        # Embedding
        x = nrekit.network.embedding.word_position_embedding(self.word, self.word_vec_mat, self.pos1, self.pos2)

        # Encoder
        if model.encoder == "pcnn":
            x_train = nrekit.network.encoder.pcnn(x, self.mask, keep_prob=0.5)
            x_test = nrekit.network.encoder.pcnn(x, self.mask, keep_prob=1.0)
        elif model.encoder == "cnn":
            x_train = nrekit.network.encoder.cnn(x, keep_prob=0.5)
            x_test = nrekit.network.encoder.cnn(x, keep_prob=1.0)
        elif model.encoder == "rnn":
            x_train = nrekit.network.encoder.rnn(x, self.length, keep_prob=0.5)
            x_test = nrekit.network.encoder.rnn(x, self.length, keep_prob=1.0)
        elif model.encoder == "birnn":
            x_train = nrekit.network.encoder.birnn(x, self.length, keep_prob=0.5)
            x_test = nrekit.network.encoder.birnn(x, self.length, keep_prob=1.0)
        else:
            raise NotImplementedError

        # Selector
        if model.selector == "att":
            self._train_logit, train_repre = nrekit.network.selector.bag_attention(x_train, self.scope, self.ins_label, self.rel_tot, True, keep_prob=0.5)
            self._test_logit, test_repre = nrekit.network.selector.bag_attention(x_test, self.scope, self.ins_label, self.rel_tot, False, keep_prob=1.0)
        elif model.selector == "ave":
            self._train_logit, train_repre = nrekit.network.selector.bag_average(x_train, self.scope, self.rel_tot, keep_prob=0.5)
            self._test_logit, test_repre = nrekit.network.selector.bag_average(x_test, self.scope, self.rel_tot, keep_prob=1.0)
            self._test_logit = tf.nn.softmax(self._test_logit)
        elif model.selector == "max":
            self._train_logit, train_repre = nrekit.network.selector.bag_maximum(x_train, self.scope, self.ins_label, self.rel_tot, True, keep_prob=0.5)
            self._test_logit, test_repre = nrekit.network.selector.bag_maximum(x_test, self.scope, self.ins_label, self.rel_tot, False, keep_prob=1.0)
            self._test_logit = tf.nn.softmax(self._test_logit)
        else:
            raise NotImplementedError
        
        # Classifier
        self._loss = nrekit.network.classifier.softmax_cross_entropy(self._train_logit, self.label, self.rel_tot, weights_table=self.get_weights())
 
    def loss(self):
        return self._loss

    def train_logit(self):
        return self._train_logit

    def test_logit(self):
        return self._test_logit

    def get_weights(self):
        with tf.variable_scope("weights_table", reuse=tf.AUTO_REUSE):
            print("Calculating weights_table...")
            _weights_table = np.zeros((self.rel_tot), dtype=np.float32)
            for i in range(len(self.train_data_loader.data_rel)):
                _weights_table[self.train_data_loader.data_rel[i]] += 1.0 
            _weights_table = 1 / (_weights_table ** 0.05)
            weights_table = tf.get_variable(name='weights_table', dtype=tf.float32, trainable=False, initializer=_weights_table)
            print("Finish calculating")
        return weights_table

if len(sys.argv) > 2:
    model.encoder = sys.argv[2]
if len(sys.argv) > 3:
    model.selector = sys.argv[3]

auc, pred_result = framework.test(model, ckpt="./checkpoint/" + dataset_name + "_" + model.encoder + "_" + model.selector, return_result=True)

with open('./test_result/' + dataset_name + "_" + model.encoder + "_" + model.selector + "_pred.json", 'w') as outfile:
    json.dump(pred_result, outfile)



================================================
FILE: train_demo.py
================================================
import nrekit
import numpy as np
import tensorflow as tf
import sys
import os

dataset_name = 'nyt'
if len(sys.argv) > 1:
    dataset_name = sys.argv[1]
dataset_dir = os.path.join('./data', dataset_name)
if not os.path.isdir(dataset_dir):
    raise Exception("[ERROR] Dataset dir %s doesn't exist!" % (dataset_dir))

# The first 3 parameters are train / test data file name, word embedding file name and relation-id mapping file name respectively.
train_loader = nrekit.data_loader.json_file_data_loader(os.path.join(dataset_dir, 'train.json'), 
                                                        os.path.join(dataset_dir, 'word_vec.json'),
                                                        os.path.join(dataset_dir, 'rel2id.json'), 
                                                        mode=nrekit.data_loader.json_file_data_loader.MODE_RELFACT_BAG,
                                                        shuffle=True)
test_loader = nrekit.data_loader.json_file_data_loader(os.path.join(dataset_dir, 'test.json'), 
                                                       os.path.join(dataset_dir, 'word_vec.json'),
                                                       os.path.join(dataset_dir, 'rel2id.json'), 
                                                       mode=nrekit.data_loader.json_file_data_loader.MODE_ENTPAIR_BAG,
                                                       shuffle=False)

framework = nrekit.framework.re_framework(train_loader, test_loader)

class model(nrekit.framework.re_model):
    encoder = "pcnn"
    selector = "att"

    def __init__(self, train_data_loader, batch_size, max_length=120):
        nrekit.framework.re_model.__init__(self, train_data_loader, batch_size, max_length=max_length)
        self.mask = tf.placeholder(dtype=tf.int32, shape=[None, max_length], name="mask")
        
        # Embedding
        x = nrekit.network.embedding.word_position_embedding(self.word, self.word_vec_mat, self.pos1, self.pos2)

        # Encoder
        if model.encoder == "pcnn":
            x_train = nrekit.network.encoder.pcnn(x, self.mask, keep_prob=0.5)
            x_test = nrekit.network.encoder.pcnn(x, self.mask, keep_prob=1.0)
        elif model.encoder == "cnn":
            x_train = nrekit.network.encoder.cnn(x, keep_prob=0.5)
            x_test = nrekit.network.encoder.cnn(x, keep_prob=1.0)
        elif model.encoder == "rnn":
            x_train = nrekit.network.encoder.rnn(x, self.length, keep_prob=0.5)
            x_test = nrekit.network.encoder.rnn(x, self.length, keep_prob=1.0)
        elif model.encoder == "birnn":
            x_train = nrekit.network.encoder.birnn(x, self.length, keep_prob=0.5)
            x_test = nrekit.network.encoder.birnn(x, self.length, keep_prob=1.0)
        else:
            raise NotImplementedError

        # Selector
        if model.selector == "att":
            self._train_logit, train_repre = nrekit.network.selector.bag_attention(x_train, self.scope, self.ins_label, self.rel_tot, True, keep_prob=0.5)
            self._test_logit, test_repre = nrekit.network.selector.bag_attention(x_test, self.scope, self.ins_label, self.rel_tot, False, keep_prob=1.0)
        elif model.selector == "ave":
            self._train_logit, train_repre = nrekit.network.selector.bag_average(x_train, self.scope, self.rel_tot, keep_prob=0.5)
            self._test_logit, test_repre = nrekit.network.selector.bag_average(x_test, self.scope, self.rel_tot, keep_prob=1.0)
            self._test_logit = tf.nn.softmax(self._test_logit)
        elif model.selector == "one":
            self._train_logit, train_repre = nrekit.network.selector.bag_one(x_train, self.scope, self.label, self.rel_tot, True, keep_prob=0.5)
            self._test_logit, test_repre = nrekit.network.selector.bag_one(x_test, self.scope, self.label, self.rel_tot, False, keep_prob=1.0)
            self._test_logit = tf.nn.softmax(self._test_logit)
        elif model.selector == "cross_max":
            self._train_logit, train_repre = nrekit.network.selector.bag_cross_max(x_train, self.scope, self.rel_tot, keep_prob=0.5)
            self._test_logit, test_repre = nrekit.network.selector.bag_cross_max(x_test, self.scope, self.rel_tot, keep_prob=1.0)
            self._test_logit = tf.nn.softmax(self._test_logit)
        else:
            raise NotImplementedError
        
        # Classifier
        self._loss = nrekit.network.classifier.softmax_cross_entropy(self._train_logit, self.label, self.rel_tot, weights_table=self.get_weights())
 
    def loss(self):
        return self._loss

    def train_logit(self):
        return self._train_logit

    def test_logit(self):
        return self._test_logit

    def get_weights(self):
        with tf.variable_scope("weights_table", reuse=tf.AUTO_REUSE):
            print("Calculating weights_table...")
            _weights_table = np.zeros((self.rel_tot), dtype=np.float32)
            for i in range(len(self.train_data_loader.data_rel)):
                _weights_table[self.train_data_loader.data_rel[i]] += 1.0 
            _weights_table = 1 / (_weights_table ** 0.05)
            weights_table = tf.get_variable(name='weights_table', dtype=tf.float32, trainable=False, initializer=_weights_table)
            print("Finish calculating")
        return weights_table

use_rl = False
if len(sys.argv) > 2:
    model.encoder = sys.argv[2]
if len(sys.argv) > 3:
    model.selector = sys.argv[3]
if len(sys.argv) > 4:
    if sys.argv[4] == 'rl':
        use_rl = True

if use_rl:
    rl_framework = nrekit.rl.rl_re_framework(train_loader, test_loader)
    rl_framework.train(model, nrekit.rl.policy_agent, model_name=dataset_name + "_" + model.encoder + "_" + model.selector + "_rl", max_epoch=60, ckpt_dir="checkpoint")
else:
    framework.train(model, model_name=dataset_name + "_" + model.encoder + "_" + model.selector, max_epoch=60, ckpt_dir="checkpoint", gpu_nums=1)
Download .txt
gitextract_idaztadn/

├── .gitignore
├── LICENSE
├── README.md
├── draw_plot.py
├── kg_data/
│   ├── EntityMatcher.py
│   ├── README.md
│   ├── SentenceSegment.py
│   ├── add_relation.ipynb
│   ├── data_process.ipynb
│   └── stop_word.txt
├── nrekit/
│   ├── data_loader.py
│   ├── framework.py
│   ├── network/
│   │   ├── classifier.py
│   │   ├── embedding.py
│   │   ├── encoder.py
│   │   └── selector.py
│   └── rl.py
├── test_demo.py
└── train_demo.py
Download .txt
SYMBOL INDEX (88 symbols across 12 files)

FILE: draw_plot.py
  function main (line 12) | def main():

FILE: kg_data/EntityMatcher.py
  class EntityMatcher (line 6) | class EntityMatcher:
    method __init__ (line 7) | def __init__(self, entity_file, sentences_folder, process_num):
    method match (line 20) | def match(self, file_name):
    method write_file (line 35) | def write_file(self, data):
    method run (line 39) | def run(self):

FILE: kg_data/SentenceSegment.py
  function read_txt (line 7) | def read_txt(file_name):
  class SentenceSegment (line 16) | class SentenceSegment:
    method __init__ (line 17) | def __init__(self, dict_file, stop_word_file, sentences_folder, proces...
    method segment (line 28) | def segment(self, file_name):
    method write_file (line 54) | def write_file(self, data):
    method run (line 58) | def run(self):

FILE: nrekit/data_loader.py
  class file_data_loader (line 9) | class file_data_loader:
    method __next__ (line 10) | def __next__(self):
    method next (line 13) | def next(self):
    method next_batch (line 16) | def next_batch(self, batch_size):
  class npy_data_loader (line 19) | class npy_data_loader(file_data_loader):
    method __iter__ (line 24) | def __iter__(self):
    method __init__ (line 27) | def __init__(self, data_dir, prefix, mode, word_vec_npy='vec.npy', shu...
    method __next__ (line 61) | def __next__(self):
    method next_batch (line 64) | def next_batch(self, batch_size):
  class json_file_data_loader (line 136) | class json_file_data_loader(file_data_loader):
    method _load_preprocessed_file (line 141) | def _load_preprocessed_file(self):
    method __init__ (line 185) | def __init__(self, file_name, word_vec_file_name, rel2id_file_name, mo...
    method __iter__ (line 436) | def __iter__(self):
    method __next__ (line 439) | def __next__(self):
    method next_batch (line 442) | def next_batch(self, batch_size):

FILE: nrekit/framework.py
  function average_gradients (line 8) | def average_gradients(tower_grads):
  class re_model (line 45) | class re_model:
    method __init__ (line 46) | def __init__(self, train_data_loader, batch_size, max_length=120):
    method loss (line 58) | def loss(self):
    method train_logit (line 61) | def train_logit(self):
    method test_logit (line 64) | def test_logit(self):
  class re_framework (line 67) | class re_framework:
    method __init__ (line 71) | def __init__(self, train_data_loader, test_data_loader, max_length=120...
    method one_step_multi_models (line 76) | def one_step_multi_models(self, sess, models, batch_data_gen, run_arra...
    method one_step (line 99) | def one_step(self, sess, model, batch_data, run_array):
    method train (line 114) | def train(self,
    method test (line 225) | def test(self,
    method __test_bag__ (line 237) | def __test_bag__(self, model, ckpt=None, return_result=False):

FILE: nrekit/network/classifier.py
  function softmax_cross_entropy (line 4) | def softmax_cross_entropy(x, label, rel_tot, weights_table=None, weights...
  function sigmoid_cross_entropy (line 13) | def sigmoid_cross_entropy(x, label, rel_tot, weights_table=None, var_sco...
  function soft_label_softmax_cross_entropy (line 26) | def soft_label_softmax_cross_entropy(x):
  function output (line 35) | def output(x):

FILE: nrekit/network/embedding.py
  function word_embedding (line 4) | def word_embedding(word, word_vec_mat, var_scope=None, word_embedding_di...
  function pos_embedding (line 15) | def pos_embedding(pos1, pos2, var_scope=None, pos_embedding_dim=5, max_l...
  function word_position_embedding (line 29) | def word_position_embedding(word, word_vec_mat, pos1, pos2, var_scope=No...

FILE: nrekit/network/encoder.py
  function __dropout__ (line 5) | def __dropout__(x, keep_prob=1.0):
  function __pooling__ (line 8) | def __pooling__(x):
  function __piecewise_pooling__ (line 11) | def __piecewise_pooling__(x, mask):
  function __cnn_cell__ (line 18) | def __cnn_cell__(x, hidden_size=230, kernel_size=3, stride_size=1):
  function cnn (line 27) | def cnn(x, hidden_size=230, kernel_size=3, stride_size=1, activation=tf....
  function pcnn (line 36) | def pcnn(x, mask, hidden_size=230, kernel_size=3, stride_size=1, activat...
  function __rnn_cell__ (line 45) | def __rnn_cell__(hidden_size, cell_name='lstm'):
  function rnn (line 57) | def rnn(x, length, hidden_size=230, cell_name='lstm', var_scope=None, ke...
  function birnn (line 66) | def birnn(x, length, hidden_size=230, cell_name='lstm', var_scope=None, ...

FILE: nrekit/network/selector.py
  function __dropout__ (line 4) | def __dropout__(x, keep_prob=1.0):
  function __logit__ (line 7) | def __logit__(x, rel_tot, var_scope=None):
  function __attention_train_logit__ (line 14) | def __attention_train_logit__(x, query, rel_tot, var_scope=None):
  function __attention_test_logit__ (line 22) | def __attention_test_logit__(x, rel_tot, var_scope=None):
  function instance (line 29) | def instance(x, rel_tot, var_scope=None, keep_prob=1.0):
  function bag_attention (line 34) | def bag_attention(x, scope, query, rel_tot, is_training, var_scope=None,...
  function bag_average (line 64) | def bag_average(x, scope, rel_tot, var_scope=None, dropout_before=False,...
  function bag_one (line 77) | def bag_one(x, scope, query, rel_tot, is_training, var_scope=None, dropo...
  function bag_cross_max (line 106) | def bag_cross_max(x, scope, rel_tot, var_scope=None, dropout_before=Fals...

FILE: nrekit/rl.py
  class policy_agent (line 11) | class policy_agent(framework.re_model):
    method __init__ (line 12) | def __init__(self, train_data_loader, batch_size, max_length=120):
    method loss (line 23) | def loss(self):
    method train_logit (line 26) | def train_logit(self):
    method test_logit (line 29) | def test_logit(self):
  class rl_re_framework (line 32) | class rl_re_framework(framework.re_framework):
    method __init__ (line 33) | def __init__(self, train_data_loader, test_data_loader, max_length=120...
    method agent_one_step (line 36) | def agent_one_step(self, sess, agent_model, batch_data, run_array, wei...
    method pretrain_main_model (line 50) | def pretrain_main_model(self, max_epoch):
    method pretrain_agent_model (line 80) | def pretrain_agent_model(self, max_epoch):
    method train (line 111) | def train(self,

FILE: test_demo.py
  class model (line 29) | class model(nrekit.framework.re_model):
    method __init__ (line 33) | def __init__(self, train_data_loader, batch_size, max_length=120):
    method loss (line 74) | def loss(self):
    method train_logit (line 77) | def train_logit(self):
    method test_logit (line 80) | def test_logit(self):
    method get_weights (line 83) | def get_weights(self):

FILE: train_demo.py
  class model (line 28) | class model(nrekit.framework.re_model):
    method __init__ (line 32) | def __init__(self, train_data_loader, batch_size, max_length=120):
    method loss (line 77) | def loss(self):
    method train_logit (line 80) | def train_logit(self):
    method test_logit (line 83) | def test_logit(self):
    method get_weights (line 86) | def get_weights(self):
Condensed preview — 19 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (138K chars).
[
  {
    "path": ".gitignore",
    "chars": 117,
    "preview": ".ipynb_checkpoints/\nkg_data/processed/\nkg_data/baike_triples.txt\nkg_data/baiketriples.zip\nkg_data/.ipynb_checkpoints/"
  },
  {
    "path": "LICENSE",
    "chars": 1067,
    "preview": "MIT License\n\nCopyright (c) 2018 Tianyu Gao\n\nPermission is hereby granted, free of charge, to any person obtaining a copy"
  },
  {
    "path": "README.md",
    "chars": 944,
    "preview": "# Distant-Supervised-Chinese-Relation-Extraction\n## 基于远监督的中文关系抽取\n\n### 数据集构建\n\n* 中文通用知识库CN-DBpedia\n* 远监督假设\n\n处理流程可在 kg_dat"
  },
  {
    "path": "draw_plot.py",
    "chars": 1124,
    "preview": "import sklearn.metrics\nimport matplotlib\n# Use 'Agg' so this program could run on a remote server\nmatplotlib.use('Agg')\n"
  },
  {
    "path": "kg_data/EntityMatcher.py",
    "chars": 1809,
    "preview": "import pickle\nimport multiprocessing\nimport time\nimport os\n\nclass EntityMatcher:\n    def __init__(self, entity_file, sen"
  },
  {
    "path": "kg_data/README.md",
    "chars": 8566,
    "preview": "# 远监督数据集构造流程\n\n\n![](http://www.bbvdd.com/d/20190314161746ddc.png)\n\n## 运行顺序\n0. 下载原始数据, 解压后放在该目录下, 即  `kg_data/baike_triple"
  },
  {
    "path": "kg_data/SentenceSegment.py",
    "chars": 2495,
    "preview": "import pickle\nimport jieba\nimport os\nimport multiprocessing\nimport time\n\ndef read_txt(file_name):\n    txt_data = []\n    "
  },
  {
    "path": "kg_data/add_relation.ipynb",
    "chars": 11338,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n "
  },
  {
    "path": "kg_data/data_process.ipynb",
    "chars": 4330,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": "
  },
  {
    "path": "kg_data/stop_word.txt",
    "chars": 5782,
    "preview": "\n可能\n起\n便于\n有些\n上述\n纯粹\n尽然\n乃至于\n极为\n也\n并不\n其次\n矣哉\n为何\n到头来\n此外\n[①e]\n8\n就是了\n故此\n移动\n旁人\n立刻\n路经\n使\n元/吨\n刚\n归齐\n归根结底\n!\n〕\n迅速\n获得\n4\n已\n冒\n哎\n已矣\n哪天\n这么点儿\n"
  },
  {
    "path": "nrekit/data_loader.py",
    "chars": 26753,
    "preview": "from six import iteritems\n\nimport json\nimport os\nimport multiprocessing\nimport numpy as np\nimport random\n\nclass file_dat"
  },
  {
    "path": "nrekit/framework.py",
    "chars": 12349,
    "preview": "import tensorflow as tf\nimport os\nimport sklearn.metrics\nimport numpy as np\nimport sys\nimport time\n\ndef average_gradient"
  },
  {
    "path": "nrekit/network/classifier.py",
    "chars": 1776,
    "preview": "import tensorflow as tf\nimport numpy as np\n\ndef softmax_cross_entropy(x, label, rel_tot, weights_table=None, weights=1.0"
  },
  {
    "path": "nrekit/network/embedding.py",
    "chars": 2268,
    "preview": "import tensorflow as tf\nimport numpy as np\n\ndef word_embedding(word, word_vec_mat, var_scope=None, word_embedding_dim=50"
  },
  {
    "path": "nrekit/network/encoder.py",
    "chars": 3449,
    "preview": "import tensorflow as tf\nimport numpy as np\nimport math\n\ndef __dropout__(x, keep_prob=1.0):\n    return tf.contrib.layers."
  },
  {
    "path": "nrekit/network/selector.py",
    "chars": 6971,
    "preview": "import tensorflow as tf\nimport numpy as np\n\ndef __dropout__(x, keep_prob=1.0):\n    return tf.contrib.layers.dropout(x, k"
  },
  {
    "path": "nrekit/rl.py",
    "chars": 11764,
    "preview": "import tensorflow as tf\nimport os\nimport sklearn.metrics\nimport numpy as np\nimport sys\nimport math\nimport time\nimport fr"
  },
  {
    "path": "test_demo.py",
    "chars": 5377,
    "preview": "import nrekit\nimport numpy as np\nimport tensorflow as tf\nimport sys\nimport os\nimport json\n\ndataset_name = 'nyt'\nif len(s"
  },
  {
    "path": "train_demo.py",
    "chars": 5909,
    "preview": "import nrekit\nimport numpy as np\nimport tensorflow as tf\nimport sys\nimport os\n\ndataset_name = 'nyt'\nif len(sys.argv) > 1"
  }
]

About this extraction

This page contains the full source code of the xiaolalala/Distant-Supervised-Chinese-Relation-Extraction GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 19 files (111.5 KB), approximately 35.6k tokens, and a symbol index with 88 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!