[
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2019 Louis Zhibin Lv\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "NER_BERT_CRF.py",
    "content": "# -*- coding: utf-8 -*-\r\n\r\n# # # #\r\n# NER_BERT_CRF.py\r\n# @author Zhibin.LU\r\n# @created Fri Feb 15 2019 22:47:19 GMT-0500 (EST)\r\n# @last-modified Sun Mar 31 2019 12:17:08 GMT-0400 (EDT)\r\n# @website: https://louis-udm.github.io\r\n# @description: Bert pytorch pretrainde model with or without CRF for NER\r\n# The NER_BERT_CRF.py include 2 model:\r\n# - model 1:\r\n#   - This is just a pretrained BertForTokenClassification, For a comparision with my BERT-CRF model\r\n# - model 2:\r\n#   - A pretrained BERT with CRF model.\r\n# - data set\r\n#   - [CoNLL-2003](https://github.com/FuYanzhe2/Name-Entity-Recognition/tree/master/BERT-BiLSTM-CRF-NER/NERdata)\r\n# # # #\r\n\r\n\r\n# %%\r\nimport sys\r\nimport os\r\nimport time\r\nimport importlib\r\nimport numpy as np\r\nimport matplotlib.pyplot as plt\r\nimport torch\r\nimport torch.nn.functional as F\r\nimport torch.nn as nn\r\nimport torch.autograd as autograd\r\nimport torch.optim as optim\r\n\r\nfrom torch.utils.data.distributed import DistributedSampler\r\nfrom torch.utils import data\r\n\r\nfrom tqdm import tqdm, trange\r\nimport collections\r\n\r\nfrom pytorch_pretrained_bert.modeling import BertModel, BertForTokenClassification, BertLayerNorm\r\nimport pickle\r\nfrom pytorch_pretrained_bert.optimization import BertAdam, warmup_linear\r\nfrom pytorch_pretrained_bert.tokenization import BertTokenizer\r\n\r\ndef set_work_dir(local_path=\"ner_bert_crf\", server_path=\"ner_bert_crf\"):\r\n    if (os.path.exists(os.getenv(\"HOME\")+'/'+local_path)):\r\n        os.chdir(os.getenv(\"HOME\")+'/'+local_path)\r\n    elif (os.path.exists(os.getenv(\"HOME\")+'/'+server_path)):\r\n        os.chdir(os.getenv(\"HOME\")+'/'+server_path)\r\n    else:\r\n        raise Exception('Set work path error!')\r\n\r\n\r\ndef get_data_dir(local_path=\"ner_bert_crf\", server_path=\"ner_bert_crf\"):\r\n    if (os.path.exists(os.getenv(\"HOME\")+'/'+local_path)):\r\n        return os.getenv(\"HOME\")+'/'+local_path\r\n    elif (os.path.exists(os.getenv(\"HOME\")+'/'+server_path)):\r\n        return os.getenv(\"HOME\")+'/'+server_path\r\n    else:\r\n        raise Exception('get data path error!')\r\n\r\n\r\nprint('Python version ', sys.version)\r\nprint('PyTorch version ', torch.__version__)\r\n\r\nset_work_dir()\r\nprint('Current dir:', os.getcwd())\r\n\r\ncuda_yes = torch.cuda.is_available()\r\n# cuda_yes = False\r\nprint('Cuda is available?', cuda_yes)\r\ndevice = torch.device(\"cuda:0\" if cuda_yes else \"cpu\")\r\nprint('Device:', device)\r\n\r\ndata_dir = os.path.join(get_data_dir(), 'NER_data/CoNLL2003/')\r\n# \"Whether to run training.\"\r\ndo_train = True\r\n# \"Whether to run eval on the dev set.\"\r\ndo_eval = True\r\n# \"Whether to run the model in inference mode on the test set.\"\r\ndo_predict = True\r\n# Whether load checkpoint file before train model\r\nload_checkpoint = True\r\n# \"The vocabulary file that the BERT model was trained on.\"\r\nmax_seq_length = 180 #256\r\nbatch_size = 32 #32\r\n# \"The initial learning rate for Adam.\"\r\nlearning_rate0 = 5e-5\r\nlr0_crf_fc = 8e-5\r\nweight_decay_finetune = 1e-5 #0.01\r\nweight_decay_crf_fc = 5e-6 #0.005\r\ntotal_train_epochs = 15\r\ngradient_accumulation_steps = 1\r\nwarmup_proportion = 0.1\r\noutput_dir = './output/'\r\nbert_model_scale = 'bert-base-cased'\r\ndo_lower_case = False\r\n# eval_batch_size = 8\r\n# predict_batch_size = 8\r\n# \"Proportion of training to perform linear learning rate warmup for. \"\r\n# \"E.g., 0.1 = 10% of training.\"\r\n# warmup_proportion = 0.1\r\n# \"How often to save the model checkpoint.\"\r\n# save_checkpoints_steps = 1000\r\n# \"How many steps to make in each estimator call.\"\r\n# iterations_per_loop = 1000\r\n\r\n\r\n# %%\r\n'''\r\nFunctions and Classes for read and organize data set\r\n'''\r\n\r\nclass InputExample(object):\r\n    \"\"\"A single training/test example for NER.\"\"\"\r\n\r\n    def __init__(self, guid, words, labels):\r\n        \"\"\"Constructs a InputExample.\r\n\r\n        Args:\r\n          guid: Unique id for the example(a sentence or a pair of sentences).\r\n          words: list of words of sentence\r\n          labels_a/labels_b: (Optional) string. The label seqence of the text_a/text_b. This should be\r\n            specified for train and dev examples, but not for test examples.\r\n        \"\"\"\r\n        self.guid = guid\r\n        # list of words of the sentence,example: [EU, rejects, German, call, to, boycott, British, lamb .]\r\n        self.words = words\r\n        # list of label sequence of the sentence,like: [B-ORG, O, B-MISC, O, O, O, B-MISC, O, O]\r\n        self.labels = labels\r\n\r\n\r\nclass InputFeatures(object):\r\n    \"\"\"A single set of features of data.\r\n    result of convert_examples_to_features(InputExample)\r\n    \"\"\"\r\n\r\n    def __init__(self, input_ids, input_mask, segment_ids,  predict_mask, label_ids):\r\n        self.input_ids = input_ids\r\n        self.input_mask = input_mask\r\n        self.segment_ids = segment_ids\r\n        self.predict_mask = predict_mask\r\n        self.label_ids = label_ids\r\n\r\n\r\nclass DataProcessor(object):\r\n    \"\"\"Base class for data converters for sequence classification data sets.\"\"\"\r\n\r\n    def get_train_examples(self, data_dir):\r\n        \"\"\"Gets a collection of `InputExample`s for the train set.\"\"\"\r\n        raise NotImplementedError()\r\n\r\n    def get_dev_examples(self, data_dir):\r\n        \"\"\"Gets a collection of `InputExample`s for the dev set.\"\"\"\r\n        raise NotImplementedError()\r\n\r\n    def get_labels(self):\r\n        \"\"\"Gets the list of labels for this data set.\"\"\"\r\n        raise NotImplementedError()\r\n\r\n    @classmethod\r\n    def _read_data(cls, input_file):\r\n        \"\"\"\r\n        Reads a BIO data.\r\n        \"\"\"\r\n        with open(input_file) as f:\r\n            # out_lines = []\r\n            out_lists = []\r\n            entries = f.read().strip().split(\"\\n\\n\")\r\n            for entry in entries:\r\n                words = []\r\n                ner_labels = []\r\n                pos_tags = []\r\n                bio_pos_tags = []\r\n                for line in entry.splitlines():\r\n                    pieces = line.strip().split()\r\n                    if len(pieces) < 1:\r\n                        continue\r\n                    word = pieces[0]\r\n                    # if word == \"-DOCSTART-\" or word == '':\r\n                    #     continue\r\n                    words.append(word)\r\n                    pos_tags.append(pieces[1])\r\n                    bio_pos_tags.append(pieces[2])\r\n                    ner_labels.append(pieces[-1])\r\n                # sentence = ' '.join(words)\r\n                # ner_seq = ' '.join(ner_labels)\r\n                # pos_tag_seq = ' '.join(pos_tags)\r\n                # bio_pos_tag_seq = ' '.join(bio_pos_tags)\r\n                # out_lines.append([sentence, pos_tag_seq, bio_pos_tag_seq, ner_seq])\r\n                # out_lines.append([sentence, ner_seq])\r\n                out_lists.append([words,pos_tags,bio_pos_tags,ner_labels])\r\n        return out_lists\r\n\r\n\r\nclass CoNLLDataProcessor(DataProcessor):\r\n    '''\r\n    CoNLL-2003\r\n    '''\r\n\r\n    def __init__(self):\r\n        self._label_types = [ 'X', '[CLS]', '[SEP]', 'O', 'I-LOC', 'B-PER', 'I-PER', 'I-ORG', 'I-MISC', 'B-MISC', 'B-LOC', 'B-ORG']\r\n        self._num_labels = len(self._label_types)\r\n        self._label_map = {label: i for i,\r\n                           label in enumerate(self._label_types)}\r\n\r\n    def get_train_examples(self, data_dir):\r\n        return self._create_examples(\r\n            self._read_data(os.path.join(data_dir, \"train.txt\")))\r\n\r\n    def get_dev_examples(self, data_dir):\r\n        return self._create_examples(\r\n            self._read_data(os.path.join(data_dir, \"valid.txt\")))\r\n\r\n    def get_test_examples(self, data_dir):\r\n        return self._create_examples(\r\n            self._read_data(os.path.join(data_dir, \"test.txt\")))\r\n\r\n    def get_labels(self):\r\n        return self._label_types\r\n\r\n    def get_num_labels(self):\r\n        return self.get_num_labels\r\n\r\n    def get_label_map(self):\r\n        return self._label_map\r\n\r\n    def get_start_label_id(self):\r\n        return self._label_map['[CLS]']\r\n\r\n    def get_stop_label_id(self):\r\n        return self._label_map['[SEP]']\r\n\r\n    def _create_examples(self, all_lists):\r\n        examples = []\r\n        for (i, one_lists) in enumerate(all_lists):\r\n            guid = i\r\n            words = one_lists[0]\r\n            labels = one_lists[-1]\r\n            examples.append(InputExample(\r\n                guid=guid, words=words, labels=labels))\r\n        return examples\r\n\r\n    def _create_examples2(self, lines):\r\n        examples = []\r\n        for (i, line) in enumerate(lines):\r\n            guid = i\r\n            text = line[0]\r\n            ner_label = line[-1]\r\n            examples.append(InputExample(\r\n                guid=guid, text_a=text, labels_a=ner_label))\r\n        return examples\r\n\r\n\r\ndef example2feature(example, tokenizer, label_map, max_seq_length):\r\n\r\n    add_label = 'X'\r\n    # tokenize_count = []\r\n    tokens = ['[CLS]']\r\n    predict_mask = [0]\r\n    label_ids = [label_map['[CLS]']]\r\n    for i, w in enumerate(example.words):\r\n        # use bertTokenizer to split words\r\n        # 1996-08-22 => 1996 - 08 - 22\r\n        # sheepmeat => sheep ##me ##at\r\n        sub_words = tokenizer.tokenize(w)\r\n        if not sub_words:\r\n            sub_words = ['[UNK]']\r\n        # tokenize_count.append(len(sub_words))\r\n        tokens.extend(sub_words)\r\n        for j in range(len(sub_words)):\r\n            if j == 0:\r\n                predict_mask.append(1)\r\n                label_ids.append(label_map[example.labels[i]])\r\n            else:\r\n                # '##xxx' -> 'X' (see bert paper)\r\n                predict_mask.append(0)\r\n                label_ids.append(label_map[add_label])\r\n\r\n    # truncate\r\n    if len(tokens) > max_seq_length - 1:\r\n        print('Example No.{} is too long, length is {}, truncated to {}!'.format(example.guid, len(tokens), max_seq_length))\r\n        tokens = tokens[0:(max_seq_length - 1)]\r\n        predict_mask = predict_mask[0:(max_seq_length - 1)]\r\n        label_ids = label_ids[0:(max_seq_length - 1)]\r\n    tokens.append('[SEP]')\r\n    predict_mask.append(0)\r\n    label_ids.append(label_map['[SEP]'])\r\n\r\n    input_ids = tokenizer.convert_tokens_to_ids(tokens)\r\n    segment_ids = [0] * len(input_ids)\r\n    input_mask = [1] * len(input_ids)\r\n\r\n    feat=InputFeatures(\r\n                # guid=example.guid,\r\n                # tokens=tokens,\r\n                input_ids=input_ids,\r\n                input_mask=input_mask,\r\n                segment_ids=segment_ids,\r\n                predict_mask=predict_mask,\r\n                label_ids=label_ids)\r\n\r\n    return feat\r\n\r\nclass NerDataset(data.Dataset):\r\n    def __init__(self, examples, tokenizer, label_map, max_seq_length):\r\n        self.examples=examples\r\n        self.tokenizer=tokenizer\r\n        self.label_map=label_map\r\n        self.max_seq_length=max_seq_length\r\n\r\n    def __len__(self):\r\n        return len(self.examples)\r\n\r\n    def __getitem__(self, idx):\r\n        feat=example2feature(self.examples[idx], self.tokenizer, self.label_map, max_seq_length)\r\n        return feat.input_ids, feat.input_mask, feat.segment_ids, feat.predict_mask, feat.label_ids\r\n\r\n    @classmethod\r\n    def pad(cls, batch):\r\n\r\n        seqlen_list = [len(sample[0]) for sample in batch]\r\n        maxlen = np.array(seqlen_list).max()\r\n\r\n        f = lambda x, seqlen: [sample[x] + [0] * (seqlen - len(sample[x])) for sample in batch] # 0: X for padding\r\n        input_ids_list = torch.LongTensor(f(0, maxlen))\r\n        input_mask_list = torch.LongTensor(f(1, maxlen))\r\n        segment_ids_list = torch.LongTensor(f(2, maxlen))\r\n        predict_mask_list = torch.ByteTensor(f(3, maxlen))\r\n        label_ids_list = torch.LongTensor(f(4, maxlen))\r\n\r\n        return input_ids_list, input_mask_list, segment_ids_list, predict_mask_list, label_ids_list\r\n\r\ndef f1_score(y_true, y_pred):\r\n    '''\r\n    0,1,2,3 are [CLS],[SEP],[X],O\r\n    '''\r\n    ignore_id=3\r\n\r\n    num_proposed = len(y_pred[y_pred>ignore_id])\r\n    num_correct = (np.logical_and(y_true==y_pred, y_true>ignore_id)).sum()\r\n    num_gold = len(y_true[y_true>ignore_id])\r\n\r\n    try:\r\n        precision = num_correct / num_proposed\r\n    except ZeroDivisionError:\r\n        precision = 1.0\r\n\r\n    try:\r\n        recall = num_correct / num_gold\r\n    except ZeroDivisionError:\r\n        recall = 1.0\r\n\r\n    try:\r\n        f1 = 2*precision*recall / (precision + recall)\r\n    except ZeroDivisionError:\r\n        if precision*recall==0:\r\n            f1=1.0\r\n        else:\r\n            f1=0\r\n\r\n    return precision, recall, f1\r\n\r\n#%%\r\n'''\r\nPrepare data set\r\n'''\r\n# random.seed(44)\r\nnp.random.seed(44)\r\ntorch.manual_seed(44)\r\nif cuda_yes:\r\n    torch.cuda.manual_seed_all(44)\r\n\r\n# Load pre-trained model tokenizer (vocabulary)\r\nconllProcessor = CoNLLDataProcessor()\r\nlabel_list = conllProcessor.get_labels()\r\nlabel_map = conllProcessor.get_label_map()\r\ntrain_examples = conllProcessor.get_train_examples(data_dir)\r\ndev_examples = conllProcessor.get_dev_examples(data_dir)\r\ntest_examples = conllProcessor.get_test_examples(data_dir)\r\n\r\ntotal_train_steps = int(len(train_examples) / batch_size / gradient_accumulation_steps * total_train_epochs)\r\n\r\nprint(\"***** Running training *****\")\r\nprint(\"  Num examples = %d\"% len(train_examples))\r\nprint(\"  Batch size = %d\"% batch_size)\r\nprint(\"  Num steps = %d\"% total_train_steps)\r\n\r\ntokenizer = BertTokenizer.from_pretrained(bert_model_scale, do_lower_case=do_lower_case)\r\n\r\ntrain_dataset = NerDataset(train_examples,tokenizer,label_map,max_seq_length)\r\ndev_dataset = NerDataset(dev_examples,tokenizer,label_map,max_seq_length)\r\ntest_dataset = NerDataset(test_examples,tokenizer,label_map,max_seq_length)\r\n\r\ntrain_dataloader = data.DataLoader(dataset=train_dataset,\r\n                                batch_size=batch_size,\r\n                                shuffle=True,\r\n                                num_workers=4,\r\n                                collate_fn=NerDataset.pad)\r\n\r\ndev_dataloader = data.DataLoader(dataset=dev_dataset,\r\n                                batch_size=batch_size,\r\n                                shuffle=False,\r\n                                num_workers=4,\r\n                                collate_fn=NerDataset.pad)\r\n\r\ntest_dataloader = data.DataLoader(dataset=test_dataset,\r\n                                batch_size=batch_size,\r\n                                shuffle=False,\r\n                                num_workers=4,\r\n                                collate_fn=NerDataset.pad)\r\n\r\n\r\n#%%\r\n'''\r\n#####  Use only BertForTokenClassification  #####\r\n'''\r\nprint('*** Use only BertForTokenClassification ***')\r\n\r\nif load_checkpoint and os.path.exists(output_dir+'/ner_bert_checkpoint.pt'):\r\n    checkpoint = torch.load(output_dir+'/ner_bert_checkpoint.pt', map_location='cpu')\r\n    start_epoch = checkpoint['epoch']+1\r\n    valid_acc_prev = checkpoint['valid_acc']\r\n    valid_f1_prev = checkpoint['valid_f1']\r\n    model = BertForTokenClassification.from_pretrained(\r\n        bert_model_scale, state_dict=checkpoint['model_state'], num_labels=len(label_list))\r\n    print('Loaded the pretrain NER_BERT model, epoch:',checkpoint['epoch'],'valid acc:',\r\n            checkpoint['valid_acc'], 'valid f1:', checkpoint['valid_f1'])\r\nelse:\r\n    start_epoch = 0\r\n    valid_acc_prev = 0\r\n    valid_f1_prev = 0\r\n    model = BertForTokenClassification.from_pretrained(\r\n        bert_model_scale, num_labels=len(label_list))\r\n\r\nmodel.to(device)\r\n\r\n# Prepare optimizer\r\nnamed_params = list(model.named_parameters())\r\nno_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']\r\noptimizer_grouped_parameters = [\r\n    {'params': [p for n, p in named_params if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay_finetune},\r\n    {'params': [p for n, p in named_params if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}\r\n]\r\noptimizer = BertAdam(optimizer_grouped_parameters, lr=learning_rate0, warmup=warmup_proportion, t_total=total_train_steps)\r\n# optimizer = optim.Adam(model.parameters(), lr=learning_rate0)\r\n\r\ndef evaluate(model, predict_dataloader, batch_size, epoch_th, dataset_name):\r\n    # print(\"***** Running prediction *****\")\r\n    model.eval()\r\n    all_preds = []\r\n    all_labels = []\r\n    total=0\r\n    correct=0\r\n    start = time.time()\r\n    with torch.no_grad():\r\n        for batch in predict_dataloader:\r\n            batch = tuple(t.to(device) for t in batch)\r\n            input_ids, input_mask, segment_ids, predict_mask, label_ids = batch\r\n            out_scores = model(input_ids, segment_ids, input_mask)\r\n            # out_scores = out_scores.detach().cpu().numpy()\r\n            _, predicted = torch.max(out_scores, -1)\r\n            valid_predicted = torch.masked_select(predicted, predict_mask)\r\n            valid_label_ids = torch.masked_select(label_ids, predict_mask)\r\n            # print(len(valid_label_ids),len(valid_predicted),len(valid_label_ids)==len(valid_predicted))\r\n            all_preds.extend(valid_predicted.tolist())\r\n            all_labels.extend(valid_label_ids.tolist())\r\n            total += len(valid_label_ids)\r\n            correct += valid_predicted.eq(valid_label_ids).sum().item()\r\n\r\n    test_acc = correct/total\r\n    precision, recall, f1 = f1_score(np.array(all_labels), np.array(all_preds))\r\n    end = time.time()\r\n    print('Epoch:%d, Acc:%.2f, Precision: %.2f, Recall: %.2f, F1: %.2f on %s, Spend: %.3f minutes for evaluation' \\\r\n        % (epoch_th, 100.*test_acc, 100.*precision, 100.*recall, 100.*f1, dataset_name,(end-start)/60.0))\r\n    print('--------------------------------------------------------------')\r\n    return test_acc, f1\r\n\r\n\r\n#%%\r\n# train procedure using only BertForTokenClassification\r\n# train_start = time.time()\r\nglobal_step_th = int(len(train_examples) / batch_size / gradient_accumulation_steps * start_epoch)\r\n# for epoch in trange(start_epoch, total_train_epochs, desc=\"Epoch\"):\r\nfor epoch in range(start_epoch, total_train_epochs):\r\n    tr_loss = 0\r\n    train_start = time.time()\r\n    model.train()\r\n    optimizer.zero_grad()\r\n    # for step, batch in enumerate(tqdm(train_dataloader, desc=\"Iteration\")):\r\n    for step, batch in enumerate(train_dataloader):\r\n        batch = tuple(t.to(device) for t in batch)\r\n\r\n        input_ids, input_mask, segment_ids, predict_mask, label_ids = batch\r\n        loss = model(input_ids, segment_ids, input_mask, label_ids)\r\n\r\n        if gradient_accumulation_steps > 1:\r\n            loss = loss / gradient_accumulation_steps\r\n\r\n        loss.backward()\r\n        tr_loss += loss.item()\r\n\r\n        if (step + 1) % gradient_accumulation_steps == 0:\r\n            # modify learning rate with special warm up BERT uses\r\n            lr_this_step = learning_rate0 * warmup_linear(global_step_th/total_train_steps, warmup_proportion)\r\n            for param_group in optimizer.param_groups:\r\n                param_group['lr'] = lr_this_step\r\n            optimizer.step()\r\n            optimizer.zero_grad()\r\n            global_step_th += 1\r\n\r\n        print(\"Epoch:{}-{}/{}, CrossEntropyLoss: {} \".format(epoch, step, len(train_dataloader), loss.item()))\r\n\r\n    print('--------------------------------------------------------------')\r\n    print(\"Epoch:{} completed, Total training's Loss: {}, Spend: {}m\".format(epoch, tr_loss, (time.time() - train_start) / 60.0))\r\n    valid_acc, valid_f1 = evaluate(model, dev_dataloader, batch_size, epoch, 'Valid_set')\r\n    # Save a checkpoint\r\n    if valid_f1 > valid_f1_prev:\r\n        # model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self\r\n        torch.save({'epoch': epoch, 'model_state': model.state_dict(), 'valid_acc': valid_acc,\r\n            'valid_f1': valid_f1, 'max_seq_length': max_seq_length, 'lower_case': do_lower_case},\r\n                    os.path.join(output_dir, 'ner_bert_checkpoint.pt'))\r\n        valid_f1_prev = valid_f1\r\n\r\nevaluate(model, test_dataloader, batch_size, total_train_epochs-1, 'Test_set')\r\n\r\n#%%\r\n'''\r\nTest_set prediction using the best epoch of NER_BERT model\r\n'''\r\ncheckpoint = torch.load(output_dir+'/ner_bert_checkpoint.pt', map_location='cpu')\r\nepoch = checkpoint['epoch']\r\nvalid_acc_prev = checkpoint['valid_acc']\r\nvalid_f1_prev = checkpoint['valid_f1']\r\nmodel = BertForTokenClassification.from_pretrained(\r\n    bert_model_scale, state_dict=checkpoint['model_state'], num_labels=len(label_list))\r\n# if os.path.exists(output_dir+'/ner_bert_crf_checkpoint.pt'):\r\nmodel.to(device)\r\nprint('Loaded the pretrain NER_BERT model, epoch:',checkpoint['epoch'],'valid acc:', \r\n        checkpoint['valid_acc'], 'valid f1:', checkpoint['valid_f1'])\r\n\r\nmodel.to(device)\r\n# evaluate(model, train_dataloader, batch_size, total_train_epochs-1, 'Train_set')\r\nevaluate(model, test_dataloader, batch_size, epoch, 'Test_set')\r\n\r\n\r\n#%%\r\n'''\r\n#####  Use BertModel + CRF  #####\r\nCRF is for transition and the maximum likelyhood estimate(MLE).\r\nBert is for latent label -> Emission of word embedding.\r\n'''\r\nprint('*** Use BertModel + CRF ***')\r\n\r\ndef log_sum_exp_1vec(vec):  # shape(1,m)\r\n    max_score = vec[0, np.argmax(vec)]\r\n    max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])\r\n    return max_score + torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))\r\n\r\ndef log_sum_exp_mat(log_M, axis=-1):  # shape(n,m)\r\n    return torch.max(log_M, axis)[0]+torch.log(torch.exp(log_M-torch.max(log_M, axis)[0][:, None]).sum(axis))\r\n\r\ndef log_sum_exp_batch(log_Tensor, axis=-1): # shape (batch_size,n,m)\r\n    return torch.max(log_Tensor, axis)[0]+torch.log(torch.exp(log_Tensor-torch.max(log_Tensor, axis)[0].view(log_Tensor.shape[0],-1,1)).sum(axis))\r\n\r\n\r\nclass BERT_CRF_NER(nn.Module):\r\n\r\n    def __init__(self, bert_model, start_label_id, stop_label_id, num_labels, max_seq_length, batch_size, device):\r\n        super(BERT_CRF_NER, self).__init__()\r\n        self.hidden_size = 768\r\n        self.start_label_id = start_label_id\r\n        self.stop_label_id = stop_label_id\r\n        self.num_labels = num_labels\r\n        # self.max_seq_length = max_seq_length\r\n        self.batch_size = batch_size\r\n        self.device=device\r\n\r\n        # use pretrainded BertModel\r\n        self.bert = bert_model\r\n        self.dropout = torch.nn.Dropout(0.2)\r\n        # Maps the output of the bert into label space.\r\n        self.hidden2label = nn.Linear(self.hidden_size, self.num_labels)\r\n\r\n        # Matrix of transition parameters.  Entry i,j is the score of transitioning *to* i *from* j.\r\n        self.transitions = nn.Parameter(\r\n            torch.randn(self.num_labels, self.num_labels))\r\n\r\n        # These two statements enforce the constraint that we never transfer *to* the start tag(or label),\r\n        # and we never transfer *from* the stop label (the model would probably learn this anyway,\r\n        # so this enforcement is likely unimportant)\r\n        self.transitions.data[start_label_id, :] = -10000\r\n        self.transitions.data[:, stop_label_id] = -10000\r\n\r\n        nn.init.xavier_uniform_(self.hidden2label.weight)\r\n        nn.init.constant_(self.hidden2label.bias, 0.0)\r\n        # self.apply(self.init_bert_weights)\r\n\r\n    def init_bert_weights(self, module):\r\n        \"\"\" Initialize the weights.\r\n        \"\"\"\r\n        if isinstance(module, (nn.Linear, nn.Embedding)):\r\n            # Slightly different from the TF version which uses truncated_normal for initialization\r\n            # cf https://github.com/pytorch/pytorch/pull/5617\r\n            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\r\n        elif isinstance(module, BertLayerNorm):\r\n            module.bias.data.zero_()\r\n            module.weight.data.fill_(1.0)\r\n        if isinstance(module, nn.Linear) and module.bias is not None:\r\n            module.bias.data.zero_()\r\n\r\n    def _forward_alg(self, feats):\r\n        '''\r\n        this also called alpha-recursion or forward recursion, to calculate log_prob of all barX\r\n        '''\r\n\r\n        # T = self.max_seq_length\r\n        T = feats.shape[1]\r\n        batch_size = feats.shape[0]\r\n\r\n        # alpha_recursion,forward, alpha(zt)=p(zt,bar_x_1:t)\r\n        log_alpha = torch.Tensor(batch_size, 1, self.num_labels).fill_(-10000.).to(self.device)\r\n        # normal_alpha_0 : alpha[0]=Ot[0]*self.PIs\r\n        # self.start_label has all of the score. it is log,0 is p=1\r\n        log_alpha[:, 0, self.start_label_id] = 0\r\n\r\n        # feats: sentances -> word embedding -> lstm -> MLP -> feats\r\n        # feats is the probability of emission, feat.shape=(1,tag_size)\r\n        for t in range(1, T):\r\n            log_alpha = (log_sum_exp_batch(self.transitions + log_alpha, axis=-1) + feats[:, t]).unsqueeze(1)\r\n\r\n        # log_prob of all barX\r\n        log_prob_all_barX = log_sum_exp_batch(log_alpha)\r\n        return log_prob_all_barX\r\n\r\n    def _get_bert_features(self, input_ids, segment_ids, input_mask):\r\n        '''\r\n        sentances -> word embedding -> lstm -> MLP -> feats\r\n        '''\r\n        bert_seq_out, _ = self.bert(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, output_all_encoded_layers=False)\r\n        bert_seq_out = self.dropout(bert_seq_out)\r\n        bert_feats = self.hidden2label(bert_seq_out)\r\n        return bert_feats\r\n\r\n    def _score_sentence(self, feats, label_ids):\r\n        '''\r\n        Gives the score of a provided label sequence\r\n        p(X=w1:t,Zt=tag1:t)=...p(Zt=tag_t|Zt-1=tag_t-1)p(xt|Zt=tag_t)...\r\n        '''\r\n\r\n        # T = self.max_seq_length\r\n        T = feats.shape[1]\r\n        batch_size = feats.shape[0]\r\n\r\n        batch_transitions = self.transitions.expand(batch_size,self.num_labels,self.num_labels)\r\n        batch_transitions = batch_transitions.flatten(1)\r\n\r\n        score = torch.zeros((feats.shape[0],1)).to(device)\r\n        # the 0th node is start_label->start_word,\bthe probability of them=1. so t begin with 1.\r\n        for t in range(1, T):\r\n            score = score + \\\r\n                batch_transitions.gather(-1, (label_ids[:, t]*self.num_labels+label_ids[:, t-1]).view(-1,1)) \\\r\n                    + feats[:, t].gather(-1, label_ids[:, t].view(-1,1)).view(-1,1)\r\n        return score\r\n\r\n    def _viterbi_decode(self, feats):\r\n        '''\r\n        Max-Product Algorithm or viterbi algorithm, argmax(p(z_0:t|x_0:t))\r\n        '''\r\n\r\n        # T = self.max_seq_length\r\n        T = feats.shape[1]\r\n        batch_size = feats.shape[0]\r\n\r\n        # batch_transitions=self.transitions.expand(batch_size,self.num_labels,self.num_labels)\r\n\r\n        log_delta = torch.Tensor(batch_size, 1, self.num_labels).fill_(-10000.).to(self.device)\r\n        log_delta[:, 0, self.start_label_id] = 0\r\n\r\n        # psi is for the vaule of the last latent that make P(this_latent) maximum.\r\n        psi = torch.zeros((batch_size, T, self.num_labels), dtype=torch.long).to(self.device)  # psi[0]=0000 useless\r\n        for t in range(1, T):\r\n            # delta[t][k]=max_z1:t-1( p(x1,x2,...,xt,z1,z2,...,zt-1,zt=k|theta) )\r\n            # delta[t] is the max prob of the path from  z_t-1 to z_t[k]\r\n            log_delta, psi[:, t] = torch.max(self.transitions + log_delta, -1)\r\n            # psi[t][k]=argmax_z1:t-1( p(x1,x2,...,xt,z1,z2,...,zt-1,zt=k|theta) )\r\n            # psi[t][k] is the path choosed from z_t-1 to z_t[k],the value is the z_state(is k) index of z_t-1\r\n            log_delta = (log_delta + feats[:, t]).unsqueeze(1)\r\n\r\n        # trace back\r\n        path = torch.zeros((batch_size, T), dtype=torch.long).to(self.device)\r\n\r\n        # max p(z1:t,all_x|theta)\r\n        max_logLL_allz_allx, path[:, -1] = torch.max(log_delta.squeeze(), -1)\r\n\r\n        for t in range(T-2, -1, -1):\r\n            # choose the state of z_t according the state choosed of z_t+1.\r\n            path[:, t] = psi[:, t+1].gather(-1,path[:, t+1].view(-1,1)).squeeze()\r\n\r\n        return max_logLL_allz_allx, path\r\n\r\n    def neg_log_likelihood(self, input_ids, segment_ids, input_mask, label_ids):\r\n        bert_feats = self._get_bert_features(input_ids, segment_ids, input_mask)\r\n        forward_score = self._forward_alg(bert_feats)\r\n        # p(X=w1:t,Zt=tag1:t)=...p(Zt=tag_t|Zt-1=tag_t-1)p(xt|Zt=tag_t)...\r\n        gold_score = self._score_sentence(bert_feats, label_ids)\r\n        # - log[ p(X=w1:t,Zt=tag1:t)/p(X=w1:t) ] = - log[ p(Zt=tag1:t|X=w1:t) ]\r\n        return torch.mean(forward_score - gold_score)\r\n\r\n    # this forward is just for predict, not for train\r\n    # dont confuse this with _forward_alg above.\r\n    def forward(self, input_ids, segment_ids, input_mask):\r\n        # Get the emission scores from the BiLSTM\r\n        bert_feats = self._get_bert_features(input_ids, segment_ids, input_mask)\r\n\r\n        # Find the best path, given the features.\r\n        score, label_seq_ids = self._viterbi_decode(bert_feats)\r\n        return score, label_seq_ids\r\n\r\n\r\nstart_label_id = conllProcessor.get_start_label_id()\r\nstop_label_id = conllProcessor.get_stop_label_id()\r\n\r\nbert_model = BertModel.from_pretrained(bert_model_scale)\r\nmodel = BERT_CRF_NER(bert_model, start_label_id, stop_label_id, len(label_list), max_seq_length, batch_size, device)\r\n\r\n#%%\r\nif load_checkpoint and os.path.exists(output_dir+'/ner_bert_crf_checkpoint.pt'):\r\n    checkpoint = torch.load(output_dir+'/ner_bert_crf_checkpoint.pt', map_location='cpu')\r\n    start_epoch = checkpoint['epoch']+1\r\n    valid_acc_prev = checkpoint['valid_acc']\r\n    valid_f1_prev = checkpoint['valid_f1']\r\n    pretrained_dict=checkpoint['model_state']\r\n    net_state_dict = model.state_dict()\r\n    pretrained_dict_selected = {k: v for k, v in pretrained_dict.items() if k in net_state_dict}\r\n    net_state_dict.update(pretrained_dict_selected)\r\n    model.load_state_dict(net_state_dict)\r\n    print('Loaded the pretrain NER_BERT_CRF model, epoch:',checkpoint['epoch'],'valid acc:',\r\n            checkpoint['valid_acc'], 'valid f1:', checkpoint['valid_f1'])\r\nelse:\r\n    start_epoch = 0\r\n    valid_acc_prev = 0\r\n    valid_f1_prev = 0\r\n\r\nmodel.to(device)\r\n\r\n# Prepare optimizer\r\nparam_optimizer = list(model.named_parameters())\r\n\r\nno_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']\r\nnew_param = ['transitions', 'hidden2label.weight', 'hidden2label.bias']\r\noptimizer_grouped_parameters = [\r\n    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay) \\\r\n        and not any(nd in n for nd in new_param)], 'weight_decay': weight_decay_finetune},\r\n    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay) \\\r\n        and not any(nd in n for nd in new_param)], 'weight_decay': 0.0},\r\n    {'params': [p for n, p in param_optimizer if n in ('transitions','hidden2label.weight')] \\\r\n        , 'lr':lr0_crf_fc, 'weight_decay': weight_decay_crf_fc},\r\n    {'params': [p for n, p in param_optimizer if n == 'hidden2label.bias'] \\\r\n        , 'lr':lr0_crf_fc, 'weight_decay': 0.0}\r\n]\r\noptimizer = BertAdam(optimizer_grouped_parameters, lr=learning_rate0, warmup=warmup_proportion, t_total=total_train_steps)\r\n# optimizer = optim.Adam(model.parameters(), lr=learning_rate0)\r\n\r\ndef warmup_linear(x, warmup=0.002):\r\n    if x < warmup:\r\n        return x/warmup\r\n    return 1.0 - x\r\n\r\ndef evaluate(model, predict_dataloader, batch_size, epoch_th, dataset_name):\r\n    # print(\"***** Running prediction *****\")\r\n    model.eval()\r\n    all_preds = []\r\n    all_labels = []\r\n    total=0\r\n    correct=0\r\n    start = time.time()\r\n    with torch.no_grad():\r\n        for batch in predict_dataloader:\r\n            batch = tuple(t.to(device) for t in batch)\r\n            input_ids, input_mask, segment_ids, predict_mask, label_ids = batch\r\n            _, predicted_label_seq_ids = model(input_ids, segment_ids, input_mask)\r\n            # _, predicted = torch.max(out_scores, -1)\r\n            valid_predicted = torch.masked_select(predicted_label_seq_ids, predict_mask)\r\n            valid_label_ids = torch.masked_select(label_ids, predict_mask)\r\n            all_preds.extend(valid_predicted.tolist())\r\n            all_labels.extend(valid_label_ids.tolist())\r\n            # print(len(valid_label_ids),len(valid_predicted),len(valid_label_ids)==len(valid_predicted))\r\n            total += len(valid_label_ids)\r\n            correct += valid_predicted.eq(valid_label_ids).sum().item()\r\n\r\n    test_acc = correct/total\r\n    precision, recall, f1 = f1_score(np.array(all_labels), np.array(all_preds))\r\n    end = time.time()\r\n    print('Epoch:%d, Acc:%.2f, Precision: %.2f, Recall: %.2f, F1: %.2f on %s, Spend:%.3f minutes for evaluation' \\\r\n        % (epoch_th, 100.*test_acc, 100.*precision, 100.*recall, 100.*f1, dataset_name,(end-start)/60.0))\r\n    print('--------------------------------------------------------------')\r\n    return test_acc, f1\r\n\r\n#%%\r\n# train procedure\r\nglobal_step_th = int(len(train_examples) / batch_size / gradient_accumulation_steps * start_epoch)\r\n\r\n# train_start=time.time()\r\n# for epoch in trange(start_epoch, total_train_epochs, desc=\"Epoch\"):\r\nfor epoch in range(start_epoch, total_train_epochs):\r\n    tr_loss = 0\r\n    train_start = time.time()\r\n    model.train()\r\n    optimizer.zero_grad()\r\n    # for step, batch in enumerate(tqdm(train_dataloader, desc=\"Iteration\")):\r\n    for step, batch in enumerate(train_dataloader):\r\n        batch = tuple(t.to(device) for t in batch)\r\n        input_ids, input_mask, segment_ids, predict_mask, label_ids = batch\r\n\r\n        neg_log_likelihood = model.neg_log_likelihood(input_ids, segment_ids, input_mask, label_ids)\r\n\r\n        if gradient_accumulation_steps > 1:\r\n            neg_log_likelihood = neg_log_likelihood / gradient_accumulation_steps\r\n\r\n        neg_log_likelihood.backward()\r\n\r\n        tr_loss += neg_log_likelihood.item()\r\n\r\n        if (step + 1) % gradient_accumulation_steps == 0:\r\n            # modify learning rate with special warm up BERT uses\r\n            lr_this_step = learning_rate0 * warmup_linear(global_step_th/total_train_steps, warmup_proportion)\r\n            for param_group in optimizer.param_groups:\r\n                param_group['lr'] = lr_this_step\r\n            optimizer.step()\r\n            optimizer.zero_grad()\r\n            global_step_th += 1\r\n\r\n        print(\"Epoch:{}-{}/{}, Negative loglikelihood: {} \".format(epoch, step, len(train_dataloader), neg_log_likelihood.item()))\r\n\r\n    print('--------------------------------------------------------------')\r\n    print(\"Epoch:{} completed, Total training's Loss: {}, Spend: {}m\".format(epoch, tr_loss, (time.time() - train_start)/60.0))\r\n    valid_acc, valid_f1 = evaluate(model, dev_dataloader, batch_size, epoch, 'Valid_set')\r\n\r\n    # Save a checkpoint\r\n    if valid_f1 > valid_f1_prev:\r\n        # model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self\r\n        torch.save({'epoch': epoch, 'model_state': model.state_dict(), 'valid_acc': valid_acc,\r\n            'valid_f1': valid_f1, 'max_seq_length': max_seq_length, 'lower_case': do_lower_case},\r\n                    os.path.join(output_dir, 'ner_bert_crf_checkpoint.pt'))\r\n        valid_f1_prev = valid_f1\r\n\r\nevaluate(model, test_dataloader, batch_size, total_train_epochs-1, 'Test_set')\r\n\r\n\r\n#%%\r\n'''\r\nTest_set prediction using the best epoch of NER_BERT_CRF model\r\n'''\r\ncheckpoint = torch.load(output_dir+'/ner_bert_crf_checkpoint.pt', map_location='cpu')\r\nepoch = checkpoint['epoch']\r\nvalid_acc_prev = checkpoint['valid_acc']\r\nvalid_f1_prev = checkpoint['valid_f1']\r\npretrained_dict=checkpoint['model_state']\r\nnet_state_dict = model.state_dict()\r\npretrained_dict_selected = {k: v for k, v in pretrained_dict.items() if k in net_state_dict}\r\nnet_state_dict.update(pretrained_dict_selected)\r\nmodel.load_state_dict(net_state_dict)\r\nprint('Loaded the pretrain  NER_BERT_CRF  model, epoch:',checkpoint['epoch'],'valid acc:',\r\n      checkpoint['valid_acc'], 'valid f1:', checkpoint['valid_f1'])\r\n\r\nmodel.to(device)\r\n#evaluate(model, train_dataloader, batch_size, total_train_epochs-1, 'Train_set')\r\nevaluate(model, test_dataloader, batch_size, epoch, 'Test_set')\r\n# print('Total spend:',(time.time()-train_start)/60.0)\r\n\r\n\r\n#%%\r\nmodel.eval()\r\nwith torch.no_grad():\r\n    demon_dataloader = data.DataLoader(dataset=test_dataset,\r\n                                batch_size=10,\r\n                                shuffle=False,\r\n                                num_workers=4,\r\n                                collate_fn=pad)\r\n    for batch in demon_dataloader:\r\n        batch = tuple(t.to(device) for t in batch)\r\n        input_ids, input_mask, segment_ids, predict_mask, label_ids = batch\r\n        _, predicted_label_seq_ids = model(input_ids, segment_ids, input_mask)\r\n        # _, predicted = torch.max(out_scores, -1)\r\n        valid_predicted = torch.masked_select(predicted_label_seq_ids, predict_mask)\r\n        # valid_label_ids = torch.masked_select(label_ids, predict_mask)\r\n        for i in range(10):\r\n            print(predicted_label_seq_ids[i])\r\n            print(label_ids[i])\r\n            new_ids=predicted_label_seq_ids[i].cpu().numpy()[predict_mask[i].cpu().numpy()==1]\r\n            print(list(map(lambda i: label_list[i], new_ids)))\r\n            print(test_examples[i].labels)\r\n        break\r\n#%%\r\nprint(conllProcessor.get_label_map())\r\n# print(test_examples[8].words)\r\n# print(test_features[8].label_ids)\r\n"
  },
  {
    "path": "README.md",
    "content": "# NER implementation with BERT and CRF model\n> Zhibin Lu\n\nThis is a named entity recognizer based on [BERT Model(pytorch-pretrained-BERT)](https://github.com/huggingface/pytorch-pretrained-BERT) and CRF.\n\nSomeone construct model with BERT, LSTM and CRF, like this [BERT-BiLSTM-CRF-NER](https://github.com/FuYanzhe2/Name-Entity-Recognition/tree/master/BERT-BiLSTM-CRF-NER), but in theory, the BERT mechanism has replaced the role of LSTM, so I think LSTM is redundant.\n\nFor the performance, BERT+CRF is always a little better than single BERT in my experience.\n\n## Requirements\n- python 3.6\n- pytorch 1.0.0\n- [pytorch-pretrained-bert 0.4.0](https://github.com/huggingface/transformers/releases/tag/v0.4.0)\n## Overview\nThe NER_BERT_CRF.py include 2 model:\n- model 1:\n  - This is just a pretrained BertForTokenClassification, For a comparision with my BERT-CRF model\n- model 2:\n  - A pretrained BERT with CRF model.\n- data set\n  - [CoNLL-2003](https://github.com/Franck-Dernoncourt/NeuroNER/tree/master/neuroner/data/conll2003/en)\n### Parameters\n- NER_labels = ['X', '[CLS]', '[SEP]', 'O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']\n- max_seq_length = 180\n- batch_size = 32\n- learning_rate = 5e-5\n- weight_decay = 1e-5\n- learning_rate for CRF and FC: 8e-5 \n- weight_decay for CRF and FC: 5e-6\n- total_train_epochs = 20\n- bert_model_scale = 'bert-base-cased'\n- do_lower_case = False\n### Performance\n- [Bert paper](https://arxiv.org/abs/1810.04805)\n  - F1-Score on valid data: 96.4 %\n  - F1-Score on test data: 92.4 %\n- BertForTokenClassification (epochs = 15)\n  - Accuracy on valid data: 99.10 %\n  - Accuracy on test data: 98.11 %\n  - F1-Score on valid data: 96.18 %\n  - F1-Score on test data: 92.17 %\n- Bert+CRF (epochs = 16)\n  - Accuracy on valid data: 99.10 %\n  - Accuracy of test data: 98.14 % \n  - F1-Score on valid data: 96.23 %\n  - F1-Score on test data: 92.29 %\n### References\n- [Bert paper](https://arxiv.org/abs/1810.04805)\n- [Bert with PyTorch implementation](https://github.com/huggingface/pytorch-pretrained-BERT)\n- [ericput/Bert-ner](https://github.com/ericput/bert-ner)\n- [CoNLL-2003 data set](https://github.com/Franck-Dernoncourt/NeuroNER/tree/master/neuroner/data/conll2003/en)\n- [Kyubyong/bert_ner](https://github.com/Kyubyong/bert_ner)\n"
  }
]