Showing preview only (325K chars total). Download the full file or copy to clipboard to get everything.
Repository: pcyin/NL2code
Branch: master
Commit: f9732f1f5caa
Files: 51
Total size: 309.2 KB
Directory structure:
gitextract__fzn9_a9/
├── .gitignore
├── README.md
├── astnode.py
├── code_gen.py
├── components.py
├── config.py
├── dataset.py
├── decoder.py
├── evaluation.py
├── interactive_mode.py
├── lang/
│ ├── __init__.py
│ ├── grammar.py
│ ├── ifttt/
│ │ ├── __init__.py
│ │ ├── grammar.py
│ │ ├── ifttt_dataset.py
│ │ └── parse.py
│ ├── py/
│ │ ├── __init__.py
│ │ ├── grammar.py
│ │ ├── parse.py
│ │ ├── py_dataset.py
│ │ ├── seq2tree_exp.py
│ │ └── unaryclosure.py
│ ├── type_system.py
│ └── util.py
├── learner.py
├── main.py
├── model.py
├── nn/
│ ├── __init__.py
│ ├── activations.py
│ ├── initializations.py
│ ├── layers/
│ │ ├── __init__.py
│ │ ├── convolution.py
│ │ ├── core.py
│ │ ├── embeddings.py
│ │ └── recurrent.py
│ ├── objectives.py
│ ├── optimizers.py
│ └── utils/
│ ├── __init__.py
│ ├── config_factory.py
│ ├── generic_utils.py
│ ├── io_utils.py
│ ├── np_utils.py
│ ├── test_utils.py
│ └── theano_utils.py
├── parse.py
├── parse_hiro.py
├── run_interactive.sh
├── run_interactive_singlefile.sh
├── run_trained_model.sh
├── train.sh
└── util.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*,cover
.hypothesis/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# IPython Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# dotenv
.env
# virtualenv
venv/
ENV/
# Spyder project settings
.spyderproject
# Rope project settings
.ropeproject
================================================
FILE: README.md
================================================
# NL2code
A syntactic neural model for parsing natural language to executable code [paper](https://arxiv.org/abs/1704.01696).
## Dataset and Trained Models
Get serialized datasets and trained models from [here](https://drive.google.com/drive/folders/0B14lJ2VVvtmJWEQ5RlFjQUY2Vzg). Put `models/` and `data/` folders under the root directory of the project.
## Usage
To train new model
```bash
. train.sh [hs|django]
```
To use trained model for decoding test sets
```bash
. run_trained_model.sh [hs|django]
```
## Dependencies
* Theano
* vprof
* NLTK 3.2.1
* astor 0.6
## Reference
```
@inproceedings{yin17acl,
title = {A Syntactic Neural Model for General-Purpose Code Generation},
author = {Pengcheng Yin and Graham Neubig},
booktitle = {The 55th Annual Meeting of the Association for Computational Linguistics (ACL)},
address = {Vancouver, Canada},
month = {July},
url = {https://arxiv.org/abs/1704.01696},
year = {2017}
}
```
================================================
FILE: astnode.py
================================================
from collections import namedtuple
import cPickle
from collections import Iterable, OrderedDict, defaultdict
from cStringIO import StringIO
from lang.util import typename
class ASTNode(object):
def __init__(self, node_type, label=None, value=None, children=None):
self.type = node_type
self.label = label
self.value = value
if type(self) is not Rule:
self.parent = None
self.children = list()
if children:
if isinstance(children, Iterable):
for child in children:
self.add_child(child)
elif isinstance(children, ASTNode):
self.add_child(children)
else:
raise AttributeError('Wrong type for child nodes')
assert not (bool(children) and bool(value)), 'terminal node with a value cannot have children'
@property
def is_leaf(self):
return len(self.children) == 0
@property
def is_preterminal(self):
return len(self.children) == 1 and self.children[0].is_leaf
@property
def size(self):
if self.is_leaf:
return 1
node_num = 1
for child in self.children:
node_num += child.size
return node_num
@property
def nodes(self):
"""a generator that returns all the nodes"""
yield self
for child in self.children:
for child_n in child.nodes:
yield child_n
@property
def as_type_node(self):
"""return an ASTNode with type information only"""
return ASTNode(self.type)
def __repr__(self):
repr_str = ''
# if not self.is_leaf:
repr_str += '('
repr_str += typename(self.type)
if self.label is not None:
repr_str += '{%s}' % self.label
if self.value is not None:
repr_str += '{val=%s}' % self.value
# if not self.is_leaf:
for child in self.children:
repr_str += ' ' + child.__repr__()
repr_str += ')'
return repr_str
def __hash__(self):
code = hash(self.type)
if self.label is not None:
code = code * 37 + hash(self.label)
if self.value is not None:
code = code * 37 + hash(self.value)
for child in self.children:
code = code * 37 + hash(child)
return code
def __eq__(self, other):
if not isinstance(other, self.__class__):
return False
if hash(self) != hash(other):
return False
if self.type != other.type:
return False
if self.label != other.label:
return False
if self.value != other.value:
return False
if len(self.children) != len(other.children):
return False
for i in xrange(len(self.children)):
if self.children[i] != other.children[i]:
return False
return True
def __ne__(self, other):
return not self.__eq__(other)
def __getitem__(self, child_type):
return next(iter([c for c in self.children if c.type == child_type]))
def __delitem__(self, child_type):
tgt_child = [c for c in self.children if c.type == child_type]
if tgt_child:
assert len(tgt_child) == 1, 'unsafe deletion for more than one children'
tgt_child = tgt_child[0]
self.children.remove(tgt_child)
else:
raise KeyError
def add_child(self, child):
child.parent = self
self.children.append(child)
def get_child_id(self, child):
for i, _child in enumerate(self.children):
if child == _child:
return i
raise KeyError
def pretty_print(self):
sb = StringIO()
new_line = False
self.pretty_print_helper(sb, 0, new_line)
return sb.getvalue()
def pretty_print_helper(self, sb, depth, new_line=False):
if new_line:
sb.write('\n')
for i in xrange(depth): sb.write(' ')
sb.write('(')
sb.write(typename(self.type))
if self.label is not None:
sb.write('{%s}' % self.label)
if self.value is not None:
sb.write('{val=%s}' % self.value)
if len(self.children) == 0:
sb.write(')')
return
sb.write(' ')
new_line = True
for child in self.children:
child.pretty_print_helper(sb, depth + 2, new_line)
sb.write('\n')
for i in xrange(depth): sb.write(' ')
sb.write(')')
def get_leaves(self):
if self.is_leaf:
return [self]
leaves = []
for child in self.children:
leaves.extend(child.get_leaves())
return leaves
def to_rule(self, include_value=False):
"""
transform the current AST node to a production rule
"""
rule = Rule(self.type)
for c in self.children:
val = c.value if include_value else None
child = ASTNode(c.type, c.label, val)
rule.add_child(child)
return rule
def get_productions(self, include_value_node=False):
"""
get the depth-first, left-to-right sequence of rule applications
returns a list of production rules and a map to their parent rules
attention: node value is not included in child nodes
"""
rule_list = list()
rule_parents = OrderedDict()
node_rule_map = dict()
s = list()
s.append(self)
rule_num = 0
while len(s) > 0:
node = s.pop()
for child in reversed(node.children):
if not child.is_leaf:
s.append(child)
elif include_value_node:
if child.value is not None:
s.append(child)
# only non-terminals and terminal nodes holding values
# can form a production rule
if node.children or node.value is not None:
rule = Rule(node.type)
if include_value_node:
rule.value = node.value
for c in node.children:
val = None
child = ASTNode(c.type, c.label, val)
rule.add_child(child)
rule_list.append(rule)
if node.parent:
child_id = node.parent.get_child_id(node)
parent_rule = node_rule_map[node.parent]
rule_parents[(rule_num, rule)] = (parent_rule, child_id)
else:
rule_parents[(rule_num, rule)] = (None, -1)
rule_num += 1
node_rule_map[node] = rule
return rule_list, rule_parents
def copy(self):
# if not hasattr(self, '_dump'):
# dump = cPickle.dumps(self, -1)
# setattr(self, '_dump', dump)
#
# return cPickle.loads(dump)
#
# return cPickle.loads(self._dump)
new_tree = ASTNode(self.type, self.label, self.value)
if self.is_leaf:
return new_tree
for child in self.children:
new_tree.add_child(child.copy())
return new_tree
class DecodeTree(ASTNode):
def __init__(self, node_type, label=None, value=None, children=None, t=-1):
super(DecodeTree, self).__init__(node_type, label, value, children)
# record the time step when this subtree is created from a rule application
self.t = t
# record the ApplyRule action that is used to expand the current node
self.applied_rule = None
def copy(self):
new_tree = DecodeTree(self.type, self.label, value=self.value, t=self.t)
new_tree.applied_rule = self.applied_rule
if self.is_leaf:
return new_tree
for child in self.children:
new_tree.add_child(child.copy())
return new_tree
class Rule(ASTNode):
def __init__(self, *args, **kwargs):
super(Rule, self).__init__(*args, **kwargs)
assert self.value is None and self.label is None, 'Rule LHS cannot have values or labels'
@property
def parent(self):
return self.as_type_node
def __repr__(self):
parent = typename(self.type)
if self.label is not None:
parent += '{%s}' % self.label
if self.value is not None:
parent += '{val=%s}' % self.value
return '%s -> %s' % (parent, ', '.join([repr(c) for c in self.children]))
if __name__ == '__main__':
import ast
t1 = ASTNode('root', children=[
ASTNode(str, 'a1_label', children=[ASTNode(int, children=[ASTNode('a21', value=123)]),
ASTNode(ast.NodeTransformer, children=[ASTNode('a21', value='hahaha')])]
),
ASTNode('a2', children=[ASTNode('a21', value='asdf')])
])
t2 = ASTNode('root', children=[
ASTNode(str, 'a1_label', children=[ASTNode(int, children=[ASTNode('a21', value=123)]),
ASTNode(ast.NodeTransformer, children=[ASTNode('a21', value='hahaha')])]
),
ASTNode('a2', children=[ASTNode('a21', value='asdf')])
])
print t1 == t2
a, b = t1.get_productions(include_value_node=True)
# t = ASTNode('root', children=ASTNode('sdf'))
print t1.__repr__()
print t1.pretty_print()
================================================
FILE: code_gen.py
================================================
import numpy as np
import cProfile
import ast
import traceback
import argparse
import os
import logging
from vprof import profiler
from model import Model
from dataset import DataEntry, DataSet, Vocab, Action
import config
from learner import Learner
from evaluation import *
from decoder import decode_python_dataset
from components import Hyp
from astnode import ASTNode
from nn.utils.generic_utils import init_logging
from nn.utils.io_utils import deserialize_from_file, serialize_to_file
parser = argparse.ArgumentParser()
parser.add_argument('-data')
parser.add_argument('-random_seed', default=181783, type=int)
parser.add_argument('-output_dir', default='.outputs')
parser.add_argument('-model', default=None)
# model's main configuration
parser.add_argument('-data_type', default='django', choices=['django', 'ifttt', 'hs'])
# neural model's parameters
parser.add_argument('-source_vocab_size', default=0, type=int)
parser.add_argument('-target_vocab_size', default=0, type=int)
parser.add_argument('-rule_num', default=0, type=int)
parser.add_argument('-node_num', default=0, type=int)
parser.add_argument('-word_embed_dim', default=128, type=int)
parser.add_argument('-rule_embed_dim', default=256, type=int)
parser.add_argument('-node_embed_dim', default=256, type=int)
parser.add_argument('-encoder_hidden_dim', default=256, type=int)
parser.add_argument('-decoder_hidden_dim', default=256, type=int)
parser.add_argument('-attention_hidden_dim', default=50, type=int)
parser.add_argument('-ptrnet_hidden_dim', default=50, type=int)
parser.add_argument('-dropout', default=0.2, type=float)
# encoder
parser.add_argument('-encoder', default='bilstm', choices=['bilstm', 'lstm'])
# decoder
parser.add_argument('-parent_hidden_state_feed', dest='parent_hidden_state_feed', action='store_true')
parser.add_argument('-no_parent_hidden_state_feed', dest='parent_hidden_state_feed', action='store_false')
parser.set_defaults(parent_hidden_state_feed=True)
parser.add_argument('-parent_action_feed', dest='parent_action_feed', action='store_true')
parser.add_argument('-no_parent_action_feed', dest='parent_action_feed', action='store_false')
parser.set_defaults(parent_action_feed=True)
parser.add_argument('-frontier_node_type_feed', dest='frontier_node_type_feed', action='store_true')
parser.add_argument('-no_frontier_node_type_feed', dest='frontier_node_type_feed', action='store_false')
parser.set_defaults(frontier_node_type_feed=True)
parser.add_argument('-tree_attention', dest='tree_attention', action='store_true')
parser.add_argument('-no_tree_attention', dest='tree_attention', action='store_false')
parser.set_defaults(tree_attention=False)
parser.add_argument('-enable_copy', dest='enable_copy', action='store_true')
parser.add_argument('-no_copy', dest='enable_copy', action='store_false')
parser.set_defaults(enable_copy=True)
# training
parser.add_argument('-optimizer', default='adam')
parser.add_argument('-clip_grad', default=0., type=float)
parser.add_argument('-train_patience', default=10, type=int)
parser.add_argument('-max_epoch', default=50, type=int)
parser.add_argument('-batch_size', default=10, type=int)
parser.add_argument('-valid_per_batch', default=4000, type=int)
parser.add_argument('-save_per_batch', default=4000, type=int)
parser.add_argument('-valid_metric', default='bleu')
# decoding
parser.add_argument('-beam_size', default=15, type=int)
parser.add_argument('-max_query_length', default=70, type=int)
parser.add_argument('-decode_max_time_step', default=100, type=int)
parser.add_argument('-head_nt_constraint', dest='head_nt_constraint', action='store_true')
parser.add_argument('-no_head_nt_constraint', dest='head_nt_constraint', action='store_false')
parser.set_defaults(head_nt_constraint=True)
sub_parsers = parser.add_subparsers(dest='operation', help='operation to take')
train_parser = sub_parsers.add_parser('train')
decode_parser = sub_parsers.add_parser('decode')
interactive_parser = sub_parsers.add_parser('interactive')
evaluate_parser = sub_parsers.add_parser('evaluate')
# decoding operation
decode_parser.add_argument('-saveto', default='decode_results.bin')
decode_parser.add_argument('-type', default='test_data')
# evaluation operation
evaluate_parser.add_argument('-mode', default='self')
evaluate_parser.add_argument('-input', default='decode_results.bin')
evaluate_parser.add_argument('-type', default='test_data')
evaluate_parser.add_argument('-seq2tree_sample_file', default='model.sample')
evaluate_parser.add_argument('-seq2tree_id_file', default='test.id.txt')
evaluate_parser.add_argument('-seq2tree_rareword_map', default=None)
evaluate_parser.add_argument('-seq2seq_decode_file')
evaluate_parser.add_argument('-seq2seq_ref_file')
evaluate_parser.add_argument('-is_nbest', default=False, action='store_true')
# misc
parser.add_argument('-ifttt_test_split', default='data/ifff.test_data.gold.id')
# interactive operation
interactive_parser.add_argument('-mode', default='dataset')
if __name__ == '__main__':
args = parser.parse_args()
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
np.random.seed(args.random_seed)
init_logging(os.path.join(args.output_dir, 'parser.log'), logging.INFO)
logging.info('command line: %s', ' '.join(sys.argv))
logging.info('loading dataset [%s]', args.data)
train_data, dev_data, test_data = deserialize_from_file(args.data)
if not args.source_vocab_size:
args.source_vocab_size = train_data.annot_vocab.size
if not args.target_vocab_size:
args.target_vocab_size = train_data.terminal_vocab.size
if not args.rule_num:
args.rule_num = len(train_data.grammar.rules)
if not args.node_num:
args.node_num = len(train_data.grammar.node_type_to_id)
logging.info('current config: %s', args)
config_module = sys.modules['config']
for name, value in vars(args).iteritems():
setattr(config_module, name, value)
# get dataset statistics
avg_action_num = np.average([len(e.actions) for e in train_data.examples])
logging.info('avg_action_num: %d', avg_action_num)
logging.info('grammar rule num.: %d', len(train_data.grammar.rules))
logging.info('grammar node type num.: %d', len(train_data.grammar.node_type_to_id))
logging.info('source vocab size: %d', train_data.annot_vocab.size)
logging.info('target vocab size: %d', train_data.terminal_vocab.size)
if args.operation in ['train', 'decode', 'interactive']:
model = Model()
model.build()
if args.model:
model.load(args.model)
if args.operation == 'train':
# train_data = train_data.get_dataset_by_ids(range(2000), 'train_sample')
# dev_data = dev_data.get_dataset_by_ids(range(10), 'dev_sample')
learner = Learner(model, train_data, dev_data)
learner.train()
if args.operation == 'decode':
# ==========================
# investigate short examples
# ==========================
# short_examples = [e for e in test_data.examples if e.parse_tree.size <= 2]
# for e in short_examples:
# print e.parse_tree
# print 'short examples num: ', len(short_examples)
# dataset = test_data # test_data.get_dataset_by_ids([1,2,3,4,5,6,7,8,9,10], name='sample')
# cProfile.run('decode_dataset(model, dataset)', sort=2)
# from evaluation import decode_and_evaluate_ifttt
if args.data_type == 'ifttt':
decode_results = decode_and_evaluate_ifttt_by_split(model, test_data)
else:
dataset = eval(args.type)
decode_results = decode_python_dataset(model, dataset)
serialize_to_file(decode_results, args.saveto)
if args.operation == 'evaluate':
dataset = eval(args.type)
if config.mode == 'self':
decode_results_file = args.input
decode_results = deserialize_from_file(decode_results_file)
evaluate_decode_results(dataset, decode_results)
elif config.mode == 'seq2tree':
from evaluation import evaluate_seq2tree_sample_file
evaluate_seq2tree_sample_file(config.seq2tree_sample_file, config.seq2tree_id_file, dataset)
elif config.mode == 'seq2seq':
from evaluation import evaluate_seq2seq_decode_results
evaluate_seq2seq_decode_results(dataset, config.seq2seq_decode_file, config.seq2seq_ref_file, is_nbest=config.is_nbest)
elif config.mode == 'analyze':
from evaluation import analyze_decode_results
decode_results_file = args.input
decode_results = deserialize_from_file(decode_results_file)
analyze_decode_results(dataset, decode_results)
if args.operation == 'interactive':
from dataset import canonicalize_query, query_to_data
from collections import namedtuple
from lang.py.parse import decode_tree_to_python_ast
assert model is not None
while True:
cmd = raw_input('example id or query: ')
if args.mode == 'dataset':
try:
example_id = int(cmd)
example = [e for e in test_data.examples if e.raw_id == example_id][0]
except:
print 'something went wrong ...'
continue
elif args.mode == 'new':
# we play with new examples!
query, str_map = canonicalize_query(cmd)
vocab = train_data.annot_vocab
query_tokens = query.split(' ')
query_tokens_data = [query_to_data(query, vocab)]
example = namedtuple('example', ['query', 'data'])(query=query_tokens, data=query_tokens_data)
if hasattr(example, 'parse_tree'):
print 'gold parse tree:'
print example.parse_tree
cand_list = model.decode(example, train_data.grammar, train_data.terminal_vocab,
beam_size=args.beam_size, max_time_step=args.decode_max_time_step, log=True)
has_grammar_error = any([c for c in cand_list if c.has_grammar_error])
print 'has_grammar_error: ', has_grammar_error
for cid, cand in enumerate(cand_list[:5]):
print '*' * 60
print 'cand #%d, score: %f' % (cid, cand.score)
try:
ast_tree = decode_tree_to_python_ast(cand.tree)
code = astor.to_source(ast_tree)
print 'code: ', code
print 'decode log: ', cand.log
except:
print "Exception in converting tree to code:"
print '-' * 60
print 'raw_id: %d, beam pos: %d' % (example.raw_id, cid)
traceback.print_exc(file=sys.stdout)
print '-' * 60
finally:
print '* parse tree *'
print cand.tree.__repr__()
print 'n_timestep: %d' % cand.n_timestep
print 'ast size: %d' % cand.tree.size
print '*' * 60
================================================
FILE: components.py
================================================
import theano
import theano.tensor as T
import numpy as np
import logging
import copy
from nn.layers.embeddings import Embedding
from nn.layers.core import Dense, Layer
from nn.layers.recurrent import BiLSTM, LSTM, CondAttLSTM
from nn.utils.theano_utils import ndim_itensor, tensor_right_shift, ndim_tensor, alloc_zeros_matrix, shared_zeros
import nn.initializations as initializations
import nn.activations as activations
import nn.optimizers as optimizers
import config
from lang.grammar import Grammar
from parse import *
from astnode import *
class PointerNet(Layer):
def __init__(self, name='PointerNet'):
super(PointerNet, self).__init__()
self.dense1_input = Dense(config.encoder_hidden_dim, config.ptrnet_hidden_dim, activation='linear', name='Dense1_input')
self.dense1_h = Dense(config.decoder_hidden_dim + config.encoder_hidden_dim, config.ptrnet_hidden_dim, activation='linear', name='Dense1_h')
self.dense2 = Dense(config.ptrnet_hidden_dim, 1, activation='linear', name='Dense2')
self.params += self.dense1_input.params + self.dense1_h.params + self.dense2.params
self.set_name(name)
def __call__(self, query_embed, query_token_embed_mask, decoder_states):
query_embed_trans = self.dense1_input(query_embed)
h_trans = self.dense1_h(decoder_states)
query_embed_trans = query_embed_trans.dimshuffle((0, 'x', 1, 2))
h_trans = h_trans.dimshuffle((0, 1, 'x', 2))
# (batch_size, max_decode_step, query_token_num, ptr_net_hidden_dim)
dense1_trans = T.tanh(query_embed_trans + h_trans)
scores = self.dense2(dense1_trans).flatten(3)
scores = T.exp(scores - T.max(scores, axis=-1, keepdims=True))
scores *= query_token_embed_mask.dimshuffle((0, 'x', 1))
scores = scores / T.sum(scores, axis=-1, keepdims=True)
return scores
class Hyp:
def __init__(self, *args):
if isinstance(args[0], Hyp):
hyp = args[0]
self.grammar = hyp.grammar
self.tree = hyp.tree.copy()
self.t = hyp.t
self.hist_h = list(hyp.hist_h)
self.log = hyp.log
self.has_grammar_error = hyp.has_grammar_error
else:
assert isinstance(args[0], Grammar)
grammar = args[0]
self.grammar = grammar
self.tree = DecodeTree(grammar.root_node.type)
self.t=-1
self.hist_h = []
self.log = ''
self.has_grammar_error = False
self.score = 0.0
self.__frontier_nt = self.tree
self.__frontier_nt_t = -1
def __repr__(self):
return self.tree.__repr__()
def can_expand(self, node):
if self.grammar.is_value_node(node):
# if the node is finished
if node.value is not None and node.value.endswith('<eos>'):
return False
return True
elif self.grammar.is_terminal(node):
return False
# elif node.type == 'epsilon':
# return False
# elif is_terminal_ast_type(node.type):
# return False
# if node.type == 'root':
# return True
# elif inspect.isclass(node.type) and issubclass(node.type, ast.AST) and not is_terminal_ast_type(node.type):
# return True
# elif node.holds_value and not node.label.endswith('<eos>'):
# return True
return True
def apply_rule(self, rule, nt=None):
if nt is None:
nt = self.frontier_nt()
# assert rule.parent.type == nt.type
if rule.parent.type != nt.type:
self.has_grammar_error = True
self.t += 1
# set the time step when the rule leading by this nt is applied
nt.t = self.t
# record the ApplyRule action that is used to expand the current node
nt.applied_rule = rule
for child_node in rule.children:
child = DecodeTree(child_node.type, child_node.label, child_node.value)
# if is_builtin_type(rule.parent.type):
# child.label = None
# child.holds_value = True
nt.add_child(child)
def append_token(self, token, nt=None):
if nt is None:
nt = self.frontier_nt()
self.t += 1
if nt.value is None:
# this terminal node is empty
nt.t = self.t
nt.value = token
else:
nt.value += token
def frontier_nt_helper(self, node):
if node.is_leaf:
if self.can_expand(node):
return node
else:
return None
for child in node.children:
result = self.frontier_nt_helper(child)
if result:
return result
return None
def frontier_nt(self):
if self.__frontier_nt_t == self.t:
return self.__frontier_nt
else:
_frontier_nt = self.frontier_nt_helper(self.tree)
self.__frontier_nt = _frontier_nt
self.__frontier_nt_t = self.t
return _frontier_nt
def get_action_parent_t(self):
"""
get the time step when the parent of the current
action was generated
WARNING: 0 will be returned if parent if None
"""
nt = self.frontier_nt()
# if nt is a non-finishing leaf
# if nt.holds_value:
# return nt.t
if nt.parent:
return nt.parent.t
else:
return 0
# def get_action_parent_tree(self):
# """
# get the parent tree
# """
# nt = self.frontier_nt()
#
# # if nt is a non-finishing leaf
# if nt.holds_value:
# return nt
#
# if nt.parent:
# return nt.parent
# else:
# return None
class CondAttLSTM(Layer):
"""
Conditional LSTM with Attention
"""
def __init__(self, input_dim, output_dim,
context_dim, att_hidden_dim,
init='glorot_uniform', inner_init='orthogonal', forget_bias_init='one',
activation='tanh', inner_activation='sigmoid', name='CondAttLSTM'):
super(CondAttLSTM, self).__init__()
self.output_dim = output_dim
self.init = initializations.get(init)
self.inner_init = initializations.get(inner_init)
self.forget_bias_init = initializations.get(forget_bias_init)
self.activation = activations.get(activation)
self.inner_activation = activations.get(inner_activation)
self.context_dim = context_dim
self.input_dim = input_dim
# regular LSTM layer
self.W_i = self.init((input_dim, self.output_dim))
self.U_i = self.inner_init((self.output_dim, self.output_dim))
self.C_i = self.inner_init((self.context_dim, self.output_dim))
self.H_i = self.inner_init((self.output_dim, self.output_dim))
self.P_i = self.inner_init((self.output_dim, self.output_dim))
self.b_i = shared_zeros((self.output_dim))
self.W_f = self.init((input_dim, self.output_dim))
self.U_f = self.inner_init((self.output_dim, self.output_dim))
self.C_f = self.inner_init((self.context_dim, self.output_dim))
self.H_f = self.inner_init((self.output_dim, self.output_dim))
self.P_f = self.inner_init((self.output_dim, self.output_dim))
self.b_f = self.forget_bias_init((self.output_dim))
self.W_c = self.init((input_dim, self.output_dim))
self.U_c = self.inner_init((self.output_dim, self.output_dim))
self.C_c = self.inner_init((self.context_dim, self.output_dim))
self.H_c = self.inner_init((self.output_dim, self.output_dim))
self.P_c = self.inner_init((self.output_dim, self.output_dim))
self.b_c = shared_zeros((self.output_dim))
self.W_o = self.init((input_dim, self.output_dim))
self.U_o = self.inner_init((self.output_dim, self.output_dim))
self.C_o = self.inner_init((self.context_dim, self.output_dim))
self.H_o = self.inner_init((self.output_dim, self.output_dim))
self.P_o = self.inner_init((self.output_dim, self.output_dim))
self.b_o = shared_zeros((self.output_dim))
self.params = [
self.W_i, self.U_i, self.b_i, self.C_i, self.H_i, self.P_i,
self.W_c, self.U_c, self.b_c, self.C_c, self.H_c, self.P_c,
self.W_f, self.U_f, self.b_f, self.C_f, self.H_f, self.P_f,
self.W_o, self.U_o, self.b_o, self.C_o, self.H_o, self.P_o,
]
# attention layer
self.att_ctx_W1 = self.init((context_dim, att_hidden_dim))
self.att_h_W1 = self.init((output_dim, att_hidden_dim))
self.att_b1 = shared_zeros((att_hidden_dim))
self.att_W2 = self.init((att_hidden_dim, 1))
self.att_b2 = shared_zeros((1))
self.params += [
self.att_ctx_W1, self.att_h_W1, self.att_b1,
self.att_W2, self.att_b2
]
# attention over history
self.hatt_h_W1 = self.init((output_dim, att_hidden_dim))
self.hatt_hist_W1 = self.init((output_dim, att_hidden_dim))
self.hatt_b1 = shared_zeros((att_hidden_dim))
self.hatt_W2 = self.init((att_hidden_dim, 1))
self.hatt_b2 = shared_zeros((1))
self.params += [
self.hatt_h_W1, self.hatt_hist_W1, self.hatt_b1,
self.hatt_W2, self.hatt_b2
]
self.set_name(name)
def _step(self,
t, xi_t, xf_t, xo_t, xc_t, mask_t, parent_t,
h_tm1, c_tm1, hist_h,
u_i, u_f, u_o, u_c,
c_i, c_f, c_o, c_c,
h_i, h_f, h_o, h_c,
p_i, p_f, p_o, p_c,
att_h_w1, att_w2, att_b2,
context, context_mask, context_att_trans,
b_u):
# context: (batch_size, context_size, context_dim)
# (batch_size, att_layer1_dim)
h_tm1_att_trans = T.dot(h_tm1, att_h_w1)
# h_tm1_att_trans = theano.printing.Print('h_tm1_att_trans')(h_tm1_att_trans)
# (batch_size, context_size, att_layer1_dim)
att_hidden = T.tanh(context_att_trans + h_tm1_att_trans[:, None, :])
# (batch_size, context_size, 1)
att_raw = T.dot(att_hidden, att_w2) + att_b2
att_raw = att_raw.reshape((att_raw.shape[0], att_raw.shape[1]))
# (batch_size, context_size)
ctx_att = T.exp(att_raw - T.max(att_raw, axis=-1, keepdims=True))
if context_mask:
ctx_att = ctx_att * context_mask
ctx_att = ctx_att / T.sum(ctx_att, axis=-1, keepdims=True)
# (batch_size, context_dim)
ctx_vec = T.sum(context * ctx_att[:, :, None], axis=1)
# t = theano.printing.Print('t')(t)
##### attention over history #####
def _attention_over_history():
hist_h_mask = T.zeros((hist_h.shape[0], hist_h.shape[1]), dtype='int8')
hist_h_mask = T.set_subtensor(hist_h_mask[:, T.arange(t)], 1)
hist_h_att_trans = T.dot(hist_h, self.hatt_hist_W1) + self.hatt_b1
h_tm1_hatt_trans = T.dot(h_tm1, self.hatt_h_W1)
hatt_hidden = T.tanh(hist_h_att_trans + h_tm1_hatt_trans[:, None, :])
hatt_raw = T.dot(hatt_hidden, self.hatt_W2) + self.hatt_b2
hatt_raw = hatt_raw.reshape((hist_h.shape[0], hist_h.shape[1]))
# hatt_raw = theano.printing.Print('hatt_raw')(hatt_raw)
hatt_exp = T.exp(hatt_raw - T.max(hatt_raw, axis=-1, keepdims=True)) * hist_h_mask
# hatt_exp = theano.printing.Print('hatt_exp')(hatt_exp)
# hatt_exp = hatt_exp.flatten(2)
h_att_weights = hatt_exp / (T.sum(hatt_exp, axis=-1, keepdims=True) + 1e-7)
# h_att_weights = theano.printing.Print('h_att_weights')(h_att_weights)
# (batch_size, output_dim)
_h_ctx_vec = T.sum(hist_h * h_att_weights[:, :, None], axis=1)
return _h_ctx_vec
h_ctx_vec = T.switch(t,
_attention_over_history(),
T.zeros_like(h_tm1))
# h_ctx_vec = theano.printing.Print('h_ctx_vec')(h_ctx_vec)
##### attention over history #####
##### feed in parent hidden state #####
if not config.parent_hidden_state_feed:
t = 0
par_h = T.switch(t,
hist_h[T.arange(hist_h.shape[0]), parent_t, :],
T.zeros_like(h_tm1))
##### feed in parent hidden state #####
if config.tree_attention:
i_t = self.inner_activation(
xi_t + T.dot(h_tm1 * b_u[0], u_i) + T.dot(ctx_vec, c_i) + T.dot(par_h, p_i) + T.dot(h_ctx_vec, h_i))
f_t = self.inner_activation(
xf_t + T.dot(h_tm1 * b_u[1], u_f) + T.dot(ctx_vec, c_f) + T.dot(par_h, p_f) + T.dot(h_ctx_vec, h_f))
c_t = f_t * c_tm1 + i_t * self.activation(
xc_t + T.dot(h_tm1 * b_u[2], u_c) + T.dot(ctx_vec, c_c) + T.dot(par_h, p_c) + T.dot(h_ctx_vec, h_c))
o_t = self.inner_activation(
xo_t + T.dot(h_tm1 * b_u[3], u_o) + T.dot(ctx_vec, c_o) + T.dot(par_h, p_o) + T.dot(h_ctx_vec, h_o))
else:
i_t = self.inner_activation(
xi_t + T.dot(h_tm1 * b_u[0], u_i) + T.dot(ctx_vec, c_i) + T.dot(par_h, p_i)) # + T.dot(h_ctx_vec, h_i)
f_t = self.inner_activation(
xf_t + T.dot(h_tm1 * b_u[1], u_f) + T.dot(ctx_vec, c_f) + T.dot(par_h, p_f)) # + T.dot(h_ctx_vec, h_f)
c_t = f_t * c_tm1 + i_t * self.activation(
xc_t + T.dot(h_tm1 * b_u[2], u_c) + T.dot(ctx_vec, c_c) + T.dot(par_h, p_c)) # + T.dot(h_ctx_vec, h_c)
o_t = self.inner_activation(
xo_t + T.dot(h_tm1 * b_u[3], u_o) + T.dot(ctx_vec, c_o) + T.dot(par_h, p_o)) # + T.dot(h_ctx_vec, h_o)
h_t = o_t * self.activation(c_t)
h_t = (1 - mask_t) * h_tm1 + mask_t * h_t
c_t = (1 - mask_t) * c_tm1 + mask_t * c_t
new_hist_h = T.set_subtensor(hist_h[:, t, :], h_t)
return h_t, c_t, ctx_vec, new_hist_h
def _for_step(self,
xi_t, xf_t, xo_t, xc_t, mask_t,
h_tm1, c_tm1,
context, context_mask, context_att_trans,
hist_h, hist_h_att_trans,
b_u):
# context: (batch_size, context_size, context_dim)
# (batch_size, att_layer1_dim)
h_tm1_att_trans = T.dot(h_tm1, self.att_h_W1)
# (batch_size, context_size, att_layer1_dim)
att_hidden = T.tanh(context_att_trans + h_tm1_att_trans[:, None, :])
# (batch_size, context_size, 1)
att_raw = T.dot(att_hidden, self.att_W2) + self.att_b2
# (batch_size, context_size)
ctx_att = T.exp(att_raw).reshape((att_raw.shape[0], att_raw.shape[1]))
if context_mask:
ctx_att = ctx_att * context_mask
ctx_att = ctx_att / T.sum(ctx_att, axis=-1, keepdims=True)
# (batch_size, context_dim)
ctx_vec = T.sum(context * ctx_att[:, :, None], axis=1)
##### attention over history #####
if hist_h:
hist_h = T.stack(hist_h).dimshuffle((1, 0, 2))
hist_h_att_trans = T.stack(hist_h_att_trans).dimshuffle((1, 0, 2))
h_tm1_hatt_trans = T.dot(h_tm1, self.hatt_h_W1)
hatt_hidden = T.tanh(hist_h_att_trans + h_tm1_hatt_trans[:, None, :])
hatt_raw = T.dot(hatt_hidden, self.hatt_W2) + self.hatt_b2
hatt_raw = hatt_raw.flatten(2)
h_att_weights = T.nnet.softmax(hatt_raw)
# (batch_size, output_dim)
h_ctx_vec = T.sum(hist_h * h_att_weights[:, :, None], axis=1)
else:
h_ctx_vec = T.zeros_like(h_tm1)
##### attention over history #####
i_t = self.inner_activation(xi_t + T.dot(h_tm1 * b_u[0], self.U_i) + T.dot(ctx_vec, self.C_i) + T.dot(h_ctx_vec, self.H_i))
f_t = self.inner_activation(xf_t + T.dot(h_tm1 * b_u[1], self.U_f) + T.dot(ctx_vec, self.C_f) + T.dot(h_ctx_vec, self.H_f))
c_t = f_t * c_tm1 + i_t * self.activation(xc_t + T.dot(h_tm1 * b_u[2], self.U_c) + T.dot(ctx_vec, self.C_c) + T.dot(h_ctx_vec, self.H_c))
o_t = self.inner_activation(xo_t + T.dot(h_tm1 * b_u[3], self.U_o) + T.dot(ctx_vec, self.C_o) + T.dot(h_ctx_vec, self.H_o))
h_t = o_t * self.activation(c_t)
h_t = (1 - mask_t) * h_tm1 + mask_t * h_t
c_t = (1 - mask_t) * c_tm1 + mask_t * c_t
# ctx_vec = theano.printing.Print('ctx_vec')(ctx_vec)
return h_t, c_t, ctx_vec
def __call__(self, X, context, parent_t_seq, init_state=None, init_cell=None, hist_h=None,
mask=None, context_mask=None,
dropout=0, train=True, srng=None,
time_steps=None):
assert context_mask.dtype == 'int8', 'context_mask is not int8, got %s' % context_mask.dtype
# (n_timestep, batch_size)
mask = self.get_mask(mask, X)
# (n_timestep, batch_size, input_dim)
X = X.dimshuffle((1, 0, 2))
retain_prob = 1. - dropout
B_w = np.ones((4,), dtype=theano.config.floatX)
B_u = np.ones((4,), dtype=theano.config.floatX)
if dropout > 0:
logging.info('applying dropout with p = %f', dropout)
if train:
B_w = srng.binomial((4, X.shape[1], self.input_dim), p=retain_prob,
dtype=theano.config.floatX)
B_u = srng.binomial((4, X.shape[1], self.output_dim), p=retain_prob,
dtype=theano.config.floatX)
else:
B_w *= retain_prob
B_u *= retain_prob
# (n_timestep, batch_size, output_dim)
xi = T.dot(X * B_w[0], self.W_i) + self.b_i
xf = T.dot(X * B_w[1], self.W_f) + self.b_f
xc = T.dot(X * B_w[2], self.W_c) + self.b_c
xo = T.dot(X * B_w[3], self.W_o) + self.b_o
# (batch_size, context_size, att_layer1_dim)
context_att_trans = T.dot(context, self.att_ctx_W1) + self.att_b1
if init_state:
# (batch_size, output_dim)
first_state = T.unbroadcast(init_state, 1)
else:
first_state = T.unbroadcast(alloc_zeros_matrix(X.shape[1], self.output_dim), 1)
if init_cell:
# (batch_size, output_dim)
first_cell = T.unbroadcast(init_cell, 1)
else:
first_cell = T.unbroadcast(alloc_zeros_matrix(X.shape[1], self.output_dim), 1)
if not hist_h:
# (batch_size, n_timestep, output_dim)
hist_h = alloc_zeros_matrix(X.shape[1], X.shape[0], self.output_dim)
if train:
n_timestep = X.shape[0]
time_steps = T.arange(n_timestep, dtype='int32')
# (n_timestep, batch_size)
parent_t_seq = parent_t_seq.dimshuffle((1, 0))
[outputs, cells, ctx_vectors, hist_h_outputs], updates = theano.scan(
self._step,
sequences=[time_steps, xi, xf, xo, xc, mask, parent_t_seq],
outputs_info=[
first_state, # for h
first_cell, # for cell
None, # T.unbroadcast(alloc_zeros_matrix(X.shape[1], self.context_dim), 1), # for ctx vector
hist_h, # for hist_h
],
non_sequences=[
self.U_i, self.U_f, self.U_o, self.U_c,
self.C_i, self.C_f, self.C_o, self.C_c,
self.H_i, self.H_f, self.H_o, self.H_c,
self.P_i, self.P_f, self.P_o, self.P_c,
self.att_h_W1, self.att_W2, self.att_b2,
context, context_mask, context_att_trans,
B_u
])
outputs = outputs.dimshuffle((1, 0, 2))
ctx_vectors = ctx_vectors.dimshuffle((1, 0, 2))
cells = cells.dimshuffle((1, 0, 2))
return outputs, cells, ctx_vectors
def get_mask(self, mask, X):
if mask is None:
mask = T.ones((X.shape[0], X.shape[1]))
mask = T.shape_padright(mask) # (nb_samples, time, 1)
mask = T.addbroadcast(mask, -1) # (time, nb_samples, 1) matrix.
mask = mask.dimshuffle(1, 0, 2) # (time, nb_samples, 1)
mask = mask.astype('int8')
return mask
================================================
FILE: config.py
================================================
# MODE = 'django'
#
# SOURCE_VOCAB_SIZE = 2490 # 2492 # 5980
# TARGET_VOCAB_SIZE = 2101 # 2110 # 4830 #
# RULE_NUM = 222 # 228
# NODE_NUM = 96
#
# NODE_EMBED_DIM = 256
# EMBED_DIM = 128
# RULE_EMBED_DIM = 256
# QUERY_DIM = 256
# LSTM_STATE_DIM = 256
# DECODER_ATT_HIDDEN_DIM = 50
# POINTER_NET_HIDDEN_DIM = 50
#
# MAX_QUERY_LENGTH = 70
# MAX_EXAMPLE_ACTION_NUM = 100
#
# DECODER_DROPOUT = 0.2
# WORD_DROPOUT = 0
#
# # encoder
# ENCODER_LSTM = 'bilstm'
#
# # decoder
# PARENT_HIDDEN_STATE_FEEDING = True
# PARENT_RULE_FEEDING = True
# NODE_TYPE_FEEDING = True
# TREE_ATTENTION = True
#
# # training
# TRAIN_PATIENCE = 10
# MAX_EPOCH = 50
# BATCH_SIZE = 10
# VALID_PER_MINIBATCH = 4000
# SAVE_PER_MINIBATCH = 4000
#
# # decoding
# BEAM_SIZE = 15
# DECODE_MAX_TIME_STEP = 100
config_info = None
================================================
FILE: dataset.py
================================================
from __future__ import division
import copy
import nltk
from collections import OrderedDict, defaultdict
import logging
import collections
import numpy as np
import string
import re
import astor
from itertools import chain
from nn.utils.io_utils import serialize_to_file, deserialize_from_file
import config
from lang.py.parse import get_grammar
from lang.py.unaryclosure import get_top_unary_closures, apply_unary_closures
# define actions
APPLY_RULE = 0
GEN_TOKEN = 1
COPY_TOKEN = 2
GEN_COPY_TOKEN = 3
ACTION_NAMES = {APPLY_RULE: 'APPLY_RULE',
GEN_TOKEN: 'GEN_TOKEN',
COPY_TOKEN: 'COPY_TOKEN',
GEN_COPY_TOKEN: 'GEN_COPY_TOKEN'}
class Action(object):
def __init__(self, act_type, data):
self.act_type = act_type
self.data = data
def __repr__(self):
data_str = self.data if not isinstance(self.data, dict) else \
', '.join(['%s: %s' % (k, v) for k, v in self.data.iteritems()])
repr_str = 'Action{%s}[%s]' % (ACTION_NAMES[self.act_type], data_str)
return repr_str
class Vocab(object):
def __init__(self):
self.token_id_map = OrderedDict()
self.insert_token('<pad>')
self.insert_token('<unk>')
self.insert_token('<eos>')
@property
def unk(self):
return self.token_id_map['<unk>']
@property
def eos(self):
return self.token_id_map['<eos>']
def __getitem__(self, item):
if item in self.token_id_map:
return self.token_id_map[item]
logging.debug('encounter one unknown word [%s]' % item)
return self.token_id_map['<unk>']
def __contains__(self, item):
return item in self.token_id_map
@property
def size(self):
return len(self.token_id_map)
def __setitem__(self, key, value):
self.token_id_map[key] = value
def __len__(self):
return len(self.token_id_map)
def __iter__(self):
return self.token_id_map.iterkeys()
def iteritems(self):
return self.token_id_map.iteritems()
def complete(self):
self.id_token_map = dict((v, k) for (k, v) in self.token_id_map.iteritems())
def get_token(self, token_id):
return self.id_token_map[token_id]
def insert_token(self, token):
if token in self.token_id_map:
return self[token]
else:
idx = len(self)
self[token] = idx
return idx
replace_punctuation = string.maketrans(string.punctuation, ' '*len(string.punctuation))
def tokenize(str):
str = str.translate(replace_punctuation)
return nltk.word_tokenize(str)
def gen_vocab(tokens, vocab_size=3000, freq_cutoff=5):
word_freq = defaultdict(int)
for token in tokens:
word_freq[token] += 1
print 'total num. of tokens: %d' % len(word_freq)
words_freq_cutoff = [w for w in word_freq if word_freq[w] >= freq_cutoff]
print 'num. of words appear at least %d: %d' % (freq_cutoff, len(words_freq_cutoff))
ranked_words = sorted(words_freq_cutoff, key=word_freq.get, reverse=True)[:vocab_size-2]
ranked_words = set(ranked_words)
vocab = Vocab()
for token in tokens:
if token in ranked_words:
vocab.insert_token(token)
vocab.complete()
return vocab
class DataEntry:
def __init__(self, raw_id, query, parse_tree, code, actions, meta_data=None):
self.raw_id = raw_id
self.eid = -1
# FIXME: rename to query_token
self.query = query
self.parse_tree = parse_tree
self.actions = actions
self.code = code
self.meta_data = meta_data
@property
def data(self):
if not hasattr(self, '_data'):
assert self.dataset is not None, 'No associated dataset for the example'
self._data = self.dataset.get_prob_func_inputs([self.eid])
return self._data
def copy(self):
e = DataEntry(self.raw_id, self.query, self.parse_tree, self.code, self.actions, self.meta_data)
return e
class DataSet:
def __init__(self, annot_vocab, terminal_vocab, grammar, name='train_data'):
self.annot_vocab = annot_vocab
self.terminal_vocab = terminal_vocab
self.name = name
self.examples = []
self.data_matrix = dict()
self.grammar = grammar
def add(self, example):
example.eid = len(self.examples)
example.dataset = self
self.examples.append(example)
def get_dataset_by_ids(self, ids, name):
dataset = DataSet(self.annot_vocab, self.terminal_vocab,
self.grammar, name)
for eid in ids:
example_copy = self.examples[eid].copy()
dataset.add(example_copy)
for k, v in self.data_matrix.iteritems():
dataset.data_matrix[k] = v[ids]
return dataset
@property
def count(self):
if self.examples:
return len(self.examples)
return 0
def get_examples(self, ids):
if isinstance(ids, collections.Iterable):
return [self.examples[i] for i in ids]
else:
return self.examples[ids]
def get_prob_func_inputs(self, ids):
order = ['query_tokens', 'tgt_action_seq', 'tgt_action_seq_type',
'tgt_node_seq', 'tgt_par_rule_seq', 'tgt_par_t_seq']
max_src_seq_len = max(len(self.examples[i].query) for i in ids)
max_tgt_seq_len = max(len(self.examples[i].actions) for i in ids)
logging.debug('max. src sequence length: %d', max_src_seq_len)
logging.debug('max. tgt sequence length: %d', max_tgt_seq_len)
data = []
for entry in order:
if entry == 'query_tokens':
data.append(self.data_matrix[entry][ids, :max_src_seq_len])
else:
data.append(self.data_matrix[entry][ids, :max_tgt_seq_len])
return data
def init_data_matrices(self, max_query_length=70, max_example_action_num=100):
logging.info('init data matrices for [%s] dataset', self.name)
annot_vocab = self.annot_vocab
terminal_vocab = self.terminal_vocab
# np.max([len(e.query) for e in self.examples])
# np.max([len(e.rules) for e in self.examples])
query_tokens = self.data_matrix['query_tokens'] = np.zeros((self.count, max_query_length), dtype='int32')
tgt_node_seq = self.data_matrix['tgt_node_seq'] = np.zeros((self.count, max_example_action_num), dtype='int32')
tgt_par_rule_seq = self.data_matrix['tgt_par_rule_seq'] = np.zeros((self.count, max_example_action_num), dtype='int32')
tgt_par_t_seq = self.data_matrix['tgt_par_t_seq'] = np.zeros((self.count, max_example_action_num), dtype='int32')
tgt_action_seq = self.data_matrix['tgt_action_seq'] = np.zeros((self.count, max_example_action_num, 3), dtype='int32')
tgt_action_seq_type = self.data_matrix['tgt_action_seq_type'] = np.zeros((self.count, max_example_action_num, 3), dtype='int32')
for eid, example in enumerate(self.examples):
exg_query_tokens = example.query[:max_query_length]
exg_action_seq = example.actions[:max_example_action_num]
for tid, token in enumerate(exg_query_tokens):
token_id = annot_vocab[token]
query_tokens[eid, tid] = token_id
assert len(exg_action_seq) > 0
for t, action in enumerate(exg_action_seq):
if action.act_type == APPLY_RULE:
rule = action.data['rule']
tgt_action_seq[eid, t, 0] = self.grammar.rule_to_id[rule]
tgt_action_seq_type[eid, t, 0] = 1
elif action.act_type == GEN_TOKEN:
token = action.data['literal']
token_id = terminal_vocab[token]
tgt_action_seq[eid, t, 1] = token_id
tgt_action_seq_type[eid, t, 1] = 1
elif action.act_type == COPY_TOKEN:
src_token_idx = action.data['source_idx']
tgt_action_seq[eid, t, 2] = src_token_idx
tgt_action_seq_type[eid, t, 2] = 1
elif action.act_type == GEN_COPY_TOKEN:
token = action.data['literal']
token_id = terminal_vocab[token]
tgt_action_seq[eid, t, 1] = token_id
tgt_action_seq_type[eid, t, 1] = 1
src_token_idx = action.data['source_idx']
tgt_action_seq[eid, t, 2] = src_token_idx
tgt_action_seq_type[eid, t, 2] = 1
else:
raise RuntimeError('wrong action type!')
# parent information
rule = action.data['rule']
parent_rule = action.data['parent_rule']
tgt_node_seq[eid, t] = self.grammar.get_node_type_id(rule.parent)
if parent_rule:
tgt_par_rule_seq[eid, t] = self.grammar.rule_to_id[parent_rule]
else:
assert t == 0
tgt_par_rule_seq[eid, t] = -1
# parent hidden states
parent_t = action.data['parent_t']
tgt_par_t_seq[eid, t] = parent_t
example.dataset = self
class DataHelper(object):
@staticmethod
def canonicalize_query(query):
return query
def parse_django_dataset_nt_only():
from parse import parse_django
annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.anno'
vocab = gen_vocab(annot_file, vocab_size=4500)
code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code'
grammar, all_parse_trees = parse_django(code_file)
train_data = DataSet(vocab, grammar, name='train')
dev_data = DataSet(vocab, grammar, name='dev')
test_data = DataSet(vocab, grammar, name='test')
# train_data
train_annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/train.anno'
train_parse_trees = all_parse_trees[0:16000]
for line, parse_tree in zip(open(train_annot_file), train_parse_trees):
if parse_tree.is_leaf:
continue
line = line.strip()
tokens = tokenize(line)
entry = DataEntry(tokens, parse_tree)
train_data.add(entry)
train_data.init_data_matrices()
# dev_data
dev_annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/dev.anno'
dev_parse_trees = all_parse_trees[16000:17000]
for line, parse_tree in zip(open(dev_annot_file), dev_parse_trees):
if parse_tree.is_leaf:
continue
line = line.strip()
tokens = tokenize(line)
entry = DataEntry(tokens, parse_tree)
dev_data.add(entry)
dev_data.init_data_matrices()
# test_data
test_annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/test.anno'
test_parse_trees = all_parse_trees[17000:18805]
for line, parse_tree in zip(open(test_annot_file), test_parse_trees):
if parse_tree.is_leaf:
continue
line = line.strip()
tokens = tokenize(line)
entry = DataEntry(tokens, parse_tree)
test_data.add(entry)
test_data.init_data_matrices()
serialize_to_file((train_data, dev_data, test_data), 'django.typed_rule.bin')
def parse_django_dataset():
from lang.py.parse import parse_raw
from lang.util import escape
MAX_QUERY_LENGTH = 70
UNARY_CUTOFF_FREQ = 30
annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.anno'
code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code'
data = preprocess_dataset(annot_file, code_file)
for e in data:
e['parse_tree'] = parse_raw(e['code'])
parse_trees = [e['parse_tree'] for e in data]
# apply unary closures
# unary_closures = get_top_unary_closures(parse_trees, k=0, freq=UNARY_CUTOFF_FREQ)
# for i, parse_tree in enumerate(parse_trees):
# apply_unary_closures(parse_tree, unary_closures)
# build the grammar
grammar = get_grammar(parse_trees)
# write grammar
with open('django.grammar.unary_closure.txt', 'w') as f:
for rule in grammar:
f.write(rule.__repr__() + '\n')
# # build grammar ...
# from lang.py.py_dataset import extract_grammar
# grammar, all_parse_trees = extract_grammar(code_file)
annot_tokens = list(chain(*[e['query_tokens'] for e in data]))
annot_vocab = gen_vocab(annot_tokens, vocab_size=5000, freq_cutoff=3) # gen_vocab(annot_tokens, vocab_size=5980)
terminal_token_seq = []
empty_actions_count = 0
# helper function begins
def get_terminal_tokens(_terminal_str):
# _terminal_tokens = filter(None, re.split('([, .?!])', _terminal_str)) # _terminal_str.split('-SP-')
# _terminal_tokens = filter(None, re.split('( )', _terminal_str)) # _terminal_str.split('-SP-')
tmp_terminal_tokens = _terminal_str.split(' ')
_terminal_tokens = []
for token in tmp_terminal_tokens:
if token:
_terminal_tokens.append(token)
_terminal_tokens.append(' ')
return _terminal_tokens[:-1]
# return _terminal_tokens
# helper function ends
# first pass
for entry in data:
idx = entry['id']
query_tokens = entry['query_tokens']
code = entry['code']
parse_tree = entry['parse_tree']
for node in parse_tree.get_leaves():
if grammar.is_value_node(node):
terminal_val = node.value
terminal_str = str(terminal_val)
terminal_tokens = get_terminal_tokens(terminal_str)
for terminal_token in terminal_tokens:
assert len(terminal_token) > 0
terminal_token_seq.append(terminal_token)
terminal_vocab = gen_vocab(terminal_token_seq, vocab_size=5000, freq_cutoff=3)
assert '_STR:0_' in terminal_vocab
train_data = DataSet(annot_vocab, terminal_vocab, grammar, 'train_data')
dev_data = DataSet(annot_vocab, terminal_vocab, grammar, 'dev_data')
test_data = DataSet(annot_vocab, terminal_vocab, grammar, 'test_data')
all_examples = []
can_fully_gen_num = 0
# second pass
for entry in data:
idx = entry['id']
query_tokens = entry['query_tokens']
code = entry['code']
str_map = entry['str_map']
parse_tree = entry['parse_tree']
rule_list, rule_parents = parse_tree.get_productions(include_value_node=True)
actions = []
can_fully_gen = True
rule_pos_map = dict()
for rule_count, rule in enumerate(rule_list):
if not grammar.is_value_node(rule.parent):
assert rule.value is None
parent_rule = rule_parents[(rule_count, rule)][0]
if parent_rule:
parent_t = rule_pos_map[parent_rule]
else:
parent_t = 0
rule_pos_map[rule] = len(actions)
d = {'rule': rule, 'parent_t': parent_t, 'parent_rule': parent_rule}
action = Action(APPLY_RULE, d)
actions.append(action)
else:
assert rule.is_leaf
parent_rule = rule_parents[(rule_count, rule)][0]
parent_t = rule_pos_map[parent_rule]
terminal_val = rule.value
terminal_str = str(terminal_val)
terminal_tokens = get_terminal_tokens(terminal_str)
# assert len(terminal_tokens) > 0
for terminal_token in terminal_tokens:
term_tok_id = terminal_vocab[terminal_token]
tok_src_idx = -1
try:
tok_src_idx = query_tokens.index(terminal_token)
except ValueError:
pass
d = {'literal': terminal_token, 'rule': rule, 'parent_rule': parent_rule, 'parent_t': parent_t}
# cannot copy, only generation
# could be unk!
if tok_src_idx < 0 or tok_src_idx >= MAX_QUERY_LENGTH:
action = Action(GEN_TOKEN, d)
if terminal_token not in terminal_vocab:
if terminal_token not in query_tokens:
# print terminal_token
can_fully_gen = False
else: # copy
if term_tok_id != terminal_vocab.unk:
d['source_idx'] = tok_src_idx
action = Action(GEN_COPY_TOKEN, d)
else:
d['source_idx'] = tok_src_idx
action = Action(COPY_TOKEN, d)
actions.append(action)
d = {'literal': '<eos>', 'rule': rule, 'parent_rule': parent_rule, 'parent_t': parent_t}
actions.append(Action(GEN_TOKEN, d))
if len(actions) == 0:
empty_actions_count += 1
continue
example = DataEntry(idx, query_tokens, parse_tree, code, actions,
{'raw_code': entry['raw_code'], 'str_map': entry['str_map']})
if can_fully_gen:
can_fully_gen_num += 1
# train, valid, test
if 0 <= idx < 16000:
train_data.add(example)
elif 16000 <= idx < 17000:
dev_data.add(example)
else:
test_data.add(example)
all_examples.append(example)
# print statistics
max_query_len = max(len(e.query) for e in all_examples)
max_actions_len = max(len(e.actions) for e in all_examples)
serialize_to_file([len(e.query) for e in all_examples], 'query.len')
serialize_to_file([len(e.actions) for e in all_examples], 'actions.len')
logging.info('examples that can be fully reconstructed: %d/%d=%f',
can_fully_gen_num, len(all_examples),
can_fully_gen_num / len(all_examples))
logging.info('empty_actions_count: %d', empty_actions_count)
logging.info('max_query_len: %d', max_query_len)
logging.info('max_actions_len: %d', max_actions_len)
train_data.init_data_matrices()
dev_data.init_data_matrices()
test_data.init_data_matrices()
serialize_to_file((train_data, dev_data, test_data),
'data/django.cleaned.dataset.freq3.par_info.refact.space_only.order_by_ulink_len.bin')
# 'data/django.cleaned.dataset.freq5.par_info.refact.space_only.unary_closure.freq{UNARY_CUTOFF_FREQ}.order_by_ulink_len.bin'.format(UNARY_CUTOFF_FREQ=UNARY_CUTOFF_FREQ))
return train_data, dev_data, test_data
def check_terminals():
from parse import parse_django, unescape
grammar, parse_trees = parse_django('/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code')
annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.anno'
unique_terminals = set()
invalid_terminals = set()
for i, line in enumerate(open(annot_file)):
parse_tree = parse_trees[i]
utterance = line.strip()
leaves = parse_tree.get_leaves()
# tokens = set(nltk.word_tokenize(utterance))
leave_tokens = [l.label for l in leaves if l.label]
not_included = []
for leaf_token in leave_tokens:
leaf_token = str(leaf_token)
leaf_token = unescape(leaf_token)
if leaf_token not in utterance:
not_included.append(leaf_token)
if len(leaf_token) <= 15:
unique_terminals.add(leaf_token)
else:
invalid_terminals.add(leaf_token)
else:
if isinstance(leaf_token, str):
print leaf_token
# if not_included:
# print str(i) + '---' + ', '.join(not_included)
# print 'num of unique leaves: %d' % len(unique_terminals)
# print unique_terminals
#
# print 'num of invalid leaves: %d' % len(invalid_terminals)
# print invalid_terminals
def query_to_data(query, annot_vocab):
query_tokens = query.split(' ')
token_num = min(config.max_qeury_length, len(query_tokens))
data = np.zeros((1, token_num), dtype='int32')
for tid, token in enumerate(query_tokens[:token_num]):
token_id = annot_vocab[token]
data[0, tid] = token_id
return data
QUOTED_STRING_RE = re.compile(r"(?P<quote>['\"])(?P<string>.*?)(?<!\\)(?P=quote)")
def canonicalize_query(query):
"""
canonicalize the query, replace strings to a special place holder
"""
str_count = 0
str_map = dict()
matches = QUOTED_STRING_RE.findall(query)
# de-duplicate
cur_replaced_strs = set()
for match in matches:
# If one or more groups are present in the pattern,
# it returns a list of groups
quote = match[0]
str_literal = quote + match[1] + quote
if str_literal in cur_replaced_strs:
continue
# FIXME: substitute the ' % s ' with
if str_literal in ['\'%s\'', '\"%s\"']:
continue
str_repr = '_STR:%d_' % str_count
str_map[str_literal] = str_repr
query = query.replace(str_literal, str_repr)
str_count += 1
cur_replaced_strs.add(str_literal)
# tokenize
query_tokens = nltk.word_tokenize(query)
new_query_tokens = []
# break up function calls like foo.bar.func
for token in query_tokens:
new_query_tokens.append(token)
i = token.find('.')
if 0 < i < len(token) - 1:
new_tokens = ['['] + token.replace('.', ' . ').split(' ') + [']']
new_query_tokens.extend(new_tokens)
query = ' '.join(new_query_tokens)
return query, str_map
def canonicalize_example(query, code):
from lang.py.parse import parse_raw, parse_tree_to_python_ast, canonicalize_code as make_it_compilable
import astor, ast
canonical_query, str_map = canonicalize_query(query)
canonical_code = code
for str_literal, str_repr in str_map.iteritems():
canonical_code = canonical_code.replace(str_literal, '\'' + str_repr + '\'')
canonical_code = make_it_compilable(canonical_code)
# sanity check
parse_tree = parse_raw(canonical_code)
gold_ast_tree = ast.parse(canonical_code).body[0]
gold_source = astor.to_source(gold_ast_tree)
ast_tree = parse_tree_to_python_ast(parse_tree)
source = astor.to_source(ast_tree)
assert gold_source == source, 'sanity check fails: gold=[%s], actual=[%s]' % (gold_source, source)
query_tokens = canonical_query.split(' ')
return query_tokens, canonical_code, str_map
def process_query(query, code):
from parse import code_to_ast, ast_to_tree, tree_to_ast, parse
import astor
str_count = 0
str_map = dict()
match_count = 1
match = QUOTED_STRING_RE.search(query)
while match:
str_repr = '_STR:%d_' % str_count
str_literal = match.group(0)
str_string = match.group(2)
match_count += 1
# if match_count > 50:
# return
#
query = QUOTED_STRING_RE.sub(str_repr, query, 1)
str_map[str_literal] = str_repr
str_count += 1
match = QUOTED_STRING_RE.search(query)
code = code.replace(str_literal, '\'' + str_repr + '\'')
# clean the annotation
# query = query.replace('.', ' . ')
for k, v in str_map.iteritems():
if k == '\'%s\'' or k == '\"%s\"':
query = query.replace(v, k)
code = code.replace('\'' + v + '\'', k)
# tokenize
query_tokens = nltk.word_tokenize(query)
new_query_tokens = []
# break up function calls
for token in query_tokens:
new_query_tokens.append(token)
i = token.find('.')
if 0 < i < len(token) - 1:
new_tokens = ['['] + token.replace('.', ' . ').split(' ') + [']']
new_query_tokens.extend(new_tokens)
# check if the code compiles
tree = parse(code)
ast_tree = tree_to_ast(tree)
astor.to_source(ast_tree)
return new_query_tokens, code, str_map
def preprocess_dataset(annot_file, code_file):
f_annot = open('annot.all.canonicalized.txt', 'w')
f_code = open('code.all.canonicalized.txt', 'w')
examples = []
err_num = 0
for idx, (annot, code) in enumerate(zip(open(annot_file), open(code_file))):
annot = annot.strip()
code = code.strip()
try:
clean_query_tokens, clean_code, str_map = canonicalize_example(annot, code)
example = {'id': idx, 'query_tokens': clean_query_tokens, 'code': clean_code,
'str_map': str_map, 'raw_code': code}
examples.append(example)
f_annot.write('example# %d\n' % idx)
f_annot.write(' '.join(clean_query_tokens) + '\n')
f_annot.write('%d\n' % len(str_map))
for k, v in str_map.iteritems():
f_annot.write('%s ||| %s\n' % (k, v))
f_code.write('example# %d\n' % idx)
f_code.write(clean_code + '\n')
except:
print code
err_num += 1
idx += 1
f_annot.close()
f_annot.close()
# serialize_to_file(examples, 'django.cleaned.bin')
print 'error num: %d' % err_num
print 'preprocess_dataset: cleaned example num: %d' % len(examples)
return examples
if __name__== '__main__':
from nn.utils.generic_utils import init_logging
init_logging('parse.log')
# annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.anno'
# code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code'
# preprocess_dataset(annot_file, code_file)
# parse_django_dataset()
# check_terminals()
# print process_query(""" ALLOWED_VARIABLE_CHARS is a string 'abcdefgh"ijklm" nop"%s"qrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_.'.""")
# for i, query in enumerate(open('/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.anno')):
# print i, process_query(query)
# clean_dataset()
parse_django_dataset()
# from lang.py.py_dataset import parse_hs_dataset
# parse_hs_dataset()
================================================
FILE: decoder.py
================================================
import traceback
import config
from model import *
def decode_python_dataset(model, dataset, verbose=True):
from lang.py.parse import decode_tree_to_python_ast
if verbose:
logging.info('decoding [%s] set, num. examples: %d', dataset.name, dataset.count)
decode_results = []
cum_num = 0
for example in dataset.examples:
cand_list = model.decode(example, dataset.grammar, dataset.terminal_vocab,
beam_size=config.beam_size, max_time_step=config.decode_max_time_step)
exg_decode_results = []
for cid, cand in enumerate(cand_list[:10]):
try:
ast_tree = decode_tree_to_python_ast(cand.tree)
code = astor.to_source(ast_tree)
exg_decode_results.append((cid, cand, ast_tree, code))
except:
if verbose:
print "Exception in converting tree to code:"
print '-' * 60
print 'raw_id: %d, beam pos: %d' % (example.raw_id, cid)
traceback.print_exc(file=sys.stdout)
print '-' * 60
cum_num += 1
if cum_num % 50 == 0 and verbose:
print '%d examples so far ...' % cum_num
decode_results.append(exg_decode_results)
return decode_results
# serialize_to_file(decode_results, '%s.decode_results.profile' % dataset.name)
def decode_ifttt_dataset(model, dataset, verbose=True):
if verbose:
logging.info('decoding [%s] set, num. examples: %d', dataset.name, dataset.count)
decode_results = []
cum_num = 0
for example in dataset.examples:
cand_list = model.decode(example, dataset.grammar, dataset.terminal_vocab,
beam_size=config.beam_size, max_time_step=config.decode_max_time_step)
exg_decode_results = []
for cid, cand in enumerate(cand_list[:10]):
try:
exg_decode_results.append((cid, cand))
except:
if verbose:
print "Exception in converting tree to code:"
print '-' * 60
print 'raw_id: %d, beam pos: %d' % (example.raw_id, cid)
traceback.print_exc(file=sys.stdout)
print '-' * 60
cum_num += 1
if cum_num % 50 == 0 and verbose:
print '%d examples so far ...' % cum_num
decode_results.append(exg_decode_results)
return decode_results
================================================
FILE: evaluation.py
================================================
# -*- coding: UTF-8 -*-
from __future__ import division
import os
from nltk.translate.bleu_score import sentence_bleu, corpus_bleu, SmoothingFunction
import logging
import traceback
from nn.utils.generic_utils import init_logging
from model import *
DJANGO_ANNOT_FILE = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.anno'
def tokenize_for_bleu_eval(code):
code = re.sub(r'([^A-Za-z0-9_])', r' \1 ', code)
code = re.sub(r'([a-z])([A-Z])', r'\1 \2', code)
code = re.sub(r'\s+', ' ', code)
code = code.replace('"', '`')
code = code.replace('\'', '`')
tokens = [t for t in code.split(' ') if t]
return tokens
def evaluate(model, dataset, verbose=True):
if verbose:
logging.info('evaluating [%s] dataset, [%d] examples' % (dataset.name, dataset.count))
exact_match_ratio = 0.0
for example in dataset.examples:
logging.info('evaluating example [%d]' % example.eid)
hyps, hyp_scores = model.decode(example, max_time_step=config.decode_max_time_step)
gold_rules = example.rules
if len(hyps) == 0:
logging.warning('no decoding result for example [%d]!' % example.eid)
continue
best_hyp = hyps[0]
predict_rules = [dataset.grammar.id_to_rule[rid] for rid in best_hyp]
assert len(predict_rules) > 0 and len(gold_rules) > 0
exact_match = sorted(gold_rules, key=lambda x: x.__repr__()) == sorted(predict_rules, key=lambda x: x.__repr__())
if exact_match:
exact_match_ratio += 1
# p = len(predict_rules.intersection(gold_rules)) / len(predict_rules)
# r = len(predict_rules.intersection(gold_rules)) / len(gold_rules)
exact_match_ratio /= dataset.count
logging.info('exact_match_ratio = %f' % exact_match_ratio)
return exact_match_ratio
def evaluate_decode_results(dataset, decode_results, verbose=True):
from lang.py.parse import tokenize_code, de_canonicalize_code
# tokenize_code = tokenize_for_bleu_eval
import ast
assert dataset.count == len(decode_results)
f = f_decode = None
if verbose:
f = open(dataset.name + '.exact_match', 'w')
exact_match_ids = []
f_decode = open(dataset.name + '.decode_results.txt', 'w')
eid_to_annot = dict()
if config.data_type == 'django':
for raw_id, line in enumerate(open(DJANGO_ANNOT_FILE)):
eid_to_annot[raw_id] = line.strip()
f_bleu_eval_ref = open(dataset.name + '.ref', 'w')
f_bleu_eval_hyp = open(dataset.name + '.hyp', 'w')
f_generated_code = open(dataset.name + '.geneated_code', 'w')
logging.info('evaluating [%s] set, [%d] examples', dataset.name, dataset.count)
cum_oracle_bleu = 0.0
cum_oracle_acc = 0.0
cum_bleu = 0.0
cum_acc = 0.0
sm = SmoothingFunction()
all_references = []
all_predictions = []
if all(len(cand) == 0 for cand in decode_results):
logging.ERROR('Empty decoding results for the current dataset!')
return -1, -1
for eid in range(dataset.count):
example = dataset.examples[eid]
ref_code = example.code
ref_ast_tree = ast.parse(ref_code).body[0]
refer_source = astor.to_source(ref_ast_tree).strip()
# refer_source = ref_code
refer_tokens = tokenize_code(refer_source)
cur_example_correct = False
decode_cands = decode_results[eid]
if len(decode_cands) == 0:
continue
decode_cand = decode_cands[0]
cid, cand, ast_tree, code = decode_cand
code = astor.to_source(ast_tree).strip()
# simple_url_2_re = re.compile('_STR:0_', re.))
try:
predict_tokens = tokenize_code(code)
except:
logging.error('error in tokenizing [%s]', code)
continue
if refer_tokens == predict_tokens:
cum_acc += 1
cur_example_correct = True
if verbose:
exact_match_ids.append(example.raw_id)
f.write('-' * 60 + '\n')
f.write('example_id: %d\n' % example.raw_id)
f.write(code + '\n')
f.write('-' * 60 + '\n')
if config.data_type == 'django':
ref_code_for_bleu = example.meta_data['raw_code']
pred_code_for_bleu = de_canonicalize_code(code, example.meta_data['raw_code'])
# ref_code_for_bleu = de_canonicalize_code(ref_code_for_bleu, example.meta_data['raw_code'])
# convert canonicalized code to raw code
for literal, place_holder in example.meta_data['str_map'].iteritems():
pred_code_for_bleu = pred_code_for_bleu.replace('\'' + place_holder + '\'', literal)
# ref_code_for_bleu = ref_code_for_bleu.replace('\'' + place_holder + '\'', literal)
elif config.data_type == 'hs':
ref_code_for_bleu = ref_code
pred_code_for_bleu = code
# we apply Ling Wang's trick when evaluating BLEU scores
refer_tokens_for_bleu = tokenize_for_bleu_eval(ref_code_for_bleu)
pred_tokens_for_bleu = tokenize_for_bleu_eval(pred_code_for_bleu)
# The if-chunk below is for debugging purpose, sometimes the reference cannot match with the prediction
# because of inconsistent quotes (e.g., single quotes in reference, double quotes in prediction).
# However most of these cases are solved by cannonicalizing the reference code using astor (parse the reference
# into AST, and regenerate the code. Use this regenerated one as the reference)
weired = False
if refer_tokens_for_bleu == pred_tokens_for_bleu and refer_tokens != predict_tokens:
# cum_acc += 1
weired = True
elif refer_tokens == predict_tokens:
# weired!
# weired = True
pass
shorter = len(pred_tokens_for_bleu) < len(refer_tokens_for_bleu)
all_references.append([refer_tokens_for_bleu])
all_predictions.append(pred_tokens_for_bleu)
# try:
ngram_weights = [0.25] * min(4, len(refer_tokens_for_bleu))
bleu_score = sentence_bleu([refer_tokens_for_bleu], pred_tokens_for_bleu, weights=ngram_weights, smoothing_function=sm.method3)
cum_bleu += bleu_score
# except:
# pass
if verbose:
print 'raw_id: %d, bleu_score: %f' % (example.raw_id, bleu_score)
f_decode.write('-' * 60 + '\n')
f_decode.write('example_id: %d\n' % example.raw_id)
f_decode.write('intent: \n')
if config.data_type == 'django':
f_decode.write(eid_to_annot[example.raw_id] + '\n')
elif config.data_type == 'hs':
f_decode.write(' '.join(example.query) + '\n')
f_bleu_eval_ref.write(' '.join(refer_tokens_for_bleu) + '\n')
f_bleu_eval_hyp.write(' '.join(pred_tokens_for_bleu) + '\n')
f_decode.write('canonicalized reference: \n')
f_decode.write(refer_source + '\n')
f_decode.write('canonicalized prediction: \n')
f_decode.write(code + '\n')
f_decode.write('reference code for bleu calculation: \n')
f_decode.write(ref_code_for_bleu + '\n')
f_decode.write('predicted code for bleu calculation: \n')
f_decode.write(pred_code_for_bleu + '\n')
f_decode.write('pred_shorter_than_ref: %s\n' % shorter)
f_decode.write('weired: %s\n' % weired)
f_decode.write('-' * 60 + '\n')
# for Hiro's evaluation
f_generated_code.write(pred_code_for_bleu.replace('\n', '#NEWLINE#') + '\n')
# compute oracle
best_score = 0.
cur_oracle_acc = 0.
for decode_cand in decode_cands[:config.beam_size]:
cid, cand, ast_tree, code = decode_cand
try:
code = astor.to_source(ast_tree).strip()
predict_tokens = tokenize_code(code)
if predict_tokens == refer_tokens:
cur_oracle_acc = 1
if config.data_type == 'django':
pred_code_for_bleu = de_canonicalize_code(code, example.meta_data['raw_code'])
# convert canonicalized code to raw code
for literal, place_holder in example.meta_data['str_map'].iteritems():
pred_code_for_bleu = pred_code_for_bleu.replace('\'' + place_holder + '\'', literal)
elif config.data_type == 'hs':
pred_code_for_bleu = code
# we apply Ling Wang's trick when evaluating BLEU scores
pred_tokens_for_bleu = tokenize_for_bleu_eval(pred_code_for_bleu)
ngram_weights = [0.25] * min(4, len(refer_tokens_for_bleu))
bleu_score = sentence_bleu([refer_tokens_for_bleu], pred_tokens_for_bleu,
weights=ngram_weights,
smoothing_function=sm.method3)
if bleu_score > best_score:
best_score = bleu_score
except:
continue
cum_oracle_bleu += best_score
cum_oracle_acc += cur_oracle_acc
cum_bleu /= dataset.count
cum_acc /= dataset.count
cum_oracle_bleu /= dataset.count
cum_oracle_acc /= dataset.count
logging.info('corpus level bleu: %f', corpus_bleu(all_references, all_predictions, smoothing_function=sm.method3))
logging.info('sentence level bleu: %f', cum_bleu)
logging.info('accuracy: %f', cum_acc)
logging.info('oracle bleu: %f', cum_oracle_bleu)
logging.info('oracle accuracy: %f', cum_oracle_acc)
if verbose:
f.write(', '.join(str(i) for i in exact_match_ids))
f.close()
f_decode.close()
f_bleu_eval_ref.close()
f_bleu_eval_hyp.close()
f_generated_code.close()
return cum_bleu, cum_acc
def analyze_decode_results(dataset, decode_results, verbose=True):
from lang.py.parse import tokenize_code, de_canonicalize_code
# tokenize_code = tokenize_for_bleu_eval
import ast
assert dataset.count == len(decode_results)
f = f_decode = None
if verbose:
f = open(dataset.name + '.exact_match', 'w')
exact_match_ids = []
f_decode = open(dataset.name + '.decode_results.txt', 'w')
eid_to_annot = dict()
if config.data_type == 'django':
for raw_id, line in enumerate(open(DJANGO_ANNOT_FILE)):
eid_to_annot[raw_id] = line.strip()
f_bleu_eval_ref = open(dataset.name + '.ref', 'w')
f_bleu_eval_hyp = open(dataset.name + '.hyp', 'w')
logging.info('evaluating [%s] set, [%d] examples', dataset.name, dataset.count)
cum_oracle_bleu = 0.0
cum_oracle_acc = 0.0
cum_bleu = 0.0
cum_acc = 0.0
sm = SmoothingFunction()
all_references = []
all_predictions = []
if all(len(cand) == 0 for cand in decode_results):
logging.ERROR('Empty decoding results for the current dataset!')
return -1, -1
binned_results_dict = defaultdict(list)
def get_binned_key(ast_size):
cutoff = 50 if config.data_type == 'django' else 250
k = 10 if config.data_type == 'django' else 25 # for hs
if ast_size >= cutoff:
return '%d - inf' % cutoff
lower = int(ast_size / k) * k
upper = lower + k
key = '%d - %d' % (lower, upper)
return key
for eid in range(dataset.count):
example = dataset.examples[eid]
ref_code = example.code
ref_ast_tree = ast.parse(ref_code).body[0]
refer_source = astor.to_source(ref_ast_tree).strip()
# refer_source = ref_code
refer_tokens = tokenize_code(refer_source)
cur_example_acc = 0.0
decode_cands = decode_results[eid]
if len(decode_cands) == 0:
continue
decode_cand = decode_cands[0]
cid, cand, ast_tree, code = decode_cand
code = astor.to_source(ast_tree).strip()
# simple_url_2_re = re.compile('_STR:0_', re.))
try:
predict_tokens = tokenize_code(code)
except:
logging.error('error in tokenizing [%s]', code)
continue
if refer_tokens == predict_tokens:
cum_acc += 1
cur_example_acc = 1.0
if verbose:
exact_match_ids.append(example.raw_id)
f.write('-' * 60 + '\n')
f.write('example_id: %d\n' % example.raw_id)
f.write(code + '\n')
f.write('-' * 60 + '\n')
if config.data_type == 'django':
ref_code_for_bleu = example.meta_data['raw_code']
pred_code_for_bleu = de_canonicalize_code(code, example.meta_data['raw_code'])
# ref_code_for_bleu = de_canonicalize_code(ref_code_for_bleu, example.meta_data['raw_code'])
# convert canonicalized code to raw code
for literal, place_holder in example.meta_data['str_map'].iteritems():
pred_code_for_bleu = pred_code_for_bleu.replace('\'' + place_holder + '\'', literal)
# ref_code_for_bleu = ref_code_for_bleu.replace('\'' + place_holder + '\'', literal)
elif config.data_type == 'hs':
ref_code_for_bleu = ref_code
pred_code_for_bleu = code
# we apply Ling Wang's trick when evaluating BLEU scores
refer_tokens_for_bleu = tokenize_for_bleu_eval(ref_code_for_bleu)
pred_tokens_for_bleu = tokenize_for_bleu_eval(pred_code_for_bleu)
shorter = len(pred_tokens_for_bleu) < len(refer_tokens_for_bleu)
all_references.append([refer_tokens_for_bleu])
all_predictions.append(pred_tokens_for_bleu)
# try:
ngram_weights = [0.25] * min(4, len(refer_tokens_for_bleu))
bleu_score = sentence_bleu([refer_tokens_for_bleu], pred_tokens_for_bleu, weights=ngram_weights, smoothing_function=sm.method3)
cum_bleu += bleu_score
# except:
# pass
if verbose:
print 'raw_id: %d, bleu_score: %f' % (example.raw_id, bleu_score)
f_decode.write('-' * 60 + '\n')
f_decode.write('example_id: %d\n' % example.raw_id)
f_decode.write('intent: \n')
if config.data_type == 'django':
f_decode.write(eid_to_annot[example.raw_id] + '\n')
elif config.data_type == 'hs':
f_decode.write(' '.join(example.query) + '\n')
f_bleu_eval_ref.write(' '.join(refer_tokens_for_bleu) + '\n')
f_bleu_eval_hyp.write(' '.join(pred_tokens_for_bleu) + '\n')
f_decode.write('canonicalized reference: \n')
f_decode.write(refer_source + '\n')
f_decode.write('canonicalized prediction: \n')
f_decode.write(code + '\n')
f_decode.write('reference code for bleu calculation: \n')
f_decode.write(ref_code_for_bleu + '\n')
f_decode.write('predicted code for bleu calculation: \n')
f_decode.write(pred_code_for_bleu + '\n')
f_decode.write('pred_shorter_than_ref: %s\n' % shorter)
# f_decode.write('weired: %s\n' % weired)
f_decode.write('-' * 60 + '\n')
# compute oracle
best_bleu_score = 0.
cur_oracle_acc = 0.
for decode_cand in decode_cands[:config.beam_size]:
cid, cand, ast_tree, code = decode_cand
try:
code = astor.to_source(ast_tree).strip()
predict_tokens = tokenize_code(code)
if predict_tokens == refer_tokens:
cur_oracle_acc = 1.
if config.data_type == 'django':
pred_code_for_bleu = de_canonicalize_code(code, example.meta_data['raw_code'])
# convert canonicalized code to raw code
for literal, place_holder in example.meta_data['str_map'].iteritems():
pred_code_for_bleu = pred_code_for_bleu.replace('\'' + place_holder + '\'', literal)
elif config.data_type == 'hs':
pred_code_for_bleu = code
# we apply Ling Wang's trick when evaluating BLEU scores
pred_tokens_for_bleu = tokenize_for_bleu_eval(pred_code_for_bleu)
ngram_weights = [0.25] * min(4, len(refer_tokens_for_bleu))
cand_bleu_score = sentence_bleu([refer_tokens_for_bleu], pred_tokens_for_bleu,
weights=ngram_weights,
smoothing_function=sm.method3)
if cand_bleu_score > best_bleu_score:
best_bleu_score = cand_bleu_score
except:
continue
cum_oracle_bleu += best_bleu_score
cum_oracle_acc += cur_oracle_acc
ref_ast_size = example.parse_tree.size
binned_key = get_binned_key(ref_ast_size)
binned_results_dict[binned_key].append((bleu_score, cur_example_acc, best_bleu_score, cur_oracle_acc))
cum_bleu /= dataset.count
cum_acc /= dataset.count
cum_oracle_bleu /= dataset.count
cum_oracle_acc /= dataset.count
logging.info('corpus level bleu: %f', corpus_bleu(all_references, all_predictions, smoothing_function=sm.method3))
logging.info('sentence level bleu: %f', cum_bleu)
logging.info('accuracy: %f', cum_acc)
logging.info('oracle bleu: %f', cum_oracle_bleu)
logging.info('oracle accuracy: %f', cum_oracle_acc)
keys = sorted(binned_results_dict, key=lambda x: int(x.split(' - ')[0]))
Y = [[], [], [], []]
X = []
for binned_key in keys:
entry = binned_results_dict[binned_key]
avg_bleu = np.average([t[0] for t in entry])
avg_acc = np.average([t[1] for t in entry])
avg_oracle_bleu = np.average([t[2] for t in entry])
avg_oracle_acc = np.average([t[3] for t in entry])
print binned_key, avg_bleu, avg_acc, avg_oracle_bleu, avg_oracle_acc, len(entry)
Y[0].append(avg_bleu)
Y[1].append(avg_acc)
Y[2].append(avg_oracle_bleu)
Y[3].append(avg_oracle_acc)
X.append(int(binned_key.split(' - ')[0]))
import matplotlib.pyplot as plt
from pylab import rcParams
rcParams['figure.figsize'] = 6, 2.5
if config.data_type == 'django':
fig, ax = plt.subplots()
ax.plot(X, Y[0], 'bs--', label='BLEU', lw=1.2)
# ax.plot(X, Y[2], 'r^--', label='oracle BLEU', lw=1.2)
ax.plot(X, Y[1], 'r^--', label='acc', lw=1.2)
# ax.plot(X, Y[3], 'r^--', label='oracle acc', lw=1.2)
ax.set_ylabel('Performance')
ax.set_xlabel('Reference AST Size (# nodes)')
plt.legend(loc='upper right', ncol=6)
plt.tight_layout()
# plt.savefig('django_acc_ast_size.pdf', dpi=300)
# os.system('pcrop.sh django_acc_ast_size.pdf')
plt.savefig('django_perf_ast_size.pdf', dpi=300)
os.system('pcrop.sh django_perf_ast_size.pdf')
else:
fig, ax = plt.subplots()
ax.plot(X, Y[0], 'bs--', label='BLEU', lw=1.2)
# ax.plot(X, Y[2], 'r^--', label='oracle BLEU', lw=1.2)
ax.plot(X, Y[1], 'r^--', label='acc', lw=1.2)
# ax.plot(X, Y[3], 'r^--', label='oracle acc', lw=1.2)
ax.set_ylabel('Performance')
ax.set_xlabel('Reference AST Size (# nodes)')
plt.legend(loc='upper right', ncol=6)
plt.tight_layout()
# plt.savefig('hs_bleu_ast_size.pdf', dpi=300)
# os.system('pcrop.sh hs_bleu_ast_size.pdf')
plt.savefig('hs_perf_ast_size.pdf', dpi=300)
os.system('pcrop.sh hs_perf_ast_size.pdf')
if verbose:
f.write(', '.join(str(i) for i in exact_match_ids))
f.close()
f_decode.close()
f_bleu_eval_ref.close()
f_bleu_eval_hyp.close()
return cum_bleu, cum_acc
def evaluate_seq2seq_decode_results(dataset, seq2seq_decode_file, seq2seq_ref_file, verbose=True, is_nbest=False):
from lang.py.parse import parse
f_seq2seq_decode = open(seq2seq_decode_file)
f_seq2seq_ref = open(seq2seq_ref_file)
if verbose:
logging.info('evaluating [%s] set, [%d] examples', dataset.name, dataset.count)
cum_bleu = 0.0
cum_acc = 0.0
sm = SmoothingFunction()
decode_file_data = [l.strip() for l in f_seq2seq_decode.readlines()]
ref_code_data = [l.strip() for l in f_seq2seq_ref.readlines()]
if is_nbest:
for i in xrange(len(decode_file_data)):
d = decode_file_data[i].split(' ||| ')
decode_file_data[i] = (int(d[0]), d[1])
def is_well_formed_python_code(_hyp):
try:
_hyp = _hyp.replace('#NEWLINE#', '\n').replace('#INDENT#', ' ').replace(' #MERGE# ', '')
hyp_ast_tree = parse(_hyp)
return True
except:
return False
for eid in range(dataset.count):
example = dataset.examples[eid]
cur_example_correct = False
if is_nbest:
# find the best-scored well-formed code from the n-best list
n_best_list = filter(lambda x: x[0] == eid, decode_file_data)
code = top_scored_code = n_best_list[0][1]
for _, hyp in n_best_list:
if is_well_formed_python_code(hyp):
code = hyp
break
if top_scored_code != code:
print '*' * 60
print top_scored_code
print code
print '*' * 60
code = n_best_list[0][1]
else:
code = decode_file_data[eid]
code = code.replace('#NEWLINE#', '\n').replace('#INDENT#', ' ').replace(' #MERGE# ', '')
ref_code = ref_code_data[eid].replace('#NEWLINE#', '\n').replace('#INDENT#', ' ').replace(' #MERGE# ', '')
if code == ref_code:
cum_acc += 1
cur_example_correct = True
if config.data_type == 'django':
ref_code_for_bleu = example.meta_data['raw_code']
pred_code_for_bleu = code # de_canonicalize_code(code, example.meta_data['raw_code'])
# ref_code_for_bleu = de_canonicalize_code(ref_code_for_bleu, example.meta_data['raw_code'])
# convert canonicalized code to raw code
for literal, place_holder in example.meta_data['str_map'].iteritems():
pred_code_for_bleu = pred_code_for_bleu.replace('\'' + place_holder + '\'', literal)
# ref_code_for_bleu = ref_code_for_bleu.replace('\'' + place_holder + '\'', literal)
elif config.data_type == 'hs':
ref_code_for_bleu = example.code
pred_code_for_bleu = code
# we apply Ling Wang's trick when evaluating BLEU scores
refer_tokens_for_bleu = tokenize_for_bleu_eval(ref_code_for_bleu)
pred_tokens_for_bleu = tokenize_for_bleu_eval(pred_code_for_bleu)
ngram_weights = [0.25] * min(4, len(refer_tokens_for_bleu))
bleu_score = sentence_bleu([refer_tokens_for_bleu], pred_tokens_for_bleu, weights=ngram_weights, smoothing_function=sm.method3)
cum_bleu += bleu_score
cum_bleu /= dataset.count
cum_acc /= dataset.count
logging.info('sentence level bleu: %f', cum_bleu)
logging.info('accuracy: %f', cum_acc)
def evaluate_seq2tree_sample_file(sample_file, id_file, dataset):
from lang.py.parse import tokenize_code, de_canonicalize_code
import ast, astor
import traceback
from lang.py.seq2tree_exp import seq2tree_repr_to_ast_tree, merge_broken_value_nodes
from lang.py.parse import decode_tree_to_python_ast
f_sample = open(sample_file)
line_id_to_raw_id = OrderedDict()
raw_id_to_eid = OrderedDict()
for i, line in enumerate(open(id_file)):
raw_id = int(line.strip())
line_id_to_raw_id[i] = raw_id
for eid in range(len(dataset.examples)):
raw_id_to_eid[dataset.examples[eid].raw_id] = eid
rare_word_map = defaultdict(dict)
if config.seq2tree_rareword_map:
logging.info('use rare word map')
for i, line in enumerate(open(config.seq2tree_rareword_map)):
line = line.strip()
if line:
for e in line.split(' '):
d = e.split(':', 1)
rare_word_map[i][int(d[0])] = d[1]
cum_bleu = 0.0
cum_acc = 0.0
sm = SmoothingFunction()
convert_error_num = 0
for i in range(len(line_id_to_raw_id)):
# print 'working on %d' % i
ref_repr = f_sample.readline().strip()
predict_repr = f_sample.readline().strip()
predict_repr = predict_repr.replace('<U>', 'str{}{unk}') # .replace('( )', '( str{}{unk} )')
f_sample.readline()
# if ' ( ) ' in ref_repr:
# print i, ref_repr
if i in rare_word_map:
for unk_id, w in rare_word_map[i].iteritems():
ref_repr = ref_repr.replace(' str{}{unk_%s} ' % unk_id, ' str{}{%s} ' % w)
predict_repr = predict_repr.replace(' str{}{unk_%s} ' % unk_id, ' str{}{%s} ' % w)
try:
parse_tree = seq2tree_repr_to_ast_tree(predict_repr)
merge_broken_value_nodes(parse_tree)
except:
print 'error when converting:'
print predict_repr
convert_error_num += 1
continue
raw_id = line_id_to_raw_id[i]
eid = raw_id_to_eid[raw_id]
example = dataset.examples[eid]
ref_code = example.code
ref_ast_tree = ast.parse(ref_code).body[0]
refer_source = astor.to_source(ref_ast_tree).strip()
refer_tokens = tokenize_code(refer_source)
try:
ast_tree = decode_tree_to_python_ast(parse_tree)
code = astor.to_source(ast_tree).strip()
except:
print "Exception in converting tree to code:"
print '-' * 60
print 'line id: %d' % i
traceback.print_exc(file=sys.stdout)
print '-' * 60
convert_error_num += 1
continue
if config.data_type == 'django':
ref_code_for_bleu = example.meta_data['raw_code']
pred_code_for_bleu = de_canonicalize_code(code, example.meta_data['raw_code'])
# convert canonicalized code to raw code
for literal, place_holder in example.meta_data['str_map'].iteritems():
pred_code_for_bleu = pred_code_for_bleu.replace('\'' + place_holder + '\'', literal)
elif config.data_type == 'hs':
ref_code_for_bleu = ref_code
pred_code_for_bleu = code
# we apply Ling Wang's trick when evaluating BLEU scores
refer_tokens_for_bleu = tokenize_for_bleu_eval(ref_code_for_bleu)
pred_tokens_for_bleu = tokenize_for_bleu_eval(pred_code_for_bleu)
predict_tokens = tokenize_code(code)
# if ref_repr == predict_repr:
if predict_tokens == refer_tokens:
cum_acc += 1
ngram_weights = [0.25] * min(4, len(refer_tokens_for_bleu))
bleu_score = sentence_bleu([refer_tokens_for_bleu], pred_tokens_for_bleu, weights=ngram_weights,
smoothing_function=sm.method3)
cum_bleu += bleu_score
cum_bleu /= len(line_id_to_raw_id)
cum_acc /= len(line_id_to_raw_id)
logging.info('nun. examples: %d', len(line_id_to_raw_id))
logging.info('num. errors when converting repr to tree: %d', convert_error_num)
logging.info('ratio of grammatically incorrect trees: %f', convert_error_num / float(len(line_id_to_raw_id)))
logging.info('sentence level bleu: %f', cum_bleu)
logging.info('accuracy: %f', cum_acc)
def evaluate_ifttt_results(dataset, decode_results, verbose=True):
assert dataset.count == len(decode_results)
f = f_decode = None
if verbose:
f = open(dataset.name + '.exact_match', 'w')
exact_match_ids = []
f_decode = open(os.path.join(config.output_dir, dataset.name + '.decode_results.txt'), 'w')
logging.info('evaluating [%s] set, [%d] examples', dataset.name, dataset.count)
cum_channel_acc = 0.0
cum_channel_func_acc = 0.0
cum_prod_f1 = 0.0
cum_oracle_prod_f1 = 0.0
if all(len(cand) == 0 for cand in decode_results):
logging.ERROR('Empty decoding results for the current dataset!')
return -1, -1, -1
for eid in range(dataset.count):
example = dataset.examples[eid]
ref_parse_tree = example.parse_tree
decode_candidates = decode_results[eid]
if len(decode_candidates) == 0:
continue
decode_cand = decode_candidates[0]
cid, cand_hyp = decode_cand
predict_parse_tree = cand_hyp.tree
exact_match = predict_parse_tree == ref_parse_tree
channel_acc, channel_func_acc, prod_f1 = ifttt_metric(predict_parse_tree, ref_parse_tree)
cum_channel_acc += channel_acc
cum_channel_func_acc += channel_func_acc
cum_prod_f1 += prod_f1
if verbose:
if exact_match:
exact_match_ids.append(example.raw_id)
print 'raw_id: %d, prod_f1: %f' % (example.raw_id, prod_f1)
f_decode.write('-' * 60 + '\n')
f_decode.write('example_id: %d\n' % example.raw_id)
f_decode.write('intent: \n')
f_decode.write(' '.join(example.query) + '\n')
f_decode.write('reference: \n')
f_decode.write(str(ref_parse_tree) + '\n')
f_decode.write('prediction: \n')
f_decode.write(str(predict_parse_tree) + '\n')
f_decode.write('-' * 60 + '\n')
# compute oracle
best_prod_f1 = -1.
for decode_cand in decode_candidates[:10]:
cid, cand_hyp = decode_cand
predict_parse_tree = cand_hyp.tree
channel_acc, channel_func_acc, prod_f1 = ifttt_metric(predict_parse_tree, ref_parse_tree)
if prod_f1 > best_prod_f1:
best_prod_f1 = prod_f1
cum_oracle_prod_f1 += best_prod_f1
cum_channel_acc /= dataset.count
cum_channel_func_acc /= dataset.count
cum_prod_f1 /= dataset.count
cum_oracle_prod_f1 /= dataset.count
logging.info('channel_acc: %f', cum_channel_acc)
logging.info('channel_func_acc: %f', cum_channel_func_acc)
logging.info('prod_f1: %f', cum_prod_f1)
logging.info('oracle prod_f1: %f', cum_oracle_prod_f1)
if verbose:
f.write(', '.join(str(i) for i in exact_match_ids))
f.close()
f_decode.close()
return cum_channel_acc, cum_channel_func_acc, cum_prod_f1
def ifttt_metric(predict_parse_tree, ref_parse_tree):
channel_acc = channel_func_acc = prod_f1 = 0.
# channel acc.
channel_match = False
if predict_parse_tree['TRIGGER'].children[0].type == ref_parse_tree['TRIGGER'].children[0].type and \
predict_parse_tree['ACTION'].children[0].type == ref_parse_tree['ACTION'].children[0].type:
channel_acc += 1.
channel_match = True
# channel+func acc.
if channel_match and predict_parse_tree['TRIGGER'].children[0].children[0].type == ref_parse_tree['TRIGGER'].children[0].children[0].type and \
predict_parse_tree['ACTION'].children[0].children[0].type == ref_parse_tree['ACTION'].children[0].children[0].type:
channel_func_acc += 1.
# predict_parse_tree is of type DecodingTree, different from reference tree!
# if predict_parse_tree == ref_parse_tree:
# channel_func_acc += 1.
# prod. F1
ref_rules, _ = ref_parse_tree.get_productions()
predict_rules, _ = predict_parse_tree.get_productions()
prod_f1 = len(set(ref_rules).intersection(set(predict_rules))) / len(ref_rules)
return channel_acc, channel_func_acc, prod_f1
def decode_and_evaluate_ifttt(model, test_data):
raw_ids = [int(i.strip()) for i in open(config.ifttt_test_split)] # 'data/ifff.test_data.gold.id'
eids = [i for i, e in enumerate(test_data.examples) if e.raw_id in raw_ids]
test_data_subset = test_data.get_dataset_by_ids(eids, test_data.name + '.subset')
from decoder import decode_ifttt_dataset
decode_results = decode_ifttt_dataset(model, test_data_subset, verbose=True)
evaluate_ifttt_results(test_data_subset, decode_results)
return decode_results
def decode_and_evaluate_ifttt_by_split(model, test_data):
for split in ['ifff.test_data.omit_non_english.id', 'ifff.test_data.omit_unintelligible.id', 'ifff.test_data.gold.id']:
raw_ids = [int(i.strip()) for i in open(os.path.join(config.ifttt_test_split), split)] # 'data/ifff.test_data.gold.id'
eids = [i for i, e in enumerate(test_data.examples) if e.raw_id in raw_ids]
test_data_subset = test_data.get_dataset_by_ids(eids, test_data.name + '.' + split)
from decoder import decode_ifttt_dataset
decode_results = decode_ifttt_dataset(model, test_data_subset, verbose=True)
evaluate_ifttt_results(test_data_subset, decode_results)
if __name__ == '__main__':
from dataset import DataEntry, DataSet, Vocab, Action
init_logging('parser.log', logging.INFO)
train_data, dev_data, test_data = deserialize_from_file('data/ifttt.freq3.bin')
decoding_results = []
for eid in range(test_data.count):
example = test_data.examples[eid]
decoding_results.append([(eid, example.parse_tree)])
evaluate_ifttt_results(test_data, decoding_results, verbose=True)
================================================
FILE: interactive_mode.py
================================================
import argparse, sys
from nn.utils.generic_utils import init_logging
from nn.utils.io_utils import deserialize_from_file, serialize_to_file
from evaluation import *
from dataset import canonicalize_query, query_to_data
from collections import namedtuple
from lang.py.parse import decode_tree_to_python_ast
from model import Model
from dataset import DataEntry, DataSet, Vocab, Action
import config
parser = argparse.ArgumentParser()
parser.add_argument('-data_type', default='django', choices=['django', 'hs'])
parser.add_argument('-data')
parser.add_argument('-random_seed', default=181783, type=int)
parser.add_argument('-model', default=None)
# neural model's parameters
parser.add_argument('-source_vocab_size', default=0, type=int)
parser.add_argument('-target_vocab_size', default=0, type=int)
parser.add_argument('-rule_num', default=0, type=int)
parser.add_argument('-node_num', default=0, type=int)
parser.add_argument('-word_embed_dim', default=128, type=int)
parser.add_argument('-rule_embed_dim', default=128, type=int)
parser.add_argument('-node_embed_dim', default=64, type=int)
parser.add_argument('-encoder_hidden_dim', default=256, type=int)
parser.add_argument('-decoder_hidden_dim', default=256, type=int)
parser.add_argument('-attention_hidden_dim', default=50, type=int)
parser.add_argument('-ptrnet_hidden_dim', default=50, type=int)
parser.add_argument('-dropout', default=0.2, type=float)
# encoder
parser.add_argument('-encoder', default='bilstm', choices=['bilstm', 'lstm'])
# decoder
parser.add_argument('-parent_hidden_state_feed', dest='parent_hidden_state_feed', action='store_true')
parser.add_argument('-no_parent_hidden_state_feed', dest='parent_hidden_state_feed', action='store_false')
parser.set_defaults(parent_hidden_state_feed=True)
parser.add_argument('-parent_action_feed', dest='parent_action_feed', action='store_true')
parser.add_argument('-no_parent_action_feed', dest='parent_action_feed', action='store_false')
parser.set_defaults(parent_action_feed=True)
parser.add_argument('-frontier_node_type_feed', dest='frontier_node_type_feed', action='store_true')
parser.add_argument('-no_frontier_node_type_feed', dest='frontier_node_type_feed', action='store_false')
parser.set_defaults(frontier_node_type_feed=True)
parser.add_argument('-tree_attention', dest='tree_attention', action='store_true')
parser.add_argument('-no_tree_attention', dest='tree_attention', action='store_false')
parser.set_defaults(tree_attention=False)
parser.add_argument('-enable_copy', dest='enable_copy', action='store_true')
parser.add_argument('-no_copy', dest='enable_copy', action='store_false')
parser.set_defaults(enable_copy=True)
# training
parser.add_argument('-optimizer', default='adam')
parser.add_argument('-clip_grad', default=0., type=float)
parser.add_argument('-train_patience', default=10, type=int)
parser.add_argument('-max_epoch', default=50, type=int)
parser.add_argument('-batch_size', default=10, type=int)
parser.add_argument('-valid_per_batch', default=4000, type=int)
parser.add_argument('-save_per_batch', default=4000, type=int)
parser.add_argument('-valid_metric', default='bleu')
# decoding
parser.add_argument('-beam_size', default=15, type=int)
parser.add_argument('-max_query_length', default=70, type=int)
parser.add_argument('-decode_max_time_step', default=100, type=int)
parser.add_argument('-head_nt_constraint', dest='head_nt_constraint', action='store_true')
parser.add_argument('-no_head_nt_constraint', dest='head_nt_constraint', action='store_false')
parser.set_defaults(head_nt_constraint=True)
args = parser.parse_args(args=['-data_type', 'django', '-data', 'data/django.cleaned.dataset.freq5.par_info.refact.space_only.bin',
'-model', 'models/model.django_word128_encoder256_rule128_node64.beam15.adam.simple_trans.no_unary_closure.8e39832.run3.best_acc.npz'])
if args.data_type == 'hs':
args.decode_max_time_step = 350
logging.info('loading dataset [%s]', args.data)
train_data, dev_data, test_data = deserialize_from_file(args.data)
if not args.source_vocab_size:
args.source_vocab_size = train_data.annot_vocab.size
if not args.target_vocab_size:
args.target_vocab_size = train_data.terminal_vocab.size
if not args.rule_num:
args.rule_num = len(train_data.grammar.rules)
if not args.node_num:
args.node_num = len(train_data.grammar.node_type_to_id)
config_module = sys.modules['config']
for name, value in vars(args).iteritems():
setattr(config_module, name, value)
# build the model
model = Model()
model.build()
model.load(args.model)
def decode_query(query):
"""decode a given natural language query, return a list of generated candidates"""
query, str_map = canonicalize_query(query)
vocab = train_data.annot_vocab
query_tokens = query.split(' ')
query_tokens_data = [query_to_data(query, vocab)]
example = namedtuple('example', ['query', 'data'])(query=query_tokens, data=query_tokens_data)
cand_list = model.decode(example, train_data.grammar, train_data.terminal_vocab,
beam_size=args.beam_size, max_time_step=args.decode_max_time_step, log=True)
return cand_list
if __name__ == '__main__':
print 'run in interactive mode'
while True:
query = raw_input('input a query: ')
cand_list = decode_query(query)
# output top 5 candidates
for cid, cand in enumerate(cand_list[:5]):
print '*' * 60
print 'cand #%d, score: %f' % (cid, cand.score)
try:
ast_tree = decode_tree_to_python_ast(cand.tree)
code = astor.to_source(ast_tree)
print 'code: ', code
print 'decode log: ', cand.log
except:
print "Exception in converting tree to code:"
print '-' * 60
print 'raw_id: %d, beam pos: %d' % (example.raw_id, cid)
traceback.print_exc(file=sys.stdout)
print '-' * 60
finally:
print '* parse tree *'
print cand.tree.__repr__()
print 'n_timestep: %d' % cand.n_timestep
print 'ast size: %d' % cand.tree.size
print '*' * 60
================================================
FILE: lang/__init__.py
================================================
================================================
FILE: lang/grammar.py
================================================
from collections import OrderedDict, defaultdict
import logging
from astnode import ASTNode
from lang.util import typename
class Grammar(object):
def __init__(self, rules):
"""
instantiate a grammar with a set of production rules of type Rule
"""
self.rules = rules
self.rule_index = defaultdict(list)
self.rule_to_id = OrderedDict()
node_types = set()
lhs_nodes = set()
rhs_nodes = set()
for rule in self.rules:
self.rule_index[rule.parent].append(rule)
# we also store all unique node types
for node in rule.nodes:
node_types.add(typename(node.type))
lhs_nodes.add(rule.parent)
for child in rule.children:
rhs_nodes.add(child.as_type_node)
root_node = lhs_nodes - rhs_nodes
assert len(root_node) == 1
self.root_node = next(iter(root_node))
self.terminal_nodes = rhs_nodes - lhs_nodes
self.terminal_types = set([n.type for n in self.terminal_nodes])
self.node_type_to_id = OrderedDict()
for i, type in enumerate(node_types, start=0):
self.node_type_to_id[type] = i
for gid, rule in enumerate(rules, start=0):
self.rule_to_id[rule] = gid
self.id_to_rule = OrderedDict((v, k) for (k, v) in self.rule_to_id.iteritems())
logging.info('num. rules: %d', len(self.rules))
logging.info('num. types: %d', len(self.node_type_to_id))
logging.info('root: %s', self.root_node)
logging.info('terminals: %s', ', '.join(repr(n) for n in self.terminal_nodes))
def __iter__(self):
return self.rules.__iter__()
def __len__(self):
return len(self.rules)
def __getitem__(self, lhs):
key_node = ASTNode(lhs.type, None) # Rules are indexed by types only
if key_node in self.rule_index:
return self.rule_index[key_node]
else:
KeyError('key=%s' % key_node)
def get_node_type_id(self, node):
from astnode import ASTNode
if isinstance(node, ASTNode):
type_repr = typename(node.type)
return self.node_type_to_id[type_repr]
else:
# assert isinstance(node, str)
# it is a type
type_repr = typename(node)
return self.node_type_to_id[type_repr]
def is_terminal(self, node):
return node.type in self.terminal_types
def is_value_node(self, node):
raise NotImplementedError
================================================
FILE: lang/ifttt/__init__.py
================================================
================================================
FILE: lang/ifttt/grammar.py
================================================
from lang.grammar import Grammar
class IFTTTGrammar(Grammar):
def __init__(self, rules):
super(IFTTTGrammar, self).__init__(rules)
def is_value_node(self, node):
return False
================================================
FILE: lang/ifttt/ifttt_dataset.py
================================================
# -*- coding: UTF-8 -*-
from __future__ import division
import string
from collections import OrderedDict
from collections import defaultdict
from itertools import count
from nn.utils.io_utils import serialize_to_file, deserialize_from_file
from lang.ifttt.grammar import IFTTTGrammar
from parse import ifttt_ast_to_parse_tree
from lang.grammar import Grammar
import logging
from itertools import chain
from nn.utils.generic_utils import init_logging
from dataset import gen_vocab, DataSet, DataEntry, Action, APPLY_RULE, GEN_TOKEN, COPY_TOKEN, GEN_COPY_TOKEN
def load_examples(data_file):
f = open(data_file)
next(f)
examples = []
for line in f:
d = line.strip().split('\t')
description = d[4]
code = d[9]
parse_tree = ifttt_ast_to_parse_tree(code)
examples.append({'description': description, 'parse_tree': parse_tree, 'code': code})
return examples
def analyze_ifttt_dataset():
data_file = '/Users/yinpengcheng/Research/SemanticParsing/ifttt/recipe_summaries.all.tsv'
examples = load_examples(data_file)
rule_num = 0.
max_rule_num = -1
example_with_max_rule_num = -1
for idx, example in enumerate(examples):
parse_tree = example['parse_tree']
rules, _ = parse_tree.get_productions(include_value_node=True)
rule_num += len(rules)
if max_rule_num < len(rules):
max_rule_num = len(rules)
example_with_max_rule_num = idx
logging.info('avg. num. of rules: %f', rule_num / len(examples))
logging.info('max_rule_num: %d', max_rule_num)
logging.info('example_with_max_rule_num: %d', example_with_max_rule_num)
def canonicalize_ifttt_example(annot, code):
parse_tree = ifttt_ast_to_parse_tree(code, attach_func_to_channel=False)
clean_code = str(parse_tree)
clean_query_tokens = annot.split()
clean_query_tokens = [t.lower() for t in clean_query_tokens]
return clean_query_tokens, clean_code, parse_tree
def preprocess_ifttt_dataset(annot_file, code_file):
f = open('ifttt_dataset.examples.txt', 'w')
examples = []
for idx, (annot, code) in enumerate(zip(open(annot_file), open(code_file))):
annot = annot.strip()
code = code.strip()
clean_query_tokens, clean_code, parse_tree = canonicalize_ifttt_example(annot, code)
example = {'id': idx, 'query_tokens': clean_query_tokens, 'code': clean_code, 'parse_tree': parse_tree,
'str_map': None, 'raw_code': code}
examples.append(example)
f.write('*' * 50 + '\n')
f.write('example# %d\n' % idx)
f.write(' '.join(clean_query_tokens) + '\n')
f.write('\n')
f.write(clean_code + '\n')
f.write('*' * 50 + '\n')
idx += 1
f.close()
print 'preprocess_dataset: cleaned example num: %d' % len(examples)
return examples
def get_grammar(parse_trees):
rules = set()
for parse_tree in parse_trees:
parse_tree_rules, rule_parents = parse_tree.get_productions()
for rule in parse_tree_rules:
rules.add(rule)
rules = list(sorted(rules, key=lambda x: x.__repr__()))
grammar = IFTTTGrammar(rules)
logging.info('num. rules: %d', len(rules))
with open('grammar.txt', 'w') as f:
for rule in grammar:
str = rule.__repr__()
f.write(str + '\n')
with open('parse_trees.txt', 'w') as f:
for tree in parse_trees:
f.write(tree.__repr__() + '\n')
return grammar
def parse_ifttt_dataset():
WORD_FREQ_CUT_OFF = 2
annot_file = '/Users/yinpengcheng/Research/SemanticParsing/ifttt/Data/lang.all.txt'
code_file = '/Users/yinpengcheng/Research/SemanticParsing/ifttt/Data/code.all.txt'
data = preprocess_ifttt_dataset(annot_file, code_file)
# build the grammar
grammar = get_grammar([e['parse_tree'] for e in data])
annot_tokens = list(chain(*[e['query_tokens'] for e in data]))
annot_vocab = gen_vocab(annot_tokens, vocab_size=30000, freq_cutoff=WORD_FREQ_CUT_OFF)
logging.info('annot vocab. size: %d', annot_vocab.size)
# we have no terminal tokens in ifttt
all_terminal_tokens = []
terminal_vocab = gen_vocab(all_terminal_tokens, vocab_size=4000, freq_cutoff=WORD_FREQ_CUT_OFF)
# now generate the dataset!
train_data = DataSet(annot_vocab, terminal_vocab, grammar, 'ifttt.train_data')
dev_data = DataSet(annot_vocab, terminal_vocab, grammar, 'ifttt.dev_data')
test_data = DataSet(annot_vocab, terminal_vocab, grammar, 'ifttt.test_data')
all_examples = []
can_fully_reconstructed_examples_num = 0
examples_with_empty_actions_num = 0
for entry in data:
idx = entry['id']
query_tokens = entry['query_tokens']
code = entry['code']
parse_tree = entry['parse_tree']
# check if query tokens are valid
query_token_ids = [annot_vocab[token] for token in query_tokens if token not in string.punctuation]
valid_query_tokens_ids = [tid for tid in query_token_ids if tid != annot_vocab.unk]
# remove examples with rare words from train and dev, avoid overfitting
if len(valid_query_tokens_ids) == 0 and 0 <= idx < 77495 + 5171:
continue
rule_list, rule_parents = parse_tree.get_productions(include_value_node=True)
actions = []
can_fully_reconstructed = True
rule_pos_map = dict()
for rule_count, rule in enumerate(rule_list):
if not grammar.is_value_node(rule.parent):
assert rule.value is None
parent_rule = rule_parents[(rule_count, rule)][0]
if parent_rule:
parent_t = rule_pos_map[parent_rule]
else:
parent_t = 0
rule_pos_map[rule] = len(actions)
d = {'rule': rule, 'parent_t': parent_t, 'parent_rule': parent_rule}
action = Action(APPLY_RULE, d)
actions.append(action)
else:
raise RuntimeError('no terminals should be in ifttt dataset!')
if len(actions) == 0:
examples_with_empty_actions_num += 1
continue
example = DataEntry(idx, query_tokens, parse_tree, code, actions,
{'str_map': None, 'raw_code': entry['raw_code']})
if can_fully_reconstructed:
can_fully_reconstructed_examples_num += 1
# train, valid, test splits
if 0 <= idx < 77495:
train_data.add(example)
elif idx < 77495 + 5171:
dev_data.add(example)
else:
test_data.add(example)
all_examples.append(example)
# print statistics
max_query_len = max(len(e.query) for e in all_examples)
max_actions_len = max(len(e.actions) for e in all_examples)
# serialize_to_file([len(e.query) for e in all_examples], 'query.len')
# serialize_to_file([len(e.actions) for e in all_examples], 'actions.len')
logging.info('train_data examples: %d', train_data.count)
logging.info('dev_data examples: %d', dev_data.count)
logging.info('test_data examples: %d', test_data.count)
logging.info('examples that can be fully reconstructed: %d/%d=%f',
can_fully_reconstructed_examples_num, len(all_examples),
can_fully_reconstructed_examples_num / len(all_examples))
logging.info('empty_actions_count: %d', examples_with_empty_actions_num)
logging.info('max_query_len: %d', max_query_len)
logging.info('max_actions_len: %d', max_actions_len)
train_data.init_data_matrices(max_query_length=40, max_example_action_num=6)
dev_data.init_data_matrices()
test_data.init_data_matrices()
serialize_to_file((train_data, dev_data, test_data),
'data/ifttt.freq{WORD_FREQ_CUT_OFF}.bin'.format(WORD_FREQ_CUT_OFF=WORD_FREQ_CUT_OFF))
return train_data, dev_data, test_data
def parse_data_for_seq2seq(data_file='data/ifttt.freq3.bin'):
train_data, dev_data, test_data = deserialize_from_file(data_file)
prefix = 'data/seq2seq/'
for dataset, output in [(train_data, prefix + 'ifttt.train'),
(dev_data, prefix + 'ifttt.dev'),
(test_data, prefix + 'ifttt.test')]:
f_source = open(output + '.desc', 'w')
f_target = open(output + '.code', 'w')
if 'test' in output:
raw_ids = [int(i.strip()) for i in open('data/ifff.test_data.gold.id')]
eids = [i for i, e in enumerate(test_data.examples) if e.raw_id in raw_ids]
dataset = test_data.get_dataset_by_ids(eids, test_data.name + '.subset')
for e in dataset.examples:
query_tokens = e.query
trigger = e.parse_tree['TRIGGER'].children[0].type + ' . ' + e.parse_tree['TRIGGER'].children[0].children[0].type
action = e.parse_tree['ACTION'].children[0].type + ' . ' + e.parse_tree['ACTION'].children[0].children[0].type
code = 'IF ' + trigger + ' THEN ' + action
f_source.write(' '.join(query_tokens) + '\n')
f_target.write(code + '\n')
f_source.close()
f_target.close()
def extract_turk_data():
turk_annot_file = '/Users/yinpengcheng/Research/SemanticParsing/ifttt/public_release/data/turk_public.tsv'
reference_file = '/Users/yinpengcheng/Research/SemanticParsing/ifttt/public_release/data/ifttt_public.tsv'
f_turk = open(turk_annot_file)
next(f_turk)
annot_data = OrderedDict()
for line in f_turk:
d = line.strip().split('\t')
url = d[0]
if url not in annot_data:
annot_data[url] = list()
annot_data[url].append({'trigger_channel': d[2], 'trigger_func': d[3], 'action_channel': d[4], 'action_func': d[5]})
f_ref = open(reference_file)
next(f_ref)
ref_data = OrderedDict()
for line in f_ref:
d = line.strip().split('\t')
url = d[0]
ref_data[url] = {'trigger_channel': d[2], 'trigger_func': d[3], 'action_channel': d[4], 'action_func': d[5]}
lt_three_agree_with_gold = []
non_english_examples = []
unintelligible_examples = []
for url, annots in annot_data.iteritems():
vote_dict = defaultdict(int)
ref = ref_data[url]
match_with_gold_num = 0
non_english_num = unintelligible_num = 0
non_english_annots = []
unintelligible_annots = []
for annot in annots:
if annot['trigger_channel'] == ref['trigger_channel'] and annot['trigger_func'] == ref['trigger_func'] and \
annot['action_channel'] == ref['action_channel'] and annot['action_func'] == ref['action_func']:
match_with_gold_num += 1
vote_dict['#'.join(annot.values())] += 1
for i, annot in enumerate(annots):
if annot['trigger_channel'] == 'nonenglish' and annot['trigger_func'] == 'nonenglish' and \
annot['action_channel'] == 'nonenglish' and annot['action_func'] == 'nonenglish':
non_english_num += 1
non_english_annots.append(i)
if annot['trigger_channel'] == 'unintelligible' and annot['trigger_func'] == 'unintelligible' and \
annot['action_channel'] == 'unintelligible' and annot['action_func'] == 'unintelligible':
unintelligible_num += 1
unintelligible_annots.append(i)
max_vote_num = max(vote_dict.values())
# omitting descriptions marked as non-English by a majority of the crowdsourced workers
if non_english_num == max_vote_num:
non_english_examples.append(url)
non_english_and_unintelligible_num = len(set(non_english_annots).union(set(unintelligible_annots)))
# if this example has no non_english and unintelligible annotations
if non_english_and_unintelligible_num > 0: # < len(annots) - non_english_and_unintelligible_num:
unintelligible_examples.append(url)
if match_with_gold_num >= 3:
lt_three_agree_with_gold.append(url)
omit_non_english_examples = set(annot_data) - set(non_english_examples)
omit_unintelligible_examples = set(annot_data) - set(unintelligible_examples)
print len(omit_non_english_examples) # should be 3,741
print len(omit_unintelligible_examples) # should be 2,262
print len(lt_three_agree_with_gold) # should be 758
url2id = defaultdict(count(0).next)
for url in ref_data:
url2id[url] = url2id[url] + 77495 + 5171
f_gold = open('data/ifff.test_data.gold.id', 'w')
for url in lt_three_agree_with_gold:
i = url2id[url]
f_gold.write(str(i) + '\n')
f_gold.close()
f_gold = open('data/ifff.test_data.omit_unintelligible.id', 'w')
for url in omit_unintelligible_examples:
i = url2id[url]
f_gold.write(str(i) + '\n')
f_gold.close()
f_gold = open('data/ifff.test_data.omit_non_english.id', 'w')
for url in omit_non_english_examples:
i = url2id[url]
f_gold.write(str(i) + '\n')
f_gold.close()
omit_non_english_examples = [url2id[url] for url in omit_non_english_examples]
omit_unintelligible_examples = [url2id[url] for url in omit_unintelligible_examples]
lt_three_agree_with_gold = [url2id[url] for url in lt_three_agree_with_gold]
return omit_non_english_examples, omit_unintelligible_examples, lt_three_agree_with_gold
if __name__ == '__main__':
init_logging('ifttt.log')
# parse_ifttt_dataset()
# analyze_ifttt_dataset()
extract_turk_data()
# parse_data_for_seq2seq()
================================================
FILE: lang/ifttt/parse.py
================================================
from astnode import ASTNode
def ifttt_ast_to_parse_tree_helper(s, offset):
"""
adapted from ifttt codebase
"""
if s[offset] != '(':
raise RuntimeError('malformed string: node did not start with open paren at position ' + offset)
offset += 1
# extract node name(type)
name = ''
if s[offset] == '\"':
offset += 1
while s[offset] != '\"':
if s[offset] == '\\':
offset += 1
name += s[offset]
offset += 1
offset += 1
else:
while s[offset] != ' ' and s[offset] != ')':
name += s[offset]
offset += 1
node = ASTNode(name)
while True:
if s[offset] == ')':
offset += 1
return node, offset
if s[offset] != ' ':
raise RuntimeError('malformed string: node should have either had a '
'close paren or a space at position ' + offset)
offset += 1
child_node, offset = ifttt_ast_to_parse_tree_helper(s, offset)
node.add_child(child_node)
def ifttt_ast_to_parse_tree(s, attach_func_to_channel=True):
parse_tree, _ = ifttt_ast_to_parse_tree_helper(s, 0)
parse_tree = strip_params(parse_tree)
if attach_func_to_channel:
parse_tree = attach_function_to_channel(parse_tree)
return parse_tree
def strip_params(parse_tree):
if parse_tree.type == 'PARAMS':
raise RuntimeError('should not go to here!')
parse_tree.children = [c for c in parse_tree.children if c.type != 'PARAMS' and c.type != 'OUTPARAMS']
for i, child in enumerate(parse_tree.children):
parse_tree.children[i] = strip_params(child)
return parse_tree
def attach_function_to_channel(parse_tree):
trigger_func = parse_tree['TRIGGER']['FUNC'].children
assert len(trigger_func) == 1
trigger_func = trigger_func[0]
parse_tree['TRIGGER'].children[0].add_child(trigger_func)
del parse_tree['TRIGGER']['FUNC']
action_func = parse_tree['ACTION']['FUNC'].children
assert len(action_func) == 1
action_func = action_func[0]
parse_tree['ACTION'].children[0].add_child(action_func)
del parse_tree['ACTION']['FUNC']
return parse_tree
if __name__ == '__main__':
tree_code = """(ROOT (IF) (TRIGGER (Instagram) (FUNC (Any_new_photo_by_you) (PARAMS))) (THEN) (ACTION (Dropbox) (FUNC (Add_file_from_URL) (PARAMS (File_URL ({{Caption}})) (File_name ("")) (Dropbox_folder_path (ifttt/instagram))))))"""
parse_tree = ifttt_ast_to_parse_tree(tree_code)
print parse_tree
print strip_params(parse_tree)
print attach_function_to_channel(parse_tree)
================================================
FILE: lang/py/__init__.py
================================================
================================================
FILE: lang/py/grammar.py
================================================
"""
Python grammar and typing system
"""
import ast
import inspect
import astor
from lang.grammar import Grammar
PY_AST_NODE_FIELDS = {
'FunctionDef': {
'name': {
'type': str,
'is_list': False,
'is_optional': False
},
'args': {
'type': ast.arguments,
'is_list': False,
'is_optional': False
},
'body': {
'type': ast.stmt,
'is_list': True,
'is_optional': False
},
'decorator_list': {
'type': ast.expr,
'is_list': True,
'is_optional': False
}
},
'ClassDef': {
'name': {
'type': ast.arguments,
'is_list': False,
'is_optional': False
},
'bases': {
'type': ast.expr,
'is_list': True,
'is_optional': False
},
'body': {
'type': ast.stmt,
'is_list': True,
'is_optional': False
},
'decorator_list': {
'type': ast.expr,
'is_list': True,
'is_optional': False
}
},
'Return': {
'value': {
'type': ast.expr,
'is_list': False,
'is_optional': True
},
},
'Delete': {
'targets': {
'type': ast.expr,
'is_list': True,
'is_optional': False
},
},
'Assign': {
'targets': {
'type': ast.expr,
'is_list': True,
'is_optional': False
},
'value': {
'type': ast.expr,
'is_list': False,
'is_optional': False
}
},
'AugAssign': {
'target': {
'type': ast.expr,
'is_list': False,
'is_optional': False
},
'op': {
'type': ast.operator,
'is_list': False,
'is_optional': False
},
'value': {
'type': ast.expr,
'is_list': False,
'is_optional': False
}
},
'Print': {
'dest': {
'type': ast.expr,
'is_list': False,
'is_optional': True
},
'values': {
'type': ast.expr,
'is_list': True,
'is_optional': False
},
'nl': {
'type': bool,
'is_list': False,
'is_optional': False
}
},
'For': {
'target': {
'type': ast.expr,
'is_list': False,
'is_optional': False
},
'iter': {
'type': ast.expr,
'is_list': False,
'is_optional': False
},
'body': {
'type': ast.stmt,
'is_list': True,
'is_optional': False
},
'orelse': {
'type': ast.stmt,
'is_list': True,
'is_optional': False
}
},
'While': {
'test': {
'type': ast.expr,
'is_list': False,
'is_optional': False
},
'body': {
'type': ast.stmt,
'is_list': True,
'is_optional': False
},
'orelse': {
'type': ast.stmt,
'is_list': True,
'is_optional': False
},
},
'If': {
'test': {
'type': ast.expr,
'is_list': False,
'is_optional': False
},
'body': {
'type': ast.stmt,
'is_list': True,
'is_optional': False
},
'orelse': {
'type': ast.stmt,
'is_list': True,
'is_optional': False
},
},
'With': {
'context_expr': {
'type': ast.expr,
'is_list': False,
'is_optional': False
},
'optional_vars': {
'type': ast.expr,
'is_list': False,
'is_optional': True
},
'body': {
'type': ast.stmt,
'is_list': True,
'is_optional': False
},
},
'Raise': {
'type': {
'type': ast.expr,
'is_list': False,
'is_optional': True
},
'inst': {
'type': ast.expr,
'is_list': False,
'is_optional': True
},
'tback': {
'type': ast.expr,
'is_list': False,
'is_optional': True
},
},
'TryExcept': {
'body': {
'type': ast.stmt,
'is_list': True,
'is_optional': False
},
'handlers': {
'type': ast.excepthandler,
'is_list': True,
'is_optional': False
},
'orelse': {
'type': ast.stmt,
'is_list': True,
'is_optional': False
},
},
'TryFinally': {
'body': {
'type': ast.stmt,
'is_list': True,
'is_optional': False
},
'finalbody': {
'type': ast.stmt,
'is_list': True,
'is_optional': False
}
},
'Assert': {
'test': {
'type': ast.expr,
'is_list': False,
'is_optional': False
},
'msg': {
'type': ast.expr,
'is_list': False,
'is_optional': True
}
},
'Import': {
'names': {
'type': ast.alias,
'is_list': True,
'is_optional': False
}
},
'ImportFrom': {
'module': {
'type': str,
'is_list': False,
'is_optional': True
},
'names': {
'type': ast.alias,
'is_list': True,
'is_optional': False
},
'level': {
'type': int,
'is_list': False,
'is_optional': True
}
},
'Exec': {
'body': {
'type': ast.expr,
'is_list': False,
'is_optional': False
},
'globals': {
'type': ast.expr,
'is_list': False,
'is_optional': True
},
'locals': {
'type': ast.expr,
'is_list': False,
'is_optional': True
},
},
'Global': {
'names': {
'type': str,
'is_list': True,
'is_optional': False
},
},
'Expr': {
'value': {
'type': ast.expr,
'is_list': False,
'is_optional': False
},
},
'BoolOp': {
'op': {
'type': ast.boolop,
'is_list': False,
'is_optional': False
},
'values': {
'type': ast.expr,
'is_list': True,
'is_optional': False
},
},
'BinOp': {
'left': {
'type': ast.expr,
'is_list': False,
'is_optional': False
},
'op': {
'type': ast.operator,
'is_list': False,
'is_optional': False
},
'right': {
'type': ast.expr,
'is_list': False,
'is_optional': False
},
},
'UnaryOp': {
'op': {
'type': ast.unaryop,
'is_list': False,
'is_optional': False
},
'operand': {
'type': ast.expr,
'is_list': False,
'is_optional': False
},
},
'Lambda': {
'args': {
'type': ast.arguments,
'is_list': False,
'is_optional': False
},
'body': {
'type': ast.expr,
'is_list': False,
'is_optional': False
},
},
'IfExp': {
'test': {
'type': ast.expr,
'is_list': False,
'is_optional': False
},
'body': {
'type': ast.expr,
'is_list': False,
'is_optional': False
},
'orelse': {
'type': ast.expr,
'is_list': False,
'is_optional': False
},
},
'Dict': {
'keys': {
'type': ast.expr,
'is_list': True,
'is_optional': False
},
'values': {
'type': ast.expr,
'is_list': True,
'is_optional': False
},
},
'Set': {
'elts': {
'type': ast.expr,
'is_list': True,
'is_optional': False
},
},
'ListComp': {
'elt': {
'type': ast.expr,
'is_list': False,
'is_optional': False
},
'generators': {
'type': ast.comprehension,
'is_list': True,
'is_optional': False
},
},
'SetComp': {
'elt': {
'type': ast.expr,
'is_list': False,
'is_optional': False
},
'generators': {
'type': ast.comprehension,
'is_list': True,
'is_optional': False
},
},
'DictComp': {
'key': {
'type': ast.expr,
'is_list': False,
'is_optional': False
},
'value': {
'type': ast.expr,
'is_list': False,
'is_optional': False
},
'generators': {
'type': ast.comprehension,
'is_list': True,
'is_optional': False
},
},
'GeneratorExp': {
'elt': {
'type': ast.expr,
'is_list': False,
'is_optional': False
},
'generators': {
'type': ast.comprehension,
'is_list': True,
'is_optional': False
},
},
'Yield': {
'value': {
'type': ast.expr,
'is_list': False,
'is_optional': True
}
},
'Compare': {
'left': {
'type': ast.expr,
'is_list': False,
'is_optional': False
},
'ops': {
'type': ast.cmpop,
'is_list': True,
'is_optional': False
},
'comparators': {
'type': ast.expr,
'is_list': True,
'is_optional': False
},
},
'Call': {
'func': {
'type': ast.expr,
'is_list': False,
'is_optional': False
},
'args': {
'type': ast.expr,
'is_list': True,
'is_optional': False
},
'keywords': {
'type': ast.keyword,
'is_list': True,
'is_optional': False
},
'starargs': {
'type': ast.expr,
'is_list': False,
'is_optional': True
},
'kwargs': {
'type': ast.expr,
'is_list': False,
'is_optional': True
},
},
'Repr': {
'value': {
'type': ast.expr,
'is_list': False,
'is_optional': False
}
},
'Num': {
'n': {
'type': object, #FIXME: should be float or int?
'is_list': False,
'is_optional': False
}
},
'Str': {
's': {
'type': str,
'is_list': False,
'is_optional': False
}
},
'Attribute': {
'value': {
'type': ast.expr,
'is_list': False,
'is_optional': False
},
'attr': {
'type': str,
'is_list': False,
'is_optional': False
},
'ctx': {
'type': ast.expr_context,
'is_list': False,
'is_optional': False
},
},
'Subscript': {
'value': {
'type': ast.expr,
'is_list': False,
'is_optional': False
},
'slice': {
'type': ast.slice,
'is_list': False,
'is_optional': False
},
},
'Name': {
'id': {
'type': str,
'is_list': False,
'is_optional': False
}
},
'List': {
'elts': {
'type': ast.expr,
'is_list': True,
'is_optional': False
},
'ctx': {
'type': ast.expr_context,
'is_list': False,
'is_optional': False
},
},
'Tuple': {
'elts': {
'type': ast.expr,
'is_list': True,
'is_optional': False
},
'ctx': {
'type': ast.expr_context,
'is_list': False,
'is_optional': False
},
},
'ExceptHandler': {
'type': {
'type': ast.expr,
'is_list': False,
'is_optional': True
},
'name': {
'type': ast.expr,
'is_list': False,
'is_optional': True
},
'body': {
'type': ast.stmt,
'is_list': True,
'is_optional': False
}
},
'arguments': {
'args': {
'type': ast.expr,
'is_list': True,
'is_optional': False
},
'vararg': {
'type': str,
'is_list': False,
'is_optional': True
},
'kwarg': {
'type': str,
'is_list': False,
'is_optional': True
},
'defaults': {
'type': ast.expr,
'is_list': True,
'is_optional': False
},
},
'comprehension': {
'target': {
'type': ast.expr,
'is_list': False,
'is_optional': False
},
'iter': {
'type': ast.expr,
'is_list': False,
'is_optional': False
},
'ifs': {
'type': ast.expr,
'is_list': True,
'is_optional': False
},
},
'keyword': {
'arg': {
'type': str,
'is_list': False,
'is_optional': False
},
'value': {
'type': ast.expr,
'is_list': False,
'is_optional': True
}
},
'alias': {
'name': {
'type': str,
'is_list': False,
'is_optional': False
},
'asname': {
'type': str,
'is_list': False,
'is_optional': True
}
},
'Slice': {
'lower': {
'type': ast.expr,
'is_list': False,
'is_optional': True
},
'upper': {
'type': ast.expr,
'is_list': False,
'is_optional': True
},
'step': {
'type': ast.expr,
'is_list': False,
'is_optional': True
}
},
'ExtSlice': {
'dims': {
'type': ast.slice,
'is_list': True,
'is_optional': False
}
},
'Index': {
'value': {
'type': ast.expr,
'is_list': False,
'is_optional': False
}
}
}
NODE_FIELD_BLACK_LIST = {'ctx'}
TERMINAL_AST_TYPES = {
ast.Pass,
ast.Break,
ast.Continue,
ast.Add,
ast.BitAnd,
ast.BitOr,
ast.BitXor,
ast.Div,
ast.FloorDiv,
ast.LShift,
ast.Mod,
ast.Mult,
ast.Pow,
ast.Sub,
ast.And,
ast.Or,
ast.Eq,
ast.Gt,
ast.GtE,
ast.In,
ast.Is,
ast.IsNot,
ast.Lt,
ast.LtE,
ast.NotEq,
ast.NotIn,
ast.Not,
ast.USub
}
def is_builtin_type(x):
return x == str or x == int or x == float or x == bool or x == object or x == 'identifier'
def is_terminal_ast_type(x):
if inspect.isclass(x) and x in TERMINAL_AST_TYPES:
return True
return False
# def is_terminal_type(x):
# if is_builtin_type(x):
# return True
#
# if x == 'epsilon':
# return True
#
# if inspect.isclass(x) and (issubclass(x, ast.Pass) or issubclass(x, ast.Raise) or issubclass(x, ast.Break)
# or issubclass(x, ast.Continue)
# or issubclass(x, ast.Return)
# or issubclass(x, ast.operator) or issubclass(x, ast.boolop)
# or issubclass(x, ast.Ellipsis) or issubclass(x, ast.unaryop)
# or issubclass(x, ast.cmpop)):
# return True
#
# return False
# class Node:
# def __init__(self, node_type, label):
# self.type = node_type
# self.label = label
#
# @property
# def is_preterminal(self):
# return is_terminal_type(self.type)
#
# def __eq__(self, other):
# return self.type == other.type and self.label == other.label
#
# def __hash__(self):
# return typename(self.type).__hash__() ^ self.label.__hash__()
#
# def __repr__(self):
# repr_str = typename(self.type)
# if self.label:
# repr_str += '{%s}' % self.label
# return repr_str
#
#
# class TypedRule:
# def __init__(self, parent, children, tree=None):
# self.parent = parent
# if isinstance(children, list) or isinstance(children, tuple):
# self.children = tuple(children)
# else:
# self.children = (children, )
#
# # tree property is not incorporated in eq, hash
# self.tree = tree
#
# # @property
# # def is_terminal_rule(self):
# # return is_terminal_type(self.parent.type)
#
# def __eq__(self, other):
# return self.parent == other.parent and self.children == other.children
#
# def __hash__(self):
# return self.parent.__hash__() ^ self.children.__hash__()
#
# def __repr__(self):
# return '%s -> %s' % (self.parent, ', '.join([c.__repr__() for c in self.children]))
def type_str_to_type(type_str):
if type_str.endswith('*') or type_str == 'root' or type_str == 'epsilon':
return type_str
else:
try:
type_obj = eval(type_str)
if is_builtin_type(type_obj):
return type_obj
except:
pass
try:
type_obj = eval('ast.' + type_str)
return type_obj
except:
raise RuntimeError('unidentified type string: %s' % type_str)
def is_compositional_leaf(node):
is_leaf = True
for field_name, field_value in ast.iter_fields(node):
if field_name in NODE_FIELD_BLACK_LIST:
continue
if field_value is None:
is_leaf &= True
elif isinstance(field_value, list) and len(field_value) == 0:
is_leaf &= True
else:
is_leaf &= False
return is_leaf
class PythonGrammar(Grammar):
def __init__(self, rules):
super(PythonGrammar, self).__init__(rules)
def is_value_node(self, node):
return is_builtin_type(node.type)
================================================
FILE: lang/py/parse.py
================================================
import ast
import logging
import re
import token as tk
from cStringIO import StringIO
from tokenize import generate_tokens
from astnode import ASTNode
from lang.py.grammar import is_compositional_leaf, PY_AST_NODE_FIELDS, NODE_FIELD_BLACK_LIST, PythonGrammar
from lang.util import escape
from lang.util import typename
def python_ast_to_parse_tree(node):
assert isinstance(node, ast.AST)
node_type = type(node)
tree = ASTNode(node_type)
# it's a leaf AST node, e.g., ADD, Break, etc.
if len(node._fields) == 0:
return tree
# if it's a compositional AST node with empty fields
if is_compositional_leaf(node):
epsilon = ASTNode('epsilon')
tree.add_child(epsilon)
return tree
fields_info = PY_AST_NODE_FIELDS[node_type.__name__]
for field_name, field_value in ast.iter_fields(node):
# remove ctx stuff
if field_name in NODE_FIELD_BLACK_LIST:
continue
# omit empty fields, including empty lists
if field_value is None or (isinstance(field_value, list) and len(field_value) == 0):
continue
# now it's not empty!
field_type = fields_info[field_name]['type']
is_list_field = fields_info[field_name]['is_list']
if isinstance(field_value, ast.AST):
child = ASTNode(field_type, field_name)
child.add_child(python_ast_to_parse_tree(field_value))
elif type(field_value) is str or type(field_value) is int or \
type(field_value) is float or type(field_value) is object or \
type(field_value) is bool:
# if field_type != type(field_value):
# print 'expect [%s] type, got [%s]' % (field_type, type(field_value))
child = ASTNode(type(field_value), field_name, value=field_value)
elif is_list_field:
list_node_type = typename(field_type) + '*'
child = ASTNode(list_node_type, field_name)
for n in field_value:
if field_type in {ast.comprehension, ast.excepthandler, ast.arguments, ast.keyword, ast.alias}:
child.add_child(python_ast_to_parse_tree(n))
else:
intermediate_node = ASTNode(field_type)
if field_type is str:
intermediate_node.value = n
else:
intermediate_node.add_child(python_ast_to_parse_tree(n))
child.add_child(intermediate_node)
else:
raise RuntimeError('unknown AST node field!')
tree.add_child(child)
return tree
def parse_tree_to_python_ast(tree):
node_type = tree.type
node_label = tree.label
# remove root
if node_type == 'root':
return parse_tree_to_python_ast(tree.children[0])
ast_node = node_type()
node_type_name = typename(node_type)
# if it's a compositional AST node, populate its children nodes,
# fill fields with empty(default) values otherwise
if node_type_name in PY_AST_NODE_FIELDS:
fields_info = PY_AST_NODE_FIELDS[node_type_name]
for child_node in tree.children:
# if it's a compositional leaf
if child_node.type == 'epsilon':
continue
field_type = child_node.type
field_label = child_node.label
field_entry = fields_info[field_label]
is_list = field_entry['is_list']
if is_list:
field_type = field_entry['type']
field_value = []
if field_type in {ast.comprehension, ast.excepthandler, ast.arguments, ast.keyword, ast.alias}:
nodes_in_list = child_node.children
for sub_node in nodes_in_list:
sub_node_ast = parse_tree_to_python_ast(sub_node)
field_value.append(sub_node_ast)
else: # expr stuffs
inter_nodes = child_node.children
for inter_node in inter_nodes:
if inter_node.value is None:
assert len(inter_node.children) == 1
sub_node_ast = parse_tree_to_python_ast(inter_node.children[0])
field_value.append(sub_node_ast)
else:
assert len(inter_node.children) == 0
field_value.append(inter_node.value)
else:
# this node either holds a value, or is an non-terminal
if child_node.value is None:
assert len(child_node.children) == 1
field_value = parse_tree_to_python_ast(child_node.children[0])
else:
assert child_node.is_leaf
field_value = child_node.value
setattr(ast_node, field_label, field_value)
for field in ast_node._fields:
if not hasattr(ast_node, field) and not field in NODE_FIELD_BLACK_LIST:
if fields_info and fields_info[field]['is_list'] and not fields_info[field]['is_optional']:
setattr(ast_node, field, list())
else:
setattr(ast_node, field, None)
return ast_node
def decode_tree_to_python_ast(decode_tree):
from lang.py.unaryclosure import compressed_ast_to_normal
compressed_ast_to_normal(decode_tree)
decode_tree = decode_tree.children[0]
terminals = decode_tree.get_leaves()
for terminal in terminals:
if terminal.value is not None and type(terminal.value) is str:
if terminal.value.endswith('<eos>'):
terminal.value = terminal.value[:-5]
if terminal.type in {int, float, str, bool}:
# cast to target data type
terminal.value = terminal.type(terminal.value)
ast_tree = parse_tree_to_python_ast(decode_tree)
return ast_tree
p_elif = re.compile(r'^elif\s?')
p_else = re.compile(r'^else\s?')
p_try = re.compile(r'^try\s?')
p_except = re.compile(r'^except\s?')
p_finally = re.compile(r'^finally\s?')
p_decorator = re.compile(r'^@.*')
def canonicalize_code(code):
if p_elif.match(code):
code = 'if True: pass\n' + code
if p_else.match(code):
code = 'if True: pass\n' + code
if p_try.match(code):
code = code + 'pass\nexcept: pass'
elif p_except.match(code):
code = 'try: pass\n' + code
elif p_finally.match(code):
code = 'try: pass\n' + code
if p_decorator.match(code):
code = code + '\ndef dummy(): pass'
if code[-1] == ':':
code = code + 'pass'
return code
def de_canonicalize_code(code, ref_raw_code):
if code.endswith('def dummy():\n pass'):
code = code.replace('def dummy():\n pass', '').strip()
if p_elif.match(ref_raw_code):
# remove leading if true
code = code.replace('if True:\n pass', '').strip()
elif p_else.match(ref_raw_code):
# remove leading if true
code = code.replace('if True:\n pass', '').strip()
# try/catch/except stuff
if p_try.match(ref_raw_code):
code = code.replace('except:\n pass', '').strip()
elif p_except.match(ref_raw_code):
code = code.replace('try:\n pass', '').strip()
elif p_finally.match(ref_raw_code):
code = code.replace('try:\n pass', '').strip()
# remove ending pass
if code.endswith(':\n pass'):
code = code[:-len('\n pass')]
return code
def de_canonicalize_code_for_seq2seq(code, ref_raw_code):
if code.endswith('\ndef dummy(): pass'):
code = code.replace('\ndef dummy(): pass', '').strip()
if p_elif.match(ref_raw_code):
# remove leading if true
code = code.replace('if True: pass\n', '').strip()
elif p_else.match(ref_raw_code):
# remove leading if true
code = code.replace('if True: pass\n', '').strip()
# try/catch/except stuff
if p_try.match(ref_raw_code):
code = code.replace('pass\nexcept: pass', '').strip()
elif p_except.match(ref_raw_code):
code = code.replace('try: pass\n', '').strip()
elif p_finally.match(ref_raw_code):
code = code.replace('try: pass\n', '').strip()
# remove ending pass
if code.endswith(':pass'):
code = code[:-len('pass')]
return code.strip()
def add_root(tree):
root_node = ASTNode('root')
root_node.add_child(tree)
return root_node
def parse(code):
"""
parse a python code into a tree structure
code -> AST tree -> AST tree to internal tree structure
"""
code = canonicalize_code(code)
py_ast = ast.parse(code)
tree = python_ast_to_parse_tree(py_ast.body[0])
tree = add_root(tree)
return tree
def parse_raw(code):
py_ast = ast.parse(code)
tree = python_ast_to_parse_tree(py_ast.body[0])
tree = add_root(tree)
return tree
def get_grammar(parse_trees):
rules = set()
# rule_num_dist = defaultdict(int)
for parse_tree in parse_trees:
parse_tree_rules, rule_parents = parse_tree.get_productions()
for rule in parse_tree_rules:
rules.add(rule)
rules = list(sorted(rules, key=lambda x: x.__repr__()))
grammar = PythonGrammar(rules)
logging.info('num. rules: %d', len(rules))
return grammar
def tokenize_code(code):
token_stream = generate_tokens(StringIO(code).readline)
tokens = []
for toknum, tokval, (srow, scol), (erow, ecol), _ in token_stream:
if toknum == tk.ENDMARKER:
break
tokens.append(tokval)
return tokens
def tokenize_code_adv(code, breakCamelStr=False):
token_stream = generate_tokens(StringIO(code).readline)
tokens = []
indent_level = 0
for toknum, tokval, (srow, scol), (erow, ecol), _ in token_stream:
if toknum == tk.ENDMARKER:
break
if toknum == tk.INDENT:
indent_level += 1
tokens.extend(['#INDENT#'] * indent_level)
continue
elif toknum == tk.DEDENT:
indent_level -= 1
tokens.extend(['#INDENT#'] * indent_level)
continue
elif len(tokens) > 0 and tokens[-1] == '\n' and tokval != '\n':
tokens.extend(['#INDENT#'] * indent_level)
if toknum == tk.STRING:
quote = tokval[0]
tokval = tokval[1:-1]
tokens.append(quote)
if breakCamelStr:
sub_tokens = re.sub(r'([a-z])([A-Z])', r'\1 #MERGE# \2', tokval).split(' ')
tokens.extend(sub_tokens)
else:
tokens.append(tokval)
if toknum == tk.STRING:
tokens.append(quote)
return tokens
if __name__ == '__main__':
from nn.utils.generic_utils import init_logging
init_logging('misc.log')
# django_code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code'
#
# grammar, parse_trees = extract_grammar(django_code_file)
# id = 1888
# parse_tree = parse_trees[id]
# print parse_tree
# from components import Hyp
# hyp = Hyp(grammar)
# rules, rule_parents = parse_tree.get_productions()
#
# while hyp.frontier_nt():
# nt = hyp.frontier_nt()
# if grammar.is_value_node(nt):
# hyp.append_token('111<eos>')
# else:
# rule = rules[0]
# hyp.apply_rule(rule)
# del rules[0]
#
# print hyp
#
# ast_tree = decode_tree_to_python_ast(hyp.tree)
# source = astor.to_source(ast_tree)
# print source
# for code in open(django_code_file):
# code = code.strip()
# ref_ast_tree = ast.parse(canonicalize_code(code)).body[0]
# parse_tree = parse(code)
# ast_tree = parse_tree_to_python_ast(parse_tree)
# source1 = astor.to_source(ast_tree)
# source2 = astor.to_source(ref_ast_tree)
#
# if source1 != source2:
# pass
code = """
class Demonwrath(SpellCard):
def __init__(self):
super().__init__("Demonwrath", 3, CHARACTER_CLASS.WARLOCK, CARD_RARITY.RARE)
def use(self, player, game):
super().use(player, game)
targets = copy.copy(game.other_player.minions)
targets.extend(game.current_player.minions)
for minion in targets:
if minion.card.minion_type is not MINION_TYPE.DEMON:
minion.damage(player.effective_spell_damage(2), self)
"""
code = """sorted(mydict, key=mydict.get, reverse=True)"""
# # code = """a = [1,2,3,4,'asdf', 234.3]"""
parse_tree = parse(code)
# for leaf in parse_tree.get_leaves():
# if leaf.value: print escape(leaf.value)
#
print parse_tree
# ast_tree = parse_tree_to_python_ast(parse_tree)
# print astor.to_source(ast_tree)
================================================
FILE: lang/py/py_dataset.py
================================================
# -*- coding: UTF-8 -*-
from __future__ import division
import ast
import astor
import logging
from itertools import chain
import nltk
import re
from nn.utils.io_utils import serialize_to_file, deserialize_from_file
from nn.utils.generic_utils import init_logging
from dataset import gen_vocab, DataSet, DataEntry, Action, APPLY_RULE, GEN_TOKEN, COPY_TOKEN, GEN_COPY_TOKEN, Vocab
from lang.py.parse import parse, parse_tree_to_python_ast, canonicalize_code, get_grammar, parse_raw, \
de_canonicalize_code, tokenize_code, tokenize_code_adv, de_canonicalize_code_for_seq2seq
from lang.py.unaryclosure import get_top_unary_closures, apply_unary_closures
def extract_grammar(code_file, prefix='py'):
line_num = 0
parse_trees = []
for line in open(code_file):
code = line.strip()
parse_tree = parse(code)
# leaves = parse_tree.get_leaves()
# for leaf in leaves:
# if not is_terminal_type(leaf.type):
# print parse_tree
# parse_tree = add_root(parse_tree)
parse_trees.append(parse_tree)
# sanity check
ast_tree = parse_tree_to_python_ast(parse_tree)
ref_ast_tree = ast.parse(canonicalize_code(code)).body[0]
source1 = astor.to_source(ast_tree)
source2 = astor.to_source(ref_ast_tree)
assert source1 == source2
# check rules
# rule_list = parse_tree.get_rule_list(include_leaf=True)
# for rule in rule_list:
# if rule.parent.type == int and rule.children[0].type == int:
# # rule.parent.type == str and rule.children[0].type == str:
# pass
# ast_tree = tree_to_ast(parse_tree)
# print astor.to_source(ast_tree)
# print parse_tree
# except Exception as e:
# error_num += 1
# #pass
# #print e
line_num += 1
print 'total line of code: %d' % line_num
grammar = get_grammar(parse_trees)
with open(prefix + '.grammar.txt', 'w') as f:
for rule in grammar:
str = rule.__repr__()
f.write(str + '\n')
with open(prefix + '.parse_trees.txt', 'w') as f:
for tree in parse_trees:
f.write(tree.__repr__() + '\n')
return grammar, parse_trees
def rule_vs_node_stat():
line_num = 0
parse_trees = []
code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/card_datasets/hearthstone/all_hs.out' # '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code'
node_nums = rule_nums = 0.
for line in open(code_file):
code = line.replace('§', '\n').strip()
parse_tree = parse(code)
node_nums += len(list(parse_tree.nodes))
rules, _ = parse_tree.get_productions()
rule_nums += len(rules)
parse_trees.append(parse_tree)
line_num += 1
print 'avg. nums of nodes: %f' % (node_nums / line_num)
print 'avg. nums of rules: %f' % (rule_nums / line_num)
def process_heart_stone_dataset():
data_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/card_datasets/hearthstone/all_hs.out'
parse_trees = []
rule_num = 0.
example_num = 0
for line in open(data_file):
code = line.replace('§', '\n').strip()
parse_tree = parse(code)
# sanity check
pred_ast = parse_tree_to_python_ast(parse_tree)
pred_code = astor.to_source(pred_ast)
ref_ast = ast.parse(code)
ref_code = astor.to_source(ref_ast)
if pred_code != ref_code:
raise RuntimeError('code mismatch!')
rules, _ = parse_tree.get_productions(include_value_node=False)
rule_num += len(rules)
example_num += 1
parse_trees.append(parse_tree)
grammar = get_grammar(parse_trees)
with open('hs.grammar.txt', 'w') as f:
for rule in grammar:
str = rule.__repr__()
f.write(str + '\n')
with open('hs.parse_trees.txt', 'w') as f:
for tree in parse_trees:
f.write(tree.__repr__() + '\n')
print 'avg. nums of rules: %f' % (rule_num / example_num)
def canonicalize_hs_example(query, code):
query = re.sub(r'<.*?>', '', query)
query_tokens = nltk.word_tokenize(query)
code = code.replace('§', '\n').strip()
# sanity check
parse_tree = parse_raw(code)
gold_ast_tree = ast.parse(code).body[0]
gold_source = astor.to_source(gold_ast_tree)
ast_tree = parse_tree_to_python_ast(parse_tree)
pred_source = astor.to_source(ast_tree)
assert gold_source == pred_source, 'sanity check fails: gold=[%s], actual=[%s]' % (gold_source, pred_source)
return query_tokens, code, parse_tree
def preprocess_hs_dataset(annot_file, code_file):
f = open('hs_dataset.examples.txt', 'w')
examples = []
for idx, (annot, code) in enumerate(zip(open(annot_file), open(code_file))):
annot = annot.strip()
code = code.strip()
clean_query_tokens, clean_code, parse_tree = canonicalize_hs_example(annot, code)
example = {'id': idx, 'query_tokens': clean_query_tokens, 'code': clean_code, 'parse_tree': parse_tree,
'str_map': None, 'raw_code': code}
examples.append(example)
f.write('*' * 50 + '\n')
f.write('example# %d\n' % idx)
f.write(' '.join(clean_query_tokens) + '\n')
f.write('\n')
f.write(clean_code + '\n')
f.write('*' * 50 + '\n')
idx += 1
f.close()
print 'preprocess_dataset: cleaned example num: %d' % len(examples)
return examples
def parse_hs_dataset():
MAX_QUERY_LENGTH = 70 # FIXME: figure out the best config!
WORD_FREQ_CUT_OFF = 3
annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/card_datasets/hearthstone/all_hs.mod.in'
code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/card_datasets/hearthstone/all_hs.out'
data = preprocess_hs_dataset(annot_file, code_file)
parse_trees = [e['parse_tree'] for e in data]
# apply unary closures
unary_closures = get_top_unary_closures(parse_trees, k=20)
for parse_tree in parse_trees:
apply_unary_closures(parse_tree, unary_closures)
# build the grammar
grammar = get_grammar(parse_trees)
with open('hs.grammar.unary_closure.txt', 'w') as f:
for rule in grammar:
f.write(rule.__repr__() + '\n')
annot_tokens = list(chain(*[e['query_tokens'] for e in data]))
annot_vocab = gen_vocab(annot_tokens, vocab_size=5000, freq_cutoff=WORD_FREQ_CUT_OFF)
def get_terminal_tokens(_terminal_str):
"""
get terminal tokens
break words like MinionCards into [Minion, Cards]
"""
tmp_terminal_tokens = [t for t in _terminal_str.split(' ') if len(t) > 0]
_terminal_tokens = []
for token in tmp_terminal_tokens:
sub_tokens = re.sub(r'([a-z])([A-Z])', r'\1 \2', token).split(' ')
_terminal_tokens.extend(sub_tokens)
_terminal_tokens.append(' ')
return _terminal_tokens[:-1]
# enumerate all terminal tokens to build up the terminal tokens vocabulary
all_terminal_tokens = []
for entry in data:
parse_tree = entry['parse_tree']
for node in parse_tree.get_leaves():
if grammar.is_value_node(node):
terminal_val = node.value
terminal_str = str(terminal_val)
terminal_tokens = get_terminal_tokens(terminal_str)
for terminal_token in terminal_tokens:
assert len(terminal_token) > 0
all_terminal_tokens.append(terminal_token)
terminal_vocab = gen_vocab(all_terminal_tokens, vocab_size=5000, freq_cutoff=WORD_FREQ_CUT_OFF)
# now generate the dataset!
train_data = DataSet(annot_vocab, terminal_vocab, grammar, 'hs.train_data')
dev_data = DataSet(annot_vocab, terminal_vocab, grammar, 'hs.dev_data')
test_data = DataSet(annot_vocab, terminal_vocab, grammar, 'hs.test_data')
all_examples = []
can_fully_reconstructed_examples_num = 0
examples_with_empty_actions_num = 0
for entry in data:
idx = entry['id']
query_tokens = entry['query_tokens']
code = entry['code']
parse_tree = entry['parse_tree']
rule_list, rule_parents = parse_tree.get_productions(include_value_node=True)
actions = []
can_fully_reconstructed = True
rule_pos_map = dict()
for rule_count, rule in enumerate(rule_list):
if not grammar.is_value_node(rule.parent):
assert rule.value is None
parent_rule = rule_parents[(rule_count, rule)][0]
if parent_rule:
parent_t = rule_pos_map[parent_rule]
else:
parent_t = 0
rule_pos_map[rule] = len(actions)
d = {'rule': rule, 'parent_t': parent_t, 'parent_rule': parent_rule}
action = Action(APPLY_RULE, d)
actions.append(action)
else:
assert rule.is_leaf
parent_rule = rule_parents[(rule_count, rule)][0]
parent_t = rule_pos_map[parent_rule]
terminal_val = rule.value
terminal_str = str(terminal_val)
terminal_tokens = get_terminal_tokens(terminal_str)
# assert len(terminal_tokens) > 0
for terminal_token in terminal_tokens:
term_tok_id = terminal_vocab[terminal_token]
tok_src_idx = -1
try:
tok_src_idx = query_tokens.index(terminal_token)
except ValueError:
pass
d = {'literal': terminal_token, 'rule': rule, 'parent_rule': parent_rule, 'parent_t': parent_t}
# cannot copy, only generation
# could be unk!
if tok_src_idx < 0 or tok_src_idx >= MAX_QUERY_LENGTH:
action = Action(GEN_TOKEN, d)
if terminal_token not in terminal_vocab:
if terminal_token not in query_tokens:
# print terminal_token
can_fully_reconstructed = False
else: # copy
if term_tok_id != terminal_vocab.unk:
d['source_idx'] = tok_src_idx
action = Action(GEN_COPY_TOKEN, d)
else:
d['source_idx'] = tok_src_idx
action = Action(COPY_TOKEN, d)
actions.append(action)
d = {'literal': '<eos>', 'rule': rule, 'parent_rule': parent_rule, 'parent_t': parent_t}
actions.append(Action(GEN_TOKEN, d))
if len(actions) == 0:
examples_with_empty_actions_num += 1
continue
example = DataEntry(idx, query_tokens, parse_tree, code, actions, {'str_map': None, 'raw_code': entry['raw_code']})
if can_fully_reconstructed:
can_fully_reconstructed_examples_num += 1
# train, valid, test splits
if 0 <= idx < 533:
train_data.add(example)
elif idx < 599:
dev_data.add(example)
else:
test_data.add(example)
all_examples.append(example)
# print statistics
max_query_len = max(len(e.query) for e in all_examples)
max_actions_len = max(len(e.actions) for e in all_examples)
# serialize_to_file([len(e.query) for e in all_examples], 'query.len')
# serialize_to_file([len(e.actions) for e in all_examples], 'actions.len')
logging.info('examples that can be fully reconstructed: %d/%d=%f',
can_fully_reconstructed_examples_num, len(all_examples),
can_fully_reconstructed_examples_num / len(all_examples))
logging.info('empty_actions_count: %d', examples_with_empty_actions_num)
logging.info('max_query_len: %d', max_query_len)
logging.info('max_actions_len: %d', max_actions_len)
train_data.init_data_matrices(max_query_length=70, max_example_action_num=350)
dev_data.init_data_matrices(max_query_length=70, max_example_action_num=350)
test_data.init_data_matrices(max_query_length=70, max_example_action_num=350)
serialize_to_file((train_data, dev_data, test_data),
'data/hs.freq{WORD_FREQ_CUT_OFF}.max_action350.pre_suf.unary_closure.bin'.format(WORD_FREQ_CUT_OFF=WORD_FREQ_CUT_OFF))
return train_data, dev_data, test_data
def dump_data_for_evaluation(data_type='django', data_file='', max_query_length=70):
train_data, dev_data, test_data = deserialize_from_file(data_file)
prefix = '/Users/yinpengcheng/Projects/dl4mt-tutorial/codegen_data/'
for dataset, output in [(train_data, prefix + '%s.train' % data_type),
(dev_data, prefix + '%s.dev' % data_type),
(test_data, prefix + '%s.test' % data_type)]:
f_source = open(output + '.desc', 'w')
f_target = open(output + '.code', 'w')
for e in dataset.examples:
query_tokens = e.query[:max_query_length]
code = e.code
if data_type == 'django':
target_code = de_canonicalize_code_for_seq2seq(code, e.meta_data['raw_code'])
else:
target_code = code
# tokenize code
target_code = target_code.strip()
tokenized_target = tokenize_code_adv(target_code, breakCamelStr=False if data_type=='django' else True)
tokenized_target = [tk.replace('\n', '#NEWLINE#') for tk in tokenized_target]
tokenized_target = [tk for tk in tokenized_target if tk is not None]
while tokenized_target[-1] == '#INDENT#':
tokenized_target = tokenized_target[:-1]
f_source.write(' '.join(query_tokens) + '\n')
f_target.write(' '.join(tokenized_target) + '\n')
f_source.close()
f_target.close()
if __name__ == '__main__':
init_logging('py.log')
# rule_vs_node_stat()
# process_heart_stone_dataset()
parse_hs_dataset()
# dump_data_for_evaluation(data_file='data/django.cleaned.dataset.freq5.par_info.refact.space_only.bin')
# dump_data_for_evaluation(data_type='hs', data_file='data/hs.freq3.pre_suf.unary_closure.bin')
# code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code'
# py_grammar, _ = extract_grammar(code_file)
# serialize_to_file(py_grammar, 'py_grammar.bin')
================================================
FILE: lang/py/seq2tree_exp.py
================================================
import logging
import re
from collections import defaultdict, OrderedDict
from itertools import chain
import sys
from astnode import ASTNode
from dataset import preprocess_dataset, gen_vocab
from lang.py.grammar import type_str_to_type
from lang.py.parse import parse, get_grammar, decode_tree_to_python_ast
from lang.py.unaryclosure import get_top_unary_closures, apply_unary_closures
from lang.util import typename, escape, unescape
from nn.utils.generic_utils import init_logging
from nn.utils.io_utils import serialize_to_file
def ast_tree_to_seq2tree_repr(tree):
repr_str = ''
# node_name = typename(tree.type)
label_val = '' if tree.label is None else tree.label
value = '' if tree.value is None else tree.value
node_name = '%s{%s}{%s}' % (typename(tree.type), label_val, value)
repr_str += node_name
# wrap children with parentheses
if tree.children:
repr_str += ' ('
for child in tree.children:
child_repr = ast_tree_to_seq2tree_repr(child)
repr_str += ' ' + child_repr
repr_str += ' )'
return repr_str
node_re = re.compile(r'(?P<type>.*?)\{(?P<label>.*?)\}\{(?P<value>.*)\}')
def seq2tree_repr_to_ast_tree_helper(tree_repr, offset):
"""convert a seq2tree representation to AST tree"""
# extract node name
node_name_end = offset
while node_name_end < len(tree_repr) and tree_repr[node_name_end] != ' ':
node_name_end += 1
node_repr = tree_repr[offset:node_name_end]
m = node_re.match(node_repr)
n_type = m.group('type')
n_type = type_str_to_type(n_type)
n_label = m.group('label')
n_value = m.group('value')
if n_type in {int, float, str, bool}:
n_value = n_type(n_value)
n_label = None if n_label == '' else n_label
n_value = None if n_value == '' else n_value
node = ASTNode(n_type, label=n_label, value=n_value)
offset = node_name_end
if offset == len(tree_repr):
return node, offset
offset += 1
if tree_repr[offset] == '(':
offset += 2
while True:
child_node, offset = seq2tree_repr_to_ast_tree_helper(tree_repr, offset=offset)
node.add_child(child_node)
if offset >= len(tree_repr) or tree_repr[offset] == ')':
offset += 2
break
return node, offset
def seq2tree_repr_to_ast_tree(tree_repr):
tree, _ = seq2tree_repr_to_ast_tree_helper(tree_repr, 0)
return tree
def break_value_nodes(tree, hs=False):
"""inplace break value nodes with a string separaed by spaces"""
if tree.type == str and tree.value is not None:
assert tree.is_leaf
if hs:
tokens = re.sub(r'([a-z])([A-Z])', r'\1 #MERGE# \2', tree.value).split(' ')
else:
tokens = tree.value.split(' ')
tree.value = 'NT'
for token in tokens:
assert token is not None
tree.add_child(ASTNode(tree.type, value=escape(token)))
else:
for child in tree.children:
break_value_nodes(child, hs=hs)
def merge_broken_value_nodes(tree):
"""redo *break_value_nodes*"""
if tree.type == str and not tree.is_leaf:
assert tree.value == 'NT'
valid_children = [c for c in tree.children if c.value is not None]
value = ' '.join(unescape(c.value) for c in valid_children)
value = value.replace(' #MERGE# ', '')
tree.value = value
tree.children = []
else:
for child in tree.children:
merge_broken_value_nodes(child)
def parse_django_dataset_for_seq2tree():
from lang.py.parse import parse_raw
MAX_QUERY_LENGTH = 70
MAX_DECODING_TIME_STEP = 300
UNARY_CUTOFF_FREQ = 30
annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.anno'
code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code'
data = preprocess_dataset(annot_file, code_file)
for e in data:
e['parse_tree'] = parse_raw(e['code'])
parse_trees = [e['parse_tree'] for e in data]
# apply unary closures
# unary_closures = get_top_unary_closures(parse_trees, k=0, freq=UNARY_CUTOFF_FREQ)
# for i, parse_tree in enumerate(parse_trees):
# apply_unary_closures(parse_tree, unary_closures)
# build the grammar
grammar = get_grammar(parse_trees)
# # build grammar ...
# from lang.py.py_dataset import extract_grammar
# grammar, all_parse_trees = extract_grammar(code_file)
f_train = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/django/data/train.txt', 'w')
f_dev = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/django/data/dev.txt', 'w')
f_test = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/django/data/test.txt', 'w')
f_train_rawid = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/django/data/train.id.txt', 'w')
f_dev_rawid = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/django/data/dev.id.txt', 'w')
f_test_rawid = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/django/data/test.id.txt', 'w')
decode_time_steps = defaultdict(int)
# first pass
for entry in data:
idx = entry['id']
query_tokens = entry['query_tokens']
code = entry['code']
parse_tree = entry['parse_tree']
original_parse_tree = parse_tree.copy()
break_value_nodes(parse_tree)
tree_repr = ast_tree_to_seq2tree_repr(parse_tree)
num_decode_time_step = len(tree_repr.split(' '))
decode_time_steps[num_decode_time_step] += 1
new_tree = seq2tree_repr_to_ast_tree(tree_repr)
merge_broken_value_nodes(new_tree)
query_tokens = [t for t in query_tokens if t != ''][:MAX_QUERY_LENGTH]
query = ' '.join(query_tokens)
line = query + '\t' + tree_repr
if num_decode_time_step > MAX_DECODING_TIME_STEP:
continue
# train, valid, test
if 0 <= idx < 16000:
f_train.write(line + '\n')
f_train_rawid.write(str(idx) + '\n')
elif 16000 <= idx < 17000:
f_dev.write(line + '\n')
f_dev_rawid.write(str(idx) + '\n')
else:
f_test.write(line + '\n')
f_test_rawid.write(str(idx) + '\n')
if original_parse_tree != new_tree:
print '*' * 50
print idx
print code
f_train.close()
f_dev.close()
f_test.close()
f_train_rawid.close()
f_dev_rawid.close()
f_test_rawid.close()
# print 'num. of decoding time steps distribution:'
# for k in sorted(decode_time_steps):
# print '%d\t%d' % (k, decode_time_steps[k])
def parse_hs_dataset_for_seq2tree():
from lang.py.py_dataset import preprocess_hs_dataset
MAX_QUERY_LENGTH = 70 # FIXME: figure out the best config!
WORD_FREQ_CUT_OFF = 3
MAX_DECODING_TIME_STEP = 800
annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/card_datasets/hearthstone/all_hs.mod.in'
code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/card_datasets/hearthstone/all_hs.out'
data = preprocess_hs_dataset(annot_file, code_file)
parse_trees = [e['parse_tree'] for e in data]
# apply unary closures
unary_closures = get_top_unary_closures(parse_trees, k=20)
for parse_tree in parse_trees:
apply_unary_closures(parse_tree, unary_closures)
# build the grammar
grammar = get_grammar(parse_trees)
decode_time_steps = defaultdict(int)
f_train = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/hs/data_unkreplaced/train.txt', 'w')
f_dev = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/hs/data_unkreplaced/dev.txt', 'w')
f_test = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/hs/data_unkreplaced/test.txt', 'w')
f_train_rawid = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/hs/data_unkreplaced/train.id.txt', 'w')
f_dev_rawid = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/hs/data_unkreplaced/dev.id.txt', 'w')
f_test_rawid = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/hs/data_unkreplaced/test.id.txt', 'w')
# first pass
for entry in data:
idx = entry['id']
query_tokens = entry['query_tokens']
parse_tree = entry['parse_tree']
original_parse_tree = parse_tree.copy()
break_value_nodes(parse_tree, hs=True)
tree_repr = ast_tree_to_seq2tree_repr(parse_tree)
num_decode_time_step = len(tree_repr.split(' '))
decode_time_steps[num_decode_time_step] += 1
new_tree = seq2tree_repr_to_ast_tree(tree_repr)
merge_broken_value_nodes(new_tree)
query_tokens = [t for t in query_tokens if t != ''][:MAX_QUERY_LENGTH]
query = ' '.join(query_tokens)
line = query + '\t' + tree_repr
if num_decode_time_step > MAX_DECODING_TIME_STEP:
continue
# train, valid, test
if 0 <= idx < 533:
f_train.write(line + '\n')
f_train_rawid.write(str(idx) + '\n')
elif idx < 599:
f_dev.write(line + '\n')
f_dev_rawid.write(str(idx) + '\n')
else:
f_test.write(line + '\n')
f_test_rawid.write(str(idx) + '\n')
if original_parse_tree != new_tree:
print '*' * 50
print idx
print code
f_train.close()
f_dev.close()
f_test.close()
f_train_rawid.close()
f_dev_rawid.close()
f_test_rawid.close()
# print 'num. of decoding time steps distribution:'
for k in sorted(decode_time_steps):
print '%d\t%d' % (k, decode_time_steps[k])
if __name__ == '__main__':
init_logging('py.log')
# code = "return ( format_html_join ( '' , '_STR:0_' , sorted ( attrs . items ( ) ) ) + format_html_join ( '' , ' {0}' , sorted ( boolean_attrs ) ) )"
code = "call('{0}')"
parse_tree = parse(code)
# parse_tree = ASTNode('root', children=[
# ASTNode('lambda'),
# ASTNode('$0'),
# ASTNode('e', children=[
# ASTNode('and', children=[
# ASTNode('>', children=[ASTNode('$0')]),
# ASTNode('from', children=[ASTNode('$0'), ASTNode('ci0')]),
# ])
# ]),
# ])
original_parse_tree = parse_tree.copy()
break_value_nodes(parse_tree)
# tree_repr = """root{}{} ( For{}{} ( expr{target}{} ( Name{}{} ( str{id}{NT} ( ) ) ) expr{iter}{} ( Name{}{} ( str{id}{NT} ( Name{}{} ( str{id}{NT} ( str{}{self} ) ) ) ) ) stmt*{body}{} ( stmt{}{} ( Pass{}{} ) ) ) )"""
# print tree_repr
# new_tree = seq2tree_repr_to_ast_tree(tree_repr)
# merge_broken_value_nodes(new_tree)
# print str(original_parse_tree)
# print str(new_tree)
# assert original_parse_tree == new_tree
# parse_django_dataset_for_seq2tree()
parse_hs_dataset_for_seq2tree()
================================================
FILE: lang/py/unaryclosure.py
================================================
# -*- coding: UTF-8 -*-
from astnode import ASTNode
from lang.py.grammar import type_str_to_type
from lang.py.parse import parse
from collections import Counter
import re
def extract_unary_closure_helper(parse_tree, unary_link, last_node):
if parse_tree.is_leaf:
if unary_link and unary_link.size > 2:
return [unary_link]
else:
return []
elif len(parse_tree.children) > 1:
unary_links = []
if unary_link and unary_link.size > 2:
unary_links.append(unary_link)
for child in parse_tree.children:
new_node = ASTNode(child.type)
child_unary_links = extract_unary_closure_helper(child, new_node, new_node)
unary_links.extend(child_unary_links)
return unary_links
else: # has a single child
child = parse_tree.children[0]
new_node = ASTNode(child.type, label=child.label)
last_node.add_child(new_node)
last_node = new_node
return extract_unary_closure_helper(child, unary_link, last_node)
def extract_unary_closure(parse_tree):
root_node_copy = ASTNode(parse_tree.type)
unary_links = extract_unary_closure_helper(parse_tree, root_node_copy, root_node_copy)
return unary_links
def get_unary_links():
# data_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/card_datasets/hearthstone/all_hs.out'
data_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code'
parse_trees = []
unary_links_counter = Counter()
for line in open(data_file):
code = line.replace('§', '\n').strip()
parse_tree = parse(code)
parse_trees.append(parse_tree)
example_unary_links = extract_unary_closure(parse_tree)
for link in example_unary_links:
unary_links_counter[link] += 1
ranked_links = sorted(unary_links_counter, key=unary_links_counter.get, reverse=True)
for link in ranked_links:
print str(link) + ' ||| ' + str(unary_links_counter[link])
unary_links = ranked_links[:20]
unary_closures = []
for link in unary_links:
unary_closures.append(unary_link_to_closure(link))
unary_closures = zip(unary_links, unary_closures)
node_nums = rule_nums = 0.
for parse_tree in parse_trees:
original_parse_tree = parse_tree.copy()
for link, closure in unary_closures:
apply_unary_closure(parse_tree, closure, link)
# assert original_parse_tree != parse_tree
compressed_ast_to_normal(parse_tree)
assert original_parse_tree == parse_tree
rules, _ = parse_tree.get_productions()
rule_nums += len(rules)
node_nums += len(list(parse_tree.nodes))
print '**** after applying unary closures ****'
print 'avg. nums of nodes: %f' % (node_nums / len(parse_trees))
print 'avg. nums of rules: %f' % (rule_nums / len(parse_trees))
def get_top_unary_closures(parse_trees, k=20, freq=50):
unary_links_counter = Counter()
for parse_tree in parse_trees:
example_unary_links = extract_unary_closure(parse_tree)
for link in example_unary_links:
unary_links_counter[link] += 1
ranked_links = sorted(unary_links_counter, key=unary_links_counter.get, reverse=True)
if k:
print 'rank cut off: %d' % k
unary_links = ranked_links[:k]
else:
print 'freq cut off: %d' % freq
unary_links = sorted([l for l in unary_links_counter if unary_links_counter[l] >= freq], key=unary_links_counter.get, reverse=True)
unary_closures = []
for link in unary_links:
unary_closures.append(unary_link_to_closure(link))
unary_closures = zip(unary_links, unary_closures)
for link, closure in unary_closures:
print 'link: %s ||| closure: %s ||| freq: %d' % (link, closure, unary_links_counter[link])
return unary_closures
def apply_unary_closures(parse_tree, unary_closures):
unary_closures = sorted(unary_closures, key=lambda x: x[0].size, reverse=True)
original_parse_tree = parse_tree.copy()
# apply all unary closures
for link, closure in unary_closures:
apply_unary_closure(parse_tree, closure, link)
new_tree_copy = parse_tree.copy()
compressed_ast_to_normal(new_tree_copy)
assert original_parse_tree == new_tree_copy
rule_regex = re.compile(r'(?P<parent>.*?) -> \((?P<child>.*?)(\{(?P<clabel>.*?)\})?\)')
def compressed_ast_to_normal(parse_tree):
if parse_tree.label and '@' in parse_tree.label and '$' in parse_tree.label:
label = parse_tree.label
label = label.replace('$', ' ')
rule_reprs = label.split('@')
intermediate_nodes = []
first_node = last_node = None
for rule_repr in rule_reprs:
m = rule_regex.match(rule_repr)
p = m.group('parent')
c = m.group('child')
cl = m.group('clabel')
p_type = type_str_to_type(p)
c_type = type_str_to_type(c)
node = ASTNode(c_type, label=cl)
if last_node:
last_node.add_child(node)
if not first_node:
first_node = node
last_node = node
intermediate_nodes.append(node)
last_node.value = parse_tree.value
for child in parse_tree.children:
last_node.add_child(child)
compressed_ast_to_normal(child)
parent_node = parse_tree.parent
assert len(parent_node.children) == 1
del parent_node.children[0]
parent_node.add_child(first_node)
# return first_node
else:
new_child_trees = []
for child in parse_tree.children[:]:
compressed_ast_to_normal(child)
# new_child_trees.append(new_child_tree)
# del parse_tree.children[:]
# for child_tree in new_child_trees:
# parse_tree.add_child(child_tree)
#
# return parse_tree
def match_sub_tree(parse_tree, cur_match_node, is_root=False):
cur_level_match = False
if parse_tree.type == cur_match_node.type and (len(parse_tree.children) == 1 or cur_match_node.is_leaf) and \
(is_root or parse_tree.label == cur_match_node.label):
cur_level_match = True
if cur_level_match:
if cur_match_node.is_leaf:
return parse_tree
last_node = match_sub_tree(parse_tree.children[0], cur_match_node.children[0])
return last_node
else:
return None
def find(parse_tree, sub_tree):
match_results = []
last_node = match_sub_tree(parse_tree, sub_tree, True)
if last_node:
match_results.append((parse_tree, last_node))
for child in parse_tree.children:
child_match_results = find(child, sub_tree)
match_results.extend(child_match_results)
return match_results
def apply_unary_closure(parse_tree, unary_closure, unary_link):
match_results = find(parse_tree, unary_link)
for firs
gitextract__fzn9_a9/ ├── .gitignore ├── README.md ├── astnode.py ├── code_gen.py ├── components.py ├── config.py ├── dataset.py ├── decoder.py ├── evaluation.py ├── interactive_mode.py ├── lang/ │ ├── __init__.py │ ├── grammar.py │ ├── ifttt/ │ │ ├── __init__.py │ │ ├── grammar.py │ │ ├── ifttt_dataset.py │ │ └── parse.py │ ├── py/ │ │ ├── __init__.py │ │ ├── grammar.py │ │ ├── parse.py │ │ ├── py_dataset.py │ │ ├── seq2tree_exp.py │ │ └── unaryclosure.py │ ├── type_system.py │ └── util.py ├── learner.py ├── main.py ├── model.py ├── nn/ │ ├── __init__.py │ ├── activations.py │ ├── initializations.py │ ├── layers/ │ │ ├── __init__.py │ │ ├── convolution.py │ │ ├── core.py │ │ ├── embeddings.py │ │ └── recurrent.py │ ├── objectives.py │ ├── optimizers.py │ └── utils/ │ ├── __init__.py │ ├── config_factory.py │ ├── generic_utils.py │ ├── io_utils.py │ ├── np_utils.py │ ├── test_utils.py │ └── theano_utils.py ├── parse.py ├── parse_hiro.py ├── run_interactive.sh ├── run_interactive_singlefile.sh ├── run_trained_model.sh ├── train.sh └── util.py
SYMBOL INDEX (372 symbols across 35 files)
FILE: astnode.py
class ASTNode (line 8) | class ASTNode(object):
method __init__ (line 9) | def __init__(self, node_type, label=None, value=None, children=None):
method is_leaf (line 31) | def is_leaf(self):
method is_preterminal (line 35) | def is_preterminal(self):
method size (line 39) | def size(self):
method nodes (line 50) | def nodes(self):
method as_type_node (line 59) | def as_type_node(self):
method __repr__ (line 63) | def __repr__(self):
method __hash__ (line 83) | def __hash__(self):
method __eq__ (line 94) | def __eq__(self, other):
method __ne__ (line 118) | def __ne__(self, other):
method __getitem__ (line 121) | def __getitem__(self, child_type):
method __delitem__ (line 124) | def __delitem__(self, child_type):
method add_child (line 133) | def add_child(self, child):
method get_child_id (line 137) | def get_child_id(self, child):
method pretty_print (line 144) | def pretty_print(self):
method pretty_print_helper (line 150) | def pretty_print_helper(self, sb, depth, new_line=False):
method get_leaves (line 176) | def get_leaves(self):
method to_rule (line 186) | def to_rule(self, include_value=False):
method get_productions (line 198) | def get_productions(self, include_value_node=False):
method copy (line 245) | def copy(self):
class DecodeTree (line 264) | class DecodeTree(ASTNode):
method __init__ (line 265) | def __init__(self, node_type, label=None, value=None, children=None, t...
method copy (line 273) | def copy(self):
class Rule (line 285) | class Rule(ASTNode):
method __init__ (line 286) | def __init__(self, *args, **kwargs):
method parent (line 292) | def parent(self):
method __repr__ (line 295) | def __repr__(self):
FILE: components.py
class PointerNet (line 21) | class PointerNet(Layer):
method __init__ (line 22) | def __init__(self, name='PointerNet'):
method __call__ (line 35) | def __call__(self, query_embed, query_token_embed_mask, decoder_states):
class Hyp (line 53) | class Hyp:
method __init__ (line 54) | def __init__(self, *args):
method __repr__ (line 78) | def __repr__(self):
method can_expand (line 81) | def can_expand(self, node):
method apply_rule (line 104) | def apply_rule(self, rule, nt=None):
method append_token (line 126) | def append_token(self, token, nt=None):
method frontier_nt_helper (line 139) | def frontier_nt_helper(self, node):
method frontier_nt (line 153) | def frontier_nt(self):
method get_action_parent_t (line 163) | def get_action_parent_t(self):
class CondAttLSTM (line 195) | class CondAttLSTM(Layer):
method __init__ (line 199) | def __init__(self, input_dim, output_dim,
method _step (line 280) | def _step(self,
method _for_step (line 385) | def _for_step(self,
method __call__ (line 446) | def __call__(self, X, context, parent_t_seq, init_state=None, init_cel...
method get_mask (line 528) | def get_mask(self, mask, X):
FILE: dataset.py
class Action (line 31) | class Action(object):
method __init__ (line 32) | def __init__(self, act_type, data):
method __repr__ (line 36) | def __repr__(self):
class Vocab (line 44) | class Vocab(object):
method __init__ (line 45) | def __init__(self):
method unk (line 52) | def unk(self):
method eos (line 56) | def eos(self):
method __getitem__ (line 59) | def __getitem__(self, item):
method __contains__ (line 66) | def __contains__(self, item):
method size (line 70) | def size(self):
method __setitem__ (line 73) | def __setitem__(self, key, value):
method __len__ (line 76) | def __len__(self):
method __iter__ (line 79) | def __iter__(self):
method iteritems (line 82) | def iteritems(self):
method complete (line 85) | def complete(self):
method get_token (line 88) | def get_token(self, token_id):
method insert_token (line 91) | def insert_token(self, token):
function tokenize (line 104) | def tokenize(str):
function gen_vocab (line 109) | def gen_vocab(tokens, vocab_size=3000, freq_cutoff=5):
class DataEntry (line 133) | class DataEntry:
method __init__ (line 134) | def __init__(self, raw_id, query, parse_tree, code, actions, meta_data...
method data (line 145) | def data(self):
method copy (line 153) | def copy(self):
class DataSet (line 159) | class DataSet:
method __init__ (line 160) | def __init__(self, annot_vocab, terminal_vocab, grammar, name='train_d...
method add (line 168) | def add(self, example):
method get_dataset_by_ids (line 173) | def get_dataset_by_ids(self, ids, name):
method count (line 186) | def count(self):
method get_examples (line 192) | def get_examples(self, ids):
method get_prob_func_inputs (line 198) | def get_prob_func_inputs(self, ids):
method init_data_matrices (line 218) | def init_data_matrices(self, max_query_length=70, max_example_action_n...
class DataHelper (line 287) | class DataHelper(object):
method canonicalize_query (line 289) | def canonicalize_query(query):
function parse_django_dataset_nt_only (line 293) | def parse_django_dataset_nt_only():
function parse_django_dataset (line 359) | def parse_django_dataset():
function check_terminals (line 559) | def check_terminals():
function query_to_data (line 600) | def query_to_data(query, annot_vocab):
function canonicalize_query (line 616) | def canonicalize_query(query):
function canonicalize_example (line 664) | def canonicalize_example(query, code):
function process_query (line 690) | def process_query(query, code):
function preprocess_dataset (line 745) | def preprocess_dataset(annot_file, code_file):
FILE: decoder.py
function decode_python_dataset (line 6) | def decode_python_dataset(model, dataset, verbose=True):
function decode_ifttt_dataset (line 41) | def decode_ifttt_dataset(model, dataset, verbose=True):
FILE: evaluation.py
function tokenize_for_bleu_eval (line 17) | def tokenize_for_bleu_eval(code):
function evaluate (line 28) | def evaluate(model, dataset, verbose=True):
function evaluate_decode_results (line 62) | def evaluate_decode_results(dataset, decode_results, verbose=True):
function analyze_decode_results (line 266) | def analyze_decode_results(dataset, decode_results, verbose=True):
function evaluate_seq2seq_decode_results (line 523) | def evaluate_seq2seq_decode_results(dataset, seq2seq_decode_file, seq2se...
function evaluate_seq2tree_sample_file (line 610) | def evaluate_seq2tree_sample_file(sample_file, id_file, dataset):
function evaluate_ifttt_results (line 720) | def evaluate_ifttt_results(dataset, decode_results, verbose=True):
function ifttt_metric (line 809) | def ifttt_metric(predict_parse_tree, ref_parse_tree):
function decode_and_evaluate_ifttt (line 836) | def decode_and_evaluate_ifttt(model, test_data):
function decode_and_evaluate_ifttt_by_split (line 848) | def decode_and_evaluate_ifttt_by_split(model, test_data):
FILE: interactive_mode.py
function decode_query (line 101) | def decode_query(query):
FILE: lang/grammar.py
class Grammar (line 7) | class Grammar(object):
method __init__ (line 8) | def __init__(self, rules):
method __iter__ (line 51) | def __iter__(self):
method __len__ (line 54) | def __len__(self):
method __getitem__ (line 57) | def __getitem__(self, lhs):
method get_node_type_id (line 64) | def get_node_type_id(self, node):
method is_terminal (line 76) | def is_terminal(self, node):
method is_value_node (line 79) | def is_value_node(self, node):
FILE: lang/ifttt/grammar.py
class IFTTTGrammar (line 3) | class IFTTTGrammar(Grammar):
method __init__ (line 4) | def __init__(self, rules):
method is_value_node (line 7) | def is_value_node(self, node):
FILE: lang/ifttt/ifttt_dataset.py
function load_examples (line 20) | def load_examples(data_file):
function analyze_ifttt_dataset (line 35) | def analyze_ifttt_dataset():
function canonicalize_ifttt_example (line 57) | def canonicalize_ifttt_example(annot, code):
function preprocess_ifttt_dataset (line 66) | def preprocess_ifttt_dataset(annot_file, code_file):
function get_grammar (line 96) | def get_grammar(parse_trees):
function parse_ifttt_dataset (line 121) | def parse_ifttt_dataset():
function parse_data_for_seq2seq (line 239) | def parse_data_for_seq2seq(data_file='data/ifttt.freq3.bin'):
function extract_turk_data (line 267) | def extract_turk_data():
FILE: lang/ifttt/parse.py
function ifttt_ast_to_parse_tree_helper (line 3) | def ifttt_ast_to_parse_tree_helper(s, offset):
function ifttt_ast_to_parse_tree (line 39) | def ifttt_ast_to_parse_tree(s, attach_func_to_channel=True):
function strip_params (line 49) | def strip_params(parse_tree):
function attach_function_to_channel (line 60) | def attach_function_to_channel(parse_tree):
FILE: lang/py/grammar.py
function is_builtin_type (line 722) | def is_builtin_type(x):
function is_terminal_ast_type (line 726) | def is_terminal_ast_type(x):
function type_str_to_type (line 798) | def type_str_to_type(type_str):
function is_compositional_leaf (line 816) | def is_compositional_leaf(node):
class PythonGrammar (line 832) | class PythonGrammar(Grammar):
method __init__ (line 833) | def __init__(self, rules):
method is_value_node (line 836) | def is_value_node(self, node):
FILE: lang/py/parse.py
function python_ast_to_parse_tree (line 14) | def python_ast_to_parse_tree(node):
function parse_tree_to_python_ast (line 77) | def parse_tree_to_python_ast(tree):
function decode_tree_to_python_ast (line 143) | def decode_tree_to_python_ast(decode_tree):
function canonicalize_code (line 172) | def canonicalize_code(code):
function de_canonicalize_code (line 195) | def de_canonicalize_code(code, ref_raw_code):
function de_canonicalize_code_for_seq2seq (line 221) | def de_canonicalize_code_for_seq2seq(code, ref_raw_code):
function add_root (line 247) | def add_root(tree):
function parse (line 254) | def parse(code):
function parse_raw (line 270) | def parse_raw(code):
function get_grammar (line 280) | def get_grammar(parse_trees):
function tokenize_code (line 297) | def tokenize_code(code):
function tokenize_code_adv (line 308) | def tokenize_code_adv(code, breakCamelStr=False):
FILE: lang/py/py_dataset.py
function extract_grammar (line 19) | def extract_grammar(code_file, prefix='py'):
function rule_vs_node_stat (line 76) | def rule_vs_node_stat():
function process_heart_stone_dataset (line 95) | def process_heart_stone_dataset():
function canonicalize_hs_example (line 133) | def canonicalize_hs_example(query, code):
function preprocess_hs_dataset (line 151) | def preprocess_hs_dataset(annot_file, code_file):
function parse_hs_dataset (line 181) | def parse_hs_dataset():
function dump_data_for_evaluation (line 363) | def dump_data_for_evaluation(data_type='django', data_file='', max_query...
FILE: lang/py/seq2tree_exp.py
function ast_tree_to_seq2tree_repr (line 18) | def ast_tree_to_seq2tree_repr(tree):
function seq2tree_repr_to_ast_tree_helper (line 40) | def seq2tree_repr_to_ast_tree_helper(tree_repr, offset):
function seq2tree_repr_to_ast_tree (line 82) | def seq2tree_repr_to_ast_tree(tree_repr):
function break_value_nodes (line 88) | def break_value_nodes(tree, hs=False):
function merge_broken_value_nodes (line 106) | def merge_broken_value_nodes(tree):
function parse_django_dataset_for_seq2tree (line 122) | def parse_django_dataset_for_seq2tree():
function parse_hs_dataset_for_seq2tree (line 213) | def parse_hs_dataset_for_seq2tree():
FILE: lang/py/unaryclosure.py
function extract_unary_closure_helper (line 10) | def extract_unary_closure_helper(parse_tree, unary_link, last_node):
function extract_unary_closure (line 35) | def extract_unary_closure(parse_tree):
function get_unary_links (line 42) | def get_unary_links():
function get_top_unary_closures (line 88) | def get_top_unary_closures(parse_trees, k=20, freq=50):
function apply_unary_closures (line 115) | def apply_unary_closures(parse_tree, unary_closures):
function compressed_ast_to_normal (line 129) | def compressed_ast_to_normal(parse_tree):
function match_sub_tree (line 178) | def match_sub_tree(parse_tree, cur_match_node, is_root=False):
function find (line 194) | def find(parse_tree, sub_tree):
function apply_unary_closure (line 208) | def apply_unary_closure(parse_tree, unary_closure, unary_link):
function unary_link_to_closure (line 223) | def unary_link_to_closure(unary_link):
FILE: lang/util.py
function typename (line 2) | def typename(x):
function escape (line 7) | def escape(text):
function unescape (line 26) | def unescape(text):
FILE: learner.py
class Learner (line 15) | class Learner(object):
method __init__ (line 16) | def __init__(self, model, train_data, val_data=None):
method train (line 27) | def train(self):
class DataIterator (line 164) | class DataIterator:
method __init__ (line 165) | def __init__(self, dataset, batch_size=10):
method reset (line 173) | def reset(self):
method __iter__ (line 178) | def __iter__(self):
method next_batch (line 181) | def next_batch(self):
method next (line 189) | def next(self):
FILE: main.py
function escape (line 14) | def escape(text):
function typename (line 27) | def typename(x):
function get_tree_str_repr (line 31) | def get_tree_str_repr(node):
function get_tree (line 64) | def get_tree(node):
function parse (line 107) | def parse(code):
function parse_django (line 125) | def parse_django(code_file):
FILE: model.py
class Model (line 30) | class Model:
method __init__ (line 31) | def __init__(self):
method build (line 74) | def build(self):
method build_decoder (line 218) | def build_decoder(self, query_tokens, query_token_embed, query_token_e...
method decode (line 324) | def decode(self, example, grammar, terminal_vocab, beam_size, max_time...
method params_name_to_id (line 575) | def params_name_to_id(self):
method params_dict (line 586) | def params_dict(self):
method pull_params (line 590) | def pull_params(self):
method save (line 593) | def save(self, model_file, **kwargs):
method load (line 602) | def load(self, model_file):
FILE: nn/activations.py
function softmax (line 4) | def softmax(x):
function time_distributed_softmax (line 8) | def time_distributed_softmax(x):
function softplus (line 14) | def softplus(x):
function relu (line 18) | def relu(x):
function tanh (line 22) | def tanh(x):
function sigmoid (line 26) | def sigmoid(x):
function hard_sigmoid (line 30) | def hard_sigmoid(x):
function linear (line 34) | def linear(x):
function get (line 42) | def get(identifier):
FILE: nn/initializations.py
function get_fans (line 8) | def get_fans(shape):
function uniform (line 14) | def uniform(shape, scale=0.01, name=None):
function normal (line 18) | def normal(shape, scale=0.01, name=None):
function lecun_uniform (line 22) | def lecun_uniform(shape):
function glorot_normal (line 31) | def glorot_normal(shape):
function glorot_uniform (line 39) | def glorot_uniform(shape, name=None):
function he_normal (line 45) | def he_normal(shape):
function he_uniform (line 53) | def he_uniform(shape):
function orthogonal (line 59) | def orthogonal(shape, scale=1.1):
function identity (line 71) | def identity(shape, scale=1):
function zero (line 78) | def zero(shape):
function one (line 82) | def one(shape):
function get (line 87) | def get(identifier):
FILE: nn/layers/convolution.py
class Convolution2d (line 11) | class Convolution2d(Layer):
method __init__ (line 14) | def __init__(self, max_sent_len, word_embed_dim, filter_num, filter_wi...
method __call__ (line 35) | def __call__(self, X):
FILE: nn/layers/core.py
class Layer (line 15) | class Layer(object):
method __init__ (line 16) | def __init__(self):
method init_updates (line 19) | def init_updates(self):
method __call__ (line 22) | def __call__(self, X):
method supports_masked_input (line 25) | def supports_masked_input(self):
method get_output_mask (line 31) | def get_output_mask(self, train=None):
method set_weights (line 47) | def set_weights(self, weights):
method get_weights (line 53) | def get_weights(self):
method get_params (line 59) | def get_params(self):
method set_name (line 62) | def set_name(self, name):
class MaskedLayer (line 73) | class MaskedLayer(Layer):
method supports_masked_input (line 78) | def supports_masked_input(self):
class Dense (line 82) | class Dense(Layer):
method __init__ (line 83) | def __init__(self, input_dim, output_dim, init='glorot_uniform', activ...
method set_name (line 100) | def set_name(self, name):
method __call__ (line 104) | def __call__(self, X):
class Dropout (line 109) | class Dropout(Layer):
method __init__ (line 110) | def __init__(self, p, srng, name='dropout'):
method __call__ (line 121) | def __call__(self, X, train_only=True):
class WordDropout (line 132) | class WordDropout(Layer):
method __init__ (line 133) | def __init__(self, p, srng, name='WordDropout'):
method __call__ (line 139) | def __call__(self, X, train_only=True):
FILE: nn/layers/embeddings.py
function get_embed_iter (line 11) | def get_embed_iter(file_path):
class Embedding (line 22) | class Embedding(Layer):
method __init__ (line 30) | def __init__(self, input_dim, output_dim, init='uniform', name=None):
method get_output_mask (line 43) | def get_output_mask(self, X):
method init_pretrained (line 46) | def init_pretrained(self, file_path, vocab):
method __call__ (line 59) | def __call__(self, X, mask_zero=False):
class HybridEmbedding (line 67) | class HybridEmbedding(Layer):
method __init__ (line 75) | def __init__(self, embed_size, unfixed_embed_size, embed_dim, init='un...
method get_output_mask (line 95) | def get_output_mask(self, X):
method __call__ (line 98) | def __call__(self, X, mask_zero=False):
FILE: nn/layers/recurrent.py
class GRU (line 11) | class GRU(Layer):
method __init__ (line 33) | def __init__(self, input_dim, output_dim=128,
method _step (line 69) | def _step(self,
method __call__ (line 82) | def __call__(self, X, mask=None, init_state=None):
method get_padded_shuffled_mask (line 106) | def get_padded_shuffled_mask(self, mask, X, pad=0):
class GRU_4BiRNN (line 122) | class GRU_4BiRNN(Layer):
method __init__ (line 144) | def __init__(self, input_dim, output_dim=128,
method _step (line 180) | def _step(self,
method __call__ (line 200) | def __call__(self, X, mask=None, init_state=None):
method get_padded_shuffled_mask (line 232) | def get_padded_shuffled_mask(self, mask, pad=0):
class LSTM (line 246) | class LSTM(Layer):
method __init__ (line 247) | def __init__(self, input_dim, output_dim,
method _step (line 288) | def _step(self,
method __call__ (line 304) | def __call__(self, X, mask=None, init_state=None, dropout=0, train=Tru...
method get_mask (line 346) | def get_mask(self, mask, X):
class BiLSTM (line 358) | class BiLSTM(Layer):
method __init__ (line 359) | def __init__(self, input_dim, output_dim,
method __call__ (line 380) | def __call__(self, X, mask=None, init_state=None, dropout=0, train=Tru...
class CondAttLSTM (line 397) | class CondAttLSTM(Layer):
method __init__ (line 401) | def __init__(self, input_dim, output_dim,
method _step (line 461) | def _step(self,
method __call__ (line 504) | def __call__(self, X, context, init_state=None, init_cell=None, mask=N...
method get_mask (line 568) | def get_mask(self, mask, X):
class GRUDecoder (line 580) | class GRUDecoder(Layer):
method __init__ (line 584) | def __init__(self, input_dim, context_dim, hidden_dim, vocab_num,
method _step (line 631) | def _step(self,
method __call__ (line 642) | def __call__(self, target, context, mask=None):
method get_padded_shuffled_mask (line 674) | def get_padded_shuffled_mask(self, mask, pad=0):
FILE: nn/objectives.py
function mean_squared_error (line 13) | def mean_squared_error(y_true, y_pred):
function mean_absolute_error (line 17) | def mean_absolute_error(y_true, y_pred):
function mean_absolute_percentage_error (line 21) | def mean_absolute_percentage_error(y_true, y_pred):
function mean_squared_logarithmic_error (line 25) | def mean_squared_logarithmic_error(y_true, y_pred):
function squared_hinge (line 29) | def squared_hinge(y_true, y_pred):
function hinge (line 33) | def hinge(y_true, y_pred):
function categorical_crossentropy (line 37) | def categorical_crossentropy(y_true, y_pred):
function binary_crossentropy (line 47) | def binary_crossentropy(y_true, y_pred):
function poisson_loss (line 53) | def poisson_loss(y_true, y_pred):
function get (line 63) | def get(identifier):
FILE: nn/optimizers.py
function clip_norm (line 14) | def clip_norm(g, c, n):
function kl_divergence (line 20) | def kl_divergence(p, p_hat):
class Optimizer (line 24) | class Optimizer(object):
method __init__ (line 25) | def __init__(self, **kwargs):
method get_state (line 29) | def get_state(self):
method set_state (line 32) | def set_state(self, value_list):
method get_updates (line 37) | def get_updates(self, params, constraints, loss, **kwargs):
method get_gradients (line 40) | def get_gradients(self, loss, params, **kwargs):
method get_config (line 51) | def get_config(self):
class SGD (line 55) | class SGD(Optimizer):
method __init__ (line 57) | def __init__(self, lr=0.01, momentum=0., decay=0., nesterov=False, *ar...
method get_updates (line 64) | def get_updates(self, params, loss):
method get_config (line 82) | def get_config(self):
class RMSprop (line 90) | class RMSprop(Optimizer):
method __init__ (line 91) | def __init__(self, lr=0.001, rho=0.9, epsilon=1e-6, *args, **kwargs):
method get_updates (line 97) | def get_updates(self, params, constraints, loss):
method get_config (line 110) | def get_config(self):
class Adagrad (line 117) | class Adagrad(Optimizer):
method __init__ (line 118) | def __init__(self, lr=0.01, epsilon=1e-6, *args, **kwargs):
method get_updates (line 123) | def get_updates(self, params, constraints, loss):
method get_config (line 135) | def get_config(self):
class Adadelta (line 141) | class Adadelta(Optimizer):
method __init__ (line 145) | def __init__(self, lr=1.0, rho=0.95, epsilon=1e-6, *args, **kwargs):
method get_updates (line 150) | def get_updates(self, params, loss):
method get_config (line 172) | def get_config(self):
class Adadelta_GaussianNoise (line 179) | class Adadelta_GaussianNoise(Optimizer):
method __init__ (line 183) | def __init__(self, lr=1.0, rho=0.95, epsilon=1e-6, *args, **kwargs):
method get_updates (line 189) | def get_updates(self, params, loss):
method get_config (line 216) | def get_config(self):
class Adam (line 223) | class Adam(Optimizer):
method __init__ (line 229) | def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8, *...
method get_updates (line 236) | def get_updates(self, params, loss, **kwargs):
method get_config (line 268) | def get_config(self):
function get (line 284) | def get(identifier, kwargs=None):
FILE: nn/utils/config_factory.py
class MetaConfig (line 4) | class MetaConfig(type):
method __getitem__ (line 5) | def __getitem__(self, key):
method __setitem__ (line 8) | def __setitem__(self, key, value):
class config (line 12) | class config(object):
method set (line 17) | def set(key, val):
method init_config (line 21) | def init_config(file='config.py'):
FILE: nn/utils/generic_utils.py
function get_from_module (line 9) | def get_from_module(identifier, module_params, module_name, instantiate=...
function make_tuple (line 23) | def make_tuple(*args):
function printv (line 27) | def printv(v, prefix=''):
function make_batches (line 49) | def make_batches(size, batch_size):
function slice_X (line 54) | def slice_X(X, start=None, stop=None):
function init_logging (line 67) | def init_logging(file_name, level=logging.INFO):
function pad_sequences (line 83) | def pad_sequences(sequences, maxlen=None, dtype='int32',
class Progbar (line 147) | class Progbar(object):
method __init__ (line 148) | def __init__(self, target, width=30, verbose=1):
method update (line 161) | def update(self, current, values=[]):
method add (line 228) | def add(self, n, values=[]):
FILE: nn/utils/io_utils.py
class HDF5Matrix (line 9) | class HDF5Matrix():
method __init__ (line 12) | def __init__(self, datapath, dataset, start, end, normalizer=None):
method __len__ (line 23) | def __len__(self):
method __getitem__ (line 26) | def __getitem__(self, key):
method shape (line 53) | def shape(self):
function save_array (line 57) | def save_array(array, name):
function load_array (line 66) | def load_array(name):
function serialize_to_file (line 76) | def serialize_to_file(obj, path, protocol=cPickle.HIGHEST_PROTOCOL):
function deserialize_from_file (line 82) | def deserialize_from_file(path):
FILE: nn/utils/np_utils.py
function to_categorical (line 8) | def to_categorical(y, nb_classes=None):
function normalize (line 21) | def normalize(a, axis=-1, order=2):
function binary_logloss (line 27) | def binary_logloss(p, y):
function multiclass_logloss (line 36) | def multiclass_logloss(P, Y):
function accuracy (line 43) | def accuracy(p, y):
function probas_to_classes (line 47) | def probas_to_classes(y_pred):
function categorical_probas_to_classes (line 53) | def categorical_probas_to_classes(p):
FILE: nn/utils/test_utils.py
function get_test_data (line 4) | def get_test_data(nb_train=1000, nb_test=500, input_shape=(10,), output_...
FILE: nn/utils/theano_utils.py
function floatX (line 7) | def floatX(X):
function sharedX (line 11) | def sharedX(X, dtype=theano.config.floatX, name=None):
function shared_zeros (line 15) | def shared_zeros(shape, dtype=theano.config.floatX, name=None):
function shared_scalar (line 19) | def shared_scalar(val=0., dtype=theano.config.floatX, name=None):
function shared_ones (line 23) | def shared_ones(shape, dtype=theano.config.floatX, name=None):
function alloc_zeros_matrix (line 27) | def alloc_zeros_matrix(*dims):
function tensor_right_shift (line 31) | def tensor_right_shift(tensor):
function ndim_tensor (line 38) | def ndim_tensor(ndim, name=None):
function ndim_itensor (line 51) | def ndim_itensor(ndim, name=None):
function ndim_btensor (line 62) | def ndim_btensor(ndim, name=None):
FILE: parse_hiro.py
function typename (line 6) | def typename(x):
function escape (line 9) | def escape(text):
function makestr (line 21) | def makestr(node):
function main (line 61) | def main():
FILE: util.py
function is_numeric (line 1) | def is_numeric(s):
Condensed preview — 51 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (329K chars).
[
{
"path": ".gitignore",
"chars": 1045,
"preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
},
{
"path": "README.md",
"chars": 970,
"preview": "# NL2code\n\nA syntactic neural model for parsing natural language to executable code [paper](https://arxiv.org/abs/1704.0"
},
{
"path": "astnode.py",
"chars": 9570,
"preview": "from collections import namedtuple\nimport cPickle\nfrom collections import Iterable, OrderedDict, defaultdict\nfrom cStrin"
},
{
"path": "code_gen.py",
"chars": 11213,
"preview": "import numpy as np\nimport cProfile\nimport ast\nimport traceback\nimport argparse\nimport os\nimport logging\nfrom vprof impor"
},
{
"path": "components.py",
"chars": 20459,
"preview": "import theano\nimport theano.tensor as T\nimport numpy as np\nimport logging\nimport copy\n\nfrom nn.layers.embeddings import "
},
{
"path": "config.py",
"chars": 794,
"preview": "# MODE = 'django'\n#\n# SOURCE_VOCAB_SIZE = 2490 # 2492 # 5980\n# TARGET_VOCAB_SIZE = 2101 # 2110 # 4830 #\n# RULE_NUM = 222"
},
{
"path": "dataset.py",
"chars": 26709,
"preview": "from __future__ import division\nimport copy\n\nimport nltk\nfrom collections import OrderedDict, defaultdict\nimport logging"
},
{
"path": "decoder.py",
"chars": 2518,
"preview": "import traceback\nimport config\n\nfrom model import *\n\ndef decode_python_dataset(model, dataset, verbose=True):\n from l"
},
{
"path": "evaluation.py",
"chars": 33709,
"preview": "# -*- coding: UTF-8 -*-\n\nfrom __future__ import division\nimport os\nfrom nltk.translate.bleu_score import sentence_bleu, "
},
{
"path": "interactive_mode.py",
"chars": 6245,
"preview": "import argparse, sys\nfrom nn.utils.generic_utils import init_logging\nfrom nn.utils.io_utils import deserialize_from_file"
},
{
"path": "lang/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "lang/grammar.py",
"chars": 2563,
"preview": "from collections import OrderedDict, defaultdict\nimport logging\n\nfrom astnode import ASTNode\nfrom lang.util import typen"
},
{
"path": "lang/ifttt/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "lang/ifttt/grammar.py",
"chars": 200,
"preview": "from lang.grammar import Grammar\n\nclass IFTTTGrammar(Grammar):\n def __init__(self, rules):\n super(IFTTTGrammar"
},
{
"path": "lang/ifttt/ifttt_dataset.py",
"chars": 13667,
"preview": "# -*- coding: UTF-8 -*-\nfrom __future__ import division\nimport string\nfrom collections import OrderedDict\nfrom collectio"
},
{
"path": "lang/ifttt/parse.py",
"chars": 2669,
"preview": "from astnode import ASTNode\n\ndef ifttt_ast_to_parse_tree_helper(s, offset):\n \"\"\"\n adapted from ifttt codebase\n "
},
{
"path": "lang/py/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "lang/py/grammar.py",
"chars": 19283,
"preview": "\"\"\"\nPython grammar and typing system\n\"\"\"\nimport ast\nimport inspect\nimport astor\n\nfrom lang.grammar import Grammar\n\nPY_AS"
},
{
"path": "lang/py/parse.py",
"chars": 12913,
"preview": "import ast\nimport logging\nimport re\nimport token as tk\nfrom cStringIO import StringIO\nfrom tokenize import generate_toke"
},
{
"path": "lang/py/py_dataset.py",
"chars": 14800,
"preview": "# -*- coding: UTF-8 -*-\nfrom __future__ import division\nimport ast\nimport astor\nimport logging\nfrom itertools import cha"
},
{
"path": "lang/py/seq2tree_exp.py",
"chars": 10961,
"preview": "import logging\nimport re\nfrom collections import defaultdict, OrderedDict\nfrom itertools import chain\n\nimport sys\n\nfrom "
},
{
"path": "lang/py/unaryclosure.py",
"chars": 8802,
"preview": "# -*- coding: UTF-8 -*-\n\nfrom astnode import ASTNode\nfrom lang.py.grammar import type_str_to_type\nfrom lang.py.parse imp"
},
{
"path": "lang/type_system.py",
"chars": 0,
"preview": ""
},
{
"path": "lang/util.py",
"chars": 966,
"preview": "# x is a type\ndef typename(x):\n if isinstance(x, str):\n return x\n return x.__name__\n\ndef escape(text):\n "
},
{
"path": "learner.py",
"chars": 7996,
"preview": "from nn.utils.config_factory import config\nfrom nn.utils.generic_utils import *\n\nimport logging\nimport numpy as np\nimpor"
},
{
"path": "main.py",
"chars": 4763,
"preview": "import ast\nimport re\n\nfrom astnode import *\n\np_elif = re.compile(r'^elif\\s?')\np_else = re.compile(r'^else\\s?')\np_try = r"
},
{
"path": "model.py",
"chars": 28422,
"preview": "import theano\nimport theano.tensor as T\nfrom theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams\nimport num"
},
{
"path": "nn/__init__.py",
"chars": 28,
"preview": "__author__ = 'yinpengcheng'\n"
},
{
"path": "nn/activations.py",
"chars": 789,
"preview": "import theano.tensor as T\n\n\ndef softmax(x):\n return T.nnet.softmax(x.reshape((-1, x.shape[-1]))).reshape(x.shape)\n\n\nd"
},
{
"path": "nn/initializations.py",
"chars": 2298,
"preview": "import theano\nimport theano.tensor as T\nimport numpy as np\n\nfrom .utils.theano_utils import sharedX, shared_zeros, share"
},
{
"path": "nn/layers/__init__.py",
"chars": 28,
"preview": "__author__ = 'yinpengcheng'\n"
},
{
"path": "nn/layers/convolution.py",
"chars": 1971,
"preview": "# -*- coding: utf-8 -*-\n\nfrom .core import Layer\nfrom nn.utils.theano_utils import *\nimport nn.initializations as initia"
},
{
"path": "nn/layers/core.py",
"chars": 4797,
"preview": "# -*- coding: utf-8 -*-\n\nimport theano\nimport theano.tensor as T\nimport numpy as np\n\nfrom nn.utils.theano_utils import *"
},
{
"path": "nn/layers/embeddings.py",
"chars": 3103,
"preview": "# -*- coding: utf-8 -*-\n\nfrom .core import Layer\nfrom nn.utils.theano_utils import *\nimport nn.initializations as initia"
},
{
"path": "nn/layers/recurrent.py",
"chars": 26795,
"preview": "# -*- coding: utf-8 -*-\n\nimport logging\nimport theano\nimport theano.tensor as T\nimport numpy as np\n\nfrom .core import *\n"
},
{
"path": "nn/objectives.py",
"chars": 1873,
"preview": "from __future__ import absolute_import\nimport theano\nimport theano.tensor as T\nimport numpy as np\nfrom six.moves import "
},
{
"path": "nn/optimizers.py",
"chars": 10315,
"preview": "from __future__ import absolute_import\nimport theano\nimport theano.tensor as T\n\nfrom .utils.theano_utils import shared_z"
},
{
"path": "nn/utils/__init__.py",
"chars": 28,
"preview": "__author__ = 'yinpengcheng'\n"
},
{
"path": "nn/utils/config_factory.py",
"chars": 584,
"preview": "import logging\n\n\nclass MetaConfig(type):\n def __getitem__(self, key):\n return config._config[key]\n\n def __s"
},
{
"path": "nn/utils/generic_utils.py",
"chars": 7725,
"preview": "from __future__ import absolute_import\nimport numpy as np\nimport time\nimport sys\nimport six\nimport logging\n\n\ndef get_fro"
},
{
"path": "nn/utils/io_utils.py",
"chars": 2320,
"preview": "from __future__ import absolute_import\n\nimport cPickle\nimport h5py\nimport numpy as np\nfrom collections import defaultdic"
},
{
"path": "nn/utils/np_utils.py",
"chars": 1395,
"preview": "from __future__ import absolute_import\nimport numpy as np\nimport scipy as sp\nfrom six.moves import range\nfrom six.moves "
},
{
"path": "nn/utils/test_utils.py",
"chars": 1091,
"preview": "import numpy as np\n\n\ndef get_test_data(nb_train=1000, nb_test=500, input_shape=(10,), output_shape=(2,),\n "
},
{
"path": "nn/utils/theano_utils.py",
"chars": 1618,
"preview": "from __future__ import absolute_import\nimport numpy as np\nimport theano\nimport theano.tensor as T\n\n\ndef floatX(X):\n r"
},
{
"path": "parse.py",
"chars": 1529,
"preview": "import ast\nimport re\nimport sys, inspect\nfrom StringIO import StringIO\n\nimport astor\nfrom collections import OrderedDict"
},
{
"path": "parse_hiro.py",
"chars": 2503,
"preview": "import ast\nimport sys\nimport re\nimport inspect\n\ndef typename(x):\n return type(x).__name__\n\ndef escape(text):\n text"
},
{
"path": "run_interactive.sh",
"chars": 969,
"preview": "output=\"runs\"\ndevice=\"cpu\"\n\nif [ \"$1\" == \"hs\" ]; then\n\t# hs dataset\n\techo \"run trained model for hs\"\n\tdataset=\"data/hs.f"
},
{
"path": "run_interactive_singlefile.sh",
"chars": 806,
"preview": "device=\"cpu\"\n\nif [ \"$1\" == \"hs\" ]; then\n\t# hs dataset\n\techo \"run trained model for hs\"\n\tdataset=\"data/hs.freq3.pre_suf.u"
},
{
"path": "run_trained_model.sh",
"chars": 1252,
"preview": "output=\"runs\"\ndevice=\"cpu\"\n\nif [ \"$1\" == \"hs\" ]; then\n\t# hs dataset\n\techo \"run trained model for hs\"\n\tdataset=\"data/hs.f"
},
{
"path": "train.sh",
"chars": 1481,
"preview": "output=\"runs\"\ndevice=\"cpu\"\n\nif [ \"$1\" == \"hs\" ]; then\n\t# hs dataset\n\techo \"training hs dataset\"\n\tdataset=\"hs.freq3.pre_s"
},
{
"path": "util.py",
"chars": 99,
"preview": "def is_numeric(s):\n if s[0] in ('-', '+'):\n return s[1:].isdigit()\n return s.isdigit()"
}
]
About this extraction
This page contains the full source code of the pcyin/NL2code GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 51 files (309.2 KB), approximately 78.4k tokens, and a symbol index with 372 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.