Repository: harvardnlp/urnng Branch: master Commit: b1eeffa5b590 Files: 15 Total size: 101.9 KB Directory structure: gitextract_zduez7kc/ ├── .gitignore ├── COLLINS.prm ├── README.md ├── TreeCRF.py ├── data/ │ ├── test.txt │ ├── train.txt │ └── valid.txt ├── data.py ├── eval_ppl.py ├── models.py ├── parse.py ├── preprocess.py ├── train.py ├── train_lm.py └── utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ *.pt *.amat *.mat *.out *.out~ *.pyc *.pt~ .gitignore~ *.out~ *.sh *.sh~ *.py~ *.json *.json~ *.model *.h5 *.tar.gz *.hdf5 *.dict *.pkl ================================================ FILE: COLLINS.prm ================================================ ##------------------------------------------## ## Debug mode ## ## 0: No debugging ## ## 1: print data for individual sentence ## ##------------------------------------------## DEBUG 0 ##------------------------------------------## ## MAX error ## ## Number of error to stop the process. ## ## This is useful if there could be ## ## tokanization error. ## ## The process will stop when this number## ## of errors are accumulated. ## ##------------------------------------------## MAX_ERROR 10 ##------------------------------------------## ## Cut-off length for statistics ## ## At the end of evaluation, the ## ## statistics for the senetnces of length## ## less than or equal to this number will## ## be shown, on top of the statistics ## ## for all the sentences ## ##------------------------------------------## CUTOFF_LEN 10 ##------------------------------------------## ## unlabeled or labeled bracketing ## ## 0: unlabeled bracketing ## ## 1: labeled bracketing ## ##------------------------------------------## LABELED 0 ##------------------------------------------## ## Delete labels ## ## list of labels to be ignored. ## ## If it is a pre-terminal label, delete ## ## the word along with the brackets. ## ## If it is a non-terminal label, just ## ## delete the brackets (don't delete ## ## deildrens). ## ##------------------------------------------## DELETE_LABEL TOP DELETE_LABEL -NONE- DELETE_LABEL , DELETE_LABEL : DELETE_LABEL `` DELETE_LABEL '' DELETE_LABEL . ##------------------------------------------## ## Delete labels for length calculation ## ## list of labels to be ignored for ## ## length calculation purpose ## ##------------------------------------------## DELETE_LABEL_FOR_LENGTH -NONE- ##------------------------------------------## ## Equivalent labels, words ## ## the pairs are considered equivalent ## ## This is non-directional. ## ##------------------------------------------## EQ_LABEL ADVP PRT # EQ_WORD Example example ================================================ FILE: README.md ================================================ # Unsupervised Recurrent Neural Network Grammars This is an implementation of the paper: [Unsupervised Recurrent Neural Network Grammars](https://arxiv.org/abs/1904.03746) Yoon Kim, Alexander Rush, Lei Yu, Adhiguna Kuncoro, Chris Dyer, Gabor Melis NAACL 2019 ## Dependencies The code was tested in `python 3.6` and `pytorch 1.0`. ## Data Sample train/val/test data is in the `data/` folder. These are the standard datasets from PTB. First preprocess the data: ``` python preprocess.py --trainfile data/train.txt --valfile data/valid.txt --testfile data/test.txt --outputfile data/ptb --vocabminfreq 1 --lowercase 0 --replace_num 0 --batchsize 16 ``` Running this will save the following files in the `data/` folder: `ptb-train.pkl`, `ptb-val.pkl`, `ptb-test.pkl`, `ptb.dict`. Here `ptb.dict` is the word-idx mapping, and you can change the output folder/name by changing the argument to `outputfile`. Also, the preprocessing here will replace singletons with a single `` rather than with Berkeley parser's mapping rules (see below for results using this setup). ## Training To train the URNNG: ``` python train.py --train_file data/ptb-train.pkl --val_file data/ptb-val.pkl --save_path urnng.pt --mode unsupervised --gpu 0 ``` where `save_path` is where you want to save the model, and `gpu 0` is for using the first GPU in the cluster (the mapping from PyTorch GPU index to your cluster's GPU index may vary). Training should take 2 to 3 days depending on your setup. To train the RNNG: ``` python train.py --train_file data/ptb-train.pkl --val_file data/ptb-val.pkl --save_path rnng.pt --mode supervised --train_q_epochs 18 --gpu 0 ``` For fine-tuning: ``` python train.py --train_from rnng.pt --train_file data/ptb-train.pkl --val_file data/ptb-val.pkl --save_path rnng-urnng.pt --mode unsupervised --lr 0.1 --train_q_epochs 10 --epochs 10 --min_epochs 6 --gpu 0 --kl_warmup 0 ``` To train the LM: ``` python train_lm.py --train_file data/ptb-train.pkl --val_file data/ptb-val.pkl --test_file data/ptb-test.pkl --save_path lm.pt ``` ## Evaluation To evaluate perplexity with importance sampling on the test set: ``` python eval_ppl.py --model_file urnng.pt --test_file data/ptb-test.pkl --samples 1000 --is_temp 2 --gpu 0 ``` The argument `samples` is for the number of importance weighted samples, and `is_temp` is for flattening the inference network's distribution (footnote 14 in the paper). The same evaluation code will work for RNNG. For LM evaluation: ``` python train_lm.py --train_from lm.pt --test_file data/ptb-test.pkl --test 1 ``` To evaluate F1, first we need to parse the test set: ``` python parse.py --model_file urnng.pt --data_file data/ptb-test.txt --out_file pred-parse.txt --gold_out_file gold-parse.txt --gpu 0 ``` This will output the predicted parse trees into `pred-parse.txt`. We also output a version of the gold parse `gold-parse.txt` to be used as input for `evalb`, since sentences with only trivial spans are ignored by `parse.py`. Note that corpus/sentence F1 results printed here do not correspond to the results reported in the paper, since it does not ignore punctuation. Finally, download/install `evalb`, available [here](https://nlp.cs.nyu.edu/evalb). Then run: ``` evalb -p COLLINS.prm gold-parse.txt test-parse.txt ``` where `COLLINS.prm` is the parameter file (provided in this repo) that tells `evalb` to ignore punctuation and evaluate on unlabeled F1. ## Note Regarding Preprocessing Note that some of the details regarding the preprocessing is slightly different from the original paper. In particular, in this implementation we replace singleton words a single `` token instead of using Berkeley parser's mapping rules. This results in slight lower perplexity for all models, since the vocabulary size is smaller. Here are the perplexty numbers I get in this setting: - RNNLM: 89.2 - RNNG: 83.7 - URNNG: 85.1 (F1: 38.4) - RNNG --> URNNG: 82.5 ## Acknowledgements Some of our preprocessing and evaluation code is based on the following repositories: - [Recurrent Neural Network Grammars](https://github.com/clab/rnng) - [Parsing Reading Predict Network](https://github.com/yikangshen/PRPN) ## License MIT ================================================ FILE: TreeCRF.py ================================================ #!/usr/bin/env python3 import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import itertools import utils import random class ConstituencyTreeCRF(nn.Module): def __init__(self): super(ConstituencyTreeCRF, self).__init__() self.huge = 1e9 def logadd(self, x, y): d = torch.max(x,y) return torch.log(torch.exp(x-d) + torch.exp(y-d)) + d def logsumexp(self, x, dim=1): d = torch.max(x, dim)[0] return torch.log(torch.exp(x - d.unsqueeze(dim).expand_as(x)).sum(dim)) + d def _init_table(self, scores): # initialize dynamic programming table batch_size = scores.size(0) n = scores.size(1) self.alpha = [[scores.new(batch_size).fill_(-self.huge) for _ in range(n)] for _ in range(n)] def _forward(self, scores): #inside step batch_size = scores.size(0) n = scores.size(1) self._init_table(scores) for i in range(n): self.alpha[i][i] = scores[:, i, i] for k in np.arange(1, n+1): for s in range(n): t = s + k if t > n-1: break tmp = [self.alpha[s][u] + self.alpha[u+1][t] + scores[:, s, t] for u in np.arange(s,t)] tmp = torch.stack(tmp, 1) self.alpha[s][t] = self.logsumexp(tmp, 1) def _backward(self, scores): #outside step batch_size = scores.size(0) n = scores.size(1) self.beta = [[None for _ in range(n)] for _ in range(n)] self.beta[0][n-1] = scores.new(batch_size).fill_(0) for k in np.arange(n-1, 0, -1): for s in range(n): t = s + k if t > n-1: break for u in np.arange(s, t): if s < u+1: tmp = self.beta[s][t] + self.alpha[u+1][t] + scores[:, s, t] if self.beta[s][u] is None: self.beta[s][u] = tmp else: self.beta[s][u] = self.logadd(self.beta[s][u], tmp) if u+1 < t+1: tmp = self.beta[s][t] + self.alpha[s][u] + scores[:, s, t] if self.beta[u+1][t] is None: self.beta[u+1][t] = tmp else: self.beta[u+1][t] = self.logadd(self.beta[u+1][t], tmp) def _marginal(self, scores): batch_size = scores.size(0) n = scores.size(1) self.log_marginal = [[None for _ in range(n)] for _ in range(n)] log_Z = self.alpha[0][n-1] for s in range(n): for t in np.arange(s, n): self.log_marginal[s][t] = self.alpha[s][t] + self.beta[s][t] - log_Z def _entropy(self, scores): batch_size = scores.size(0) n = scores.size(1) self.entropy = [[None for _ in range(n)] for _ in range(n)] for i in range(n): self.entropy[i][i] = scores.new(batch_size).fill_(0) for k in np.arange(1, n+1): for s in range(n): t = s + k if t > n-1: break score = [] prev_ent = [] for u in np.arange(s, t): score.append(self.alpha[s][u] + self.alpha[u+1][t]) prev_ent.append(self.entropy[s][u] + self.entropy[u+1][t]) score = torch.stack(score, 1) prev_ent = torch.stack(prev_ent, 1) log_prob = F.log_softmax(score, dim = 1) prob = log_prob.exp() entropy = ((prev_ent - log_prob)*prob).sum(1) self.entropy[s][t] = entropy def _sample(self, scores, alpha = None, argmax = False): # sample from p(tree | sent) # also get the spans if alpha is None: self._forward(scores) alpha = self.alpha batch_size = scores.size(0) n = scores.size(1) tree = scores.new(batch_size, n, n).zero_() all_log_probs = [] tree_brackets = [] spans = [] for b in range(batch_size): sampled = [(0, n-1)] span = [(0, n-1)] queue = [(0, n-1)] #start, end log_probs = [] tree_str = get_span_str(0, n-1) while len(queue) > 0: node = queue.pop(0) start, end = node left_parent = get_span_str(start, None) right_parent = get_span_str(None, end) score = [] score_idx = [] for u in np.arange(start, end): score.append(alpha[start][u][b] + alpha[u+1][end][b]) score_idx.append([(start, u), (u+1, end)]) score = torch.stack(score, 0) log_prob = F.log_softmax(score, dim = 0) if argmax: sample = torch.max(log_prob, 0)[1] else: prob = log_prob.exp() sample = torch.multinomial(log_prob.exp(), 1) sample_idx = score_idx[sample.item()] log_probs.append(log_prob[sample.item()]) for idx in sample_idx: if idx[0] != idx[1]: queue.append(idx) span.append(idx) sampled.append(idx) left_child = '(' + get_span_str(sample_idx[0][0], sample_idx[0][1]) right_child = get_span_str(sample_idx[1][0], sample_idx[1][1]) + ')' if sample_idx[0][0] != sample_idx[0][1]: tree_str = tree_str.replace(left_parent, left_child) if sample_idx[1][0] != sample_idx[1][1]: tree_str = tree_str.replace(right_parent, right_child) all_log_probs.append(torch.stack(log_probs, 0).sum(0)) tree_brackets.append(tree_str) spans.append(span[::-1]) for idx in sampled: tree[b][idx[0]][idx[1]] = 1 all_log_probs = torch.stack(all_log_probs, 0) return tree, all_log_probs, tree_brackets, spans def _viterbi(self, scores): # cky algorithm batch_size = scores.size(0) n = scores.size(1) self.max_scores = scores.new(batch_size, n, n).fill_(-self.huge) self.bp = scores.new(batch_size, n, n).zero_() self.argmax = scores.new(batch_size, n, n).zero_() self.spans = [[] for _ in range(batch_size)] tmp = scores.new(batch_size, n).zero_() for i in range(n): self.max_scores[:, i, i] = scores[:, i, i] for k in np.arange(1, n): for s in np.arange(n): t = s + k if t > n-1: break for u in np.arange(s, t): tmp = self.max_scores[:, s, u] + self.max_scores[:, u+1, t] + scores[:, s, t] self.bp[:, s, t][self.max_scores[:, s, t] < tmp] = int(u) self.max_scores[:, s, t] = torch.max(self.max_scores[:, s, t], tmp) for b in range(batch_size): self._backtrack(b, 0, n-1) return self.max_scores[:, 0, n-1], self.argmax, self.spans def _backtrack(self, b, s, t): u = int(self.bp[b][s][t]) self.argmax[b][s][t] = 1 if s == t: return None else: self.spans[b].insert(0, (s,t)) self._backtrack(b, s, u) self._backtrack(b, u+1, t) return None def get_span_str(start = None, end = None): assert(start is not None or end is not None) if start is None: return ' ' + str(end) + ')' elif end is None: return '(' + str(start) + ' ' else: return ' (' + str(start) + ' ' + str(end) + ') ' ================================================ FILE: data/test.txt ================================================ (S (INTJ (RB No)) (, ,) (NP (PRP it)) (VP (VBD was) (RB n't) (NP (NNP Black) (NNP Monday))) (. .)) (S (CC But) (SBAR (IN while) (S (NP (DT the) (NNP New) (NNP York) (NNP Stock) (NNP Exchange)) (VP (VBD did) (RB n't) (VP (VB fall) (ADVP (RB apart)) (NP (NNP Friday)) (SBAR (IN as) (S (NP (DT the) (NNP Dow) (NNP Jones) (NNP Industrial) (NNP Average)) (VP (VBD plunged) (NP (NP (CD 190.58) (NNS points)) (PRN (: --) (NP (NP (JJS most)) (PP (IN of) (NP (PRP it))) (PP (IN in) (NP (DT the) (JJ final) (NN hour)))) (: --)))))))))) (NP (PRP it)) (ADVP (RB barely)) (VP (VBD managed) (S (VP (TO to) (VP (VB stay) (NP (NP (DT this) (NN side)) (PP (IN of) (NP (NN chaos)))))))) (. .)) (S (NP (NP (DT Some) (`` ``) (NN circuit) (NNS breakers) ('' '')) (VP (VBN installed) (PP (IN after) (NP (DT the) (NNP October) (CD 1987) (NN crash))))) (VP (VBD failed) (NP (PRP$ their) (JJ first) (NN test)) (PRN (, ,) (S (NP (NNS traders)) (VP (VBP say))) (, ,)) (S (ADJP (JJ unable) (S (VP (TO to) (VP (VB cool) (NP (NP (DT the) (NN selling) (NN panic)) (PP (IN in) (NP (DT both) (NNS stocks) (CC and) (NNS futures)))))))))) (. .)) (S (NP (NP (NP (DT The) (CD 49) (NN stock) (NN specialist) (NNS firms)) (PP (IN on) (NP (DT the) (NNP Big) (NNP Board) (NN floor)))) (: --) (NP (NP (DT the) (NNS buyers) (CC and) (NNS sellers)) (PP (IN of) (NP (JJ last) (NN resort))) (SBAR (WHNP (WP who)) (S (VP (VBD were) (VP (VBN criticized) (PP (IN after) (NP (DT the) (CD 1987) (NN crash)))))))) (: --)) (ADVP (RB once) (RB again)) (VP (MD could) (RB n't) (VP (VB handle) (NP (DT the) (NN selling) (NN pressure)))) (. .)) (S (S (NP (JJ Big) (NN investment) (NNS banks)) (VP (VBD refused) (S (VP (TO to) (VP (VB step) (ADVP (IN up) (PP (TO to) (NP (DT the) (NN plate)))) (S (VP (TO to) (VP (VB support) (NP (DT the) (JJ beleaguered) (NN floor) (NNS traders)) (PP (IN by) (S (VP (VBG buying) (NP (NP (JJ big) (NNS blocks)) (PP (IN of) (NP (NN stock))))))))))))))) (, ,) (NP (NNS traders)) (VP (VBP say)) (. .)) (S (NP (NP (JJ Heavy) (NN selling)) (PP (IN of) (NP (NP (NNP Standard) (CC &) (NNP Poor) (POS 's)) (JJ 500-stock) (NN index) (NNS futures))) (PP (IN in) (NP (NNP Chicago)))) (VP (ADVP (RB relentlessly)) (VBD beat) (NP (NNS stocks)) (ADVP (RB downward))) (. .)) (S (NP (NP (CD Seven) (NNP Big) (NNP Board) (NNS stocks)) (: --) (NP (NP (NNP UAL)) (, ,) (NP (NNP AMR)) (, ,) (NP (NNP BankAmerica)) (, ,) (NP (NNP Walt) (NNP Disney)) (, ,) (NP (NNP Capital) (NNP Cities\/ABC)) (, ,) (NP (NNP Philip) (NNP Morris)) (CC and) (NP (NNP Pacific) (NNP Telesis) (NNP Group))) (: --)) (VP (VP (VBD stopped) (S (VP (VBG trading)))) (CC and) (VP (ADVP (RB never)) (VBD resumed))) (. .)) (S (NP (DT The) (NN finger-pointing)) (VP (VBZ has) (ADVP (RB already)) (VP (VBN begun))) (. .)) (S (`` ``) (NP (DT The) (NN equity) (NN market)) (VP (VBD was) (ADJP (JJ illiquid))) (. .)) (SINV (S (ADVP (RB Once) (RB again)) (-LRB- -LCB-) (NP (DT the) (NNS specialists)) (-RRB- -RCB-) (VP (VBD were) (RB not) (ADJP (JJ able) (S (VP (TO to) (VP (VB handle) (NP (NP (DT the) (NNS imbalances)) (PP (IN on) (NP (NP (DT the) (NN floor)) (PP (IN of) (NP (DT the) (NNP New) (NNP York) (NNP Stock) (NNP Exchange)))))))))))) (, ,) ('' '') (VP (VBD said)) (NP (NP (NNP Christopher) (NNP Pedersen)) (, ,) (NP (NP (JJ senior) (NN vice) (NN president)) (PP (IN at) (NP (NNP Twenty-First) (NNP Securities) (NNP Corp))))) (. .)) (SINV (VP (VBD Countered)) (NP (NP (NNP James) (NNP Maguire)) (, ,) (NP (NP (NN chairman)) (PP (IN of) (NP (NNS specialists) (NNP Henderson) (NNP Brothers) (NNP Inc.))))) (: :) (`` ``) (S (NP (PRP It)) (VP (VBZ is) (ADJP (JJ easy)) (S (VP (TO to) (VP (VB say) (SBAR (S (NP (DT the) (NN specialist)) (VP (VBZ is) (RB n't) (VP (VBG doing) (NP (PRP$ his) (NN job))))))))))) (. .)) (S (SBAR (WHADVP (WRB When)) (S (NP (DT the) (NN dollar)) (VP (VBZ is) (PP (IN in) (NP (DT a) (NN free-fall)))))) (, ,) (NP (RB even) (JJ central) (NNS banks)) (VP (MD ca) (RB n't) (VP (VB stop) (NP (PRP it)))) (. .)) (S (NP (NNS Speculators)) (VP (VBP are) (VP (VBG calling) (PP (IN for) (NP (NP (DT a) (NN degree)) (PP (IN of) (NP (NN liquidity))) (SBAR (WHNP (WDT that)) (S (VP (VBZ is) (RB not) (ADVP (RB there)) (PP (IN in) (NP (DT the) (NN market)))))))))) (. .) ('' '')) (S (NP (NP (JJ Many) (NN money) (NNS managers)) (CC and) (NP (DT some) (NNS traders))) (VP (VBD had) (ADVP (RB already)) (VP (VBN left) (NP (PRP$ their) (NNS offices)) (NP (RB early) (NNP Friday) (NN afternoon)) (PP (IN on) (NP (DT a) (JJ warm) (NN autumn) (NN day))) (: --) (SBAR (IN because) (S (NP (DT the) (NN stock) (NN market)) (VP (VBD was) (ADJP (RB so) (JJ quiet))))))) (. .)) (S (RB Then) (PP (IN in) (NP (DT a) (NN lightning) (NN plunge))) (, ,) (NP (DT the) (NNP Dow) (NNP Jones) (NNS industrials)) (PP (IN in) (NP (QP (RB barely) (DT an)) (NN hour))) (VP (VBD surrendered) (NP (NP (QP (RB about) (DT a)) (JJ third)) (PP (IN of) (NP (NP (PRP$ their) (NNS gains)) (NP (DT this) (NN year))))) (, ,) (S (VP (VBG chalking) (PRT (RP up)) (NP (NP (DT a) (ADJP (ADJP (JJ 190.58-point)) (, ,) (CC or) (ADJP (CD 6.9) (NN %)) (, ,)) (NN loss)) (PP (IN on) (NP (DT the) (NN day)))) (PP (IN in) (NP (JJ gargantuan) (NN trading) (NN volume)))))) (. .)) (S (NP (JJ Final-hour) (NN trading)) (VP (VBD accelerated) (PP (TO to) (NP (NP (QP (CD 108.1) (CD million)) (NNS shares)) (, ,) (NP (NP (DT a) (NN record)) (PP (IN for) (NP (DT the) (NNP Big) (NNP Board))))))) (. .)) (S (PP (IN At) (NP (NP (DT the) (NN end)) (PP (IN of) (NP (DT the) (NN day))))) (, ,) (NP (QP (CD 251.2) (CD million)) (NNS shares)) (VP (VBD were) (VP (VBN traded))) (. .)) (S (NP (DT The) (NNP Dow) (NNP Jones) (NNS industrials)) (VP (VBD closed) (PP (IN at) (NP (CD 2569.26)))) (. .)) (S (NP (NP (DT The) (NNP Dow) (POS 's)) (NN decline)) (VP (VBD was) (ADJP (JJ second) (PP (IN in) (NP (NN point) (NNS terms))) (PP (ADVP (RB only)) (TO to) (NP (NP (DT the) (JJ 508-point) (NNP Black) (NNP Monday) (NN crash)) (SBAR (WHNP (WDT that)) (S (VP (VBD occurred) (NP (NNP Oct.) (CD 19) (, ,) (CD 1987))))))))) (. .)) (S (PP (IN In) (NP (NN percentage) (NNS terms))) (, ,) (ADVP (RB however)) (, ,) (NP (NP (DT the) (NNP Dow) (POS 's)) (NN dive)) (VP (VBD was) (NP (NP (NP (DT the) (JJ 12th-worst)) (ADVP (RB ever))) (CC and) (NP (NP (DT the) (JJS sharpest)) (SBAR (IN since) (S (NP (DT the) (NN market)) (VP (VBD fell) (NP (NP (CD 156.83)) (, ,) (CC or) (NP (CD 8) (NN %))) (, ,) (PP (NP (DT a) (NN week)) (IN after) (NP (NNP Black) (NNP Monday))))))))) (. .)) ================================================ FILE: data/train.txt ================================================ (S (PP (IN In) (NP (NP (DT an) (NNP Oct.) (CD 19) (NN review)) (PP (IN of) (NP (`` ``) (NP (DT The) (NN Misanthrope)) ('' '') (PP (IN at) (NP (NP (NNP Chicago) (POS 's)) (NNP Goodman) (NNP Theatre))))) (PRN (-LRB- -LRB-) (`` ``) (S (NP (VBN Revitalized) (NNS Classics)) (VP (VBP Take) (NP (DT the) (NN Stage)) (PP (IN in) (NP (NNP Windy) (NNP City))))) (, ,) ('' '') (NP (NN Leisure) (CC &) (NNS Arts)) (-RRB- -RRB-)))) (, ,) (NP (NP (NP (DT the) (NN role)) (PP (IN of) (NP (NNP Celimene)))) (, ,) (VP (VBN played) (PP (IN by) (NP (NNP Kim) (NNP Cattrall)))) (, ,)) (VP (VBD was) (VP (ADVP (RB mistakenly)) (VBN attributed) (PP (TO to) (NP (NNP Christina) (NNP Haag))))) (. .)) (S (NP (NNP Ms.) (NNP Haag)) (VP (VBZ plays) (NP (NNP Elianti))) (. .)) (S (NP (NNP Rolls-Royce) (NNP Motor) (NNPS Cars) (NNP Inc.)) (VP (VBD said) (SBAR (S (NP (PRP it)) (VP (VBZ expects) (S (NP (PRP$ its) (NNP U.S.) (NNS sales)) (VP (TO to) (VP (VB remain) (ADJP (JJ steady)) (PP (IN at) (NP (QP (IN about) (CD 1,200)) (NNS cars))) (PP (IN in) (NP (CD 1990)))))))))) (. .)) (S (NP (DT The) (NN luxury) (NN auto) (NN maker)) (NP (JJ last) (NN year)) (VP (VBD sold) (NP (CD 1,214) (NNS cars)) (PP (IN in) (NP (DT the) (NNP U.S.))))) (S (NP (NP (NNP Howard) (NNP Mosher)) (, ,) (NP (NP (NN president)) (CC and) (NP (JJ chief) (NN executive) (NN officer))) (, ,)) (VP (VBD said) (SBAR (S (NP (PRP he)) (VP (VBZ anticipates) (NP (NP (NN growth)) (PP (IN for) (NP (DT the) (NN luxury) (NN auto) (NN maker))) (PP (PP (IN in) (NP (NNP Britain) (CC and) (NNP Europe))) (, ,) (CC and) (PP (IN in) (NP (ADJP (JJ Far) (JJ Eastern)) (NNS markets))))))))) (. .)) (S (NP (NNP BELL) (NNP INDUSTRIES) (NNP Inc.)) (VP (VBD increased) (NP (PRP$ its) (NN quarterly)) (PP (TO to) (NP (CD 10) (NNS cents))) (PP (IN from) (NP (NP (CD seven) (NNS cents)) (NP (DT a) (NN share))))) (. .)) (S (NP (DT The) (JJ new) (NN rate)) (VP (MD will) (VP (VB be) (ADJP (JJ payable) (NP (NNP Feb.) (CD 15))))) (. .)) (S (NP (DT A) (NN record) (NN date)) (VP (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set)))) (. .)) (S (NP (NP (NNP Bell)) (, ,) (VP (VBN based) (PP (IN in) (NP (NNP Los) (NNP Angeles)))) (, ,)) (VP (VBZ makes) (CC and) (VBZ distributes) (NP (UCP (JJ electronic) (, ,) (NN computer) (CC and) (NN building)) (NNS products))) (. .)) (S (NP (NNS Investors)) (VP (VBP are) (VP (VBG appealing) (PP (TO to) (NP (DT the) (NNPS Securities) (CC and) (NNP Exchange) (NNP Commission))) (S (RB not) (VP (TO to) (VP (VB limit) (NP (NP (PRP$ their) (NN access)) (PP (TO to) (NP (NP (NN information)) (PP (IN about) (NP (NP (NN stock) (NNS purchases) (CC and) (NNS sales)) (PP (IN by) (NP (JJ corporate) (NNS insiders))))))))))))) (. .)) (S (S (NP (DT A) (NNP SEC) (NN proposal) (S (VP (TO to) (VP (VB ease) (NP (NP (NN reporting) (NNS requirements)) (PP (IN for) (NP (DT some) (NN company) (NNS executives)))))))) (VP (MD would) (VP (VB undermine) (NP (NP (DT the) (NN usefulness)) (PP (IN of) (NP (NP (NN information)) (PP (IN on) (NP (NN insider) (NNS trades))))) (PP (IN as) (NP (DT a) (JJ stock-picking) (NN tool))))))) (, ,) (NP (NP (JJ individual) (NNS investors)) (CC and) (NP (JJ professional) (NN money) (NNS managers))) (VP (VBP contend)) (. .)) (S (NP (PRP They)) (VP (VBP make) (NP (DT the) (NN argument)) (PP (IN in) (NP (NP (NNS letters)) (PP (TO to) (NP (DT the) (NN agency))) (PP (IN about) (NP (NP (NN rule) (NNS changes)) (VP (VBD proposed) (NP (DT this) (JJ past) (NN summer))) (SBAR (WHNP (IN that)) (, ,) (S (PP (IN among) (NP (JJ other) (NNS things))) (, ,) (VP (MD would) (VP (VB exempt) (NP (JJ many) (JJ middle-management) (NNS executives)) (PP (IN from) (S (VP (VBG reporting) (NP (NP (NNS trades)) (PP (IN in) (NP (NP (PRP$ their) (JJ own) (NNS companies) (POS ')) (NNS shares)))))))))))))))) (. .)) (S (NP (DT The) (VBN proposed) (NNS changes)) (ADVP (RB also)) (VP (MD would) (VP (VB allow) (S (NP (NNS executives)) (VP (TO to) (VP (VB report) (NP (NP (NNS exercises)) (PP (IN of) (NP (NNS options)))) (ADVP (ADVP (RBR later)) (CC and) (ADVP (RBR less) (RB often)))))))) (. .)) (S (NP (NP (JJ Many)) (PP (IN of) (NP (DT the) (NNS letters)))) (VP (VBP maintain) (SBAR (IN that) (S (S (NP (NN investor) (NN confidence)) (VP (VBZ has) (VP (VBN been) (VP (ADVP (RB so)) (VBN shaken) (PP (IN by) (NP (DT the) (CD 1987) (NN stock) (NN market) (NN crash))))))) (: --) (CC and) (S (NP (DT the) (NNS markets)) (ADVP (RB already)) (VP (ADVP (RB so)) (VBN stacked) (PP (IN against) (NP (DT the) (JJ little) (NN guy))))) (: --) (SBAR (IN that) (S (NP (NP (DT any) (NN decrease)) (PP (IN in) (NP (NP (NN information)) (PP (IN on) (NP (NN insider-trading) (NNS patterns)))))) (VP (MD might) (VP (VB prompt) (S (NP (NNS individuals)) (VP (TO to) (VP (VB get) (ADVP (RB out) (PP (IN of) (NP (NNS stocks)))) (ADVP (RB altogether)))))))))))) (. .)) (SINV (`` ``) (S (NP (DT The) (NNP SEC)) (VP (VBZ has) (ADVP (RB historically)) (VP (VBN paid) (NP (NN obeisance)) (PP (TO to) (NP (NP (DT the) (NN ideal)) (PP (IN of) (NP (DT a) (JJ level) (NN playing) (NN field)))))))) (, ,) ('' '') (VP (VBD wrote)) (NP (NP (NNP Clyde) (NNP S.) (NNP McGregor)) (PP (IN of) (NP (NP (NNP Winnetka)) (, ,) (NP (NNP Ill.)) (, ,)))) (PP (IN in) (NP (NP (CD one)) (PP (IN of) (NP (NP (DT the) (CD 92) (NNS letters)) (SBAR (S (NP (DT the) (NN agency)) (VP (VBZ has) (VP (VBN received) (SBAR (IN since) (S (NP (DT the) (NNS changes)) (VP (VBD were) (VP (VBN proposed) (NP (NNP Aug.) (CD 17)))))))))))))) (. .)) (S (`` ``) (ADVP (RB Apparently)) (NP (DT the) (NN commission)) (VP (VBD did) (RB not) (ADVP (RB really)) (VP (VB believe) (PP (IN in) (NP (DT this) (NN ideal))))) (. .) ('' '')) (S (ADVP (RB Currently)) (, ,) (NP (DT the) (NNS rules)) (VP (VBP force) (S (NP (NP (NNS executives)) (, ,) (NP (NNS directors)) (CC and) (NP (JJ other) (JJ corporate) (NNS insiders))) (VP (TO to) (VP (VB report) (NP (NP (NNS purchases) (CC and) (NNS sales)) (PP (IN of) (NP (NP (PRP$ their) (NNS companies) (POS ')) (NNS shares)))) (PP (IN within) (NP (NP (QP (IN about) (DT a)) (NN month)) (PP (IN after) (NP (DT the) (NN transaction))))))))) (. .)) (S (CC But) (NP (NP (QP (IN about) (CD 25)) (NN %)) (PP (IN of) (NP (DT the) (NNS insiders)))) (, ,) (PP (VBG according) (PP (TO to) (NP (NNP SEC) (NNS figures)))) (, ,) (VP (VBP file) (NP (PRP$ their) (NNS reports)) (ADVP (RB late))) (. .)) (SINV (S (NP (DT The) (NNS changes)) (VP (VBD were) (VP (VBN proposed) (PP (IN in) (NP (DT an) (NN effort) (S (VP (TO to) (VP (VP (VB streamline) (NP (JJ federal) (NN bureaucracy))) (CC and) (VP (VB boost) (NP (NP (NN compliance)) (PP (IN by) (NP (NP (DT the) (NNS executives)) (`` ``) (SBAR (WHNP (WP who)) (S (VP (VBP are) (ADVP (RB really)) (VP (VBG calling) (NP (DT the) (NNS shots)))))))))))))))))) (, ,) ('' '') (VP (VBD said)) (NP (NP (NNP Brian) (NNP Lane)) (, ,) (NP (NP (JJ special) (NN counsel)) (PP (IN at) (NP (NP (NP (NP (DT the) (NNP SEC) (POS 's)) (NN office)) (PP (IN of) (NP (NN disclosure) (NN policy)))) (, ,) (SBAR (WHNP (WDT which)) (S (VP (VBD proposed) (NP (DT the) (NNS changes))))))))) (. .)) (S (S (S (NP (NP (NNS Investors)) (, ,) (NP (NN money) (NNS managers)) (CC and) (NP (JJ corporate) (NNS officials))) (VP (VBD had) (PP (IN until) (NP (NN today))) (S (VP (TO to) (VP (VB comment) (PP (IN on) (NP (DT the) (NNS proposals)))))))) (, ,) (CC and) (S (NP (DT the) (NN issue)) (VP (VBZ has) (VP (VBN produced) (NP (NP (JJR more) (NN mail)) (PP (IN than) (NP (NP (ADJP (RB almost) (DT any)) (JJ other) (NN issue)) (PP (IN in) (NP (NN memory)))))))))) (, ,) (NP (NNP Mr.) (NNP Lane)) (VP (VBD said)) (. .)) ================================================ FILE: data/valid.txt ================================================ (S (NP (NP (DT The) (NN economy) (POS 's)) (NN temperature)) (VP (MD will) (VP (VB be) (VP (VBN taken) (PP (IN from) (NP (JJ several) (NN vantage) (NNS points))) (NP (DT this) (NN week)) (, ,) (PP (IN with) (NP (NP (NNS readings)) (PP (IN on) (NP (NP (NN trade)) (, ,) (NP (NN output)) (, ,) (NP (NN housing)) (CC and) (NP (NN inflation))))))))) (. .)) (S (NP (DT The) (ADJP (RBS most) (JJ troublesome)) (NN report)) (VP (MD may) (VP (VB be) (NP (NP (DT the) (NNP August) (NN merchandise) (NN trade) (NN deficit)) (ADJP (JJ due) (ADVP (IN out)) (NP (NN tomorrow)))))) (. .)) (S (NP (DT The) (NN trade) (NN gap)) (VP (VBZ is) (VP (VBN expected) (S (VP (TO to) (VP (VB widen) (PP (TO to) (NP (QP (IN about) ($ $) (CD 9) (CD billion)))) (PP (IN from) (NP (NP (NNP July) (POS 's)) (QP ($ $) (CD 7.6) (CD billion))))))) (, ,) (PP (VBG according) (PP (TO to) (NP (NP (DT a) (NN survey)) (PP (IN by) (NP (NP (NNP MMS) (NNP International)) (, ,) (NP (NP (DT a) (NN unit)) (PP (IN of) (NP (NP (NNP McGraw-Hill) (NNP Inc.)) (, ,) (NP (NNP New) (NNP York)))))))))))) (. .)) (S (NP (NP (NP (NNP Thursday) (POS 's)) (NN report)) (PP (IN on) (NP (DT the) (NNP September) (NN consumer) (NN price) (NN index)))) (VP (VBZ is) (VP (VBN expected) (S (VP (TO to) (VP (VB rise) (, ,) (SBAR (IN although) (ADVP (ADVP (RB not) (RB as) (RB sharply)) (PP (IN as) (NP (NP (DT the) (ADJP (CD 0.9) (NN %)) (NN gain)) (VP (VBN reported) (NP (NNP Friday)) (PP (IN in) (NP (DT the) (NN producer) (NN price) (NN index))))))))))))) (. .)) (S (NP (DT That) (NN gain)) (VP (VBD was) (VP (VBG being) (VP (VBD cited) (PP (IN as) (NP (NP (DT a) (NN reason)) (SBAR (S (NP (DT the) (NN stock) (NN market)) (VP (VBD was) (ADVP (IN down)) (ADVP (RB early) (PP (IN in) (NP (NP (NNP Friday) (POS 's)) (NN session)))) (, ,) (SBAR (IN before) (S (NP (PRP it)) (VP (VBD got) (S (VP (VBN started) (PP (IN on) (NP (PRP$ its) (JJ reckless) (JJ 190-point) (NN plunge)))))))))))))))) (. .)) (S (NP (NNS Economists)) (VP (VBP are) (VP (VBN divided) (PP (IN as) (PP (TO to) (SBAR (WHNP (WHADVP (WRB how) (JJ much)) (VBG manufacturing) (NN strength)) (S (NP (PRP they)) (VP (VBP expect) (S (VP (TO to) (VP (VB see) (PP (IN in) (NP (NP (NP (NNP September) (NNS reports)) (PP (IN on) (NP (NP (JJ industrial) (NN production)) (CC and) (NP (NN capacity) (NN utilization))))) (, ,) (ADJP (ADVP (RB also)) (JJ due) (NP (NN tomorrow))))))))))))))) (. .)) (S (ADVP (RB Meanwhile)) (, ,) (NP (NP (NNP September) (NN housing) (NNS starts)) (, ,) (ADJP (JJ due) (NP (NNP Wednesday))) (, ,)) (VP (VBP are) (VP (VBN thought) (S (VP (TO to) (VP (VB have) (VP (VBN inched) (ADVP (RB upward)))))))) (. .)) (SINV (S (`` ``) (NP (EX There)) (VP (VBZ 's) (NP (NP (DT a) (NN possibility)) (PP (IN of) (NP (NP (DT a) (NN surprise)) ('' '') (PP (IN in) (NP (DT the) (NN trade) (NN report)))))))) (, ,) (VP (VBD said)) (NP (NP (NNP Michael) (NNP Englund)) (, ,) (NP (NP (NN director)) (PP (IN of) (NP (NN research))) (PP (IN at) (NP (NNP MMS))))) (. .)) (S (S (NP (NP (DT A) (NN widening)) (PP (IN of) (NP (DT the) (NN deficit)))) (, ,) (SBAR (IN if) (S (NP (PRP it)) (VP (VBD were) (VP (VBN combined) (PP (IN with) (NP (DT a) (ADJP (RB stubbornly) (JJ strong)) (NN dollar))))))) (, ,) (VP (MD would) (VP (VB exacerbate) (NP (NN trade) (NNS problems)))) (: --)) (CC but) (S (NP (DT the) (NN dollar)) (VP (VBD weakened) (NP (NNP Friday)) (SBAR (IN as) (S (NP (NNS stocks)) (VP (VBD plummeted)))))) (. .)) (S (PP (IN In) (NP (DT any) (NN event))) (, ,) (NP (NP (NNP Mr.) (NNP Englund)) (CC and) (NP (JJ many) (NNS others))) (VP (VBP say) (SBAR (IN that) (S (NP (NP (DT the) (JJ easy) (NNS gains)) (PP (IN in) (S (VP (VBG narrowing) (NP (DT the) (NN trade) (NN gap)))))) (VP (VBP have) (ADVP (RB already)) (VP (VBN been) (VP (VBN made))))))) (. .)) (S (`` ``) (S (NP (NN Trade)) (VP (VBZ is) (ADVP (RB definitely)) (VP (VBG going) (S (VP (TO to) (VP (VB be) (ADJP (RBR more) (RB politically) (JJ sensitive)) (PP (IN over) (NP (DT the) (JJ next) (QP (CD six) (CC or) (CD seven)) (NNS months))) (SBAR (IN as) (S (NP (NN improvement)) (VP (VBZ begins) (S (VP (TO to) (VP (VB slow))))))))))))) (, ,) ('' '') (NP (PRP he)) (VP (VBD said)) (. .)) (S (S (NP (NNS Exports)) (VP (VBP are) (VP (VBN thought) (S (VP (TO to) (VP (VB have) (VP (VBN risen) (ADVP (ADVP (RB strongly) (PP (IN in) (NP (NNP August)))) (, ,) (CC but) (ADVP (ADVP (RB probably)) (RB not) (RB enough) (S (VP (TO to) (VP (VB offset) (NP (NP (DT the) (NN jump)) (PP (IN in) (NP (NNS imports)))))))))))))))) (, ,) (NP (NNS economists)) (VP (VBD said)) (. .)) (S (NP (NP (NNS Views)) (PP (IN on) (NP (VBG manufacturing) (NN strength)))) (VP (VBP are) (ADJP (VBN split) (PP (IN between) (NP (NP (NP (NNS economists)) (SBAR (WHNP (WP who)) (S (VP (VBP read) (NP (NP (NP (NNP September) (POS 's)) (JJ low) (NN level)) (PP (IN of) (NP (NN factory) (NN job) (NN growth)))) (PP (IN as) (NP (NP (DT a) (NN sign)) (PP (IN of) (NP (DT a) (NN slowdown))))))))) (CC and) (NP (NP (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBP use) (NP (DT the) (ADJP (RB somewhat) (JJR more) (VBG comforting)) (JJ total) (NN employment) (NNS figures)) (PP (IN in) (NP (PRP$ their) (NNS calculations))))))))))) (. .)) (S (S (NP (NP (DT The) (JJ wide) (NN range)) (PP (IN of) (NP (NP (NNS estimates)) (PP (IN for) (NP (DT the) (JJ industrial) (NN output) (NN number)))))) (VP (VBZ underscores) (NP (DT the) (NNS differences)))) (: :) (S (NP (DT The) (NNS forecasts)) (VP (VBD run) (PP (IN from) (NP (NP (DT a) (NN drop)) (PP (IN of) (NP (CD 0.5) (NN %))))) (PP (TO to) (NP (NP (DT an) (NN increase)) (PP (IN of) (NP (CD 0.4) (NN %))))) (, ,) (PP (VBG according) (PP (TO to) (NP (NNP MMS)))))) (. .)) (S (NP (NP (DT A) (NN rebound)) (PP (IN in) (NP (NN energy) (NNS prices))) (, ,) (SBAR (WHNP (WDT which)) (S (VP (VBD helped) (VP (VB push) (PRT (RP up)) (NP (DT the) (NN producer) (NN price) (NN index)))))) (, ,)) (VP (VBZ is) (VP (VBN expected) (S (VP (TO to) (VP (VB do) (NP (DT the) (JJ same)) (PP (IN in) (NP (DT the) (NN consumer) (NN price) (NN report)))))))) (. .)) (S (NP (DT The) (NN consensus) (NN view)) (VP (VBZ expects) (NP (NP (DT a) (ADJP (CD 0.4) (NN %)) (NN increase)) (PP (IN in) (NP (DT the) (NNP September) (NNP CPI)))) (PP (IN after) (NP (NP (DT a) (JJ flat) (NN reading)) (PP (IN in) (NP (NNP August)))))) (. .)) (S (NP (NP (NNP Robert) (NNP H.) (NNP Chandross)) (, ,) (NP (NP (DT an) (NN economist)) (PP (IN for) (NP (NP (NP (NNP Lloyd) (POS 's)) (NNP Bank)) (PP (IN in) (NP (NNP New) (NNP York)))))) (, ,)) (VP (VBZ is) (PP (IN among) (NP (NP (DT those)) (VP (VBG expecting) (NP (NP (DT a) (ADJP (RBR more) (JJ moderate)) (NN gain)) (PP (IN in) (NP (DT the) (NNP CPI))) (PP (IN than) (PP (IN in) (NP (NP (NNS prices)) (PP (IN at) (NP (DT the) (NN producer) (NN level))))))))))) (. .)) (S (`` ``) (S (S (NP (NN Auto) (NNS prices)) (VP (VBD had) (NP (DT a) (JJ big) (NN effect)) (PP (IN in) (NP (DT the) (NNP PPI))))) (, ,) (CC and) (S (PP (IN at) (NP (DT the) (NNP CPI) (NN level))) (NP (PRP they)) (VP (MD wo) (RB n't)))) (, ,) ('' '') (NP (PRP he)) (VP (VBD said)) (. .)) (SINV (S (S (NP (NN Food) (NNS prices)) (VP (VBP are) (VP (VBN expected) (S (VP (TO to) (VP (VB be) (ADJP (JJ unchanged)))))))) (, ,) (CC but) (S (NP (NN energy) (NNS costs)) (VP (VBD jumped) (NP (NP (RB as) (RB much) (IN as) (CD 4)) (NN %))))) (, ,) (VP (VBD said)) (NP (NP (NNP Gary) (NNP Ciminero)) (, ,) (NP (NP (NN economist)) (PP (IN at) (NP (NNP Fleet\/Norstar) (NNP Financial) (NNP Group))))) (. .)) (S (NP (PRP He)) (ADVP (RB also)) (VP (VBZ says) (SBAR (S (NP (PRP he)) (VP (VBZ thinks) (SBAR (S (NP (`` ``) (NP (NN core) (NN inflation)) (, ,) ('' '') (SBAR (WHNP (WDT which)) (S (VP (VBZ excludes) (NP (DT the) (JJ volatile) (NN food) (CC and) (NN energy) (NNS prices))))) (, ,)) (VP (VBD was) (ADJP (JJ strong)) (NP (JJ last) (NN month))))))))) (. .)) ================================================ FILE: data.py ================================================ #!/usr/bin/env python3 import numpy as np import torch import pickle class Dataset(object): def __init__(self, data_file): data = pickle.load(open(data_file, 'rb')) #get text data self.sents = self._convert(data['source']).long() self.other_data = data['other_data'] self.sent_lengths = self._convert(data['source_l']).long() self.batch_size = self._convert(data['batch_l']).long() self.batch_idx = self._convert(data['batch_idx']).long() self.vocab_size = data['vocab_size'][0] self.num_batches = self.batch_idx.size(0) self.word2idx = data['word2idx'] self.idx2word = data['idx2word'] def _convert(self, x): return torch.from_numpy(np.asarray(x)) def __len__(self): return self.num_batches def __getitem__(self, idx): assert(idx < self.num_batches and idx >= 0) start_idx = self.batch_idx[idx] end_idx = start_idx + self.batch_size[idx] length = self.sent_lengths[idx].item() sents = self.sents[start_idx:end_idx] other_data = self.other_data[start_idx:end_idx] sent_str = [d[0] for d in other_data] tags = [d[1] for d in other_data] actions = [d[2] for d in other_data] binary_tree = [d[3] for d in other_data] spans = [d[5] for d in other_data] batch_size = self.batch_size[idx].item() # by default, we return sents with tokens # hence we subtract 2 from length as these are (by default) not counted for evaluation data_batch = [sents[:, :length], length-2, batch_size, actions, spans, binary_tree, other_data] return data_batch ================================================ FILE: eval_ppl.py ================================================ #!/usr/bin/env python3 import sys import os import argparse import json import random import shutil import copy import torch from torch import cuda import torch.nn as nn from torch.autograd import Variable from torch.nn.parameter import Parameter import torch.nn.functional as F import numpy as np import time import logging from data import Dataset from models import RNNG from utils import * parser = argparse.ArgumentParser() # Data path options parser.add_argument('--test_file', default='data/ptb-test.pkl') parser.add_argument('--model_file', default='') parser.add_argument('--is_temp', default=2., type=float, help='divide scores by is_temp before CRF') parser.add_argument('--samples', default=1000, type=int, help='samples for IS calculation') parser.add_argument('--count_eos_ppl', default=0, type=int, help='whether to count eos in val PPL') parser.add_argument('--gpu', default=2, type=int, help='which gpu to use') parser.add_argument('--seed', default=3435, type=int) def main(args): np.random.seed(args.seed) torch.manual_seed(args.seed) data = Dataset(args.test_file) checkpoint = torch.load(args.model_file) model = checkpoint['model'] print("model architecture") print(model) cuda.set_device(args.gpu) model.cuda() model.eval() num_sents = 0 num_words = 0 total_nll_recon = 0. total_kl = 0. total_nll_iwae = 0. samples_batch = 50 S = args.samples // samples_batch samples = S*samples_batch with torch.no_grad(): for i in list(reversed(range(len(data)))): sents, length, batch_size, gold_actions, gold_spans, gold_binary_trees, other_data = data[i] if length == 1: # length 1 sents are ignored since our generative model requires sents of length >= 2 continue if args.count_eos_ppl == 1: length += 1 else: sents = sents[:, :-1] sents = sents.cuda() ll_word_all2 = [] ll_action_p_all2 = [] ll_action_q_all2 = [] for j in range(S): ll_word_all, ll_action_p_all, ll_action_q_all, actions_all, q_entropy = model( sents, samples = samples_batch, is_temp = args.is_temp, has_eos = args.count_eos_ppl == 1) ll_word_all2.append(ll_word_all.detach().cpu()) ll_action_p_all2.append(ll_action_p_all.detach().cpu()) ll_action_q_all2.append(ll_action_q_all.detach().cpu()) ll_word_all2 = torch.cat(ll_word_all2, 1) ll_action_p_all2 = torch.cat(ll_action_p_all2, 1) ll_action_q_all2 = torch.cat(ll_action_q_all2, 1) sample_ll = torch.zeros(batch_size, ll_word_all2.size(1)) total_nll_recon += -ll_word_all.mean(1).sum().item() total_kl += (ll_action_q_all - ll_action_p_all).mean(1).sum().item() for j in range(sample_ll.size(1)): ll_word_j, ll_action_p_j, ll_action_q_j = ll_word_all2[:, j], ll_action_p_all2[:, j], ll_action_q_all2[:, j] sample_ll[:, j].copy_(ll_word_j + ll_action_p_j - ll_action_q_j) ll_iwae = model.logsumexp(sample_ll, 1) - np.log(samples) total_nll_iwae -= ll_iwae.sum().item() num_sents += batch_size num_words += batch_size * length print('Batch: %d/%d, ElboPPL: %.2f, KL: %.4f, IwaePPL: %.2f' % (i, len(data), np.exp((total_nll_recon + total_kl) / num_words), total_kl / num_sents, np.exp(total_nll_iwae / num_words))) elbo_ppl = np.exp((total_nll_recon + total_kl) / num_words) recon_ppl = np.exp(total_nll_recon / num_words) iwae_ppl = np.exp(total_nll_iwae /num_words) kl = total_kl / num_sents print('ElboPPL: %.2f, ReconPPL: %.2f, KL: %.4f, IwaePPL: %.2f' % (elbo_ppl, recon_ppl, kl, iwae_ppl)) if __name__ == '__main__': args = parser.parse_args() main(args) ================================================ FILE: models.py ================================================ import torch from torch import nn import torch.nn.functional as F import numpy as np from utils import * from TreeCRF import ConstituencyTreeCRF from torch.distributions import Bernoulli class RNNLM(nn.Module): def __init__(self, vocab=10000, w_dim=650, h_dim=650, num_layers=2, dropout=0.5): super(RNNLM, self).__init__() self.h_dim = h_dim self.num_layers = num_layers self.word_vecs = nn.Embedding(vocab, w_dim) self.dropout = nn.Dropout(dropout) self.rnn = nn.LSTM(w_dim, h_dim, num_layers = num_layers, dropout = dropout, batch_first = True) self.vocab_linear = nn.Linear(h_dim, vocab) self.vocab_linear.weight = self.word_vecs.weight # weight sharing def forward(self, sent): word_vecs = self.dropout(self.word_vecs(sent[:, :-1])) h, _ = self.rnn(word_vecs) log_prob = F.log_softmax(self.vocab_linear(self.dropout(h)), 2) # b x l x v ll = torch.gather(log_prob, 2, sent[:, 1:].unsqueeze(2)).squeeze(2) return ll.sum(1) def generate(self, bos = 2, eos = 3, max_len = 150): x = [] bos = torch.LongTensor(1,1).cuda().fill_(bos) emb = self.dropout(self.word_vecs(bos)) prev_h = None for l in range(max_len): h, prev_h = self.rnn(emb, prev_h) prob = F.softmax(self.vocab_linear(self.dropout(h.squeeze(1))), 1) sample = torch.multinomial(prob, 1) emb = self.dropout(self.word_vecs(sample)) x.append(sample.item()) if x[-1] == eos: x.pop() break return x class SeqLSTM(nn.Module): def __init__(self, i_dim = 200, h_dim = 0, num_layers = 1, dropout = 0): super(SeqLSTM, self).__init__() self.i_dim = i_dim self.h_dim = h_dim self.num_layers = num_layers self.linears = nn.ModuleList([nn.Linear(h_dim + i_dim, h_dim*4) if l == 0 else nn.Linear(h_dim*2, h_dim*4) for l in range(num_layers)]) self.dropout = dropout self.dropout_layer = nn.Dropout(dropout) def forward(self, x, prev_h = None): if prev_h is None: prev_h = [(x.new(x.size(0), self.h_dim).fill_(0), x.new(x.size(0), self.h_dim).fill_(0)) for _ in range(self.num_layers)] curr_h = [] for l in range(self.num_layers): input = x if l == 0 else curr_h[l-1][0] if l > 0 and self.dropout > 0: input = self.dropout_layer(input) concat = torch.cat([input, prev_h[l][0]], 1) all_sum = self.linears[l](concat) i, f, o, g = all_sum.split(self.h_dim, 1) c = F.sigmoid(f)*prev_h[l][1] + F.sigmoid(i)*F.tanh(g) h = F.sigmoid(o)*F.tanh(c) curr_h.append((h, c)) return curr_h class TreeLSTM(nn.Module): def __init__(self, dim = 200): super(TreeLSTM, self).__init__() self.dim = dim self.linear = nn.Linear(dim*2, dim*5) def forward(self, x1, x2, e=None): if not isinstance(x1, tuple): x1 = (x1, None) h1, c1 = x1 if x2 is None: x2 = (torch.zeros_like(h1), torch.zeros_like(h1)) elif not isinstance(x2, tuple): x2 = (x2, None) h2, c2 = x2 if c1 is None: c1 = torch.zeros_like(h1) if c2 is None: c2 = torch.zeros_like(h2) concat = torch.cat([h1, h2], 1) all_sum = self.linear(concat) i, f1, f2, o, g = all_sum.split(self.dim, 1) c = F.sigmoid(f1)*c1 + F.sigmoid(f2)*c2 + F.sigmoid(i)*F.tanh(g) h = F.sigmoid(o)*F.tanh(c) return (h, c) class RNNG(nn.Module): def __init__(self, vocab = 100, w_dim = 20, h_dim = 20, num_layers = 1, dropout = 0, q_dim = 20, max_len = 250): super(RNNG, self).__init__() self.S = 0 #action idx for shift/generate self.R = 1 #action idx for reduce self.emb = nn.Embedding(vocab, w_dim) self.dropout = nn.Dropout(dropout) self.stack_rnn = SeqLSTM(w_dim, h_dim, num_layers = num_layers, dropout = dropout) self.tree_rnn = TreeLSTM(w_dim) self.vocab_mlp = nn.Sequential(nn.Dropout(dropout), nn.Linear(h_dim, vocab)) self.num_layers = num_layers self.q_binary = nn.Sequential(nn.Linear(q_dim*2, q_dim*2), nn.ReLU(), nn.LayerNorm(q_dim*2), nn.Dropout(dropout), nn.Linear(q_dim*2, 1)) self.action_mlp_p = nn.Sequential(nn.Dropout(dropout), nn.Linear(h_dim, 1)) self.w_dim = w_dim self.h_dim = h_dim self.q_dim = q_dim self.q_leaf_rnn = nn.LSTM(w_dim, q_dim, bidirectional = True, batch_first = True) self.q_crf = ConstituencyTreeCRF() self.pad1 = 0 # idx for token from ptb.dict self.pad2 = 2 # idx for token from ptb.dict self.q_pos_emb = nn.Embedding(max_len, w_dim) # position embeddings self.vocab_mlp[-1].weight = self.emb.weight #share embeddings def get_span_scores(self, x): #produces the span scores s_ij bos = x.new(x.size(0), 1).fill_(self.pad1) eos = x.new(x.size(0), 1).fill_(self.pad2) x = torch.cat([bos, x, eos], 1) x_vec = self.dropout(self.emb(x)) pos = torch.arange(0, x.size(1)).unsqueeze(0).expand_as(x).long().cuda() x_vec = x_vec + self.dropout(self.q_pos_emb(pos)) q_h, _ = self.q_leaf_rnn(x_vec) fwd = q_h[:, 1:, :self.q_dim] bwd = q_h[:, :-1, self.q_dim:] fwd_diff = fwd[:, 1:].unsqueeze(1) - fwd[:, :-1].unsqueeze(2) bwd_diff = bwd[:, :-1].unsqueeze(2) - bwd[:, 1:].unsqueeze(1) concat = torch.cat([fwd_diff, bwd_diff], 3) scores = self.q_binary(concat).squeeze(3) return scores def get_action_masks(self, actions, length): #this masks out actions so that we don't incur a loss if some actions are deterministic #in practice this doesn't really seem to matter mask = actions.new(actions.size(0), actions.size(1)).fill_(1) for b in range(actions.size(0)): num_shift = 0 stack_len = 0 for l in range(actions.size(1)): if stack_len < 2: mask[b][l].fill_(0) if actions[b][l].item() == self.S: num_shift += 1 stack_len += 1 else: stack_len -= 1 return mask def forward(self, x, samples = 1, is_temp = 1., has_eos=True): #For has eos, if exists, then inference network ignores it. #Note that is predicted for training since we want the model to know when to stop. #However it is ignored for PPL evaluation on the version of the PTB dataset from #the original RNNG paper (Dyer et al. 2016) init_emb = self.dropout(self.emb(x[:, 0])) x = x[:, 1:] batch, length = x.size(0), x.size(1) if has_eos: parse_length = length - 1 parse_x = x[:, :-1] else: parse_length = length parse_x = x word_vecs = self.dropout(self.emb(x)) scores = self.get_span_scores(parse_x) self.scores = scores scores = scores / is_temp self.q_crf._forward(scores) self.q_crf._entropy(scores) entropy = self.q_crf.entropy[0][parse_length-1] crf_input = scores.unsqueeze(1).expand(batch, samples, parse_length, parse_length) crf_input = crf_input.contiguous().view(batch*samples, parse_length, parse_length) for i in range(len(self.q_crf.alpha)): for j in range(len(self.q_crf.alpha)): self.q_crf.alpha[i][j] = self.q_crf.alpha[i][j].unsqueeze(1).expand( batch, samples).contiguous().view(batch*samples) _, log_probs_action_q, tree_brackets, spans = self.q_crf._sample(crf_input, self.q_crf.alpha) actions = [] for b in range(crf_input.size(0)): action = get_actions(tree_brackets[b]) if has_eos: actions.append(action + [self.S, self.R]) #we train the model to generate and then do a final reduce else: actions.append(action) actions = torch.Tensor(actions).float().cuda() action_masks = self.get_action_masks(actions, length) num_action = 2*length - 1 batch_expand = batch*samples contexts = [] log_probs_action_p = [] #conditional prior init_emb = init_emb.unsqueeze(1).expand(batch, samples, self.w_dim) init_emb = init_emb.contiguous().view(batch_expand, self.w_dim) init_stack = self.stack_rnn(init_emb, None) x_expand = x.unsqueeze(1).expand(batch, samples, length) x_expand = x_expand.contiguous().view(batch_expand, length) word_vecs = self.dropout(self.emb(x_expand)) word_vecs = word_vecs.unsqueeze(2) word_vecs_zeros = torch.zeros_like(word_vecs) stack = [init_stack] stack_child = [[] for _ in range(batch_expand)] stack2 = [[] for _ in range(batch_expand)] for b in range(batch_expand): stack2[b].append([[init_stack[l][0][b], init_stack[l][1][b]] for l in range(self.num_layers)]) pointer = [0]*batch_expand for l in range(num_action): contexts.append(stack[-1][-1][0]) stack_input = [] child1_h = [] child1_c = [] child2_h = [] child2_c = [] stack_context = [] for b in range(batch_expand): # batch all the shift/reduce operations separately if actions[b][l].item() == self.R: child1 = stack_child[b].pop() child2 = stack_child[b].pop() child1_h.append(child1[0]) child1_c.append(child1[1]) child2_h.append(child2[0]) child2_c.append(child2[1]) stack2[b].pop() stack2[b].pop() if len(child1_h) > 0: child1_h = torch.cat(child1_h, 0) child1_c = torch.cat(child1_c, 0) child2_h = torch.cat(child2_h, 0) child2_c = torch.cat(child2_c, 0) new_child = self.tree_rnn((child1_h, child1_c), (child2_h, child2_c)) child_idx = 0 stack_h = [[[], []] for _ in range(self.num_layers)] for b in range(batch_expand): assert(len(stack2[b]) - 1 == len(stack_child[b])) for k in range(self.num_layers): stack_h[k][0].append(stack2[b][-1][k][0]) stack_h[k][1].append(stack2[b][-1][k][1]) if actions[b][l].item() == self.S: input_b = word_vecs[b][pointer[b]] stack_child[b].append((word_vecs[b][pointer[b]], word_vecs_zeros[b][pointer[b]])) pointer[b] += 1 else: input_b = new_child[0][child_idx].unsqueeze(0) stack_child[b].append((input_b, new_child[1][child_idx].unsqueeze(0))) child_idx += 1 stack_input.append(input_b) stack_input = torch.cat(stack_input, 0) stack_h_all = [] for k in range(self.num_layers): stack_h_all.append((torch.stack(stack_h[k][0], 0), torch.stack(stack_h[k][1], 0))) stack_h = self.stack_rnn(stack_input, stack_h_all) stack.append(stack_h) for b in range(batch_expand): stack2[b].append([[stack_h[k][0][b], stack_h[k][1][b]] for k in range(self.num_layers)]) contexts = torch.stack(contexts, 1) #stack contexts action_logit_p = self.action_mlp_p(contexts).squeeze(2) action_prob_p = F.sigmoid(action_logit_p).clamp(min=1e-7, max=1-1e-7) action_shift_score = (1 - action_prob_p).log() action_reduce_score = action_prob_p.log() action_score = (1-actions)*action_shift_score + actions*action_reduce_score action_score = (action_score*action_masks).sum(1) word_contexts = contexts[actions < 1] word_contexts = word_contexts.contiguous().view(batch*samples, length, self.h_dim) log_probs_word = F.log_softmax(self.vocab_mlp(word_contexts), 2) log_probs_word = torch.gather(log_probs_word, 2, x_expand.unsqueeze(2)).squeeze(2) log_probs_word = log_probs_word.sum(1) log_probs_word = log_probs_word.contiguous().view(batch, samples) log_probs_action_p = action_score.contiguous().view(batch, samples) log_probs_action_q = log_probs_action_q.contiguous().view(batch, samples) actions = actions.contiguous().view(batch, samples, -1) return log_probs_word, log_probs_action_p, log_probs_action_q, actions, entropy def forward_actions(self, x, actions, has_eos=True): # this is for when ground through actions are available init_emb = self.dropout(self.emb(x[:, 0])) x = x[:, 1:] if has_eos: new_actions = [] for action in actions: new_actions.append(action + [self.S, self.R]) actions = new_actions batch, length = x.size(0), x.size(1) word_vecs = self.dropout(self.emb(x)) actions = torch.Tensor(actions).float().cuda() action_masks = self.get_action_masks(actions, length) num_action = 2*length - 1 contexts = [] log_probs_action_p = [] #prior init_stack = self.stack_rnn(init_emb, None) word_vecs = word_vecs.unsqueeze(2) word_vecs_zeros = torch.zeros_like(word_vecs) stack = [init_stack] stack_child = [[] for _ in range(batch)] stack2 = [[] for _ in range(batch)] pointer = [0]*batch for b in range(batch): stack2[b].append([[init_stack[l][0][b], init_stack[l][1][b]] for l in range(self.num_layers)]) for l in range(num_action): contexts.append(stack[-1][-1][0]) stack_input = [] child1_h = [] child1_c = [] child2_h = [] child2_c = [] stack_context = [] for b in range(batch): if actions[b][l].item() == self.R: child1 = stack_child[b].pop() child2 = stack_child[b].pop() child1_h.append(child1[0]) child1_c.append(child1[1]) child2_h.append(child2[0]) child2_c.append(child2[1]) stack2[b].pop() stack2[b].pop() if len(child1_h) > 0: child1_h = torch.cat(child1_h, 0) child1_c = torch.cat(child1_c, 0) child2_h = torch.cat(child2_h, 0) child2_c = torch.cat(child2_c, 0) new_child = self.tree_rnn((child1_h, child1_c), (child2_h, child2_c)) child_idx = 0 stack_h = [[[], []] for _ in range(self.num_layers)] for b in range(batch): assert(len(stack2[b]) - 1 == len(stack_child[b])) for k in range(self.num_layers): stack_h[k][0].append(stack2[b][-1][k][0]) stack_h[k][1].append(stack2[b][-1][k][1]) if actions[b][l].item() == self.S: input_b = word_vecs[b][pointer[b]] stack_child[b].append((word_vecs[b][pointer[b]], word_vecs_zeros[b][pointer[b]])) pointer[b] += 1 else: input_b = new_child[0][child_idx].unsqueeze(0) stack_child[b].append((input_b, new_child[1][child_idx].unsqueeze(0))) child_idx += 1 stack_input.append(input_b) stack_input = torch.cat(stack_input, 0) stack_h_all = [] for k in range(self.num_layers): stack_h_all.append((torch.stack(stack_h[k][0], 0), torch.stack(stack_h[k][1], 0))) stack_h = self.stack_rnn(stack_input, stack_h_all) stack.append(stack_h) for b in range(batch): stack2[b].append([[stack_h[k][0][b], stack_h[k][1][b]] for k in range(self.num_layers)]) contexts = torch.stack(contexts, 1) action_logit_p = self.action_mlp_p(contexts).squeeze(2) action_prob_p = F.sigmoid(action_logit_p).clamp(min=1e-7, max=1-1e-7) action_shift_score = (1 - action_prob_p).log() action_reduce_score = action_prob_p.log() action_score = (1-actions)*action_shift_score + actions*action_reduce_score action_score = (action_score*action_masks).sum(1) word_contexts = contexts[actions < 1] word_contexts = word_contexts.contiguous().view(batch, length, self.h_dim) log_probs_word = F.log_softmax(self.vocab_mlp(word_contexts), 2) log_probs_word = torch.gather(log_probs_word, 2, x.unsqueeze(2)).squeeze(2).sum(1) log_probs_action_p = action_score.contiguous().view(batch) actions = actions.contiguous().view(batch, 1, -1) return log_probs_word, log_probs_action_p, actions def forward_tree(self, x, actions, has_eos=True): # this is log q( tree | x) for discriminative parser training in supervised RNNG init_emb = self.dropout(self.emb(x[:, 0])) x = x[:, 1:-1] batch, length = x.size(0), x.size(1) scores = self.get_span_scores(x) crf_input = scores gold_spans = scores.new(batch, length, length) for b in range(batch): gold_spans[b].copy_(torch.eye(length).cuda()) spans = get_spans(actions[b]) for span in spans: gold_spans[b][span[0]][span[1]] = 1 self.q_crf._forward(crf_input) log_Z = self.q_crf.alpha[0][length-1] span_scores = (gold_spans*scores).sum(2).sum(1) ll_action_q = span_scores - log_Z return ll_action_q def logsumexp(self, x, dim=1): d = torch.max(x, dim)[0] if x.dim() == 1: return torch.log(torch.exp(x - d).sum(dim)) + d else: return torch.log(torch.exp(x - d.unsqueeze(dim).expand_as(x)).sum(dim)) + d ================================================ FILE: parse.py ================================================ #!/usr/bin/env python3 import sys import os import argparse import json import random import shutil import copy import torch from torch import cuda import torch.nn as nn import numpy as np import time from utils import * import utils import re parser = argparse.ArgumentParser() # Data path options parser.add_argument('--data_file', default='ptb-test.txt') parser.add_argument('--model_file', default='urnng.pt') parser.add_argument('--out_file', default='pred-parse.txt') parser.add_argument('--gold_out_file', default='gold-parse.txt') parser.add_argument('--lowercase', type=int, default=0) parser.add_argument('--replace_num', type=int, default=0) # Inference options parser.add_argument('--gpu', default=0, type=int, help='which gpu to use') def is_next_open_bracket(line, start_idx): for char in line[(start_idx + 1):]: if char == '(': return True elif char == ')': return False raise IndexError('Bracket possibly not balanced, open bracket not followed by closed bracket') def get_between_brackets(line, start_idx): output = [] for char in line[(start_idx + 1):]: if char == ')': break assert not(char == '(') output.append(char) return ''.join(output) def get_tags_tokens_lowercase(line): output = [] line_strip = line.rstrip() for i in range(len(line_strip)): if i == 0: assert line_strip[i] == '(' if line_strip[i] == '(' and not(is_next_open_bracket(line_strip, i)): # fulfilling this condition means this is a terminal symbol output.append(get_between_brackets(line_strip, i)) #print 'output:',output output_tags = [] output_tokens = [] output_lowercase = [] for terminal in output: terminal_split = terminal.split() assert len(terminal_split) == 2 # each terminal contains a POS tag and word output_tags.append(terminal_split[0]) output_tokens.append(terminal_split[1]) output_lowercase.append(terminal_split[1].lower()) return [output_tags, output_tokens, output_lowercase] def get_nonterminal(line, start_idx): assert line[start_idx] == '(' # make sure it's an open bracket output = [] for char in line[(start_idx + 1):]: if char == ' ': break assert not(char == '(') and not(char == ')') output.append(char) return ''.join(output) def get_actions(line): output_actions = [] line_strip = line.rstrip() i = 0 max_idx = (len(line_strip) - 1) while i <= max_idx: assert line_strip[i] == '(' or line_strip[i] == ')' if line_strip[i] == '(': if is_next_open_bracket(line_strip, i): # open non-terminal curr_NT = get_nonterminal(line_strip, i) output_actions.append('NT(' + curr_NT + ')') i += 1 while line_strip[i] != '(': # get the next open bracket, which may be a terminal or another non-terminal i += 1 else: # it's a terminal symbol output_actions.append('SHIFT') while line_strip[i] != ')': i += 1 i += 1 while line_strip[i] != ')' and line_strip[i] != '(': i += 1 else: output_actions.append('REDUCE') if i == max_idx: break i += 1 while line_strip[i] != ')' and line_strip[i] != '(': i += 1 assert i == max_idx return output_actions def clean_number(w): new_w = re.sub('[0-9]{1,}([,.]?[0-9]*)*', 'N', w) return new_w def main(args): print('loading model from ' + args.model_file) checkpoint = torch.load(args.model_file) model = checkpoint['model'] word2idx = checkpoint['word2idx'] cuda.set_device(args.gpu) model.eval() model.cuda() corpus_f1 = [0., 0., 0.] sent_f1 = [] pred_out = open(args.out_file, "w") gold_out = open(args.gold_out_file, "w") with torch.no_grad(): for j, gold_tree in enumerate(open(args.data_file, "r")): tree = gold_tree.strip() action = get_actions(tree) tags, sent, sent_lower = get_tags_tokens_lowercase(tree) sent_orig = sent[::] if args.lowercase == 1: sent = sent_lower gold_span, binary_actions, nonbinary_actions = get_nonbinary_spans(action) length = len(sent) if args.replace_num == 1: sent = [clean_number(w) for w in sent] if length == 1: continue # we ignore length 1 sents. this doesn't change F1 since we discard trivial spans sent_idx = [word2idx[""]] + [word2idx[w] if w in word2idx else word2idx[""] for w in sent] sents = torch.from_numpy(np.array(sent_idx)).unsqueeze(0) sents = sents.cuda() ll_word_all, ll_action_p_all, ll_action_q_all, actions_all, q_entropy = model( sents, samples = 1, is_temp = 1, has_eos = False) _, binary_matrix, argmax_spans = model.q_crf._viterbi(model.scores) tree = get_tree_from_binary_matrix(binary_matrix[0], len(sent)) actions = utils.get_actions(tree) pred_span= [(a[0], a[1]) for a in argmax_spans[0]] pred_span_set = set(pred_span[:-1]) #the last span in the list is always the gold_span_set = set(gold_span[:-1]) #trival sent-level span so we ignore it tp, fp, fn = get_stats(pred_span_set, gold_span_set) corpus_f1[0] += tp corpus_f1[1] += fp corpus_f1[2] += fn binary_matrix = binary_matrix[0].cpu().numpy() pred_tree = {} for i in range(length): tag = tags[i] # need gold tags so evalb correctly ignores punctuation pred_tree[i] = "(" + tag + " " + sent_orig[i] + ")" for k in np.arange(1, length): for s in np.arange(length): t = s + k if t > length - 1: break if binary_matrix[s][t] == 1: nt = "NT-1" span = "(" + nt + " " + pred_tree[s] + " " + pred_tree[t] + ")" pred_tree[s] = span pred_tree[t] = span pred_tree = pred_tree[0] pred_out.write(pred_tree.strip() + "\n") gold_out.write(gold_tree.strip() + "\n") print(pred_tree) # sent-level F1 is based on L83-89 from https://github.com/yikangshen/PRPN/test_phrase_grammar.py overlap = pred_span_set.intersection(gold_span_set) prec = float(len(overlap)) / (len(pred_span_set) + 1e-8) reca = float(len(overlap)) / (len(gold_span_set) + 1e-8) if len(gold_span_set) == 0: reca = 1. if len(pred_span_set) == 0: prec = 1. f1 = 2 * prec * reca / (prec + reca + 1e-8) sent_f1.append(f1) pred_out.close() gold_out.close() tp, fp, fn = corpus_f1 prec = tp / (tp + fp) recall = tp / (tp + fn) corpus_f1 = 2*prec*recall/(prec+recall) if prec+recall > 0 else 0. print('Corpus F1: %.2f, Sentence F1: %.2f' % (corpus_f1*100, np.mean(np.array(sent_f1))*100)) if __name__ == '__main__': args = parser.parse_args() main(args) ================================================ FILE: preprocess.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- """Create data files """ import os import sys import argparse import numpy as np import pickle import itertools from collections import defaultdict import utils import re class Indexer: def __init__(self, symbols = ["","","",""]): self.vocab = defaultdict(int) self.PAD = symbols[0] self.UNK = symbols[1] self.BOS = symbols[2] self.EOS = symbols[3] self.d = {self.PAD: 0, self.UNK: 1, self.BOS: 2, self.EOS: 3} self.idx2word = {} def add_w(self, ws): for w in ws: if w not in self.d: self.d[w] = len(self.d) def convert(self, w): return self.d[w] if w in self.d else self.d[self.UNK] def convert_sequence(self, ls): return [self.convert(l) for l in ls] def write(self, outfile): out = open(outfile, "w") items = [(v, k) for k, v in self.d.items()] items.sort() for v, k in items: out.write(" ".join([k, str(v)]) + "\n") out.close() def prune_vocab(self, k, cnt = False): vocab_list = [(word, count) for word, count in self.vocab.items()] if cnt: self.pruned_vocab = {pair[0]:pair[1] for pair in vocab_list if pair[1] > k} else: vocab_list.sort(key = lambda x: x[1], reverse=True) k = min(k, len(vocab_list)) self.pruned_vocab = {pair[0]:pair[1] for pair in vocab_list[:k]} for word in self.pruned_vocab: if word not in self.d: self.d[word] = len(self.d) for word, idx in self.d.items(): self.idx2word[idx] = word def load_vocab(self, vocab_file): self.d = {} self.idx2word = {} for line in open(vocab_file, 'r'): v, k = line.strip().split() self.d[v] = int(k) for word, idx in self.d.items(): self.idx2word[idx] = word def is_next_open_bracket(line, start_idx): for char in line[(start_idx + 1):]: if char == '(': return True elif char == ')': return False raise IndexError('Bracket possibly not balanced, open bracket not followed by closed bracket') def get_between_brackets(line, start_idx): output = [] for char in line[(start_idx + 1):]: if char == ')': break assert not(char == '(') output.append(char) return ''.join(output) def get_tags_tokens_lowercase(line): output = [] line_strip = line.rstrip() for i in range(len(line_strip)): if i == 0: assert line_strip[i] == '(' if line_strip[i] == '(' and not(is_next_open_bracket(line_strip, i)): # fulfilling this condition means this is a terminal symbol output.append(get_between_brackets(line_strip, i)) #print 'output:',output output_tags = [] output_tokens = [] output_lowercase = [] for terminal in output: terminal_split = terminal.split() # print(terminal, terminal_split) assert len(terminal_split) == 2 # each terminal contains a POS tag and word output_tags.append(terminal_split[0]) output_tokens.append(terminal_split[1]) output_lowercase.append(terminal_split[1].lower()) return [output_tags, output_tokens, output_lowercase] def get_nonterminal(line, start_idx): assert line[start_idx] == '(' # make sure it's an open bracket output = [] for char in line[(start_idx + 1):]: if char == ' ': break assert not(char == '(') and not(char == ')') output.append(char) return ''.join(output) def get_actions(line): output_actions = [] line_strip = line.rstrip() i = 0 max_idx = (len(line_strip) - 1) while i <= max_idx: assert line_strip[i] == '(' or line_strip[i] == ')' if line_strip[i] == '(': if is_next_open_bracket(line_strip, i): # open non-terminal curr_NT = get_nonterminal(line_strip, i) output_actions.append('NT(' + curr_NT + ')') i += 1 while line_strip[i] != '(': # get the next open bracket, which may be a terminal or another non-terminal i += 1 else: # it's a terminal symbol output_actions.append('SHIFT') while line_strip[i] != ')': i += 1 i += 1 while line_strip[i] != ')' and line_strip[i] != '(': i += 1 else: output_actions.append('REDUCE') if i == max_idx: break i += 1 while line_strip[i] != ')' and line_strip[i] != '(': i += 1 assert i == max_idx return output_actions def pad(ls, length, symbol): if len(ls) >= length: return ls[:length] return ls + [symbol] * (length -len(ls)) def clean_number(w): new_w = re.sub('[0-9]{1,}([,.]?[0-9]*)*', 'N', w) return new_w def get_data(args): indexer = Indexer(["","","",""]) def make_vocab(textfile, seqlength, minseqlength, lowercase, replace_num, train=1, apply_length_filter=1): num_sents = 0 max_seqlength = 0 for tree in open(textfile, 'r'): tree = tree.strip() tags, sent, sent_lower = get_tags_tokens_lowercase(tree) assert(len(tags) == len(sent)) if lowercase == 1: sent = sent_lower if replace_num == 1: sent = [clean_number(w) for w in sent] if (len(sent) > seqlength and apply_length_filter == 1) or len(sent) < minseqlength: continue num_sents += 1 max_seqlength = max(max_seqlength, len(sent)) if train == 1: for word in sent: indexer.vocab[word] += 1 return num_sents, max_seqlength def convert(textfile, lowercase, replace_num, batchsize, seqlength, minseqlength, outfile, num_sents, max_sent_l=0, shuffle=0, include_boundary=1, apply_length_filter=1): newseqlength = seqlength if include_boundary == 1: newseqlength += 2 #add 2 for EOS and BOS sents = np.zeros((num_sents, newseqlength), dtype=int) sent_lengths = np.zeros((num_sents,), dtype=int) dropped = 0 sent_id = 0 other_data = [] for tree in open(textfile, 'r'): tree = tree.strip() action = get_actions(tree) tags, sent, sent_lower = get_tags_tokens_lowercase(tree) assert(len(tags) == len(sent)) if lowercase == 1: sent = sent_lower if (len(sent) > seqlength and apply_length_filter == 1) or len(sent) < minseqlength: continue sent_str = " ".join(sent) if replace_num == 1: sent = [clean_number(w) for w in sent] if include_boundary == 1: sent = [indexer.BOS] + sent + [indexer.EOS] max_sent_l = max(len(sent), max_sent_l) sent_pad = pad(sent, newseqlength, indexer.PAD) sents[sent_id] = np.array(indexer.convert_sequence(sent_pad), dtype=int) sent_lengths[sent_id] = (sents[sent_id] != 0).sum() span, binary_actions, nonbinary_actions = utils.get_nonbinary_spans(action) other_data.append([sent_str, tags, action, binary_actions, nonbinary_actions, span, tree]) assert(2*(len(sent)- 2) - 1 == len(binary_actions)) assert(sum(binary_actions) + 1 == len(sent) - 2) sent_id += 1 if sent_id % 100000 == 0: print("{}/{} sentences processed".format(sent_id, num_sents)) print(sent_id, num_sents) if shuffle == 1: rand_idx = np.random.permutation(sent_id) sents = sents[rand_idx] sent_lengths = sent_lengths[rand_idx] other_data = [other_data[idx] for idx in rand_idx] print(len(sents), len(other_data)) #break up batches based on source lengths sent_lengths = sent_lengths[:sent_id] sent_sort = np.argsort(sent_lengths) sents = sents[sent_sort] other_data = [other_data[idx] for idx in sent_sort] sent_l = sent_lengths[sent_sort] curr_l = 1 l_location = [] #idx where sent length changes for j,i in enumerate(sent_sort): if sent_lengths[i] > curr_l: curr_l = sent_lengths[i] l_location.append(j) l_location.append(len(sents)) #get batch sizes curr_idx = 0 batch_idx = [0] nonzeros = [] batch_l = [] batch_w = [] for i in range(len(l_location)-1): while curr_idx < l_location[i+1]: curr_idx = min(curr_idx + batchsize, l_location[i+1]) batch_idx.append(curr_idx) for i in range(len(batch_idx)-1): batch_l.append(batch_idx[i+1] - batch_idx[i]) batch_w.append(sent_l[batch_idx[i]]) # Write output f = {} f["source"] = sents f["other_data"] = other_data f["batch_l"] = np.array(batch_l, dtype=int) f["source_l"] = np.array(batch_w, dtype=int) f["sents_l"] = np.array(sent_l, dtype = int) f["batch_idx"] = np.array(batch_idx[:-1], dtype=int) f["vocab_size"] = np.array([len(indexer.d)]) f["idx2word"] = indexer.idx2word f["word2idx"] = {word : idx for idx, word in indexer.idx2word.items()} print("Saved {} sentences (dropped {} due to length/unk filter)".format( len(f["source"]), dropped)) pickle.dump(f, open(outfile, 'wb')) return max_sent_l print("First pass through data to get vocab...") num_sents_train, train_seqlength = make_vocab(args.trainfile, args.seqlength, args.minseqlength, args.lowercase, args.replace_num, 1, 1) print("Number of sentences in training: {}".format(num_sents_train)) num_sents_valid, valid_seqlength = make_vocab(args.valfile, args.seqlength, args.minseqlength, args.lowercase, args.replace_num, 0, 0) print("Number of sentences in valid: {}".format(num_sents_valid)) num_sents_test, test_seqlength = make_vocab(args.testfile, args.seqlength, args.minseqlength, args.lowercase, args.replace_num, 0, 0) print("Number of sentences in test: {}".format(num_sents_test)) if args.vocabminfreq >= 0: indexer.prune_vocab(args.vocabminfreq, True) else: indexer.prune_vocab(args.vocabsize, False) if args.vocabfile != '': print('Loading pre-specified source vocab from ' + args.vocabfile) indexer.load_vocab(args.vocabfile) indexer.write(args.outputfile + ".dict") print("Vocab size: Original = {}, Pruned = {}".format(len(indexer.vocab), len(indexer.d))) print(train_seqlength, valid_seqlength, test_seqlength) max_sent_l = 0 max_sent_l = convert(args.testfile, args.lowercase, args.replace_num, args.batchsize, test_seqlength, args.minseqlength, args.outputfile + "-test.pkl", num_sents_test, max_sent_l, args.shuffle, args.include_boundary, 0) max_sent_l = convert(args.valfile, args.lowercase, args.replace_num, args.batchsize, valid_seqlength, args.minseqlength, args.outputfile + "-val.pkl", num_sents_valid, max_sent_l, args.shuffle, args.include_boundary, 0) max_sent_l = convert(args.trainfile, args.lowercase, args.replace_num, args.batchsize, args.seqlength, args.minseqlength, args.outputfile + "-train.pkl", num_sents_train, max_sent_l, args.shuffle, args.include_boundary, 1) print("Max sent length (before dropping): {}".format(max_sent_l)) def main(arguments): parser = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--vocabsize', help="Size of source vocabulary, constructed " "by taking the top X most frequent words. " " Rest are replaced with special UNK tokens.", type=int, default=10000) parser.add_argument('--vocabminfreq', help="Minimum frequency for vocab. Use this instead of " "vocabsize if > 0", type=int, default=1) parser.add_argument('--include_boundary', help="Add BOS/EOS tokens", type=int, default=1) parser.add_argument('--lowercase', help="Lower case", type=int, default=0) parser.add_argument('--replace_num', help="Replace numbers with N", type=int, default=0) parser.add_argument('--trainfile', help="Path to training data.", required=True) parser.add_argument('--valfile', help="Path to validation data.", required=True) parser.add_argument('--testfile', help="Path to test validation data.", required=True) parser.add_argument('--batchsize', help="Size of each minibatch.", type=int, default=16) parser.add_argument('--seqlength', help="Maximum sequence length. Sequences longer " "than this are dropped.", type=int, default=200) parser.add_argument('--minseqlength', help="Minimum sequence length. Sequences shorter " "than this are dropped.", type=int, default=0) parser.add_argument('--outputfile', help="Prefix of the output file names. ", type=str, required=True) parser.add_argument('--vocabfile', help="If working with a preset vocab, " "then including this will ignore srcvocabsize and use the" "vocab provided here.", type = str, default='') parser.add_argument('--shuffle', help="If = 1, shuffle sentences before sorting (based on " "source length).", type = int, default = 1) args = parser.parse_args(arguments) np.random.seed(3435) get_data(args) if __name__ == '__main__': sys.exit(main(sys.argv[1:])) ================================================ FILE: train.py ================================================ #!/usr/bin/env python3 import sys import os import argparse import json import random import shutil import copy import torch from torch import cuda import torch.nn as nn from torch.autograd import Variable from torch.nn.parameter import Parameter import torch.nn.functional as F import numpy as np import time import logging from data import Dataset from models import RNNG from utils import * parser = argparse.ArgumentParser() # Data path options parser.add_argument('--train_file', default='data/ptb-1unk-train.pkl') parser.add_argument('--val_file', default='data/ptb-1unk-val.pkl') parser.add_argument('--train_from', default='') # Model options parser.add_argument('--w_dim', default=650, type=int, help='hidden dimension for LM/RNNG') parser.add_argument('--h_dim', default=650, type=int, help='hidden dimension for LM/RNNG') parser.add_argument('--q_dim', default=256, type=int, help='hidden dimension for variational RNN') parser.add_argument('--num_layers', default=2, type=int, help='number of layers in LM and the stack LSTM (for RNNG)') parser.add_argument('--dropout', default=0.5, type=float, help='dropout rate') # Optimization options parser.add_argument('--count_eos_ppl', default=0, type=int, help='whether to count eos in val PPL') parser.add_argument('--save_path', default='urnng.pt', help='where to save the data') parser.add_argument('--num_epochs', default=18, type=int, help='number of training epochs') parser.add_argument('--min_epochs', default=8, type=int, help='do not decay learning rate for at least this many epochs') parser.add_argument('--mode', default='unsupervised', type=str, choices=['unsupervised', 'supervised']) parser.add_argument('--mc_samples', default=5, type=int, help='how many samples for IWAE bound calc for evaluation') parser.add_argument('--samples', default=8, type=int, help='how many samples for score function gradients') parser.add_argument('--lr', default=1, type=float, help='starting learning rate') parser.add_argument('--q_lr', default=0.0001, type=float, help='learning rate for inference network q') parser.add_argument('--action_lr', default=0.1, type=float, help='learning rate for action layer') parser.add_argument('--decay', default=0.5, type=float, help='') parser.add_argument('--kl_warmup', default=2, type=int, help='') parser.add_argument('--train_q_epochs', default=2, type=int, help='') parser.add_argument('--param_init', default=0.1, type=float, help='parameter initialization (over uniform)') parser.add_argument('--max_grad_norm', default=5, type=float, help='gradient clipping parameter') parser.add_argument('--q_max_grad_norm', default=1, type=float, help='gradient clipping parameter for q') parser.add_argument('--gpu', default=2, type=int, help='which gpu to use') parser.add_argument('--seed', default=3435, type=int, help='random seed') parser.add_argument('--print_every', type=int, default=500, help='print stats after this many batches') def main(args): np.random.seed(args.seed) torch.manual_seed(args.seed) train_data = Dataset(args.train_file) val_data = Dataset(args.val_file) vocab_size = int(train_data.vocab_size) print('Train: %d sents / %d batches, Val: %d sents / %d batches' % (train_data.sents.size(0), len(train_data), val_data.sents.size(0), len(val_data))) print('Vocab size: %d' % vocab_size) cuda.set_device(args.gpu) if args.train_from == '': model = RNNG(vocab = vocab_size, w_dim = args.w_dim, h_dim = args.h_dim, dropout = args.dropout, num_layers = args.num_layers, q_dim = args.q_dim) if args.param_init > 0: for param in model.parameters(): param.data.uniform_(-args.param_init, args.param_init) else: print('loading model from ' + args.train_from) checkpoint = torch.load(args.train_from) model = checkpoint['model'] print("model architecture") print(model) q_params = [] action_params = [] model_params = [] for name, param in model.named_parameters(): if 'action' in name: print(name) action_params.append(param) elif 'q_' in name: print(name) q_params.append(param) else: model_params.append(param) q_lr = args.q_lr optimizer = torch.optim.SGD(model_params, lr=args.lr) q_optimizer = torch.optim.Adam(q_params, lr=q_lr) action_optimizer = torch.optim.SGD(action_params, lr=args.action_lr) model.train() model.cuda() epoch = 0 decay= 0 if args.kl_warmup > 0: kl_pen = 0. kl_warmup_batch = 1./(args.kl_warmup * len(train_data)) else: kl_pen = 1. best_val_ppl = 5e5 best_val_f1 = 0 samples = args.samples best_val_ppl, best_val_f1 = eval(val_data, model, samples = args.mc_samples, count_eos_ppl = args.count_eos_ppl) all_stats = [[0., 0., 0.]] #true pos, false pos, false neg for f1 calc while epoch < args.num_epochs: start_time = time.time() epoch += 1 if epoch > args.train_q_epochs: #stop training q after this many epochs args.q_lr = 0. for param_group in q_optimizer.param_groups: param_group['lr'] = args.q_lr print('Starting epoch %d' % epoch) train_nll_recon = 0. train_nll_iwae = 0. train_kl = 0. train_q_entropy = 0. num_sents = 0. num_words = 0. b = 0 for i in np.random.permutation(len(train_data)): if args.kl_warmup > 0: kl_pen = min(1., kl_pen + kl_warmup_batch) sents, length, batch_size, gold_actions, gold_spans, gold_binary_trees, other_data = train_data[i] if length == 1: # we ignore length 1 sents during training/eval since we work with binary trees only continue sents = sents.cuda() b += 1 q_optimizer.zero_grad() optimizer.zero_grad() action_optimizer.zero_grad() if args.mode == 'unsupervised': ll_word, ll_action_p, ll_action_q, all_actions, q_entropy = model(sents, samples=samples, has_eos = True) log_f = ll_word + kl_pen*ll_action_p iwae_ll = log_f.mean(1).detach() + kl_pen*q_entropy.detach() obj = log_f.mean(1) if epoch < args.train_q_epochs: obj += kl_pen*q_entropy baseline = torch.zeros_like(log_f) baseline_k = torch.zeros_like(log_f) for k in range(samples): baseline_k.copy_(log_f) baseline_k[:, k].fill_(0) baseline[:, k] = baseline_k.detach().sum(1) / (samples - 1) obj += ((log_f.detach() - baseline.detach())*ll_action_q).mean(1) kl = (ll_action_q - ll_action_p).mean(1).detach() ll_word = ll_word.mean(1) train_q_entropy += q_entropy.sum().item() else: gold_actions = gold_binary_trees ll_action_q = model.forward_tree(sents, gold_actions, has_eos=True) ll_word, ll_action_p, all_actions = model.forward_actions(sents, gold_actions) obj = ll_word + ll_action_p + ll_action_q kl = -ll_action_q iwae_ll = ll_word + ll_action_p train_nll_iwae += -iwae_ll.sum().item() actions = all_actions[:, 0].long().cpu() train_nll_recon += -ll_word.sum().item() train_kl += kl.sum().item() (-obj.mean()).backward() if args.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model_params + action_params, args.max_grad_norm) if args.q_max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(q_params, args.q_max_grad_norm) q_optimizer.step() optimizer.step() action_optimizer.step() num_sents += batch_size num_words += batch_size * length for bb in range(batch_size): action = list(actions[bb].numpy()) span_b = get_spans(action) span_b_set = set(span_b[:-1]) #ignore the sentence-level trivial span update_stats(span_b_set, [set(gold_spans[bb][:-1])], all_stats) if b % args.print_every == 0: all_f1 = get_f1(all_stats) param_norm = sum([p.norm()**2 for p in model.parameters()]).item()**0.5 log_str = 'Epoch: %d, Batch: %d/%d, LR: %.4f, qLR: %.5f, qEnt: %.4f, TrainVAEPPL: %.2f, ' + \ 'TrainReconPPL: %.2f, TrainKL: %.2f, TrainIWAEPPL: %.2f, ' + \ '|Param|: %.2f, BestValPerf: %.2f, BestValF1: %.2f, KLPen: %.4f, ' + \ 'GoldTreeF1: %.2f, Throughput: %.2f examples/sec' print(log_str % (epoch, b, len(train_data), args.lr, args.q_lr, train_q_entropy / num_sents, np.exp((train_nll_recon + train_kl)/ num_words), np.exp(train_nll_recon/num_words), train_kl / num_sents, np.exp(train_nll_iwae/num_words), param_norm, best_val_ppl, best_val_f1, kl_pen, all_f1[0], num_sents / (time.time() - start_time))) sent_str = [train_data.idx2word[word_idx] for word_idx in list(sents[-1][1:-1].cpu().numpy())] print("PRED:", get_tree(action[:-2], sent_str)) print("GOLD:", get_tree(gold_binary_trees[-1], sent_str)) print('--------------------------------') print('Checking validation perf...') val_ppl, val_f1 = eval(val_data, model, samples = args.mc_samples, count_eos_ppl = args.count_eos_ppl) print('--------------------------------') if val_ppl < best_val_ppl: best_val_ppl = val_ppl best_val_f1 = val_f1 checkpoint = { 'args': args.__dict__, 'model': model.cpu(), 'word2idx': train_data.word2idx, 'idx2word': train_data.idx2word } print('Saving checkpoint to %s' % args.save_path) torch.save(checkpoint, args.save_path) model.cuda() else: if epoch > args.min_epochs: decay = 1 if decay == 1: args.lr = args.decay*args.lr args.q_lr = args.decay*args.q_lr args.action_lr = args.decay*args.action_lr for param_group in optimizer.param_groups: param_group['lr'] = args.lr for param_group in q_optimizer.param_groups: param_group['lr'] = args.q_lr for param_group in action_optimizer.param_groups: param_group['lr'] = args.action_lr if args.lr < 0.03: break print("Finished training!") def eval(data, model, samples = 0, count_eos_ppl = 0): model.eval() num_sents = 0 num_words = 0 total_nll_recon = 0. total_kl = 0. total_nll_iwae = 0. corpus_f1 = [0., 0., 0.] sent_f1 = [] with torch.no_grad(): for i in list(reversed(range(len(data)))): sents, length, batch_size, gold_actions, gold_spans, gold_binary_trees, other_data = data[i] if length == 1: # length 1 sents are ignored since URNNG needs at least length 2 sents continue if args.count_eos_ppl == 1: tree_length = length length += 1 else: sents = sents[:, :-1] tree_length = length sents = sents.cuda() ll_word_all, ll_action_p_all, ll_action_q_all, actions_all, q_entropy = model(sents, samples = samples, has_eos = count_eos_ppl == 1) ll_word, ll_action_p, ll_action_q = ll_word_all.mean(1), ll_action_p_all.mean(1), ll_action_q_all.mean(1) kl = ll_action_q - ll_action_p _, binary_matrix, argmax_spans = model.q_crf._viterbi(model.scores) actions = [] for b in range(batch_size): tree = get_tree_from_binary_matrix(binary_matrix[b], tree_length) actions.append(get_actions(tree)) actions = torch.Tensor(actions).long() total_nll_recon += -ll_word.sum().item() total_kl += kl.sum().item() num_sents += batch_size num_words += batch_size * length if samples > 0: #PPL estimate based on IWAE sample_ll = torch.zeros(batch_size, samples) for j in range(samples): ll_word_j, ll_action_p_j, ll_action_q_j = ll_word_all[:, j], ll_action_p_all[:, j], ll_action_q_all[:, j] sample_ll[:, j].copy_(ll_word_j + ll_action_p_j - ll_action_q_j) ll_iwae = model.logsumexp(sample_ll, 1) - np.log(samples) total_nll_iwae -= ll_iwae.sum().item() for b in range(batch_size): action = list(actions[b].numpy()) span_b = get_spans(action) span_b = argmax_spans[b] span_b_set = set(span_b[:-1]) gold_b_set = set(gold_spans[b][:-1]) tp, fp, fn = get_stats(span_b_set, gold_b_set) corpus_f1[0] += tp corpus_f1[1] += fp corpus_f1[2] += fn # sent-level F1 is based on L83-89 from https://github.com/yikangshen/PRPN/test_phrase_grammar.py model_out = span_b_set std_out = gold_b_set overlap = model_out.intersection(std_out) prec = float(len(overlap)) / (len(model_out) + 1e-8) reca = float(len(overlap)) / (len(std_out) + 1e-8) if len(std_out) == 0: reca = 1. if len(model_out) == 0: prec = 1. f1 = 2 * prec * reca / (prec + reca + 1e-8) sent_f1.append(f1) tp, fp, fn = corpus_f1 prec = tp / (tp + fp) recall = tp / (tp + fn) corpus_f1 = 2*prec*recall/(prec+recall)*100 if prec+recall > 0 else 0. sent_f1 = np.mean(np.array(sent_f1))*100 elbo_ppl = np.exp((total_nll_recon + total_kl) / num_words) recon_ppl = np.exp(total_nll_recon / num_words) iwae_ppl = np.exp(total_nll_iwae /num_words) kl = total_kl / num_sents print('ElboPPL: %.2f, ReconPPL: %.2f, KL: %.4f, IwaePPL: %.2f, CorpusF1: %.2f, SentAvgF1: %.2f' % (elbo_ppl, recon_ppl, kl, iwae_ppl, corpus_f1, sent_f1)) #note that corpus F1 printed here is different from what you should get from #evalb since we do not ignore any tags (e.g. punctuation), while evalb ignores it model.train() return iwae_ppl, corpus_f1 if __name__ == '__main__': args = parser.parse_args() main(args) ================================================ FILE: train_lm.py ================================================ #!/usr/bin/env python3 import sys import os import argparse import json import random import shutil import copy import torch from torch import cuda import torch.nn as nn from torch.autograd import Variable from torch.nn.parameter import Parameter import torch.nn.functional as F import numpy as np import time import logging from data import Dataset from models import RNNLM from utils import * parser = argparse.ArgumentParser() # Data path options parser.add_argument('--train_file', default='data/ptb-train.pkl') parser.add_argument('--val_file', default='data/ptb-val.pkl') parser.add_argument('--test_file', default='data/ptb-test.pkl') parser.add_argument('--train_from', default='') # Model options parser.add_argument('--w_dim', default=650, type=int, help='hidden dimension for LM') parser.add_argument('--h_dim', default=650, type=int, help='hidden dimension for LM') parser.add_argument('--num_layers', default=2, type=int, help='number of layers in LM and the stack LSTM (for RNNG)') parser.add_argument('--dropout', default=0.5, type=float, help='dropout rate') # Optimization options parser.add_argument('--count_eos_ppl', default=0, type=int, help='whether to count eos in val PPL') parser.add_argument('--test', default=0, type=int, help='') parser.add_argument('--save_path', default='urnng.pt', help='where to save the data') parser.add_argument('--num_epochs', default=30, type=int, help='number of training epochs') parser.add_argument('--min_epochs', default=8, type=int, help='do not decay learning rate for at least this many epochs') parser.add_argument('--lr', default=1, type=float, help='starting learning rate') parser.add_argument('--decay', default=0.5, type=float, help='') parser.add_argument('--param_init', default=0.1, type=float, help='parameter initialization (over uniform)') parser.add_argument('--max_grad_norm', default=5, type=float, help='gradient clipping parameter') parser.add_argument('--gpu', default=2, type=int, help='which gpu to use') parser.add_argument('--seed', default=3435, type=int, help='random seed') parser.add_argument('--print_every', type=int, default=500, help='print stats after this many batches') def main(args): np.random.seed(args.seed) torch.manual_seed(args.seed) train_data = Dataset(args.train_file) val_data = Dataset(args.val_file) vocab_size = int(train_data.vocab_size) print('Train: %d sents / %d batches, Val: %d sents / %d batches' % (train_data.sents.size(0), len(train_data), val_data.sents.size(0), len(val_data))) print('Vocab size: %d' % vocab_size) cuda.set_device(args.gpu) if args.train_from == '': model = RNNLM(vocab = vocab_size, w_dim = args.w_dim, h_dim = args.h_dim, dropout = args.dropout, num_layers = args.num_layers) if args.param_init > 0: for param in model.parameters(): param.data.uniform_(-args.param_init, args.param_init) else: print('loading model from ' + args.train_from) checkpoint = torch.load(args.train_from) model = checkpoint['model'] print("model architecture") print(model) optimizer = torch.optim.SGD(model.parameters(), lr=args.lr) model.train() model.cuda() epoch = 0 decay= 0 if args.test == 1: test_data = Dataset(args.test_file) test_ppl = eval(test_data, model, count_eos_ppl = args.count_eos_ppl) sys.exit(0) best_val_ppl = eval(val_data, model, count_eos_ppl = args.count_eos_ppl) while epoch < args.num_epochs: start_time = time.time() epoch += 1 print('Starting epoch %d' % epoch) train_nll = 0. num_sents = 0. num_words = 0. b = 0 for i in np.random.permutation(len(train_data)): sents, length, batch_size, gold_actions, gold_spans, gold_binary_trees, other_data = train_data[i] if length == 1: continue sents = sents.cuda() b += 1 optimizer.zero_grad() optimizer.zero_grad() nll = -model(sents).mean() train_nll += nll.item()*batch_size nll.backward() if args.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() num_sents += batch_size num_words += batch_size * length if b % args.print_every == 0: param_norm = sum([p.norm()**2 for p in model.parameters()]).item()**0.5 print('Epoch: %d, Batch: %d/%d, LR: %.4f, TrainPPL: %.2f, |Param|: %.4f, BestValPerf: %.2f, Throughput: %.2f examples/sec' % (epoch, b, len(train_data), args.lr, np.exp(train_nll / num_words), param_norm, best_val_ppl, num_sents / (time.time() - start_time))) print('--------------------------------') print('Checking validation perf...') val_ppl = eval(val_data, model, count_eos_ppl = args.count_eos_ppl) print('--------------------------------') if val_ppl < best_val_ppl: best_val_ppl = val_ppl checkpoint = { 'args': args.__dict__, 'model': model.cpu(), 'word2idx': train_data.word2idx, 'idx2word': train_data.idx2word } print('Saving checkpoint to %s' % args.save_path) torch.save(checkpoint, args.save_path) model.cuda() else: if epoch > args.min_epochs: decay = 1 if decay == 1: args.lr = args.decay*args.lr for param_group in optimizer.param_groups: param_group['lr'] = args.lr if args.lr < 0.03: break print("Finished training") def eval(data, model, count_eos_ppl = 0): model.eval() num_words = 0 total_nll = 0. with torch.no_grad(): for i in list(reversed(range(len(data)))): sents, length, batch_size, gold_actions, gold_spans, gold_binary_trees, other_data = data[i] if length == 1: #we ignore length 1 sents in URNNG eval so do this for LM too continue if args.count_eos_ppl == 1: length += 1 else: sents = sents[:, :-1] sents = sents.cuda() num_words += length * batch_size nll = -model(sents).mean() total_nll += nll.item()*batch_size ppl = np.exp(total_nll / num_words) print('PPL: %.2f' % (ppl)) model.train() return ppl if __name__ == '__main__': args = parser.parse_args() main(args) ================================================ FILE: utils.py ================================================ #!/usr/bin/env python3 import numpy as np import itertools import random def get_actions(tree, SHIFT = 0, REDUCE = 1, OPEN='(', CLOSE=')'): #input tree in bracket form: ((A B) (C D)) #output action sequence: 0 0 1 0 0 1 1, where 0 is SHIFT and 1 is REDUCE actions = [] tree = tree.strip() i = 0 num_shift = 0 num_reduce = 0 left = 0 right = 0 while i < len(tree): if tree[i] != ' ' and tree[i] != OPEN and tree[i] != CLOSE: #terminal if tree[i-1] == OPEN or tree[i-1] == ' ': actions.append(SHIFT) num_shift += 1 elif tree[i] == CLOSE: actions.append(REDUCE) num_reduce += 1 right += 1 elif tree[i] == OPEN: left += 1 i += 1 assert(num_shift == num_reduce + 1) return actions def get_tree(actions, sent = None, SHIFT = 0, REDUCE = 1): #input action and sent (lists), e.g. S S R S S R R, A B C D #output tree ((A B) (C D)) stack = [] pointer = 0 if sent is None: sent = list(map(str, range((len(actions)+1) // 2))) for action in actions: if action == SHIFT: word = sent[pointer] stack.append(word) pointer += 1 elif action == REDUCE: right = stack.pop() left = stack.pop() stack.append('(' + left + ' ' + right + ')') assert(len(stack) == 1) return stack[-1] def get_spans(actions, SHIFT = 0, REDUCE = 1): sent = list(range((len(actions)+1) // 2)) spans = [] pointer = 0 stack = [] for action in actions: if action == SHIFT: word = sent[pointer] stack.append(word) pointer += 1 elif action == REDUCE: right = stack.pop() left = stack.pop() if isinstance(left, int): left = (left, None) if isinstance(right, int): right = (None, right) new_span = (left[0], right[1]) spans.append(new_span) stack.append(new_span) return spans def get_stats(span1, span2): tp = 0 fp = 0 fn = 0 for span in span1: if span in span2: tp += 1 else: fp += 1 for span in span2: if span not in span1: fn += 1 return tp, fp, fn def update_stats(pred_span, gold_spans, stats): for gold_span, stat in zip(gold_spans, stats): tp, fp, fn = get_stats(pred_span, gold_span) stat[0] += tp stat[1] += fp stat[2] += fn def get_f1(stats): f1s = [] for stat in stats: prec = stat[0] / (stat[0] + stat[1]) recall = stat[0] / (stat[0] + stat[2]) f1 = 2*prec*recall / (prec + recall)*100 if prec+recall > 0 else 0. f1s.append(f1) return f1s def span_str(start = None, end = None): assert(start is not None or end is not None) if start is None: return ' ' + str(end) + ')' elif end is None: return '(' + str(start) + ' ' else: return ' (' + str(start) + ' ' + str(end) + ') ' def get_tree_from_binary_matrix(matrix, length): sent = list(map(str, range(length))) n = len(sent) tree = {} for i in range(n): tree[i] = sent[i] for k in np.arange(1, n): for s in np.arange(n): t = s + k if t > n-1: break if matrix[s][t].item() == 1: span = '(' + tree[s] + ' ' + tree[t] + ')' tree[s] = span tree[t] = span return tree[0] def get_nonbinary_spans(actions, SHIFT = 0, REDUCE = 1): spans = [] stack = [] pointer = 0 binary_actions = [] nonbinary_actions = [] num_shift = 0 num_reduce = 0 for action in actions: # print(action, stack) if action == "SHIFT": nonbinary_actions.append(SHIFT) stack.append((pointer, pointer)) pointer += 1 binary_actions.append(SHIFT) num_shift += 1 elif action[:3] == 'NT(': stack.append('(') elif action == "REDUCE": nonbinary_actions.append(REDUCE) right = stack.pop() left = right n = 1 while stack[-1] is not '(': left = stack.pop() n += 1 span = (left[0], right[1]) if left[0] != right[1]: spans.append(span) stack.pop() stack.append(span) while n > 1: n -= 1 binary_actions.append(REDUCE) num_reduce += 1 else: assert False assert(len(stack) == 1) assert(num_shift == num_reduce + 1) return spans, binary_actions, nonbinary_actions