Full Code of harvardnlp/urnng for AI

master b1eeffa5b590 cached
15 files
101.9 KB
31.6k tokens
74 symbols
1 requests
Download .txt
Repository: harvardnlp/urnng
Branch: master
Commit: b1eeffa5b590
Files: 15
Total size: 101.9 KB

Directory structure:
gitextract_zduez7kc/

├── .gitignore
├── COLLINS.prm
├── README.md
├── TreeCRF.py
├── data/
│   ├── test.txt
│   ├── train.txt
│   └── valid.txt
├── data.py
├── eval_ppl.py
├── models.py
├── parse.py
├── preprocess.py
├── train.py
├── train_lm.py
└── utils.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
*.pt
*.amat
*.mat
*.out
*.out~
*.pyc
*.pt~
.gitignore~
*.out~
*.sh
*.sh~
*.py~
*.json
*.json~
*.model
*.h5
*.tar.gz
*.hdf5
*.dict
*.pkl


================================================
FILE: COLLINS.prm
================================================
##------------------------------------------##
## Debug mode                               ##
##   0: No debugging                        ##
##   1: print data for individual sentence  ##
##------------------------------------------##
DEBUG 0

##------------------------------------------##
## MAX error                                ##
##    Number of error to stop the process.  ##
##    This is useful if there could be      ##
##    tokanization error.                   ##
##    The process will stop when this number##
##    of errors are accumulated.            ##
##------------------------------------------##
MAX_ERROR 10

##------------------------------------------##
## Cut-off length for statistics            ##
##    At the end of evaluation, the         ##
##    statistics for the senetnces of length##
##    less than or equal to this number will##
##    be shown, on top of the statistics    ##
##    for all the sentences                 ##
##------------------------------------------##
CUTOFF_LEN 10

##------------------------------------------##
## unlabeled or labeled bracketing          ##
##    0: unlabeled bracketing               ##
##    1: labeled bracketing                 ##
##------------------------------------------##
LABELED 0

##------------------------------------------##
## Delete labels                            ##
##    list of labels to be ignored.         ##
##    If it is a pre-terminal label, delete ##
##    the word along with the brackets.     ##
##    If it is a non-terminal label, just   ##
##    delete the brackets (don't delete     ##
##    deildrens).                           ##
##------------------------------------------##
DELETE_LABEL TOP
DELETE_LABEL -NONE-
DELETE_LABEL ,
DELETE_LABEL :
DELETE_LABEL ``
DELETE_LABEL ''
DELETE_LABEL .

##------------------------------------------##
## Delete labels for length calculation     ##
##    list of labels to be ignored for      ##
##    length calculation purpose            ##
##------------------------------------------##
DELETE_LABEL_FOR_LENGTH -NONE-

##------------------------------------------##
## Equivalent labels, words                 ##
##     the pairs are considered equivalent  ##
##     This is non-directional.             ##
##------------------------------------------##
EQ_LABEL ADVP PRT

# EQ_WORD  Example example


================================================
FILE: README.md
================================================
# Unsupervised Recurrent Neural Network Grammars

This is an implementation of the paper:  
[Unsupervised Recurrent Neural Network Grammars](https://arxiv.org/abs/1904.03746)  
Yoon Kim, Alexander Rush, Lei Yu, Adhiguna Kuncoro, Chris Dyer, Gabor Melis  
NAACL 2019  

## Dependencies
The code was tested in `python 3.6` and `pytorch 1.0`.

## Data  
Sample train/val/test data is in the `data/` folder. These are the standard datasets from PTB.
First preprocess the data:
```
python preprocess.py --trainfile data/train.txt --valfile data/valid.txt --testfile data/test.txt 
--outputfile data/ptb --vocabminfreq 1 --lowercase 0 --replace_num 0 --batchsize 16
```
Running this will save the following files in the `data/` folder: `ptb-train.pkl`, `ptb-val.pkl`,
`ptb-test.pkl`, `ptb.dict`. Here `ptb.dict` is the word-idx mapping, and you can change the
output folder/name by changing the argument to `outputfile`. Also, the preprocessing here
will replace singletons with a single `<unk>` rather than with Berkeley parser's mapping rules
(see below for results using this setup).

## Training
To train the URNNG:
```
python train.py --train_file data/ptb-train.pkl --val_file data/ptb-val.pkl --save_path urnng.pt 
--mode unsupervised --gpu 0
```
where `save_path` is where you want to save the model, and `gpu 0` is for using the first GPU
in the cluster (the mapping from PyTorch GPU index to your cluster's GPU index may vary).
Training should take 2 to 3 days depending on your setup.

To train the RNNG:
```
python train.py --train_file data/ptb-train.pkl --val_file data/ptb-val.pkl --save_path rnng.pt 
--mode supervised --train_q_epochs 18 --gpu 0 
```

For fine-tuning:
```
python train.py --train_from rnng.pt --train_file data/ptb-train.pkl --val_file data/ptb-val.pkl 
--save_path rnng-urnng.pt --mode unsupervised --lr 0.1 --train_q_epochs 10 --epochs 10 
--min_epochs 6 --gpu 0 --kl_warmup 0
```

To train the LM:
```
python train_lm.py --train_file data/ptb-train.pkl --val_file data/ptb-val.pkl 
--test_file data/ptb-test.pkl --save_path lm.pt 
```

## Evaluation
To evaluate perplexity with importance sampling on the test set:
```
python eval_ppl.py --model_file urnng.pt --test_file data/ptb-test.pkl --samples 1000 
--is_temp 2 --gpu 0
```
The argument `samples` is for the number of importance weighted samples, and `is_temp` is for
flattening the inference network's distribution (footnote 14 in the paper).
The same evaluation code will work for RNNG. 

For LM evaluation:
```
python train_lm.py --train_from lm.pt --test_file data/ptb-test.pkl --test 1
```

To evaluate F1, first we need to parse the test set:
```
python parse.py --model_file urnng.pt --data_file data/ptb-test.txt --out_file pred-parse.txt 
--gold_out_file gold-parse.txt --gpu 0
```
This will output the predicted parse trees into `pred-parse.txt`. We also output a version
of the gold parse `gold-parse.txt` to be used as input for `evalb`, since sentences with only trivial spans are ignored by `parse.py`. Note that corpus/sentence F1 results printed here do not correspond to the results reported in the paper, since it does not ignore punctuation. 

Finally, download/install `evalb`, available [here](https://nlp.cs.nyu.edu/evalb).
Then run:
```
evalb -p COLLINS.prm gold-parse.txt test-parse.txt
```
where `COLLINS.prm` is the parameter file (provided in this repo) that tells `evalb` to ignore
punctuation and evaluate on unlabeled F1.

## Note Regarding Preprocessing
Note that some of the details regarding the preprocessing is slightly different from the original 
paper. In particular, in this implementation we replace singleton words a single `<unk>` token
instead of using Berkeley parser's mapping rules. This results in slight lower perplexity
for all models, since the vocabulary size is smaller. Here are the perplexty numbers I get
in this setting:

- RNNLM: 89.2 
- RNNG: 83.7 
- URNNG: 85.1 (F1: 38.4) 
- RNNG --> URNNG: 82.5

## Acknowledgements
Some of our preprocessing and evaluation code is based on the following repositories:  
- [Recurrent Neural Network Grammars](https://github.com/clab/rnng)  
- [Parsing Reading Predict Network](https://github.com/yikangshen/PRPN)  

## License
MIT

================================================
FILE: TreeCRF.py
================================================
#!/usr/bin/env python3

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import itertools
import utils
import random
  
class ConstituencyTreeCRF(nn.Module):
  def __init__(self):
    super(ConstituencyTreeCRF, self).__init__()
    self.huge = 1e9

  def logadd(self, x, y):
    d = torch.max(x,y)  
    return torch.log(torch.exp(x-d) + torch.exp(y-d)) + d    

  def logsumexp(self, x, dim=1):
    d = torch.max(x, dim)[0]
    return torch.log(torch.exp(x - d.unsqueeze(dim).expand_as(x)).sum(dim)) + d

  def _init_table(self, scores):
    # initialize dynamic programming table
    batch_size = scores.size(0)
    n = scores.size(1)
    self.alpha = [[scores.new(batch_size).fill_(-self.huge) for _ in range(n)] for _ in range(n)]

  def _forward(self, scores):
    #inside step
    batch_size = scores.size(0)
    n = scores.size(1)
    self._init_table(scores)
    for i in range(n):
      self.alpha[i][i] = scores[:, i, i]
    for k in np.arange(1, n+1):
      for s in range(n):
        t = s + k
        if t > n-1:
          break
        tmp = [self.alpha[s][u] + self.alpha[u+1][t] + scores[:, s, t] for u in np.arange(s,t)]
        tmp = torch.stack(tmp, 1)
        self.alpha[s][t] = self.logsumexp(tmp, 1)
            
  def _backward(self, scores):
    #outside step
    batch_size = scores.size(0)
    n = scores.size(1)
    self.beta = [[None for _ in range(n)] for _ in range(n)]
    self.beta[0][n-1] = scores.new(batch_size).fill_(0)
    for k in np.arange(n-1, 0, -1):
      for s in range(n):
        t = s + k
        if t > n-1:
          break
        for u in np.arange(s, t):                    
          if s < u+1:
            tmp = self.beta[s][t] + self.alpha[u+1][t] + scores[:, s, t]
            if self.beta[s][u] is None:
              self.beta[s][u] = tmp
            else:
              self.beta[s][u] = self.logadd(self.beta[s][u], tmp)
          if u+1 < t+1:
            tmp =  self.beta[s][t] + self.alpha[s][u]  + scores[:, s, t]
            if self.beta[u+1][t] is None:
              self.beta[u+1][t] = tmp
            else:
              self.beta[u+1][t] = self.logadd(self.beta[u+1][t], tmp)

  def _marginal(self, scores):
    batch_size = scores.size(0)
    n = scores.size(1)
    self.log_marginal = [[None for _ in range(n)] for _ in range(n)]
    log_Z = self.alpha[0][n-1]
    for s in range(n):
      for t in np.arange(s, n):
        self.log_marginal[s][t] = self.alpha[s][t] + self.beta[s][t] - log_Z
  
  def _entropy(self, scores):
    batch_size = scores.size(0)
    n = scores.size(1)
    self.entropy = [[None for _ in range(n)] for _ in range(n)]
    for i in range(n):
      self.entropy[i][i] = scores.new(batch_size).fill_(0)
    for k in np.arange(1, n+1):
      for s in range(n):
        t = s + k
        if t > n-1:
          break
        score = []
        prev_ent = []
        for u in np.arange(s, t):
          score.append(self.alpha[s][u] + self.alpha[u+1][t])
          prev_ent.append(self.entropy[s][u] + self.entropy[u+1][t])
        score = torch.stack(score, 1) 
        prev_ent = torch.stack(prev_ent, 1)
        log_prob = F.log_softmax(score, dim = 1)
        prob = log_prob.exp()        
        entropy = ((prev_ent - log_prob)*prob).sum(1)
        self.entropy[s][t] = entropy
      
        
  def _sample(self, scores, alpha = None, argmax = False):    
    # sample from p(tree | sent)
    # also get the spans
    if alpha is None:
      self._forward(scores)
      alpha = self.alpha
    batch_size = scores.size(0)
    n = scores.size(1)
    tree = scores.new(batch_size, n, n).zero_()
    all_log_probs = []
    tree_brackets = []
    spans = []
    for b in range(batch_size):
      sampled = [(0, n-1)]
      span = [(0, n-1)]
      queue = [(0, n-1)] #start, end
      log_probs = []
      tree_str = get_span_str(0, n-1)
      while len(queue) > 0:
        node = queue.pop(0)
        start, end = node
        left_parent = get_span_str(start, None)
        right_parent = get_span_str(None, end)
        score = []
        score_idx = []
        for u in np.arange(start, end):
          score.append(alpha[start][u][b] + alpha[u+1][end][b])
          score_idx.append([(start, u), (u+1, end)])
        score = torch.stack(score, 0) 
        log_prob = F.log_softmax(score, dim = 0)
        if argmax:
          sample = torch.max(log_prob, 0)[1]
        else:
          prob = log_prob.exp()
          sample = torch.multinomial(log_prob.exp(), 1)          
        sample_idx = score_idx[sample.item()]
        log_probs.append(log_prob[sample.item()])
        for idx in sample_idx:
          if idx[0] != idx[1]:
            queue.append(idx)
            span.append(idx)
          sampled.append(idx)
        left_child = '(' + get_span_str(sample_idx[0][0], sample_idx[0][1])    
        right_child = get_span_str(sample_idx[1][0], sample_idx[1][1]) + ')'
        if sample_idx[0][0] != sample_idx[0][1]:
          tree_str = tree_str.replace(left_parent, left_child)
        if sample_idx[1][0] != sample_idx[1][1]:
          tree_str = tree_str.replace(right_parent, right_child)
      all_log_probs.append(torch.stack(log_probs, 0).sum(0))
      tree_brackets.append(tree_str)
      spans.append(span[::-1])
      for idx in sampled:
        tree[b][idx[0]][idx[1]] = 1
        
    all_log_probs = torch.stack(all_log_probs, 0)
    return tree, all_log_probs, tree_brackets, spans

  def _viterbi(self, scores):
    # cky algorithm
    batch_size = scores.size(0)
    n = scores.size(1)
    self.max_scores = scores.new(batch_size, n, n).fill_(-self.huge)
    self.bp = scores.new(batch_size, n, n).zero_()
    self.argmax = scores.new(batch_size, n, n).zero_()
    self.spans = [[] for _ in range(batch_size)]
    tmp = scores.new(batch_size, n).zero_()
    for i in range(n):
      self.max_scores[:, i, i] = scores[:, i, i]      
    for k in np.arange(1, n):
      for s in np.arange(n):
        t = s + k
        if t > n-1:
          break
        for u in np.arange(s, t):
          tmp = self.max_scores[:, s, u] + self.max_scores[:, u+1, t] + scores[:, s, t]
          self.bp[:, s, t][self.max_scores[:, s, t] < tmp] = int(u)
          self.max_scores[:, s, t] = torch.max(self.max_scores[:, s, t], tmp)
    for b in range(batch_size):
      self._backtrack(b, 0, n-1)      
    return self.max_scores[:, 0, n-1], self.argmax, self.spans

  def _backtrack(self, b, s, t):
    u = int(self.bp[b][s][t])
    self.argmax[b][s][t] = 1
    if s == t:
      return None      
    else:
      self.spans[b].insert(0, (s,t))
      self._backtrack(b, s, u)
      self._backtrack(b, u+1, t)
    return None  
 
def get_span_str(start = None, end = None):
  assert(start is not None or end is not None)
  if start is None:
    return ' '  + str(end) + ')'
  elif end is None:
    return '(' + str(start) + ' '
  else:
    return ' (' + str(start) + ' ' + str(end) + ') '    


================================================
FILE: data/test.txt
================================================
(S (INTJ (RB No)) (, ,) (NP (PRP it)) (VP (VBD was) (RB n't) (NP (NNP Black) (NNP Monday))) (. .))
(S (CC But) (SBAR (IN while) (S (NP (DT the) (NNP New) (NNP York) (NNP Stock) (NNP Exchange)) (VP (VBD did) (RB n't) (VP (VB fall) (ADVP (RB apart)) (NP (NNP Friday)) (SBAR (IN as) (S (NP (DT the) (NNP Dow) (NNP Jones) (NNP Industrial) (NNP Average)) (VP (VBD plunged) (NP (NP (CD 190.58) (NNS points)) (PRN (: --) (NP (NP (JJS most)) (PP (IN of) (NP (PRP it))) (PP (IN in) (NP (DT the) (JJ final) (NN hour)))) (: --)))))))))) (NP (PRP it)) (ADVP (RB barely)) (VP (VBD managed) (S (VP (TO to) (VP (VB stay) (NP (NP (DT this) (NN side)) (PP (IN of) (NP (NN chaos)))))))) (. .))
(S (NP (NP (DT Some) (`` ``) (NN circuit) (NNS breakers) ('' '')) (VP (VBN installed) (PP (IN after) (NP (DT the) (NNP October) (CD 1987) (NN crash))))) (VP (VBD failed) (NP (PRP$ their) (JJ first) (NN test)) (PRN (, ,) (S (NP (NNS traders)) (VP (VBP say))) (, ,)) (S (ADJP (JJ unable) (S (VP (TO to) (VP (VB cool) (NP (NP (DT the) (NN selling) (NN panic)) (PP (IN in) (NP (DT both) (NNS stocks) (CC and) (NNS futures)))))))))) (. .))
(S (NP (NP (NP (DT The) (CD 49) (NN stock) (NN specialist) (NNS firms)) (PP (IN on) (NP (DT the) (NNP Big) (NNP Board) (NN floor)))) (: --) (NP (NP (DT the) (NNS buyers) (CC and) (NNS sellers)) (PP (IN of) (NP (JJ last) (NN resort))) (SBAR (WHNP (WP who)) (S (VP (VBD were) (VP (VBN criticized) (PP (IN after) (NP (DT the) (CD 1987) (NN crash)))))))) (: --)) (ADVP (RB once) (RB again)) (VP (MD could) (RB n't) (VP (VB handle) (NP (DT the) (NN selling) (NN pressure)))) (. .))
(S (S (NP (JJ Big) (NN investment) (NNS banks)) (VP (VBD refused) (S (VP (TO to) (VP (VB step) (ADVP (IN up) (PP (TO to) (NP (DT the) (NN plate)))) (S (VP (TO to) (VP (VB support) (NP (DT the) (JJ beleaguered) (NN floor) (NNS traders)) (PP (IN by) (S (VP (VBG buying) (NP (NP (JJ big) (NNS blocks)) (PP (IN of) (NP (NN stock))))))))))))))) (, ,) (NP (NNS traders)) (VP (VBP say)) (. .))
(S (NP (NP (JJ Heavy) (NN selling)) (PP (IN of) (NP (NP (NNP Standard) (CC &) (NNP Poor) (POS 's)) (JJ 500-stock) (NN index) (NNS futures))) (PP (IN in) (NP (NNP Chicago)))) (VP (ADVP (RB relentlessly)) (VBD beat) (NP (NNS stocks)) (ADVP (RB downward))) (. .))
(S (NP (NP (CD Seven) (NNP Big) (NNP Board) (NNS stocks)) (: --) (NP (NP (NNP UAL)) (, ,) (NP (NNP AMR)) (, ,) (NP (NNP BankAmerica)) (, ,) (NP (NNP Walt) (NNP Disney)) (, ,) (NP (NNP Capital) (NNP Cities\/ABC)) (, ,) (NP (NNP Philip) (NNP Morris)) (CC and) (NP (NNP Pacific) (NNP Telesis) (NNP Group))) (: --)) (VP (VP (VBD stopped) (S (VP (VBG trading)))) (CC and) (VP (ADVP (RB never)) (VBD resumed))) (. .))
(S (NP (DT The) (NN finger-pointing)) (VP (VBZ has) (ADVP (RB already)) (VP (VBN begun))) (. .))
(S (`` ``) (NP (DT The) (NN equity) (NN market)) (VP (VBD was) (ADJP (JJ illiquid))) (. .))
(SINV (S (ADVP (RB Once) (RB again)) (-LRB- -LCB-) (NP (DT the) (NNS specialists)) (-RRB- -RCB-) (VP (VBD were) (RB not) (ADJP (JJ able) (S (VP (TO to) (VP (VB handle) (NP (NP (DT the) (NNS imbalances)) (PP (IN on) (NP (NP (DT the) (NN floor)) (PP (IN of) (NP (DT the) (NNP New) (NNP York) (NNP Stock) (NNP Exchange)))))))))))) (, ,) ('' '') (VP (VBD said)) (NP (NP (NNP Christopher) (NNP Pedersen)) (, ,) (NP (NP (JJ senior) (NN vice) (NN president)) (PP (IN at) (NP (NNP Twenty-First) (NNP Securities) (NNP Corp))))) (. .))
(SINV (VP (VBD Countered)) (NP (NP (NNP James) (NNP Maguire)) (, ,) (NP (NP (NN chairman)) (PP (IN of) (NP (NNS specialists) (NNP Henderson) (NNP Brothers) (NNP Inc.))))) (: :) (`` ``) (S (NP (PRP It)) (VP (VBZ is) (ADJP (JJ easy)) (S (VP (TO to) (VP (VB say) (SBAR (S (NP (DT the) (NN specialist)) (VP (VBZ is) (RB n't) (VP (VBG doing) (NP (PRP$ his) (NN job))))))))))) (. .))
(S (SBAR (WHADVP (WRB When)) (S (NP (DT the) (NN dollar)) (VP (VBZ is) (PP (IN in) (NP (DT a) (NN free-fall)))))) (, ,) (NP (RB even) (JJ central) (NNS banks)) (VP (MD ca) (RB n't) (VP (VB stop) (NP (PRP it)))) (. .))
(S (NP (NNS Speculators)) (VP (VBP are) (VP (VBG calling) (PP (IN for) (NP (NP (DT a) (NN degree)) (PP (IN of) (NP (NN liquidity))) (SBAR (WHNP (WDT that)) (S (VP (VBZ is) (RB not) (ADVP (RB there)) (PP (IN in) (NP (DT the) (NN market)))))))))) (. .) ('' ''))
(S (NP (NP (JJ Many) (NN money) (NNS managers)) (CC and) (NP (DT some) (NNS traders))) (VP (VBD had) (ADVP (RB already)) (VP (VBN left) (NP (PRP$ their) (NNS offices)) (NP (RB early) (NNP Friday) (NN afternoon)) (PP (IN on) (NP (DT a) (JJ warm) (NN autumn) (NN day))) (: --) (SBAR (IN because) (S (NP (DT the) (NN stock) (NN market)) (VP (VBD was) (ADJP (RB so) (JJ quiet))))))) (. .))
(S (RB Then) (PP (IN in) (NP (DT a) (NN lightning) (NN plunge))) (, ,) (NP (DT the) (NNP Dow) (NNP Jones) (NNS industrials)) (PP (IN in) (NP (QP (RB barely) (DT an)) (NN hour))) (VP (VBD surrendered) (NP (NP (QP (RB about) (DT a)) (JJ third)) (PP (IN of) (NP (NP (PRP$ their) (NNS gains)) (NP (DT this) (NN year))))) (, ,) (S (VP (VBG chalking) (PRT (RP up)) (NP (NP (DT a) (ADJP (ADJP (JJ 190.58-point)) (, ,) (CC or) (ADJP (CD 6.9) (NN %)) (, ,)) (NN loss)) (PP (IN on) (NP (DT the) (NN day)))) (PP (IN in) (NP (JJ gargantuan) (NN trading) (NN volume)))))) (. .))
(S (NP (JJ Final-hour) (NN trading)) (VP (VBD accelerated) (PP (TO to) (NP (NP (QP (CD 108.1) (CD million)) (NNS shares)) (, ,) (NP (NP (DT a) (NN record)) (PP (IN for) (NP (DT the) (NNP Big) (NNP Board))))))) (. .))
(S (PP (IN At) (NP (NP (DT the) (NN end)) (PP (IN of) (NP (DT the) (NN day))))) (, ,) (NP (QP (CD 251.2) (CD million)) (NNS shares)) (VP (VBD were) (VP (VBN traded))) (. .))
(S (NP (DT The) (NNP Dow) (NNP Jones) (NNS industrials)) (VP (VBD closed) (PP (IN at) (NP (CD 2569.26)))) (. .))
(S (NP (NP (DT The) (NNP Dow) (POS 's)) (NN decline)) (VP (VBD was) (ADJP (JJ second) (PP (IN in) (NP (NN point) (NNS terms))) (PP (ADVP (RB only)) (TO to) (NP (NP (DT the) (JJ 508-point) (NNP Black) (NNP Monday) (NN crash)) (SBAR (WHNP (WDT that)) (S (VP (VBD occurred) (NP (NNP Oct.) (CD 19) (, ,) (CD 1987))))))))) (. .))
(S (PP (IN In) (NP (NN percentage) (NNS terms))) (, ,) (ADVP (RB however)) (, ,) (NP (NP (DT the) (NNP Dow) (POS 's)) (NN dive)) (VP (VBD was) (NP (NP (NP (DT the) (JJ 12th-worst)) (ADVP (RB ever))) (CC and) (NP (NP (DT the) (JJS sharpest)) (SBAR (IN since) (S (NP (DT the) (NN market)) (VP (VBD fell) (NP (NP (CD 156.83)) (, ,) (CC or) (NP (CD 8) (NN %))) (, ,) (PP (NP (DT a) (NN week)) (IN after) (NP (NNP Black) (NNP Monday))))))))) (. .))


================================================
FILE: data/train.txt
================================================
(S (PP (IN In) (NP (NP (DT an) (NNP Oct.) (CD 19) (NN review)) (PP (IN of) (NP (`` ``) (NP (DT The) (NN Misanthrope)) ('' '') (PP (IN at) (NP (NP (NNP Chicago) (POS 's)) (NNP Goodman) (NNP Theatre))))) (PRN (-LRB- -LRB-) (`` ``) (S (NP (VBN Revitalized) (NNS Classics)) (VP (VBP Take) (NP (DT the) (NN Stage)) (PP (IN in) (NP (NNP Windy) (NNP City))))) (, ,) ('' '') (NP (NN Leisure) (CC &) (NNS Arts)) (-RRB- -RRB-)))) (, ,) (NP (NP (NP (DT the) (NN role)) (PP (IN of) (NP (NNP Celimene)))) (, ,) (VP (VBN played) (PP (IN by) (NP (NNP Kim) (NNP Cattrall)))) (, ,)) (VP (VBD was) (VP (ADVP (RB mistakenly)) (VBN attributed) (PP (TO to) (NP (NNP Christina) (NNP Haag))))) (. .))
(S (NP (NNP Ms.) (NNP Haag)) (VP (VBZ plays) (NP (NNP Elianti))) (. .))
(S (NP (NNP Rolls-Royce) (NNP Motor) (NNPS Cars) (NNP Inc.)) (VP (VBD said) (SBAR (S (NP (PRP it)) (VP (VBZ expects) (S (NP (PRP$ its) (NNP U.S.) (NNS sales)) (VP (TO to) (VP (VB remain) (ADJP (JJ steady)) (PP (IN at) (NP (QP (IN about) (CD 1,200)) (NNS cars))) (PP (IN in) (NP (CD 1990)))))))))) (. .))
(S (NP (DT The) (NN luxury) (NN auto) (NN maker)) (NP (JJ last) (NN year)) (VP (VBD sold) (NP (CD 1,214) (NNS cars)) (PP (IN in) (NP (DT the) (NNP U.S.)))))
(S (NP (NP (NNP Howard) (NNP Mosher)) (, ,) (NP (NP (NN president)) (CC and) (NP (JJ chief) (NN executive) (NN officer))) (, ,)) (VP (VBD said) (SBAR (S (NP (PRP he)) (VP (VBZ anticipates) (NP (NP (NN growth)) (PP (IN for) (NP (DT the) (NN luxury) (NN auto) (NN maker))) (PP (PP (IN in) (NP (NNP Britain) (CC and) (NNP Europe))) (, ,) (CC and) (PP (IN in) (NP (ADJP (JJ Far) (JJ Eastern)) (NNS markets))))))))) (. .))
(S (NP (NNP BELL) (NNP INDUSTRIES) (NNP Inc.)) (VP (VBD increased) (NP (PRP$ its) (NN quarterly)) (PP (TO to) (NP (CD 10) (NNS cents))) (PP (IN from) (NP (NP (CD seven) (NNS cents)) (NP (DT a) (NN share))))) (. .))
(S (NP (DT The) (JJ new) (NN rate)) (VP (MD will) (VP (VB be) (ADJP (JJ payable) (NP (NNP Feb.) (CD 15))))) (. .))
(S (NP (DT A) (NN record) (NN date)) (VP (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set)))) (. .))
(S (NP (NP (NNP Bell)) (, ,) (VP (VBN based) (PP (IN in) (NP (NNP Los) (NNP Angeles)))) (, ,)) (VP (VBZ makes) (CC and) (VBZ distributes) (NP (UCP (JJ electronic) (, ,) (NN computer) (CC and) (NN building)) (NNS products))) (. .))
(S (NP (NNS Investors)) (VP (VBP are) (VP (VBG appealing) (PP (TO to) (NP (DT the) (NNPS Securities) (CC and) (NNP Exchange) (NNP Commission))) (S (RB not) (VP (TO to) (VP (VB limit) (NP (NP (PRP$ their) (NN access)) (PP (TO to) (NP (NP (NN information)) (PP (IN about) (NP (NP (NN stock) (NNS purchases) (CC and) (NNS sales)) (PP (IN by) (NP (JJ corporate) (NNS insiders))))))))))))) (. .))
(S (S (NP (DT A) (NNP SEC) (NN proposal) (S (VP (TO to) (VP (VB ease) (NP (NP (NN reporting) (NNS requirements)) (PP (IN for) (NP (DT some) (NN company) (NNS executives)))))))) (VP (MD would) (VP (VB undermine) (NP (NP (DT the) (NN usefulness)) (PP (IN of) (NP (NP (NN information)) (PP (IN on) (NP (NN insider) (NNS trades))))) (PP (IN as) (NP (DT a) (JJ stock-picking) (NN tool))))))) (, ,) (NP (NP (JJ individual) (NNS investors)) (CC and) (NP (JJ professional) (NN money) (NNS managers))) (VP (VBP contend)) (. .))
(S (NP (PRP They)) (VP (VBP make) (NP (DT the) (NN argument)) (PP (IN in) (NP (NP (NNS letters)) (PP (TO to) (NP (DT the) (NN agency))) (PP (IN about) (NP (NP (NN rule) (NNS changes)) (VP (VBD proposed) (NP (DT this) (JJ past) (NN summer))) (SBAR (WHNP (IN that)) (, ,) (S (PP (IN among) (NP (JJ other) (NNS things))) (, ,) (VP (MD would) (VP (VB exempt) (NP (JJ many) (JJ middle-management) (NNS executives)) (PP (IN from) (S (VP (VBG reporting) (NP (NP (NNS trades)) (PP (IN in) (NP (NP (PRP$ their) (JJ own) (NNS companies) (POS ')) (NNS shares)))))))))))))))) (. .))
(S (NP (DT The) (VBN proposed) (NNS changes)) (ADVP (RB also)) (VP (MD would) (VP (VB allow) (S (NP (NNS executives)) (VP (TO to) (VP (VB report) (NP (NP (NNS exercises)) (PP (IN of) (NP (NNS options)))) (ADVP (ADVP (RBR later)) (CC and) (ADVP (RBR less) (RB often)))))))) (. .))
(S (NP (NP (JJ Many)) (PP (IN of) (NP (DT the) (NNS letters)))) (VP (VBP maintain) (SBAR (IN that) (S (S (NP (NN investor) (NN confidence)) (VP (VBZ has) (VP (VBN been) (VP (ADVP (RB so)) (VBN shaken) (PP (IN by) (NP (DT the) (CD 1987) (NN stock) (NN market) (NN crash))))))) (: --) (CC and) (S (NP (DT the) (NNS markets)) (ADVP (RB already)) (VP (ADVP (RB so)) (VBN stacked) (PP (IN against) (NP (DT the) (JJ little) (NN guy))))) (: --) (SBAR (IN that) (S (NP (NP (DT any) (NN decrease)) (PP (IN in) (NP (NP (NN information)) (PP (IN on) (NP (NN insider-trading) (NNS patterns)))))) (VP (MD might) (VP (VB prompt) (S (NP (NNS individuals)) (VP (TO to) (VP (VB get) (ADVP (RB out) (PP (IN of) (NP (NNS stocks)))) (ADVP (RB altogether)))))))))))) (. .))
(SINV (`` ``) (S (NP (DT The) (NNP SEC)) (VP (VBZ has) (ADVP (RB historically)) (VP (VBN paid) (NP (NN obeisance)) (PP (TO to) (NP (NP (DT the) (NN ideal)) (PP (IN of) (NP (DT a) (JJ level) (NN playing) (NN field)))))))) (, ,) ('' '') (VP (VBD wrote)) (NP (NP (NNP Clyde) (NNP S.) (NNP McGregor)) (PP (IN of) (NP (NP (NNP Winnetka)) (, ,) (NP (NNP Ill.)) (, ,)))) (PP (IN in) (NP (NP (CD one)) (PP (IN of) (NP (NP (DT the) (CD 92) (NNS letters)) (SBAR (S (NP (DT the) (NN agency)) (VP (VBZ has) (VP (VBN received) (SBAR (IN since) (S (NP (DT the) (NNS changes)) (VP (VBD were) (VP (VBN proposed) (NP (NNP Aug.) (CD 17)))))))))))))) (. .))
(S (`` ``) (ADVP (RB Apparently)) (NP (DT the) (NN commission)) (VP (VBD did) (RB not) (ADVP (RB really)) (VP (VB believe) (PP (IN in) (NP (DT this) (NN ideal))))) (. .) ('' ''))
(S (ADVP (RB Currently)) (, ,) (NP (DT the) (NNS rules)) (VP (VBP force) (S (NP (NP (NNS executives)) (, ,) (NP (NNS directors)) (CC and) (NP (JJ other) (JJ corporate) (NNS insiders))) (VP (TO to) (VP (VB report) (NP (NP (NNS purchases) (CC and) (NNS sales)) (PP (IN of) (NP (NP (PRP$ their) (NNS companies) (POS ')) (NNS shares)))) (PP (IN within) (NP (NP (QP (IN about) (DT a)) (NN month)) (PP (IN after) (NP (DT the) (NN transaction))))))))) (. .))
(S (CC But) (NP (NP (QP (IN about) (CD 25)) (NN %)) (PP (IN of) (NP (DT the) (NNS insiders)))) (, ,) (PP (VBG according) (PP (TO to) (NP (NNP SEC) (NNS figures)))) (, ,) (VP (VBP file) (NP (PRP$ their) (NNS reports)) (ADVP (RB late))) (. .))
(SINV (S (NP (DT The) (NNS changes)) (VP (VBD were) (VP (VBN proposed) (PP (IN in) (NP (DT an) (NN effort) (S (VP (TO to) (VP (VP (VB streamline) (NP (JJ federal) (NN bureaucracy))) (CC and) (VP (VB boost) (NP (NP (NN compliance)) (PP (IN by) (NP (NP (DT the) (NNS executives)) (`` ``) (SBAR (WHNP (WP who)) (S (VP (VBP are) (ADVP (RB really)) (VP (VBG calling) (NP (DT the) (NNS shots)))))))))))))))))) (, ,) ('' '') (VP (VBD said)) (NP (NP (NNP Brian) (NNP Lane)) (, ,) (NP (NP (JJ special) (NN counsel)) (PP (IN at) (NP (NP (NP (NP (DT the) (NNP SEC) (POS 's)) (NN office)) (PP (IN of) (NP (NN disclosure) (NN policy)))) (, ,) (SBAR (WHNP (WDT which)) (S (VP (VBD proposed) (NP (DT the) (NNS changes))))))))) (. .))
(S (S (S (NP (NP (NNS Investors)) (, ,) (NP (NN money) (NNS managers)) (CC and) (NP (JJ corporate) (NNS officials))) (VP (VBD had) (PP (IN until) (NP (NN today))) (S (VP (TO to) (VP (VB comment) (PP (IN on) (NP (DT the) (NNS proposals)))))))) (, ,) (CC and) (S (NP (DT the) (NN issue)) (VP (VBZ has) (VP (VBN produced) (NP (NP (JJR more) (NN mail)) (PP (IN than) (NP (NP (ADJP (RB almost) (DT any)) (JJ other) (NN issue)) (PP (IN in) (NP (NN memory)))))))))) (, ,) (NP (NNP Mr.) (NNP Lane)) (VP (VBD said)) (. .))


================================================
FILE: data/valid.txt
================================================
(S (NP (NP (DT The) (NN economy) (POS 's)) (NN temperature)) (VP (MD will) (VP (VB be) (VP (VBN taken) (PP (IN from) (NP (JJ several) (NN vantage) (NNS points))) (NP (DT this) (NN week)) (, ,) (PP (IN with) (NP (NP (NNS readings)) (PP (IN on) (NP (NP (NN trade)) (, ,) (NP (NN output)) (, ,) (NP (NN housing)) (CC and) (NP (NN inflation))))))))) (. .))
(S (NP (DT The) (ADJP (RBS most) (JJ troublesome)) (NN report)) (VP (MD may) (VP (VB be) (NP (NP (DT the) (NNP August) (NN merchandise) (NN trade) (NN deficit)) (ADJP (JJ due) (ADVP (IN out)) (NP (NN tomorrow)))))) (. .))
(S (NP (DT The) (NN trade) (NN gap)) (VP (VBZ is) (VP (VBN expected) (S (VP (TO to) (VP (VB widen) (PP (TO to) (NP (QP (IN about) ($ $) (CD 9) (CD billion)))) (PP (IN from) (NP (NP (NNP July) (POS 's)) (QP ($ $) (CD 7.6) (CD billion))))))) (, ,) (PP (VBG according) (PP (TO to) (NP (NP (DT a) (NN survey)) (PP (IN by) (NP (NP (NNP MMS) (NNP International)) (, ,) (NP (NP (DT a) (NN unit)) (PP (IN of) (NP (NP (NNP McGraw-Hill) (NNP Inc.)) (, ,) (NP (NNP New) (NNP York)))))))))))) (. .))
(S (NP (NP (NP (NNP Thursday) (POS 's)) (NN report)) (PP (IN on) (NP (DT the) (NNP September) (NN consumer) (NN price) (NN index)))) (VP (VBZ is) (VP (VBN expected) (S (VP (TO to) (VP (VB rise) (, ,) (SBAR (IN although) (ADVP (ADVP (RB not) (RB as) (RB sharply)) (PP (IN as) (NP (NP (DT the) (ADJP (CD 0.9) (NN %)) (NN gain)) (VP (VBN reported) (NP (NNP Friday)) (PP (IN in) (NP (DT the) (NN producer) (NN price) (NN index))))))))))))) (. .))
(S (NP (DT That) (NN gain)) (VP (VBD was) (VP (VBG being) (VP (VBD cited) (PP (IN as) (NP (NP (DT a) (NN reason)) (SBAR (S (NP (DT the) (NN stock) (NN market)) (VP (VBD was) (ADVP (IN down)) (ADVP (RB early) (PP (IN in) (NP (NP (NNP Friday) (POS 's)) (NN session)))) (, ,) (SBAR (IN before) (S (NP (PRP it)) (VP (VBD got) (S (VP (VBN started) (PP (IN on) (NP (PRP$ its) (JJ reckless) (JJ 190-point) (NN plunge)))))))))))))))) (. .))
(S (NP (NNS Economists)) (VP (VBP are) (VP (VBN divided) (PP (IN as) (PP (TO to) (SBAR (WHNP (WHADVP (WRB how) (JJ much)) (VBG manufacturing) (NN strength)) (S (NP (PRP they)) (VP (VBP expect) (S (VP (TO to) (VP (VB see) (PP (IN in) (NP (NP (NP (NNP September) (NNS reports)) (PP (IN on) (NP (NP (JJ industrial) (NN production)) (CC and) (NP (NN capacity) (NN utilization))))) (, ,) (ADJP (ADVP (RB also)) (JJ due) (NP (NN tomorrow))))))))))))))) (. .))
(S (ADVP (RB Meanwhile)) (, ,) (NP (NP (NNP September) (NN housing) (NNS starts)) (, ,) (ADJP (JJ due) (NP (NNP Wednesday))) (, ,)) (VP (VBP are) (VP (VBN thought) (S (VP (TO to) (VP (VB have) (VP (VBN inched) (ADVP (RB upward)))))))) (. .))
(SINV (S (`` ``) (NP (EX There)) (VP (VBZ 's) (NP (NP (DT a) (NN possibility)) (PP (IN of) (NP (NP (DT a) (NN surprise)) ('' '') (PP (IN in) (NP (DT the) (NN trade) (NN report)))))))) (, ,) (VP (VBD said)) (NP (NP (NNP Michael) (NNP Englund)) (, ,) (NP (NP (NN director)) (PP (IN of) (NP (NN research))) (PP (IN at) (NP (NNP MMS))))) (. .))
(S (S (NP (NP (DT A) (NN widening)) (PP (IN of) (NP (DT the) (NN deficit)))) (, ,) (SBAR (IN if) (S (NP (PRP it)) (VP (VBD were) (VP (VBN combined) (PP (IN with) (NP (DT a) (ADJP (RB stubbornly) (JJ strong)) (NN dollar))))))) (, ,) (VP (MD would) (VP (VB exacerbate) (NP (NN trade) (NNS problems)))) (: --)) (CC but) (S (NP (DT the) (NN dollar)) (VP (VBD weakened) (NP (NNP Friday)) (SBAR (IN as) (S (NP (NNS stocks)) (VP (VBD plummeted)))))) (. .))
(S (PP (IN In) (NP (DT any) (NN event))) (, ,) (NP (NP (NNP Mr.) (NNP Englund)) (CC and) (NP (JJ many) (NNS others))) (VP (VBP say) (SBAR (IN that) (S (NP (NP (DT the) (JJ easy) (NNS gains)) (PP (IN in) (S (VP (VBG narrowing) (NP (DT the) (NN trade) (NN gap)))))) (VP (VBP have) (ADVP (RB already)) (VP (VBN been) (VP (VBN made))))))) (. .))
(S (`` ``) (S (NP (NN Trade)) (VP (VBZ is) (ADVP (RB definitely)) (VP (VBG going) (S (VP (TO to) (VP (VB be) (ADJP (RBR more) (RB politically) (JJ sensitive)) (PP (IN over) (NP (DT the) (JJ next) (QP (CD six) (CC or) (CD seven)) (NNS months))) (SBAR (IN as) (S (NP (NN improvement)) (VP (VBZ begins) (S (VP (TO to) (VP (VB slow))))))))))))) (, ,) ('' '') (NP (PRP he)) (VP (VBD said)) (. .))
(S (S (NP (NNS Exports)) (VP (VBP are) (VP (VBN thought) (S (VP (TO to) (VP (VB have) (VP (VBN risen) (ADVP (ADVP (RB strongly) (PP (IN in) (NP (NNP August)))) (, ,) (CC but) (ADVP (ADVP (RB probably)) (RB not) (RB enough) (S (VP (TO to) (VP (VB offset) (NP (NP (DT the) (NN jump)) (PP (IN in) (NP (NNS imports)))))))))))))))) (, ,) (NP (NNS economists)) (VP (VBD said)) (. .))
(S (NP (NP (NNS Views)) (PP (IN on) (NP (VBG manufacturing) (NN strength)))) (VP (VBP are) (ADJP (VBN split) (PP (IN between) (NP (NP (NP (NNS economists)) (SBAR (WHNP (WP who)) (S (VP (VBP read) (NP (NP (NP (NNP September) (POS 's)) (JJ low) (NN level)) (PP (IN of) (NP (NN factory) (NN job) (NN growth)))) (PP (IN as) (NP (NP (DT a) (NN sign)) (PP (IN of) (NP (DT a) (NN slowdown))))))))) (CC and) (NP (NP (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBP use) (NP (DT the) (ADJP (RB somewhat) (JJR more) (VBG comforting)) (JJ total) (NN employment) (NNS figures)) (PP (IN in) (NP (PRP$ their) (NNS calculations))))))))))) (. .))
(S (S (NP (NP (DT The) (JJ wide) (NN range)) (PP (IN of) (NP (NP (NNS estimates)) (PP (IN for) (NP (DT the) (JJ industrial) (NN output) (NN number)))))) (VP (VBZ underscores) (NP (DT the) (NNS differences)))) (: :) (S (NP (DT The) (NNS forecasts)) (VP (VBD run) (PP (IN from) (NP (NP (DT a) (NN drop)) (PP (IN of) (NP (CD 0.5) (NN %))))) (PP (TO to) (NP (NP (DT an) (NN increase)) (PP (IN of) (NP (CD 0.4) (NN %))))) (, ,) (PP (VBG according) (PP (TO to) (NP (NNP MMS)))))) (. .))
(S (NP (NP (DT A) (NN rebound)) (PP (IN in) (NP (NN energy) (NNS prices))) (, ,) (SBAR (WHNP (WDT which)) (S (VP (VBD helped) (VP (VB push) (PRT (RP up)) (NP (DT the) (NN producer) (NN price) (NN index)))))) (, ,)) (VP (VBZ is) (VP (VBN expected) (S (VP (TO to) (VP (VB do) (NP (DT the) (JJ same)) (PP (IN in) (NP (DT the) (NN consumer) (NN price) (NN report)))))))) (. .))
(S (NP (DT The) (NN consensus) (NN view)) (VP (VBZ expects) (NP (NP (DT a) (ADJP (CD 0.4) (NN %)) (NN increase)) (PP (IN in) (NP (DT the) (NNP September) (NNP CPI)))) (PP (IN after) (NP (NP (DT a) (JJ flat) (NN reading)) (PP (IN in) (NP (NNP August)))))) (. .))
(S (NP (NP (NNP Robert) (NNP H.) (NNP Chandross)) (, ,) (NP (NP (DT an) (NN economist)) (PP (IN for) (NP (NP (NP (NNP Lloyd) (POS 's)) (NNP Bank)) (PP (IN in) (NP (NNP New) (NNP York)))))) (, ,)) (VP (VBZ is) (PP (IN among) (NP (NP (DT those)) (VP (VBG expecting) (NP (NP (DT a) (ADJP (RBR more) (JJ moderate)) (NN gain)) (PP (IN in) (NP (DT the) (NNP CPI))) (PP (IN than) (PP (IN in) (NP (NP (NNS prices)) (PP (IN at) (NP (DT the) (NN producer) (NN level))))))))))) (. .))
(S (`` ``) (S (S (NP (NN Auto) (NNS prices)) (VP (VBD had) (NP (DT a) (JJ big) (NN effect)) (PP (IN in) (NP (DT the) (NNP PPI))))) (, ,) (CC and) (S (PP (IN at) (NP (DT the) (NNP CPI) (NN level))) (NP (PRP they)) (VP (MD wo) (RB n't)))) (, ,) ('' '') (NP (PRP he)) (VP (VBD said)) (. .))
(SINV (S (S (NP (NN Food) (NNS prices)) (VP (VBP are) (VP (VBN expected) (S (VP (TO to) (VP (VB be) (ADJP (JJ unchanged)))))))) (, ,) (CC but) (S (NP (NN energy) (NNS costs)) (VP (VBD jumped) (NP (NP (RB as) (RB much) (IN as) (CD 4)) (NN %))))) (, ,) (VP (VBD said)) (NP (NP (NNP Gary) (NNP Ciminero)) (, ,) (NP (NP (NN economist)) (PP (IN at) (NP (NNP Fleet\/Norstar) (NNP Financial) (NNP Group))))) (. .))
(S (NP (PRP He)) (ADVP (RB also)) (VP (VBZ says) (SBAR (S (NP (PRP he)) (VP (VBZ thinks) (SBAR (S (NP (`` ``) (NP (NN core) (NN inflation)) (, ,) ('' '') (SBAR (WHNP (WDT which)) (S (VP (VBZ excludes) (NP (DT the) (JJ volatile) (NN food) (CC and) (NN energy) (NNS prices))))) (, ,)) (VP (VBD was) (ADJP (JJ strong)) (NP (JJ last) (NN month))))))))) (. .))


================================================
FILE: data.py
================================================
#!/usr/bin/env python3
import numpy as np
import torch
import pickle

class Dataset(object):
  def __init__(self, data_file):
    data = pickle.load(open(data_file, 'rb')) #get text data
    self.sents = self._convert(data['source']).long()
    self.other_data = data['other_data']
    self.sent_lengths = self._convert(data['source_l']).long()
    self.batch_size = self._convert(data['batch_l']).long()
    self.batch_idx = self._convert(data['batch_idx']).long()
    self.vocab_size = data['vocab_size'][0]
    self.num_batches = self.batch_idx.size(0)
    self.word2idx = data['word2idx']
    self.idx2word = data['idx2word']

  def _convert(self, x):
    return torch.from_numpy(np.asarray(x))

  def __len__(self):
    return self.num_batches

  def __getitem__(self, idx):
    assert(idx < self.num_batches and idx >= 0)
    start_idx = self.batch_idx[idx]
    end_idx = start_idx + self.batch_size[idx]
    length = self.sent_lengths[idx].item()
    sents = self.sents[start_idx:end_idx]
    other_data = self.other_data[start_idx:end_idx]
    sent_str = [d[0] for d in other_data]
    tags = [d[1] for d in other_data]
    actions = [d[2] for d in other_data]
    binary_tree = [d[3] for d in other_data]
    spans = [d[5] for d in other_data]
    batch_size = self.batch_size[idx].item()
    # by default, we return sents with <s> </s> tokens
    # hence we subtract 2 from length as these are (by default) not counted for evaluation
    data_batch = [sents[:, :length], length-2, batch_size, actions, 
                  spans, binary_tree, other_data]
    return data_batch


================================================
FILE: eval_ppl.py
================================================
#!/usr/bin/env python3
import sys
import os

import argparse
import json
import random
import shutil
import copy

import torch
from torch import cuda
import torch.nn as nn
from torch.autograd import Variable
from torch.nn.parameter import Parameter

import torch.nn.functional as F
import numpy as np
import time
import logging
from data import Dataset
from models import RNNG
from utils import *

parser = argparse.ArgumentParser()

# Data path options
parser.add_argument('--test_file', default='data/ptb-test.pkl')
parser.add_argument('--model_file', default='')
parser.add_argument('--is_temp', default=2., type=float, help='divide scores by is_temp before CRF')
parser.add_argument('--samples', default=1000, type=int, help='samples for IS calculation')
parser.add_argument('--count_eos_ppl', default=0, type=int, help='whether to count eos in val PPL')
parser.add_argument('--gpu', default=2, type=int, help='which gpu to use')
parser.add_argument('--seed', default=3435, type=int)


def main(args):
  np.random.seed(args.seed)
  torch.manual_seed(args.seed)
  data = Dataset(args.test_file)  
  checkpoint = torch.load(args.model_file)
  model = checkpoint['model']
  print("model architecture")
  print(model)
  cuda.set_device(args.gpu)
  model.cuda()
  model.eval()
  num_sents = 0
  num_words = 0
  total_nll_recon = 0.
  total_kl = 0.
  total_nll_iwae = 0.
  samples_batch = 50
  S = args.samples // samples_batch  
  samples = S*samples_batch
  with torch.no_grad():
    for i in list(reversed(range(len(data)))):
      sents, length, batch_size, gold_actions, gold_spans, gold_binary_trees, other_data = data[i] 
      if length == 1:
        # length 1 sents are ignored since our generative model requires sents of length >= 2
        continue
      if args.count_eos_ppl == 1:
        length += 1
      else:
        sents = sents[:, :-1]
      sents = sents.cuda()
      ll_word_all2 = [] 
      ll_action_p_all2 = [] 
      ll_action_q_all2 = [] 
      for j in range(S):                    
        ll_word_all, ll_action_p_all, ll_action_q_all, actions_all, q_entropy = model(
          sents, samples = samples_batch, is_temp = args.is_temp, has_eos = args.count_eos_ppl == 1)
        ll_word_all2.append(ll_word_all.detach().cpu())
        ll_action_p_all2.append(ll_action_p_all.detach().cpu())
        ll_action_q_all2.append(ll_action_q_all.detach().cpu())
      ll_word_all2 = torch.cat(ll_word_all2, 1)
      ll_action_p_all2 = torch.cat(ll_action_p_all2, 1)
      ll_action_q_all2 = torch.cat(ll_action_q_all2, 1)
      sample_ll = torch.zeros(batch_size, ll_word_all2.size(1))
      total_nll_recon += -ll_word_all.mean(1).sum().item()
      total_kl += (ll_action_q_all - ll_action_p_all).mean(1).sum().item()
      for j in range(sample_ll.size(1)):
        ll_word_j, ll_action_p_j, ll_action_q_j = ll_word_all2[:, j], ll_action_p_all2[:, j], ll_action_q_all2[:, j]
        sample_ll[:, j].copy_(ll_word_j + ll_action_p_j - ll_action_q_j)
      ll_iwae = model.logsumexp(sample_ll, 1) - np.log(samples)
      total_nll_iwae -= ll_iwae.sum().item()      
      num_sents += batch_size
      num_words += batch_size * length
      
      print('Batch: %d/%d, ElboPPL: %.2f, KL: %.4f, IwaePPL: %.2f' % 
            (i, len(data), np.exp((total_nll_recon + total_kl) / num_words),
            total_kl / num_sents, np.exp(total_nll_iwae / num_words)))
  elbo_ppl = np.exp((total_nll_recon + total_kl) / num_words)
  recon_ppl = np.exp(total_nll_recon / num_words)
  iwae_ppl = np.exp(total_nll_iwae /num_words)
  kl = total_kl / num_sents  
  print('ElboPPL: %.2f, ReconPPL: %.2f, KL: %.4f, IwaePPL: %.2f' % 
        (elbo_ppl, recon_ppl, kl, iwae_ppl))

if __name__ == '__main__':
  args = parser.parse_args()
  main(args)


================================================
FILE: models.py
================================================
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from utils import *
from TreeCRF import ConstituencyTreeCRF
from torch.distributions import Bernoulli

class RNNLM(nn.Module):
  def __init__(self, vocab=10000,
               w_dim=650,
               h_dim=650,
               num_layers=2,
               dropout=0.5):
    super(RNNLM, self).__init__()
    self.h_dim = h_dim
    self.num_layers = num_layers    
    self.word_vecs = nn.Embedding(vocab, w_dim)
    self.dropout = nn.Dropout(dropout)
    self.rnn = nn.LSTM(w_dim, h_dim, num_layers = num_layers,
                       dropout = dropout, batch_first = True)      
    self.vocab_linear =  nn.Linear(h_dim, vocab)
    self.vocab_linear.weight = self.word_vecs.weight # weight sharing

  def forward(self, sent):
    word_vecs = self.dropout(self.word_vecs(sent[:, :-1]))
    h, _ = self.rnn(word_vecs)
    log_prob = F.log_softmax(self.vocab_linear(self.dropout(h)), 2) # b x l x v
    ll = torch.gather(log_prob, 2, sent[:, 1:].unsqueeze(2)).squeeze(2)
    return ll.sum(1)
  
  def generate(self, bos = 2, eos = 3, max_len = 150):
    x = []
    bos = torch.LongTensor(1,1).cuda().fill_(bos)
    emb = self.dropout(self.word_vecs(bos))
    prev_h = None
    for l in range(max_len):
      h, prev_h = self.rnn(emb, prev_h)
      prob = F.softmax(self.vocab_linear(self.dropout(h.squeeze(1))), 1)
      sample = torch.multinomial(prob, 1)
      emb = self.dropout(self.word_vecs(sample))
      x.append(sample.item())
      if x[-1] == eos:
        x.pop()
        break
    return x

class SeqLSTM(nn.Module):
  def __init__(self, i_dim = 200,
               h_dim = 0,
               num_layers = 1,
               dropout = 0):
    super(SeqLSTM, self).__init__()    
    self.i_dim = i_dim
    self.h_dim = h_dim
    self.num_layers = num_layers
    self.linears = nn.ModuleList([nn.Linear(h_dim + i_dim, h_dim*4) if l == 0 else
                                  nn.Linear(h_dim*2, h_dim*4) for l in range(num_layers)])
    self.dropout = dropout
    self.dropout_layer = nn.Dropout(dropout)

  def forward(self, x, prev_h = None):
    if prev_h is None:
      prev_h = [(x.new(x.size(0), self.h_dim).fill_(0),
                 x.new(x.size(0), self.h_dim).fill_(0)) for _ in range(self.num_layers)]
    curr_h = []
    for l in range(self.num_layers):
      input = x if l == 0 else curr_h[l-1][0]
      if l > 0 and self.dropout > 0:
        input = self.dropout_layer(input)
      concat = torch.cat([input, prev_h[l][0]], 1)
      all_sum = self.linears[l](concat)
      i, f, o, g = all_sum.split(self.h_dim, 1)
      c = F.sigmoid(f)*prev_h[l][1] + F.sigmoid(i)*F.tanh(g)
      h = F.sigmoid(o)*F.tanh(c)
      curr_h.append((h, c))
    return curr_h

class TreeLSTM(nn.Module):
  def __init__(self, dim = 200):
    super(TreeLSTM, self).__init__()
    self.dim = dim
    self.linear = nn.Linear(dim*2, dim*5)

  def forward(self, x1, x2, e=None):
    if not isinstance(x1, tuple):
      x1 = (x1, None)    
    h1, c1 = x1 
    if x2 is None: 
      x2 = (torch.zeros_like(h1), torch.zeros_like(h1))
    elif not isinstance(x2, tuple):
      x2 = (x2, None)    
    h2, c2 = x2
    if c1 is None:
      c1 = torch.zeros_like(h1)
    if c2 is None:
      c2 = torch.zeros_like(h2)
    concat = torch.cat([h1, h2], 1)
    all_sum = self.linear(concat)
    i, f1, f2, o, g = all_sum.split(self.dim, 1)

    c = F.sigmoid(f1)*c1 + F.sigmoid(f2)*c2 + F.sigmoid(i)*F.tanh(g)
    h = F.sigmoid(o)*F.tanh(c)
    return (h, c)
      
class RNNG(nn.Module):
  def __init__(self, vocab = 100,
               w_dim = 20, 
               h_dim = 20,
               num_layers = 1,
               dropout = 0,
               q_dim = 20,
               max_len = 250):
    super(RNNG, self).__init__()
    self.S = 0 #action idx for shift/generate
    self.R = 1 #action idx for reduce
    self.emb = nn.Embedding(vocab, w_dim)
    self.dropout = nn.Dropout(dropout)    
    self.stack_rnn = SeqLSTM(w_dim, h_dim, num_layers = num_layers, dropout = dropout)
    self.tree_rnn = TreeLSTM(w_dim)
    self.vocab_mlp = nn.Sequential(nn.Dropout(dropout), nn.Linear(h_dim, vocab))
    self.num_layers = num_layers
    self.q_binary = nn.Sequential(nn.Linear(q_dim*2, q_dim*2), nn.ReLU(), nn.LayerNorm(q_dim*2),
                                  nn.Dropout(dropout), nn.Linear(q_dim*2, 1))
    self.action_mlp_p = nn.Sequential(nn.Dropout(dropout), nn.Linear(h_dim, 1))
    self.w_dim = w_dim
    self.h_dim = h_dim
    self.q_dim = q_dim    
    self.q_leaf_rnn = nn.LSTM(w_dim, q_dim, bidirectional = True, batch_first = True)
    self.q_crf = ConstituencyTreeCRF()
    self.pad1 = 0 # idx for <s> token from ptb.dict
    self.pad2 = 2 # idx for </s> token from ptb.dict 
    self.q_pos_emb = nn.Embedding(max_len, w_dim) # position embeddings
    self.vocab_mlp[-1].weight = self.emb.weight #share embeddings

  def get_span_scores(self, x):
    #produces the span scores s_ij
    bos = x.new(x.size(0), 1).fill_(self.pad1)
    eos  = x.new(x.size(0), 1).fill_(self.pad2)
    x = torch.cat([bos, x, eos], 1)
    x_vec = self.dropout(self.emb(x))
    pos = torch.arange(0, x.size(1)).unsqueeze(0).expand_as(x).long().cuda()
    x_vec = x_vec + self.dropout(self.q_pos_emb(pos))
    q_h, _ = self.q_leaf_rnn(x_vec)
    fwd = q_h[:, 1:, :self.q_dim]
    bwd = q_h[:, :-1, self.q_dim:]
    fwd_diff = fwd[:, 1:].unsqueeze(1) - fwd[:, :-1].unsqueeze(2)
    bwd_diff = bwd[:, :-1].unsqueeze(2) - bwd[:, 1:].unsqueeze(1)
    concat = torch.cat([fwd_diff, bwd_diff], 3)
    scores = self.q_binary(concat).squeeze(3)
    return scores

  def get_action_masks(self, actions, length):
    #this masks out actions so that we don't incur a loss if some actions are deterministic
    #in practice this doesn't really seem to matter
    mask = actions.new(actions.size(0), actions.size(1)).fill_(1)
    for b in range(actions.size(0)):      
      num_shift = 0
      stack_len = 0
      for l in range(actions.size(1)):
        if stack_len < 2:
          mask[b][l].fill_(0)
        if actions[b][l].item() == self.S:
          num_shift += 1
          stack_len += 1
        else:
          stack_len -= 1
    return mask

  def forward(self, x, samples = 1, is_temp = 1., has_eos=True):
    #For has eos, if </s> exists, then inference network ignores it. 
    #Note that </s> is predicted for training since we want the model to know when to stop.
    #However it is ignored for PPL evaluation on the version of the PTB dataset from
    #the original RNNG paper (Dyer et al. 2016)
    init_emb = self.dropout(self.emb(x[:, 0]))
    x = x[:, 1:]
    batch, length = x.size(0), x.size(1)
    if has_eos: 
      parse_length = length - 1
      parse_x = x[:, :-1]
    else:
      parse_length = length
      parse_x = x
    word_vecs =  self.dropout(self.emb(x))
    scores = self.get_span_scores(parse_x)
    self.scores = scores
    scores = scores / is_temp
    self.q_crf._forward(scores)
    self.q_crf._entropy(scores)
    entropy = self.q_crf.entropy[0][parse_length-1]
    crf_input = scores.unsqueeze(1).expand(batch, samples, parse_length, parse_length)
    crf_input = crf_input.contiguous().view(batch*samples, parse_length, parse_length)
    for i in range(len(self.q_crf.alpha)):
      for j in range(len(self.q_crf.alpha)):
        self.q_crf.alpha[i][j] = self.q_crf.alpha[i][j].unsqueeze(1).expand(
          batch, samples).contiguous().view(batch*samples)        
    _, log_probs_action_q, tree_brackets, spans = self.q_crf._sample(crf_input, self.q_crf.alpha)
    actions = []
    for b in range(crf_input.size(0)):    
      action = get_actions(tree_brackets[b])
      if has_eos:
        actions.append(action + [self.S, self.R]) #we train the model to generate <s> and then do a final reduce
      else:
        actions.append(action)
    actions = torch.Tensor(actions).float().cuda()
    action_masks = self.get_action_masks(actions, length) 
    num_action = 2*length - 1
    batch_expand = batch*samples
    contexts = []
    log_probs_action_p = [] #conditional prior
    init_emb = init_emb.unsqueeze(1).expand(batch, samples, self.w_dim)
    init_emb = init_emb.contiguous().view(batch_expand, self.w_dim)
    init_stack = self.stack_rnn(init_emb, None)
    x_expand = x.unsqueeze(1).expand(batch, samples, length)
    x_expand = x_expand.contiguous().view(batch_expand, length)
    word_vecs = self.dropout(self.emb(x_expand))
    word_vecs = word_vecs.unsqueeze(2)
    word_vecs_zeros = torch.zeros_like(word_vecs)
    stack = [init_stack]
    stack_child = [[] for _ in range(batch_expand)]
    stack2 = [[] for _ in range(batch_expand)]
    for b in range(batch_expand):
      stack2[b].append([[init_stack[l][0][b], init_stack[l][1][b]] for l in range(self.num_layers)])
    pointer = [0]*batch_expand
    for l in range(num_action):
      contexts.append(stack[-1][-1][0])
      stack_input = []
      child1_h = []
      child1_c = []
      child2_h = []
      child2_c = []
      stack_context = []
      for b in range(batch_expand):
        # batch all the shift/reduce operations separately
        if actions[b][l].item() == self.R:
          child1 = stack_child[b].pop()
          child2 = stack_child[b].pop()
          child1_h.append(child1[0])
          child1_c.append(child1[1])
          child2_h.append(child2[0])
          child2_c.append(child2[1])
          stack2[b].pop()
          stack2[b].pop()
      if len(child1_h) > 0:
        child1_h = torch.cat(child1_h, 0)
        child1_c = torch.cat(child1_c, 0)
        child2_h = torch.cat(child2_h, 0)
        child2_c = torch.cat(child2_c, 0)
        new_child = self.tree_rnn((child1_h, child1_c), (child2_h, child2_c))

      child_idx = 0
      stack_h = [[[], []] for _ in range(self.num_layers)]
      for b in range(batch_expand):
        assert(len(stack2[b]) - 1 == len(stack_child[b]))
        for k in range(self.num_layers):
          stack_h[k][0].append(stack2[b][-1][k][0])
          stack_h[k][1].append(stack2[b][-1][k][1])
        if actions[b][l].item() == self.S:          
          input_b = word_vecs[b][pointer[b]]
          stack_child[b].append((word_vecs[b][pointer[b]], word_vecs_zeros[b][pointer[b]]))
          pointer[b] += 1          
        else:
          input_b = new_child[0][child_idx].unsqueeze(0)
          stack_child[b].append((input_b, new_child[1][child_idx].unsqueeze(0)))
          child_idx += 1
        stack_input.append(input_b)
      stack_input = torch.cat(stack_input, 0)
      stack_h_all = []
      for k in range(self.num_layers):
        stack_h_all.append((torch.stack(stack_h[k][0], 0), torch.stack(stack_h[k][1], 0)))
      stack_h = self.stack_rnn(stack_input, stack_h_all)
      stack.append(stack_h)
      for b in range(batch_expand):
        stack2[b].append([[stack_h[k][0][b], stack_h[k][1][b]] for k in range(self.num_layers)])
      
    contexts = torch.stack(contexts, 1) #stack contexts
    action_logit_p = self.action_mlp_p(contexts).squeeze(2) 
    action_prob_p = F.sigmoid(action_logit_p).clamp(min=1e-7, max=1-1e-7)
    action_shift_score = (1 - action_prob_p).log()
    action_reduce_score = action_prob_p.log()
    action_score = (1-actions)*action_shift_score + actions*action_reduce_score
    action_score = (action_score*action_masks).sum(1)
    
    word_contexts = contexts[actions < 1]
    word_contexts = word_contexts.contiguous().view(batch*samples, length, self.h_dim)

    log_probs_word = F.log_softmax(self.vocab_mlp(word_contexts), 2)
    log_probs_word = torch.gather(log_probs_word, 2, x_expand.unsqueeze(2)).squeeze(2)
    log_probs_word = log_probs_word.sum(1)
    log_probs_word = log_probs_word.contiguous().view(batch, samples)
    log_probs_action_p = action_score.contiguous().view(batch, samples)
    log_probs_action_q = log_probs_action_q.contiguous().view(batch, samples)
    actions = actions.contiguous().view(batch, samples, -1)
    return log_probs_word, log_probs_action_p, log_probs_action_q, actions, entropy

  def forward_actions(self, x, actions, has_eos=True):
    # this is for when ground through actions are available
    init_emb = self.dropout(self.emb(x[:, 0]))
    x = x[:, 1:]    
    if has_eos:
      new_actions = []
      for action in actions:
        new_actions.append(action + [self.S, self.R])
      actions = new_actions
    batch, length = x.size(0), x.size(1)
    word_vecs =  self.dropout(self.emb(x))
    actions = torch.Tensor(actions).float().cuda()
    action_masks = self.get_action_masks(actions, length)
    num_action = 2*length - 1
    contexts = []
    log_probs_action_p = [] #prior
    init_stack = self.stack_rnn(init_emb, None)
    word_vecs = word_vecs.unsqueeze(2)
    word_vecs_zeros = torch.zeros_like(word_vecs)
    stack = [init_stack]
    stack_child = [[] for _ in range(batch)]
    stack2 = [[] for _ in range(batch)]
    pointer = [0]*batch
    for b in range(batch):
      stack2[b].append([[init_stack[l][0][b], init_stack[l][1][b]] for l in range(self.num_layers)])
    for l in range(num_action):
      contexts.append(stack[-1][-1][0])
      stack_input = []
      child1_h = []
      child1_c = []
      child2_h = []
      child2_c = []
      stack_context = []
      for b in range(batch):
        if actions[b][l].item() == self.R:
          child1 = stack_child[b].pop()
          child2 = stack_child[b].pop()
          child1_h.append(child1[0])
          child1_c.append(child1[1])
          child2_h.append(child2[0])
          child2_c.append(child2[1])
          stack2[b].pop()
          stack2[b].pop()
      if len(child1_h) > 0:
        child1_h = torch.cat(child1_h, 0)
        child1_c = torch.cat(child1_c, 0)
        child2_h = torch.cat(child2_h, 0)
        child2_c = torch.cat(child2_c, 0)
        new_child = self.tree_rnn((child1_h, child1_c), (child2_h, child2_c))
      child_idx = 0
      stack_h = [[[], []] for _ in range(self.num_layers)]
      for b in range(batch):
        assert(len(stack2[b]) - 1 == len(stack_child[b]))
        for k in range(self.num_layers):
          stack_h[k][0].append(stack2[b][-1][k][0])
          stack_h[k][1].append(stack2[b][-1][k][1])
        if actions[b][l].item() == self.S:          
          input_b = word_vecs[b][pointer[b]]
          stack_child[b].append((word_vecs[b][pointer[b]], word_vecs_zeros[b][pointer[b]]))
          pointer[b] += 1          
        else:
          input_b = new_child[0][child_idx].unsqueeze(0)
          stack_child[b].append((input_b, new_child[1][child_idx].unsqueeze(0)))
          child_idx += 1
        stack_input.append(input_b)
      stack_input = torch.cat(stack_input, 0)
      stack_h_all = []
      for k in range(self.num_layers):
        stack_h_all.append((torch.stack(stack_h[k][0], 0), torch.stack(stack_h[k][1], 0)))
      stack_h = self.stack_rnn(stack_input, stack_h_all)
      stack.append(stack_h)
      for b in range(batch):
        stack2[b].append([[stack_h[k][0][b], stack_h[k][1][b]] for k in range(self.num_layers)])
    contexts = torch.stack(contexts, 1)
    action_logit_p = self.action_mlp_p(contexts).squeeze(2)
    action_prob_p = F.sigmoid(action_logit_p).clamp(min=1e-7, max=1-1e-7)
    action_shift_score = (1 - action_prob_p).log()
    action_reduce_score = action_prob_p.log()
    action_score = (1-actions)*action_shift_score + actions*action_reduce_score
    action_score = (action_score*action_masks).sum(1)
    
    word_contexts = contexts[actions < 1]
    word_contexts = word_contexts.contiguous().view(batch, length, self.h_dim)
    log_probs_word = F.log_softmax(self.vocab_mlp(word_contexts), 2)
    log_probs_word = torch.gather(log_probs_word, 2, x.unsqueeze(2)).squeeze(2).sum(1)
    log_probs_action_p = action_score.contiguous().view(batch)
    actions = actions.contiguous().view(batch, 1, -1)
    return log_probs_word, log_probs_action_p, actions
  
  def forward_tree(self, x, actions, has_eos=True):
    # this is log q( tree | x) for discriminative parser training in supervised RNNG
    init_emb = self.dropout(self.emb(x[:, 0]))
    x = x[:, 1:-1]
    batch, length = x.size(0), x.size(1)
    scores = self.get_span_scores(x)
    crf_input = scores
    gold_spans = scores.new(batch, length, length)
    for b in range(batch):
      gold_spans[b].copy_(torch.eye(length).cuda())
      spans = get_spans(actions[b])
      for span in spans:
        gold_spans[b][span[0]][span[1]] = 1
    self.q_crf._forward(crf_input)
    log_Z = self.q_crf.alpha[0][length-1]
    span_scores = (gold_spans*scores).sum(2).sum(1)
    ll_action_q = span_scores - log_Z
    return ll_action_q
    
  def logsumexp(self, x, dim=1):
    d = torch.max(x, dim)[0]    
    if x.dim() == 1:
      return torch.log(torch.exp(x - d).sum(dim)) + d
    else:
      return torch.log(torch.exp(x - d.unsqueeze(dim).expand_as(x)).sum(dim)) + d    
    


================================================
FILE: parse.py
================================================
#!/usr/bin/env python3
import sys
import os

import argparse
import json
import random
import shutil
import copy

import torch
from torch import cuda
import torch.nn as nn
import numpy as np
import time
from utils import *
import utils
import re

parser = argparse.ArgumentParser()

# Data path options
parser.add_argument('--data_file', default='ptb-test.txt')
parser.add_argument('--model_file', default='urnng.pt')
parser.add_argument('--out_file', default='pred-parse.txt')
parser.add_argument('--gold_out_file', default='gold-parse.txt')
parser.add_argument('--lowercase', type=int, default=0)
parser.add_argument('--replace_num', type=int, default=0)

# Inference options
parser.add_argument('--gpu', default=0, type=int, help='which gpu to use')

def is_next_open_bracket(line, start_idx):
    for char in line[(start_idx + 1):]:
        if char == '(':
            return True
        elif char == ')':
            return False
    raise IndexError('Bracket possibly not balanced, open bracket not followed by closed bracket')    

def get_between_brackets(line, start_idx):
    output = []
    for char in line[(start_idx + 1):]:
        if char == ')':
            break
        assert not(char == '(')
        output.append(char)    
    return ''.join(output)

def get_tags_tokens_lowercase(line):
    output = []
    line_strip = line.rstrip()
    for i in range(len(line_strip)):
        if i == 0:
            assert line_strip[i] == '('    
        if line_strip[i] == '(' and not(is_next_open_bracket(line_strip, i)): # fulfilling this condition means this is a terminal symbol
            output.append(get_between_brackets(line_strip, i))
    #print 'output:',output
    output_tags = []
    output_tokens = []
    output_lowercase = []
    for terminal in output:
        terminal_split = terminal.split()
        assert len(terminal_split) == 2 # each terminal contains a POS tag and word        
        output_tags.append(terminal_split[0])
        output_tokens.append(terminal_split[1])
        output_lowercase.append(terminal_split[1].lower())
    return [output_tags, output_tokens, output_lowercase]    

def get_nonterminal(line, start_idx):
    assert line[start_idx] == '(' # make sure it's an open bracket
    output = []
    for char in line[(start_idx + 1):]:
        if char == ' ':
            break
        assert not(char == '(') and not(char == ')')
        output.append(char)
    return ''.join(output)


def get_actions(line):
    output_actions = []
    line_strip = line.rstrip()
    i = 0
    max_idx = (len(line_strip) - 1)
    while i <= max_idx:
        assert line_strip[i] == '(' or line_strip[i] == ')'
        if line_strip[i] == '(':
            if is_next_open_bracket(line_strip, i): # open non-terminal
                curr_NT = get_nonterminal(line_strip, i)
                output_actions.append('NT(' + curr_NT + ')')
                i += 1  
                while line_strip[i] != '(': # get the next open bracket, which may be a terminal or another non-terminal
                    i += 1
            else: # it's a terminal symbol
                output_actions.append('SHIFT')
                while line_strip[i] != ')':
                    i += 1
                i += 1
                while line_strip[i] != ')' and line_strip[i] != '(':
                    i += 1
        else:
             output_actions.append('REDUCE')
             if i == max_idx:
                 break
             i += 1
             while line_strip[i] != ')' and line_strip[i] != '(':
                 i += 1
    assert i == max_idx  
    return output_actions

def clean_number(w):    
    new_w = re.sub('[0-9]{1,}([,.]?[0-9]*)*', 'N', w)
    return new_w
  
def main(args):
  print('loading model from ' + args.model_file)
  checkpoint = torch.load(args.model_file)
  model = checkpoint['model']
  word2idx = checkpoint['word2idx']
  cuda.set_device(args.gpu)
  model.eval()
  model.cuda()
  corpus_f1 = [0., 0., 0.] 
  sent_f1 = [] 
  pred_out = open(args.out_file, "w")
  gold_out = open(args.gold_out_file, "w")
  with torch.no_grad():
    for j, gold_tree in enumerate(open(args.data_file, "r")):
      tree = gold_tree.strip()
      action = get_actions(tree)
      tags, sent, sent_lower = get_tags_tokens_lowercase(tree)
      sent_orig = sent[::]
      if args.lowercase == 1:
          sent = sent_lower
      gold_span, binary_actions, nonbinary_actions = get_nonbinary_spans(action)
      length = len(sent)
      if args.replace_num == 1:
          sent = [clean_number(w) for w in sent]
      if length == 1:
        continue # we ignore length 1 sents. this doesn't change F1 since we discard trivial spans
      sent_idx = [word2idx["<s>"]] + [word2idx[w] if w in word2idx else word2idx["<unk>"] for w in sent]
      sents = torch.from_numpy(np.array(sent_idx)).unsqueeze(0)
      sents = sents.cuda()
      ll_word_all, ll_action_p_all, ll_action_q_all, actions_all, q_entropy = model(
          sents, samples = 1, is_temp = 1, has_eos = False)
      _, binary_matrix, argmax_spans = model.q_crf._viterbi(model.scores)
      tree = get_tree_from_binary_matrix(binary_matrix[0], len(sent))
      actions = utils.get_actions(tree)
      pred_span= [(a[0], a[1]) for a in argmax_spans[0]]
      pred_span_set = set(pred_span[:-1]) #the last span in the list is always the
      gold_span_set = set(gold_span[:-1]) #trival sent-level span so we ignore it
      tp, fp, fn = get_stats(pred_span_set, gold_span_set) 
      corpus_f1[0] += tp
      corpus_f1[1] += fp
      corpus_f1[2] += fn
      binary_matrix = binary_matrix[0].cpu().numpy()
      pred_tree = {}
      for i in range(length):
        tag = tags[i] # need gold tags so evalb correctly ignores punctuation
        pred_tree[i] = "(" + tag + " " + sent_orig[i] + ")"
      for k in np.arange(1, length):
        for s in np.arange(length):
          t = s + k
          if t > length - 1: break
          if binary_matrix[s][t] == 1:
            nt = "NT-1" 
            span = "(" + nt + " " + pred_tree[s] + " " + pred_tree[t] +  ")"
            pred_tree[s] = span
            pred_tree[t] = span
      pred_tree = pred_tree[0]
      pred_out.write(pred_tree.strip() + "\n")
      gold_out.write(gold_tree.strip() + "\n")
      print(pred_tree)
      # sent-level F1 is based on L83-89 from https://github.com/yikangshen/PRPN/test_phrase_grammar.py
      overlap = pred_span_set.intersection(gold_span_set)
      prec = float(len(overlap)) / (len(pred_span_set) + 1e-8)
      reca = float(len(overlap)) / (len(gold_span_set) + 1e-8)
      if len(gold_span_set) == 0:
        reca = 1. 
        if len(pred_span_set) == 0:
          prec = 1.
      f1 = 2 * prec * reca / (prec + reca + 1e-8)
      sent_f1.append(f1)
  pred_out.close()
  gold_out.close()
  tp, fp, fn = corpus_f1  
  prec = tp / (tp + fp)
  recall = tp / (tp + fn)
  corpus_f1 = 2*prec*recall/(prec+recall) if prec+recall > 0 else 0.
  print('Corpus F1: %.2f, Sentence F1: %.2f' %
        (corpus_f1*100, np.mean(np.array(sent_f1))*100))

if __name__ == '__main__':
  args = parser.parse_args()
  main(args)


================================================
FILE: preprocess.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""Create data files
"""

import os
import sys
import argparse
import numpy as np
import pickle
import itertools
from collections import defaultdict
import utils
import re

class Indexer:
    def __init__(self, symbols = ["<pad>","<unk>","<s>","</s>"]):
        self.vocab = defaultdict(int)
        self.PAD = symbols[0]
        self.UNK = symbols[1]
        self.BOS = symbols[2]
        self.EOS = symbols[3]
        self.d = {self.PAD: 0, self.UNK: 1, self.BOS: 2, self.EOS: 3}
        self.idx2word = {}
        
    def add_w(self, ws):
        for w in ws:
            if w not in self.d:
                self.d[w] = len(self.d)

    def convert(self, w):
        return self.d[w] if w in self.d else self.d[self.UNK]

    def convert_sequence(self, ls):
        return [self.convert(l) for l in ls]

    def write(self, outfile):
        out = open(outfile, "w")
        items = [(v, k) for k, v in self.d.items()]
        items.sort()
        for v, k in items:
            out.write(" ".join([k, str(v)]) + "\n")
        out.close()

    def prune_vocab(self, k, cnt = False):
        vocab_list = [(word, count) for word, count in self.vocab.items()]
        if cnt:
            self.pruned_vocab = {pair[0]:pair[1] for pair in vocab_list if pair[1] > k}
        else:
            vocab_list.sort(key = lambda x: x[1], reverse=True)
            k = min(k, len(vocab_list))
            self.pruned_vocab = {pair[0]:pair[1] for pair in vocab_list[:k]}
        for word in self.pruned_vocab:
            if word not in self.d:
                self.d[word] = len(self.d)
        for word, idx in self.d.items():
            self.idx2word[idx] = word

    def load_vocab(self, vocab_file):
        self.d = {}
        self.idx2word = {}
        for line in open(vocab_file, 'r'):
            v, k = line.strip().split()
            self.d[v] = int(k)
        for word, idx in self.d.items():
            self.idx2word[idx] = word


def is_next_open_bracket(line, start_idx):
    for char in line[(start_idx + 1):]:
        if char == '(':
            return True
        elif char == ')':
            return False
    raise IndexError('Bracket possibly not balanced, open bracket not followed by closed bracket')    

def get_between_brackets(line, start_idx):
    output = []
    for char in line[(start_idx + 1):]:
        if char == ')':
            break
        assert not(char == '(')
        output.append(char)    
    return ''.join(output)

def get_tags_tokens_lowercase(line):
    output = []
    line_strip = line.rstrip()
    for i in range(len(line_strip)):
        if i == 0:
            assert line_strip[i] == '('    
        if line_strip[i] == '(' and not(is_next_open_bracket(line_strip, i)): # fulfilling this condition means this is a terminal symbol
            output.append(get_between_brackets(line_strip, i))
    #print 'output:',output
    output_tags = []
    output_tokens = []
    output_lowercase = []
    for terminal in output:
        terminal_split = terminal.split()
        # print(terminal, terminal_split)
        assert len(terminal_split) == 2 # each terminal contains a POS tag and word        
        output_tags.append(terminal_split[0])
        output_tokens.append(terminal_split[1])
        output_lowercase.append(terminal_split[1].lower())
    return [output_tags, output_tokens, output_lowercase]    

def get_nonterminal(line, start_idx):
    assert line[start_idx] == '(' # make sure it's an open bracket
    output = []
    for char in line[(start_idx + 1):]:
        if char == ' ':
            break
        assert not(char == '(') and not(char == ')')
        output.append(char)
    return ''.join(output)


def get_actions(line):
    output_actions = []
    line_strip = line.rstrip()
    i = 0
    max_idx = (len(line_strip) - 1)
    while i <= max_idx:
        assert line_strip[i] == '(' or line_strip[i] == ')'
        if line_strip[i] == '(':
            if is_next_open_bracket(line_strip, i): # open non-terminal
                curr_NT = get_nonterminal(line_strip, i)
                output_actions.append('NT(' + curr_NT + ')')
                i += 1  
                while line_strip[i] != '(': # get the next open bracket, which may be a terminal or another non-terminal
                    i += 1
            else: # it's a terminal symbol
                output_actions.append('SHIFT')
                while line_strip[i] != ')':
                    i += 1
                i += 1
                while line_strip[i] != ')' and line_strip[i] != '(':
                    i += 1
        else:
             output_actions.append('REDUCE')
             if i == max_idx:
                 break
             i += 1
             while line_strip[i] != ')' and line_strip[i] != '(':
                 i += 1
    assert i == max_idx  
    return output_actions

def pad(ls, length, symbol):
    if len(ls) >= length:
        return ls[:length]
    return ls + [symbol] * (length -len(ls))

def clean_number(w):    
    new_w = re.sub('[0-9]{1,}([,.]?[0-9]*)*', 'N', w)
    return new_w

def get_data(args):
    indexer = Indexer(["<pad>","<unk>","<s>","</s>"])

    def make_vocab(textfile, seqlength, minseqlength, lowercase, replace_num,
                   train=1, apply_length_filter=1):
        num_sents = 0
        max_seqlength = 0
        for tree in open(textfile, 'r'):
            tree = tree.strip()
            tags, sent, sent_lower = get_tags_tokens_lowercase(tree)
            
            assert(len(tags) == len(sent))
            if lowercase == 1:
                sent = sent_lower
            if replace_num == 1:
                sent = [clean_number(w) for w in sent]
            if (len(sent) > seqlength and apply_length_filter == 1) or len(sent) < minseqlength:
                continue
            num_sents += 1
            max_seqlength = max(max_seqlength, len(sent))
            if train == 1:
                for word in sent:
                    indexer.vocab[word] += 1
        return num_sents, max_seqlength

    def convert(textfile, lowercase, replace_num,  
                batchsize, seqlength, minseqlength, outfile, num_sents, max_sent_l=0,
                shuffle=0, include_boundary=1, apply_length_filter=1):
        newseqlength = seqlength
        if include_boundary == 1:
            newseqlength += 2 #add 2 for EOS and BOS
        sents = np.zeros((num_sents, newseqlength), dtype=int)
        sent_lengths = np.zeros((num_sents,), dtype=int)
        dropped = 0
        sent_id = 0
        other_data = []
        for tree in open(textfile, 'r'):
            tree = tree.strip()
            action = get_actions(tree)
            tags, sent, sent_lower = get_tags_tokens_lowercase(tree)
            assert(len(tags) == len(sent))
            if lowercase == 1:
                sent = sent_lower
            if (len(sent) > seqlength and apply_length_filter == 1) or len(sent) < minseqlength:
                continue
            sent_str = " ".join(sent)
            if replace_num == 1:
                sent = [clean_number(w) for w in sent]
            if include_boundary == 1:
                sent = [indexer.BOS] + sent + [indexer.EOS]
            max_sent_l = max(len(sent), max_sent_l)
            sent_pad = pad(sent, newseqlength, indexer.PAD)
            sents[sent_id] = np.array(indexer.convert_sequence(sent_pad), dtype=int)
            sent_lengths[sent_id] = (sents[sent_id] != 0).sum()
            span, binary_actions, nonbinary_actions = utils.get_nonbinary_spans(action)
            other_data.append([sent_str, tags, action, 
                               binary_actions, nonbinary_actions, span, tree])
            assert(2*(len(sent)- 2) - 1 == len(binary_actions))
            assert(sum(binary_actions) + 1 == len(sent) - 2)
            sent_id += 1
            if sent_id % 100000 == 0:
                print("{}/{} sentences processed".format(sent_id, num_sents))
        print(sent_id, num_sents)
        if shuffle == 1:
            rand_idx = np.random.permutation(sent_id)
            sents = sents[rand_idx]
            sent_lengths = sent_lengths[rand_idx]
            other_data = [other_data[idx] for idx in rand_idx]

        print(len(sents), len(other_data))
        #break up batches based on source lengths
        sent_lengths = sent_lengths[:sent_id]
        sent_sort = np.argsort(sent_lengths)
        sents = sents[sent_sort]
        other_data = [other_data[idx] for idx in sent_sort]
        sent_l = sent_lengths[sent_sort]
        curr_l = 1
        l_location = [] #idx where sent length changes

        for j,i in enumerate(sent_sort):
            if sent_lengths[i] > curr_l:
                curr_l = sent_lengths[i]
                l_location.append(j)
        l_location.append(len(sents))
        #get batch sizes
        curr_idx = 0
        batch_idx = [0]
        nonzeros = []
        batch_l = []
        batch_w = []
        for i in range(len(l_location)-1):
            while curr_idx < l_location[i+1]:
                curr_idx = min(curr_idx + batchsize, l_location[i+1])
                batch_idx.append(curr_idx)
        for i in range(len(batch_idx)-1):
            batch_l.append(batch_idx[i+1] - batch_idx[i])
            batch_w.append(sent_l[batch_idx[i]])

        # Write output
        f = {}
        f["source"] = sents
        f["other_data"] = other_data
        f["batch_l"] = np.array(batch_l, dtype=int)
        f["source_l"] = np.array(batch_w, dtype=int)
        f["sents_l"]  = np.array(sent_l, dtype = int)
        f["batch_idx"] = np.array(batch_idx[:-1], dtype=int)
        f["vocab_size"] = np.array([len(indexer.d)])
        f["idx2word"] = indexer.idx2word
        f["word2idx"] = {word : idx for idx, word in indexer.idx2word.items()}
        
        print("Saved {} sentences (dropped {} due to length/unk filter)".format(
            len(f["source"]), dropped))
        pickle.dump(f, open(outfile, 'wb'))
        return max_sent_l

    print("First pass through data to get vocab...")
    num_sents_train, train_seqlength = make_vocab(args.trainfile, args.seqlength, args.minseqlength,
                                                  args.lowercase, args.replace_num, 1, 1)
    print("Number of sentences in training: {}".format(num_sents_train))
    num_sents_valid, valid_seqlength = make_vocab(args.valfile, args.seqlength, args.minseqlength, 
                                                  args.lowercase, args.replace_num, 0, 0)
    print("Number of sentences in valid: {}".format(num_sents_valid))
    num_sents_test, test_seqlength = make_vocab(args.testfile, args.seqlength, args.minseqlength, 
                                                args.lowercase, args.replace_num, 0, 0)
    print("Number of sentences in test: {}".format(num_sents_test))

    if args.vocabminfreq >= 0:
        indexer.prune_vocab(args.vocabminfreq, True)        
    else:
        indexer.prune_vocab(args.vocabsize, False)
    if args.vocabfile != '':
        print('Loading pre-specified source vocab from ' + args.vocabfile)
        indexer.load_vocab(args.vocabfile)
    indexer.write(args.outputfile + ".dict")
    print("Vocab size: Original = {}, Pruned = {}".format(len(indexer.vocab),
                                                          len(indexer.d)))
    print(train_seqlength, valid_seqlength, test_seqlength)
    max_sent_l = 0
    max_sent_l = convert(args.testfile, args.lowercase, args.replace_num, 
                         args.batchsize, test_seqlength, args.minseqlength, 
                         args.outputfile + "-test.pkl", num_sents_test,
                         max_sent_l, args.shuffle, args.include_boundary, 0)
    max_sent_l = convert(args.valfile, args.lowercase, args.replace_num, 
                         args.batchsize, valid_seqlength, args.minseqlength, 
                         args.outputfile + "-val.pkl", num_sents_valid,
                         max_sent_l, args.shuffle, args.include_boundary, 0)
    max_sent_l = convert(args.trainfile, args.lowercase, args.replace_num, 
                         args.batchsize, args.seqlength,  args.minseqlength,
                         args.outputfile + "-train.pkl", num_sents_train,
                         max_sent_l, args.shuffle, args.include_boundary, 1)
    print("Max sent length (before dropping): {}".format(max_sent_l))

def main(arguments):
    parser = argparse.ArgumentParser(
        description=__doc__,
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--vocabsize', help="Size of source vocabulary, constructed "
                                                "by taking the top X most frequent words. "
                                                " Rest are replaced with special UNK tokens.",
                                                type=int, default=10000)
    parser.add_argument('--vocabminfreq', help="Minimum frequency for vocab. Use this instead of "
                                                "vocabsize if > 0",
                                                type=int, default=1)
    parser.add_argument('--include_boundary', help="Add BOS/EOS tokens", type=int, default=1)        
    parser.add_argument('--lowercase', help="Lower case", type=int, default=0)        
    parser.add_argument('--replace_num', help="Replace numbers with N", type=int, default=0)        
    parser.add_argument('--trainfile', help="Path to training data.", required=True)
    parser.add_argument('--valfile', help="Path to validation data.", required=True)
    parser.add_argument('--testfile', help="Path to test validation data.", required=True)
    parser.add_argument('--batchsize', help="Size of each minibatch.", type=int, default=16)
    parser.add_argument('--seqlength', help="Maximum sequence length. Sequences longer "
                                               "than this are dropped.", type=int, default=200)
    parser.add_argument('--minseqlength', help="Minimum sequence length. Sequences shorter "
                                               "than this are dropped.", type=int, default=0)
    parser.add_argument('--outputfile', help="Prefix of the output file names. ", type=str,
                        required=True)
    parser.add_argument('--vocabfile', help="If working with a preset vocab, "
                                          "then including this will ignore srcvocabsize and use the"
                                          "vocab provided here.",
                                          type = str, default='')
    parser.add_argument('--shuffle', help="If = 1, shuffle sentences before sorting (based on  "
                                           "source length).",
                                          type = int, default = 1)
    args = parser.parse_args(arguments)
    np.random.seed(3435)
    get_data(args)

if __name__ == '__main__':
    sys.exit(main(sys.argv[1:]))


================================================
FILE: train.py
================================================
#!/usr/bin/env python3
import sys
import os

import argparse
import json
import random
import shutil
import copy

import torch
from torch import cuda
import torch.nn as nn
from torch.autograd import Variable
from torch.nn.parameter import Parameter

import torch.nn.functional as F
import numpy as np
import time
import logging
from data import Dataset
from models import RNNG
from utils import *

parser = argparse.ArgumentParser()

# Data path options
parser.add_argument('--train_file', default='data/ptb-1unk-train.pkl')
parser.add_argument('--val_file', default='data/ptb-1unk-val.pkl')
parser.add_argument('--train_from', default='')
# Model options
parser.add_argument('--w_dim', default=650, type=int, help='hidden dimension for LM/RNNG')
parser.add_argument('--h_dim', default=650, type=int, help='hidden dimension for LM/RNNG')
parser.add_argument('--q_dim', default=256, type=int, help='hidden dimension for variational RNN')
parser.add_argument('--num_layers', default=2, type=int, help='number of layers in LM and the stack LSTM (for RNNG)')
parser.add_argument('--dropout', default=0.5, type=float, help='dropout rate')
# Optimization options
parser.add_argument('--count_eos_ppl', default=0, type=int, help='whether to count eos in val PPL')
parser.add_argument('--save_path', default='urnng.pt', help='where to save the data')
parser.add_argument('--num_epochs', default=18, type=int, help='number of training epochs')
parser.add_argument('--min_epochs', default=8, type=int, help='do not decay learning rate for at least this many epochs')
parser.add_argument('--mode', default='unsupervised', type=str, choices=['unsupervised', 'supervised'])
parser.add_argument('--mc_samples', default=5, type=int, 
                    help='how many samples for IWAE bound calc for evaluation')
parser.add_argument('--samples', default=8, type=int, 
                    help='how many samples for score function gradients')
parser.add_argument('--lr', default=1, type=float, help='starting learning rate')
parser.add_argument('--q_lr', default=0.0001, type=float, help='learning rate for inference network q')
parser.add_argument('--action_lr', default=0.1, type=float, help='learning rate for action layer')
parser.add_argument('--decay', default=0.5, type=float, help='')
parser.add_argument('--kl_warmup', default=2, type=int, help='')
parser.add_argument('--train_q_epochs', default=2, type=int, help='')
parser.add_argument('--param_init', default=0.1, type=float, help='parameter initialization (over uniform)')
parser.add_argument('--max_grad_norm', default=5, type=float, help='gradient clipping parameter')
parser.add_argument('--q_max_grad_norm', default=1, type=float, help='gradient clipping parameter for q')
parser.add_argument('--gpu', default=2, type=int, help='which gpu to use')
parser.add_argument('--seed', default=3435, type=int, help='random seed')
parser.add_argument('--print_every', type=int, default=500, help='print stats after this many batches')


def main(args):
  np.random.seed(args.seed)
  torch.manual_seed(args.seed)
  train_data = Dataset(args.train_file)
  val_data = Dataset(args.val_file)  
  vocab_size = int(train_data.vocab_size)    
  print('Train: %d sents / %d batches, Val: %d sents / %d batches' % 
        (train_data.sents.size(0), len(train_data), val_data.sents.size(0), 
         len(val_data)))
  print('Vocab size: %d' % vocab_size)
  cuda.set_device(args.gpu)
  if args.train_from == '':
    model = RNNG(vocab = vocab_size,
                 w_dim = args.w_dim, 
                 h_dim = args.h_dim,
                 dropout = args.dropout,
                 num_layers = args.num_layers,
                 q_dim = args.q_dim)
    if args.param_init > 0:
      for param in model.parameters():    
        param.data.uniform_(-args.param_init, args.param_init)      
  else:
    print('loading model from ' + args.train_from)
    checkpoint = torch.load(args.train_from)
    model = checkpoint['model']
  print("model architecture")
  print(model)
  q_params = []
  action_params = []
  model_params = []
  for name, param in model.named_parameters():    
    if 'action' in name:
      print(name)
      action_params.append(param)
    elif 'q_' in name:
      print(name)
      q_params.append(param)
    else:
      model_params.append(param)
  q_lr = args.q_lr
  optimizer = torch.optim.SGD(model_params, lr=args.lr)
  q_optimizer = torch.optim.Adam(q_params, lr=q_lr)
  action_optimizer = torch.optim.SGD(action_params, lr=args.action_lr)
  model.train()
  model.cuda()

  epoch = 0
  decay= 0
  if args.kl_warmup > 0:
    kl_pen = 0.
    kl_warmup_batch = 1./(args.kl_warmup * len(train_data))
  else:
    kl_pen = 1.
  best_val_ppl = 5e5
  best_val_f1 = 0
  samples = args.samples
  best_val_ppl, best_val_f1 = eval(val_data, model, samples = args.mc_samples, 
                                   count_eos_ppl = args.count_eos_ppl)
  all_stats = [[0., 0., 0.]] #true pos, false pos, false neg for f1 calc
  while epoch < args.num_epochs:
    start_time = time.time()
    epoch += 1  
    if epoch > args.train_q_epochs:
      #stop training q after this many epochs
      args.q_lr = 0.
      for param_group in q_optimizer.param_groups:
        param_group['lr'] = args.q_lr
    print('Starting epoch %d' % epoch)
    train_nll_recon = 0.
    train_nll_iwae = 0.
    train_kl = 0.
    train_q_entropy = 0.
    num_sents = 0.
    num_words = 0.
    b = 0
    for i in np.random.permutation(len(train_data)):
      if args.kl_warmup > 0:
        kl_pen = min(1., kl_pen + kl_warmup_batch) 
      sents, length, batch_size, gold_actions, gold_spans, gold_binary_trees, other_data = train_data[i]      
      if length == 1:
        # we ignore length 1 sents during training/eval since we work with binary trees only
        continue
      sents = sents.cuda()
      b += 1
      q_optimizer.zero_grad()
      optimizer.zero_grad()
      action_optimizer.zero_grad()
      if args.mode == 'unsupervised':
        ll_word, ll_action_p, ll_action_q, all_actions, q_entropy = model(sents, samples=samples, 
                                                                          has_eos = True)
        log_f = ll_word + kl_pen*ll_action_p
        iwae_ll = log_f.mean(1).detach() + kl_pen*q_entropy.detach()
        obj = log_f.mean(1)
        if epoch < args.train_q_epochs:
          obj += kl_pen*q_entropy
          baseline = torch.zeros_like(log_f)
          baseline_k = torch.zeros_like(log_f)
          for k in range(samples):
            baseline_k.copy_(log_f)
            baseline_k[:, k].fill_(0)
            baseline[:, k] =  baseline_k.detach().sum(1) / (samples - 1)        
          obj += ((log_f.detach() - baseline.detach())*ll_action_q).mean(1)                      
        kl = (ll_action_q - ll_action_p).mean(1).detach()
        ll_word = ll_word.mean(1)
        train_q_entropy += q_entropy.sum().item()
      else:
        gold_actions = gold_binary_trees
        ll_action_q = model.forward_tree(sents, gold_actions, has_eos=True)        
        ll_word, ll_action_p, all_actions = model.forward_actions(sents, gold_actions)
        obj = ll_word + ll_action_p + ll_action_q
        kl = -ll_action_q
        iwae_ll = ll_word + ll_action_p
      train_nll_iwae += -iwae_ll.sum().item()
      actions = all_actions[:, 0].long().cpu()
      train_nll_recon += -ll_word.sum().item()
      train_kl += kl.sum().item()
      (-obj.mean()).backward()      
      if args.max_grad_norm > 0:
        torch.nn.utils.clip_grad_norm_(model_params + action_params, args.max_grad_norm)        
      if args.q_max_grad_norm > 0:
        torch.nn.utils.clip_grad_norm_(q_params, args.q_max_grad_norm)        
      q_optimizer.step()
      optimizer.step()
      action_optimizer.step()
      num_sents += batch_size
      num_words += batch_size * length
      for bb in range(batch_size):
        action = list(actions[bb].numpy())
        span_b = get_spans(action)
        span_b_set = set(span_b[:-1]) #ignore the sentence-level trivial span
        update_stats(span_b_set, [set(gold_spans[bb][:-1])], all_stats)
      if b % args.print_every == 0:
        all_f1 = get_f1(all_stats)
        param_norm = sum([p.norm()**2 for p in model.parameters()]).item()**0.5
        log_str = 'Epoch: %d, Batch: %d/%d, LR: %.4f, qLR: %.5f, qEnt: %.4f, TrainVAEPPL: %.2f, ' + \
                  'TrainReconPPL: %.2f, TrainKL: %.2f, TrainIWAEPPL: %.2f, ' + \
                  '|Param|: %.2f, BestValPerf: %.2f, BestValF1: %.2f, KLPen: %.4f, ' + \
                  'GoldTreeF1: %.2f, Throughput: %.2f examples/sec'
        print(log_str %
              (epoch, b, len(train_data), args.lr, args.q_lr, train_q_entropy / num_sents, 
               np.exp((train_nll_recon + train_kl)/ num_words),
               np.exp(train_nll_recon/num_words), train_kl / num_sents, 
               np.exp(train_nll_iwae/num_words),
               param_norm, best_val_ppl, best_val_f1, kl_pen, 
               all_f1[0], num_sents / (time.time() - start_time)))
        sent_str = [train_data.idx2word[word_idx] for word_idx in list(sents[-1][1:-1].cpu().numpy())]
        print("PRED:", get_tree(action[:-2], sent_str))
        print("GOLD:", get_tree(gold_binary_trees[-1], sent_str))
    print('--------------------------------')
    print('Checking validation perf...')    
    val_ppl, val_f1 = eval(val_data, model, 
                           samples = args.mc_samples, count_eos_ppl = args.count_eos_ppl)
    print('--------------------------------')
    if val_ppl < best_val_ppl:
      best_val_ppl = val_ppl
      best_val_f1 = val_f1
      checkpoint = {
        'args': args.__dict__,
        'model': model.cpu(),
        'word2idx': train_data.word2idx,
        'idx2word': train_data.idx2word
      }
      print('Saving checkpoint to %s' % args.save_path)
      torch.save(checkpoint, args.save_path)
      model.cuda()
    else:
      if epoch > args.min_epochs:
        decay = 1
    if decay == 1:
      args.lr = args.decay*args.lr
      args.q_lr = args.decay*args.q_lr
      args.action_lr = args.decay*args.action_lr
      for param_group in optimizer.param_groups:
        param_group['lr'] = args.lr
      for param_group in q_optimizer.param_groups:
        param_group['lr'] = args.q_lr
      for param_group in action_optimizer.param_groups:
        param_group['lr'] = args.action_lr
    if args.lr < 0.03:
      break
  print("Finished training!")

def eval(data, model, samples = 0, count_eos_ppl = 0):
  model.eval()
  num_sents = 0
  num_words = 0
  total_nll_recon = 0.
  total_kl = 0.
  total_nll_iwae = 0.
  corpus_f1 = [0., 0., 0.]
  sent_f1 = [] 
  with torch.no_grad():
    for i in list(reversed(range(len(data)))):
      sents, length, batch_size, gold_actions, gold_spans, gold_binary_trees, other_data = data[i] 
      if length == 1: # length 1 sents are ignored since URNNG needs at least length 2 sents
        continue
      if args.count_eos_ppl == 1:
        tree_length = length
        length += 1 
      else:
        sents = sents[:, :-1] 
        tree_length = length
      sents = sents.cuda()
      ll_word_all, ll_action_p_all, ll_action_q_all, actions_all, q_entropy = model(sents, 
                    samples = samples, has_eos = count_eos_ppl == 1)
      ll_word, ll_action_p, ll_action_q = ll_word_all.mean(1), ll_action_p_all.mean(1), ll_action_q_all.mean(1)
      kl = ll_action_q - ll_action_p
      _, binary_matrix, argmax_spans = model.q_crf._viterbi(model.scores)
      actions = []
      for b in range(batch_size):        
        tree = get_tree_from_binary_matrix(binary_matrix[b], tree_length)
        actions.append(get_actions(tree))
      actions = torch.Tensor(actions).long()
      total_nll_recon += -ll_word.sum().item()
      total_kl += kl.sum().item()
      num_sents += batch_size
      num_words += batch_size * length
      if samples > 0:
        #PPL estimate based on IWAE
        sample_ll = torch.zeros(batch_size, samples)
        for j in range(samples):
          ll_word_j, ll_action_p_j, ll_action_q_j = ll_word_all[:, j], ll_action_p_all[:, j], ll_action_q_all[:, j]
          sample_ll[:, j].copy_(ll_word_j + ll_action_p_j - ll_action_q_j)
        ll_iwae = model.logsumexp(sample_ll, 1) - np.log(samples)
        total_nll_iwae -= ll_iwae.sum().item()      
      for b in range(batch_size):
        action = list(actions[b].numpy())
        span_b = get_spans(action)
        span_b = argmax_spans[b]
        span_b_set = set(span_b[:-1])        
        gold_b_set = set(gold_spans[b][:-1])
        tp, fp, fn = get_stats(span_b_set, gold_b_set) 
        corpus_f1[0] += tp
        corpus_f1[1] += fp
        corpus_f1[2] += fn

        # sent-level F1 is based on L83-89 from https://github.com/yikangshen/PRPN/test_phrase_grammar.py
        model_out = span_b_set
        std_out = gold_b_set
        overlap = model_out.intersection(std_out)
        prec = float(len(overlap)) / (len(model_out) + 1e-8)
        reca = float(len(overlap)) / (len(std_out) + 1e-8)
        if len(std_out) == 0:
          reca = 1. 
          if len(model_out) == 0:
            prec = 1.
        f1 = 2 * prec * reca / (prec + reca + 1e-8)
        sent_f1.append(f1)
  tp, fp, fn = corpus_f1  
  prec = tp / (tp + fp)
  recall = tp / (tp + fn)
  corpus_f1 = 2*prec*recall/(prec+recall)*100 if prec+recall > 0 else 0.
  sent_f1 = np.mean(np.array(sent_f1))*100

  elbo_ppl = np.exp((total_nll_recon + total_kl) / num_words)
  recon_ppl = np.exp(total_nll_recon / num_words)
  iwae_ppl = np.exp(total_nll_iwae /num_words)
  kl = total_kl / num_sents  
  print('ElboPPL: %.2f, ReconPPL: %.2f, KL: %.4f, IwaePPL: %.2f, CorpusF1: %.2f, SentAvgF1: %.2f' % 
        (elbo_ppl, recon_ppl, kl, iwae_ppl, corpus_f1, sent_f1))
  #note that corpus F1 printed here is different from what you should get from
  #evalb since we do not ignore any tags (e.g. punctuation), while evalb ignores it
  model.train()
  return iwae_ppl, corpus_f1

if __name__ == '__main__':
  args = parser.parse_args()
  main(args)


================================================
FILE: train_lm.py
================================================
#!/usr/bin/env python3
import sys
import os

import argparse
import json
import random
import shutil
import copy

import torch
from torch import cuda
import torch.nn as nn
from torch.autograd import Variable
from torch.nn.parameter import Parameter

import torch.nn.functional as F
import numpy as np
import time
import logging
from data import Dataset
from models import RNNLM
from utils import *

parser = argparse.ArgumentParser()

# Data path options
parser.add_argument('--train_file', default='data/ptb-train.pkl')
parser.add_argument('--val_file', default='data/ptb-val.pkl')
parser.add_argument('--test_file', default='data/ptb-test.pkl')
parser.add_argument('--train_from', default='')
# Model options
parser.add_argument('--w_dim', default=650, type=int, help='hidden dimension for LM')
parser.add_argument('--h_dim', default=650, type=int, help='hidden dimension for LM')
parser.add_argument('--num_layers', default=2, type=int, help='number of layers in LM and the stack LSTM (for RNNG)')
parser.add_argument('--dropout', default=0.5, type=float, help='dropout rate')
# Optimization options
parser.add_argument('--count_eos_ppl', default=0, type=int, help='whether to count eos in val PPL')
parser.add_argument('--test', default=0, type=int, help='')
parser.add_argument('--save_path', default='urnng.pt', help='where to save the data')
parser.add_argument('--num_epochs', default=30, type=int, help='number of training epochs')
parser.add_argument('--min_epochs', default=8, type=int, help='do not decay learning rate for at least this many epochs')
parser.add_argument('--lr', default=1, type=float, help='starting learning rate')
parser.add_argument('--decay', default=0.5, type=float, help='')
parser.add_argument('--param_init', default=0.1, type=float, help='parameter initialization (over uniform)')
parser.add_argument('--max_grad_norm', default=5, type=float, help='gradient clipping parameter')
parser.add_argument('--gpu', default=2, type=int, help='which gpu to use')
parser.add_argument('--seed', default=3435, type=int, help='random seed')
parser.add_argument('--print_every', type=int, default=500, help='print stats after this many batches')


def main(args):
  np.random.seed(args.seed)
  torch.manual_seed(args.seed)
  train_data = Dataset(args.train_file)
  val_data = Dataset(args.val_file)  
  vocab_size = int(train_data.vocab_size)    
  print('Train: %d sents / %d batches, Val: %d sents / %d batches' % 
        (train_data.sents.size(0), len(train_data), val_data.sents.size(0), 
         len(val_data)))
  print('Vocab size: %d' % vocab_size)
  cuda.set_device(args.gpu)
  if args.train_from == '':
    model = RNNLM(vocab = vocab_size,
                  w_dim = args.w_dim, 
                  h_dim = args.h_dim,
                  dropout = args.dropout,
                  num_layers = args.num_layers)
    if args.param_init > 0:
      for param in model.parameters():    
        param.data.uniform_(-args.param_init, args.param_init)      
  else:
    print('loading model from ' + args.train_from)
    checkpoint = torch.load(args.train_from)
    model = checkpoint['model']
  print("model architecture")
  print(model)
  optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
  model.train()
  model.cuda()
  epoch = 0
  decay= 0
  if args.test == 1:
    test_data = Dataset(args.test_file)  
    test_ppl = eval(test_data, model, count_eos_ppl = args.count_eos_ppl)
    sys.exit(0)
  best_val_ppl = eval(val_data, model, count_eos_ppl = args.count_eos_ppl)
  while epoch < args.num_epochs:
    start_time = time.time()
    epoch += 1  
    print('Starting epoch %d' % epoch)
    train_nll = 0.
    num_sents = 0.
    num_words = 0.
    b = 0
    for i in np.random.permutation(len(train_data)):
      sents, length, batch_size, gold_actions, gold_spans, gold_binary_trees, other_data = train_data[i]
      if length == 1:
        continue
      sents = sents.cuda()
      b += 1
      optimizer.zero_grad()
      optimizer.zero_grad()
      nll = -model(sents).mean()
      train_nll += nll.item()*batch_size
      nll.backward()
      if args.max_grad_norm > 0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)        
      optimizer.step()
      num_sents += batch_size
      num_words += batch_size * length
      if b % args.print_every == 0:
        param_norm = sum([p.norm()**2 for p in model.parameters()]).item()**0.5
        print('Epoch: %d, Batch: %d/%d, LR: %.4f, TrainPPL: %.2f, |Param|: %.4f, BestValPerf: %.2f, Throughput: %.2f examples/sec' % 
              (epoch, b, len(train_data), args.lr, np.exp(train_nll / num_words), 
               param_norm, best_val_ppl, num_sents / (time.time() - start_time)))
    print('--------------------------------')
    print('Checking validation perf...')    
    val_ppl = eval(val_data, model,  count_eos_ppl = args.count_eos_ppl)
    print('--------------------------------')
    if val_ppl < best_val_ppl:
      best_val_ppl = val_ppl
      checkpoint = {
        'args': args.__dict__,
        'model': model.cpu(),
        'word2idx': train_data.word2idx,
        'idx2word': train_data.idx2word
      }
      print('Saving checkpoint to %s' % args.save_path)
      torch.save(checkpoint, args.save_path)
      model.cuda()
    else:
      if epoch > args.min_epochs:
        decay = 1
    if decay == 1:
      args.lr = args.decay*args.lr
      for param_group in optimizer.param_groups:
        param_group['lr'] = args.lr
    if args.lr < 0.03:
      break
    print("Finished training")

def eval(data, model, count_eos_ppl = 0):
  model.eval()
  num_words = 0
  total_nll = 0.
  with torch.no_grad():
    for i in list(reversed(range(len(data)))):
      sents, length, batch_size, gold_actions, gold_spans, gold_binary_trees, other_data = data[i] 
      if length == 1: #we ignore length 1 sents in URNNG eval so do this for LM too
        continue
      if args.count_eos_ppl == 1:
        length += 1 
      else:
        sents = sents[:, :-1] 
      sents = sents.cuda()
      num_words += length * batch_size
      nll = -model(sents).mean()
      total_nll += nll.item()*batch_size
  ppl = np.exp(total_nll / num_words)
  print('PPL: %.2f' % (ppl))
  model.train()
  return ppl

if __name__ == '__main__':
  args = parser.parse_args()
  main(args)


================================================
FILE: utils.py
================================================
#!/usr/bin/env python3
import numpy as np
import itertools
import random


def get_actions(tree, SHIFT = 0, REDUCE = 1, OPEN='(', CLOSE=')'):
  #input tree in bracket form: ((A B) (C D))
  #output action sequence: 0 0 1 0 0 1 1, where 0 is SHIFT and 1 is REDUCE
  actions = []
  tree = tree.strip()
  i = 0
  num_shift = 0
  num_reduce = 0
  left = 0
  right = 0
  while i < len(tree):
    if tree[i] != ' ' and tree[i] != OPEN and tree[i] != CLOSE: #terminal      
      if tree[i-1] == OPEN or tree[i-1] == ' ':
        actions.append(SHIFT)
        num_shift += 1
    elif tree[i] == CLOSE:
      actions.append(REDUCE)
      num_reduce += 1
      right += 1
    elif tree[i] == OPEN:
      left += 1
    i += 1
  assert(num_shift == num_reduce + 1)
  return actions

    
def get_tree(actions, sent = None, SHIFT = 0, REDUCE = 1):
  #input action and sent (lists), e.g. S S R S S R R, A B C D
  #output tree ((A B) (C D))
  stack = []
  pointer = 0
  if sent is None:
    sent = list(map(str, range((len(actions)+1) // 2)))
  for action in actions:
    if action == SHIFT:
      word = sent[pointer]
      stack.append(word)
      pointer += 1
    elif action == REDUCE:
      right = stack.pop()
      left = stack.pop()
      stack.append('(' + left + ' ' + right + ')')
  assert(len(stack) == 1)
  return stack[-1]
      
def get_spans(actions, SHIFT = 0, REDUCE = 1):
  sent = list(range((len(actions)+1) // 2))
  spans = []
  pointer = 0
  stack = []
  for action in actions:
    if action == SHIFT:
      word = sent[pointer]
      stack.append(word)
      pointer += 1
    elif action == REDUCE:
      right = stack.pop()
      left = stack.pop()
      if isinstance(left, int):
        left = (left, None)
      if isinstance(right, int):
        right = (None, right)
      new_span = (left[0], right[1])
      spans.append(new_span)
      stack.append(new_span)
  return spans

def get_stats(span1, span2):
  tp = 0
  fp = 0
  fn = 0
  for span in span1:
    if span in span2:
      tp += 1
    else:
      fp += 1
  for span in span2:
    if span not in span1:
      fn += 1
  return tp, fp, fn

def update_stats(pred_span, gold_spans, stats):
  for gold_span, stat in zip(gold_spans, stats):
    tp, fp, fn = get_stats(pred_span, gold_span)
    stat[0] += tp
    stat[1] += fp
    stat[2] += fn

def get_f1(stats):
  f1s = []
  for stat in stats:
    prec = stat[0] / (stat[0] + stat[1])
    recall = stat[0] / (stat[0] + stat[2])
    f1 = 2*prec*recall / (prec + recall)*100 if prec+recall > 0 else 0.
    f1s.append(f1)
  return f1s


def span_str(start = None, end = None):
  assert(start is not None or end is not None)
  if start is None:
    return ' '  + str(end) + ')'
  elif end is None:
    return '(' + str(start) + ' '
  else:
    return ' (' + str(start) + ' ' + str(end) + ') '    


def get_tree_from_binary_matrix(matrix, length):    
  sent = list(map(str, range(length)))
  n = len(sent)
  tree = {}
  for i in range(n):
    tree[i] = sent[i]
  for k in np.arange(1, n):
    for s in np.arange(n):
      t = s + k
      if t > n-1:
        break
      if matrix[s][t].item() == 1:
        span = '(' + tree[s] + ' ' + tree[t] + ')'
        tree[s] = span
        tree[t] = span
  return tree[0]

def get_nonbinary_spans(actions, SHIFT = 0, REDUCE = 1):
  spans = []
  stack = []
  pointer = 0
  binary_actions = []
  nonbinary_actions = []
  num_shift = 0
  num_reduce = 0
  for action in actions:
    # print(action, stack)
    if action == "SHIFT":
      nonbinary_actions.append(SHIFT)
      stack.append((pointer, pointer))
      pointer += 1
      binary_actions.append(SHIFT)
      num_shift += 1
    elif action[:3] == 'NT(':
      stack.append('(')            
    elif action == "REDUCE":
      nonbinary_actions.append(REDUCE)
      right = stack.pop()
      left = right
      n = 1
      while stack[-1] is not '(':
        left = stack.pop()
        n += 1
      span = (left[0], right[1])
      if left[0] != right[1]:
        spans.append(span)
      stack.pop()
      stack.append(span)
      while n > 1:
        n -= 1
        binary_actions.append(REDUCE)        
        num_reduce += 1
    else:
      assert False  
  assert(len(stack) == 1)
  assert(num_shift == num_reduce + 1)
  return spans, binary_actions, nonbinary_actions
Download .txt
gitextract_zduez7kc/

├── .gitignore
├── COLLINS.prm
├── README.md
├── TreeCRF.py
├── data/
│   ├── test.txt
│   ├── train.txt
│   └── valid.txt
├── data.py
├── eval_ppl.py
├── models.py
├── parse.py
├── preprocess.py
├── train.py
├── train_lm.py
└── utils.py
Download .txt
SYMBOL INDEX (74 symbols across 9 files)

FILE: TreeCRF.py
  class ConstituencyTreeCRF (line 11) | class ConstituencyTreeCRF(nn.Module):
    method __init__ (line 12) | def __init__(self):
    method logadd (line 16) | def logadd(self, x, y):
    method logsumexp (line 20) | def logsumexp(self, x, dim=1):
    method _init_table (line 24) | def _init_table(self, scores):
    method _forward (line 30) | def _forward(self, scores):
    method _backward (line 46) | def _backward(self, scores):
    method _marginal (line 71) | def _marginal(self, scores):
    method _entropy (line 80) | def _entropy(self, scores):
    method _sample (line 104) | def _sample(self, scores, alpha = None, argmax = False):
    method _viterbi (line 161) | def _viterbi(self, scores):
    method _backtrack (line 185) | def _backtrack(self, b, s, t):
  function get_span_str (line 196) | def get_span_str(start = None, end = None):

FILE: data.py
  class Dataset (line 6) | class Dataset(object):
    method __init__ (line 7) | def __init__(self, data_file):
    method _convert (line 19) | def _convert(self, x):
    method __len__ (line 22) | def __len__(self):
    method __getitem__ (line 25) | def __getitem__(self, idx):

FILE: eval_ppl.py
  function main (line 37) | def main(args):

FILE: models.py
  class RNNLM (line 9) | class RNNLM(nn.Module):
    method __init__ (line 10) | def __init__(self, vocab=10000,
    method forward (line 25) | def forward(self, sent):
    method generate (line 32) | def generate(self, bos = 2, eos = 3, max_len = 150):
  class SeqLSTM (line 48) | class SeqLSTM(nn.Module):
    method __init__ (line 49) | def __init__(self, i_dim = 200,
    method forward (line 62) | def forward(self, x, prev_h = None):
  class TreeLSTM (line 79) | class TreeLSTM(nn.Module):
    method __init__ (line 80) | def __init__(self, dim = 200):
    method forward (line 85) | def forward(self, x1, x2, e=None):
  class RNNG (line 106) | class RNNG(nn.Module):
    method __init__ (line 107) | def __init__(self, vocab = 100,
    method get_span_scores (line 136) | def get_span_scores(self, x):
    method get_action_masks (line 153) | def get_action_masks(self, actions, length):
    method forward (line 170) | def forward(self, x, samples = 1, is_temp = 1., has_eos=True):
    method forward_actions (line 296) | def forward_actions(self, x, actions, has_eos=True):
    method forward_tree (line 385) | def forward_tree(self, x, actions, has_eos=True):
    method logsumexp (line 404) | def logsumexp(self, x, dim=1):

FILE: parse.py
  function is_next_open_bracket (line 33) | def is_next_open_bracket(line, start_idx):
  function get_between_brackets (line 41) | def get_between_brackets(line, start_idx):
  function get_tags_tokens_lowercase (line 50) | def get_tags_tokens_lowercase(line):
  function get_nonterminal (line 70) | def get_nonterminal(line, start_idx):
  function get_actions (line 81) | def get_actions(line):
  function clean_number (line 112) | def clean_number(w):
  function main (line 116) | def main(args):

FILE: preprocess.py
  class Indexer (line 17) | class Indexer:
    method __init__ (line 18) | def __init__(self, symbols = ["<pad>","<unk>","<s>","</s>"]):
    method add_w (line 27) | def add_w(self, ws):
    method convert (line 32) | def convert(self, w):
    method convert_sequence (line 35) | def convert_sequence(self, ls):
    method write (line 38) | def write(self, outfile):
    method prune_vocab (line 46) | def prune_vocab(self, k, cnt = False):
    method load_vocab (line 60) | def load_vocab(self, vocab_file):
  function is_next_open_bracket (line 70) | def is_next_open_bracket(line, start_idx):
  function get_between_brackets (line 78) | def get_between_brackets(line, start_idx):
  function get_tags_tokens_lowercase (line 87) | def get_tags_tokens_lowercase(line):
  function get_nonterminal (line 108) | def get_nonterminal(line, start_idx):
  function get_actions (line 119) | def get_actions(line):
  function pad (line 150) | def pad(ls, length, symbol):
  function clean_number (line 155) | def clean_number(w):
  function get_data (line 159) | def get_data(args):
  function main (line 311) | def main(arguments):

FILE: train.py
  function main (line 61) | def main(args):
  function eval (line 243) | def eval(data, model, samples = 0, count_eos_ppl = 0):

FILE: train_lm.py
  function main (line 52) | def main(args):
  function eval (line 143) | def eval(data, model, count_eos_ppl = 0):

FILE: utils.py
  function get_actions (line 7) | def get_actions(tree, SHIFT = 0, REDUCE = 1, OPEN='(', CLOSE=')'):
  function get_tree (line 33) | def get_tree(actions, sent = None, SHIFT = 0, REDUCE = 1):
  function get_spans (line 52) | def get_spans(actions, SHIFT = 0, REDUCE = 1):
  function get_stats (line 74) | def get_stats(span1, span2):
  function update_stats (line 88) | def update_stats(pred_span, gold_spans, stats):
  function get_f1 (line 95) | def get_f1(stats):
  function span_str (line 105) | def span_str(start = None, end = None):
  function get_tree_from_binary_matrix (line 115) | def get_tree_from_binary_matrix(matrix, length):
  function get_nonbinary_spans (line 132) | def get_nonbinary_spans(actions, SHIFT = 0, REDUCE = 1):
Condensed preview — 15 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (107K chars).
[
  {
    "path": ".gitignore",
    "chars": 136,
    "preview": "*.pt\n*.amat\n*.mat\n*.out\n*.out~\n*.pyc\n*.pt~\n.gitignore~\n*.out~\n*.sh\n*.sh~\n*.py~\n*.json\n*.json~\n*.model\n*.h5\n*.tar.gz\n*.hd"
  },
  {
    "path": "COLLINS.prm",
    "chars": 2357,
    "preview": "##------------------------------------------##\n## Debug mode                               ##\n##   0: No debugging      "
  },
  {
    "path": "README.md",
    "chars": 4211,
    "preview": "# Unsupervised Recurrent Neural Network Grammars\n\nThis is an implementation of the paper:  \n[Unsupervised Recurrent Neur"
  },
  {
    "path": "TreeCRF.py",
    "chars": 6934,
    "preview": "#!/usr/bin/env python3\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nimport ite"
  },
  {
    "path": "data/test.txt",
    "chars": 6444,
    "preview": "(S (INTJ (RB No)) (, ,) (NP (PRP it)) (VP (VBD was) (RB n't) (NP (NNP Black) (NNP Monday))) (. .))\n(S (CC But) (SBAR (IN"
  },
  {
    "path": "data/train.txt",
    "chars": 7549,
    "preview": "(S (PP (IN In) (NP (NP (DT an) (NNP Oct.) (CD 19) (NN review)) (PP (IN of) (NP (`` ``) (NP (DT The) (NN Misanthrope)) ('"
  },
  {
    "path": "data/valid.txt",
    "chars": 7809,
    "preview": "(S (NP (NP (DT The) (NN economy) (POS 's)) (NN temperature)) (VP (MD will) (VP (VB be) (VP (VBN taken) (PP (IN from) (NP"
  },
  {
    "path": "data.py",
    "chars": 1585,
    "preview": "#!/usr/bin/env python3\nimport numpy as np\nimport torch\nimport pickle\n\nclass Dataset(object):\n  def __init__(self, data_f"
  },
  {
    "path": "eval_ppl.py",
    "chars": 3753,
    "preview": "#!/usr/bin/env python3\nimport sys\nimport os\n\nimport argparse\nimport json\nimport random\nimport shutil\nimport copy\n\nimport"
  },
  {
    "path": "models.py",
    "chars": 16848,
    "preview": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\nimport numpy as np\nfrom utils import *\nfrom TreeCRF im"
  },
  {
    "path": "parse.py",
    "chars": 7103,
    "preview": "#!/usr/bin/env python3\nimport sys\nimport os\n\nimport argparse\nimport json\nimport random\nimport shutil\nimport copy\n\nimport"
  },
  {
    "path": "preprocess.py",
    "chars": 14948,
    "preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n\n\"\"\"Create data files\n\"\"\"\n\nimport os\nimport sys\nimport argparse\nimport num"
  },
  {
    "path": "train.py",
    "chars": 14042,
    "preview": "#!/usr/bin/env python3\nimport sys\nimport os\n\nimport argparse\nimport json\nimport random\nimport shutil\nimport copy\n\nimport"
  },
  {
    "path": "train_lm.py",
    "chars": 6303,
    "preview": "#!/usr/bin/env python3\nimport sys\nimport os\n\nimport argparse\nimport json\nimport random\nimport shutil\nimport copy\n\nimport"
  },
  {
    "path": "utils.py",
    "chars": 4286,
    "preview": "#!/usr/bin/env python3\nimport numpy as np\nimport itertools\nimport random\n\n\ndef get_actions(tree, SHIFT = 0, REDUCE = 1, "
  }
]

About this extraction

This page contains the full source code of the harvardnlp/urnng GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 15 files (101.9 KB), approximately 31.6k tokens, and a symbol index with 74 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!