[
  {
    "path": ".gitignore",
    "content": "*.pt\n*.amat\n*.mat\n*.out\n*.out~\n*.pyc\n*.pt~\n.gitignore~\n*.out~\n*.sh\n*.sh~\n*.py~\n*.json\n*.json~\n*.model\n*.h5\n*.tar.gz\n*.hdf5\n*.dict\n*.pkl\n"
  },
  {
    "path": "COLLINS.prm",
    "content": "##------------------------------------------##\n## Debug mode                               ##\n##   0: No debugging                        ##\n##   1: print data for individual sentence  ##\n##------------------------------------------##\nDEBUG 0\n\n##------------------------------------------##\n## MAX error                                ##\n##    Number of error to stop the process.  ##\n##    This is useful if there could be      ##\n##    tokanization error.                   ##\n##    The process will stop when this number##\n##    of errors are accumulated.            ##\n##------------------------------------------##\nMAX_ERROR 10\n\n##------------------------------------------##\n## Cut-off length for statistics            ##\n##    At the end of evaluation, the         ##\n##    statistics for the senetnces of length##\n##    less than or equal to this number will##\n##    be shown, on top of the statistics    ##\n##    for all the sentences                 ##\n##------------------------------------------##\nCUTOFF_LEN 10\n\n##------------------------------------------##\n## unlabeled or labeled bracketing          ##\n##    0: unlabeled bracketing               ##\n##    1: labeled bracketing                 ##\n##------------------------------------------##\nLABELED 0\n\n##------------------------------------------##\n## Delete labels                            ##\n##    list of labels to be ignored.         ##\n##    If it is a pre-terminal label, delete ##\n##    the word along with the brackets.     ##\n##    If it is a non-terminal label, just   ##\n##    delete the brackets (don't delete     ##\n##    deildrens).                           ##\n##------------------------------------------##\nDELETE_LABEL TOP\nDELETE_LABEL -NONE-\nDELETE_LABEL ,\nDELETE_LABEL :\nDELETE_LABEL ``\nDELETE_LABEL ''\nDELETE_LABEL .\n\n##------------------------------------------##\n## Delete labels for length calculation     ##\n##    list of labels to be ignored for      ##\n##    length calculation purpose            ##\n##------------------------------------------##\nDELETE_LABEL_FOR_LENGTH -NONE-\n\n##------------------------------------------##\n## Equivalent labels, words                 ##\n##     the pairs are considered equivalent  ##\n##     This is non-directional.             ##\n##------------------------------------------##\nEQ_LABEL ADVP PRT\n\n# EQ_WORD  Example example\n"
  },
  {
    "path": "README.md",
    "content": "# Unsupervised Recurrent Neural Network Grammars\n\nThis is an implementation of the paper:  \n[Unsupervised Recurrent Neural Network Grammars](https://arxiv.org/abs/1904.03746)  \nYoon Kim, Alexander Rush, Lei Yu, Adhiguna Kuncoro, Chris Dyer, Gabor Melis  \nNAACL 2019  \n\n## Dependencies\nThe code was tested in `python 3.6` and `pytorch 1.0`.\n\n## Data  \nSample train/val/test data is in the `data/` folder. These are the standard datasets from PTB.\nFirst preprocess the data:\n```\npython preprocess.py --trainfile data/train.txt --valfile data/valid.txt --testfile data/test.txt \n--outputfile data/ptb --vocabminfreq 1 --lowercase 0 --replace_num 0 --batchsize 16\n```\nRunning this will save the following files in the `data/` folder: `ptb-train.pkl`, `ptb-val.pkl`,\n`ptb-test.pkl`, `ptb.dict`. Here `ptb.dict` is the word-idx mapping, and you can change the\noutput folder/name by changing the argument to `outputfile`. Also, the preprocessing here\nwill replace singletons with a single `<unk>` rather than with Berkeley parser's mapping rules\n(see below for results using this setup).\n\n## Training\nTo train the URNNG:\n```\npython train.py --train_file data/ptb-train.pkl --val_file data/ptb-val.pkl --save_path urnng.pt \n--mode unsupervised --gpu 0\n```\nwhere `save_path` is where you want to save the model, and `gpu 0` is for using the first GPU\nin the cluster (the mapping from PyTorch GPU index to your cluster's GPU index may vary).\nTraining should take 2 to 3 days depending on your setup.\n\nTo train the RNNG:\n```\npython train.py --train_file data/ptb-train.pkl --val_file data/ptb-val.pkl --save_path rnng.pt \n--mode supervised --train_q_epochs 18 --gpu 0 \n```\n\nFor fine-tuning:\n```\npython train.py --train_from rnng.pt --train_file data/ptb-train.pkl --val_file data/ptb-val.pkl \n--save_path rnng-urnng.pt --mode unsupervised --lr 0.1 --train_q_epochs 10 --epochs 10 \n--min_epochs 6 --gpu 0 --kl_warmup 0\n```\n\nTo train the LM:\n```\npython train_lm.py --train_file data/ptb-train.pkl --val_file data/ptb-val.pkl \n--test_file data/ptb-test.pkl --save_path lm.pt \n```\n\n## Evaluation\nTo evaluate perplexity with importance sampling on the test set:\n```\npython eval_ppl.py --model_file urnng.pt --test_file data/ptb-test.pkl --samples 1000 \n--is_temp 2 --gpu 0\n```\nThe argument `samples` is for the number of importance weighted samples, and `is_temp` is for\nflattening the inference network's distribution (footnote 14 in the paper).\nThe same evaluation code will work for RNNG. \n\nFor LM evaluation:\n```\npython train_lm.py --train_from lm.pt --test_file data/ptb-test.pkl --test 1\n```\n\nTo evaluate F1, first we need to parse the test set:\n```\npython parse.py --model_file urnng.pt --data_file data/ptb-test.txt --out_file pred-parse.txt \n--gold_out_file gold-parse.txt --gpu 0\n```\nThis will output the predicted parse trees into `pred-parse.txt`. We also output a version\nof the gold parse `gold-parse.txt` to be used as input for `evalb`, since sentences with only trivial spans are ignored by `parse.py`. Note that corpus/sentence F1 results printed here do not correspond to the results reported in the paper, since it does not ignore punctuation. \n\nFinally, download/install `evalb`, available [here](https://nlp.cs.nyu.edu/evalb).\nThen run:\n```\nevalb -p COLLINS.prm gold-parse.txt test-parse.txt\n```\nwhere `COLLINS.prm` is the parameter file (provided in this repo) that tells `evalb` to ignore\npunctuation and evaluate on unlabeled F1.\n\n## Note Regarding Preprocessing\nNote that some of the details regarding the preprocessing is slightly different from the original \npaper. In particular, in this implementation we replace singleton words a single `<unk>` token\ninstead of using Berkeley parser's mapping rules. This results in slight lower perplexity\nfor all models, since the vocabulary size is smaller. Here are the perplexty numbers I get\nin this setting:\n\n- RNNLM: 89.2 \n- RNNG: 83.7 \n- URNNG: 85.1 (F1: 38.4) \n- RNNG --> URNNG: 82.5\n\n## Acknowledgements\nSome of our preprocessing and evaluation code is based on the following repositories:  \n- [Recurrent Neural Network Grammars](https://github.com/clab/rnng)  \n- [Parsing Reading Predict Network](https://github.com/yikangshen/PRPN)  \n\n## License\nMIT"
  },
  {
    "path": "TreeCRF.py",
    "content": "#!/usr/bin/env python3\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nimport itertools\nimport utils\nimport random\n  \nclass ConstituencyTreeCRF(nn.Module):\n  def __init__(self):\n    super(ConstituencyTreeCRF, self).__init__()\n    self.huge = 1e9\n\n  def logadd(self, x, y):\n    d = torch.max(x,y)  \n    return torch.log(torch.exp(x-d) + torch.exp(y-d)) + d    \n\n  def logsumexp(self, x, dim=1):\n    d = torch.max(x, dim)[0]\n    return torch.log(torch.exp(x - d.unsqueeze(dim).expand_as(x)).sum(dim)) + d\n\n  def _init_table(self, scores):\n    # initialize dynamic programming table\n    batch_size = scores.size(0)\n    n = scores.size(1)\n    self.alpha = [[scores.new(batch_size).fill_(-self.huge) for _ in range(n)] for _ in range(n)]\n\n  def _forward(self, scores):\n    #inside step\n    batch_size = scores.size(0)\n    n = scores.size(1)\n    self._init_table(scores)\n    for i in range(n):\n      self.alpha[i][i] = scores[:, i, i]\n    for k in np.arange(1, n+1):\n      for s in range(n):\n        t = s + k\n        if t > n-1:\n          break\n        tmp = [self.alpha[s][u] + self.alpha[u+1][t] + scores[:, s, t] for u in np.arange(s,t)]\n        tmp = torch.stack(tmp, 1)\n        self.alpha[s][t] = self.logsumexp(tmp, 1)\n            \n  def _backward(self, scores):\n    #outside step\n    batch_size = scores.size(0)\n    n = scores.size(1)\n    self.beta = [[None for _ in range(n)] for _ in range(n)]\n    self.beta[0][n-1] = scores.new(batch_size).fill_(0)\n    for k in np.arange(n-1, 0, -1):\n      for s in range(n):\n        t = s + k\n        if t > n-1:\n          break\n        for u in np.arange(s, t):                    \n          if s < u+1:\n            tmp = self.beta[s][t] + self.alpha[u+1][t] + scores[:, s, t]\n            if self.beta[s][u] is None:\n              self.beta[s][u] = tmp\n            else:\n              self.beta[s][u] = self.logadd(self.beta[s][u], tmp)\n          if u+1 < t+1:\n            tmp =  self.beta[s][t] + self.alpha[s][u]  + scores[:, s, t]\n            if self.beta[u+1][t] is None:\n              self.beta[u+1][t] = tmp\n            else:\n              self.beta[u+1][t] = self.logadd(self.beta[u+1][t], tmp)\n\n  def _marginal(self, scores):\n    batch_size = scores.size(0)\n    n = scores.size(1)\n    self.log_marginal = [[None for _ in range(n)] for _ in range(n)]\n    log_Z = self.alpha[0][n-1]\n    for s in range(n):\n      for t in np.arange(s, n):\n        self.log_marginal[s][t] = self.alpha[s][t] + self.beta[s][t] - log_Z\n  \n  def _entropy(self, scores):\n    batch_size = scores.size(0)\n    n = scores.size(1)\n    self.entropy = [[None for _ in range(n)] for _ in range(n)]\n    for i in range(n):\n      self.entropy[i][i] = scores.new(batch_size).fill_(0)\n    for k in np.arange(1, n+1):\n      for s in range(n):\n        t = s + k\n        if t > n-1:\n          break\n        score = []\n        prev_ent = []\n        for u in np.arange(s, t):\n          score.append(self.alpha[s][u] + self.alpha[u+1][t])\n          prev_ent.append(self.entropy[s][u] + self.entropy[u+1][t])\n        score = torch.stack(score, 1) \n        prev_ent = torch.stack(prev_ent, 1)\n        log_prob = F.log_softmax(score, dim = 1)\n        prob = log_prob.exp()        \n        entropy = ((prev_ent - log_prob)*prob).sum(1)\n        self.entropy[s][t] = entropy\n      \n        \n  def _sample(self, scores, alpha = None, argmax = False):    \n    # sample from p(tree | sent)\n    # also get the spans\n    if alpha is None:\n      self._forward(scores)\n      alpha = self.alpha\n    batch_size = scores.size(0)\n    n = scores.size(1)\n    tree = scores.new(batch_size, n, n).zero_()\n    all_log_probs = []\n    tree_brackets = []\n    spans = []\n    for b in range(batch_size):\n      sampled = [(0, n-1)]\n      span = [(0, n-1)]\n      queue = [(0, n-1)] #start, end\n      log_probs = []\n      tree_str = get_span_str(0, n-1)\n      while len(queue) > 0:\n        node = queue.pop(0)\n        start, end = node\n        left_parent = get_span_str(start, None)\n        right_parent = get_span_str(None, end)\n        score = []\n        score_idx = []\n        for u in np.arange(start, end):\n          score.append(alpha[start][u][b] + alpha[u+1][end][b])\n          score_idx.append([(start, u), (u+1, end)])\n        score = torch.stack(score, 0) \n        log_prob = F.log_softmax(score, dim = 0)\n        if argmax:\n          sample = torch.max(log_prob, 0)[1]\n        else:\n          prob = log_prob.exp()\n          sample = torch.multinomial(log_prob.exp(), 1)          \n        sample_idx = score_idx[sample.item()]\n        log_probs.append(log_prob[sample.item()])\n        for idx in sample_idx:\n          if idx[0] != idx[1]:\n            queue.append(idx)\n            span.append(idx)\n          sampled.append(idx)\n        left_child = '(' + get_span_str(sample_idx[0][0], sample_idx[0][1])    \n        right_child = get_span_str(sample_idx[1][0], sample_idx[1][1]) + ')'\n        if sample_idx[0][0] != sample_idx[0][1]:\n          tree_str = tree_str.replace(left_parent, left_child)\n        if sample_idx[1][0] != sample_idx[1][1]:\n          tree_str = tree_str.replace(right_parent, right_child)\n      all_log_probs.append(torch.stack(log_probs, 0).sum(0))\n      tree_brackets.append(tree_str)\n      spans.append(span[::-1])\n      for idx in sampled:\n        tree[b][idx[0]][idx[1]] = 1\n        \n    all_log_probs = torch.stack(all_log_probs, 0)\n    return tree, all_log_probs, tree_brackets, spans\n\n  def _viterbi(self, scores):\n    # cky algorithm\n    batch_size = scores.size(0)\n    n = scores.size(1)\n    self.max_scores = scores.new(batch_size, n, n).fill_(-self.huge)\n    self.bp = scores.new(batch_size, n, n).zero_()\n    self.argmax = scores.new(batch_size, n, n).zero_()\n    self.spans = [[] for _ in range(batch_size)]\n    tmp = scores.new(batch_size, n).zero_()\n    for i in range(n):\n      self.max_scores[:, i, i] = scores[:, i, i]      \n    for k in np.arange(1, n):\n      for s in np.arange(n):\n        t = s + k\n        if t > n-1:\n          break\n        for u in np.arange(s, t):\n          tmp = self.max_scores[:, s, u] + self.max_scores[:, u+1, t] + scores[:, s, t]\n          self.bp[:, s, t][self.max_scores[:, s, t] < tmp] = int(u)\n          self.max_scores[:, s, t] = torch.max(self.max_scores[:, s, t], tmp)\n    for b in range(batch_size):\n      self._backtrack(b, 0, n-1)      \n    return self.max_scores[:, 0, n-1], self.argmax, self.spans\n\n  def _backtrack(self, b, s, t):\n    u = int(self.bp[b][s][t])\n    self.argmax[b][s][t] = 1\n    if s == t:\n      return None      \n    else:\n      self.spans[b].insert(0, (s,t))\n      self._backtrack(b, s, u)\n      self._backtrack(b, u+1, t)\n    return None  \n \ndef get_span_str(start = None, end = None):\n  assert(start is not None or end is not None)\n  if start is None:\n    return ' '  + str(end) + ')'\n  elif end is None:\n    return '(' + str(start) + ' '\n  else:\n    return ' (' + str(start) + ' ' + str(end) + ') '    \n"
  },
  {
    "path": "data/test.txt",
    "content": "(S (INTJ (RB No)) (, ,) (NP (PRP it)) (VP (VBD was) (RB n't) (NP (NNP Black) (NNP Monday))) (. .))\n(S (CC But) (SBAR (IN while) (S (NP (DT the) (NNP New) (NNP York) (NNP Stock) (NNP Exchange)) (VP (VBD did) (RB n't) (VP (VB fall) (ADVP (RB apart)) (NP (NNP Friday)) (SBAR (IN as) (S (NP (DT the) (NNP Dow) (NNP Jones) (NNP Industrial) (NNP Average)) (VP (VBD plunged) (NP (NP (CD 190.58) (NNS points)) (PRN (: --) (NP (NP (JJS most)) (PP (IN of) (NP (PRP it))) (PP (IN in) (NP (DT the) (JJ final) (NN hour)))) (: --)))))))))) (NP (PRP it)) (ADVP (RB barely)) (VP (VBD managed) (S (VP (TO to) (VP (VB stay) (NP (NP (DT this) (NN side)) (PP (IN of) (NP (NN chaos)))))))) (. .))\n(S (NP (NP (DT Some) (`` ``) (NN circuit) (NNS breakers) ('' '')) (VP (VBN installed) (PP (IN after) (NP (DT the) (NNP October) (CD 1987) (NN crash))))) (VP (VBD failed) (NP (PRP$ their) (JJ first) (NN test)) (PRN (, ,) (S (NP (NNS traders)) (VP (VBP say))) (, ,)) (S (ADJP (JJ unable) (S (VP (TO to) (VP (VB cool) (NP (NP (DT the) (NN selling) (NN panic)) (PP (IN in) (NP (DT both) (NNS stocks) (CC and) (NNS futures)))))))))) (. .))\n(S (NP (NP (NP (DT The) (CD 49) (NN stock) (NN specialist) (NNS firms)) (PP (IN on) (NP (DT the) (NNP Big) (NNP Board) (NN floor)))) (: --) (NP (NP (DT the) (NNS buyers) (CC and) (NNS sellers)) (PP (IN of) (NP (JJ last) (NN resort))) (SBAR (WHNP (WP who)) (S (VP (VBD were) (VP (VBN criticized) (PP (IN after) (NP (DT the) (CD 1987) (NN crash)))))))) (: --)) (ADVP (RB once) (RB again)) (VP (MD could) (RB n't) (VP (VB handle) (NP (DT the) (NN selling) (NN pressure)))) (. .))\n(S (S (NP (JJ Big) (NN investment) (NNS banks)) (VP (VBD refused) (S (VP (TO to) (VP (VB step) (ADVP (IN up) (PP (TO to) (NP (DT the) (NN plate)))) (S (VP (TO to) (VP (VB support) (NP (DT the) (JJ beleaguered) (NN floor) (NNS traders)) (PP (IN by) (S (VP (VBG buying) (NP (NP (JJ big) (NNS blocks)) (PP (IN of) (NP (NN stock))))))))))))))) (, ,) (NP (NNS traders)) (VP (VBP say)) (. .))\n(S (NP (NP (JJ Heavy) (NN selling)) (PP (IN of) (NP (NP (NNP Standard) (CC &) (NNP Poor) (POS 's)) (JJ 500-stock) (NN index) (NNS futures))) (PP (IN in) (NP (NNP Chicago)))) (VP (ADVP (RB relentlessly)) (VBD beat) (NP (NNS stocks)) (ADVP (RB downward))) (. .))\n(S (NP (NP (CD Seven) (NNP Big) (NNP Board) (NNS stocks)) (: --) (NP (NP (NNP UAL)) (, ,) (NP (NNP AMR)) (, ,) (NP (NNP BankAmerica)) (, ,) (NP (NNP Walt) (NNP Disney)) (, ,) (NP (NNP Capital) (NNP Cities\\/ABC)) (, ,) (NP (NNP Philip) (NNP Morris)) (CC and) (NP (NNP Pacific) (NNP Telesis) (NNP Group))) (: --)) (VP (VP (VBD stopped) (S (VP (VBG trading)))) (CC and) (VP (ADVP (RB never)) (VBD resumed))) (. .))\n(S (NP (DT The) (NN finger-pointing)) (VP (VBZ has) (ADVP (RB already)) (VP (VBN begun))) (. .))\n(S (`` ``) (NP (DT The) (NN equity) (NN market)) (VP (VBD was) (ADJP (JJ illiquid))) (. .))\n(SINV (S (ADVP (RB Once) (RB again)) (-LRB- -LCB-) (NP (DT the) (NNS specialists)) (-RRB- -RCB-) (VP (VBD were) (RB not) (ADJP (JJ able) (S (VP (TO to) (VP (VB handle) (NP (NP (DT the) (NNS imbalances)) (PP (IN on) (NP (NP (DT the) (NN floor)) (PP (IN of) (NP (DT the) (NNP New) (NNP York) (NNP Stock) (NNP Exchange)))))))))))) (, ,) ('' '') (VP (VBD said)) (NP (NP (NNP Christopher) (NNP Pedersen)) (, ,) (NP (NP (JJ senior) (NN vice) (NN president)) (PP (IN at) (NP (NNP Twenty-First) (NNP Securities) (NNP Corp))))) (. .))\n(SINV (VP (VBD Countered)) (NP (NP (NNP James) (NNP Maguire)) (, ,) (NP (NP (NN chairman)) (PP (IN of) (NP (NNS specialists) (NNP Henderson) (NNP Brothers) (NNP Inc.))))) (: :) (`` ``) (S (NP (PRP It)) (VP (VBZ is) (ADJP (JJ easy)) (S (VP (TO to) (VP (VB say) (SBAR (S (NP (DT the) (NN specialist)) (VP (VBZ is) (RB n't) (VP (VBG doing) (NP (PRP$ his) (NN job))))))))))) (. .))\n(S (SBAR (WHADVP (WRB When)) (S (NP (DT the) (NN dollar)) (VP (VBZ is) (PP (IN in) (NP (DT a) (NN free-fall)))))) (, ,) (NP (RB even) (JJ central) (NNS banks)) (VP (MD ca) (RB n't) (VP (VB stop) (NP (PRP it)))) (. .))\n(S (NP (NNS Speculators)) (VP (VBP are) (VP (VBG calling) (PP (IN for) (NP (NP (DT a) (NN degree)) (PP (IN of) (NP (NN liquidity))) (SBAR (WHNP (WDT that)) (S (VP (VBZ is) (RB not) (ADVP (RB there)) (PP (IN in) (NP (DT the) (NN market)))))))))) (. .) ('' ''))\n(S (NP (NP (JJ Many) (NN money) (NNS managers)) (CC and) (NP (DT some) (NNS traders))) (VP (VBD had) (ADVP (RB already)) (VP (VBN left) (NP (PRP$ their) (NNS offices)) (NP (RB early) (NNP Friday) (NN afternoon)) (PP (IN on) (NP (DT a) (JJ warm) (NN autumn) (NN day))) (: --) (SBAR (IN because) (S (NP (DT the) (NN stock) (NN market)) (VP (VBD was) (ADJP (RB so) (JJ quiet))))))) (. .))\n(S (RB Then) (PP (IN in) (NP (DT a) (NN lightning) (NN plunge))) (, ,) (NP (DT the) (NNP Dow) (NNP Jones) (NNS industrials)) (PP (IN in) (NP (QP (RB barely) (DT an)) (NN hour))) (VP (VBD surrendered) (NP (NP (QP (RB about) (DT a)) (JJ third)) (PP (IN of) (NP (NP (PRP$ their) (NNS gains)) (NP (DT this) (NN year))))) (, ,) (S (VP (VBG chalking) (PRT (RP up)) (NP (NP (DT a) (ADJP (ADJP (JJ 190.58-point)) (, ,) (CC or) (ADJP (CD 6.9) (NN %)) (, ,)) (NN loss)) (PP (IN on) (NP (DT the) (NN day)))) (PP (IN in) (NP (JJ gargantuan) (NN trading) (NN volume)))))) (. .))\n(S (NP (JJ Final-hour) (NN trading)) (VP (VBD accelerated) (PP (TO to) (NP (NP (QP (CD 108.1) (CD million)) (NNS shares)) (, ,) (NP (NP (DT a) (NN record)) (PP (IN for) (NP (DT the) (NNP Big) (NNP Board))))))) (. .))\n(S (PP (IN At) (NP (NP (DT the) (NN end)) (PP (IN of) (NP (DT the) (NN day))))) (, ,) (NP (QP (CD 251.2) (CD million)) (NNS shares)) (VP (VBD were) (VP (VBN traded))) (. .))\n(S (NP (DT The) (NNP Dow) (NNP Jones) (NNS industrials)) (VP (VBD closed) (PP (IN at) (NP (CD 2569.26)))) (. .))\n(S (NP (NP (DT The) (NNP Dow) (POS 's)) (NN decline)) (VP (VBD was) (ADJP (JJ second) (PP (IN in) (NP (NN point) (NNS terms))) (PP (ADVP (RB only)) (TO to) (NP (NP (DT the) (JJ 508-point) (NNP Black) (NNP Monday) (NN crash)) (SBAR (WHNP (WDT that)) (S (VP (VBD occurred) (NP (NNP Oct.) (CD 19) (, ,) (CD 1987))))))))) (. .))\n(S (PP (IN In) (NP (NN percentage) (NNS terms))) (, ,) (ADVP (RB however)) (, ,) (NP (NP (DT the) (NNP Dow) (POS 's)) (NN dive)) (VP (VBD was) (NP (NP (NP (DT the) (JJ 12th-worst)) (ADVP (RB ever))) (CC and) (NP (NP (DT the) (JJS sharpest)) (SBAR (IN since) (S (NP (DT the) (NN market)) (VP (VBD fell) (NP (NP (CD 156.83)) (, ,) (CC or) (NP (CD 8) (NN %))) (, ,) (PP (NP (DT a) (NN week)) (IN after) (NP (NNP Black) (NNP Monday))))))))) (. .))\n"
  },
  {
    "path": "data/train.txt",
    "content": "(S (PP (IN In) (NP (NP (DT an) (NNP Oct.) (CD 19) (NN review)) (PP (IN of) (NP (`` ``) (NP (DT The) (NN Misanthrope)) ('' '') (PP (IN at) (NP (NP (NNP Chicago) (POS 's)) (NNP Goodman) (NNP Theatre))))) (PRN (-LRB- -LRB-) (`` ``) (S (NP (VBN Revitalized) (NNS Classics)) (VP (VBP Take) (NP (DT the) (NN Stage)) (PP (IN in) (NP (NNP Windy) (NNP City))))) (, ,) ('' '') (NP (NN Leisure) (CC &) (NNS Arts)) (-RRB- -RRB-)))) (, ,) (NP (NP (NP (DT the) (NN role)) (PP (IN of) (NP (NNP Celimene)))) (, ,) (VP (VBN played) (PP (IN by) (NP (NNP Kim) (NNP Cattrall)))) (, ,)) (VP (VBD was) (VP (ADVP (RB mistakenly)) (VBN attributed) (PP (TO to) (NP (NNP Christina) (NNP Haag))))) (. .))\n(S (NP (NNP Ms.) (NNP Haag)) (VP (VBZ plays) (NP (NNP Elianti))) (. .))\n(S (NP (NNP Rolls-Royce) (NNP Motor) (NNPS Cars) (NNP Inc.)) (VP (VBD said) (SBAR (S (NP (PRP it)) (VP (VBZ expects) (S (NP (PRP$ its) (NNP U.S.) (NNS sales)) (VP (TO to) (VP (VB remain) (ADJP (JJ steady)) (PP (IN at) (NP (QP (IN about) (CD 1,200)) (NNS cars))) (PP (IN in) (NP (CD 1990)))))))))) (. .))\n(S (NP (DT The) (NN luxury) (NN auto) (NN maker)) (NP (JJ last) (NN year)) (VP (VBD sold) (NP (CD 1,214) (NNS cars)) (PP (IN in) (NP (DT the) (NNP U.S.)))))\n(S (NP (NP (NNP Howard) (NNP Mosher)) (, ,) (NP (NP (NN president)) (CC and) (NP (JJ chief) (NN executive) (NN officer))) (, ,)) (VP (VBD said) (SBAR (S (NP (PRP he)) (VP (VBZ anticipates) (NP (NP (NN growth)) (PP (IN for) (NP (DT the) (NN luxury) (NN auto) (NN maker))) (PP (PP (IN in) (NP (NNP Britain) (CC and) (NNP Europe))) (, ,) (CC and) (PP (IN in) (NP (ADJP (JJ Far) (JJ Eastern)) (NNS markets))))))))) (. .))\n(S (NP (NNP BELL) (NNP INDUSTRIES) (NNP Inc.)) (VP (VBD increased) (NP (PRP$ its) (NN quarterly)) (PP (TO to) (NP (CD 10) (NNS cents))) (PP (IN from) (NP (NP (CD seven) (NNS cents)) (NP (DT a) (NN share))))) (. .))\n(S (NP (DT The) (JJ new) (NN rate)) (VP (MD will) (VP (VB be) (ADJP (JJ payable) (NP (NNP Feb.) (CD 15))))) (. .))\n(S (NP (DT A) (NN record) (NN date)) (VP (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set)))) (. .))\n(S (NP (NP (NNP Bell)) (, ,) (VP (VBN based) (PP (IN in) (NP (NNP Los) (NNP Angeles)))) (, ,)) (VP (VBZ makes) (CC and) (VBZ distributes) (NP (UCP (JJ electronic) (, ,) (NN computer) (CC and) (NN building)) (NNS products))) (. .))\n(S (NP (NNS Investors)) (VP (VBP are) (VP (VBG appealing) (PP (TO to) (NP (DT the) (NNPS Securities) (CC and) (NNP Exchange) (NNP Commission))) (S (RB not) (VP (TO to) (VP (VB limit) (NP (NP (PRP$ their) (NN access)) (PP (TO to) (NP (NP (NN information)) (PP (IN about) (NP (NP (NN stock) (NNS purchases) (CC and) (NNS sales)) (PP (IN by) (NP (JJ corporate) (NNS insiders))))))))))))) (. .))\n(S (S (NP (DT A) (NNP SEC) (NN proposal) (S (VP (TO to) (VP (VB ease) (NP (NP (NN reporting) (NNS requirements)) (PP (IN for) (NP (DT some) (NN company) (NNS executives)))))))) (VP (MD would) (VP (VB undermine) (NP (NP (DT the) (NN usefulness)) (PP (IN of) (NP (NP (NN information)) (PP (IN on) (NP (NN insider) (NNS trades))))) (PP (IN as) (NP (DT a) (JJ stock-picking) (NN tool))))))) (, ,) (NP (NP (JJ individual) (NNS investors)) (CC and) (NP (JJ professional) (NN money) (NNS managers))) (VP (VBP contend)) (. .))\n(S (NP (PRP They)) (VP (VBP make) (NP (DT the) (NN argument)) (PP (IN in) (NP (NP (NNS letters)) (PP (TO to) (NP (DT the) (NN agency))) (PP (IN about) (NP (NP (NN rule) (NNS changes)) (VP (VBD proposed) (NP (DT this) (JJ past) (NN summer))) (SBAR (WHNP (IN that)) (, ,) (S (PP (IN among) (NP (JJ other) (NNS things))) (, ,) (VP (MD would) (VP (VB exempt) (NP (JJ many) (JJ middle-management) (NNS executives)) (PP (IN from) (S (VP (VBG reporting) (NP (NP (NNS trades)) (PP (IN in) (NP (NP (PRP$ their) (JJ own) (NNS companies) (POS ')) (NNS shares)))))))))))))))) (. .))\n(S (NP (DT The) (VBN proposed) (NNS changes)) (ADVP (RB also)) (VP (MD would) (VP (VB allow) (S (NP (NNS executives)) (VP (TO to) (VP (VB report) (NP (NP (NNS exercises)) (PP (IN of) (NP (NNS options)))) (ADVP (ADVP (RBR later)) (CC and) (ADVP (RBR less) (RB often)))))))) (. .))\n(S (NP (NP (JJ Many)) (PP (IN of) (NP (DT the) (NNS letters)))) (VP (VBP maintain) (SBAR (IN that) (S (S (NP (NN investor) (NN confidence)) (VP (VBZ has) (VP (VBN been) (VP (ADVP (RB so)) (VBN shaken) (PP (IN by) (NP (DT the) (CD 1987) (NN stock) (NN market) (NN crash))))))) (: --) (CC and) (S (NP (DT the) (NNS markets)) (ADVP (RB already)) (VP (ADVP (RB so)) (VBN stacked) (PP (IN against) (NP (DT the) (JJ little) (NN guy))))) (: --) (SBAR (IN that) (S (NP (NP (DT any) (NN decrease)) (PP (IN in) (NP (NP (NN information)) (PP (IN on) (NP (NN insider-trading) (NNS patterns)))))) (VP (MD might) (VP (VB prompt) (S (NP (NNS individuals)) (VP (TO to) (VP (VB get) (ADVP (RB out) (PP (IN of) (NP (NNS stocks)))) (ADVP (RB altogether)))))))))))) (. .))\n(SINV (`` ``) (S (NP (DT The) (NNP SEC)) (VP (VBZ has) (ADVP (RB historically)) (VP (VBN paid) (NP (NN obeisance)) (PP (TO to) (NP (NP (DT the) (NN ideal)) (PP (IN of) (NP (DT a) (JJ level) (NN playing) (NN field)))))))) (, ,) ('' '') (VP (VBD wrote)) (NP (NP (NNP Clyde) (NNP S.) (NNP McGregor)) (PP (IN of) (NP (NP (NNP Winnetka)) (, ,) (NP (NNP Ill.)) (, ,)))) (PP (IN in) (NP (NP (CD one)) (PP (IN of) (NP (NP (DT the) (CD 92) (NNS letters)) (SBAR (S (NP (DT the) (NN agency)) (VP (VBZ has) (VP (VBN received) (SBAR (IN since) (S (NP (DT the) (NNS changes)) (VP (VBD were) (VP (VBN proposed) (NP (NNP Aug.) (CD 17)))))))))))))) (. .))\n(S (`` ``) (ADVP (RB Apparently)) (NP (DT the) (NN commission)) (VP (VBD did) (RB not) (ADVP (RB really)) (VP (VB believe) (PP (IN in) (NP (DT this) (NN ideal))))) (. .) ('' ''))\n(S (ADVP (RB Currently)) (, ,) (NP (DT the) (NNS rules)) (VP (VBP force) (S (NP (NP (NNS executives)) (, ,) (NP (NNS directors)) (CC and) (NP (JJ other) (JJ corporate) (NNS insiders))) (VP (TO to) (VP (VB report) (NP (NP (NNS purchases) (CC and) (NNS sales)) (PP (IN of) (NP (NP (PRP$ their) (NNS companies) (POS ')) (NNS shares)))) (PP (IN within) (NP (NP (QP (IN about) (DT a)) (NN month)) (PP (IN after) (NP (DT the) (NN transaction))))))))) (. .))\n(S (CC But) (NP (NP (QP (IN about) (CD 25)) (NN %)) (PP (IN of) (NP (DT the) (NNS insiders)))) (, ,) (PP (VBG according) (PP (TO to) (NP (NNP SEC) (NNS figures)))) (, ,) (VP (VBP file) (NP (PRP$ their) (NNS reports)) (ADVP (RB late))) (. .))\n(SINV (S (NP (DT The) (NNS changes)) (VP (VBD were) (VP (VBN proposed) (PP (IN in) (NP (DT an) (NN effort) (S (VP (TO to) (VP (VP (VB streamline) (NP (JJ federal) (NN bureaucracy))) (CC and) (VP (VB boost) (NP (NP (NN compliance)) (PP (IN by) (NP (NP (DT the) (NNS executives)) (`` ``) (SBAR (WHNP (WP who)) (S (VP (VBP are) (ADVP (RB really)) (VP (VBG calling) (NP (DT the) (NNS shots)))))))))))))))))) (, ,) ('' '') (VP (VBD said)) (NP (NP (NNP Brian) (NNP Lane)) (, ,) (NP (NP (JJ special) (NN counsel)) (PP (IN at) (NP (NP (NP (NP (DT the) (NNP SEC) (POS 's)) (NN office)) (PP (IN of) (NP (NN disclosure) (NN policy)))) (, ,) (SBAR (WHNP (WDT which)) (S (VP (VBD proposed) (NP (DT the) (NNS changes))))))))) (. .))\n(S (S (S (NP (NP (NNS Investors)) (, ,) (NP (NN money) (NNS managers)) (CC and) (NP (JJ corporate) (NNS officials))) (VP (VBD had) (PP (IN until) (NP (NN today))) (S (VP (TO to) (VP (VB comment) (PP (IN on) (NP (DT the) (NNS proposals)))))))) (, ,) (CC and) (S (NP (DT the) (NN issue)) (VP (VBZ has) (VP (VBN produced) (NP (NP (JJR more) (NN mail)) (PP (IN than) (NP (NP (ADJP (RB almost) (DT any)) (JJ other) (NN issue)) (PP (IN in) (NP (NN memory)))))))))) (, ,) (NP (NNP Mr.) (NNP Lane)) (VP (VBD said)) (. .))\n"
  },
  {
    "path": "data/valid.txt",
    "content": "(S (NP (NP (DT The) (NN economy) (POS 's)) (NN temperature)) (VP (MD will) (VP (VB be) (VP (VBN taken) (PP (IN from) (NP (JJ several) (NN vantage) (NNS points))) (NP (DT this) (NN week)) (, ,) (PP (IN with) (NP (NP (NNS readings)) (PP (IN on) (NP (NP (NN trade)) (, ,) (NP (NN output)) (, ,) (NP (NN housing)) (CC and) (NP (NN inflation))))))))) (. .))\n(S (NP (DT The) (ADJP (RBS most) (JJ troublesome)) (NN report)) (VP (MD may) (VP (VB be) (NP (NP (DT the) (NNP August) (NN merchandise) (NN trade) (NN deficit)) (ADJP (JJ due) (ADVP (IN out)) (NP (NN tomorrow)))))) (. .))\n(S (NP (DT The) (NN trade) (NN gap)) (VP (VBZ is) (VP (VBN expected) (S (VP (TO to) (VP (VB widen) (PP (TO to) (NP (QP (IN about) ($ $) (CD 9) (CD billion)))) (PP (IN from) (NP (NP (NNP July) (POS 's)) (QP ($ $) (CD 7.6) (CD billion))))))) (, ,) (PP (VBG according) (PP (TO to) (NP (NP (DT a) (NN survey)) (PP (IN by) (NP (NP (NNP MMS) (NNP International)) (, ,) (NP (NP (DT a) (NN unit)) (PP (IN of) (NP (NP (NNP McGraw-Hill) (NNP Inc.)) (, ,) (NP (NNP New) (NNP York)))))))))))) (. .))\n(S (NP (NP (NP (NNP Thursday) (POS 's)) (NN report)) (PP (IN on) (NP (DT the) (NNP September) (NN consumer) (NN price) (NN index)))) (VP (VBZ is) (VP (VBN expected) (S (VP (TO to) (VP (VB rise) (, ,) (SBAR (IN although) (ADVP (ADVP (RB not) (RB as) (RB sharply)) (PP (IN as) (NP (NP (DT the) (ADJP (CD 0.9) (NN %)) (NN gain)) (VP (VBN reported) (NP (NNP Friday)) (PP (IN in) (NP (DT the) (NN producer) (NN price) (NN index))))))))))))) (. .))\n(S (NP (DT That) (NN gain)) (VP (VBD was) (VP (VBG being) (VP (VBD cited) (PP (IN as) (NP (NP (DT a) (NN reason)) (SBAR (S (NP (DT the) (NN stock) (NN market)) (VP (VBD was) (ADVP (IN down)) (ADVP (RB early) (PP (IN in) (NP (NP (NNP Friday) (POS 's)) (NN session)))) (, ,) (SBAR (IN before) (S (NP (PRP it)) (VP (VBD got) (S (VP (VBN started) (PP (IN on) (NP (PRP$ its) (JJ reckless) (JJ 190-point) (NN plunge)))))))))))))))) (. .))\n(S (NP (NNS Economists)) (VP (VBP are) (VP (VBN divided) (PP (IN as) (PP (TO to) (SBAR (WHNP (WHADVP (WRB how) (JJ much)) (VBG manufacturing) (NN strength)) (S (NP (PRP they)) (VP (VBP expect) (S (VP (TO to) (VP (VB see) (PP (IN in) (NP (NP (NP (NNP September) (NNS reports)) (PP (IN on) (NP (NP (JJ industrial) (NN production)) (CC and) (NP (NN capacity) (NN utilization))))) (, ,) (ADJP (ADVP (RB also)) (JJ due) (NP (NN tomorrow))))))))))))))) (. .))\n(S (ADVP (RB Meanwhile)) (, ,) (NP (NP (NNP September) (NN housing) (NNS starts)) (, ,) (ADJP (JJ due) (NP (NNP Wednesday))) (, ,)) (VP (VBP are) (VP (VBN thought) (S (VP (TO to) (VP (VB have) (VP (VBN inched) (ADVP (RB upward)))))))) (. .))\n(SINV (S (`` ``) (NP (EX There)) (VP (VBZ 's) (NP (NP (DT a) (NN possibility)) (PP (IN of) (NP (NP (DT a) (NN surprise)) ('' '') (PP (IN in) (NP (DT the) (NN trade) (NN report)))))))) (, ,) (VP (VBD said)) (NP (NP (NNP Michael) (NNP Englund)) (, ,) (NP (NP (NN director)) (PP (IN of) (NP (NN research))) (PP (IN at) (NP (NNP MMS))))) (. .))\n(S (S (NP (NP (DT A) (NN widening)) (PP (IN of) (NP (DT the) (NN deficit)))) (, ,) (SBAR (IN if) (S (NP (PRP it)) (VP (VBD were) (VP (VBN combined) (PP (IN with) (NP (DT a) (ADJP (RB stubbornly) (JJ strong)) (NN dollar))))))) (, ,) (VP (MD would) (VP (VB exacerbate) (NP (NN trade) (NNS problems)))) (: --)) (CC but) (S (NP (DT the) (NN dollar)) (VP (VBD weakened) (NP (NNP Friday)) (SBAR (IN as) (S (NP (NNS stocks)) (VP (VBD plummeted)))))) (. .))\n(S (PP (IN In) (NP (DT any) (NN event))) (, ,) (NP (NP (NNP Mr.) (NNP Englund)) (CC and) (NP (JJ many) (NNS others))) (VP (VBP say) (SBAR (IN that) (S (NP (NP (DT the) (JJ easy) (NNS gains)) (PP (IN in) (S (VP (VBG narrowing) (NP (DT the) (NN trade) (NN gap)))))) (VP (VBP have) (ADVP (RB already)) (VP (VBN been) (VP (VBN made))))))) (. .))\n(S (`` ``) (S (NP (NN Trade)) (VP (VBZ is) (ADVP (RB definitely)) (VP (VBG going) (S (VP (TO to) (VP (VB be) (ADJP (RBR more) (RB politically) (JJ sensitive)) (PP (IN over) (NP (DT the) (JJ next) (QP (CD six) (CC or) (CD seven)) (NNS months))) (SBAR (IN as) (S (NP (NN improvement)) (VP (VBZ begins) (S (VP (TO to) (VP (VB slow))))))))))))) (, ,) ('' '') (NP (PRP he)) (VP (VBD said)) (. .))\n(S (S (NP (NNS Exports)) (VP (VBP are) (VP (VBN thought) (S (VP (TO to) (VP (VB have) (VP (VBN risen) (ADVP (ADVP (RB strongly) (PP (IN in) (NP (NNP August)))) (, ,) (CC but) (ADVP (ADVP (RB probably)) (RB not) (RB enough) (S (VP (TO to) (VP (VB offset) (NP (NP (DT the) (NN jump)) (PP (IN in) (NP (NNS imports)))))))))))))))) (, ,) (NP (NNS economists)) (VP (VBD said)) (. .))\n(S (NP (NP (NNS Views)) (PP (IN on) (NP (VBG manufacturing) (NN strength)))) (VP (VBP are) (ADJP (VBN split) (PP (IN between) (NP (NP (NP (NNS economists)) (SBAR (WHNP (WP who)) (S (VP (VBP read) (NP (NP (NP (NNP September) (POS 's)) (JJ low) (NN level)) (PP (IN of) (NP (NN factory) (NN job) (NN growth)))) (PP (IN as) (NP (NP (DT a) (NN sign)) (PP (IN of) (NP (DT a) (NN slowdown))))))))) (CC and) (NP (NP (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBP use) (NP (DT the) (ADJP (RB somewhat) (JJR more) (VBG comforting)) (JJ total) (NN employment) (NNS figures)) (PP (IN in) (NP (PRP$ their) (NNS calculations))))))))))) (. .))\n(S (S (NP (NP (DT The) (JJ wide) (NN range)) (PP (IN of) (NP (NP (NNS estimates)) (PP (IN for) (NP (DT the) (JJ industrial) (NN output) (NN number)))))) (VP (VBZ underscores) (NP (DT the) (NNS differences)))) (: :) (S (NP (DT The) (NNS forecasts)) (VP (VBD run) (PP (IN from) (NP (NP (DT a) (NN drop)) (PP (IN of) (NP (CD 0.5) (NN %))))) (PP (TO to) (NP (NP (DT an) (NN increase)) (PP (IN of) (NP (CD 0.4) (NN %))))) (, ,) (PP (VBG according) (PP (TO to) (NP (NNP MMS)))))) (. .))\n(S (NP (NP (DT A) (NN rebound)) (PP (IN in) (NP (NN energy) (NNS prices))) (, ,) (SBAR (WHNP (WDT which)) (S (VP (VBD helped) (VP (VB push) (PRT (RP up)) (NP (DT the) (NN producer) (NN price) (NN index)))))) (, ,)) (VP (VBZ is) (VP (VBN expected) (S (VP (TO to) (VP (VB do) (NP (DT the) (JJ same)) (PP (IN in) (NP (DT the) (NN consumer) (NN price) (NN report)))))))) (. .))\n(S (NP (DT The) (NN consensus) (NN view)) (VP (VBZ expects) (NP (NP (DT a) (ADJP (CD 0.4) (NN %)) (NN increase)) (PP (IN in) (NP (DT the) (NNP September) (NNP CPI)))) (PP (IN after) (NP (NP (DT a) (JJ flat) (NN reading)) (PP (IN in) (NP (NNP August)))))) (. .))\n(S (NP (NP (NNP Robert) (NNP H.) (NNP Chandross)) (, ,) (NP (NP (DT an) (NN economist)) (PP (IN for) (NP (NP (NP (NNP Lloyd) (POS 's)) (NNP Bank)) (PP (IN in) (NP (NNP New) (NNP York)))))) (, ,)) (VP (VBZ is) (PP (IN among) (NP (NP (DT those)) (VP (VBG expecting) (NP (NP (DT a) (ADJP (RBR more) (JJ moderate)) (NN gain)) (PP (IN in) (NP (DT the) (NNP CPI))) (PP (IN than) (PP (IN in) (NP (NP (NNS prices)) (PP (IN at) (NP (DT the) (NN producer) (NN level))))))))))) (. .))\n(S (`` ``) (S (S (NP (NN Auto) (NNS prices)) (VP (VBD had) (NP (DT a) (JJ big) (NN effect)) (PP (IN in) (NP (DT the) (NNP PPI))))) (, ,) (CC and) (S (PP (IN at) (NP (DT the) (NNP CPI) (NN level))) (NP (PRP they)) (VP (MD wo) (RB n't)))) (, ,) ('' '') (NP (PRP he)) (VP (VBD said)) (. .))\n(SINV (S (S (NP (NN Food) (NNS prices)) (VP (VBP are) (VP (VBN expected) (S (VP (TO to) (VP (VB be) (ADJP (JJ unchanged)))))))) (, ,) (CC but) (S (NP (NN energy) (NNS costs)) (VP (VBD jumped) (NP (NP (RB as) (RB much) (IN as) (CD 4)) (NN %))))) (, ,) (VP (VBD said)) (NP (NP (NNP Gary) (NNP Ciminero)) (, ,) (NP (NP (NN economist)) (PP (IN at) (NP (NNP Fleet\\/Norstar) (NNP Financial) (NNP Group))))) (. .))\n(S (NP (PRP He)) (ADVP (RB also)) (VP (VBZ says) (SBAR (S (NP (PRP he)) (VP (VBZ thinks) (SBAR (S (NP (`` ``) (NP (NN core) (NN inflation)) (, ,) ('' '') (SBAR (WHNP (WDT which)) (S (VP (VBZ excludes) (NP (DT the) (JJ volatile) (NN food) (CC and) (NN energy) (NNS prices))))) (, ,)) (VP (VBD was) (ADJP (JJ strong)) (NP (JJ last) (NN month))))))))) (. .))\n"
  },
  {
    "path": "data.py",
    "content": "#!/usr/bin/env python3\nimport numpy as np\nimport torch\nimport pickle\n\nclass Dataset(object):\n  def __init__(self, data_file):\n    data = pickle.load(open(data_file, 'rb')) #get text data\n    self.sents = self._convert(data['source']).long()\n    self.other_data = data['other_data']\n    self.sent_lengths = self._convert(data['source_l']).long()\n    self.batch_size = self._convert(data['batch_l']).long()\n    self.batch_idx = self._convert(data['batch_idx']).long()\n    self.vocab_size = data['vocab_size'][0]\n    self.num_batches = self.batch_idx.size(0)\n    self.word2idx = data['word2idx']\n    self.idx2word = data['idx2word']\n\n  def _convert(self, x):\n    return torch.from_numpy(np.asarray(x))\n\n  def __len__(self):\n    return self.num_batches\n\n  def __getitem__(self, idx):\n    assert(idx < self.num_batches and idx >= 0)\n    start_idx = self.batch_idx[idx]\n    end_idx = start_idx + self.batch_size[idx]\n    length = self.sent_lengths[idx].item()\n    sents = self.sents[start_idx:end_idx]\n    other_data = self.other_data[start_idx:end_idx]\n    sent_str = [d[0] for d in other_data]\n    tags = [d[1] for d in other_data]\n    actions = [d[2] for d in other_data]\n    binary_tree = [d[3] for d in other_data]\n    spans = [d[5] for d in other_data]\n    batch_size = self.batch_size[idx].item()\n    # by default, we return sents with <s> </s> tokens\n    # hence we subtract 2 from length as these are (by default) not counted for evaluation\n    data_batch = [sents[:, :length], length-2, batch_size, actions, \n                  spans, binary_tree, other_data]\n    return data_batch\n"
  },
  {
    "path": "eval_ppl.py",
    "content": "#!/usr/bin/env python3\nimport sys\nimport os\n\nimport argparse\nimport json\nimport random\nimport shutil\nimport copy\n\nimport torch\nfrom torch import cuda\nimport torch.nn as nn\nfrom torch.autograd import Variable\nfrom torch.nn.parameter import Parameter\n\nimport torch.nn.functional as F\nimport numpy as np\nimport time\nimport logging\nfrom data import Dataset\nfrom models import RNNG\nfrom utils import *\n\nparser = argparse.ArgumentParser()\n\n# Data path options\nparser.add_argument('--test_file', default='data/ptb-test.pkl')\nparser.add_argument('--model_file', default='')\nparser.add_argument('--is_temp', default=2., type=float, help='divide scores by is_temp before CRF')\nparser.add_argument('--samples', default=1000, type=int, help='samples for IS calculation')\nparser.add_argument('--count_eos_ppl', default=0, type=int, help='whether to count eos in val PPL')\nparser.add_argument('--gpu', default=2, type=int, help='which gpu to use')\nparser.add_argument('--seed', default=3435, type=int)\n\n\ndef main(args):\n  np.random.seed(args.seed)\n  torch.manual_seed(args.seed)\n  data = Dataset(args.test_file)  \n  checkpoint = torch.load(args.model_file)\n  model = checkpoint['model']\n  print(\"model architecture\")\n  print(model)\n  cuda.set_device(args.gpu)\n  model.cuda()\n  model.eval()\n  num_sents = 0\n  num_words = 0\n  total_nll_recon = 0.\n  total_kl = 0.\n  total_nll_iwae = 0.\n  samples_batch = 50\n  S = args.samples // samples_batch  \n  samples = S*samples_batch\n  with torch.no_grad():\n    for i in list(reversed(range(len(data)))):\n      sents, length, batch_size, gold_actions, gold_spans, gold_binary_trees, other_data = data[i] \n      if length == 1:\n        # length 1 sents are ignored since our generative model requires sents of length >= 2\n        continue\n      if args.count_eos_ppl == 1:\n        length += 1\n      else:\n        sents = sents[:, :-1]\n      sents = sents.cuda()\n      ll_word_all2 = [] \n      ll_action_p_all2 = [] \n      ll_action_q_all2 = [] \n      for j in range(S):                    \n        ll_word_all, ll_action_p_all, ll_action_q_all, actions_all, q_entropy = model(\n          sents, samples = samples_batch, is_temp = args.is_temp, has_eos = args.count_eos_ppl == 1)\n        ll_word_all2.append(ll_word_all.detach().cpu())\n        ll_action_p_all2.append(ll_action_p_all.detach().cpu())\n        ll_action_q_all2.append(ll_action_q_all.detach().cpu())\n      ll_word_all2 = torch.cat(ll_word_all2, 1)\n      ll_action_p_all2 = torch.cat(ll_action_p_all2, 1)\n      ll_action_q_all2 = torch.cat(ll_action_q_all2, 1)\n      sample_ll = torch.zeros(batch_size, ll_word_all2.size(1))\n      total_nll_recon += -ll_word_all.mean(1).sum().item()\n      total_kl += (ll_action_q_all - ll_action_p_all).mean(1).sum().item()\n      for j in range(sample_ll.size(1)):\n        ll_word_j, ll_action_p_j, ll_action_q_j = ll_word_all2[:, j], ll_action_p_all2[:, j], ll_action_q_all2[:, j]\n        sample_ll[:, j].copy_(ll_word_j + ll_action_p_j - ll_action_q_j)\n      ll_iwae = model.logsumexp(sample_ll, 1) - np.log(samples)\n      total_nll_iwae -= ll_iwae.sum().item()      \n      num_sents += batch_size\n      num_words += batch_size * length\n      \n      print('Batch: %d/%d, ElboPPL: %.2f, KL: %.4f, IwaePPL: %.2f' % \n            (i, len(data), np.exp((total_nll_recon + total_kl) / num_words),\n            total_kl / num_sents, np.exp(total_nll_iwae / num_words)))\n  elbo_ppl = np.exp((total_nll_recon + total_kl) / num_words)\n  recon_ppl = np.exp(total_nll_recon / num_words)\n  iwae_ppl = np.exp(total_nll_iwae /num_words)\n  kl = total_kl / num_sents  \n  print('ElboPPL: %.2f, ReconPPL: %.2f, KL: %.4f, IwaePPL: %.2f' % \n        (elbo_ppl, recon_ppl, kl, iwae_ppl))\n\nif __name__ == '__main__':\n  args = parser.parse_args()\n  main(args)\n"
  },
  {
    "path": "models.py",
    "content": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\nimport numpy as np\nfrom utils import *\nfrom TreeCRF import ConstituencyTreeCRF\nfrom torch.distributions import Bernoulli\n\nclass RNNLM(nn.Module):\n  def __init__(self, vocab=10000,\n               w_dim=650,\n               h_dim=650,\n               num_layers=2,\n               dropout=0.5):\n    super(RNNLM, self).__init__()\n    self.h_dim = h_dim\n    self.num_layers = num_layers    \n    self.word_vecs = nn.Embedding(vocab, w_dim)\n    self.dropout = nn.Dropout(dropout)\n    self.rnn = nn.LSTM(w_dim, h_dim, num_layers = num_layers,\n                       dropout = dropout, batch_first = True)      \n    self.vocab_linear =  nn.Linear(h_dim, vocab)\n    self.vocab_linear.weight = self.word_vecs.weight # weight sharing\n\n  def forward(self, sent):\n    word_vecs = self.dropout(self.word_vecs(sent[:, :-1]))\n    h, _ = self.rnn(word_vecs)\n    log_prob = F.log_softmax(self.vocab_linear(self.dropout(h)), 2) # b x l x v\n    ll = torch.gather(log_prob, 2, sent[:, 1:].unsqueeze(2)).squeeze(2)\n    return ll.sum(1)\n  \n  def generate(self, bos = 2, eos = 3, max_len = 150):\n    x = []\n    bos = torch.LongTensor(1,1).cuda().fill_(bos)\n    emb = self.dropout(self.word_vecs(bos))\n    prev_h = None\n    for l in range(max_len):\n      h, prev_h = self.rnn(emb, prev_h)\n      prob = F.softmax(self.vocab_linear(self.dropout(h.squeeze(1))), 1)\n      sample = torch.multinomial(prob, 1)\n      emb = self.dropout(self.word_vecs(sample))\n      x.append(sample.item())\n      if x[-1] == eos:\n        x.pop()\n        break\n    return x\n\nclass SeqLSTM(nn.Module):\n  def __init__(self, i_dim = 200,\n               h_dim = 0,\n               num_layers = 1,\n               dropout = 0):\n    super(SeqLSTM, self).__init__()    \n    self.i_dim = i_dim\n    self.h_dim = h_dim\n    self.num_layers = num_layers\n    self.linears = nn.ModuleList([nn.Linear(h_dim + i_dim, h_dim*4) if l == 0 else\n                                  nn.Linear(h_dim*2, h_dim*4) for l in range(num_layers)])\n    self.dropout = dropout\n    self.dropout_layer = nn.Dropout(dropout)\n\n  def forward(self, x, prev_h = None):\n    if prev_h is None:\n      prev_h = [(x.new(x.size(0), self.h_dim).fill_(0),\n                 x.new(x.size(0), self.h_dim).fill_(0)) for _ in range(self.num_layers)]\n    curr_h = []\n    for l in range(self.num_layers):\n      input = x if l == 0 else curr_h[l-1][0]\n      if l > 0 and self.dropout > 0:\n        input = self.dropout_layer(input)\n      concat = torch.cat([input, prev_h[l][0]], 1)\n      all_sum = self.linears[l](concat)\n      i, f, o, g = all_sum.split(self.h_dim, 1)\n      c = F.sigmoid(f)*prev_h[l][1] + F.sigmoid(i)*F.tanh(g)\n      h = F.sigmoid(o)*F.tanh(c)\n      curr_h.append((h, c))\n    return curr_h\n\nclass TreeLSTM(nn.Module):\n  def __init__(self, dim = 200):\n    super(TreeLSTM, self).__init__()\n    self.dim = dim\n    self.linear = nn.Linear(dim*2, dim*5)\n\n  def forward(self, x1, x2, e=None):\n    if not isinstance(x1, tuple):\n      x1 = (x1, None)    \n    h1, c1 = x1 \n    if x2 is None: \n      x2 = (torch.zeros_like(h1), torch.zeros_like(h1))\n    elif not isinstance(x2, tuple):\n      x2 = (x2, None)    \n    h2, c2 = x2\n    if c1 is None:\n      c1 = torch.zeros_like(h1)\n    if c2 is None:\n      c2 = torch.zeros_like(h2)\n    concat = torch.cat([h1, h2], 1)\n    all_sum = self.linear(concat)\n    i, f1, f2, o, g = all_sum.split(self.dim, 1)\n\n    c = F.sigmoid(f1)*c1 + F.sigmoid(f2)*c2 + F.sigmoid(i)*F.tanh(g)\n    h = F.sigmoid(o)*F.tanh(c)\n    return (h, c)\n      \nclass RNNG(nn.Module):\n  def __init__(self, vocab = 100,\n               w_dim = 20, \n               h_dim = 20,\n               num_layers = 1,\n               dropout = 0,\n               q_dim = 20,\n               max_len = 250):\n    super(RNNG, self).__init__()\n    self.S = 0 #action idx for shift/generate\n    self.R = 1 #action idx for reduce\n    self.emb = nn.Embedding(vocab, w_dim)\n    self.dropout = nn.Dropout(dropout)    \n    self.stack_rnn = SeqLSTM(w_dim, h_dim, num_layers = num_layers, dropout = dropout)\n    self.tree_rnn = TreeLSTM(w_dim)\n    self.vocab_mlp = nn.Sequential(nn.Dropout(dropout), nn.Linear(h_dim, vocab))\n    self.num_layers = num_layers\n    self.q_binary = nn.Sequential(nn.Linear(q_dim*2, q_dim*2), nn.ReLU(), nn.LayerNorm(q_dim*2),\n                                  nn.Dropout(dropout), nn.Linear(q_dim*2, 1))\n    self.action_mlp_p = nn.Sequential(nn.Dropout(dropout), nn.Linear(h_dim, 1))\n    self.w_dim = w_dim\n    self.h_dim = h_dim\n    self.q_dim = q_dim    \n    self.q_leaf_rnn = nn.LSTM(w_dim, q_dim, bidirectional = True, batch_first = True)\n    self.q_crf = ConstituencyTreeCRF()\n    self.pad1 = 0 # idx for <s> token from ptb.dict\n    self.pad2 = 2 # idx for </s> token from ptb.dict \n    self.q_pos_emb = nn.Embedding(max_len, w_dim) # position embeddings\n    self.vocab_mlp[-1].weight = self.emb.weight #share embeddings\n\n  def get_span_scores(self, x):\n    #produces the span scores s_ij\n    bos = x.new(x.size(0), 1).fill_(self.pad1)\n    eos  = x.new(x.size(0), 1).fill_(self.pad2)\n    x = torch.cat([bos, x, eos], 1)\n    x_vec = self.dropout(self.emb(x))\n    pos = torch.arange(0, x.size(1)).unsqueeze(0).expand_as(x).long().cuda()\n    x_vec = x_vec + self.dropout(self.q_pos_emb(pos))\n    q_h, _ = self.q_leaf_rnn(x_vec)\n    fwd = q_h[:, 1:, :self.q_dim]\n    bwd = q_h[:, :-1, self.q_dim:]\n    fwd_diff = fwd[:, 1:].unsqueeze(1) - fwd[:, :-1].unsqueeze(2)\n    bwd_diff = bwd[:, :-1].unsqueeze(2) - bwd[:, 1:].unsqueeze(1)\n    concat = torch.cat([fwd_diff, bwd_diff], 3)\n    scores = self.q_binary(concat).squeeze(3)\n    return scores\n\n  def get_action_masks(self, actions, length):\n    #this masks out actions so that we don't incur a loss if some actions are deterministic\n    #in practice this doesn't really seem to matter\n    mask = actions.new(actions.size(0), actions.size(1)).fill_(1)\n    for b in range(actions.size(0)):      \n      num_shift = 0\n      stack_len = 0\n      for l in range(actions.size(1)):\n        if stack_len < 2:\n          mask[b][l].fill_(0)\n        if actions[b][l].item() == self.S:\n          num_shift += 1\n          stack_len += 1\n        else:\n          stack_len -= 1\n    return mask\n\n  def forward(self, x, samples = 1, is_temp = 1., has_eos=True):\n    #For has eos, if </s> exists, then inference network ignores it. \n    #Note that </s> is predicted for training since we want the model to know when to stop.\n    #However it is ignored for PPL evaluation on the version of the PTB dataset from\n    #the original RNNG paper (Dyer et al. 2016)\n    init_emb = self.dropout(self.emb(x[:, 0]))\n    x = x[:, 1:]\n    batch, length = x.size(0), x.size(1)\n    if has_eos: \n      parse_length = length - 1\n      parse_x = x[:, :-1]\n    else:\n      parse_length = length\n      parse_x = x\n    word_vecs =  self.dropout(self.emb(x))\n    scores = self.get_span_scores(parse_x)\n    self.scores = scores\n    scores = scores / is_temp\n    self.q_crf._forward(scores)\n    self.q_crf._entropy(scores)\n    entropy = self.q_crf.entropy[0][parse_length-1]\n    crf_input = scores.unsqueeze(1).expand(batch, samples, parse_length, parse_length)\n    crf_input = crf_input.contiguous().view(batch*samples, parse_length, parse_length)\n    for i in range(len(self.q_crf.alpha)):\n      for j in range(len(self.q_crf.alpha)):\n        self.q_crf.alpha[i][j] = self.q_crf.alpha[i][j].unsqueeze(1).expand(\n          batch, samples).contiguous().view(batch*samples)        \n    _, log_probs_action_q, tree_brackets, spans = self.q_crf._sample(crf_input, self.q_crf.alpha)\n    actions = []\n    for b in range(crf_input.size(0)):    \n      action = get_actions(tree_brackets[b])\n      if has_eos:\n        actions.append(action + [self.S, self.R]) #we train the model to generate <s> and then do a final reduce\n      else:\n        actions.append(action)\n    actions = torch.Tensor(actions).float().cuda()\n    action_masks = self.get_action_masks(actions, length) \n    num_action = 2*length - 1\n    batch_expand = batch*samples\n    contexts = []\n    log_probs_action_p = [] #conditional prior\n    init_emb = init_emb.unsqueeze(1).expand(batch, samples, self.w_dim)\n    init_emb = init_emb.contiguous().view(batch_expand, self.w_dim)\n    init_stack = self.stack_rnn(init_emb, None)\n    x_expand = x.unsqueeze(1).expand(batch, samples, length)\n    x_expand = x_expand.contiguous().view(batch_expand, length)\n    word_vecs = self.dropout(self.emb(x_expand))\n    word_vecs = word_vecs.unsqueeze(2)\n    word_vecs_zeros = torch.zeros_like(word_vecs)\n    stack = [init_stack]\n    stack_child = [[] for _ in range(batch_expand)]\n    stack2 = [[] for _ in range(batch_expand)]\n    for b in range(batch_expand):\n      stack2[b].append([[init_stack[l][0][b], init_stack[l][1][b]] for l in range(self.num_layers)])\n    pointer = [0]*batch_expand\n    for l in range(num_action):\n      contexts.append(stack[-1][-1][0])\n      stack_input = []\n      child1_h = []\n      child1_c = []\n      child2_h = []\n      child2_c = []\n      stack_context = []\n      for b in range(batch_expand):\n        # batch all the shift/reduce operations separately\n        if actions[b][l].item() == self.R:\n          child1 = stack_child[b].pop()\n          child2 = stack_child[b].pop()\n          child1_h.append(child1[0])\n          child1_c.append(child1[1])\n          child2_h.append(child2[0])\n          child2_c.append(child2[1])\n          stack2[b].pop()\n          stack2[b].pop()\n      if len(child1_h) > 0:\n        child1_h = torch.cat(child1_h, 0)\n        child1_c = torch.cat(child1_c, 0)\n        child2_h = torch.cat(child2_h, 0)\n        child2_c = torch.cat(child2_c, 0)\n        new_child = self.tree_rnn((child1_h, child1_c), (child2_h, child2_c))\n\n      child_idx = 0\n      stack_h = [[[], []] for _ in range(self.num_layers)]\n      for b in range(batch_expand):\n        assert(len(stack2[b]) - 1 == len(stack_child[b]))\n        for k in range(self.num_layers):\n          stack_h[k][0].append(stack2[b][-1][k][0])\n          stack_h[k][1].append(stack2[b][-1][k][1])\n        if actions[b][l].item() == self.S:          \n          input_b = word_vecs[b][pointer[b]]\n          stack_child[b].append((word_vecs[b][pointer[b]], word_vecs_zeros[b][pointer[b]]))\n          pointer[b] += 1          \n        else:\n          input_b = new_child[0][child_idx].unsqueeze(0)\n          stack_child[b].append((input_b, new_child[1][child_idx].unsqueeze(0)))\n          child_idx += 1\n        stack_input.append(input_b)\n      stack_input = torch.cat(stack_input, 0)\n      stack_h_all = []\n      for k in range(self.num_layers):\n        stack_h_all.append((torch.stack(stack_h[k][0], 0), torch.stack(stack_h[k][1], 0)))\n      stack_h = self.stack_rnn(stack_input, stack_h_all)\n      stack.append(stack_h)\n      for b in range(batch_expand):\n        stack2[b].append([[stack_h[k][0][b], stack_h[k][1][b]] for k in range(self.num_layers)])\n      \n    contexts = torch.stack(contexts, 1) #stack contexts\n    action_logit_p = self.action_mlp_p(contexts).squeeze(2) \n    action_prob_p = F.sigmoid(action_logit_p).clamp(min=1e-7, max=1-1e-7)\n    action_shift_score = (1 - action_prob_p).log()\n    action_reduce_score = action_prob_p.log()\n    action_score = (1-actions)*action_shift_score + actions*action_reduce_score\n    action_score = (action_score*action_masks).sum(1)\n    \n    word_contexts = contexts[actions < 1]\n    word_contexts = word_contexts.contiguous().view(batch*samples, length, self.h_dim)\n\n    log_probs_word = F.log_softmax(self.vocab_mlp(word_contexts), 2)\n    log_probs_word = torch.gather(log_probs_word, 2, x_expand.unsqueeze(2)).squeeze(2)\n    log_probs_word = log_probs_word.sum(1)\n    log_probs_word = log_probs_word.contiguous().view(batch, samples)\n    log_probs_action_p = action_score.contiguous().view(batch, samples)\n    log_probs_action_q = log_probs_action_q.contiguous().view(batch, samples)\n    actions = actions.contiguous().view(batch, samples, -1)\n    return log_probs_word, log_probs_action_p, log_probs_action_q, actions, entropy\n\n  def forward_actions(self, x, actions, has_eos=True):\n    # this is for when ground through actions are available\n    init_emb = self.dropout(self.emb(x[:, 0]))\n    x = x[:, 1:]    \n    if has_eos:\n      new_actions = []\n      for action in actions:\n        new_actions.append(action + [self.S, self.R])\n      actions = new_actions\n    batch, length = x.size(0), x.size(1)\n    word_vecs =  self.dropout(self.emb(x))\n    actions = torch.Tensor(actions).float().cuda()\n    action_masks = self.get_action_masks(actions, length)\n    num_action = 2*length - 1\n    contexts = []\n    log_probs_action_p = [] #prior\n    init_stack = self.stack_rnn(init_emb, None)\n    word_vecs = word_vecs.unsqueeze(2)\n    word_vecs_zeros = torch.zeros_like(word_vecs)\n    stack = [init_stack]\n    stack_child = [[] for _ in range(batch)]\n    stack2 = [[] for _ in range(batch)]\n    pointer = [0]*batch\n    for b in range(batch):\n      stack2[b].append([[init_stack[l][0][b], init_stack[l][1][b]] for l in range(self.num_layers)])\n    for l in range(num_action):\n      contexts.append(stack[-1][-1][0])\n      stack_input = []\n      child1_h = []\n      child1_c = []\n      child2_h = []\n      child2_c = []\n      stack_context = []\n      for b in range(batch):\n        if actions[b][l].item() == self.R:\n          child1 = stack_child[b].pop()\n          child2 = stack_child[b].pop()\n          child1_h.append(child1[0])\n          child1_c.append(child1[1])\n          child2_h.append(child2[0])\n          child2_c.append(child2[1])\n          stack2[b].pop()\n          stack2[b].pop()\n      if len(child1_h) > 0:\n        child1_h = torch.cat(child1_h, 0)\n        child1_c = torch.cat(child1_c, 0)\n        child2_h = torch.cat(child2_h, 0)\n        child2_c = torch.cat(child2_c, 0)\n        new_child = self.tree_rnn((child1_h, child1_c), (child2_h, child2_c))\n      child_idx = 0\n      stack_h = [[[], []] for _ in range(self.num_layers)]\n      for b in range(batch):\n        assert(len(stack2[b]) - 1 == len(stack_child[b]))\n        for k in range(self.num_layers):\n          stack_h[k][0].append(stack2[b][-1][k][0])\n          stack_h[k][1].append(stack2[b][-1][k][1])\n        if actions[b][l].item() == self.S:          \n          input_b = word_vecs[b][pointer[b]]\n          stack_child[b].append((word_vecs[b][pointer[b]], word_vecs_zeros[b][pointer[b]]))\n          pointer[b] += 1          \n        else:\n          input_b = new_child[0][child_idx].unsqueeze(0)\n          stack_child[b].append((input_b, new_child[1][child_idx].unsqueeze(0)))\n          child_idx += 1\n        stack_input.append(input_b)\n      stack_input = torch.cat(stack_input, 0)\n      stack_h_all = []\n      for k in range(self.num_layers):\n        stack_h_all.append((torch.stack(stack_h[k][0], 0), torch.stack(stack_h[k][1], 0)))\n      stack_h = self.stack_rnn(stack_input, stack_h_all)\n      stack.append(stack_h)\n      for b in range(batch):\n        stack2[b].append([[stack_h[k][0][b], stack_h[k][1][b]] for k in range(self.num_layers)])\n    contexts = torch.stack(contexts, 1)\n    action_logit_p = self.action_mlp_p(contexts).squeeze(2)\n    action_prob_p = F.sigmoid(action_logit_p).clamp(min=1e-7, max=1-1e-7)\n    action_shift_score = (1 - action_prob_p).log()\n    action_reduce_score = action_prob_p.log()\n    action_score = (1-actions)*action_shift_score + actions*action_reduce_score\n    action_score = (action_score*action_masks).sum(1)\n    \n    word_contexts = contexts[actions < 1]\n    word_contexts = word_contexts.contiguous().view(batch, length, self.h_dim)\n    log_probs_word = F.log_softmax(self.vocab_mlp(word_contexts), 2)\n    log_probs_word = torch.gather(log_probs_word, 2, x.unsqueeze(2)).squeeze(2).sum(1)\n    log_probs_action_p = action_score.contiguous().view(batch)\n    actions = actions.contiguous().view(batch, 1, -1)\n    return log_probs_word, log_probs_action_p, actions\n  \n  def forward_tree(self, x, actions, has_eos=True):\n    # this is log q( tree | x) for discriminative parser training in supervised RNNG\n    init_emb = self.dropout(self.emb(x[:, 0]))\n    x = x[:, 1:-1]\n    batch, length = x.size(0), x.size(1)\n    scores = self.get_span_scores(x)\n    crf_input = scores\n    gold_spans = scores.new(batch, length, length)\n    for b in range(batch):\n      gold_spans[b].copy_(torch.eye(length).cuda())\n      spans = get_spans(actions[b])\n      for span in spans:\n        gold_spans[b][span[0]][span[1]] = 1\n    self.q_crf._forward(crf_input)\n    log_Z = self.q_crf.alpha[0][length-1]\n    span_scores = (gold_spans*scores).sum(2).sum(1)\n    ll_action_q = span_scores - log_Z\n    return ll_action_q\n    \n  def logsumexp(self, x, dim=1):\n    d = torch.max(x, dim)[0]    \n    if x.dim() == 1:\n      return torch.log(torch.exp(x - d).sum(dim)) + d\n    else:\n      return torch.log(torch.exp(x - d.unsqueeze(dim).expand_as(x)).sum(dim)) + d    \n    \n"
  },
  {
    "path": "parse.py",
    "content": "#!/usr/bin/env python3\nimport sys\nimport os\n\nimport argparse\nimport json\nimport random\nimport shutil\nimport copy\n\nimport torch\nfrom torch import cuda\nimport torch.nn as nn\nimport numpy as np\nimport time\nfrom utils import *\nimport utils\nimport re\n\nparser = argparse.ArgumentParser()\n\n# Data path options\nparser.add_argument('--data_file', default='ptb-test.txt')\nparser.add_argument('--model_file', default='urnng.pt')\nparser.add_argument('--out_file', default='pred-parse.txt')\nparser.add_argument('--gold_out_file', default='gold-parse.txt')\nparser.add_argument('--lowercase', type=int, default=0)\nparser.add_argument('--replace_num', type=int, default=0)\n\n# Inference options\nparser.add_argument('--gpu', default=0, type=int, help='which gpu to use')\n\ndef is_next_open_bracket(line, start_idx):\n    for char in line[(start_idx + 1):]:\n        if char == '(':\n            return True\n        elif char == ')':\n            return False\n    raise IndexError('Bracket possibly not balanced, open bracket not followed by closed bracket')    \n\ndef get_between_brackets(line, start_idx):\n    output = []\n    for char in line[(start_idx + 1):]:\n        if char == ')':\n            break\n        assert not(char == '(')\n        output.append(char)    \n    return ''.join(output)\n\ndef get_tags_tokens_lowercase(line):\n    output = []\n    line_strip = line.rstrip()\n    for i in range(len(line_strip)):\n        if i == 0:\n            assert line_strip[i] == '('    \n        if line_strip[i] == '(' and not(is_next_open_bracket(line_strip, i)): # fulfilling this condition means this is a terminal symbol\n            output.append(get_between_brackets(line_strip, i))\n    #print 'output:',output\n    output_tags = []\n    output_tokens = []\n    output_lowercase = []\n    for terminal in output:\n        terminal_split = terminal.split()\n        assert len(terminal_split) == 2 # each terminal contains a POS tag and word        \n        output_tags.append(terminal_split[0])\n        output_tokens.append(terminal_split[1])\n        output_lowercase.append(terminal_split[1].lower())\n    return [output_tags, output_tokens, output_lowercase]    \n\ndef get_nonterminal(line, start_idx):\n    assert line[start_idx] == '(' # make sure it's an open bracket\n    output = []\n    for char in line[(start_idx + 1):]:\n        if char == ' ':\n            break\n        assert not(char == '(') and not(char == ')')\n        output.append(char)\n    return ''.join(output)\n\n\ndef get_actions(line):\n    output_actions = []\n    line_strip = line.rstrip()\n    i = 0\n    max_idx = (len(line_strip) - 1)\n    while i <= max_idx:\n        assert line_strip[i] == '(' or line_strip[i] == ')'\n        if line_strip[i] == '(':\n            if is_next_open_bracket(line_strip, i): # open non-terminal\n                curr_NT = get_nonterminal(line_strip, i)\n                output_actions.append('NT(' + curr_NT + ')')\n                i += 1  \n                while line_strip[i] != '(': # get the next open bracket, which may be a terminal or another non-terminal\n                    i += 1\n            else: # it's a terminal symbol\n                output_actions.append('SHIFT')\n                while line_strip[i] != ')':\n                    i += 1\n                i += 1\n                while line_strip[i] != ')' and line_strip[i] != '(':\n                    i += 1\n        else:\n             output_actions.append('REDUCE')\n             if i == max_idx:\n                 break\n             i += 1\n             while line_strip[i] != ')' and line_strip[i] != '(':\n                 i += 1\n    assert i == max_idx  \n    return output_actions\n\ndef clean_number(w):    \n    new_w = re.sub('[0-9]{1,}([,.]?[0-9]*)*', 'N', w)\n    return new_w\n  \ndef main(args):\n  print('loading model from ' + args.model_file)\n  checkpoint = torch.load(args.model_file)\n  model = checkpoint['model']\n  word2idx = checkpoint['word2idx']\n  cuda.set_device(args.gpu)\n  model.eval()\n  model.cuda()\n  corpus_f1 = [0., 0., 0.] \n  sent_f1 = [] \n  pred_out = open(args.out_file, \"w\")\n  gold_out = open(args.gold_out_file, \"w\")\n  with torch.no_grad():\n    for j, gold_tree in enumerate(open(args.data_file, \"r\")):\n      tree = gold_tree.strip()\n      action = get_actions(tree)\n      tags, sent, sent_lower = get_tags_tokens_lowercase(tree)\n      sent_orig = sent[::]\n      if args.lowercase == 1:\n          sent = sent_lower\n      gold_span, binary_actions, nonbinary_actions = get_nonbinary_spans(action)\n      length = len(sent)\n      if args.replace_num == 1:\n          sent = [clean_number(w) for w in sent]\n      if length == 1:\n        continue # we ignore length 1 sents. this doesn't change F1 since we discard trivial spans\n      sent_idx = [word2idx[\"<s>\"]] + [word2idx[w] if w in word2idx else word2idx[\"<unk>\"] for w in sent]\n      sents = torch.from_numpy(np.array(sent_idx)).unsqueeze(0)\n      sents = sents.cuda()\n      ll_word_all, ll_action_p_all, ll_action_q_all, actions_all, q_entropy = model(\n          sents, samples = 1, is_temp = 1, has_eos = False)\n      _, binary_matrix, argmax_spans = model.q_crf._viterbi(model.scores)\n      tree = get_tree_from_binary_matrix(binary_matrix[0], len(sent))\n      actions = utils.get_actions(tree)\n      pred_span= [(a[0], a[1]) for a in argmax_spans[0]]\n      pred_span_set = set(pred_span[:-1]) #the last span in the list is always the\n      gold_span_set = set(gold_span[:-1]) #trival sent-level span so we ignore it\n      tp, fp, fn = get_stats(pred_span_set, gold_span_set) \n      corpus_f1[0] += tp\n      corpus_f1[1] += fp\n      corpus_f1[2] += fn\n      binary_matrix = binary_matrix[0].cpu().numpy()\n      pred_tree = {}\n      for i in range(length):\n        tag = tags[i] # need gold tags so evalb correctly ignores punctuation\n        pred_tree[i] = \"(\" + tag + \" \" + sent_orig[i] + \")\"\n      for k in np.arange(1, length):\n        for s in np.arange(length):\n          t = s + k\n          if t > length - 1: break\n          if binary_matrix[s][t] == 1:\n            nt = \"NT-1\" \n            span = \"(\" + nt + \" \" + pred_tree[s] + \" \" + pred_tree[t] +  \")\"\n            pred_tree[s] = span\n            pred_tree[t] = span\n      pred_tree = pred_tree[0]\n      pred_out.write(pred_tree.strip() + \"\\n\")\n      gold_out.write(gold_tree.strip() + \"\\n\")\n      print(pred_tree)\n      # sent-level F1 is based on L83-89 from https://github.com/yikangshen/PRPN/test_phrase_grammar.py\n      overlap = pred_span_set.intersection(gold_span_set)\n      prec = float(len(overlap)) / (len(pred_span_set) + 1e-8)\n      reca = float(len(overlap)) / (len(gold_span_set) + 1e-8)\n      if len(gold_span_set) == 0:\n        reca = 1. \n        if len(pred_span_set) == 0:\n          prec = 1.\n      f1 = 2 * prec * reca / (prec + reca + 1e-8)\n      sent_f1.append(f1)\n  pred_out.close()\n  gold_out.close()\n  tp, fp, fn = corpus_f1  \n  prec = tp / (tp + fp)\n  recall = tp / (tp + fn)\n  corpus_f1 = 2*prec*recall/(prec+recall) if prec+recall > 0 else 0.\n  print('Corpus F1: %.2f, Sentence F1: %.2f' %\n        (corpus_f1*100, np.mean(np.array(sent_f1))*100))\n\nif __name__ == '__main__':\n  args = parser.parse_args()\n  main(args)\n"
  },
  {
    "path": "preprocess.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n\n\"\"\"Create data files\n\"\"\"\n\nimport os\nimport sys\nimport argparse\nimport numpy as np\nimport pickle\nimport itertools\nfrom collections import defaultdict\nimport utils\nimport re\n\nclass Indexer:\n    def __init__(self, symbols = [\"<pad>\",\"<unk>\",\"<s>\",\"</s>\"]):\n        self.vocab = defaultdict(int)\n        self.PAD = symbols[0]\n        self.UNK = symbols[1]\n        self.BOS = symbols[2]\n        self.EOS = symbols[3]\n        self.d = {self.PAD: 0, self.UNK: 1, self.BOS: 2, self.EOS: 3}\n        self.idx2word = {}\n        \n    def add_w(self, ws):\n        for w in ws:\n            if w not in self.d:\n                self.d[w] = len(self.d)\n\n    def convert(self, w):\n        return self.d[w] if w in self.d else self.d[self.UNK]\n\n    def convert_sequence(self, ls):\n        return [self.convert(l) for l in ls]\n\n    def write(self, outfile):\n        out = open(outfile, \"w\")\n        items = [(v, k) for k, v in self.d.items()]\n        items.sort()\n        for v, k in items:\n            out.write(\" \".join([k, str(v)]) + \"\\n\")\n        out.close()\n\n    def prune_vocab(self, k, cnt = False):\n        vocab_list = [(word, count) for word, count in self.vocab.items()]\n        if cnt:\n            self.pruned_vocab = {pair[0]:pair[1] for pair in vocab_list if pair[1] > k}\n        else:\n            vocab_list.sort(key = lambda x: x[1], reverse=True)\n            k = min(k, len(vocab_list))\n            self.pruned_vocab = {pair[0]:pair[1] for pair in vocab_list[:k]}\n        for word in self.pruned_vocab:\n            if word not in self.d:\n                self.d[word] = len(self.d)\n        for word, idx in self.d.items():\n            self.idx2word[idx] = word\n\n    def load_vocab(self, vocab_file):\n        self.d = {}\n        self.idx2word = {}\n        for line in open(vocab_file, 'r'):\n            v, k = line.strip().split()\n            self.d[v] = int(k)\n        for word, idx in self.d.items():\n            self.idx2word[idx] = word\n\n\ndef is_next_open_bracket(line, start_idx):\n    for char in line[(start_idx + 1):]:\n        if char == '(':\n            return True\n        elif char == ')':\n            return False\n    raise IndexError('Bracket possibly not balanced, open bracket not followed by closed bracket')    \n\ndef get_between_brackets(line, start_idx):\n    output = []\n    for char in line[(start_idx + 1):]:\n        if char == ')':\n            break\n        assert not(char == '(')\n        output.append(char)    \n    return ''.join(output)\n\ndef get_tags_tokens_lowercase(line):\n    output = []\n    line_strip = line.rstrip()\n    for i in range(len(line_strip)):\n        if i == 0:\n            assert line_strip[i] == '('    \n        if line_strip[i] == '(' and not(is_next_open_bracket(line_strip, i)): # fulfilling this condition means this is a terminal symbol\n            output.append(get_between_brackets(line_strip, i))\n    #print 'output:',output\n    output_tags = []\n    output_tokens = []\n    output_lowercase = []\n    for terminal in output:\n        terminal_split = terminal.split()\n        # print(terminal, terminal_split)\n        assert len(terminal_split) == 2 # each terminal contains a POS tag and word        \n        output_tags.append(terminal_split[0])\n        output_tokens.append(terminal_split[1])\n        output_lowercase.append(terminal_split[1].lower())\n    return [output_tags, output_tokens, output_lowercase]    \n\ndef get_nonterminal(line, start_idx):\n    assert line[start_idx] == '(' # make sure it's an open bracket\n    output = []\n    for char in line[(start_idx + 1):]:\n        if char == ' ':\n            break\n        assert not(char == '(') and not(char == ')')\n        output.append(char)\n    return ''.join(output)\n\n\ndef get_actions(line):\n    output_actions = []\n    line_strip = line.rstrip()\n    i = 0\n    max_idx = (len(line_strip) - 1)\n    while i <= max_idx:\n        assert line_strip[i] == '(' or line_strip[i] == ')'\n        if line_strip[i] == '(':\n            if is_next_open_bracket(line_strip, i): # open non-terminal\n                curr_NT = get_nonterminal(line_strip, i)\n                output_actions.append('NT(' + curr_NT + ')')\n                i += 1  \n                while line_strip[i] != '(': # get the next open bracket, which may be a terminal or another non-terminal\n                    i += 1\n            else: # it's a terminal symbol\n                output_actions.append('SHIFT')\n                while line_strip[i] != ')':\n                    i += 1\n                i += 1\n                while line_strip[i] != ')' and line_strip[i] != '(':\n                    i += 1\n        else:\n             output_actions.append('REDUCE')\n             if i == max_idx:\n                 break\n             i += 1\n             while line_strip[i] != ')' and line_strip[i] != '(':\n                 i += 1\n    assert i == max_idx  \n    return output_actions\n\ndef pad(ls, length, symbol):\n    if len(ls) >= length:\n        return ls[:length]\n    return ls + [symbol] * (length -len(ls))\n\ndef clean_number(w):    \n    new_w = re.sub('[0-9]{1,}([,.]?[0-9]*)*', 'N', w)\n    return new_w\n\ndef get_data(args):\n    indexer = Indexer([\"<pad>\",\"<unk>\",\"<s>\",\"</s>\"])\n\n    def make_vocab(textfile, seqlength, minseqlength, lowercase, replace_num,\n                   train=1, apply_length_filter=1):\n        num_sents = 0\n        max_seqlength = 0\n        for tree in open(textfile, 'r'):\n            tree = tree.strip()\n            tags, sent, sent_lower = get_tags_tokens_lowercase(tree)\n            \n            assert(len(tags) == len(sent))\n            if lowercase == 1:\n                sent = sent_lower\n            if replace_num == 1:\n                sent = [clean_number(w) for w in sent]\n            if (len(sent) > seqlength and apply_length_filter == 1) or len(sent) < minseqlength:\n                continue\n            num_sents += 1\n            max_seqlength = max(max_seqlength, len(sent))\n            if train == 1:\n                for word in sent:\n                    indexer.vocab[word] += 1\n        return num_sents, max_seqlength\n\n    def convert(textfile, lowercase, replace_num,  \n                batchsize, seqlength, minseqlength, outfile, num_sents, max_sent_l=0,\n                shuffle=0, include_boundary=1, apply_length_filter=1):\n        newseqlength = seqlength\n        if include_boundary == 1:\n            newseqlength += 2 #add 2 for EOS and BOS\n        sents = np.zeros((num_sents, newseqlength), dtype=int)\n        sent_lengths = np.zeros((num_sents,), dtype=int)\n        dropped = 0\n        sent_id = 0\n        other_data = []\n        for tree in open(textfile, 'r'):\n            tree = tree.strip()\n            action = get_actions(tree)\n            tags, sent, sent_lower = get_tags_tokens_lowercase(tree)\n            assert(len(tags) == len(sent))\n            if lowercase == 1:\n                sent = sent_lower\n            if (len(sent) > seqlength and apply_length_filter == 1) or len(sent) < minseqlength:\n                continue\n            sent_str = \" \".join(sent)\n            if replace_num == 1:\n                sent = [clean_number(w) for w in sent]\n            if include_boundary == 1:\n                sent = [indexer.BOS] + sent + [indexer.EOS]\n            max_sent_l = max(len(sent), max_sent_l)\n            sent_pad = pad(sent, newseqlength, indexer.PAD)\n            sents[sent_id] = np.array(indexer.convert_sequence(sent_pad), dtype=int)\n            sent_lengths[sent_id] = (sents[sent_id] != 0).sum()\n            span, binary_actions, nonbinary_actions = utils.get_nonbinary_spans(action)\n            other_data.append([sent_str, tags, action, \n                               binary_actions, nonbinary_actions, span, tree])\n            assert(2*(len(sent)- 2) - 1 == len(binary_actions))\n            assert(sum(binary_actions) + 1 == len(sent) - 2)\n            sent_id += 1\n            if sent_id % 100000 == 0:\n                print(\"{}/{} sentences processed\".format(sent_id, num_sents))\n        print(sent_id, num_sents)\n        if shuffle == 1:\n            rand_idx = np.random.permutation(sent_id)\n            sents = sents[rand_idx]\n            sent_lengths = sent_lengths[rand_idx]\n            other_data = [other_data[idx] for idx in rand_idx]\n\n        print(len(sents), len(other_data))\n        #break up batches based on source lengths\n        sent_lengths = sent_lengths[:sent_id]\n        sent_sort = np.argsort(sent_lengths)\n        sents = sents[sent_sort]\n        other_data = [other_data[idx] for idx in sent_sort]\n        sent_l = sent_lengths[sent_sort]\n        curr_l = 1\n        l_location = [] #idx where sent length changes\n\n        for j,i in enumerate(sent_sort):\n            if sent_lengths[i] > curr_l:\n                curr_l = sent_lengths[i]\n                l_location.append(j)\n        l_location.append(len(sents))\n        #get batch sizes\n        curr_idx = 0\n        batch_idx = [0]\n        nonzeros = []\n        batch_l = []\n        batch_w = []\n        for i in range(len(l_location)-1):\n            while curr_idx < l_location[i+1]:\n                curr_idx = min(curr_idx + batchsize, l_location[i+1])\n                batch_idx.append(curr_idx)\n        for i in range(len(batch_idx)-1):\n            batch_l.append(batch_idx[i+1] - batch_idx[i])\n            batch_w.append(sent_l[batch_idx[i]])\n\n        # Write output\n        f = {}\n        f[\"source\"] = sents\n        f[\"other_data\"] = other_data\n        f[\"batch_l\"] = np.array(batch_l, dtype=int)\n        f[\"source_l\"] = np.array(batch_w, dtype=int)\n        f[\"sents_l\"]  = np.array(sent_l, dtype = int)\n        f[\"batch_idx\"] = np.array(batch_idx[:-1], dtype=int)\n        f[\"vocab_size\"] = np.array([len(indexer.d)])\n        f[\"idx2word\"] = indexer.idx2word\n        f[\"word2idx\"] = {word : idx for idx, word in indexer.idx2word.items()}\n        \n        print(\"Saved {} sentences (dropped {} due to length/unk filter)\".format(\n            len(f[\"source\"]), dropped))\n        pickle.dump(f, open(outfile, 'wb'))\n        return max_sent_l\n\n    print(\"First pass through data to get vocab...\")\n    num_sents_train, train_seqlength = make_vocab(args.trainfile, args.seqlength, args.minseqlength,\n                                                  args.lowercase, args.replace_num, 1, 1)\n    print(\"Number of sentences in training: {}\".format(num_sents_train))\n    num_sents_valid, valid_seqlength = make_vocab(args.valfile, args.seqlength, args.minseqlength, \n                                                  args.lowercase, args.replace_num, 0, 0)\n    print(\"Number of sentences in valid: {}\".format(num_sents_valid))\n    num_sents_test, test_seqlength = make_vocab(args.testfile, args.seqlength, args.minseqlength, \n                                                args.lowercase, args.replace_num, 0, 0)\n    print(\"Number of sentences in test: {}\".format(num_sents_test))\n\n    if args.vocabminfreq >= 0:\n        indexer.prune_vocab(args.vocabminfreq, True)        \n    else:\n        indexer.prune_vocab(args.vocabsize, False)\n    if args.vocabfile != '':\n        print('Loading pre-specified source vocab from ' + args.vocabfile)\n        indexer.load_vocab(args.vocabfile)\n    indexer.write(args.outputfile + \".dict\")\n    print(\"Vocab size: Original = {}, Pruned = {}\".format(len(indexer.vocab),\n                                                          len(indexer.d)))\n    print(train_seqlength, valid_seqlength, test_seqlength)\n    max_sent_l = 0\n    max_sent_l = convert(args.testfile, args.lowercase, args.replace_num, \n                         args.batchsize, test_seqlength, args.minseqlength, \n                         args.outputfile + \"-test.pkl\", num_sents_test,\n                         max_sent_l, args.shuffle, args.include_boundary, 0)\n    max_sent_l = convert(args.valfile, args.lowercase, args.replace_num, \n                         args.batchsize, valid_seqlength, args.minseqlength, \n                         args.outputfile + \"-val.pkl\", num_sents_valid,\n                         max_sent_l, args.shuffle, args.include_boundary, 0)\n    max_sent_l = convert(args.trainfile, args.lowercase, args.replace_num, \n                         args.batchsize, args.seqlength,  args.minseqlength,\n                         args.outputfile + \"-train.pkl\", num_sents_train,\n                         max_sent_l, args.shuffle, args.include_boundary, 1)\n    print(\"Max sent length (before dropping): {}\".format(max_sent_l))\n\ndef main(arguments):\n    parser = argparse.ArgumentParser(\n        description=__doc__,\n        formatter_class=argparse.ArgumentDefaultsHelpFormatter)\n    parser.add_argument('--vocabsize', help=\"Size of source vocabulary, constructed \"\n                                                \"by taking the top X most frequent words. \"\n                                                \" Rest are replaced with special UNK tokens.\",\n                                                type=int, default=10000)\n    parser.add_argument('--vocabminfreq', help=\"Minimum frequency for vocab. Use this instead of \"\n                                                \"vocabsize if > 0\",\n                                                type=int, default=1)\n    parser.add_argument('--include_boundary', help=\"Add BOS/EOS tokens\", type=int, default=1)        \n    parser.add_argument('--lowercase', help=\"Lower case\", type=int, default=0)        \n    parser.add_argument('--replace_num', help=\"Replace numbers with N\", type=int, default=0)        \n    parser.add_argument('--trainfile', help=\"Path to training data.\", required=True)\n    parser.add_argument('--valfile', help=\"Path to validation data.\", required=True)\n    parser.add_argument('--testfile', help=\"Path to test validation data.\", required=True)\n    parser.add_argument('--batchsize', help=\"Size of each minibatch.\", type=int, default=16)\n    parser.add_argument('--seqlength', help=\"Maximum sequence length. Sequences longer \"\n                                               \"than this are dropped.\", type=int, default=200)\n    parser.add_argument('--minseqlength', help=\"Minimum sequence length. Sequences shorter \"\n                                               \"than this are dropped.\", type=int, default=0)\n    parser.add_argument('--outputfile', help=\"Prefix of the output file names. \", type=str,\n                        required=True)\n    parser.add_argument('--vocabfile', help=\"If working with a preset vocab, \"\n                                          \"then including this will ignore srcvocabsize and use the\"\n                                          \"vocab provided here.\",\n                                          type = str, default='')\n    parser.add_argument('--shuffle', help=\"If = 1, shuffle sentences before sorting (based on  \"\n                                           \"source length).\",\n                                          type = int, default = 1)\n    args = parser.parse_args(arguments)\n    np.random.seed(3435)\n    get_data(args)\n\nif __name__ == '__main__':\n    sys.exit(main(sys.argv[1:]))\n"
  },
  {
    "path": "train.py",
    "content": "#!/usr/bin/env python3\nimport sys\nimport os\n\nimport argparse\nimport json\nimport random\nimport shutil\nimport copy\n\nimport torch\nfrom torch import cuda\nimport torch.nn as nn\nfrom torch.autograd import Variable\nfrom torch.nn.parameter import Parameter\n\nimport torch.nn.functional as F\nimport numpy as np\nimport time\nimport logging\nfrom data import Dataset\nfrom models import RNNG\nfrom utils import *\n\nparser = argparse.ArgumentParser()\n\n# Data path options\nparser.add_argument('--train_file', default='data/ptb-1unk-train.pkl')\nparser.add_argument('--val_file', default='data/ptb-1unk-val.pkl')\nparser.add_argument('--train_from', default='')\n# Model options\nparser.add_argument('--w_dim', default=650, type=int, help='hidden dimension for LM/RNNG')\nparser.add_argument('--h_dim', default=650, type=int, help='hidden dimension for LM/RNNG')\nparser.add_argument('--q_dim', default=256, type=int, help='hidden dimension for variational RNN')\nparser.add_argument('--num_layers', default=2, type=int, help='number of layers in LM and the stack LSTM (for RNNG)')\nparser.add_argument('--dropout', default=0.5, type=float, help='dropout rate')\n# Optimization options\nparser.add_argument('--count_eos_ppl', default=0, type=int, help='whether to count eos in val PPL')\nparser.add_argument('--save_path', default='urnng.pt', help='where to save the data')\nparser.add_argument('--num_epochs', default=18, type=int, help='number of training epochs')\nparser.add_argument('--min_epochs', default=8, type=int, help='do not decay learning rate for at least this many epochs')\nparser.add_argument('--mode', default='unsupervised', type=str, choices=['unsupervised', 'supervised'])\nparser.add_argument('--mc_samples', default=5, type=int, \n                    help='how many samples for IWAE bound calc for evaluation')\nparser.add_argument('--samples', default=8, type=int, \n                    help='how many samples for score function gradients')\nparser.add_argument('--lr', default=1, type=float, help='starting learning rate')\nparser.add_argument('--q_lr', default=0.0001, type=float, help='learning rate for inference network q')\nparser.add_argument('--action_lr', default=0.1, type=float, help='learning rate for action layer')\nparser.add_argument('--decay', default=0.5, type=float, help='')\nparser.add_argument('--kl_warmup', default=2, type=int, help='')\nparser.add_argument('--train_q_epochs', default=2, type=int, help='')\nparser.add_argument('--param_init', default=0.1, type=float, help='parameter initialization (over uniform)')\nparser.add_argument('--max_grad_norm', default=5, type=float, help='gradient clipping parameter')\nparser.add_argument('--q_max_grad_norm', default=1, type=float, help='gradient clipping parameter for q')\nparser.add_argument('--gpu', default=2, type=int, help='which gpu to use')\nparser.add_argument('--seed', default=3435, type=int, help='random seed')\nparser.add_argument('--print_every', type=int, default=500, help='print stats after this many batches')\n\n\ndef main(args):\n  np.random.seed(args.seed)\n  torch.manual_seed(args.seed)\n  train_data = Dataset(args.train_file)\n  val_data = Dataset(args.val_file)  \n  vocab_size = int(train_data.vocab_size)    \n  print('Train: %d sents / %d batches, Val: %d sents / %d batches' % \n        (train_data.sents.size(0), len(train_data), val_data.sents.size(0), \n         len(val_data)))\n  print('Vocab size: %d' % vocab_size)\n  cuda.set_device(args.gpu)\n  if args.train_from == '':\n    model = RNNG(vocab = vocab_size,\n                 w_dim = args.w_dim, \n                 h_dim = args.h_dim,\n                 dropout = args.dropout,\n                 num_layers = args.num_layers,\n                 q_dim = args.q_dim)\n    if args.param_init > 0:\n      for param in model.parameters():    \n        param.data.uniform_(-args.param_init, args.param_init)      \n  else:\n    print('loading model from ' + args.train_from)\n    checkpoint = torch.load(args.train_from)\n    model = checkpoint['model']\n  print(\"model architecture\")\n  print(model)\n  q_params = []\n  action_params = []\n  model_params = []\n  for name, param in model.named_parameters():    \n    if 'action' in name:\n      print(name)\n      action_params.append(param)\n    elif 'q_' in name:\n      print(name)\n      q_params.append(param)\n    else:\n      model_params.append(param)\n  q_lr = args.q_lr\n  optimizer = torch.optim.SGD(model_params, lr=args.lr)\n  q_optimizer = torch.optim.Adam(q_params, lr=q_lr)\n  action_optimizer = torch.optim.SGD(action_params, lr=args.action_lr)\n  model.train()\n  model.cuda()\n\n  epoch = 0\n  decay= 0\n  if args.kl_warmup > 0:\n    kl_pen = 0.\n    kl_warmup_batch = 1./(args.kl_warmup * len(train_data))\n  else:\n    kl_pen = 1.\n  best_val_ppl = 5e5\n  best_val_f1 = 0\n  samples = args.samples\n  best_val_ppl, best_val_f1 = eval(val_data, model, samples = args.mc_samples, \n                                   count_eos_ppl = args.count_eos_ppl)\n  all_stats = [[0., 0., 0.]] #true pos, false pos, false neg for f1 calc\n  while epoch < args.num_epochs:\n    start_time = time.time()\n    epoch += 1  \n    if epoch > args.train_q_epochs:\n      #stop training q after this many epochs\n      args.q_lr = 0.\n      for param_group in q_optimizer.param_groups:\n        param_group['lr'] = args.q_lr\n    print('Starting epoch %d' % epoch)\n    train_nll_recon = 0.\n    train_nll_iwae = 0.\n    train_kl = 0.\n    train_q_entropy = 0.\n    num_sents = 0.\n    num_words = 0.\n    b = 0\n    for i in np.random.permutation(len(train_data)):\n      if args.kl_warmup > 0:\n        kl_pen = min(1., kl_pen + kl_warmup_batch) \n      sents, length, batch_size, gold_actions, gold_spans, gold_binary_trees, other_data = train_data[i]      \n      if length == 1:\n        # we ignore length 1 sents during training/eval since we work with binary trees only\n        continue\n      sents = sents.cuda()\n      b += 1\n      q_optimizer.zero_grad()\n      optimizer.zero_grad()\n      action_optimizer.zero_grad()\n      if args.mode == 'unsupervised':\n        ll_word, ll_action_p, ll_action_q, all_actions, q_entropy = model(sents, samples=samples, \n                                                                          has_eos = True)\n        log_f = ll_word + kl_pen*ll_action_p\n        iwae_ll = log_f.mean(1).detach() + kl_pen*q_entropy.detach()\n        obj = log_f.mean(1)\n        if epoch < args.train_q_epochs:\n          obj += kl_pen*q_entropy\n          baseline = torch.zeros_like(log_f)\n          baseline_k = torch.zeros_like(log_f)\n          for k in range(samples):\n            baseline_k.copy_(log_f)\n            baseline_k[:, k].fill_(0)\n            baseline[:, k] =  baseline_k.detach().sum(1) / (samples - 1)        \n          obj += ((log_f.detach() - baseline.detach())*ll_action_q).mean(1)                      \n        kl = (ll_action_q - ll_action_p).mean(1).detach()\n        ll_word = ll_word.mean(1)\n        train_q_entropy += q_entropy.sum().item()\n      else:\n        gold_actions = gold_binary_trees\n        ll_action_q = model.forward_tree(sents, gold_actions, has_eos=True)        \n        ll_word, ll_action_p, all_actions = model.forward_actions(sents, gold_actions)\n        obj = ll_word + ll_action_p + ll_action_q\n        kl = -ll_action_q\n        iwae_ll = ll_word + ll_action_p\n      train_nll_iwae += -iwae_ll.sum().item()\n      actions = all_actions[:, 0].long().cpu()\n      train_nll_recon += -ll_word.sum().item()\n      train_kl += kl.sum().item()\n      (-obj.mean()).backward()      \n      if args.max_grad_norm > 0:\n        torch.nn.utils.clip_grad_norm_(model_params + action_params, args.max_grad_norm)        \n      if args.q_max_grad_norm > 0:\n        torch.nn.utils.clip_grad_norm_(q_params, args.q_max_grad_norm)        \n      q_optimizer.step()\n      optimizer.step()\n      action_optimizer.step()\n      num_sents += batch_size\n      num_words += batch_size * length\n      for bb in range(batch_size):\n        action = list(actions[bb].numpy())\n        span_b = get_spans(action)\n        span_b_set = set(span_b[:-1]) #ignore the sentence-level trivial span\n        update_stats(span_b_set, [set(gold_spans[bb][:-1])], all_stats)\n      if b % args.print_every == 0:\n        all_f1 = get_f1(all_stats)\n        param_norm = sum([p.norm()**2 for p in model.parameters()]).item()**0.5\n        log_str = 'Epoch: %d, Batch: %d/%d, LR: %.4f, qLR: %.5f, qEnt: %.4f, TrainVAEPPL: %.2f, ' + \\\n                  'TrainReconPPL: %.2f, TrainKL: %.2f, TrainIWAEPPL: %.2f, ' + \\\n                  '|Param|: %.2f, BestValPerf: %.2f, BestValF1: %.2f, KLPen: %.4f, ' + \\\n                  'GoldTreeF1: %.2f, Throughput: %.2f examples/sec'\n        print(log_str %\n              (epoch, b, len(train_data), args.lr, args.q_lr, train_q_entropy / num_sents, \n               np.exp((train_nll_recon + train_kl)/ num_words),\n               np.exp(train_nll_recon/num_words), train_kl / num_sents, \n               np.exp(train_nll_iwae/num_words),\n               param_norm, best_val_ppl, best_val_f1, kl_pen, \n               all_f1[0], num_sents / (time.time() - start_time)))\n        sent_str = [train_data.idx2word[word_idx] for word_idx in list(sents[-1][1:-1].cpu().numpy())]\n        print(\"PRED:\", get_tree(action[:-2], sent_str))\n        print(\"GOLD:\", get_tree(gold_binary_trees[-1], sent_str))\n    print('--------------------------------')\n    print('Checking validation perf...')    \n    val_ppl, val_f1 = eval(val_data, model, \n                           samples = args.mc_samples, count_eos_ppl = args.count_eos_ppl)\n    print('--------------------------------')\n    if val_ppl < best_val_ppl:\n      best_val_ppl = val_ppl\n      best_val_f1 = val_f1\n      checkpoint = {\n        'args': args.__dict__,\n        'model': model.cpu(),\n        'word2idx': train_data.word2idx,\n        'idx2word': train_data.idx2word\n      }\n      print('Saving checkpoint to %s' % args.save_path)\n      torch.save(checkpoint, args.save_path)\n      model.cuda()\n    else:\n      if epoch > args.min_epochs:\n        decay = 1\n    if decay == 1:\n      args.lr = args.decay*args.lr\n      args.q_lr = args.decay*args.q_lr\n      args.action_lr = args.decay*args.action_lr\n      for param_group in optimizer.param_groups:\n        param_group['lr'] = args.lr\n      for param_group in q_optimizer.param_groups:\n        param_group['lr'] = args.q_lr\n      for param_group in action_optimizer.param_groups:\n        param_group['lr'] = args.action_lr\n    if args.lr < 0.03:\n      break\n  print(\"Finished training!\")\n\ndef eval(data, model, samples = 0, count_eos_ppl = 0):\n  model.eval()\n  num_sents = 0\n  num_words = 0\n  total_nll_recon = 0.\n  total_kl = 0.\n  total_nll_iwae = 0.\n  corpus_f1 = [0., 0., 0.]\n  sent_f1 = [] \n  with torch.no_grad():\n    for i in list(reversed(range(len(data)))):\n      sents, length, batch_size, gold_actions, gold_spans, gold_binary_trees, other_data = data[i] \n      if length == 1: # length 1 sents are ignored since URNNG needs at least length 2 sents\n        continue\n      if args.count_eos_ppl == 1:\n        tree_length = length\n        length += 1 \n      else:\n        sents = sents[:, :-1] \n        tree_length = length\n      sents = sents.cuda()\n      ll_word_all, ll_action_p_all, ll_action_q_all, actions_all, q_entropy = model(sents, \n                    samples = samples, has_eos = count_eos_ppl == 1)\n      ll_word, ll_action_p, ll_action_q = ll_word_all.mean(1), ll_action_p_all.mean(1), ll_action_q_all.mean(1)\n      kl = ll_action_q - ll_action_p\n      _, binary_matrix, argmax_spans = model.q_crf._viterbi(model.scores)\n      actions = []\n      for b in range(batch_size):        \n        tree = get_tree_from_binary_matrix(binary_matrix[b], tree_length)\n        actions.append(get_actions(tree))\n      actions = torch.Tensor(actions).long()\n      total_nll_recon += -ll_word.sum().item()\n      total_kl += kl.sum().item()\n      num_sents += batch_size\n      num_words += batch_size * length\n      if samples > 0:\n        #PPL estimate based on IWAE\n        sample_ll = torch.zeros(batch_size, samples)\n        for j in range(samples):\n          ll_word_j, ll_action_p_j, ll_action_q_j = ll_word_all[:, j], ll_action_p_all[:, j], ll_action_q_all[:, j]\n          sample_ll[:, j].copy_(ll_word_j + ll_action_p_j - ll_action_q_j)\n        ll_iwae = model.logsumexp(sample_ll, 1) - np.log(samples)\n        total_nll_iwae -= ll_iwae.sum().item()      \n      for b in range(batch_size):\n        action = list(actions[b].numpy())\n        span_b = get_spans(action)\n        span_b = argmax_spans[b]\n        span_b_set = set(span_b[:-1])        \n        gold_b_set = set(gold_spans[b][:-1])\n        tp, fp, fn = get_stats(span_b_set, gold_b_set) \n        corpus_f1[0] += tp\n        corpus_f1[1] += fp\n        corpus_f1[2] += fn\n\n        # sent-level F1 is based on L83-89 from https://github.com/yikangshen/PRPN/test_phrase_grammar.py\n        model_out = span_b_set\n        std_out = gold_b_set\n        overlap = model_out.intersection(std_out)\n        prec = float(len(overlap)) / (len(model_out) + 1e-8)\n        reca = float(len(overlap)) / (len(std_out) + 1e-8)\n        if len(std_out) == 0:\n          reca = 1. \n          if len(model_out) == 0:\n            prec = 1.\n        f1 = 2 * prec * reca / (prec + reca + 1e-8)\n        sent_f1.append(f1)\n  tp, fp, fn = corpus_f1  \n  prec = tp / (tp + fp)\n  recall = tp / (tp + fn)\n  corpus_f1 = 2*prec*recall/(prec+recall)*100 if prec+recall > 0 else 0.\n  sent_f1 = np.mean(np.array(sent_f1))*100\n\n  elbo_ppl = np.exp((total_nll_recon + total_kl) / num_words)\n  recon_ppl = np.exp(total_nll_recon / num_words)\n  iwae_ppl = np.exp(total_nll_iwae /num_words)\n  kl = total_kl / num_sents  \n  print('ElboPPL: %.2f, ReconPPL: %.2f, KL: %.4f, IwaePPL: %.2f, CorpusF1: %.2f, SentAvgF1: %.2f' % \n        (elbo_ppl, recon_ppl, kl, iwae_ppl, corpus_f1, sent_f1))\n  #note that corpus F1 printed here is different from what you should get from\n  #evalb since we do not ignore any tags (e.g. punctuation), while evalb ignores it\n  model.train()\n  return iwae_ppl, corpus_f1\n\nif __name__ == '__main__':\n  args = parser.parse_args()\n  main(args)\n"
  },
  {
    "path": "train_lm.py",
    "content": "#!/usr/bin/env python3\nimport sys\nimport os\n\nimport argparse\nimport json\nimport random\nimport shutil\nimport copy\n\nimport torch\nfrom torch import cuda\nimport torch.nn as nn\nfrom torch.autograd import Variable\nfrom torch.nn.parameter import Parameter\n\nimport torch.nn.functional as F\nimport numpy as np\nimport time\nimport logging\nfrom data import Dataset\nfrom models import RNNLM\nfrom utils import *\n\nparser = argparse.ArgumentParser()\n\n# Data path options\nparser.add_argument('--train_file', default='data/ptb-train.pkl')\nparser.add_argument('--val_file', default='data/ptb-val.pkl')\nparser.add_argument('--test_file', default='data/ptb-test.pkl')\nparser.add_argument('--train_from', default='')\n# Model options\nparser.add_argument('--w_dim', default=650, type=int, help='hidden dimension for LM')\nparser.add_argument('--h_dim', default=650, type=int, help='hidden dimension for LM')\nparser.add_argument('--num_layers', default=2, type=int, help='number of layers in LM and the stack LSTM (for RNNG)')\nparser.add_argument('--dropout', default=0.5, type=float, help='dropout rate')\n# Optimization options\nparser.add_argument('--count_eos_ppl', default=0, type=int, help='whether to count eos in val PPL')\nparser.add_argument('--test', default=0, type=int, help='')\nparser.add_argument('--save_path', default='urnng.pt', help='where to save the data')\nparser.add_argument('--num_epochs', default=30, type=int, help='number of training epochs')\nparser.add_argument('--min_epochs', default=8, type=int, help='do not decay learning rate for at least this many epochs')\nparser.add_argument('--lr', default=1, type=float, help='starting learning rate')\nparser.add_argument('--decay', default=0.5, type=float, help='')\nparser.add_argument('--param_init', default=0.1, type=float, help='parameter initialization (over uniform)')\nparser.add_argument('--max_grad_norm', default=5, type=float, help='gradient clipping parameter')\nparser.add_argument('--gpu', default=2, type=int, help='which gpu to use')\nparser.add_argument('--seed', default=3435, type=int, help='random seed')\nparser.add_argument('--print_every', type=int, default=500, help='print stats after this many batches')\n\n\ndef main(args):\n  np.random.seed(args.seed)\n  torch.manual_seed(args.seed)\n  train_data = Dataset(args.train_file)\n  val_data = Dataset(args.val_file)  \n  vocab_size = int(train_data.vocab_size)    \n  print('Train: %d sents / %d batches, Val: %d sents / %d batches' % \n        (train_data.sents.size(0), len(train_data), val_data.sents.size(0), \n         len(val_data)))\n  print('Vocab size: %d' % vocab_size)\n  cuda.set_device(args.gpu)\n  if args.train_from == '':\n    model = RNNLM(vocab = vocab_size,\n                  w_dim = args.w_dim, \n                  h_dim = args.h_dim,\n                  dropout = args.dropout,\n                  num_layers = args.num_layers)\n    if args.param_init > 0:\n      for param in model.parameters():    \n        param.data.uniform_(-args.param_init, args.param_init)      \n  else:\n    print('loading model from ' + args.train_from)\n    checkpoint = torch.load(args.train_from)\n    model = checkpoint['model']\n  print(\"model architecture\")\n  print(model)\n  optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)\n  model.train()\n  model.cuda()\n  epoch = 0\n  decay= 0\n  if args.test == 1:\n    test_data = Dataset(args.test_file)  \n    test_ppl = eval(test_data, model, count_eos_ppl = args.count_eos_ppl)\n    sys.exit(0)\n  best_val_ppl = eval(val_data, model, count_eos_ppl = args.count_eos_ppl)\n  while epoch < args.num_epochs:\n    start_time = time.time()\n    epoch += 1  \n    print('Starting epoch %d' % epoch)\n    train_nll = 0.\n    num_sents = 0.\n    num_words = 0.\n    b = 0\n    for i in np.random.permutation(len(train_data)):\n      sents, length, batch_size, gold_actions, gold_spans, gold_binary_trees, other_data = train_data[i]\n      if length == 1:\n        continue\n      sents = sents.cuda()\n      b += 1\n      optimizer.zero_grad()\n      optimizer.zero_grad()\n      nll = -model(sents).mean()\n      train_nll += nll.item()*batch_size\n      nll.backward()\n      if args.max_grad_norm > 0:\n        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)        \n      optimizer.step()\n      num_sents += batch_size\n      num_words += batch_size * length\n      if b % args.print_every == 0:\n        param_norm = sum([p.norm()**2 for p in model.parameters()]).item()**0.5\n        print('Epoch: %d, Batch: %d/%d, LR: %.4f, TrainPPL: %.2f, |Param|: %.4f, BestValPerf: %.2f, Throughput: %.2f examples/sec' % \n              (epoch, b, len(train_data), args.lr, np.exp(train_nll / num_words), \n               param_norm, best_val_ppl, num_sents / (time.time() - start_time)))\n    print('--------------------------------')\n    print('Checking validation perf...')    \n    val_ppl = eval(val_data, model,  count_eos_ppl = args.count_eos_ppl)\n    print('--------------------------------')\n    if val_ppl < best_val_ppl:\n      best_val_ppl = val_ppl\n      checkpoint = {\n        'args': args.__dict__,\n        'model': model.cpu(),\n        'word2idx': train_data.word2idx,\n        'idx2word': train_data.idx2word\n      }\n      print('Saving checkpoint to %s' % args.save_path)\n      torch.save(checkpoint, args.save_path)\n      model.cuda()\n    else:\n      if epoch > args.min_epochs:\n        decay = 1\n    if decay == 1:\n      args.lr = args.decay*args.lr\n      for param_group in optimizer.param_groups:\n        param_group['lr'] = args.lr\n    if args.lr < 0.03:\n      break\n    print(\"Finished training\")\n\ndef eval(data, model, count_eos_ppl = 0):\n  model.eval()\n  num_words = 0\n  total_nll = 0.\n  with torch.no_grad():\n    for i in list(reversed(range(len(data)))):\n      sents, length, batch_size, gold_actions, gold_spans, gold_binary_trees, other_data = data[i] \n      if length == 1: #we ignore length 1 sents in URNNG eval so do this for LM too\n        continue\n      if args.count_eos_ppl == 1:\n        length += 1 \n      else:\n        sents = sents[:, :-1] \n      sents = sents.cuda()\n      num_words += length * batch_size\n      nll = -model(sents).mean()\n      total_nll += nll.item()*batch_size\n  ppl = np.exp(total_nll / num_words)\n  print('PPL: %.2f' % (ppl))\n  model.train()\n  return ppl\n\nif __name__ == '__main__':\n  args = parser.parse_args()\n  main(args)\n"
  },
  {
    "path": "utils.py",
    "content": "#!/usr/bin/env python3\nimport numpy as np\nimport itertools\nimport random\n\n\ndef get_actions(tree, SHIFT = 0, REDUCE = 1, OPEN='(', CLOSE=')'):\n  #input tree in bracket form: ((A B) (C D))\n  #output action sequence: 0 0 1 0 0 1 1, where 0 is SHIFT and 1 is REDUCE\n  actions = []\n  tree = tree.strip()\n  i = 0\n  num_shift = 0\n  num_reduce = 0\n  left = 0\n  right = 0\n  while i < len(tree):\n    if tree[i] != ' ' and tree[i] != OPEN and tree[i] != CLOSE: #terminal      \n      if tree[i-1] == OPEN or tree[i-1] == ' ':\n        actions.append(SHIFT)\n        num_shift += 1\n    elif tree[i] == CLOSE:\n      actions.append(REDUCE)\n      num_reduce += 1\n      right += 1\n    elif tree[i] == OPEN:\n      left += 1\n    i += 1\n  assert(num_shift == num_reduce + 1)\n  return actions\n\n    \ndef get_tree(actions, sent = None, SHIFT = 0, REDUCE = 1):\n  #input action and sent (lists), e.g. S S R S S R R, A B C D\n  #output tree ((A B) (C D))\n  stack = []\n  pointer = 0\n  if sent is None:\n    sent = list(map(str, range((len(actions)+1) // 2)))\n  for action in actions:\n    if action == SHIFT:\n      word = sent[pointer]\n      stack.append(word)\n      pointer += 1\n    elif action == REDUCE:\n      right = stack.pop()\n      left = stack.pop()\n      stack.append('(' + left + ' ' + right + ')')\n  assert(len(stack) == 1)\n  return stack[-1]\n      \ndef get_spans(actions, SHIFT = 0, REDUCE = 1):\n  sent = list(range((len(actions)+1) // 2))\n  spans = []\n  pointer = 0\n  stack = []\n  for action in actions:\n    if action == SHIFT:\n      word = sent[pointer]\n      stack.append(word)\n      pointer += 1\n    elif action == REDUCE:\n      right = stack.pop()\n      left = stack.pop()\n      if isinstance(left, int):\n        left = (left, None)\n      if isinstance(right, int):\n        right = (None, right)\n      new_span = (left[0], right[1])\n      spans.append(new_span)\n      stack.append(new_span)\n  return spans\n\ndef get_stats(span1, span2):\n  tp = 0\n  fp = 0\n  fn = 0\n  for span in span1:\n    if span in span2:\n      tp += 1\n    else:\n      fp += 1\n  for span in span2:\n    if span not in span1:\n      fn += 1\n  return tp, fp, fn\n\ndef update_stats(pred_span, gold_spans, stats):\n  for gold_span, stat in zip(gold_spans, stats):\n    tp, fp, fn = get_stats(pred_span, gold_span)\n    stat[0] += tp\n    stat[1] += fp\n    stat[2] += fn\n\ndef get_f1(stats):\n  f1s = []\n  for stat in stats:\n    prec = stat[0] / (stat[0] + stat[1])\n    recall = stat[0] / (stat[0] + stat[2])\n    f1 = 2*prec*recall / (prec + recall)*100 if prec+recall > 0 else 0.\n    f1s.append(f1)\n  return f1s\n\n\ndef span_str(start = None, end = None):\n  assert(start is not None or end is not None)\n  if start is None:\n    return ' '  + str(end) + ')'\n  elif end is None:\n    return '(' + str(start) + ' '\n  else:\n    return ' (' + str(start) + ' ' + str(end) + ') '    \n\n\ndef get_tree_from_binary_matrix(matrix, length):    \n  sent = list(map(str, range(length)))\n  n = len(sent)\n  tree = {}\n  for i in range(n):\n    tree[i] = sent[i]\n  for k in np.arange(1, n):\n    for s in np.arange(n):\n      t = s + k\n      if t > n-1:\n        break\n      if matrix[s][t].item() == 1:\n        span = '(' + tree[s] + ' ' + tree[t] + ')'\n        tree[s] = span\n        tree[t] = span\n  return tree[0]\n\ndef get_nonbinary_spans(actions, SHIFT = 0, REDUCE = 1):\n  spans = []\n  stack = []\n  pointer = 0\n  binary_actions = []\n  nonbinary_actions = []\n  num_shift = 0\n  num_reduce = 0\n  for action in actions:\n    # print(action, stack)\n    if action == \"SHIFT\":\n      nonbinary_actions.append(SHIFT)\n      stack.append((pointer, pointer))\n      pointer += 1\n      binary_actions.append(SHIFT)\n      num_shift += 1\n    elif action[:3] == 'NT(':\n      stack.append('(')            \n    elif action == \"REDUCE\":\n      nonbinary_actions.append(REDUCE)\n      right = stack.pop()\n      left = right\n      n = 1\n      while stack[-1] is not '(':\n        left = stack.pop()\n        n += 1\n      span = (left[0], right[1])\n      if left[0] != right[1]:\n        spans.append(span)\n      stack.pop()\n      stack.append(span)\n      while n > 1:\n        n -= 1\n        binary_actions.append(REDUCE)        \n        num_reduce += 1\n    else:\n      assert False  \n  assert(len(stack) == 1)\n  assert(num_shift == num_reduce + 1)\n  return spans, binary_actions, nonbinary_actions\n"
  }
]