Full Code of RowitZou/LGN for AI

master c1e7a681f514 cached
11 files
84.2 KB
21.7k tokens
78 symbols
1 requests
Download .txt
Repository: RowitZou/LGN
Branch: master
Commit: c1e7a681f514
Files: 11
Total size: 84.2 KB

Directory structure:
gitextract_y25cez6i/

├── LICENSE
├── README.md
├── main.py
├── model/
│   ├── LGN.py
│   ├── crf.py
│   └── module.py
└── utils/
    ├── alphabet.py
    ├── data.py
    ├── functions.py
    ├── metric.py
    └── word_trie.py

================================================
FILE CONTENTS
================================================

================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2019 Yicheng Zou

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
# LGN

Pytorch implementation of [A Lexicon-Based Graph Neural Network for Chinese NER](https://www.aclweb.org/anthology/D19-1096.pdf).

The code is partially referred to https://github.com/jiesutd/LatticeLSTM.

## Requirements

* Python 3.6 or higher
* Pytorch 0.4.1 or higher

## Input Format

BMES tag scheme, with each character its label for one line. Sentences are splited with a null line.

	印   B-LOC
	度   M-LOC
	河   E-LOC
	流   O
	经   O
	印   B-GPE
	度   E-GPE

## Usage

* Training

		python main.py --status train \
		               --train data/onto4ner.cn/train.char.bmes \
		               --dev data/onto4ner.cn/dev.char.bmes \
		               --test data/onto4ner.cn/test.char.bmes \
		               --saved_model saved_model/model_onto4ner \
		               --saved_set data/onto4ner.cn/saved.dset
		               
* Testing

		python main.py --status test \
		               --test data/onto4ner.cn/test.char.bmes \
		               --saved_model saved_model/model_onto4ner \
		               --saved_set data/onto4ner.cn/saved.dset
		               
* Decoding (Raw file can either be labeled or not.)

		python main.py --status decode \
		               --raw data/onto4ner.cn/test.char.bmes \
		               --output tagged_file.txt \
		               --saved_model saved_model/model_onto4ner \
		               --saved_set data/onto4ner.cn/saved.dset
		               
## Data

The pretrained character and word embeddings can be downloaded from [Lattice LSTM](https://github.com/jiesutd/LatticeLSTM).

Original datasets can be found at [OntoNotes](https://catalog.ldc.upenn.edu/LDC2011T03), [MSRA](http://sighan.cs.uchicago.edu/bakeoff2006/), 
[Weibo](https://github.com/hltcoe/golden-horse) and [Resume](https://github.com/jiesutd/LatticeLSTM/tree/master/ResumeNER).
The preprocessed datasets that satisfy the input format of our codes are available at [Google Drive](https://drive.google.com/open?id=1Rvju5_gp2E6BFiqzMBtnMqVP803AbBcm) and 
[Baidu Pan](https://pan.baidu.com/s/1zbzLriRpc8S_5ez_upC7OA) (Code: akcm)

## Pretrained Model Downloads

We also provide pretrained models on the four datasets, which are the same models as reported in the paper.
If you try to retrain models from scratch under the same hyper-parameter settings, you may obtain a sightly 
lower or higher F1 score than that reported in the paper (in our experiments we selected the model that performed best).

Pretrained models and related hyper-parameter settings are available at [Google Drive](https://drive.google.com/file/d/1KKkCW8WRhgR2P2UbRpNpKyE_RAv1EREv/view?usp=sharing) or [Baidu Pan](https://pan.baidu.com/s/1U89EwnhPpMa4bNrS--u4EA).

When running main.py in test mode for pretrained models, you can get the results as follows:

| Datasets       | Precision | Recall  | F1    | 
|:--------------:|:---------:|:-------:|:-----:|
| OntoNotes dev  |   74.00   |  70.03  | 71.96 |
| OntoNotes test |   76.13   |  73.68  | 74.89 | 
| MSRA dev       |     -     |   -     |   -   |
| MSRA test      |   94.19   |  92.73  | 93.46 |
| Weibo dev      |   66.09   |  59.13  | 62.42 |
| Weibo test     |   65.71   |  55.56  | 60.21 |
| Resume dev     |   94.27   |  94.59  | 94.43 |
| Resume test    |   95.28   |  95.46  | 95.37 |

## Citation

	@inproceedings{gui2019lexicon,
  	 title={A Lexicon-Based Graph Neural Network for Chinese NER},
  	 author={Gui, Tao and Zou, Yicheng and Zhang, Qi and Peng, Minlong and 
	 Fu, Jinlan and Wei, Zhongyu and Huang, Xuanjing},
  	 booktitle={Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing 
	 and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)},
  	 pages={1039--1049},
  	 year={2019}
	}


================================================
FILE: main.py
================================================
# -*- coding: utf-8 -*-
# @Author: Yicheng Zou
# @Last Modified by:   Yicheng Zou,    Contact: yczou18@fudan.edu.cn

import time
import sys
import argparse
import random
import torch
import gc
import pickle
import os
import torch.autograd as autograd
import torch.optim as optim
import numpy as np
from utils.metric import get_ner_fmeasure
from model.LGN import Graph
from utils.data import Data


def str2bool(v):
    if isinstance(v, bool):
       return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def lr_decay(optimizer, epoch, decay_rate, init_lr):
    lr = init_lr * ((1-decay_rate)**epoch)
    print( " Learning rate is setted as:", lr)
    for param_group in optimizer.param_groups:
        if param_group['name'] == 'aggr':
            param_group['lr'] = lr * 2.
        else:
            param_group['lr'] = lr
    return optimizer


def data_initialization(data, word_file, train_file, dev_file, test_file):

    data.build_word_file(word_file)

    if train_file:
        data.build_alphabet(train_file)
        data.build_word_alphabet(train_file)
    if dev_file:
        data.build_alphabet(dev_file)
        data.build_word_alphabet(dev_file)
    if test_file:
        data.build_alphabet(test_file)
        data.build_word_alphabet(test_file)
    return data


def predict_check(pred_variable, gold_variable, mask_variable):

    pred = pred_variable.cpu().data.numpy()
    gold = gold_variable.cpu().data.numpy()
    mask = mask_variable.cpu().data.numpy()
    overlaped = (pred == gold)
    right_token = np.sum(overlaped * mask)
    total_token = mask.sum()
    return right_token, total_token


def recover_label(pred_variable, gold_variable, mask_variable, label_alphabet):

    batch_size = gold_variable.size(0)
    seq_len = gold_variable.size(1)
    mask = mask_variable.cpu().data.numpy()
    pred_tag = pred_variable.cpu().data.numpy()
    gold_tag = gold_variable.cpu().data.numpy()
    pred_label = []
    gold_label = []

    for idx in range(batch_size):
        pred = [label_alphabet.get_instance(pred_tag[idx][idy]) for idy in range(seq_len) if mask[idx][idy] != 0]
        gold = [label_alphabet.get_instance(gold_tag[idx][idy]) for idy in range(seq_len) if mask[idx][idy] != 0]
        assert(len(pred)==len(gold))
        pred_label.append(pred)
        gold_label.append(gold)

    return pred_label, gold_label


def print_args(args):
    print("CONFIG SUMMARY:")
    print("     Batch size: %s" % (args.batch_size))
    print("     If use GPU: %s" % (args.use_gpu))
    print("     If use CRF: %s" % (args.use_crf))
    print("     Epoch  number: %s" % (args.num_epoch))
    print("     Learning rate: %s" % (args.lr))
    print("     L2 normalization rate: %s" % (args.weight_decay))
    print("     If use edge embedding: %s" % (args.use_edge))
    print("     If  use  global  node: %s" % (args.use_global))
    print("     Bidirectional digraph: %s" % (args.bidirectional))
    print("     Update   step  number: %s" % (args.iters))
    print("     Attention  dropout   rate: %s" % (args.tf_drop_rate))
    print("     Embedding  dropout   rate: %s" % (args.emb_drop_rate))
    print("     Hidden  state   dimension: %s" % (args.hidden_dim))
    print("     Learning rate decay ratio: %s" % (args.lr_decay))
    print("     Aggregation module dropout rate: %s" % (args.cell_drop_rate))
    print("     Head    number   of   attention: %s" % (args.num_head))
    print("     Head  dimension   of  attention: %s" % (args.head_dim))
    print("CONFIG SUMMARY END.")
    sys.stdout.flush()


def evaluate(data, args, model, name):
    if name == "train":
        instances = data.train_Ids
    elif name == "dev":
        instances = data.dev_Ids
    elif name == 'test':
        instances = data.test_Ids
    elif name == 'raw':
        instances = data.raw_Ids
    else:
        print("Error: wrong evaluate name,", name)
        exit(0)

    pred_results = []
    gold_results = []

    # set model in eval model
    model.eval()
    batch_size = args.batch_size
    start_time = time.time()
    train_num = len(instances)
    total_batch = train_num // batch_size + 1

    for batch_id in range(total_batch):
        start = batch_id*batch_size
        end = (batch_id+1)*batch_size
        if end > train_num:
            end = train_num
        instance = instances[start:end]
        if not instance:
            continue

        word_list, batch_char, batch_label, mask = batchify_with_label(instance, args.use_gpu)
        _, tag_seq = model(word_list, batch_char, mask)

        pred_label, gold_label = recover_label(tag_seq, batch_label, mask, data.label_alphabet)

        pred_results += pred_label
        gold_results += gold_label

    decode_time = time.time() - start_time
    speed = len(instances) / decode_time

    acc, p, r, f = get_ner_fmeasure(gold_results, pred_results)
    return speed, acc, p, r, f, pred_results


def batchify_with_label(input_batch_list, gpu):

    batch_size = len(input_batch_list)
    chars = [sent[0] for sent in input_batch_list]
    words = [sent[1] for sent in input_batch_list]
    labels = [sent[2] for sent in input_batch_list]

    sent_lengths = torch.LongTensor(list(map(len, chars)))
    max_sent_len = sent_lengths.max()
    char_seq_tensor = autograd.Variable(torch.zeros((batch_size, max_sent_len))).long()
    label_seq_tensor = autograd.Variable(torch.zeros((batch_size, max_sent_len))).long()
    mask = autograd.Variable(torch.zeros((batch_size, max_sent_len))).byte()

    for idx, (seq, label, seq_len) in enumerate(zip(chars, labels, sent_lengths)):
        char_seq_tensor[idx, :seq_len] = torch.LongTensor(seq)
        label_seq_tensor[idx, :seq_len] = torch.LongTensor(label)
        mask[idx, :seq_len] = torch.Tensor([1] * int(seq_len))

    if gpu:
        char_seq_tensor = char_seq_tensor.cuda()
        label_seq_tensor = label_seq_tensor.cuda()
        mask = mask.cuda()

    return words, char_seq_tensor, label_seq_tensor, mask


def train(data, args, saved_model_path):

    print( "Training model...")
    model = Graph(data, args)
    if args.use_gpu:
        model = model.cuda()
    print('# generated parameters:', sum(param.numel() for param in model.parameters()))
    print( "Finished built model.")

    best_dev_epoch = 0
    best_dev_f = -1
    best_dev_p = -1
    best_dev_r = -1

    best_test_f = -1
    best_test_p = -1
    best_test_r = -1

    # Initialize the optimizer
    aggr_module_params = []
    other_module_params = []
    for m_name in model._modules:
        m = model._modules[m_name]
        if isinstance(m, torch.nn.ModuleList):
            for p in m.parameters():
                if p.requires_grad:
                    aggr_module_params.append(p)
        else:
            for p in m.parameters():
                if p.requires_grad:
                    other_module_params.append(p)

    optimizer = optim.Adam([
            {"params": (aggr_module_params), "name": "aggr"},
            {"params": (other_module_params), "name": "other"}
        ],
        lr=args.lr,
        weight_decay=args.weight_decay
    )

    for idx in range(args.num_epoch):
        epoch_start = time.time()
        temp_start = epoch_start
        print(("Epoch: %s/%s" %(idx, args.num_epoch)))
        optimizer = lr_decay(optimizer, idx, args.lr_decay, args.lr)
        sample_loss = 0
        batch_loss = 0
        total_loss = 0
        right_token = 0
        whole_token = 0
        random.shuffle(data.train_Ids)
        # set model in train model
        model.train()
        model.zero_grad()
        batch_size = args.batch_size
        train_num = len(data.train_Ids)
        total_batch = train_num // batch_size + 1

        for batch_id in range(total_batch):
            # Get one batch-sized instance
            start = batch_id * batch_size
            end = (batch_id + 1) * batch_size
            if end > train_num:
                end = train_num
            instance = data.train_Ids[start:end]
            if not instance:
                continue

            word_list, batch_char, batch_label, mask = batchify_with_label(instance, args.use_gpu)
            loss, tag_seq = model(word_list, batch_char, mask, batch_label)
            right, whole = predict_check(tag_seq, batch_label, mask)
            right_token += right
            whole_token += whole
            sample_loss += loss.data
            total_loss += loss.data
            batch_loss += loss

            if end % 500 == 0:
                temp_time = time.time()
                temp_cost = temp_time - temp_start
                temp_start = temp_time
                print(("     Instance: %s; Time: %.2fs; loss: %.4f; acc: %s/%s=%.4f" %
                       (end, temp_cost, sample_loss, right_token, whole_token, (right_token+0.)/whole_token)))
                sys.stdout.flush()
                sample_loss = 0
            if end % args.batch_size == 0:
                batch_loss.backward()
                optimizer.step()
                model.zero_grad()
                batch_loss = 0

        temp_time = time.time()
        temp_cost = temp_time - temp_start
        print(("     Instance: %s; Time: %.2fs; loss: %.4f; acc: %s/%s=%.4f" %
               (end, temp_cost, sample_loss, right_token, whole_token, (right_token+0.)/whole_token)))
        epoch_finish = time.time()
        epoch_cost = epoch_finish - epoch_start
        print(("Epoch: %s training finished. Time: %.2fs, speed: %.2fst/s,  total loss: %s" %
               (idx, epoch_cost, train_num/epoch_cost, total_loss)))

        # dev
        speed, acc, dev_p, dev_r, dev_f, _ = evaluate(data, args, model, "dev")
        dev_finish = time.time()
        dev_cost = dev_finish - epoch_finish

        print(("Dev: time: %.2fs, speed: %.2fst/s; acc: %.4f, p: %.4f, r: %.4f, f: %.4f" %
               (dev_cost, speed, acc, dev_p, dev_r, dev_f)))

        # test
        speed, acc, test_p, test_r, test_f, _ = evaluate(data, args, model, "test")
        test_finish = time.time()
        test_cost = test_finish - dev_finish

        print(("Test: time: %.2fs, speed: %.2fst/s; acc: %.4f, p: %.4f, r: %.4f, f: %.4f" %
               (test_cost, speed, acc, test_p, test_r, test_f)))

        if dev_f > best_dev_f:
            print("Exceed previous best f score: %.4f" % best_dev_f)
            torch.save(model.state_dict(), saved_model_path + "_best")
            best_dev_p = dev_p
            best_dev_r = dev_r
            best_dev_f = dev_f
            best_dev_epoch = idx + 1
            best_test_p = test_p
            best_test_r = test_r
            best_test_f = test_f

        model_idx_path = saved_model_path + "_" + str(idx)
        torch.save(model.state_dict(), model_idx_path)
        with open(saved_model_path + "_result.txt", "a") as file:
            file.write(model_idx_path + '\n')
            file.write("Dev score: %.4f, r: %.4f, f: %.4f\n" % (dev_p, dev_r, dev_f))
            file.write("Test score: %.4f, r: %.4f, f: %.4f\n\n" % (test_p, test_r, test_f))
            file.close()

        print("Best dev epoch: %d" % best_dev_epoch)
        print("Best dev score: p: %.4f, r: %.4f, f: %.4f" % (best_dev_p, best_dev_r, best_dev_f))
        print("Best test score: p: %.4f, r: %.4f, f: %.4f" % (best_test_p, best_test_r, best_test_f))

        gc.collect()

    with open(saved_model_path + "_result.txt", "a") as file:
        file.write("Best epoch: %d" % best_dev_epoch + '\n')
        file.write("Best Dev score: %.4f, r: %.4f, f: %.4f\n" % (best_dev_p, best_dev_r, best_dev_f))
        file.write("Test score: %.4f, r: %.4f, f: %.4f\n\n" % (best_test_p, best_test_r, best_test_f))
        file.close()

    with open(saved_model_path + "_best_HP.config", "wb") as file:
        pickle.dump(args, file)


def load_model_decode(model_dir, data, args, name):
    model_dir = model_dir + "_best"
    print("Load Model from file: ", model_dir)
    model = Graph(data, args)
    model.load_state_dict(torch.load(model_dir))

    # load model need consider if the model trained in GPU and load in CPU, or vice versa
    if args.use_gpu:
        model = model.cuda()

    print(("Decode %s data ..." % name))
    start_time = time.time()
    speed, acc, p, r, f, pred_results = evaluate(data, args, model, name)
    end_time = time.time()
    time_cost = end_time - start_time
    print(("%s: time:%.2fs, speed:%.2fst/s; acc: %.4f, p: %.4f, r: %.4f, f: %.4f" %
           (name, time_cost, speed, acc, p, r, f)))

    return pred_results


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--status', choices=['train', 'test', 'decode'], help='Function status.', default='train')
    parser.add_argument('--use_gpu', type=str2bool, default=True)
    parser.add_argument('--train', help='Training set.', default='data/onto4ner.cn/train.char.bmes')
    parser.add_argument('--dev', help='Developing set.', default='data/onto4ner.cn/dev.char.bmes')
    parser.add_argument('--test', help='Testing set.', default='data/onto4ner.cn/test.char.bmes')
    parser.add_argument('--raw', help='Raw file for decoding.')
    parser.add_argument('--output', help='Output results for decoding.')
    parser.add_argument('--saved_set', help='Path of saved data set.', default='data/onto4ner.cn/saved.dset')
    parser.add_argument('--saved_model', help='Path of saved model.', default="saved_model/model_onto4ner")
    parser.add_argument('--char_emb', help='Path of character embedding file.', default="data/gigaword_chn.all.a2b.uni.ite50.vec")
    parser.add_argument('--word_emb', help='Path of word embedding file.', default="data/ctb.50d.vec")

    parser.add_argument('--use_crf', type=str2bool, default=True)
    parser.add_argument('--use_edge', type=str2bool, default=True, help='If use lexicon embeddings (edge embeddings).')
    parser.add_argument('--use_global', type=str2bool, default=True, help='If use the global node.')
    parser.add_argument('--bidirectional', type=str2bool, default=True, help='If use bidirectional digraph.')

    parser.add_argument('--seed', help='Random seed', default=1023, type=int)
    parser.add_argument('--batch_size', help='Batch size.', default=1, type=int)
    parser.add_argument('--num_epoch',default=100, type=int, help="Epoch number.")
    parser.add_argument('--iters', default=4, type=int, help='The number of Graph iterations.')
    parser.add_argument('--hidden_dim', default=50, type=int, help='Hidden state size.')
    parser.add_argument('--num_head', default=10, type=int, help='Number of transformer head.')
    parser.add_argument('--head_dim', default=20, type=int, help='Head dimension of transformer.')
    parser.add_argument('--tf_drop_rate', default=0.1, type=float, help='Transformer dropout rate.')
    parser.add_argument('--emb_drop_rate', default=0.5, type=float, help='Embedding dropout rate.')
    parser.add_argument('--cell_drop_rate', default=0.2, type=float, help='Aggregation module dropout rate.')
    parser.add_argument('--word_alphabet_size', type=int, help='Word alphabet size.')
    parser.add_argument('--char_alphabet_size', type=int, help='Char alphabet size.')
    parser.add_argument('--label_alphabet_size', type=int, help='Label alphabet size.')
    parser.add_argument('--char_dim', type=int, help='Char embedding size.')
    parser.add_argument('--word_dim', type=int, help='Word embedding size.')
    parser.add_argument('--lr', type=float, default=2e-05)
    parser.add_argument('--lr_decay', type=float, default=0)
    parser.add_argument('--weight_decay', type=float, default=0)

    args = parser.parse_args()

    status = args.status.lower()
    seed_num = args.seed
    random.seed(seed_num)
    torch.manual_seed(seed_num)
    np.random.seed(seed_num)

    train_file = args.train
    dev_file = args.dev
    test_file = args.test
    raw_file = args.raw
    output_file = args.output
    saved_set_path = args.saved_set
    saved_model_path = args.saved_model
    char_file = args.char_emb
    word_file = args.word_emb

    if status == 'train':
        if os.path.exists(saved_set_path):
            print('Loading saved data set...')
            with open(saved_set_path, 'rb') as f:
                data = pickle.load(f)
        else:
            data = Data()
            data_initialization(data, word_file, train_file, dev_file, test_file)
            data.generate_instance_with_words(train_file, 'train')
            data.generate_instance_with_words(dev_file, 'dev')
            data.generate_instance_with_words(test_file, 'test')
            data.build_char_pretrain_emb(char_file)
            data.build_word_pretrain_emb(word_file)
            if saved_set_path is not None:
                print('Dumping data...')
                with open(saved_set_path, 'wb') as f:
                    pickle.dump(data, f)
        data.show_data_summary()
        args.word_alphabet_size = data.word_alphabet.size()
        args.char_alphabet_size = data.char_alphabet.size()
        args.label_alphabet_size = data.label_alphabet.size()
        args.char_dim = data.char_emb_dim
        args.word_dim = data.word_emb_dim
        print_args(args)
        train(data, args, saved_model_path)

    elif status == 'test':
        assert not (test_file is None)
        if os.path.exists(saved_set_path):
            print('Loading saved data set...')
            with open(saved_set_path, 'rb') as f:
                data = pickle.load(f)
        else:
            print("Cannot find saved data set: ", saved_set_path)
            exit(0)
        data.generate_instance_with_words(test_file, 'test')
        with open(saved_model_path + "_best_HP.config", "rb") as f:
            args = pickle.load(f)
        data.show_data_summary()
        print_args(args)
        load_model_decode(saved_model_path, data, args, "test")

    elif status == 'decode':
        assert not (raw_file is None or output_file is None)
        if os.path.exists(saved_set_path):
            print('Loading saved data set...')
            with open(saved_set_path, 'rb') as f:
                data = pickle.load(f)
        else:
            print("Cannot find saved data set: ", saved_set_path)
            exit(0)
        data.generate_instance_with_words(raw_file, 'raw')
        with open(saved_model_path + "_best_HP.config", "rb") as f:
            args = pickle.load(f)
        data.show_data_summary()
        print_args(args)
        decode_results = load_model_decode(saved_model_path, data, args, "raw")
        data.write_decoded_results(output_file, decode_results, 'raw')
    else:
        print("Invalid argument! Please use valid arguments! (train/test/decode)")


================================================
FILE: model/LGN.py
================================================
# -*- coding: utf-8 -*-
# @Author: Yicheng Zou
# @Last Modified by:   Yicheng Zou,     Contact: yczou18@fudan.edu.cn

from model.crf import CRF
from model.module import *


class Graph(nn.Module):
    def __init__(self, data, args):
        super(Graph, self).__init__()

        self.gpu = args.use_gpu
        self.char_emb_dim = args.char_dim
        self.word_emb_dim = args.word_dim
        self.hidden_dim = args.hidden_dim
        self.num_head = args.num_head  # 5 10 20
        self.head_dim = args.head_dim  # 10 20
        self.tf_dropout_rate = args.tf_drop_rate
        self.iters = args.iters
        self.bmes_dim = 10
        self.length_dim = 10
        self.max_word_length = 5
        self.emb_dropout_rate = args.emb_drop_rate
        self.cell_dropout_rate = args.cell_drop_rate
        self.use_crf = args.use_crf
        self.use_global = args.use_global
        self.use_edge = args.use_edge
        self.bidirectional = args.bidirectional
        self.label_size = args.label_alphabet_size

        # char embedding
        self.char_embedding = nn.Embedding(args.char_alphabet_size, self.char_emb_dim)
        if data.pretrain_char_embedding is not None:
            self.char_embedding.weight.data.copy_(torch.from_numpy(data.pretrain_char_embedding))

        if self.use_edge:

            # word embedding
            self.word_embedding = nn.Embedding(args.word_alphabet_size, self.word_emb_dim)
            if data.pretrain_word_embedding is not None:
                scale = np.sqrt(3.0 / self.word_emb_dim)
                data.pretrain_word_embedding[0, :] = np.random.uniform(-scale, scale, [1, self.word_emb_dim])
                self.word_embedding.weight.data.copy_(torch.from_numpy(data.pretrain_word_embedding))

            # bmes embedding
            self.bmes_embedding = nn.Embedding(4, self.bmes_dim)
            """
            self.edge_emb_linear = nn.Sequential(
                nn.Linear(self.word_emb_dim, self.hidden_dim),
                nn.ELU()
            )
            """
        # lstm
        self.emb_rnn_f = nn.LSTM(self.char_emb_dim, self.hidden_dim, batch_first=True)
        self.emb_rnn_b = nn.LSTM(self.char_emb_dim, self.hidden_dim, batch_first=True)

        # length embedding
        self.length_embedding = nn.Embedding(self.max_word_length, self.length_dim)

        self.dropout = nn.Dropout(self.emb_dropout_rate)
        self.norm = nn.LayerNorm(self.hidden_dim)

        if self.use_edge:
            # Node aggregation module
            self.edge2node_f = nn.ModuleList(
                [MultiHeadAtt(self.hidden_dim, self.hidden_dim * 2 + self.length_dim,
                              nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate)
                 for _ in range(self.iters)])
            # Edge aggregation module
            self.node2edge_f = nn.ModuleList(
                [MultiHeadAtt(self.hidden_dim, self.hidden_dim + self.bmes_dim, nhead=self.num_head,
                              head_dim=self.head_dim, dropout=self.tf_dropout_rate)
                 for _ in range(self.iters)])

        else:
            # Node aggregation module
            self.edge2node_f = nn.ModuleList(
                [MultiHeadAtt(self.hidden_dim, self.hidden_dim + self.length_dim,
                              nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate)
                 for _ in range(self.iters)])

        if self.use_global:
            # Global Node aggregation module
            self.glo_att_f_node = nn.ModuleList(
                [GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate)
                 for _ in range(self.iters)])

            if self.use_edge:
                self.glo_att_f_edge = nn.ModuleList(
                    [GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate)
                     for _ in range(self.iters)])

            # Updating modules
            if self.use_edge:
                self.glo_rnn_f = Global_Cell(self.hidden_dim * 3, self.hidden_dim, dropout=self.cell_dropout_rate)
                self.node_rnn_f = Nodes_Cell(self.hidden_dim * 5, self.hidden_dim, dropout=self.cell_dropout_rate)
                self.edge_rnn_f = Edges_Cell(self.hidden_dim * 4, self.hidden_dim, dropout=self.cell_dropout_rate)
            else:
                self.glo_rnn_f = Global_Cell(self.hidden_dim * 2, self.hidden_dim, dropout=self.cell_dropout_rate)
                self.node_rnn_f = Nodes_Cell(self.hidden_dim * 4, self.hidden_dim, dropout=self.cell_dropout_rate)

        else:
            # Updating modules
            self.node_rnn_f = Nodes_Cell(self.hidden_dim * 3, self.hidden_dim, use_global=False, dropout=self.cell_dropout_rate)
            if self.use_edge:
                self.edge_rnn_f = Edges_Cell(self.hidden_dim * 2, self.hidden_dim, use_global=False, dropout=self.cell_dropout_rate)

        if self.bidirectional:

            if self.use_edge:
                # Node aggregation module
                self.edge2node_b = nn.ModuleList(
                    [MultiHeadAtt(self.hidden_dim, self.hidden_dim * 2 + self.length_dim,
                                  nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate)
                     for _ in range(self.iters)])
                # Edge aggregation module
                self.node2edge_b = nn.ModuleList(
                    [MultiHeadAtt(self.hidden_dim, self.hidden_dim + self.bmes_dim, nhead=self.num_head,
                                  head_dim=self.head_dim, dropout=self.tf_dropout_rate)
                     for _ in range(self.iters)])

            else:
                # Node aggregation module
                self.edge2node_b = nn.ModuleList(
                    [MultiHeadAtt(self.hidden_dim, self.hidden_dim + self.length_dim,
                                  nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate)
                     for _ in range(self.iters)])

            if self.use_global:
                # Global Node aggregation module
                self.glo_att_b_node = nn.ModuleList(
                    [GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate)
                     for _ in range(self.iters)])
                if self.use_edge:
                    self.glo_att_b_edge = nn.ModuleList(
                        [GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate)
                         for _ in range(self.iters)])

                # Updating modules
                if self.use_edge:
                    self.glo_rnn_b = Global_Cell(self.hidden_dim * 3, self.hidden_dim, self.cell_dropout_rate)
                    self.node_rnn_b = Nodes_Cell(self.hidden_dim * 5, self.hidden_dim, self.cell_dropout_rate)
                    self.edge_rnn_b = Edges_Cell(self.hidden_dim * 4, self.hidden_dim, self.cell_dropout_rate)
                else:
                    self.glo_rnn_b = Global_Cell(self.hidden_dim * 2, self.hidden_dim, self.cell_dropout_rate)
                    self.node_rnn_b = Nodes_Cell(self.hidden_dim * 4, self.hidden_dim, self.cell_dropout_rate)

            else:
                # Updating modules
                self.node_rnn_b = Nodes_Cell(self.hidden_dim * 3, self.hidden_dim, use_global=False, dropout=self.cell_dropout_rate)
                if self.use_edge:
                    self.edge_rnn_b = Edges_Cell(self.hidden_dim * 2, self.hidden_dim, use_global=False, dropout=self.cell_dropout_rate)

        if self.bidirectional:
            output_dim = self.hidden_dim * 2
        else:
            output_dim = self.hidden_dim

        self.layer_att_W = nn.Linear(output_dim, 1)

        if self.use_crf:
            self.hidden2tag = nn.Linear(output_dim, self.label_size + 2)
            self.crf = CRF(self.label_size, self.gpu)
        else:
            self.hidden2tag = nn.Linear(output_dim, self.label_size)
            self.criterion = nn.CrossEntropyLoss()

    def construct_graph(self, batch_size, seq_len, word_list):

        if self.cuda:
            device = 'cuda'
        else:
            device = 'cpu'
        if self.use_edge:
            unk_index = torch.tensor(0, device=device)
            unk_emb = self.word_embedding(unk_index)

            bmes_emb_b = self.bmes_embedding(torch.tensor(0, device=device))
            bmes_emb_m = self.bmes_embedding(torch.tensor(1, device=device))
            bmes_emb_e = self.bmes_embedding(torch.tensor(2, device=device))
            bmes_emb_s = self.bmes_embedding(torch.tensor(3, device=device))

        sen_nodes_mask_list = []
        sen_words_length_list =[]
        sen_words_mask_f_list = []
        sen_words_mask_b_list = []
        sen_word_embed_list = []
        sen_bmes_embed_list = []
        max_edge_num = -1

        for sen in range(batch_size):
            sen_nodes_mask = torch.zeros([1, seq_len], device=device).byte()
            sen_words_length = torch.zeros([1, self.length_dim], device=device)
            sen_words_mask_f = torch.zeros([1, seq_len], device=device).byte()
            sen_words_mask_b = torch.zeros([1, seq_len], device=device).byte()

            if self.use_edge:
                sen_word_embed = unk_emb[None, :]
                sen_bmes_embed = torch.zeros([1, seq_len, self.bmes_dim], device=device)

            for w in range(seq_len):
                if w < len(word_list[sen]) and word_list[sen][w]:
                    for word, word_len in zip(word_list[sen][w][0], word_list[sen][w][1]):

                        if word_len <= self.max_word_length:
                            word_length_index = torch.tensor(word_len-1, device=device)
                        else:
                            word_length_index = torch.tensor(self.max_word_length - 1, device=device)
                        word_length = self.length_embedding(word_length_index)
                        sen_words_length = torch.cat([sen_words_length, word_length[None, :]], 0)

                        # mask: Masked elements are marked by 1, batch_size * word_num * seq_len
                        nodes_mask = torch.ones([1, seq_len], device=device).byte()
                        words_mask_f = torch.ones([1, seq_len], device=device).byte()
                        words_mask_b = torch.ones([1, seq_len], device=device).byte()

                        words_mask_f[0, w + word_len - 1] = 0
                        sen_words_mask_f = torch.cat([sen_words_mask_f, words_mask_f], 0)

                        words_mask_b[0, w] = 0
                        sen_words_mask_b = torch.cat([sen_words_mask_b, words_mask_b], 0)

                        if self.use_edge:
                            word_index = torch.tensor(word, device=device)
                            word_embedding = self.word_embedding(word_index)
                            sen_word_embed = torch.cat([sen_word_embed, word_embedding[None, :]], 0)

                            bmes_embed = torch.zeros([1, seq_len, self.bmes_dim], device=device)

                            for index in range(word_len):
                                nodes_mask[0, w + index] = 0
                                if word_len == 1:
                                    bmes_embed[0, w + index, :] = bmes_emb_s
                                elif index == 0:
                                    bmes_embed[0, w + index, :] = bmes_emb_b
                                elif index == word_len - 1:
                                    bmes_embed[0, w + index, :] = bmes_emb_e
                                else:
                                    bmes_embed[0, w + index, :] = bmes_emb_m

                            sen_bmes_embed = torch.cat([sen_bmes_embed, bmes_embed], 0)
                            sen_nodes_mask = torch.cat([sen_nodes_mask, nodes_mask], 0)

            if sen_words_mask_f.size(0) > max_edge_num:
                max_edge_num = sen_words_mask_f.size(0)
            sen_words_mask_f_list.append(sen_words_mask_f.unsqueeze_(0))
            sen_words_mask_b_list.append(sen_words_mask_b.unsqueeze_(0))
            sen_words_length_list.append(sen_words_length.unsqueeze_(0))
            if self.use_edge:
                sen_nodes_mask_list.append(sen_nodes_mask.unsqueeze_(0))
                sen_word_embed_list.append(sen_word_embed.unsqueeze_(0))
                sen_bmes_embed_list.append(sen_bmes_embed.unsqueeze_(0))

        edges_mask = torch.zeros([batch_size, max_edge_num], device=device)
        batch_words_mask_f = torch.ones([batch_size, max_edge_num, seq_len], device=device).byte()
        batch_words_mask_b = torch.ones([batch_size, max_edge_num, seq_len], device=device).byte()
        batch_words_length = torch.zeros([batch_size, max_edge_num, self.length_dim], device=device)
        if self.use_edge:
            batch_nodes_mask = torch.zeros([batch_size, max_edge_num, seq_len], device=device).byte()
            batch_word_embed = torch.zeros([batch_size, max_edge_num, self.word_emb_dim], device=device)
            batch_bmes_embed = torch.zeros([batch_size, max_edge_num, seq_len, self.bmes_dim], device=device)
        else:
            batch_word_embed = None
            batch_bmes_embed = None
            batch_nodes_mask = None

        for index in range(batch_size):
            curr_edge_num = sen_words_mask_f_list[index].size(1)
            edges_mask[index, 0:curr_edge_num] = 1.
            batch_words_mask_f[index, 0:curr_edge_num, :] = sen_words_mask_f_list[index]
            batch_words_mask_b[index, 0:curr_edge_num, :] = sen_words_mask_b_list[index]
            batch_words_length[index, 0:curr_edge_num, :] = sen_words_length_list[index]
            if self.use_edge:
                batch_nodes_mask[index, 0:curr_edge_num, :] = sen_nodes_mask_list[index]
                batch_word_embed[index, 0:curr_edge_num, :] = sen_word_embed_list[index]
                batch_bmes_embed[index, 0:curr_edge_num, :, :] = sen_bmes_embed_list[index]

        return batch_word_embed, batch_bmes_embed, batch_nodes_mask, batch_words_mask_f, \
               batch_words_mask_b, batch_words_length, edges_mask

    def update_graph(self, word_list, word_inputs, mask):
        mask = mask.float()
        node_embeds = self.char_embedding(word_inputs)  # batch_size, max_seq_len, embedding
        B, L, _ = node_embeds.size()

        edge_embs, bmes_embs, nodes_mask, words_mask_f, words_mask_b, words_length, edges_mask = \
            self.construct_graph(B, L, word_list)

        node_embeds = self.dropout(node_embeds)

        _, N, _ = words_mask_f.size()

        if self.use_edge:
            edge_embs = self.dropout(edge_embs)

        # forward direction digraph
        nodes_f, _ = self.emb_rnn_f(node_embeds)
        nodes_f = nodes_f * mask.unsqueeze(2)
        nodes_f_cat = nodes_f[:, None, :, :]
        _, _, H = nodes_f.size()

        if self.use_edge:
            edges_f = edge_embs * edges_mask.unsqueeze(2)
            edges_f_cat = edges_f[:, None, :, :]

            if self.use_global:
                glo_f = edges_f.sum(1, keepdim=True) / edges_mask.sum(1, keepdim=True).unsqueeze_(2) + \
                        nodes_f.sum(1, keepdim=True) / mask.sum(1, keepdim=True).unsqueeze_(2)
                glo_f_cat = glo_f[:, None, :, :]

        else:
            if self.use_global:
                glo_f = (nodes_f * mask.unsqueeze(2)).sum(1, keepdim=True) / mask.sum(1, keepdim=True).unsqueeze_(2)
                glo_f_cat = glo_f[:, None, :, :]

        for i in range(self.iters):

            # Attention-based aggregation
            if self.use_edge and N > 1:
                bmes_nodes_f = torch.cat([nodes_f.unsqueeze(2).expand(B, L, N, H), bmes_embs.transpose(1, 2)], -1)
                edges_att_f = self.node2edge_f[i](edges_f, bmes_nodes_f, nodes_mask.transpose(1, 2))

            nodes_begin_f = torch.sum(nodes_f[:, None, :, :] * (1 - words_mask_b)[:, :, :, None].float(), 2)
            nodes_begin_f = torch.cat([torch.zeros([B, 1, H], device=nodes_f.device), nodes_begin_f[:, 1:N, :]], 1)

            if self.use_edge:
                nodes_att_f = self.edge2node_f[i](nodes_f, torch.cat([edges_f, nodes_begin_f, words_length], -1).unsqueeze(2), words_mask_f)
                if self.use_global:
                    glo_att_f = torch.cat([self.glo_att_f_node[i](glo_f, nodes_f, (1 - mask).byte()),
                                           self.glo_att_f_edge[i](glo_f, edges_f, (1 - edges_mask).byte())], -1)
            else:
                nodes_att_f = self.edge2node_f[i](nodes_f, torch.cat([nodes_begin_f, words_length], -1).unsqueeze(2), words_mask_f)
                if self.use_global:
                    glo_att_f = self.glo_att_f_node[i](glo_f, nodes_f, (1 - mask).byte())

            # RNN-based update
            if self.use_edge and N > 1:
                if self.use_global:
                    edges_f = torch.cat([edges_f[:, 0:1, :], self.edge_rnn_f(edges_f[:, 1:N, :],
                                         edges_att_f[:, 1:N, :], glo_att_f.expand(B, N-1, H*2))], 1)
                else:
                    edges_f = torch.cat([edges_f[:, 0:1, :], self.edge_rnn_f(edges_f[:, 1:N, :], edges_att_f[:, 1:N, :])], 1)

                edges_f_cat = torch.cat([edges_f_cat, edges_f[:, None, :, :]], 1)
                edges_f = torch.cat([edges_f[:, 0:1, :], self.norm(torch.sum(edges_f_cat[:, :, 1:N, :], 1))], 1)

            nodes_f_r = torch.cat([torch.zeros([B, 1, self.hidden_dim], device=nodes_f.device), nodes_f[:, 0:(L-1), :]], 1)

            if self.use_global:
                nodes_f = self.node_rnn_f(nodes_f, nodes_f_r, nodes_att_f, glo_att_f.expand(B, L, -1))
            else:
                nodes_f = self.node_rnn_f(nodes_f, nodes_f_r, nodes_att_f)

            nodes_f_cat = torch.cat([nodes_f_cat, nodes_f[:, None, :, :]], 1)
            nodes_f = self.norm(torch.sum(nodes_f_cat, 1))

            if self.use_global:
                glo_f = self.glo_rnn_f(glo_f, glo_att_f)
                glo_f_cat = torch.cat([glo_f_cat, glo_f[:, None, :, :]], 1)
                glo_f = self.norm(torch.sum(glo_f_cat, 1))

        nodes_cat = nodes_f_cat

        # backward direction digraph
        if self.bidirectional:
            nodes_b, _ = self.emb_rnn_b(torch.flip(node_embeds, [1]))
            nodes_b = torch.flip(nodes_b, [1])
            nodes_b = nodes_b * mask.unsqueeze(2)
            nodes_b_cat = nodes_b[:, None, :, :]

            if self.use_edge:
                edges_b = edge_embs * edges_mask.unsqueeze(2)
                edges_b_cat = edges_b[:, None, :, :]
                if self.use_global:
                    glo_b = edges_b.sum(1, keepdim=True) / edges_mask.sum(1, keepdim=True).unsqueeze_(2) + \
                            nodes_b.sum(1, keepdim=True) / mask.sum(1, keepdim=True).unsqueeze_(2)
                    glo_b_cat = glo_b[:, None, :, :]

            else:
                if self.use_global:
                    glo_b = nodes_b.sum(1, keepdim=True) / mask.sum(1, keepdim=True).unsqueeze_(2)
                    glo_b_cat = glo_b[:, None, :, :]

            for i in range(self.iters):

                # Attention-based aggregation
                if self.use_edge and N > 1:
                    bmes_nodes_b = torch.cat([nodes_b.unsqueeze(2).expand(B, L, N, H), bmes_embs.transpose(1, 2)], -1)
                    edges_att_b = self.node2edge_b[i](edges_b, bmes_nodes_b, nodes_mask.transpose(1, 2))

                nodes_begin_b = torch.sum(nodes_b[:, None, :, :] * (1 - words_mask_f)[:, :, :, None].float(), 2)
                nodes_begin_b = torch.cat([torch.zeros([B, 1, H], device=nodes_b.device), nodes_begin_b[:, 1:N, :]], 1)

                if self.use_edge:
                    nodes_att_b = self.edge2node_b[i](nodes_b, torch.cat([edges_b, nodes_begin_b, words_length], -1).unsqueeze(2), words_mask_b)
                    if self.use_global:
                        glo_att_b = torch.cat([self.glo_att_b_node[i](glo_b, nodes_b, (1-mask).byte()),
                                               self.glo_att_b_edge[i](glo_b, edges_b, (1-edges_mask).byte())], -1)
                else:
                    nodes_att_b = self.edge2node_b[i](nodes_b, torch.cat([nodes_begin_b, words_length], -1).unsqueeze(2), words_mask_b)
                    if self.use_global:
                        glo_att_b = self.glo_att_b_node[i](glo_b, nodes_b, (1-mask).byte())

                # RNN-based update
                if self.use_edge and N > 1:
                    if self.use_global:
                        edges_b = torch.cat([edges_b[:, 0:1, :], self.edge_rnn_b(edges_b[:, 1:N, :],
                                             edges_att_b[:, 1:N, :], glo_att_b.expand(B, N-1, H*2))], 1)
                    else:
                        edges_b = torch.cat([edges_b[:, 0:1, :], self.edge_rnn_b(edges_b[:, 1:N, :], edges_att_b[:, 1:N, :])], 1)

                    edges_b_cat = torch.cat([edges_b_cat, edges_b[:, None, :, :]], 1)
                    edges_b = torch.cat([edges_b[:, 0:1, :], self.norm(torch.sum(edges_b_cat[:, :, 1:N, :], 1))], 1)

                nodes_b_r = torch.cat([nodes_b[:, 1:L, :], torch.zeros([B, 1, self.hidden_dim], device=nodes_b.device)], 1)

                if self.use_global:
                    nodes_b = self.node_rnn_b(nodes_b, nodes_b_r, nodes_att_b, glo_att_b.expand(B, L, -1))
                else:
                    nodes_b = self.node_rnn_b(nodes_b, nodes_b_r, nodes_att_b)

                nodes_b_cat = torch.cat([nodes_b_cat, nodes_b[:, None, :, :]], 1)
                nodes_b = self.norm(torch.sum(nodes_b_cat, 1))

                if self.use_global:
                    glo_b = self.glo_rnn_b(glo_b, glo_att_b)
                    glo_b_cat = torch.cat([glo_b_cat, glo_b[:, None, :, :]], 1)
                    glo_b = self.norm(torch.sum(glo_b_cat, 1))

            nodes_cat = torch.cat([nodes_f_cat, nodes_b_cat], -1)

        layer_att = torch.sigmoid(self.layer_att_W(nodes_cat))
        layer_alpha = F.softmax(layer_att, 1)
        nodes = torch.sum(layer_alpha * nodes_cat, 1)

        tags = self.hidden2tag(nodes)

        return tags

    def forward(self, word_list, batch_inputs, mask, batch_label=None):

        tags = self.update_graph(word_list, batch_inputs, mask)

        if batch_label is not None:
            if self.use_crf:
                total_loss = self.crf.neg_log_likelihood_loss(tags, mask, batch_label)
            else:
                total_loss = self.criterion(tags.view(-1, self.label_size), batch_label.view(-1))
        else:
            total_loss = None

        if self.use_crf:
            _, tag_seq = self.crf._viterbi_decode(tags, mask)
        else:
            tag_seq = tags.argmax(-1)

        return total_loss, tag_seq


================================================
FILE: model/crf.py
================================================
# -*- coding: utf-8 -*-
# @Author: Jie Yang
# @Last Modified by:   Yicheng Zou,     Contact: yczou18@fudan.edu.cn

import torch
import torch.autograd as autograd
import torch.nn as nn
START_TAG = -2
STOP_TAG = -1


# Compute log sum exp in a numerically stable way for the forward algorithm
def log_sum_exp(vec, m_size):
    """
    calculate log of exp sum
    args:
        vec (batch_size, vanishing_dim, hidden_dim) : input tensor
        m_size : hidden_dim
    return:
        batch_size, hidden_dim
    """
    _, idx = torch.max(vec, 1)  # B * 1 * M
    max_score = torch.gather(vec, 1, idx.view(-1, 1, m_size)).view(-1, 1, m_size)  # B * M
    return max_score.view(-1, m_size) + torch.log(torch.sum(torch.exp(vec - max_score.expand_as(vec)), 1)).view(-1, m_size)  # B * M


class CRF(nn.Module):

    def __init__(self, tagset_size, gpu):
        super(CRF, self).__init__()
        print ("build batched crf...")
        self.gpu = gpu
        # Matrix of transition parameters.  Entry i,j is the score of transitioning *to* i *from* j.
        self.average_batch = False
        self.tagset_size = tagset_size
        # # We add 2 here, because of START_TAG and STOP_TAG
        # # transitions (f_tag_size, t_tag_size), transition value from f_tag to t_tag
        init_transitions = torch.zeros(self.tagset_size+2, self.tagset_size+2)
        # init_transitions = torch.zeros(self.tagset_size+2, self.tagset_size+2)
        # init_transitions[:,START_TAG] = -1000.0
        # init_transitions[STOP_TAG,:] = -1000.0
        # init_transitions[:,0] = -1000.0
        # init_transitions[0,:] = -1000.0
        if self.gpu:
            init_transitions = init_transitions.cuda()
        self.transitions = nn.Parameter(init_transitions)  #(t+2,t+2)

        # self.transitions = nn.Parameter(torch.Tensor(self.tagset_size+2, self.tagset_size+2))
        # self.transitions.data.zero_()

    def _calculate_PZ(self, feats, mask):
        """
            input:
                feats: (batch, seq_len, self.tag_size+2)  (b,m,t+2)
                masks: (batch, seq_len)   (b,m)
        """
        batch_size = feats.size(0)
        seq_len = feats.size(1)
        tag_size = feats.size(2)
        # print feats.view(seq_len, tag_size)
        assert(tag_size == self.tagset_size+2)
        mask = mask.transpose(1,0).contiguous()  #(m,b)
        ins_num = seq_len * batch_size
        ## be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1)
        feats = feats.transpose(1,0).contiguous().view(ins_num,1, tag_size).expand(ins_num, tag_size, tag_size)  #(i,t+2,t+2) 第2维t+2的每一个是一样的
        ## need to consider start
        scores = feats + self.transitions.view(1,tag_size,tag_size).expand(ins_num, tag_size, tag_size)
        scores = scores.view(seq_len, batch_size, tag_size, tag_size)
        # build iter
        seq_iter = enumerate(scores)
        _, inivalues = seq_iter.__next__()  # bat_size * from_target_size * to_target_size  (b,t,t) inivalues是每个句子的第一个字
        # only need start from start_tag
        partition = inivalues[:, START_TAG, :].clone().view(batch_size, tag_size, 1)  # bat_size * to_target_size (b,t,1)

        ## add start score (from start to all tag, duplicate to batch_size)
        # partition = partition + self.transitions[START_TAG,:].view(1, tag_size, 1).expand(batch_size, tag_size, 1)
        # iter over last scores
        for idx, cur_values in seq_iter:
            # previous to_target is current from_target
            # partition: previous results log(exp(from_target)), #(batch_size * from_target)
            # cur_values: bat_size * from_target * to_target
            
            cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
            cur_partition = log_sum_exp(cur_values, tag_size)  #(b,t)
            # print cur_partition.data
            
                # (bat_size * from_target * to_target) -> (bat_size * to_target)
            # partition = utils.switch(partition, cur_partition, mask[idx].view(bat_size, 1).expand(bat_size, self.tagset_size)).view(bat_size, -1)
            mask_idx = mask[idx, :].view(batch_size, 1).expand(batch_size, tag_size)
            
            ## effective updated partition part, only keep the partition value of mask value = 1
            masked_cur_partition = cur_partition.masked_select(mask_idx)
            ## let mask_idx broadcastable, to disable warning
            mask_idx = mask_idx.contiguous().view(batch_size, tag_size, 1)

            ## replace the partition where the maskvalue=1, other partition value keeps the same
            partition.masked_scatter_(mask_idx, masked_cur_partition)  
        # until the last state, add transition score for all partition (and do log_sum_exp) then select the value in STOP_TAG
        cur_values = self.transitions.view(1,tag_size, tag_size).expand(batch_size, tag_size, tag_size) + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
        cur_partition = log_sum_exp(cur_values, tag_size)  #(batch_size,hidden_dim)
        final_partition = cur_partition[:, STOP_TAG]  #(batch_size)
        return final_partition.sum(), scores #scores: (seq_len, batch, tag_size, tag_size)


    def _viterbi_decode(self, feats, mask):
        """
            input:
                feats: (batch, seq_len, self.tag_size+2)
                mask: (batch, seq_len)
            output:
                decode_idx: (batch, seq_len) decoded sequence
                path_score: (batch, 1) corresponding score for each sequence (to be implementated)
        """
        batch_size = feats.size(0)
        seq_len = feats.size(1)
        tag_size = feats.size(2)
        assert(tag_size == self.tagset_size+2)
        ## calculate sentence length for each sentence
        length_mask = torch.sum(mask.long(), dim = 1).view(batch_size,1).long()  #(batch_size,1) 每个句子的mask长度
        ## mask to (seq_len, batch_size)
        mask = mask.transpose(1,0).contiguous()  #(seq_len,b)
        ins_num = seq_len * batch_size
        ## be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1)
        feats = feats.transpose(1,0).contiguous().view(ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size)  #(ins_num, tag_size, tag_size)
        ## need to consider start
        scores = feats + self.transitions.view(1,tag_size,tag_size).expand(ins_num, tag_size, tag_size)
        scores = scores.view(seq_len, batch_size, tag_size, tag_size)

        # build iter
        seq_iter = enumerate(scores)
        ## record the position of best score
        back_points = list()
        partition_history = list()
        
        ##  reverse mask (bug for mask = 1- mask, use this as alternative choice)
        # mask = 1 + (-1)*mask
        mask =  (1 - mask.long()).byte()
        _, inivalues = seq_iter.__next__()  # bat_size * from_target_size * to_target_size
        # only need start from start_tag
        partition = inivalues[:, START_TAG, :].clone().view(batch_size, tag_size, 1)  # bat_size * to_target_size
        partition_history.append(partition) #(seqlen,batch_size,tag_size,1)
        # iter over last scores
        for idx, cur_values in seq_iter:
            # previous to_target is current from_target
            # partition: previous results log(exp(from_target)), #(batch_size * from_target)
            # cur_values: batch_size * from_target * to_target
            cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
            ## forscores, cur_bp = torch.max(cur_values[:,:-2,:], 1) # do not consider START_TAG/STOP_TAG
            partition, cur_bp = torch.max(cur_values,dim=1)
            partition_history.append(partition.unsqueeze(2))
            ## cur_bp: (batch_size, tag_size) max source score position in current tag
            ## set padded label as 0, which will be filtered in post processing
            cur_bp.masked_fill_(mask[idx].view(batch_size, 1).expand(batch_size, tag_size), 0) 
            back_points.append(cur_bp)
        ### add score to final STOP_TAG
        partition_history = torch.cat(partition_history,dim=0).view(seq_len, batch_size,-1).transpose(1,0).contiguous() ## (batch_size, seq_len, tag_size)
        ### get the last position for each setences, and select the last partitions using gather()
        last_position = length_mask.view(batch_size,1,1).expand(batch_size, 1, tag_size) -1
        last_partition = torch.gather(partition_history, 1, last_position).view(batch_size,tag_size,1)
        ### calculate the score from last partition to end state (and then select the STOP_TAG from it)
        last_values = last_partition.expand(batch_size, tag_size, tag_size) + self.transitions.view(1,tag_size, tag_size).expand(batch_size, tag_size, tag_size)
        _, last_bp = torch.max(last_values, 1)  #(batch_size,tag_size)
        pad_zero = autograd.Variable(torch.zeros(batch_size, tag_size)).long()
        if self.gpu:
            pad_zero = pad_zero.cuda()
        back_points.append(pad_zero)
        back_points  =  torch.cat(back_points).view(seq_len, batch_size, tag_size)
        
        ## select end ids in STOP_TAG
        pointer = last_bp[:, STOP_TAG] #(batch_size)
        insert_last = pointer.contiguous().view(batch_size,1,1).expand(batch_size,1, tag_size)
        back_points = back_points.transpose(1,0).contiguous()   #(batch_size,sq_len,tag_size)
        ## move the end ids(expand to tag_size) to the corresponding position of back_points to replace the 0 values
        # print "lp:",last_position
        # print "il:",insert_last
        back_points.scatter_(1, last_position, insert_last)  ##(batch_size,sq_len,tag_size)
        # print "bp:",back_points
        # exit(0)
        back_points = back_points.transpose(1,0).contiguous()  #(seq_len, batch_size, tag_size)
        ## decode from the end, padded position ids are 0, which will be filtered if following evaluation
        decode_idx = autograd.Variable(torch.LongTensor(seq_len, batch_size))
        if self.gpu:
            decode_idx = decode_idx.cuda()
        decode_idx[-1] = pointer.data
        for idx in range(len(back_points)-2, -1, -1):
            pointer = torch.gather(back_points[idx], 1, pointer.contiguous().view(batch_size, 1)) #pointer's size:(batch_size,1)
            decode_idx[idx] = pointer.squeeze(1).data
        path_score = None
        decode_idx = decode_idx.transpose(1,0) #(batch_size, sent_len)
        return path_score, decode_idx  #


    def forward(self, feats):
        path_score, best_path = self._viterbi_decode(feats)
        return path_score, best_path
        

    def _score_sentence(self, scores, mask, tags):
        """
            input:
                scores: variable (seq_len, batch, tag_size, tag_size)
                mask: (batch, seq_len)
                tags: tensor  (batch, seq_len)
            output:
                score: sum of score for gold sequences within whole batch
        """
        # Gives the score of a provided tag sequence
        batch_size = scores.size(1)
        seq_len = scores.size(0)
        tag_size = scores.size(2)
        ## convert tag value into a new format, recorded label bigram information to index  
        new_tags = autograd.Variable(torch.LongTensor(batch_size, seq_len))
        if self.gpu:
            new_tags = new_tags.cuda()
        for idx in range(seq_len):
            if idx == 0:
                ## start -> first score
                new_tags[:,0] =  (tag_size - 2)*tag_size + tags[:,0]

            else:
                new_tags[:,idx] =  tags[:,idx-1]*tag_size + tags[:,idx]

        ## transition for label to STOP_TAG
        end_transition = self.transitions[:,STOP_TAG].contiguous().view(1, tag_size).expand(batch_size, tag_size)
        ## length for batch,  last word position = length - 1
        length_mask = torch.sum(mask.long(), dim = 1).view(batch_size,1).long()
        ## index the label id of last word
        end_ids = torch.gather(tags, 1, length_mask - 1)

        ## index the transition score for end_id to STOP_TAG
        end_energy = torch.gather(end_transition, 1, end_ids)

        ## convert tag as (seq_len, batch_size, 1)
        new_tags = new_tags.transpose(1,0).contiguous().view(seq_len, batch_size, 1)
        ### need convert tags id to search from 400 positions of scores
        tg_energy = torch.gather(scores.view(seq_len, batch_size, -1), 2, new_tags).view(seq_len, batch_size)  # seq_len * bat_size
        ## mask transpose to (seq_len, batch_size)
        tg_energy = tg_energy.masked_select(mask.transpose(1,0))
        
        # ## calculate the score from START_TAG to first label
        # start_transition = self.transitions[START_TAG,:].view(1, tag_size).expand(batch_size, tag_size)
        # start_energy = torch.gather(start_transition, 1, tags[0,:])

        ## add all score together
        # gold_score = start_energy.sum() + tg_energy.sum() + end_energy.sum()
        gold_score = tg_energy.sum() + end_energy.sum()
        return gold_score

    def neg_log_likelihood_loss(self, feats, mask, tags):
        # nonegative log likelihood
        batch_size = feats.size(0)
        forward_score, scores = self._calculate_PZ(feats, mask)  #forward_score:long, scores: (seq_len, batch, tag_size, tag_size)
        gold_score = self._score_sentence(scores, mask, tags)
        #print ("batch, f:", forward_score.data, " g:", gold_score.data, " dis:", forward_score.data - gold_score.data)
        # exit(0)
        if self.average_batch:
            return (forward_score - gold_score)/batch_size
        else:
             return forward_score - gold_score


================================================
FILE: model/module.py
================================================
# -*- coding: utf-8 -*-
# @Author: Yicheng Zou
# @Last Modified by:   Yicheng Zou,     Contact: yczou18@fudan.edu.cn


import torch
import math
import torch.nn as nn
import numpy as np
import torch.nn.functional as F


class MultiHeadAtt(nn.Module):
    def __init__(self, nhid, keyhid, nhead=10, head_dim=10, dropout=0.1, if_g=False):
        super(MultiHeadAtt, self).__init__()

        if if_g:
            self.WQ = nn.Conv2d(nhid * 3, nhead * head_dim, 1)
        else:
            self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1)
        self.WK = nn.Conv2d(keyhid, nhead * head_dim, 1)
        self.WV = nn.Conv2d(keyhid, nhead * head_dim, 1)
        self.WO = nn.Conv2d(nhead * head_dim, nhid, 1)

        self.drop = nn.Dropout(dropout)

        self.norm = nn.LayerNorm(nhid)

        self.nhid, self.nhead, self.head_dim = nhid, nhead, head_dim

    def forward(self, query_h, value, mask, query_g=None):

        if not (query_g is None):
            query = torch.cat([query_h, query_g], -1)
        else:
            query = query_h
        query = query.permute(0, 2, 1)[:, :, :, None]
        value = value.permute(0, 3, 1, 2)

        residual = query_h
        nhid, nhead, head_dim = self.nhid, self.nhead, self.head_dim

        B, QL, H = query_h.shape

        _, _, VL, VD = value.shape  # VD = 1 or VD = QL

        assert VD == 1 or VD == QL
        # q: (B, H, QL, 1)
        # v: (B, H, VL, VD)
        q, k, v = self.WQ(query), self.WK(value), self.WV(value)

        q = q.view(B, nhead, head_dim, 1, QL)
        k = k.view(B, nhead, head_dim, VL, VD)
        v = v.view(B, nhead, head_dim, VL, VD)

        alpha = (q * k).sum(2, keepdim=True) / np.sqrt(head_dim)
        alpha = alpha.masked_fill(mask[:, None, None, :, :], -np.inf)
        alpha = self.drop(F.softmax(alpha, 3))
        att = (alpha * v).sum(3).view(B, nhead * head_dim, QL, 1)

        output = F.leaky_relu(self.WO(att)).permute(0, 2, 3, 1).view(B, QL, H)
        output = self.norm(output + residual)

        return output


class GloAtt(nn.Module):
    def __init__(self, nhid, nhead=10, head_dim=10, dropout=0.1):
        # Multi-head Self Attention Case 2, a broadcastable query for a sequence key and value
        super(GloAtt, self).__init__()
        self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1)
        self.WK = nn.Conv2d(nhid, nhead * head_dim, 1)
        self.WV = nn.Conv2d(nhid, nhead * head_dim, 1)
        self.WO = nn.Conv2d(nhead * head_dim, nhid, 1)

        self.drop = nn.Dropout(dropout)

        self.norm = nn.LayerNorm(nhid)

        # print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim)
        self.nhid, self.nhead, self.head_dim = nhid, nhead, head_dim

    def forward(self, x, y, mask=None):
        # x: B, H, 1, 1, 1 y: B H L 1
        nhid, nhead, head_dim = self.nhid, self.nhead, self.head_dim
        B, L, H = y.shape

        x = x.permute(0, 2, 1)[:, :, :, None]
        y = y.permute(0, 2, 1)[:, :, :, None]

        residual = x
        q, k, v = self.WQ(x), self.WK(y), self.WV(y)

        q = q.view(B, nhead, 1, head_dim)  # B, H, 1, 1 -> B, N, 1, h
        k = k.view(B, nhead, head_dim, L)  # B, H, L, 1 -> B, N, h, L
        v = v.view(B, nhead, head_dim, L).permute(0, 1, 3, 2)  # B, H, L, 1 -> B, N, L, h

        pre_a = torch.matmul(q, k) / np.sqrt(head_dim)
        if mask is not None:
            pre_a = pre_a.masked_fill(mask[:, None, None, :], -float('inf'))
        alphas = self.drop(F.softmax(pre_a, 3))  # B, N, 1, L
        att = torch.matmul(alphas, v).view(B, -1, 1, 1)  # B, N, 1, h -> B, N*h, 1, 1
        output = F.leaky_relu(self.WO(att)) + residual
        output = self.norm(output.permute(0, 2, 3, 1)).view(B, 1, H)

        return output


class Nodes_Cell(nn.Module):
    def __init__(self, input_h, hid_h, use_global=True, dropout=0.2):
        super(Nodes_Cell, self).__init__()

        self.use_global = use_global
        self.hidden_size = hid_h
        self.Wix = nn.Linear(input_h, hid_h)
        self.Wi2 = nn.Linear(input_h, hid_h)
        self.Wf = nn.Linear(input_h, hid_h)
        self.Wcx = nn.Linear(input_h, hid_h)

        self.drop = nn.Dropout(dropout)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            nn.init.uniform_(weight, -stdv, stdv)

    def forward(self, h, h2, x, glo=None):

        x = self.drop(x)

        if self.use_global:
            glo = self.drop(glo)
            cat_all = torch.cat([h, h2, x, glo], -1)
        else:
            cat_all = torch.cat([h, h2, x], -1)

        ix = torch.sigmoid(self.Wix(cat_all))
        i2 = torch.sigmoid(self.Wi2(cat_all))
        f = torch.sigmoid(self.Wf(cat_all))
        cx = torch.tanh(self.Wcx(cat_all))

        alpha = F.softmax(torch.cat([ix.unsqueeze(1), i2.unsqueeze(1), f.unsqueeze(1)], 1), 1)
        output = (alpha[:, 0] * cx) + (alpha[:, 1] * h2) + (alpha[:, 2] * h)

        return output


class Edges_Cell(nn.Module):
    def __init__(self, input_h, hid_h, use_global=True, dropout=0.2):
        super(Edges_Cell, self).__init__()

        self.use_global = use_global
        self.hidden_size = hid_h
        self.Wi = nn.Linear(input_h, hid_h)
        self.Wf = nn.Linear(input_h, hid_h)
        self.Wc = nn.Linear(input_h, hid_h)

        self.drop = nn.Dropout(dropout)

        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            nn.init.uniform_(weight, -stdv, stdv)

    def forward(self, h, x, glo=None):

        x = self.drop(x)

        if self.use_global:
            glo = self.drop(glo)
            cat_all = torch.cat([h, x, glo], -1)
        else:
            cat_all = torch.cat([h, x], -1)

        i = torch.sigmoid(self.Wi(cat_all))
        f = torch.sigmoid(self.Wf(cat_all))
        c = torch.tanh(self.Wc(cat_all))

        alpha = F.softmax(torch.cat([i.unsqueeze(1), f.unsqueeze(1)], 1), 1)
        output = (alpha[:, 0] * c) + (alpha[:, 1] * h)

        return output


class Global_Cell(nn.Module):
    def __init__(self, input_h, hid_h, dropout=0.2):
        super(Global_Cell, self).__init__()

        self.hidden_size = hid_h
        self.Wi = nn.Linear(input_h, hid_h)
        self.Wf = nn.Linear(input_h, hid_h)
        self.Wc = nn.Linear(input_h, hid_h)

        self.drop = nn.Dropout(dropout)

        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            nn.init.uniform_(weight, -stdv, stdv)

    def forward(self, h, x):

        x = self.drop(x)

        cat_all = torch.cat([h, x], -1)
        i = torch.sigmoid(self.Wi(cat_all))
        f = torch.sigmoid(self.Wf(cat_all))
        c = torch.tanh(self.Wc(cat_all))

        alpha = F.softmax(torch.cat([i.unsqueeze(1), f.unsqueeze(1)], 1), 1)
        output = (alpha[:, 0] * c) + (alpha[:, 1] * h)

        return output


================================================
FILE: utils/alphabet.py
================================================
# -*- coding: utf-8 -*-
# @Author: Max
# @Last Modified by:   Yicheng Zou,     Contact: yczou18@fudan.edu.cn

"""
Alphabet maps objects to integer ids. It provides two way mapping from the index to the objects.
"""
import json
import os


class Alphabet:
    def __init__(self, name, label=False, keep_growing=True):
        self.__name = name
        self.UNKNOWN = "</unk>"
        self.label = label
        self.instance2index = {}
        self.instances = []
        self.keep_growing = keep_growing

        # Index 0 is occupied by default, all else following.
        self.default_index = 0
        self.next_index = 1
        if not self.label:
            self.add(self.UNKNOWN)

    def clear(self, keep_growing=True):
        self.instance2index = {}
        self.instances = []
        self.keep_growing = keep_growing

        # Index 0 is occupied by default, all else following.
        self.default_index = 0
        self.next_index = 1
        
    def add(self, instance):
        if instance not in self.instance2index:
            self.instances.append(instance)
            self.instance2index[instance] = self.next_index
            self.next_index += 1

    def get_index(self, instance):
        try:
            return self.instance2index[instance]
        except KeyError:
            if self.keep_growing:
                index = self.next_index
                self.add(instance)
                return index
            else:
                return self.instance2index[self.UNKNOWN]

    def get_instance(self, index):
        if index == 0:
            # First index is occupied by the wildcard element.
            return None
        try:
            return self.instances[index - 1]
        except IndexError:
            print('WARNING:Alphabet get_instance ,unknown instance index {}, return the first label.'.format(index))
            return self.instances[0]

    def size(self):
        return len(self.instances) + 1

    def iteritems(self):
        return self.instance2index.items()

    def enumerate_items(self, start=1):
        if start < 1 or start >= self.size():
            raise IndexError("Enumerate is allowed between [1 : size of the alphabet)")
        return zip(range(start, len(self.instances) + 1), self.instances[start - 1:])

    def close(self):
        self.keep_growing = False

    def open(self):
        self.keep_growing = True

    def get_content(self):
        return {'instance2index': self.instance2index, 'instances': self.instances}

    def from_json(self, data):
        self.instances = data["instances"]
        self.instance2index = data["instance2index"]

    def save(self, output_directory, name=None):
        """
        Save both alhpabet records to the given directory.
        :param output_directory: Directory to save model and weights.
        :param name: The alphabet saving name, optional.
        :return:
        """
        saving_name = name if name else self.__name
        try:
            json.dump(self.get_content(), open(os.path.join(output_directory, saving_name + ".json"), 'w'))
        except Exception as e:
            print("Exception: Alphabet is not saved: " + repr(e))

    def load(self, input_directory, name=None):
        """
        Load model architecture and weights from the give directory. This allow we use old models even the structure
        changes.
        :param input_directory: Directory to save model and weights
        :return:
        """
        loading_name = name if name else self.__name
        self.from_json(json.load(open(os.path.join(input_directory, loading_name + ".json"))))


================================================
FILE: utils/data.py
================================================
# -*- coding: utf-8 -*-
# @Author: Jie Yang
# @Last Modified by:   Yicheng Zou,     Contact: yczou18@fudan.edu.cn

import sys
from utils.alphabet import Alphabet
from utils.functions import *
from utils.word_trie import Word_Trie


class Data:
    def __init__(self): 
        self.MAX_SENTENCE_LENGTH = 250
        self.MAX_WORD_LENGTH = -1
        self.number_normalized = True
        self.norm_char_emb = True
        self.norm_word_emb = False
        self.char_alphabet = Alphabet('character')
        self.label_alphabet = Alphabet('label', True)
        self.word_dict = Word_Trie()
        self.word_alphabet = Alphabet('word')

        self.train_texts = []
        self.dev_texts = []
        self.test_texts = []
        self.raw_texts = []

        self.train_Ids = []
        self.dev_Ids = []
        self.test_Ids = []
        self.raw_Ids = []
        self.char_emb_dim = 50
        self.word_emb_dim = 50
        self.pretrain_char_embedding = None
        self.pretrain_word_embedding = None
        self.label_size = 0
        
    def show_data_summary(self):
        print("DATA SUMMARY:")
        print("     MAX SENTENCE LENGTH: %s"%(self.MAX_SENTENCE_LENGTH))
        print("     MAX   WORD   LENGTH: %s"%(self.MAX_WORD_LENGTH))
        print("     Number   normalized: %s"%(self.number_normalized))
        print("     Word  alphabet size: %s"%(self.word_alphabet.size()))
        print("     Char  alphabet size: %s"%(self.char_alphabet.size()))
        print("     Label alphabet size: %s"%(self.label_alphabet.size()))
        print("     Word embedding size: %s"%(self.word_emb_dim))
        print("     Char embedding size: %s"%(self.char_emb_dim))
        print("     Norm     char   emb: %s"%(self.norm_char_emb))
        print("     Norm     word   emb: %s"%(self.norm_word_emb))
        print("     Train instance number: %s"%(len(self.train_texts)))
        print("     Dev   instance number: %s"%(len(self.dev_texts)))
        print("     Test  instance number: %s"%(len(self.test_texts)))
        print("     Raw   instance number: %s"%(len(self.raw_texts)))
        print("DATA SUMMARY END.")
        sys.stdout.flush()

    def build_alphabet(self, input_file):
        self.char_alphabet.open()
        self.label_alphabet.open()

        with open(input_file, 'r', encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if len(line) == 0:
                    continue
                pair = line.split()
                char = pair[0]
                if self.number_normalized:
                    # Mapping numbers to 0
                    char = normalize_word(char)
                label = pair[-1]
                self.label_alphabet.add(label)
                self.char_alphabet.add(char)

        self.label_alphabet.close()
        self.char_alphabet.close()

    def build_word_file(self, word_file):
        # build word file,initial word embedding file
        with open(word_file, 'r', encoding="utf-8") as f:
            for line in f:
                word = line.strip().split()[0]
                if word:
                    self.word_dict.insert(word)
        print("Building the word dict...")

    def build_word_alphabet(self, input_file):
        print("Loading file: " + input_file)
        self.word_alphabet.open()
        word_list = []
        with open(input_file, 'r', encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if len(line) > 0:
                    word = line.split()[0]
                    if self.number_normalized:
                        word = normalize_word(word)
                    word_list.append(word)
                else:
                    for idx in range(len(word_list)):
                        matched_words = self.word_dict.recursive_search(word_list[idx:])
                        for matched_word in matched_words:
                            self.word_alphabet.add(matched_word)
                    word_list = []
        self.word_alphabet.close()
        print("word alphabet size:", self.word_alphabet.size())

    def build_char_pretrain_emb(self, emb_path):
        print ("Building character pretrain emb...")
        self.pretrain_char_embedding, self.char_emb_dim = build_pretrain_embedding(emb_path, self.char_alphabet, self.norm_char_emb)

    def build_word_pretrain_emb(self, emb_path):
        print ("Building word pretrain emb...")
        self.pretrain_word_embedding, self.word_emb_dim = build_pretrain_embedding(emb_path, self.word_alphabet, self.norm_word_emb)

    def generate_instance_with_words(self, input_file, name):
        if name == "train":
            self.train_texts, self.train_Ids = read_instance_with_gaz(input_file, self.word_dict, self.char_alphabet,
                    self.word_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
        elif name == "dev":
            self.dev_texts, self.dev_Ids = read_instance_with_gaz(input_file, self.word_dict, self.char_alphabet,
                    self.word_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
        elif name == "test":
            self.test_texts, self.test_Ids = read_instance_with_gaz(input_file, self.word_dict, self.char_alphabet,
                    self.word_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
        elif name == "raw":
            self.raw_texts, self.raw_Ids = read_instance_with_gaz(input_file, self.word_dict, self.char_alphabet,
                    self.word_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)
        else:
            print("Error: you can only generate train/dev/test/raw instance! Illegal input:%s"%(name))

    def write_decoded_results(self, output_file, predict_results, name):
        fout = open(output_file, 'w', encoding="utf-8")
        sent_num = len(predict_results)
        content_list = []
        if name == 'raw':
           content_list = self.raw_texts
        elif name == 'test':
            content_list = self.test_texts
        elif name == 'dev':
            content_list = self.dev_texts
        elif name == 'train':
            content_list = self.train_texts
        else:
            print("Error: illegal name during writing predict result, name should be within train/dev/test/raw !")
        assert(sent_num == len(content_list))
        for idx in range(sent_num):
            sent_length = len(predict_results[idx])
            for idy in range(sent_length):
                # content_list[idx] is a list with [word, char, label]
                fout.write(content_list[idx][0][idy] + " " + predict_results[idx][idy] + '\n')
            fout.write('\n')
        fout.close()
        print("Predict %s result has been written into file. %s"%(name, output_file))


================================================
FILE: utils/functions.py
================================================
# -*- coding: utf-8 -*-
# @Author: Jie Yang
# @Last Modified by:   Yicheng Zou,     Contact: yczou18@fudan.edu.cn

import numpy as np


def normalize_word(word):
    new_word = ""
    for char in word:
        if char.isdigit():
            new_word += '0'
        else:
            new_word += char
    return new_word


def read_instance_with_gaz(input_file, word_dict, char_alphabet, word_alphabet, label_alphabet, number_normalized, max_sent_length):
    instence_texts = []
    instence_Ids = []

    with open(input_file, 'r', encoding="utf-8") as f:

        chars = []
        labels = []
        char_Ids = []
        label_Ids = []

        for line in f:
            if len(line) > 1:
                pairs = line.strip().split()
                char = pairs[0]
                if number_normalized:
                    char = normalize_word(char)
                chars.append(char)
                char_Ids.append(char_alphabet.get_index(char))
                if len(pairs) > 1:
                    label = pairs[-1]
                else:
                    label = 'O'
                labels.append(label)
                label_Ids.append(label_alphabet.get_index(label))

            # A sentence is finished.
            else:
                # Only keep the sentence whose length is smaller than MAX_SENT_LENGTH.
                if ((max_sent_length < 0) or (len(chars) < max_sent_length)) and (len(chars)>0):
                    words = []
                    word_Ids = []
                    for idx in range(len(chars)):
                        matched_list = word_dict.recursive_search(chars[idx:])
                        matched_length = [len(a) for a in matched_list]

                        words.append(matched_list)
                        matched_Id = [word_alphabet.get_index(word) for word in matched_list]
                        if matched_Id:
                            word_Ids.append([matched_Id, matched_length])
                        else:
                            word_Ids.append([])

                    instence_texts.append([chars, words, labels])
                    instence_Ids.append([char_Ids, word_Ids, label_Ids])
                chars = []
                labels = []
                char_Ids = []
                label_Ids = []

    return instence_texts, instence_Ids


def build_pretrain_embedding(embedding_path, word_alphabet, norm=True, embedd_dim=50):

    def norm2one(vec):
        root_sum_square = np.sqrt(np.sum(np.square(vec)))
        return vec / root_sum_square

    embedd_dict = dict()
    if embedding_path != None:
        embedd_dict, embedd_dim = load_pretrain_emb(embedding_path)

    scale = np.sqrt(3.0 / embedd_dim)
    pretrain_emb = np.empty([word_alphabet.size(), embedd_dim])
    not_match = 0
    for word, index in word_alphabet.instance2index.items():
        if word.lower() in embedd_dict:
            if norm:
                pretrain_emb[index,:] = norm2one(embedd_dict[word.lower()])
            else:
                pretrain_emb[index,:] = embedd_dict[word.lower()]
        elif word in embedd_dict:
            if norm:
                pretrain_emb[index,:] = norm2one(embedd_dict[word])
            else:
                pretrain_emb[index,:] = embedd_dict[word]
        else:
            pretrain_emb[index,:] = np.random.uniform(-scale, scale, [1, embedd_dim])
            not_match += 1
    pretrained_size = len(embedd_dict)
    print("Embedding:\n     pretrain word:%s, match:%s, oov:%s, oov%%:%.4f" %
          (pretrained_size, word_alphabet.size() - not_match, not_match, (not_match+0.)/word_alphabet.size()))
    return pretrain_emb, embedd_dim


def load_pretrain_emb(embedding_path):
    embedd_dict = dict()
    with open(embedding_path, 'r', encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if len(line) == 0:
                continue
            tokens = line.split()
            embedd_dim = len(tokens) - 1
            embedd = np.empty([1, embedd_dim])
            embedd[:] = tokens[1:]
            embedd_dict[tokens[0]] = embedd
    return embedd_dict, embedd_dim


================================================
FILE: utils/metric.py
================================================
# -*- coding: utf-8 -*-
# @Author: Jie Yang
# @Last Modified by:   Yicheng Zou,     Contact: yczou18@fudan.edu.cn


# input as sentence level labels
def get_ner_fmeasure(golden_lists, predict_lists):
    sent_num = len(golden_lists)
    golden_full = []
    predict_full = []
    right_full = []
    right_tag = 0
    all_tag = 0
    for idx in range(0,sent_num):
        golden_list = golden_lists[idx]
        predict_list = predict_lists[idx]
        for idy in range(len(golden_list)):
            if golden_list[idy] == predict_list[idy]:
                right_tag += 1
        all_tag += len(golden_list)

        gold_matrix = get_ner_BMES(golden_list)
        pred_matrix = get_ner_BMES(predict_list)

        right_ner = list(set(gold_matrix).intersection(set(pred_matrix)))
        golden_full += gold_matrix
        predict_full += pred_matrix
        right_full += right_ner
    right_num = len(right_full)
    golden_num = len(golden_full)
    predict_num = len(predict_full)
    if predict_num == 0:
        precision = -1
    else:
        precision =  (right_num+0.0)/predict_num
    if golden_num == 0:
        recall = -1
    else:
        recall = (right_num+0.0)/golden_num
    if (precision == -1) or (recall == -1) or (precision+recall) <= 0.:
        f_measure = -1
    else:
        f_measure = 2*precision*recall/(precision+recall)
    accuracy = (right_tag+0.0)/all_tag
    print("gold_num = ", golden_num, " pred_num = ", predict_num, " right_num = ", right_num)
    return accuracy, precision, recall, f_measure


def reverse_style(input_string):
    target_position = input_string.index('[')
    input_len = len(input_string)
    output_string = input_string[target_position:input_len] + input_string[0:target_position]
    return output_string


def get_ner_BMES(label_list):

    list_len = len(label_list)
    begin_label = 'B-'
    end_label = 'E-'
    single_label = 'S-'
    whole_tag = ''
    index_tag = ''
    tag_list = []
    stand_matrix = []
    for i in range(0, list_len):
        # wordlabel = word_list[i]
        current_label = label_list[i].upper() if label_list[i] else []
        if begin_label in current_label:
            if index_tag != '':
                tag_list.append(whole_tag + ',' + str(i-1))
            whole_tag = current_label.replace(begin_label,"",1) +'[' +str(i)
            index_tag = current_label.replace(begin_label,"",1)
            
        elif single_label in current_label:
            if index_tag != '':
                tag_list.append(whole_tag + ',' + str(i-1))
            whole_tag = current_label.replace(single_label,"",1) +'[' +str(i)
            tag_list.append(whole_tag)
            whole_tag = ""
            index_tag = ""
        elif end_label in current_label:
            if index_tag != '':
                tag_list.append(whole_tag +',' + str(i))
            whole_tag = ''
            index_tag = ''
        else:
            continue
    if (whole_tag != '')&(index_tag != ''):
        tag_list.append(whole_tag)
    tag_list_len = len(tag_list)

    for i in range(0, tag_list_len):
        if  len(tag_list[i]) > 0:
            tag_list[i] = tag_list[i]+ ']'
            insert_list = reverse_style(tag_list[i])
            stand_matrix.append(insert_list)

    return stand_matrix


================================================
FILE: utils/word_trie.py
================================================
# -*- coding: utf-8 -*-
# @Author: Yicheng Zou
# @Last Modified by:   Yicheng Zou,     Contact: yczou18@fudan.edu.cn

_end = "_end_"


class Word_Trie:
    def __init__(self):
        self.root = dict()

    def recursive_search(self, word_list):
        match_list = []
        while len(word_list) > 1:
            if self.search(word_list):
                match_list.append("".join(word_list))
            del word_list[-1]
        return match_list

    def search(self, word):
        current_dict = self.root
        for char in word:
            if char in current_dict:
                current_dict = current_dict[char]
            else:
                return False
        else:
            if _end in current_dict:
                return True
            else:
                return False

    def insert(self, word):
        current_dict = self.root
        for char in word:
            current_dict = current_dict.setdefault(char, {})
        current_dict[_end] = _end
Download .txt
gitextract_y25cez6i/

├── LICENSE
├── README.md
├── main.py
├── model/
│   ├── LGN.py
│   ├── crf.py
│   └── module.py
└── utils/
    ├── alphabet.py
    ├── data.py
    ├── functions.py
    ├── metric.py
    └── word_trie.py
Download .txt
SYMBOL INDEX (78 symbols across 9 files)

FILE: main.py
  function str2bool (line 21) | def str2bool(v):
  function lr_decay (line 32) | def lr_decay(optimizer, epoch, decay_rate, init_lr):
  function data_initialization (line 43) | def data_initialization(data, word_file, train_file, dev_file, test_file):
  function predict_check (line 59) | def predict_check(pred_variable, gold_variable, mask_variable):
  function recover_label (line 70) | def recover_label(pred_variable, gold_variable, mask_variable, label_alp...
  function print_args (line 90) | def print_args(args):
  function evaluate (line 113) | def evaluate(data, args, model, name):
  function batchify_with_label (line 160) | def batchify_with_label(input_batch_list, gpu):
  function train (line 186) | def train(data, args, saved_model_path):
  function load_model_decode (line 337) | def load_model_decode(model_dir, data, args, name):

FILE: model/LGN.py
  class Graph (line 9) | class Graph(nn.Module):
    method __init__ (line 10) | def __init__(self, data, args):
    method construct_graph (line 169) | def construct_graph(self, batch_size, seq_len, word_list):
    method update_graph (line 282) | def update_graph(self, word_list, word_inputs, mask):
    method forward (line 441) | def forward(self, word_list, batch_inputs, mask, batch_label=None):

FILE: model/crf.py
  function log_sum_exp (line 13) | def log_sum_exp(vec, m_size):
  class CRF (line 27) | class CRF(nn.Module):
    method __init__ (line 29) | def __init__(self, tagset_size, gpu):
    method _calculate_PZ (line 51) | def _calculate_PZ(self, feats, mask):
    method _viterbi_decode (line 105) | def _viterbi_decode(self, feats, mask):
    method forward (line 193) | def forward(self, feats):
    method _score_sentence (line 198) | def _score_sentence(self, scores, mask, tags):
    method neg_log_likelihood_loss (line 249) | def neg_log_likelihood_loss(self, feats, mask, tags):

FILE: model/module.py
  class MultiHeadAtt (line 13) | class MultiHeadAtt(nn.Module):
    method __init__ (line 14) | def __init__(self, nhid, keyhid, nhead=10, head_dim=10, dropout=0.1, i...
    method forward (line 31) | def forward(self, query_h, value, mask, query_g=None):
  class GloAtt (line 67) | class GloAtt(nn.Module):
    method __init__ (line 68) | def __init__(self, nhid, nhead=10, head_dim=10, dropout=0.1):
    method forward (line 83) | def forward(self, x, y, mask=None):
  class Nodes_Cell (line 109) | class Nodes_Cell(nn.Module):
    method __init__ (line 110) | def __init__(self, input_h, hid_h, use_global=True, dropout=0.2):
    method reset_parameters (line 123) | def reset_parameters(self):
    method forward (line 128) | def forward(self, h, h2, x, glo=None):
  class Edges_Cell (line 149) | class Edges_Cell(nn.Module):
    method __init__ (line 150) | def __init__(self, input_h, hid_h, use_global=True, dropout=0.2):
    method reset_parameters (line 163) | def reset_parameters(self):
    method forward (line 168) | def forward(self, h, x, glo=None):
  class Global_Cell (line 188) | class Global_Cell(nn.Module):
    method __init__ (line 189) | def __init__(self, input_h, hid_h, dropout=0.2):
    method reset_parameters (line 201) | def reset_parameters(self):
    method forward (line 206) | def forward(self, h, x):

FILE: utils/alphabet.py
  class Alphabet (line 12) | class Alphabet:
    method __init__ (line 13) | def __init__(self, name, label=False, keep_growing=True):
    method clear (line 27) | def clear(self, keep_growing=True):
    method add (line 36) | def add(self, instance):
    method get_index (line 42) | def get_index(self, instance):
    method get_instance (line 53) | def get_instance(self, index):
    method size (line 63) | def size(self):
    method iteritems (line 66) | def iteritems(self):
    method enumerate_items (line 69) | def enumerate_items(self, start=1):
    method close (line 74) | def close(self):
    method open (line 77) | def open(self):
    method get_content (line 80) | def get_content(self):
    method from_json (line 83) | def from_json(self, data):
    method save (line 87) | def save(self, output_directory, name=None):
    method load (line 100) | def load(self, input_directory, name=None):

FILE: utils/data.py
  class Data (line 11) | class Data:
    method __init__ (line 12) | def __init__(self):
    method show_data_summary (line 38) | def show_data_summary(self):
    method build_alphabet (line 57) | def build_alphabet(self, input_file):
    method build_word_file (line 78) | def build_word_file(self, word_file):
    method build_word_alphabet (line 87) | def build_word_alphabet(self, input_file):
    method build_char_pretrain_emb (line 108) | def build_char_pretrain_emb(self, emb_path):
    method build_word_pretrain_emb (line 112) | def build_word_pretrain_emb(self, emb_path):
    method generate_instance_with_words (line 116) | def generate_instance_with_words(self, input_file, name):
    method write_decoded_results (line 132) | def write_decoded_results(self, output_file, predict_results, name):

FILE: utils/functions.py
  function normalize_word (line 8) | def normalize_word(word):
  function read_instance_with_gaz (line 18) | def read_instance_with_gaz(input_file, word_dict, char_alphabet, word_al...
  function build_pretrain_embedding (line 71) | def build_pretrain_embedding(embedding_path, word_alphabet, norm=True, e...
  function load_pretrain_emb (line 104) | def load_pretrain_emb(embedding_path):

FILE: utils/metric.py
  function get_ner_fmeasure (line 7) | def get_ner_fmeasure(golden_lists, predict_lists):
  function reverse_style (line 49) | def reverse_style(input_string):
  function get_ner_BMES (line 56) | def get_ner_BMES(label_list):

FILE: utils/word_trie.py
  class Word_Trie (line 8) | class Word_Trie:
    method __init__ (line 9) | def __init__(self):
    method recursive_search (line 12) | def recursive_search(self, word_list):
    method search (line 20) | def search(self, word):
    method insert (line 33) | def insert(self, word):
Condensed preview — 11 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (89K chars).
[
  {
    "path": "LICENSE",
    "chars": 1068,
    "preview": "MIT License\n\nCopyright (c) 2019 Yicheng Zou\n\nPermission is hereby granted, free of charge, to any person obtaining a cop"
  },
  {
    "path": "README.md",
    "chars": 3714,
    "preview": "# LGN\n\nPytorch implementation of [A Lexicon-Based Graph Neural Network for Chinese NER](https://www.aclweb.org/anthology"
  },
  {
    "path": "main.py",
    "chars": 18828,
    "preview": "# -*- coding: utf-8 -*-\n# @Author: Yicheng Zou\n# @Last Modified by:   Yicheng Zou,    Contact: yczou18@fudan.edu.cn\n\nimp"
  },
  {
    "path": "model/LGN.py",
    "chars": 22912,
    "preview": "# -*- coding: utf-8 -*-\n# @Author: Yicheng Zou\n# @Last Modified by:   Yicheng Zou,     Contact: yczou18@fudan.edu.cn\n\nfr"
  },
  {
    "path": "model/crf.py",
    "chars": 13801,
    "preview": "# -*- coding: utf-8 -*-\n# @Author: Jie Yang\n# @Last Modified by:   Yicheng Zou,     Contact: yczou18@fudan.edu.cn\n\nimpor"
  },
  {
    "path": "model/module.py",
    "chars": 7012,
    "preview": "# -*- coding: utf-8 -*-\n# @Author: Yicheng Zou\n# @Last Modified by:   Yicheng Zou,     Contact: yczou18@fudan.edu.cn\n\n\ni"
  },
  {
    "path": "utils/alphabet.py",
    "chars": 3625,
    "preview": "# -*- coding: utf-8 -*-\n# @Author: Max\n# @Last Modified by:   Yicheng Zou,     Contact: yczou18@fudan.edu.cn\n\n\"\"\"\nAlphab"
  },
  {
    "path": "utils/data.py",
    "chars": 6868,
    "preview": "# -*- coding: utf-8 -*-\n# @Author: Jie Yang\n# @Last Modified by:   Yicheng Zou,     Contact: yczou18@fudan.edu.cn\n\nimpor"
  },
  {
    "path": "utils/functions.py",
    "chars": 4126,
    "preview": "# -*- coding: utf-8 -*-\n# @Author: Jie Yang\n# @Last Modified by:   Yicheng Zou,     Contact: yczou18@fudan.edu.cn\n\nimpor"
  },
  {
    "path": "utils/metric.py",
    "chars": 3286,
    "preview": "# -*- coding: utf-8 -*-\n# @Author: Jie Yang\n# @Last Modified by:   Yicheng Zou,     Contact: yczou18@fudan.edu.cn\n\n\n# in"
  },
  {
    "path": "utils/word_trie.py",
    "chars": 985,
    "preview": "# -*- coding: utf-8 -*-\n# @Author: Yicheng Zou\n# @Last Modified by:   Yicheng Zou,     Contact: yczou18@fudan.edu.cn\n\n_e"
  }
]

About this extraction

This page contains the full source code of the RowitZou/LGN GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 11 files (84.2 KB), approximately 21.7k tokens, and a symbol index with 78 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!