[
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2019 Yicheng Zou\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# LGN\n\nPytorch implementation of [A Lexicon-Based Graph Neural Network for Chinese NER](https://www.aclweb.org/anthology/D19-1096.pdf).\n\nThe code is partially referred to https://github.com/jiesutd/LatticeLSTM.\n\n## Requirements\n\n* Python 3.6 or higher\n* Pytorch 0.4.1 or higher\n\n## Input Format\n\nBMES tag scheme, with each character its label for one line. Sentences are splited with a null line.\n\n\t印   B-LOC\n\t度   M-LOC\n\t河   E-LOC\n\t流   O\n\t经   O\n\t印   B-GPE\n\t度   E-GPE\n\n## Usage\n\n* Training\n\n\t\tpython main.py --status train \\\n\t\t               --train data/onto4ner.cn/train.char.bmes \\\n\t\t               --dev data/onto4ner.cn/dev.char.bmes \\\n\t\t               --test data/onto4ner.cn/test.char.bmes \\\n\t\t               --saved_model saved_model/model_onto4ner \\\n\t\t               --saved_set data/onto4ner.cn/saved.dset\n\t\t               \n* Testing\n\n\t\tpython main.py --status test \\\n\t\t               --test data/onto4ner.cn/test.char.bmes \\\n\t\t               --saved_model saved_model/model_onto4ner \\\n\t\t               --saved_set data/onto4ner.cn/saved.dset\n\t\t               \n* Decoding (Raw file can either be labeled or not.)\n\n\t\tpython main.py --status decode \\\n\t\t               --raw data/onto4ner.cn/test.char.bmes \\\n\t\t               --output tagged_file.txt \\\n\t\t               --saved_model saved_model/model_onto4ner \\\n\t\t               --saved_set data/onto4ner.cn/saved.dset\n\t\t               \n## Data\n\nThe pretrained character and word embeddings can be downloaded from [Lattice LSTM](https://github.com/jiesutd/LatticeLSTM).\n\nOriginal datasets can be found at [OntoNotes](https://catalog.ldc.upenn.edu/LDC2011T03), [MSRA](http://sighan.cs.uchicago.edu/bakeoff2006/), \n[Weibo](https://github.com/hltcoe/golden-horse) and [Resume](https://github.com/jiesutd/LatticeLSTM/tree/master/ResumeNER).\nThe preprocessed datasets that satisfy the input format of our codes are available at [Google Drive](https://drive.google.com/open?id=1Rvju5_gp2E6BFiqzMBtnMqVP803AbBcm) and \n[Baidu Pan](https://pan.baidu.com/s/1zbzLriRpc8S_5ez_upC7OA) (Code: akcm)\n\n## Pretrained Model Downloads\n\nWe also provide pretrained models on the four datasets, which are the same models as reported in the paper.\nIf you try to retrain models from scratch under the same hyper-parameter settings, you may obtain a sightly \nlower or higher F1 score than that reported in the paper (in our experiments we selected the model that performed best).\n\nPretrained models and related hyper-parameter settings are available at [Google Drive](https://drive.google.com/file/d/1KKkCW8WRhgR2P2UbRpNpKyE_RAv1EREv/view?usp=sharing) or [Baidu Pan](https://pan.baidu.com/s/1U89EwnhPpMa4bNrS--u4EA).\n\nWhen running main.py in test mode for pretrained models, you can get the results as follows:\n\n| Datasets       | Precision | Recall  | F1    | \n|:--------------:|:---------:|:-------:|:-----:|\n| OntoNotes dev  |   74.00   |  70.03  | 71.96 |\n| OntoNotes test |   76.13   |  73.68  | 74.89 | \n| MSRA dev       |     -     |   -     |   -   |\n| MSRA test      |   94.19   |  92.73  | 93.46 |\n| Weibo dev      |   66.09   |  59.13  | 62.42 |\n| Weibo test     |   65.71   |  55.56  | 60.21 |\n| Resume dev     |   94.27   |  94.59  | 94.43 |\n| Resume test    |   95.28   |  95.46  | 95.37 |\n\n## Citation\n\n\t@inproceedings{gui2019lexicon,\n  \t title={A Lexicon-Based Graph Neural Network for Chinese NER},\n  \t author={Gui, Tao and Zou, Yicheng and Zhang, Qi and Peng, Minlong and \n\t Fu, Jinlan and Wei, Zhongyu and Huang, Xuanjing},\n  \t booktitle={Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing \n\t and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)},\n  \t pages={1039--1049},\n  \t year={2019}\n\t}\n"
  },
  {
    "path": "main.py",
    "content": "# -*- coding: utf-8 -*-\n# @Author: Yicheng Zou\n# @Last Modified by:   Yicheng Zou,    Contact: yczou18@fudan.edu.cn\n\nimport time\nimport sys\nimport argparse\nimport random\nimport torch\nimport gc\nimport pickle\nimport os\nimport torch.autograd as autograd\nimport torch.optim as optim\nimport numpy as np\nfrom utils.metric import get_ner_fmeasure\nfrom model.LGN import Graph\nfrom utils.data import Data\n\n\ndef str2bool(v):\n    if isinstance(v, bool):\n       return v\n    if v.lower() in ('yes', 'true', 't', 'y', '1'):\n        return True\n    elif v.lower() in ('no', 'false', 'f', 'n', '0'):\n        return False\n    else:\n        raise argparse.ArgumentTypeError('Boolean value expected.')\n\n\ndef lr_decay(optimizer, epoch, decay_rate, init_lr):\n    lr = init_lr * ((1-decay_rate)**epoch)\n    print( \" Learning rate is setted as:\", lr)\n    for param_group in optimizer.param_groups:\n        if param_group['name'] == 'aggr':\n            param_group['lr'] = lr * 2.\n        else:\n            param_group['lr'] = lr\n    return optimizer\n\n\ndef data_initialization(data, word_file, train_file, dev_file, test_file):\n\n    data.build_word_file(word_file)\n\n    if train_file:\n        data.build_alphabet(train_file)\n        data.build_word_alphabet(train_file)\n    if dev_file:\n        data.build_alphabet(dev_file)\n        data.build_word_alphabet(dev_file)\n    if test_file:\n        data.build_alphabet(test_file)\n        data.build_word_alphabet(test_file)\n    return data\n\n\ndef predict_check(pred_variable, gold_variable, mask_variable):\n\n    pred = pred_variable.cpu().data.numpy()\n    gold = gold_variable.cpu().data.numpy()\n    mask = mask_variable.cpu().data.numpy()\n    overlaped = (pred == gold)\n    right_token = np.sum(overlaped * mask)\n    total_token = mask.sum()\n    return right_token, total_token\n\n\ndef recover_label(pred_variable, gold_variable, mask_variable, label_alphabet):\n\n    batch_size = gold_variable.size(0)\n    seq_len = gold_variable.size(1)\n    mask = mask_variable.cpu().data.numpy()\n    pred_tag = pred_variable.cpu().data.numpy()\n    gold_tag = gold_variable.cpu().data.numpy()\n    pred_label = []\n    gold_label = []\n\n    for idx in range(batch_size):\n        pred = [label_alphabet.get_instance(pred_tag[idx][idy]) for idy in range(seq_len) if mask[idx][idy] != 0]\n        gold = [label_alphabet.get_instance(gold_tag[idx][idy]) for idy in range(seq_len) if mask[idx][idy] != 0]\n        assert(len(pred)==len(gold))\n        pred_label.append(pred)\n        gold_label.append(gold)\n\n    return pred_label, gold_label\n\n\ndef print_args(args):\n    print(\"CONFIG SUMMARY:\")\n    print(\"     Batch size: %s\" % (args.batch_size))\n    print(\"     If use GPU: %s\" % (args.use_gpu))\n    print(\"     If use CRF: %s\" % (args.use_crf))\n    print(\"     Epoch  number: %s\" % (args.num_epoch))\n    print(\"     Learning rate: %s\" % (args.lr))\n    print(\"     L2 normalization rate: %s\" % (args.weight_decay))\n    print(\"     If use edge embedding: %s\" % (args.use_edge))\n    print(\"     If  use  global  node: %s\" % (args.use_global))\n    print(\"     Bidirectional digraph: %s\" % (args.bidirectional))\n    print(\"     Update   step  number: %s\" % (args.iters))\n    print(\"     Attention  dropout   rate: %s\" % (args.tf_drop_rate))\n    print(\"     Embedding  dropout   rate: %s\" % (args.emb_drop_rate))\n    print(\"     Hidden  state   dimension: %s\" % (args.hidden_dim))\n    print(\"     Learning rate decay ratio: %s\" % (args.lr_decay))\n    print(\"     Aggregation module dropout rate: %s\" % (args.cell_drop_rate))\n    print(\"     Head    number   of   attention: %s\" % (args.num_head))\n    print(\"     Head  dimension   of  attention: %s\" % (args.head_dim))\n    print(\"CONFIG SUMMARY END.\")\n    sys.stdout.flush()\n\n\ndef evaluate(data, args, model, name):\n    if name == \"train\":\n        instances = data.train_Ids\n    elif name == \"dev\":\n        instances = data.dev_Ids\n    elif name == 'test':\n        instances = data.test_Ids\n    elif name == 'raw':\n        instances = data.raw_Ids\n    else:\n        print(\"Error: wrong evaluate name,\", name)\n        exit(0)\n\n    pred_results = []\n    gold_results = []\n\n    # set model in eval model\n    model.eval()\n    batch_size = args.batch_size\n    start_time = time.time()\n    train_num = len(instances)\n    total_batch = train_num // batch_size + 1\n\n    for batch_id in range(total_batch):\n        start = batch_id*batch_size\n        end = (batch_id+1)*batch_size\n        if end > train_num:\n            end = train_num\n        instance = instances[start:end]\n        if not instance:\n            continue\n\n        word_list, batch_char, batch_label, mask = batchify_with_label(instance, args.use_gpu)\n        _, tag_seq = model(word_list, batch_char, mask)\n\n        pred_label, gold_label = recover_label(tag_seq, batch_label, mask, data.label_alphabet)\n\n        pred_results += pred_label\n        gold_results += gold_label\n\n    decode_time = time.time() - start_time\n    speed = len(instances) / decode_time\n\n    acc, p, r, f = get_ner_fmeasure(gold_results, pred_results)\n    return speed, acc, p, r, f, pred_results\n\n\ndef batchify_with_label(input_batch_list, gpu):\n\n    batch_size = len(input_batch_list)\n    chars = [sent[0] for sent in input_batch_list]\n    words = [sent[1] for sent in input_batch_list]\n    labels = [sent[2] for sent in input_batch_list]\n\n    sent_lengths = torch.LongTensor(list(map(len, chars)))\n    max_sent_len = sent_lengths.max()\n    char_seq_tensor = autograd.Variable(torch.zeros((batch_size, max_sent_len))).long()\n    label_seq_tensor = autograd.Variable(torch.zeros((batch_size, max_sent_len))).long()\n    mask = autograd.Variable(torch.zeros((batch_size, max_sent_len))).byte()\n\n    for idx, (seq, label, seq_len) in enumerate(zip(chars, labels, sent_lengths)):\n        char_seq_tensor[idx, :seq_len] = torch.LongTensor(seq)\n        label_seq_tensor[idx, :seq_len] = torch.LongTensor(label)\n        mask[idx, :seq_len] = torch.Tensor([1] * int(seq_len))\n\n    if gpu:\n        char_seq_tensor = char_seq_tensor.cuda()\n        label_seq_tensor = label_seq_tensor.cuda()\n        mask = mask.cuda()\n\n    return words, char_seq_tensor, label_seq_tensor, mask\n\n\ndef train(data, args, saved_model_path):\n\n    print( \"Training model...\")\n    model = Graph(data, args)\n    if args.use_gpu:\n        model = model.cuda()\n    print('# generated parameters:', sum(param.numel() for param in model.parameters()))\n    print( \"Finished built model.\")\n\n    best_dev_epoch = 0\n    best_dev_f = -1\n    best_dev_p = -1\n    best_dev_r = -1\n\n    best_test_f = -1\n    best_test_p = -1\n    best_test_r = -1\n\n    # Initialize the optimizer\n    aggr_module_params = []\n    other_module_params = []\n    for m_name in model._modules:\n        m = model._modules[m_name]\n        if isinstance(m, torch.nn.ModuleList):\n            for p in m.parameters():\n                if p.requires_grad:\n                    aggr_module_params.append(p)\n        else:\n            for p in m.parameters():\n                if p.requires_grad:\n                    other_module_params.append(p)\n\n    optimizer = optim.Adam([\n            {\"params\": (aggr_module_params), \"name\": \"aggr\"},\n            {\"params\": (other_module_params), \"name\": \"other\"}\n        ],\n        lr=args.lr,\n        weight_decay=args.weight_decay\n    )\n\n    for idx in range(args.num_epoch):\n        epoch_start = time.time()\n        temp_start = epoch_start\n        print((\"Epoch: %s/%s\" %(idx, args.num_epoch)))\n        optimizer = lr_decay(optimizer, idx, args.lr_decay, args.lr)\n        sample_loss = 0\n        batch_loss = 0\n        total_loss = 0\n        right_token = 0\n        whole_token = 0\n        random.shuffle(data.train_Ids)\n        # set model in train model\n        model.train()\n        model.zero_grad()\n        batch_size = args.batch_size\n        train_num = len(data.train_Ids)\n        total_batch = train_num // batch_size + 1\n\n        for batch_id in range(total_batch):\n            # Get one batch-sized instance\n            start = batch_id * batch_size\n            end = (batch_id + 1) * batch_size\n            if end > train_num:\n                end = train_num\n            instance = data.train_Ids[start:end]\n            if not instance:\n                continue\n\n            word_list, batch_char, batch_label, mask = batchify_with_label(instance, args.use_gpu)\n            loss, tag_seq = model(word_list, batch_char, mask, batch_label)\n            right, whole = predict_check(tag_seq, batch_label, mask)\n            right_token += right\n            whole_token += whole\n            sample_loss += loss.data\n            total_loss += loss.data\n            batch_loss += loss\n\n            if end % 500 == 0:\n                temp_time = time.time()\n                temp_cost = temp_time - temp_start\n                temp_start = temp_time\n                print((\"     Instance: %s; Time: %.2fs; loss: %.4f; acc: %s/%s=%.4f\" %\n                       (end, temp_cost, sample_loss, right_token, whole_token, (right_token+0.)/whole_token)))\n                sys.stdout.flush()\n                sample_loss = 0\n            if end % args.batch_size == 0:\n                batch_loss.backward()\n                optimizer.step()\n                model.zero_grad()\n                batch_loss = 0\n\n        temp_time = time.time()\n        temp_cost = temp_time - temp_start\n        print((\"     Instance: %s; Time: %.2fs; loss: %.4f; acc: %s/%s=%.4f\" %\n               (end, temp_cost, sample_loss, right_token, whole_token, (right_token+0.)/whole_token)))\n        epoch_finish = time.time()\n        epoch_cost = epoch_finish - epoch_start\n        print((\"Epoch: %s training finished. Time: %.2fs, speed: %.2fst/s,  total loss: %s\" %\n               (idx, epoch_cost, train_num/epoch_cost, total_loss)))\n\n        # dev\n        speed, acc, dev_p, dev_r, dev_f, _ = evaluate(data, args, model, \"dev\")\n        dev_finish = time.time()\n        dev_cost = dev_finish - epoch_finish\n\n        print((\"Dev: time: %.2fs, speed: %.2fst/s; acc: %.4f, p: %.4f, r: %.4f, f: %.4f\" %\n               (dev_cost, speed, acc, dev_p, dev_r, dev_f)))\n\n        # test\n        speed, acc, test_p, test_r, test_f, _ = evaluate(data, args, model, \"test\")\n        test_finish = time.time()\n        test_cost = test_finish - dev_finish\n\n        print((\"Test: time: %.2fs, speed: %.2fst/s; acc: %.4f, p: %.4f, r: %.4f, f: %.4f\" %\n               (test_cost, speed, acc, test_p, test_r, test_f)))\n\n        if dev_f > best_dev_f:\n            print(\"Exceed previous best f score: %.4f\" % best_dev_f)\n            torch.save(model.state_dict(), saved_model_path + \"_best\")\n            best_dev_p = dev_p\n            best_dev_r = dev_r\n            best_dev_f = dev_f\n            best_dev_epoch = idx + 1\n            best_test_p = test_p\n            best_test_r = test_r\n            best_test_f = test_f\n\n        model_idx_path = saved_model_path + \"_\" + str(idx)\n        torch.save(model.state_dict(), model_idx_path)\n        with open(saved_model_path + \"_result.txt\", \"a\") as file:\n            file.write(model_idx_path + '\\n')\n            file.write(\"Dev score: %.4f, r: %.4f, f: %.4f\\n\" % (dev_p, dev_r, dev_f))\n            file.write(\"Test score: %.4f, r: %.4f, f: %.4f\\n\\n\" % (test_p, test_r, test_f))\n            file.close()\n\n        print(\"Best dev epoch: %d\" % best_dev_epoch)\n        print(\"Best dev score: p: %.4f, r: %.4f, f: %.4f\" % (best_dev_p, best_dev_r, best_dev_f))\n        print(\"Best test score: p: %.4f, r: %.4f, f: %.4f\" % (best_test_p, best_test_r, best_test_f))\n\n        gc.collect()\n\n    with open(saved_model_path + \"_result.txt\", \"a\") as file:\n        file.write(\"Best epoch: %d\" % best_dev_epoch + '\\n')\n        file.write(\"Best Dev score: %.4f, r: %.4f, f: %.4f\\n\" % (best_dev_p, best_dev_r, best_dev_f))\n        file.write(\"Test score: %.4f, r: %.4f, f: %.4f\\n\\n\" % (best_test_p, best_test_r, best_test_f))\n        file.close()\n\n    with open(saved_model_path + \"_best_HP.config\", \"wb\") as file:\n        pickle.dump(args, file)\n\n\ndef load_model_decode(model_dir, data, args, name):\n    model_dir = model_dir + \"_best\"\n    print(\"Load Model from file: \", model_dir)\n    model = Graph(data, args)\n    model.load_state_dict(torch.load(model_dir))\n\n    # load model need consider if the model trained in GPU and load in CPU, or vice versa\n    if args.use_gpu:\n        model = model.cuda()\n\n    print((\"Decode %s data ...\" % name))\n    start_time = time.time()\n    speed, acc, p, r, f, pred_results = evaluate(data, args, model, name)\n    end_time = time.time()\n    time_cost = end_time - start_time\n    print((\"%s: time:%.2fs, speed:%.2fst/s; acc: %.4f, p: %.4f, r: %.4f, f: %.4f\" %\n           (name, time_cost, speed, acc, p, r, f)))\n\n    return pred_results\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--status', choices=['train', 'test', 'decode'], help='Function status.', default='train')\n    parser.add_argument('--use_gpu', type=str2bool, default=True)\n    parser.add_argument('--train', help='Training set.', default='data/onto4ner.cn/train.char.bmes')\n    parser.add_argument('--dev', help='Developing set.', default='data/onto4ner.cn/dev.char.bmes')\n    parser.add_argument('--test', help='Testing set.', default='data/onto4ner.cn/test.char.bmes')\n    parser.add_argument('--raw', help='Raw file for decoding.')\n    parser.add_argument('--output', help='Output results for decoding.')\n    parser.add_argument('--saved_set', help='Path of saved data set.', default='data/onto4ner.cn/saved.dset')\n    parser.add_argument('--saved_model', help='Path of saved model.', default=\"saved_model/model_onto4ner\")\n    parser.add_argument('--char_emb', help='Path of character embedding file.', default=\"data/gigaword_chn.all.a2b.uni.ite50.vec\")\n    parser.add_argument('--word_emb', help='Path of word embedding file.', default=\"data/ctb.50d.vec\")\n\n    parser.add_argument('--use_crf', type=str2bool, default=True)\n    parser.add_argument('--use_edge', type=str2bool, default=True, help='If use lexicon embeddings (edge embeddings).')\n    parser.add_argument('--use_global', type=str2bool, default=True, help='If use the global node.')\n    parser.add_argument('--bidirectional', type=str2bool, default=True, help='If use bidirectional digraph.')\n\n    parser.add_argument('--seed', help='Random seed', default=1023, type=int)\n    parser.add_argument('--batch_size', help='Batch size.', default=1, type=int)\n    parser.add_argument('--num_epoch',default=100, type=int, help=\"Epoch number.\")\n    parser.add_argument('--iters', default=4, type=int, help='The number of Graph iterations.')\n    parser.add_argument('--hidden_dim', default=50, type=int, help='Hidden state size.')\n    parser.add_argument('--num_head', default=10, type=int, help='Number of transformer head.')\n    parser.add_argument('--head_dim', default=20, type=int, help='Head dimension of transformer.')\n    parser.add_argument('--tf_drop_rate', default=0.1, type=float, help='Transformer dropout rate.')\n    parser.add_argument('--emb_drop_rate', default=0.5, type=float, help='Embedding dropout rate.')\n    parser.add_argument('--cell_drop_rate', default=0.2, type=float, help='Aggregation module dropout rate.')\n    parser.add_argument('--word_alphabet_size', type=int, help='Word alphabet size.')\n    parser.add_argument('--char_alphabet_size', type=int, help='Char alphabet size.')\n    parser.add_argument('--label_alphabet_size', type=int, help='Label alphabet size.')\n    parser.add_argument('--char_dim', type=int, help='Char embedding size.')\n    parser.add_argument('--word_dim', type=int, help='Word embedding size.')\n    parser.add_argument('--lr', type=float, default=2e-05)\n    parser.add_argument('--lr_decay', type=float, default=0)\n    parser.add_argument('--weight_decay', type=float, default=0)\n\n    args = parser.parse_args()\n\n    status = args.status.lower()\n    seed_num = args.seed\n    random.seed(seed_num)\n    torch.manual_seed(seed_num)\n    np.random.seed(seed_num)\n\n    train_file = args.train\n    dev_file = args.dev\n    test_file = args.test\n    raw_file = args.raw\n    output_file = args.output\n    saved_set_path = args.saved_set\n    saved_model_path = args.saved_model\n    char_file = args.char_emb\n    word_file = args.word_emb\n\n    if status == 'train':\n        if os.path.exists(saved_set_path):\n            print('Loading saved data set...')\n            with open(saved_set_path, 'rb') as f:\n                data = pickle.load(f)\n        else:\n            data = Data()\n            data_initialization(data, word_file, train_file, dev_file, test_file)\n            data.generate_instance_with_words(train_file, 'train')\n            data.generate_instance_with_words(dev_file, 'dev')\n            data.generate_instance_with_words(test_file, 'test')\n            data.build_char_pretrain_emb(char_file)\n            data.build_word_pretrain_emb(word_file)\n            if saved_set_path is not None:\n                print('Dumping data...')\n                with open(saved_set_path, 'wb') as f:\n                    pickle.dump(data, f)\n        data.show_data_summary()\n        args.word_alphabet_size = data.word_alphabet.size()\n        args.char_alphabet_size = data.char_alphabet.size()\n        args.label_alphabet_size = data.label_alphabet.size()\n        args.char_dim = data.char_emb_dim\n        args.word_dim = data.word_emb_dim\n        print_args(args)\n        train(data, args, saved_model_path)\n\n    elif status == 'test':\n        assert not (test_file is None)\n        if os.path.exists(saved_set_path):\n            print('Loading saved data set...')\n            with open(saved_set_path, 'rb') as f:\n                data = pickle.load(f)\n        else:\n            print(\"Cannot find saved data set: \", saved_set_path)\n            exit(0)\n        data.generate_instance_with_words(test_file, 'test')\n        with open(saved_model_path + \"_best_HP.config\", \"rb\") as f:\n            args = pickle.load(f)\n        data.show_data_summary()\n        print_args(args)\n        load_model_decode(saved_model_path, data, args, \"test\")\n\n    elif status == 'decode':\n        assert not (raw_file is None or output_file is None)\n        if os.path.exists(saved_set_path):\n            print('Loading saved data set...')\n            with open(saved_set_path, 'rb') as f:\n                data = pickle.load(f)\n        else:\n            print(\"Cannot find saved data set: \", saved_set_path)\n            exit(0)\n        data.generate_instance_with_words(raw_file, 'raw')\n        with open(saved_model_path + \"_best_HP.config\", \"rb\") as f:\n            args = pickle.load(f)\n        data.show_data_summary()\n        print_args(args)\n        decode_results = load_model_decode(saved_model_path, data, args, \"raw\")\n        data.write_decoded_results(output_file, decode_results, 'raw')\n    else:\n        print(\"Invalid argument! Please use valid arguments! (train/test/decode)\")\n"
  },
  {
    "path": "model/LGN.py",
    "content": "# -*- coding: utf-8 -*-\n# @Author: Yicheng Zou\n# @Last Modified by:   Yicheng Zou,     Contact: yczou18@fudan.edu.cn\n\nfrom model.crf import CRF\nfrom model.module import *\n\n\nclass Graph(nn.Module):\n    def __init__(self, data, args):\n        super(Graph, self).__init__()\n\n        self.gpu = args.use_gpu\n        self.char_emb_dim = args.char_dim\n        self.word_emb_dim = args.word_dim\n        self.hidden_dim = args.hidden_dim\n        self.num_head = args.num_head  # 5 10 20\n        self.head_dim = args.head_dim  # 10 20\n        self.tf_dropout_rate = args.tf_drop_rate\n        self.iters = args.iters\n        self.bmes_dim = 10\n        self.length_dim = 10\n        self.max_word_length = 5\n        self.emb_dropout_rate = args.emb_drop_rate\n        self.cell_dropout_rate = args.cell_drop_rate\n        self.use_crf = args.use_crf\n        self.use_global = args.use_global\n        self.use_edge = args.use_edge\n        self.bidirectional = args.bidirectional\n        self.label_size = args.label_alphabet_size\n\n        # char embedding\n        self.char_embedding = nn.Embedding(args.char_alphabet_size, self.char_emb_dim)\n        if data.pretrain_char_embedding is not None:\n            self.char_embedding.weight.data.copy_(torch.from_numpy(data.pretrain_char_embedding))\n\n        if self.use_edge:\n\n            # word embedding\n            self.word_embedding = nn.Embedding(args.word_alphabet_size, self.word_emb_dim)\n            if data.pretrain_word_embedding is not None:\n                scale = np.sqrt(3.0 / self.word_emb_dim)\n                data.pretrain_word_embedding[0, :] = np.random.uniform(-scale, scale, [1, self.word_emb_dim])\n                self.word_embedding.weight.data.copy_(torch.from_numpy(data.pretrain_word_embedding))\n\n            # bmes embedding\n            self.bmes_embedding = nn.Embedding(4, self.bmes_dim)\n            \"\"\"\n            self.edge_emb_linear = nn.Sequential(\n                nn.Linear(self.word_emb_dim, self.hidden_dim),\n                nn.ELU()\n            )\n            \"\"\"\n        # lstm\n        self.emb_rnn_f = nn.LSTM(self.char_emb_dim, self.hidden_dim, batch_first=True)\n        self.emb_rnn_b = nn.LSTM(self.char_emb_dim, self.hidden_dim, batch_first=True)\n\n        # length embedding\n        self.length_embedding = nn.Embedding(self.max_word_length, self.length_dim)\n\n        self.dropout = nn.Dropout(self.emb_dropout_rate)\n        self.norm = nn.LayerNorm(self.hidden_dim)\n\n        if self.use_edge:\n            # Node aggregation module\n            self.edge2node_f = nn.ModuleList(\n                [MultiHeadAtt(self.hidden_dim, self.hidden_dim * 2 + self.length_dim,\n                              nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate)\n                 for _ in range(self.iters)])\n            # Edge aggregation module\n            self.node2edge_f = nn.ModuleList(\n                [MultiHeadAtt(self.hidden_dim, self.hidden_dim + self.bmes_dim, nhead=self.num_head,\n                              head_dim=self.head_dim, dropout=self.tf_dropout_rate)\n                 for _ in range(self.iters)])\n\n        else:\n            # Node aggregation module\n            self.edge2node_f = nn.ModuleList(\n                [MultiHeadAtt(self.hidden_dim, self.hidden_dim + self.length_dim,\n                              nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate)\n                 for _ in range(self.iters)])\n\n        if self.use_global:\n            # Global Node aggregation module\n            self.glo_att_f_node = nn.ModuleList(\n                [GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate)\n                 for _ in range(self.iters)])\n\n            if self.use_edge:\n                self.glo_att_f_edge = nn.ModuleList(\n                    [GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate)\n                     for _ in range(self.iters)])\n\n            # Updating modules\n            if self.use_edge:\n                self.glo_rnn_f = Global_Cell(self.hidden_dim * 3, self.hidden_dim, dropout=self.cell_dropout_rate)\n                self.node_rnn_f = Nodes_Cell(self.hidden_dim * 5, self.hidden_dim, dropout=self.cell_dropout_rate)\n                self.edge_rnn_f = Edges_Cell(self.hidden_dim * 4, self.hidden_dim, dropout=self.cell_dropout_rate)\n            else:\n                self.glo_rnn_f = Global_Cell(self.hidden_dim * 2, self.hidden_dim, dropout=self.cell_dropout_rate)\n                self.node_rnn_f = Nodes_Cell(self.hidden_dim * 4, self.hidden_dim, dropout=self.cell_dropout_rate)\n\n        else:\n            # Updating modules\n            self.node_rnn_f = Nodes_Cell(self.hidden_dim * 3, self.hidden_dim, use_global=False, dropout=self.cell_dropout_rate)\n            if self.use_edge:\n                self.edge_rnn_f = Edges_Cell(self.hidden_dim * 2, self.hidden_dim, use_global=False, dropout=self.cell_dropout_rate)\n\n        if self.bidirectional:\n\n            if self.use_edge:\n                # Node aggregation module\n                self.edge2node_b = nn.ModuleList(\n                    [MultiHeadAtt(self.hidden_dim, self.hidden_dim * 2 + self.length_dim,\n                                  nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate)\n                     for _ in range(self.iters)])\n                # Edge aggregation module\n                self.node2edge_b = nn.ModuleList(\n                    [MultiHeadAtt(self.hidden_dim, self.hidden_dim + self.bmes_dim, nhead=self.num_head,\n                                  head_dim=self.head_dim, dropout=self.tf_dropout_rate)\n                     for _ in range(self.iters)])\n\n            else:\n                # Node aggregation module\n                self.edge2node_b = nn.ModuleList(\n                    [MultiHeadAtt(self.hidden_dim, self.hidden_dim + self.length_dim,\n                                  nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate)\n                     for _ in range(self.iters)])\n\n            if self.use_global:\n                # Global Node aggregation module\n                self.glo_att_b_node = nn.ModuleList(\n                    [GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate)\n                     for _ in range(self.iters)])\n                if self.use_edge:\n                    self.glo_att_b_edge = nn.ModuleList(\n                        [GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate)\n                         for _ in range(self.iters)])\n\n                # Updating modules\n                if self.use_edge:\n                    self.glo_rnn_b = Global_Cell(self.hidden_dim * 3, self.hidden_dim, self.cell_dropout_rate)\n                    self.node_rnn_b = Nodes_Cell(self.hidden_dim * 5, self.hidden_dim, self.cell_dropout_rate)\n                    self.edge_rnn_b = Edges_Cell(self.hidden_dim * 4, self.hidden_dim, self.cell_dropout_rate)\n                else:\n                    self.glo_rnn_b = Global_Cell(self.hidden_dim * 2, self.hidden_dim, self.cell_dropout_rate)\n                    self.node_rnn_b = Nodes_Cell(self.hidden_dim * 4, self.hidden_dim, self.cell_dropout_rate)\n\n            else:\n                # Updating modules\n                self.node_rnn_b = Nodes_Cell(self.hidden_dim * 3, self.hidden_dim, use_global=False, dropout=self.cell_dropout_rate)\n                if self.use_edge:\n                    self.edge_rnn_b = Edges_Cell(self.hidden_dim * 2, self.hidden_dim, use_global=False, dropout=self.cell_dropout_rate)\n\n        if self.bidirectional:\n            output_dim = self.hidden_dim * 2\n        else:\n            output_dim = self.hidden_dim\n\n        self.layer_att_W = nn.Linear(output_dim, 1)\n\n        if self.use_crf:\n            self.hidden2tag = nn.Linear(output_dim, self.label_size + 2)\n            self.crf = CRF(self.label_size, self.gpu)\n        else:\n            self.hidden2tag = nn.Linear(output_dim, self.label_size)\n            self.criterion = nn.CrossEntropyLoss()\n\n    def construct_graph(self, batch_size, seq_len, word_list):\n\n        if self.cuda:\n            device = 'cuda'\n        else:\n            device = 'cpu'\n        if self.use_edge:\n            unk_index = torch.tensor(0, device=device)\n            unk_emb = self.word_embedding(unk_index)\n\n            bmes_emb_b = self.bmes_embedding(torch.tensor(0, device=device))\n            bmes_emb_m = self.bmes_embedding(torch.tensor(1, device=device))\n            bmes_emb_e = self.bmes_embedding(torch.tensor(2, device=device))\n            bmes_emb_s = self.bmes_embedding(torch.tensor(3, device=device))\n\n        sen_nodes_mask_list = []\n        sen_words_length_list =[]\n        sen_words_mask_f_list = []\n        sen_words_mask_b_list = []\n        sen_word_embed_list = []\n        sen_bmes_embed_list = []\n        max_edge_num = -1\n\n        for sen in range(batch_size):\n            sen_nodes_mask = torch.zeros([1, seq_len], device=device).byte()\n            sen_words_length = torch.zeros([1, self.length_dim], device=device)\n            sen_words_mask_f = torch.zeros([1, seq_len], device=device).byte()\n            sen_words_mask_b = torch.zeros([1, seq_len], device=device).byte()\n\n            if self.use_edge:\n                sen_word_embed = unk_emb[None, :]\n                sen_bmes_embed = torch.zeros([1, seq_len, self.bmes_dim], device=device)\n\n            for w in range(seq_len):\n                if w < len(word_list[sen]) and word_list[sen][w]:\n                    for word, word_len in zip(word_list[sen][w][0], word_list[sen][w][1]):\n\n                        if word_len <= self.max_word_length:\n                            word_length_index = torch.tensor(word_len-1, device=device)\n                        else:\n                            word_length_index = torch.tensor(self.max_word_length - 1, device=device)\n                        word_length = self.length_embedding(word_length_index)\n                        sen_words_length = torch.cat([sen_words_length, word_length[None, :]], 0)\n\n                        # mask: Masked elements are marked by 1, batch_size * word_num * seq_len\n                        nodes_mask = torch.ones([1, seq_len], device=device).byte()\n                        words_mask_f = torch.ones([1, seq_len], device=device).byte()\n                        words_mask_b = torch.ones([1, seq_len], device=device).byte()\n\n                        words_mask_f[0, w + word_len - 1] = 0\n                        sen_words_mask_f = torch.cat([sen_words_mask_f, words_mask_f], 0)\n\n                        words_mask_b[0, w] = 0\n                        sen_words_mask_b = torch.cat([sen_words_mask_b, words_mask_b], 0)\n\n                        if self.use_edge:\n                            word_index = torch.tensor(word, device=device)\n                            word_embedding = self.word_embedding(word_index)\n                            sen_word_embed = torch.cat([sen_word_embed, word_embedding[None, :]], 0)\n\n                            bmes_embed = torch.zeros([1, seq_len, self.bmes_dim], device=device)\n\n                            for index in range(word_len):\n                                nodes_mask[0, w + index] = 0\n                                if word_len == 1:\n                                    bmes_embed[0, w + index, :] = bmes_emb_s\n                                elif index == 0:\n                                    bmes_embed[0, w + index, :] = bmes_emb_b\n                                elif index == word_len - 1:\n                                    bmes_embed[0, w + index, :] = bmes_emb_e\n                                else:\n                                    bmes_embed[0, w + index, :] = bmes_emb_m\n\n                            sen_bmes_embed = torch.cat([sen_bmes_embed, bmes_embed], 0)\n                            sen_nodes_mask = torch.cat([sen_nodes_mask, nodes_mask], 0)\n\n            if sen_words_mask_f.size(0) > max_edge_num:\n                max_edge_num = sen_words_mask_f.size(0)\n            sen_words_mask_f_list.append(sen_words_mask_f.unsqueeze_(0))\n            sen_words_mask_b_list.append(sen_words_mask_b.unsqueeze_(0))\n            sen_words_length_list.append(sen_words_length.unsqueeze_(0))\n            if self.use_edge:\n                sen_nodes_mask_list.append(sen_nodes_mask.unsqueeze_(0))\n                sen_word_embed_list.append(sen_word_embed.unsqueeze_(0))\n                sen_bmes_embed_list.append(sen_bmes_embed.unsqueeze_(0))\n\n        edges_mask = torch.zeros([batch_size, max_edge_num], device=device)\n        batch_words_mask_f = torch.ones([batch_size, max_edge_num, seq_len], device=device).byte()\n        batch_words_mask_b = torch.ones([batch_size, max_edge_num, seq_len], device=device).byte()\n        batch_words_length = torch.zeros([batch_size, max_edge_num, self.length_dim], device=device)\n        if self.use_edge:\n            batch_nodes_mask = torch.zeros([batch_size, max_edge_num, seq_len], device=device).byte()\n            batch_word_embed = torch.zeros([batch_size, max_edge_num, self.word_emb_dim], device=device)\n            batch_bmes_embed = torch.zeros([batch_size, max_edge_num, seq_len, self.bmes_dim], device=device)\n        else:\n            batch_word_embed = None\n            batch_bmes_embed = None\n            batch_nodes_mask = None\n\n        for index in range(batch_size):\n            curr_edge_num = sen_words_mask_f_list[index].size(1)\n            edges_mask[index, 0:curr_edge_num] = 1.\n            batch_words_mask_f[index, 0:curr_edge_num, :] = sen_words_mask_f_list[index]\n            batch_words_mask_b[index, 0:curr_edge_num, :] = sen_words_mask_b_list[index]\n            batch_words_length[index, 0:curr_edge_num, :] = sen_words_length_list[index]\n            if self.use_edge:\n                batch_nodes_mask[index, 0:curr_edge_num, :] = sen_nodes_mask_list[index]\n                batch_word_embed[index, 0:curr_edge_num, :] = sen_word_embed_list[index]\n                batch_bmes_embed[index, 0:curr_edge_num, :, :] = sen_bmes_embed_list[index]\n\n        return batch_word_embed, batch_bmes_embed, batch_nodes_mask, batch_words_mask_f, \\\n               batch_words_mask_b, batch_words_length, edges_mask\n\n    def update_graph(self, word_list, word_inputs, mask):\n        mask = mask.float()\n        node_embeds = self.char_embedding(word_inputs)  # batch_size, max_seq_len, embedding\n        B, L, _ = node_embeds.size()\n\n        edge_embs, bmes_embs, nodes_mask, words_mask_f, words_mask_b, words_length, edges_mask = \\\n            self.construct_graph(B, L, word_list)\n\n        node_embeds = self.dropout(node_embeds)\n\n        _, N, _ = words_mask_f.size()\n\n        if self.use_edge:\n            edge_embs = self.dropout(edge_embs)\n\n        # forward direction digraph\n        nodes_f, _ = self.emb_rnn_f(node_embeds)\n        nodes_f = nodes_f * mask.unsqueeze(2)\n        nodes_f_cat = nodes_f[:, None, :, :]\n        _, _, H = nodes_f.size()\n\n        if self.use_edge:\n            edges_f = edge_embs * edges_mask.unsqueeze(2)\n            edges_f_cat = edges_f[:, None, :, :]\n\n            if self.use_global:\n                glo_f = edges_f.sum(1, keepdim=True) / edges_mask.sum(1, keepdim=True).unsqueeze_(2) + \\\n                        nodes_f.sum(1, keepdim=True) / mask.sum(1, keepdim=True).unsqueeze_(2)\n                glo_f_cat = glo_f[:, None, :, :]\n\n        else:\n            if self.use_global:\n                glo_f = (nodes_f * mask.unsqueeze(2)).sum(1, keepdim=True) / mask.sum(1, keepdim=True).unsqueeze_(2)\n                glo_f_cat = glo_f[:, None, :, :]\n\n        for i in range(self.iters):\n\n            # Attention-based aggregation\n            if self.use_edge and N > 1:\n                bmes_nodes_f = torch.cat([nodes_f.unsqueeze(2).expand(B, L, N, H), bmes_embs.transpose(1, 2)], -1)\n                edges_att_f = self.node2edge_f[i](edges_f, bmes_nodes_f, nodes_mask.transpose(1, 2))\n\n            nodes_begin_f = torch.sum(nodes_f[:, None, :, :] * (1 - words_mask_b)[:, :, :, None].float(), 2)\n            nodes_begin_f = torch.cat([torch.zeros([B, 1, H], device=nodes_f.device), nodes_begin_f[:, 1:N, :]], 1)\n\n            if self.use_edge:\n                nodes_att_f = self.edge2node_f[i](nodes_f, torch.cat([edges_f, nodes_begin_f, words_length], -1).unsqueeze(2), words_mask_f)\n                if self.use_global:\n                    glo_att_f = torch.cat([self.glo_att_f_node[i](glo_f, nodes_f, (1 - mask).byte()),\n                                           self.glo_att_f_edge[i](glo_f, edges_f, (1 - edges_mask).byte())], -1)\n            else:\n                nodes_att_f = self.edge2node_f[i](nodes_f, torch.cat([nodes_begin_f, words_length], -1).unsqueeze(2), words_mask_f)\n                if self.use_global:\n                    glo_att_f = self.glo_att_f_node[i](glo_f, nodes_f, (1 - mask).byte())\n\n            # RNN-based update\n            if self.use_edge and N > 1:\n                if self.use_global:\n                    edges_f = torch.cat([edges_f[:, 0:1, :], self.edge_rnn_f(edges_f[:, 1:N, :],\n                                         edges_att_f[:, 1:N, :], glo_att_f.expand(B, N-1, H*2))], 1)\n                else:\n                    edges_f = torch.cat([edges_f[:, 0:1, :], self.edge_rnn_f(edges_f[:, 1:N, :], edges_att_f[:, 1:N, :])], 1)\n\n                edges_f_cat = torch.cat([edges_f_cat, edges_f[:, None, :, :]], 1)\n                edges_f = torch.cat([edges_f[:, 0:1, :], self.norm(torch.sum(edges_f_cat[:, :, 1:N, :], 1))], 1)\n\n            nodes_f_r = torch.cat([torch.zeros([B, 1, self.hidden_dim], device=nodes_f.device), nodes_f[:, 0:(L-1), :]], 1)\n\n            if self.use_global:\n                nodes_f = self.node_rnn_f(nodes_f, nodes_f_r, nodes_att_f, glo_att_f.expand(B, L, -1))\n            else:\n                nodes_f = self.node_rnn_f(nodes_f, nodes_f_r, nodes_att_f)\n\n            nodes_f_cat = torch.cat([nodes_f_cat, nodes_f[:, None, :, :]], 1)\n            nodes_f = self.norm(torch.sum(nodes_f_cat, 1))\n\n            if self.use_global:\n                glo_f = self.glo_rnn_f(glo_f, glo_att_f)\n                glo_f_cat = torch.cat([glo_f_cat, glo_f[:, None, :, :]], 1)\n                glo_f = self.norm(torch.sum(glo_f_cat, 1))\n\n        nodes_cat = nodes_f_cat\n\n        # backward direction digraph\n        if self.bidirectional:\n            nodes_b, _ = self.emb_rnn_b(torch.flip(node_embeds, [1]))\n            nodes_b = torch.flip(nodes_b, [1])\n            nodes_b = nodes_b * mask.unsqueeze(2)\n            nodes_b_cat = nodes_b[:, None, :, :]\n\n            if self.use_edge:\n                edges_b = edge_embs * edges_mask.unsqueeze(2)\n                edges_b_cat = edges_b[:, None, :, :]\n                if self.use_global:\n                    glo_b = edges_b.sum(1, keepdim=True) / edges_mask.sum(1, keepdim=True).unsqueeze_(2) + \\\n                            nodes_b.sum(1, keepdim=True) / mask.sum(1, keepdim=True).unsqueeze_(2)\n                    glo_b_cat = glo_b[:, None, :, :]\n\n            else:\n                if self.use_global:\n                    glo_b = nodes_b.sum(1, keepdim=True) / mask.sum(1, keepdim=True).unsqueeze_(2)\n                    glo_b_cat = glo_b[:, None, :, :]\n\n            for i in range(self.iters):\n\n                # Attention-based aggregation\n                if self.use_edge and N > 1:\n                    bmes_nodes_b = torch.cat([nodes_b.unsqueeze(2).expand(B, L, N, H), bmes_embs.transpose(1, 2)], -1)\n                    edges_att_b = self.node2edge_b[i](edges_b, bmes_nodes_b, nodes_mask.transpose(1, 2))\n\n                nodes_begin_b = torch.sum(nodes_b[:, None, :, :] * (1 - words_mask_f)[:, :, :, None].float(), 2)\n                nodes_begin_b = torch.cat([torch.zeros([B, 1, H], device=nodes_b.device), nodes_begin_b[:, 1:N, :]], 1)\n\n                if self.use_edge:\n                    nodes_att_b = self.edge2node_b[i](nodes_b, torch.cat([edges_b, nodes_begin_b, words_length], -1).unsqueeze(2), words_mask_b)\n                    if self.use_global:\n                        glo_att_b = torch.cat([self.glo_att_b_node[i](glo_b, nodes_b, (1-mask).byte()),\n                                               self.glo_att_b_edge[i](glo_b, edges_b, (1-edges_mask).byte())], -1)\n                else:\n                    nodes_att_b = self.edge2node_b[i](nodes_b, torch.cat([nodes_begin_b, words_length], -1).unsqueeze(2), words_mask_b)\n                    if self.use_global:\n                        glo_att_b = self.glo_att_b_node[i](glo_b, nodes_b, (1-mask).byte())\n\n                # RNN-based update\n                if self.use_edge and N > 1:\n                    if self.use_global:\n                        edges_b = torch.cat([edges_b[:, 0:1, :], self.edge_rnn_b(edges_b[:, 1:N, :],\n                                             edges_att_b[:, 1:N, :], glo_att_b.expand(B, N-1, H*2))], 1)\n                    else:\n                        edges_b = torch.cat([edges_b[:, 0:1, :], self.edge_rnn_b(edges_b[:, 1:N, :], edges_att_b[:, 1:N, :])], 1)\n\n                    edges_b_cat = torch.cat([edges_b_cat, edges_b[:, None, :, :]], 1)\n                    edges_b = torch.cat([edges_b[:, 0:1, :], self.norm(torch.sum(edges_b_cat[:, :, 1:N, :], 1))], 1)\n\n                nodes_b_r = torch.cat([nodes_b[:, 1:L, :], torch.zeros([B, 1, self.hidden_dim], device=nodes_b.device)], 1)\n\n                if self.use_global:\n                    nodes_b = self.node_rnn_b(nodes_b, nodes_b_r, nodes_att_b, glo_att_b.expand(B, L, -1))\n                else:\n                    nodes_b = self.node_rnn_b(nodes_b, nodes_b_r, nodes_att_b)\n\n                nodes_b_cat = torch.cat([nodes_b_cat, nodes_b[:, None, :, :]], 1)\n                nodes_b = self.norm(torch.sum(nodes_b_cat, 1))\n\n                if self.use_global:\n                    glo_b = self.glo_rnn_b(glo_b, glo_att_b)\n                    glo_b_cat = torch.cat([glo_b_cat, glo_b[:, None, :, :]], 1)\n                    glo_b = self.norm(torch.sum(glo_b_cat, 1))\n\n            nodes_cat = torch.cat([nodes_f_cat, nodes_b_cat], -1)\n\n        layer_att = torch.sigmoid(self.layer_att_W(nodes_cat))\n        layer_alpha = F.softmax(layer_att, 1)\n        nodes = torch.sum(layer_alpha * nodes_cat, 1)\n\n        tags = self.hidden2tag(nodes)\n\n        return tags\n\n    def forward(self, word_list, batch_inputs, mask, batch_label=None):\n\n        tags = self.update_graph(word_list, batch_inputs, mask)\n\n        if batch_label is not None:\n            if self.use_crf:\n                total_loss = self.crf.neg_log_likelihood_loss(tags, mask, batch_label)\n            else:\n                total_loss = self.criterion(tags.view(-1, self.label_size), batch_label.view(-1))\n        else:\n            total_loss = None\n\n        if self.use_crf:\n            _, tag_seq = self.crf._viterbi_decode(tags, mask)\n        else:\n            tag_seq = tags.argmax(-1)\n\n        return total_loss, tag_seq\n"
  },
  {
    "path": "model/crf.py",
    "content": "# -*- coding: utf-8 -*-\n# @Author: Jie Yang\n# @Last Modified by:   Yicheng Zou,     Contact: yczou18@fudan.edu.cn\n\nimport torch\nimport torch.autograd as autograd\nimport torch.nn as nn\nSTART_TAG = -2\nSTOP_TAG = -1\n\n\n# Compute log sum exp in a numerically stable way for the forward algorithm\ndef log_sum_exp(vec, m_size):\n    \"\"\"\n    calculate log of exp sum\n    args:\n        vec (batch_size, vanishing_dim, hidden_dim) : input tensor\n        m_size : hidden_dim\n    return:\n        batch_size, hidden_dim\n    \"\"\"\n    _, idx = torch.max(vec, 1)  # B * 1 * M\n    max_score = torch.gather(vec, 1, idx.view(-1, 1, m_size)).view(-1, 1, m_size)  # B * M\n    return max_score.view(-1, m_size) + torch.log(torch.sum(torch.exp(vec - max_score.expand_as(vec)), 1)).view(-1, m_size)  # B * M\n\n\nclass CRF(nn.Module):\n\n    def __init__(self, tagset_size, gpu):\n        super(CRF, self).__init__()\n        print (\"build batched crf...\")\n        self.gpu = gpu\n        # Matrix of transition parameters.  Entry i,j is the score of transitioning *to* i *from* j.\n        self.average_batch = False\n        self.tagset_size = tagset_size\n        # # We add 2 here, because of START_TAG and STOP_TAG\n        # # transitions (f_tag_size, t_tag_size), transition value from f_tag to t_tag\n        init_transitions = torch.zeros(self.tagset_size+2, self.tagset_size+2)\n        # init_transitions = torch.zeros(self.tagset_size+2, self.tagset_size+2)\n        # init_transitions[:,START_TAG] = -1000.0\n        # init_transitions[STOP_TAG,:] = -1000.0\n        # init_transitions[:,0] = -1000.0\n        # init_transitions[0,:] = -1000.0\n        if self.gpu:\n            init_transitions = init_transitions.cuda()\n        self.transitions = nn.Parameter(init_transitions)  #(t+2,t+2)\n\n        # self.transitions = nn.Parameter(torch.Tensor(self.tagset_size+2, self.tagset_size+2))\n        # self.transitions.data.zero_()\n\n    def _calculate_PZ(self, feats, mask):\n        \"\"\"\n            input:\n                feats: (batch, seq_len, self.tag_size+2)  (b,m,t+2)\n                masks: (batch, seq_len)   (b,m)\n        \"\"\"\n        batch_size = feats.size(0)\n        seq_len = feats.size(1)\n        tag_size = feats.size(2)\n        # print feats.view(seq_len, tag_size)\n        assert(tag_size == self.tagset_size+2)\n        mask = mask.transpose(1,0).contiguous()  #(m,b)\n        ins_num = seq_len * batch_size\n        ## be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1)\n        feats = feats.transpose(1,0).contiguous().view(ins_num,1, tag_size).expand(ins_num, tag_size, tag_size)  #(i,t+2,t+2) 第2维t+2的每一个是一样的\n        ## need to consider start\n        scores = feats + self.transitions.view(1,tag_size,tag_size).expand(ins_num, tag_size, tag_size)\n        scores = scores.view(seq_len, batch_size, tag_size, tag_size)\n        # build iter\n        seq_iter = enumerate(scores)\n        _, inivalues = seq_iter.__next__()  # bat_size * from_target_size * to_target_size  (b,t,t) inivalues是每个句子的第一个字\n        # only need start from start_tag\n        partition = inivalues[:, START_TAG, :].clone().view(batch_size, tag_size, 1)  # bat_size * to_target_size (b,t,1)\n\n        ## add start score (from start to all tag, duplicate to batch_size)\n        # partition = partition + self.transitions[START_TAG,:].view(1, tag_size, 1).expand(batch_size, tag_size, 1)\n        # iter over last scores\n        for idx, cur_values in seq_iter:\n            # previous to_target is current from_target\n            # partition: previous results log(exp(from_target)), #(batch_size * from_target)\n            # cur_values: bat_size * from_target * to_target\n            \n            cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)\n            cur_partition = log_sum_exp(cur_values, tag_size)  #(b,t)\n            # print cur_partition.data\n            \n                # (bat_size * from_target * to_target) -> (bat_size * to_target)\n            # partition = utils.switch(partition, cur_partition, mask[idx].view(bat_size, 1).expand(bat_size, self.tagset_size)).view(bat_size, -1)\n            mask_idx = mask[idx, :].view(batch_size, 1).expand(batch_size, tag_size)\n            \n            ## effective updated partition part, only keep the partition value of mask value = 1\n            masked_cur_partition = cur_partition.masked_select(mask_idx)\n            ## let mask_idx broadcastable, to disable warning\n            mask_idx = mask_idx.contiguous().view(batch_size, tag_size, 1)\n\n            ## replace the partition where the maskvalue=1, other partition value keeps the same\n            partition.masked_scatter_(mask_idx, masked_cur_partition)  \n        # until the last state, add transition score for all partition (and do log_sum_exp) then select the value in STOP_TAG\n        cur_values = self.transitions.view(1,tag_size, tag_size).expand(batch_size, tag_size, tag_size) + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)\n        cur_partition = log_sum_exp(cur_values, tag_size)  #(batch_size,hidden_dim)\n        final_partition = cur_partition[:, STOP_TAG]  #(batch_size)\n        return final_partition.sum(), scores #scores: (seq_len, batch, tag_size, tag_size)\n\n\n    def _viterbi_decode(self, feats, mask):\n        \"\"\"\n            input:\n                feats: (batch, seq_len, self.tag_size+2)\n                mask: (batch, seq_len)\n            output:\n                decode_idx: (batch, seq_len) decoded sequence\n                path_score: (batch, 1) corresponding score for each sequence (to be implementated)\n        \"\"\"\n        batch_size = feats.size(0)\n        seq_len = feats.size(1)\n        tag_size = feats.size(2)\n        assert(tag_size == self.tagset_size+2)\n        ## calculate sentence length for each sentence\n        length_mask = torch.sum(mask.long(), dim = 1).view(batch_size,1).long()  #(batch_size,1) 每个句子的mask长度\n        ## mask to (seq_len, batch_size)\n        mask = mask.transpose(1,0).contiguous()  #（seq_len,b）\n        ins_num = seq_len * batch_size\n        ## be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1)\n        feats = feats.transpose(1,0).contiguous().view(ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size)  #(ins_num, tag_size, tag_size)\n        ## need to consider start\n        scores = feats + self.transitions.view(1,tag_size,tag_size).expand(ins_num, tag_size, tag_size)\n        scores = scores.view(seq_len, batch_size, tag_size, tag_size)\n\n        # build iter\n        seq_iter = enumerate(scores)\n        ## record the position of best score\n        back_points = list()\n        partition_history = list()\n        \n        ##  reverse mask (bug for mask = 1- mask, use this as alternative choice)\n        # mask = 1 + (-1)*mask\n        mask =  (1 - mask.long()).byte()\n        _, inivalues = seq_iter.__next__()  # bat_size * from_target_size * to_target_size\n        # only need start from start_tag\n        partition = inivalues[:, START_TAG, :].clone().view(batch_size, tag_size, 1)  # bat_size * to_target_size\n        partition_history.append(partition) #(seqlen,batch_size,tag_size,1)\n        # iter over last scores\n        for idx, cur_values in seq_iter:\n            # previous to_target is current from_target\n            # partition: previous results log(exp(from_target)), #(batch_size * from_target)\n            # cur_values: batch_size * from_target * to_target\n            cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)\n            ## forscores, cur_bp = torch.max(cur_values[:,:-2,:], 1) # do not consider START_TAG/STOP_TAG\n            partition, cur_bp = torch.max(cur_values,dim=1)\n            partition_history.append(partition.unsqueeze(2))\n            ## cur_bp: (batch_size, tag_size) max source score position in current tag\n            ## set padded label as 0, which will be filtered in post processing\n            cur_bp.masked_fill_(mask[idx].view(batch_size, 1).expand(batch_size, tag_size), 0) \n            back_points.append(cur_bp)\n        ### add score to final STOP_TAG\n        partition_history = torch.cat(partition_history,dim=0).view(seq_len, batch_size,-1).transpose(1,0).contiguous() ## (batch_size, seq_len, tag_size)\n        ### get the last position for each setences, and select the last partitions using gather()\n        last_position = length_mask.view(batch_size,1,1).expand(batch_size, 1, tag_size) -1\n        last_partition = torch.gather(partition_history, 1, last_position).view(batch_size,tag_size,1)\n        ### calculate the score from last partition to end state (and then select the STOP_TAG from it)\n        last_values = last_partition.expand(batch_size, tag_size, tag_size) + self.transitions.view(1,tag_size, tag_size).expand(batch_size, tag_size, tag_size)\n        _, last_bp = torch.max(last_values, 1)  #(batch_size,tag_size)\n        pad_zero = autograd.Variable(torch.zeros(batch_size, tag_size)).long()\n        if self.gpu:\n            pad_zero = pad_zero.cuda()\n        back_points.append(pad_zero)\n        back_points  =  torch.cat(back_points).view(seq_len, batch_size, tag_size)\n        \n        ## select end ids in STOP_TAG\n        pointer = last_bp[:, STOP_TAG] #(batch_size)\n        insert_last = pointer.contiguous().view(batch_size,1,1).expand(batch_size,1, tag_size)\n        back_points = back_points.transpose(1,0).contiguous()   #(batch_size,sq_len,tag_size)\n        ## move the end ids(expand to tag_size) to the corresponding position of back_points to replace the 0 values\n        # print \"lp:\",last_position\n        # print \"il:\",insert_last\n        back_points.scatter_(1, last_position, insert_last)  ##(batch_size,sq_len,tag_size)\n        # print \"bp:\",back_points\n        # exit(0)\n        back_points = back_points.transpose(1,0).contiguous()  #(seq_len, batch_size, tag_size)\n        ## decode from the end, padded position ids are 0, which will be filtered if following evaluation\n        decode_idx = autograd.Variable(torch.LongTensor(seq_len, batch_size))\n        if self.gpu:\n            decode_idx = decode_idx.cuda()\n        decode_idx[-1] = pointer.data\n        for idx in range(len(back_points)-2, -1, -1):\n            pointer = torch.gather(back_points[idx], 1, pointer.contiguous().view(batch_size, 1)) #pointer's size:(batch_size,1)\n            decode_idx[idx] = pointer.squeeze(1).data\n        path_score = None\n        decode_idx = decode_idx.transpose(1,0) #(batch_size, sent_len)\n        return path_score, decode_idx  #\n\n\n    def forward(self, feats):\n        path_score, best_path = self._viterbi_decode(feats)\n        return path_score, best_path\n        \n\n    def _score_sentence(self, scores, mask, tags):\n        \"\"\"\n            input:\n                scores: variable (seq_len, batch, tag_size, tag_size)\n                mask: (batch, seq_len)\n                tags: tensor  (batch, seq_len)\n            output:\n                score: sum of score for gold sequences within whole batch\n        \"\"\"\n        # Gives the score of a provided tag sequence\n        batch_size = scores.size(1)\n        seq_len = scores.size(0)\n        tag_size = scores.size(2)\n        ## convert tag value into a new format, recorded label bigram information to index  \n        new_tags = autograd.Variable(torch.LongTensor(batch_size, seq_len))\n        if self.gpu:\n            new_tags = new_tags.cuda()\n        for idx in range(seq_len):\n            if idx == 0:\n                ## start -> first score\n                new_tags[:,0] =  (tag_size - 2)*tag_size + tags[:,0]\n\n            else:\n                new_tags[:,idx] =  tags[:,idx-1]*tag_size + tags[:,idx]\n\n        ## transition for label to STOP_TAG\n        end_transition = self.transitions[:,STOP_TAG].contiguous().view(1, tag_size).expand(batch_size, tag_size)\n        ## length for batch,  last word position = length - 1\n        length_mask = torch.sum(mask.long(), dim = 1).view(batch_size,1).long()\n        ## index the label id of last word\n        end_ids = torch.gather(tags, 1, length_mask - 1)\n\n        ## index the transition score for end_id to STOP_TAG\n        end_energy = torch.gather(end_transition, 1, end_ids)\n\n        ## convert tag as (seq_len, batch_size, 1)\n        new_tags = new_tags.transpose(1,0).contiguous().view(seq_len, batch_size, 1)\n        ### need convert tags id to search from 400 positions of scores\n        tg_energy = torch.gather(scores.view(seq_len, batch_size, -1), 2, new_tags).view(seq_len, batch_size)  # seq_len * bat_size\n        ## mask transpose to (seq_len, batch_size)\n        tg_energy = tg_energy.masked_select(mask.transpose(1,0))\n        \n        # ## calculate the score from START_TAG to first label\n        # start_transition = self.transitions[START_TAG,:].view(1, tag_size).expand(batch_size, tag_size)\n        # start_energy = torch.gather(start_transition, 1, tags[0,:])\n\n        ## add all score together\n        # gold_score = start_energy.sum() + tg_energy.sum() + end_energy.sum()\n        gold_score = tg_energy.sum() + end_energy.sum()\n        return gold_score\n\n    def neg_log_likelihood_loss(self, feats, mask, tags):\n        # nonegative log likelihood\n        batch_size = feats.size(0)\n        forward_score, scores = self._calculate_PZ(feats, mask)  #forward_score:long, scores: (seq_len, batch, tag_size, tag_size)\n        gold_score = self._score_sentence(scores, mask, tags)\n        #print (\"batch, f:\", forward_score.data, \" g:\", gold_score.data, \" dis:\", forward_score.data - gold_score.data)\n        # exit(0)\n        if self.average_batch:\n            return (forward_score - gold_score)/batch_size\n        else:\n             return forward_score - gold_score\n"
  },
  {
    "path": "model/module.py",
    "content": "# -*- coding: utf-8 -*-\n# @Author: Yicheng Zou\n# @Last Modified by:   Yicheng Zou,     Contact: yczou18@fudan.edu.cn\n\n\nimport torch\nimport math\nimport torch.nn as nn\nimport numpy as np\nimport torch.nn.functional as F\n\n\nclass MultiHeadAtt(nn.Module):\n    def __init__(self, nhid, keyhid, nhead=10, head_dim=10, dropout=0.1, if_g=False):\n        super(MultiHeadAtt, self).__init__()\n\n        if if_g:\n            self.WQ = nn.Conv2d(nhid * 3, nhead * head_dim, 1)\n        else:\n            self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1)\n        self.WK = nn.Conv2d(keyhid, nhead * head_dim, 1)\n        self.WV = nn.Conv2d(keyhid, nhead * head_dim, 1)\n        self.WO = nn.Conv2d(nhead * head_dim, nhid, 1)\n\n        self.drop = nn.Dropout(dropout)\n\n        self.norm = nn.LayerNorm(nhid)\n\n        self.nhid, self.nhead, self.head_dim = nhid, nhead, head_dim\n\n    def forward(self, query_h, value, mask, query_g=None):\n\n        if not (query_g is None):\n            query = torch.cat([query_h, query_g], -1)\n        else:\n            query = query_h\n        query = query.permute(0, 2, 1)[:, :, :, None]\n        value = value.permute(0, 3, 1, 2)\n\n        residual = query_h\n        nhid, nhead, head_dim = self.nhid, self.nhead, self.head_dim\n\n        B, QL, H = query_h.shape\n\n        _, _, VL, VD = value.shape  # VD = 1 or VD = QL\n\n        assert VD == 1 or VD == QL\n        # q: (B, H, QL, 1)\n        # v: (B, H, VL, VD)\n        q, k, v = self.WQ(query), self.WK(value), self.WV(value)\n\n        q = q.view(B, nhead, head_dim, 1, QL)\n        k = k.view(B, nhead, head_dim, VL, VD)\n        v = v.view(B, nhead, head_dim, VL, VD)\n\n        alpha = (q * k).sum(2, keepdim=True) / np.sqrt(head_dim)\n        alpha = alpha.masked_fill(mask[:, None, None, :, :], -np.inf)\n        alpha = self.drop(F.softmax(alpha, 3))\n        att = (alpha * v).sum(3).view(B, nhead * head_dim, QL, 1)\n\n        output = F.leaky_relu(self.WO(att)).permute(0, 2, 3, 1).view(B, QL, H)\n        output = self.norm(output + residual)\n\n        return output\n\n\nclass GloAtt(nn.Module):\n    def __init__(self, nhid, nhead=10, head_dim=10, dropout=0.1):\n        # Multi-head Self Attention Case 2, a broadcastable query for a sequence key and value\n        super(GloAtt, self).__init__()\n        self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1)\n        self.WK = nn.Conv2d(nhid, nhead * head_dim, 1)\n        self.WV = nn.Conv2d(nhid, nhead * head_dim, 1)\n        self.WO = nn.Conv2d(nhead * head_dim, nhid, 1)\n\n        self.drop = nn.Dropout(dropout)\n\n        self.norm = nn.LayerNorm(nhid)\n\n        # print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim)\n        self.nhid, self.nhead, self.head_dim = nhid, nhead, head_dim\n\n    def forward(self, x, y, mask=None):\n        # x: B, H, 1, 1, 1 y: B H L 1\n        nhid, nhead, head_dim = self.nhid, self.nhead, self.head_dim\n        B, L, H = y.shape\n\n        x = x.permute(0, 2, 1)[:, :, :, None]\n        y = y.permute(0, 2, 1)[:, :, :, None]\n\n        residual = x\n        q, k, v = self.WQ(x), self.WK(y), self.WV(y)\n\n        q = q.view(B, nhead, 1, head_dim)  # B, H, 1, 1 -> B, N, 1, h\n        k = k.view(B, nhead, head_dim, L)  # B, H, L, 1 -> B, N, h, L\n        v = v.view(B, nhead, head_dim, L).permute(0, 1, 3, 2)  # B, H, L, 1 -> B, N, L, h\n\n        pre_a = torch.matmul(q, k) / np.sqrt(head_dim)\n        if mask is not None:\n            pre_a = pre_a.masked_fill(mask[:, None, None, :], -float('inf'))\n        alphas = self.drop(F.softmax(pre_a, 3))  # B, N, 1, L\n        att = torch.matmul(alphas, v).view(B, -1, 1, 1)  # B, N, 1, h -> B, N*h, 1, 1\n        output = F.leaky_relu(self.WO(att)) + residual\n        output = self.norm(output.permute(0, 2, 3, 1)).view(B, 1, H)\n\n        return output\n\n\nclass Nodes_Cell(nn.Module):\n    def __init__(self, input_h, hid_h, use_global=True, dropout=0.2):\n        super(Nodes_Cell, self).__init__()\n\n        self.use_global = use_global\n        self.hidden_size = hid_h\n        self.Wix = nn.Linear(input_h, hid_h)\n        self.Wi2 = nn.Linear(input_h, hid_h)\n        self.Wf = nn.Linear(input_h, hid_h)\n        self.Wcx = nn.Linear(input_h, hid_h)\n\n        self.drop = nn.Dropout(dropout)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        stdv = 1.0 / math.sqrt(self.hidden_size)\n        for weight in self.parameters():\n            nn.init.uniform_(weight, -stdv, stdv)\n\n    def forward(self, h, h2, x, glo=None):\n\n        x = self.drop(x)\n\n        if self.use_global:\n            glo = self.drop(glo)\n            cat_all = torch.cat([h, h2, x, glo], -1)\n        else:\n            cat_all = torch.cat([h, h2, x], -1)\n\n        ix = torch.sigmoid(self.Wix(cat_all))\n        i2 = torch.sigmoid(self.Wi2(cat_all))\n        f = torch.sigmoid(self.Wf(cat_all))\n        cx = torch.tanh(self.Wcx(cat_all))\n\n        alpha = F.softmax(torch.cat([ix.unsqueeze(1), i2.unsqueeze(1), f.unsqueeze(1)], 1), 1)\n        output = (alpha[:, 0] * cx) + (alpha[:, 1] * h2) + (alpha[:, 2] * h)\n\n        return output\n\n\nclass Edges_Cell(nn.Module):\n    def __init__(self, input_h, hid_h, use_global=True, dropout=0.2):\n        super(Edges_Cell, self).__init__()\n\n        self.use_global = use_global\n        self.hidden_size = hid_h\n        self.Wi = nn.Linear(input_h, hid_h)\n        self.Wf = nn.Linear(input_h, hid_h)\n        self.Wc = nn.Linear(input_h, hid_h)\n\n        self.drop = nn.Dropout(dropout)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        stdv = 1.0 / math.sqrt(self.hidden_size)\n        for weight in self.parameters():\n            nn.init.uniform_(weight, -stdv, stdv)\n\n    def forward(self, h, x, glo=None):\n\n        x = self.drop(x)\n\n        if self.use_global:\n            glo = self.drop(glo)\n            cat_all = torch.cat([h, x, glo], -1)\n        else:\n            cat_all = torch.cat([h, x], -1)\n\n        i = torch.sigmoid(self.Wi(cat_all))\n        f = torch.sigmoid(self.Wf(cat_all))\n        c = torch.tanh(self.Wc(cat_all))\n\n        alpha = F.softmax(torch.cat([i.unsqueeze(1), f.unsqueeze(1)], 1), 1)\n        output = (alpha[:, 0] * c) + (alpha[:, 1] * h)\n\n        return output\n\n\nclass Global_Cell(nn.Module):\n    def __init__(self, input_h, hid_h, dropout=0.2):\n        super(Global_Cell, self).__init__()\n\n        self.hidden_size = hid_h\n        self.Wi = nn.Linear(input_h, hid_h)\n        self.Wf = nn.Linear(input_h, hid_h)\n        self.Wc = nn.Linear(input_h, hid_h)\n\n        self.drop = nn.Dropout(dropout)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        stdv = 1.0 / math.sqrt(self.hidden_size)\n        for weight in self.parameters():\n            nn.init.uniform_(weight, -stdv, stdv)\n\n    def forward(self, h, x):\n\n        x = self.drop(x)\n\n        cat_all = torch.cat([h, x], -1)\n        i = torch.sigmoid(self.Wi(cat_all))\n        f = torch.sigmoid(self.Wf(cat_all))\n        c = torch.tanh(self.Wc(cat_all))\n\n        alpha = F.softmax(torch.cat([i.unsqueeze(1), f.unsqueeze(1)], 1), 1)\n        output = (alpha[:, 0] * c) + (alpha[:, 1] * h)\n\n        return output\n"
  },
  {
    "path": "utils/alphabet.py",
    "content": "# -*- coding: utf-8 -*-\n# @Author: Max\n# @Last Modified by:   Yicheng Zou,     Contact: yczou18@fudan.edu.cn\n\n\"\"\"\nAlphabet maps objects to integer ids. It provides two way mapping from the index to the objects.\n\"\"\"\nimport json\nimport os\n\n\nclass Alphabet:\n    def __init__(self, name, label=False, keep_growing=True):\n        self.__name = name\n        self.UNKNOWN = \"</unk>\"\n        self.label = label\n        self.instance2index = {}\n        self.instances = []\n        self.keep_growing = keep_growing\n\n        # Index 0 is occupied by default, all else following.\n        self.default_index = 0\n        self.next_index = 1\n        if not self.label:\n            self.add(self.UNKNOWN)\n\n    def clear(self, keep_growing=True):\n        self.instance2index = {}\n        self.instances = []\n        self.keep_growing = keep_growing\n\n        # Index 0 is occupied by default, all else following.\n        self.default_index = 0\n        self.next_index = 1\n        \n    def add(self, instance):\n        if instance not in self.instance2index:\n            self.instances.append(instance)\n            self.instance2index[instance] = self.next_index\n            self.next_index += 1\n\n    def get_index(self, instance):\n        try:\n            return self.instance2index[instance]\n        except KeyError:\n            if self.keep_growing:\n                index = self.next_index\n                self.add(instance)\n                return index\n            else:\n                return self.instance2index[self.UNKNOWN]\n\n    def get_instance(self, index):\n        if index == 0:\n            # First index is occupied by the wildcard element.\n            return None\n        try:\n            return self.instances[index - 1]\n        except IndexError:\n            print('WARNING:Alphabet get_instance ,unknown instance index {}, return the first label.'.format(index))\n            return self.instances[0]\n\n    def size(self):\n        return len(self.instances) + 1\n\n    def iteritems(self):\n        return self.instance2index.items()\n\n    def enumerate_items(self, start=1):\n        if start < 1 or start >= self.size():\n            raise IndexError(\"Enumerate is allowed between [1 : size of the alphabet)\")\n        return zip(range(start, len(self.instances) + 1), self.instances[start - 1:])\n\n    def close(self):\n        self.keep_growing = False\n\n    def open(self):\n        self.keep_growing = True\n\n    def get_content(self):\n        return {'instance2index': self.instance2index, 'instances': self.instances}\n\n    def from_json(self, data):\n        self.instances = data[\"instances\"]\n        self.instance2index = data[\"instance2index\"]\n\n    def save(self, output_directory, name=None):\n        \"\"\"\n        Save both alhpabet records to the given directory.\n        :param output_directory: Directory to save model and weights.\n        :param name: The alphabet saving name, optional.\n        :return:\n        \"\"\"\n        saving_name = name if name else self.__name\n        try:\n            json.dump(self.get_content(), open(os.path.join(output_directory, saving_name + \".json\"), 'w'))\n        except Exception as e:\n            print(\"Exception: Alphabet is not saved: \" + repr(e))\n\n    def load(self, input_directory, name=None):\n        \"\"\"\n        Load model architecture and weights from the give directory. This allow we use old models even the structure\n        changes.\n        :param input_directory: Directory to save model and weights\n        :return:\n        \"\"\"\n        loading_name = name if name else self.__name\n        self.from_json(json.load(open(os.path.join(input_directory, loading_name + \".json\"))))\n"
  },
  {
    "path": "utils/data.py",
    "content": "# -*- coding: utf-8 -*-\n# @Author: Jie Yang\n# @Last Modified by:   Yicheng Zou,     Contact: yczou18@fudan.edu.cn\n\nimport sys\nfrom utils.alphabet import Alphabet\nfrom utils.functions import *\nfrom utils.word_trie import Word_Trie\n\n\nclass Data:\n    def __init__(self): \n        self.MAX_SENTENCE_LENGTH = 250\n        self.MAX_WORD_LENGTH = -1\n        self.number_normalized = True\n        self.norm_char_emb = True\n        self.norm_word_emb = False\n        self.char_alphabet = Alphabet('character')\n        self.label_alphabet = Alphabet('label', True)\n        self.word_dict = Word_Trie()\n        self.word_alphabet = Alphabet('word')\n\n        self.train_texts = []\n        self.dev_texts = []\n        self.test_texts = []\n        self.raw_texts = []\n\n        self.train_Ids = []\n        self.dev_Ids = []\n        self.test_Ids = []\n        self.raw_Ids = []\n        self.char_emb_dim = 50\n        self.word_emb_dim = 50\n        self.pretrain_char_embedding = None\n        self.pretrain_word_embedding = None\n        self.label_size = 0\n        \n    def show_data_summary(self):\n        print(\"DATA SUMMARY:\")\n        print(\"     MAX SENTENCE LENGTH: %s\"%(self.MAX_SENTENCE_LENGTH))\n        print(\"     MAX   WORD   LENGTH: %s\"%(self.MAX_WORD_LENGTH))\n        print(\"     Number   normalized: %s\"%(self.number_normalized))\n        print(\"     Word  alphabet size: %s\"%(self.word_alphabet.size()))\n        print(\"     Char  alphabet size: %s\"%(self.char_alphabet.size()))\n        print(\"     Label alphabet size: %s\"%(self.label_alphabet.size()))\n        print(\"     Word embedding size: %s\"%(self.word_emb_dim))\n        print(\"     Char embedding size: %s\"%(self.char_emb_dim))\n        print(\"     Norm     char   emb: %s\"%(self.norm_char_emb))\n        print(\"     Norm     word   emb: %s\"%(self.norm_word_emb))\n        print(\"     Train instance number: %s\"%(len(self.train_texts)))\n        print(\"     Dev   instance number: %s\"%(len(self.dev_texts)))\n        print(\"     Test  instance number: %s\"%(len(self.test_texts)))\n        print(\"     Raw   instance number: %s\"%(len(self.raw_texts)))\n        print(\"DATA SUMMARY END.\")\n        sys.stdout.flush()\n\n    def build_alphabet(self, input_file):\n        self.char_alphabet.open()\n        self.label_alphabet.open()\n\n        with open(input_file, 'r', encoding=\"utf-8\") as f:\n            for line in f:\n                line = line.strip()\n                if len(line) == 0:\n                    continue\n                pair = line.split()\n                char = pair[0]\n                if self.number_normalized:\n                    # Mapping numbers to 0\n                    char = normalize_word(char)\n                label = pair[-1]\n                self.label_alphabet.add(label)\n                self.char_alphabet.add(char)\n\n        self.label_alphabet.close()\n        self.char_alphabet.close()\n\n    def build_word_file(self, word_file):\n        # build word file,initial word embedding file\n        with open(word_file, 'r', encoding=\"utf-8\") as f:\n            for line in f:\n                word = line.strip().split()[0]\n                if word:\n                    self.word_dict.insert(word)\n        print(\"Building the word dict...\")\n\n    def build_word_alphabet(self, input_file):\n        print(\"Loading file: \" + input_file)\n        self.word_alphabet.open()\n        word_list = []\n        with open(input_file, 'r', encoding=\"utf-8\") as f:\n            for line in f:\n                line = line.strip()\n                if len(line) > 0:\n                    word = line.split()[0]\n                    if self.number_normalized:\n                        word = normalize_word(word)\n                    word_list.append(word)\n                else:\n                    for idx in range(len(word_list)):\n                        matched_words = self.word_dict.recursive_search(word_list[idx:])\n                        for matched_word in matched_words:\n                            self.word_alphabet.add(matched_word)\n                    word_list = []\n        self.word_alphabet.close()\n        print(\"word alphabet size:\", self.word_alphabet.size())\n\n    def build_char_pretrain_emb(self, emb_path):\n        print (\"Building character pretrain emb...\")\n        self.pretrain_char_embedding, self.char_emb_dim = build_pretrain_embedding(emb_path, self.char_alphabet, self.norm_char_emb)\n\n    def build_word_pretrain_emb(self, emb_path):\n        print (\"Building word pretrain emb...\")\n        self.pretrain_word_embedding, self.word_emb_dim = build_pretrain_embedding(emb_path, self.word_alphabet, self.norm_word_emb)\n\n    def generate_instance_with_words(self, input_file, name):\n        if name == \"train\":\n            self.train_texts, self.train_Ids = read_instance_with_gaz(input_file, self.word_dict, self.char_alphabet,\n                    self.word_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)\n        elif name == \"dev\":\n            self.dev_texts, self.dev_Ids = read_instance_with_gaz(input_file, self.word_dict, self.char_alphabet,\n                    self.word_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)\n        elif name == \"test\":\n            self.test_texts, self.test_Ids = read_instance_with_gaz(input_file, self.word_dict, self.char_alphabet,\n                    self.word_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)\n        elif name == \"raw\":\n            self.raw_texts, self.raw_Ids = read_instance_with_gaz(input_file, self.word_dict, self.char_alphabet,\n                    self.word_alphabet, self.label_alphabet, self.number_normalized, self.MAX_SENTENCE_LENGTH)\n        else:\n            print(\"Error: you can only generate train/dev/test/raw instance! Illegal input:%s\"%(name))\n\n    def write_decoded_results(self, output_file, predict_results, name):\n        fout = open(output_file, 'w', encoding=\"utf-8\")\n        sent_num = len(predict_results)\n        content_list = []\n        if name == 'raw':\n           content_list = self.raw_texts\n        elif name == 'test':\n            content_list = self.test_texts\n        elif name == 'dev':\n            content_list = self.dev_texts\n        elif name == 'train':\n            content_list = self.train_texts\n        else:\n            print(\"Error: illegal name during writing predict result, name should be within train/dev/test/raw !\")\n        assert(sent_num == len(content_list))\n        for idx in range(sent_num):\n            sent_length = len(predict_results[idx])\n            for idy in range(sent_length):\n                # content_list[idx] is a list with [word, char, label]\n                fout.write(content_list[idx][0][idy] + \" \" + predict_results[idx][idy] + '\\n')\n            fout.write('\\n')\n        fout.close()\n        print(\"Predict %s result has been written into file. %s\"%(name, output_file))\n"
  },
  {
    "path": "utils/functions.py",
    "content": "# -*- coding: utf-8 -*-\n# @Author: Jie Yang\n# @Last Modified by:   Yicheng Zou,     Contact: yczou18@fudan.edu.cn\n\nimport numpy as np\n\n\ndef normalize_word(word):\n    new_word = \"\"\n    for char in word:\n        if char.isdigit():\n            new_word += '0'\n        else:\n            new_word += char\n    return new_word\n\n\ndef read_instance_with_gaz(input_file, word_dict, char_alphabet, word_alphabet, label_alphabet, number_normalized, max_sent_length):\n    instence_texts = []\n    instence_Ids = []\n\n    with open(input_file, 'r', encoding=\"utf-8\") as f:\n\n        chars = []\n        labels = []\n        char_Ids = []\n        label_Ids = []\n\n        for line in f:\n            if len(line) > 1:\n                pairs = line.strip().split()\n                char = pairs[0]\n                if number_normalized:\n                    char = normalize_word(char)\n                chars.append(char)\n                char_Ids.append(char_alphabet.get_index(char))\n                if len(pairs) > 1:\n                    label = pairs[-1]\n                else:\n                    label = 'O'\n                labels.append(label)\n                label_Ids.append(label_alphabet.get_index(label))\n\n            # A sentence is finished.\n            else:\n                # Only keep the sentence whose length is smaller than MAX_SENT_LENGTH.\n                if ((max_sent_length < 0) or (len(chars) < max_sent_length)) and (len(chars)>0):\n                    words = []\n                    word_Ids = []\n                    for idx in range(len(chars)):\n                        matched_list = word_dict.recursive_search(chars[idx:])\n                        matched_length = [len(a) for a in matched_list]\n\n                        words.append(matched_list)\n                        matched_Id = [word_alphabet.get_index(word) for word in matched_list]\n                        if matched_Id:\n                            word_Ids.append([matched_Id, matched_length])\n                        else:\n                            word_Ids.append([])\n\n                    instence_texts.append([chars, words, labels])\n                    instence_Ids.append([char_Ids, word_Ids, label_Ids])\n                chars = []\n                labels = []\n                char_Ids = []\n                label_Ids = []\n\n    return instence_texts, instence_Ids\n\n\ndef build_pretrain_embedding(embedding_path, word_alphabet, norm=True, embedd_dim=50):\n\n    def norm2one(vec):\n        root_sum_square = np.sqrt(np.sum(np.square(vec)))\n        return vec / root_sum_square\n\n    embedd_dict = dict()\n    if embedding_path != None:\n        embedd_dict, embedd_dim = load_pretrain_emb(embedding_path)\n\n    scale = np.sqrt(3.0 / embedd_dim)\n    pretrain_emb = np.empty([word_alphabet.size(), embedd_dim])\n    not_match = 0\n    for word, index in word_alphabet.instance2index.items():\n        if word.lower() in embedd_dict:\n            if norm:\n                pretrain_emb[index,:] = norm2one(embedd_dict[word.lower()])\n            else:\n                pretrain_emb[index,:] = embedd_dict[word.lower()]\n        elif word in embedd_dict:\n            if norm:\n                pretrain_emb[index,:] = norm2one(embedd_dict[word])\n            else:\n                pretrain_emb[index,:] = embedd_dict[word]\n        else:\n            pretrain_emb[index,:] = np.random.uniform(-scale, scale, [1, embedd_dim])\n            not_match += 1\n    pretrained_size = len(embedd_dict)\n    print(\"Embedding:\\n     pretrain word:%s, match:%s, oov:%s, oov%%:%.4f\" %\n          (pretrained_size, word_alphabet.size() - not_match, not_match, (not_match+0.)/word_alphabet.size()))\n    return pretrain_emb, embedd_dim\n\n\ndef load_pretrain_emb(embedding_path):\n    embedd_dict = dict()\n    with open(embedding_path, 'r', encoding=\"utf-8\") as f:\n        for line in f:\n            line = line.strip()\n            if len(line) == 0:\n                continue\n            tokens = line.split()\n            embedd_dim = len(tokens) - 1\n            embedd = np.empty([1, embedd_dim])\n            embedd[:] = tokens[1:]\n            embedd_dict[tokens[0]] = embedd\n    return embedd_dict, embedd_dim\n"
  },
  {
    "path": "utils/metric.py",
    "content": "# -*- coding: utf-8 -*-\n# @Author: Jie Yang\n# @Last Modified by:   Yicheng Zou,     Contact: yczou18@fudan.edu.cn\n\n\n# input as sentence level labels\ndef get_ner_fmeasure(golden_lists, predict_lists):\n    sent_num = len(golden_lists)\n    golden_full = []\n    predict_full = []\n    right_full = []\n    right_tag = 0\n    all_tag = 0\n    for idx in range(0,sent_num):\n        golden_list = golden_lists[idx]\n        predict_list = predict_lists[idx]\n        for idy in range(len(golden_list)):\n            if golden_list[idy] == predict_list[idy]:\n                right_tag += 1\n        all_tag += len(golden_list)\n\n        gold_matrix = get_ner_BMES(golden_list)\n        pred_matrix = get_ner_BMES(predict_list)\n\n        right_ner = list(set(gold_matrix).intersection(set(pred_matrix)))\n        golden_full += gold_matrix\n        predict_full += pred_matrix\n        right_full += right_ner\n    right_num = len(right_full)\n    golden_num = len(golden_full)\n    predict_num = len(predict_full)\n    if predict_num == 0:\n        precision = -1\n    else:\n        precision =  (right_num+0.0)/predict_num\n    if golden_num == 0:\n        recall = -1\n    else:\n        recall = (right_num+0.0)/golden_num\n    if (precision == -1) or (recall == -1) or (precision+recall) <= 0.:\n        f_measure = -1\n    else:\n        f_measure = 2*precision*recall/(precision+recall)\n    accuracy = (right_tag+0.0)/all_tag\n    print(\"gold_num = \", golden_num, \" pred_num = \", predict_num, \" right_num = \", right_num)\n    return accuracy, precision, recall, f_measure\n\n\ndef reverse_style(input_string):\n    target_position = input_string.index('[')\n    input_len = len(input_string)\n    output_string = input_string[target_position:input_len] + input_string[0:target_position]\n    return output_string\n\n\ndef get_ner_BMES(label_list):\n\n    list_len = len(label_list)\n    begin_label = 'B-'\n    end_label = 'E-'\n    single_label = 'S-'\n    whole_tag = ''\n    index_tag = ''\n    tag_list = []\n    stand_matrix = []\n    for i in range(0, list_len):\n        # wordlabel = word_list[i]\n        current_label = label_list[i].upper() if label_list[i] else []\n        if begin_label in current_label:\n            if index_tag != '':\n                tag_list.append(whole_tag + ',' + str(i-1))\n            whole_tag = current_label.replace(begin_label,\"\",1) +'[' +str(i)\n            index_tag = current_label.replace(begin_label,\"\",1)\n            \n        elif single_label in current_label:\n            if index_tag != '':\n                tag_list.append(whole_tag + ',' + str(i-1))\n            whole_tag = current_label.replace(single_label,\"\",1) +'[' +str(i)\n            tag_list.append(whole_tag)\n            whole_tag = \"\"\n            index_tag = \"\"\n        elif end_label in current_label:\n            if index_tag != '':\n                tag_list.append(whole_tag +',' + str(i))\n            whole_tag = ''\n            index_tag = ''\n        else:\n            continue\n    if (whole_tag != '')&(index_tag != ''):\n        tag_list.append(whole_tag)\n    tag_list_len = len(tag_list)\n\n    for i in range(0, tag_list_len):\n        if  len(tag_list[i]) > 0:\n            tag_list[i] = tag_list[i]+ ']'\n            insert_list = reverse_style(tag_list[i])\n            stand_matrix.append(insert_list)\n\n    return stand_matrix\n"
  },
  {
    "path": "utils/word_trie.py",
    "content": "# -*- coding: utf-8 -*-\n# @Author: Yicheng Zou\n# @Last Modified by:   Yicheng Zou,     Contact: yczou18@fudan.edu.cn\n\n_end = \"_end_\"\n\n\nclass Word_Trie:\n    def __init__(self):\n        self.root = dict()\n\n    def recursive_search(self, word_list):\n        match_list = []\n        while len(word_list) > 1:\n            if self.search(word_list):\n                match_list.append(\"\".join(word_list))\n            del word_list[-1]\n        return match_list\n\n    def search(self, word):\n        current_dict = self.root\n        for char in word:\n            if char in current_dict:\n                current_dict = current_dict[char]\n            else:\n                return False\n        else:\n            if _end in current_dict:\n                return True\n            else:\n                return False\n\n    def insert(self, word):\n        current_dict = self.root\n        for char in word:\n            current_dict = current_dict.setdefault(char, {})\n        current_dict[_end] = _end\n"
  }
]