Repository: RowitZou/LGN Branch: master Commit: c1e7a681f514 Files: 11 Total size: 84.2 KB Directory structure: gitextract_y25cez6i/ ├── LICENSE ├── README.md ├── main.py ├── model/ │ ├── LGN.py │ ├── crf.py │ └── module.py └── utils/ ├── alphabet.py ├── data.py ├── functions.py ├── metric.py └── word_trie.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2019 Yicheng Zou Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # LGN Pytorch implementation of [A Lexicon-Based Graph Neural Network for Chinese NER](https://www.aclweb.org/anthology/D19-1096.pdf). The code is partially referred to https://github.com/jiesutd/LatticeLSTM. ## Requirements * Python 3.6 or higher * Pytorch 0.4.1 or higher ## Input Format BMES tag scheme, with each character its label for one line. Sentences are splited with a null line. 印 B-LOC 度 M-LOC 河 E-LOC 流 O 经 O 印 B-GPE 度 E-GPE ## Usage * Training python main.py --status train \ --train data/onto4ner.cn/train.char.bmes \ --dev data/onto4ner.cn/dev.char.bmes \ --test data/onto4ner.cn/test.char.bmes \ --saved_model saved_model/model_onto4ner \ --saved_set data/onto4ner.cn/saved.dset * Testing python main.py --status test \ --test data/onto4ner.cn/test.char.bmes \ --saved_model saved_model/model_onto4ner \ --saved_set data/onto4ner.cn/saved.dset * Decoding (Raw file can either be labeled or not.) python main.py --status decode \ --raw data/onto4ner.cn/test.char.bmes \ --output tagged_file.txt \ --saved_model saved_model/model_onto4ner \ --saved_set data/onto4ner.cn/saved.dset ## Data The pretrained character and word embeddings can be downloaded from [Lattice LSTM](https://github.com/jiesutd/LatticeLSTM). Original datasets can be found at [OntoNotes](https://catalog.ldc.upenn.edu/LDC2011T03), [MSRA](http://sighan.cs.uchicago.edu/bakeoff2006/), [Weibo](https://github.com/hltcoe/golden-horse) and [Resume](https://github.com/jiesutd/LatticeLSTM/tree/master/ResumeNER). The preprocessed datasets that satisfy the input format of our codes are available at [Google Drive](https://drive.google.com/open?id=1Rvju5_gp2E6BFiqzMBtnMqVP803AbBcm) and [Baidu Pan](https://pan.baidu.com/s/1zbzLriRpc8S_5ez_upC7OA) (Code: akcm) ## Pretrained Model Downloads We also provide pretrained models on the four datasets, which are the same models as reported in the paper. If you try to retrain models from scratch under the same hyper-parameter settings, you may obtain a sightly lower or higher F1 score than that reported in the paper (in our experiments we selected the model that performed best). Pretrained models and related hyper-parameter settings are available at [Google Drive](https://drive.google.com/file/d/1KKkCW8WRhgR2P2UbRpNpKyE_RAv1EREv/view?usp=sharing) or [Baidu Pan](https://pan.baidu.com/s/1U89EwnhPpMa4bNrS--u4EA). When running main.py in test mode for pretrained models, you can get the results as follows: | Datasets | Precision | Recall | F1 | |:--------------:|:---------:|:-------:|:-----:| | OntoNotes dev | 74.00 | 70.03 | 71.96 | | OntoNotes test | 76.13 | 73.68 | 74.89 | | MSRA dev | - | - | - | | MSRA test | 94.19 | 92.73 | 93.46 | | Weibo dev | 66.09 | 59.13 | 62.42 | | Weibo test | 65.71 | 55.56 | 60.21 | | Resume dev | 94.27 | 94.59 | 94.43 | | Resume test | 95.28 | 95.46 | 95.37 | ## Citation @inproceedings{gui2019lexicon, title={A Lexicon-Based Graph Neural Network for Chinese NER}, author={Gui, Tao and Zou, Yicheng and Zhang, Qi and Peng, Minlong and Fu, Jinlan and Wei, Zhongyu and Huang, Xuanjing}, booktitle={Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)}, pages={1039--1049}, year={2019} } ================================================ FILE: main.py ================================================ # -*- coding: utf-8 -*- # @Author: Yicheng Zou # @Last Modified by: Yicheng Zou, Contact: yczou18@fudan.edu.cn import time import sys import argparse import random import torch import gc import pickle import os import torch.autograd as autograd import torch.optim as optim import numpy as np from utils.metric import get_ner_fmeasure from model.LGN import Graph from utils.data import Data def str2bool(v): if isinstance(v, bool): return v if v.lower() in ('yes', 'true', 't', 'y', '1'): return True elif v.lower() in ('no', 'false', 'f', 'n', '0'): return False else: raise argparse.ArgumentTypeError('Boolean value expected.') def lr_decay(optimizer, epoch, decay_rate, init_lr): lr = init_lr * ((1-decay_rate)**epoch) print( " Learning rate is setted as:", lr) for param_group in optimizer.param_groups: if param_group['name'] == 'aggr': param_group['lr'] = lr * 2. else: param_group['lr'] = lr return optimizer def data_initialization(data, word_file, train_file, dev_file, test_file): data.build_word_file(word_file) if train_file: data.build_alphabet(train_file) data.build_word_alphabet(train_file) if dev_file: data.build_alphabet(dev_file) data.build_word_alphabet(dev_file) if test_file: data.build_alphabet(test_file) data.build_word_alphabet(test_file) return data def predict_check(pred_variable, gold_variable, mask_variable): pred = pred_variable.cpu().data.numpy() gold = gold_variable.cpu().data.numpy() mask = mask_variable.cpu().data.numpy() overlaped = (pred == gold) right_token = np.sum(overlaped * mask) total_token = mask.sum() return right_token, total_token def recover_label(pred_variable, gold_variable, mask_variable, label_alphabet): batch_size = gold_variable.size(0) seq_len = gold_variable.size(1) mask = mask_variable.cpu().data.numpy() pred_tag = pred_variable.cpu().data.numpy() gold_tag = gold_variable.cpu().data.numpy() pred_label = [] gold_label = [] for idx in range(batch_size): pred = [label_alphabet.get_instance(pred_tag[idx][idy]) for idy in range(seq_len) if mask[idx][idy] != 0] gold = [label_alphabet.get_instance(gold_tag[idx][idy]) for idy in range(seq_len) if mask[idx][idy] != 0] assert(len(pred)==len(gold)) pred_label.append(pred) gold_label.append(gold) return pred_label, gold_label def print_args(args): print("CONFIG SUMMARY:") print(" Batch size: %s" % (args.batch_size)) print(" If use GPU: %s" % (args.use_gpu)) print(" If use CRF: %s" % (args.use_crf)) print(" Epoch number: %s" % (args.num_epoch)) print(" Learning rate: %s" % (args.lr)) print(" L2 normalization rate: %s" % (args.weight_decay)) print(" If use edge embedding: %s" % (args.use_edge)) print(" If use global node: %s" % (args.use_global)) print(" Bidirectional digraph: %s" % (args.bidirectional)) print(" Update step number: %s" % (args.iters)) print(" Attention dropout rate: %s" % (args.tf_drop_rate)) print(" Embedding dropout rate: %s" % (args.emb_drop_rate)) print(" Hidden state dimension: %s" % (args.hidden_dim)) print(" Learning rate decay ratio: %s" % (args.lr_decay)) print(" Aggregation module dropout rate: %s" % (args.cell_drop_rate)) print(" Head number of attention: %s" % (args.num_head)) print(" Head dimension of attention: %s" % (args.head_dim)) print("CONFIG SUMMARY END.") sys.stdout.flush() def evaluate(data, args, model, name): if name == "train": instances = data.train_Ids elif name == "dev": instances = data.dev_Ids elif name == 'test': instances = data.test_Ids elif name == 'raw': instances = data.raw_Ids else: print("Error: wrong evaluate name,", name) exit(0) pred_results = [] gold_results = [] # set model in eval model model.eval() batch_size = args.batch_size start_time = time.time() train_num = len(instances) total_batch = train_num // batch_size + 1 for batch_id in range(total_batch): start = batch_id*batch_size end = (batch_id+1)*batch_size if end > train_num: end = train_num instance = instances[start:end] if not instance: continue word_list, batch_char, batch_label, mask = batchify_with_label(instance, args.use_gpu) _, tag_seq = model(word_list, batch_char, mask) pred_label, gold_label = recover_label(tag_seq, batch_label, mask, data.label_alphabet) pred_results += pred_label gold_results += gold_label decode_time = time.time() - start_time speed = len(instances) / decode_time acc, p, r, f = get_ner_fmeasure(gold_results, pred_results) return speed, acc, p, r, f, pred_results def batchify_with_label(input_batch_list, gpu): batch_size = len(input_batch_list) chars = [sent[0] for sent in input_batch_list] words = [sent[1] for sent in input_batch_list] labels = [sent[2] for sent in input_batch_list] sent_lengths = torch.LongTensor(list(map(len, chars))) max_sent_len = sent_lengths.max() char_seq_tensor = autograd.Variable(torch.zeros((batch_size, max_sent_len))).long() label_seq_tensor = autograd.Variable(torch.zeros((batch_size, max_sent_len))).long() mask = autograd.Variable(torch.zeros((batch_size, max_sent_len))).byte() for idx, (seq, label, seq_len) in enumerate(zip(chars, labels, sent_lengths)): char_seq_tensor[idx, :seq_len] = torch.LongTensor(seq) label_seq_tensor[idx, :seq_len] = torch.LongTensor(label) mask[idx, :seq_len] = torch.Tensor([1] * int(seq_len)) if gpu: char_seq_tensor = char_seq_tensor.cuda() label_seq_tensor = label_seq_tensor.cuda() mask = mask.cuda() return words, char_seq_tensor, label_seq_tensor, mask def train(data, args, saved_model_path): print( "Training model...") model = Graph(data, args) if args.use_gpu: model = model.cuda() print('# generated parameters:', sum(param.numel() for param in model.parameters())) print( "Finished built model.") best_dev_epoch = 0 best_dev_f = -1 best_dev_p = -1 best_dev_r = -1 best_test_f = -1 best_test_p = -1 best_test_r = -1 # Initialize the optimizer aggr_module_params = [] other_module_params = [] for m_name in model._modules: m = model._modules[m_name] if isinstance(m, torch.nn.ModuleList): for p in m.parameters(): if p.requires_grad: aggr_module_params.append(p) else: for p in m.parameters(): if p.requires_grad: other_module_params.append(p) optimizer = optim.Adam([ {"params": (aggr_module_params), "name": "aggr"}, {"params": (other_module_params), "name": "other"} ], lr=args.lr, weight_decay=args.weight_decay ) for idx in range(args.num_epoch): epoch_start = time.time() temp_start = epoch_start print(("Epoch: %s/%s" %(idx, args.num_epoch))) optimizer = lr_decay(optimizer, idx, args.lr_decay, args.lr) sample_loss = 0 batch_loss = 0 total_loss = 0 right_token = 0 whole_token = 0 random.shuffle(data.train_Ids) # set model in train model model.train() model.zero_grad() batch_size = args.batch_size train_num = len(data.train_Ids) total_batch = train_num // batch_size + 1 for batch_id in range(total_batch): # Get one batch-sized instance start = batch_id * batch_size end = (batch_id + 1) * batch_size if end > train_num: end = train_num instance = data.train_Ids[start:end] if not instance: continue word_list, batch_char, batch_label, mask = batchify_with_label(instance, args.use_gpu) loss, tag_seq = model(word_list, batch_char, mask, batch_label) right, whole = predict_check(tag_seq, batch_label, mask) right_token += right whole_token += whole sample_loss += loss.data total_loss += loss.data batch_loss += loss if end % 500 == 0: temp_time = time.time() temp_cost = temp_time - temp_start temp_start = temp_time print((" Instance: %s; Time: %.2fs; loss: %.4f; acc: %s/%s=%.4f" % (end, temp_cost, sample_loss, right_token, whole_token, (right_token+0.)/whole_token))) sys.stdout.flush() sample_loss = 0 if end % args.batch_size == 0: batch_loss.backward() optimizer.step() model.zero_grad() batch_loss = 0 temp_time = time.time() temp_cost = temp_time - temp_start print((" Instance: %s; Time: %.2fs; loss: %.4f; acc: %s/%s=%.4f" % (end, temp_cost, sample_loss, right_token, whole_token, (right_token+0.)/whole_token))) epoch_finish = time.time() epoch_cost = epoch_finish - epoch_start print(("Epoch: %s training finished. Time: %.2fs, speed: %.2fst/s, total loss: %s" % (idx, epoch_cost, train_num/epoch_cost, total_loss))) # dev speed, acc, dev_p, dev_r, dev_f, _ = evaluate(data, args, model, "dev") dev_finish = time.time() dev_cost = dev_finish - epoch_finish print(("Dev: time: %.2fs, speed: %.2fst/s; acc: %.4f, p: %.4f, r: %.4f, f: %.4f" % (dev_cost, speed, acc, dev_p, dev_r, dev_f))) # test speed, acc, test_p, test_r, test_f, _ = evaluate(data, args, model, "test") test_finish = time.time() test_cost = test_finish - dev_finish print(("Test: time: %.2fs, speed: %.2fst/s; acc: %.4f, p: %.4f, r: %.4f, f: %.4f" % (test_cost, speed, acc, test_p, test_r, test_f))) if dev_f > best_dev_f: print("Exceed previous best f score: %.4f" % best_dev_f) torch.save(model.state_dict(), saved_model_path + "_best") best_dev_p = dev_p best_dev_r = dev_r best_dev_f = dev_f best_dev_epoch = idx + 1 best_test_p = test_p best_test_r = test_r best_test_f = test_f model_idx_path = saved_model_path + "_" + str(idx) torch.save(model.state_dict(), model_idx_path) with open(saved_model_path + "_result.txt", "a") as file: file.write(model_idx_path + '\n') file.write("Dev score: %.4f, r: %.4f, f: %.4f\n" % (dev_p, dev_r, dev_f)) file.write("Test score: %.4f, r: %.4f, f: %.4f\n\n" % (test_p, test_r, test_f)) file.close() print("Best dev epoch: %d" % best_dev_epoch) print("Best dev score: p: %.4f, r: %.4f, f: %.4f" % (best_dev_p, best_dev_r, best_dev_f)) print("Best test score: p: %.4f, r: %.4f, f: %.4f" % (best_test_p, best_test_r, best_test_f)) gc.collect() with open(saved_model_path + "_result.txt", "a") as file: file.write("Best epoch: %d" % best_dev_epoch + '\n') file.write("Best Dev score: %.4f, r: %.4f, f: %.4f\n" % (best_dev_p, best_dev_r, best_dev_f)) file.write("Test score: %.4f, r: %.4f, f: %.4f\n\n" % (best_test_p, best_test_r, best_test_f)) file.close() with open(saved_model_path + "_best_HP.config", "wb") as file: pickle.dump(args, file) def load_model_decode(model_dir, data, args, name): model_dir = model_dir + "_best" print("Load Model from file: ", model_dir) model = Graph(data, args) model.load_state_dict(torch.load(model_dir)) # load model need consider if the model trained in GPU and load in CPU, or vice versa if args.use_gpu: model = model.cuda() print(("Decode %s data ..." % name)) start_time = time.time() speed, acc, p, r, f, pred_results = evaluate(data, args, model, name) end_time = time.time() time_cost = end_time - start_time print(("%s: time:%.2fs, speed:%.2fst/s; acc: %.4f, p: %.4f, r: %.4f, f: %.4f" % (name, time_cost, speed, acc, p, r, f))) return pred_results if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--status', choices=['train', 'test', 'decode'], help='Function status.', default='train') parser.add_argument('--use_gpu', type=str2bool, default=True) parser.add_argument('--train', help='Training set.', default='data/onto4ner.cn/train.char.bmes') parser.add_argument('--dev', help='Developing set.', default='data/onto4ner.cn/dev.char.bmes') parser.add_argument('--test', help='Testing set.', default='data/onto4ner.cn/test.char.bmes') parser.add_argument('--raw', help='Raw file for decoding.') parser.add_argument('--output', help='Output results for decoding.') parser.add_argument('--saved_set', help='Path of saved data set.', default='data/onto4ner.cn/saved.dset') parser.add_argument('--saved_model', help='Path of saved model.', default="saved_model/model_onto4ner") parser.add_argument('--char_emb', help='Path of character embedding file.', default="data/gigaword_chn.all.a2b.uni.ite50.vec") parser.add_argument('--word_emb', help='Path of word embedding file.', default="data/ctb.50d.vec") parser.add_argument('--use_crf', type=str2bool, default=True) parser.add_argument('--use_edge', type=str2bool, default=True, help='If use lexicon embeddings (edge embeddings).') parser.add_argument('--use_global', type=str2bool, default=True, help='If use the global node.') parser.add_argument('--bidirectional', type=str2bool, default=True, help='If use bidirectional digraph.') parser.add_argument('--seed', help='Random seed', default=1023, type=int) parser.add_argument('--batch_size', help='Batch size.', default=1, type=int) parser.add_argument('--num_epoch',default=100, type=int, help="Epoch number.") parser.add_argument('--iters', default=4, type=int, help='The number of Graph iterations.') parser.add_argument('--hidden_dim', default=50, type=int, help='Hidden state size.') parser.add_argument('--num_head', default=10, type=int, help='Number of transformer head.') parser.add_argument('--head_dim', default=20, type=int, help='Head dimension of transformer.') parser.add_argument('--tf_drop_rate', default=0.1, type=float, help='Transformer dropout rate.') parser.add_argument('--emb_drop_rate', default=0.5, type=float, help='Embedding dropout rate.') parser.add_argument('--cell_drop_rate', default=0.2, type=float, help='Aggregation module dropout rate.') parser.add_argument('--word_alphabet_size', type=int, help='Word alphabet size.') parser.add_argument('--char_alphabet_size', type=int, help='Char alphabet size.') parser.add_argument('--label_alphabet_size', type=int, help='Label alphabet size.') parser.add_argument('--char_dim', type=int, help='Char embedding size.') parser.add_argument('--word_dim', type=int, help='Word embedding size.') parser.add_argument('--lr', type=float, default=2e-05) parser.add_argument('--lr_decay', type=float, default=0) parser.add_argument('--weight_decay', type=float, default=0) args = parser.parse_args() status = args.status.lower() seed_num = args.seed random.seed(seed_num) torch.manual_seed(seed_num) np.random.seed(seed_num) train_file = args.train dev_file = args.dev test_file = args.test raw_file = args.raw output_file = args.output saved_set_path = args.saved_set saved_model_path = args.saved_model char_file = args.char_emb word_file = args.word_emb if status == 'train': if os.path.exists(saved_set_path): print('Loading saved data set...') with open(saved_set_path, 'rb') as f: data = pickle.load(f) else: data = Data() data_initialization(data, word_file, train_file, dev_file, test_file) data.generate_instance_with_words(train_file, 'train') data.generate_instance_with_words(dev_file, 'dev') data.generate_instance_with_words(test_file, 'test') data.build_char_pretrain_emb(char_file) data.build_word_pretrain_emb(word_file) if saved_set_path is not None: print('Dumping data...') with open(saved_set_path, 'wb') as f: pickle.dump(data, f) data.show_data_summary() args.word_alphabet_size = data.word_alphabet.size() args.char_alphabet_size = data.char_alphabet.size() args.label_alphabet_size = data.label_alphabet.size() args.char_dim = data.char_emb_dim args.word_dim = data.word_emb_dim print_args(args) train(data, args, saved_model_path) elif status == 'test': assert not (test_file is None) if os.path.exists(saved_set_path): print('Loading saved data set...') with open(saved_set_path, 'rb') as f: data = pickle.load(f) else: print("Cannot find saved data set: ", saved_set_path) exit(0) data.generate_instance_with_words(test_file, 'test') with open(saved_model_path + "_best_HP.config", "rb") as f: args = pickle.load(f) data.show_data_summary() print_args(args) load_model_decode(saved_model_path, data, args, "test") elif status == 'decode': assert not (raw_file is None or output_file is None) if os.path.exists(saved_set_path): print('Loading saved data set...') with open(saved_set_path, 'rb') as f: data = pickle.load(f) else: print("Cannot find saved data set: ", saved_set_path) exit(0) data.generate_instance_with_words(raw_file, 'raw') with open(saved_model_path + "_best_HP.config", "rb") as f: args = pickle.load(f) data.show_data_summary() print_args(args) decode_results = load_model_decode(saved_model_path, data, args, "raw") data.write_decoded_results(output_file, decode_results, 'raw') else: print("Invalid argument! Please use valid arguments! (train/test/decode)") ================================================ FILE: model/LGN.py ================================================ # -*- coding: utf-8 -*- # @Author: Yicheng Zou # @Last Modified by: Yicheng Zou, Contact: yczou18@fudan.edu.cn from model.crf import CRF from model.module import * class Graph(nn.Module): def __init__(self, data, args): super(Graph, self).__init__() self.gpu = args.use_gpu self.char_emb_dim = args.char_dim self.word_emb_dim = args.word_dim self.hidden_dim = args.hidden_dim self.num_head = args.num_head # 5 10 20 self.head_dim = args.head_dim # 10 20 self.tf_dropout_rate = args.tf_drop_rate self.iters = args.iters self.bmes_dim = 10 self.length_dim = 10 self.max_word_length = 5 self.emb_dropout_rate = args.emb_drop_rate self.cell_dropout_rate = args.cell_drop_rate self.use_crf = args.use_crf self.use_global = args.use_global self.use_edge = args.use_edge self.bidirectional = args.bidirectional self.label_size = args.label_alphabet_size # char embedding self.char_embedding = nn.Embedding(args.char_alphabet_size, self.char_emb_dim) if data.pretrain_char_embedding is not None: self.char_embedding.weight.data.copy_(torch.from_numpy(data.pretrain_char_embedding)) if self.use_edge: # word embedding self.word_embedding = nn.Embedding(args.word_alphabet_size, self.word_emb_dim) if data.pretrain_word_embedding is not None: scale = np.sqrt(3.0 / self.word_emb_dim) data.pretrain_word_embedding[0, :] = np.random.uniform(-scale, scale, [1, self.word_emb_dim]) self.word_embedding.weight.data.copy_(torch.from_numpy(data.pretrain_word_embedding)) # bmes embedding self.bmes_embedding = nn.Embedding(4, self.bmes_dim) """ self.edge_emb_linear = nn.Sequential( nn.Linear(self.word_emb_dim, self.hidden_dim), nn.ELU() ) """ # lstm self.emb_rnn_f = nn.LSTM(self.char_emb_dim, self.hidden_dim, batch_first=True) self.emb_rnn_b = nn.LSTM(self.char_emb_dim, self.hidden_dim, batch_first=True) # length embedding self.length_embedding = nn.Embedding(self.max_word_length, self.length_dim) self.dropout = nn.Dropout(self.emb_dropout_rate) self.norm = nn.LayerNorm(self.hidden_dim) if self.use_edge: # Node aggregation module self.edge2node_f = nn.ModuleList( [MultiHeadAtt(self.hidden_dim, self.hidden_dim * 2 + self.length_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) for _ in range(self.iters)]) # Edge aggregation module self.node2edge_f = nn.ModuleList( [MultiHeadAtt(self.hidden_dim, self.hidden_dim + self.bmes_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) for _ in range(self.iters)]) else: # Node aggregation module self.edge2node_f = nn.ModuleList( [MultiHeadAtt(self.hidden_dim, self.hidden_dim + self.length_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) for _ in range(self.iters)]) if self.use_global: # Global Node aggregation module self.glo_att_f_node = nn.ModuleList( [GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) for _ in range(self.iters)]) if self.use_edge: self.glo_att_f_edge = nn.ModuleList( [GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) for _ in range(self.iters)]) # Updating modules if self.use_edge: self.glo_rnn_f = Global_Cell(self.hidden_dim * 3, self.hidden_dim, dropout=self.cell_dropout_rate) self.node_rnn_f = Nodes_Cell(self.hidden_dim * 5, self.hidden_dim, dropout=self.cell_dropout_rate) self.edge_rnn_f = Edges_Cell(self.hidden_dim * 4, self.hidden_dim, dropout=self.cell_dropout_rate) else: self.glo_rnn_f = Global_Cell(self.hidden_dim * 2, self.hidden_dim, dropout=self.cell_dropout_rate) self.node_rnn_f = Nodes_Cell(self.hidden_dim * 4, self.hidden_dim, dropout=self.cell_dropout_rate) else: # Updating modules self.node_rnn_f = Nodes_Cell(self.hidden_dim * 3, self.hidden_dim, use_global=False, dropout=self.cell_dropout_rate) if self.use_edge: self.edge_rnn_f = Edges_Cell(self.hidden_dim * 2, self.hidden_dim, use_global=False, dropout=self.cell_dropout_rate) if self.bidirectional: if self.use_edge: # Node aggregation module self.edge2node_b = nn.ModuleList( [MultiHeadAtt(self.hidden_dim, self.hidden_dim * 2 + self.length_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) for _ in range(self.iters)]) # Edge aggregation module self.node2edge_b = nn.ModuleList( [MultiHeadAtt(self.hidden_dim, self.hidden_dim + self.bmes_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) for _ in range(self.iters)]) else: # Node aggregation module self.edge2node_b = nn.ModuleList( [MultiHeadAtt(self.hidden_dim, self.hidden_dim + self.length_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) for _ in range(self.iters)]) if self.use_global: # Global Node aggregation module self.glo_att_b_node = nn.ModuleList( [GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) for _ in range(self.iters)]) if self.use_edge: self.glo_att_b_edge = nn.ModuleList( [GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) for _ in range(self.iters)]) # Updating modules if self.use_edge: self.glo_rnn_b = Global_Cell(self.hidden_dim * 3, self.hidden_dim, self.cell_dropout_rate) self.node_rnn_b = Nodes_Cell(self.hidden_dim * 5, self.hidden_dim, self.cell_dropout_rate) self.edge_rnn_b = Edges_Cell(self.hidden_dim * 4, self.hidden_dim, self.cell_dropout_rate) else: self.glo_rnn_b = Global_Cell(self.hidden_dim * 2, self.hidden_dim, self.cell_dropout_rate) self.node_rnn_b = Nodes_Cell(self.hidden_dim * 4, self.hidden_dim, self.cell_dropout_rate) else: # Updating modules self.node_rnn_b = Nodes_Cell(self.hidden_dim * 3, self.hidden_dim, use_global=False, dropout=self.cell_dropout_rate) if self.use_edge: self.edge_rnn_b = Edges_Cell(self.hidden_dim * 2, self.hidden_dim, use_global=False, dropout=self.cell_dropout_rate) if self.bidirectional: output_dim = self.hidden_dim * 2 else: output_dim = self.hidden_dim self.layer_att_W = nn.Linear(output_dim, 1) if self.use_crf: self.hidden2tag = nn.Linear(output_dim, self.label_size + 2) self.crf = CRF(self.label_size, self.gpu) else: self.hidden2tag = nn.Linear(output_dim, self.label_size) self.criterion = nn.CrossEntropyLoss() def construct_graph(self, batch_size, seq_len, word_list): if self.cuda: device = 'cuda' else: device = 'cpu' if self.use_edge: unk_index = torch.tensor(0, device=device) unk_emb = self.word_embedding(unk_index) bmes_emb_b = self.bmes_embedding(torch.tensor(0, device=device)) bmes_emb_m = self.bmes_embedding(torch.tensor(1, device=device)) bmes_emb_e = self.bmes_embedding(torch.tensor(2, device=device)) bmes_emb_s = self.bmes_embedding(torch.tensor(3, device=device)) sen_nodes_mask_list = [] sen_words_length_list =[] sen_words_mask_f_list = [] sen_words_mask_b_list = [] sen_word_embed_list = [] sen_bmes_embed_list = [] max_edge_num = -1 for sen in range(batch_size): sen_nodes_mask = torch.zeros([1, seq_len], device=device).byte() sen_words_length = torch.zeros([1, self.length_dim], device=device) sen_words_mask_f = torch.zeros([1, seq_len], device=device).byte() sen_words_mask_b = torch.zeros([1, seq_len], device=device).byte() if self.use_edge: sen_word_embed = unk_emb[None, :] sen_bmes_embed = torch.zeros([1, seq_len, self.bmes_dim], device=device) for w in range(seq_len): if w < len(word_list[sen]) and word_list[sen][w]: for word, word_len in zip(word_list[sen][w][0], word_list[sen][w][1]): if word_len <= self.max_word_length: word_length_index = torch.tensor(word_len-1, device=device) else: word_length_index = torch.tensor(self.max_word_length - 1, device=device) word_length = self.length_embedding(word_length_index) sen_words_length = torch.cat([sen_words_length, word_length[None, :]], 0) # mask: Masked elements are marked by 1, batch_size * word_num * seq_len nodes_mask = torch.ones([1, seq_len], device=device).byte() words_mask_f = torch.ones([1, seq_len], device=device).byte() words_mask_b = torch.ones([1, seq_len], device=device).byte() words_mask_f[0, w + word_len - 1] = 0 sen_words_mask_f = torch.cat([sen_words_mask_f, words_mask_f], 0) words_mask_b[0, w] = 0 sen_words_mask_b = torch.cat([sen_words_mask_b, words_mask_b], 0) if self.use_edge: word_index = torch.tensor(word, device=device) word_embedding = self.word_embedding(word_index) sen_word_embed = torch.cat([sen_word_embed, word_embedding[None, :]], 0) bmes_embed = torch.zeros([1, seq_len, self.bmes_dim], device=device) for index in range(word_len): nodes_mask[0, w + index] = 0 if word_len == 1: bmes_embed[0, w + index, :] = bmes_emb_s elif index == 0: bmes_embed[0, w + index, :] = bmes_emb_b elif index == word_len - 1: bmes_embed[0, w + index, :] = bmes_emb_e else: bmes_embed[0, w + index, :] = bmes_emb_m sen_bmes_embed = torch.cat([sen_bmes_embed, bmes_embed], 0) sen_nodes_mask = torch.cat([sen_nodes_mask, nodes_mask], 0) if sen_words_mask_f.size(0) > max_edge_num: max_edge_num = sen_words_mask_f.size(0) sen_words_mask_f_list.append(sen_words_mask_f.unsqueeze_(0)) sen_words_mask_b_list.append(sen_words_mask_b.unsqueeze_(0)) sen_words_length_list.append(sen_words_length.unsqueeze_(0)) if self.use_edge: sen_nodes_mask_list.append(sen_nodes_mask.unsqueeze_(0)) sen_word_embed_list.append(sen_word_embed.unsqueeze_(0)) sen_bmes_embed_list.append(sen_bmes_embed.unsqueeze_(0)) edges_mask = torch.zeros([batch_size, max_edge_num], device=device) batch_words_mask_f = torch.ones([batch_size, max_edge_num, seq_len], device=device).byte() batch_words_mask_b = torch.ones([batch_size, max_edge_num, seq_len], device=device).byte() batch_words_length = torch.zeros([batch_size, max_edge_num, self.length_dim], device=device) if self.use_edge: batch_nodes_mask = torch.zeros([batch_size, max_edge_num, seq_len], device=device).byte() batch_word_embed = torch.zeros([batch_size, max_edge_num, self.word_emb_dim], device=device) batch_bmes_embed = torch.zeros([batch_size, max_edge_num, seq_len, self.bmes_dim], device=device) else: batch_word_embed = None batch_bmes_embed = None batch_nodes_mask = None for index in range(batch_size): curr_edge_num = sen_words_mask_f_list[index].size(1) edges_mask[index, 0:curr_edge_num] = 1. batch_words_mask_f[index, 0:curr_edge_num, :] = sen_words_mask_f_list[index] batch_words_mask_b[index, 0:curr_edge_num, :] = sen_words_mask_b_list[index] batch_words_length[index, 0:curr_edge_num, :] = sen_words_length_list[index] if self.use_edge: batch_nodes_mask[index, 0:curr_edge_num, :] = sen_nodes_mask_list[index] batch_word_embed[index, 0:curr_edge_num, :] = sen_word_embed_list[index] batch_bmes_embed[index, 0:curr_edge_num, :, :] = sen_bmes_embed_list[index] return batch_word_embed, batch_bmes_embed, batch_nodes_mask, batch_words_mask_f, \ batch_words_mask_b, batch_words_length, edges_mask def update_graph(self, word_list, word_inputs, mask): mask = mask.float() node_embeds = self.char_embedding(word_inputs) # batch_size, max_seq_len, embedding B, L, _ = node_embeds.size() edge_embs, bmes_embs, nodes_mask, words_mask_f, words_mask_b, words_length, edges_mask = \ self.construct_graph(B, L, word_list) node_embeds = self.dropout(node_embeds) _, N, _ = words_mask_f.size() if self.use_edge: edge_embs = self.dropout(edge_embs) # forward direction digraph nodes_f, _ = self.emb_rnn_f(node_embeds) nodes_f = nodes_f * mask.unsqueeze(2) nodes_f_cat = nodes_f[:, None, :, :] _, _, H = nodes_f.size() if self.use_edge: edges_f = edge_embs * edges_mask.unsqueeze(2) edges_f_cat = edges_f[:, None, :, :] if self.use_global: glo_f = edges_f.sum(1, keepdim=True) / edges_mask.sum(1, keepdim=True).unsqueeze_(2) + \ nodes_f.sum(1, keepdim=True) / mask.sum(1, keepdim=True).unsqueeze_(2) glo_f_cat = glo_f[:, None, :, :] else: if self.use_global: glo_f = (nodes_f * mask.unsqueeze(2)).sum(1, keepdim=True) / mask.sum(1, keepdim=True).unsqueeze_(2) glo_f_cat = glo_f[:, None, :, :] for i in range(self.iters): # Attention-based aggregation if self.use_edge and N > 1: bmes_nodes_f = torch.cat([nodes_f.unsqueeze(2).expand(B, L, N, H), bmes_embs.transpose(1, 2)], -1) edges_att_f = self.node2edge_f[i](edges_f, bmes_nodes_f, nodes_mask.transpose(1, 2)) nodes_begin_f = torch.sum(nodes_f[:, None, :, :] * (1 - words_mask_b)[:, :, :, None].float(), 2) nodes_begin_f = torch.cat([torch.zeros([B, 1, H], device=nodes_f.device), nodes_begin_f[:, 1:N, :]], 1) if self.use_edge: nodes_att_f = self.edge2node_f[i](nodes_f, torch.cat([edges_f, nodes_begin_f, words_length], -1).unsqueeze(2), words_mask_f) if self.use_global: glo_att_f = torch.cat([self.glo_att_f_node[i](glo_f, nodes_f, (1 - mask).byte()), self.glo_att_f_edge[i](glo_f, edges_f, (1 - edges_mask).byte())], -1) else: nodes_att_f = self.edge2node_f[i](nodes_f, torch.cat([nodes_begin_f, words_length], -1).unsqueeze(2), words_mask_f) if self.use_global: glo_att_f = self.glo_att_f_node[i](glo_f, nodes_f, (1 - mask).byte()) # RNN-based update if self.use_edge and N > 1: if self.use_global: edges_f = torch.cat([edges_f[:, 0:1, :], self.edge_rnn_f(edges_f[:, 1:N, :], edges_att_f[:, 1:N, :], glo_att_f.expand(B, N-1, H*2))], 1) else: edges_f = torch.cat([edges_f[:, 0:1, :], self.edge_rnn_f(edges_f[:, 1:N, :], edges_att_f[:, 1:N, :])], 1) edges_f_cat = torch.cat([edges_f_cat, edges_f[:, None, :, :]], 1) edges_f = torch.cat([edges_f[:, 0:1, :], self.norm(torch.sum(edges_f_cat[:, :, 1:N, :], 1))], 1) nodes_f_r = torch.cat([torch.zeros([B, 1, self.hidden_dim], device=nodes_f.device), nodes_f[:, 0:(L-1), :]], 1) if self.use_global: nodes_f = self.node_rnn_f(nodes_f, nodes_f_r, nodes_att_f, glo_att_f.expand(B, L, -1)) else: nodes_f = self.node_rnn_f(nodes_f, nodes_f_r, nodes_att_f) nodes_f_cat = torch.cat([nodes_f_cat, nodes_f[:, None, :, :]], 1) nodes_f = self.norm(torch.sum(nodes_f_cat, 1)) if self.use_global: glo_f = self.glo_rnn_f(glo_f, glo_att_f) glo_f_cat = torch.cat([glo_f_cat, glo_f[:, None, :, :]], 1) glo_f = self.norm(torch.sum(glo_f_cat, 1)) nodes_cat = nodes_f_cat # backward direction digraph if self.bidirectional: nodes_b, _ = self.emb_rnn_b(torch.flip(node_embeds, [1])) nodes_b = torch.flip(nodes_b, [1]) nodes_b = nodes_b * mask.unsqueeze(2) nodes_b_cat = nodes_b[:, None, :, :] if self.use_edge: edges_b = edge_embs * edges_mask.unsqueeze(2) edges_b_cat = edges_b[:, None, :, :] if self.use_global: glo_b = edges_b.sum(1, keepdim=True) / edges_mask.sum(1, keepdim=True).unsqueeze_(2) + \ nodes_b.sum(1, keepdim=True) / mask.sum(1, keepdim=True).unsqueeze_(2) glo_b_cat = glo_b[:, None, :, :] else: if self.use_global: glo_b = nodes_b.sum(1, keepdim=True) / mask.sum(1, keepdim=True).unsqueeze_(2) glo_b_cat = glo_b[:, None, :, :] for i in range(self.iters): # Attention-based aggregation if self.use_edge and N > 1: bmes_nodes_b = torch.cat([nodes_b.unsqueeze(2).expand(B, L, N, H), bmes_embs.transpose(1, 2)], -1) edges_att_b = self.node2edge_b[i](edges_b, bmes_nodes_b, nodes_mask.transpose(1, 2)) nodes_begin_b = torch.sum(nodes_b[:, None, :, :] * (1 - words_mask_f)[:, :, :, None].float(), 2) nodes_begin_b = torch.cat([torch.zeros([B, 1, H], device=nodes_b.device), nodes_begin_b[:, 1:N, :]], 1) if self.use_edge: nodes_att_b = self.edge2node_b[i](nodes_b, torch.cat([edges_b, nodes_begin_b, words_length], -1).unsqueeze(2), words_mask_b) if self.use_global: glo_att_b = torch.cat([self.glo_att_b_node[i](glo_b, nodes_b, (1-mask).byte()), self.glo_att_b_edge[i](glo_b, edges_b, (1-edges_mask).byte())], -1) else: nodes_att_b = self.edge2node_b[i](nodes_b, torch.cat([nodes_begin_b, words_length], -1).unsqueeze(2), words_mask_b) if self.use_global: glo_att_b = self.glo_att_b_node[i](glo_b, nodes_b, (1-mask).byte()) # RNN-based update if self.use_edge and N > 1: if self.use_global: edges_b = torch.cat([edges_b[:, 0:1, :], self.edge_rnn_b(edges_b[:, 1:N, :], edges_att_b[:, 1:N, :], glo_att_b.expand(B, N-1, H*2))], 1) else: edges_b = torch.cat([edges_b[:, 0:1, :], self.edge_rnn_b(edges_b[:, 1:N, :], edges_att_b[:, 1:N, :])], 1) edges_b_cat = torch.cat([edges_b_cat, edges_b[:, None, :, :]], 1) edges_b = torch.cat([edges_b[:, 0:1, :], self.norm(torch.sum(edges_b_cat[:, :, 1:N, :], 1))], 1) nodes_b_r = torch.cat([nodes_b[:, 1:L, :], torch.zeros([B, 1, self.hidden_dim], device=nodes_b.device)], 1) if self.use_global: nodes_b = self.node_rnn_b(nodes_b, nodes_b_r, nodes_att_b, glo_att_b.expand(B, L, -1)) else: nodes_b = self.node_rnn_b(nodes_b, nodes_b_r, nodes_att_b) nodes_b_cat = torch.cat([nodes_b_cat, nodes_b[:, None, :, :]], 1) nodes_b = self.norm(torch.sum(nodes_b_cat, 1)) if self.use_global: glo_b = self.glo_rnn_b(glo_b, glo_att_b) glo_b_cat = torch.cat([glo_b_cat, glo_b[:, None, :, :]], 1) glo_b = self.norm(torch.sum(glo_b_cat, 1)) nodes_cat = torch.cat([nodes_f_cat, nodes_b_cat], -1) layer_att = torch.sigmoid(self.layer_att_W(nodes_cat)) layer_alpha = F.softmax(layer_att, 1) nodes = torch.sum(layer_alpha * nodes_cat, 1) tags = self.hidden2tag(nodes) return tags def forward(self, word_list, batch_inputs, mask, batch_label=None): tags = self.update_graph(word_list, batch_inputs, mask) if batch_label is not None: if self.use_crf: total_loss = self.crf.neg_log_likelihood_loss(tags, mask, batch_label) else: total_loss = self.criterion(tags.view(-1, self.label_size), batch_label.view(-1)) else: total_loss = None if self.use_crf: _, tag_seq = self.crf._viterbi_decode(tags, mask) else: tag_seq = tags.argmax(-1) return total_loss, tag_seq ================================================ FILE: model/crf.py ================================================ # -*- coding: utf-8 -*- # @Author: Jie Yang # @Last Modified by: Yicheng Zou, Contact: yczou18@fudan.edu.cn import torch import torch.autograd as autograd import torch.nn as nn START_TAG = -2 STOP_TAG = -1 # Compute log sum exp in a numerically stable way for the forward algorithm def log_sum_exp(vec, m_size): """ calculate log of exp sum args: vec (batch_size, vanishing_dim, hidden_dim) : input tensor m_size : hidden_dim return: batch_size, hidden_dim """ _, idx = torch.max(vec, 1) # B * 1 * M max_score = torch.gather(vec, 1, idx.view(-1, 1, m_size)).view(-1, 1, m_size) # B * M return max_score.view(-1, m_size) + torch.log(torch.sum(torch.exp(vec - max_score.expand_as(vec)), 1)).view(-1, m_size) # B * M class CRF(nn.Module): def __init__(self, tagset_size, gpu): super(CRF, self).__init__() print ("build batched crf...") self.gpu = gpu # Matrix of transition parameters. Entry i,j is the score of transitioning *to* i *from* j. self.average_batch = False self.tagset_size = tagset_size # # We add 2 here, because of START_TAG and STOP_TAG # # transitions (f_tag_size, t_tag_size), transition value from f_tag to t_tag init_transitions = torch.zeros(self.tagset_size+2, self.tagset_size+2) # init_transitions = torch.zeros(self.tagset_size+2, self.tagset_size+2) # init_transitions[:,START_TAG] = -1000.0 # init_transitions[STOP_TAG,:] = -1000.0 # init_transitions[:,0] = -1000.0 # init_transitions[0,:] = -1000.0 if self.gpu: init_transitions = init_transitions.cuda() self.transitions = nn.Parameter(init_transitions) #(t+2,t+2) # self.transitions = nn.Parameter(torch.Tensor(self.tagset_size+2, self.tagset_size+2)) # self.transitions.data.zero_() def _calculate_PZ(self, feats, mask): """ input: feats: (batch, seq_len, self.tag_size+2) (b,m,t+2) masks: (batch, seq_len) (b,m) """ batch_size = feats.size(0) seq_len = feats.size(1) tag_size = feats.size(2) # print feats.view(seq_len, tag_size) assert(tag_size == self.tagset_size+2) mask = mask.transpose(1,0).contiguous() #(m,b) ins_num = seq_len * batch_size ## be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1) feats = feats.transpose(1,0).contiguous().view(ins_num,1, tag_size).expand(ins_num, tag_size, tag_size) #(i,t+2,t+2) 第2维t+2的每一个是一样的 ## need to consider start scores = feats + self.transitions.view(1,tag_size,tag_size).expand(ins_num, tag_size, tag_size) scores = scores.view(seq_len, batch_size, tag_size, tag_size) # build iter seq_iter = enumerate(scores) _, inivalues = seq_iter.__next__() # bat_size * from_target_size * to_target_size (b,t,t) inivalues是每个句子的第一个字 # only need start from start_tag partition = inivalues[:, START_TAG, :].clone().view(batch_size, tag_size, 1) # bat_size * to_target_size (b,t,1) ## add start score (from start to all tag, duplicate to batch_size) # partition = partition + self.transitions[START_TAG,:].view(1, tag_size, 1).expand(batch_size, tag_size, 1) # iter over last scores for idx, cur_values in seq_iter: # previous to_target is current from_target # partition: previous results log(exp(from_target)), #(batch_size * from_target) # cur_values: bat_size * from_target * to_target cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size) cur_partition = log_sum_exp(cur_values, tag_size) #(b,t) # print cur_partition.data # (bat_size * from_target * to_target) -> (bat_size * to_target) # partition = utils.switch(partition, cur_partition, mask[idx].view(bat_size, 1).expand(bat_size, self.tagset_size)).view(bat_size, -1) mask_idx = mask[idx, :].view(batch_size, 1).expand(batch_size, tag_size) ## effective updated partition part, only keep the partition value of mask value = 1 masked_cur_partition = cur_partition.masked_select(mask_idx) ## let mask_idx broadcastable, to disable warning mask_idx = mask_idx.contiguous().view(batch_size, tag_size, 1) ## replace the partition where the maskvalue=1, other partition value keeps the same partition.masked_scatter_(mask_idx, masked_cur_partition) # until the last state, add transition score for all partition (and do log_sum_exp) then select the value in STOP_TAG cur_values = self.transitions.view(1,tag_size, tag_size).expand(batch_size, tag_size, tag_size) + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size) cur_partition = log_sum_exp(cur_values, tag_size) #(batch_size,hidden_dim) final_partition = cur_partition[:, STOP_TAG] #(batch_size) return final_partition.sum(), scores #scores: (seq_len, batch, tag_size, tag_size) def _viterbi_decode(self, feats, mask): """ input: feats: (batch, seq_len, self.tag_size+2) mask: (batch, seq_len) output: decode_idx: (batch, seq_len) decoded sequence path_score: (batch, 1) corresponding score for each sequence (to be implementated) """ batch_size = feats.size(0) seq_len = feats.size(1) tag_size = feats.size(2) assert(tag_size == self.tagset_size+2) ## calculate sentence length for each sentence length_mask = torch.sum(mask.long(), dim = 1).view(batch_size,1).long() #(batch_size,1) 每个句子的mask长度 ## mask to (seq_len, batch_size) mask = mask.transpose(1,0).contiguous() #(seq_len,b) ins_num = seq_len * batch_size ## be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1) feats = feats.transpose(1,0).contiguous().view(ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size) #(ins_num, tag_size, tag_size) ## need to consider start scores = feats + self.transitions.view(1,tag_size,tag_size).expand(ins_num, tag_size, tag_size) scores = scores.view(seq_len, batch_size, tag_size, tag_size) # build iter seq_iter = enumerate(scores) ## record the position of best score back_points = list() partition_history = list() ## reverse mask (bug for mask = 1- mask, use this as alternative choice) # mask = 1 + (-1)*mask mask = (1 - mask.long()).byte() _, inivalues = seq_iter.__next__() # bat_size * from_target_size * to_target_size # only need start from start_tag partition = inivalues[:, START_TAG, :].clone().view(batch_size, tag_size, 1) # bat_size * to_target_size partition_history.append(partition) #(seqlen,batch_size,tag_size,1) # iter over last scores for idx, cur_values in seq_iter: # previous to_target is current from_target # partition: previous results log(exp(from_target)), #(batch_size * from_target) # cur_values: batch_size * from_target * to_target cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size) ## forscores, cur_bp = torch.max(cur_values[:,:-2,:], 1) # do not consider START_TAG/STOP_TAG partition, cur_bp = torch.max(cur_values,dim=1) partition_history.append(partition.unsqueeze(2)) ## cur_bp: (batch_size, tag_size) max source score position in current tag ## set padded label as 0, which will be filtered in post processing cur_bp.masked_fill_(mask[idx].view(batch_size, 1).expand(batch_size, tag_size), 0) back_points.append(cur_bp) ### add score to final STOP_TAG partition_history = torch.cat(partition_history,dim=0).view(seq_len, batch_size,-1).transpose(1,0).contiguous() ## (batch_size, seq_len, tag_size) ### get the last position for each setences, and select the last partitions using gather() last_position = length_mask.view(batch_size,1,1).expand(batch_size, 1, tag_size) -1 last_partition = torch.gather(partition_history, 1, last_position).view(batch_size,tag_size,1) ### calculate the score from last partition to end state (and then select the STOP_TAG from it) last_values = last_partition.expand(batch_size, tag_size, tag_size) + self.transitions.view(1,tag_size, tag_size).expand(batch_size, tag_size, tag_size) _, last_bp = torch.max(last_values, 1) #(batch_size,tag_size) pad_zero = autograd.Variable(torch.zeros(batch_size, tag_size)).long() if self.gpu: pad_zero = pad_zero.cuda() back_points.append(pad_zero) back_points = torch.cat(back_points).view(seq_len, batch_size, tag_size) ## select end ids in STOP_TAG pointer = last_bp[:, STOP_TAG] #(batch_size) insert_last = pointer.contiguous().view(batch_size,1,1).expand(batch_size,1, tag_size) back_points = back_points.transpose(1,0).contiguous() #(batch_size,sq_len,tag_size) ## move the end ids(expand to tag_size) to the corresponding position of back_points to replace the 0 values # print "lp:",last_position # print "il:",insert_last back_points.scatter_(1, last_position, insert_last) ##(batch_size,sq_len,tag_size) # print "bp:",back_points # exit(0) back_points = back_points.transpose(1,0).contiguous() #(seq_len, batch_size, tag_size) ## decode from the end, padded position ids are 0, which will be filtered if following evaluation decode_idx = autograd.Variable(torch.LongTensor(seq_len, batch_size)) if self.gpu: decode_idx = decode_idx.cuda() decode_idx[-1] = pointer.data for idx in range(len(back_points)-2, -1, -1): pointer = torch.gather(back_points[idx], 1, pointer.contiguous().view(batch_size, 1)) #pointer's size:(batch_size,1) decode_idx[idx] = pointer.squeeze(1).data path_score = None decode_idx = decode_idx.transpose(1,0) #(batch_size, sent_len) return path_score, decode_idx # def forward(self, feats): path_score, best_path = self._viterbi_decode(feats) return path_score, best_path def _score_sentence(self, scores, mask, tags): """ input: scores: variable (seq_len, batch, tag_size, tag_size) mask: (batch, seq_len) tags: tensor (batch, seq_len) output: score: sum of score for gold sequences within whole batch """ # Gives the score of a provided tag sequence batch_size = scores.size(1) seq_len = scores.size(0) tag_size = scores.size(2) ## convert tag value into a new format, recorded label bigram information to index new_tags = autograd.Variable(torch.LongTensor(batch_size, seq_len)) if self.gpu: new_tags = new_tags.cuda() for idx in range(seq_len): if idx == 0: ## start -> first score new_tags[:,0] = (tag_size - 2)*tag_size + tags[:,0] else: new_tags[:,idx] = tags[:,idx-1]*tag_size + tags[:,idx] ## transition for label to STOP_TAG end_transition = self.transitions[:,STOP_TAG].contiguous().view(1, tag_size).expand(batch_size, tag_size) ## length for batch, last word position = length - 1 length_mask = torch.sum(mask.long(), dim = 1).view(batch_size,1).long() ## index the label id of last word end_ids = torch.gather(tags, 1, length_mask - 1) ## index the transition score for end_id to STOP_TAG end_energy = torch.gather(end_transition, 1, end_ids) ## convert tag as (seq_len, batch_size, 1) new_tags = new_tags.transpose(1,0).contiguous().view(seq_len, batch_size, 1) ### need convert tags id to search from 400 positions of scores tg_energy = torch.gather(scores.view(seq_len, batch_size, -1), 2, new_tags).view(seq_len, batch_size) # seq_len * bat_size ## mask transpose to (seq_len, batch_size) tg_energy = tg_energy.masked_select(mask.transpose(1,0)) # ## calculate the score from START_TAG to first label # start_transition = self.transitions[START_TAG,:].view(1, tag_size).expand(batch_size, tag_size) # start_energy = torch.gather(start_transition, 1, tags[0,:]) ## add all score together # gold_score = start_energy.sum() + tg_energy.sum() + end_energy.sum() gold_score = tg_energy.sum() + end_energy.sum() return gold_score def neg_log_likelihood_loss(self, feats, mask, tags): # nonegative log likelihood batch_size = feats.size(0) forward_score, scores = self._calculate_PZ(feats, mask) #forward_score:long, scores: (seq_len, batch, tag_size, tag_size) gold_score = self._score_sentence(scores, mask, tags) #print ("batch, f:", forward_score.data, " g:", gold_score.data, " dis:", forward_score.data - gold_score.data) # exit(0) if self.average_batch: return (forward_score - gold_score)/batch_size else: return forward_score - gold_score ================================================ FILE: model/module.py ================================================ # -*- coding: utf-8 -*- # @Author: Yicheng Zou # @Last Modified by: Yicheng Zou, Contact: yczou18@fudan.edu.cn import torch import math import torch.nn as nn import numpy as np import torch.nn.functional as F class MultiHeadAtt(nn.Module): def __init__(self, nhid, keyhid, nhead=10, head_dim=10, dropout=0.1, if_g=False): super(MultiHeadAtt, self).__init__() if if_g: self.WQ = nn.Conv2d(nhid * 3, nhead * head_dim, 1) else: self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1) self.WK = nn.Conv2d(keyhid, nhead * head_dim, 1) self.WV = nn.Conv2d(keyhid, nhead * head_dim, 1) self.WO = nn.Conv2d(nhead * head_dim, nhid, 1) self.drop = nn.Dropout(dropout) self.norm = nn.LayerNorm(nhid) self.nhid, self.nhead, self.head_dim = nhid, nhead, head_dim def forward(self, query_h, value, mask, query_g=None): if not (query_g is None): query = torch.cat([query_h, query_g], -1) else: query = query_h query = query.permute(0, 2, 1)[:, :, :, None] value = value.permute(0, 3, 1, 2) residual = query_h nhid, nhead, head_dim = self.nhid, self.nhead, self.head_dim B, QL, H = query_h.shape _, _, VL, VD = value.shape # VD = 1 or VD = QL assert VD == 1 or VD == QL # q: (B, H, QL, 1) # v: (B, H, VL, VD) q, k, v = self.WQ(query), self.WK(value), self.WV(value) q = q.view(B, nhead, head_dim, 1, QL) k = k.view(B, nhead, head_dim, VL, VD) v = v.view(B, nhead, head_dim, VL, VD) alpha = (q * k).sum(2, keepdim=True) / np.sqrt(head_dim) alpha = alpha.masked_fill(mask[:, None, None, :, :], -np.inf) alpha = self.drop(F.softmax(alpha, 3)) att = (alpha * v).sum(3).view(B, nhead * head_dim, QL, 1) output = F.leaky_relu(self.WO(att)).permute(0, 2, 3, 1).view(B, QL, H) output = self.norm(output + residual) return output class GloAtt(nn.Module): def __init__(self, nhid, nhead=10, head_dim=10, dropout=0.1): # Multi-head Self Attention Case 2, a broadcastable query for a sequence key and value super(GloAtt, self).__init__() self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1) self.WK = nn.Conv2d(nhid, nhead * head_dim, 1) self.WV = nn.Conv2d(nhid, nhead * head_dim, 1) self.WO = nn.Conv2d(nhead * head_dim, nhid, 1) self.drop = nn.Dropout(dropout) self.norm = nn.LayerNorm(nhid) # print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim) self.nhid, self.nhead, self.head_dim = nhid, nhead, head_dim def forward(self, x, y, mask=None): # x: B, H, 1, 1, 1 y: B H L 1 nhid, nhead, head_dim = self.nhid, self.nhead, self.head_dim B, L, H = y.shape x = x.permute(0, 2, 1)[:, :, :, None] y = y.permute(0, 2, 1)[:, :, :, None] residual = x q, k, v = self.WQ(x), self.WK(y), self.WV(y) q = q.view(B, nhead, 1, head_dim) # B, H, 1, 1 -> B, N, 1, h k = k.view(B, nhead, head_dim, L) # B, H, L, 1 -> B, N, h, L v = v.view(B, nhead, head_dim, L).permute(0, 1, 3, 2) # B, H, L, 1 -> B, N, L, h pre_a = torch.matmul(q, k) / np.sqrt(head_dim) if mask is not None: pre_a = pre_a.masked_fill(mask[:, None, None, :], -float('inf')) alphas = self.drop(F.softmax(pre_a, 3)) # B, N, 1, L att = torch.matmul(alphas, v).view(B, -1, 1, 1) # B, N, 1, h -> B, N*h, 1, 1 output = F.leaky_relu(self.WO(att)) + residual output = self.norm(output.permute(0, 2, 3, 1)).view(B, 1, H) return output class Nodes_Cell(nn.Module): def __init__(self, input_h, hid_h, use_global=True, dropout=0.2): super(Nodes_Cell, self).__init__() self.use_global = use_global self.hidden_size = hid_h self.Wix = nn.Linear(input_h, hid_h) self.Wi2 = nn.Linear(input_h, hid_h) self.Wf = nn.Linear(input_h, hid_h) self.Wcx = nn.Linear(input_h, hid_h) self.drop = nn.Dropout(dropout) self.reset_parameters() def reset_parameters(self): stdv = 1.0 / math.sqrt(self.hidden_size) for weight in self.parameters(): nn.init.uniform_(weight, -stdv, stdv) def forward(self, h, h2, x, glo=None): x = self.drop(x) if self.use_global: glo = self.drop(glo) cat_all = torch.cat([h, h2, x, glo], -1) else: cat_all = torch.cat([h, h2, x], -1) ix = torch.sigmoid(self.Wix(cat_all)) i2 = torch.sigmoid(self.Wi2(cat_all)) f = torch.sigmoid(self.Wf(cat_all)) cx = torch.tanh(self.Wcx(cat_all)) alpha = F.softmax(torch.cat([ix.unsqueeze(1), i2.unsqueeze(1), f.unsqueeze(1)], 1), 1) output = (alpha[:, 0] * cx) + (alpha[:, 1] * h2) + (alpha[:, 2] * h) return output class Edges_Cell(nn.Module): def __init__(self, input_h, hid_h, use_global=True, dropout=0.2): super(Edges_Cell, self).__init__() self.use_global = use_global self.hidden_size = hid_h self.Wi = nn.Linear(input_h, hid_h) self.Wf = nn.Linear(input_h, hid_h) self.Wc = nn.Linear(input_h, hid_h) self.drop = nn.Dropout(dropout) self.reset_parameters() def reset_parameters(self): stdv = 1.0 / math.sqrt(self.hidden_size) for weight in self.parameters(): nn.init.uniform_(weight, -stdv, stdv) def forward(self, h, x, glo=None): x = self.drop(x) if self.use_global: glo = self.drop(glo) cat_all = torch.cat([h, x, glo], -1) else: cat_all = torch.cat([h, x], -1) i = torch.sigmoid(self.Wi(cat_all)) f = torch.sigmoid(self.Wf(cat_all)) c = torch.tanh(self.Wc(cat_all)) alpha = F.softmax(torch.cat([i.unsqueeze(1), f.unsqueeze(1)], 1), 1) output = (alpha[:, 0] * c) + (alpha[:, 1] * h) return output class Global_Cell(nn.Module): def __init__(self, input_h, hid_h, dropout=0.2): super(Global_Cell, self).__init__() self.hidden_size = hid_h self.Wi = nn.Linear(input_h, hid_h) self.Wf = nn.Linear(input_h, hid_h) self.Wc = nn.Linear(input_h, hid_h) self.drop = nn.Dropout(dropout) self.reset_parameters() def reset_parameters(self): stdv = 1.0 / math.sqrt(self.hidden_size) for weight in self.parameters(): nn.init.uniform_(weight, -stdv, stdv) def forward(self, h, x): x = self.drop(x) cat_all = torch.cat([h, x], -1) i = torch.sigmoid(self.Wi(cat_all)) f = torch.sigmoid(self.Wf(cat_all)) c = torch.tanh(self.Wc(cat_all)) alpha = F.softmax(torch.cat([i.unsqueeze(1), f.unsqueeze(1)], 1), 1) output = (alpha[:, 0] * c) + (alpha[:, 1] * h) return output ================================================ FILE: utils/alphabet.py ================================================ # -*- coding: utf-8 -*- # @Author: Max # @Last Modified by: Yicheng Zou, Contact: yczou18@fudan.edu.cn """ Alphabet maps objects to integer ids. It provides two way mapping from the index to the objects. """ import json import os class Alphabet: def __init__(self, name, label=False, keep_growing=True): self.__name = name self.UNKNOWN = "" self.label = label self.instance2index = {} self.instances = [] self.keep_growing = keep_growing # Index 0 is occupied by default, all else following. self.default_index = 0 self.next_index = 1 if not self.label: self.add(self.UNKNOWN) def clear(self, keep_growing=True): self.instance2index = {} self.instances = [] self.keep_growing = keep_growing # Index 0 is occupied by default, all else following. self.default_index = 0 self.next_index = 1 def add(self, instance): if instance not in self.instance2index: self.instances.append(instance) self.instance2index[instance] = self.next_index self.next_index += 1 def get_index(self, instance): try: return self.instance2index[instance] except KeyError: if self.keep_growing: index = self.next_index self.add(instance) return index else: return self.instance2index[self.UNKNOWN] def get_instance(self, index): if index == 0: # First index is occupied by the wildcard element. return None try: return self.instances[index - 1] except IndexError: print('WARNING:Alphabet get_instance ,unknown instance index {}, return the first label.'.format(index)) return self.instances[0] def size(self): return len(self.instances) + 1 def iteritems(self): return self.instance2index.items() def enumerate_items(self, start=1): if start < 1 or start >= self.size(): raise IndexError("Enumerate is allowed between [1 : size of the alphabet)") return zip(range(start, len(self.instances) + 1), self.instances[start - 1:]) def close(self): self.keep_growing = False def open(self): self.keep_growing = True def get_content(self): return {'instance2index': self.instance2index, 'instances': self.instances} def from_json(self, data): self.instances = data["instances"] self.instance2index = data["instance2index"] def save(self, output_directory, name=None): """ Save both alhpabet records to the given directory. :param output_directory: Directory to save model and weights. :param name: The alphabet saving name, optional. :return: """ saving_name = name if name else self.__name try: json.dump(self.get_content(), open(os.path.join(output_directory, saving_name + ".json"), 'w')) except Exception as e: print("Exception: Alphabet is not saved: " + repr(e)) def load(self, input_directory, name=None): """ Load model architecture and weights from the give directory. This allow we use old models even the structure changes. :param input_directory: Directory to save model and weights :return: """ loading_name = name if name else self.__name self.from_json(json.load(open(os.path.join(input_directory, loading_name + ".json")))) ================================================ FILE: utils/data.py ================================================ # -*- coding: utf-8 -*- # @Author: Jie Yang # @Last Modified by: Yicheng Zou, Contact: yczou18@fudan.edu.cn import sys from utils.alphabet import Alphabet from utils.functions import * from utils.word_trie import Word_Trie class Data: def __init__(self): self.MAX_SENTENCE_LENGTH = 250 self.MAX_WORD_LENGTH = -1 self.number_normalized = True self.norm_char_emb = True self.norm_word_emb = False self.char_alphabet = Alphabet('character') self.label_alphabet = Alphabet('label', True) self.word_dict = Word_Trie() self.word_alphabet = Alphabet('word') self.train_texts = [] self.dev_texts = [] self.test_texts = [] self.raw_texts = [] self.train_Ids = [] self.dev_Ids = [] self.test_Ids = [] self.raw_Ids = [] self.char_emb_dim = 50 self.word_emb_dim = 50 self.pretrain_char_embedding = None self.pretrain_word_embedding = None self.label_size = 0 def show_data_summary(self): print("DATA SUMMARY:") print(" MAX SENTENCE LENGTH: %s"%(self.MAX_SENTENCE_LENGTH)) print(" MAX WORD LENGTH: %s"%(self.MAX_WORD_LENGTH)) print(" Number normalized: %s"%(self.number_normalized)) print(" Word alphabet size: %s"%(self.word_alphabet.size())) print(" Char alphabet size: %s"%(self.char_alphabet.size())) print(" Label alphabet size: %s"%(self.label_alphabet.size())) print(" Word embedding size: %s"%(self.word_emb_dim)) print(" Char embedding size: %s"%(self.char_emb_dim)) print(" Norm char emb: %s"%(self.norm_char_emb)) print(" Norm word emb: %s"%(self.norm_word_emb)) print(" Train instance number: %s"%(len(self.train_texts))) print(" Dev instance number: %s"%(len(self.dev_texts))) print(" Test instance number: %s"%(len(self.test_texts))) print(" Raw instance number: %s"%(len(self.raw_texts))) print("DATA SUMMARY END.") sys.stdout.flush() def build_alphabet(self, input_file): self.char_alphabet.open() self.label_alphabet.open() with open(input_file, 'r', encoding="utf-8") as f: for line in f: line = line.strip() if len(line) == 0: continue pair = line.split() char = pair[0] if self.number_normalized: # Mapping numbers to 0 char = normalize_word(char) label = pair[-1] self.label_alphabet.add(label) self.char_alphabet.add(char) self.label_alphabet.close() self.char_alphabet.close() def build_word_file(self, word_file): # build word file,initial word embedding file with open(word_file, 'r', encoding="utf-8") as f: for line in f: word = line.strip().split()[0] if word: self.word_dict.insert(word) print("Building the word dict...") def build_word_alphabet(self, input_file): print("Loading file: " + input_file) self.word_alphabet.open() word_list = [] with open(input_file, 'r', encoding="utf-8") as f: for line in f: line = line.strip() if len(line) > 0: word = line.split()[0] if self.number_normalized: word = normalize_word(word) word_list.append(word) else: for idx in range(len(word_list)): matched_words = self.word_dict.recursive_search(word_list[idx:]) for matched_word in matched_words: self.word_alphabet.add(matched_word) word_list = [] self.word_alphabet.close() print("word alphabet size:", self.word_alphabet.size()) def build_char_pretrain_emb(self, emb_path): print ("Building character pretrain emb...") self.pretrain_char_embedding, self.char_emb_dim = build_pretrain_embedding(emb_path, self.char_alphabet, self.norm_char_emb) def build_word_pretrain_emb(self, emb_path): print ("Building word pretrain emb...") self.pretrain_word_embedding, self.word_emb_dim = build_pretrain_embedding(emb_path, self.word_alphabet, self.norm_word_emb) def generate_instance_with_words(self, input_file, name): if name == "train": self.train_texts, self.train_Ids = read_instance_with_gaz(input_file, self.word_dict, self.char_alphabet, self.word_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH) elif name == "dev": self.dev_texts, self.dev_Ids = read_instance_with_gaz(input_file, self.word_dict, self.char_alphabet, self.word_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH) elif name == "test": self.test_texts, self.test_Ids = read_instance_with_gaz(input_file, self.word_dict, self.char_alphabet, self.word_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH) elif name == "raw": self.raw_texts, self.raw_Ids = read_instance_with_gaz(input_file, self.word_dict, self.char_alphabet, self.word_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH) else: print("Error: you can only generate train/dev/test/raw instance! Illegal input:%s"%(name)) def write_decoded_results(self, output_file, predict_results, name): fout = open(output_file, 'w', encoding="utf-8") sent_num = len(predict_results) content_list = [] if name == 'raw': content_list = self.raw_texts elif name == 'test': content_list = self.test_texts elif name == 'dev': content_list = self.dev_texts elif name == 'train': content_list = self.train_texts else: print("Error: illegal name during writing predict result, name should be within train/dev/test/raw !") assert(sent_num == len(content_list)) for idx in range(sent_num): sent_length = len(predict_results[idx]) for idy in range(sent_length): # content_list[idx] is a list with [word, char, label] fout.write(content_list[idx][0][idy] + " " + predict_results[idx][idy] + '\n') fout.write('\n') fout.close() print("Predict %s result has been written into file. %s"%(name, output_file)) ================================================ FILE: utils/functions.py ================================================ # -*- coding: utf-8 -*- # @Author: Jie Yang # @Last Modified by: Yicheng Zou, Contact: yczou18@fudan.edu.cn import numpy as np def normalize_word(word): new_word = "" for char in word: if char.isdigit(): new_word += '0' else: new_word += char return new_word def read_instance_with_gaz(input_file, word_dict, char_alphabet, word_alphabet, label_alphabet, number_normalized, max_sent_length): instence_texts = [] instence_Ids = [] with open(input_file, 'r', encoding="utf-8") as f: chars = [] labels = [] char_Ids = [] label_Ids = [] for line in f: if len(line) > 1: pairs = line.strip().split() char = pairs[0] if number_normalized: char = normalize_word(char) chars.append(char) char_Ids.append(char_alphabet.get_index(char)) if len(pairs) > 1: label = pairs[-1] else: label = 'O' labels.append(label) label_Ids.append(label_alphabet.get_index(label)) # A sentence is finished. else: # Only keep the sentence whose length is smaller than MAX_SENT_LENGTH. if ((max_sent_length < 0) or (len(chars) < max_sent_length)) and (len(chars)>0): words = [] word_Ids = [] for idx in range(len(chars)): matched_list = word_dict.recursive_search(chars[idx:]) matched_length = [len(a) for a in matched_list] words.append(matched_list) matched_Id = [word_alphabet.get_index(word) for word in matched_list] if matched_Id: word_Ids.append([matched_Id, matched_length]) else: word_Ids.append([]) instence_texts.append([chars, words, labels]) instence_Ids.append([char_Ids, word_Ids, label_Ids]) chars = [] labels = [] char_Ids = [] label_Ids = [] return instence_texts, instence_Ids def build_pretrain_embedding(embedding_path, word_alphabet, norm=True, embedd_dim=50): def norm2one(vec): root_sum_square = np.sqrt(np.sum(np.square(vec))) return vec / root_sum_square embedd_dict = dict() if embedding_path != None: embedd_dict, embedd_dim = load_pretrain_emb(embedding_path) scale = np.sqrt(3.0 / embedd_dim) pretrain_emb = np.empty([word_alphabet.size(), embedd_dim]) not_match = 0 for word, index in word_alphabet.instance2index.items(): if word.lower() in embedd_dict: if norm: pretrain_emb[index,:] = norm2one(embedd_dict[word.lower()]) else: pretrain_emb[index,:] = embedd_dict[word.lower()] elif word in embedd_dict: if norm: pretrain_emb[index,:] = norm2one(embedd_dict[word]) else: pretrain_emb[index,:] = embedd_dict[word] else: pretrain_emb[index,:] = np.random.uniform(-scale, scale, [1, embedd_dim]) not_match += 1 pretrained_size = len(embedd_dict) print("Embedding:\n pretrain word:%s, match:%s, oov:%s, oov%%:%.4f" % (pretrained_size, word_alphabet.size() - not_match, not_match, (not_match+0.)/word_alphabet.size())) return pretrain_emb, embedd_dim def load_pretrain_emb(embedding_path): embedd_dict = dict() with open(embedding_path, 'r', encoding="utf-8") as f: for line in f: line = line.strip() if len(line) == 0: continue tokens = line.split() embedd_dim = len(tokens) - 1 embedd = np.empty([1, embedd_dim]) embedd[:] = tokens[1:] embedd_dict[tokens[0]] = embedd return embedd_dict, embedd_dim ================================================ FILE: utils/metric.py ================================================ # -*- coding: utf-8 -*- # @Author: Jie Yang # @Last Modified by: Yicheng Zou, Contact: yczou18@fudan.edu.cn # input as sentence level labels def get_ner_fmeasure(golden_lists, predict_lists): sent_num = len(golden_lists) golden_full = [] predict_full = [] right_full = [] right_tag = 0 all_tag = 0 for idx in range(0,sent_num): golden_list = golden_lists[idx] predict_list = predict_lists[idx] for idy in range(len(golden_list)): if golden_list[idy] == predict_list[idy]: right_tag += 1 all_tag += len(golden_list) gold_matrix = get_ner_BMES(golden_list) pred_matrix = get_ner_BMES(predict_list) right_ner = list(set(gold_matrix).intersection(set(pred_matrix))) golden_full += gold_matrix predict_full += pred_matrix right_full += right_ner right_num = len(right_full) golden_num = len(golden_full) predict_num = len(predict_full) if predict_num == 0: precision = -1 else: precision = (right_num+0.0)/predict_num if golden_num == 0: recall = -1 else: recall = (right_num+0.0)/golden_num if (precision == -1) or (recall == -1) or (precision+recall) <= 0.: f_measure = -1 else: f_measure = 2*precision*recall/(precision+recall) accuracy = (right_tag+0.0)/all_tag print("gold_num = ", golden_num, " pred_num = ", predict_num, " right_num = ", right_num) return accuracy, precision, recall, f_measure def reverse_style(input_string): target_position = input_string.index('[') input_len = len(input_string) output_string = input_string[target_position:input_len] + input_string[0:target_position] return output_string def get_ner_BMES(label_list): list_len = len(label_list) begin_label = 'B-' end_label = 'E-' single_label = 'S-' whole_tag = '' index_tag = '' tag_list = [] stand_matrix = [] for i in range(0, list_len): # wordlabel = word_list[i] current_label = label_list[i].upper() if label_list[i] else [] if begin_label in current_label: if index_tag != '': tag_list.append(whole_tag + ',' + str(i-1)) whole_tag = current_label.replace(begin_label,"",1) +'[' +str(i) index_tag = current_label.replace(begin_label,"",1) elif single_label in current_label: if index_tag != '': tag_list.append(whole_tag + ',' + str(i-1)) whole_tag = current_label.replace(single_label,"",1) +'[' +str(i) tag_list.append(whole_tag) whole_tag = "" index_tag = "" elif end_label in current_label: if index_tag != '': tag_list.append(whole_tag +',' + str(i)) whole_tag = '' index_tag = '' else: continue if (whole_tag != '')&(index_tag != ''): tag_list.append(whole_tag) tag_list_len = len(tag_list) for i in range(0, tag_list_len): if len(tag_list[i]) > 0: tag_list[i] = tag_list[i]+ ']' insert_list = reverse_style(tag_list[i]) stand_matrix.append(insert_list) return stand_matrix ================================================ FILE: utils/word_trie.py ================================================ # -*- coding: utf-8 -*- # @Author: Yicheng Zou # @Last Modified by: Yicheng Zou, Contact: yczou18@fudan.edu.cn _end = "_end_" class Word_Trie: def __init__(self): self.root = dict() def recursive_search(self, word_list): match_list = [] while len(word_list) > 1: if self.search(word_list): match_list.append("".join(word_list)) del word_list[-1] return match_list def search(self, word): current_dict = self.root for char in word: if char in current_dict: current_dict = current_dict[char] else: return False else: if _end in current_dict: return True else: return False def insert(self, word): current_dict = self.root for char in word: current_dict = current_dict.setdefault(char, {}) current_dict[_end] = _end