Repository: RowitZou/LGN
Branch: master
Commit: c1e7a681f514
Files: 11
Total size: 84.2 KB
Directory structure:
gitextract_y25cez6i/
├── LICENSE
├── README.md
├── main.py
├── model/
│ ├── LGN.py
│ ├── crf.py
│ └── module.py
└── utils/
├── alphabet.py
├── data.py
├── functions.py
├── metric.py
└── word_trie.py
================================================
FILE CONTENTS
================================================
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2019 Yicheng Zou
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# LGN
Pytorch implementation of [A Lexicon-Based Graph Neural Network for Chinese NER](https://www.aclweb.org/anthology/D19-1096.pdf).
The code is partially referred to https://github.com/jiesutd/LatticeLSTM.
## Requirements
* Python 3.6 or higher
* Pytorch 0.4.1 or higher
## Input Format
BMES tag scheme, with each character its label for one line. Sentences are splited with a null line.
印 B-LOC
度 M-LOC
河 E-LOC
流 O
经 O
印 B-GPE
度 E-GPE
## Usage
* Training
python main.py --status train \
--train data/onto4ner.cn/train.char.bmes \
--dev data/onto4ner.cn/dev.char.bmes \
--test data/onto4ner.cn/test.char.bmes \
--saved_model saved_model/model_onto4ner \
--saved_set data/onto4ner.cn/saved.dset
* Testing
python main.py --status test \
--test data/onto4ner.cn/test.char.bmes \
--saved_model saved_model/model_onto4ner \
--saved_set data/onto4ner.cn/saved.dset
* Decoding (Raw file can either be labeled or not.)
python main.py --status decode \
--raw data/onto4ner.cn/test.char.bmes \
--output tagged_file.txt \
--saved_model saved_model/model_onto4ner \
--saved_set data/onto4ner.cn/saved.dset
## Data
The pretrained character and word embeddings can be downloaded from [Lattice LSTM](https://github.com/jiesutd/LatticeLSTM).
Original datasets can be found at [OntoNotes](https://catalog.ldc.upenn.edu/LDC2011T03), [MSRA](http://sighan.cs.uchicago.edu/bakeoff2006/),
[Weibo](https://github.com/hltcoe/golden-horse) and [Resume](https://github.com/jiesutd/LatticeLSTM/tree/master/ResumeNER).
The preprocessed datasets that satisfy the input format of our codes are available at [Google Drive](https://drive.google.com/open?id=1Rvju5_gp2E6BFiqzMBtnMqVP803AbBcm) and
[Baidu Pan](https://pan.baidu.com/s/1zbzLriRpc8S_5ez_upC7OA) (Code: akcm)
## Pretrained Model Downloads
We also provide pretrained models on the four datasets, which are the same models as reported in the paper.
If you try to retrain models from scratch under the same hyper-parameter settings, you may obtain a sightly
lower or higher F1 score than that reported in the paper (in our experiments we selected the model that performed best).
Pretrained models and related hyper-parameter settings are available at [Google Drive](https://drive.google.com/file/d/1KKkCW8WRhgR2P2UbRpNpKyE_RAv1EREv/view?usp=sharing) or [Baidu Pan](https://pan.baidu.com/s/1U89EwnhPpMa4bNrS--u4EA).
When running main.py in test mode for pretrained models, you can get the results as follows:
| Datasets | Precision | Recall | F1 |
|:--------------:|:---------:|:-------:|:-----:|
| OntoNotes dev | 74.00 | 70.03 | 71.96 |
| OntoNotes test | 76.13 | 73.68 | 74.89 |
| MSRA dev | - | - | - |
| MSRA test | 94.19 | 92.73 | 93.46 |
| Weibo dev | 66.09 | 59.13 | 62.42 |
| Weibo test | 65.71 | 55.56 | 60.21 |
| Resume dev | 94.27 | 94.59 | 94.43 |
| Resume test | 95.28 | 95.46 | 95.37 |
## Citation
@inproceedings{gui2019lexicon,
title={A Lexicon-Based Graph Neural Network for Chinese NER},
author={Gui, Tao and Zou, Yicheng and Zhang, Qi and Peng, Minlong and
Fu, Jinlan and Wei, Zhongyu and Huang, Xuanjing},
booktitle={Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing
and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)},
pages={1039--1049},
year={2019}
}
================================================
FILE: main.py
================================================
# -*- coding: utf-8 -*-
# @Author: Yicheng Zou
# @Last Modified by: Yicheng Zou, Contact: yczou18@fudan.edu.cn
import time
import sys
import argparse
import random
import torch
import gc
import pickle
import os
import torch.autograd as autograd
import torch.optim as optim
import numpy as np
from utils.metric import get_ner_fmeasure
from model.LGN import Graph
from utils.data import Data
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def lr_decay(optimizer, epoch, decay_rate, init_lr):
lr = init_lr * ((1-decay_rate)**epoch)
print( " Learning rate is setted as:", lr)
for param_group in optimizer.param_groups:
if param_group['name'] == 'aggr':
param_group['lr'] = lr * 2.
else:
param_group['lr'] = lr
return optimizer
def data_initialization(data, word_file, train_file, dev_file, test_file):
data.build_word_file(word_file)
if train_file:
data.build_alphabet(train_file)
data.build_word_alphabet(train_file)
if dev_file:
data.build_alphabet(dev_file)
data.build_word_alphabet(dev_file)
if test_file:
data.build_alphabet(test_file)
data.build_word_alphabet(test_file)
return data
def predict_check(pred_variable, gold_variable, mask_variable):
pred = pred_variable.cpu().data.numpy()
gold = gold_variable.cpu().data.numpy()
mask = mask_variable.cpu().data.numpy()
overlaped = (pred == gold)
right_token = np.sum(overlaped * mask)
total_token = mask.sum()
return right_token, total_token
def recover_label(pred_variable, gold_variable, mask_variable, label_alphabet):
batch_size = gold_variable.size(0)
seq_len = gold_variable.size(1)
mask = mask_variable.cpu().data.numpy()
pred_tag = pred_variable.cpu().data.numpy()
gold_tag = gold_variable.cpu().data.numpy()
pred_label = []
gold_label = []
for idx in range(batch_size):
pred = [label_alphabet.get_instance(pred_tag[idx][idy]) for idy in range(seq_len) if mask[idx][idy] != 0]
gold = [label_alphabet.get_instance(gold_tag[idx][idy]) for idy in range(seq_len) if mask[idx][idy] != 0]
assert(len(pred)==len(gold))
pred_label.append(pred)
gold_label.append(gold)
return pred_label, gold_label
def print_args(args):
print("CONFIG SUMMARY:")
print(" Batch size: %s" % (args.batch_size))
print(" If use GPU: %s" % (args.use_gpu))
print(" If use CRF: %s" % (args.use_crf))
print(" Epoch number: %s" % (args.num_epoch))
print(" Learning rate: %s" % (args.lr))
print(" L2 normalization rate: %s" % (args.weight_decay))
print(" If use edge embedding: %s" % (args.use_edge))
print(" If use global node: %s" % (args.use_global))
print(" Bidirectional digraph: %s" % (args.bidirectional))
print(" Update step number: %s" % (args.iters))
print(" Attention dropout rate: %s" % (args.tf_drop_rate))
print(" Embedding dropout rate: %s" % (args.emb_drop_rate))
print(" Hidden state dimension: %s" % (args.hidden_dim))
print(" Learning rate decay ratio: %s" % (args.lr_decay))
print(" Aggregation module dropout rate: %s" % (args.cell_drop_rate))
print(" Head number of attention: %s" % (args.num_head))
print(" Head dimension of attention: %s" % (args.head_dim))
print("CONFIG SUMMARY END.")
sys.stdout.flush()
def evaluate(data, args, model, name):
if name == "train":
instances = data.train_Ids
elif name == "dev":
instances = data.dev_Ids
elif name == 'test':
instances = data.test_Ids
elif name == 'raw':
instances = data.raw_Ids
else:
print("Error: wrong evaluate name,", name)
exit(0)
pred_results = []
gold_results = []
# set model in eval model
model.eval()
batch_size = args.batch_size
start_time = time.time()
train_num = len(instances)
total_batch = train_num // batch_size + 1
for batch_id in range(total_batch):
start = batch_id*batch_size
end = (batch_id+1)*batch_size
if end > train_num:
end = train_num
instance = instances[start:end]
if not instance:
continue
word_list, batch_char, batch_label, mask = batchify_with_label(instance, args.use_gpu)
_, tag_seq = model(word_list, batch_char, mask)
pred_label, gold_label = recover_label(tag_seq, batch_label, mask, data.label_alphabet)
pred_results += pred_label
gold_results += gold_label
decode_time = time.time() - start_time
speed = len(instances) / decode_time
acc, p, r, f = get_ner_fmeasure(gold_results, pred_results)
return speed, acc, p, r, f, pred_results
def batchify_with_label(input_batch_list, gpu):
batch_size = len(input_batch_list)
chars = [sent[0] for sent in input_batch_list]
words = [sent[1] for sent in input_batch_list]
labels = [sent[2] for sent in input_batch_list]
sent_lengths = torch.LongTensor(list(map(len, chars)))
max_sent_len = sent_lengths.max()
char_seq_tensor = autograd.Variable(torch.zeros((batch_size, max_sent_len))).long()
label_seq_tensor = autograd.Variable(torch.zeros((batch_size, max_sent_len))).long()
mask = autograd.Variable(torch.zeros((batch_size, max_sent_len))).byte()
for idx, (seq, label, seq_len) in enumerate(zip(chars, labels, sent_lengths)):
char_seq_tensor[idx, :seq_len] = torch.LongTensor(seq)
label_seq_tensor[idx, :seq_len] = torch.LongTensor(label)
mask[idx, :seq_len] = torch.Tensor([1] * int(seq_len))
if gpu:
char_seq_tensor = char_seq_tensor.cuda()
label_seq_tensor = label_seq_tensor.cuda()
mask = mask.cuda()
return words, char_seq_tensor, label_seq_tensor, mask
def train(data, args, saved_model_path):
print( "Training model...")
model = Graph(data, args)
if args.use_gpu:
model = model.cuda()
print('# generated parameters:', sum(param.numel() for param in model.parameters()))
print( "Finished built model.")
best_dev_epoch = 0
best_dev_f = -1
best_dev_p = -1
best_dev_r = -1
best_test_f = -1
best_test_p = -1
best_test_r = -1
# Initialize the optimizer
aggr_module_params = []
other_module_params = []
for m_name in model._modules:
m = model._modules[m_name]
if isinstance(m, torch.nn.ModuleList):
for p in m.parameters():
if p.requires_grad:
aggr_module_params.append(p)
else:
for p in m.parameters():
if p.requires_grad:
other_module_params.append(p)
optimizer = optim.Adam([
{"params": (aggr_module_params), "name": "aggr"},
{"params": (other_module_params), "name": "other"}
],
lr=args.lr,
weight_decay=args.weight_decay
)
for idx in range(args.num_epoch):
epoch_start = time.time()
temp_start = epoch_start
print(("Epoch: %s/%s" %(idx, args.num_epoch)))
optimizer = lr_decay(optimizer, idx, args.lr_decay, args.lr)
sample_loss = 0
batch_loss = 0
total_loss = 0
right_token = 0
whole_token = 0
random.shuffle(data.train_Ids)
# set model in train model
model.train()
model.zero_grad()
batch_size = args.batch_size
train_num = len(data.train_Ids)
total_batch = train_num // batch_size + 1
for batch_id in range(total_batch):
# Get one batch-sized instance
start = batch_id * batch_size
end = (batch_id + 1) * batch_size
if end > train_num:
end = train_num
instance = data.train_Ids[start:end]
if not instance:
continue
word_list, batch_char, batch_label, mask = batchify_with_label(instance, args.use_gpu)
loss, tag_seq = model(word_list, batch_char, mask, batch_label)
right, whole = predict_check(tag_seq, batch_label, mask)
right_token += right
whole_token += whole
sample_loss += loss.data
total_loss += loss.data
batch_loss += loss
if end % 500 == 0:
temp_time = time.time()
temp_cost = temp_time - temp_start
temp_start = temp_time
print((" Instance: %s; Time: %.2fs; loss: %.4f; acc: %s/%s=%.4f" %
(end, temp_cost, sample_loss, right_token, whole_token, (right_token+0.)/whole_token)))
sys.stdout.flush()
sample_loss = 0
if end % args.batch_size == 0:
batch_loss.backward()
optimizer.step()
model.zero_grad()
batch_loss = 0
temp_time = time.time()
temp_cost = temp_time - temp_start
print((" Instance: %s; Time: %.2fs; loss: %.4f; acc: %s/%s=%.4f" %
(end, temp_cost, sample_loss, right_token, whole_token, (right_token+0.)/whole_token)))
epoch_finish = time.time()
epoch_cost = epoch_finish - epoch_start
print(("Epoch: %s training finished. Time: %.2fs, speed: %.2fst/s, total loss: %s" %
(idx, epoch_cost, train_num/epoch_cost, total_loss)))
# dev
speed, acc, dev_p, dev_r, dev_f, _ = evaluate(data, args, model, "dev")
dev_finish = time.time()
dev_cost = dev_finish - epoch_finish
print(("Dev: time: %.2fs, speed: %.2fst/s; acc: %.4f, p: %.4f, r: %.4f, f: %.4f" %
(dev_cost, speed, acc, dev_p, dev_r, dev_f)))
# test
speed, acc, test_p, test_r, test_f, _ = evaluate(data, args, model, "test")
test_finish = time.time()
test_cost = test_finish - dev_finish
print(("Test: time: %.2fs, speed: %.2fst/s; acc: %.4f, p: %.4f, r: %.4f, f: %.4f" %
(test_cost, speed, acc, test_p, test_r, test_f)))
if dev_f > best_dev_f:
print("Exceed previous best f score: %.4f" % best_dev_f)
torch.save(model.state_dict(), saved_model_path + "_best")
best_dev_p = dev_p
best_dev_r = dev_r
best_dev_f = dev_f
best_dev_epoch = idx + 1
best_test_p = test_p
best_test_r = test_r
best_test_f = test_f
model_idx_path = saved_model_path + "_" + str(idx)
torch.save(model.state_dict(), model_idx_path)
with open(saved_model_path + "_result.txt", "a") as file:
file.write(model_idx_path + '\n')
file.write("Dev score: %.4f, r: %.4f, f: %.4f\n" % (dev_p, dev_r, dev_f))
file.write("Test score: %.4f, r: %.4f, f: %.4f\n\n" % (test_p, test_r, test_f))
file.close()
print("Best dev epoch: %d" % best_dev_epoch)
print("Best dev score: p: %.4f, r: %.4f, f: %.4f" % (best_dev_p, best_dev_r, best_dev_f))
print("Best test score: p: %.4f, r: %.4f, f: %.4f" % (best_test_p, best_test_r, best_test_f))
gc.collect()
with open(saved_model_path + "_result.txt", "a") as file:
file.write("Best epoch: %d" % best_dev_epoch + '\n')
file.write("Best Dev score: %.4f, r: %.4f, f: %.4f\n" % (best_dev_p, best_dev_r, best_dev_f))
file.write("Test score: %.4f, r: %.4f, f: %.4f\n\n" % (best_test_p, best_test_r, best_test_f))
file.close()
with open(saved_model_path + "_best_HP.config", "wb") as file:
pickle.dump(args, file)
def load_model_decode(model_dir, data, args, name):
model_dir = model_dir + "_best"
print("Load Model from file: ", model_dir)
model = Graph(data, args)
model.load_state_dict(torch.load(model_dir))
# load model need consider if the model trained in GPU and load in CPU, or vice versa
if args.use_gpu:
model = model.cuda()
print(("Decode %s data ..." % name))
start_time = time.time()
speed, acc, p, r, f, pred_results = evaluate(data, args, model, name)
end_time = time.time()
time_cost = end_time - start_time
print(("%s: time:%.2fs, speed:%.2fst/s; acc: %.4f, p: %.4f, r: %.4f, f: %.4f" %
(name, time_cost, speed, acc, p, r, f)))
return pred_results
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--status', choices=['train', 'test', 'decode'], help='Function status.', default='train')
parser.add_argument('--use_gpu', type=str2bool, default=True)
parser.add_argument('--train', help='Training set.', default='data/onto4ner.cn/train.char.bmes')
parser.add_argument('--dev', help='Developing set.', default='data/onto4ner.cn/dev.char.bmes')
parser.add_argument('--test', help='Testing set.', default='data/onto4ner.cn/test.char.bmes')
parser.add_argument('--raw', help='Raw file for decoding.')
parser.add_argument('--output', help='Output results for decoding.')
parser.add_argument('--saved_set', help='Path of saved data set.', default='data/onto4ner.cn/saved.dset')
parser.add_argument('--saved_model', help='Path of saved model.', default="saved_model/model_onto4ner")
parser.add_argument('--char_emb', help='Path of character embedding file.', default="data/gigaword_chn.all.a2b.uni.ite50.vec")
parser.add_argument('--word_emb', help='Path of word embedding file.', default="data/ctb.50d.vec")
parser.add_argument('--use_crf', type=str2bool, default=True)
parser.add_argument('--use_edge', type=str2bool, default=True, help='If use lexicon embeddings (edge embeddings).')
parser.add_argument('--use_global', type=str2bool, default=True, help='If use the global node.')
parser.add_argument('--bidirectional', type=str2bool, default=True, help='If use bidirectional digraph.')
parser.add_argument('--seed', help='Random seed', default=1023, type=int)
parser.add_argument('--batch_size', help='Batch size.', default=1, type=int)
parser.add_argument('--num_epoch',default=100, type=int, help="Epoch number.")
parser.add_argument('--iters', default=4, type=int, help='The number of Graph iterations.')
parser.add_argument('--hidden_dim', default=50, type=int, help='Hidden state size.')
parser.add_argument('--num_head', default=10, type=int, help='Number of transformer head.')
parser.add_argument('--head_dim', default=20, type=int, help='Head dimension of transformer.')
parser.add_argument('--tf_drop_rate', default=0.1, type=float, help='Transformer dropout rate.')
parser.add_argument('--emb_drop_rate', default=0.5, type=float, help='Embedding dropout rate.')
parser.add_argument('--cell_drop_rate', default=0.2, type=float, help='Aggregation module dropout rate.')
parser.add_argument('--word_alphabet_size', type=int, help='Word alphabet size.')
parser.add_argument('--char_alphabet_size', type=int, help='Char alphabet size.')
parser.add_argument('--label_alphabet_size', type=int, help='Label alphabet size.')
parser.add_argument('--char_dim', type=int, help='Char embedding size.')
parser.add_argument('--word_dim', type=int, help='Word embedding size.')
parser.add_argument('--lr', type=float, default=2e-05)
parser.add_argument('--lr_decay', type=float, default=0)
parser.add_argument('--weight_decay', type=float, default=0)
args = parser.parse_args()
status = args.status.lower()
seed_num = args.seed
random.seed(seed_num)
torch.manual_seed(seed_num)
np.random.seed(seed_num)
train_file = args.train
dev_file = args.dev
test_file = args.test
raw_file = args.raw
output_file = args.output
saved_set_path = args.saved_set
saved_model_path = args.saved_model
char_file = args.char_emb
word_file = args.word_emb
if status == 'train':
if os.path.exists(saved_set_path):
print('Loading saved data set...')
with open(saved_set_path, 'rb') as f:
data = pickle.load(f)
else:
data = Data()
data_initialization(data, word_file, train_file, dev_file, test_file)
data.generate_instance_with_words(train_file, 'train')
data.generate_instance_with_words(dev_file, 'dev')
data.generate_instance_with_words(test_file, 'test')
data.build_char_pretrain_emb(char_file)
data.build_word_pretrain_emb(word_file)
if saved_set_path is not None:
print('Dumping data...')
with open(saved_set_path, 'wb') as f:
pickle.dump(data, f)
data.show_data_summary()
args.word_alphabet_size = data.word_alphabet.size()
args.char_alphabet_size = data.char_alphabet.size()
args.label_alphabet_size = data.label_alphabet.size()
args.char_dim = data.char_emb_dim
args.word_dim = data.word_emb_dim
print_args(args)
train(data, args, saved_model_path)
elif status == 'test':
assert not (test_file is None)
if os.path.exists(saved_set_path):
print('Loading saved data set...')
with open(saved_set_path, 'rb') as f:
data = pickle.load(f)
else:
print("Cannot find saved data set: ", saved_set_path)
exit(0)
data.generate_instance_with_words(test_file, 'test')
with open(saved_model_path + "_best_HP.config", "rb") as f:
args = pickle.load(f)
data.show_data_summary()
print_args(args)
load_model_decode(saved_model_path, data, args, "test")
elif status == 'decode':
assert not (raw_file is None or output_file is None)
if os.path.exists(saved_set_path):
print('Loading saved data set...')
with open(saved_set_path, 'rb') as f:
data = pickle.load(f)
else:
print("Cannot find saved data set: ", saved_set_path)
exit(0)
data.generate_instance_with_words(raw_file, 'raw')
with open(saved_model_path + "_best_HP.config", "rb") as f:
args = pickle.load(f)
data.show_data_summary()
print_args(args)
decode_results = load_model_decode(saved_model_path, data, args, "raw")
data.write_decoded_results(output_file, decode_results, 'raw')
else:
print("Invalid argument! Please use valid arguments! (train/test/decode)")
================================================
FILE: model/LGN.py
================================================
# -*- coding: utf-8 -*-
# @Author: Yicheng Zou
# @Last Modified by: Yicheng Zou, Contact: yczou18@fudan.edu.cn
from model.crf import CRF
from model.module import *
class Graph(nn.Module):
def __init__(self, data, args):
super(Graph, self).__init__()
self.gpu = args.use_gpu
self.char_emb_dim = args.char_dim
self.word_emb_dim = args.word_dim
self.hidden_dim = args.hidden_dim
self.num_head = args.num_head # 5 10 20
self.head_dim = args.head_dim # 10 20
self.tf_dropout_rate = args.tf_drop_rate
self.iters = args.iters
self.bmes_dim = 10
self.length_dim = 10
self.max_word_length = 5
self.emb_dropout_rate = args.emb_drop_rate
self.cell_dropout_rate = args.cell_drop_rate
self.use_crf = args.use_crf
self.use_global = args.use_global
self.use_edge = args.use_edge
self.bidirectional = args.bidirectional
self.label_size = args.label_alphabet_size
# char embedding
self.char_embedding = nn.Embedding(args.char_alphabet_size, self.char_emb_dim)
if data.pretrain_char_embedding is not None:
self.char_embedding.weight.data.copy_(torch.from_numpy(data.pretrain_char_embedding))
if self.use_edge:
# word embedding
self.word_embedding = nn.Embedding(args.word_alphabet_size, self.word_emb_dim)
if data.pretrain_word_embedding is not None:
scale = np.sqrt(3.0 / self.word_emb_dim)
data.pretrain_word_embedding[0, :] = np.random.uniform(-scale, scale, [1, self.word_emb_dim])
self.word_embedding.weight.data.copy_(torch.from_numpy(data.pretrain_word_embedding))
# bmes embedding
self.bmes_embedding = nn.Embedding(4, self.bmes_dim)
"""
self.edge_emb_linear = nn.Sequential(
nn.Linear(self.word_emb_dim, self.hidden_dim),
nn.ELU()
)
"""
# lstm
self.emb_rnn_f = nn.LSTM(self.char_emb_dim, self.hidden_dim, batch_first=True)
self.emb_rnn_b = nn.LSTM(self.char_emb_dim, self.hidden_dim, batch_first=True)
# length embedding
self.length_embedding = nn.Embedding(self.max_word_length, self.length_dim)
self.dropout = nn.Dropout(self.emb_dropout_rate)
self.norm = nn.LayerNorm(self.hidden_dim)
if self.use_edge:
# Node aggregation module
self.edge2node_f = nn.ModuleList(
[MultiHeadAtt(self.hidden_dim, self.hidden_dim * 2 + self.length_dim,
nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate)
for _ in range(self.iters)])
# Edge aggregation module
self.node2edge_f = nn.ModuleList(
[MultiHeadAtt(self.hidden_dim, self.hidden_dim + self.bmes_dim, nhead=self.num_head,
head_dim=self.head_dim, dropout=self.tf_dropout_rate)
for _ in range(self.iters)])
else:
# Node aggregation module
self.edge2node_f = nn.ModuleList(
[MultiHeadAtt(self.hidden_dim, self.hidden_dim + self.length_dim,
nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate)
for _ in range(self.iters)])
if self.use_global:
# Global Node aggregation module
self.glo_att_f_node = nn.ModuleList(
[GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate)
for _ in range(self.iters)])
if self.use_edge:
self.glo_att_f_edge = nn.ModuleList(
[GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate)
for _ in range(self.iters)])
# Updating modules
if self.use_edge:
self.glo_rnn_f = Global_Cell(self.hidden_dim * 3, self.hidden_dim, dropout=self.cell_dropout_rate)
self.node_rnn_f = Nodes_Cell(self.hidden_dim * 5, self.hidden_dim, dropout=self.cell_dropout_rate)
self.edge_rnn_f = Edges_Cell(self.hidden_dim * 4, self.hidden_dim, dropout=self.cell_dropout_rate)
else:
self.glo_rnn_f = Global_Cell(self.hidden_dim * 2, self.hidden_dim, dropout=self.cell_dropout_rate)
self.node_rnn_f = Nodes_Cell(self.hidden_dim * 4, self.hidden_dim, dropout=self.cell_dropout_rate)
else:
# Updating modules
self.node_rnn_f = Nodes_Cell(self.hidden_dim * 3, self.hidden_dim, use_global=False, dropout=self.cell_dropout_rate)
if self.use_edge:
self.edge_rnn_f = Edges_Cell(self.hidden_dim * 2, self.hidden_dim, use_global=False, dropout=self.cell_dropout_rate)
if self.bidirectional:
if self.use_edge:
# Node aggregation module
self.edge2node_b = nn.ModuleList(
[MultiHeadAtt(self.hidden_dim, self.hidden_dim * 2 + self.length_dim,
nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate)
for _ in range(self.iters)])
# Edge aggregation module
self.node2edge_b = nn.ModuleList(
[MultiHeadAtt(self.hidden_dim, self.hidden_dim + self.bmes_dim, nhead=self.num_head,
head_dim=self.head_dim, dropout=self.tf_dropout_rate)
for _ in range(self.iters)])
else:
# Node aggregation module
self.edge2node_b = nn.ModuleList(
[MultiHeadAtt(self.hidden_dim, self.hidden_dim + self.length_dim,
nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate)
for _ in range(self.iters)])
if self.use_global:
# Global Node aggregation module
self.glo_att_b_node = nn.ModuleList(
[GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate)
for _ in range(self.iters)])
if self.use_edge:
self.glo_att_b_edge = nn.ModuleList(
[GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate)
for _ in range(self.iters)])
# Updating modules
if self.use_edge:
self.glo_rnn_b = Global_Cell(self.hidden_dim * 3, self.hidden_dim, self.cell_dropout_rate)
self.node_rnn_b = Nodes_Cell(self.hidden_dim * 5, self.hidden_dim, self.cell_dropout_rate)
self.edge_rnn_b = Edges_Cell(self.hidden_dim * 4, self.hidden_dim, self.cell_dropout_rate)
else:
self.glo_rnn_b = Global_Cell(self.hidden_dim * 2, self.hidden_dim, self.cell_dropout_rate)
self.node_rnn_b = Nodes_Cell(self.hidden_dim * 4, self.hidden_dim, self.cell_dropout_rate)
else:
# Updating modules
self.node_rnn_b = Nodes_Cell(self.hidden_dim * 3, self.hidden_dim, use_global=False, dropout=self.cell_dropout_rate)
if self.use_edge:
self.edge_rnn_b = Edges_Cell(self.hidden_dim * 2, self.hidden_dim, use_global=False, dropout=self.cell_dropout_rate)
if self.bidirectional:
output_dim = self.hidden_dim * 2
else:
output_dim = self.hidden_dim
self.layer_att_W = nn.Linear(output_dim, 1)
if self.use_crf:
self.hidden2tag = nn.Linear(output_dim, self.label_size + 2)
self.crf = CRF(self.label_size, self.gpu)
else:
self.hidden2tag = nn.Linear(output_dim, self.label_size)
self.criterion = nn.CrossEntropyLoss()
def construct_graph(self, batch_size, seq_len, word_list):
if self.cuda:
device = 'cuda'
else:
device = 'cpu'
if self.use_edge:
unk_index = torch.tensor(0, device=device)
unk_emb = self.word_embedding(unk_index)
bmes_emb_b = self.bmes_embedding(torch.tensor(0, device=device))
bmes_emb_m = self.bmes_embedding(torch.tensor(1, device=device))
bmes_emb_e = self.bmes_embedding(torch.tensor(2, device=device))
bmes_emb_s = self.bmes_embedding(torch.tensor(3, device=device))
sen_nodes_mask_list = []
sen_words_length_list =[]
sen_words_mask_f_list = []
sen_words_mask_b_list = []
sen_word_embed_list = []
sen_bmes_embed_list = []
max_edge_num = -1
for sen in range(batch_size):
sen_nodes_mask = torch.zeros([1, seq_len], device=device).byte()
sen_words_length = torch.zeros([1, self.length_dim], device=device)
sen_words_mask_f = torch.zeros([1, seq_len], device=device).byte()
sen_words_mask_b = torch.zeros([1, seq_len], device=device).byte()
if self.use_edge:
sen_word_embed = unk_emb[None, :]
sen_bmes_embed = torch.zeros([1, seq_len, self.bmes_dim], device=device)
for w in range(seq_len):
if w < len(word_list[sen]) and word_list[sen][w]:
for word, word_len in zip(word_list[sen][w][0], word_list[sen][w][1]):
if word_len <= self.max_word_length:
word_length_index = torch.tensor(word_len-1, device=device)
else:
word_length_index = torch.tensor(self.max_word_length - 1, device=device)
word_length = self.length_embedding(word_length_index)
sen_words_length = torch.cat([sen_words_length, word_length[None, :]], 0)
# mask: Masked elements are marked by 1, batch_size * word_num * seq_len
nodes_mask = torch.ones([1, seq_len], device=device).byte()
words_mask_f = torch.ones([1, seq_len], device=device).byte()
words_mask_b = torch.ones([1, seq_len], device=device).byte()
words_mask_f[0, w + word_len - 1] = 0
sen_words_mask_f = torch.cat([sen_words_mask_f, words_mask_f], 0)
words_mask_b[0, w] = 0
sen_words_mask_b = torch.cat([sen_words_mask_b, words_mask_b], 0)
if self.use_edge:
word_index = torch.tensor(word, device=device)
word_embedding = self.word_embedding(word_index)
sen_word_embed = torch.cat([sen_word_embed, word_embedding[None, :]], 0)
bmes_embed = torch.zeros([1, seq_len, self.bmes_dim], device=device)
for index in range(word_len):
nodes_mask[0, w + index] = 0
if word_len == 1:
bmes_embed[0, w + index, :] = bmes_emb_s
elif index == 0:
bmes_embed[0, w + index, :] = bmes_emb_b
elif index == word_len - 1:
bmes_embed[0, w + index, :] = bmes_emb_e
else:
bmes_embed[0, w + index, :] = bmes_emb_m
sen_bmes_embed = torch.cat([sen_bmes_embed, bmes_embed], 0)
sen_nodes_mask = torch.cat([sen_nodes_mask, nodes_mask], 0)
if sen_words_mask_f.size(0) > max_edge_num:
max_edge_num = sen_words_mask_f.size(0)
sen_words_mask_f_list.append(sen_words_mask_f.unsqueeze_(0))
sen_words_mask_b_list.append(sen_words_mask_b.unsqueeze_(0))
sen_words_length_list.append(sen_words_length.unsqueeze_(0))
if self.use_edge:
sen_nodes_mask_list.append(sen_nodes_mask.unsqueeze_(0))
sen_word_embed_list.append(sen_word_embed.unsqueeze_(0))
sen_bmes_embed_list.append(sen_bmes_embed.unsqueeze_(0))
edges_mask = torch.zeros([batch_size, max_edge_num], device=device)
batch_words_mask_f = torch.ones([batch_size, max_edge_num, seq_len], device=device).byte()
batch_words_mask_b = torch.ones([batch_size, max_edge_num, seq_len], device=device).byte()
batch_words_length = torch.zeros([batch_size, max_edge_num, self.length_dim], device=device)
if self.use_edge:
batch_nodes_mask = torch.zeros([batch_size, max_edge_num, seq_len], device=device).byte()
batch_word_embed = torch.zeros([batch_size, max_edge_num, self.word_emb_dim], device=device)
batch_bmes_embed = torch.zeros([batch_size, max_edge_num, seq_len, self.bmes_dim], device=device)
else:
batch_word_embed = None
batch_bmes_embed = None
batch_nodes_mask = None
for index in range(batch_size):
curr_edge_num = sen_words_mask_f_list[index].size(1)
edges_mask[index, 0:curr_edge_num] = 1.
batch_words_mask_f[index, 0:curr_edge_num, :] = sen_words_mask_f_list[index]
batch_words_mask_b[index, 0:curr_edge_num, :] = sen_words_mask_b_list[index]
batch_words_length[index, 0:curr_edge_num, :] = sen_words_length_list[index]
if self.use_edge:
batch_nodes_mask[index, 0:curr_edge_num, :] = sen_nodes_mask_list[index]
batch_word_embed[index, 0:curr_edge_num, :] = sen_word_embed_list[index]
batch_bmes_embed[index, 0:curr_edge_num, :, :] = sen_bmes_embed_list[index]
return batch_word_embed, batch_bmes_embed, batch_nodes_mask, batch_words_mask_f, \
batch_words_mask_b, batch_words_length, edges_mask
def update_graph(self, word_list, word_inputs, mask):
mask = mask.float()
node_embeds = self.char_embedding(word_inputs) # batch_size, max_seq_len, embedding
B, L, _ = node_embeds.size()
edge_embs, bmes_embs, nodes_mask, words_mask_f, words_mask_b, words_length, edges_mask = \
self.construct_graph(B, L, word_list)
node_embeds = self.dropout(node_embeds)
_, N, _ = words_mask_f.size()
if self.use_edge:
edge_embs = self.dropout(edge_embs)
# forward direction digraph
nodes_f, _ = self.emb_rnn_f(node_embeds)
nodes_f = nodes_f * mask.unsqueeze(2)
nodes_f_cat = nodes_f[:, None, :, :]
_, _, H = nodes_f.size()
if self.use_edge:
edges_f = edge_embs * edges_mask.unsqueeze(2)
edges_f_cat = edges_f[:, None, :, :]
if self.use_global:
glo_f = edges_f.sum(1, keepdim=True) / edges_mask.sum(1, keepdim=True).unsqueeze_(2) + \
nodes_f.sum(1, keepdim=True) / mask.sum(1, keepdim=True).unsqueeze_(2)
glo_f_cat = glo_f[:, None, :, :]
else:
if self.use_global:
glo_f = (nodes_f * mask.unsqueeze(2)).sum(1, keepdim=True) / mask.sum(1, keepdim=True).unsqueeze_(2)
glo_f_cat = glo_f[:, None, :, :]
for i in range(self.iters):
# Attention-based aggregation
if self.use_edge and N > 1:
bmes_nodes_f = torch.cat([nodes_f.unsqueeze(2).expand(B, L, N, H), bmes_embs.transpose(1, 2)], -1)
edges_att_f = self.node2edge_f[i](edges_f, bmes_nodes_f, nodes_mask.transpose(1, 2))
nodes_begin_f = torch.sum(nodes_f[:, None, :, :] * (1 - words_mask_b)[:, :, :, None].float(), 2)
nodes_begin_f = torch.cat([torch.zeros([B, 1, H], device=nodes_f.device), nodes_begin_f[:, 1:N, :]], 1)
if self.use_edge:
nodes_att_f = self.edge2node_f[i](nodes_f, torch.cat([edges_f, nodes_begin_f, words_length], -1).unsqueeze(2), words_mask_f)
if self.use_global:
glo_att_f = torch.cat([self.glo_att_f_node[i](glo_f, nodes_f, (1 - mask).byte()),
self.glo_att_f_edge[i](glo_f, edges_f, (1 - edges_mask).byte())], -1)
else:
nodes_att_f = self.edge2node_f[i](nodes_f, torch.cat([nodes_begin_f, words_length], -1).unsqueeze(2), words_mask_f)
if self.use_global:
glo_att_f = self.glo_att_f_node[i](glo_f, nodes_f, (1 - mask).byte())
# RNN-based update
if self.use_edge and N > 1:
if self.use_global:
edges_f = torch.cat([edges_f[:, 0:1, :], self.edge_rnn_f(edges_f[:, 1:N, :],
edges_att_f[:, 1:N, :], glo_att_f.expand(B, N-1, H*2))], 1)
else:
edges_f = torch.cat([edges_f[:, 0:1, :], self.edge_rnn_f(edges_f[:, 1:N, :], edges_att_f[:, 1:N, :])], 1)
edges_f_cat = torch.cat([edges_f_cat, edges_f[:, None, :, :]], 1)
edges_f = torch.cat([edges_f[:, 0:1, :], self.norm(torch.sum(edges_f_cat[:, :, 1:N, :], 1))], 1)
nodes_f_r = torch.cat([torch.zeros([B, 1, self.hidden_dim], device=nodes_f.device), nodes_f[:, 0:(L-1), :]], 1)
if self.use_global:
nodes_f = self.node_rnn_f(nodes_f, nodes_f_r, nodes_att_f, glo_att_f.expand(B, L, -1))
else:
nodes_f = self.node_rnn_f(nodes_f, nodes_f_r, nodes_att_f)
nodes_f_cat = torch.cat([nodes_f_cat, nodes_f[:, None, :, :]], 1)
nodes_f = self.norm(torch.sum(nodes_f_cat, 1))
if self.use_global:
glo_f = self.glo_rnn_f(glo_f, glo_att_f)
glo_f_cat = torch.cat([glo_f_cat, glo_f[:, None, :, :]], 1)
glo_f = self.norm(torch.sum(glo_f_cat, 1))
nodes_cat = nodes_f_cat
# backward direction digraph
if self.bidirectional:
nodes_b, _ = self.emb_rnn_b(torch.flip(node_embeds, [1]))
nodes_b = torch.flip(nodes_b, [1])
nodes_b = nodes_b * mask.unsqueeze(2)
nodes_b_cat = nodes_b[:, None, :, :]
if self.use_edge:
edges_b = edge_embs * edges_mask.unsqueeze(2)
edges_b_cat = edges_b[:, None, :, :]
if self.use_global:
glo_b = edges_b.sum(1, keepdim=True) / edges_mask.sum(1, keepdim=True).unsqueeze_(2) + \
nodes_b.sum(1, keepdim=True) / mask.sum(1, keepdim=True).unsqueeze_(2)
glo_b_cat = glo_b[:, None, :, :]
else:
if self.use_global:
glo_b = nodes_b.sum(1, keepdim=True) / mask.sum(1, keepdim=True).unsqueeze_(2)
glo_b_cat = glo_b[:, None, :, :]
for i in range(self.iters):
# Attention-based aggregation
if self.use_edge and N > 1:
bmes_nodes_b = torch.cat([nodes_b.unsqueeze(2).expand(B, L, N, H), bmes_embs.transpose(1, 2)], -1)
edges_att_b = self.node2edge_b[i](edges_b, bmes_nodes_b, nodes_mask.transpose(1, 2))
nodes_begin_b = torch.sum(nodes_b[:, None, :, :] * (1 - words_mask_f)[:, :, :, None].float(), 2)
nodes_begin_b = torch.cat([torch.zeros([B, 1, H], device=nodes_b.device), nodes_begin_b[:, 1:N, :]], 1)
if self.use_edge:
nodes_att_b = self.edge2node_b[i](nodes_b, torch.cat([edges_b, nodes_begin_b, words_length], -1).unsqueeze(2), words_mask_b)
if self.use_global:
glo_att_b = torch.cat([self.glo_att_b_node[i](glo_b, nodes_b, (1-mask).byte()),
self.glo_att_b_edge[i](glo_b, edges_b, (1-edges_mask).byte())], -1)
else:
nodes_att_b = self.edge2node_b[i](nodes_b, torch.cat([nodes_begin_b, words_length], -1).unsqueeze(2), words_mask_b)
if self.use_global:
glo_att_b = self.glo_att_b_node[i](glo_b, nodes_b, (1-mask).byte())
# RNN-based update
if self.use_edge and N > 1:
if self.use_global:
edges_b = torch.cat([edges_b[:, 0:1, :], self.edge_rnn_b(edges_b[:, 1:N, :],
edges_att_b[:, 1:N, :], glo_att_b.expand(B, N-1, H*2))], 1)
else:
edges_b = torch.cat([edges_b[:, 0:1, :], self.edge_rnn_b(edges_b[:, 1:N, :], edges_att_b[:, 1:N, :])], 1)
edges_b_cat = torch.cat([edges_b_cat, edges_b[:, None, :, :]], 1)
edges_b = torch.cat([edges_b[:, 0:1, :], self.norm(torch.sum(edges_b_cat[:, :, 1:N, :], 1))], 1)
nodes_b_r = torch.cat([nodes_b[:, 1:L, :], torch.zeros([B, 1, self.hidden_dim], device=nodes_b.device)], 1)
if self.use_global:
nodes_b = self.node_rnn_b(nodes_b, nodes_b_r, nodes_att_b, glo_att_b.expand(B, L, -1))
else:
nodes_b = self.node_rnn_b(nodes_b, nodes_b_r, nodes_att_b)
nodes_b_cat = torch.cat([nodes_b_cat, nodes_b[:, None, :, :]], 1)
nodes_b = self.norm(torch.sum(nodes_b_cat, 1))
if self.use_global:
glo_b = self.glo_rnn_b(glo_b, glo_att_b)
glo_b_cat = torch.cat([glo_b_cat, glo_b[:, None, :, :]], 1)
glo_b = self.norm(torch.sum(glo_b_cat, 1))
nodes_cat = torch.cat([nodes_f_cat, nodes_b_cat], -1)
layer_att = torch.sigmoid(self.layer_att_W(nodes_cat))
layer_alpha = F.softmax(layer_att, 1)
nodes = torch.sum(layer_alpha * nodes_cat, 1)
tags = self.hidden2tag(nodes)
return tags
def forward(self, word_list, batch_inputs, mask, batch_label=None):
tags = self.update_graph(word_list, batch_inputs, mask)
if batch_label is not None:
if self.use_crf:
total_loss = self.crf.neg_log_likelihood_loss(tags, mask, batch_label)
else:
total_loss = self.criterion(tags.view(-1, self.label_size), batch_label.view(-1))
else:
total_loss = None
if self.use_crf:
_, tag_seq = self.crf._viterbi_decode(tags, mask)
else:
tag_seq = tags.argmax(-1)
return total_loss, tag_seq
================================================
FILE: model/crf.py
================================================
# -*- coding: utf-8 -*-
# @Author: Jie Yang
# @Last Modified by: Yicheng Zou, Contact: yczou18@fudan.edu.cn
import torch
import torch.autograd as autograd
import torch.nn as nn
START_TAG = -2
STOP_TAG = -1
# Compute log sum exp in a numerically stable way for the forward algorithm
def log_sum_exp(vec, m_size):
"""
calculate log of exp sum
args:
vec (batch_size, vanishing_dim, hidden_dim) : input tensor
m_size : hidden_dim
return:
batch_size, hidden_dim
"""
_, idx = torch.max(vec, 1) # B * 1 * M
max_score = torch.gather(vec, 1, idx.view(-1, 1, m_size)).view(-1, 1, m_size) # B * M
return max_score.view(-1, m_size) + torch.log(torch.sum(torch.exp(vec - max_score.expand_as(vec)), 1)).view(-1, m_size) # B * M
class CRF(nn.Module):
def __init__(self, tagset_size, gpu):
super(CRF, self).__init__()
print ("build batched crf...")
self.gpu = gpu
# Matrix of transition parameters. Entry i,j is the score of transitioning *to* i *from* j.
self.average_batch = False
self.tagset_size = tagset_size
# # We add 2 here, because of START_TAG and STOP_TAG
# # transitions (f_tag_size, t_tag_size), transition value from f_tag to t_tag
init_transitions = torch.zeros(self.tagset_size+2, self.tagset_size+2)
# init_transitions = torch.zeros(self.tagset_size+2, self.tagset_size+2)
# init_transitions[:,START_TAG] = -1000.0
# init_transitions[STOP_TAG,:] = -1000.0
# init_transitions[:,0] = -1000.0
# init_transitions[0,:] = -1000.0
if self.gpu:
init_transitions = init_transitions.cuda()
self.transitions = nn.Parameter(init_transitions) #(t+2,t+2)
# self.transitions = nn.Parameter(torch.Tensor(self.tagset_size+2, self.tagset_size+2))
# self.transitions.data.zero_()
def _calculate_PZ(self, feats, mask):
"""
input:
feats: (batch, seq_len, self.tag_size+2) (b,m,t+2)
masks: (batch, seq_len) (b,m)
"""
batch_size = feats.size(0)
seq_len = feats.size(1)
tag_size = feats.size(2)
# print feats.view(seq_len, tag_size)
assert(tag_size == self.tagset_size+2)
mask = mask.transpose(1,0).contiguous() #(m,b)
ins_num = seq_len * batch_size
## be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1)
feats = feats.transpose(1,0).contiguous().view(ins_num,1, tag_size).expand(ins_num, tag_size, tag_size) #(i,t+2,t+2) 第2维t+2的每一个是一样的
## need to consider start
scores = feats + self.transitions.view(1,tag_size,tag_size).expand(ins_num, tag_size, tag_size)
scores = scores.view(seq_len, batch_size, tag_size, tag_size)
# build iter
seq_iter = enumerate(scores)
_, inivalues = seq_iter.__next__() # bat_size * from_target_size * to_target_size (b,t,t) inivalues是每个句子的第一个字
# only need start from start_tag
partition = inivalues[:, START_TAG, :].clone().view(batch_size, tag_size, 1) # bat_size * to_target_size (b,t,1)
## add start score (from start to all tag, duplicate to batch_size)
# partition = partition + self.transitions[START_TAG,:].view(1, tag_size, 1).expand(batch_size, tag_size, 1)
# iter over last scores
for idx, cur_values in seq_iter:
# previous to_target is current from_target
# partition: previous results log(exp(from_target)), #(batch_size * from_target)
# cur_values: bat_size * from_target * to_target
cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
cur_partition = log_sum_exp(cur_values, tag_size) #(b,t)
# print cur_partition.data
# (bat_size * from_target * to_target) -> (bat_size * to_target)
# partition = utils.switch(partition, cur_partition, mask[idx].view(bat_size, 1).expand(bat_size, self.tagset_size)).view(bat_size, -1)
mask_idx = mask[idx, :].view(batch_size, 1).expand(batch_size, tag_size)
## effective updated partition part, only keep the partition value of mask value = 1
masked_cur_partition = cur_partition.masked_select(mask_idx)
## let mask_idx broadcastable, to disable warning
mask_idx = mask_idx.contiguous().view(batch_size, tag_size, 1)
## replace the partition where the maskvalue=1, other partition value keeps the same
partition.masked_scatter_(mask_idx, masked_cur_partition)
# until the last state, add transition score for all partition (and do log_sum_exp) then select the value in STOP_TAG
cur_values = self.transitions.view(1,tag_size, tag_size).expand(batch_size, tag_size, tag_size) + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
cur_partition = log_sum_exp(cur_values, tag_size) #(batch_size,hidden_dim)
final_partition = cur_partition[:, STOP_TAG] #(batch_size)
return final_partition.sum(), scores #scores: (seq_len, batch, tag_size, tag_size)
def _viterbi_decode(self, feats, mask):
"""
input:
feats: (batch, seq_len, self.tag_size+2)
mask: (batch, seq_len)
output:
decode_idx: (batch, seq_len) decoded sequence
path_score: (batch, 1) corresponding score for each sequence (to be implementated)
"""
batch_size = feats.size(0)
seq_len = feats.size(1)
tag_size = feats.size(2)
assert(tag_size == self.tagset_size+2)
## calculate sentence length for each sentence
length_mask = torch.sum(mask.long(), dim = 1).view(batch_size,1).long() #(batch_size,1) 每个句子的mask长度
## mask to (seq_len, batch_size)
mask = mask.transpose(1,0).contiguous() #(seq_len,b)
ins_num = seq_len * batch_size
## be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1)
feats = feats.transpose(1,0).contiguous().view(ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size) #(ins_num, tag_size, tag_size)
## need to consider start
scores = feats + self.transitions.view(1,tag_size,tag_size).expand(ins_num, tag_size, tag_size)
scores = scores.view(seq_len, batch_size, tag_size, tag_size)
# build iter
seq_iter = enumerate(scores)
## record the position of best score
back_points = list()
partition_history = list()
## reverse mask (bug for mask = 1- mask, use this as alternative choice)
# mask = 1 + (-1)*mask
mask = (1 - mask.long()).byte()
_, inivalues = seq_iter.__next__() # bat_size * from_target_size * to_target_size
# only need start from start_tag
partition = inivalues[:, START_TAG, :].clone().view(batch_size, tag_size, 1) # bat_size * to_target_size
partition_history.append(partition) #(seqlen,batch_size,tag_size,1)
# iter over last scores
for idx, cur_values in seq_iter:
# previous to_target is current from_target
# partition: previous results log(exp(from_target)), #(batch_size * from_target)
# cur_values: batch_size * from_target * to_target
cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
## forscores, cur_bp = torch.max(cur_values[:,:-2,:], 1) # do not consider START_TAG/STOP_TAG
partition, cur_bp = torch.max(cur_values,dim=1)
partition_history.append(partition.unsqueeze(2))
## cur_bp: (batch_size, tag_size) max source score position in current tag
## set padded label as 0, which will be filtered in post processing
cur_bp.masked_fill_(mask[idx].view(batch_size, 1).expand(batch_size, tag_size), 0)
back_points.append(cur_bp)
### add score to final STOP_TAG
partition_history = torch.cat(partition_history,dim=0).view(seq_len, batch_size,-1).transpose(1,0).contiguous() ## (batch_size, seq_len, tag_size)
### get the last position for each setences, and select the last partitions using gather()
last_position = length_mask.view(batch_size,1,1).expand(batch_size, 1, tag_size) -1
last_partition = torch.gather(partition_history, 1, last_position).view(batch_size,tag_size,1)
### calculate the score from last partition to end state (and then select the STOP_TAG from it)
last_values = last_partition.expand(batch_size, tag_size, tag_size) + self.transitions.view(1,tag_size, tag_size).expand(batch_size, tag_size, tag_size)
_, last_bp = torch.max(last_values, 1) #(batch_size,tag_size)
pad_zero = autograd.Variable(torch.zeros(batch_size, tag_size)).long()
if self.gpu:
pad_zero = pad_zero.cuda()
back_points.append(pad_zero)
back_points = torch.cat(back_points).view(seq_len, batch_size, tag_size)
## select end ids in STOP_TAG
pointer = last_bp[:, STOP_TAG] #(batch_size)
insert_last = pointer.contiguous().view(batch_size,1,1).expand(batch_size,1, tag_size)
back_points = back_points.transpose(1,0).contiguous() #(batch_size,sq_len,tag_size)
## move the end ids(expand to tag_size) to the corresponding position of back_points to replace the 0 values
# print "lp:",last_position
# print "il:",insert_last
back_points.scatter_(1, last_position, insert_last) ##(batch_size,sq_len,tag_size)
# print "bp:",back_points
# exit(0)
back_points = back_points.transpose(1,0).contiguous() #(seq_len, batch_size, tag_size)
## decode from the end, padded position ids are 0, which will be filtered if following evaluation
decode_idx = autograd.Variable(torch.LongTensor(seq_len, batch_size))
if self.gpu:
decode_idx = decode_idx.cuda()
decode_idx[-1] = pointer.data
for idx in range(len(back_points)-2, -1, -1):
pointer = torch.gather(back_points[idx], 1, pointer.contiguous().view(batch_size, 1)) #pointer's size:(batch_size,1)
decode_idx[idx] = pointer.squeeze(1).data
path_score = None
decode_idx = decode_idx.transpose(1,0) #(batch_size, sent_len)
return path_score, decode_idx #
def forward(self, feats):
path_score, best_path = self._viterbi_decode(feats)
return path_score, best_path
def _score_sentence(self, scores, mask, tags):
"""
input:
scores: variable (seq_len, batch, tag_size, tag_size)
mask: (batch, seq_len)
tags: tensor (batch, seq_len)
output:
score: sum of score for gold sequences within whole batch
"""
# Gives the score of a provided tag sequence
batch_size = scores.size(1)
seq_len = scores.size(0)
tag_size = scores.size(2)
## convert tag value into a new format, recorded label bigram information to index
new_tags = autograd.Variable(torch.LongTensor(batch_size, seq_len))
if self.gpu:
new_tags = new_tags.cuda()
for idx in range(seq_len):
if idx == 0:
## start -> first score
new_tags[:,0] = (tag_size - 2)*tag_size + tags[:,0]
else:
new_tags[:,idx] = tags[:,idx-1]*tag_size + tags[:,idx]
## transition for label to STOP_TAG
end_transition = self.transitions[:,STOP_TAG].contiguous().view(1, tag_size).expand(batch_size, tag_size)
## length for batch, last word position = length - 1
length_mask = torch.sum(mask.long(), dim = 1).view(batch_size,1).long()
## index the label id of last word
end_ids = torch.gather(tags, 1, length_mask - 1)
## index the transition score for end_id to STOP_TAG
end_energy = torch.gather(end_transition, 1, end_ids)
## convert tag as (seq_len, batch_size, 1)
new_tags = new_tags.transpose(1,0).contiguous().view(seq_len, batch_size, 1)
### need convert tags id to search from 400 positions of scores
tg_energy = torch.gather(scores.view(seq_len, batch_size, -1), 2, new_tags).view(seq_len, batch_size) # seq_len * bat_size
## mask transpose to (seq_len, batch_size)
tg_energy = tg_energy.masked_select(mask.transpose(1,0))
# ## calculate the score from START_TAG to first label
# start_transition = self.transitions[START_TAG,:].view(1, tag_size).expand(batch_size, tag_size)
# start_energy = torch.gather(start_transition, 1, tags[0,:])
## add all score together
# gold_score = start_energy.sum() + tg_energy.sum() + end_energy.sum()
gold_score = tg_energy.sum() + end_energy.sum()
return gold_score
def neg_log_likelihood_loss(self, feats, mask, tags):
# nonegative log likelihood
batch_size = feats.size(0)
forward_score, scores = self._calculate_PZ(feats, mask) #forward_score:long, scores: (seq_len, batch, tag_size, tag_size)
gold_score = self._score_sentence(scores, mask, tags)
#print ("batch, f:", forward_score.data, " g:", gold_score.data, " dis:", forward_score.data - gold_score.data)
# exit(0)
if self.average_batch:
return (forward_score - gold_score)/batch_size
else:
return forward_score - gold_score
================================================
FILE: model/module.py
================================================
# -*- coding: utf-8 -*-
# @Author: Yicheng Zou
# @Last Modified by: Yicheng Zou, Contact: yczou18@fudan.edu.cn
import torch
import math
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
class MultiHeadAtt(nn.Module):
def __init__(self, nhid, keyhid, nhead=10, head_dim=10, dropout=0.1, if_g=False):
super(MultiHeadAtt, self).__init__()
if if_g:
self.WQ = nn.Conv2d(nhid * 3, nhead * head_dim, 1)
else:
self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1)
self.WK = nn.Conv2d(keyhid, nhead * head_dim, 1)
self.WV = nn.Conv2d(keyhid, nhead * head_dim, 1)
self.WO = nn.Conv2d(nhead * head_dim, nhid, 1)
self.drop = nn.Dropout(dropout)
self.norm = nn.LayerNorm(nhid)
self.nhid, self.nhead, self.head_dim = nhid, nhead, head_dim
def forward(self, query_h, value, mask, query_g=None):
if not (query_g is None):
query = torch.cat([query_h, query_g], -1)
else:
query = query_h
query = query.permute(0, 2, 1)[:, :, :, None]
value = value.permute(0, 3, 1, 2)
residual = query_h
nhid, nhead, head_dim = self.nhid, self.nhead, self.head_dim
B, QL, H = query_h.shape
_, _, VL, VD = value.shape # VD = 1 or VD = QL
assert VD == 1 or VD == QL
# q: (B, H, QL, 1)
# v: (B, H, VL, VD)
q, k, v = self.WQ(query), self.WK(value), self.WV(value)
q = q.view(B, nhead, head_dim, 1, QL)
k = k.view(B, nhead, head_dim, VL, VD)
v = v.view(B, nhead, head_dim, VL, VD)
alpha = (q * k).sum(2, keepdim=True) / np.sqrt(head_dim)
alpha = alpha.masked_fill(mask[:, None, None, :, :], -np.inf)
alpha = self.drop(F.softmax(alpha, 3))
att = (alpha * v).sum(3).view(B, nhead * head_dim, QL, 1)
output = F.leaky_relu(self.WO(att)).permute(0, 2, 3, 1).view(B, QL, H)
output = self.norm(output + residual)
return output
class GloAtt(nn.Module):
def __init__(self, nhid, nhead=10, head_dim=10, dropout=0.1):
# Multi-head Self Attention Case 2, a broadcastable query for a sequence key and value
super(GloAtt, self).__init__()
self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1)
self.WK = nn.Conv2d(nhid, nhead * head_dim, 1)
self.WV = nn.Conv2d(nhid, nhead * head_dim, 1)
self.WO = nn.Conv2d(nhead * head_dim, nhid, 1)
self.drop = nn.Dropout(dropout)
self.norm = nn.LayerNorm(nhid)
# print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim)
self.nhid, self.nhead, self.head_dim = nhid, nhead, head_dim
def forward(self, x, y, mask=None):
# x: B, H, 1, 1, 1 y: B H L 1
nhid, nhead, head_dim = self.nhid, self.nhead, self.head_dim
B, L, H = y.shape
x = x.permute(0, 2, 1)[:, :, :, None]
y = y.permute(0, 2, 1)[:, :, :, None]
residual = x
q, k, v = self.WQ(x), self.WK(y), self.WV(y)
q = q.view(B, nhead, 1, head_dim) # B, H, 1, 1 -> B, N, 1, h
k = k.view(B, nhead, head_dim, L) # B, H, L, 1 -> B, N, h, L
v = v.view(B, nhead, head_dim, L).permute(0, 1, 3, 2) # B, H, L, 1 -> B, N, L, h
pre_a = torch.matmul(q, k) / np.sqrt(head_dim)
if mask is not None:
pre_a = pre_a.masked_fill(mask[:, None, None, :], -float('inf'))
alphas = self.drop(F.softmax(pre_a, 3)) # B, N, 1, L
att = torch.matmul(alphas, v).view(B, -1, 1, 1) # B, N, 1, h -> B, N*h, 1, 1
output = F.leaky_relu(self.WO(att)) + residual
output = self.norm(output.permute(0, 2, 3, 1)).view(B, 1, H)
return output
class Nodes_Cell(nn.Module):
def __init__(self, input_h, hid_h, use_global=True, dropout=0.2):
super(Nodes_Cell, self).__init__()
self.use_global = use_global
self.hidden_size = hid_h
self.Wix = nn.Linear(input_h, hid_h)
self.Wi2 = nn.Linear(input_h, hid_h)
self.Wf = nn.Linear(input_h, hid_h)
self.Wcx = nn.Linear(input_h, hid_h)
self.drop = nn.Dropout(dropout)
self.reset_parameters()
def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
nn.init.uniform_(weight, -stdv, stdv)
def forward(self, h, h2, x, glo=None):
x = self.drop(x)
if self.use_global:
glo = self.drop(glo)
cat_all = torch.cat([h, h2, x, glo], -1)
else:
cat_all = torch.cat([h, h2, x], -1)
ix = torch.sigmoid(self.Wix(cat_all))
i2 = torch.sigmoid(self.Wi2(cat_all))
f = torch.sigmoid(self.Wf(cat_all))
cx = torch.tanh(self.Wcx(cat_all))
alpha = F.softmax(torch.cat([ix.unsqueeze(1), i2.unsqueeze(1), f.unsqueeze(1)], 1), 1)
output = (alpha[:, 0] * cx) + (alpha[:, 1] * h2) + (alpha[:, 2] * h)
return output
class Edges_Cell(nn.Module):
def __init__(self, input_h, hid_h, use_global=True, dropout=0.2):
super(Edges_Cell, self).__init__()
self.use_global = use_global
self.hidden_size = hid_h
self.Wi = nn.Linear(input_h, hid_h)
self.Wf = nn.Linear(input_h, hid_h)
self.Wc = nn.Linear(input_h, hid_h)
self.drop = nn.Dropout(dropout)
self.reset_parameters()
def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
nn.init.uniform_(weight, -stdv, stdv)
def forward(self, h, x, glo=None):
x = self.drop(x)
if self.use_global:
glo = self.drop(glo)
cat_all = torch.cat([h, x, glo], -1)
else:
cat_all = torch.cat([h, x], -1)
i = torch.sigmoid(self.Wi(cat_all))
f = torch.sigmoid(self.Wf(cat_all))
c = torch.tanh(self.Wc(cat_all))
alpha = F.softmax(torch.cat([i.unsqueeze(1), f.unsqueeze(1)], 1), 1)
output = (alpha[:, 0] * c) + (alpha[:, 1] * h)
return output
class Global_Cell(nn.Module):
def __init__(self, input_h, hid_h, dropout=0.2):
super(Global_Cell, self).__init__()
self.hidden_size = hid_h
self.Wi = nn.Linear(input_h, hid_h)
self.Wf = nn.Linear(input_h, hid_h)
self.Wc = nn.Linear(input_h, hid_h)
self.drop = nn.Dropout(dropout)
self.reset_parameters()
def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
nn.init.uniform_(weight, -stdv, stdv)
def forward(self, h, x):
x = self.drop(x)
cat_all = torch.cat([h, x], -1)
i = torch.sigmoid(self.Wi(cat_all))
f = torch.sigmoid(self.Wf(cat_all))
c = torch.tanh(self.Wc(cat_all))
alpha = F.softmax(torch.cat([i.unsqueeze(1), f.unsqueeze(1)], 1), 1)
output = (alpha[:, 0] * c) + (alpha[:, 1] * h)
return output
================================================
FILE: utils/alphabet.py
================================================
# -*- coding: utf-8 -*-
# @Author: Max
# @Last Modified by: Yicheng Zou, Contact: yczou18@fudan.edu.cn
"""
Alphabet maps objects to integer ids. It provides two way mapping from the index to the objects.
"""
import json
import os
class Alphabet:
def __init__(self, name, label=False, keep_growing=True):
self.__name = name
self.UNKNOWN = "</unk>"
self.label = label
self.instance2index = {}
self.instances = []
self.keep_growing = keep_growing
# Index 0 is occupied by default, all else following.
self.default_index = 0
self.next_index = 1
if not self.label:
self.add(self.UNKNOWN)
def clear(self, keep_growing=True):
self.instance2index = {}
self.instances = []
self.keep_growing = keep_growing
# Index 0 is occupied by default, all else following.
self.default_index = 0
self.next_index = 1
def add(self, instance):
if instance not in self.instance2index:
self.instances.append(instance)
self.instance2index[instance] = self.next_index
self.next_index += 1
def get_index(self, instance):
try:
return self.instance2index[instance]
except KeyError:
if self.keep_growing:
index = self.next_index
self.add(instance)
return index
else:
return self.instance2index[self.UNKNOWN]
def get_instance(self, index):
if index == 0:
# First index is occupied by the wildcard element.
return None
try:
return self.instances[index - 1]
except IndexError:
print('WARNING:Alphabet get_instance ,unknown instance index {}, return the first label.'.format(index))
return self.instances[0]
def size(self):
return len(self.instances) + 1
def iteritems(self):
return self.instance2index.items()
def enumerate_items(self, start=1):
if start < 1 or start >= self.size():
raise IndexError("Enumerate is allowed between [1 : size of the alphabet)")
return zip(range(start, len(self.instances) + 1), self.instances[start - 1:])
def close(self):
self.keep_growing = False
def open(self):
self.keep_growing = True
def get_content(self):
return {'instance2index': self.instance2index, 'instances': self.instances}
def from_json(self, data):
self.instances = data["instances"]
self.instance2index = data["instance2index"]
def save(self, output_directory, name=None):
"""
Save both alhpabet records to the given directory.
:param output_directory: Directory to save model and weights.
:param name: The alphabet saving name, optional.
:return:
"""
saving_name = name if name else self.__name
try:
json.dump(self.get_content(), open(os.path.join(output_directory, saving_name + ".json"), 'w'))
except Exception as e:
print("Exception: Alphabet is not saved: " + repr(e))
def load(self, input_directory, name=None):
"""
Load model architecture and weights from the give directory. This allow we use old models even the structure
changes.
:param input_directory: Directory to save model and weights
:return:
"""
loading_name = name if name else self.__name
self.from_json(json.load(open(os.path.join(input_directory, loading_name + ".json"))))
================================================
FILE: utils/data.py
================================================
# -*- coding: utf-8 -*-
# @Author: Jie Yang
# @Last Modified by: Yicheng Zou, Contact: yczou18@fudan.edu.cn
import sys
from utils.alphabet import Alphabet
from utils.functions import *
from utils.word_trie import Word_Trie
class Data:
def __init__(self):
self.MAX_SENTENCE_LENGTH = 250
self.MAX_WORD_LENGTH = -1
self.number_normalized = True
self.norm_char_emb = True
self.norm_word_emb = False
self.char_alphabet = Alphabet('character')
self.label_alphabet = Alphabet('label', True)
self.word_dict = Word_Trie()
self.word_alphabet = Alphabet('word')
self.train_texts = []
self.dev_texts = []
self.test_texts = []
self.raw_texts = []
self.train_Ids = []
self.dev_Ids = []
self.test_Ids = []
self.raw_Ids = []
self.char_emb_dim = 50
self.word_emb_dim = 50
self.pretrain_char_embedding = None
self.pretrain_word_embedding = None
self.label_size = 0
def show_data_summary(self):
print("DATA SUMMARY:")
print(" MAX SENTENCE LENGTH: %s"%(self.MAX_SENTENCE_LENGTH))
print(" MAX WORD LENGTH: %s"%(self.MAX_WORD_LENGTH))
print(" Number normalized: %s"%(self.number_normalized))
print(" Word alphabet size: %s"%(self.word_alphabet.size()))
print(" Char alphabet size: %s"%(self.char_alphabet.size()))
print(" Label alphabet size: %s"%(self.label_alphabet.size()))
print(" Word embedding size: %s"%(self.word_emb_dim))
print(" Char embedding size: %s"%(self.char_emb_dim))
print(" Norm char emb: %s"%(self.norm_char_emb))
print(" Norm word emb: %s"%(self.norm_word_emb))
print(" Train instance number: %s"%(len(self.train_texts)))
print(" Dev instance number: %s"%(len(self.dev_texts)))
print(" Test instance number: %s"%(len(self.test_texts)))
print(" Raw instance number: %s"%(len(self.raw_texts)))
print("DATA SUMMARY END.")
sys.stdout.flush()
def build_alphabet(self, input_file):
self.char_alphabet.open()
self.label_alphabet.open()
with open(input_file, 'r', encoding="utf-8") as f:
for line in f:
line = line.strip()
if len(line) == 0:
continue
pair = line.split()
char = pair[0]
if self.number_normalized:
# Mapping numbers to 0
char = normalize_word(char)
label = pair[-1]
self.label_alphabet.add(label)
self.char_alphabet.add(char)
self.label_alphabet.close()
self.char_alphabet.close()
def build_word_file(self, word_file):
# build word file,initial word embedding file
with open(word_file, 'r', encoding="utf-8") as f:
for line in f:
word = line.strip().split()[0]
if word:
self.word_dict.insert(word)
print("Building the word dict...")
def build_word_alphabet(self, input_file):
print("Loading file: " + input_file)
self.word_alphabet.open()
word_list = []
with open(input_file, 'r', encoding="utf-8") as f:
for line in f:
line = line.strip()
if len(line) > 0:
word = line.split()[0]
if self.number_normalized:
word = normalize_word(word)
word_list.append(word)
else:
for idx in range(len(word_list)):
matched_words = self.word_dict.recursive_search(word_list[idx:])
for matched_word in matched_words:
self.word_alphabet.add(matched_word)
word_list = []
self.word_alphabet.close()
print("word alphabet size:", self.word_alphabet.size())
def build_char_pretrain_emb(self, emb_path):
print ("Building character pretrain emb...")
self.pretrain_char_embedding, self.char_emb_dim = build_pretrain_embedding(emb_path, self.char_alphabet, self.norm_char_emb)
def build_word_pretrain_emb(self, emb_path):
print ("Building word pretrain emb...")
self.pretrain_word_embedding, self.word_emb_dim = build_pretrain_embedding(emb_path, self.word_alphabet, self.norm_word_emb)
def generate_instance_with_words(self, input_file, name):
if name == "train":
self.train_texts, self.train_Ids = read_instance_with_gaz(input_file, self.word_dict, self.char_alphabet,
self.word_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
elif name == "dev":
self.dev_texts, self.dev_Ids = read_instance_with_gaz(input_file, self.word_dict, self.char_alphabet,
self.word_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
elif name == "test":
self.test_texts, self.test_Ids = read_instance_with_gaz(input_file, self.word_dict, self.char_alphabet,
self.word_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
elif name == "raw":
self.raw_texts, self.raw_Ids = read_instance_with_gaz(input_file, self.word_dict, self.char_alphabet,
self.word_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
else:
print("Error: you can only generate train/dev/test/raw instance! Illegal input:%s"%(name))
def write_decoded_results(self, output_file, predict_results, name):
fout = open(output_file, 'w', encoding="utf-8")
sent_num = len(predict_results)
content_list = []
if name == 'raw':
content_list = self.raw_texts
elif name == 'test':
content_list = self.test_texts
elif name == 'dev':
content_list = self.dev_texts
elif name == 'train':
content_list = self.train_texts
else:
print("Error: illegal name during writing predict result, name should be within train/dev/test/raw !")
assert(sent_num == len(content_list))
for idx in range(sent_num):
sent_length = len(predict_results[idx])
for idy in range(sent_length):
# content_list[idx] is a list with [word, char, label]
fout.write(content_list[idx][0][idy] + " " + predict_results[idx][idy] + '\n')
fout.write('\n')
fout.close()
print("Predict %s result has been written into file. %s"%(name, output_file))
================================================
FILE: utils/functions.py
================================================
# -*- coding: utf-8 -*-
# @Author: Jie Yang
# @Last Modified by: Yicheng Zou, Contact: yczou18@fudan.edu.cn
import numpy as np
def normalize_word(word):
new_word = ""
for char in word:
if char.isdigit():
new_word += '0'
else:
new_word += char
return new_word
def read_instance_with_gaz(input_file, word_dict, char_alphabet, word_alphabet, label_alphabet, number_normalized, max_sent_length):
instence_texts = []
instence_Ids = []
with open(input_file, 'r', encoding="utf-8") as f:
chars = []
labels = []
char_Ids = []
label_Ids = []
for line in f:
if len(line) > 1:
pairs = line.strip().split()
char = pairs[0]
if number_normalized:
char = normalize_word(char)
chars.append(char)
char_Ids.append(char_alphabet.get_index(char))
if len(pairs) > 1:
label = pairs[-1]
else:
label = 'O'
labels.append(label)
label_Ids.append(label_alphabet.get_index(label))
# A sentence is finished.
else:
# Only keep the sentence whose length is smaller than MAX_SENT_LENGTH.
if ((max_sent_length < 0) or (len(chars) < max_sent_length)) and (len(chars)>0):
words = []
word_Ids = []
for idx in range(len(chars)):
matched_list = word_dict.recursive_search(chars[idx:])
matched_length = [len(a) for a in matched_list]
words.append(matched_list)
matched_Id = [word_alphabet.get_index(word) for word in matched_list]
if matched_Id:
word_Ids.append([matched_Id, matched_length])
else:
word_Ids.append([])
instence_texts.append([chars, words, labels])
instence_Ids.append([char_Ids, word_Ids, label_Ids])
chars = []
labels = []
char_Ids = []
label_Ids = []
return instence_texts, instence_Ids
def build_pretrain_embedding(embedding_path, word_alphabet, norm=True, embedd_dim=50):
def norm2one(vec):
root_sum_square = np.sqrt(np.sum(np.square(vec)))
return vec / root_sum_square
embedd_dict = dict()
if embedding_path != None:
embedd_dict, embedd_dim = load_pretrain_emb(embedding_path)
scale = np.sqrt(3.0 / embedd_dim)
pretrain_emb = np.empty([word_alphabet.size(), embedd_dim])
not_match = 0
for word, index in word_alphabet.instance2index.items():
if word.lower() in embedd_dict:
if norm:
pretrain_emb[index,:] = norm2one(embedd_dict[word.lower()])
else:
pretrain_emb[index,:] = embedd_dict[word.lower()]
elif word in embedd_dict:
if norm:
pretrain_emb[index,:] = norm2one(embedd_dict[word])
else:
pretrain_emb[index,:] = embedd_dict[word]
else:
pretrain_emb[index,:] = np.random.uniform(-scale, scale, [1, embedd_dim])
not_match += 1
pretrained_size = len(embedd_dict)
print("Embedding:\n pretrain word:%s, match:%s, oov:%s, oov%%:%.4f" %
(pretrained_size, word_alphabet.size() - not_match, not_match, (not_match+0.)/word_alphabet.size()))
return pretrain_emb, embedd_dim
def load_pretrain_emb(embedding_path):
embedd_dict = dict()
with open(embedding_path, 'r', encoding="utf-8") as f:
for line in f:
line = line.strip()
if len(line) == 0:
continue
tokens = line.split()
embedd_dim = len(tokens) - 1
embedd = np.empty([1, embedd_dim])
embedd[:] = tokens[1:]
embedd_dict[tokens[0]] = embedd
return embedd_dict, embedd_dim
================================================
FILE: utils/metric.py
================================================
# -*- coding: utf-8 -*-
# @Author: Jie Yang
# @Last Modified by: Yicheng Zou, Contact: yczou18@fudan.edu.cn
# input as sentence level labels
def get_ner_fmeasure(golden_lists, predict_lists):
sent_num = len(golden_lists)
golden_full = []
predict_full = []
right_full = []
right_tag = 0
all_tag = 0
for idx in range(0,sent_num):
golden_list = golden_lists[idx]
predict_list = predict_lists[idx]
for idy in range(len(golden_list)):
if golden_list[idy] == predict_list[idy]:
right_tag += 1
all_tag += len(golden_list)
gold_matrix = get_ner_BMES(golden_list)
pred_matrix = get_ner_BMES(predict_list)
right_ner = list(set(gold_matrix).intersection(set(pred_matrix)))
golden_full += gold_matrix
predict_full += pred_matrix
right_full += right_ner
right_num = len(right_full)
golden_num = len(golden_full)
predict_num = len(predict_full)
if predict_num == 0:
precision = -1
else:
precision = (right_num+0.0)/predict_num
if golden_num == 0:
recall = -1
else:
recall = (right_num+0.0)/golden_num
if (precision == -1) or (recall == -1) or (precision+recall) <= 0.:
f_measure = -1
else:
f_measure = 2*precision*recall/(precision+recall)
accuracy = (right_tag+0.0)/all_tag
print("gold_num = ", golden_num, " pred_num = ", predict_num, " right_num = ", right_num)
return accuracy, precision, recall, f_measure
def reverse_style(input_string):
target_position = input_string.index('[')
input_len = len(input_string)
output_string = input_string[target_position:input_len] + input_string[0:target_position]
return output_string
def get_ner_BMES(label_list):
list_len = len(label_list)
begin_label = 'B-'
end_label = 'E-'
single_label = 'S-'
whole_tag = ''
index_tag = ''
tag_list = []
stand_matrix = []
for i in range(0, list_len):
# wordlabel = word_list[i]
current_label = label_list[i].upper() if label_list[i] else []
if begin_label in current_label:
if index_tag != '':
tag_list.append(whole_tag + ',' + str(i-1))
whole_tag = current_label.replace(begin_label,"",1) +'[' +str(i)
index_tag = current_label.replace(begin_label,"",1)
elif single_label in current_label:
if index_tag != '':
tag_list.append(whole_tag + ',' + str(i-1))
whole_tag = current_label.replace(single_label,"",1) +'[' +str(i)
tag_list.append(whole_tag)
whole_tag = ""
index_tag = ""
elif end_label in current_label:
if index_tag != '':
tag_list.append(whole_tag +',' + str(i))
whole_tag = ''
index_tag = ''
else:
continue
if (whole_tag != '')&(index_tag != ''):
tag_list.append(whole_tag)
tag_list_len = len(tag_list)
for i in range(0, tag_list_len):
if len(tag_list[i]) > 0:
tag_list[i] = tag_list[i]+ ']'
insert_list = reverse_style(tag_list[i])
stand_matrix.append(insert_list)
return stand_matrix
================================================
FILE: utils/word_trie.py
================================================
# -*- coding: utf-8 -*-
# @Author: Yicheng Zou
# @Last Modified by: Yicheng Zou, Contact: yczou18@fudan.edu.cn
_end = "_end_"
class Word_Trie:
def __init__(self):
self.root = dict()
def recursive_search(self, word_list):
match_list = []
while len(word_list) > 1:
if self.search(word_list):
match_list.append("".join(word_list))
del word_list[-1]
return match_list
def search(self, word):
current_dict = self.root
for char in word:
if char in current_dict:
current_dict = current_dict[char]
else:
return False
else:
if _end in current_dict:
return True
else:
return False
def insert(self, word):
current_dict = self.root
for char in word:
current_dict = current_dict.setdefault(char, {})
current_dict[_end] = _end
gitextract_y25cez6i/
├── LICENSE
├── README.md
├── main.py
├── model/
│ ├── LGN.py
│ ├── crf.py
│ └── module.py
└── utils/
├── alphabet.py
├── data.py
├── functions.py
├── metric.py
└── word_trie.py
SYMBOL INDEX (78 symbols across 9 files)
FILE: main.py
function str2bool (line 21) | def str2bool(v):
function lr_decay (line 32) | def lr_decay(optimizer, epoch, decay_rate, init_lr):
function data_initialization (line 43) | def data_initialization(data, word_file, train_file, dev_file, test_file):
function predict_check (line 59) | def predict_check(pred_variable, gold_variable, mask_variable):
function recover_label (line 70) | def recover_label(pred_variable, gold_variable, mask_variable, label_alp...
function print_args (line 90) | def print_args(args):
function evaluate (line 113) | def evaluate(data, args, model, name):
function batchify_with_label (line 160) | def batchify_with_label(input_batch_list, gpu):
function train (line 186) | def train(data, args, saved_model_path):
function load_model_decode (line 337) | def load_model_decode(model_dir, data, args, name):
FILE: model/LGN.py
class Graph (line 9) | class Graph(nn.Module):
method __init__ (line 10) | def __init__(self, data, args):
method construct_graph (line 169) | def construct_graph(self, batch_size, seq_len, word_list):
method update_graph (line 282) | def update_graph(self, word_list, word_inputs, mask):
method forward (line 441) | def forward(self, word_list, batch_inputs, mask, batch_label=None):
FILE: model/crf.py
function log_sum_exp (line 13) | def log_sum_exp(vec, m_size):
class CRF (line 27) | class CRF(nn.Module):
method __init__ (line 29) | def __init__(self, tagset_size, gpu):
method _calculate_PZ (line 51) | def _calculate_PZ(self, feats, mask):
method _viterbi_decode (line 105) | def _viterbi_decode(self, feats, mask):
method forward (line 193) | def forward(self, feats):
method _score_sentence (line 198) | def _score_sentence(self, scores, mask, tags):
method neg_log_likelihood_loss (line 249) | def neg_log_likelihood_loss(self, feats, mask, tags):
FILE: model/module.py
class MultiHeadAtt (line 13) | class MultiHeadAtt(nn.Module):
method __init__ (line 14) | def __init__(self, nhid, keyhid, nhead=10, head_dim=10, dropout=0.1, i...
method forward (line 31) | def forward(self, query_h, value, mask, query_g=None):
class GloAtt (line 67) | class GloAtt(nn.Module):
method __init__ (line 68) | def __init__(self, nhid, nhead=10, head_dim=10, dropout=0.1):
method forward (line 83) | def forward(self, x, y, mask=None):
class Nodes_Cell (line 109) | class Nodes_Cell(nn.Module):
method __init__ (line 110) | def __init__(self, input_h, hid_h, use_global=True, dropout=0.2):
method reset_parameters (line 123) | def reset_parameters(self):
method forward (line 128) | def forward(self, h, h2, x, glo=None):
class Edges_Cell (line 149) | class Edges_Cell(nn.Module):
method __init__ (line 150) | def __init__(self, input_h, hid_h, use_global=True, dropout=0.2):
method reset_parameters (line 163) | def reset_parameters(self):
method forward (line 168) | def forward(self, h, x, glo=None):
class Global_Cell (line 188) | class Global_Cell(nn.Module):
method __init__ (line 189) | def __init__(self, input_h, hid_h, dropout=0.2):
method reset_parameters (line 201) | def reset_parameters(self):
method forward (line 206) | def forward(self, h, x):
FILE: utils/alphabet.py
class Alphabet (line 12) | class Alphabet:
method __init__ (line 13) | def __init__(self, name, label=False, keep_growing=True):
method clear (line 27) | def clear(self, keep_growing=True):
method add (line 36) | def add(self, instance):
method get_index (line 42) | def get_index(self, instance):
method get_instance (line 53) | def get_instance(self, index):
method size (line 63) | def size(self):
method iteritems (line 66) | def iteritems(self):
method enumerate_items (line 69) | def enumerate_items(self, start=1):
method close (line 74) | def close(self):
method open (line 77) | def open(self):
method get_content (line 80) | def get_content(self):
method from_json (line 83) | def from_json(self, data):
method save (line 87) | def save(self, output_directory, name=None):
method load (line 100) | def load(self, input_directory, name=None):
FILE: utils/data.py
class Data (line 11) | class Data:
method __init__ (line 12) | def __init__(self):
method show_data_summary (line 38) | def show_data_summary(self):
method build_alphabet (line 57) | def build_alphabet(self, input_file):
method build_word_file (line 78) | def build_word_file(self, word_file):
method build_word_alphabet (line 87) | def build_word_alphabet(self, input_file):
method build_char_pretrain_emb (line 108) | def build_char_pretrain_emb(self, emb_path):
method build_word_pretrain_emb (line 112) | def build_word_pretrain_emb(self, emb_path):
method generate_instance_with_words (line 116) | def generate_instance_with_words(self, input_file, name):
method write_decoded_results (line 132) | def write_decoded_results(self, output_file, predict_results, name):
FILE: utils/functions.py
function normalize_word (line 8) | def normalize_word(word):
function read_instance_with_gaz (line 18) | def read_instance_with_gaz(input_file, word_dict, char_alphabet, word_al...
function build_pretrain_embedding (line 71) | def build_pretrain_embedding(embedding_path, word_alphabet, norm=True, e...
function load_pretrain_emb (line 104) | def load_pretrain_emb(embedding_path):
FILE: utils/metric.py
function get_ner_fmeasure (line 7) | def get_ner_fmeasure(golden_lists, predict_lists):
function reverse_style (line 49) | def reverse_style(input_string):
function get_ner_BMES (line 56) | def get_ner_BMES(label_list):
FILE: utils/word_trie.py
class Word_Trie (line 8) | class Word_Trie:
method __init__ (line 9) | def __init__(self):
method recursive_search (line 12) | def recursive_search(self, word_list):
method search (line 20) | def search(self, word):
method insert (line 33) | def insert(self, word):
Condensed preview — 11 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (89K chars).
[
{
"path": "LICENSE",
"chars": 1068,
"preview": "MIT License\n\nCopyright (c) 2019 Yicheng Zou\n\nPermission is hereby granted, free of charge, to any person obtaining a cop"
},
{
"path": "README.md",
"chars": 3714,
"preview": "# LGN\n\nPytorch implementation of [A Lexicon-Based Graph Neural Network for Chinese NER](https://www.aclweb.org/anthology"
},
{
"path": "main.py",
"chars": 18828,
"preview": "# -*- coding: utf-8 -*-\n# @Author: Yicheng Zou\n# @Last Modified by: Yicheng Zou, Contact: yczou18@fudan.edu.cn\n\nimp"
},
{
"path": "model/LGN.py",
"chars": 22912,
"preview": "# -*- coding: utf-8 -*-\n# @Author: Yicheng Zou\n# @Last Modified by: Yicheng Zou, Contact: yczou18@fudan.edu.cn\n\nfr"
},
{
"path": "model/crf.py",
"chars": 13801,
"preview": "# -*- coding: utf-8 -*-\n# @Author: Jie Yang\n# @Last Modified by: Yicheng Zou, Contact: yczou18@fudan.edu.cn\n\nimpor"
},
{
"path": "model/module.py",
"chars": 7012,
"preview": "# -*- coding: utf-8 -*-\n# @Author: Yicheng Zou\n# @Last Modified by: Yicheng Zou, Contact: yczou18@fudan.edu.cn\n\n\ni"
},
{
"path": "utils/alphabet.py",
"chars": 3625,
"preview": "# -*- coding: utf-8 -*-\n# @Author: Max\n# @Last Modified by: Yicheng Zou, Contact: yczou18@fudan.edu.cn\n\n\"\"\"\nAlphab"
},
{
"path": "utils/data.py",
"chars": 6868,
"preview": "# -*- coding: utf-8 -*-\n# @Author: Jie Yang\n# @Last Modified by: Yicheng Zou, Contact: yczou18@fudan.edu.cn\n\nimpor"
},
{
"path": "utils/functions.py",
"chars": 4126,
"preview": "# -*- coding: utf-8 -*-\n# @Author: Jie Yang\n# @Last Modified by: Yicheng Zou, Contact: yczou18@fudan.edu.cn\n\nimpor"
},
{
"path": "utils/metric.py",
"chars": 3286,
"preview": "# -*- coding: utf-8 -*-\n# @Author: Jie Yang\n# @Last Modified by: Yicheng Zou, Contact: yczou18@fudan.edu.cn\n\n\n# in"
},
{
"path": "utils/word_trie.py",
"chars": 985,
"preview": "# -*- coding: utf-8 -*-\n# @Author: Yicheng Zou\n# @Last Modified by: Yicheng Zou, Contact: yczou18@fudan.edu.cn\n\n_e"
}
]
About this extraction
This page contains the full source code of the RowitZou/LGN GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 11 files (84.2 KB), approximately 21.7k tokens, and a symbol index with 78 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.