Repository: harvardnlp/urnng
Branch: master
Commit: b1eeffa5b590
Files: 15
Total size: 101.9 KB
Directory structure:
gitextract_zduez7kc/
├── .gitignore
├── COLLINS.prm
├── README.md
├── TreeCRF.py
├── data/
│ ├── test.txt
│ ├── train.txt
│ └── valid.txt
├── data.py
├── eval_ppl.py
├── models.py
├── parse.py
├── preprocess.py
├── train.py
├── train_lm.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
*.pt
*.amat
*.mat
*.out
*.out~
*.pyc
*.pt~
.gitignore~
*.out~
*.sh
*.sh~
*.py~
*.json
*.json~
*.model
*.h5
*.tar.gz
*.hdf5
*.dict
*.pkl
================================================
FILE: COLLINS.prm
================================================
##------------------------------------------##
## Debug mode ##
## 0: No debugging ##
## 1: print data for individual sentence ##
##------------------------------------------##
DEBUG 0
##------------------------------------------##
## MAX error ##
## Number of error to stop the process. ##
## This is useful if there could be ##
## tokanization error. ##
## The process will stop when this number##
## of errors are accumulated. ##
##------------------------------------------##
MAX_ERROR 10
##------------------------------------------##
## Cut-off length for statistics ##
## At the end of evaluation, the ##
## statistics for the senetnces of length##
## less than or equal to this number will##
## be shown, on top of the statistics ##
## for all the sentences ##
##------------------------------------------##
CUTOFF_LEN 10
##------------------------------------------##
## unlabeled or labeled bracketing ##
## 0: unlabeled bracketing ##
## 1: labeled bracketing ##
##------------------------------------------##
LABELED 0
##------------------------------------------##
## Delete labels ##
## list of labels to be ignored. ##
## If it is a pre-terminal label, delete ##
## the word along with the brackets. ##
## If it is a non-terminal label, just ##
## delete the brackets (don't delete ##
## deildrens). ##
##------------------------------------------##
DELETE_LABEL TOP
DELETE_LABEL -NONE-
DELETE_LABEL ,
DELETE_LABEL :
DELETE_LABEL ``
DELETE_LABEL ''
DELETE_LABEL .
##------------------------------------------##
## Delete labels for length calculation ##
## list of labels to be ignored for ##
## length calculation purpose ##
##------------------------------------------##
DELETE_LABEL_FOR_LENGTH -NONE-
##------------------------------------------##
## Equivalent labels, words ##
## the pairs are considered equivalent ##
## This is non-directional. ##
##------------------------------------------##
EQ_LABEL ADVP PRT
# EQ_WORD Example example
================================================
FILE: README.md
================================================
# Unsupervised Recurrent Neural Network Grammars
This is an implementation of the paper:
[Unsupervised Recurrent Neural Network Grammars](https://arxiv.org/abs/1904.03746)
Yoon Kim, Alexander Rush, Lei Yu, Adhiguna Kuncoro, Chris Dyer, Gabor Melis
NAACL 2019
## Dependencies
The code was tested in `python 3.6` and `pytorch 1.0`.
## Data
Sample train/val/test data is in the `data/` folder. These are the standard datasets from PTB.
First preprocess the data:
```
python preprocess.py --trainfile data/train.txt --valfile data/valid.txt --testfile data/test.txt
--outputfile data/ptb --vocabminfreq 1 --lowercase 0 --replace_num 0 --batchsize 16
```
Running this will save the following files in the `data/` folder: `ptb-train.pkl`, `ptb-val.pkl`,
`ptb-test.pkl`, `ptb.dict`. Here `ptb.dict` is the word-idx mapping, and you can change the
output folder/name by changing the argument to `outputfile`. Also, the preprocessing here
will replace singletons with a single `<unk>` rather than with Berkeley parser's mapping rules
(see below for results using this setup).
## Training
To train the URNNG:
```
python train.py --train_file data/ptb-train.pkl --val_file data/ptb-val.pkl --save_path urnng.pt
--mode unsupervised --gpu 0
```
where `save_path` is where you want to save the model, and `gpu 0` is for using the first GPU
in the cluster (the mapping from PyTorch GPU index to your cluster's GPU index may vary).
Training should take 2 to 3 days depending on your setup.
To train the RNNG:
```
python train.py --train_file data/ptb-train.pkl --val_file data/ptb-val.pkl --save_path rnng.pt
--mode supervised --train_q_epochs 18 --gpu 0
```
For fine-tuning:
```
python train.py --train_from rnng.pt --train_file data/ptb-train.pkl --val_file data/ptb-val.pkl
--save_path rnng-urnng.pt --mode unsupervised --lr 0.1 --train_q_epochs 10 --epochs 10
--min_epochs 6 --gpu 0 --kl_warmup 0
```
To train the LM:
```
python train_lm.py --train_file data/ptb-train.pkl --val_file data/ptb-val.pkl
--test_file data/ptb-test.pkl --save_path lm.pt
```
## Evaluation
To evaluate perplexity with importance sampling on the test set:
```
python eval_ppl.py --model_file urnng.pt --test_file data/ptb-test.pkl --samples 1000
--is_temp 2 --gpu 0
```
The argument `samples` is for the number of importance weighted samples, and `is_temp` is for
flattening the inference network's distribution (footnote 14 in the paper).
The same evaluation code will work for RNNG.
For LM evaluation:
```
python train_lm.py --train_from lm.pt --test_file data/ptb-test.pkl --test 1
```
To evaluate F1, first we need to parse the test set:
```
python parse.py --model_file urnng.pt --data_file data/ptb-test.txt --out_file pred-parse.txt
--gold_out_file gold-parse.txt --gpu 0
```
This will output the predicted parse trees into `pred-parse.txt`. We also output a version
of the gold parse `gold-parse.txt` to be used as input for `evalb`, since sentences with only trivial spans are ignored by `parse.py`. Note that corpus/sentence F1 results printed here do not correspond to the results reported in the paper, since it does not ignore punctuation.
Finally, download/install `evalb`, available [here](https://nlp.cs.nyu.edu/evalb).
Then run:
```
evalb -p COLLINS.prm gold-parse.txt test-parse.txt
```
where `COLLINS.prm` is the parameter file (provided in this repo) that tells `evalb` to ignore
punctuation and evaluate on unlabeled F1.
## Note Regarding Preprocessing
Note that some of the details regarding the preprocessing is slightly different from the original
paper. In particular, in this implementation we replace singleton words a single `<unk>` token
instead of using Berkeley parser's mapping rules. This results in slight lower perplexity
for all models, since the vocabulary size is smaller. Here are the perplexty numbers I get
in this setting:
- RNNLM: 89.2
- RNNG: 83.7
- URNNG: 85.1 (F1: 38.4)
- RNNG --> URNNG: 82.5
## Acknowledgements
Some of our preprocessing and evaluation code is based on the following repositories:
- [Recurrent Neural Network Grammars](https://github.com/clab/rnng)
- [Parsing Reading Predict Network](https://github.com/yikangshen/PRPN)
## License
MIT
================================================
FILE: TreeCRF.py
================================================
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import itertools
import utils
import random
class ConstituencyTreeCRF(nn.Module):
def __init__(self):
super(ConstituencyTreeCRF, self).__init__()
self.huge = 1e9
def logadd(self, x, y):
d = torch.max(x,y)
return torch.log(torch.exp(x-d) + torch.exp(y-d)) + d
def logsumexp(self, x, dim=1):
d = torch.max(x, dim)[0]
return torch.log(torch.exp(x - d.unsqueeze(dim).expand_as(x)).sum(dim)) + d
def _init_table(self, scores):
# initialize dynamic programming table
batch_size = scores.size(0)
n = scores.size(1)
self.alpha = [[scores.new(batch_size).fill_(-self.huge) for _ in range(n)] for _ in range(n)]
def _forward(self, scores):
#inside step
batch_size = scores.size(0)
n = scores.size(1)
self._init_table(scores)
for i in range(n):
self.alpha[i][i] = scores[:, i, i]
for k in np.arange(1, n+1):
for s in range(n):
t = s + k
if t > n-1:
break
tmp = [self.alpha[s][u] + self.alpha[u+1][t] + scores[:, s, t] for u in np.arange(s,t)]
tmp = torch.stack(tmp, 1)
self.alpha[s][t] = self.logsumexp(tmp, 1)
def _backward(self, scores):
#outside step
batch_size = scores.size(0)
n = scores.size(1)
self.beta = [[None for _ in range(n)] for _ in range(n)]
self.beta[0][n-1] = scores.new(batch_size).fill_(0)
for k in np.arange(n-1, 0, -1):
for s in range(n):
t = s + k
if t > n-1:
break
for u in np.arange(s, t):
if s < u+1:
tmp = self.beta[s][t] + self.alpha[u+1][t] + scores[:, s, t]
if self.beta[s][u] is None:
self.beta[s][u] = tmp
else:
self.beta[s][u] = self.logadd(self.beta[s][u], tmp)
if u+1 < t+1:
tmp = self.beta[s][t] + self.alpha[s][u] + scores[:, s, t]
if self.beta[u+1][t] is None:
self.beta[u+1][t] = tmp
else:
self.beta[u+1][t] = self.logadd(self.beta[u+1][t], tmp)
def _marginal(self, scores):
batch_size = scores.size(0)
n = scores.size(1)
self.log_marginal = [[None for _ in range(n)] for _ in range(n)]
log_Z = self.alpha[0][n-1]
for s in range(n):
for t in np.arange(s, n):
self.log_marginal[s][t] = self.alpha[s][t] + self.beta[s][t] - log_Z
def _entropy(self, scores):
batch_size = scores.size(0)
n = scores.size(1)
self.entropy = [[None for _ in range(n)] for _ in range(n)]
for i in range(n):
self.entropy[i][i] = scores.new(batch_size).fill_(0)
for k in np.arange(1, n+1):
for s in range(n):
t = s + k
if t > n-1:
break
score = []
prev_ent = []
for u in np.arange(s, t):
score.append(self.alpha[s][u] + self.alpha[u+1][t])
prev_ent.append(self.entropy[s][u] + self.entropy[u+1][t])
score = torch.stack(score, 1)
prev_ent = torch.stack(prev_ent, 1)
log_prob = F.log_softmax(score, dim = 1)
prob = log_prob.exp()
entropy = ((prev_ent - log_prob)*prob).sum(1)
self.entropy[s][t] = entropy
def _sample(self, scores, alpha = None, argmax = False):
# sample from p(tree | sent)
# also get the spans
if alpha is None:
self._forward(scores)
alpha = self.alpha
batch_size = scores.size(0)
n = scores.size(1)
tree = scores.new(batch_size, n, n).zero_()
all_log_probs = []
tree_brackets = []
spans = []
for b in range(batch_size):
sampled = [(0, n-1)]
span = [(0, n-1)]
queue = [(0, n-1)] #start, end
log_probs = []
tree_str = get_span_str(0, n-1)
while len(queue) > 0:
node = queue.pop(0)
start, end = node
left_parent = get_span_str(start, None)
right_parent = get_span_str(None, end)
score = []
score_idx = []
for u in np.arange(start, end):
score.append(alpha[start][u][b] + alpha[u+1][end][b])
score_idx.append([(start, u), (u+1, end)])
score = torch.stack(score, 0)
log_prob = F.log_softmax(score, dim = 0)
if argmax:
sample = torch.max(log_prob, 0)[1]
else:
prob = log_prob.exp()
sample = torch.multinomial(log_prob.exp(), 1)
sample_idx = score_idx[sample.item()]
log_probs.append(log_prob[sample.item()])
for idx in sample_idx:
if idx[0] != idx[1]:
queue.append(idx)
span.append(idx)
sampled.append(idx)
left_child = '(' + get_span_str(sample_idx[0][0], sample_idx[0][1])
right_child = get_span_str(sample_idx[1][0], sample_idx[1][1]) + ')'
if sample_idx[0][0] != sample_idx[0][1]:
tree_str = tree_str.replace(left_parent, left_child)
if sample_idx[1][0] != sample_idx[1][1]:
tree_str = tree_str.replace(right_parent, right_child)
all_log_probs.append(torch.stack(log_probs, 0).sum(0))
tree_brackets.append(tree_str)
spans.append(span[::-1])
for idx in sampled:
tree[b][idx[0]][idx[1]] = 1
all_log_probs = torch.stack(all_log_probs, 0)
return tree, all_log_probs, tree_brackets, spans
def _viterbi(self, scores):
# cky algorithm
batch_size = scores.size(0)
n = scores.size(1)
self.max_scores = scores.new(batch_size, n, n).fill_(-self.huge)
self.bp = scores.new(batch_size, n, n).zero_()
self.argmax = scores.new(batch_size, n, n).zero_()
self.spans = [[] for _ in range(batch_size)]
tmp = scores.new(batch_size, n).zero_()
for i in range(n):
self.max_scores[:, i, i] = scores[:, i, i]
for k in np.arange(1, n):
for s in np.arange(n):
t = s + k
if t > n-1:
break
for u in np.arange(s, t):
tmp = self.max_scores[:, s, u] + self.max_scores[:, u+1, t] + scores[:, s, t]
self.bp[:, s, t][self.max_scores[:, s, t] < tmp] = int(u)
self.max_scores[:, s, t] = torch.max(self.max_scores[:, s, t], tmp)
for b in range(batch_size):
self._backtrack(b, 0, n-1)
return self.max_scores[:, 0, n-1], self.argmax, self.spans
def _backtrack(self, b, s, t):
u = int(self.bp[b][s][t])
self.argmax[b][s][t] = 1
if s == t:
return None
else:
self.spans[b].insert(0, (s,t))
self._backtrack(b, s, u)
self._backtrack(b, u+1, t)
return None
def get_span_str(start = None, end = None):
assert(start is not None or end is not None)
if start is None:
return ' ' + str(end) + ')'
elif end is None:
return '(' + str(start) + ' '
else:
return ' (' + str(start) + ' ' + str(end) + ') '
================================================
FILE: data/test.txt
================================================
(S (INTJ (RB No)) (, ,) (NP (PRP it)) (VP (VBD was) (RB n't) (NP (NNP Black) (NNP Monday))) (. .))
(S (CC But) (SBAR (IN while) (S (NP (DT the) (NNP New) (NNP York) (NNP Stock) (NNP Exchange)) (VP (VBD did) (RB n't) (VP (VB fall) (ADVP (RB apart)) (NP (NNP Friday)) (SBAR (IN as) (S (NP (DT the) (NNP Dow) (NNP Jones) (NNP Industrial) (NNP Average)) (VP (VBD plunged) (NP (NP (CD 190.58) (NNS points)) (PRN (: --) (NP (NP (JJS most)) (PP (IN of) (NP (PRP it))) (PP (IN in) (NP (DT the) (JJ final) (NN hour)))) (: --)))))))))) (NP (PRP it)) (ADVP (RB barely)) (VP (VBD managed) (S (VP (TO to) (VP (VB stay) (NP (NP (DT this) (NN side)) (PP (IN of) (NP (NN chaos)))))))) (. .))
(S (NP (NP (DT Some) (`` ``) (NN circuit) (NNS breakers) ('' '')) (VP (VBN installed) (PP (IN after) (NP (DT the) (NNP October) (CD 1987) (NN crash))))) (VP (VBD failed) (NP (PRP$ their) (JJ first) (NN test)) (PRN (, ,) (S (NP (NNS traders)) (VP (VBP say))) (, ,)) (S (ADJP (JJ unable) (S (VP (TO to) (VP (VB cool) (NP (NP (DT the) (NN selling) (NN panic)) (PP (IN in) (NP (DT both) (NNS stocks) (CC and) (NNS futures)))))))))) (. .))
(S (NP (NP (NP (DT The) (CD 49) (NN stock) (NN specialist) (NNS firms)) (PP (IN on) (NP (DT the) (NNP Big) (NNP Board) (NN floor)))) (: --) (NP (NP (DT the) (NNS buyers) (CC and) (NNS sellers)) (PP (IN of) (NP (JJ last) (NN resort))) (SBAR (WHNP (WP who)) (S (VP (VBD were) (VP (VBN criticized) (PP (IN after) (NP (DT the) (CD 1987) (NN crash)))))))) (: --)) (ADVP (RB once) (RB again)) (VP (MD could) (RB n't) (VP (VB handle) (NP (DT the) (NN selling) (NN pressure)))) (. .))
(S (S (NP (JJ Big) (NN investment) (NNS banks)) (VP (VBD refused) (S (VP (TO to) (VP (VB step) (ADVP (IN up) (PP (TO to) (NP (DT the) (NN plate)))) (S (VP (TO to) (VP (VB support) (NP (DT the) (JJ beleaguered) (NN floor) (NNS traders)) (PP (IN by) (S (VP (VBG buying) (NP (NP (JJ big) (NNS blocks)) (PP (IN of) (NP (NN stock))))))))))))))) (, ,) (NP (NNS traders)) (VP (VBP say)) (. .))
(S (NP (NP (JJ Heavy) (NN selling)) (PP (IN of) (NP (NP (NNP Standard) (CC &) (NNP Poor) (POS 's)) (JJ 500-stock) (NN index) (NNS futures))) (PP (IN in) (NP (NNP Chicago)))) (VP (ADVP (RB relentlessly)) (VBD beat) (NP (NNS stocks)) (ADVP (RB downward))) (. .))
(S (NP (NP (CD Seven) (NNP Big) (NNP Board) (NNS stocks)) (: --) (NP (NP (NNP UAL)) (, ,) (NP (NNP AMR)) (, ,) (NP (NNP BankAmerica)) (, ,) (NP (NNP Walt) (NNP Disney)) (, ,) (NP (NNP Capital) (NNP Cities\/ABC)) (, ,) (NP (NNP Philip) (NNP Morris)) (CC and) (NP (NNP Pacific) (NNP Telesis) (NNP Group))) (: --)) (VP (VP (VBD stopped) (S (VP (VBG trading)))) (CC and) (VP (ADVP (RB never)) (VBD resumed))) (. .))
(S (NP (DT The) (NN finger-pointing)) (VP (VBZ has) (ADVP (RB already)) (VP (VBN begun))) (. .))
(S (`` ``) (NP (DT The) (NN equity) (NN market)) (VP (VBD was) (ADJP (JJ illiquid))) (. .))
(SINV (S (ADVP (RB Once) (RB again)) (-LRB- -LCB-) (NP (DT the) (NNS specialists)) (-RRB- -RCB-) (VP (VBD were) (RB not) (ADJP (JJ able) (S (VP (TO to) (VP (VB handle) (NP (NP (DT the) (NNS imbalances)) (PP (IN on) (NP (NP (DT the) (NN floor)) (PP (IN of) (NP (DT the) (NNP New) (NNP York) (NNP Stock) (NNP Exchange)))))))))))) (, ,) ('' '') (VP (VBD said)) (NP (NP (NNP Christopher) (NNP Pedersen)) (, ,) (NP (NP (JJ senior) (NN vice) (NN president)) (PP (IN at) (NP (NNP Twenty-First) (NNP Securities) (NNP Corp))))) (. .))
(SINV (VP (VBD Countered)) (NP (NP (NNP James) (NNP Maguire)) (, ,) (NP (NP (NN chairman)) (PP (IN of) (NP (NNS specialists) (NNP Henderson) (NNP Brothers) (NNP Inc.))))) (: :) (`` ``) (S (NP (PRP It)) (VP (VBZ is) (ADJP (JJ easy)) (S (VP (TO to) (VP (VB say) (SBAR (S (NP (DT the) (NN specialist)) (VP (VBZ is) (RB n't) (VP (VBG doing) (NP (PRP$ his) (NN job))))))))))) (. .))
(S (SBAR (WHADVP (WRB When)) (S (NP (DT the) (NN dollar)) (VP (VBZ is) (PP (IN in) (NP (DT a) (NN free-fall)))))) (, ,) (NP (RB even) (JJ central) (NNS banks)) (VP (MD ca) (RB n't) (VP (VB stop) (NP (PRP it)))) (. .))
(S (NP (NNS Speculators)) (VP (VBP are) (VP (VBG calling) (PP (IN for) (NP (NP (DT a) (NN degree)) (PP (IN of) (NP (NN liquidity))) (SBAR (WHNP (WDT that)) (S (VP (VBZ is) (RB not) (ADVP (RB there)) (PP (IN in) (NP (DT the) (NN market)))))))))) (. .) ('' ''))
(S (NP (NP (JJ Many) (NN money) (NNS managers)) (CC and) (NP (DT some) (NNS traders))) (VP (VBD had) (ADVP (RB already)) (VP (VBN left) (NP (PRP$ their) (NNS offices)) (NP (RB early) (NNP Friday) (NN afternoon)) (PP (IN on) (NP (DT a) (JJ warm) (NN autumn) (NN day))) (: --) (SBAR (IN because) (S (NP (DT the) (NN stock) (NN market)) (VP (VBD was) (ADJP (RB so) (JJ quiet))))))) (. .))
(S (RB Then) (PP (IN in) (NP (DT a) (NN lightning) (NN plunge))) (, ,) (NP (DT the) (NNP Dow) (NNP Jones) (NNS industrials)) (PP (IN in) (NP (QP (RB barely) (DT an)) (NN hour))) (VP (VBD surrendered) (NP (NP (QP (RB about) (DT a)) (JJ third)) (PP (IN of) (NP (NP (PRP$ their) (NNS gains)) (NP (DT this) (NN year))))) (, ,) (S (VP (VBG chalking) (PRT (RP up)) (NP (NP (DT a) (ADJP (ADJP (JJ 190.58-point)) (, ,) (CC or) (ADJP (CD 6.9) (NN %)) (, ,)) (NN loss)) (PP (IN on) (NP (DT the) (NN day)))) (PP (IN in) (NP (JJ gargantuan) (NN trading) (NN volume)))))) (. .))
(S (NP (JJ Final-hour) (NN trading)) (VP (VBD accelerated) (PP (TO to) (NP (NP (QP (CD 108.1) (CD million)) (NNS shares)) (, ,) (NP (NP (DT a) (NN record)) (PP (IN for) (NP (DT the) (NNP Big) (NNP Board))))))) (. .))
(S (PP (IN At) (NP (NP (DT the) (NN end)) (PP (IN of) (NP (DT the) (NN day))))) (, ,) (NP (QP (CD 251.2) (CD million)) (NNS shares)) (VP (VBD were) (VP (VBN traded))) (. .))
(S (NP (DT The) (NNP Dow) (NNP Jones) (NNS industrials)) (VP (VBD closed) (PP (IN at) (NP (CD 2569.26)))) (. .))
(S (NP (NP (DT The) (NNP Dow) (POS 's)) (NN decline)) (VP (VBD was) (ADJP (JJ second) (PP (IN in) (NP (NN point) (NNS terms))) (PP (ADVP (RB only)) (TO to) (NP (NP (DT the) (JJ 508-point) (NNP Black) (NNP Monday) (NN crash)) (SBAR (WHNP (WDT that)) (S (VP (VBD occurred) (NP (NNP Oct.) (CD 19) (, ,) (CD 1987))))))))) (. .))
(S (PP (IN In) (NP (NN percentage) (NNS terms))) (, ,) (ADVP (RB however)) (, ,) (NP (NP (DT the) (NNP Dow) (POS 's)) (NN dive)) (VP (VBD was) (NP (NP (NP (DT the) (JJ 12th-worst)) (ADVP (RB ever))) (CC and) (NP (NP (DT the) (JJS sharpest)) (SBAR (IN since) (S (NP (DT the) (NN market)) (VP (VBD fell) (NP (NP (CD 156.83)) (, ,) (CC or) (NP (CD 8) (NN %))) (, ,) (PP (NP (DT a) (NN week)) (IN after) (NP (NNP Black) (NNP Monday))))))))) (. .))
================================================
FILE: data/train.txt
================================================
(S (PP (IN In) (NP (NP (DT an) (NNP Oct.) (CD 19) (NN review)) (PP (IN of) (NP (`` ``) (NP (DT The) (NN Misanthrope)) ('' '') (PP (IN at) (NP (NP (NNP Chicago) (POS 's)) (NNP Goodman) (NNP Theatre))))) (PRN (-LRB- -LRB-) (`` ``) (S (NP (VBN Revitalized) (NNS Classics)) (VP (VBP Take) (NP (DT the) (NN Stage)) (PP (IN in) (NP (NNP Windy) (NNP City))))) (, ,) ('' '') (NP (NN Leisure) (CC &) (NNS Arts)) (-RRB- -RRB-)))) (, ,) (NP (NP (NP (DT the) (NN role)) (PP (IN of) (NP (NNP Celimene)))) (, ,) (VP (VBN played) (PP (IN by) (NP (NNP Kim) (NNP Cattrall)))) (, ,)) (VP (VBD was) (VP (ADVP (RB mistakenly)) (VBN attributed) (PP (TO to) (NP (NNP Christina) (NNP Haag))))) (. .))
(S (NP (NNP Ms.) (NNP Haag)) (VP (VBZ plays) (NP (NNP Elianti))) (. .))
(S (NP (NNP Rolls-Royce) (NNP Motor) (NNPS Cars) (NNP Inc.)) (VP (VBD said) (SBAR (S (NP (PRP it)) (VP (VBZ expects) (S (NP (PRP$ its) (NNP U.S.) (NNS sales)) (VP (TO to) (VP (VB remain) (ADJP (JJ steady)) (PP (IN at) (NP (QP (IN about) (CD 1,200)) (NNS cars))) (PP (IN in) (NP (CD 1990)))))))))) (. .))
(S (NP (DT The) (NN luxury) (NN auto) (NN maker)) (NP (JJ last) (NN year)) (VP (VBD sold) (NP (CD 1,214) (NNS cars)) (PP (IN in) (NP (DT the) (NNP U.S.)))))
(S (NP (NP (NNP Howard) (NNP Mosher)) (, ,) (NP (NP (NN president)) (CC and) (NP (JJ chief) (NN executive) (NN officer))) (, ,)) (VP (VBD said) (SBAR (S (NP (PRP he)) (VP (VBZ anticipates) (NP (NP (NN growth)) (PP (IN for) (NP (DT the) (NN luxury) (NN auto) (NN maker))) (PP (PP (IN in) (NP (NNP Britain) (CC and) (NNP Europe))) (, ,) (CC and) (PP (IN in) (NP (ADJP (JJ Far) (JJ Eastern)) (NNS markets))))))))) (. .))
(S (NP (NNP BELL) (NNP INDUSTRIES) (NNP Inc.)) (VP (VBD increased) (NP (PRP$ its) (NN quarterly)) (PP (TO to) (NP (CD 10) (NNS cents))) (PP (IN from) (NP (NP (CD seven) (NNS cents)) (NP (DT a) (NN share))))) (. .))
(S (NP (DT The) (JJ new) (NN rate)) (VP (MD will) (VP (VB be) (ADJP (JJ payable) (NP (NNP Feb.) (CD 15))))) (. .))
(S (NP (DT A) (NN record) (NN date)) (VP (VBZ has) (RB n't) (VP (VBN been) (VP (VBN set)))) (. .))
(S (NP (NP (NNP Bell)) (, ,) (VP (VBN based) (PP (IN in) (NP (NNP Los) (NNP Angeles)))) (, ,)) (VP (VBZ makes) (CC and) (VBZ distributes) (NP (UCP (JJ electronic) (, ,) (NN computer) (CC and) (NN building)) (NNS products))) (. .))
(S (NP (NNS Investors)) (VP (VBP are) (VP (VBG appealing) (PP (TO to) (NP (DT the) (NNPS Securities) (CC and) (NNP Exchange) (NNP Commission))) (S (RB not) (VP (TO to) (VP (VB limit) (NP (NP (PRP$ their) (NN access)) (PP (TO to) (NP (NP (NN information)) (PP (IN about) (NP (NP (NN stock) (NNS purchases) (CC and) (NNS sales)) (PP (IN by) (NP (JJ corporate) (NNS insiders))))))))))))) (. .))
(S (S (NP (DT A) (NNP SEC) (NN proposal) (S (VP (TO to) (VP (VB ease) (NP (NP (NN reporting) (NNS requirements)) (PP (IN for) (NP (DT some) (NN company) (NNS executives)))))))) (VP (MD would) (VP (VB undermine) (NP (NP (DT the) (NN usefulness)) (PP (IN of) (NP (NP (NN information)) (PP (IN on) (NP (NN insider) (NNS trades))))) (PP (IN as) (NP (DT a) (JJ stock-picking) (NN tool))))))) (, ,) (NP (NP (JJ individual) (NNS investors)) (CC and) (NP (JJ professional) (NN money) (NNS managers))) (VP (VBP contend)) (. .))
(S (NP (PRP They)) (VP (VBP make) (NP (DT the) (NN argument)) (PP (IN in) (NP (NP (NNS letters)) (PP (TO to) (NP (DT the) (NN agency))) (PP (IN about) (NP (NP (NN rule) (NNS changes)) (VP (VBD proposed) (NP (DT this) (JJ past) (NN summer))) (SBAR (WHNP (IN that)) (, ,) (S (PP (IN among) (NP (JJ other) (NNS things))) (, ,) (VP (MD would) (VP (VB exempt) (NP (JJ many) (JJ middle-management) (NNS executives)) (PP (IN from) (S (VP (VBG reporting) (NP (NP (NNS trades)) (PP (IN in) (NP (NP (PRP$ their) (JJ own) (NNS companies) (POS ')) (NNS shares)))))))))))))))) (. .))
(S (NP (DT The) (VBN proposed) (NNS changes)) (ADVP (RB also)) (VP (MD would) (VP (VB allow) (S (NP (NNS executives)) (VP (TO to) (VP (VB report) (NP (NP (NNS exercises)) (PP (IN of) (NP (NNS options)))) (ADVP (ADVP (RBR later)) (CC and) (ADVP (RBR less) (RB often)))))))) (. .))
(S (NP (NP (JJ Many)) (PP (IN of) (NP (DT the) (NNS letters)))) (VP (VBP maintain) (SBAR (IN that) (S (S (NP (NN investor) (NN confidence)) (VP (VBZ has) (VP (VBN been) (VP (ADVP (RB so)) (VBN shaken) (PP (IN by) (NP (DT the) (CD 1987) (NN stock) (NN market) (NN crash))))))) (: --) (CC and) (S (NP (DT the) (NNS markets)) (ADVP (RB already)) (VP (ADVP (RB so)) (VBN stacked) (PP (IN against) (NP (DT the) (JJ little) (NN guy))))) (: --) (SBAR (IN that) (S (NP (NP (DT any) (NN decrease)) (PP (IN in) (NP (NP (NN information)) (PP (IN on) (NP (NN insider-trading) (NNS patterns)))))) (VP (MD might) (VP (VB prompt) (S (NP (NNS individuals)) (VP (TO to) (VP (VB get) (ADVP (RB out) (PP (IN of) (NP (NNS stocks)))) (ADVP (RB altogether)))))))))))) (. .))
(SINV (`` ``) (S (NP (DT The) (NNP SEC)) (VP (VBZ has) (ADVP (RB historically)) (VP (VBN paid) (NP (NN obeisance)) (PP (TO to) (NP (NP (DT the) (NN ideal)) (PP (IN of) (NP (DT a) (JJ level) (NN playing) (NN field)))))))) (, ,) ('' '') (VP (VBD wrote)) (NP (NP (NNP Clyde) (NNP S.) (NNP McGregor)) (PP (IN of) (NP (NP (NNP Winnetka)) (, ,) (NP (NNP Ill.)) (, ,)))) (PP (IN in) (NP (NP (CD one)) (PP (IN of) (NP (NP (DT the) (CD 92) (NNS letters)) (SBAR (S (NP (DT the) (NN agency)) (VP (VBZ has) (VP (VBN received) (SBAR (IN since) (S (NP (DT the) (NNS changes)) (VP (VBD were) (VP (VBN proposed) (NP (NNP Aug.) (CD 17)))))))))))))) (. .))
(S (`` ``) (ADVP (RB Apparently)) (NP (DT the) (NN commission)) (VP (VBD did) (RB not) (ADVP (RB really)) (VP (VB believe) (PP (IN in) (NP (DT this) (NN ideal))))) (. .) ('' ''))
(S (ADVP (RB Currently)) (, ,) (NP (DT the) (NNS rules)) (VP (VBP force) (S (NP (NP (NNS executives)) (, ,) (NP (NNS directors)) (CC and) (NP (JJ other) (JJ corporate) (NNS insiders))) (VP (TO to) (VP (VB report) (NP (NP (NNS purchases) (CC and) (NNS sales)) (PP (IN of) (NP (NP (PRP$ their) (NNS companies) (POS ')) (NNS shares)))) (PP (IN within) (NP (NP (QP (IN about) (DT a)) (NN month)) (PP (IN after) (NP (DT the) (NN transaction))))))))) (. .))
(S (CC But) (NP (NP (QP (IN about) (CD 25)) (NN %)) (PP (IN of) (NP (DT the) (NNS insiders)))) (, ,) (PP (VBG according) (PP (TO to) (NP (NNP SEC) (NNS figures)))) (, ,) (VP (VBP file) (NP (PRP$ their) (NNS reports)) (ADVP (RB late))) (. .))
(SINV (S (NP (DT The) (NNS changes)) (VP (VBD were) (VP (VBN proposed) (PP (IN in) (NP (DT an) (NN effort) (S (VP (TO to) (VP (VP (VB streamline) (NP (JJ federal) (NN bureaucracy))) (CC and) (VP (VB boost) (NP (NP (NN compliance)) (PP (IN by) (NP (NP (DT the) (NNS executives)) (`` ``) (SBAR (WHNP (WP who)) (S (VP (VBP are) (ADVP (RB really)) (VP (VBG calling) (NP (DT the) (NNS shots)))))))))))))))))) (, ,) ('' '') (VP (VBD said)) (NP (NP (NNP Brian) (NNP Lane)) (, ,) (NP (NP (JJ special) (NN counsel)) (PP (IN at) (NP (NP (NP (NP (DT the) (NNP SEC) (POS 's)) (NN office)) (PP (IN of) (NP (NN disclosure) (NN policy)))) (, ,) (SBAR (WHNP (WDT which)) (S (VP (VBD proposed) (NP (DT the) (NNS changes))))))))) (. .))
(S (S (S (NP (NP (NNS Investors)) (, ,) (NP (NN money) (NNS managers)) (CC and) (NP (JJ corporate) (NNS officials))) (VP (VBD had) (PP (IN until) (NP (NN today))) (S (VP (TO to) (VP (VB comment) (PP (IN on) (NP (DT the) (NNS proposals)))))))) (, ,) (CC and) (S (NP (DT the) (NN issue)) (VP (VBZ has) (VP (VBN produced) (NP (NP (JJR more) (NN mail)) (PP (IN than) (NP (NP (ADJP (RB almost) (DT any)) (JJ other) (NN issue)) (PP (IN in) (NP (NN memory)))))))))) (, ,) (NP (NNP Mr.) (NNP Lane)) (VP (VBD said)) (. .))
================================================
FILE: data/valid.txt
================================================
(S (NP (NP (DT The) (NN economy) (POS 's)) (NN temperature)) (VP (MD will) (VP (VB be) (VP (VBN taken) (PP (IN from) (NP (JJ several) (NN vantage) (NNS points))) (NP (DT this) (NN week)) (, ,) (PP (IN with) (NP (NP (NNS readings)) (PP (IN on) (NP (NP (NN trade)) (, ,) (NP (NN output)) (, ,) (NP (NN housing)) (CC and) (NP (NN inflation))))))))) (. .))
(S (NP (DT The) (ADJP (RBS most) (JJ troublesome)) (NN report)) (VP (MD may) (VP (VB be) (NP (NP (DT the) (NNP August) (NN merchandise) (NN trade) (NN deficit)) (ADJP (JJ due) (ADVP (IN out)) (NP (NN tomorrow)))))) (. .))
(S (NP (DT The) (NN trade) (NN gap)) (VP (VBZ is) (VP (VBN expected) (S (VP (TO to) (VP (VB widen) (PP (TO to) (NP (QP (IN about) ($ $) (CD 9) (CD billion)))) (PP (IN from) (NP (NP (NNP July) (POS 's)) (QP ($ $) (CD 7.6) (CD billion))))))) (, ,) (PP (VBG according) (PP (TO to) (NP (NP (DT a) (NN survey)) (PP (IN by) (NP (NP (NNP MMS) (NNP International)) (, ,) (NP (NP (DT a) (NN unit)) (PP (IN of) (NP (NP (NNP McGraw-Hill) (NNP Inc.)) (, ,) (NP (NNP New) (NNP York)))))))))))) (. .))
(S (NP (NP (NP (NNP Thursday) (POS 's)) (NN report)) (PP (IN on) (NP (DT the) (NNP September) (NN consumer) (NN price) (NN index)))) (VP (VBZ is) (VP (VBN expected) (S (VP (TO to) (VP (VB rise) (, ,) (SBAR (IN although) (ADVP (ADVP (RB not) (RB as) (RB sharply)) (PP (IN as) (NP (NP (DT the) (ADJP (CD 0.9) (NN %)) (NN gain)) (VP (VBN reported) (NP (NNP Friday)) (PP (IN in) (NP (DT the) (NN producer) (NN price) (NN index))))))))))))) (. .))
(S (NP (DT That) (NN gain)) (VP (VBD was) (VP (VBG being) (VP (VBD cited) (PP (IN as) (NP (NP (DT a) (NN reason)) (SBAR (S (NP (DT the) (NN stock) (NN market)) (VP (VBD was) (ADVP (IN down)) (ADVP (RB early) (PP (IN in) (NP (NP (NNP Friday) (POS 's)) (NN session)))) (, ,) (SBAR (IN before) (S (NP (PRP it)) (VP (VBD got) (S (VP (VBN started) (PP (IN on) (NP (PRP$ its) (JJ reckless) (JJ 190-point) (NN plunge)))))))))))))))) (. .))
(S (NP (NNS Economists)) (VP (VBP are) (VP (VBN divided) (PP (IN as) (PP (TO to) (SBAR (WHNP (WHADVP (WRB how) (JJ much)) (VBG manufacturing) (NN strength)) (S (NP (PRP they)) (VP (VBP expect) (S (VP (TO to) (VP (VB see) (PP (IN in) (NP (NP (NP (NNP September) (NNS reports)) (PP (IN on) (NP (NP (JJ industrial) (NN production)) (CC and) (NP (NN capacity) (NN utilization))))) (, ,) (ADJP (ADVP (RB also)) (JJ due) (NP (NN tomorrow))))))))))))))) (. .))
(S (ADVP (RB Meanwhile)) (, ,) (NP (NP (NNP September) (NN housing) (NNS starts)) (, ,) (ADJP (JJ due) (NP (NNP Wednesday))) (, ,)) (VP (VBP are) (VP (VBN thought) (S (VP (TO to) (VP (VB have) (VP (VBN inched) (ADVP (RB upward)))))))) (. .))
(SINV (S (`` ``) (NP (EX There)) (VP (VBZ 's) (NP (NP (DT a) (NN possibility)) (PP (IN of) (NP (NP (DT a) (NN surprise)) ('' '') (PP (IN in) (NP (DT the) (NN trade) (NN report)))))))) (, ,) (VP (VBD said)) (NP (NP (NNP Michael) (NNP Englund)) (, ,) (NP (NP (NN director)) (PP (IN of) (NP (NN research))) (PP (IN at) (NP (NNP MMS))))) (. .))
(S (S (NP (NP (DT A) (NN widening)) (PP (IN of) (NP (DT the) (NN deficit)))) (, ,) (SBAR (IN if) (S (NP (PRP it)) (VP (VBD were) (VP (VBN combined) (PP (IN with) (NP (DT a) (ADJP (RB stubbornly) (JJ strong)) (NN dollar))))))) (, ,) (VP (MD would) (VP (VB exacerbate) (NP (NN trade) (NNS problems)))) (: --)) (CC but) (S (NP (DT the) (NN dollar)) (VP (VBD weakened) (NP (NNP Friday)) (SBAR (IN as) (S (NP (NNS stocks)) (VP (VBD plummeted)))))) (. .))
(S (PP (IN In) (NP (DT any) (NN event))) (, ,) (NP (NP (NNP Mr.) (NNP Englund)) (CC and) (NP (JJ many) (NNS others))) (VP (VBP say) (SBAR (IN that) (S (NP (NP (DT the) (JJ easy) (NNS gains)) (PP (IN in) (S (VP (VBG narrowing) (NP (DT the) (NN trade) (NN gap)))))) (VP (VBP have) (ADVP (RB already)) (VP (VBN been) (VP (VBN made))))))) (. .))
(S (`` ``) (S (NP (NN Trade)) (VP (VBZ is) (ADVP (RB definitely)) (VP (VBG going) (S (VP (TO to) (VP (VB be) (ADJP (RBR more) (RB politically) (JJ sensitive)) (PP (IN over) (NP (DT the) (JJ next) (QP (CD six) (CC or) (CD seven)) (NNS months))) (SBAR (IN as) (S (NP (NN improvement)) (VP (VBZ begins) (S (VP (TO to) (VP (VB slow))))))))))))) (, ,) ('' '') (NP (PRP he)) (VP (VBD said)) (. .))
(S (S (NP (NNS Exports)) (VP (VBP are) (VP (VBN thought) (S (VP (TO to) (VP (VB have) (VP (VBN risen) (ADVP (ADVP (RB strongly) (PP (IN in) (NP (NNP August)))) (, ,) (CC but) (ADVP (ADVP (RB probably)) (RB not) (RB enough) (S (VP (TO to) (VP (VB offset) (NP (NP (DT the) (NN jump)) (PP (IN in) (NP (NNS imports)))))))))))))))) (, ,) (NP (NNS economists)) (VP (VBD said)) (. .))
(S (NP (NP (NNS Views)) (PP (IN on) (NP (VBG manufacturing) (NN strength)))) (VP (VBP are) (ADJP (VBN split) (PP (IN between) (NP (NP (NP (NNS economists)) (SBAR (WHNP (WP who)) (S (VP (VBP read) (NP (NP (NP (NNP September) (POS 's)) (JJ low) (NN level)) (PP (IN of) (NP (NN factory) (NN job) (NN growth)))) (PP (IN as) (NP (NP (DT a) (NN sign)) (PP (IN of) (NP (DT a) (NN slowdown))))))))) (CC and) (NP (NP (DT those)) (SBAR (WHNP (WP who)) (S (VP (VBP use) (NP (DT the) (ADJP (RB somewhat) (JJR more) (VBG comforting)) (JJ total) (NN employment) (NNS figures)) (PP (IN in) (NP (PRP$ their) (NNS calculations))))))))))) (. .))
(S (S (NP (NP (DT The) (JJ wide) (NN range)) (PP (IN of) (NP (NP (NNS estimates)) (PP (IN for) (NP (DT the) (JJ industrial) (NN output) (NN number)))))) (VP (VBZ underscores) (NP (DT the) (NNS differences)))) (: :) (S (NP (DT The) (NNS forecasts)) (VP (VBD run) (PP (IN from) (NP (NP (DT a) (NN drop)) (PP (IN of) (NP (CD 0.5) (NN %))))) (PP (TO to) (NP (NP (DT an) (NN increase)) (PP (IN of) (NP (CD 0.4) (NN %))))) (, ,) (PP (VBG according) (PP (TO to) (NP (NNP MMS)))))) (. .))
(S (NP (NP (DT A) (NN rebound)) (PP (IN in) (NP (NN energy) (NNS prices))) (, ,) (SBAR (WHNP (WDT which)) (S (VP (VBD helped) (VP (VB push) (PRT (RP up)) (NP (DT the) (NN producer) (NN price) (NN index)))))) (, ,)) (VP (VBZ is) (VP (VBN expected) (S (VP (TO to) (VP (VB do) (NP (DT the) (JJ same)) (PP (IN in) (NP (DT the) (NN consumer) (NN price) (NN report)))))))) (. .))
(S (NP (DT The) (NN consensus) (NN view)) (VP (VBZ expects) (NP (NP (DT a) (ADJP (CD 0.4) (NN %)) (NN increase)) (PP (IN in) (NP (DT the) (NNP September) (NNP CPI)))) (PP (IN after) (NP (NP (DT a) (JJ flat) (NN reading)) (PP (IN in) (NP (NNP August)))))) (. .))
(S (NP (NP (NNP Robert) (NNP H.) (NNP Chandross)) (, ,) (NP (NP (DT an) (NN economist)) (PP (IN for) (NP (NP (NP (NNP Lloyd) (POS 's)) (NNP Bank)) (PP (IN in) (NP (NNP New) (NNP York)))))) (, ,)) (VP (VBZ is) (PP (IN among) (NP (NP (DT those)) (VP (VBG expecting) (NP (NP (DT a) (ADJP (RBR more) (JJ moderate)) (NN gain)) (PP (IN in) (NP (DT the) (NNP CPI))) (PP (IN than) (PP (IN in) (NP (NP (NNS prices)) (PP (IN at) (NP (DT the) (NN producer) (NN level))))))))))) (. .))
(S (`` ``) (S (S (NP (NN Auto) (NNS prices)) (VP (VBD had) (NP (DT a) (JJ big) (NN effect)) (PP (IN in) (NP (DT the) (NNP PPI))))) (, ,) (CC and) (S (PP (IN at) (NP (DT the) (NNP CPI) (NN level))) (NP (PRP they)) (VP (MD wo) (RB n't)))) (, ,) ('' '') (NP (PRP he)) (VP (VBD said)) (. .))
(SINV (S (S (NP (NN Food) (NNS prices)) (VP (VBP are) (VP (VBN expected) (S (VP (TO to) (VP (VB be) (ADJP (JJ unchanged)))))))) (, ,) (CC but) (S (NP (NN energy) (NNS costs)) (VP (VBD jumped) (NP (NP (RB as) (RB much) (IN as) (CD 4)) (NN %))))) (, ,) (VP (VBD said)) (NP (NP (NNP Gary) (NNP Ciminero)) (, ,) (NP (NP (NN economist)) (PP (IN at) (NP (NNP Fleet\/Norstar) (NNP Financial) (NNP Group))))) (. .))
(S (NP (PRP He)) (ADVP (RB also)) (VP (VBZ says) (SBAR (S (NP (PRP he)) (VP (VBZ thinks) (SBAR (S (NP (`` ``) (NP (NN core) (NN inflation)) (, ,) ('' '') (SBAR (WHNP (WDT which)) (S (VP (VBZ excludes) (NP (DT the) (JJ volatile) (NN food) (CC and) (NN energy) (NNS prices))))) (, ,)) (VP (VBD was) (ADJP (JJ strong)) (NP (JJ last) (NN month))))))))) (. .))
================================================
FILE: data.py
================================================
#!/usr/bin/env python3
import numpy as np
import torch
import pickle
class Dataset(object):
def __init__(self, data_file):
data = pickle.load(open(data_file, 'rb')) #get text data
self.sents = self._convert(data['source']).long()
self.other_data = data['other_data']
self.sent_lengths = self._convert(data['source_l']).long()
self.batch_size = self._convert(data['batch_l']).long()
self.batch_idx = self._convert(data['batch_idx']).long()
self.vocab_size = data['vocab_size'][0]
self.num_batches = self.batch_idx.size(0)
self.word2idx = data['word2idx']
self.idx2word = data['idx2word']
def _convert(self, x):
return torch.from_numpy(np.asarray(x))
def __len__(self):
return self.num_batches
def __getitem__(self, idx):
assert(idx < self.num_batches and idx >= 0)
start_idx = self.batch_idx[idx]
end_idx = start_idx + self.batch_size[idx]
length = self.sent_lengths[idx].item()
sents = self.sents[start_idx:end_idx]
other_data = self.other_data[start_idx:end_idx]
sent_str = [d[0] for d in other_data]
tags = [d[1] for d in other_data]
actions = [d[2] for d in other_data]
binary_tree = [d[3] for d in other_data]
spans = [d[5] for d in other_data]
batch_size = self.batch_size[idx].item()
# by default, we return sents with <s> </s> tokens
# hence we subtract 2 from length as these are (by default) not counted for evaluation
data_batch = [sents[:, :length], length-2, batch_size, actions,
spans, binary_tree, other_data]
return data_batch
================================================
FILE: eval_ppl.py
================================================
#!/usr/bin/env python3
import sys
import os
import argparse
import json
import random
import shutil
import copy
import torch
from torch import cuda
import torch.nn as nn
from torch.autograd import Variable
from torch.nn.parameter import Parameter
import torch.nn.functional as F
import numpy as np
import time
import logging
from data import Dataset
from models import RNNG
from utils import *
parser = argparse.ArgumentParser()
# Data path options
parser.add_argument('--test_file', default='data/ptb-test.pkl')
parser.add_argument('--model_file', default='')
parser.add_argument('--is_temp', default=2., type=float, help='divide scores by is_temp before CRF')
parser.add_argument('--samples', default=1000, type=int, help='samples for IS calculation')
parser.add_argument('--count_eos_ppl', default=0, type=int, help='whether to count eos in val PPL')
parser.add_argument('--gpu', default=2, type=int, help='which gpu to use')
parser.add_argument('--seed', default=3435, type=int)
def main(args):
np.random.seed(args.seed)
torch.manual_seed(args.seed)
data = Dataset(args.test_file)
checkpoint = torch.load(args.model_file)
model = checkpoint['model']
print("model architecture")
print(model)
cuda.set_device(args.gpu)
model.cuda()
model.eval()
num_sents = 0
num_words = 0
total_nll_recon = 0.
total_kl = 0.
total_nll_iwae = 0.
samples_batch = 50
S = args.samples // samples_batch
samples = S*samples_batch
with torch.no_grad():
for i in list(reversed(range(len(data)))):
sents, length, batch_size, gold_actions, gold_spans, gold_binary_trees, other_data = data[i]
if length == 1:
# length 1 sents are ignored since our generative model requires sents of length >= 2
continue
if args.count_eos_ppl == 1:
length += 1
else:
sents = sents[:, :-1]
sents = sents.cuda()
ll_word_all2 = []
ll_action_p_all2 = []
ll_action_q_all2 = []
for j in range(S):
ll_word_all, ll_action_p_all, ll_action_q_all, actions_all, q_entropy = model(
sents, samples = samples_batch, is_temp = args.is_temp, has_eos = args.count_eos_ppl == 1)
ll_word_all2.append(ll_word_all.detach().cpu())
ll_action_p_all2.append(ll_action_p_all.detach().cpu())
ll_action_q_all2.append(ll_action_q_all.detach().cpu())
ll_word_all2 = torch.cat(ll_word_all2, 1)
ll_action_p_all2 = torch.cat(ll_action_p_all2, 1)
ll_action_q_all2 = torch.cat(ll_action_q_all2, 1)
sample_ll = torch.zeros(batch_size, ll_word_all2.size(1))
total_nll_recon += -ll_word_all.mean(1).sum().item()
total_kl += (ll_action_q_all - ll_action_p_all).mean(1).sum().item()
for j in range(sample_ll.size(1)):
ll_word_j, ll_action_p_j, ll_action_q_j = ll_word_all2[:, j], ll_action_p_all2[:, j], ll_action_q_all2[:, j]
sample_ll[:, j].copy_(ll_word_j + ll_action_p_j - ll_action_q_j)
ll_iwae = model.logsumexp(sample_ll, 1) - np.log(samples)
total_nll_iwae -= ll_iwae.sum().item()
num_sents += batch_size
num_words += batch_size * length
print('Batch: %d/%d, ElboPPL: %.2f, KL: %.4f, IwaePPL: %.2f' %
(i, len(data), np.exp((total_nll_recon + total_kl) / num_words),
total_kl / num_sents, np.exp(total_nll_iwae / num_words)))
elbo_ppl = np.exp((total_nll_recon + total_kl) / num_words)
recon_ppl = np.exp(total_nll_recon / num_words)
iwae_ppl = np.exp(total_nll_iwae /num_words)
kl = total_kl / num_sents
print('ElboPPL: %.2f, ReconPPL: %.2f, KL: %.4f, IwaePPL: %.2f' %
(elbo_ppl, recon_ppl, kl, iwae_ppl))
if __name__ == '__main__':
args = parser.parse_args()
main(args)
================================================
FILE: models.py
================================================
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from utils import *
from TreeCRF import ConstituencyTreeCRF
from torch.distributions import Bernoulli
class RNNLM(nn.Module):
def __init__(self, vocab=10000,
w_dim=650,
h_dim=650,
num_layers=2,
dropout=0.5):
super(RNNLM, self).__init__()
self.h_dim = h_dim
self.num_layers = num_layers
self.word_vecs = nn.Embedding(vocab, w_dim)
self.dropout = nn.Dropout(dropout)
self.rnn = nn.LSTM(w_dim, h_dim, num_layers = num_layers,
dropout = dropout, batch_first = True)
self.vocab_linear = nn.Linear(h_dim, vocab)
self.vocab_linear.weight = self.word_vecs.weight # weight sharing
def forward(self, sent):
word_vecs = self.dropout(self.word_vecs(sent[:, :-1]))
h, _ = self.rnn(word_vecs)
log_prob = F.log_softmax(self.vocab_linear(self.dropout(h)), 2) # b x l x v
ll = torch.gather(log_prob, 2, sent[:, 1:].unsqueeze(2)).squeeze(2)
return ll.sum(1)
def generate(self, bos = 2, eos = 3, max_len = 150):
x = []
bos = torch.LongTensor(1,1).cuda().fill_(bos)
emb = self.dropout(self.word_vecs(bos))
prev_h = None
for l in range(max_len):
h, prev_h = self.rnn(emb, prev_h)
prob = F.softmax(self.vocab_linear(self.dropout(h.squeeze(1))), 1)
sample = torch.multinomial(prob, 1)
emb = self.dropout(self.word_vecs(sample))
x.append(sample.item())
if x[-1] == eos:
x.pop()
break
return x
class SeqLSTM(nn.Module):
def __init__(self, i_dim = 200,
h_dim = 0,
num_layers = 1,
dropout = 0):
super(SeqLSTM, self).__init__()
self.i_dim = i_dim
self.h_dim = h_dim
self.num_layers = num_layers
self.linears = nn.ModuleList([nn.Linear(h_dim + i_dim, h_dim*4) if l == 0 else
nn.Linear(h_dim*2, h_dim*4) for l in range(num_layers)])
self.dropout = dropout
self.dropout_layer = nn.Dropout(dropout)
def forward(self, x, prev_h = None):
if prev_h is None:
prev_h = [(x.new(x.size(0), self.h_dim).fill_(0),
x.new(x.size(0), self.h_dim).fill_(0)) for _ in range(self.num_layers)]
curr_h = []
for l in range(self.num_layers):
input = x if l == 0 else curr_h[l-1][0]
if l > 0 and self.dropout > 0:
input = self.dropout_layer(input)
concat = torch.cat([input, prev_h[l][0]], 1)
all_sum = self.linears[l](concat)
i, f, o, g = all_sum.split(self.h_dim, 1)
c = F.sigmoid(f)*prev_h[l][1] + F.sigmoid(i)*F.tanh(g)
h = F.sigmoid(o)*F.tanh(c)
curr_h.append((h, c))
return curr_h
class TreeLSTM(nn.Module):
def __init__(self, dim = 200):
super(TreeLSTM, self).__init__()
self.dim = dim
self.linear = nn.Linear(dim*2, dim*5)
def forward(self, x1, x2, e=None):
if not isinstance(x1, tuple):
x1 = (x1, None)
h1, c1 = x1
if x2 is None:
x2 = (torch.zeros_like(h1), torch.zeros_like(h1))
elif not isinstance(x2, tuple):
x2 = (x2, None)
h2, c2 = x2
if c1 is None:
c1 = torch.zeros_like(h1)
if c2 is None:
c2 = torch.zeros_like(h2)
concat = torch.cat([h1, h2], 1)
all_sum = self.linear(concat)
i, f1, f2, o, g = all_sum.split(self.dim, 1)
c = F.sigmoid(f1)*c1 + F.sigmoid(f2)*c2 + F.sigmoid(i)*F.tanh(g)
h = F.sigmoid(o)*F.tanh(c)
return (h, c)
class RNNG(nn.Module):
def __init__(self, vocab = 100,
w_dim = 20,
h_dim = 20,
num_layers = 1,
dropout = 0,
q_dim = 20,
max_len = 250):
super(RNNG, self).__init__()
self.S = 0 #action idx for shift/generate
self.R = 1 #action idx for reduce
self.emb = nn.Embedding(vocab, w_dim)
self.dropout = nn.Dropout(dropout)
self.stack_rnn = SeqLSTM(w_dim, h_dim, num_layers = num_layers, dropout = dropout)
self.tree_rnn = TreeLSTM(w_dim)
self.vocab_mlp = nn.Sequential(nn.Dropout(dropout), nn.Linear(h_dim, vocab))
self.num_layers = num_layers
self.q_binary = nn.Sequential(nn.Linear(q_dim*2, q_dim*2), nn.ReLU(), nn.LayerNorm(q_dim*2),
nn.Dropout(dropout), nn.Linear(q_dim*2, 1))
self.action_mlp_p = nn.Sequential(nn.Dropout(dropout), nn.Linear(h_dim, 1))
self.w_dim = w_dim
self.h_dim = h_dim
self.q_dim = q_dim
self.q_leaf_rnn = nn.LSTM(w_dim, q_dim, bidirectional = True, batch_first = True)
self.q_crf = ConstituencyTreeCRF()
self.pad1 = 0 # idx for <s> token from ptb.dict
self.pad2 = 2 # idx for </s> token from ptb.dict
self.q_pos_emb = nn.Embedding(max_len, w_dim) # position embeddings
self.vocab_mlp[-1].weight = self.emb.weight #share embeddings
def get_span_scores(self, x):
#produces the span scores s_ij
bos = x.new(x.size(0), 1).fill_(self.pad1)
eos = x.new(x.size(0), 1).fill_(self.pad2)
x = torch.cat([bos, x, eos], 1)
x_vec = self.dropout(self.emb(x))
pos = torch.arange(0, x.size(1)).unsqueeze(0).expand_as(x).long().cuda()
x_vec = x_vec + self.dropout(self.q_pos_emb(pos))
q_h, _ = self.q_leaf_rnn(x_vec)
fwd = q_h[:, 1:, :self.q_dim]
bwd = q_h[:, :-1, self.q_dim:]
fwd_diff = fwd[:, 1:].unsqueeze(1) - fwd[:, :-1].unsqueeze(2)
bwd_diff = bwd[:, :-1].unsqueeze(2) - bwd[:, 1:].unsqueeze(1)
concat = torch.cat([fwd_diff, bwd_diff], 3)
scores = self.q_binary(concat).squeeze(3)
return scores
def get_action_masks(self, actions, length):
#this masks out actions so that we don't incur a loss if some actions are deterministic
#in practice this doesn't really seem to matter
mask = actions.new(actions.size(0), actions.size(1)).fill_(1)
for b in range(actions.size(0)):
num_shift = 0
stack_len = 0
for l in range(actions.size(1)):
if stack_len < 2:
mask[b][l].fill_(0)
if actions[b][l].item() == self.S:
num_shift += 1
stack_len += 1
else:
stack_len -= 1
return mask
def forward(self, x, samples = 1, is_temp = 1., has_eos=True):
#For has eos, if </s> exists, then inference network ignores it.
#Note that </s> is predicted for training since we want the model to know when to stop.
#However it is ignored for PPL evaluation on the version of the PTB dataset from
#the original RNNG paper (Dyer et al. 2016)
init_emb = self.dropout(self.emb(x[:, 0]))
x = x[:, 1:]
batch, length = x.size(0), x.size(1)
if has_eos:
parse_length = length - 1
parse_x = x[:, :-1]
else:
parse_length = length
parse_x = x
word_vecs = self.dropout(self.emb(x))
scores = self.get_span_scores(parse_x)
self.scores = scores
scores = scores / is_temp
self.q_crf._forward(scores)
self.q_crf._entropy(scores)
entropy = self.q_crf.entropy[0][parse_length-1]
crf_input = scores.unsqueeze(1).expand(batch, samples, parse_length, parse_length)
crf_input = crf_input.contiguous().view(batch*samples, parse_length, parse_length)
for i in range(len(self.q_crf.alpha)):
for j in range(len(self.q_crf.alpha)):
self.q_crf.alpha[i][j] = self.q_crf.alpha[i][j].unsqueeze(1).expand(
batch, samples).contiguous().view(batch*samples)
_, log_probs_action_q, tree_brackets, spans = self.q_crf._sample(crf_input, self.q_crf.alpha)
actions = []
for b in range(crf_input.size(0)):
action = get_actions(tree_brackets[b])
if has_eos:
actions.append(action + [self.S, self.R]) #we train the model to generate <s> and then do a final reduce
else:
actions.append(action)
actions = torch.Tensor(actions).float().cuda()
action_masks = self.get_action_masks(actions, length)
num_action = 2*length - 1
batch_expand = batch*samples
contexts = []
log_probs_action_p = [] #conditional prior
init_emb = init_emb.unsqueeze(1).expand(batch, samples, self.w_dim)
init_emb = init_emb.contiguous().view(batch_expand, self.w_dim)
init_stack = self.stack_rnn(init_emb, None)
x_expand = x.unsqueeze(1).expand(batch, samples, length)
x_expand = x_expand.contiguous().view(batch_expand, length)
word_vecs = self.dropout(self.emb(x_expand))
word_vecs = word_vecs.unsqueeze(2)
word_vecs_zeros = torch.zeros_like(word_vecs)
stack = [init_stack]
stack_child = [[] for _ in range(batch_expand)]
stack2 = [[] for _ in range(batch_expand)]
for b in range(batch_expand):
stack2[b].append([[init_stack[l][0][b], init_stack[l][1][b]] for l in range(self.num_layers)])
pointer = [0]*batch_expand
for l in range(num_action):
contexts.append(stack[-1][-1][0])
stack_input = []
child1_h = []
child1_c = []
child2_h = []
child2_c = []
stack_context = []
for b in range(batch_expand):
# batch all the shift/reduce operations separately
if actions[b][l].item() == self.R:
child1 = stack_child[b].pop()
child2 = stack_child[b].pop()
child1_h.append(child1[0])
child1_c.append(child1[1])
child2_h.append(child2[0])
child2_c.append(child2[1])
stack2[b].pop()
stack2[b].pop()
if len(child1_h) > 0:
child1_h = torch.cat(child1_h, 0)
child1_c = torch.cat(child1_c, 0)
child2_h = torch.cat(child2_h, 0)
child2_c = torch.cat(child2_c, 0)
new_child = self.tree_rnn((child1_h, child1_c), (child2_h, child2_c))
child_idx = 0
stack_h = [[[], []] for _ in range(self.num_layers)]
for b in range(batch_expand):
assert(len(stack2[b]) - 1 == len(stack_child[b]))
for k in range(self.num_layers):
stack_h[k][0].append(stack2[b][-1][k][0])
stack_h[k][1].append(stack2[b][-1][k][1])
if actions[b][l].item() == self.S:
input_b = word_vecs[b][pointer[b]]
stack_child[b].append((word_vecs[b][pointer[b]], word_vecs_zeros[b][pointer[b]]))
pointer[b] += 1
else:
input_b = new_child[0][child_idx].unsqueeze(0)
stack_child[b].append((input_b, new_child[1][child_idx].unsqueeze(0)))
child_idx += 1
stack_input.append(input_b)
stack_input = torch.cat(stack_input, 0)
stack_h_all = []
for k in range(self.num_layers):
stack_h_all.append((torch.stack(stack_h[k][0], 0), torch.stack(stack_h[k][1], 0)))
stack_h = self.stack_rnn(stack_input, stack_h_all)
stack.append(stack_h)
for b in range(batch_expand):
stack2[b].append([[stack_h[k][0][b], stack_h[k][1][b]] for k in range(self.num_layers)])
contexts = torch.stack(contexts, 1) #stack contexts
action_logit_p = self.action_mlp_p(contexts).squeeze(2)
action_prob_p = F.sigmoid(action_logit_p).clamp(min=1e-7, max=1-1e-7)
action_shift_score = (1 - action_prob_p).log()
action_reduce_score = action_prob_p.log()
action_score = (1-actions)*action_shift_score + actions*action_reduce_score
action_score = (action_score*action_masks).sum(1)
word_contexts = contexts[actions < 1]
word_contexts = word_contexts.contiguous().view(batch*samples, length, self.h_dim)
log_probs_word = F.log_softmax(self.vocab_mlp(word_contexts), 2)
log_probs_word = torch.gather(log_probs_word, 2, x_expand.unsqueeze(2)).squeeze(2)
log_probs_word = log_probs_word.sum(1)
log_probs_word = log_probs_word.contiguous().view(batch, samples)
log_probs_action_p = action_score.contiguous().view(batch, samples)
log_probs_action_q = log_probs_action_q.contiguous().view(batch, samples)
actions = actions.contiguous().view(batch, samples, -1)
return log_probs_word, log_probs_action_p, log_probs_action_q, actions, entropy
def forward_actions(self, x, actions, has_eos=True):
# this is for when ground through actions are available
init_emb = self.dropout(self.emb(x[:, 0]))
x = x[:, 1:]
if has_eos:
new_actions = []
for action in actions:
new_actions.append(action + [self.S, self.R])
actions = new_actions
batch, length = x.size(0), x.size(1)
word_vecs = self.dropout(self.emb(x))
actions = torch.Tensor(actions).float().cuda()
action_masks = self.get_action_masks(actions, length)
num_action = 2*length - 1
contexts = []
log_probs_action_p = [] #prior
init_stack = self.stack_rnn(init_emb, None)
word_vecs = word_vecs.unsqueeze(2)
word_vecs_zeros = torch.zeros_like(word_vecs)
stack = [init_stack]
stack_child = [[] for _ in range(batch)]
stack2 = [[] for _ in range(batch)]
pointer = [0]*batch
for b in range(batch):
stack2[b].append([[init_stack[l][0][b], init_stack[l][1][b]] for l in range(self.num_layers)])
for l in range(num_action):
contexts.append(stack[-1][-1][0])
stack_input = []
child1_h = []
child1_c = []
child2_h = []
child2_c = []
stack_context = []
for b in range(batch):
if actions[b][l].item() == self.R:
child1 = stack_child[b].pop()
child2 = stack_child[b].pop()
child1_h.append(child1[0])
child1_c.append(child1[1])
child2_h.append(child2[0])
child2_c.append(child2[1])
stack2[b].pop()
stack2[b].pop()
if len(child1_h) > 0:
child1_h = torch.cat(child1_h, 0)
child1_c = torch.cat(child1_c, 0)
child2_h = torch.cat(child2_h, 0)
child2_c = torch.cat(child2_c, 0)
new_child = self.tree_rnn((child1_h, child1_c), (child2_h, child2_c))
child_idx = 0
stack_h = [[[], []] for _ in range(self.num_layers)]
for b in range(batch):
assert(len(stack2[b]) - 1 == len(stack_child[b]))
for k in range(self.num_layers):
stack_h[k][0].append(stack2[b][-1][k][0])
stack_h[k][1].append(stack2[b][-1][k][1])
if actions[b][l].item() == self.S:
input_b = word_vecs[b][pointer[b]]
stack_child[b].append((word_vecs[b][pointer[b]], word_vecs_zeros[b][pointer[b]]))
pointer[b] += 1
else:
input_b = new_child[0][child_idx].unsqueeze(0)
stack_child[b].append((input_b, new_child[1][child_idx].unsqueeze(0)))
child_idx += 1
stack_input.append(input_b)
stack_input = torch.cat(stack_input, 0)
stack_h_all = []
for k in range(self.num_layers):
stack_h_all.append((torch.stack(stack_h[k][0], 0), torch.stack(stack_h[k][1], 0)))
stack_h = self.stack_rnn(stack_input, stack_h_all)
stack.append(stack_h)
for b in range(batch):
stack2[b].append([[stack_h[k][0][b], stack_h[k][1][b]] for k in range(self.num_layers)])
contexts = torch.stack(contexts, 1)
action_logit_p = self.action_mlp_p(contexts).squeeze(2)
action_prob_p = F.sigmoid(action_logit_p).clamp(min=1e-7, max=1-1e-7)
action_shift_score = (1 - action_prob_p).log()
action_reduce_score = action_prob_p.log()
action_score = (1-actions)*action_shift_score + actions*action_reduce_score
action_score = (action_score*action_masks).sum(1)
word_contexts = contexts[actions < 1]
word_contexts = word_contexts.contiguous().view(batch, length, self.h_dim)
log_probs_word = F.log_softmax(self.vocab_mlp(word_contexts), 2)
log_probs_word = torch.gather(log_probs_word, 2, x.unsqueeze(2)).squeeze(2).sum(1)
log_probs_action_p = action_score.contiguous().view(batch)
actions = actions.contiguous().view(batch, 1, -1)
return log_probs_word, log_probs_action_p, actions
def forward_tree(self, x, actions, has_eos=True):
# this is log q( tree | x) for discriminative parser training in supervised RNNG
init_emb = self.dropout(self.emb(x[:, 0]))
x = x[:, 1:-1]
batch, length = x.size(0), x.size(1)
scores = self.get_span_scores(x)
crf_input = scores
gold_spans = scores.new(batch, length, length)
for b in range(batch):
gold_spans[b].copy_(torch.eye(length).cuda())
spans = get_spans(actions[b])
for span in spans:
gold_spans[b][span[0]][span[1]] = 1
self.q_crf._forward(crf_input)
log_Z = self.q_crf.alpha[0][length-1]
span_scores = (gold_spans*scores).sum(2).sum(1)
ll_action_q = span_scores - log_Z
return ll_action_q
def logsumexp(self, x, dim=1):
d = torch.max(x, dim)[0]
if x.dim() == 1:
return torch.log(torch.exp(x - d).sum(dim)) + d
else:
return torch.log(torch.exp(x - d.unsqueeze(dim).expand_as(x)).sum(dim)) + d
================================================
FILE: parse.py
================================================
#!/usr/bin/env python3
import sys
import os
import argparse
import json
import random
import shutil
import copy
import torch
from torch import cuda
import torch.nn as nn
import numpy as np
import time
from utils import *
import utils
import re
parser = argparse.ArgumentParser()
# Data path options
parser.add_argument('--data_file', default='ptb-test.txt')
parser.add_argument('--model_file', default='urnng.pt')
parser.add_argument('--out_file', default='pred-parse.txt')
parser.add_argument('--gold_out_file', default='gold-parse.txt')
parser.add_argument('--lowercase', type=int, default=0)
parser.add_argument('--replace_num', type=int, default=0)
# Inference options
parser.add_argument('--gpu', default=0, type=int, help='which gpu to use')
def is_next_open_bracket(line, start_idx):
for char in line[(start_idx + 1):]:
if char == '(':
return True
elif char == ')':
return False
raise IndexError('Bracket possibly not balanced, open bracket not followed by closed bracket')
def get_between_brackets(line, start_idx):
output = []
for char in line[(start_idx + 1):]:
if char == ')':
break
assert not(char == '(')
output.append(char)
return ''.join(output)
def get_tags_tokens_lowercase(line):
output = []
line_strip = line.rstrip()
for i in range(len(line_strip)):
if i == 0:
assert line_strip[i] == '('
if line_strip[i] == '(' and not(is_next_open_bracket(line_strip, i)): # fulfilling this condition means this is a terminal symbol
output.append(get_between_brackets(line_strip, i))
#print 'output:',output
output_tags = []
output_tokens = []
output_lowercase = []
for terminal in output:
terminal_split = terminal.split()
assert len(terminal_split) == 2 # each terminal contains a POS tag and word
output_tags.append(terminal_split[0])
output_tokens.append(terminal_split[1])
output_lowercase.append(terminal_split[1].lower())
return [output_tags, output_tokens, output_lowercase]
def get_nonterminal(line, start_idx):
assert line[start_idx] == '(' # make sure it's an open bracket
output = []
for char in line[(start_idx + 1):]:
if char == ' ':
break
assert not(char == '(') and not(char == ')')
output.append(char)
return ''.join(output)
def get_actions(line):
output_actions = []
line_strip = line.rstrip()
i = 0
max_idx = (len(line_strip) - 1)
while i <= max_idx:
assert line_strip[i] == '(' or line_strip[i] == ')'
if line_strip[i] == '(':
if is_next_open_bracket(line_strip, i): # open non-terminal
curr_NT = get_nonterminal(line_strip, i)
output_actions.append('NT(' + curr_NT + ')')
i += 1
while line_strip[i] != '(': # get the next open bracket, which may be a terminal or another non-terminal
i += 1
else: # it's a terminal symbol
output_actions.append('SHIFT')
while line_strip[i] != ')':
i += 1
i += 1
while line_strip[i] != ')' and line_strip[i] != '(':
i += 1
else:
output_actions.append('REDUCE')
if i == max_idx:
break
i += 1
while line_strip[i] != ')' and line_strip[i] != '(':
i += 1
assert i == max_idx
return output_actions
def clean_number(w):
new_w = re.sub('[0-9]{1,}([,.]?[0-9]*)*', 'N', w)
return new_w
def main(args):
print('loading model from ' + args.model_file)
checkpoint = torch.load(args.model_file)
model = checkpoint['model']
word2idx = checkpoint['word2idx']
cuda.set_device(args.gpu)
model.eval()
model.cuda()
corpus_f1 = [0., 0., 0.]
sent_f1 = []
pred_out = open(args.out_file, "w")
gold_out = open(args.gold_out_file, "w")
with torch.no_grad():
for j, gold_tree in enumerate(open(args.data_file, "r")):
tree = gold_tree.strip()
action = get_actions(tree)
tags, sent, sent_lower = get_tags_tokens_lowercase(tree)
sent_orig = sent[::]
if args.lowercase == 1:
sent = sent_lower
gold_span, binary_actions, nonbinary_actions = get_nonbinary_spans(action)
length = len(sent)
if args.replace_num == 1:
sent = [clean_number(w) for w in sent]
if length == 1:
continue # we ignore length 1 sents. this doesn't change F1 since we discard trivial spans
sent_idx = [word2idx["<s>"]] + [word2idx[w] if w in word2idx else word2idx["<unk>"] for w in sent]
sents = torch.from_numpy(np.array(sent_idx)).unsqueeze(0)
sents = sents.cuda()
ll_word_all, ll_action_p_all, ll_action_q_all, actions_all, q_entropy = model(
sents, samples = 1, is_temp = 1, has_eos = False)
_, binary_matrix, argmax_spans = model.q_crf._viterbi(model.scores)
tree = get_tree_from_binary_matrix(binary_matrix[0], len(sent))
actions = utils.get_actions(tree)
pred_span= [(a[0], a[1]) for a in argmax_spans[0]]
pred_span_set = set(pred_span[:-1]) #the last span in the list is always the
gold_span_set = set(gold_span[:-1]) #trival sent-level span so we ignore it
tp, fp, fn = get_stats(pred_span_set, gold_span_set)
corpus_f1[0] += tp
corpus_f1[1] += fp
corpus_f1[2] += fn
binary_matrix = binary_matrix[0].cpu().numpy()
pred_tree = {}
for i in range(length):
tag = tags[i] # need gold tags so evalb correctly ignores punctuation
pred_tree[i] = "(" + tag + " " + sent_orig[i] + ")"
for k in np.arange(1, length):
for s in np.arange(length):
t = s + k
if t > length - 1: break
if binary_matrix[s][t] == 1:
nt = "NT-1"
span = "(" + nt + " " + pred_tree[s] + " " + pred_tree[t] + ")"
pred_tree[s] = span
pred_tree[t] = span
pred_tree = pred_tree[0]
pred_out.write(pred_tree.strip() + "\n")
gold_out.write(gold_tree.strip() + "\n")
print(pred_tree)
# sent-level F1 is based on L83-89 from https://github.com/yikangshen/PRPN/test_phrase_grammar.py
overlap = pred_span_set.intersection(gold_span_set)
prec = float(len(overlap)) / (len(pred_span_set) + 1e-8)
reca = float(len(overlap)) / (len(gold_span_set) + 1e-8)
if len(gold_span_set) == 0:
reca = 1.
if len(pred_span_set) == 0:
prec = 1.
f1 = 2 * prec * reca / (prec + reca + 1e-8)
sent_f1.append(f1)
pred_out.close()
gold_out.close()
tp, fp, fn = corpus_f1
prec = tp / (tp + fp)
recall = tp / (tp + fn)
corpus_f1 = 2*prec*recall/(prec+recall) if prec+recall > 0 else 0.
print('Corpus F1: %.2f, Sentence F1: %.2f' %
(corpus_f1*100, np.mean(np.array(sent_f1))*100))
if __name__ == '__main__':
args = parser.parse_args()
main(args)
================================================
FILE: preprocess.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Create data files
"""
import os
import sys
import argparse
import numpy as np
import pickle
import itertools
from collections import defaultdict
import utils
import re
class Indexer:
def __init__(self, symbols = ["<pad>","<unk>","<s>","</s>"]):
self.vocab = defaultdict(int)
self.PAD = symbols[0]
self.UNK = symbols[1]
self.BOS = symbols[2]
self.EOS = symbols[3]
self.d = {self.PAD: 0, self.UNK: 1, self.BOS: 2, self.EOS: 3}
self.idx2word = {}
def add_w(self, ws):
for w in ws:
if w not in self.d:
self.d[w] = len(self.d)
def convert(self, w):
return self.d[w] if w in self.d else self.d[self.UNK]
def convert_sequence(self, ls):
return [self.convert(l) for l in ls]
def write(self, outfile):
out = open(outfile, "w")
items = [(v, k) for k, v in self.d.items()]
items.sort()
for v, k in items:
out.write(" ".join([k, str(v)]) + "\n")
out.close()
def prune_vocab(self, k, cnt = False):
vocab_list = [(word, count) for word, count in self.vocab.items()]
if cnt:
self.pruned_vocab = {pair[0]:pair[1] for pair in vocab_list if pair[1] > k}
else:
vocab_list.sort(key = lambda x: x[1], reverse=True)
k = min(k, len(vocab_list))
self.pruned_vocab = {pair[0]:pair[1] for pair in vocab_list[:k]}
for word in self.pruned_vocab:
if word not in self.d:
self.d[word] = len(self.d)
for word, idx in self.d.items():
self.idx2word[idx] = word
def load_vocab(self, vocab_file):
self.d = {}
self.idx2word = {}
for line in open(vocab_file, 'r'):
v, k = line.strip().split()
self.d[v] = int(k)
for word, idx in self.d.items():
self.idx2word[idx] = word
def is_next_open_bracket(line, start_idx):
for char in line[(start_idx + 1):]:
if char == '(':
return True
elif char == ')':
return False
raise IndexError('Bracket possibly not balanced, open bracket not followed by closed bracket')
def get_between_brackets(line, start_idx):
output = []
for char in line[(start_idx + 1):]:
if char == ')':
break
assert not(char == '(')
output.append(char)
return ''.join(output)
def get_tags_tokens_lowercase(line):
output = []
line_strip = line.rstrip()
for i in range(len(line_strip)):
if i == 0:
assert line_strip[i] == '('
if line_strip[i] == '(' and not(is_next_open_bracket(line_strip, i)): # fulfilling this condition means this is a terminal symbol
output.append(get_between_brackets(line_strip, i))
#print 'output:',output
output_tags = []
output_tokens = []
output_lowercase = []
for terminal in output:
terminal_split = terminal.split()
# print(terminal, terminal_split)
assert len(terminal_split) == 2 # each terminal contains a POS tag and word
output_tags.append(terminal_split[0])
output_tokens.append(terminal_split[1])
output_lowercase.append(terminal_split[1].lower())
return [output_tags, output_tokens, output_lowercase]
def get_nonterminal(line, start_idx):
assert line[start_idx] == '(' # make sure it's an open bracket
output = []
for char in line[(start_idx + 1):]:
if char == ' ':
break
assert not(char == '(') and not(char == ')')
output.append(char)
return ''.join(output)
def get_actions(line):
output_actions = []
line_strip = line.rstrip()
i = 0
max_idx = (len(line_strip) - 1)
while i <= max_idx:
assert line_strip[i] == '(' or line_strip[i] == ')'
if line_strip[i] == '(':
if is_next_open_bracket(line_strip, i): # open non-terminal
curr_NT = get_nonterminal(line_strip, i)
output_actions.append('NT(' + curr_NT + ')')
i += 1
while line_strip[i] != '(': # get the next open bracket, which may be a terminal or another non-terminal
i += 1
else: # it's a terminal symbol
output_actions.append('SHIFT')
while line_strip[i] != ')':
i += 1
i += 1
while line_strip[i] != ')' and line_strip[i] != '(':
i += 1
else:
output_actions.append('REDUCE')
if i == max_idx:
break
i += 1
while line_strip[i] != ')' and line_strip[i] != '(':
i += 1
assert i == max_idx
return output_actions
def pad(ls, length, symbol):
if len(ls) >= length:
return ls[:length]
return ls + [symbol] * (length -len(ls))
def clean_number(w):
new_w = re.sub('[0-9]{1,}([,.]?[0-9]*)*', 'N', w)
return new_w
def get_data(args):
indexer = Indexer(["<pad>","<unk>","<s>","</s>"])
def make_vocab(textfile, seqlength, minseqlength, lowercase, replace_num,
train=1, apply_length_filter=1):
num_sents = 0
max_seqlength = 0
for tree in open(textfile, 'r'):
tree = tree.strip()
tags, sent, sent_lower = get_tags_tokens_lowercase(tree)
assert(len(tags) == len(sent))
if lowercase == 1:
sent = sent_lower
if replace_num == 1:
sent = [clean_number(w) for w in sent]
if (len(sent) > seqlength and apply_length_filter == 1) or len(sent) < minseqlength:
continue
num_sents += 1
max_seqlength = max(max_seqlength, len(sent))
if train == 1:
for word in sent:
indexer.vocab[word] += 1
return num_sents, max_seqlength
def convert(textfile, lowercase, replace_num,
batchsize, seqlength, minseqlength, outfile, num_sents, max_sent_l=0,
shuffle=0, include_boundary=1, apply_length_filter=1):
newseqlength = seqlength
if include_boundary == 1:
newseqlength += 2 #add 2 for EOS and BOS
sents = np.zeros((num_sents, newseqlength), dtype=int)
sent_lengths = np.zeros((num_sents,), dtype=int)
dropped = 0
sent_id = 0
other_data = []
for tree in open(textfile, 'r'):
tree = tree.strip()
action = get_actions(tree)
tags, sent, sent_lower = get_tags_tokens_lowercase(tree)
assert(len(tags) == len(sent))
if lowercase == 1:
sent = sent_lower
if (len(sent) > seqlength and apply_length_filter == 1) or len(sent) < minseqlength:
continue
sent_str = " ".join(sent)
if replace_num == 1:
sent = [clean_number(w) for w in sent]
if include_boundary == 1:
sent = [indexer.BOS] + sent + [indexer.EOS]
max_sent_l = max(len(sent), max_sent_l)
sent_pad = pad(sent, newseqlength, indexer.PAD)
sents[sent_id] = np.array(indexer.convert_sequence(sent_pad), dtype=int)
sent_lengths[sent_id] = (sents[sent_id] != 0).sum()
span, binary_actions, nonbinary_actions = utils.get_nonbinary_spans(action)
other_data.append([sent_str, tags, action,
binary_actions, nonbinary_actions, span, tree])
assert(2*(len(sent)- 2) - 1 == len(binary_actions))
assert(sum(binary_actions) + 1 == len(sent) - 2)
sent_id += 1
if sent_id % 100000 == 0:
print("{}/{} sentences processed".format(sent_id, num_sents))
print(sent_id, num_sents)
if shuffle == 1:
rand_idx = np.random.permutation(sent_id)
sents = sents[rand_idx]
sent_lengths = sent_lengths[rand_idx]
other_data = [other_data[idx] for idx in rand_idx]
print(len(sents), len(other_data))
#break up batches based on source lengths
sent_lengths = sent_lengths[:sent_id]
sent_sort = np.argsort(sent_lengths)
sents = sents[sent_sort]
other_data = [other_data[idx] for idx in sent_sort]
sent_l = sent_lengths[sent_sort]
curr_l = 1
l_location = [] #idx where sent length changes
for j,i in enumerate(sent_sort):
if sent_lengths[i] > curr_l:
curr_l = sent_lengths[i]
l_location.append(j)
l_location.append(len(sents))
#get batch sizes
curr_idx = 0
batch_idx = [0]
nonzeros = []
batch_l = []
batch_w = []
for i in range(len(l_location)-1):
while curr_idx < l_location[i+1]:
curr_idx = min(curr_idx + batchsize, l_location[i+1])
batch_idx.append(curr_idx)
for i in range(len(batch_idx)-1):
batch_l.append(batch_idx[i+1] - batch_idx[i])
batch_w.append(sent_l[batch_idx[i]])
# Write output
f = {}
f["source"] = sents
f["other_data"] = other_data
f["batch_l"] = np.array(batch_l, dtype=int)
f["source_l"] = np.array(batch_w, dtype=int)
f["sents_l"] = np.array(sent_l, dtype = int)
f["batch_idx"] = np.array(batch_idx[:-1], dtype=int)
f["vocab_size"] = np.array([len(indexer.d)])
f["idx2word"] = indexer.idx2word
f["word2idx"] = {word : idx for idx, word in indexer.idx2word.items()}
print("Saved {} sentences (dropped {} due to length/unk filter)".format(
len(f["source"]), dropped))
pickle.dump(f, open(outfile, 'wb'))
return max_sent_l
print("First pass through data to get vocab...")
num_sents_train, train_seqlength = make_vocab(args.trainfile, args.seqlength, args.minseqlength,
args.lowercase, args.replace_num, 1, 1)
print("Number of sentences in training: {}".format(num_sents_train))
num_sents_valid, valid_seqlength = make_vocab(args.valfile, args.seqlength, args.minseqlength,
args.lowercase, args.replace_num, 0, 0)
print("Number of sentences in valid: {}".format(num_sents_valid))
num_sents_test, test_seqlength = make_vocab(args.testfile, args.seqlength, args.minseqlength,
args.lowercase, args.replace_num, 0, 0)
print("Number of sentences in test: {}".format(num_sents_test))
if args.vocabminfreq >= 0:
indexer.prune_vocab(args.vocabminfreq, True)
else:
indexer.prune_vocab(args.vocabsize, False)
if args.vocabfile != '':
print('Loading pre-specified source vocab from ' + args.vocabfile)
indexer.load_vocab(args.vocabfile)
indexer.write(args.outputfile + ".dict")
print("Vocab size: Original = {}, Pruned = {}".format(len(indexer.vocab),
len(indexer.d)))
print(train_seqlength, valid_seqlength, test_seqlength)
max_sent_l = 0
max_sent_l = convert(args.testfile, args.lowercase, args.replace_num,
args.batchsize, test_seqlength, args.minseqlength,
args.outputfile + "-test.pkl", num_sents_test,
max_sent_l, args.shuffle, args.include_boundary, 0)
max_sent_l = convert(args.valfile, args.lowercase, args.replace_num,
args.batchsize, valid_seqlength, args.minseqlength,
args.outputfile + "-val.pkl", num_sents_valid,
max_sent_l, args.shuffle, args.include_boundary, 0)
max_sent_l = convert(args.trainfile, args.lowercase, args.replace_num,
args.batchsize, args.seqlength, args.minseqlength,
args.outputfile + "-train.pkl", num_sents_train,
max_sent_l, args.shuffle, args.include_boundary, 1)
print("Max sent length (before dropping): {}".format(max_sent_l))
def main(arguments):
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--vocabsize', help="Size of source vocabulary, constructed "
"by taking the top X most frequent words. "
" Rest are replaced with special UNK tokens.",
type=int, default=10000)
parser.add_argument('--vocabminfreq', help="Minimum frequency for vocab. Use this instead of "
"vocabsize if > 0",
type=int, default=1)
parser.add_argument('--include_boundary', help="Add BOS/EOS tokens", type=int, default=1)
parser.add_argument('--lowercase', help="Lower case", type=int, default=0)
parser.add_argument('--replace_num', help="Replace numbers with N", type=int, default=0)
parser.add_argument('--trainfile', help="Path to training data.", required=True)
parser.add_argument('--valfile', help="Path to validation data.", required=True)
parser.add_argument('--testfile', help="Path to test validation data.", required=True)
parser.add_argument('--batchsize', help="Size of each minibatch.", type=int, default=16)
parser.add_argument('--seqlength', help="Maximum sequence length. Sequences longer "
"than this are dropped.", type=int, default=200)
parser.add_argument('--minseqlength', help="Minimum sequence length. Sequences shorter "
"than this are dropped.", type=int, default=0)
parser.add_argument('--outputfile', help="Prefix of the output file names. ", type=str,
required=True)
parser.add_argument('--vocabfile', help="If working with a preset vocab, "
"then including this will ignore srcvocabsize and use the"
"vocab provided here.",
type = str, default='')
parser.add_argument('--shuffle', help="If = 1, shuffle sentences before sorting (based on "
"source length).",
type = int, default = 1)
args = parser.parse_args(arguments)
np.random.seed(3435)
get_data(args)
if __name__ == '__main__':
sys.exit(main(sys.argv[1:]))
================================================
FILE: train.py
================================================
#!/usr/bin/env python3
import sys
import os
import argparse
import json
import random
import shutil
import copy
import torch
from torch import cuda
import torch.nn as nn
from torch.autograd import Variable
from torch.nn.parameter import Parameter
import torch.nn.functional as F
import numpy as np
import time
import logging
from data import Dataset
from models import RNNG
from utils import *
parser = argparse.ArgumentParser()
# Data path options
parser.add_argument('--train_file', default='data/ptb-1unk-train.pkl')
parser.add_argument('--val_file', default='data/ptb-1unk-val.pkl')
parser.add_argument('--train_from', default='')
# Model options
parser.add_argument('--w_dim', default=650, type=int, help='hidden dimension for LM/RNNG')
parser.add_argument('--h_dim', default=650, type=int, help='hidden dimension for LM/RNNG')
parser.add_argument('--q_dim', default=256, type=int, help='hidden dimension for variational RNN')
parser.add_argument('--num_layers', default=2, type=int, help='number of layers in LM and the stack LSTM (for RNNG)')
parser.add_argument('--dropout', default=0.5, type=float, help='dropout rate')
# Optimization options
parser.add_argument('--count_eos_ppl', default=0, type=int, help='whether to count eos in val PPL')
parser.add_argument('--save_path', default='urnng.pt', help='where to save the data')
parser.add_argument('--num_epochs', default=18, type=int, help='number of training epochs')
parser.add_argument('--min_epochs', default=8, type=int, help='do not decay learning rate for at least this many epochs')
parser.add_argument('--mode', default='unsupervised', type=str, choices=['unsupervised', 'supervised'])
parser.add_argument('--mc_samples', default=5, type=int,
help='how many samples for IWAE bound calc for evaluation')
parser.add_argument('--samples', default=8, type=int,
help='how many samples for score function gradients')
parser.add_argument('--lr', default=1, type=float, help='starting learning rate')
parser.add_argument('--q_lr', default=0.0001, type=float, help='learning rate for inference network q')
parser.add_argument('--action_lr', default=0.1, type=float, help='learning rate for action layer')
parser.add_argument('--decay', default=0.5, type=float, help='')
parser.add_argument('--kl_warmup', default=2, type=int, help='')
parser.add_argument('--train_q_epochs', default=2, type=int, help='')
parser.add_argument('--param_init', default=0.1, type=float, help='parameter initialization (over uniform)')
parser.add_argument('--max_grad_norm', default=5, type=float, help='gradient clipping parameter')
parser.add_argument('--q_max_grad_norm', default=1, type=float, help='gradient clipping parameter for q')
parser.add_argument('--gpu', default=2, type=int, help='which gpu to use')
parser.add_argument('--seed', default=3435, type=int, help='random seed')
parser.add_argument('--print_every', type=int, default=500, help='print stats after this many batches')
def main(args):
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_data = Dataset(args.train_file)
val_data = Dataset(args.val_file)
vocab_size = int(train_data.vocab_size)
print('Train: %d sents / %d batches, Val: %d sents / %d batches' %
(train_data.sents.size(0), len(train_data), val_data.sents.size(0),
len(val_data)))
print('Vocab size: %d' % vocab_size)
cuda.set_device(args.gpu)
if args.train_from == '':
model = RNNG(vocab = vocab_size,
w_dim = args.w_dim,
h_dim = args.h_dim,
dropout = args.dropout,
num_layers = args.num_layers,
q_dim = args.q_dim)
if args.param_init > 0:
for param in model.parameters():
param.data.uniform_(-args.param_init, args.param_init)
else:
print('loading model from ' + args.train_from)
checkpoint = torch.load(args.train_from)
model = checkpoint['model']
print("model architecture")
print(model)
q_params = []
action_params = []
model_params = []
for name, param in model.named_parameters():
if 'action' in name:
print(name)
action_params.append(param)
elif 'q_' in name:
print(name)
q_params.append(param)
else:
model_params.append(param)
q_lr = args.q_lr
optimizer = torch.optim.SGD(model_params, lr=args.lr)
q_optimizer = torch.optim.Adam(q_params, lr=q_lr)
action_optimizer = torch.optim.SGD(action_params, lr=args.action_lr)
model.train()
model.cuda()
epoch = 0
decay= 0
if args.kl_warmup > 0:
kl_pen = 0.
kl_warmup_batch = 1./(args.kl_warmup * len(train_data))
else:
kl_pen = 1.
best_val_ppl = 5e5
best_val_f1 = 0
samples = args.samples
best_val_ppl, best_val_f1 = eval(val_data, model, samples = args.mc_samples,
count_eos_ppl = args.count_eos_ppl)
all_stats = [[0., 0., 0.]] #true pos, false pos, false neg for f1 calc
while epoch < args.num_epochs:
start_time = time.time()
epoch += 1
if epoch > args.train_q_epochs:
#stop training q after this many epochs
args.q_lr = 0.
for param_group in q_optimizer.param_groups:
param_group['lr'] = args.q_lr
print('Starting epoch %d' % epoch)
train_nll_recon = 0.
train_nll_iwae = 0.
train_kl = 0.
train_q_entropy = 0.
num_sents = 0.
num_words = 0.
b = 0
for i in np.random.permutation(len(train_data)):
if args.kl_warmup > 0:
kl_pen = min(1., kl_pen + kl_warmup_batch)
sents, length, batch_size, gold_actions, gold_spans, gold_binary_trees, other_data = train_data[i]
if length == 1:
# we ignore length 1 sents during training/eval since we work with binary trees only
continue
sents = sents.cuda()
b += 1
q_optimizer.zero_grad()
optimizer.zero_grad()
action_optimizer.zero_grad()
if args.mode == 'unsupervised':
ll_word, ll_action_p, ll_action_q, all_actions, q_entropy = model(sents, samples=samples,
has_eos = True)
log_f = ll_word + kl_pen*ll_action_p
iwae_ll = log_f.mean(1).detach() + kl_pen*q_entropy.detach()
obj = log_f.mean(1)
if epoch < args.train_q_epochs:
obj += kl_pen*q_entropy
baseline = torch.zeros_like(log_f)
baseline_k = torch.zeros_like(log_f)
for k in range(samples):
baseline_k.copy_(log_f)
baseline_k[:, k].fill_(0)
baseline[:, k] = baseline_k.detach().sum(1) / (samples - 1)
obj += ((log_f.detach() - baseline.detach())*ll_action_q).mean(1)
kl = (ll_action_q - ll_action_p).mean(1).detach()
ll_word = ll_word.mean(1)
train_q_entropy += q_entropy.sum().item()
else:
gold_actions = gold_binary_trees
ll_action_q = model.forward_tree(sents, gold_actions, has_eos=True)
ll_word, ll_action_p, all_actions = model.forward_actions(sents, gold_actions)
obj = ll_word + ll_action_p + ll_action_q
kl = -ll_action_q
iwae_ll = ll_word + ll_action_p
train_nll_iwae += -iwae_ll.sum().item()
actions = all_actions[:, 0].long().cpu()
train_nll_recon += -ll_word.sum().item()
train_kl += kl.sum().item()
(-obj.mean()).backward()
if args.max_grad_norm > 0:
torch.nn.utils.clip_grad_norm_(model_params + action_params, args.max_grad_norm)
if args.q_max_grad_norm > 0:
torch.nn.utils.clip_grad_norm_(q_params, args.q_max_grad_norm)
q_optimizer.step()
optimizer.step()
action_optimizer.step()
num_sents += batch_size
num_words += batch_size * length
for bb in range(batch_size):
action = list(actions[bb].numpy())
span_b = get_spans(action)
span_b_set = set(span_b[:-1]) #ignore the sentence-level trivial span
update_stats(span_b_set, [set(gold_spans[bb][:-1])], all_stats)
if b % args.print_every == 0:
all_f1 = get_f1(all_stats)
param_norm = sum([p.norm()**2 for p in model.parameters()]).item()**0.5
log_str = 'Epoch: %d, Batch: %d/%d, LR: %.4f, qLR: %.5f, qEnt: %.4f, TrainVAEPPL: %.2f, ' + \
'TrainReconPPL: %.2f, TrainKL: %.2f, TrainIWAEPPL: %.2f, ' + \
'|Param|: %.2f, BestValPerf: %.2f, BestValF1: %.2f, KLPen: %.4f, ' + \
'GoldTreeF1: %.2f, Throughput: %.2f examples/sec'
print(log_str %
(epoch, b, len(train_data), args.lr, args.q_lr, train_q_entropy / num_sents,
np.exp((train_nll_recon + train_kl)/ num_words),
np.exp(train_nll_recon/num_words), train_kl / num_sents,
np.exp(train_nll_iwae/num_words),
param_norm, best_val_ppl, best_val_f1, kl_pen,
all_f1[0], num_sents / (time.time() - start_time)))
sent_str = [train_data.idx2word[word_idx] for word_idx in list(sents[-1][1:-1].cpu().numpy())]
print("PRED:", get_tree(action[:-2], sent_str))
print("GOLD:", get_tree(gold_binary_trees[-1], sent_str))
print('--------------------------------')
print('Checking validation perf...')
val_ppl, val_f1 = eval(val_data, model,
samples = args.mc_samples, count_eos_ppl = args.count_eos_ppl)
print('--------------------------------')
if val_ppl < best_val_ppl:
best_val_ppl = val_ppl
best_val_f1 = val_f1
checkpoint = {
'args': args.__dict__,
'model': model.cpu(),
'word2idx': train_data.word2idx,
'idx2word': train_data.idx2word
}
print('Saving checkpoint to %s' % args.save_path)
torch.save(checkpoint, args.save_path)
model.cuda()
else:
if epoch > args.min_epochs:
decay = 1
if decay == 1:
args.lr = args.decay*args.lr
args.q_lr = args.decay*args.q_lr
args.action_lr = args.decay*args.action_lr
for param_group in optimizer.param_groups:
param_group['lr'] = args.lr
for param_group in q_optimizer.param_groups:
param_group['lr'] = args.q_lr
for param_group in action_optimizer.param_groups:
param_group['lr'] = args.action_lr
if args.lr < 0.03:
break
print("Finished training!")
def eval(data, model, samples = 0, count_eos_ppl = 0):
model.eval()
num_sents = 0
num_words = 0
total_nll_recon = 0.
total_kl = 0.
total_nll_iwae = 0.
corpus_f1 = [0., 0., 0.]
sent_f1 = []
with torch.no_grad():
for i in list(reversed(range(len(data)))):
sents, length, batch_size, gold_actions, gold_spans, gold_binary_trees, other_data = data[i]
if length == 1: # length 1 sents are ignored since URNNG needs at least length 2 sents
continue
if args.count_eos_ppl == 1:
tree_length = length
length += 1
else:
sents = sents[:, :-1]
tree_length = length
sents = sents.cuda()
ll_word_all, ll_action_p_all, ll_action_q_all, actions_all, q_entropy = model(sents,
samples = samples, has_eos = count_eos_ppl == 1)
ll_word, ll_action_p, ll_action_q = ll_word_all.mean(1), ll_action_p_all.mean(1), ll_action_q_all.mean(1)
kl = ll_action_q - ll_action_p
_, binary_matrix, argmax_spans = model.q_crf._viterbi(model.scores)
actions = []
for b in range(batch_size):
tree = get_tree_from_binary_matrix(binary_matrix[b], tree_length)
actions.append(get_actions(tree))
actions = torch.Tensor(actions).long()
total_nll_recon += -ll_word.sum().item()
total_kl += kl.sum().item()
num_sents += batch_size
num_words += batch_size * length
if samples > 0:
#PPL estimate based on IWAE
sample_ll = torch.zeros(batch_size, samples)
for j in range(samples):
ll_word_j, ll_action_p_j, ll_action_q_j = ll_word_all[:, j], ll_action_p_all[:, j], ll_action_q_all[:, j]
sample_ll[:, j].copy_(ll_word_j + ll_action_p_j - ll_action_q_j)
ll_iwae = model.logsumexp(sample_ll, 1) - np.log(samples)
total_nll_iwae -= ll_iwae.sum().item()
for b in range(batch_size):
action = list(actions[b].numpy())
span_b = get_spans(action)
span_b = argmax_spans[b]
span_b_set = set(span_b[:-1])
gold_b_set = set(gold_spans[b][:-1])
tp, fp, fn = get_stats(span_b_set, gold_b_set)
corpus_f1[0] += tp
corpus_f1[1] += fp
corpus_f1[2] += fn
# sent-level F1 is based on L83-89 from https://github.com/yikangshen/PRPN/test_phrase_grammar.py
model_out = span_b_set
std_out = gold_b_set
overlap = model_out.intersection(std_out)
prec = float(len(overlap)) / (len(model_out) + 1e-8)
reca = float(len(overlap)) / (len(std_out) + 1e-8)
if len(std_out) == 0:
reca = 1.
if len(model_out) == 0:
prec = 1.
f1 = 2 * prec * reca / (prec + reca + 1e-8)
sent_f1.append(f1)
tp, fp, fn = corpus_f1
prec = tp / (tp + fp)
recall = tp / (tp + fn)
corpus_f1 = 2*prec*recall/(prec+recall)*100 if prec+recall > 0 else 0.
sent_f1 = np.mean(np.array(sent_f1))*100
elbo_ppl = np.exp((total_nll_recon + total_kl) / num_words)
recon_ppl = np.exp(total_nll_recon / num_words)
iwae_ppl = np.exp(total_nll_iwae /num_words)
kl = total_kl / num_sents
print('ElboPPL: %.2f, ReconPPL: %.2f, KL: %.4f, IwaePPL: %.2f, CorpusF1: %.2f, SentAvgF1: %.2f' %
(elbo_ppl, recon_ppl, kl, iwae_ppl, corpus_f1, sent_f1))
#note that corpus F1 printed here is different from what you should get from
#evalb since we do not ignore any tags (e.g. punctuation), while evalb ignores it
model.train()
return iwae_ppl, corpus_f1
if __name__ == '__main__':
args = parser.parse_args()
main(args)
================================================
FILE: train_lm.py
================================================
#!/usr/bin/env python3
import sys
import os
import argparse
import json
import random
import shutil
import copy
import torch
from torch import cuda
import torch.nn as nn
from torch.autograd import Variable
from torch.nn.parameter import Parameter
import torch.nn.functional as F
import numpy as np
import time
import logging
from data import Dataset
from models import RNNLM
from utils import *
parser = argparse.ArgumentParser()
# Data path options
parser.add_argument('--train_file', default='data/ptb-train.pkl')
parser.add_argument('--val_file', default='data/ptb-val.pkl')
parser.add_argument('--test_file', default='data/ptb-test.pkl')
parser.add_argument('--train_from', default='')
# Model options
parser.add_argument('--w_dim', default=650, type=int, help='hidden dimension for LM')
parser.add_argument('--h_dim', default=650, type=int, help='hidden dimension for LM')
parser.add_argument('--num_layers', default=2, type=int, help='number of layers in LM and the stack LSTM (for RNNG)')
parser.add_argument('--dropout', default=0.5, type=float, help='dropout rate')
# Optimization options
parser.add_argument('--count_eos_ppl', default=0, type=int, help='whether to count eos in val PPL')
parser.add_argument('--test', default=0, type=int, help='')
parser.add_argument('--save_path', default='urnng.pt', help='where to save the data')
parser.add_argument('--num_epochs', default=30, type=int, help='number of training epochs')
parser.add_argument('--min_epochs', default=8, type=int, help='do not decay learning rate for at least this many epochs')
parser.add_argument('--lr', default=1, type=float, help='starting learning rate')
parser.add_argument('--decay', default=0.5, type=float, help='')
parser.add_argument('--param_init', default=0.1, type=float, help='parameter initialization (over uniform)')
parser.add_argument('--max_grad_norm', default=5, type=float, help='gradient clipping parameter')
parser.add_argument('--gpu', default=2, type=int, help='which gpu to use')
parser.add_argument('--seed', default=3435, type=int, help='random seed')
parser.add_argument('--print_every', type=int, default=500, help='print stats after this many batches')
def main(args):
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_data = Dataset(args.train_file)
val_data = Dataset(args.val_file)
vocab_size = int(train_data.vocab_size)
print('Train: %d sents / %d batches, Val: %d sents / %d batches' %
(train_data.sents.size(0), len(train_data), val_data.sents.size(0),
len(val_data)))
print('Vocab size: %d' % vocab_size)
cuda.set_device(args.gpu)
if args.train_from == '':
model = RNNLM(vocab = vocab_size,
w_dim = args.w_dim,
h_dim = args.h_dim,
dropout = args.dropout,
num_layers = args.num_layers)
if args.param_init > 0:
for param in model.parameters():
param.data.uniform_(-args.param_init, args.param_init)
else:
print('loading model from ' + args.train_from)
checkpoint = torch.load(args.train_from)
model = checkpoint['model']
print("model architecture")
print(model)
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
model.train()
model.cuda()
epoch = 0
decay= 0
if args.test == 1:
test_data = Dataset(args.test_file)
test_ppl = eval(test_data, model, count_eos_ppl = args.count_eos_ppl)
sys.exit(0)
best_val_ppl = eval(val_data, model, count_eos_ppl = args.count_eos_ppl)
while epoch < args.num_epochs:
start_time = time.time()
epoch += 1
print('Starting epoch %d' % epoch)
train_nll = 0.
num_sents = 0.
num_words = 0.
b = 0
for i in np.random.permutation(len(train_data)):
sents, length, batch_size, gold_actions, gold_spans, gold_binary_trees, other_data = train_data[i]
if length == 1:
continue
sents = sents.cuda()
b += 1
optimizer.zero_grad()
optimizer.zero_grad()
nll = -model(sents).mean()
train_nll += nll.item()*batch_size
nll.backward()
if args.max_grad_norm > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
optimizer.step()
num_sents += batch_size
num_words += batch_size * length
if b % args.print_every == 0:
param_norm = sum([p.norm()**2 for p in model.parameters()]).item()**0.5
print('Epoch: %d, Batch: %d/%d, LR: %.4f, TrainPPL: %.2f, |Param|: %.4f, BestValPerf: %.2f, Throughput: %.2f examples/sec' %
(epoch, b, len(train_data), args.lr, np.exp(train_nll / num_words),
param_norm, best_val_ppl, num_sents / (time.time() - start_time)))
print('--------------------------------')
print('Checking validation perf...')
val_ppl = eval(val_data, model, count_eos_ppl = args.count_eos_ppl)
print('--------------------------------')
if val_ppl < best_val_ppl:
best_val_ppl = val_ppl
checkpoint = {
'args': args.__dict__,
'model': model.cpu(),
'word2idx': train_data.word2idx,
'idx2word': train_data.idx2word
}
print('Saving checkpoint to %s' % args.save_path)
torch.save(checkpoint, args.save_path)
model.cuda()
else:
if epoch > args.min_epochs:
decay = 1
if decay == 1:
args.lr = args.decay*args.lr
for param_group in optimizer.param_groups:
param_group['lr'] = args.lr
if args.lr < 0.03:
break
print("Finished training")
def eval(data, model, count_eos_ppl = 0):
model.eval()
num_words = 0
total_nll = 0.
with torch.no_grad():
for i in list(reversed(range(len(data)))):
sents, length, batch_size, gold_actions, gold_spans, gold_binary_trees, other_data = data[i]
if length == 1: #we ignore length 1 sents in URNNG eval so do this for LM too
continue
if args.count_eos_ppl == 1:
length += 1
else:
sents = sents[:, :-1]
sents = sents.cuda()
num_words += length * batch_size
nll = -model(sents).mean()
total_nll += nll.item()*batch_size
ppl = np.exp(total_nll / num_words)
print('PPL: %.2f' % (ppl))
model.train()
return ppl
if __name__ == '__main__':
args = parser.parse_args()
main(args)
================================================
FILE: utils.py
================================================
#!/usr/bin/env python3
import numpy as np
import itertools
import random
def get_actions(tree, SHIFT = 0, REDUCE = 1, OPEN='(', CLOSE=')'):
#input tree in bracket form: ((A B) (C D))
#output action sequence: 0 0 1 0 0 1 1, where 0 is SHIFT and 1 is REDUCE
actions = []
tree = tree.strip()
i = 0
num_shift = 0
num_reduce = 0
left = 0
right = 0
while i < len(tree):
if tree[i] != ' ' and tree[i] != OPEN and tree[i] != CLOSE: #terminal
if tree[i-1] == OPEN or tree[i-1] == ' ':
actions.append(SHIFT)
num_shift += 1
elif tree[i] == CLOSE:
actions.append(REDUCE)
num_reduce += 1
right += 1
elif tree[i] == OPEN:
left += 1
i += 1
assert(num_shift == num_reduce + 1)
return actions
def get_tree(actions, sent = None, SHIFT = 0, REDUCE = 1):
#input action and sent (lists), e.g. S S R S S R R, A B C D
#output tree ((A B) (C D))
stack = []
pointer = 0
if sent is None:
sent = list(map(str, range((len(actions)+1) // 2)))
for action in actions:
if action == SHIFT:
word = sent[pointer]
stack.append(word)
pointer += 1
elif action == REDUCE:
right = stack.pop()
left = stack.pop()
stack.append('(' + left + ' ' + right + ')')
assert(len(stack) == 1)
return stack[-1]
def get_spans(actions, SHIFT = 0, REDUCE = 1):
sent = list(range((len(actions)+1) // 2))
spans = []
pointer = 0
stack = []
for action in actions:
if action == SHIFT:
word = sent[pointer]
stack.append(word)
pointer += 1
elif action == REDUCE:
right = stack.pop()
left = stack.pop()
if isinstance(left, int):
left = (left, None)
if isinstance(right, int):
right = (None, right)
new_span = (left[0], right[1])
spans.append(new_span)
stack.append(new_span)
return spans
def get_stats(span1, span2):
tp = 0
fp = 0
fn = 0
for span in span1:
if span in span2:
tp += 1
else:
fp += 1
for span in span2:
if span not in span1:
fn += 1
return tp, fp, fn
def update_stats(pred_span, gold_spans, stats):
for gold_span, stat in zip(gold_spans, stats):
tp, fp, fn = get_stats(pred_span, gold_span)
stat[0] += tp
stat[1] += fp
stat[2] += fn
def get_f1(stats):
f1s = []
for stat in stats:
prec = stat[0] / (stat[0] + stat[1])
recall = stat[0] / (stat[0] + stat[2])
f1 = 2*prec*recall / (prec + recall)*100 if prec+recall > 0 else 0.
f1s.append(f1)
return f1s
def span_str(start = None, end = None):
assert(start is not None or end is not None)
if start is None:
return ' ' + str(end) + ')'
elif end is None:
return '(' + str(start) + ' '
else:
return ' (' + str(start) + ' ' + str(end) + ') '
def get_tree_from_binary_matrix(matrix, length):
sent = list(map(str, range(length)))
n = len(sent)
tree = {}
for i in range(n):
tree[i] = sent[i]
for k in np.arange(1, n):
for s in np.arange(n):
t = s + k
if t > n-1:
break
if matrix[s][t].item() == 1:
span = '(' + tree[s] + ' ' + tree[t] + ')'
tree[s] = span
tree[t] = span
return tree[0]
def get_nonbinary_spans(actions, SHIFT = 0, REDUCE = 1):
spans = []
stack = []
pointer = 0
binary_actions = []
nonbinary_actions = []
num_shift = 0
num_reduce = 0
for action in actions:
# print(action, stack)
if action == "SHIFT":
nonbinary_actions.append(SHIFT)
stack.append((pointer, pointer))
pointer += 1
binary_actions.append(SHIFT)
num_shift += 1
elif action[:3] == 'NT(':
stack.append('(')
elif action == "REDUCE":
nonbinary_actions.append(REDUCE)
right = stack.pop()
left = right
n = 1
while stack[-1] is not '(':
left = stack.pop()
n += 1
span = (left[0], right[1])
if left[0] != right[1]:
spans.append(span)
stack.pop()
stack.append(span)
while n > 1:
n -= 1
binary_actions.append(REDUCE)
num_reduce += 1
else:
assert False
assert(len(stack) == 1)
assert(num_shift == num_reduce + 1)
return spans, binary_actions, nonbinary_actions
gitextract_zduez7kc/ ├── .gitignore ├── COLLINS.prm ├── README.md ├── TreeCRF.py ├── data/ │ ├── test.txt │ ├── train.txt │ └── valid.txt ├── data.py ├── eval_ppl.py ├── models.py ├── parse.py ├── preprocess.py ├── train.py ├── train_lm.py └── utils.py
SYMBOL INDEX (74 symbols across 9 files)
FILE: TreeCRF.py
class ConstituencyTreeCRF (line 11) | class ConstituencyTreeCRF(nn.Module):
method __init__ (line 12) | def __init__(self):
method logadd (line 16) | def logadd(self, x, y):
method logsumexp (line 20) | def logsumexp(self, x, dim=1):
method _init_table (line 24) | def _init_table(self, scores):
method _forward (line 30) | def _forward(self, scores):
method _backward (line 46) | def _backward(self, scores):
method _marginal (line 71) | def _marginal(self, scores):
method _entropy (line 80) | def _entropy(self, scores):
method _sample (line 104) | def _sample(self, scores, alpha = None, argmax = False):
method _viterbi (line 161) | def _viterbi(self, scores):
method _backtrack (line 185) | def _backtrack(self, b, s, t):
function get_span_str (line 196) | def get_span_str(start = None, end = None):
FILE: data.py
class Dataset (line 6) | class Dataset(object):
method __init__ (line 7) | def __init__(self, data_file):
method _convert (line 19) | def _convert(self, x):
method __len__ (line 22) | def __len__(self):
method __getitem__ (line 25) | def __getitem__(self, idx):
FILE: eval_ppl.py
function main (line 37) | def main(args):
FILE: models.py
class RNNLM (line 9) | class RNNLM(nn.Module):
method __init__ (line 10) | def __init__(self, vocab=10000,
method forward (line 25) | def forward(self, sent):
method generate (line 32) | def generate(self, bos = 2, eos = 3, max_len = 150):
class SeqLSTM (line 48) | class SeqLSTM(nn.Module):
method __init__ (line 49) | def __init__(self, i_dim = 200,
method forward (line 62) | def forward(self, x, prev_h = None):
class TreeLSTM (line 79) | class TreeLSTM(nn.Module):
method __init__ (line 80) | def __init__(self, dim = 200):
method forward (line 85) | def forward(self, x1, x2, e=None):
class RNNG (line 106) | class RNNG(nn.Module):
method __init__ (line 107) | def __init__(self, vocab = 100,
method get_span_scores (line 136) | def get_span_scores(self, x):
method get_action_masks (line 153) | def get_action_masks(self, actions, length):
method forward (line 170) | def forward(self, x, samples = 1, is_temp = 1., has_eos=True):
method forward_actions (line 296) | def forward_actions(self, x, actions, has_eos=True):
method forward_tree (line 385) | def forward_tree(self, x, actions, has_eos=True):
method logsumexp (line 404) | def logsumexp(self, x, dim=1):
FILE: parse.py
function is_next_open_bracket (line 33) | def is_next_open_bracket(line, start_idx):
function get_between_brackets (line 41) | def get_between_brackets(line, start_idx):
function get_tags_tokens_lowercase (line 50) | def get_tags_tokens_lowercase(line):
function get_nonterminal (line 70) | def get_nonterminal(line, start_idx):
function get_actions (line 81) | def get_actions(line):
function clean_number (line 112) | def clean_number(w):
function main (line 116) | def main(args):
FILE: preprocess.py
class Indexer (line 17) | class Indexer:
method __init__ (line 18) | def __init__(self, symbols = ["<pad>","<unk>","<s>","</s>"]):
method add_w (line 27) | def add_w(self, ws):
method convert (line 32) | def convert(self, w):
method convert_sequence (line 35) | def convert_sequence(self, ls):
method write (line 38) | def write(self, outfile):
method prune_vocab (line 46) | def prune_vocab(self, k, cnt = False):
method load_vocab (line 60) | def load_vocab(self, vocab_file):
function is_next_open_bracket (line 70) | def is_next_open_bracket(line, start_idx):
function get_between_brackets (line 78) | def get_between_brackets(line, start_idx):
function get_tags_tokens_lowercase (line 87) | def get_tags_tokens_lowercase(line):
function get_nonterminal (line 108) | def get_nonterminal(line, start_idx):
function get_actions (line 119) | def get_actions(line):
function pad (line 150) | def pad(ls, length, symbol):
function clean_number (line 155) | def clean_number(w):
function get_data (line 159) | def get_data(args):
function main (line 311) | def main(arguments):
FILE: train.py
function main (line 61) | def main(args):
function eval (line 243) | def eval(data, model, samples = 0, count_eos_ppl = 0):
FILE: train_lm.py
function main (line 52) | def main(args):
function eval (line 143) | def eval(data, model, count_eos_ppl = 0):
FILE: utils.py
function get_actions (line 7) | def get_actions(tree, SHIFT = 0, REDUCE = 1, OPEN='(', CLOSE=')'):
function get_tree (line 33) | def get_tree(actions, sent = None, SHIFT = 0, REDUCE = 1):
function get_spans (line 52) | def get_spans(actions, SHIFT = 0, REDUCE = 1):
function get_stats (line 74) | def get_stats(span1, span2):
function update_stats (line 88) | def update_stats(pred_span, gold_spans, stats):
function get_f1 (line 95) | def get_f1(stats):
function span_str (line 105) | def span_str(start = None, end = None):
function get_tree_from_binary_matrix (line 115) | def get_tree_from_binary_matrix(matrix, length):
function get_nonbinary_spans (line 132) | def get_nonbinary_spans(actions, SHIFT = 0, REDUCE = 1):
Condensed preview — 15 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (107K chars).
[
{
"path": ".gitignore",
"chars": 136,
"preview": "*.pt\n*.amat\n*.mat\n*.out\n*.out~\n*.pyc\n*.pt~\n.gitignore~\n*.out~\n*.sh\n*.sh~\n*.py~\n*.json\n*.json~\n*.model\n*.h5\n*.tar.gz\n*.hd"
},
{
"path": "COLLINS.prm",
"chars": 2357,
"preview": "##------------------------------------------##\n## Debug mode ##\n## 0: No debugging "
},
{
"path": "README.md",
"chars": 4211,
"preview": "# Unsupervised Recurrent Neural Network Grammars\n\nThis is an implementation of the paper: \n[Unsupervised Recurrent Neur"
},
{
"path": "TreeCRF.py",
"chars": 6934,
"preview": "#!/usr/bin/env python3\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nimport ite"
},
{
"path": "data/test.txt",
"chars": 6444,
"preview": "(S (INTJ (RB No)) (, ,) (NP (PRP it)) (VP (VBD was) (RB n't) (NP (NNP Black) (NNP Monday))) (. .))\n(S (CC But) (SBAR (IN"
},
{
"path": "data/train.txt",
"chars": 7549,
"preview": "(S (PP (IN In) (NP (NP (DT an) (NNP Oct.) (CD 19) (NN review)) (PP (IN of) (NP (`` ``) (NP (DT The) (NN Misanthrope)) ('"
},
{
"path": "data/valid.txt",
"chars": 7809,
"preview": "(S (NP (NP (DT The) (NN economy) (POS 's)) (NN temperature)) (VP (MD will) (VP (VB be) (VP (VBN taken) (PP (IN from) (NP"
},
{
"path": "data.py",
"chars": 1585,
"preview": "#!/usr/bin/env python3\nimport numpy as np\nimport torch\nimport pickle\n\nclass Dataset(object):\n def __init__(self, data_f"
},
{
"path": "eval_ppl.py",
"chars": 3753,
"preview": "#!/usr/bin/env python3\nimport sys\nimport os\n\nimport argparse\nimport json\nimport random\nimport shutil\nimport copy\n\nimport"
},
{
"path": "models.py",
"chars": 16848,
"preview": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\nimport numpy as np\nfrom utils import *\nfrom TreeCRF im"
},
{
"path": "parse.py",
"chars": 7103,
"preview": "#!/usr/bin/env python3\nimport sys\nimport os\n\nimport argparse\nimport json\nimport random\nimport shutil\nimport copy\n\nimport"
},
{
"path": "preprocess.py",
"chars": 14948,
"preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n\n\"\"\"Create data files\n\"\"\"\n\nimport os\nimport sys\nimport argparse\nimport num"
},
{
"path": "train.py",
"chars": 14042,
"preview": "#!/usr/bin/env python3\nimport sys\nimport os\n\nimport argparse\nimport json\nimport random\nimport shutil\nimport copy\n\nimport"
},
{
"path": "train_lm.py",
"chars": 6303,
"preview": "#!/usr/bin/env python3\nimport sys\nimport os\n\nimport argparse\nimport json\nimport random\nimport shutil\nimport copy\n\nimport"
},
{
"path": "utils.py",
"chars": 4286,
"preview": "#!/usr/bin/env python3\nimport numpy as np\nimport itertools\nimport random\n\n\ndef get_actions(tree, SHIFT = 0, REDUCE = 1, "
}
]
About this extraction
This page contains the full source code of the harvardnlp/urnng GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 15 files (101.9 KB), approximately 31.6k tokens, and a symbol index with 74 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.