Repository: pcyin/NL2code Branch: master Commit: f9732f1f5caa Files: 51 Total size: 309.2 KB Directory structure: gitextract__fzn9_a9/ ├── .gitignore ├── README.md ├── astnode.py ├── code_gen.py ├── components.py ├── config.py ├── dataset.py ├── decoder.py ├── evaluation.py ├── interactive_mode.py ├── lang/ │ ├── __init__.py │ ├── grammar.py │ ├── ifttt/ │ │ ├── __init__.py │ │ ├── grammar.py │ │ ├── ifttt_dataset.py │ │ └── parse.py │ ├── py/ │ │ ├── __init__.py │ │ ├── grammar.py │ │ ├── parse.py │ │ ├── py_dataset.py │ │ ├── seq2tree_exp.py │ │ └── unaryclosure.py │ ├── type_system.py │ └── util.py ├── learner.py ├── main.py ├── model.py ├── nn/ │ ├── __init__.py │ ├── activations.py │ ├── initializations.py │ ├── layers/ │ │ ├── __init__.py │ │ ├── convolution.py │ │ ├── core.py │ │ ├── embeddings.py │ │ └── recurrent.py │ ├── objectives.py │ ├── optimizers.py │ └── utils/ │ ├── __init__.py │ ├── config_factory.py │ ├── generic_utils.py │ ├── io_utils.py │ ├── np_utils.py │ ├── test_utils.py │ └── theano_utils.py ├── parse.py ├── parse_hiro.py ├── run_interactive.sh ├── run_interactive_singlefile.sh ├── run_trained_model.sh ├── train.sh └── util.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python env/ build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ *.egg-info/ .installed.cfg *.egg # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *,cover .hypothesis/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # IPython Notebook .ipynb_checkpoints # pyenv .python-version # celery beat schedule file celerybeat-schedule # dotenv .env # virtualenv venv/ ENV/ # Spyder project settings .spyderproject # Rope project settings .ropeproject ================================================ FILE: README.md ================================================ # NL2code A syntactic neural model for parsing natural language to executable code [paper](https://arxiv.org/abs/1704.01696). ## Dataset and Trained Models Get serialized datasets and trained models from [here](https://drive.google.com/drive/folders/0B14lJ2VVvtmJWEQ5RlFjQUY2Vzg). Put `models/` and `data/` folders under the root directory of the project. ## Usage To train new model ```bash . train.sh [hs|django] ``` To use trained model for decoding test sets ```bash . run_trained_model.sh [hs|django] ``` ## Dependencies * Theano * vprof * NLTK 3.2.1 * astor 0.6 ## Reference ``` @inproceedings{yin17acl, title = {A Syntactic Neural Model for General-Purpose Code Generation}, author = {Pengcheng Yin and Graham Neubig}, booktitle = {The 55th Annual Meeting of the Association for Computational Linguistics (ACL)}, address = {Vancouver, Canada}, month = {July}, url = {https://arxiv.org/abs/1704.01696}, year = {2017} } ``` ================================================ FILE: astnode.py ================================================ from collections import namedtuple import cPickle from collections import Iterable, OrderedDict, defaultdict from cStringIO import StringIO from lang.util import typename class ASTNode(object): def __init__(self, node_type, label=None, value=None, children=None): self.type = node_type self.label = label self.value = value if type(self) is not Rule: self.parent = None self.children = list() if children: if isinstance(children, Iterable): for child in children: self.add_child(child) elif isinstance(children, ASTNode): self.add_child(children) else: raise AttributeError('Wrong type for child nodes') assert not (bool(children) and bool(value)), 'terminal node with a value cannot have children' @property def is_leaf(self): return len(self.children) == 0 @property def is_preterminal(self): return len(self.children) == 1 and self.children[0].is_leaf @property def size(self): if self.is_leaf: return 1 node_num = 1 for child in self.children: node_num += child.size return node_num @property def nodes(self): """a generator that returns all the nodes""" yield self for child in self.children: for child_n in child.nodes: yield child_n @property def as_type_node(self): """return an ASTNode with type information only""" return ASTNode(self.type) def __repr__(self): repr_str = '' # if not self.is_leaf: repr_str += '(' repr_str += typename(self.type) if self.label is not None: repr_str += '{%s}' % self.label if self.value is not None: repr_str += '{val=%s}' % self.value # if not self.is_leaf: for child in self.children: repr_str += ' ' + child.__repr__() repr_str += ')' return repr_str def __hash__(self): code = hash(self.type) if self.label is not None: code = code * 37 + hash(self.label) if self.value is not None: code = code * 37 + hash(self.value) for child in self.children: code = code * 37 + hash(child) return code def __eq__(self, other): if not isinstance(other, self.__class__): return False if hash(self) != hash(other): return False if self.type != other.type: return False if self.label != other.label: return False if self.value != other.value: return False if len(self.children) != len(other.children): return False for i in xrange(len(self.children)): if self.children[i] != other.children[i]: return False return True def __ne__(self, other): return not self.__eq__(other) def __getitem__(self, child_type): return next(iter([c for c in self.children if c.type == child_type])) def __delitem__(self, child_type): tgt_child = [c for c in self.children if c.type == child_type] if tgt_child: assert len(tgt_child) == 1, 'unsafe deletion for more than one children' tgt_child = tgt_child[0] self.children.remove(tgt_child) else: raise KeyError def add_child(self, child): child.parent = self self.children.append(child) def get_child_id(self, child): for i, _child in enumerate(self.children): if child == _child: return i raise KeyError def pretty_print(self): sb = StringIO() new_line = False self.pretty_print_helper(sb, 0, new_line) return sb.getvalue() def pretty_print_helper(self, sb, depth, new_line=False): if new_line: sb.write('\n') for i in xrange(depth): sb.write(' ') sb.write('(') sb.write(typename(self.type)) if self.label is not None: sb.write('{%s}' % self.label) if self.value is not None: sb.write('{val=%s}' % self.value) if len(self.children) == 0: sb.write(')') return sb.write(' ') new_line = True for child in self.children: child.pretty_print_helper(sb, depth + 2, new_line) sb.write('\n') for i in xrange(depth): sb.write(' ') sb.write(')') def get_leaves(self): if self.is_leaf: return [self] leaves = [] for child in self.children: leaves.extend(child.get_leaves()) return leaves def to_rule(self, include_value=False): """ transform the current AST node to a production rule """ rule = Rule(self.type) for c in self.children: val = c.value if include_value else None child = ASTNode(c.type, c.label, val) rule.add_child(child) return rule def get_productions(self, include_value_node=False): """ get the depth-first, left-to-right sequence of rule applications returns a list of production rules and a map to their parent rules attention: node value is not included in child nodes """ rule_list = list() rule_parents = OrderedDict() node_rule_map = dict() s = list() s.append(self) rule_num = 0 while len(s) > 0: node = s.pop() for child in reversed(node.children): if not child.is_leaf: s.append(child) elif include_value_node: if child.value is not None: s.append(child) # only non-terminals and terminal nodes holding values # can form a production rule if node.children or node.value is not None: rule = Rule(node.type) if include_value_node: rule.value = node.value for c in node.children: val = None child = ASTNode(c.type, c.label, val) rule.add_child(child) rule_list.append(rule) if node.parent: child_id = node.parent.get_child_id(node) parent_rule = node_rule_map[node.parent] rule_parents[(rule_num, rule)] = (parent_rule, child_id) else: rule_parents[(rule_num, rule)] = (None, -1) rule_num += 1 node_rule_map[node] = rule return rule_list, rule_parents def copy(self): # if not hasattr(self, '_dump'): # dump = cPickle.dumps(self, -1) # setattr(self, '_dump', dump) # # return cPickle.loads(dump) # # return cPickle.loads(self._dump) new_tree = ASTNode(self.type, self.label, self.value) if self.is_leaf: return new_tree for child in self.children: new_tree.add_child(child.copy()) return new_tree class DecodeTree(ASTNode): def __init__(self, node_type, label=None, value=None, children=None, t=-1): super(DecodeTree, self).__init__(node_type, label, value, children) # record the time step when this subtree is created from a rule application self.t = t # record the ApplyRule action that is used to expand the current node self.applied_rule = None def copy(self): new_tree = DecodeTree(self.type, self.label, value=self.value, t=self.t) new_tree.applied_rule = self.applied_rule if self.is_leaf: return new_tree for child in self.children: new_tree.add_child(child.copy()) return new_tree class Rule(ASTNode): def __init__(self, *args, **kwargs): super(Rule, self).__init__(*args, **kwargs) assert self.value is None and self.label is None, 'Rule LHS cannot have values or labels' @property def parent(self): return self.as_type_node def __repr__(self): parent = typename(self.type) if self.label is not None: parent += '{%s}' % self.label if self.value is not None: parent += '{val=%s}' % self.value return '%s -> %s' % (parent, ', '.join([repr(c) for c in self.children])) if __name__ == '__main__': import ast t1 = ASTNode('root', children=[ ASTNode(str, 'a1_label', children=[ASTNode(int, children=[ASTNode('a21', value=123)]), ASTNode(ast.NodeTransformer, children=[ASTNode('a21', value='hahaha')])] ), ASTNode('a2', children=[ASTNode('a21', value='asdf')]) ]) t2 = ASTNode('root', children=[ ASTNode(str, 'a1_label', children=[ASTNode(int, children=[ASTNode('a21', value=123)]), ASTNode(ast.NodeTransformer, children=[ASTNode('a21', value='hahaha')])] ), ASTNode('a2', children=[ASTNode('a21', value='asdf')]) ]) print t1 == t2 a, b = t1.get_productions(include_value_node=True) # t = ASTNode('root', children=ASTNode('sdf')) print t1.__repr__() print t1.pretty_print() ================================================ FILE: code_gen.py ================================================ import numpy as np import cProfile import ast import traceback import argparse import os import logging from vprof import profiler from model import Model from dataset import DataEntry, DataSet, Vocab, Action import config from learner import Learner from evaluation import * from decoder import decode_python_dataset from components import Hyp from astnode import ASTNode from nn.utils.generic_utils import init_logging from nn.utils.io_utils import deserialize_from_file, serialize_to_file parser = argparse.ArgumentParser() parser.add_argument('-data') parser.add_argument('-random_seed', default=181783, type=int) parser.add_argument('-output_dir', default='.outputs') parser.add_argument('-model', default=None) # model's main configuration parser.add_argument('-data_type', default='django', choices=['django', 'ifttt', 'hs']) # neural model's parameters parser.add_argument('-source_vocab_size', default=0, type=int) parser.add_argument('-target_vocab_size', default=0, type=int) parser.add_argument('-rule_num', default=0, type=int) parser.add_argument('-node_num', default=0, type=int) parser.add_argument('-word_embed_dim', default=128, type=int) parser.add_argument('-rule_embed_dim', default=256, type=int) parser.add_argument('-node_embed_dim', default=256, type=int) parser.add_argument('-encoder_hidden_dim', default=256, type=int) parser.add_argument('-decoder_hidden_dim', default=256, type=int) parser.add_argument('-attention_hidden_dim', default=50, type=int) parser.add_argument('-ptrnet_hidden_dim', default=50, type=int) parser.add_argument('-dropout', default=0.2, type=float) # encoder parser.add_argument('-encoder', default='bilstm', choices=['bilstm', 'lstm']) # decoder parser.add_argument('-parent_hidden_state_feed', dest='parent_hidden_state_feed', action='store_true') parser.add_argument('-no_parent_hidden_state_feed', dest='parent_hidden_state_feed', action='store_false') parser.set_defaults(parent_hidden_state_feed=True) parser.add_argument('-parent_action_feed', dest='parent_action_feed', action='store_true') parser.add_argument('-no_parent_action_feed', dest='parent_action_feed', action='store_false') parser.set_defaults(parent_action_feed=True) parser.add_argument('-frontier_node_type_feed', dest='frontier_node_type_feed', action='store_true') parser.add_argument('-no_frontier_node_type_feed', dest='frontier_node_type_feed', action='store_false') parser.set_defaults(frontier_node_type_feed=True) parser.add_argument('-tree_attention', dest='tree_attention', action='store_true') parser.add_argument('-no_tree_attention', dest='tree_attention', action='store_false') parser.set_defaults(tree_attention=False) parser.add_argument('-enable_copy', dest='enable_copy', action='store_true') parser.add_argument('-no_copy', dest='enable_copy', action='store_false') parser.set_defaults(enable_copy=True) # training parser.add_argument('-optimizer', default='adam') parser.add_argument('-clip_grad', default=0., type=float) parser.add_argument('-train_patience', default=10, type=int) parser.add_argument('-max_epoch', default=50, type=int) parser.add_argument('-batch_size', default=10, type=int) parser.add_argument('-valid_per_batch', default=4000, type=int) parser.add_argument('-save_per_batch', default=4000, type=int) parser.add_argument('-valid_metric', default='bleu') # decoding parser.add_argument('-beam_size', default=15, type=int) parser.add_argument('-max_query_length', default=70, type=int) parser.add_argument('-decode_max_time_step', default=100, type=int) parser.add_argument('-head_nt_constraint', dest='head_nt_constraint', action='store_true') parser.add_argument('-no_head_nt_constraint', dest='head_nt_constraint', action='store_false') parser.set_defaults(head_nt_constraint=True) sub_parsers = parser.add_subparsers(dest='operation', help='operation to take') train_parser = sub_parsers.add_parser('train') decode_parser = sub_parsers.add_parser('decode') interactive_parser = sub_parsers.add_parser('interactive') evaluate_parser = sub_parsers.add_parser('evaluate') # decoding operation decode_parser.add_argument('-saveto', default='decode_results.bin') decode_parser.add_argument('-type', default='test_data') # evaluation operation evaluate_parser.add_argument('-mode', default='self') evaluate_parser.add_argument('-input', default='decode_results.bin') evaluate_parser.add_argument('-type', default='test_data') evaluate_parser.add_argument('-seq2tree_sample_file', default='model.sample') evaluate_parser.add_argument('-seq2tree_id_file', default='test.id.txt') evaluate_parser.add_argument('-seq2tree_rareword_map', default=None) evaluate_parser.add_argument('-seq2seq_decode_file') evaluate_parser.add_argument('-seq2seq_ref_file') evaluate_parser.add_argument('-is_nbest', default=False, action='store_true') # misc parser.add_argument('-ifttt_test_split', default='data/ifff.test_data.gold.id') # interactive operation interactive_parser.add_argument('-mode', default='dataset') if __name__ == '__main__': args = parser.parse_args() if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) np.random.seed(args.random_seed) init_logging(os.path.join(args.output_dir, 'parser.log'), logging.INFO) logging.info('command line: %s', ' '.join(sys.argv)) logging.info('loading dataset [%s]', args.data) train_data, dev_data, test_data = deserialize_from_file(args.data) if not args.source_vocab_size: args.source_vocab_size = train_data.annot_vocab.size if not args.target_vocab_size: args.target_vocab_size = train_data.terminal_vocab.size if not args.rule_num: args.rule_num = len(train_data.grammar.rules) if not args.node_num: args.node_num = len(train_data.grammar.node_type_to_id) logging.info('current config: %s', args) config_module = sys.modules['config'] for name, value in vars(args).iteritems(): setattr(config_module, name, value) # get dataset statistics avg_action_num = np.average([len(e.actions) for e in train_data.examples]) logging.info('avg_action_num: %d', avg_action_num) logging.info('grammar rule num.: %d', len(train_data.grammar.rules)) logging.info('grammar node type num.: %d', len(train_data.grammar.node_type_to_id)) logging.info('source vocab size: %d', train_data.annot_vocab.size) logging.info('target vocab size: %d', train_data.terminal_vocab.size) if args.operation in ['train', 'decode', 'interactive']: model = Model() model.build() if args.model: model.load(args.model) if args.operation == 'train': # train_data = train_data.get_dataset_by_ids(range(2000), 'train_sample') # dev_data = dev_data.get_dataset_by_ids(range(10), 'dev_sample') learner = Learner(model, train_data, dev_data) learner.train() if args.operation == 'decode': # ========================== # investigate short examples # ========================== # short_examples = [e for e in test_data.examples if e.parse_tree.size <= 2] # for e in short_examples: # print e.parse_tree # print 'short examples num: ', len(short_examples) # dataset = test_data # test_data.get_dataset_by_ids([1,2,3,4,5,6,7,8,9,10], name='sample') # cProfile.run('decode_dataset(model, dataset)', sort=2) # from evaluation import decode_and_evaluate_ifttt if args.data_type == 'ifttt': decode_results = decode_and_evaluate_ifttt_by_split(model, test_data) else: dataset = eval(args.type) decode_results = decode_python_dataset(model, dataset) serialize_to_file(decode_results, args.saveto) if args.operation == 'evaluate': dataset = eval(args.type) if config.mode == 'self': decode_results_file = args.input decode_results = deserialize_from_file(decode_results_file) evaluate_decode_results(dataset, decode_results) elif config.mode == 'seq2tree': from evaluation import evaluate_seq2tree_sample_file evaluate_seq2tree_sample_file(config.seq2tree_sample_file, config.seq2tree_id_file, dataset) elif config.mode == 'seq2seq': from evaluation import evaluate_seq2seq_decode_results evaluate_seq2seq_decode_results(dataset, config.seq2seq_decode_file, config.seq2seq_ref_file, is_nbest=config.is_nbest) elif config.mode == 'analyze': from evaluation import analyze_decode_results decode_results_file = args.input decode_results = deserialize_from_file(decode_results_file) analyze_decode_results(dataset, decode_results) if args.operation == 'interactive': from dataset import canonicalize_query, query_to_data from collections import namedtuple from lang.py.parse import decode_tree_to_python_ast assert model is not None while True: cmd = raw_input('example id or query: ') if args.mode == 'dataset': try: example_id = int(cmd) example = [e for e in test_data.examples if e.raw_id == example_id][0] except: print 'something went wrong ...' continue elif args.mode == 'new': # we play with new examples! query, str_map = canonicalize_query(cmd) vocab = train_data.annot_vocab query_tokens = query.split(' ') query_tokens_data = [query_to_data(query, vocab)] example = namedtuple('example', ['query', 'data'])(query=query_tokens, data=query_tokens_data) if hasattr(example, 'parse_tree'): print 'gold parse tree:' print example.parse_tree cand_list = model.decode(example, train_data.grammar, train_data.terminal_vocab, beam_size=args.beam_size, max_time_step=args.decode_max_time_step, log=True) has_grammar_error = any([c for c in cand_list if c.has_grammar_error]) print 'has_grammar_error: ', has_grammar_error for cid, cand in enumerate(cand_list[:5]): print '*' * 60 print 'cand #%d, score: %f' % (cid, cand.score) try: ast_tree = decode_tree_to_python_ast(cand.tree) code = astor.to_source(ast_tree) print 'code: ', code print 'decode log: ', cand.log except: print "Exception in converting tree to code:" print '-' * 60 print 'raw_id: %d, beam pos: %d' % (example.raw_id, cid) traceback.print_exc(file=sys.stdout) print '-' * 60 finally: print '* parse tree *' print cand.tree.__repr__() print 'n_timestep: %d' % cand.n_timestep print 'ast size: %d' % cand.tree.size print '*' * 60 ================================================ FILE: components.py ================================================ import theano import theano.tensor as T import numpy as np import logging import copy from nn.layers.embeddings import Embedding from nn.layers.core import Dense, Layer from nn.layers.recurrent import BiLSTM, LSTM, CondAttLSTM from nn.utils.theano_utils import ndim_itensor, tensor_right_shift, ndim_tensor, alloc_zeros_matrix, shared_zeros import nn.initializations as initializations import nn.activations as activations import nn.optimizers as optimizers import config from lang.grammar import Grammar from parse import * from astnode import * class PointerNet(Layer): def __init__(self, name='PointerNet'): super(PointerNet, self).__init__() self.dense1_input = Dense(config.encoder_hidden_dim, config.ptrnet_hidden_dim, activation='linear', name='Dense1_input') self.dense1_h = Dense(config.decoder_hidden_dim + config.encoder_hidden_dim, config.ptrnet_hidden_dim, activation='linear', name='Dense1_h') self.dense2 = Dense(config.ptrnet_hidden_dim, 1, activation='linear', name='Dense2') self.params += self.dense1_input.params + self.dense1_h.params + self.dense2.params self.set_name(name) def __call__(self, query_embed, query_token_embed_mask, decoder_states): query_embed_trans = self.dense1_input(query_embed) h_trans = self.dense1_h(decoder_states) query_embed_trans = query_embed_trans.dimshuffle((0, 'x', 1, 2)) h_trans = h_trans.dimshuffle((0, 1, 'x', 2)) # (batch_size, max_decode_step, query_token_num, ptr_net_hidden_dim) dense1_trans = T.tanh(query_embed_trans + h_trans) scores = self.dense2(dense1_trans).flatten(3) scores = T.exp(scores - T.max(scores, axis=-1, keepdims=True)) scores *= query_token_embed_mask.dimshuffle((0, 'x', 1)) scores = scores / T.sum(scores, axis=-1, keepdims=True) return scores class Hyp: def __init__(self, *args): if isinstance(args[0], Hyp): hyp = args[0] self.grammar = hyp.grammar self.tree = hyp.tree.copy() self.t = hyp.t self.hist_h = list(hyp.hist_h) self.log = hyp.log self.has_grammar_error = hyp.has_grammar_error else: assert isinstance(args[0], Grammar) grammar = args[0] self.grammar = grammar self.tree = DecodeTree(grammar.root_node.type) self.t=-1 self.hist_h = [] self.log = '' self.has_grammar_error = False self.score = 0.0 self.__frontier_nt = self.tree self.__frontier_nt_t = -1 def __repr__(self): return self.tree.__repr__() def can_expand(self, node): if self.grammar.is_value_node(node): # if the node is finished if node.value is not None and node.value.endswith(''): return False return True elif self.grammar.is_terminal(node): return False # elif node.type == 'epsilon': # return False # elif is_terminal_ast_type(node.type): # return False # if node.type == 'root': # return True # elif inspect.isclass(node.type) and issubclass(node.type, ast.AST) and not is_terminal_ast_type(node.type): # return True # elif node.holds_value and not node.label.endswith(''): # return True return True def apply_rule(self, rule, nt=None): if nt is None: nt = self.frontier_nt() # assert rule.parent.type == nt.type if rule.parent.type != nt.type: self.has_grammar_error = True self.t += 1 # set the time step when the rule leading by this nt is applied nt.t = self.t # record the ApplyRule action that is used to expand the current node nt.applied_rule = rule for child_node in rule.children: child = DecodeTree(child_node.type, child_node.label, child_node.value) # if is_builtin_type(rule.parent.type): # child.label = None # child.holds_value = True nt.add_child(child) def append_token(self, token, nt=None): if nt is None: nt = self.frontier_nt() self.t += 1 if nt.value is None: # this terminal node is empty nt.t = self.t nt.value = token else: nt.value += token def frontier_nt_helper(self, node): if node.is_leaf: if self.can_expand(node): return node else: return None for child in node.children: result = self.frontier_nt_helper(child) if result: return result return None def frontier_nt(self): if self.__frontier_nt_t == self.t: return self.__frontier_nt else: _frontier_nt = self.frontier_nt_helper(self.tree) self.__frontier_nt = _frontier_nt self.__frontier_nt_t = self.t return _frontier_nt def get_action_parent_t(self): """ get the time step when the parent of the current action was generated WARNING: 0 will be returned if parent if None """ nt = self.frontier_nt() # if nt is a non-finishing leaf # if nt.holds_value: # return nt.t if nt.parent: return nt.parent.t else: return 0 # def get_action_parent_tree(self): # """ # get the parent tree # """ # nt = self.frontier_nt() # # # if nt is a non-finishing leaf # if nt.holds_value: # return nt # # if nt.parent: # return nt.parent # else: # return None class CondAttLSTM(Layer): """ Conditional LSTM with Attention """ def __init__(self, input_dim, output_dim, context_dim, att_hidden_dim, init='glorot_uniform', inner_init='orthogonal', forget_bias_init='one', activation='tanh', inner_activation='sigmoid', name='CondAttLSTM'): super(CondAttLSTM, self).__init__() self.output_dim = output_dim self.init = initializations.get(init) self.inner_init = initializations.get(inner_init) self.forget_bias_init = initializations.get(forget_bias_init) self.activation = activations.get(activation) self.inner_activation = activations.get(inner_activation) self.context_dim = context_dim self.input_dim = input_dim # regular LSTM layer self.W_i = self.init((input_dim, self.output_dim)) self.U_i = self.inner_init((self.output_dim, self.output_dim)) self.C_i = self.inner_init((self.context_dim, self.output_dim)) self.H_i = self.inner_init((self.output_dim, self.output_dim)) self.P_i = self.inner_init((self.output_dim, self.output_dim)) self.b_i = shared_zeros((self.output_dim)) self.W_f = self.init((input_dim, self.output_dim)) self.U_f = self.inner_init((self.output_dim, self.output_dim)) self.C_f = self.inner_init((self.context_dim, self.output_dim)) self.H_f = self.inner_init((self.output_dim, self.output_dim)) self.P_f = self.inner_init((self.output_dim, self.output_dim)) self.b_f = self.forget_bias_init((self.output_dim)) self.W_c = self.init((input_dim, self.output_dim)) self.U_c = self.inner_init((self.output_dim, self.output_dim)) self.C_c = self.inner_init((self.context_dim, self.output_dim)) self.H_c = self.inner_init((self.output_dim, self.output_dim)) self.P_c = self.inner_init((self.output_dim, self.output_dim)) self.b_c = shared_zeros((self.output_dim)) self.W_o = self.init((input_dim, self.output_dim)) self.U_o = self.inner_init((self.output_dim, self.output_dim)) self.C_o = self.inner_init((self.context_dim, self.output_dim)) self.H_o = self.inner_init((self.output_dim, self.output_dim)) self.P_o = self.inner_init((self.output_dim, self.output_dim)) self.b_o = shared_zeros((self.output_dim)) self.params = [ self.W_i, self.U_i, self.b_i, self.C_i, self.H_i, self.P_i, self.W_c, self.U_c, self.b_c, self.C_c, self.H_c, self.P_c, self.W_f, self.U_f, self.b_f, self.C_f, self.H_f, self.P_f, self.W_o, self.U_o, self.b_o, self.C_o, self.H_o, self.P_o, ] # attention layer self.att_ctx_W1 = self.init((context_dim, att_hidden_dim)) self.att_h_W1 = self.init((output_dim, att_hidden_dim)) self.att_b1 = shared_zeros((att_hidden_dim)) self.att_W2 = self.init((att_hidden_dim, 1)) self.att_b2 = shared_zeros((1)) self.params += [ self.att_ctx_W1, self.att_h_W1, self.att_b1, self.att_W2, self.att_b2 ] # attention over history self.hatt_h_W1 = self.init((output_dim, att_hidden_dim)) self.hatt_hist_W1 = self.init((output_dim, att_hidden_dim)) self.hatt_b1 = shared_zeros((att_hidden_dim)) self.hatt_W2 = self.init((att_hidden_dim, 1)) self.hatt_b2 = shared_zeros((1)) self.params += [ self.hatt_h_W1, self.hatt_hist_W1, self.hatt_b1, self.hatt_W2, self.hatt_b2 ] self.set_name(name) def _step(self, t, xi_t, xf_t, xo_t, xc_t, mask_t, parent_t, h_tm1, c_tm1, hist_h, u_i, u_f, u_o, u_c, c_i, c_f, c_o, c_c, h_i, h_f, h_o, h_c, p_i, p_f, p_o, p_c, att_h_w1, att_w2, att_b2, context, context_mask, context_att_trans, b_u): # context: (batch_size, context_size, context_dim) # (batch_size, att_layer1_dim) h_tm1_att_trans = T.dot(h_tm1, att_h_w1) # h_tm1_att_trans = theano.printing.Print('h_tm1_att_trans')(h_tm1_att_trans) # (batch_size, context_size, att_layer1_dim) att_hidden = T.tanh(context_att_trans + h_tm1_att_trans[:, None, :]) # (batch_size, context_size, 1) att_raw = T.dot(att_hidden, att_w2) + att_b2 att_raw = att_raw.reshape((att_raw.shape[0], att_raw.shape[1])) # (batch_size, context_size) ctx_att = T.exp(att_raw - T.max(att_raw, axis=-1, keepdims=True)) if context_mask: ctx_att = ctx_att * context_mask ctx_att = ctx_att / T.sum(ctx_att, axis=-1, keepdims=True) # (batch_size, context_dim) ctx_vec = T.sum(context * ctx_att[:, :, None], axis=1) # t = theano.printing.Print('t')(t) ##### attention over history ##### def _attention_over_history(): hist_h_mask = T.zeros((hist_h.shape[0], hist_h.shape[1]), dtype='int8') hist_h_mask = T.set_subtensor(hist_h_mask[:, T.arange(t)], 1) hist_h_att_trans = T.dot(hist_h, self.hatt_hist_W1) + self.hatt_b1 h_tm1_hatt_trans = T.dot(h_tm1, self.hatt_h_W1) hatt_hidden = T.tanh(hist_h_att_trans + h_tm1_hatt_trans[:, None, :]) hatt_raw = T.dot(hatt_hidden, self.hatt_W2) + self.hatt_b2 hatt_raw = hatt_raw.reshape((hist_h.shape[0], hist_h.shape[1])) # hatt_raw = theano.printing.Print('hatt_raw')(hatt_raw) hatt_exp = T.exp(hatt_raw - T.max(hatt_raw, axis=-1, keepdims=True)) * hist_h_mask # hatt_exp = theano.printing.Print('hatt_exp')(hatt_exp) # hatt_exp = hatt_exp.flatten(2) h_att_weights = hatt_exp / (T.sum(hatt_exp, axis=-1, keepdims=True) + 1e-7) # h_att_weights = theano.printing.Print('h_att_weights')(h_att_weights) # (batch_size, output_dim) _h_ctx_vec = T.sum(hist_h * h_att_weights[:, :, None], axis=1) return _h_ctx_vec h_ctx_vec = T.switch(t, _attention_over_history(), T.zeros_like(h_tm1)) # h_ctx_vec = theano.printing.Print('h_ctx_vec')(h_ctx_vec) ##### attention over history ##### ##### feed in parent hidden state ##### if not config.parent_hidden_state_feed: t = 0 par_h = T.switch(t, hist_h[T.arange(hist_h.shape[0]), parent_t, :], T.zeros_like(h_tm1)) ##### feed in parent hidden state ##### if config.tree_attention: i_t = self.inner_activation( xi_t + T.dot(h_tm1 * b_u[0], u_i) + T.dot(ctx_vec, c_i) + T.dot(par_h, p_i) + T.dot(h_ctx_vec, h_i)) f_t = self.inner_activation( xf_t + T.dot(h_tm1 * b_u[1], u_f) + T.dot(ctx_vec, c_f) + T.dot(par_h, p_f) + T.dot(h_ctx_vec, h_f)) c_t = f_t * c_tm1 + i_t * self.activation( xc_t + T.dot(h_tm1 * b_u[2], u_c) + T.dot(ctx_vec, c_c) + T.dot(par_h, p_c) + T.dot(h_ctx_vec, h_c)) o_t = self.inner_activation( xo_t + T.dot(h_tm1 * b_u[3], u_o) + T.dot(ctx_vec, c_o) + T.dot(par_h, p_o) + T.dot(h_ctx_vec, h_o)) else: i_t = self.inner_activation( xi_t + T.dot(h_tm1 * b_u[0], u_i) + T.dot(ctx_vec, c_i) + T.dot(par_h, p_i)) # + T.dot(h_ctx_vec, h_i) f_t = self.inner_activation( xf_t + T.dot(h_tm1 * b_u[1], u_f) + T.dot(ctx_vec, c_f) + T.dot(par_h, p_f)) # + T.dot(h_ctx_vec, h_f) c_t = f_t * c_tm1 + i_t * self.activation( xc_t + T.dot(h_tm1 * b_u[2], u_c) + T.dot(ctx_vec, c_c) + T.dot(par_h, p_c)) # + T.dot(h_ctx_vec, h_c) o_t = self.inner_activation( xo_t + T.dot(h_tm1 * b_u[3], u_o) + T.dot(ctx_vec, c_o) + T.dot(par_h, p_o)) # + T.dot(h_ctx_vec, h_o) h_t = o_t * self.activation(c_t) h_t = (1 - mask_t) * h_tm1 + mask_t * h_t c_t = (1 - mask_t) * c_tm1 + mask_t * c_t new_hist_h = T.set_subtensor(hist_h[:, t, :], h_t) return h_t, c_t, ctx_vec, new_hist_h def _for_step(self, xi_t, xf_t, xo_t, xc_t, mask_t, h_tm1, c_tm1, context, context_mask, context_att_trans, hist_h, hist_h_att_trans, b_u): # context: (batch_size, context_size, context_dim) # (batch_size, att_layer1_dim) h_tm1_att_trans = T.dot(h_tm1, self.att_h_W1) # (batch_size, context_size, att_layer1_dim) att_hidden = T.tanh(context_att_trans + h_tm1_att_trans[:, None, :]) # (batch_size, context_size, 1) att_raw = T.dot(att_hidden, self.att_W2) + self.att_b2 # (batch_size, context_size) ctx_att = T.exp(att_raw).reshape((att_raw.shape[0], att_raw.shape[1])) if context_mask: ctx_att = ctx_att * context_mask ctx_att = ctx_att / T.sum(ctx_att, axis=-1, keepdims=True) # (batch_size, context_dim) ctx_vec = T.sum(context * ctx_att[:, :, None], axis=1) ##### attention over history ##### if hist_h: hist_h = T.stack(hist_h).dimshuffle((1, 0, 2)) hist_h_att_trans = T.stack(hist_h_att_trans).dimshuffle((1, 0, 2)) h_tm1_hatt_trans = T.dot(h_tm1, self.hatt_h_W1) hatt_hidden = T.tanh(hist_h_att_trans + h_tm1_hatt_trans[:, None, :]) hatt_raw = T.dot(hatt_hidden, self.hatt_W2) + self.hatt_b2 hatt_raw = hatt_raw.flatten(2) h_att_weights = T.nnet.softmax(hatt_raw) # (batch_size, output_dim) h_ctx_vec = T.sum(hist_h * h_att_weights[:, :, None], axis=1) else: h_ctx_vec = T.zeros_like(h_tm1) ##### attention over history ##### i_t = self.inner_activation(xi_t + T.dot(h_tm1 * b_u[0], self.U_i) + T.dot(ctx_vec, self.C_i) + T.dot(h_ctx_vec, self.H_i)) f_t = self.inner_activation(xf_t + T.dot(h_tm1 * b_u[1], self.U_f) + T.dot(ctx_vec, self.C_f) + T.dot(h_ctx_vec, self.H_f)) c_t = f_t * c_tm1 + i_t * self.activation(xc_t + T.dot(h_tm1 * b_u[2], self.U_c) + T.dot(ctx_vec, self.C_c) + T.dot(h_ctx_vec, self.H_c)) o_t = self.inner_activation(xo_t + T.dot(h_tm1 * b_u[3], self.U_o) + T.dot(ctx_vec, self.C_o) + T.dot(h_ctx_vec, self.H_o)) h_t = o_t * self.activation(c_t) h_t = (1 - mask_t) * h_tm1 + mask_t * h_t c_t = (1 - mask_t) * c_tm1 + mask_t * c_t # ctx_vec = theano.printing.Print('ctx_vec')(ctx_vec) return h_t, c_t, ctx_vec def __call__(self, X, context, parent_t_seq, init_state=None, init_cell=None, hist_h=None, mask=None, context_mask=None, dropout=0, train=True, srng=None, time_steps=None): assert context_mask.dtype == 'int8', 'context_mask is not int8, got %s' % context_mask.dtype # (n_timestep, batch_size) mask = self.get_mask(mask, X) # (n_timestep, batch_size, input_dim) X = X.dimshuffle((1, 0, 2)) retain_prob = 1. - dropout B_w = np.ones((4,), dtype=theano.config.floatX) B_u = np.ones((4,), dtype=theano.config.floatX) if dropout > 0: logging.info('applying dropout with p = %f', dropout) if train: B_w = srng.binomial((4, X.shape[1], self.input_dim), p=retain_prob, dtype=theano.config.floatX) B_u = srng.binomial((4, X.shape[1], self.output_dim), p=retain_prob, dtype=theano.config.floatX) else: B_w *= retain_prob B_u *= retain_prob # (n_timestep, batch_size, output_dim) xi = T.dot(X * B_w[0], self.W_i) + self.b_i xf = T.dot(X * B_w[1], self.W_f) + self.b_f xc = T.dot(X * B_w[2], self.W_c) + self.b_c xo = T.dot(X * B_w[3], self.W_o) + self.b_o # (batch_size, context_size, att_layer1_dim) context_att_trans = T.dot(context, self.att_ctx_W1) + self.att_b1 if init_state: # (batch_size, output_dim) first_state = T.unbroadcast(init_state, 1) else: first_state = T.unbroadcast(alloc_zeros_matrix(X.shape[1], self.output_dim), 1) if init_cell: # (batch_size, output_dim) first_cell = T.unbroadcast(init_cell, 1) else: first_cell = T.unbroadcast(alloc_zeros_matrix(X.shape[1], self.output_dim), 1) if not hist_h: # (batch_size, n_timestep, output_dim) hist_h = alloc_zeros_matrix(X.shape[1], X.shape[0], self.output_dim) if train: n_timestep = X.shape[0] time_steps = T.arange(n_timestep, dtype='int32') # (n_timestep, batch_size) parent_t_seq = parent_t_seq.dimshuffle((1, 0)) [outputs, cells, ctx_vectors, hist_h_outputs], updates = theano.scan( self._step, sequences=[time_steps, xi, xf, xo, xc, mask, parent_t_seq], outputs_info=[ first_state, # for h first_cell, # for cell None, # T.unbroadcast(alloc_zeros_matrix(X.shape[1], self.context_dim), 1), # for ctx vector hist_h, # for hist_h ], non_sequences=[ self.U_i, self.U_f, self.U_o, self.U_c, self.C_i, self.C_f, self.C_o, self.C_c, self.H_i, self.H_f, self.H_o, self.H_c, self.P_i, self.P_f, self.P_o, self.P_c, self.att_h_W1, self.att_W2, self.att_b2, context, context_mask, context_att_trans, B_u ]) outputs = outputs.dimshuffle((1, 0, 2)) ctx_vectors = ctx_vectors.dimshuffle((1, 0, 2)) cells = cells.dimshuffle((1, 0, 2)) return outputs, cells, ctx_vectors def get_mask(self, mask, X): if mask is None: mask = T.ones((X.shape[0], X.shape[1])) mask = T.shape_padright(mask) # (nb_samples, time, 1) mask = T.addbroadcast(mask, -1) # (time, nb_samples, 1) matrix. mask = mask.dimshuffle(1, 0, 2) # (time, nb_samples, 1) mask = mask.astype('int8') return mask ================================================ FILE: config.py ================================================ # MODE = 'django' # # SOURCE_VOCAB_SIZE = 2490 # 2492 # 5980 # TARGET_VOCAB_SIZE = 2101 # 2110 # 4830 # # RULE_NUM = 222 # 228 # NODE_NUM = 96 # # NODE_EMBED_DIM = 256 # EMBED_DIM = 128 # RULE_EMBED_DIM = 256 # QUERY_DIM = 256 # LSTM_STATE_DIM = 256 # DECODER_ATT_HIDDEN_DIM = 50 # POINTER_NET_HIDDEN_DIM = 50 # # MAX_QUERY_LENGTH = 70 # MAX_EXAMPLE_ACTION_NUM = 100 # # DECODER_DROPOUT = 0.2 # WORD_DROPOUT = 0 # # # encoder # ENCODER_LSTM = 'bilstm' # # # decoder # PARENT_HIDDEN_STATE_FEEDING = True # PARENT_RULE_FEEDING = True # NODE_TYPE_FEEDING = True # TREE_ATTENTION = True # # # training # TRAIN_PATIENCE = 10 # MAX_EPOCH = 50 # BATCH_SIZE = 10 # VALID_PER_MINIBATCH = 4000 # SAVE_PER_MINIBATCH = 4000 # # # decoding # BEAM_SIZE = 15 # DECODE_MAX_TIME_STEP = 100 config_info = None ================================================ FILE: dataset.py ================================================ from __future__ import division import copy import nltk from collections import OrderedDict, defaultdict import logging import collections import numpy as np import string import re import astor from itertools import chain from nn.utils.io_utils import serialize_to_file, deserialize_from_file import config from lang.py.parse import get_grammar from lang.py.unaryclosure import get_top_unary_closures, apply_unary_closures # define actions APPLY_RULE = 0 GEN_TOKEN = 1 COPY_TOKEN = 2 GEN_COPY_TOKEN = 3 ACTION_NAMES = {APPLY_RULE: 'APPLY_RULE', GEN_TOKEN: 'GEN_TOKEN', COPY_TOKEN: 'COPY_TOKEN', GEN_COPY_TOKEN: 'GEN_COPY_TOKEN'} class Action(object): def __init__(self, act_type, data): self.act_type = act_type self.data = data def __repr__(self): data_str = self.data if not isinstance(self.data, dict) else \ ', '.join(['%s: %s' % (k, v) for k, v in self.data.iteritems()]) repr_str = 'Action{%s}[%s]' % (ACTION_NAMES[self.act_type], data_str) return repr_str class Vocab(object): def __init__(self): self.token_id_map = OrderedDict() self.insert_token('') self.insert_token('') self.insert_token('') @property def unk(self): return self.token_id_map[''] @property def eos(self): return self.token_id_map[''] def __getitem__(self, item): if item in self.token_id_map: return self.token_id_map[item] logging.debug('encounter one unknown word [%s]' % item) return self.token_id_map[''] def __contains__(self, item): return item in self.token_id_map @property def size(self): return len(self.token_id_map) def __setitem__(self, key, value): self.token_id_map[key] = value def __len__(self): return len(self.token_id_map) def __iter__(self): return self.token_id_map.iterkeys() def iteritems(self): return self.token_id_map.iteritems() def complete(self): self.id_token_map = dict((v, k) for (k, v) in self.token_id_map.iteritems()) def get_token(self, token_id): return self.id_token_map[token_id] def insert_token(self, token): if token in self.token_id_map: return self[token] else: idx = len(self) self[token] = idx return idx replace_punctuation = string.maketrans(string.punctuation, ' '*len(string.punctuation)) def tokenize(str): str = str.translate(replace_punctuation) return nltk.word_tokenize(str) def gen_vocab(tokens, vocab_size=3000, freq_cutoff=5): word_freq = defaultdict(int) for token in tokens: word_freq[token] += 1 print 'total num. of tokens: %d' % len(word_freq) words_freq_cutoff = [w for w in word_freq if word_freq[w] >= freq_cutoff] print 'num. of words appear at least %d: %d' % (freq_cutoff, len(words_freq_cutoff)) ranked_words = sorted(words_freq_cutoff, key=word_freq.get, reverse=True)[:vocab_size-2] ranked_words = set(ranked_words) vocab = Vocab() for token in tokens: if token in ranked_words: vocab.insert_token(token) vocab.complete() return vocab class DataEntry: def __init__(self, raw_id, query, parse_tree, code, actions, meta_data=None): self.raw_id = raw_id self.eid = -1 # FIXME: rename to query_token self.query = query self.parse_tree = parse_tree self.actions = actions self.code = code self.meta_data = meta_data @property def data(self): if not hasattr(self, '_data'): assert self.dataset is not None, 'No associated dataset for the example' self._data = self.dataset.get_prob_func_inputs([self.eid]) return self._data def copy(self): e = DataEntry(self.raw_id, self.query, self.parse_tree, self.code, self.actions, self.meta_data) return e class DataSet: def __init__(self, annot_vocab, terminal_vocab, grammar, name='train_data'): self.annot_vocab = annot_vocab self.terminal_vocab = terminal_vocab self.name = name self.examples = [] self.data_matrix = dict() self.grammar = grammar def add(self, example): example.eid = len(self.examples) example.dataset = self self.examples.append(example) def get_dataset_by_ids(self, ids, name): dataset = DataSet(self.annot_vocab, self.terminal_vocab, self.grammar, name) for eid in ids: example_copy = self.examples[eid].copy() dataset.add(example_copy) for k, v in self.data_matrix.iteritems(): dataset.data_matrix[k] = v[ids] return dataset @property def count(self): if self.examples: return len(self.examples) return 0 def get_examples(self, ids): if isinstance(ids, collections.Iterable): return [self.examples[i] for i in ids] else: return self.examples[ids] def get_prob_func_inputs(self, ids): order = ['query_tokens', 'tgt_action_seq', 'tgt_action_seq_type', 'tgt_node_seq', 'tgt_par_rule_seq', 'tgt_par_t_seq'] max_src_seq_len = max(len(self.examples[i].query) for i in ids) max_tgt_seq_len = max(len(self.examples[i].actions) for i in ids) logging.debug('max. src sequence length: %d', max_src_seq_len) logging.debug('max. tgt sequence length: %d', max_tgt_seq_len) data = [] for entry in order: if entry == 'query_tokens': data.append(self.data_matrix[entry][ids, :max_src_seq_len]) else: data.append(self.data_matrix[entry][ids, :max_tgt_seq_len]) return data def init_data_matrices(self, max_query_length=70, max_example_action_num=100): logging.info('init data matrices for [%s] dataset', self.name) annot_vocab = self.annot_vocab terminal_vocab = self.terminal_vocab # np.max([len(e.query) for e in self.examples]) # np.max([len(e.rules) for e in self.examples]) query_tokens = self.data_matrix['query_tokens'] = np.zeros((self.count, max_query_length), dtype='int32') tgt_node_seq = self.data_matrix['tgt_node_seq'] = np.zeros((self.count, max_example_action_num), dtype='int32') tgt_par_rule_seq = self.data_matrix['tgt_par_rule_seq'] = np.zeros((self.count, max_example_action_num), dtype='int32') tgt_par_t_seq = self.data_matrix['tgt_par_t_seq'] = np.zeros((self.count, max_example_action_num), dtype='int32') tgt_action_seq = self.data_matrix['tgt_action_seq'] = np.zeros((self.count, max_example_action_num, 3), dtype='int32') tgt_action_seq_type = self.data_matrix['tgt_action_seq_type'] = np.zeros((self.count, max_example_action_num, 3), dtype='int32') for eid, example in enumerate(self.examples): exg_query_tokens = example.query[:max_query_length] exg_action_seq = example.actions[:max_example_action_num] for tid, token in enumerate(exg_query_tokens): token_id = annot_vocab[token] query_tokens[eid, tid] = token_id assert len(exg_action_seq) > 0 for t, action in enumerate(exg_action_seq): if action.act_type == APPLY_RULE: rule = action.data['rule'] tgt_action_seq[eid, t, 0] = self.grammar.rule_to_id[rule] tgt_action_seq_type[eid, t, 0] = 1 elif action.act_type == GEN_TOKEN: token = action.data['literal'] token_id = terminal_vocab[token] tgt_action_seq[eid, t, 1] = token_id tgt_action_seq_type[eid, t, 1] = 1 elif action.act_type == COPY_TOKEN: src_token_idx = action.data['source_idx'] tgt_action_seq[eid, t, 2] = src_token_idx tgt_action_seq_type[eid, t, 2] = 1 elif action.act_type == GEN_COPY_TOKEN: token = action.data['literal'] token_id = terminal_vocab[token] tgt_action_seq[eid, t, 1] = token_id tgt_action_seq_type[eid, t, 1] = 1 src_token_idx = action.data['source_idx'] tgt_action_seq[eid, t, 2] = src_token_idx tgt_action_seq_type[eid, t, 2] = 1 else: raise RuntimeError('wrong action type!') # parent information rule = action.data['rule'] parent_rule = action.data['parent_rule'] tgt_node_seq[eid, t] = self.grammar.get_node_type_id(rule.parent) if parent_rule: tgt_par_rule_seq[eid, t] = self.grammar.rule_to_id[parent_rule] else: assert t == 0 tgt_par_rule_seq[eid, t] = -1 # parent hidden states parent_t = action.data['parent_t'] tgt_par_t_seq[eid, t] = parent_t example.dataset = self class DataHelper(object): @staticmethod def canonicalize_query(query): return query def parse_django_dataset_nt_only(): from parse import parse_django annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.anno' vocab = gen_vocab(annot_file, vocab_size=4500) code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code' grammar, all_parse_trees = parse_django(code_file) train_data = DataSet(vocab, grammar, name='train') dev_data = DataSet(vocab, grammar, name='dev') test_data = DataSet(vocab, grammar, name='test') # train_data train_annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/train.anno' train_parse_trees = all_parse_trees[0:16000] for line, parse_tree in zip(open(train_annot_file), train_parse_trees): if parse_tree.is_leaf: continue line = line.strip() tokens = tokenize(line) entry = DataEntry(tokens, parse_tree) train_data.add(entry) train_data.init_data_matrices() # dev_data dev_annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/dev.anno' dev_parse_trees = all_parse_trees[16000:17000] for line, parse_tree in zip(open(dev_annot_file), dev_parse_trees): if parse_tree.is_leaf: continue line = line.strip() tokens = tokenize(line) entry = DataEntry(tokens, parse_tree) dev_data.add(entry) dev_data.init_data_matrices() # test_data test_annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/test.anno' test_parse_trees = all_parse_trees[17000:18805] for line, parse_tree in zip(open(test_annot_file), test_parse_trees): if parse_tree.is_leaf: continue line = line.strip() tokens = tokenize(line) entry = DataEntry(tokens, parse_tree) test_data.add(entry) test_data.init_data_matrices() serialize_to_file((train_data, dev_data, test_data), 'django.typed_rule.bin') def parse_django_dataset(): from lang.py.parse import parse_raw from lang.util import escape MAX_QUERY_LENGTH = 70 UNARY_CUTOFF_FREQ = 30 annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.anno' code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code' data = preprocess_dataset(annot_file, code_file) for e in data: e['parse_tree'] = parse_raw(e['code']) parse_trees = [e['parse_tree'] for e in data] # apply unary closures # unary_closures = get_top_unary_closures(parse_trees, k=0, freq=UNARY_CUTOFF_FREQ) # for i, parse_tree in enumerate(parse_trees): # apply_unary_closures(parse_tree, unary_closures) # build the grammar grammar = get_grammar(parse_trees) # write grammar with open('django.grammar.unary_closure.txt', 'w') as f: for rule in grammar: f.write(rule.__repr__() + '\n') # # build grammar ... # from lang.py.py_dataset import extract_grammar # grammar, all_parse_trees = extract_grammar(code_file) annot_tokens = list(chain(*[e['query_tokens'] for e in data])) annot_vocab = gen_vocab(annot_tokens, vocab_size=5000, freq_cutoff=3) # gen_vocab(annot_tokens, vocab_size=5980) terminal_token_seq = [] empty_actions_count = 0 # helper function begins def get_terminal_tokens(_terminal_str): # _terminal_tokens = filter(None, re.split('([, .?!])', _terminal_str)) # _terminal_str.split('-SP-') # _terminal_tokens = filter(None, re.split('( )', _terminal_str)) # _terminal_str.split('-SP-') tmp_terminal_tokens = _terminal_str.split(' ') _terminal_tokens = [] for token in tmp_terminal_tokens: if token: _terminal_tokens.append(token) _terminal_tokens.append(' ') return _terminal_tokens[:-1] # return _terminal_tokens # helper function ends # first pass for entry in data: idx = entry['id'] query_tokens = entry['query_tokens'] code = entry['code'] parse_tree = entry['parse_tree'] for node in parse_tree.get_leaves(): if grammar.is_value_node(node): terminal_val = node.value terminal_str = str(terminal_val) terminal_tokens = get_terminal_tokens(terminal_str) for terminal_token in terminal_tokens: assert len(terminal_token) > 0 terminal_token_seq.append(terminal_token) terminal_vocab = gen_vocab(terminal_token_seq, vocab_size=5000, freq_cutoff=3) assert '_STR:0_' in terminal_vocab train_data = DataSet(annot_vocab, terminal_vocab, grammar, 'train_data') dev_data = DataSet(annot_vocab, terminal_vocab, grammar, 'dev_data') test_data = DataSet(annot_vocab, terminal_vocab, grammar, 'test_data') all_examples = [] can_fully_gen_num = 0 # second pass for entry in data: idx = entry['id'] query_tokens = entry['query_tokens'] code = entry['code'] str_map = entry['str_map'] parse_tree = entry['parse_tree'] rule_list, rule_parents = parse_tree.get_productions(include_value_node=True) actions = [] can_fully_gen = True rule_pos_map = dict() for rule_count, rule in enumerate(rule_list): if not grammar.is_value_node(rule.parent): assert rule.value is None parent_rule = rule_parents[(rule_count, rule)][0] if parent_rule: parent_t = rule_pos_map[parent_rule] else: parent_t = 0 rule_pos_map[rule] = len(actions) d = {'rule': rule, 'parent_t': parent_t, 'parent_rule': parent_rule} action = Action(APPLY_RULE, d) actions.append(action) else: assert rule.is_leaf parent_rule = rule_parents[(rule_count, rule)][0] parent_t = rule_pos_map[parent_rule] terminal_val = rule.value terminal_str = str(terminal_val) terminal_tokens = get_terminal_tokens(terminal_str) # assert len(terminal_tokens) > 0 for terminal_token in terminal_tokens: term_tok_id = terminal_vocab[terminal_token] tok_src_idx = -1 try: tok_src_idx = query_tokens.index(terminal_token) except ValueError: pass d = {'literal': terminal_token, 'rule': rule, 'parent_rule': parent_rule, 'parent_t': parent_t} # cannot copy, only generation # could be unk! if tok_src_idx < 0 or tok_src_idx >= MAX_QUERY_LENGTH: action = Action(GEN_TOKEN, d) if terminal_token not in terminal_vocab: if terminal_token not in query_tokens: # print terminal_token can_fully_gen = False else: # copy if term_tok_id != terminal_vocab.unk: d['source_idx'] = tok_src_idx action = Action(GEN_COPY_TOKEN, d) else: d['source_idx'] = tok_src_idx action = Action(COPY_TOKEN, d) actions.append(action) d = {'literal': '', 'rule': rule, 'parent_rule': parent_rule, 'parent_t': parent_t} actions.append(Action(GEN_TOKEN, d)) if len(actions) == 0: empty_actions_count += 1 continue example = DataEntry(idx, query_tokens, parse_tree, code, actions, {'raw_code': entry['raw_code'], 'str_map': entry['str_map']}) if can_fully_gen: can_fully_gen_num += 1 # train, valid, test if 0 <= idx < 16000: train_data.add(example) elif 16000 <= idx < 17000: dev_data.add(example) else: test_data.add(example) all_examples.append(example) # print statistics max_query_len = max(len(e.query) for e in all_examples) max_actions_len = max(len(e.actions) for e in all_examples) serialize_to_file([len(e.query) for e in all_examples], 'query.len') serialize_to_file([len(e.actions) for e in all_examples], 'actions.len') logging.info('examples that can be fully reconstructed: %d/%d=%f', can_fully_gen_num, len(all_examples), can_fully_gen_num / len(all_examples)) logging.info('empty_actions_count: %d', empty_actions_count) logging.info('max_query_len: %d', max_query_len) logging.info('max_actions_len: %d', max_actions_len) train_data.init_data_matrices() dev_data.init_data_matrices() test_data.init_data_matrices() serialize_to_file((train_data, dev_data, test_data), 'data/django.cleaned.dataset.freq3.par_info.refact.space_only.order_by_ulink_len.bin') # 'data/django.cleaned.dataset.freq5.par_info.refact.space_only.unary_closure.freq{UNARY_CUTOFF_FREQ}.order_by_ulink_len.bin'.format(UNARY_CUTOFF_FREQ=UNARY_CUTOFF_FREQ)) return train_data, dev_data, test_data def check_terminals(): from parse import parse_django, unescape grammar, parse_trees = parse_django('/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code') annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.anno' unique_terminals = set() invalid_terminals = set() for i, line in enumerate(open(annot_file)): parse_tree = parse_trees[i] utterance = line.strip() leaves = parse_tree.get_leaves() # tokens = set(nltk.word_tokenize(utterance)) leave_tokens = [l.label for l in leaves if l.label] not_included = [] for leaf_token in leave_tokens: leaf_token = str(leaf_token) leaf_token = unescape(leaf_token) if leaf_token not in utterance: not_included.append(leaf_token) if len(leaf_token) <= 15: unique_terminals.add(leaf_token) else: invalid_terminals.add(leaf_token) else: if isinstance(leaf_token, str): print leaf_token # if not_included: # print str(i) + '---' + ', '.join(not_included) # print 'num of unique leaves: %d' % len(unique_terminals) # print unique_terminals # # print 'num of invalid leaves: %d' % len(invalid_terminals) # print invalid_terminals def query_to_data(query, annot_vocab): query_tokens = query.split(' ') token_num = min(config.max_qeury_length, len(query_tokens)) data = np.zeros((1, token_num), dtype='int32') for tid, token in enumerate(query_tokens[:token_num]): token_id = annot_vocab[token] data[0, tid] = token_id return data QUOTED_STRING_RE = re.compile(r"(?P['\"])(?P.*?)(? 50: # return # query = QUOTED_STRING_RE.sub(str_repr, query, 1) str_map[str_literal] = str_repr str_count += 1 match = QUOTED_STRING_RE.search(query) code = code.replace(str_literal, '\'' + str_repr + '\'') # clean the annotation # query = query.replace('.', ' . ') for k, v in str_map.iteritems(): if k == '\'%s\'' or k == '\"%s\"': query = query.replace(v, k) code = code.replace('\'' + v + '\'', k) # tokenize query_tokens = nltk.word_tokenize(query) new_query_tokens = [] # break up function calls for token in query_tokens: new_query_tokens.append(token) i = token.find('.') if 0 < i < len(token) - 1: new_tokens = ['['] + token.replace('.', ' . ').split(' ') + [']'] new_query_tokens.extend(new_tokens) # check if the code compiles tree = parse(code) ast_tree = tree_to_ast(tree) astor.to_source(ast_tree) return new_query_tokens, code, str_map def preprocess_dataset(annot_file, code_file): f_annot = open('annot.all.canonicalized.txt', 'w') f_code = open('code.all.canonicalized.txt', 'w') examples = [] err_num = 0 for idx, (annot, code) in enumerate(zip(open(annot_file), open(code_file))): annot = annot.strip() code = code.strip() try: clean_query_tokens, clean_code, str_map = canonicalize_example(annot, code) example = {'id': idx, 'query_tokens': clean_query_tokens, 'code': clean_code, 'str_map': str_map, 'raw_code': code} examples.append(example) f_annot.write('example# %d\n' % idx) f_annot.write(' '.join(clean_query_tokens) + '\n') f_annot.write('%d\n' % len(str_map)) for k, v in str_map.iteritems(): f_annot.write('%s ||| %s\n' % (k, v)) f_code.write('example# %d\n' % idx) f_code.write(clean_code + '\n') except: print code err_num += 1 idx += 1 f_annot.close() f_annot.close() # serialize_to_file(examples, 'django.cleaned.bin') print 'error num: %d' % err_num print 'preprocess_dataset: cleaned example num: %d' % len(examples) return examples if __name__== '__main__': from nn.utils.generic_utils import init_logging init_logging('parse.log') # annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.anno' # code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code' # preprocess_dataset(annot_file, code_file) # parse_django_dataset() # check_terminals() # print process_query(""" ALLOWED_VARIABLE_CHARS is a string 'abcdefgh"ijklm" nop"%s"qrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_.'.""") # for i, query in enumerate(open('/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.anno')): # print i, process_query(query) # clean_dataset() parse_django_dataset() # from lang.py.py_dataset import parse_hs_dataset # parse_hs_dataset() ================================================ FILE: decoder.py ================================================ import traceback import config from model import * def decode_python_dataset(model, dataset, verbose=True): from lang.py.parse import decode_tree_to_python_ast if verbose: logging.info('decoding [%s] set, num. examples: %d', dataset.name, dataset.count) decode_results = [] cum_num = 0 for example in dataset.examples: cand_list = model.decode(example, dataset.grammar, dataset.terminal_vocab, beam_size=config.beam_size, max_time_step=config.decode_max_time_step) exg_decode_results = [] for cid, cand in enumerate(cand_list[:10]): try: ast_tree = decode_tree_to_python_ast(cand.tree) code = astor.to_source(ast_tree) exg_decode_results.append((cid, cand, ast_tree, code)) except: if verbose: print "Exception in converting tree to code:" print '-' * 60 print 'raw_id: %d, beam pos: %d' % (example.raw_id, cid) traceback.print_exc(file=sys.stdout) print '-' * 60 cum_num += 1 if cum_num % 50 == 0 and verbose: print '%d examples so far ...' % cum_num decode_results.append(exg_decode_results) return decode_results # serialize_to_file(decode_results, '%s.decode_results.profile' % dataset.name) def decode_ifttt_dataset(model, dataset, verbose=True): if verbose: logging.info('decoding [%s] set, num. examples: %d', dataset.name, dataset.count) decode_results = [] cum_num = 0 for example in dataset.examples: cand_list = model.decode(example, dataset.grammar, dataset.terminal_vocab, beam_size=config.beam_size, max_time_step=config.decode_max_time_step) exg_decode_results = [] for cid, cand in enumerate(cand_list[:10]): try: exg_decode_results.append((cid, cand)) except: if verbose: print "Exception in converting tree to code:" print '-' * 60 print 'raw_id: %d, beam pos: %d' % (example.raw_id, cid) traceback.print_exc(file=sys.stdout) print '-' * 60 cum_num += 1 if cum_num % 50 == 0 and verbose: print '%d examples so far ...' % cum_num decode_results.append(exg_decode_results) return decode_results ================================================ FILE: evaluation.py ================================================ # -*- coding: UTF-8 -*- from __future__ import division import os from nltk.translate.bleu_score import sentence_bleu, corpus_bleu, SmoothingFunction import logging import traceback from nn.utils.generic_utils import init_logging from model import * DJANGO_ANNOT_FILE = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.anno' def tokenize_for_bleu_eval(code): code = re.sub(r'([^A-Za-z0-9_])', r' \1 ', code) code = re.sub(r'([a-z])([A-Z])', r'\1 \2', code) code = re.sub(r'\s+', ' ', code) code = code.replace('"', '`') code = code.replace('\'', '`') tokens = [t for t in code.split(' ') if t] return tokens def evaluate(model, dataset, verbose=True): if verbose: logging.info('evaluating [%s] dataset, [%d] examples' % (dataset.name, dataset.count)) exact_match_ratio = 0.0 for example in dataset.examples: logging.info('evaluating example [%d]' % example.eid) hyps, hyp_scores = model.decode(example, max_time_step=config.decode_max_time_step) gold_rules = example.rules if len(hyps) == 0: logging.warning('no decoding result for example [%d]!' % example.eid) continue best_hyp = hyps[0] predict_rules = [dataset.grammar.id_to_rule[rid] for rid in best_hyp] assert len(predict_rules) > 0 and len(gold_rules) > 0 exact_match = sorted(gold_rules, key=lambda x: x.__repr__()) == sorted(predict_rules, key=lambda x: x.__repr__()) if exact_match: exact_match_ratio += 1 # p = len(predict_rules.intersection(gold_rules)) / len(predict_rules) # r = len(predict_rules.intersection(gold_rules)) / len(gold_rules) exact_match_ratio /= dataset.count logging.info('exact_match_ratio = %f' % exact_match_ratio) return exact_match_ratio def evaluate_decode_results(dataset, decode_results, verbose=True): from lang.py.parse import tokenize_code, de_canonicalize_code # tokenize_code = tokenize_for_bleu_eval import ast assert dataset.count == len(decode_results) f = f_decode = None if verbose: f = open(dataset.name + '.exact_match', 'w') exact_match_ids = [] f_decode = open(dataset.name + '.decode_results.txt', 'w') eid_to_annot = dict() if config.data_type == 'django': for raw_id, line in enumerate(open(DJANGO_ANNOT_FILE)): eid_to_annot[raw_id] = line.strip() f_bleu_eval_ref = open(dataset.name + '.ref', 'w') f_bleu_eval_hyp = open(dataset.name + '.hyp', 'w') f_generated_code = open(dataset.name + '.geneated_code', 'w') logging.info('evaluating [%s] set, [%d] examples', dataset.name, dataset.count) cum_oracle_bleu = 0.0 cum_oracle_acc = 0.0 cum_bleu = 0.0 cum_acc = 0.0 sm = SmoothingFunction() all_references = [] all_predictions = [] if all(len(cand) == 0 for cand in decode_results): logging.ERROR('Empty decoding results for the current dataset!') return -1, -1 for eid in range(dataset.count): example = dataset.examples[eid] ref_code = example.code ref_ast_tree = ast.parse(ref_code).body[0] refer_source = astor.to_source(ref_ast_tree).strip() # refer_source = ref_code refer_tokens = tokenize_code(refer_source) cur_example_correct = False decode_cands = decode_results[eid] if len(decode_cands) == 0: continue decode_cand = decode_cands[0] cid, cand, ast_tree, code = decode_cand code = astor.to_source(ast_tree).strip() # simple_url_2_re = re.compile('_STR:0_', re.)) try: predict_tokens = tokenize_code(code) except: logging.error('error in tokenizing [%s]', code) continue if refer_tokens == predict_tokens: cum_acc += 1 cur_example_correct = True if verbose: exact_match_ids.append(example.raw_id) f.write('-' * 60 + '\n') f.write('example_id: %d\n' % example.raw_id) f.write(code + '\n') f.write('-' * 60 + '\n') if config.data_type == 'django': ref_code_for_bleu = example.meta_data['raw_code'] pred_code_for_bleu = de_canonicalize_code(code, example.meta_data['raw_code']) # ref_code_for_bleu = de_canonicalize_code(ref_code_for_bleu, example.meta_data['raw_code']) # convert canonicalized code to raw code for literal, place_holder in example.meta_data['str_map'].iteritems(): pred_code_for_bleu = pred_code_for_bleu.replace('\'' + place_holder + '\'', literal) # ref_code_for_bleu = ref_code_for_bleu.replace('\'' + place_holder + '\'', literal) elif config.data_type == 'hs': ref_code_for_bleu = ref_code pred_code_for_bleu = code # we apply Ling Wang's trick when evaluating BLEU scores refer_tokens_for_bleu = tokenize_for_bleu_eval(ref_code_for_bleu) pred_tokens_for_bleu = tokenize_for_bleu_eval(pred_code_for_bleu) # The if-chunk below is for debugging purpose, sometimes the reference cannot match with the prediction # because of inconsistent quotes (e.g., single quotes in reference, double quotes in prediction). # However most of these cases are solved by cannonicalizing the reference code using astor (parse the reference # into AST, and regenerate the code. Use this regenerated one as the reference) weired = False if refer_tokens_for_bleu == pred_tokens_for_bleu and refer_tokens != predict_tokens: # cum_acc += 1 weired = True elif refer_tokens == predict_tokens: # weired! # weired = True pass shorter = len(pred_tokens_for_bleu) < len(refer_tokens_for_bleu) all_references.append([refer_tokens_for_bleu]) all_predictions.append(pred_tokens_for_bleu) # try: ngram_weights = [0.25] * min(4, len(refer_tokens_for_bleu)) bleu_score = sentence_bleu([refer_tokens_for_bleu], pred_tokens_for_bleu, weights=ngram_weights, smoothing_function=sm.method3) cum_bleu += bleu_score # except: # pass if verbose: print 'raw_id: %d, bleu_score: %f' % (example.raw_id, bleu_score) f_decode.write('-' * 60 + '\n') f_decode.write('example_id: %d\n' % example.raw_id) f_decode.write('intent: \n') if config.data_type == 'django': f_decode.write(eid_to_annot[example.raw_id] + '\n') elif config.data_type == 'hs': f_decode.write(' '.join(example.query) + '\n') f_bleu_eval_ref.write(' '.join(refer_tokens_for_bleu) + '\n') f_bleu_eval_hyp.write(' '.join(pred_tokens_for_bleu) + '\n') f_decode.write('canonicalized reference: \n') f_decode.write(refer_source + '\n') f_decode.write('canonicalized prediction: \n') f_decode.write(code + '\n') f_decode.write('reference code for bleu calculation: \n') f_decode.write(ref_code_for_bleu + '\n') f_decode.write('predicted code for bleu calculation: \n') f_decode.write(pred_code_for_bleu + '\n') f_decode.write('pred_shorter_than_ref: %s\n' % shorter) f_decode.write('weired: %s\n' % weired) f_decode.write('-' * 60 + '\n') # for Hiro's evaluation f_generated_code.write(pred_code_for_bleu.replace('\n', '#NEWLINE#') + '\n') # compute oracle best_score = 0. cur_oracle_acc = 0. for decode_cand in decode_cands[:config.beam_size]: cid, cand, ast_tree, code = decode_cand try: code = astor.to_source(ast_tree).strip() predict_tokens = tokenize_code(code) if predict_tokens == refer_tokens: cur_oracle_acc = 1 if config.data_type == 'django': pred_code_for_bleu = de_canonicalize_code(code, example.meta_data['raw_code']) # convert canonicalized code to raw code for literal, place_holder in example.meta_data['str_map'].iteritems(): pred_code_for_bleu = pred_code_for_bleu.replace('\'' + place_holder + '\'', literal) elif config.data_type == 'hs': pred_code_for_bleu = code # we apply Ling Wang's trick when evaluating BLEU scores pred_tokens_for_bleu = tokenize_for_bleu_eval(pred_code_for_bleu) ngram_weights = [0.25] * min(4, len(refer_tokens_for_bleu)) bleu_score = sentence_bleu([refer_tokens_for_bleu], pred_tokens_for_bleu, weights=ngram_weights, smoothing_function=sm.method3) if bleu_score > best_score: best_score = bleu_score except: continue cum_oracle_bleu += best_score cum_oracle_acc += cur_oracle_acc cum_bleu /= dataset.count cum_acc /= dataset.count cum_oracle_bleu /= dataset.count cum_oracle_acc /= dataset.count logging.info('corpus level bleu: %f', corpus_bleu(all_references, all_predictions, smoothing_function=sm.method3)) logging.info('sentence level bleu: %f', cum_bleu) logging.info('accuracy: %f', cum_acc) logging.info('oracle bleu: %f', cum_oracle_bleu) logging.info('oracle accuracy: %f', cum_oracle_acc) if verbose: f.write(', '.join(str(i) for i in exact_match_ids)) f.close() f_decode.close() f_bleu_eval_ref.close() f_bleu_eval_hyp.close() f_generated_code.close() return cum_bleu, cum_acc def analyze_decode_results(dataset, decode_results, verbose=True): from lang.py.parse import tokenize_code, de_canonicalize_code # tokenize_code = tokenize_for_bleu_eval import ast assert dataset.count == len(decode_results) f = f_decode = None if verbose: f = open(dataset.name + '.exact_match', 'w') exact_match_ids = [] f_decode = open(dataset.name + '.decode_results.txt', 'w') eid_to_annot = dict() if config.data_type == 'django': for raw_id, line in enumerate(open(DJANGO_ANNOT_FILE)): eid_to_annot[raw_id] = line.strip() f_bleu_eval_ref = open(dataset.name + '.ref', 'w') f_bleu_eval_hyp = open(dataset.name + '.hyp', 'w') logging.info('evaluating [%s] set, [%d] examples', dataset.name, dataset.count) cum_oracle_bleu = 0.0 cum_oracle_acc = 0.0 cum_bleu = 0.0 cum_acc = 0.0 sm = SmoothingFunction() all_references = [] all_predictions = [] if all(len(cand) == 0 for cand in decode_results): logging.ERROR('Empty decoding results for the current dataset!') return -1, -1 binned_results_dict = defaultdict(list) def get_binned_key(ast_size): cutoff = 50 if config.data_type == 'django' else 250 k = 10 if config.data_type == 'django' else 25 # for hs if ast_size >= cutoff: return '%d - inf' % cutoff lower = int(ast_size / k) * k upper = lower + k key = '%d - %d' % (lower, upper) return key for eid in range(dataset.count): example = dataset.examples[eid] ref_code = example.code ref_ast_tree = ast.parse(ref_code).body[0] refer_source = astor.to_source(ref_ast_tree).strip() # refer_source = ref_code refer_tokens = tokenize_code(refer_source) cur_example_acc = 0.0 decode_cands = decode_results[eid] if len(decode_cands) == 0: continue decode_cand = decode_cands[0] cid, cand, ast_tree, code = decode_cand code = astor.to_source(ast_tree).strip() # simple_url_2_re = re.compile('_STR:0_', re.)) try: predict_tokens = tokenize_code(code) except: logging.error('error in tokenizing [%s]', code) continue if refer_tokens == predict_tokens: cum_acc += 1 cur_example_acc = 1.0 if verbose: exact_match_ids.append(example.raw_id) f.write('-' * 60 + '\n') f.write('example_id: %d\n' % example.raw_id) f.write(code + '\n') f.write('-' * 60 + '\n') if config.data_type == 'django': ref_code_for_bleu = example.meta_data['raw_code'] pred_code_for_bleu = de_canonicalize_code(code, example.meta_data['raw_code']) # ref_code_for_bleu = de_canonicalize_code(ref_code_for_bleu, example.meta_data['raw_code']) # convert canonicalized code to raw code for literal, place_holder in example.meta_data['str_map'].iteritems(): pred_code_for_bleu = pred_code_for_bleu.replace('\'' + place_holder + '\'', literal) # ref_code_for_bleu = ref_code_for_bleu.replace('\'' + place_holder + '\'', literal) elif config.data_type == 'hs': ref_code_for_bleu = ref_code pred_code_for_bleu = code # we apply Ling Wang's trick when evaluating BLEU scores refer_tokens_for_bleu = tokenize_for_bleu_eval(ref_code_for_bleu) pred_tokens_for_bleu = tokenize_for_bleu_eval(pred_code_for_bleu) shorter = len(pred_tokens_for_bleu) < len(refer_tokens_for_bleu) all_references.append([refer_tokens_for_bleu]) all_predictions.append(pred_tokens_for_bleu) # try: ngram_weights = [0.25] * min(4, len(refer_tokens_for_bleu)) bleu_score = sentence_bleu([refer_tokens_for_bleu], pred_tokens_for_bleu, weights=ngram_weights, smoothing_function=sm.method3) cum_bleu += bleu_score # except: # pass if verbose: print 'raw_id: %d, bleu_score: %f' % (example.raw_id, bleu_score) f_decode.write('-' * 60 + '\n') f_decode.write('example_id: %d\n' % example.raw_id) f_decode.write('intent: \n') if config.data_type == 'django': f_decode.write(eid_to_annot[example.raw_id] + '\n') elif config.data_type == 'hs': f_decode.write(' '.join(example.query) + '\n') f_bleu_eval_ref.write(' '.join(refer_tokens_for_bleu) + '\n') f_bleu_eval_hyp.write(' '.join(pred_tokens_for_bleu) + '\n') f_decode.write('canonicalized reference: \n') f_decode.write(refer_source + '\n') f_decode.write('canonicalized prediction: \n') f_decode.write(code + '\n') f_decode.write('reference code for bleu calculation: \n') f_decode.write(ref_code_for_bleu + '\n') f_decode.write('predicted code for bleu calculation: \n') f_decode.write(pred_code_for_bleu + '\n') f_decode.write('pred_shorter_than_ref: %s\n' % shorter) # f_decode.write('weired: %s\n' % weired) f_decode.write('-' * 60 + '\n') # compute oracle best_bleu_score = 0. cur_oracle_acc = 0. for decode_cand in decode_cands[:config.beam_size]: cid, cand, ast_tree, code = decode_cand try: code = astor.to_source(ast_tree).strip() predict_tokens = tokenize_code(code) if predict_tokens == refer_tokens: cur_oracle_acc = 1. if config.data_type == 'django': pred_code_for_bleu = de_canonicalize_code(code, example.meta_data['raw_code']) # convert canonicalized code to raw code for literal, place_holder in example.meta_data['str_map'].iteritems(): pred_code_for_bleu = pred_code_for_bleu.replace('\'' + place_holder + '\'', literal) elif config.data_type == 'hs': pred_code_for_bleu = code # we apply Ling Wang's trick when evaluating BLEU scores pred_tokens_for_bleu = tokenize_for_bleu_eval(pred_code_for_bleu) ngram_weights = [0.25] * min(4, len(refer_tokens_for_bleu)) cand_bleu_score = sentence_bleu([refer_tokens_for_bleu], pred_tokens_for_bleu, weights=ngram_weights, smoothing_function=sm.method3) if cand_bleu_score > best_bleu_score: best_bleu_score = cand_bleu_score except: continue cum_oracle_bleu += best_bleu_score cum_oracle_acc += cur_oracle_acc ref_ast_size = example.parse_tree.size binned_key = get_binned_key(ref_ast_size) binned_results_dict[binned_key].append((bleu_score, cur_example_acc, best_bleu_score, cur_oracle_acc)) cum_bleu /= dataset.count cum_acc /= dataset.count cum_oracle_bleu /= dataset.count cum_oracle_acc /= dataset.count logging.info('corpus level bleu: %f', corpus_bleu(all_references, all_predictions, smoothing_function=sm.method3)) logging.info('sentence level bleu: %f', cum_bleu) logging.info('accuracy: %f', cum_acc) logging.info('oracle bleu: %f', cum_oracle_bleu) logging.info('oracle accuracy: %f', cum_oracle_acc) keys = sorted(binned_results_dict, key=lambda x: int(x.split(' - ')[0])) Y = [[], [], [], []] X = [] for binned_key in keys: entry = binned_results_dict[binned_key] avg_bleu = np.average([t[0] for t in entry]) avg_acc = np.average([t[1] for t in entry]) avg_oracle_bleu = np.average([t[2] for t in entry]) avg_oracle_acc = np.average([t[3] for t in entry]) print binned_key, avg_bleu, avg_acc, avg_oracle_bleu, avg_oracle_acc, len(entry) Y[0].append(avg_bleu) Y[1].append(avg_acc) Y[2].append(avg_oracle_bleu) Y[3].append(avg_oracle_acc) X.append(int(binned_key.split(' - ')[0])) import matplotlib.pyplot as plt from pylab import rcParams rcParams['figure.figsize'] = 6, 2.5 if config.data_type == 'django': fig, ax = plt.subplots() ax.plot(X, Y[0], 'bs--', label='BLEU', lw=1.2) # ax.plot(X, Y[2], 'r^--', label='oracle BLEU', lw=1.2) ax.plot(X, Y[1], 'r^--', label='acc', lw=1.2) # ax.plot(X, Y[3], 'r^--', label='oracle acc', lw=1.2) ax.set_ylabel('Performance') ax.set_xlabel('Reference AST Size (# nodes)') plt.legend(loc='upper right', ncol=6) plt.tight_layout() # plt.savefig('django_acc_ast_size.pdf', dpi=300) # os.system('pcrop.sh django_acc_ast_size.pdf') plt.savefig('django_perf_ast_size.pdf', dpi=300) os.system('pcrop.sh django_perf_ast_size.pdf') else: fig, ax = plt.subplots() ax.plot(X, Y[0], 'bs--', label='BLEU', lw=1.2) # ax.plot(X, Y[2], 'r^--', label='oracle BLEU', lw=1.2) ax.plot(X, Y[1], 'r^--', label='acc', lw=1.2) # ax.plot(X, Y[3], 'r^--', label='oracle acc', lw=1.2) ax.set_ylabel('Performance') ax.set_xlabel('Reference AST Size (# nodes)') plt.legend(loc='upper right', ncol=6) plt.tight_layout() # plt.savefig('hs_bleu_ast_size.pdf', dpi=300) # os.system('pcrop.sh hs_bleu_ast_size.pdf') plt.savefig('hs_perf_ast_size.pdf', dpi=300) os.system('pcrop.sh hs_perf_ast_size.pdf') if verbose: f.write(', '.join(str(i) for i in exact_match_ids)) f.close() f_decode.close() f_bleu_eval_ref.close() f_bleu_eval_hyp.close() return cum_bleu, cum_acc def evaluate_seq2seq_decode_results(dataset, seq2seq_decode_file, seq2seq_ref_file, verbose=True, is_nbest=False): from lang.py.parse import parse f_seq2seq_decode = open(seq2seq_decode_file) f_seq2seq_ref = open(seq2seq_ref_file) if verbose: logging.info('evaluating [%s] set, [%d] examples', dataset.name, dataset.count) cum_bleu = 0.0 cum_acc = 0.0 sm = SmoothingFunction() decode_file_data = [l.strip() for l in f_seq2seq_decode.readlines()] ref_code_data = [l.strip() for l in f_seq2seq_ref.readlines()] if is_nbest: for i in xrange(len(decode_file_data)): d = decode_file_data[i].split(' ||| ') decode_file_data[i] = (int(d[0]), d[1]) def is_well_formed_python_code(_hyp): try: _hyp = _hyp.replace('#NEWLINE#', '\n').replace('#INDENT#', ' ').replace(' #MERGE# ', '') hyp_ast_tree = parse(_hyp) return True except: return False for eid in range(dataset.count): example = dataset.examples[eid] cur_example_correct = False if is_nbest: # find the best-scored well-formed code from the n-best list n_best_list = filter(lambda x: x[0] == eid, decode_file_data) code = top_scored_code = n_best_list[0][1] for _, hyp in n_best_list: if is_well_formed_python_code(hyp): code = hyp break if top_scored_code != code: print '*' * 60 print top_scored_code print code print '*' * 60 code = n_best_list[0][1] else: code = decode_file_data[eid] code = code.replace('#NEWLINE#', '\n').replace('#INDENT#', ' ').replace(' #MERGE# ', '') ref_code = ref_code_data[eid].replace('#NEWLINE#', '\n').replace('#INDENT#', ' ').replace(' #MERGE# ', '') if code == ref_code: cum_acc += 1 cur_example_correct = True if config.data_type == 'django': ref_code_for_bleu = example.meta_data['raw_code'] pred_code_for_bleu = code # de_canonicalize_code(code, example.meta_data['raw_code']) # ref_code_for_bleu = de_canonicalize_code(ref_code_for_bleu, example.meta_data['raw_code']) # convert canonicalized code to raw code for literal, place_holder in example.meta_data['str_map'].iteritems(): pred_code_for_bleu = pred_code_for_bleu.replace('\'' + place_holder + '\'', literal) # ref_code_for_bleu = ref_code_for_bleu.replace('\'' + place_holder + '\'', literal) elif config.data_type == 'hs': ref_code_for_bleu = example.code pred_code_for_bleu = code # we apply Ling Wang's trick when evaluating BLEU scores refer_tokens_for_bleu = tokenize_for_bleu_eval(ref_code_for_bleu) pred_tokens_for_bleu = tokenize_for_bleu_eval(pred_code_for_bleu) ngram_weights = [0.25] * min(4, len(refer_tokens_for_bleu)) bleu_score = sentence_bleu([refer_tokens_for_bleu], pred_tokens_for_bleu, weights=ngram_weights, smoothing_function=sm.method3) cum_bleu += bleu_score cum_bleu /= dataset.count cum_acc /= dataset.count logging.info('sentence level bleu: %f', cum_bleu) logging.info('accuracy: %f', cum_acc) def evaluate_seq2tree_sample_file(sample_file, id_file, dataset): from lang.py.parse import tokenize_code, de_canonicalize_code import ast, astor import traceback from lang.py.seq2tree_exp import seq2tree_repr_to_ast_tree, merge_broken_value_nodes from lang.py.parse import decode_tree_to_python_ast f_sample = open(sample_file) line_id_to_raw_id = OrderedDict() raw_id_to_eid = OrderedDict() for i, line in enumerate(open(id_file)): raw_id = int(line.strip()) line_id_to_raw_id[i] = raw_id for eid in range(len(dataset.examples)): raw_id_to_eid[dataset.examples[eid].raw_id] = eid rare_word_map = defaultdict(dict) if config.seq2tree_rareword_map: logging.info('use rare word map') for i, line in enumerate(open(config.seq2tree_rareword_map)): line = line.strip() if line: for e in line.split(' '): d = e.split(':', 1) rare_word_map[i][int(d[0])] = d[1] cum_bleu = 0.0 cum_acc = 0.0 sm = SmoothingFunction() convert_error_num = 0 for i in range(len(line_id_to_raw_id)): # print 'working on %d' % i ref_repr = f_sample.readline().strip() predict_repr = f_sample.readline().strip() predict_repr = predict_repr.replace('', 'str{}{unk}') # .replace('( )', '( str{}{unk} )') f_sample.readline() # if ' ( ) ' in ref_repr: # print i, ref_repr if i in rare_word_map: for unk_id, w in rare_word_map[i].iteritems(): ref_repr = ref_repr.replace(' str{}{unk_%s} ' % unk_id, ' str{}{%s} ' % w) predict_repr = predict_repr.replace(' str{}{unk_%s} ' % unk_id, ' str{}{%s} ' % w) try: parse_tree = seq2tree_repr_to_ast_tree(predict_repr) merge_broken_value_nodes(parse_tree) except: print 'error when converting:' print predict_repr convert_error_num += 1 continue raw_id = line_id_to_raw_id[i] eid = raw_id_to_eid[raw_id] example = dataset.examples[eid] ref_code = example.code ref_ast_tree = ast.parse(ref_code).body[0] refer_source = astor.to_source(ref_ast_tree).strip() refer_tokens = tokenize_code(refer_source) try: ast_tree = decode_tree_to_python_ast(parse_tree) code = astor.to_source(ast_tree).strip() except: print "Exception in converting tree to code:" print '-' * 60 print 'line id: %d' % i traceback.print_exc(file=sys.stdout) print '-' * 60 convert_error_num += 1 continue if config.data_type == 'django': ref_code_for_bleu = example.meta_data['raw_code'] pred_code_for_bleu = de_canonicalize_code(code, example.meta_data['raw_code']) # convert canonicalized code to raw code for literal, place_holder in example.meta_data['str_map'].iteritems(): pred_code_for_bleu = pred_code_for_bleu.replace('\'' + place_holder + '\'', literal) elif config.data_type == 'hs': ref_code_for_bleu = ref_code pred_code_for_bleu = code # we apply Ling Wang's trick when evaluating BLEU scores refer_tokens_for_bleu = tokenize_for_bleu_eval(ref_code_for_bleu) pred_tokens_for_bleu = tokenize_for_bleu_eval(pred_code_for_bleu) predict_tokens = tokenize_code(code) # if ref_repr == predict_repr: if predict_tokens == refer_tokens: cum_acc += 1 ngram_weights = [0.25] * min(4, len(refer_tokens_for_bleu)) bleu_score = sentence_bleu([refer_tokens_for_bleu], pred_tokens_for_bleu, weights=ngram_weights, smoothing_function=sm.method3) cum_bleu += bleu_score cum_bleu /= len(line_id_to_raw_id) cum_acc /= len(line_id_to_raw_id) logging.info('nun. examples: %d', len(line_id_to_raw_id)) logging.info('num. errors when converting repr to tree: %d', convert_error_num) logging.info('ratio of grammatically incorrect trees: %f', convert_error_num / float(len(line_id_to_raw_id))) logging.info('sentence level bleu: %f', cum_bleu) logging.info('accuracy: %f', cum_acc) def evaluate_ifttt_results(dataset, decode_results, verbose=True): assert dataset.count == len(decode_results) f = f_decode = None if verbose: f = open(dataset.name + '.exact_match', 'w') exact_match_ids = [] f_decode = open(os.path.join(config.output_dir, dataset.name + '.decode_results.txt'), 'w') logging.info('evaluating [%s] set, [%d] examples', dataset.name, dataset.count) cum_channel_acc = 0.0 cum_channel_func_acc = 0.0 cum_prod_f1 = 0.0 cum_oracle_prod_f1 = 0.0 if all(len(cand) == 0 for cand in decode_results): logging.ERROR('Empty decoding results for the current dataset!') return -1, -1, -1 for eid in range(dataset.count): example = dataset.examples[eid] ref_parse_tree = example.parse_tree decode_candidates = decode_results[eid] if len(decode_candidates) == 0: continue decode_cand = decode_candidates[0] cid, cand_hyp = decode_cand predict_parse_tree = cand_hyp.tree exact_match = predict_parse_tree == ref_parse_tree channel_acc, channel_func_acc, prod_f1 = ifttt_metric(predict_parse_tree, ref_parse_tree) cum_channel_acc += channel_acc cum_channel_func_acc += channel_func_acc cum_prod_f1 += prod_f1 if verbose: if exact_match: exact_match_ids.append(example.raw_id) print 'raw_id: %d, prod_f1: %f' % (example.raw_id, prod_f1) f_decode.write('-' * 60 + '\n') f_decode.write('example_id: %d\n' % example.raw_id) f_decode.write('intent: \n') f_decode.write(' '.join(example.query) + '\n') f_decode.write('reference: \n') f_decode.write(str(ref_parse_tree) + '\n') f_decode.write('prediction: \n') f_decode.write(str(predict_parse_tree) + '\n') f_decode.write('-' * 60 + '\n') # compute oracle best_prod_f1 = -1. for decode_cand in decode_candidates[:10]: cid, cand_hyp = decode_cand predict_parse_tree = cand_hyp.tree channel_acc, channel_func_acc, prod_f1 = ifttt_metric(predict_parse_tree, ref_parse_tree) if prod_f1 > best_prod_f1: best_prod_f1 = prod_f1 cum_oracle_prod_f1 += best_prod_f1 cum_channel_acc /= dataset.count cum_channel_func_acc /= dataset.count cum_prod_f1 /= dataset.count cum_oracle_prod_f1 /= dataset.count logging.info('channel_acc: %f', cum_channel_acc) logging.info('channel_func_acc: %f', cum_channel_func_acc) logging.info('prod_f1: %f', cum_prod_f1) logging.info('oracle prod_f1: %f', cum_oracle_prod_f1) if verbose: f.write(', '.join(str(i) for i in exact_match_ids)) f.close() f_decode.close() return cum_channel_acc, cum_channel_func_acc, cum_prod_f1 def ifttt_metric(predict_parse_tree, ref_parse_tree): channel_acc = channel_func_acc = prod_f1 = 0. # channel acc. channel_match = False if predict_parse_tree['TRIGGER'].children[0].type == ref_parse_tree['TRIGGER'].children[0].type and \ predict_parse_tree['ACTION'].children[0].type == ref_parse_tree['ACTION'].children[0].type: channel_acc += 1. channel_match = True # channel+func acc. if channel_match and predict_parse_tree['TRIGGER'].children[0].children[0].type == ref_parse_tree['TRIGGER'].children[0].children[0].type and \ predict_parse_tree['ACTION'].children[0].children[0].type == ref_parse_tree['ACTION'].children[0].children[0].type: channel_func_acc += 1. # predict_parse_tree is of type DecodingTree, different from reference tree! # if predict_parse_tree == ref_parse_tree: # channel_func_acc += 1. # prod. F1 ref_rules, _ = ref_parse_tree.get_productions() predict_rules, _ = predict_parse_tree.get_productions() prod_f1 = len(set(ref_rules).intersection(set(predict_rules))) / len(ref_rules) return channel_acc, channel_func_acc, prod_f1 def decode_and_evaluate_ifttt(model, test_data): raw_ids = [int(i.strip()) for i in open(config.ifttt_test_split)] # 'data/ifff.test_data.gold.id' eids = [i for i, e in enumerate(test_data.examples) if e.raw_id in raw_ids] test_data_subset = test_data.get_dataset_by_ids(eids, test_data.name + '.subset') from decoder import decode_ifttt_dataset decode_results = decode_ifttt_dataset(model, test_data_subset, verbose=True) evaluate_ifttt_results(test_data_subset, decode_results) return decode_results def decode_and_evaluate_ifttt_by_split(model, test_data): for split in ['ifff.test_data.omit_non_english.id', 'ifff.test_data.omit_unintelligible.id', 'ifff.test_data.gold.id']: raw_ids = [int(i.strip()) for i in open(os.path.join(config.ifttt_test_split), split)] # 'data/ifff.test_data.gold.id' eids = [i for i, e in enumerate(test_data.examples) if e.raw_id in raw_ids] test_data_subset = test_data.get_dataset_by_ids(eids, test_data.name + '.' + split) from decoder import decode_ifttt_dataset decode_results = decode_ifttt_dataset(model, test_data_subset, verbose=True) evaluate_ifttt_results(test_data_subset, decode_results) if __name__ == '__main__': from dataset import DataEntry, DataSet, Vocab, Action init_logging('parser.log', logging.INFO) train_data, dev_data, test_data = deserialize_from_file('data/ifttt.freq3.bin') decoding_results = [] for eid in range(test_data.count): example = test_data.examples[eid] decoding_results.append([(eid, example.parse_tree)]) evaluate_ifttt_results(test_data, decoding_results, verbose=True) ================================================ FILE: interactive_mode.py ================================================ import argparse, sys from nn.utils.generic_utils import init_logging from nn.utils.io_utils import deserialize_from_file, serialize_to_file from evaluation import * from dataset import canonicalize_query, query_to_data from collections import namedtuple from lang.py.parse import decode_tree_to_python_ast from model import Model from dataset import DataEntry, DataSet, Vocab, Action import config parser = argparse.ArgumentParser() parser.add_argument('-data_type', default='django', choices=['django', 'hs']) parser.add_argument('-data') parser.add_argument('-random_seed', default=181783, type=int) parser.add_argument('-model', default=None) # neural model's parameters parser.add_argument('-source_vocab_size', default=0, type=int) parser.add_argument('-target_vocab_size', default=0, type=int) parser.add_argument('-rule_num', default=0, type=int) parser.add_argument('-node_num', default=0, type=int) parser.add_argument('-word_embed_dim', default=128, type=int) parser.add_argument('-rule_embed_dim', default=128, type=int) parser.add_argument('-node_embed_dim', default=64, type=int) parser.add_argument('-encoder_hidden_dim', default=256, type=int) parser.add_argument('-decoder_hidden_dim', default=256, type=int) parser.add_argument('-attention_hidden_dim', default=50, type=int) parser.add_argument('-ptrnet_hidden_dim', default=50, type=int) parser.add_argument('-dropout', default=0.2, type=float) # encoder parser.add_argument('-encoder', default='bilstm', choices=['bilstm', 'lstm']) # decoder parser.add_argument('-parent_hidden_state_feed', dest='parent_hidden_state_feed', action='store_true') parser.add_argument('-no_parent_hidden_state_feed', dest='parent_hidden_state_feed', action='store_false') parser.set_defaults(parent_hidden_state_feed=True) parser.add_argument('-parent_action_feed', dest='parent_action_feed', action='store_true') parser.add_argument('-no_parent_action_feed', dest='parent_action_feed', action='store_false') parser.set_defaults(parent_action_feed=True) parser.add_argument('-frontier_node_type_feed', dest='frontier_node_type_feed', action='store_true') parser.add_argument('-no_frontier_node_type_feed', dest='frontier_node_type_feed', action='store_false') parser.set_defaults(frontier_node_type_feed=True) parser.add_argument('-tree_attention', dest='tree_attention', action='store_true') parser.add_argument('-no_tree_attention', dest='tree_attention', action='store_false') parser.set_defaults(tree_attention=False) parser.add_argument('-enable_copy', dest='enable_copy', action='store_true') parser.add_argument('-no_copy', dest='enable_copy', action='store_false') parser.set_defaults(enable_copy=True) # training parser.add_argument('-optimizer', default='adam') parser.add_argument('-clip_grad', default=0., type=float) parser.add_argument('-train_patience', default=10, type=int) parser.add_argument('-max_epoch', default=50, type=int) parser.add_argument('-batch_size', default=10, type=int) parser.add_argument('-valid_per_batch', default=4000, type=int) parser.add_argument('-save_per_batch', default=4000, type=int) parser.add_argument('-valid_metric', default='bleu') # decoding parser.add_argument('-beam_size', default=15, type=int) parser.add_argument('-max_query_length', default=70, type=int) parser.add_argument('-decode_max_time_step', default=100, type=int) parser.add_argument('-head_nt_constraint', dest='head_nt_constraint', action='store_true') parser.add_argument('-no_head_nt_constraint', dest='head_nt_constraint', action='store_false') parser.set_defaults(head_nt_constraint=True) args = parser.parse_args(args=['-data_type', 'django', '-data', 'data/django.cleaned.dataset.freq5.par_info.refact.space_only.bin', '-model', 'models/model.django_word128_encoder256_rule128_node64.beam15.adam.simple_trans.no_unary_closure.8e39832.run3.best_acc.npz']) if args.data_type == 'hs': args.decode_max_time_step = 350 logging.info('loading dataset [%s]', args.data) train_data, dev_data, test_data = deserialize_from_file(args.data) if not args.source_vocab_size: args.source_vocab_size = train_data.annot_vocab.size if not args.target_vocab_size: args.target_vocab_size = train_data.terminal_vocab.size if not args.rule_num: args.rule_num = len(train_data.grammar.rules) if not args.node_num: args.node_num = len(train_data.grammar.node_type_to_id) config_module = sys.modules['config'] for name, value in vars(args).iteritems(): setattr(config_module, name, value) # build the model model = Model() model.build() model.load(args.model) def decode_query(query): """decode a given natural language query, return a list of generated candidates""" query, str_map = canonicalize_query(query) vocab = train_data.annot_vocab query_tokens = query.split(' ') query_tokens_data = [query_to_data(query, vocab)] example = namedtuple('example', ['query', 'data'])(query=query_tokens, data=query_tokens_data) cand_list = model.decode(example, train_data.grammar, train_data.terminal_vocab, beam_size=args.beam_size, max_time_step=args.decode_max_time_step, log=True) return cand_list if __name__ == '__main__': print 'run in interactive mode' while True: query = raw_input('input a query: ') cand_list = decode_query(query) # output top 5 candidates for cid, cand in enumerate(cand_list[:5]): print '*' * 60 print 'cand #%d, score: %f' % (cid, cand.score) try: ast_tree = decode_tree_to_python_ast(cand.tree) code = astor.to_source(ast_tree) print 'code: ', code print 'decode log: ', cand.log except: print "Exception in converting tree to code:" print '-' * 60 print 'raw_id: %d, beam pos: %d' % (example.raw_id, cid) traceback.print_exc(file=sys.stdout) print '-' * 60 finally: print '* parse tree *' print cand.tree.__repr__() print 'n_timestep: %d' % cand.n_timestep print 'ast size: %d' % cand.tree.size print '*' * 60 ================================================ FILE: lang/__init__.py ================================================ ================================================ FILE: lang/grammar.py ================================================ from collections import OrderedDict, defaultdict import logging from astnode import ASTNode from lang.util import typename class Grammar(object): def __init__(self, rules): """ instantiate a grammar with a set of production rules of type Rule """ self.rules = rules self.rule_index = defaultdict(list) self.rule_to_id = OrderedDict() node_types = set() lhs_nodes = set() rhs_nodes = set() for rule in self.rules: self.rule_index[rule.parent].append(rule) # we also store all unique node types for node in rule.nodes: node_types.add(typename(node.type)) lhs_nodes.add(rule.parent) for child in rule.children: rhs_nodes.add(child.as_type_node) root_node = lhs_nodes - rhs_nodes assert len(root_node) == 1 self.root_node = next(iter(root_node)) self.terminal_nodes = rhs_nodes - lhs_nodes self.terminal_types = set([n.type for n in self.terminal_nodes]) self.node_type_to_id = OrderedDict() for i, type in enumerate(node_types, start=0): self.node_type_to_id[type] = i for gid, rule in enumerate(rules, start=0): self.rule_to_id[rule] = gid self.id_to_rule = OrderedDict((v, k) for (k, v) in self.rule_to_id.iteritems()) logging.info('num. rules: %d', len(self.rules)) logging.info('num. types: %d', len(self.node_type_to_id)) logging.info('root: %s', self.root_node) logging.info('terminals: %s', ', '.join(repr(n) for n in self.terminal_nodes)) def __iter__(self): return self.rules.__iter__() def __len__(self): return len(self.rules) def __getitem__(self, lhs): key_node = ASTNode(lhs.type, None) # Rules are indexed by types only if key_node in self.rule_index: return self.rule_index[key_node] else: KeyError('key=%s' % key_node) def get_node_type_id(self, node): from astnode import ASTNode if isinstance(node, ASTNode): type_repr = typename(node.type) return self.node_type_to_id[type_repr] else: # assert isinstance(node, str) # it is a type type_repr = typename(node) return self.node_type_to_id[type_repr] def is_terminal(self, node): return node.type in self.terminal_types def is_value_node(self, node): raise NotImplementedError ================================================ FILE: lang/ifttt/__init__.py ================================================ ================================================ FILE: lang/ifttt/grammar.py ================================================ from lang.grammar import Grammar class IFTTTGrammar(Grammar): def __init__(self, rules): super(IFTTTGrammar, self).__init__(rules) def is_value_node(self, node): return False ================================================ FILE: lang/ifttt/ifttt_dataset.py ================================================ # -*- coding: UTF-8 -*- from __future__ import division import string from collections import OrderedDict from collections import defaultdict from itertools import count from nn.utils.io_utils import serialize_to_file, deserialize_from_file from lang.ifttt.grammar import IFTTTGrammar from parse import ifttt_ast_to_parse_tree from lang.grammar import Grammar import logging from itertools import chain from nn.utils.generic_utils import init_logging from dataset import gen_vocab, DataSet, DataEntry, Action, APPLY_RULE, GEN_TOKEN, COPY_TOKEN, GEN_COPY_TOKEN def load_examples(data_file): f = open(data_file) next(f) examples = [] for line in f: d = line.strip().split('\t') description = d[4] code = d[9] parse_tree = ifttt_ast_to_parse_tree(code) examples.append({'description': description, 'parse_tree': parse_tree, 'code': code}) return examples def analyze_ifttt_dataset(): data_file = '/Users/yinpengcheng/Research/SemanticParsing/ifttt/recipe_summaries.all.tsv' examples = load_examples(data_file) rule_num = 0. max_rule_num = -1 example_with_max_rule_num = -1 for idx, example in enumerate(examples): parse_tree = example['parse_tree'] rules, _ = parse_tree.get_productions(include_value_node=True) rule_num += len(rules) if max_rule_num < len(rules): max_rule_num = len(rules) example_with_max_rule_num = idx logging.info('avg. num. of rules: %f', rule_num / len(examples)) logging.info('max_rule_num: %d', max_rule_num) logging.info('example_with_max_rule_num: %d', example_with_max_rule_num) def canonicalize_ifttt_example(annot, code): parse_tree = ifttt_ast_to_parse_tree(code, attach_func_to_channel=False) clean_code = str(parse_tree) clean_query_tokens = annot.split() clean_query_tokens = [t.lower() for t in clean_query_tokens] return clean_query_tokens, clean_code, parse_tree def preprocess_ifttt_dataset(annot_file, code_file): f = open('ifttt_dataset.examples.txt', 'w') examples = [] for idx, (annot, code) in enumerate(zip(open(annot_file), open(code_file))): annot = annot.strip() code = code.strip() clean_query_tokens, clean_code, parse_tree = canonicalize_ifttt_example(annot, code) example = {'id': idx, 'query_tokens': clean_query_tokens, 'code': clean_code, 'parse_tree': parse_tree, 'str_map': None, 'raw_code': code} examples.append(example) f.write('*' * 50 + '\n') f.write('example# %d\n' % idx) f.write(' '.join(clean_query_tokens) + '\n') f.write('\n') f.write(clean_code + '\n') f.write('*' * 50 + '\n') idx += 1 f.close() print 'preprocess_dataset: cleaned example num: %d' % len(examples) return examples def get_grammar(parse_trees): rules = set() for parse_tree in parse_trees: parse_tree_rules, rule_parents = parse_tree.get_productions() for rule in parse_tree_rules: rules.add(rule) rules = list(sorted(rules, key=lambda x: x.__repr__())) grammar = IFTTTGrammar(rules) logging.info('num. rules: %d', len(rules)) with open('grammar.txt', 'w') as f: for rule in grammar: str = rule.__repr__() f.write(str + '\n') with open('parse_trees.txt', 'w') as f: for tree in parse_trees: f.write(tree.__repr__() + '\n') return grammar def parse_ifttt_dataset(): WORD_FREQ_CUT_OFF = 2 annot_file = '/Users/yinpengcheng/Research/SemanticParsing/ifttt/Data/lang.all.txt' code_file = '/Users/yinpengcheng/Research/SemanticParsing/ifttt/Data/code.all.txt' data = preprocess_ifttt_dataset(annot_file, code_file) # build the grammar grammar = get_grammar([e['parse_tree'] for e in data]) annot_tokens = list(chain(*[e['query_tokens'] for e in data])) annot_vocab = gen_vocab(annot_tokens, vocab_size=30000, freq_cutoff=WORD_FREQ_CUT_OFF) logging.info('annot vocab. size: %d', annot_vocab.size) # we have no terminal tokens in ifttt all_terminal_tokens = [] terminal_vocab = gen_vocab(all_terminal_tokens, vocab_size=4000, freq_cutoff=WORD_FREQ_CUT_OFF) # now generate the dataset! train_data = DataSet(annot_vocab, terminal_vocab, grammar, 'ifttt.train_data') dev_data = DataSet(annot_vocab, terminal_vocab, grammar, 'ifttt.dev_data') test_data = DataSet(annot_vocab, terminal_vocab, grammar, 'ifttt.test_data') all_examples = [] can_fully_reconstructed_examples_num = 0 examples_with_empty_actions_num = 0 for entry in data: idx = entry['id'] query_tokens = entry['query_tokens'] code = entry['code'] parse_tree = entry['parse_tree'] # check if query tokens are valid query_token_ids = [annot_vocab[token] for token in query_tokens if token not in string.punctuation] valid_query_tokens_ids = [tid for tid in query_token_ids if tid != annot_vocab.unk] # remove examples with rare words from train and dev, avoid overfitting if len(valid_query_tokens_ids) == 0 and 0 <= idx < 77495 + 5171: continue rule_list, rule_parents = parse_tree.get_productions(include_value_node=True) actions = [] can_fully_reconstructed = True rule_pos_map = dict() for rule_count, rule in enumerate(rule_list): if not grammar.is_value_node(rule.parent): assert rule.value is None parent_rule = rule_parents[(rule_count, rule)][0] if parent_rule: parent_t = rule_pos_map[parent_rule] else: parent_t = 0 rule_pos_map[rule] = len(actions) d = {'rule': rule, 'parent_t': parent_t, 'parent_rule': parent_rule} action = Action(APPLY_RULE, d) actions.append(action) else: raise RuntimeError('no terminals should be in ifttt dataset!') if len(actions) == 0: examples_with_empty_actions_num += 1 continue example = DataEntry(idx, query_tokens, parse_tree, code, actions, {'str_map': None, 'raw_code': entry['raw_code']}) if can_fully_reconstructed: can_fully_reconstructed_examples_num += 1 # train, valid, test splits if 0 <= idx < 77495: train_data.add(example) elif idx < 77495 + 5171: dev_data.add(example) else: test_data.add(example) all_examples.append(example) # print statistics max_query_len = max(len(e.query) for e in all_examples) max_actions_len = max(len(e.actions) for e in all_examples) # serialize_to_file([len(e.query) for e in all_examples], 'query.len') # serialize_to_file([len(e.actions) for e in all_examples], 'actions.len') logging.info('train_data examples: %d', train_data.count) logging.info('dev_data examples: %d', dev_data.count) logging.info('test_data examples: %d', test_data.count) logging.info('examples that can be fully reconstructed: %d/%d=%f', can_fully_reconstructed_examples_num, len(all_examples), can_fully_reconstructed_examples_num / len(all_examples)) logging.info('empty_actions_count: %d', examples_with_empty_actions_num) logging.info('max_query_len: %d', max_query_len) logging.info('max_actions_len: %d', max_actions_len) train_data.init_data_matrices(max_query_length=40, max_example_action_num=6) dev_data.init_data_matrices() test_data.init_data_matrices() serialize_to_file((train_data, dev_data, test_data), 'data/ifttt.freq{WORD_FREQ_CUT_OFF}.bin'.format(WORD_FREQ_CUT_OFF=WORD_FREQ_CUT_OFF)) return train_data, dev_data, test_data def parse_data_for_seq2seq(data_file='data/ifttt.freq3.bin'): train_data, dev_data, test_data = deserialize_from_file(data_file) prefix = 'data/seq2seq/' for dataset, output in [(train_data, prefix + 'ifttt.train'), (dev_data, prefix + 'ifttt.dev'), (test_data, prefix + 'ifttt.test')]: f_source = open(output + '.desc', 'w') f_target = open(output + '.code', 'w') if 'test' in output: raw_ids = [int(i.strip()) for i in open('data/ifff.test_data.gold.id')] eids = [i for i, e in enumerate(test_data.examples) if e.raw_id in raw_ids] dataset = test_data.get_dataset_by_ids(eids, test_data.name + '.subset') for e in dataset.examples: query_tokens = e.query trigger = e.parse_tree['TRIGGER'].children[0].type + ' . ' + e.parse_tree['TRIGGER'].children[0].children[0].type action = e.parse_tree['ACTION'].children[0].type + ' . ' + e.parse_tree['ACTION'].children[0].children[0].type code = 'IF ' + trigger + ' THEN ' + action f_source.write(' '.join(query_tokens) + '\n') f_target.write(code + '\n') f_source.close() f_target.close() def extract_turk_data(): turk_annot_file = '/Users/yinpengcheng/Research/SemanticParsing/ifttt/public_release/data/turk_public.tsv' reference_file = '/Users/yinpengcheng/Research/SemanticParsing/ifttt/public_release/data/ifttt_public.tsv' f_turk = open(turk_annot_file) next(f_turk) annot_data = OrderedDict() for line in f_turk: d = line.strip().split('\t') url = d[0] if url not in annot_data: annot_data[url] = list() annot_data[url].append({'trigger_channel': d[2], 'trigger_func': d[3], 'action_channel': d[4], 'action_func': d[5]}) f_ref = open(reference_file) next(f_ref) ref_data = OrderedDict() for line in f_ref: d = line.strip().split('\t') url = d[0] ref_data[url] = {'trigger_channel': d[2], 'trigger_func': d[3], 'action_channel': d[4], 'action_func': d[5]} lt_three_agree_with_gold = [] non_english_examples = [] unintelligible_examples = [] for url, annots in annot_data.iteritems(): vote_dict = defaultdict(int) ref = ref_data[url] match_with_gold_num = 0 non_english_num = unintelligible_num = 0 non_english_annots = [] unintelligible_annots = [] for annot in annots: if annot['trigger_channel'] == ref['trigger_channel'] and annot['trigger_func'] == ref['trigger_func'] and \ annot['action_channel'] == ref['action_channel'] and annot['action_func'] == ref['action_func']: match_with_gold_num += 1 vote_dict['#'.join(annot.values())] += 1 for i, annot in enumerate(annots): if annot['trigger_channel'] == 'nonenglish' and annot['trigger_func'] == 'nonenglish' and \ annot['action_channel'] == 'nonenglish' and annot['action_func'] == 'nonenglish': non_english_num += 1 non_english_annots.append(i) if annot['trigger_channel'] == 'unintelligible' and annot['trigger_func'] == 'unintelligible' and \ annot['action_channel'] == 'unintelligible' and annot['action_func'] == 'unintelligible': unintelligible_num += 1 unintelligible_annots.append(i) max_vote_num = max(vote_dict.values()) # omitting descriptions marked as non-English by a majority of the crowdsourced workers if non_english_num == max_vote_num: non_english_examples.append(url) non_english_and_unintelligible_num = len(set(non_english_annots).union(set(unintelligible_annots))) # if this example has no non_english and unintelligible annotations if non_english_and_unintelligible_num > 0: # < len(annots) - non_english_and_unintelligible_num: unintelligible_examples.append(url) if match_with_gold_num >= 3: lt_three_agree_with_gold.append(url) omit_non_english_examples = set(annot_data) - set(non_english_examples) omit_unintelligible_examples = set(annot_data) - set(unintelligible_examples) print len(omit_non_english_examples) # should be 3,741 print len(omit_unintelligible_examples) # should be 2,262 print len(lt_three_agree_with_gold) # should be 758 url2id = defaultdict(count(0).next) for url in ref_data: url2id[url] = url2id[url] + 77495 + 5171 f_gold = open('data/ifff.test_data.gold.id', 'w') for url in lt_three_agree_with_gold: i = url2id[url] f_gold.write(str(i) + '\n') f_gold.close() f_gold = open('data/ifff.test_data.omit_unintelligible.id', 'w') for url in omit_unintelligible_examples: i = url2id[url] f_gold.write(str(i) + '\n') f_gold.close() f_gold = open('data/ifff.test_data.omit_non_english.id', 'w') for url in omit_non_english_examples: i = url2id[url] f_gold.write(str(i) + '\n') f_gold.close() omit_non_english_examples = [url2id[url] for url in omit_non_english_examples] omit_unintelligible_examples = [url2id[url] for url in omit_unintelligible_examples] lt_three_agree_with_gold = [url2id[url] for url in lt_three_agree_with_gold] return omit_non_english_examples, omit_unintelligible_examples, lt_three_agree_with_gold if __name__ == '__main__': init_logging('ifttt.log') # parse_ifttt_dataset() # analyze_ifttt_dataset() extract_turk_data() # parse_data_for_seq2seq() ================================================ FILE: lang/ifttt/parse.py ================================================ from astnode import ASTNode def ifttt_ast_to_parse_tree_helper(s, offset): """ adapted from ifttt codebase """ if s[offset] != '(': raise RuntimeError('malformed string: node did not start with open paren at position ' + offset) offset += 1 # extract node name(type) name = '' if s[offset] == '\"': offset += 1 while s[offset] != '\"': if s[offset] == '\\': offset += 1 name += s[offset] offset += 1 offset += 1 else: while s[offset] != ' ' and s[offset] != ')': name += s[offset] offset += 1 node = ASTNode(name) while True: if s[offset] == ')': offset += 1 return node, offset if s[offset] != ' ': raise RuntimeError('malformed string: node should have either had a ' 'close paren or a space at position ' + offset) offset += 1 child_node, offset = ifttt_ast_to_parse_tree_helper(s, offset) node.add_child(child_node) def ifttt_ast_to_parse_tree(s, attach_func_to_channel=True): parse_tree, _ = ifttt_ast_to_parse_tree_helper(s, 0) parse_tree = strip_params(parse_tree) if attach_func_to_channel: parse_tree = attach_function_to_channel(parse_tree) return parse_tree def strip_params(parse_tree): if parse_tree.type == 'PARAMS': raise RuntimeError('should not go to here!') parse_tree.children = [c for c in parse_tree.children if c.type != 'PARAMS' and c.type != 'OUTPARAMS'] for i, child in enumerate(parse_tree.children): parse_tree.children[i] = strip_params(child) return parse_tree def attach_function_to_channel(parse_tree): trigger_func = parse_tree['TRIGGER']['FUNC'].children assert len(trigger_func) == 1 trigger_func = trigger_func[0] parse_tree['TRIGGER'].children[0].add_child(trigger_func) del parse_tree['TRIGGER']['FUNC'] action_func = parse_tree['ACTION']['FUNC'].children assert len(action_func) == 1 action_func = action_func[0] parse_tree['ACTION'].children[0].add_child(action_func) del parse_tree['ACTION']['FUNC'] return parse_tree if __name__ == '__main__': tree_code = """(ROOT (IF) (TRIGGER (Instagram) (FUNC (Any_new_photo_by_you) (PARAMS))) (THEN) (ACTION (Dropbox) (FUNC (Add_file_from_URL) (PARAMS (File_URL ({{Caption}})) (File_name ("")) (Dropbox_folder_path (ifttt/instagram))))))""" parse_tree = ifttt_ast_to_parse_tree(tree_code) print parse_tree print strip_params(parse_tree) print attach_function_to_channel(parse_tree) ================================================ FILE: lang/py/__init__.py ================================================ ================================================ FILE: lang/py/grammar.py ================================================ """ Python grammar and typing system """ import ast import inspect import astor from lang.grammar import Grammar PY_AST_NODE_FIELDS = { 'FunctionDef': { 'name': { 'type': str, 'is_list': False, 'is_optional': False }, 'args': { 'type': ast.arguments, 'is_list': False, 'is_optional': False }, 'body': { 'type': ast.stmt, 'is_list': True, 'is_optional': False }, 'decorator_list': { 'type': ast.expr, 'is_list': True, 'is_optional': False } }, 'ClassDef': { 'name': { 'type': ast.arguments, 'is_list': False, 'is_optional': False }, 'bases': { 'type': ast.expr, 'is_list': True, 'is_optional': False }, 'body': { 'type': ast.stmt, 'is_list': True, 'is_optional': False }, 'decorator_list': { 'type': ast.expr, 'is_list': True, 'is_optional': False } }, 'Return': { 'value': { 'type': ast.expr, 'is_list': False, 'is_optional': True }, }, 'Delete': { 'targets': { 'type': ast.expr, 'is_list': True, 'is_optional': False }, }, 'Assign': { 'targets': { 'type': ast.expr, 'is_list': True, 'is_optional': False }, 'value': { 'type': ast.expr, 'is_list': False, 'is_optional': False } }, 'AugAssign': { 'target': { 'type': ast.expr, 'is_list': False, 'is_optional': False }, 'op': { 'type': ast.operator, 'is_list': False, 'is_optional': False }, 'value': { 'type': ast.expr, 'is_list': False, 'is_optional': False } }, 'Print': { 'dest': { 'type': ast.expr, 'is_list': False, 'is_optional': True }, 'values': { 'type': ast.expr, 'is_list': True, 'is_optional': False }, 'nl': { 'type': bool, 'is_list': False, 'is_optional': False } }, 'For': { 'target': { 'type': ast.expr, 'is_list': False, 'is_optional': False }, 'iter': { 'type': ast.expr, 'is_list': False, 'is_optional': False }, 'body': { 'type': ast.stmt, 'is_list': True, 'is_optional': False }, 'orelse': { 'type': ast.stmt, 'is_list': True, 'is_optional': False } }, 'While': { 'test': { 'type': ast.expr, 'is_list': False, 'is_optional': False }, 'body': { 'type': ast.stmt, 'is_list': True, 'is_optional': False }, 'orelse': { 'type': ast.stmt, 'is_list': True, 'is_optional': False }, }, 'If': { 'test': { 'type': ast.expr, 'is_list': False, 'is_optional': False }, 'body': { 'type': ast.stmt, 'is_list': True, 'is_optional': False }, 'orelse': { 'type': ast.stmt, 'is_list': True, 'is_optional': False }, }, 'With': { 'context_expr': { 'type': ast.expr, 'is_list': False, 'is_optional': False }, 'optional_vars': { 'type': ast.expr, 'is_list': False, 'is_optional': True }, 'body': { 'type': ast.stmt, 'is_list': True, 'is_optional': False }, }, 'Raise': { 'type': { 'type': ast.expr, 'is_list': False, 'is_optional': True }, 'inst': { 'type': ast.expr, 'is_list': False, 'is_optional': True }, 'tback': { 'type': ast.expr, 'is_list': False, 'is_optional': True }, }, 'TryExcept': { 'body': { 'type': ast.stmt, 'is_list': True, 'is_optional': False }, 'handlers': { 'type': ast.excepthandler, 'is_list': True, 'is_optional': False }, 'orelse': { 'type': ast.stmt, 'is_list': True, 'is_optional': False }, }, 'TryFinally': { 'body': { 'type': ast.stmt, 'is_list': True, 'is_optional': False }, 'finalbody': { 'type': ast.stmt, 'is_list': True, 'is_optional': False } }, 'Assert': { 'test': { 'type': ast.expr, 'is_list': False, 'is_optional': False }, 'msg': { 'type': ast.expr, 'is_list': False, 'is_optional': True } }, 'Import': { 'names': { 'type': ast.alias, 'is_list': True, 'is_optional': False } }, 'ImportFrom': { 'module': { 'type': str, 'is_list': False, 'is_optional': True }, 'names': { 'type': ast.alias, 'is_list': True, 'is_optional': False }, 'level': { 'type': int, 'is_list': False, 'is_optional': True } }, 'Exec': { 'body': { 'type': ast.expr, 'is_list': False, 'is_optional': False }, 'globals': { 'type': ast.expr, 'is_list': False, 'is_optional': True }, 'locals': { 'type': ast.expr, 'is_list': False, 'is_optional': True }, }, 'Global': { 'names': { 'type': str, 'is_list': True, 'is_optional': False }, }, 'Expr': { 'value': { 'type': ast.expr, 'is_list': False, 'is_optional': False }, }, 'BoolOp': { 'op': { 'type': ast.boolop, 'is_list': False, 'is_optional': False }, 'values': { 'type': ast.expr, 'is_list': True, 'is_optional': False }, }, 'BinOp': { 'left': { 'type': ast.expr, 'is_list': False, 'is_optional': False }, 'op': { 'type': ast.operator, 'is_list': False, 'is_optional': False }, 'right': { 'type': ast.expr, 'is_list': False, 'is_optional': False }, }, 'UnaryOp': { 'op': { 'type': ast.unaryop, 'is_list': False, 'is_optional': False }, 'operand': { 'type': ast.expr, 'is_list': False, 'is_optional': False }, }, 'Lambda': { 'args': { 'type': ast.arguments, 'is_list': False, 'is_optional': False }, 'body': { 'type': ast.expr, 'is_list': False, 'is_optional': False }, }, 'IfExp': { 'test': { 'type': ast.expr, 'is_list': False, 'is_optional': False }, 'body': { 'type': ast.expr, 'is_list': False, 'is_optional': False }, 'orelse': { 'type': ast.expr, 'is_list': False, 'is_optional': False }, }, 'Dict': { 'keys': { 'type': ast.expr, 'is_list': True, 'is_optional': False }, 'values': { 'type': ast.expr, 'is_list': True, 'is_optional': False }, }, 'Set': { 'elts': { 'type': ast.expr, 'is_list': True, 'is_optional': False }, }, 'ListComp': { 'elt': { 'type': ast.expr, 'is_list': False, 'is_optional': False }, 'generators': { 'type': ast.comprehension, 'is_list': True, 'is_optional': False }, }, 'SetComp': { 'elt': { 'type': ast.expr, 'is_list': False, 'is_optional': False }, 'generators': { 'type': ast.comprehension, 'is_list': True, 'is_optional': False }, }, 'DictComp': { 'key': { 'type': ast.expr, 'is_list': False, 'is_optional': False }, 'value': { 'type': ast.expr, 'is_list': False, 'is_optional': False }, 'generators': { 'type': ast.comprehension, 'is_list': True, 'is_optional': False }, }, 'GeneratorExp': { 'elt': { 'type': ast.expr, 'is_list': False, 'is_optional': False }, 'generators': { 'type': ast.comprehension, 'is_list': True, 'is_optional': False }, }, 'Yield': { 'value': { 'type': ast.expr, 'is_list': False, 'is_optional': True } }, 'Compare': { 'left': { 'type': ast.expr, 'is_list': False, 'is_optional': False }, 'ops': { 'type': ast.cmpop, 'is_list': True, 'is_optional': False }, 'comparators': { 'type': ast.expr, 'is_list': True, 'is_optional': False }, }, 'Call': { 'func': { 'type': ast.expr, 'is_list': False, 'is_optional': False }, 'args': { 'type': ast.expr, 'is_list': True, 'is_optional': False }, 'keywords': { 'type': ast.keyword, 'is_list': True, 'is_optional': False }, 'starargs': { 'type': ast.expr, 'is_list': False, 'is_optional': True }, 'kwargs': { 'type': ast.expr, 'is_list': False, 'is_optional': True }, }, 'Repr': { 'value': { 'type': ast.expr, 'is_list': False, 'is_optional': False } }, 'Num': { 'n': { 'type': object, #FIXME: should be float or int? 'is_list': False, 'is_optional': False } }, 'Str': { 's': { 'type': str, 'is_list': False, 'is_optional': False } }, 'Attribute': { 'value': { 'type': ast.expr, 'is_list': False, 'is_optional': False }, 'attr': { 'type': str, 'is_list': False, 'is_optional': False }, 'ctx': { 'type': ast.expr_context, 'is_list': False, 'is_optional': False }, }, 'Subscript': { 'value': { 'type': ast.expr, 'is_list': False, 'is_optional': False }, 'slice': { 'type': ast.slice, 'is_list': False, 'is_optional': False }, }, 'Name': { 'id': { 'type': str, 'is_list': False, 'is_optional': False } }, 'List': { 'elts': { 'type': ast.expr, 'is_list': True, 'is_optional': False }, 'ctx': { 'type': ast.expr_context, 'is_list': False, 'is_optional': False }, }, 'Tuple': { 'elts': { 'type': ast.expr, 'is_list': True, 'is_optional': False }, 'ctx': { 'type': ast.expr_context, 'is_list': False, 'is_optional': False }, }, 'ExceptHandler': { 'type': { 'type': ast.expr, 'is_list': False, 'is_optional': True }, 'name': { 'type': ast.expr, 'is_list': False, 'is_optional': True }, 'body': { 'type': ast.stmt, 'is_list': True, 'is_optional': False } }, 'arguments': { 'args': { 'type': ast.expr, 'is_list': True, 'is_optional': False }, 'vararg': { 'type': str, 'is_list': False, 'is_optional': True }, 'kwarg': { 'type': str, 'is_list': False, 'is_optional': True }, 'defaults': { 'type': ast.expr, 'is_list': True, 'is_optional': False }, }, 'comprehension': { 'target': { 'type': ast.expr, 'is_list': False, 'is_optional': False }, 'iter': { 'type': ast.expr, 'is_list': False, 'is_optional': False }, 'ifs': { 'type': ast.expr, 'is_list': True, 'is_optional': False }, }, 'keyword': { 'arg': { 'type': str, 'is_list': False, 'is_optional': False }, 'value': { 'type': ast.expr, 'is_list': False, 'is_optional': True } }, 'alias': { 'name': { 'type': str, 'is_list': False, 'is_optional': False }, 'asname': { 'type': str, 'is_list': False, 'is_optional': True } }, 'Slice': { 'lower': { 'type': ast.expr, 'is_list': False, 'is_optional': True }, 'upper': { 'type': ast.expr, 'is_list': False, 'is_optional': True }, 'step': { 'type': ast.expr, 'is_list': False, 'is_optional': True } }, 'ExtSlice': { 'dims': { 'type': ast.slice, 'is_list': True, 'is_optional': False } }, 'Index': { 'value': { 'type': ast.expr, 'is_list': False, 'is_optional': False } } } NODE_FIELD_BLACK_LIST = {'ctx'} TERMINAL_AST_TYPES = { ast.Pass, ast.Break, ast.Continue, ast.Add, ast.BitAnd, ast.BitOr, ast.BitXor, ast.Div, ast.FloorDiv, ast.LShift, ast.Mod, ast.Mult, ast.Pow, ast.Sub, ast.And, ast.Or, ast.Eq, ast.Gt, ast.GtE, ast.In, ast.Is, ast.IsNot, ast.Lt, ast.LtE, ast.NotEq, ast.NotIn, ast.Not, ast.USub } def is_builtin_type(x): return x == str or x == int or x == float or x == bool or x == object or x == 'identifier' def is_terminal_ast_type(x): if inspect.isclass(x) and x in TERMINAL_AST_TYPES: return True return False # def is_terminal_type(x): # if is_builtin_type(x): # return True # # if x == 'epsilon': # return True # # if inspect.isclass(x) and (issubclass(x, ast.Pass) or issubclass(x, ast.Raise) or issubclass(x, ast.Break) # or issubclass(x, ast.Continue) # or issubclass(x, ast.Return) # or issubclass(x, ast.operator) or issubclass(x, ast.boolop) # or issubclass(x, ast.Ellipsis) or issubclass(x, ast.unaryop) # or issubclass(x, ast.cmpop)): # return True # # return False # class Node: # def __init__(self, node_type, label): # self.type = node_type # self.label = label # # @property # def is_preterminal(self): # return is_terminal_type(self.type) # # def __eq__(self, other): # return self.type == other.type and self.label == other.label # # def __hash__(self): # return typename(self.type).__hash__() ^ self.label.__hash__() # # def __repr__(self): # repr_str = typename(self.type) # if self.label: # repr_str += '{%s}' % self.label # return repr_str # # # class TypedRule: # def __init__(self, parent, children, tree=None): # self.parent = parent # if isinstance(children, list) or isinstance(children, tuple): # self.children = tuple(children) # else: # self.children = (children, ) # # # tree property is not incorporated in eq, hash # self.tree = tree # # # @property # # def is_terminal_rule(self): # # return is_terminal_type(self.parent.type) # # def __eq__(self, other): # return self.parent == other.parent and self.children == other.children # # def __hash__(self): # return self.parent.__hash__() ^ self.children.__hash__() # # def __repr__(self): # return '%s -> %s' % (self.parent, ', '.join([c.__repr__() for c in self.children])) def type_str_to_type(type_str): if type_str.endswith('*') or type_str == 'root' or type_str == 'epsilon': return type_str else: try: type_obj = eval(type_str) if is_builtin_type(type_obj): return type_obj except: pass try: type_obj = eval('ast.' + type_str) return type_obj except: raise RuntimeError('unidentified type string: %s' % type_str) def is_compositional_leaf(node): is_leaf = True for field_name, field_value in ast.iter_fields(node): if field_name in NODE_FIELD_BLACK_LIST: continue if field_value is None: is_leaf &= True elif isinstance(field_value, list) and len(field_value) == 0: is_leaf &= True else: is_leaf &= False return is_leaf class PythonGrammar(Grammar): def __init__(self, rules): super(PythonGrammar, self).__init__(rules) def is_value_node(self, node): return is_builtin_type(node.type) ================================================ FILE: lang/py/parse.py ================================================ import ast import logging import re import token as tk from cStringIO import StringIO from tokenize import generate_tokens from astnode import ASTNode from lang.py.grammar import is_compositional_leaf, PY_AST_NODE_FIELDS, NODE_FIELD_BLACK_LIST, PythonGrammar from lang.util import escape from lang.util import typename def python_ast_to_parse_tree(node): assert isinstance(node, ast.AST) node_type = type(node) tree = ASTNode(node_type) # it's a leaf AST node, e.g., ADD, Break, etc. if len(node._fields) == 0: return tree # if it's a compositional AST node with empty fields if is_compositional_leaf(node): epsilon = ASTNode('epsilon') tree.add_child(epsilon) return tree fields_info = PY_AST_NODE_FIELDS[node_type.__name__] for field_name, field_value in ast.iter_fields(node): # remove ctx stuff if field_name in NODE_FIELD_BLACK_LIST: continue # omit empty fields, including empty lists if field_value is None or (isinstance(field_value, list) and len(field_value) == 0): continue # now it's not empty! field_type = fields_info[field_name]['type'] is_list_field = fields_info[field_name]['is_list'] if isinstance(field_value, ast.AST): child = ASTNode(field_type, field_name) child.add_child(python_ast_to_parse_tree(field_value)) elif type(field_value) is str or type(field_value) is int or \ type(field_value) is float or type(field_value) is object or \ type(field_value) is bool: # if field_type != type(field_value): # print 'expect [%s] type, got [%s]' % (field_type, type(field_value)) child = ASTNode(type(field_value), field_name, value=field_value) elif is_list_field: list_node_type = typename(field_type) + '*' child = ASTNode(list_node_type, field_name) for n in field_value: if field_type in {ast.comprehension, ast.excepthandler, ast.arguments, ast.keyword, ast.alias}: child.add_child(python_ast_to_parse_tree(n)) else: intermediate_node = ASTNode(field_type) if field_type is str: intermediate_node.value = n else: intermediate_node.add_child(python_ast_to_parse_tree(n)) child.add_child(intermediate_node) else: raise RuntimeError('unknown AST node field!') tree.add_child(child) return tree def parse_tree_to_python_ast(tree): node_type = tree.type node_label = tree.label # remove root if node_type == 'root': return parse_tree_to_python_ast(tree.children[0]) ast_node = node_type() node_type_name = typename(node_type) # if it's a compositional AST node, populate its children nodes, # fill fields with empty(default) values otherwise if node_type_name in PY_AST_NODE_FIELDS: fields_info = PY_AST_NODE_FIELDS[node_type_name] for child_node in tree.children: # if it's a compositional leaf if child_node.type == 'epsilon': continue field_type = child_node.type field_label = child_node.label field_entry = fields_info[field_label] is_list = field_entry['is_list'] if is_list: field_type = field_entry['type'] field_value = [] if field_type in {ast.comprehension, ast.excepthandler, ast.arguments, ast.keyword, ast.alias}: nodes_in_list = child_node.children for sub_node in nodes_in_list: sub_node_ast = parse_tree_to_python_ast(sub_node) field_value.append(sub_node_ast) else: # expr stuffs inter_nodes = child_node.children for inter_node in inter_nodes: if inter_node.value is None: assert len(inter_node.children) == 1 sub_node_ast = parse_tree_to_python_ast(inter_node.children[0]) field_value.append(sub_node_ast) else: assert len(inter_node.children) == 0 field_value.append(inter_node.value) else: # this node either holds a value, or is an non-terminal if child_node.value is None: assert len(child_node.children) == 1 field_value = parse_tree_to_python_ast(child_node.children[0]) else: assert child_node.is_leaf field_value = child_node.value setattr(ast_node, field_label, field_value) for field in ast_node._fields: if not hasattr(ast_node, field) and not field in NODE_FIELD_BLACK_LIST: if fields_info and fields_info[field]['is_list'] and not fields_info[field]['is_optional']: setattr(ast_node, field, list()) else: setattr(ast_node, field, None) return ast_node def decode_tree_to_python_ast(decode_tree): from lang.py.unaryclosure import compressed_ast_to_normal compressed_ast_to_normal(decode_tree) decode_tree = decode_tree.children[0] terminals = decode_tree.get_leaves() for terminal in terminals: if terminal.value is not None and type(terminal.value) is str: if terminal.value.endswith(''): terminal.value = terminal.value[:-5] if terminal.type in {int, float, str, bool}: # cast to target data type terminal.value = terminal.type(terminal.value) ast_tree = parse_tree_to_python_ast(decode_tree) return ast_tree p_elif = re.compile(r'^elif\s?') p_else = re.compile(r'^else\s?') p_try = re.compile(r'^try\s?') p_except = re.compile(r'^except\s?') p_finally = re.compile(r'^finally\s?') p_decorator = re.compile(r'^@.*') def canonicalize_code(code): if p_elif.match(code): code = 'if True: pass\n' + code if p_else.match(code): code = 'if True: pass\n' + code if p_try.match(code): code = code + 'pass\nexcept: pass' elif p_except.match(code): code = 'try: pass\n' + code elif p_finally.match(code): code = 'try: pass\n' + code if p_decorator.match(code): code = code + '\ndef dummy(): pass' if code[-1] == ':': code = code + 'pass' return code def de_canonicalize_code(code, ref_raw_code): if code.endswith('def dummy():\n pass'): code = code.replace('def dummy():\n pass', '').strip() if p_elif.match(ref_raw_code): # remove leading if true code = code.replace('if True:\n pass', '').strip() elif p_else.match(ref_raw_code): # remove leading if true code = code.replace('if True:\n pass', '').strip() # try/catch/except stuff if p_try.match(ref_raw_code): code = code.replace('except:\n pass', '').strip() elif p_except.match(ref_raw_code): code = code.replace('try:\n pass', '').strip() elif p_finally.match(ref_raw_code): code = code.replace('try:\n pass', '').strip() # remove ending pass if code.endswith(':\n pass'): code = code[:-len('\n pass')] return code def de_canonicalize_code_for_seq2seq(code, ref_raw_code): if code.endswith('\ndef dummy(): pass'): code = code.replace('\ndef dummy(): pass', '').strip() if p_elif.match(ref_raw_code): # remove leading if true code = code.replace('if True: pass\n', '').strip() elif p_else.match(ref_raw_code): # remove leading if true code = code.replace('if True: pass\n', '').strip() # try/catch/except stuff if p_try.match(ref_raw_code): code = code.replace('pass\nexcept: pass', '').strip() elif p_except.match(ref_raw_code): code = code.replace('try: pass\n', '').strip() elif p_finally.match(ref_raw_code): code = code.replace('try: pass\n', '').strip() # remove ending pass if code.endswith(':pass'): code = code[:-len('pass')] return code.strip() def add_root(tree): root_node = ASTNode('root') root_node.add_child(tree) return root_node def parse(code): """ parse a python code into a tree structure code -> AST tree -> AST tree to internal tree structure """ code = canonicalize_code(code) py_ast = ast.parse(code) tree = python_ast_to_parse_tree(py_ast.body[0]) tree = add_root(tree) return tree def parse_raw(code): py_ast = ast.parse(code) tree = python_ast_to_parse_tree(py_ast.body[0]) tree = add_root(tree) return tree def get_grammar(parse_trees): rules = set() # rule_num_dist = defaultdict(int) for parse_tree in parse_trees: parse_tree_rules, rule_parents = parse_tree.get_productions() for rule in parse_tree_rules: rules.add(rule) rules = list(sorted(rules, key=lambda x: x.__repr__())) grammar = PythonGrammar(rules) logging.info('num. rules: %d', len(rules)) return grammar def tokenize_code(code): token_stream = generate_tokens(StringIO(code).readline) tokens = [] for toknum, tokval, (srow, scol), (erow, ecol), _ in token_stream: if toknum == tk.ENDMARKER: break tokens.append(tokval) return tokens def tokenize_code_adv(code, breakCamelStr=False): token_stream = generate_tokens(StringIO(code).readline) tokens = [] indent_level = 0 for toknum, tokval, (srow, scol), (erow, ecol), _ in token_stream: if toknum == tk.ENDMARKER: break if toknum == tk.INDENT: indent_level += 1 tokens.extend(['#INDENT#'] * indent_level) continue elif toknum == tk.DEDENT: indent_level -= 1 tokens.extend(['#INDENT#'] * indent_level) continue elif len(tokens) > 0 and tokens[-1] == '\n' and tokval != '\n': tokens.extend(['#INDENT#'] * indent_level) if toknum == tk.STRING: quote = tokval[0] tokval = tokval[1:-1] tokens.append(quote) if breakCamelStr: sub_tokens = re.sub(r'([a-z])([A-Z])', r'\1 #MERGE# \2', tokval).split(' ') tokens.extend(sub_tokens) else: tokens.append(tokval) if toknum == tk.STRING: tokens.append(quote) return tokens if __name__ == '__main__': from nn.utils.generic_utils import init_logging init_logging('misc.log') # django_code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code' # # grammar, parse_trees = extract_grammar(django_code_file) # id = 1888 # parse_tree = parse_trees[id] # print parse_tree # from components import Hyp # hyp = Hyp(grammar) # rules, rule_parents = parse_tree.get_productions() # # while hyp.frontier_nt(): # nt = hyp.frontier_nt() # if grammar.is_value_node(nt): # hyp.append_token('111') # else: # rule = rules[0] # hyp.apply_rule(rule) # del rules[0] # # print hyp # # ast_tree = decode_tree_to_python_ast(hyp.tree) # source = astor.to_source(ast_tree) # print source # for code in open(django_code_file): # code = code.strip() # ref_ast_tree = ast.parse(canonicalize_code(code)).body[0] # parse_tree = parse(code) # ast_tree = parse_tree_to_python_ast(parse_tree) # source1 = astor.to_source(ast_tree) # source2 = astor.to_source(ref_ast_tree) # # if source1 != source2: # pass code = """ class Demonwrath(SpellCard): def __init__(self): super().__init__("Demonwrath", 3, CHARACTER_CLASS.WARLOCK, CARD_RARITY.RARE) def use(self, player, game): super().use(player, game) targets = copy.copy(game.other_player.minions) targets.extend(game.current_player.minions) for minion in targets: if minion.card.minion_type is not MINION_TYPE.DEMON: minion.damage(player.effective_spell_damage(2), self) """ code = """sorted(mydict, key=mydict.get, reverse=True)""" # # code = """a = [1,2,3,4,'asdf', 234.3]""" parse_tree = parse(code) # for leaf in parse_tree.get_leaves(): # if leaf.value: print escape(leaf.value) # print parse_tree # ast_tree = parse_tree_to_python_ast(parse_tree) # print astor.to_source(ast_tree) ================================================ FILE: lang/py/py_dataset.py ================================================ # -*- coding: UTF-8 -*- from __future__ import division import ast import astor import logging from itertools import chain import nltk import re from nn.utils.io_utils import serialize_to_file, deserialize_from_file from nn.utils.generic_utils import init_logging from dataset import gen_vocab, DataSet, DataEntry, Action, APPLY_RULE, GEN_TOKEN, COPY_TOKEN, GEN_COPY_TOKEN, Vocab from lang.py.parse import parse, parse_tree_to_python_ast, canonicalize_code, get_grammar, parse_raw, \ de_canonicalize_code, tokenize_code, tokenize_code_adv, de_canonicalize_code_for_seq2seq from lang.py.unaryclosure import get_top_unary_closures, apply_unary_closures def extract_grammar(code_file, prefix='py'): line_num = 0 parse_trees = [] for line in open(code_file): code = line.strip() parse_tree = parse(code) # leaves = parse_tree.get_leaves() # for leaf in leaves: # if not is_terminal_type(leaf.type): # print parse_tree # parse_tree = add_root(parse_tree) parse_trees.append(parse_tree) # sanity check ast_tree = parse_tree_to_python_ast(parse_tree) ref_ast_tree = ast.parse(canonicalize_code(code)).body[0] source1 = astor.to_source(ast_tree) source2 = astor.to_source(ref_ast_tree) assert source1 == source2 # check rules # rule_list = parse_tree.get_rule_list(include_leaf=True) # for rule in rule_list: # if rule.parent.type == int and rule.children[0].type == int: # # rule.parent.type == str and rule.children[0].type == str: # pass # ast_tree = tree_to_ast(parse_tree) # print astor.to_source(ast_tree) # print parse_tree # except Exception as e: # error_num += 1 # #pass # #print e line_num += 1 print 'total line of code: %d' % line_num grammar = get_grammar(parse_trees) with open(prefix + '.grammar.txt', 'w') as f: for rule in grammar: str = rule.__repr__() f.write(str + '\n') with open(prefix + '.parse_trees.txt', 'w') as f: for tree in parse_trees: f.write(tree.__repr__() + '\n') return grammar, parse_trees def rule_vs_node_stat(): line_num = 0 parse_trees = [] code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/card_datasets/hearthstone/all_hs.out' # '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code' node_nums = rule_nums = 0. for line in open(code_file): code = line.replace('§', '\n').strip() parse_tree = parse(code) node_nums += len(list(parse_tree.nodes)) rules, _ = parse_tree.get_productions() rule_nums += len(rules) parse_trees.append(parse_tree) line_num += 1 print 'avg. nums of nodes: %f' % (node_nums / line_num) print 'avg. nums of rules: %f' % (rule_nums / line_num) def process_heart_stone_dataset(): data_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/card_datasets/hearthstone/all_hs.out' parse_trees = [] rule_num = 0. example_num = 0 for line in open(data_file): code = line.replace('§', '\n').strip() parse_tree = parse(code) # sanity check pred_ast = parse_tree_to_python_ast(parse_tree) pred_code = astor.to_source(pred_ast) ref_ast = ast.parse(code) ref_code = astor.to_source(ref_ast) if pred_code != ref_code: raise RuntimeError('code mismatch!') rules, _ = parse_tree.get_productions(include_value_node=False) rule_num += len(rules) example_num += 1 parse_trees.append(parse_tree) grammar = get_grammar(parse_trees) with open('hs.grammar.txt', 'w') as f: for rule in grammar: str = rule.__repr__() f.write(str + '\n') with open('hs.parse_trees.txt', 'w') as f: for tree in parse_trees: f.write(tree.__repr__() + '\n') print 'avg. nums of rules: %f' % (rule_num / example_num) def canonicalize_hs_example(query, code): query = re.sub(r'<.*?>', '', query) query_tokens = nltk.word_tokenize(query) code = code.replace('§', '\n').strip() # sanity check parse_tree = parse_raw(code) gold_ast_tree = ast.parse(code).body[0] gold_source = astor.to_source(gold_ast_tree) ast_tree = parse_tree_to_python_ast(parse_tree) pred_source = astor.to_source(ast_tree) assert gold_source == pred_source, 'sanity check fails: gold=[%s], actual=[%s]' % (gold_source, pred_source) return query_tokens, code, parse_tree def preprocess_hs_dataset(annot_file, code_file): f = open('hs_dataset.examples.txt', 'w') examples = [] for idx, (annot, code) in enumerate(zip(open(annot_file), open(code_file))): annot = annot.strip() code = code.strip() clean_query_tokens, clean_code, parse_tree = canonicalize_hs_example(annot, code) example = {'id': idx, 'query_tokens': clean_query_tokens, 'code': clean_code, 'parse_tree': parse_tree, 'str_map': None, 'raw_code': code} examples.append(example) f.write('*' * 50 + '\n') f.write('example# %d\n' % idx) f.write(' '.join(clean_query_tokens) + '\n') f.write('\n') f.write(clean_code + '\n') f.write('*' * 50 + '\n') idx += 1 f.close() print 'preprocess_dataset: cleaned example num: %d' % len(examples) return examples def parse_hs_dataset(): MAX_QUERY_LENGTH = 70 # FIXME: figure out the best config! WORD_FREQ_CUT_OFF = 3 annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/card_datasets/hearthstone/all_hs.mod.in' code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/card_datasets/hearthstone/all_hs.out' data = preprocess_hs_dataset(annot_file, code_file) parse_trees = [e['parse_tree'] for e in data] # apply unary closures unary_closures = get_top_unary_closures(parse_trees, k=20) for parse_tree in parse_trees: apply_unary_closures(parse_tree, unary_closures) # build the grammar grammar = get_grammar(parse_trees) with open('hs.grammar.unary_closure.txt', 'w') as f: for rule in grammar: f.write(rule.__repr__() + '\n') annot_tokens = list(chain(*[e['query_tokens'] for e in data])) annot_vocab = gen_vocab(annot_tokens, vocab_size=5000, freq_cutoff=WORD_FREQ_CUT_OFF) def get_terminal_tokens(_terminal_str): """ get terminal tokens break words like MinionCards into [Minion, Cards] """ tmp_terminal_tokens = [t for t in _terminal_str.split(' ') if len(t) > 0] _terminal_tokens = [] for token in tmp_terminal_tokens: sub_tokens = re.sub(r'([a-z])([A-Z])', r'\1 \2', token).split(' ') _terminal_tokens.extend(sub_tokens) _terminal_tokens.append(' ') return _terminal_tokens[:-1] # enumerate all terminal tokens to build up the terminal tokens vocabulary all_terminal_tokens = [] for entry in data: parse_tree = entry['parse_tree'] for node in parse_tree.get_leaves(): if grammar.is_value_node(node): terminal_val = node.value terminal_str = str(terminal_val) terminal_tokens = get_terminal_tokens(terminal_str) for terminal_token in terminal_tokens: assert len(terminal_token) > 0 all_terminal_tokens.append(terminal_token) terminal_vocab = gen_vocab(all_terminal_tokens, vocab_size=5000, freq_cutoff=WORD_FREQ_CUT_OFF) # now generate the dataset! train_data = DataSet(annot_vocab, terminal_vocab, grammar, 'hs.train_data') dev_data = DataSet(annot_vocab, terminal_vocab, grammar, 'hs.dev_data') test_data = DataSet(annot_vocab, terminal_vocab, grammar, 'hs.test_data') all_examples = [] can_fully_reconstructed_examples_num = 0 examples_with_empty_actions_num = 0 for entry in data: idx = entry['id'] query_tokens = entry['query_tokens'] code = entry['code'] parse_tree = entry['parse_tree'] rule_list, rule_parents = parse_tree.get_productions(include_value_node=True) actions = [] can_fully_reconstructed = True rule_pos_map = dict() for rule_count, rule in enumerate(rule_list): if not grammar.is_value_node(rule.parent): assert rule.value is None parent_rule = rule_parents[(rule_count, rule)][0] if parent_rule: parent_t = rule_pos_map[parent_rule] else: parent_t = 0 rule_pos_map[rule] = len(actions) d = {'rule': rule, 'parent_t': parent_t, 'parent_rule': parent_rule} action = Action(APPLY_RULE, d) actions.append(action) else: assert rule.is_leaf parent_rule = rule_parents[(rule_count, rule)][0] parent_t = rule_pos_map[parent_rule] terminal_val = rule.value terminal_str = str(terminal_val) terminal_tokens = get_terminal_tokens(terminal_str) # assert len(terminal_tokens) > 0 for terminal_token in terminal_tokens: term_tok_id = terminal_vocab[terminal_token] tok_src_idx = -1 try: tok_src_idx = query_tokens.index(terminal_token) except ValueError: pass d = {'literal': terminal_token, 'rule': rule, 'parent_rule': parent_rule, 'parent_t': parent_t} # cannot copy, only generation # could be unk! if tok_src_idx < 0 or tok_src_idx >= MAX_QUERY_LENGTH: action = Action(GEN_TOKEN, d) if terminal_token not in terminal_vocab: if terminal_token not in query_tokens: # print terminal_token can_fully_reconstructed = False else: # copy if term_tok_id != terminal_vocab.unk: d['source_idx'] = tok_src_idx action = Action(GEN_COPY_TOKEN, d) else: d['source_idx'] = tok_src_idx action = Action(COPY_TOKEN, d) actions.append(action) d = {'literal': '', 'rule': rule, 'parent_rule': parent_rule, 'parent_t': parent_t} actions.append(Action(GEN_TOKEN, d)) if len(actions) == 0: examples_with_empty_actions_num += 1 continue example = DataEntry(idx, query_tokens, parse_tree, code, actions, {'str_map': None, 'raw_code': entry['raw_code']}) if can_fully_reconstructed: can_fully_reconstructed_examples_num += 1 # train, valid, test splits if 0 <= idx < 533: train_data.add(example) elif idx < 599: dev_data.add(example) else: test_data.add(example) all_examples.append(example) # print statistics max_query_len = max(len(e.query) for e in all_examples) max_actions_len = max(len(e.actions) for e in all_examples) # serialize_to_file([len(e.query) for e in all_examples], 'query.len') # serialize_to_file([len(e.actions) for e in all_examples], 'actions.len') logging.info('examples that can be fully reconstructed: %d/%d=%f', can_fully_reconstructed_examples_num, len(all_examples), can_fully_reconstructed_examples_num / len(all_examples)) logging.info('empty_actions_count: %d', examples_with_empty_actions_num) logging.info('max_query_len: %d', max_query_len) logging.info('max_actions_len: %d', max_actions_len) train_data.init_data_matrices(max_query_length=70, max_example_action_num=350) dev_data.init_data_matrices(max_query_length=70, max_example_action_num=350) test_data.init_data_matrices(max_query_length=70, max_example_action_num=350) serialize_to_file((train_data, dev_data, test_data), 'data/hs.freq{WORD_FREQ_CUT_OFF}.max_action350.pre_suf.unary_closure.bin'.format(WORD_FREQ_CUT_OFF=WORD_FREQ_CUT_OFF)) return train_data, dev_data, test_data def dump_data_for_evaluation(data_type='django', data_file='', max_query_length=70): train_data, dev_data, test_data = deserialize_from_file(data_file) prefix = '/Users/yinpengcheng/Projects/dl4mt-tutorial/codegen_data/' for dataset, output in [(train_data, prefix + '%s.train' % data_type), (dev_data, prefix + '%s.dev' % data_type), (test_data, prefix + '%s.test' % data_type)]: f_source = open(output + '.desc', 'w') f_target = open(output + '.code', 'w') for e in dataset.examples: query_tokens = e.query[:max_query_length] code = e.code if data_type == 'django': target_code = de_canonicalize_code_for_seq2seq(code, e.meta_data['raw_code']) else: target_code = code # tokenize code target_code = target_code.strip() tokenized_target = tokenize_code_adv(target_code, breakCamelStr=False if data_type=='django' else True) tokenized_target = [tk.replace('\n', '#NEWLINE#') for tk in tokenized_target] tokenized_target = [tk for tk in tokenized_target if tk is not None] while tokenized_target[-1] == '#INDENT#': tokenized_target = tokenized_target[:-1] f_source.write(' '.join(query_tokens) + '\n') f_target.write(' '.join(tokenized_target) + '\n') f_source.close() f_target.close() if __name__ == '__main__': init_logging('py.log') # rule_vs_node_stat() # process_heart_stone_dataset() parse_hs_dataset() # dump_data_for_evaluation(data_file='data/django.cleaned.dataset.freq5.par_info.refact.space_only.bin') # dump_data_for_evaluation(data_type='hs', data_file='data/hs.freq3.pre_suf.unary_closure.bin') # code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code' # py_grammar, _ = extract_grammar(code_file) # serialize_to_file(py_grammar, 'py_grammar.bin') ================================================ FILE: lang/py/seq2tree_exp.py ================================================ import logging import re from collections import defaultdict, OrderedDict from itertools import chain import sys from astnode import ASTNode from dataset import preprocess_dataset, gen_vocab from lang.py.grammar import type_str_to_type from lang.py.parse import parse, get_grammar, decode_tree_to_python_ast from lang.py.unaryclosure import get_top_unary_closures, apply_unary_closures from lang.util import typename, escape, unescape from nn.utils.generic_utils import init_logging from nn.utils.io_utils import serialize_to_file def ast_tree_to_seq2tree_repr(tree): repr_str = '' # node_name = typename(tree.type) label_val = '' if tree.label is None else tree.label value = '' if tree.value is None else tree.value node_name = '%s{%s}{%s}' % (typename(tree.type), label_val, value) repr_str += node_name # wrap children with parentheses if tree.children: repr_str += ' (' for child in tree.children: child_repr = ast_tree_to_seq2tree_repr(child) repr_str += ' ' + child_repr repr_str += ' )' return repr_str node_re = re.compile(r'(?P.*?)\{(?P