[
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nenv/\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\n*.egg-info/\n.installed.cfg\n*.egg\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*,cover\n.hypothesis/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# IPython Notebook\n.ipynb_checkpoints\n\n# pyenv\n.python-version\n\n# celery beat schedule file\ncelerybeat-schedule\n\n# dotenv\n.env\n\n# virtualenv\nvenv/\nENV/\n\n# Spyder project settings\n.spyderproject\n\n# Rope project settings\n.ropeproject\n"
  },
  {
    "path": "README.md",
    "content": "# NL2code\n\nA syntactic neural model for parsing natural language to executable code [paper](https://arxiv.org/abs/1704.01696). \n\n## Dataset and Trained Models\n\nGet serialized datasets and trained models from [here](https://drive.google.com/drive/folders/0B14lJ2VVvtmJWEQ5RlFjQUY2Vzg). Put `models/` and `data/` folders under the root directory of the project.\n\n## Usage\n\nTo train new model\n\n```bash\n. train.sh [hs|django]\n```\n\nTo use trained model for decoding test sets\n\n```bash\n. run_trained_model.sh [hs|django]\n```\n\n## Dependencies\n\n* Theano\n* vprof\n* NLTK 3.2.1\n* astor 0.6\n\n## Reference\n\n```\n@inproceedings{yin17acl,\n    title = {A Syntactic Neural Model for General-Purpose Code Generation},\n    author = {Pengcheng Yin and Graham Neubig},\n    booktitle = {The 55th Annual Meeting of the Association for Computational Linguistics (ACL)},\n    address = {Vancouver, Canada},\n    month = {July},\n    url = {https://arxiv.org/abs/1704.01696},\n    year = {2017}\n}\n```\n"
  },
  {
    "path": "astnode.py",
    "content": "from collections import namedtuple\nimport cPickle\nfrom collections import Iterable, OrderedDict, defaultdict\nfrom cStringIO import StringIO\n\nfrom lang.util import typename\n\nclass ASTNode(object):\n    def __init__(self, node_type, label=None, value=None, children=None):\n        self.type = node_type\n        self.label = label\n        self.value = value\n\n        if type(self) is not Rule:\n            self.parent = None\n\n        self.children = list()\n\n        if children:\n            if isinstance(children, Iterable):\n                for child in children:\n                    self.add_child(child)\n            elif isinstance(children, ASTNode):\n                self.add_child(children)\n            else:\n                raise AttributeError('Wrong type for child nodes')\n\n        assert not (bool(children) and bool(value)), 'terminal node with a value cannot have children'\n\n    @property\n    def is_leaf(self):\n        return len(self.children) == 0\n\n    @property\n    def is_preterminal(self):\n        return len(self.children) == 1 and self.children[0].is_leaf\n\n    @property\n    def size(self):\n        if self.is_leaf:\n            return 1\n\n        node_num = 1\n        for child in self.children:\n            node_num += child.size\n\n        return node_num\n\n    @property\n    def nodes(self):\n        \"\"\"a generator that returns all the nodes\"\"\"\n\n        yield self\n        for child in self.children:\n            for child_n in child.nodes:\n                yield child_n\n\n    @property\n    def as_type_node(self):\n        \"\"\"return an ASTNode with type information only\"\"\"\n        return ASTNode(self.type)\n\n    def __repr__(self):\n        repr_str = ''\n        # if not self.is_leaf:\n        repr_str += '('\n\n        repr_str += typename(self.type)\n\n        if self.label is not None:\n            repr_str += '{%s}' % self.label\n\n        if self.value is not None:\n            repr_str += '{val=%s}' % self.value\n\n        # if not self.is_leaf:\n        for child in self.children:\n            repr_str += ' ' + child.__repr__()\n        repr_str += ')'\n\n        return repr_str\n\n    def __hash__(self):\n        code = hash(self.type)\n        if self.label is not None:\n            code = code * 37 + hash(self.label)\n        if self.value is not None:\n            code = code * 37 + hash(self.value)\n        for child in self.children:\n            code = code * 37 + hash(child)\n\n        return code\n\n    def __eq__(self, other):\n        if not isinstance(other, self.__class__):\n            return False\n        if hash(self) != hash(other):\n            return False\n\n        if self.type != other.type:\n            return False\n\n        if self.label != other.label:\n            return False\n\n        if self.value != other.value:\n            return False\n\n        if len(self.children) != len(other.children):\n            return False\n\n        for i in xrange(len(self.children)):\n            if self.children[i] != other.children[i]:\n                return False\n\n        return True\n\n    def __ne__(self, other):\n        return not self.__eq__(other)\n\n    def __getitem__(self, child_type):\n        return next(iter([c for c in self.children if c.type == child_type]))\n\n    def __delitem__(self, child_type):\n        tgt_child = [c for c in self.children if c.type == child_type]\n        if tgt_child:\n            assert len(tgt_child) == 1, 'unsafe deletion for more than one children'\n            tgt_child = tgt_child[0]\n            self.children.remove(tgt_child)\n        else:\n            raise KeyError\n\n    def add_child(self, child):\n        child.parent = self\n        self.children.append(child)\n\n    def get_child_id(self, child):\n        for i, _child in enumerate(self.children):\n            if child == _child:\n                return i\n\n        raise KeyError\n\n    def pretty_print(self):\n        sb = StringIO()\n        new_line = False\n        self.pretty_print_helper(sb, 0, new_line)\n        return sb.getvalue()\n\n    def pretty_print_helper(self, sb, depth, new_line=False):\n        if new_line:\n            sb.write('\\n')\n            for i in xrange(depth): sb.write(' ')\n\n        sb.write('(')\n        sb.write(typename(self.type))\n        if self.label is not None:\n            sb.write('{%s}' % self.label)\n\n        if self.value is not None:\n            sb.write('{val=%s}' % self.value)\n\n        if len(self.children) == 0:\n            sb.write(')')\n            return\n\n        sb.write(' ')\n        new_line = True\n        for child in self.children:\n            child.pretty_print_helper(sb, depth + 2, new_line)\n\n        sb.write('\\n')\n        for i in xrange(depth): sb.write(' ')\n        sb.write(')')\n\n    def get_leaves(self):\n        if self.is_leaf:\n            return [self]\n\n        leaves = []\n        for child in self.children:\n            leaves.extend(child.get_leaves())\n\n        return leaves\n\n    def to_rule(self, include_value=False):\n        \"\"\"\n        transform the current AST node to a production rule\n        \"\"\"\n        rule = Rule(self.type)\n        for c in self.children:\n            val = c.value if include_value else None\n            child = ASTNode(c.type, c.label, val)\n            rule.add_child(child)\n\n        return rule\n\n    def get_productions(self, include_value_node=False):\n        \"\"\"\n        get the depth-first, left-to-right sequence of rule applications\n        returns a list of production rules and a map to their parent rules\n        attention: node value is not included in child nodes\n        \"\"\"\n        rule_list = list()\n        rule_parents = OrderedDict()\n        node_rule_map = dict()\n        s = list()\n        s.append(self)\n        rule_num = 0\n\n        while len(s) > 0:\n            node = s.pop()\n            for child in reversed(node.children):\n                if not child.is_leaf:\n                    s.append(child)\n                elif include_value_node:\n                    if child.value is not None:\n                        s.append(child)\n\n            # only non-terminals and terminal nodes holding values\n            # can form a production rule\n            if node.children or node.value is not None:\n                rule = Rule(node.type)\n                if include_value_node:\n                    rule.value = node.value\n\n                for c in node.children:\n                    val = None\n                    child = ASTNode(c.type, c.label, val)\n                    rule.add_child(child)\n\n                rule_list.append(rule)\n                if node.parent:\n                    child_id = node.parent.get_child_id(node)\n                    parent_rule = node_rule_map[node.parent]\n                    rule_parents[(rule_num, rule)] = (parent_rule, child_id)\n                else:\n                    rule_parents[(rule_num, rule)] = (None, -1)\n                rule_num += 1\n\n                node_rule_map[node] = rule\n\n        return rule_list, rule_parents\n\n    def copy(self):\n        # if not hasattr(self, '_dump'):\n        #     dump = cPickle.dumps(self, -1)\n        #     setattr(self, '_dump', dump)\n        #\n        #     return cPickle.loads(dump)\n        #\n        # return cPickle.loads(self._dump)\n\n        new_tree = ASTNode(self.type, self.label, self.value)\n        if self.is_leaf:\n            return new_tree\n\n        for child in self.children:\n            new_tree.add_child(child.copy())\n\n        return new_tree\n\n\nclass DecodeTree(ASTNode):\n    def __init__(self, node_type, label=None, value=None, children=None, t=-1):\n        super(DecodeTree, self).__init__(node_type, label, value, children)\n\n        # record the time step when this subtree is created from a rule application\n        self.t = t\n        # record the ApplyRule action that is used to expand the current node\n        self.applied_rule = None\n\n    def copy(self):\n        new_tree = DecodeTree(self.type, self.label, value=self.value, t=self.t)\n        new_tree.applied_rule = self.applied_rule\n        if self.is_leaf:\n            return new_tree\n\n        for child in self.children:\n            new_tree.add_child(child.copy())\n\n        return new_tree\n\n\nclass Rule(ASTNode):\n    def __init__(self, *args, **kwargs):\n        super(Rule, self).__init__(*args, **kwargs)\n\n        assert self.value is None and self.label is None, 'Rule LHS cannot have values or labels'\n\n    @property\n    def parent(self):\n        return self.as_type_node\n\n    def __repr__(self):\n        parent = typename(self.type)\n\n        if self.label is not None:\n            parent += '{%s}' % self.label\n\n        if self.value is not None:\n            parent += '{val=%s}' % self.value\n\n        return '%s -> %s' % (parent, ', '.join([repr(c) for c in self.children]))\n\n\nif __name__ == '__main__':\n    import ast\n    t1 = ASTNode('root', children=[\n        ASTNode(str, 'a1_label', children=[ASTNode(int, children=[ASTNode('a21', value=123)]),\n                                            ASTNode(ast.NodeTransformer, children=[ASTNode('a21', value='hahaha')])]\n                ),\n        ASTNode('a2', children=[ASTNode('a21', value='asdf')])\n    ])\n\n    t2 = ASTNode('root', children=[\n        ASTNode(str, 'a1_label', children=[ASTNode(int, children=[ASTNode('a21', value=123)]),\n                                           ASTNode(ast.NodeTransformer, children=[ASTNode('a21', value='hahaha')])]\n                ),\n        ASTNode('a2', children=[ASTNode('a21', value='asdf')])\n    ])\n\n    print t1 == t2\n\n\n    a, b = t1.get_productions(include_value_node=True)\n\n    # t = ASTNode('root', children=ASTNode('sdf'))\n\n    print t1.__repr__()\n    print t1.pretty_print()"
  },
  {
    "path": "code_gen.py",
    "content": "import numpy as np\nimport cProfile\nimport ast\nimport traceback\nimport argparse\nimport os\nimport logging\nfrom vprof import profiler\n\nfrom model import Model\nfrom dataset import DataEntry, DataSet, Vocab, Action\nimport config\nfrom learner import Learner\nfrom evaluation import *\nfrom decoder import decode_python_dataset\nfrom components import Hyp\nfrom astnode import ASTNode\n\nfrom nn.utils.generic_utils import init_logging\nfrom nn.utils.io_utils import deserialize_from_file, serialize_to_file\n\nparser = argparse.ArgumentParser()\nparser.add_argument('-data')\nparser.add_argument('-random_seed', default=181783, type=int)\nparser.add_argument('-output_dir', default='.outputs')\nparser.add_argument('-model', default=None)\n\n# model's main configuration\nparser.add_argument('-data_type', default='django', choices=['django', 'ifttt', 'hs'])\n\n# neural model's parameters\nparser.add_argument('-source_vocab_size', default=0, type=int)\nparser.add_argument('-target_vocab_size', default=0, type=int)\nparser.add_argument('-rule_num', default=0, type=int)\nparser.add_argument('-node_num', default=0, type=int)\n\nparser.add_argument('-word_embed_dim', default=128, type=int)\nparser.add_argument('-rule_embed_dim', default=256, type=int)\nparser.add_argument('-node_embed_dim', default=256, type=int)\nparser.add_argument('-encoder_hidden_dim', default=256, type=int)\nparser.add_argument('-decoder_hidden_dim', default=256, type=int)\nparser.add_argument('-attention_hidden_dim', default=50, type=int)\nparser.add_argument('-ptrnet_hidden_dim', default=50, type=int)\nparser.add_argument('-dropout', default=0.2, type=float)\n\n# encoder\nparser.add_argument('-encoder', default='bilstm', choices=['bilstm', 'lstm'])\n\n# decoder\nparser.add_argument('-parent_hidden_state_feed', dest='parent_hidden_state_feed', action='store_true')\nparser.add_argument('-no_parent_hidden_state_feed', dest='parent_hidden_state_feed', action='store_false')\nparser.set_defaults(parent_hidden_state_feed=True)\n\nparser.add_argument('-parent_action_feed', dest='parent_action_feed', action='store_true')\nparser.add_argument('-no_parent_action_feed', dest='parent_action_feed', action='store_false')\nparser.set_defaults(parent_action_feed=True)\n\nparser.add_argument('-frontier_node_type_feed', dest='frontier_node_type_feed', action='store_true')\nparser.add_argument('-no_frontier_node_type_feed', dest='frontier_node_type_feed', action='store_false')\nparser.set_defaults(frontier_node_type_feed=True)\n\nparser.add_argument('-tree_attention', dest='tree_attention', action='store_true')\nparser.add_argument('-no_tree_attention', dest='tree_attention', action='store_false')\nparser.set_defaults(tree_attention=False)\n\nparser.add_argument('-enable_copy', dest='enable_copy', action='store_true')\nparser.add_argument('-no_copy', dest='enable_copy', action='store_false')\nparser.set_defaults(enable_copy=True)\n\n# training\nparser.add_argument('-optimizer', default='adam')\nparser.add_argument('-clip_grad', default=0., type=float)\nparser.add_argument('-train_patience', default=10, type=int)\nparser.add_argument('-max_epoch', default=50, type=int)\nparser.add_argument('-batch_size', default=10, type=int)\nparser.add_argument('-valid_per_batch', default=4000, type=int)\nparser.add_argument('-save_per_batch', default=4000, type=int)\nparser.add_argument('-valid_metric', default='bleu')\n\n# decoding\nparser.add_argument('-beam_size', default=15, type=int)\nparser.add_argument('-max_query_length', default=70, type=int)\nparser.add_argument('-decode_max_time_step', default=100, type=int)\nparser.add_argument('-head_nt_constraint', dest='head_nt_constraint', action='store_true')\nparser.add_argument('-no_head_nt_constraint', dest='head_nt_constraint', action='store_false')\nparser.set_defaults(head_nt_constraint=True)\n\nsub_parsers = parser.add_subparsers(dest='operation', help='operation to take')\ntrain_parser = sub_parsers.add_parser('train')\ndecode_parser = sub_parsers.add_parser('decode')\ninteractive_parser = sub_parsers.add_parser('interactive')\nevaluate_parser = sub_parsers.add_parser('evaluate')\n\n# decoding operation\ndecode_parser.add_argument('-saveto', default='decode_results.bin')\ndecode_parser.add_argument('-type', default='test_data')\n\n# evaluation operation\nevaluate_parser.add_argument('-mode', default='self')\nevaluate_parser.add_argument('-input', default='decode_results.bin')\nevaluate_parser.add_argument('-type', default='test_data')\nevaluate_parser.add_argument('-seq2tree_sample_file', default='model.sample')\nevaluate_parser.add_argument('-seq2tree_id_file', default='test.id.txt')\nevaluate_parser.add_argument('-seq2tree_rareword_map', default=None)\nevaluate_parser.add_argument('-seq2seq_decode_file')\nevaluate_parser.add_argument('-seq2seq_ref_file')\nevaluate_parser.add_argument('-is_nbest', default=False, action='store_true')\n\n# misc\nparser.add_argument('-ifttt_test_split', default='data/ifff.test_data.gold.id')\n\n# interactive operation\ninteractive_parser.add_argument('-mode', default='dataset')\n\nif __name__ == '__main__':\n    args = parser.parse_args()\n\n    if not os.path.exists(args.output_dir):\n        os.makedirs(args.output_dir)\n\n    np.random.seed(args.random_seed)\n    init_logging(os.path.join(args.output_dir, 'parser.log'), logging.INFO)\n    logging.info('command line: %s', ' '.join(sys.argv))\n\n    logging.info('loading dataset [%s]', args.data)\n    train_data, dev_data, test_data = deserialize_from_file(args.data)\n\n    if not args.source_vocab_size:\n        args.source_vocab_size = train_data.annot_vocab.size\n    if not args.target_vocab_size:\n        args.target_vocab_size = train_data.terminal_vocab.size\n    if not args.rule_num:\n        args.rule_num = len(train_data.grammar.rules)\n    if not args.node_num:\n        args.node_num = len(train_data.grammar.node_type_to_id)\n\n    logging.info('current config: %s', args)\n    config_module = sys.modules['config']\n    for name, value in vars(args).iteritems():\n        setattr(config_module, name, value)\n\n    # get dataset statistics\n    avg_action_num = np.average([len(e.actions) for e in train_data.examples])\n    logging.info('avg_action_num: %d', avg_action_num)\n\n    logging.info('grammar rule num.: %d', len(train_data.grammar.rules))\n    logging.info('grammar node type num.: %d', len(train_data.grammar.node_type_to_id))\n\n    logging.info('source vocab size: %d', train_data.annot_vocab.size)\n    logging.info('target vocab size: %d', train_data.terminal_vocab.size)\n\n    if args.operation in ['train', 'decode', 'interactive']:\n        model = Model()\n        model.build()\n\n        if args.model:\n            model.load(args.model)\n\n    if args.operation == 'train':\n        # train_data = train_data.get_dataset_by_ids(range(2000), 'train_sample')\n        # dev_data = dev_data.get_dataset_by_ids(range(10), 'dev_sample')\n        learner = Learner(model, train_data, dev_data)\n        learner.train()\n\n    if args.operation == 'decode':\n        # ==========================\n        # investigate short examples\n        # ==========================\n\n        # short_examples = [e for e in test_data.examples if e.parse_tree.size <= 2]\n        # for e in short_examples:\n        #     print e.parse_tree\n        # print 'short examples num: ', len(short_examples)\n\n        # dataset = test_data # test_data.get_dataset_by_ids([1,2,3,4,5,6,7,8,9,10], name='sample')\n        # cProfile.run('decode_dataset(model, dataset)', sort=2)\n\n        # from evaluation import decode_and_evaluate_ifttt\n        if args.data_type == 'ifttt':\n            decode_results = decode_and_evaluate_ifttt_by_split(model, test_data)\n        else:\n            dataset = eval(args.type)\n            decode_results = decode_python_dataset(model, dataset)\n\n        serialize_to_file(decode_results, args.saveto)\n\n    if args.operation == 'evaluate':\n        dataset = eval(args.type)\n        if config.mode == 'self':\n            decode_results_file = args.input\n            decode_results = deserialize_from_file(decode_results_file)\n\n            evaluate_decode_results(dataset, decode_results)\n        elif config.mode == 'seq2tree':\n            from evaluation import evaluate_seq2tree_sample_file\n            evaluate_seq2tree_sample_file(config.seq2tree_sample_file, config.seq2tree_id_file, dataset)\n        elif config.mode == 'seq2seq':\n            from evaluation import evaluate_seq2seq_decode_results\n            evaluate_seq2seq_decode_results(dataset, config.seq2seq_decode_file, config.seq2seq_ref_file, is_nbest=config.is_nbest)\n        elif config.mode == 'analyze':\n            from evaluation import analyze_decode_results\n\n            decode_results_file = args.input\n            decode_results = deserialize_from_file(decode_results_file)\n            analyze_decode_results(dataset, decode_results)\n\n    if args.operation == 'interactive':\n        from dataset import canonicalize_query, query_to_data\n        from collections import namedtuple\n        from lang.py.parse import decode_tree_to_python_ast\n        assert model is not None\n\n        while True:\n            cmd = raw_input('example id or query: ')\n            if args.mode == 'dataset':\n                try:\n                    example_id = int(cmd)\n                    example = [e for e in test_data.examples if e.raw_id == example_id][0]\n                except:\n                    print 'something went wrong ...'\n                    continue\n            elif args.mode == 'new':\n                # we play with new examples!\n                query, str_map = canonicalize_query(cmd)\n                vocab = train_data.annot_vocab\n                query_tokens = query.split(' ')\n                query_tokens_data = [query_to_data(query, vocab)]\n                example = namedtuple('example', ['query', 'data'])(query=query_tokens, data=query_tokens_data)\n\n            if hasattr(example, 'parse_tree'):\n                print 'gold parse tree:'\n                print example.parse_tree\n\n            cand_list = model.decode(example, train_data.grammar, train_data.terminal_vocab,\n                                     beam_size=args.beam_size, max_time_step=args.decode_max_time_step, log=True)\n\n            has_grammar_error = any([c for c in cand_list if c.has_grammar_error])\n            print 'has_grammar_error: ', has_grammar_error\n\n            for cid, cand in enumerate(cand_list[:5]):\n                print '*' * 60\n                print 'cand #%d, score: %f' % (cid, cand.score)\n\n                try:\n                    ast_tree = decode_tree_to_python_ast(cand.tree)\n                    code = astor.to_source(ast_tree)\n                    print 'code: ', code\n                    print 'decode log: ', cand.log\n                except:\n                    print \"Exception in converting tree to code:\"\n                    print '-' * 60\n                    print 'raw_id: %d, beam pos: %d' % (example.raw_id, cid)\n                    traceback.print_exc(file=sys.stdout)\n                    print '-' * 60\n                finally:\n                    print '* parse tree *'\n                    print cand.tree.__repr__()\n                    print 'n_timestep: %d' % cand.n_timestep\n                    print 'ast size: %d' % cand.tree.size\n                    print '*' * 60\n"
  },
  {
    "path": "components.py",
    "content": "import theano\nimport theano.tensor as T\nimport numpy as np\nimport logging\nimport copy\n\nfrom nn.layers.embeddings import Embedding\nfrom nn.layers.core import Dense, Layer\nfrom nn.layers.recurrent import BiLSTM, LSTM, CondAttLSTM\nfrom nn.utils.theano_utils import ndim_itensor, tensor_right_shift, ndim_tensor, alloc_zeros_matrix, shared_zeros\nimport nn.initializations as initializations\nimport nn.activations as activations\nimport nn.optimizers as optimizers\n\nimport config\nfrom lang.grammar import Grammar\nfrom parse import *\nfrom astnode import *\n\n\nclass PointerNet(Layer):\n    def __init__(self, name='PointerNet'):\n        super(PointerNet, self).__init__()\n\n        self.dense1_input = Dense(config.encoder_hidden_dim, config.ptrnet_hidden_dim, activation='linear', name='Dense1_input')\n\n        self.dense1_h = Dense(config.decoder_hidden_dim + config.encoder_hidden_dim, config.ptrnet_hidden_dim, activation='linear', name='Dense1_h')\n\n        self.dense2 = Dense(config.ptrnet_hidden_dim, 1, activation='linear', name='Dense2')\n\n        self.params += self.dense1_input.params + self.dense1_h.params + self.dense2.params\n\n        self.set_name(name)\n\n    def __call__(self, query_embed, query_token_embed_mask, decoder_states):\n        query_embed_trans = self.dense1_input(query_embed)\n        h_trans = self.dense1_h(decoder_states)\n\n        query_embed_trans = query_embed_trans.dimshuffle((0, 'x', 1, 2))\n        h_trans = h_trans.dimshuffle((0, 1, 'x', 2))\n\n        # (batch_size, max_decode_step, query_token_num, ptr_net_hidden_dim)\n        dense1_trans = T.tanh(query_embed_trans + h_trans)\n\n        scores = self.dense2(dense1_trans).flatten(3)\n\n        scores = T.exp(scores - T.max(scores, axis=-1, keepdims=True))\n        scores *= query_token_embed_mask.dimshuffle((0, 'x', 1))\n        scores = scores / T.sum(scores, axis=-1, keepdims=True)\n\n        return scores\n\nclass Hyp:\n    def __init__(self, *args):\n        if isinstance(args[0], Hyp):\n            hyp = args[0]\n            self.grammar = hyp.grammar\n            self.tree = hyp.tree.copy()\n            self.t = hyp.t\n            self.hist_h = list(hyp.hist_h)\n            self.log = hyp.log\n            self.has_grammar_error = hyp.has_grammar_error\n        else:\n            assert isinstance(args[0], Grammar)\n            grammar = args[0]\n            self.grammar = grammar\n            self.tree = DecodeTree(grammar.root_node.type)\n            self.t=-1\n            self.hist_h = []\n            self.log = ''\n            self.has_grammar_error = False\n\n        self.score = 0.0\n\n        self.__frontier_nt = self.tree\n        self.__frontier_nt_t = -1\n\n    def __repr__(self):\n        return self.tree.__repr__()\n\n    def can_expand(self, node):\n        if self.grammar.is_value_node(node):\n            # if the node is finished\n            if node.value is not None and node.value.endswith('<eos>'):\n                return False\n            return True\n        elif self.grammar.is_terminal(node):\n            return False\n\n        # elif node.type == 'epsilon':\n        #     return False\n        # elif is_terminal_ast_type(node.type):\n        #     return False\n\n        # if node.type == 'root':\n        #     return True\n        # elif inspect.isclass(node.type) and issubclass(node.type, ast.AST) and not is_terminal_ast_type(node.type):\n        #     return True\n        # elif node.holds_value and not node.label.endswith('<eos>'):\n        #     return True\n\n        return True\n\n    def apply_rule(self, rule, nt=None):\n        if nt is None:\n            nt = self.frontier_nt()\n\n        # assert rule.parent.type == nt.type\n        if rule.parent.type != nt.type:\n            self.has_grammar_error = True\n\n        self.t += 1\n        # set the time step when the rule leading by this nt is applied\n        nt.t = self.t\n        # record the ApplyRule action that is used to expand the current node\n        nt.applied_rule = rule\n\n        for child_node in rule.children:\n            child = DecodeTree(child_node.type, child_node.label, child_node.value)\n            # if is_builtin_type(rule.parent.type):\n            #     child.label = None\n            #     child.holds_value = True\n\n            nt.add_child(child)\n\n    def append_token(self, token, nt=None):\n        if nt is None:\n            nt = self.frontier_nt()\n\n        self.t += 1\n\n        if nt.value is None:\n            # this terminal node is empty\n            nt.t = self.t\n            nt.value = token\n        else:\n            nt.value += token\n\n    def frontier_nt_helper(self, node):\n        if node.is_leaf:\n            if self.can_expand(node):\n                return node\n            else:\n                return None\n\n        for child in node.children:\n            result = self.frontier_nt_helper(child)\n            if result:\n                return result\n\n        return None\n\n    def frontier_nt(self):\n        if self.__frontier_nt_t == self.t:\n            return self.__frontier_nt\n        else:\n            _frontier_nt = self.frontier_nt_helper(self.tree)\n            self.__frontier_nt = _frontier_nt\n            self.__frontier_nt_t = self.t\n\n            return _frontier_nt\n\n    def get_action_parent_t(self):\n        \"\"\"\n        get the time step when the parent of the current\n        action was generated\n        WARNING: 0 will be returned if parent if None\n        \"\"\"\n        nt = self.frontier_nt()\n\n        # if nt is a non-finishing leaf\n        # if nt.holds_value:\n        #     return nt.t\n\n        if nt.parent:\n            return nt.parent.t\n        else:\n            return 0\n\n    # def get_action_parent_tree(self):\n    #     \"\"\"\n    #     get the parent tree\n    #     \"\"\"\n    #     nt = self.frontier_nt()\n    #\n    #     # if nt is a non-finishing leaf\n    #     if nt.holds_value:\n    #         return nt\n    #\n    #     if nt.parent:\n    #         return nt.parent\n    #     else:\n    #         return None\n\nclass CondAttLSTM(Layer):\n    \"\"\"\n    Conditional LSTM with Attention\n    \"\"\"\n    def __init__(self, input_dim, output_dim,\n                 context_dim, att_hidden_dim,\n                 init='glorot_uniform', inner_init='orthogonal', forget_bias_init='one',\n                 activation='tanh', inner_activation='sigmoid', name='CondAttLSTM'):\n\n        super(CondAttLSTM, self).__init__()\n\n        self.output_dim = output_dim\n        self.init = initializations.get(init)\n        self.inner_init = initializations.get(inner_init)\n        self.forget_bias_init = initializations.get(forget_bias_init)\n        self.activation = activations.get(activation)\n        self.inner_activation = activations.get(inner_activation)\n        self.context_dim = context_dim\n        self.input_dim = input_dim\n\n        # regular LSTM layer\n\n        self.W_i = self.init((input_dim, self.output_dim))\n        self.U_i = self.inner_init((self.output_dim, self.output_dim))\n        self.C_i = self.inner_init((self.context_dim, self.output_dim))\n        self.H_i = self.inner_init((self.output_dim, self.output_dim))\n        self.P_i = self.inner_init((self.output_dim, self.output_dim))\n        self.b_i = shared_zeros((self.output_dim))\n\n        self.W_f = self.init((input_dim, self.output_dim))\n        self.U_f = self.inner_init((self.output_dim, self.output_dim))\n        self.C_f = self.inner_init((self.context_dim, self.output_dim))\n        self.H_f = self.inner_init((self.output_dim, self.output_dim))\n        self.P_f = self.inner_init((self.output_dim, self.output_dim))\n        self.b_f = self.forget_bias_init((self.output_dim))\n\n        self.W_c = self.init((input_dim, self.output_dim))\n        self.U_c = self.inner_init((self.output_dim, self.output_dim))\n        self.C_c = self.inner_init((self.context_dim, self.output_dim))\n        self.H_c = self.inner_init((self.output_dim, self.output_dim))\n        self.P_c = self.inner_init((self.output_dim, self.output_dim))\n        self.b_c = shared_zeros((self.output_dim))\n\n        self.W_o = self.init((input_dim, self.output_dim))\n        self.U_o = self.inner_init((self.output_dim, self.output_dim))\n        self.C_o = self.inner_init((self.context_dim, self.output_dim))\n        self.H_o = self.inner_init((self.output_dim, self.output_dim))\n        self.P_o = self.inner_init((self.output_dim, self.output_dim))\n        self.b_o = shared_zeros((self.output_dim))\n\n        self.params = [\n            self.W_i, self.U_i, self.b_i, self.C_i, self.H_i, self.P_i,\n            self.W_c, self.U_c, self.b_c, self.C_c, self.H_c, self.P_c,\n            self.W_f, self.U_f, self.b_f, self.C_f, self.H_f, self.P_f,\n            self.W_o, self.U_o, self.b_o, self.C_o, self.H_o, self.P_o,\n        ]\n\n        # attention layer\n        self.att_ctx_W1 = self.init((context_dim, att_hidden_dim))\n        self.att_h_W1 = self.init((output_dim, att_hidden_dim))\n        self.att_b1 = shared_zeros((att_hidden_dim))\n\n        self.att_W2 = self.init((att_hidden_dim, 1))\n        self.att_b2 = shared_zeros((1))\n\n        self.params += [\n            self.att_ctx_W1, self.att_h_W1, self.att_b1,\n            self.att_W2, self.att_b2\n        ]\n\n        # attention over history\n        self.hatt_h_W1 = self.init((output_dim, att_hidden_dim))\n        self.hatt_hist_W1 = self.init((output_dim, att_hidden_dim))\n        self.hatt_b1 = shared_zeros((att_hidden_dim))\n\n        self.hatt_W2 = self.init((att_hidden_dim, 1))\n        self.hatt_b2 = shared_zeros((1))\n\n        self.params += [\n            self.hatt_h_W1, self.hatt_hist_W1, self.hatt_b1,\n            self.hatt_W2, self.hatt_b2\n        ]\n\n        self.set_name(name)\n\n    def _step(self,\n              t, xi_t, xf_t, xo_t, xc_t, mask_t, parent_t,\n              h_tm1, c_tm1, hist_h,\n              u_i, u_f, u_o, u_c,\n              c_i, c_f, c_o, c_c,\n              h_i, h_f, h_o, h_c,\n              p_i, p_f, p_o, p_c,\n              att_h_w1, att_w2, att_b2,\n              context, context_mask, context_att_trans,\n              b_u):\n\n        # context: (batch_size, context_size, context_dim)\n\n        # (batch_size, att_layer1_dim)\n        h_tm1_att_trans = T.dot(h_tm1, att_h_w1)\n\n        # h_tm1_att_trans = theano.printing.Print('h_tm1_att_trans')(h_tm1_att_trans)\n\n        # (batch_size, context_size, att_layer1_dim)\n        att_hidden = T.tanh(context_att_trans + h_tm1_att_trans[:, None, :])\n        # (batch_size, context_size, 1)\n        att_raw = T.dot(att_hidden, att_w2) + att_b2\n        att_raw = att_raw.reshape((att_raw.shape[0], att_raw.shape[1]))\n\n        # (batch_size, context_size)\n        ctx_att = T.exp(att_raw - T.max(att_raw, axis=-1, keepdims=True))\n\n        if context_mask:\n            ctx_att = ctx_att * context_mask\n\n        ctx_att = ctx_att / T.sum(ctx_att, axis=-1, keepdims=True)\n        # (batch_size, context_dim)\n        ctx_vec = T.sum(context * ctx_att[:, :, None], axis=1)\n\n        # t = theano.printing.Print('t')(t)\n\n        ##### attention over history #####\n\n        def _attention_over_history():\n            hist_h_mask = T.zeros((hist_h.shape[0], hist_h.shape[1]), dtype='int8')\n            hist_h_mask = T.set_subtensor(hist_h_mask[:, T.arange(t)], 1)\n\n            hist_h_att_trans = T.dot(hist_h, self.hatt_hist_W1) + self.hatt_b1\n            h_tm1_hatt_trans = T.dot(h_tm1, self.hatt_h_W1)\n\n            hatt_hidden = T.tanh(hist_h_att_trans + h_tm1_hatt_trans[:, None, :])\n            hatt_raw = T.dot(hatt_hidden, self.hatt_W2) + self.hatt_b2\n            hatt_raw = hatt_raw.reshape((hist_h.shape[0], hist_h.shape[1]))\n            # hatt_raw = theano.printing.Print('hatt_raw')(hatt_raw)\n            hatt_exp = T.exp(hatt_raw - T.max(hatt_raw, axis=-1, keepdims=True)) * hist_h_mask\n            # hatt_exp = theano.printing.Print('hatt_exp')(hatt_exp)\n            # hatt_exp = hatt_exp.flatten(2)\n            h_att_weights = hatt_exp / (T.sum(hatt_exp, axis=-1, keepdims=True) + 1e-7)\n            # h_att_weights = theano.printing.Print('h_att_weights')(h_att_weights)\n\n            # (batch_size, output_dim)\n            _h_ctx_vec = T.sum(hist_h * h_att_weights[:, :, None], axis=1)\n\n            return _h_ctx_vec\n\n        h_ctx_vec = T.switch(t,\n                             _attention_over_history(),\n                             T.zeros_like(h_tm1))\n\n        # h_ctx_vec = theano.printing.Print('h_ctx_vec')(h_ctx_vec)\n\n        ##### attention over history #####\n\n        ##### feed in parent hidden state #####\n\n        if not config.parent_hidden_state_feed:\n            t = 0\n\n        par_h = T.switch(t,\n                         hist_h[T.arange(hist_h.shape[0]), parent_t, :],\n                         T.zeros_like(h_tm1))\n\n        ##### feed in parent hidden state #####\n        if config.tree_attention:\n            i_t = self.inner_activation(\n                xi_t + T.dot(h_tm1 * b_u[0], u_i) + T.dot(ctx_vec, c_i) + T.dot(par_h, p_i) + T.dot(h_ctx_vec, h_i))\n            f_t = self.inner_activation(\n                xf_t + T.dot(h_tm1 * b_u[1], u_f) + T.dot(ctx_vec, c_f) + T.dot(par_h, p_f) + T.dot(h_ctx_vec, h_f))\n            c_t = f_t * c_tm1 + i_t * self.activation(\n                xc_t + T.dot(h_tm1 * b_u[2], u_c) + T.dot(ctx_vec, c_c) + T.dot(par_h, p_c) + T.dot(h_ctx_vec, h_c))\n            o_t = self.inner_activation(\n                xo_t + T.dot(h_tm1 * b_u[3], u_o) + T.dot(ctx_vec, c_o) + T.dot(par_h, p_o) + T.dot(h_ctx_vec, h_o))\n        else:\n            i_t = self.inner_activation(\n                xi_t + T.dot(h_tm1 * b_u[0], u_i) + T.dot(ctx_vec, c_i) + T.dot(par_h, p_i))  # + T.dot(h_ctx_vec, h_i)\n            f_t = self.inner_activation(\n                xf_t + T.dot(h_tm1 * b_u[1], u_f) + T.dot(ctx_vec, c_f) + T.dot(par_h, p_f))  # + T.dot(h_ctx_vec, h_f)\n            c_t = f_t * c_tm1 + i_t * self.activation(\n                xc_t + T.dot(h_tm1 * b_u[2], u_c) + T.dot(ctx_vec, c_c) + T.dot(par_h, p_c))  # + T.dot(h_ctx_vec, h_c)\n            o_t = self.inner_activation(\n                xo_t + T.dot(h_tm1 * b_u[3], u_o) + T.dot(ctx_vec, c_o) + T.dot(par_h, p_o))  # + T.dot(h_ctx_vec, h_o)\n        h_t = o_t * self.activation(c_t)\n\n        h_t = (1 - mask_t) * h_tm1 + mask_t * h_t\n        c_t = (1 - mask_t) * c_tm1 + mask_t * c_t\n\n        new_hist_h = T.set_subtensor(hist_h[:, t, :], h_t)\n\n        return h_t, c_t, ctx_vec, new_hist_h\n\n    def _for_step(self,\n                  xi_t, xf_t, xo_t, xc_t, mask_t,\n                  h_tm1, c_tm1,\n                  context, context_mask, context_att_trans,\n                  hist_h, hist_h_att_trans,\n                  b_u):\n\n        # context: (batch_size, context_size, context_dim)\n\n        # (batch_size, att_layer1_dim)\n        h_tm1_att_trans = T.dot(h_tm1, self.att_h_W1)\n\n        # (batch_size, context_size, att_layer1_dim)\n        att_hidden = T.tanh(context_att_trans + h_tm1_att_trans[:, None, :])\n\n        # (batch_size, context_size, 1)\n        att_raw = T.dot(att_hidden, self.att_W2) + self.att_b2\n\n        # (batch_size, context_size)\n        ctx_att = T.exp(att_raw).reshape((att_raw.shape[0], att_raw.shape[1]))\n\n        if context_mask:\n            ctx_att = ctx_att * context_mask\n\n        ctx_att = ctx_att / T.sum(ctx_att, axis=-1, keepdims=True)\n\n        # (batch_size, context_dim)\n        ctx_vec = T.sum(context * ctx_att[:, :, None], axis=1)\n\n        ##### attention over history #####\n\n        if hist_h:\n            hist_h = T.stack(hist_h).dimshuffle((1, 0, 2))\n            hist_h_att_trans = T.stack(hist_h_att_trans).dimshuffle((1, 0, 2))\n            h_tm1_hatt_trans = T.dot(h_tm1, self.hatt_h_W1)\n\n            hatt_hidden = T.tanh(hist_h_att_trans + h_tm1_hatt_trans[:, None, :])\n            hatt_raw = T.dot(hatt_hidden, self.hatt_W2) + self.hatt_b2\n            hatt_raw = hatt_raw.flatten(2)\n            h_att_weights = T.nnet.softmax(hatt_raw)\n\n            # (batch_size, output_dim)\n            h_ctx_vec = T.sum(hist_h * h_att_weights[:, :, None], axis=1)\n        else:\n            h_ctx_vec = T.zeros_like(h_tm1)\n\n        ##### attention over history #####\n\n        i_t = self.inner_activation(xi_t + T.dot(h_tm1 * b_u[0], self.U_i) + T.dot(ctx_vec, self.C_i) + T.dot(h_ctx_vec, self.H_i))\n        f_t = self.inner_activation(xf_t + T.dot(h_tm1 * b_u[1], self.U_f) + T.dot(ctx_vec, self.C_f) + T.dot(h_ctx_vec, self.H_f))\n        c_t = f_t * c_tm1 + i_t * self.activation(xc_t + T.dot(h_tm1 * b_u[2], self.U_c) + T.dot(ctx_vec, self.C_c) + T.dot(h_ctx_vec, self.H_c))\n        o_t = self.inner_activation(xo_t + T.dot(h_tm1 * b_u[3], self.U_o) + T.dot(ctx_vec, self.C_o) + T.dot(h_ctx_vec, self.H_o))\n        h_t = o_t * self.activation(c_t)\n\n        h_t = (1 - mask_t) * h_tm1 + mask_t * h_t\n        c_t = (1 - mask_t) * c_tm1 + mask_t * c_t\n\n        # ctx_vec = theano.printing.Print('ctx_vec')(ctx_vec)\n\n        return h_t, c_t, ctx_vec\n\n    def __call__(self, X, context, parent_t_seq, init_state=None, init_cell=None, hist_h=None,\n                 mask=None, context_mask=None,\n                 dropout=0, train=True, srng=None,\n                 time_steps=None):\n        assert context_mask.dtype == 'int8', 'context_mask is not int8, got %s' % context_mask.dtype\n\n        # (n_timestep, batch_size)\n        mask = self.get_mask(mask, X)\n        # (n_timestep, batch_size, input_dim)\n        X = X.dimshuffle((1, 0, 2))\n\n        retain_prob = 1. - dropout\n        B_w = np.ones((4,), dtype=theano.config.floatX)\n        B_u = np.ones((4,), dtype=theano.config.floatX)\n        if dropout > 0:\n            logging.info('applying dropout with p = %f', dropout)\n            if train:\n                B_w = srng.binomial((4, X.shape[1], self.input_dim), p=retain_prob,\n                                    dtype=theano.config.floatX)\n                B_u = srng.binomial((4, X.shape[1], self.output_dim), p=retain_prob,\n                                    dtype=theano.config.floatX)\n            else:\n                B_w *= retain_prob\n                B_u *= retain_prob\n\n        # (n_timestep, batch_size, output_dim)\n        xi = T.dot(X * B_w[0], self.W_i) + self.b_i\n        xf = T.dot(X * B_w[1], self.W_f) + self.b_f\n        xc = T.dot(X * B_w[2], self.W_c) + self.b_c\n        xo = T.dot(X * B_w[3], self.W_o) + self.b_o\n\n        # (batch_size, context_size, att_layer1_dim)\n        context_att_trans = T.dot(context, self.att_ctx_W1) + self.att_b1\n\n        if init_state:\n            # (batch_size, output_dim)\n            first_state = T.unbroadcast(init_state, 1)\n        else:\n            first_state = T.unbroadcast(alloc_zeros_matrix(X.shape[1], self.output_dim), 1)\n\n        if init_cell:\n            # (batch_size, output_dim)\n            first_cell = T.unbroadcast(init_cell, 1)\n        else:\n            first_cell = T.unbroadcast(alloc_zeros_matrix(X.shape[1], self.output_dim), 1)\n\n        if not hist_h:\n            # (batch_size, n_timestep, output_dim)\n            hist_h = alloc_zeros_matrix(X.shape[1], X.shape[0], self.output_dim)\n\n        if train:\n            n_timestep = X.shape[0]\n            time_steps = T.arange(n_timestep, dtype='int32')\n\n        # (n_timestep, batch_size)\n        parent_t_seq = parent_t_seq.dimshuffle((1, 0))\n\n        [outputs, cells, ctx_vectors, hist_h_outputs], updates = theano.scan(\n            self._step,\n            sequences=[time_steps, xi, xf, xo, xc, mask, parent_t_seq],\n            outputs_info=[\n                first_state,  # for h\n                first_cell,  # for cell\n                None, # T.unbroadcast(alloc_zeros_matrix(X.shape[1], self.context_dim), 1),  # for ctx vector\n                hist_h,  # for hist_h\n            ],\n            non_sequences=[\n                self.U_i, self.U_f, self.U_o, self.U_c,\n                self.C_i, self.C_f, self.C_o, self.C_c,\n                self.H_i, self.H_f, self.H_o, self.H_c,\n                self.P_i, self.P_f, self.P_o, self.P_c,\n                self.att_h_W1, self.att_W2, self.att_b2,\n                context, context_mask, context_att_trans,\n                B_u\n            ])\n\n        outputs = outputs.dimshuffle((1, 0, 2))\n        ctx_vectors = ctx_vectors.dimshuffle((1, 0, 2))\n        cells = cells.dimshuffle((1, 0, 2))\n\n        return outputs, cells, ctx_vectors\n\n    def get_mask(self, mask, X):\n        if mask is None:\n            mask = T.ones((X.shape[0], X.shape[1]))\n\n        mask = T.shape_padright(mask)  # (nb_samples, time, 1)\n        mask = T.addbroadcast(mask, -1)  # (time, nb_samples, 1) matrix.\n        mask = mask.dimshuffle(1, 0, 2)  # (time, nb_samples, 1)\n        mask = mask.astype('int8')\n\n        return mask"
  },
  {
    "path": "config.py",
    "content": "# MODE = 'django'\n#\n# SOURCE_VOCAB_SIZE = 2490 # 2492 # 5980\n# TARGET_VOCAB_SIZE = 2101 # 2110 # 4830 #\n# RULE_NUM = 222 # 228\n# NODE_NUM = 96\n#\n# NODE_EMBED_DIM = 256\n# EMBED_DIM = 128\n# RULE_EMBED_DIM = 256\n# QUERY_DIM = 256\n# LSTM_STATE_DIM = 256\n# DECODER_ATT_HIDDEN_DIM = 50\n# POINTER_NET_HIDDEN_DIM = 50\n#\n# MAX_QUERY_LENGTH = 70\n# MAX_EXAMPLE_ACTION_NUM = 100\n#\n# DECODER_DROPOUT = 0.2\n# WORD_DROPOUT = 0\n#\n# # encoder\n# ENCODER_LSTM = 'bilstm'\n#\n# # decoder\n# PARENT_HIDDEN_STATE_FEEDING = True\n# PARENT_RULE_FEEDING = True\n# NODE_TYPE_FEEDING = True\n# TREE_ATTENTION = True\n#\n# # training\n# TRAIN_PATIENCE = 10\n# MAX_EPOCH = 50\n# BATCH_SIZE = 10\n# VALID_PER_MINIBATCH = 4000\n# SAVE_PER_MINIBATCH = 4000\n#\n# # decoding\n# BEAM_SIZE = 15\n# DECODE_MAX_TIME_STEP = 100\n\nconfig_info = None\n\n"
  },
  {
    "path": "dataset.py",
    "content": "from __future__ import division\nimport copy\n\nimport nltk\nfrom collections import OrderedDict, defaultdict\nimport logging\nimport collections\nimport numpy as np\nimport string\nimport re\nimport astor\nfrom itertools import chain\n\nfrom nn.utils.io_utils import serialize_to_file, deserialize_from_file\n\nimport config\nfrom lang.py.parse import get_grammar\nfrom lang.py.unaryclosure import get_top_unary_closures, apply_unary_closures\n\n# define actions\nAPPLY_RULE = 0\nGEN_TOKEN = 1\nCOPY_TOKEN = 2\nGEN_COPY_TOKEN = 3\n\nACTION_NAMES = {APPLY_RULE: 'APPLY_RULE',\n                GEN_TOKEN: 'GEN_TOKEN',\n                COPY_TOKEN: 'COPY_TOKEN',\n                GEN_COPY_TOKEN: 'GEN_COPY_TOKEN'}\n\nclass Action(object):\n    def __init__(self, act_type, data):\n        self.act_type = act_type\n        self.data = data\n\n    def __repr__(self):\n        data_str = self.data if not isinstance(self.data, dict) else \\\n            ', '.join(['%s: %s' % (k, v) for k, v in self.data.iteritems()])\n        repr_str = 'Action{%s}[%s]' % (ACTION_NAMES[self.act_type], data_str)\n\n        return repr_str\n\n\nclass Vocab(object):\n    def __init__(self):\n        self.token_id_map = OrderedDict()\n        self.insert_token('<pad>')\n        self.insert_token('<unk>')\n        self.insert_token('<eos>')\n\n    @property\n    def unk(self):\n        return self.token_id_map['<unk>']\n\n    @property\n    def eos(self):\n        return self.token_id_map['<eos>']\n\n    def __getitem__(self, item):\n        if item in self.token_id_map:\n            return self.token_id_map[item]\n\n        logging.debug('encounter one unknown word [%s]' % item)\n        return self.token_id_map['<unk>']\n\n    def __contains__(self, item):\n        return item in self.token_id_map\n\n    @property\n    def size(self):\n        return len(self.token_id_map)\n\n    def __setitem__(self, key, value):\n        self.token_id_map[key] = value\n\n    def __len__(self):\n        return len(self.token_id_map)\n\n    def __iter__(self):\n        return self.token_id_map.iterkeys()\n\n    def iteritems(self):\n        return self.token_id_map.iteritems()\n\n    def complete(self):\n        self.id_token_map = dict((v, k) for (k, v) in self.token_id_map.iteritems())\n\n    def get_token(self, token_id):\n        return self.id_token_map[token_id]\n\n    def insert_token(self, token):\n        if token in self.token_id_map:\n            return self[token]\n        else:\n            idx = len(self)\n            self[token] = idx\n\n            return idx\n\n\nreplace_punctuation = string.maketrans(string.punctuation, ' '*len(string.punctuation))\n\n\ndef tokenize(str):\n    str = str.translate(replace_punctuation)\n    return nltk.word_tokenize(str)\n\n\ndef gen_vocab(tokens, vocab_size=3000, freq_cutoff=5):\n    word_freq = defaultdict(int)\n\n    for token in tokens:\n        word_freq[token] += 1\n\n    print 'total num. of tokens: %d' % len(word_freq)\n\n    words_freq_cutoff = [w for w in word_freq if word_freq[w] >= freq_cutoff]\n    print 'num. of words appear at least %d: %d' % (freq_cutoff, len(words_freq_cutoff))\n\n    ranked_words = sorted(words_freq_cutoff, key=word_freq.get, reverse=True)[:vocab_size-2]\n    ranked_words = set(ranked_words)\n\n    vocab = Vocab()\n    for token in tokens:\n        if token in ranked_words:\n            vocab.insert_token(token)\n\n    vocab.complete()\n\n    return vocab\n\n\nclass DataEntry:\n    def __init__(self, raw_id, query, parse_tree, code, actions, meta_data=None):\n        self.raw_id = raw_id\n        self.eid = -1\n        # FIXME: rename to query_token\n        self.query = query\n        self.parse_tree = parse_tree\n        self.actions = actions\n        self.code = code\n        self.meta_data = meta_data\n\n    @property\n    def data(self):\n        if not hasattr(self, '_data'):\n            assert self.dataset is not None, 'No associated dataset for the example'\n\n            self._data = self.dataset.get_prob_func_inputs([self.eid])\n\n        return self._data\n\n    def copy(self):\n        e = DataEntry(self.raw_id, self.query, self.parse_tree, self.code, self.actions, self.meta_data)\n\n        return e\n\n\nclass DataSet:\n    def __init__(self, annot_vocab, terminal_vocab, grammar, name='train_data'):\n        self.annot_vocab = annot_vocab\n        self.terminal_vocab = terminal_vocab\n        self.name = name\n        self.examples = []\n        self.data_matrix = dict()\n        self.grammar = grammar\n\n    def add(self, example):\n        example.eid = len(self.examples)\n        example.dataset = self\n        self.examples.append(example)\n\n    def get_dataset_by_ids(self, ids, name):\n        dataset = DataSet(self.annot_vocab, self.terminal_vocab,\n                          self.grammar, name)\n        for eid in ids:\n            example_copy = self.examples[eid].copy()\n            dataset.add(example_copy)\n\n        for k, v in self.data_matrix.iteritems():\n            dataset.data_matrix[k] = v[ids]\n\n        return dataset\n\n    @property\n    def count(self):\n        if self.examples:\n            return len(self.examples)\n\n        return 0\n\n    def get_examples(self, ids):\n        if isinstance(ids, collections.Iterable):\n            return [self.examples[i] for i in ids]\n        else:\n            return self.examples[ids]\n\n    def get_prob_func_inputs(self, ids):\n        order = ['query_tokens', 'tgt_action_seq', 'tgt_action_seq_type',\n                 'tgt_node_seq', 'tgt_par_rule_seq', 'tgt_par_t_seq']\n\n        max_src_seq_len = max(len(self.examples[i].query) for i in ids)\n        max_tgt_seq_len = max(len(self.examples[i].actions) for i in ids)\n\n        logging.debug('max. src sequence length: %d', max_src_seq_len)\n        logging.debug('max. tgt sequence length: %d', max_tgt_seq_len)\n\n        data = []\n        for entry in order:\n            if entry == 'query_tokens':\n                data.append(self.data_matrix[entry][ids, :max_src_seq_len])\n            else:\n                data.append(self.data_matrix[entry][ids, :max_tgt_seq_len])\n\n        return data\n\n\n    def init_data_matrices(self, max_query_length=70, max_example_action_num=100):\n        logging.info('init data matrices for [%s] dataset', self.name)\n        annot_vocab = self.annot_vocab\n        terminal_vocab = self.terminal_vocab\n\n        # np.max([len(e.query) for e in self.examples])\n        # np.max([len(e.rules) for e in self.examples])\n\n        query_tokens = self.data_matrix['query_tokens'] = np.zeros((self.count, max_query_length), dtype='int32')\n        tgt_node_seq = self.data_matrix['tgt_node_seq'] = np.zeros((self.count, max_example_action_num), dtype='int32')\n        tgt_par_rule_seq = self.data_matrix['tgt_par_rule_seq'] = np.zeros((self.count, max_example_action_num), dtype='int32')\n        tgt_par_t_seq = self.data_matrix['tgt_par_t_seq'] = np.zeros((self.count, max_example_action_num), dtype='int32')\n        tgt_action_seq = self.data_matrix['tgt_action_seq'] = np.zeros((self.count, max_example_action_num, 3), dtype='int32')\n        tgt_action_seq_type = self.data_matrix['tgt_action_seq_type'] = np.zeros((self.count, max_example_action_num, 3), dtype='int32')\n\n        for eid, example in enumerate(self.examples):\n            exg_query_tokens = example.query[:max_query_length]\n            exg_action_seq = example.actions[:max_example_action_num]\n\n            for tid, token in enumerate(exg_query_tokens):\n                token_id = annot_vocab[token]\n\n                query_tokens[eid, tid] = token_id\n\n            assert len(exg_action_seq) > 0\n\n            for t, action in enumerate(exg_action_seq):\n                if action.act_type == APPLY_RULE:\n                    rule = action.data['rule']\n                    tgt_action_seq[eid, t, 0] = self.grammar.rule_to_id[rule]\n                    tgt_action_seq_type[eid, t, 0] = 1\n                elif action.act_type == GEN_TOKEN:\n                    token = action.data['literal']\n                    token_id = terminal_vocab[token]\n                    tgt_action_seq[eid, t, 1] = token_id\n                    tgt_action_seq_type[eid, t, 1] = 1\n                elif action.act_type == COPY_TOKEN:\n                    src_token_idx = action.data['source_idx']\n                    tgt_action_seq[eid, t, 2] = src_token_idx\n                    tgt_action_seq_type[eid, t, 2] = 1\n                elif action.act_type == GEN_COPY_TOKEN:\n                    token = action.data['literal']\n                    token_id = terminal_vocab[token]\n                    tgt_action_seq[eid, t, 1] = token_id\n                    tgt_action_seq_type[eid, t, 1] = 1\n\n                    src_token_idx = action.data['source_idx']\n                    tgt_action_seq[eid, t, 2] = src_token_idx\n                    tgt_action_seq_type[eid, t, 2] = 1\n                else:\n                    raise RuntimeError('wrong action type!')\n\n                # parent information\n                rule = action.data['rule']\n                parent_rule = action.data['parent_rule']\n                tgt_node_seq[eid, t] = self.grammar.get_node_type_id(rule.parent)\n                if parent_rule:\n                    tgt_par_rule_seq[eid, t] = self.grammar.rule_to_id[parent_rule]\n                else:\n                    assert t == 0\n                    tgt_par_rule_seq[eid, t] = -1\n\n                # parent hidden states\n                parent_t = action.data['parent_t']\n                tgt_par_t_seq[eid, t] = parent_t\n\n            example.dataset = self\n\n\nclass DataHelper(object):\n    @staticmethod\n    def canonicalize_query(query):\n        return query\n\n\ndef parse_django_dataset_nt_only():\n    from parse import parse_django\n\n    annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.anno'\n\n    vocab = gen_vocab(annot_file, vocab_size=4500)\n\n    code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code'\n\n    grammar, all_parse_trees = parse_django(code_file)\n\n    train_data = DataSet(vocab, grammar, name='train')\n    dev_data = DataSet(vocab, grammar, name='dev')\n    test_data = DataSet(vocab, grammar, name='test')\n\n    # train_data\n\n    train_annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/train.anno'\n    train_parse_trees = all_parse_trees[0:16000]\n    for line, parse_tree in zip(open(train_annot_file), train_parse_trees):\n        if parse_tree.is_leaf:\n            continue\n\n        line = line.strip()\n        tokens = tokenize(line)\n        entry = DataEntry(tokens, parse_tree)\n\n        train_data.add(entry)\n\n    train_data.init_data_matrices()\n\n    # dev_data\n\n    dev_annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/dev.anno'\n    dev_parse_trees = all_parse_trees[16000:17000]\n    for line, parse_tree in zip(open(dev_annot_file), dev_parse_trees):\n        if parse_tree.is_leaf:\n            continue\n\n        line = line.strip()\n        tokens = tokenize(line)\n        entry = DataEntry(tokens, parse_tree)\n\n        dev_data.add(entry)\n\n    dev_data.init_data_matrices()\n\n    # test_data\n\n    test_annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/test.anno'\n    test_parse_trees = all_parse_trees[17000:18805]\n    for line, parse_tree in zip(open(test_annot_file), test_parse_trees):\n        if parse_tree.is_leaf:\n            continue\n\n        line = line.strip()\n        tokens = tokenize(line)\n        entry = DataEntry(tokens, parse_tree)\n\n        test_data.add(entry)\n\n    test_data.init_data_matrices()\n\n    serialize_to_file((train_data, dev_data, test_data), 'django.typed_rule.bin')\n\n\ndef parse_django_dataset():\n    from lang.py.parse import parse_raw\n    from lang.util import escape\n    MAX_QUERY_LENGTH = 70\n    UNARY_CUTOFF_FREQ = 30\n\n    annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.anno'\n    code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code'\n\n    data = preprocess_dataset(annot_file, code_file)\n\n    for e in data:\n        e['parse_tree'] = parse_raw(e['code'])\n\n    parse_trees = [e['parse_tree'] for e in data]\n\n    # apply unary closures\n    # unary_closures = get_top_unary_closures(parse_trees, k=0, freq=UNARY_CUTOFF_FREQ)\n    # for i, parse_tree in enumerate(parse_trees):\n    #     apply_unary_closures(parse_tree, unary_closures)\n\n    # build the grammar\n    grammar = get_grammar(parse_trees)\n\n    # write grammar\n    with open('django.grammar.unary_closure.txt', 'w') as f:\n        for rule in grammar:\n            f.write(rule.__repr__() + '\\n')\n\n    # # build grammar ...\n    # from lang.py.py_dataset import extract_grammar\n    # grammar, all_parse_trees = extract_grammar(code_file)\n\n    annot_tokens = list(chain(*[e['query_tokens'] for e in data]))\n    annot_vocab = gen_vocab(annot_tokens, vocab_size=5000, freq_cutoff=3) # gen_vocab(annot_tokens, vocab_size=5980)\n\n    terminal_token_seq = []\n    empty_actions_count = 0\n\n    # helper function begins\n    def get_terminal_tokens(_terminal_str):\n        # _terminal_tokens = filter(None, re.split('([, .?!])', _terminal_str)) # _terminal_str.split('-SP-')\n        # _terminal_tokens = filter(None, re.split('( )', _terminal_str))  # _terminal_str.split('-SP-')\n        tmp_terminal_tokens = _terminal_str.split(' ')\n        _terminal_tokens = []\n        for token in tmp_terminal_tokens:\n            if token:\n                _terminal_tokens.append(token)\n            _terminal_tokens.append(' ')\n\n        return _terminal_tokens[:-1]\n        # return _terminal_tokens\n    # helper function ends\n\n    # first pass\n    for entry in data:\n        idx = entry['id']\n        query_tokens = entry['query_tokens']\n        code = entry['code']\n        parse_tree = entry['parse_tree']\n\n        for node in parse_tree.get_leaves():\n            if grammar.is_value_node(node):\n                terminal_val = node.value\n                terminal_str = str(terminal_val)\n\n                terminal_tokens = get_terminal_tokens(terminal_str)\n\n                for terminal_token in terminal_tokens:\n                    assert len(terminal_token) > 0\n                    terminal_token_seq.append(terminal_token)\n\n    terminal_vocab = gen_vocab(terminal_token_seq, vocab_size=5000, freq_cutoff=3)\n    assert '_STR:0_' in terminal_vocab\n\n    train_data = DataSet(annot_vocab, terminal_vocab, grammar, 'train_data')\n    dev_data = DataSet(annot_vocab, terminal_vocab, grammar, 'dev_data')\n    test_data = DataSet(annot_vocab, terminal_vocab, grammar, 'test_data')\n\n    all_examples = []\n\n    can_fully_gen_num = 0\n\n    # second pass\n    for entry in data:\n        idx = entry['id']\n        query_tokens = entry['query_tokens']\n        code = entry['code']\n        str_map = entry['str_map']\n        parse_tree = entry['parse_tree']\n\n        rule_list, rule_parents = parse_tree.get_productions(include_value_node=True)\n\n        actions = []\n        can_fully_gen = True\n        rule_pos_map = dict()\n\n        for rule_count, rule in enumerate(rule_list):\n            if not grammar.is_value_node(rule.parent):\n                assert rule.value is None\n                parent_rule = rule_parents[(rule_count, rule)][0]\n                if parent_rule:\n                    parent_t = rule_pos_map[parent_rule]\n                else:\n                    parent_t = 0\n\n                rule_pos_map[rule] = len(actions)\n\n                d = {'rule': rule, 'parent_t': parent_t, 'parent_rule': parent_rule}\n                action = Action(APPLY_RULE, d)\n\n                actions.append(action)\n            else:\n                assert rule.is_leaf\n\n                parent_rule = rule_parents[(rule_count, rule)][0]\n                parent_t = rule_pos_map[parent_rule]\n\n                terminal_val = rule.value\n                terminal_str = str(terminal_val)\n                terminal_tokens = get_terminal_tokens(terminal_str)\n\n                # assert len(terminal_tokens) > 0\n\n                for terminal_token in terminal_tokens:\n                    term_tok_id = terminal_vocab[terminal_token]\n                    tok_src_idx = -1\n                    try:\n                        tok_src_idx = query_tokens.index(terminal_token)\n                    except ValueError:\n                        pass\n\n                    d = {'literal': terminal_token, 'rule': rule, 'parent_rule': parent_rule, 'parent_t': parent_t}\n\n                    # cannot copy, only generation\n                    # could be unk!\n                    if tok_src_idx < 0 or tok_src_idx >= MAX_QUERY_LENGTH:\n                        action = Action(GEN_TOKEN, d)\n                        if terminal_token not in terminal_vocab:\n                            if terminal_token not in query_tokens:\n                                # print terminal_token\n                                can_fully_gen = False\n                    else:  # copy\n                        if term_tok_id != terminal_vocab.unk:\n                            d['source_idx'] = tok_src_idx\n                            action = Action(GEN_COPY_TOKEN, d)\n                        else:\n                            d['source_idx'] = tok_src_idx\n                            action = Action(COPY_TOKEN, d)\n\n                    actions.append(action)\n\n                d = {'literal': '<eos>', 'rule': rule, 'parent_rule': parent_rule, 'parent_t': parent_t}\n                actions.append(Action(GEN_TOKEN, d))\n\n        if len(actions) == 0:\n            empty_actions_count += 1\n            continue\n\n        example = DataEntry(idx, query_tokens, parse_tree, code, actions,\n                            {'raw_code': entry['raw_code'], 'str_map': entry['str_map']})\n\n        if can_fully_gen:\n            can_fully_gen_num += 1\n\n        # train, valid, test\n        if 0 <= idx < 16000:\n            train_data.add(example)\n        elif 16000 <= idx < 17000:\n            dev_data.add(example)\n        else:\n            test_data.add(example)\n\n        all_examples.append(example)\n\n    # print statistics\n    max_query_len = max(len(e.query) for e in all_examples)\n    max_actions_len = max(len(e.actions) for e in all_examples)\n\n    serialize_to_file([len(e.query) for e in all_examples], 'query.len')\n    serialize_to_file([len(e.actions) for e in all_examples], 'actions.len')\n\n    logging.info('examples that can be fully reconstructed: %d/%d=%f',\n                 can_fully_gen_num, len(all_examples),\n                 can_fully_gen_num / len(all_examples))\n    logging.info('empty_actions_count: %d', empty_actions_count)\n    logging.info('max_query_len: %d', max_query_len)\n    logging.info('max_actions_len: %d', max_actions_len)\n\n    train_data.init_data_matrices()\n    dev_data.init_data_matrices()\n    test_data.init_data_matrices()\n\n    serialize_to_file((train_data, dev_data, test_data),\n                      'data/django.cleaned.dataset.freq3.par_info.refact.space_only.order_by_ulink_len.bin')\n                      # 'data/django.cleaned.dataset.freq5.par_info.refact.space_only.unary_closure.freq{UNARY_CUTOFF_FREQ}.order_by_ulink_len.bin'.format(UNARY_CUTOFF_FREQ=UNARY_CUTOFF_FREQ))\n\n    return train_data, dev_data, test_data\n\n\ndef check_terminals():\n    from parse import parse_django, unescape\n    grammar, parse_trees = parse_django('/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code')\n    annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.anno'\n\n    unique_terminals = set()\n    invalid_terminals = set()\n\n    for i, line in enumerate(open(annot_file)):\n        parse_tree = parse_trees[i]\n        utterance = line.strip()\n\n        leaves = parse_tree.get_leaves()\n        # tokens = set(nltk.word_tokenize(utterance))\n        leave_tokens = [l.label for l in leaves if l.label]\n\n        not_included = []\n        for leaf_token in leave_tokens:\n            leaf_token = str(leaf_token)\n            leaf_token = unescape(leaf_token)\n            if leaf_token not in utterance:\n                not_included.append(leaf_token)\n\n                if len(leaf_token) <= 15:\n                    unique_terminals.add(leaf_token)\n                else:\n                    invalid_terminals.add(leaf_token)\n            else:\n                if isinstance(leaf_token, str):\n                    print leaf_token\n\n        # if not_included:\n        #     print str(i) + '---' + ', '.join(not_included)\n\n    # print 'num of unique leaves: %d' % len(unique_terminals)\n    # print unique_terminals\n    #\n    # print 'num of invalid leaves: %d' % len(invalid_terminals)\n    # print invalid_terminals\n\n\ndef query_to_data(query, annot_vocab):\n    query_tokens = query.split(' ')\n    token_num = min(config.max_qeury_length, len(query_tokens))\n    data = np.zeros((1, token_num), dtype='int32')\n\n    for tid, token in enumerate(query_tokens[:token_num]):\n        token_id = annot_vocab[token]\n\n        data[0, tid] = token_id\n\n    return data\n\n\nQUOTED_STRING_RE = re.compile(r\"(?P<quote>['\\\"])(?P<string>.*?)(?<!\\\\)(?P=quote)\")\n\n\ndef canonicalize_query(query):\n    \"\"\"\n    canonicalize the query, replace strings to a special place holder\n    \"\"\"\n    str_count = 0\n    str_map = dict()\n\n    matches = QUOTED_STRING_RE.findall(query)\n    # de-duplicate\n    cur_replaced_strs = set()\n    for match in matches:\n        # If one or more groups are present in the pattern,\n        # it returns a list of groups\n        quote = match[0]\n        str_literal = quote + match[1] + quote\n\n        if str_literal in cur_replaced_strs:\n            continue\n\n        # FIXME: substitute the ' % s ' with\n        if str_literal in ['\\'%s\\'', '\\\"%s\\\"']:\n            continue\n\n        str_repr = '_STR:%d_' % str_count\n        str_map[str_literal] = str_repr\n\n        query = query.replace(str_literal, str_repr)\n\n        str_count += 1\n        cur_replaced_strs.add(str_literal)\n\n    # tokenize\n    query_tokens = nltk.word_tokenize(query)\n\n    new_query_tokens = []\n    # break up function calls like foo.bar.func\n    for token in query_tokens:\n        new_query_tokens.append(token)\n        i = token.find('.')\n        if 0 < i < len(token) - 1:\n            new_tokens = ['['] + token.replace('.', ' . ').split(' ') + [']']\n            new_query_tokens.extend(new_tokens)\n\n    query = ' '.join(new_query_tokens)\n\n    return query, str_map\n\n\ndef canonicalize_example(query, code):\n    from lang.py.parse import parse_raw, parse_tree_to_python_ast, canonicalize_code as make_it_compilable\n    import astor, ast\n\n    canonical_query, str_map = canonicalize_query(query)\n    canonical_code = code\n\n    for str_literal, str_repr in str_map.iteritems():\n        canonical_code = canonical_code.replace(str_literal, '\\'' + str_repr + '\\'')\n\n    canonical_code = make_it_compilable(canonical_code)\n\n    # sanity check\n    parse_tree = parse_raw(canonical_code)\n    gold_ast_tree = ast.parse(canonical_code).body[0]\n    gold_source = astor.to_source(gold_ast_tree)\n    ast_tree = parse_tree_to_python_ast(parse_tree)\n    source = astor.to_source(ast_tree)\n\n    assert gold_source == source, 'sanity check fails: gold=[%s], actual=[%s]' % (gold_source, source)\n\n    query_tokens = canonical_query.split(' ')\n\n    return query_tokens, canonical_code, str_map\n\n\ndef process_query(query, code):\n    from parse import code_to_ast, ast_to_tree, tree_to_ast, parse\n    import astor\n    str_count = 0\n    str_map = dict()\n\n    match_count = 1\n    match = QUOTED_STRING_RE.search(query)\n    while match:\n        str_repr = '_STR:%d_' % str_count\n        str_literal = match.group(0)\n        str_string = match.group(2)\n\n        match_count += 1\n\n        # if match_count > 50:\n        #     return\n        #\n\n        query = QUOTED_STRING_RE.sub(str_repr, query, 1)\n        str_map[str_literal] = str_repr\n\n        str_count += 1\n        match = QUOTED_STRING_RE.search(query)\n\n        code = code.replace(str_literal, '\\'' + str_repr + '\\'')\n\n    # clean the annotation\n    # query = query.replace('.', ' . ')\n\n    for k, v in str_map.iteritems():\n        if k == '\\'%s\\'' or k == '\\\"%s\\\"':\n            query = query.replace(v, k)\n            code = code.replace('\\'' + v + '\\'', k)\n\n    # tokenize\n    query_tokens = nltk.word_tokenize(query)\n\n    new_query_tokens = []\n    # break up function calls\n    for token in query_tokens:\n        new_query_tokens.append(token)\n        i = token.find('.')\n        if 0 < i < len(token) - 1:\n            new_tokens = ['['] + token.replace('.', ' . ').split(' ') + [']']\n            new_query_tokens.extend(new_tokens)\n\n    # check if the code compiles\n    tree = parse(code)\n    ast_tree = tree_to_ast(tree)\n    astor.to_source(ast_tree)\n\n    return new_query_tokens, code, str_map\n\n\ndef preprocess_dataset(annot_file, code_file):\n    f_annot = open('annot.all.canonicalized.txt', 'w')\n    f_code = open('code.all.canonicalized.txt', 'w')\n\n    examples = []\n\n    err_num = 0\n    for idx, (annot, code) in enumerate(zip(open(annot_file), open(code_file))):\n        annot = annot.strip()\n        code = code.strip()\n        try:\n            clean_query_tokens, clean_code, str_map = canonicalize_example(annot, code)\n            example = {'id': idx, 'query_tokens': clean_query_tokens, 'code': clean_code,\n                       'str_map': str_map, 'raw_code': code}\n            examples.append(example)\n\n            f_annot.write('example# %d\\n' % idx)\n            f_annot.write(' '.join(clean_query_tokens) + '\\n')\n            f_annot.write('%d\\n' % len(str_map))\n            for k, v in str_map.iteritems():\n                f_annot.write('%s ||| %s\\n' % (k, v))\n\n            f_code.write('example# %d\\n' % idx)\n            f_code.write(clean_code + '\\n')\n        except:\n            print code\n            err_num += 1\n\n        idx += 1\n\n    f_annot.close()\n    f_annot.close()\n\n    # serialize_to_file(examples, 'django.cleaned.bin')\n\n    print 'error num: %d' % err_num\n    print 'preprocess_dataset: cleaned example num: %d' % len(examples)\n\n    return examples\n\n\nif __name__== '__main__':\n    from nn.utils.generic_utils import init_logging\n    init_logging('parse.log')\n\n    # annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.anno'\n    # code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code'\n\n    # preprocess_dataset(annot_file, code_file)\n\n    # parse_django_dataset()\n    # check_terminals()\n\n    # print process_query(\"\"\" ALLOWED_VARIABLE_CHARS is a string 'abcdefgh\"ijklm\" nop\"%s\"qrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_.'.\"\"\")\n\n    # for i, query in enumerate(open('/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.anno')):\n    #     print i, process_query(query)\n\n    # clean_dataset()\n\n    parse_django_dataset()\n    # from lang.py.py_dataset import parse_hs_dataset\n    # parse_hs_dataset()\n"
  },
  {
    "path": "decoder.py",
    "content": "import traceback\nimport config\n\nfrom model import *\n\ndef decode_python_dataset(model, dataset, verbose=True):\n    from lang.py.parse import decode_tree_to_python_ast\n    if verbose:\n        logging.info('decoding [%s] set, num. examples: %d', dataset.name, dataset.count)\n\n    decode_results = []\n    cum_num = 0\n    for example in dataset.examples:\n        cand_list = model.decode(example, dataset.grammar, dataset.terminal_vocab,\n                                 beam_size=config.beam_size, max_time_step=config.decode_max_time_step)\n\n        exg_decode_results = []\n        for cid, cand in enumerate(cand_list[:10]):\n            try:\n                ast_tree = decode_tree_to_python_ast(cand.tree)\n                code = astor.to_source(ast_tree)\n                exg_decode_results.append((cid, cand, ast_tree, code))\n            except:\n                if verbose:\n                    print \"Exception in converting tree to code:\"\n                    print '-' * 60\n                    print 'raw_id: %d, beam pos: %d' % (example.raw_id, cid)\n                    traceback.print_exc(file=sys.stdout)\n                    print '-' * 60\n\n        cum_num += 1\n        if cum_num % 50 == 0 and verbose:\n            print '%d examples so far ...' % cum_num\n\n        decode_results.append(exg_decode_results)\n\n    return decode_results\n\n    # serialize_to_file(decode_results, '%s.decode_results.profile' % dataset.name)\n\ndef decode_ifttt_dataset(model, dataset, verbose=True):\n    if verbose:\n        logging.info('decoding [%s] set, num. examples: %d', dataset.name, dataset.count)\n\n    decode_results = []\n    cum_num = 0\n    for example in dataset.examples:\n        cand_list = model.decode(example, dataset.grammar, dataset.terminal_vocab,\n                                 beam_size=config.beam_size, max_time_step=config.decode_max_time_step)\n\n        exg_decode_results = []\n        for cid, cand in enumerate(cand_list[:10]):\n            try:\n                exg_decode_results.append((cid, cand))\n            except:\n                if verbose:\n                    print \"Exception in converting tree to code:\"\n                    print '-' * 60\n                    print 'raw_id: %d, beam pos: %d' % (example.raw_id, cid)\n                    traceback.print_exc(file=sys.stdout)\n                    print '-' * 60\n\n        cum_num += 1\n        if cum_num % 50 == 0 and verbose:\n            print '%d examples so far ...' % cum_num\n\n        decode_results.append(exg_decode_results)\n\n    return decode_results"
  },
  {
    "path": "evaluation.py",
    "content": "# -*- coding: UTF-8 -*-\n\nfrom __future__ import division\nimport os\nfrom nltk.translate.bleu_score import sentence_bleu, corpus_bleu, SmoothingFunction\nimport logging\nimport traceback\n\nfrom nn.utils.generic_utils import init_logging\n\nfrom model import *\n\n\nDJANGO_ANNOT_FILE = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.anno'\n\n\ndef tokenize_for_bleu_eval(code):\n    code = re.sub(r'([^A-Za-z0-9_])', r' \\1 ', code)\n    code = re.sub(r'([a-z])([A-Z])', r'\\1 \\2', code)\n    code = re.sub(r'\\s+', ' ', code)\n    code = code.replace('\"', '`')\n    code = code.replace('\\'', '`')\n    tokens = [t for t in code.split(' ') if t]\n\n    return tokens\n\n\ndef evaluate(model, dataset, verbose=True):\n    if verbose:\n        logging.info('evaluating [%s] dataset, [%d] examples' % (dataset.name, dataset.count))\n\n    exact_match_ratio = 0.0\n\n    for example in dataset.examples:\n        logging.info('evaluating example [%d]' % example.eid)\n        hyps, hyp_scores = model.decode(example, max_time_step=config.decode_max_time_step)\n        gold_rules = example.rules\n\n        if len(hyps) == 0:\n            logging.warning('no decoding result for example [%d]!' % example.eid)\n            continue\n\n        best_hyp = hyps[0]\n        predict_rules = [dataset.grammar.id_to_rule[rid] for rid in best_hyp]\n\n        assert len(predict_rules) > 0 and len(gold_rules) > 0\n\n        exact_match = sorted(gold_rules, key=lambda x: x.__repr__()) == sorted(predict_rules, key=lambda x: x.__repr__())\n        if exact_match:\n            exact_match_ratio += 1\n\n        # p = len(predict_rules.intersection(gold_rules)) / len(predict_rules)\n        # r = len(predict_rules.intersection(gold_rules)) / len(gold_rules)\n\n    exact_match_ratio /= dataset.count\n\n    logging.info('exact_match_ratio = %f' % exact_match_ratio)\n\n    return exact_match_ratio\n\n\ndef evaluate_decode_results(dataset, decode_results, verbose=True):\n    from lang.py.parse import tokenize_code, de_canonicalize_code\n    # tokenize_code = tokenize_for_bleu_eval\n    import ast\n    assert dataset.count == len(decode_results)\n\n    f = f_decode = None\n    if verbose:\n        f = open(dataset.name + '.exact_match', 'w')\n        exact_match_ids = []\n        f_decode = open(dataset.name + '.decode_results.txt', 'w')\n        eid_to_annot = dict()\n\n        if config.data_type == 'django':\n            for raw_id, line in enumerate(open(DJANGO_ANNOT_FILE)):\n                eid_to_annot[raw_id] = line.strip()\n\n        f_bleu_eval_ref = open(dataset.name + '.ref', 'w')\n        f_bleu_eval_hyp = open(dataset.name + '.hyp', 'w')\n        f_generated_code = open(dataset.name + '.geneated_code', 'w')\n\n        logging.info('evaluating [%s] set, [%d] examples', dataset.name, dataset.count)\n\n    cum_oracle_bleu = 0.0\n    cum_oracle_acc = 0.0\n    cum_bleu = 0.0\n    cum_acc = 0.0\n    sm = SmoothingFunction()\n\n    all_references = []\n    all_predictions = []\n\n    if all(len(cand) == 0 for cand in decode_results):\n        logging.ERROR('Empty decoding results for the current dataset!')\n        return -1, -1\n\n    for eid in range(dataset.count):\n        example = dataset.examples[eid]\n        ref_code = example.code\n        ref_ast_tree = ast.parse(ref_code).body[0]\n        refer_source = astor.to_source(ref_ast_tree).strip()\n        # refer_source = ref_code\n        refer_tokens = tokenize_code(refer_source)\n        cur_example_correct = False\n\n        decode_cands = decode_results[eid]\n        if len(decode_cands) == 0:\n            continue\n\n        decode_cand = decode_cands[0]\n\n        cid, cand, ast_tree, code = decode_cand\n        code = astor.to_source(ast_tree).strip()\n\n        # simple_url_2_re = re.compile('_STR:0_', re.))\n        try:\n            predict_tokens = tokenize_code(code)\n        except:\n            logging.error('error in tokenizing [%s]', code)\n            continue\n\n        if refer_tokens == predict_tokens:\n            cum_acc += 1\n            cur_example_correct = True\n\n            if verbose:\n                exact_match_ids.append(example.raw_id)\n                f.write('-' * 60 + '\\n')\n                f.write('example_id: %d\\n' % example.raw_id)\n                f.write(code + '\\n')\n                f.write('-' * 60 + '\\n')\n\n        if config.data_type == 'django':\n            ref_code_for_bleu = example.meta_data['raw_code']\n            pred_code_for_bleu = de_canonicalize_code(code, example.meta_data['raw_code'])\n            # ref_code_for_bleu = de_canonicalize_code(ref_code_for_bleu, example.meta_data['raw_code'])\n            # convert canonicalized code to raw code\n            for literal, place_holder in example.meta_data['str_map'].iteritems():\n                pred_code_for_bleu = pred_code_for_bleu.replace('\\'' + place_holder + '\\'', literal)\n                # ref_code_for_bleu = ref_code_for_bleu.replace('\\'' + place_holder + '\\'', literal)\n        elif config.data_type == 'hs':\n            ref_code_for_bleu = ref_code\n            pred_code_for_bleu = code\n\n        # we apply Ling Wang's trick when evaluating BLEU scores\n        refer_tokens_for_bleu = tokenize_for_bleu_eval(ref_code_for_bleu)\n        pred_tokens_for_bleu = tokenize_for_bleu_eval(pred_code_for_bleu)\n\n        # The if-chunk below is for debugging purpose, sometimes the reference cannot match with the prediction\n        # because of inconsistent quotes (e.g., single quotes in reference, double quotes in prediction).\n        # However most of these cases are solved by cannonicalizing the reference code using astor (parse the reference\n        # into AST, and regenerate the code. Use this regenerated one as the reference)\n        weired = False\n        if refer_tokens_for_bleu == pred_tokens_for_bleu and refer_tokens != predict_tokens:\n            # cum_acc += 1\n            weired = True\n        elif refer_tokens == predict_tokens:\n            # weired!\n            # weired = True\n            pass\n\n        shorter = len(pred_tokens_for_bleu) < len(refer_tokens_for_bleu)\n\n        all_references.append([refer_tokens_for_bleu])\n        all_predictions.append(pred_tokens_for_bleu)\n\n        # try:\n        ngram_weights = [0.25] * min(4, len(refer_tokens_for_bleu))\n        bleu_score = sentence_bleu([refer_tokens_for_bleu], pred_tokens_for_bleu, weights=ngram_weights, smoothing_function=sm.method3)\n        cum_bleu += bleu_score\n        # except:\n        #    pass\n\n        if verbose:\n            print 'raw_id: %d, bleu_score: %f' % (example.raw_id, bleu_score)\n\n            f_decode.write('-' * 60 + '\\n')\n            f_decode.write('example_id: %d\\n' % example.raw_id)\n            f_decode.write('intent: \\n')\n\n            if config.data_type == 'django':\n                f_decode.write(eid_to_annot[example.raw_id] + '\\n')\n            elif config.data_type == 'hs':\n                f_decode.write(' '.join(example.query) + '\\n')\n\n            f_bleu_eval_ref.write(' '.join(refer_tokens_for_bleu) + '\\n')\n            f_bleu_eval_hyp.write(' '.join(pred_tokens_for_bleu) + '\\n')\n\n            f_decode.write('canonicalized reference: \\n')\n            f_decode.write(refer_source + '\\n')\n            f_decode.write('canonicalized prediction: \\n')\n            f_decode.write(code + '\\n')\n            f_decode.write('reference code for bleu calculation: \\n')\n            f_decode.write(ref_code_for_bleu + '\\n')\n            f_decode.write('predicted code for bleu calculation: \\n')\n            f_decode.write(pred_code_for_bleu + '\\n')\n            f_decode.write('pred_shorter_than_ref: %s\\n' % shorter)\n            f_decode.write('weired: %s\\n' % weired)\n            f_decode.write('-' * 60 + '\\n')\n\n            # for Hiro's evaluation\n            f_generated_code.write(pred_code_for_bleu.replace('\\n', '#NEWLINE#') + '\\n')\n\n\n        # compute oracle\n        best_score = 0.\n        cur_oracle_acc = 0.\n        for decode_cand in decode_cands[:config.beam_size]:\n            cid, cand, ast_tree, code = decode_cand\n            try:\n                code = astor.to_source(ast_tree).strip()\n                predict_tokens = tokenize_code(code)\n\n                if predict_tokens == refer_tokens:\n                    cur_oracle_acc = 1\n\n                if config.data_type == 'django':\n                    pred_code_for_bleu = de_canonicalize_code(code, example.meta_data['raw_code'])\n                    # convert canonicalized code to raw code\n                    for literal, place_holder in example.meta_data['str_map'].iteritems():\n                        pred_code_for_bleu = pred_code_for_bleu.replace('\\'' + place_holder + '\\'', literal)\n                elif config.data_type == 'hs':\n                    pred_code_for_bleu = code\n\n                # we apply Ling Wang's trick when evaluating BLEU scores\n                pred_tokens_for_bleu = tokenize_for_bleu_eval(pred_code_for_bleu)\n\n                ngram_weights = [0.25] * min(4, len(refer_tokens_for_bleu))\n                bleu_score = sentence_bleu([refer_tokens_for_bleu], pred_tokens_for_bleu,\n                                           weights=ngram_weights,\n                                           smoothing_function=sm.method3)\n\n                if bleu_score > best_score:\n                    best_score = bleu_score\n\n            except:\n                continue\n\n        cum_oracle_bleu += best_score\n        cum_oracle_acc += cur_oracle_acc\n\n    cum_bleu /= dataset.count\n    cum_acc /= dataset.count\n    cum_oracle_bleu /= dataset.count\n    cum_oracle_acc /= dataset.count\n\n    logging.info('corpus level bleu: %f', corpus_bleu(all_references, all_predictions, smoothing_function=sm.method3))\n    logging.info('sentence level bleu: %f', cum_bleu)\n    logging.info('accuracy: %f', cum_acc)\n    logging.info('oracle bleu: %f', cum_oracle_bleu)\n    logging.info('oracle accuracy: %f', cum_oracle_acc)\n\n    if verbose:\n        f.write(', '.join(str(i) for i in exact_match_ids))\n        f.close()\n        f_decode.close()\n\n        f_bleu_eval_ref.close()\n        f_bleu_eval_hyp.close()\n        f_generated_code.close()\n\n    return cum_bleu, cum_acc\n\n\ndef analyze_decode_results(dataset, decode_results, verbose=True):\n    from lang.py.parse import tokenize_code, de_canonicalize_code\n    # tokenize_code = tokenize_for_bleu_eval\n    import ast\n    assert dataset.count == len(decode_results)\n\n    f = f_decode = None\n    if verbose:\n        f = open(dataset.name + '.exact_match', 'w')\n        exact_match_ids = []\n        f_decode = open(dataset.name + '.decode_results.txt', 'w')\n        eid_to_annot = dict()\n\n        if config.data_type == 'django':\n            for raw_id, line in enumerate(open(DJANGO_ANNOT_FILE)):\n                eid_to_annot[raw_id] = line.strip()\n\n        f_bleu_eval_ref = open(dataset.name + '.ref', 'w')\n        f_bleu_eval_hyp = open(dataset.name + '.hyp', 'w')\n\n        logging.info('evaluating [%s] set, [%d] examples', dataset.name, dataset.count)\n\n    cum_oracle_bleu = 0.0\n    cum_oracle_acc = 0.0\n    cum_bleu = 0.0\n    cum_acc = 0.0\n    sm = SmoothingFunction()\n\n    all_references = []\n    all_predictions = []\n\n    if all(len(cand) == 0 for cand in decode_results):\n        logging.ERROR('Empty decoding results for the current dataset!')\n        return -1, -1\n\n    binned_results_dict = defaultdict(list)\n    def get_binned_key(ast_size):\n        cutoff = 50 if config.data_type == 'django' else 250\n        k = 10 if config.data_type == 'django' else 25 # for hs\n\n        if ast_size >= cutoff:\n            return '%d - inf' % cutoff\n\n        lower = int(ast_size / k) * k\n        upper = lower + k\n\n        key = '%d - %d' % (lower, upper)\n\n        return key\n\n\n    for eid in range(dataset.count):\n        example = dataset.examples[eid]\n        ref_code = example.code\n        ref_ast_tree = ast.parse(ref_code).body[0]\n        refer_source = astor.to_source(ref_ast_tree).strip()\n        # refer_source = ref_code\n        refer_tokens = tokenize_code(refer_source)\n        cur_example_acc = 0.0\n\n        decode_cands = decode_results[eid]\n        if len(decode_cands) == 0:\n            continue\n\n        decode_cand = decode_cands[0]\n\n        cid, cand, ast_tree, code = decode_cand\n        code = astor.to_source(ast_tree).strip()\n\n        # simple_url_2_re = re.compile('_STR:0_', re.))\n        try:\n            predict_tokens = tokenize_code(code)\n        except:\n            logging.error('error in tokenizing [%s]', code)\n            continue\n\n        if refer_tokens == predict_tokens:\n            cum_acc += 1\n            cur_example_acc = 1.0\n\n            if verbose:\n                exact_match_ids.append(example.raw_id)\n                f.write('-' * 60 + '\\n')\n                f.write('example_id: %d\\n' % example.raw_id)\n                f.write(code + '\\n')\n                f.write('-' * 60 + '\\n')\n\n        if config.data_type == 'django':\n            ref_code_for_bleu = example.meta_data['raw_code']\n            pred_code_for_bleu = de_canonicalize_code(code, example.meta_data['raw_code'])\n            # ref_code_for_bleu = de_canonicalize_code(ref_code_for_bleu, example.meta_data['raw_code'])\n            # convert canonicalized code to raw code\n            for literal, place_holder in example.meta_data['str_map'].iteritems():\n                pred_code_for_bleu = pred_code_for_bleu.replace('\\'' + place_holder + '\\'', literal)\n                # ref_code_for_bleu = ref_code_for_bleu.replace('\\'' + place_holder + '\\'', literal)\n        elif config.data_type == 'hs':\n            ref_code_for_bleu = ref_code\n            pred_code_for_bleu = code\n\n        # we apply Ling Wang's trick when evaluating BLEU scores\n        refer_tokens_for_bleu = tokenize_for_bleu_eval(ref_code_for_bleu)\n        pred_tokens_for_bleu = tokenize_for_bleu_eval(pred_code_for_bleu)\n\n        shorter = len(pred_tokens_for_bleu) < len(refer_tokens_for_bleu)\n\n        all_references.append([refer_tokens_for_bleu])\n        all_predictions.append(pred_tokens_for_bleu)\n\n        # try:\n        ngram_weights = [0.25] * min(4, len(refer_tokens_for_bleu))\n        bleu_score = sentence_bleu([refer_tokens_for_bleu], pred_tokens_for_bleu, weights=ngram_weights, smoothing_function=sm.method3)\n        cum_bleu += bleu_score\n        # except:\n        #    pass\n\n        if verbose:\n            print 'raw_id: %d, bleu_score: %f' % (example.raw_id, bleu_score)\n\n            f_decode.write('-' * 60 + '\\n')\n            f_decode.write('example_id: %d\\n' % example.raw_id)\n            f_decode.write('intent: \\n')\n\n            if config.data_type == 'django':\n                f_decode.write(eid_to_annot[example.raw_id] + '\\n')\n            elif config.data_type == 'hs':\n                f_decode.write(' '.join(example.query) + '\\n')\n\n            f_bleu_eval_ref.write(' '.join(refer_tokens_for_bleu) + '\\n')\n            f_bleu_eval_hyp.write(' '.join(pred_tokens_for_bleu) + '\\n')\n\n            f_decode.write('canonicalized reference: \\n')\n            f_decode.write(refer_source + '\\n')\n            f_decode.write('canonicalized prediction: \\n')\n            f_decode.write(code + '\\n')\n            f_decode.write('reference code for bleu calculation: \\n')\n            f_decode.write(ref_code_for_bleu + '\\n')\n            f_decode.write('predicted code for bleu calculation: \\n')\n            f_decode.write(pred_code_for_bleu + '\\n')\n            f_decode.write('pred_shorter_than_ref: %s\\n' % shorter)\n            # f_decode.write('weired: %s\\n' % weired)\n            f_decode.write('-' * 60 + '\\n')\n\n        # compute oracle\n        best_bleu_score = 0.\n        cur_oracle_acc = 0.\n        for decode_cand in decode_cands[:config.beam_size]:\n            cid, cand, ast_tree, code = decode_cand\n            try:\n                code = astor.to_source(ast_tree).strip()\n                predict_tokens = tokenize_code(code)\n\n                if predict_tokens == refer_tokens:\n                    cur_oracle_acc = 1.\n\n                if config.data_type == 'django':\n                    pred_code_for_bleu = de_canonicalize_code(code, example.meta_data['raw_code'])\n                    # convert canonicalized code to raw code\n                    for literal, place_holder in example.meta_data['str_map'].iteritems():\n                        pred_code_for_bleu = pred_code_for_bleu.replace('\\'' + place_holder + '\\'', literal)\n                elif config.data_type == 'hs':\n                    pred_code_for_bleu = code\n\n                # we apply Ling Wang's trick when evaluating BLEU scores\n                pred_tokens_for_bleu = tokenize_for_bleu_eval(pred_code_for_bleu)\n\n                ngram_weights = [0.25] * min(4, len(refer_tokens_for_bleu))\n                cand_bleu_score = sentence_bleu([refer_tokens_for_bleu], pred_tokens_for_bleu,\n                                                weights=ngram_weights,\n                                                smoothing_function=sm.method3)\n\n                if cand_bleu_score > best_bleu_score:\n                    best_bleu_score = cand_bleu_score\n\n            except:\n                continue\n\n        cum_oracle_bleu += best_bleu_score\n        cum_oracle_acc += cur_oracle_acc\n\n        ref_ast_size = example.parse_tree.size\n        binned_key = get_binned_key(ref_ast_size)\n        binned_results_dict[binned_key].append((bleu_score, cur_example_acc, best_bleu_score, cur_oracle_acc))\n\n    cum_bleu /= dataset.count\n    cum_acc /= dataset.count\n    cum_oracle_bleu /= dataset.count\n    cum_oracle_acc /= dataset.count\n\n    logging.info('corpus level bleu: %f', corpus_bleu(all_references, all_predictions, smoothing_function=sm.method3))\n    logging.info('sentence level bleu: %f', cum_bleu)\n    logging.info('accuracy: %f', cum_acc)\n    logging.info('oracle bleu: %f', cum_oracle_bleu)\n    logging.info('oracle accuracy: %f', cum_oracle_acc)\n\n    keys = sorted(binned_results_dict, key=lambda x: int(x.split(' - ')[0]))\n\n    Y = [[], [], [], []]\n    X = []\n\n    for binned_key in keys:\n        entry = binned_results_dict[binned_key]\n        avg_bleu = np.average([t[0] for t in entry])\n        avg_acc = np.average([t[1] for t in entry])\n        avg_oracle_bleu = np.average([t[2] for t in entry])\n        avg_oracle_acc = np.average([t[3] for t in entry])\n        print binned_key, avg_bleu, avg_acc, avg_oracle_bleu, avg_oracle_acc, len(entry)\n\n        Y[0].append(avg_bleu)\n        Y[1].append(avg_acc)\n        Y[2].append(avg_oracle_bleu)\n        Y[3].append(avg_oracle_acc)\n\n        X.append(int(binned_key.split(' - ')[0]))\n\n    import matplotlib.pyplot as plt\n    from pylab import rcParams\n    rcParams['figure.figsize'] = 6, 2.5\n\n    if config.data_type == 'django':\n        fig, ax = plt.subplots()\n        ax.plot(X, Y[0], 'bs--', label='BLEU', lw=1.2)\n        # ax.plot(X, Y[2], 'r^--', label='oracle BLEU', lw=1.2)\n        ax.plot(X, Y[1], 'r^--', label='acc', lw=1.2)\n        # ax.plot(X, Y[3], 'r^--', label='oracle acc', lw=1.2)\n        ax.set_ylabel('Performance')\n        ax.set_xlabel('Reference AST Size (# nodes)')\n        plt.legend(loc='upper right', ncol=6)\n        plt.tight_layout()\n        # plt.savefig('django_acc_ast_size.pdf', dpi=300)\n        # os.system('pcrop.sh django_acc_ast_size.pdf')\n        plt.savefig('django_perf_ast_size.pdf', dpi=300)\n        os.system('pcrop.sh django_perf_ast_size.pdf')\n    else:\n        fig, ax = plt.subplots()\n        ax.plot(X, Y[0], 'bs--', label='BLEU', lw=1.2)\n        # ax.plot(X, Y[2], 'r^--', label='oracle BLEU', lw=1.2)\n        ax.plot(X, Y[1], 'r^--', label='acc', lw=1.2)\n        # ax.plot(X, Y[3], 'r^--', label='oracle acc', lw=1.2)\n        ax.set_ylabel('Performance')\n        ax.set_xlabel('Reference AST Size (# nodes)')\n        plt.legend(loc='upper right', ncol=6)\n        plt.tight_layout()\n        # plt.savefig('hs_bleu_ast_size.pdf', dpi=300)\n        # os.system('pcrop.sh hs_bleu_ast_size.pdf')\n        plt.savefig('hs_perf_ast_size.pdf', dpi=300)\n        os.system('pcrop.sh hs_perf_ast_size.pdf')\n    if verbose:\n        f.write(', '.join(str(i) for i in exact_match_ids))\n        f.close()\n        f_decode.close()\n\n        f_bleu_eval_ref.close()\n        f_bleu_eval_hyp.close()\n\n    return cum_bleu, cum_acc\n\n\ndef evaluate_seq2seq_decode_results(dataset, seq2seq_decode_file, seq2seq_ref_file, verbose=True, is_nbest=False):\n    from lang.py.parse import parse\n\n    f_seq2seq_decode = open(seq2seq_decode_file)\n    f_seq2seq_ref = open(seq2seq_ref_file)\n\n    if verbose:\n        logging.info('evaluating [%s] set, [%d] examples', dataset.name, dataset.count)\n\n    cum_bleu = 0.0\n    cum_acc = 0.0\n    sm = SmoothingFunction()\n\n    decode_file_data = [l.strip() for l in f_seq2seq_decode.readlines()]\n    ref_code_data = [l.strip() for l in f_seq2seq_ref.readlines()]\n\n    if is_nbest:\n        for i in xrange(len(decode_file_data)):\n            d = decode_file_data[i].split(' ||| ')\n            decode_file_data[i] = (int(d[0]), d[1])\n\n    def is_well_formed_python_code(_hyp):\n        try:\n            _hyp = _hyp.replace('#NEWLINE#', '\\n').replace('#INDENT#', '    ').replace(' #MERGE# ', '')\n            hyp_ast_tree = parse(_hyp)\n            return True\n        except:\n            return False\n\n    for eid in range(dataset.count):\n        example = dataset.examples[eid]\n        cur_example_correct = False\n\n        if is_nbest:\n            # find the best-scored well-formed code from the n-best list\n            n_best_list = filter(lambda x: x[0] == eid, decode_file_data)\n            code = top_scored_code = n_best_list[0][1]\n            for _, hyp in n_best_list:\n                if is_well_formed_python_code(hyp):\n                    code = hyp\n                    break\n\n            if top_scored_code != code:\n                print '*' * 60\n                print top_scored_code\n                print code\n                print '*' * 60\n\n            code = n_best_list[0][1]\n        else:\n            code = decode_file_data[eid]\n\n        code = code.replace('#NEWLINE#', '\\n').replace('#INDENT#', '    ').replace(' #MERGE# ', '')\n        ref_code = ref_code_data[eid].replace('#NEWLINE#', '\\n').replace('#INDENT#', '    ').replace(' #MERGE# ', '')\n\n        if code == ref_code:\n            cum_acc += 1\n            cur_example_correct = True\n\n\n        if config.data_type == 'django':\n            ref_code_for_bleu = example.meta_data['raw_code']\n            pred_code_for_bleu = code # de_canonicalize_code(code, example.meta_data['raw_code'])\n            # ref_code_for_bleu = de_canonicalize_code(ref_code_for_bleu, example.meta_data['raw_code'])\n            # convert canonicalized code to raw code\n            for literal, place_holder in example.meta_data['str_map'].iteritems():\n                pred_code_for_bleu = pred_code_for_bleu.replace('\\'' + place_holder + '\\'', literal)\n                # ref_code_for_bleu = ref_code_for_bleu.replace('\\'' + place_holder + '\\'', literal)\n        elif config.data_type == 'hs':\n            ref_code_for_bleu = example.code\n            pred_code_for_bleu = code\n\n        # we apply Ling Wang's trick when evaluating BLEU scores\n        refer_tokens_for_bleu = tokenize_for_bleu_eval(ref_code_for_bleu)\n        pred_tokens_for_bleu = tokenize_for_bleu_eval(pred_code_for_bleu)\n\n        ngram_weights = [0.25] * min(4, len(refer_tokens_for_bleu))\n        bleu_score = sentence_bleu([refer_tokens_for_bleu], pred_tokens_for_bleu, weights=ngram_weights, smoothing_function=sm.method3)\n        cum_bleu += bleu_score\n\n    cum_bleu /= dataset.count\n    cum_acc /= dataset.count\n\n    logging.info('sentence level bleu: %f', cum_bleu)\n    logging.info('accuracy: %f', cum_acc)\n\n\ndef evaluate_seq2tree_sample_file(sample_file, id_file, dataset):\n    from lang.py.parse import tokenize_code, de_canonicalize_code\n    import ast, astor\n    import traceback\n    from lang.py.seq2tree_exp import seq2tree_repr_to_ast_tree, merge_broken_value_nodes\n    from lang.py.parse import decode_tree_to_python_ast\n\n    f_sample = open(sample_file)\n    line_id_to_raw_id = OrderedDict()\n    raw_id_to_eid = OrderedDict()\n    for i, line in enumerate(open(id_file)):\n        raw_id = int(line.strip())\n        line_id_to_raw_id[i] = raw_id\n\n    for eid in range(len(dataset.examples)):\n        raw_id_to_eid[dataset.examples[eid].raw_id] = eid\n\n    rare_word_map = defaultdict(dict)\n    if config.seq2tree_rareword_map:\n        logging.info('use rare word map')\n        for i, line in enumerate(open(config.seq2tree_rareword_map)):\n            line = line.strip()\n            if line:\n                for e in line.split(' '):\n                    d = e.split(':', 1)\n                    rare_word_map[i][int(d[0])] = d[1]\n\n    cum_bleu = 0.0\n    cum_acc = 0.0\n    sm = SmoothingFunction()\n    convert_error_num = 0\n\n    for i in range(len(line_id_to_raw_id)):\n        # print 'working on %d' % i\n        ref_repr = f_sample.readline().strip()\n        predict_repr = f_sample.readline().strip()\n        predict_repr = predict_repr.replace('<U>', 'str{}{unk}') # .replace('( )', '( str{}{unk} )')\n        f_sample.readline()\n\n        # if ' ( ) ' in ref_repr:\n        #     print i, ref_repr\n\n        if i in rare_word_map:\n            for unk_id, w in rare_word_map[i].iteritems():\n                ref_repr = ref_repr.replace(' str{}{unk_%s} ' % unk_id, ' str{}{%s} ' % w)\n                predict_repr = predict_repr.replace(' str{}{unk_%s} ' % unk_id, ' str{}{%s} ' % w)\n\n        try:\n            parse_tree = seq2tree_repr_to_ast_tree(predict_repr)\n            merge_broken_value_nodes(parse_tree)\n        except:\n            print 'error when converting:'\n            print predict_repr\n            convert_error_num += 1\n            continue\n\n        raw_id = line_id_to_raw_id[i]\n        eid = raw_id_to_eid[raw_id]\n        example = dataset.examples[eid]\n\n        ref_code = example.code\n        ref_ast_tree = ast.parse(ref_code).body[0]\n        refer_source = astor.to_source(ref_ast_tree).strip()\n        refer_tokens = tokenize_code(refer_source)\n\n        try:\n            ast_tree = decode_tree_to_python_ast(parse_tree)\n            code = astor.to_source(ast_tree).strip()\n        except:\n            print \"Exception in converting tree to code:\"\n            print '-' * 60\n            print 'line id: %d' % i\n            traceback.print_exc(file=sys.stdout)\n            print '-' * 60\n            convert_error_num += 1\n            continue\n\n        if config.data_type == 'django':\n            ref_code_for_bleu = example.meta_data['raw_code']\n            pred_code_for_bleu = de_canonicalize_code(code, example.meta_data['raw_code'])\n            # convert canonicalized code to raw code\n            for literal, place_holder in example.meta_data['str_map'].iteritems():\n                pred_code_for_bleu = pred_code_for_bleu.replace('\\'' + place_holder + '\\'', literal)\n        elif config.data_type == 'hs':\n            ref_code_for_bleu = ref_code\n            pred_code_for_bleu = code\n\n        # we apply Ling Wang's trick when evaluating BLEU scores\n        refer_tokens_for_bleu = tokenize_for_bleu_eval(ref_code_for_bleu)\n        pred_tokens_for_bleu = tokenize_for_bleu_eval(pred_code_for_bleu)\n\n        predict_tokens = tokenize_code(code)\n        # if ref_repr == predict_repr:\n        if predict_tokens == refer_tokens:\n            cum_acc += 1\n\n        ngram_weights = [0.25] * min(4, len(refer_tokens_for_bleu))\n        bleu_score = sentence_bleu([refer_tokens_for_bleu], pred_tokens_for_bleu, weights=ngram_weights,\n                                   smoothing_function=sm.method3)\n        cum_bleu += bleu_score\n\n    cum_bleu /= len(line_id_to_raw_id)\n    cum_acc /= len(line_id_to_raw_id)\n    logging.info('nun. examples: %d', len(line_id_to_raw_id))\n    logging.info('num. errors when converting repr to tree: %d', convert_error_num)\n    logging.info('ratio of grammatically incorrect trees: %f', convert_error_num / float(len(line_id_to_raw_id)))\n    logging.info('sentence level bleu: %f', cum_bleu)\n    logging.info('accuracy: %f', cum_acc)\n\n\ndef evaluate_ifttt_results(dataset, decode_results, verbose=True):\n    assert dataset.count == len(decode_results)\n\n    f = f_decode = None\n    if verbose:\n        f = open(dataset.name + '.exact_match', 'w')\n        exact_match_ids = []\n        f_decode = open(os.path.join(config.output_dir, dataset.name + '.decode_results.txt'), 'w')\n\n        logging.info('evaluating [%s] set, [%d] examples', dataset.name, dataset.count)\n\n    cum_channel_acc = 0.0\n    cum_channel_func_acc = 0.0\n    cum_prod_f1 = 0.0\n    cum_oracle_prod_f1 = 0.0\n\n    if all(len(cand) == 0 for cand in decode_results):\n        logging.ERROR('Empty decoding results for the current dataset!')\n        return -1, -1, -1\n\n    for eid in range(dataset.count):\n        example = dataset.examples[eid]\n        ref_parse_tree = example.parse_tree\n        decode_candidates = decode_results[eid]\n\n        if len(decode_candidates) == 0:\n            continue\n\n        decode_cand = decode_candidates[0]\n\n        cid, cand_hyp = decode_cand\n        predict_parse_tree = cand_hyp.tree\n\n        exact_match = predict_parse_tree == ref_parse_tree\n\n        channel_acc, channel_func_acc, prod_f1 = ifttt_metric(predict_parse_tree, ref_parse_tree)\n        cum_channel_acc += channel_acc\n        cum_channel_func_acc += channel_func_acc\n        cum_prod_f1 += prod_f1\n\n        if verbose:\n            if exact_match:\n                exact_match_ids.append(example.raw_id)\n\n            print 'raw_id: %d, prod_f1: %f' % (example.raw_id, prod_f1)\n\n            f_decode.write('-' * 60 + '\\n')\n            f_decode.write('example_id: %d\\n' % example.raw_id)\n            f_decode.write('intent: \\n')\n\n            f_decode.write(' '.join(example.query) + '\\n')\n\n            f_decode.write('reference: \\n')\n            f_decode.write(str(ref_parse_tree) + '\\n')\n            f_decode.write('prediction: \\n')\n            f_decode.write(str(predict_parse_tree) + '\\n')\n            f_decode.write('-' * 60 + '\\n')\n\n        # compute oracle\n        best_prod_f1 = -1.\n        for decode_cand in decode_candidates[:10]:\n            cid, cand_hyp = decode_cand\n            predict_parse_tree = cand_hyp.tree\n\n            channel_acc, channel_func_acc, prod_f1 = ifttt_metric(predict_parse_tree, ref_parse_tree)\n\n            if prod_f1 > best_prod_f1:\n                best_prod_f1 = prod_f1\n\n        cum_oracle_prod_f1 += best_prod_f1\n\n    cum_channel_acc /= dataset.count\n    cum_channel_func_acc /= dataset.count\n    cum_prod_f1 /= dataset.count\n    cum_oracle_prod_f1 /= dataset.count\n\n    logging.info('channel_acc: %f', cum_channel_acc)\n    logging.info('channel_func_acc: %f', cum_channel_func_acc)\n    logging.info('prod_f1: %f', cum_prod_f1)\n    logging.info('oracle prod_f1: %f', cum_oracle_prod_f1)\n\n    if verbose:\n        f.write(', '.join(str(i) for i in exact_match_ids))\n        f.close()\n        f_decode.close()\n\n    return cum_channel_acc, cum_channel_func_acc, cum_prod_f1\n\n\ndef ifttt_metric(predict_parse_tree, ref_parse_tree):\n    channel_acc = channel_func_acc = prod_f1 = 0.\n    # channel acc.\n    channel_match = False\n    if predict_parse_tree['TRIGGER'].children[0].type == ref_parse_tree['TRIGGER'].children[0].type and \\\n                    predict_parse_tree['ACTION'].children[0].type == ref_parse_tree['ACTION'].children[0].type:\n        channel_acc += 1.\n        channel_match = True\n\n    # channel+func acc.\n    if channel_match and predict_parse_tree['TRIGGER'].children[0].children[0].type == ref_parse_tree['TRIGGER'].children[0].children[0].type and \\\n                    predict_parse_tree['ACTION'].children[0].children[0].type == ref_parse_tree['ACTION'].children[0].children[0].type:\n        channel_func_acc += 1.\n\n    # predict_parse_tree is of type DecodingTree, different from reference tree!\n    # if predict_parse_tree == ref_parse_tree:\n    #     channel_func_acc += 1.\n\n    # prod. F1\n    ref_rules, _ = ref_parse_tree.get_productions()\n    predict_rules, _ = predict_parse_tree.get_productions()\n\n    prod_f1 = len(set(ref_rules).intersection(set(predict_rules))) / len(ref_rules)\n\n    return channel_acc, channel_func_acc, prod_f1\n\n\ndef decode_and_evaluate_ifttt(model, test_data):\n    raw_ids = [int(i.strip()) for i in open(config.ifttt_test_split)]  # 'data/ifff.test_data.gold.id'\n    eids  = [i for i, e in enumerate(test_data.examples) if e.raw_id in raw_ids]\n    test_data_subset = test_data.get_dataset_by_ids(eids, test_data.name + '.subset')\n\n    from decoder import decode_ifttt_dataset\n    decode_results = decode_ifttt_dataset(model, test_data_subset, verbose=True)\n    evaluate_ifttt_results(test_data_subset, decode_results)\n\n    return decode_results\n\n\ndef decode_and_evaluate_ifttt_by_split(model, test_data):\n    for split in ['ifff.test_data.omit_non_english.id', 'ifff.test_data.omit_unintelligible.id', 'ifff.test_data.gold.id']:\n        raw_ids = [int(i.strip()) for i in open(os.path.join(config.ifttt_test_split), split)]  # 'data/ifff.test_data.gold.id'\n        eids = [i for i, e in enumerate(test_data.examples) if e.raw_id in raw_ids]\n        test_data_subset = test_data.get_dataset_by_ids(eids, test_data.name + '.' + split)\n\n        from decoder import decode_ifttt_dataset\n        decode_results = decode_ifttt_dataset(model, test_data_subset, verbose=True)\n        evaluate_ifttt_results(test_data_subset, decode_results)\n\n\nif __name__ == '__main__':\n    from dataset import DataEntry, DataSet, Vocab, Action\n    init_logging('parser.log', logging.INFO)\n\n    train_data, dev_data, test_data = deserialize_from_file('data/ifttt.freq3.bin')\n    decoding_results = []\n    for eid in range(test_data.count):\n        example = test_data.examples[eid]\n        decoding_results.append([(eid, example.parse_tree)])\n\n    evaluate_ifttt_results(test_data, decoding_results, verbose=True)\n"
  },
  {
    "path": "interactive_mode.py",
    "content": "import argparse, sys\nfrom nn.utils.generic_utils import init_logging\nfrom nn.utils.io_utils import deserialize_from_file, serialize_to_file\nfrom evaluation import *\nfrom dataset import canonicalize_query, query_to_data\nfrom collections import namedtuple\nfrom lang.py.parse import decode_tree_to_python_ast\nfrom model import Model\nfrom dataset import DataEntry, DataSet, Vocab, Action\nimport config\n\nparser = argparse.ArgumentParser()\nparser.add_argument('-data_type', default='django', choices=['django', 'hs'])\nparser.add_argument('-data')\nparser.add_argument('-random_seed', default=181783, type=int)\nparser.add_argument('-model', default=None)\n\n# neural model's parameters\nparser.add_argument('-source_vocab_size', default=0, type=int)\nparser.add_argument('-target_vocab_size', default=0, type=int)\nparser.add_argument('-rule_num', default=0, type=int)\nparser.add_argument('-node_num', default=0, type=int)\n\nparser.add_argument('-word_embed_dim', default=128, type=int)\nparser.add_argument('-rule_embed_dim', default=128, type=int)\nparser.add_argument('-node_embed_dim', default=64, type=int)\nparser.add_argument('-encoder_hidden_dim', default=256, type=int)\nparser.add_argument('-decoder_hidden_dim', default=256, type=int)\nparser.add_argument('-attention_hidden_dim', default=50, type=int)\nparser.add_argument('-ptrnet_hidden_dim', default=50, type=int)\nparser.add_argument('-dropout', default=0.2, type=float)\n\n# encoder\nparser.add_argument('-encoder', default='bilstm', choices=['bilstm', 'lstm'])\n\n# decoder\nparser.add_argument('-parent_hidden_state_feed', dest='parent_hidden_state_feed', action='store_true')\nparser.add_argument('-no_parent_hidden_state_feed', dest='parent_hidden_state_feed', action='store_false')\nparser.set_defaults(parent_hidden_state_feed=True)\n\nparser.add_argument('-parent_action_feed', dest='parent_action_feed', action='store_true')\nparser.add_argument('-no_parent_action_feed', dest='parent_action_feed', action='store_false')\nparser.set_defaults(parent_action_feed=True)\n\nparser.add_argument('-frontier_node_type_feed', dest='frontier_node_type_feed', action='store_true')\nparser.add_argument('-no_frontier_node_type_feed', dest='frontier_node_type_feed', action='store_false')\nparser.set_defaults(frontier_node_type_feed=True)\n\nparser.add_argument('-tree_attention', dest='tree_attention', action='store_true')\nparser.add_argument('-no_tree_attention', dest='tree_attention', action='store_false')\nparser.set_defaults(tree_attention=False)\n\nparser.add_argument('-enable_copy', dest='enable_copy', action='store_true')\nparser.add_argument('-no_copy', dest='enable_copy', action='store_false')\nparser.set_defaults(enable_copy=True)\n\n# training\nparser.add_argument('-optimizer', default='adam')\nparser.add_argument('-clip_grad', default=0., type=float)\nparser.add_argument('-train_patience', default=10, type=int)\nparser.add_argument('-max_epoch', default=50, type=int)\nparser.add_argument('-batch_size', default=10, type=int)\nparser.add_argument('-valid_per_batch', default=4000, type=int)\nparser.add_argument('-save_per_batch', default=4000, type=int)\nparser.add_argument('-valid_metric', default='bleu')\n\n# decoding\nparser.add_argument('-beam_size', default=15, type=int)\nparser.add_argument('-max_query_length', default=70, type=int)\nparser.add_argument('-decode_max_time_step', default=100, type=int)\nparser.add_argument('-head_nt_constraint', dest='head_nt_constraint', action='store_true')\nparser.add_argument('-no_head_nt_constraint', dest='head_nt_constraint', action='store_false')\nparser.set_defaults(head_nt_constraint=True)\n\nargs = parser.parse_args(args=['-data_type', 'django', '-data', 'data/django.cleaned.dataset.freq5.par_info.refact.space_only.bin',\n                               '-model', 'models/model.django_word128_encoder256_rule128_node64.beam15.adam.simple_trans.no_unary_closure.8e39832.run3.best_acc.npz'])\nif args.data_type == 'hs':\n    args.decode_max_time_step = 350\n\nlogging.info('loading dataset [%s]', args.data)\ntrain_data, dev_data, test_data = deserialize_from_file(args.data)\n\nif not args.source_vocab_size:\n    args.source_vocab_size = train_data.annot_vocab.size\nif not args.target_vocab_size:\n    args.target_vocab_size = train_data.terminal_vocab.size\nif not args.rule_num:\n    args.rule_num = len(train_data.grammar.rules)\nif not args.node_num:\n    args.node_num = len(train_data.grammar.node_type_to_id)\n\nconfig_module = sys.modules['config']\nfor name, value in vars(args).iteritems():\n    setattr(config_module, name, value)\n\n# build the model\nmodel = Model()\nmodel.build()\nmodel.load(args.model)\n\ndef decode_query(query):\n    \"\"\"decode a given natural language query, return a list of generated candidates\"\"\"\n    query, str_map = canonicalize_query(query)\n    vocab = train_data.annot_vocab\n    query_tokens = query.split(' ')\n    query_tokens_data = [query_to_data(query, vocab)]\n    example = namedtuple('example', ['query', 'data'])(query=query_tokens, data=query_tokens_data)\n\n    cand_list = model.decode(example, train_data.grammar, train_data.terminal_vocab,\n                             beam_size=args.beam_size, max_time_step=args.decode_max_time_step, log=True)\n\n    return cand_list\n\nif __name__ == '__main__':\n    print 'run in interactive mode'\n    while True:\n        query = raw_input('input a query: ')\n        cand_list = decode_query(query)\n\n        # output top 5 candidates\n        for cid, cand in enumerate(cand_list[:5]):\n            print '*' * 60\n            print 'cand #%d, score: %f' % (cid, cand.score)\n\n            try:\n                ast_tree = decode_tree_to_python_ast(cand.tree)\n                code = astor.to_source(ast_tree)\n                print 'code: ', code\n                print 'decode log: ', cand.log\n            except:\n                print \"Exception in converting tree to code:\"\n                print '-' * 60\n                print 'raw_id: %d, beam pos: %d' % (example.raw_id, cid)\n                traceback.print_exc(file=sys.stdout)\n                print '-' * 60\n            finally:\n                print '* parse tree *'\n                print cand.tree.__repr__()\n                print 'n_timestep: %d' % cand.n_timestep\n                print 'ast size: %d' % cand.tree.size\n                print '*' * 60"
  },
  {
    "path": "lang/__init__.py",
    "content": ""
  },
  {
    "path": "lang/grammar.py",
    "content": "from collections import OrderedDict, defaultdict\nimport logging\n\nfrom astnode import ASTNode\nfrom lang.util import typename\n\nclass Grammar(object):\n    def __init__(self, rules):\n        \"\"\"\n        instantiate a grammar with a set of production rules of type Rule\n        \"\"\"\n        self.rules = rules\n        self.rule_index = defaultdict(list)\n        self.rule_to_id = OrderedDict()\n\n        node_types = set()\n        lhs_nodes = set()\n        rhs_nodes = set()\n        for rule in self.rules:\n            self.rule_index[rule.parent].append(rule)\n\n            # we also store all unique node types\n            for node in rule.nodes:\n                node_types.add(typename(node.type))\n\n            lhs_nodes.add(rule.parent)\n            for child in rule.children:\n                rhs_nodes.add(child.as_type_node)\n\n        root_node = lhs_nodes - rhs_nodes\n        assert len(root_node) == 1\n        self.root_node = next(iter(root_node))\n\n        self.terminal_nodes = rhs_nodes - lhs_nodes\n        self.terminal_types = set([n.type for n in self.terminal_nodes])\n\n        self.node_type_to_id = OrderedDict()\n        for i, type in enumerate(node_types, start=0):\n            self.node_type_to_id[type] = i\n\n        for gid, rule in enumerate(rules, start=0):\n            self.rule_to_id[rule] = gid\n\n        self.id_to_rule = OrderedDict((v, k) for (k, v) in self.rule_to_id.iteritems())\n\n        logging.info('num. rules: %d', len(self.rules))\n        logging.info('num. types: %d', len(self.node_type_to_id))\n        logging.info('root: %s', self.root_node)\n        logging.info('terminals: %s', ', '.join(repr(n) for n in self.terminal_nodes))\n\n    def __iter__(self):\n        return self.rules.__iter__()\n\n    def __len__(self):\n        return len(self.rules)\n\n    def __getitem__(self, lhs):\n        key_node = ASTNode(lhs.type, None)  # Rules are indexed by types only\n        if key_node in self.rule_index:\n            return self.rule_index[key_node]\n        else:\n            KeyError('key=%s' % key_node)\n\n    def get_node_type_id(self, node):\n        from astnode import ASTNode\n\n        if isinstance(node, ASTNode):\n            type_repr = typename(node.type)\n            return self.node_type_to_id[type_repr]\n        else:\n            # assert isinstance(node, str)\n            # it is a type\n            type_repr = typename(node)\n            return self.node_type_to_id[type_repr]\n\n    def is_terminal(self, node):\n        return node.type in self.terminal_types\n\n    def is_value_node(self, node):\n        raise NotImplementedError\n"
  },
  {
    "path": "lang/ifttt/__init__.py",
    "content": ""
  },
  {
    "path": "lang/ifttt/grammar.py",
    "content": "from lang.grammar import Grammar\n\nclass IFTTTGrammar(Grammar):\n    def __init__(self, rules):\n        super(IFTTTGrammar, self).__init__(rules)\n\n    def is_value_node(self, node):\n        return False"
  },
  {
    "path": "lang/ifttt/ifttt_dataset.py",
    "content": "# -*- coding: UTF-8 -*-\nfrom __future__ import division\nimport string\nfrom collections import OrderedDict\nfrom collections import defaultdict\nfrom itertools import count\n\nfrom nn.utils.io_utils import serialize_to_file, deserialize_from_file\n\nfrom lang.ifttt.grammar import IFTTTGrammar\nfrom parse import ifttt_ast_to_parse_tree\nfrom lang.grammar import Grammar\nimport logging\nfrom itertools import chain\n\nfrom nn.utils.generic_utils import init_logging\n\nfrom dataset import gen_vocab, DataSet, DataEntry, Action, APPLY_RULE, GEN_TOKEN, COPY_TOKEN, GEN_COPY_TOKEN\n\ndef load_examples(data_file):\n    f = open(data_file)\n    next(f)\n    examples = []\n    for line in f:\n        d = line.strip().split('\\t')\n        description = d[4]\n        code = d[9]\n        parse_tree = ifttt_ast_to_parse_tree(code)\n\n        examples.append({'description': description, 'parse_tree': parse_tree, 'code': code})\n\n    return examples\n\n\ndef analyze_ifttt_dataset():\n    data_file = '/Users/yinpengcheng/Research/SemanticParsing/ifttt/recipe_summaries.all.tsv'\n    examples = load_examples(data_file)\n\n    rule_num = 0.\n    max_rule_num = -1\n    example_with_max_rule_num = -1\n\n    for idx, example in enumerate(examples):\n        parse_tree = example['parse_tree']\n        rules, _ = parse_tree.get_productions(include_value_node=True)\n\n        rule_num += len(rules)\n        if max_rule_num < len(rules):\n            max_rule_num = len(rules)\n            example_with_max_rule_num = idx\n\n    logging.info('avg. num. of rules: %f', rule_num / len(examples))\n    logging.info('max_rule_num: %d', max_rule_num)\n    logging.info('example_with_max_rule_num: %d', example_with_max_rule_num)\n\n\ndef canonicalize_ifttt_example(annot, code):\n    parse_tree = ifttt_ast_to_parse_tree(code, attach_func_to_channel=False)\n    clean_code = str(parse_tree)\n    clean_query_tokens = annot.split()\n    clean_query_tokens = [t.lower() for t in clean_query_tokens]\n\n    return clean_query_tokens, clean_code, parse_tree\n\n\ndef preprocess_ifttt_dataset(annot_file, code_file):\n    f = open('ifttt_dataset.examples.txt', 'w')\n\n    examples = []\n\n    for idx, (annot, code) in enumerate(zip(open(annot_file), open(code_file))):\n        annot = annot.strip()\n        code = code.strip()\n\n        clean_query_tokens, clean_code, parse_tree = canonicalize_ifttt_example(annot, code)\n        example = {'id': idx, 'query_tokens': clean_query_tokens, 'code': clean_code, 'parse_tree': parse_tree,\n                   'str_map': None, 'raw_code': code}\n        examples.append(example)\n\n        f.write('*' * 50 + '\\n')\n        f.write('example# %d\\n' % idx)\n        f.write(' '.join(clean_query_tokens) + '\\n')\n        f.write('\\n')\n        f.write(clean_code + '\\n')\n        f.write('*' * 50 + '\\n')\n\n        idx += 1\n\n    f.close()\n\n    print 'preprocess_dataset: cleaned example num: %d' % len(examples)\n\n    return examples\n\n\ndef get_grammar(parse_trees):\n    rules = set()\n\n    for parse_tree in parse_trees:\n        parse_tree_rules, rule_parents = parse_tree.get_productions()\n        for rule in parse_tree_rules:\n            rules.add(rule)\n\n    rules = list(sorted(rules, key=lambda x: x.__repr__()))\n    grammar = IFTTTGrammar(rules)\n\n    logging.info('num. rules: %d', len(rules))\n\n    with open('grammar.txt', 'w') as f:\n        for rule in grammar:\n            str = rule.__repr__()\n            f.write(str + '\\n')\n\n    with open('parse_trees.txt', 'w') as f:\n        for tree in parse_trees:\n            f.write(tree.__repr__() + '\\n')\n\n    return grammar\n\n\ndef parse_ifttt_dataset():\n    WORD_FREQ_CUT_OFF = 2\n\n    annot_file = '/Users/yinpengcheng/Research/SemanticParsing/ifttt/Data/lang.all.txt'\n    code_file = '/Users/yinpengcheng/Research/SemanticParsing/ifttt/Data/code.all.txt'\n\n    data = preprocess_ifttt_dataset(annot_file, code_file)\n\n    # build the grammar\n    grammar = get_grammar([e['parse_tree'] for e in data])\n\n    annot_tokens = list(chain(*[e['query_tokens'] for e in data]))\n    annot_vocab = gen_vocab(annot_tokens, vocab_size=30000, freq_cutoff=WORD_FREQ_CUT_OFF)\n\n    logging.info('annot vocab. size: %d', annot_vocab.size)\n\n    # we have no terminal tokens in ifttt\n    all_terminal_tokens = []\n    terminal_vocab = gen_vocab(all_terminal_tokens, vocab_size=4000, freq_cutoff=WORD_FREQ_CUT_OFF)\n\n    # now generate the dataset!\n\n    train_data = DataSet(annot_vocab, terminal_vocab, grammar, 'ifttt.train_data')\n    dev_data = DataSet(annot_vocab, terminal_vocab, grammar, 'ifttt.dev_data')\n    test_data = DataSet(annot_vocab, terminal_vocab, grammar, 'ifttt.test_data')\n\n    all_examples = []\n\n    can_fully_reconstructed_examples_num = 0\n    examples_with_empty_actions_num = 0\n\n    for entry in data:\n        idx = entry['id']\n        query_tokens = entry['query_tokens']\n        code = entry['code']\n        parse_tree = entry['parse_tree']\n\n        # check if query tokens are valid\n        query_token_ids = [annot_vocab[token] for token in query_tokens if token not in string.punctuation]\n        valid_query_tokens_ids = [tid for tid in query_token_ids if tid != annot_vocab.unk]\n\n        # remove examples with rare words from train and dev, avoid overfitting\n        if len(valid_query_tokens_ids) == 0 and 0 <= idx < 77495 + 5171:\n            continue\n\n        rule_list, rule_parents = parse_tree.get_productions(include_value_node=True)\n\n        actions = []\n        can_fully_reconstructed = True\n        rule_pos_map = dict()\n\n        for rule_count, rule in enumerate(rule_list):\n            if not grammar.is_value_node(rule.parent):\n                assert rule.value is None\n                parent_rule = rule_parents[(rule_count, rule)][0]\n                if parent_rule:\n                    parent_t = rule_pos_map[parent_rule]\n                else:\n                    parent_t = 0\n\n                rule_pos_map[rule] = len(actions)\n\n                d = {'rule': rule, 'parent_t': parent_t, 'parent_rule': parent_rule}\n                action = Action(APPLY_RULE, d)\n\n                actions.append(action)\n            else:\n                raise RuntimeError('no terminals should be in ifttt dataset!')\n\n        if len(actions) == 0:\n            examples_with_empty_actions_num += 1\n            continue\n\n        example = DataEntry(idx, query_tokens, parse_tree, code, actions,\n                            {'str_map': None, 'raw_code': entry['raw_code']})\n\n        if can_fully_reconstructed:\n            can_fully_reconstructed_examples_num += 1\n\n        # train, valid, test splits\n        if 0 <= idx < 77495:\n            train_data.add(example)\n        elif idx < 77495 + 5171:\n            dev_data.add(example)\n        else:\n            test_data.add(example)\n\n        all_examples.append(example)\n\n    # print statistics\n    max_query_len = max(len(e.query) for e in all_examples)\n    max_actions_len = max(len(e.actions) for e in all_examples)\n\n    # serialize_to_file([len(e.query) for e in all_examples], 'query.len')\n    # serialize_to_file([len(e.actions) for e in all_examples], 'actions.len')\n\n    logging.info('train_data examples: %d', train_data.count)\n    logging.info('dev_data examples: %d', dev_data.count)\n    logging.info('test_data examples: %d', test_data.count)\n\n    logging.info('examples that can be fully reconstructed: %d/%d=%f',\n                 can_fully_reconstructed_examples_num, len(all_examples),\n                 can_fully_reconstructed_examples_num / len(all_examples))\n    logging.info('empty_actions_count: %d', examples_with_empty_actions_num)\n\n    logging.info('max_query_len: %d', max_query_len)\n    logging.info('max_actions_len: %d', max_actions_len)\n\n    train_data.init_data_matrices(max_query_length=40, max_example_action_num=6)\n    dev_data.init_data_matrices()\n    test_data.init_data_matrices()\n\n    serialize_to_file((train_data, dev_data, test_data),\n                      'data/ifttt.freq{WORD_FREQ_CUT_OFF}.bin'.format(WORD_FREQ_CUT_OFF=WORD_FREQ_CUT_OFF))\n\n    return train_data, dev_data, test_data\n\n\ndef parse_data_for_seq2seq(data_file='data/ifttt.freq3.bin'):\n    train_data, dev_data, test_data = deserialize_from_file(data_file)\n    prefix = 'data/seq2seq/'\n\n    for dataset, output in [(train_data, prefix + 'ifttt.train'),\n                            (dev_data, prefix + 'ifttt.dev'),\n                            (test_data, prefix + 'ifttt.test')]:\n        f_source = open(output + '.desc', 'w')\n        f_target = open(output + '.code', 'w')\n\n        if 'test' in output:\n            raw_ids = [int(i.strip()) for i in open('data/ifff.test_data.gold.id')]\n            eids = [i for i, e in enumerate(test_data.examples) if e.raw_id in raw_ids]\n            dataset = test_data.get_dataset_by_ids(eids, test_data.name + '.subset')\n\n        for e in dataset.examples:\n            query_tokens = e.query\n            trigger = e.parse_tree['TRIGGER'].children[0].type + ' . ' + e.parse_tree['TRIGGER'].children[0].children[0].type\n            action = e.parse_tree['ACTION'].children[0].type + ' . ' + e.parse_tree['ACTION'].children[0].children[0].type\n            code = 'IF ' + trigger + ' THEN ' + action\n\n            f_source.write(' '.join(query_tokens) + '\\n')\n            f_target.write(code + '\\n')\n\n        f_source.close()\n        f_target.close()\n\n\ndef extract_turk_data():\n    turk_annot_file = '/Users/yinpengcheng/Research/SemanticParsing/ifttt/public_release/data/turk_public.tsv'\n    reference_file = '/Users/yinpengcheng/Research/SemanticParsing/ifttt/public_release/data/ifttt_public.tsv'\n\n    f_turk = open(turk_annot_file)\n    next(f_turk)\n\n    annot_data = OrderedDict()\n    for line in f_turk:\n        d = line.strip().split('\\t')\n        url = d[0]\n        if url not in annot_data:\n            annot_data[url] = list()\n\n        annot_data[url].append({'trigger_channel': d[2], 'trigger_func': d[3], 'action_channel': d[4], 'action_func': d[5]})\n\n    f_ref = open(reference_file)\n    next(f_ref)\n    ref_data = OrderedDict()\n    for line in f_ref:\n        d = line.strip().split('\\t')\n        url = d[0]\n\n        ref_data[url] = {'trigger_channel': d[2], 'trigger_func': d[3], 'action_channel': d[4], 'action_func': d[5]}\n\n    lt_three_agree_with_gold = []\n    non_english_examples = []\n    unintelligible_examples = []\n    for url, annots in annot_data.iteritems():\n        vote_dict = defaultdict(int)\n        ref = ref_data[url]\n        match_with_gold_num = 0\n        non_english_num = unintelligible_num = 0\n        non_english_annots = []\n        unintelligible_annots = []\n\n        for annot in annots:\n            if annot['trigger_channel'] == ref['trigger_channel'] and annot['trigger_func'] == ref['trigger_func'] and \\\n                annot['action_channel'] == ref['action_channel'] and annot['action_func'] == ref['action_func']:\n                match_with_gold_num += 1\n            vote_dict['#'.join(annot.values())] += 1\n\n        for i, annot in enumerate(annots):\n            if annot['trigger_channel'] == 'nonenglish' and annot['trigger_func'] == 'nonenglish' and \\\n                annot['action_channel'] == 'nonenglish' and annot['action_func'] == 'nonenglish':\n                non_english_num += 1\n                non_english_annots.append(i)\n\n            if annot['trigger_channel'] == 'unintelligible' and annot['trigger_func'] == 'unintelligible' and \\\n                annot['action_channel'] == 'unintelligible' and annot['action_func'] == 'unintelligible':\n                unintelligible_num += 1\n                unintelligible_annots.append(i)\n\n        max_vote_num = max(vote_dict.values())\n\n        # omitting descriptions marked as non-English by a majority of the crowdsourced workers\n        if non_english_num == max_vote_num:\n            non_english_examples.append(url)\n\n        non_english_and_unintelligible_num = len(set(non_english_annots).union(set(unintelligible_annots)))\n        # if this example has no non_english and unintelligible annotations\n        if non_english_and_unintelligible_num > 0: # < len(annots) - non_english_and_unintelligible_num:\n            unintelligible_examples.append(url)\n\n        if match_with_gold_num >= 3:\n            lt_three_agree_with_gold.append(url)\n\n    omit_non_english_examples = set(annot_data) - set(non_english_examples)\n    omit_unintelligible_examples = set(annot_data) - set(unintelligible_examples)\n    print len(omit_non_english_examples) # should be 3,741\n    print len(omit_unintelligible_examples) # should be 2,262\n    print len(lt_three_agree_with_gold) # should be 758\n\n    url2id = defaultdict(count(0).next)\n    for url in ref_data:\n        url2id[url] = url2id[url] + 77495 + 5171\n\n    f_gold = open('data/ifff.test_data.gold.id', 'w')\n    for url in lt_three_agree_with_gold:\n        i = url2id[url]\n        f_gold.write(str(i) + '\\n')\n    f_gold.close()\n\n    f_gold = open('data/ifff.test_data.omit_unintelligible.id', 'w')\n    for url in omit_unintelligible_examples:\n        i = url2id[url]\n        f_gold.write(str(i) + '\\n')\n    f_gold.close()\n\n    f_gold = open('data/ifff.test_data.omit_non_english.id', 'w')\n    for url in omit_non_english_examples:\n        i = url2id[url]\n        f_gold.write(str(i) + '\\n')\n    f_gold.close()\n\n    omit_non_english_examples = [url2id[url] for url in omit_non_english_examples]\n    omit_unintelligible_examples = [url2id[url] for url in omit_unintelligible_examples]\n    lt_three_agree_with_gold = [url2id[url] for url in lt_three_agree_with_gold]\n\n    return omit_non_english_examples, omit_unintelligible_examples, lt_three_agree_with_gold\n\nif __name__ == '__main__':\n    init_logging('ifttt.log')\n    # parse_ifttt_dataset()\n    # analyze_ifttt_dataset()\n    extract_turk_data()\n    # parse_data_for_seq2seq()\n"
  },
  {
    "path": "lang/ifttt/parse.py",
    "content": "from astnode import ASTNode\n\ndef ifttt_ast_to_parse_tree_helper(s, offset):\n    \"\"\"\n    adapted from ifttt codebase\n    \"\"\"\n    if s[offset] != '(':\n        raise RuntimeError('malformed string: node did not start with open paren at position ' + offset)\n\n    offset += 1\n    # extract node name(type)\n    name = ''\n    if s[offset] == '\\\"':\n        offset += 1\n        while s[offset] != '\\\"':\n            if s[offset] == '\\\\':\n                offset += 1\n            name += s[offset]\n            offset += 1\n        offset += 1\n    else:\n        while s[offset] != ' ' and s[offset] != ')':\n            name += s[offset]\n            offset += 1\n\n    node = ASTNode(name)\n    while True:\n        if s[offset] == ')':\n            offset += 1\n            return node, offset\n        if s[offset] != ' ':\n            raise RuntimeError('malformed string: node should have either had a '\n                               'close paren or a space at position ' + offset)\n        offset += 1\n        child_node, offset = ifttt_ast_to_parse_tree_helper(s, offset)\n        node.add_child(child_node)\n\n\ndef ifttt_ast_to_parse_tree(s, attach_func_to_channel=True):\n    parse_tree, _ = ifttt_ast_to_parse_tree_helper(s, 0)\n    parse_tree = strip_params(parse_tree)\n\n    if attach_func_to_channel:\n        parse_tree = attach_function_to_channel(parse_tree)\n\n    return parse_tree\n\n\ndef strip_params(parse_tree):\n    if parse_tree.type == 'PARAMS':\n        raise RuntimeError('should not go to here!')\n\n    parse_tree.children = [c for c in parse_tree.children if c.type != 'PARAMS' and c.type != 'OUTPARAMS']\n    for i, child in enumerate(parse_tree.children):\n        parse_tree.children[i] = strip_params(child)\n\n    return parse_tree\n\n\ndef attach_function_to_channel(parse_tree):\n    trigger_func = parse_tree['TRIGGER']['FUNC'].children\n    assert len(trigger_func) == 1\n\n    trigger_func = trigger_func[0]\n    parse_tree['TRIGGER'].children[0].add_child(trigger_func)\n\n    del parse_tree['TRIGGER']['FUNC']\n\n    action_func = parse_tree['ACTION']['FUNC'].children\n    assert len(action_func) == 1\n\n    action_func = action_func[0]\n    parse_tree['ACTION'].children[0].add_child(action_func)\n\n    del parse_tree['ACTION']['FUNC']\n\n    return parse_tree\n\n\nif __name__ == '__main__':\n    tree_code = \"\"\"(ROOT (IF) (TRIGGER (Instagram) (FUNC (Any_new_photo_by_you) (PARAMS))) (THEN) (ACTION (Dropbox) (FUNC (Add_file_from_URL) (PARAMS (File_URL ({{Caption}})) (File_name (\"\")) (Dropbox_folder_path (ifttt/instagram))))))\"\"\"\n    parse_tree = ifttt_ast_to_parse_tree(tree_code)\n\n    print parse_tree\n    print strip_params(parse_tree)\n    print attach_function_to_channel(parse_tree)"
  },
  {
    "path": "lang/py/__init__.py",
    "content": ""
  },
  {
    "path": "lang/py/grammar.py",
    "content": "\"\"\"\nPython grammar and typing system\n\"\"\"\nimport ast\nimport inspect\nimport astor\n\nfrom lang.grammar import Grammar\n\nPY_AST_NODE_FIELDS = {\n    'FunctionDef': {\n        'name': {\n            'type': str,\n            'is_list': False,\n            'is_optional': False\n        },\n        'args': {\n            'type': ast.arguments,\n            'is_list': False,\n            'is_optional': False\n        },\n        'body': {\n            'type': ast.stmt,\n            'is_list': True,\n            'is_optional': False\n        },\n        'decorator_list': {\n            'type': ast.expr,\n            'is_list': True,\n            'is_optional': False\n        }\n    },\n    'ClassDef': {\n        'name': {\n            'type': ast.arguments,\n            'is_list': False,\n            'is_optional': False\n        },\n        'bases': {\n            'type': ast.expr,\n            'is_list': True,\n            'is_optional': False\n        },\n        'body': {\n            'type': ast.stmt,\n            'is_list': True,\n            'is_optional': False\n        },\n        'decorator_list': {\n            'type': ast.expr,\n            'is_list': True,\n            'is_optional': False\n        }\n    },\n    'Return': {\n        'value': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': True\n        },\n    },\n    'Delete': {\n        'targets': {\n            'type': ast.expr,\n            'is_list': True,\n            'is_optional': False\n        },\n    },\n    'Assign': {\n        'targets': {\n            'type': ast.expr,\n            'is_list': True,\n            'is_optional': False\n        },\n        'value': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': False\n        }\n    },\n    'AugAssign': {\n        'target': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': False\n        },\n        'op': {\n            'type': ast.operator,\n            'is_list': False,\n            'is_optional': False\n        },\n        'value': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': False\n        }\n    },\n    'Print': {\n        'dest': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': True\n        },\n        'values': {\n            'type': ast.expr,\n            'is_list': True,\n            'is_optional': False\n        },\n        'nl': {\n            'type': bool,\n            'is_list': False,\n            'is_optional': False\n        }\n    },\n    'For': {\n        'target': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': False\n        },\n        'iter': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': False\n        },\n        'body': {\n            'type': ast.stmt,\n            'is_list': True,\n            'is_optional': False\n        },\n        'orelse': {\n            'type': ast.stmt,\n            'is_list': True,\n            'is_optional': False\n        }\n    },\n    'While': {\n        'test': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': False\n        },\n        'body': {\n            'type': ast.stmt,\n            'is_list': True,\n            'is_optional': False\n        },\n        'orelse': {\n            'type': ast.stmt,\n            'is_list': True,\n            'is_optional': False\n        },\n    },\n    'If': {\n        'test': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': False\n        },\n        'body': {\n            'type': ast.stmt,\n            'is_list': True,\n            'is_optional': False\n        },\n        'orelse': {\n            'type': ast.stmt,\n            'is_list': True,\n            'is_optional': False\n        },\n    },\n    'With': {\n        'context_expr': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': False\n        },\n        'optional_vars': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': True\n        },\n        'body': {\n            'type': ast.stmt,\n            'is_list': True,\n            'is_optional': False\n        },\n    },\n    'Raise': {\n        'type': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': True\n        },\n        'inst': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': True\n        },\n        'tback': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': True\n        },\n    },\n    'TryExcept': {\n        'body': {\n            'type': ast.stmt,\n            'is_list': True,\n            'is_optional': False\n        },\n        'handlers': {\n            'type': ast.excepthandler,\n            'is_list': True,\n            'is_optional': False\n        },\n        'orelse': {\n            'type': ast.stmt,\n            'is_list': True,\n            'is_optional': False\n        },\n    },\n    'TryFinally': {\n        'body': {\n            'type': ast.stmt,\n            'is_list': True,\n            'is_optional': False\n        },\n        'finalbody': {\n            'type': ast.stmt,\n            'is_list': True,\n            'is_optional': False\n        }\n    },\n    'Assert': {\n        'test': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': False\n        },\n        'msg': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': True\n        }\n    },\n    'Import': {\n        'names': {\n            'type': ast.alias,\n            'is_list': True,\n            'is_optional': False\n        }\n    },\n    'ImportFrom': {\n        'module': {\n            'type': str,\n            'is_list': False,\n            'is_optional': True\n        },\n        'names': {\n            'type': ast.alias,\n            'is_list': True,\n            'is_optional': False\n        },\n        'level': {\n            'type': int,\n            'is_list': False,\n            'is_optional': True\n        }\n    },\n    'Exec': {\n        'body': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': False\n        },\n        'globals': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': True\n        },\n        'locals': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': True\n        },\n    },\n    'Global': {\n        'names': {\n            'type': str,\n            'is_list': True,\n            'is_optional': False\n        },\n    },\n    'Expr': {\n        'value': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': False\n        },\n    },\n    'BoolOp': {\n        'op': {\n            'type': ast.boolop,\n            'is_list': False,\n            'is_optional': False\n        },\n        'values': {\n            'type': ast.expr,\n            'is_list': True,\n            'is_optional': False\n        },\n    },\n    'BinOp': {\n        'left': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': False\n        },\n        'op': {\n            'type': ast.operator,\n            'is_list': False,\n            'is_optional': False\n        },\n        'right': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': False\n        },\n    },\n    'UnaryOp': {\n        'op': {\n            'type': ast.unaryop,\n            'is_list': False,\n            'is_optional': False\n        },\n        'operand': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': False\n        },\n    },\n    'Lambda': {\n        'args': {\n            'type': ast.arguments,\n            'is_list': False,\n            'is_optional': False\n        },\n        'body': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': False\n        },\n    },\n    'IfExp': {\n        'test': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': False\n        },\n        'body': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': False\n        },\n        'orelse': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': False\n        },\n    },\n    'Dict': {\n        'keys': {\n            'type': ast.expr,\n            'is_list': True,\n            'is_optional': False\n        },\n        'values': {\n            'type': ast.expr,\n            'is_list': True,\n            'is_optional': False\n        },\n    },\n    'Set': {\n        'elts': {\n            'type': ast.expr,\n            'is_list': True,\n            'is_optional': False\n        },\n    },\n    'ListComp': {\n        'elt': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': False\n        },\n        'generators': {\n            'type': ast.comprehension,\n            'is_list': True,\n            'is_optional': False\n        },\n    },\n    'SetComp': {\n        'elt': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': False\n        },\n        'generators': {\n            'type': ast.comprehension,\n            'is_list': True,\n            'is_optional': False\n        },\n    },\n    'DictComp': {\n        'key': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': False\n        },\n        'value': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': False\n        },\n        'generators': {\n            'type': ast.comprehension,\n            'is_list': True,\n            'is_optional': False\n        },\n    },\n    'GeneratorExp': {\n        'elt': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': False\n        },\n        'generators': {\n            'type': ast.comprehension,\n            'is_list': True,\n            'is_optional': False\n        },\n    },\n    'Yield': {\n        'value': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': True\n        }\n    },\n    'Compare': {\n        'left': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': False\n        },\n        'ops': {\n            'type': ast.cmpop,\n            'is_list': True,\n            'is_optional': False\n        },\n        'comparators': {\n            'type': ast.expr,\n            'is_list': True,\n            'is_optional': False\n        },\n    },\n    'Call': {\n        'func': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': False\n        },\n        'args': {\n            'type': ast.expr,\n            'is_list': True,\n            'is_optional': False\n        },\n        'keywords': {\n            'type': ast.keyword,\n            'is_list': True,\n            'is_optional': False\n        },\n        'starargs': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': True\n        },\n        'kwargs': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': True\n        },\n    },\n    'Repr': {\n        'value': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': False\n        }\n    },\n    'Num': {\n        'n': {\n            'type': object,  #FIXME: should be float or int?\n            'is_list': False,\n            'is_optional': False\n        }\n    },\n    'Str': {\n        's': {\n            'type': str,\n            'is_list': False,\n            'is_optional': False\n        }\n    },\n    'Attribute': {\n        'value': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': False\n        },\n        'attr': {\n            'type': str,\n            'is_list': False,\n            'is_optional': False\n        },\n        'ctx': {\n            'type': ast.expr_context,\n            'is_list': False,\n            'is_optional': False\n        },\n    },\n    'Subscript': {\n        'value': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': False\n        },\n        'slice': {\n            'type': ast.slice,\n            'is_list': False,\n            'is_optional': False\n        },\n    },\n    'Name': {\n        'id': {\n            'type': str,\n            'is_list': False,\n            'is_optional': False\n        }\n    },\n    'List': {\n        'elts': {\n            'type': ast.expr,\n            'is_list': True,\n            'is_optional': False\n        },\n        'ctx': {\n            'type': ast.expr_context,\n            'is_list': False,\n            'is_optional': False\n        },\n    },\n    'Tuple': {\n        'elts': {\n            'type': ast.expr,\n            'is_list': True,\n            'is_optional': False\n        },\n        'ctx': {\n            'type': ast.expr_context,\n            'is_list': False,\n            'is_optional': False\n        },\n    },\n    'ExceptHandler': {\n        'type': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': True\n        },\n        'name': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': True\n        },\n        'body': {\n            'type': ast.stmt,\n            'is_list': True,\n            'is_optional': False\n        }\n    },\n    'arguments': {\n        'args': {\n            'type': ast.expr,\n            'is_list': True,\n            'is_optional': False\n        },\n        'vararg': {\n            'type': str,\n            'is_list': False,\n            'is_optional': True\n        },\n        'kwarg': {\n            'type': str,\n            'is_list': False,\n            'is_optional': True\n        },\n        'defaults': {\n            'type': ast.expr,\n            'is_list': True,\n            'is_optional': False\n        },\n    },\n    'comprehension': {\n        'target': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': False\n        },\n        'iter': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': False\n        },\n        'ifs': {\n            'type': ast.expr,\n            'is_list': True,\n            'is_optional': False\n        },\n    },\n    'keyword': {\n        'arg': {\n            'type': str,\n            'is_list': False,\n            'is_optional': False\n        },\n        'value': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': True\n        }\n    },\n    'alias': {\n        'name': {\n            'type': str,\n            'is_list': False,\n            'is_optional': False\n        },\n        'asname': {\n            'type': str,\n            'is_list': False,\n            'is_optional': True\n        }\n    },\n    'Slice': {\n        'lower': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': True\n        },\n        'upper': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': True\n        },\n        'step': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': True\n        }\n    },\n    'ExtSlice': {\n        'dims': {\n            'type': ast.slice,\n            'is_list': True,\n            'is_optional': False\n        }\n    },\n    'Index': {\n        'value': {\n            'type': ast.expr,\n            'is_list': False,\n            'is_optional': False\n        }\n    }\n}\n\nNODE_FIELD_BLACK_LIST = {'ctx'}\n\nTERMINAL_AST_TYPES = {\n    ast.Pass,\n    ast.Break,\n    ast.Continue,\n    ast.Add,\n    ast.BitAnd,\n    ast.BitOr,\n    ast.BitXor,\n    ast.Div,\n    ast.FloorDiv,\n    ast.LShift,\n    ast.Mod,\n    ast.Mult,\n    ast.Pow,\n    ast.Sub,\n    ast.And,\n    ast.Or,\n    ast.Eq,\n    ast.Gt,\n    ast.GtE,\n    ast.In,\n    ast.Is,\n    ast.IsNot,\n    ast.Lt,\n    ast.LtE,\n    ast.NotEq,\n    ast.NotIn,\n    ast.Not,\n    ast.USub\n}\n\n\ndef is_builtin_type(x):\n    return x == str or x == int or x == float or x == bool or x == object or x == 'identifier'\n\n\ndef is_terminal_ast_type(x):\n    if inspect.isclass(x) and x in TERMINAL_AST_TYPES:\n        return True\n\n    return False\n\n\n# def is_terminal_type(x):\n#     if is_builtin_type(x):\n#         return True\n#\n#     if x == 'epsilon':\n#         return True\n#\n#     if inspect.isclass(x) and (issubclass(x, ast.Pass) or issubclass(x, ast.Raise) or issubclass(x, ast.Break)\n#                                or issubclass(x, ast.Continue)\n#                                or issubclass(x, ast.Return)\n#                                or issubclass(x, ast.operator) or issubclass(x, ast.boolop)\n#                                or issubclass(x, ast.Ellipsis) or issubclass(x, ast.unaryop)\n#                                or issubclass(x, ast.cmpop)):\n#         return True\n#\n#     return False\n\n\n# class Node:\n#     def __init__(self, node_type, label):\n#         self.type = node_type\n#         self.label = label\n#\n#     @property\n#     def is_preterminal(self):\n#         return is_terminal_type(self.type)\n#\n#     def __eq__(self, other):\n#         return self.type == other.type and self.label == other.label\n#\n#     def __hash__(self):\n#         return typename(self.type).__hash__() ^ self.label.__hash__()\n#\n#     def __repr__(self):\n#         repr_str = typename(self.type)\n#         if self.label:\n#             repr_str += '{%s}' % self.label\n#         return repr_str\n#\n#\n# class TypedRule:\n#     def __init__(self, parent, children, tree=None):\n#         self.parent = parent\n#         if isinstance(children, list) or isinstance(children, tuple):\n#             self.children = tuple(children)\n#         else:\n#             self.children = (children, )\n#\n#         # tree property is not incorporated in eq, hash\n#         self.tree = tree\n#\n#     # @property\n#     # def is_terminal_rule(self):\n#     #     return is_terminal_type(self.parent.type)\n#\n#     def __eq__(self, other):\n#         return self.parent == other.parent and self.children == other.children\n#\n#     def __hash__(self):\n#         return self.parent.__hash__() ^ self.children.__hash__()\n#\n#     def __repr__(self):\n#         return '%s -> %s' % (self.parent, ', '.join([c.__repr__() for c in self.children]))\n\n\ndef type_str_to_type(type_str):\n    if type_str.endswith('*') or type_str == 'root' or type_str == 'epsilon':\n        return type_str\n    else:\n        try:\n            type_obj = eval(type_str)\n            if is_builtin_type(type_obj):\n                return type_obj\n        except:\n            pass\n\n        try:\n            type_obj = eval('ast.' + type_str)\n            return type_obj\n        except:\n            raise RuntimeError('unidentified type string: %s' % type_str)\n\n\ndef is_compositional_leaf(node):\n    is_leaf = True\n\n    for field_name, field_value in ast.iter_fields(node):\n        if field_name in NODE_FIELD_BLACK_LIST:\n            continue\n\n        if field_value is None:\n            is_leaf &= True\n        elif isinstance(field_value, list) and len(field_value) == 0:\n            is_leaf &= True\n        else:\n            is_leaf &= False\n    return is_leaf\n\n\nclass PythonGrammar(Grammar):\n    def __init__(self, rules):\n        super(PythonGrammar, self).__init__(rules)\n\n    def is_value_node(self, node):\n        return is_builtin_type(node.type)\n"
  },
  {
    "path": "lang/py/parse.py",
    "content": "import ast\nimport logging\nimport re\nimport token as tk\nfrom cStringIO import StringIO\nfrom tokenize import generate_tokens\n\nfrom astnode import ASTNode\nfrom lang.py.grammar import is_compositional_leaf, PY_AST_NODE_FIELDS, NODE_FIELD_BLACK_LIST, PythonGrammar\nfrom lang.util import escape\nfrom lang.util import typename\n\n\ndef python_ast_to_parse_tree(node):\n    assert isinstance(node, ast.AST)\n\n    node_type = type(node)\n    tree = ASTNode(node_type)\n\n    # it's a leaf AST node, e.g., ADD, Break, etc.\n    if len(node._fields) == 0:\n        return tree\n\n    # if it's a compositional AST node with empty fields\n    if is_compositional_leaf(node):\n        epsilon = ASTNode('epsilon')\n        tree.add_child(epsilon)\n\n        return tree\n\n    fields_info = PY_AST_NODE_FIELDS[node_type.__name__]\n\n    for field_name, field_value in ast.iter_fields(node):\n        # remove ctx stuff\n        if field_name in NODE_FIELD_BLACK_LIST:\n            continue\n\n        # omit empty fields, including empty lists\n        if field_value is None or (isinstance(field_value, list) and len(field_value) == 0):\n            continue\n\n        # now it's not empty!\n        field_type = fields_info[field_name]['type']\n        is_list_field = fields_info[field_name]['is_list']\n\n        if isinstance(field_value, ast.AST):\n            child = ASTNode(field_type, field_name)\n            child.add_child(python_ast_to_parse_tree(field_value))\n        elif type(field_value) is str or type(field_value) is int or \\\n                        type(field_value) is float or type(field_value) is object or \\\n                        type(field_value) is bool:\n            # if field_type != type(field_value):\n            #     print 'expect [%s] type, got [%s]' % (field_type, type(field_value))\n            child = ASTNode(type(field_value), field_name, value=field_value)\n        elif is_list_field:\n            list_node_type = typename(field_type) + '*'\n            child = ASTNode(list_node_type, field_name)\n            for n in field_value:\n                if field_type in {ast.comprehension, ast.excepthandler, ast.arguments, ast.keyword, ast.alias}:\n                    child.add_child(python_ast_to_parse_tree(n))\n                else:\n                    intermediate_node = ASTNode(field_type)\n                    if field_type is str:\n                        intermediate_node.value = n\n                    else:\n                        intermediate_node.add_child(python_ast_to_parse_tree(n))\n                    child.add_child(intermediate_node)\n\n        else:\n            raise RuntimeError('unknown AST node field!')\n\n        tree.add_child(child)\n\n    return tree\n\n\ndef parse_tree_to_python_ast(tree):\n    node_type = tree.type\n    node_label = tree.label\n\n    # remove root\n    if node_type == 'root':\n        return parse_tree_to_python_ast(tree.children[0])\n\n    ast_node = node_type()\n    node_type_name = typename(node_type)\n\n    # if it's a compositional AST node, populate its children nodes,\n    # fill fields with empty(default) values otherwise\n    if node_type_name in PY_AST_NODE_FIELDS:\n        fields_info = PY_AST_NODE_FIELDS[node_type_name]\n\n        for child_node in tree.children:\n            # if it's a compositional leaf\n            if child_node.type == 'epsilon':\n                continue\n\n            field_type = child_node.type\n            field_label = child_node.label\n            field_entry = fields_info[field_label]\n            is_list = field_entry['is_list']\n\n            if is_list:\n                field_type = field_entry['type']\n                field_value = []\n\n                if field_type in {ast.comprehension, ast.excepthandler, ast.arguments, ast.keyword, ast.alias}:\n                    nodes_in_list = child_node.children\n                    for sub_node in nodes_in_list:\n                        sub_node_ast = parse_tree_to_python_ast(sub_node)\n                        field_value.append(sub_node_ast)\n                else:  # expr stuffs\n                    inter_nodes = child_node.children\n                    for inter_node in inter_nodes:\n                        if inter_node.value is None:\n                            assert len(inter_node.children) == 1\n                            sub_node_ast = parse_tree_to_python_ast(inter_node.children[0])\n                            field_value.append(sub_node_ast)\n                        else:\n                            assert len(inter_node.children) == 0\n                            field_value.append(inter_node.value)\n            else:\n                # this node either holds a value, or is an non-terminal\n                if child_node.value is None:\n                    assert len(child_node.children) == 1\n                    field_value = parse_tree_to_python_ast(child_node.children[0])\n                else:\n                    assert child_node.is_leaf\n                    field_value = child_node.value\n\n            setattr(ast_node, field_label, field_value)\n\n    for field in ast_node._fields:\n        if not hasattr(ast_node, field) and not field in NODE_FIELD_BLACK_LIST:\n            if fields_info and fields_info[field]['is_list'] and not fields_info[field]['is_optional']:\n                setattr(ast_node, field, list())\n            else:\n                setattr(ast_node, field, None)\n\n    return ast_node\n\n\ndef decode_tree_to_python_ast(decode_tree):\n    from lang.py.unaryclosure import compressed_ast_to_normal\n\n    compressed_ast_to_normal(decode_tree)\n    decode_tree = decode_tree.children[0]\n    terminals = decode_tree.get_leaves()\n\n    for terminal in terminals:\n        if terminal.value is not None and type(terminal.value) is str:\n            if terminal.value.endswith('<eos>'):\n                terminal.value = terminal.value[:-5]\n\n        if terminal.type in {int, float, str, bool}:\n            # cast to target data type\n            terminal.value = terminal.type(terminal.value)\n\n    ast_tree = parse_tree_to_python_ast(decode_tree)\n\n    return ast_tree\n\n\np_elif = re.compile(r'^elif\\s?')\np_else = re.compile(r'^else\\s?')\np_try = re.compile(r'^try\\s?')\np_except = re.compile(r'^except\\s?')\np_finally = re.compile(r'^finally\\s?')\np_decorator = re.compile(r'^@.*')\n\n\ndef canonicalize_code(code):\n    if p_elif.match(code):\n        code = 'if True: pass\\n' + code\n\n    if p_else.match(code):\n        code = 'if True: pass\\n' + code\n\n    if p_try.match(code):\n        code = code + 'pass\\nexcept: pass'\n    elif p_except.match(code):\n        code = 'try: pass\\n' + code\n    elif p_finally.match(code):\n        code = 'try: pass\\n' + code\n\n    if p_decorator.match(code):\n        code = code + '\\ndef dummy(): pass'\n\n    if code[-1] == ':':\n        code = code + 'pass'\n\n    return code\n\n\ndef de_canonicalize_code(code, ref_raw_code):\n    if code.endswith('def dummy():\\n    pass'):\n        code = code.replace('def dummy():\\n    pass', '').strip()\n\n    if p_elif.match(ref_raw_code):\n        # remove leading if true\n        code = code.replace('if True:\\n    pass', '').strip()\n    elif p_else.match(ref_raw_code):\n        # remove leading if true\n        code = code.replace('if True:\\n    pass', '').strip()\n\n    # try/catch/except stuff\n    if p_try.match(ref_raw_code):\n        code = code.replace('except:\\n    pass', '').strip()\n    elif p_except.match(ref_raw_code):\n        code = code.replace('try:\\n    pass', '').strip()\n    elif p_finally.match(ref_raw_code):\n        code = code.replace('try:\\n    pass', '').strip()\n\n    # remove ending pass\n    if code.endswith(':\\n    pass'):\n        code = code[:-len('\\n    pass')]\n\n    return code\n\n\ndef de_canonicalize_code_for_seq2seq(code, ref_raw_code):\n    if code.endswith('\\ndef dummy(): pass'):\n        code = code.replace('\\ndef dummy(): pass', '').strip()\n\n    if p_elif.match(ref_raw_code):\n        # remove leading if true\n        code = code.replace('if True: pass\\n', '').strip()\n    elif p_else.match(ref_raw_code):\n        # remove leading if true\n        code = code.replace('if True: pass\\n', '').strip()\n\n    # try/catch/except stuff\n    if p_try.match(ref_raw_code):\n        code = code.replace('pass\\nexcept: pass', '').strip()\n    elif p_except.match(ref_raw_code):\n        code = code.replace('try: pass\\n', '').strip()\n    elif p_finally.match(ref_raw_code):\n        code = code.replace('try: pass\\n', '').strip()\n\n    # remove ending pass\n    if code.endswith(':pass'):\n        code = code[:-len('pass')]\n\n    return code.strip()\n\n\ndef add_root(tree):\n    root_node = ASTNode('root')\n    root_node.add_child(tree)\n\n    return root_node\n\n\ndef parse(code):\n    \"\"\"\n    parse a python code into a tree structure\n    code -> AST tree -> AST tree to internal tree structure\n    \"\"\"\n\n    code = canonicalize_code(code)\n    py_ast = ast.parse(code)\n\n    tree = python_ast_to_parse_tree(py_ast.body[0])\n\n    tree = add_root(tree)\n\n    return tree\n\n\ndef parse_raw(code):\n    py_ast = ast.parse(code)\n\n    tree = python_ast_to_parse_tree(py_ast.body[0])\n\n    tree = add_root(tree)\n\n    return tree\n\n\ndef get_grammar(parse_trees):\n    rules = set()\n    # rule_num_dist = defaultdict(int)\n\n    for parse_tree in parse_trees:\n        parse_tree_rules, rule_parents = parse_tree.get_productions()\n        for rule in parse_tree_rules:\n            rules.add(rule)\n\n    rules = list(sorted(rules, key=lambda x: x.__repr__()))\n    grammar = PythonGrammar(rules)\n\n    logging.info('num. rules: %d', len(rules))\n\n    return grammar\n\n\ndef tokenize_code(code):\n    token_stream = generate_tokens(StringIO(code).readline)\n    tokens = []\n    for toknum, tokval, (srow, scol), (erow, ecol), _ in token_stream:\n        if toknum == tk.ENDMARKER:\n            break\n        tokens.append(tokval)\n\n    return tokens\n\n\ndef tokenize_code_adv(code, breakCamelStr=False):\n    token_stream = generate_tokens(StringIO(code).readline)\n    tokens = []\n    indent_level = 0\n    for toknum, tokval, (srow, scol), (erow, ecol), _ in token_stream:\n        if toknum == tk.ENDMARKER:\n            break\n\n        if toknum == tk.INDENT:\n            indent_level += 1\n            tokens.extend(['#INDENT#'] * indent_level)\n            continue\n        elif toknum == tk.DEDENT:\n            indent_level -= 1\n            tokens.extend(['#INDENT#'] * indent_level)\n            continue\n        elif len(tokens) > 0 and tokens[-1] == '\\n' and tokval != '\\n':\n            tokens.extend(['#INDENT#'] * indent_level)\n\n        if toknum == tk.STRING:\n            quote = tokval[0]\n            tokval = tokval[1:-1]\n            tokens.append(quote)\n\n        if breakCamelStr:\n            sub_tokens = re.sub(r'([a-z])([A-Z])', r'\\1 #MERGE# \\2', tokval).split(' ')\n            tokens.extend(sub_tokens)\n        else:\n            tokens.append(tokval)\n\n        if toknum == tk.STRING:\n            tokens.append(quote)\n\n    return tokens\n\n\nif __name__ == '__main__':\n    from nn.utils.generic_utils import init_logging\n    init_logging('misc.log')\n\n    # django_code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code'\n    #\n    # grammar, parse_trees = extract_grammar(django_code_file)\n    # id = 1888\n    # parse_tree = parse_trees[id]\n    # print parse_tree\n    # from components import Hyp\n    # hyp = Hyp(grammar)\n    # rules, rule_parents = parse_tree.get_productions()\n    #\n    # while hyp.frontier_nt():\n    #     nt = hyp.frontier_nt()\n    #     if grammar.is_value_node(nt):\n    #         hyp.append_token('111<eos>')\n    #     else:\n    #         rule = rules[0]\n    #         hyp.apply_rule(rule)\n    #         del rules[0]\n    #\n    # print hyp\n    #\n    # ast_tree = decode_tree_to_python_ast(hyp.tree)\n    # source = astor.to_source(ast_tree)\n    # print source\n\n    # for code in open(django_code_file):\n    #     code = code.strip()\n    #     ref_ast_tree = ast.parse(canonicalize_code(code)).body[0]\n    #     parse_tree = parse(code)\n    #     ast_tree = parse_tree_to_python_ast(parse_tree)\n    #     source1 = astor.to_source(ast_tree)\n    #     source2 = astor.to_source(ref_ast_tree)\n    #\n    #     if source1 != source2:\n    #         pass\n\n    code = \"\"\"\nclass Demonwrath(SpellCard):\n    def __init__(self):\n        super().__init__(\"Demonwrath\", 3, CHARACTER_CLASS.WARLOCK, CARD_RARITY.RARE)\n\n    def use(self, player, game):\n        super().use(player, game)\n        targets = copy.copy(game.other_player.minions)\n        targets.extend(game.current_player.minions)\n        for minion in targets:\n            if minion.card.minion_type is not MINION_TYPE.DEMON:\n                minion.damage(player.effective_spell_damage(2), self)\n\"\"\"\n    code = \"\"\"sorted(mydict, key=mydict.get, reverse=True)\"\"\"\n    # # code = \"\"\"a = [1,2,3,4,'asdf', 234.3]\"\"\"\n    parse_tree = parse(code)\n    # for leaf in parse_tree.get_leaves():\n    #     if leaf.value: print escape(leaf.value)\n    #\n    print parse_tree\n    # ast_tree = parse_tree_to_python_ast(parse_tree)\n    # print astor.to_source(ast_tree)"
  },
  {
    "path": "lang/py/py_dataset.py",
    "content": "# -*- coding: UTF-8 -*-\nfrom __future__ import division\nimport ast\nimport astor\nimport logging\nfrom itertools import chain\nimport nltk\nimport re\n\nfrom nn.utils.io_utils import serialize_to_file, deserialize_from_file\nfrom nn.utils.generic_utils import init_logging\n\nfrom dataset import gen_vocab, DataSet, DataEntry, Action, APPLY_RULE, GEN_TOKEN, COPY_TOKEN, GEN_COPY_TOKEN, Vocab\nfrom lang.py.parse import parse, parse_tree_to_python_ast, canonicalize_code, get_grammar, parse_raw, \\\n    de_canonicalize_code, tokenize_code, tokenize_code_adv, de_canonicalize_code_for_seq2seq\nfrom lang.py.unaryclosure import get_top_unary_closures, apply_unary_closures\n\n\ndef extract_grammar(code_file, prefix='py'):\n    line_num = 0\n    parse_trees = []\n    for line in open(code_file):\n        code = line.strip()\n        parse_tree = parse(code)\n\n        # leaves = parse_tree.get_leaves()\n        # for leaf in leaves:\n        #     if not is_terminal_type(leaf.type):\n        #         print parse_tree\n\n        # parse_tree = add_root(parse_tree)\n\n        parse_trees.append(parse_tree)\n\n        # sanity check\n        ast_tree = parse_tree_to_python_ast(parse_tree)\n        ref_ast_tree = ast.parse(canonicalize_code(code)).body[0]\n        source1 = astor.to_source(ast_tree)\n        source2 = astor.to_source(ref_ast_tree)\n\n        assert source1 == source2\n\n        # check rules\n        # rule_list = parse_tree.get_rule_list(include_leaf=True)\n        # for rule in rule_list:\n        #     if rule.parent.type == int and rule.children[0].type == int:\n        #         # rule.parent.type == str and rule.children[0].type == str:\n        #         pass\n\n        # ast_tree = tree_to_ast(parse_tree)\n        # print astor.to_source(ast_tree)\n            # print parse_tree\n        # except Exception as e:\n        #     error_num += 1\n        #     #pass\n        #     #print e\n\n        line_num += 1\n\n    print 'total line of code: %d' % line_num\n\n    grammar = get_grammar(parse_trees)\n\n    with open(prefix + '.grammar.txt', 'w') as f:\n        for rule in grammar:\n            str = rule.__repr__()\n            f.write(str + '\\n')\n\n    with open(prefix + '.parse_trees.txt', 'w') as f:\n        for tree in parse_trees:\n            f.write(tree.__repr__() + '\\n')\n\n    return grammar, parse_trees\n\n\ndef rule_vs_node_stat():\n    line_num = 0\n    parse_trees = []\n    code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/card_datasets/hearthstone/all_hs.out' # '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code'\n    node_nums = rule_nums = 0.\n    for line in open(code_file):\n        code = line.replace('§', '\\n').strip()\n        parse_tree = parse(code)\n        node_nums += len(list(parse_tree.nodes))\n        rules, _ = parse_tree.get_productions()\n        rule_nums += len(rules)\n        parse_trees.append(parse_tree)\n\n        line_num += 1\n\n    print 'avg. nums of nodes: %f' % (node_nums / line_num)\n    print 'avg. nums of rules: %f' % (rule_nums / line_num)\n\n\ndef process_heart_stone_dataset():\n    data_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/card_datasets/hearthstone/all_hs.out'\n    parse_trees = []\n    rule_num = 0.\n    example_num = 0\n    for line in open(data_file):\n        code = line.replace('§', '\\n').strip()\n        parse_tree = parse(code)\n        # sanity check\n        pred_ast = parse_tree_to_python_ast(parse_tree)\n        pred_code = astor.to_source(pred_ast)\n        ref_ast = ast.parse(code)\n        ref_code = astor.to_source(ref_ast)\n\n        if pred_code != ref_code:\n            raise RuntimeError('code mismatch!')\n\n        rules, _ = parse_tree.get_productions(include_value_node=False)\n        rule_num += len(rules)\n        example_num += 1\n\n        parse_trees.append(parse_tree)\n\n    grammar = get_grammar(parse_trees)\n\n    with open('hs.grammar.txt', 'w') as f:\n        for rule in grammar:\n            str = rule.__repr__()\n            f.write(str + '\\n')\n\n    with open('hs.parse_trees.txt', 'w') as f:\n        for tree in parse_trees:\n            f.write(tree.__repr__() + '\\n')\n\n\n    print 'avg. nums of rules: %f' % (rule_num / example_num)\n\n\ndef canonicalize_hs_example(query, code):\n    query = re.sub(r'<.*?>', '', query)\n    query_tokens = nltk.word_tokenize(query)\n\n    code = code.replace('§', '\\n').strip()\n\n    # sanity check\n    parse_tree = parse_raw(code)\n    gold_ast_tree = ast.parse(code).body[0]\n    gold_source = astor.to_source(gold_ast_tree)\n    ast_tree = parse_tree_to_python_ast(parse_tree)\n    pred_source = astor.to_source(ast_tree)\n\n    assert gold_source == pred_source, 'sanity check fails: gold=[%s], actual=[%s]' % (gold_source, pred_source)\n\n    return query_tokens, code, parse_tree\n\n\ndef preprocess_hs_dataset(annot_file, code_file):\n    f = open('hs_dataset.examples.txt', 'w')\n\n    examples = []\n\n    for idx, (annot, code) in enumerate(zip(open(annot_file), open(code_file))):\n        annot = annot.strip()\n        code = code.strip()\n\n        clean_query_tokens, clean_code, parse_tree = canonicalize_hs_example(annot, code)\n        example = {'id': idx, 'query_tokens': clean_query_tokens, 'code': clean_code, 'parse_tree': parse_tree,\n                   'str_map': None, 'raw_code': code}\n        examples.append(example)\n\n        f.write('*' * 50 + '\\n')\n        f.write('example# %d\\n' % idx)\n        f.write(' '.join(clean_query_tokens) + '\\n')\n        f.write('\\n')\n        f.write(clean_code + '\\n')\n        f.write('*' * 50 + '\\n')\n\n        idx += 1\n\n    f.close()\n\n    print 'preprocess_dataset: cleaned example num: %d' % len(examples)\n\n    return examples\n\n\ndef parse_hs_dataset():\n    MAX_QUERY_LENGTH = 70 # FIXME: figure out the best config!\n    WORD_FREQ_CUT_OFF = 3\n\n    annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/card_datasets/hearthstone/all_hs.mod.in'\n    code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/card_datasets/hearthstone/all_hs.out'\n\n    data = preprocess_hs_dataset(annot_file, code_file)\n    parse_trees = [e['parse_tree'] for e in data]\n\n    # apply unary closures\n    unary_closures = get_top_unary_closures(parse_trees, k=20)\n    for parse_tree in parse_trees:\n        apply_unary_closures(parse_tree, unary_closures)\n\n    # build the grammar\n    grammar = get_grammar(parse_trees)\n\n    with open('hs.grammar.unary_closure.txt', 'w') as f:\n        for rule in grammar:\n            f.write(rule.__repr__() + '\\n')\n\n    annot_tokens = list(chain(*[e['query_tokens'] for e in data]))\n    annot_vocab = gen_vocab(annot_tokens, vocab_size=5000, freq_cutoff=WORD_FREQ_CUT_OFF)\n\n    def get_terminal_tokens(_terminal_str):\n        \"\"\"\n        get terminal tokens\n        break words like MinionCards into [Minion, Cards]\n        \"\"\"\n        tmp_terminal_tokens = [t for t in _terminal_str.split(' ') if len(t) > 0]\n        _terminal_tokens = []\n        for token in tmp_terminal_tokens:\n            sub_tokens = re.sub(r'([a-z])([A-Z])', r'\\1 \\2', token).split(' ')\n            _terminal_tokens.extend(sub_tokens)\n\n            _terminal_tokens.append(' ')\n\n        return _terminal_tokens[:-1]\n\n    # enumerate all terminal tokens to build up the terminal tokens vocabulary\n    all_terminal_tokens = []\n    for entry in data:\n        parse_tree = entry['parse_tree']\n        for node in parse_tree.get_leaves():\n            if grammar.is_value_node(node):\n                terminal_val = node.value\n                terminal_str = str(terminal_val)\n\n                terminal_tokens = get_terminal_tokens(terminal_str)\n\n                for terminal_token in terminal_tokens:\n                    assert len(terminal_token) > 0\n                    all_terminal_tokens.append(terminal_token)\n\n    terminal_vocab = gen_vocab(all_terminal_tokens, vocab_size=5000, freq_cutoff=WORD_FREQ_CUT_OFF)\n\n    # now generate the dataset!\n\n    train_data = DataSet(annot_vocab, terminal_vocab, grammar, 'hs.train_data')\n    dev_data = DataSet(annot_vocab, terminal_vocab, grammar, 'hs.dev_data')\n    test_data = DataSet(annot_vocab, terminal_vocab, grammar, 'hs.test_data')\n\n    all_examples = []\n\n    can_fully_reconstructed_examples_num = 0\n    examples_with_empty_actions_num = 0\n\n    for entry in data:\n        idx = entry['id']\n        query_tokens = entry['query_tokens']\n        code = entry['code']\n        parse_tree = entry['parse_tree']\n\n        rule_list, rule_parents = parse_tree.get_productions(include_value_node=True)\n\n        actions = []\n        can_fully_reconstructed = True\n        rule_pos_map = dict()\n\n        for rule_count, rule in enumerate(rule_list):\n            if not grammar.is_value_node(rule.parent):\n                assert rule.value is None\n                parent_rule = rule_parents[(rule_count, rule)][0]\n                if parent_rule:\n                    parent_t = rule_pos_map[parent_rule]\n                else:\n                    parent_t = 0\n\n                rule_pos_map[rule] = len(actions)\n\n                d = {'rule': rule, 'parent_t': parent_t, 'parent_rule': parent_rule}\n                action = Action(APPLY_RULE, d)\n\n                actions.append(action)\n            else:\n                assert rule.is_leaf\n\n                parent_rule = rule_parents[(rule_count, rule)][0]\n                parent_t = rule_pos_map[parent_rule]\n\n                terminal_val = rule.value\n                terminal_str = str(terminal_val)\n                terminal_tokens = get_terminal_tokens(terminal_str)\n\n                # assert len(terminal_tokens) > 0\n\n                for terminal_token in terminal_tokens:\n                    term_tok_id = terminal_vocab[terminal_token]\n                    tok_src_idx = -1\n                    try:\n                        tok_src_idx = query_tokens.index(terminal_token)\n                    except ValueError:\n                        pass\n\n                    d = {'literal': terminal_token, 'rule': rule, 'parent_rule': parent_rule, 'parent_t': parent_t}\n\n                    # cannot copy, only generation\n                    # could be unk!\n                    if tok_src_idx < 0 or tok_src_idx >= MAX_QUERY_LENGTH:\n                        action = Action(GEN_TOKEN, d)\n                        if terminal_token not in terminal_vocab:\n                            if terminal_token not in query_tokens:\n                                # print terminal_token\n                                can_fully_reconstructed = False\n                    else:  # copy\n                        if term_tok_id != terminal_vocab.unk:\n                            d['source_idx'] = tok_src_idx\n                            action = Action(GEN_COPY_TOKEN, d)\n                        else:\n                            d['source_idx'] = tok_src_idx\n                            action = Action(COPY_TOKEN, d)\n\n                    actions.append(action)\n\n                d = {'literal': '<eos>', 'rule': rule, 'parent_rule': parent_rule, 'parent_t': parent_t}\n                actions.append(Action(GEN_TOKEN, d))\n\n        if len(actions) == 0:\n            examples_with_empty_actions_num += 1\n            continue\n\n        example = DataEntry(idx, query_tokens, parse_tree, code, actions, {'str_map': None, 'raw_code': entry['raw_code']})\n\n        if can_fully_reconstructed:\n            can_fully_reconstructed_examples_num += 1\n\n        # train, valid, test splits\n        if 0 <= idx < 533:\n            train_data.add(example)\n        elif idx < 599:\n            dev_data.add(example)\n        else:\n            test_data.add(example)\n\n        all_examples.append(example)\n\n    # print statistics\n    max_query_len = max(len(e.query) for e in all_examples)\n    max_actions_len = max(len(e.actions) for e in all_examples)\n\n    # serialize_to_file([len(e.query) for e in all_examples], 'query.len')\n    # serialize_to_file([len(e.actions) for e in all_examples], 'actions.len')\n\n    logging.info('examples that can be fully reconstructed: %d/%d=%f',\n                 can_fully_reconstructed_examples_num, len(all_examples),\n                 can_fully_reconstructed_examples_num / len(all_examples))\n    logging.info('empty_actions_count: %d', examples_with_empty_actions_num)\n\n    logging.info('max_query_len: %d', max_query_len)\n    logging.info('max_actions_len: %d', max_actions_len)\n\n    train_data.init_data_matrices(max_query_length=70, max_example_action_num=350)\n    dev_data.init_data_matrices(max_query_length=70, max_example_action_num=350)\n    test_data.init_data_matrices(max_query_length=70, max_example_action_num=350)\n\n    serialize_to_file((train_data, dev_data, test_data),\n                      'data/hs.freq{WORD_FREQ_CUT_OFF}.max_action350.pre_suf.unary_closure.bin'.format(WORD_FREQ_CUT_OFF=WORD_FREQ_CUT_OFF))\n\n    return train_data, dev_data, test_data\n\n\ndef dump_data_for_evaluation(data_type='django', data_file='', max_query_length=70):\n    train_data, dev_data, test_data = deserialize_from_file(data_file)\n    prefix = '/Users/yinpengcheng/Projects/dl4mt-tutorial/codegen_data/'\n    for dataset, output in [(train_data, prefix + '%s.train' % data_type),\n                            (dev_data, prefix + '%s.dev' % data_type),\n                            (test_data, prefix + '%s.test' % data_type)]:\n        f_source = open(output + '.desc', 'w')\n        f_target = open(output + '.code', 'w')\n\n        for e in dataset.examples:\n            query_tokens = e.query[:max_query_length]\n            code = e.code\n            if data_type == 'django':\n                target_code = de_canonicalize_code_for_seq2seq(code, e.meta_data['raw_code'])\n            else:\n                target_code = code\n\n            # tokenize code\n            target_code = target_code.strip()\n            tokenized_target = tokenize_code_adv(target_code, breakCamelStr=False if data_type=='django' else True)\n            tokenized_target = [tk.replace('\\n', '#NEWLINE#') for tk in tokenized_target]\n            tokenized_target = [tk for tk in tokenized_target if tk is not None]\n\n            while tokenized_target[-1] == '#INDENT#':\n                tokenized_target = tokenized_target[:-1]\n\n            f_source.write(' '.join(query_tokens) + '\\n')\n            f_target.write(' '.join(tokenized_target) + '\\n')\n\n        f_source.close()\n        f_target.close()\n\n\nif __name__ == '__main__':\n    init_logging('py.log')\n    # rule_vs_node_stat()\n    # process_heart_stone_dataset()\n    parse_hs_dataset()\n    # dump_data_for_evaluation(data_file='data/django.cleaned.dataset.freq5.par_info.refact.space_only.bin')\n    # dump_data_for_evaluation(data_type='hs', data_file='data/hs.freq3.pre_suf.unary_closure.bin')\n    # code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code'\n    # py_grammar, _ = extract_grammar(code_file)\n    # serialize_to_file(py_grammar, 'py_grammar.bin')"
  },
  {
    "path": "lang/py/seq2tree_exp.py",
    "content": "import logging\nimport re\nfrom collections import defaultdict, OrderedDict\nfrom itertools import chain\n\nimport sys\n\nfrom astnode import ASTNode\nfrom dataset import preprocess_dataset, gen_vocab\nfrom lang.py.grammar import type_str_to_type\nfrom lang.py.parse import parse, get_grammar, decode_tree_to_python_ast\nfrom lang.py.unaryclosure import get_top_unary_closures, apply_unary_closures\nfrom lang.util import typename, escape, unescape\nfrom nn.utils.generic_utils import init_logging\nfrom nn.utils.io_utils import serialize_to_file\n\n\ndef ast_tree_to_seq2tree_repr(tree):\n    repr_str = ''\n\n    # node_name = typename(tree.type)\n    label_val = '' if tree.label is None else tree.label\n    value = '' if tree.value is None else tree.value\n    node_name = '%s{%s}{%s}' % (typename(tree.type), label_val, value)\n    repr_str += node_name\n\n    # wrap children with parentheses\n    if tree.children:\n        repr_str += ' ('\n\n        for child in tree.children:\n            child_repr = ast_tree_to_seq2tree_repr(child)\n            repr_str += ' ' + child_repr\n\n        repr_str += ' )'\n\n    return repr_str\n\nnode_re = re.compile(r'(?P<type>.*?)\\{(?P<label>.*?)\\}\\{(?P<value>.*)\\}')\ndef seq2tree_repr_to_ast_tree_helper(tree_repr, offset):\n    \"\"\"convert a seq2tree representation to AST tree\"\"\"\n\n    # extract node name\n    node_name_end = offset\n    while node_name_end < len(tree_repr) and tree_repr[node_name_end] != ' ':\n        node_name_end += 1\n\n    node_repr = tree_repr[offset:node_name_end]\n\n    m = node_re.match(node_repr)\n    n_type = m.group('type')\n    n_type = type_str_to_type(n_type)\n    n_label = m.group('label')\n    n_value = m.group('value')\n\n    if n_type in {int, float, str, bool}:\n        n_value = n_type(n_value)\n\n    n_label = None if n_label == '' else n_label\n    n_value = None if n_value == '' else n_value\n\n    node = ASTNode(n_type, label=n_label, value=n_value)\n    offset = node_name_end\n\n    if offset == len(tree_repr):\n        return node, offset\n\n    offset += 1\n    if tree_repr[offset] == '(':\n        offset += 2\n        while True:\n            child_node, offset = seq2tree_repr_to_ast_tree_helper(tree_repr, offset=offset)\n            node.add_child(child_node)\n\n            if offset >= len(tree_repr) or tree_repr[offset] == ')':\n                offset += 2\n                break\n\n    return node, offset\n\n\ndef seq2tree_repr_to_ast_tree(tree_repr):\n    tree, _ = seq2tree_repr_to_ast_tree_helper(tree_repr, 0)\n\n    return tree\n\n\ndef break_value_nodes(tree, hs=False):\n    \"\"\"inplace break value nodes with a string separaed by spaces\"\"\"\n    if tree.type == str and tree.value is not None:\n        assert tree.is_leaf\n\n        if hs:\n            tokens = re.sub(r'([a-z])([A-Z])', r'\\1 #MERGE# \\2', tree.value).split(' ')\n        else:\n            tokens = tree.value.split(' ')\n        tree.value = 'NT'\n        for token in tokens:\n            assert token is not None\n            tree.add_child(ASTNode(tree.type, value=escape(token)))\n    else:\n        for child in tree.children:\n            break_value_nodes(child, hs=hs)\n\n\ndef merge_broken_value_nodes(tree):\n    \"\"\"redo *break_value_nodes*\"\"\"\n    if tree.type == str and not tree.is_leaf:\n        assert tree.value == 'NT'\n\n        valid_children = [c for c in tree.children if c.value is not None]\n        value = ' '.join(unescape(c.value) for c in valid_children)\n        value = value.replace(' #MERGE# ', '')\n        tree.value = value\n\n        tree.children = []\n    else:\n        for child in tree.children:\n            merge_broken_value_nodes(child)\n\n\ndef parse_django_dataset_for_seq2tree():\n    from lang.py.parse import parse_raw\n    MAX_QUERY_LENGTH = 70\n    MAX_DECODING_TIME_STEP = 300\n    UNARY_CUTOFF_FREQ = 30\n\n    annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.anno'\n    code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code'\n\n    data = preprocess_dataset(annot_file, code_file)\n\n    for e in data:\n        e['parse_tree'] = parse_raw(e['code'])\n\n    parse_trees = [e['parse_tree'] for e in data]\n\n    # apply unary closures\n    # unary_closures = get_top_unary_closures(parse_trees, k=0, freq=UNARY_CUTOFF_FREQ)\n    # for i, parse_tree in enumerate(parse_trees):\n    #     apply_unary_closures(parse_tree, unary_closures)\n\n    # build the grammar\n    grammar = get_grammar(parse_trees)\n\n    # # build grammar ...\n    # from lang.py.py_dataset import extract_grammar\n    # grammar, all_parse_trees = extract_grammar(code_file)\n\n    f_train = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/django/data/train.txt', 'w')\n    f_dev = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/django/data/dev.txt', 'w')\n    f_test = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/django/data/test.txt', 'w')\n\n    f_train_rawid = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/django/data/train.id.txt', 'w')\n    f_dev_rawid = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/django/data/dev.id.txt', 'w')\n    f_test_rawid = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/django/data/test.id.txt', 'w')\n\n    decode_time_steps = defaultdict(int)\n\n    # first pass\n    for entry in data:\n        idx = entry['id']\n        query_tokens = entry['query_tokens']\n        code = entry['code']\n        parse_tree = entry['parse_tree']\n\n        original_parse_tree = parse_tree.copy()\n        break_value_nodes(parse_tree)\n        tree_repr = ast_tree_to_seq2tree_repr(parse_tree)\n\n        num_decode_time_step = len(tree_repr.split(' '))\n        decode_time_steps[num_decode_time_step] += 1\n\n        new_tree = seq2tree_repr_to_ast_tree(tree_repr)\n        merge_broken_value_nodes(new_tree)\n\n        query_tokens = [t for t in query_tokens if t != ''][:MAX_QUERY_LENGTH]\n        query = ' '.join(query_tokens)\n        line = query + '\\t' + tree_repr\n\n        if num_decode_time_step > MAX_DECODING_TIME_STEP:\n            continue\n\n        # train, valid, test\n        if 0 <= idx < 16000:\n            f_train.write(line + '\\n')\n            f_train_rawid.write(str(idx) + '\\n')\n        elif 16000 <= idx < 17000:\n            f_dev.write(line + '\\n')\n            f_dev_rawid.write(str(idx) + '\\n')\n        else:\n            f_test.write(line + '\\n')\n            f_test_rawid.write(str(idx) + '\\n')\n\n        if original_parse_tree != new_tree:\n            print '*' * 50\n            print idx\n            print code\n\n    f_train.close()\n    f_dev.close()\n    f_test.close()\n\n    f_train_rawid.close()\n    f_dev_rawid.close()\n    f_test_rawid.close()\n\n    # print 'num. of decoding time steps distribution:'\n    # for k in sorted(decode_time_steps):\n    #     print '%d\\t%d' % (k, decode_time_steps[k])\n\n\ndef parse_hs_dataset_for_seq2tree():\n    from lang.py.py_dataset import preprocess_hs_dataset\n    MAX_QUERY_LENGTH = 70 # FIXME: figure out the best config!\n    WORD_FREQ_CUT_OFF = 3\n    MAX_DECODING_TIME_STEP = 800\n\n    annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/card_datasets/hearthstone/all_hs.mod.in'\n    code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/card_datasets/hearthstone/all_hs.out'\n\n    data = preprocess_hs_dataset(annot_file, code_file)\n    parse_trees = [e['parse_tree'] for e in data]\n\n    # apply unary closures\n    unary_closures = get_top_unary_closures(parse_trees, k=20)\n    for parse_tree in parse_trees:\n        apply_unary_closures(parse_tree, unary_closures)\n\n    # build the grammar\n    grammar = get_grammar(parse_trees)\n\n    decode_time_steps = defaultdict(int)\n\n    f_train = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/hs/data_unkreplaced/train.txt', 'w')\n    f_dev = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/hs/data_unkreplaced/dev.txt', 'w')\n    f_test = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/hs/data_unkreplaced/test.txt', 'w')\n\n    f_train_rawid = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/hs/data_unkreplaced/train.id.txt', 'w')\n    f_dev_rawid = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/hs/data_unkreplaced/dev.id.txt', 'w')\n    f_test_rawid = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/hs/data_unkreplaced/test.id.txt', 'w')\n\n    # first pass\n    for entry in data:\n        idx = entry['id']\n        query_tokens = entry['query_tokens']\n        parse_tree = entry['parse_tree']\n\n        original_parse_tree = parse_tree.copy()\n        break_value_nodes(parse_tree, hs=True)\n        tree_repr = ast_tree_to_seq2tree_repr(parse_tree)\n\n        num_decode_time_step = len(tree_repr.split(' '))\n        decode_time_steps[num_decode_time_step] += 1\n\n        new_tree = seq2tree_repr_to_ast_tree(tree_repr)\n        merge_broken_value_nodes(new_tree)\n\n        query_tokens = [t for t in query_tokens if t != ''][:MAX_QUERY_LENGTH]\n        query = ' '.join(query_tokens)\n        line = query + '\\t' + tree_repr\n\n        if num_decode_time_step > MAX_DECODING_TIME_STEP:\n            continue\n\n        # train, valid, test\n        if 0 <= idx < 533:\n            f_train.write(line + '\\n')\n            f_train_rawid.write(str(idx) + '\\n')\n        elif idx < 599:\n            f_dev.write(line + '\\n')\n            f_dev_rawid.write(str(idx) + '\\n')\n        else:\n            f_test.write(line + '\\n')\n            f_test_rawid.write(str(idx) + '\\n')\n\n        if original_parse_tree != new_tree:\n            print '*' * 50\n            print idx\n            print code\n\n    f_train.close()\n    f_dev.close()\n    f_test.close()\n\n    f_train_rawid.close()\n    f_dev_rawid.close()\n    f_test_rawid.close()\n\n    # print 'num. of decoding time steps distribution:'\n    for k in sorted(decode_time_steps):\n        print '%d\\t%d' % (k, decode_time_steps[k])\n\n\nif __name__ == '__main__':\n    init_logging('py.log')\n    # code = \"return (  format_html_join ( '' , '_STR:0_' , sorted ( attrs . items ( ) ) ) +  format_html_join ( '' , ' {0}' , sorted ( boolean_attrs ) )  )\"\n    code = \"call('{0}')\"\n    parse_tree = parse(code)\n\n    # parse_tree = ASTNode('root', children=[\n    #     ASTNode('lambda'),\n    #     ASTNode('$0'),\n    #     ASTNode('e', children=[\n    #         ASTNode('and', children=[\n    #             ASTNode('>', children=[ASTNode('$0')]),\n    #             ASTNode('from', children=[ASTNode('$0'), ASTNode('ci0')]),\n    #         ])\n    #     ]),\n    # ])\n\n    original_parse_tree = parse_tree.copy()\n    break_value_nodes(parse_tree)\n\n    # tree_repr = \"\"\"root{}{} ( For{}{} ( expr{target}{} ( Name{}{} ( str{id}{NT} ( ) ) ) expr{iter}{} ( Name{}{} ( str{id}{NT} ( Name{}{} ( str{id}{NT} ( str{}{self} ) ) ) ) ) stmt*{body}{} ( stmt{}{} ( Pass{}{} ) ) ) )\"\"\"\n    # print tree_repr\n\n    # new_tree = seq2tree_repr_to_ast_tree(tree_repr)\n    # merge_broken_value_nodes(new_tree)\n\n    # print str(original_parse_tree)\n    # print str(new_tree)\n\n    # assert original_parse_tree == new_tree\n\n    # parse_django_dataset_for_seq2tree()\n    parse_hs_dataset_for_seq2tree()"
  },
  {
    "path": "lang/py/unaryclosure.py",
    "content": "# -*- coding: UTF-8 -*-\n\nfrom astnode import ASTNode\nfrom lang.py.grammar import type_str_to_type\nfrom lang.py.parse import parse\nfrom collections import Counter\nimport re\n\n\ndef extract_unary_closure_helper(parse_tree, unary_link, last_node):\n    if parse_tree.is_leaf:\n        if unary_link and unary_link.size > 2:\n            return [unary_link]\n        else:\n            return []\n    elif len(parse_tree.children) > 1:\n        unary_links = []\n        if unary_link and unary_link.size > 2:\n            unary_links.append(unary_link)\n        for child in parse_tree.children:\n            new_node = ASTNode(child.type)\n            child_unary_links = extract_unary_closure_helper(child, new_node, new_node)\n            unary_links.extend(child_unary_links)\n\n        return unary_links\n    else:  # has a single child\n        child = parse_tree.children[0]\n        new_node = ASTNode(child.type, label=child.label)\n        last_node.add_child(new_node)\n        last_node = new_node\n\n        return extract_unary_closure_helper(child, unary_link, last_node)\n\n\ndef extract_unary_closure(parse_tree):\n    root_node_copy = ASTNode(parse_tree.type)\n    unary_links = extract_unary_closure_helper(parse_tree, root_node_copy, root_node_copy)\n\n    return unary_links\n\n\ndef get_unary_links():\n    # data_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/card_datasets/hearthstone/all_hs.out'\n    data_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code'\n    parse_trees = []\n    unary_links_counter = Counter()\n\n    for line in open(data_file):\n        code = line.replace('§', '\\n').strip()\n        parse_tree = parse(code)\n        parse_trees.append(parse_tree)\n\n        example_unary_links = extract_unary_closure(parse_tree)\n        for link in example_unary_links:\n            unary_links_counter[link] += 1\n\n    ranked_links = sorted(unary_links_counter, key=unary_links_counter.get, reverse=True)\n    for link in ranked_links:\n        print str(link) + ' ||| ' + str(unary_links_counter[link])\n\n    unary_links = ranked_links[:20]\n    unary_closures = []\n    for link in unary_links:\n        unary_closures.append(unary_link_to_closure(link))\n\n    unary_closures = zip(unary_links, unary_closures)\n\n    node_nums = rule_nums = 0.\n    for parse_tree in parse_trees:\n        original_parse_tree = parse_tree.copy()\n        for link, closure in unary_closures:\n            apply_unary_closure(parse_tree, closure, link)\n\n        # assert original_parse_tree != parse_tree\n        compressed_ast_to_normal(parse_tree)\n        assert original_parse_tree == parse_tree\n\n        rules, _ = parse_tree.get_productions()\n        rule_nums += len(rules)\n        node_nums += len(list(parse_tree.nodes))\n\n    print '**** after applying unary closures ****'\n    print 'avg. nums of nodes: %f' % (node_nums / len(parse_trees))\n    print 'avg. nums of rules: %f' % (rule_nums / len(parse_trees))\n\n\n\ndef get_top_unary_closures(parse_trees, k=20, freq=50):\n    unary_links_counter = Counter()\n    for parse_tree in parse_trees:\n        example_unary_links = extract_unary_closure(parse_tree)\n        for link in example_unary_links:\n            unary_links_counter[link] += 1\n\n    ranked_links = sorted(unary_links_counter, key=unary_links_counter.get, reverse=True)\n    if k:\n        print 'rank cut off: %d' % k\n        unary_links = ranked_links[:k]\n    else:\n        print 'freq cut off: %d' % freq\n        unary_links = sorted([l for l in unary_links_counter if unary_links_counter[l] >= freq], key=unary_links_counter.get, reverse=True)\n\n    unary_closures = []\n    for link in unary_links:\n        unary_closures.append(unary_link_to_closure(link))\n\n    unary_closures = zip(unary_links, unary_closures)\n\n    for link, closure in unary_closures:\n        print 'link: %s ||| closure: %s ||| freq: %d' % (link, closure, unary_links_counter[link])\n\n    return unary_closures\n\n\ndef apply_unary_closures(parse_tree, unary_closures):\n    unary_closures = sorted(unary_closures, key=lambda x: x[0].size, reverse=True)\n    original_parse_tree = parse_tree.copy()\n\n    # apply all unary closures\n    for link, closure in unary_closures:\n        apply_unary_closure(parse_tree, closure, link)\n\n    new_tree_copy = parse_tree.copy()\n    compressed_ast_to_normal(new_tree_copy)\n    assert original_parse_tree == new_tree_copy\n\n\nrule_regex = re.compile(r'(?P<parent>.*?) -> \\((?P<child>.*?)(\\{(?P<clabel>.*?)\\})?\\)')\ndef compressed_ast_to_normal(parse_tree):\n    if parse_tree.label and '@' in parse_tree.label and '$' in parse_tree.label:\n        label = parse_tree.label\n        label = label.replace('$', ' ')\n        rule_reprs = label.split('@')\n\n        intermediate_nodes = []\n        first_node = last_node = None\n        for rule_repr in rule_reprs:\n            m = rule_regex.match(rule_repr)\n            p = m.group('parent')\n            c = m.group('child')\n            cl = m.group('clabel')\n\n            p_type = type_str_to_type(p)\n            c_type = type_str_to_type(c)\n\n            node = ASTNode(c_type, label=cl)\n            if last_node:\n                last_node.add_child(node)\n            if not first_node:\n                first_node = node\n\n            last_node = node\n            intermediate_nodes.append(node)\n\n        last_node.value = parse_tree.value\n        for child in parse_tree.children:\n            last_node.add_child(child)\n            compressed_ast_to_normal(child)\n\n\n        parent_node = parse_tree.parent\n        assert len(parent_node.children) == 1\n        del parent_node.children[0]\n        parent_node.add_child(first_node)\n        # return first_node\n    else:\n        new_child_trees = []\n        for child in parse_tree.children[:]:\n            compressed_ast_to_normal(child)\n        #     new_child_trees.append(new_child_tree)\n        # del parse_tree.children[:]\n        # for child_tree in new_child_trees:\n        #     parse_tree.add_child(child_tree)\n        #\n        # return parse_tree\n\n\ndef match_sub_tree(parse_tree, cur_match_node, is_root=False):\n    cur_level_match = False\n    if parse_tree.type == cur_match_node.type and (len(parse_tree.children) == 1 or cur_match_node.is_leaf) and \\\n            (is_root or parse_tree.label == cur_match_node.label):\n        cur_level_match = True\n\n    if cur_level_match:\n        if cur_match_node.is_leaf:\n            return parse_tree\n\n        last_node = match_sub_tree(parse_tree.children[0], cur_match_node.children[0])\n        return last_node\n    else:\n        return None\n\n\ndef find(parse_tree, sub_tree):\n    match_results = []\n    last_node = match_sub_tree(parse_tree, sub_tree, True)\n\n    if last_node:\n        match_results.append((parse_tree, last_node))\n\n    for child in parse_tree.children:\n        child_match_results = find(child, sub_tree)\n        match_results.extend(child_match_results)\n\n    return match_results\n\n\ndef apply_unary_closure(parse_tree, unary_closure, unary_link):\n    match_results = find(parse_tree, unary_link)\n    for first_node, last_node in match_results:\n        closure_copy = unary_closure.copy()\n\n        leaf = closure_copy.get_leaves()[0]\n        leaf.value = last_node.value\n        for child in last_node.children:\n            leaf.add_child(child)\n\n        new_node = closure_copy.children[0]\n        first_node.children.remove(first_node.children[0])\n        first_node.add_child(new_node)\n\n\ndef unary_link_to_closure(unary_link):\n    closure = ASTNode(unary_link.type)\n    last_node = unary_link.get_leaves()[0]\n    closure_child = ASTNode(last_node.type)\n    prod, _ = unary_link.get_productions()\n    closure_child_label = '@'.join(str(rule).replace(' ', '$') for rule in prod)\n    closure_child.label = closure_child_label\n\n    closure.add_child(closure_child)\n\n    return closure\n\nif __name__ == '__main__':\n#     code = \"\"\"\n# class Demonwrath(SpellCard):\n#     def __init__(self):\n#         super().__init__(\"Demonwrath\", 3, CHARACTER_CLASS.WARLOCK, CARD_RARITY.RARE)\n#\n#     def use(self, player, game):\n#         super().use(player, game)\n#         targets = copy.copy(game.other_player.minions)\n#         targets.extend(game.current_player.minions)\n#         for minion in targets:\n#             if minion.card.minion_type is not MINION_TYPE.DEMON:\n#                 minion.damage(player.effective_spell_damage(2), self)\n#     \"\"\"\n#     parse_tree = parse(code)\n#     original_parse_tree = parse_tree.copy()\n#     unary_links = extract_unary_closure(parse_tree)\n#\n#     for link in unary_links:\n#         closure = unary_link_to_closure(link)\n#         print closure, link\n#         apply_unary_closure(parse_tree, closure, link)\n#\n#     compressed_ast_to_normal(parse_tree)\n#     print parse_tree\n#     print original_parse_tree\n#     print parse_tree == original_parse_tree\n    get_unary_links()"
  },
  {
    "path": "lang/type_system.py",
    "content": ""
  },
  {
    "path": "lang/util.py",
    "content": "# x is a type\ndef typename(x):\n    if isinstance(x, str):\n        return x\n    return x.__name__\n\ndef escape(text):\n    text = text \\\n        .replace('\"', '-``-') \\\n        .replace('\\'', '-`-') \\\n        .replace(' ', '-SP-') \\\n        .replace('\\t', '-TAB-') \\\n        .replace('\\n', '-NL-') \\\n        .replace('\\r', '-NL2-') \\\n        .replace('(', '-LRB-') \\\n        .replace(')', '-RRB-') \\\n        .replace('|', '-BAR-')\n\n    if text is None:\n        return '-NONE-'\n    elif text == '':\n        return '-EMPTY-'\n\n    return text\n\ndef unescape(text):\n    if text == '-NONE-':\n        return None\n\n    text = text \\\n        .replace('-``-', '\"') \\\n        .replace('-`-', '\\'') \\\n        .replace('-SP-', ' ') \\\n        .replace('-TAB-', '\\t') \\\n        .replace('-NL-', '\\n') \\\n        .replace('-NL2-', '\\r') \\\n        .replace('-LRB-', '(') \\\n        .replace('-RRB-', ')') \\\n        .replace('-BAR-', '|') \\\n        .replace('-EMPTY-', '')\n\n    return text"
  },
  {
    "path": "learner.py",
    "content": "from nn.utils.config_factory import config\nfrom nn.utils.generic_utils import *\n\nimport logging\nimport numpy as np\nimport sys, os\nimport time\n\nimport decoder\nimport evaluation\nfrom dataset import *\nimport config\n\n\nclass Learner(object):\n    def __init__(self, model, train_data, val_data=None):\n        self.model = model\n        self.train_data = train_data\n        self.val_data = val_data\n\n        logging.info('initial learner with training set [%s] (%d examples)',\n                     train_data.name,\n                     train_data.count)\n        if val_data:\n            logging.info('validation set [%s] (%d examples)', val_data.name, val_data.count)\n\n    def train(self):\n        dataset = self.train_data\n        nb_train_sample = dataset.count\n        index_array = np.arange(nb_train_sample)\n\n        nb_epoch = config.max_epoch\n        batch_size = config.batch_size\n\n        logging.info('begin training')\n        cum_updates = 0\n        patience_counter = 0\n        early_stop = False\n        history_valid_perf = []\n        history_valid_bleu = []\n        history_valid_acc = []\n        best_model_params = best_model_by_acc = best_model_by_bleu = None\n\n        # train_data_iter = DataIterator(self.train_data, batch_size)\n\n        for epoch in range(nb_epoch):\n            # train_data_iter.reset()\n            # if shuffle:\n            np.random.shuffle(index_array)\n\n            batches = make_batches(nb_train_sample, batch_size)\n\n            # epoch begin\n            sys.stdout.write('Epoch %d' % epoch)\n            begin_time = time.time()\n            cum_nb_examples = 0\n            loss = 0.0\n\n            for batch_index, (batch_start, batch_end) in enumerate(batches):\n            # for batch_index, (examples, batch_ids) in enumerate(train_data_iter):\n                cum_updates += 1\n\n                batch_ids = index_array[batch_start:batch_end]\n                examples = dataset.get_examples(batch_ids)\n                cur_batch_size = len(examples)\n\n                inputs = dataset.get_prob_func_inputs(batch_ids)\n\n                if not config.enable_copy:\n                    tgt_action_seq = inputs[1]\n                    tgt_action_seq_type = inputs[2]\n\n                    for i in xrange(cur_batch_size):\n                        for t in xrange(tgt_action_seq[i].shape[0]):\n                            if tgt_action_seq_type[i, t, 2] == 1:\n                                # can only be copied\n                                if tgt_action_seq_type[i, t, 1] == 0:\n                                    tgt_action_seq_type[i, t, 1] = 1\n                                    tgt_action_seq[i, t, 1] = 1  # index of <unk>\n\n                                tgt_action_seq_type[i, t, 2] = 0\n\n                train_func_outputs = self.model.train_func(*inputs)\n                batch_loss = train_func_outputs[0]\n                logging.debug('prob_func finished computing')\n\n                cum_nb_examples += cur_batch_size\n                loss += batch_loss * batch_size\n\n                logging.debug('Batch %d, avg. loss = %f', batch_index, batch_loss)\n\n                if batch_index == 4:\n                    elapsed = time.time() - begin_time\n                    eta = nb_train_sample / (cum_nb_examples / elapsed)\n                    print ', eta %ds' % (eta)\n                    sys.stdout.flush()\n\n                if cum_updates % config.valid_per_batch == 0:\n                    logging.info('begin validation')\n\n                    if config.data_type == 'ifttt':\n                        decode_results = decoder.decode_ifttt_dataset(self.model, self.val_data, verbose=False)\n                        channel_acc, channel_func_acc, prod_f1 = evaluation.evaluate_ifttt_results(self.val_data, decode_results, verbose=False)\n\n                        val_perf = channel_func_acc\n                        logging.info('channel accuracy: %f', channel_acc)\n                        logging.info('channel+func accuracy: %f', channel_func_acc)\n                        logging.info('prod F1: %f', prod_f1)\n                    else:\n                        decode_results = decoder.decode_python_dataset(self.model, self.val_data, verbose=False)\n                        bleu, accuracy = evaluation.evaluate_decode_results(self.val_data, decode_results, verbose=False)\n\n                        val_perf = eval(config.valid_metric)\n\n                        logging.info('avg. example bleu: %f', bleu)\n                        logging.info('accuracy: %f', accuracy)\n\n                        if len(history_valid_acc) == 0 or accuracy > np.array(history_valid_acc).max():\n                            best_model_by_acc = self.model.pull_params()\n                            # logging.info('current model has best accuracy')\n                        history_valid_acc.append(accuracy)\n\n                        if len(history_valid_bleu) == 0 or bleu > np.array(history_valid_bleu).max():\n                            best_model_by_bleu = self.model.pull_params()\n                            # logging.info('current model has best accuracy')\n                        history_valid_bleu.append(bleu)\n\n                    if len(history_valid_perf) == 0 or val_perf > np.array(history_valid_perf).max():\n                        best_model_params = self.model.pull_params()\n                        patience_counter = 0\n                        logging.info('save current best model')\n                        self.model.save(os.path.join(config.output_dir, 'model.npz'))\n                    else:\n                        patience_counter += 1\n                        logging.info('hitting patience_counter: %d', patience_counter)\n                        if patience_counter >= config.train_patience:\n                            logging.info('Early Stop!')\n                            early_stop = True\n                            break\n                    history_valid_perf.append(val_perf)\n\n                if cum_updates % config.save_per_batch == 0:\n                    self.model.save(os.path.join(config.output_dir, 'model.iter%d' % cum_updates))\n\n            logging.info('[Epoch %d] cumulative loss = %f, (took %ds)',\n                         epoch,\n                         loss / cum_nb_examples,\n                         time.time() - begin_time)\n\n            if early_stop:\n                break\n\n        logging.info('training finished, save the best model')\n        np.savez(os.path.join(config.output_dir, 'model.npz'), **best_model_params)\n\n        if config.data_type == 'django' or config.data_type == 'hs':\n            logging.info('save the best model by accuracy')\n            np.savez(os.path.join(config.output_dir, 'model.best_acc.npz'), **best_model_by_acc)\n\n            logging.info('save the best model by bleu')\n            np.savez(os.path.join(config.output_dir, 'model.best_bleu.npz'), **best_model_by_bleu)\n\n\nclass DataIterator:\n    def __init__(self, dataset, batch_size=10):\n        self.dataset = dataset\n        self.batch_size = batch_size\n        self.index_array = np.arange(self.dataset.count)\n        self.ptr = 0\n        self.buffer_size = batch_size * 5\n        self.buffer = []\n\n    def reset(self):\n        self.ptr = 0\n        self.buffer = []\n        np.random.shuffle(self.index_array)\n\n    def __iter__(self):\n        return self\n\n    def next_batch(self):\n        batch = self.buffer[:self.batch_size]\n        del self.buffer[:self.batch_size]\n\n        batch_ids = [e.eid for e in batch]\n\n        return batch, batch_ids\n\n    def next(self):\n        if self.buffer:\n            return self.next_batch()\n        else:\n            if self.ptr >= self.dataset.count:\n                raise StopIteration\n\n            self.buffer = self.index_array[self.ptr:self.ptr + self.buffer_size]\n\n            # sort buffer contents\n            examples = self.dataset.get_examples(self.buffer)\n            self.buffer = sorted(examples, key=lambda e: len(e.actions))\n\n            self.ptr += self.buffer_size\n\n            return self.next_batch()"
  },
  {
    "path": "main.py",
    "content": "import ast\nimport re\n\nfrom astnode import *\n\np_elif = re.compile(r'^elif\\s?')\np_else = re.compile(r'^else\\s?')\np_try = re.compile(r'^try\\s?')\np_except = re.compile(r'^except\\s?')\np_finally = re.compile(r'^finally\\s?')\np_decorator = re.compile(r'^@.*')\n\n\ndef escape(text):\n    text = text \\\n        .replace('\"', '`') \\\n        .replace('\\'', '`') \\\n        .replace(' ', '-SP-') \\\n        .replace('\\t', '-TAB-') \\\n        .replace('\\n', '-NL-') \\\n        .replace('(', '-LRB-') \\\n        .replace(')', '-RRB-') \\\n        .replace('|', '-BAR-')\n    return repr(text)[1:-1] if text else '-NONE-'\n\n\ndef typename(x):\n    return type(x).__name__\n\n\ndef get_tree_str_repr(node):\n    treeStr = ''\n    if type(node) == list:\n        for n in node:\n            treeStr += get_tree_str_repr(n)\n\n        return treeStr\n\n    node_name = str(type(node))\n    begin = node_name.find('ast.') + len('ast.')\n    end = node_name.rfind('\\'')\n    node_name = node_name[begin: end]\n    treeStr = '(' + node_name + ' '\n    for field_name in node._fields:\n        field = getattr(node, field_name)\n        if hasattr(field, '_fields') and len(field._fields) == 0:\n            continue\n        if field:\n            if type(field) == list:\n                fieldRepr = get_tree_str_repr(field)\n                fieldRepr = '(' + field_name + ' ' + fieldRepr + ') '\n            elif type(field) == str or type(field) == int:\n                fieldRepr = '(' + field_name + ' ' + str(field) + ') '\n            else:\n                fieldRepr = get_tree_str_repr(field)\n                fieldRepr = '(' + field_name + ' ' + fieldRepr + ') '\n\n            treeStr += fieldRepr\n    treeStr += ') '\n\n    return treeStr\n\n\ndef get_tree(node):\n\n    if isinstance(node, str):\n        node_name = escape(node)\n    elif isinstance(node, int):\n        node_name = node\n    else:\n        node_name = typename(node)\n\n    tree = ASTNode(node_name)\n\n    if not isinstance(node, ast.AST):\n        return tree\n\n    for field_name, field in ast.iter_fields(node):\n        # omit empty fields\n        if isinstance(field, ast.AST):\n            if len(field._fields) == 0:\n                continue\n\n            child = get_tree(field)\n\n            tree.children.append(ASTNode(field_name, child))\n        elif isinstance(field, str):\n            field_val = escape(field)\n            child = ASTNode(field_name, ASTNode(field_val))\n\n            tree.children.append(child)\n        elif isinstance(field, int):\n            child = ASTNode(field_name, ASTNode(field))\n\n            tree.children.append(child)\n        elif isinstance(field, list) and field:\n            child = ASTNode(field_name)\n\n            for n in field:\n                child.children.append(get_tree(n))\n\n            tree.children.append(child)\n\n    return tree\n\n\ndef parse(code):\n    if p_elif.match(code): code = 'if True: pass\\n' + code\n    if p_else.match(code): code = 'if True: pass\\n' + code\n\n    if p_try.match(code): code = code + 'pass\\nexcept: pass'\n    elif p_except.match(code): code = 'try: pass\\n' + code\n    elif p_finally.match(code): code = 'try: pass\\n' + code\n\n    if p_decorator.match(code): code = code + '\\ndef dummy(): pass'\n    if code[-1] == ':': code = code + 'pass'\n\n    root_node = ast.parse(code)\n\n    tree = get_tree(root_node.body[0])\n\n    return tree\n\n\ndef parse_django(code_file):\n    line_num = 0\n    error_num = 0\n    parse_trees = []\n    for line in open(code_file):\n        code = line.strip()\n        try:\n            parse_tree = parse(code)\n            # rule_list = parse_tree.get_rule_list(include_leaf=False)\n            parse_trees.append(parse_tree)\n            print parse_tree\n        except Exception as e:\n            error_num += 1\n            pass\n            # print e\n\n        line_num += 1\n\n    print 'total line of code: %d' % line_num\n    print 'error num: %d' % error_num\n\n    assert error_num == 0\n\n    grammar = get_grammar(parse_trees)\n\n    with open('grammar.txt', 'w') as f:\n        for rule in grammar:\n            str = rule.parent + ' -> ' + ', '.join(rule.children)\n            f.write(str + '\\n')\n\n    return grammar, parse_trees\n\n\nif __name__ == '__main__':\n#     node = ast.parse('''\n# # for i in range(1, 100):\n# #  sum = sum + i\n# #\n# # sorted(arr, reverse=True)\n# # sorted(my_dict, key=lambda x: my_dict[x], reverse=True)\n# # m = dict ( zip ( new_keys , keys ) )\n# # for f in sorted ( os . listdir ( self . path ) ) :\n# #     pass\n# for f in sorted ( os . listdir ( self . path ) ) : pass\n# ''')\n    # print ast.dump(node, annotate_fields=False)\n    # print get_tree_str_repr(node)\n    # print parse('for f in sorted ( os . listdir ( self . path ) ) : sum = sum + 1; sum = \"(hello there)\" ')\n    # print parse('global _standard_context_processors')\n\n    parse_django()\n\n\n"
  },
  {
    "path": "model.py",
    "content": "import theano\nimport theano.tensor as T\nfrom theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams\nimport numpy as np\n\nfrom collections import OrderedDict\nimport logging\nimport copy\nimport heapq\nimport sys\n\nfrom nn.layers.embeddings import Embedding\nfrom nn.layers.core import Dense, Dropout, WordDropout\nfrom nn.layers.recurrent import BiLSTM, LSTM\nimport nn.optimizers as optimizers\nimport nn.initializations as initializations\nfrom nn.activations import softmax\nfrom nn.utils.theano_utils import *\n\nfrom config import config_info\nimport config\nfrom lang.grammar import Grammar\nfrom parse import *\nfrom astnode import *\nfrom util import is_numeric\nfrom components import Hyp, PointerNet, CondAttLSTM\n\nsys.setrecursionlimit(50000)\n\nclass Model:\n    def __init__(self):\n        # self.node_embedding = Embedding(config.node_num, config.node_embed_dim, name='node_embed')\n\n        self.query_embedding = Embedding(config.source_vocab_size, config.word_embed_dim, name='query_embed')\n\n        if config.encoder == 'bilstm':\n            self.query_encoder_lstm = BiLSTM(config.word_embed_dim, config.encoder_hidden_dim / 2, return_sequences=True,\n                                             name='query_encoder_lstm')\n        else:\n            self.query_encoder_lstm = LSTM(config.word_embed_dim, config.encoder_hidden_dim, return_sequences=True,\n                                           name='query_encoder_lstm')\n\n        self.decoder_lstm = CondAttLSTM(config.rule_embed_dim + config.node_embed_dim + config.rule_embed_dim,\n                                        config.decoder_hidden_dim, config.encoder_hidden_dim, config.attention_hidden_dim,\n                                        name='decoder_lstm')\n\n        self.src_ptr_net = PointerNet()\n\n        self.terminal_gen_softmax = Dense(config.decoder_hidden_dim, 2, activation='softmax', name='terminal_gen_softmax')\n\n        self.rule_embedding_W = initializations.get('normal')((config.rule_num, config.rule_embed_dim), name='rule_embedding_W', scale=0.1)\n        self.rule_embedding_b = shared_zeros(config.rule_num, name='rule_embedding_b')\n\n        self.node_embedding = initializations.get('normal')((config.node_num, config.node_embed_dim), name='node_embed', scale=0.1)\n\n        self.vocab_embedding_W = initializations.get('normal')((config.target_vocab_size, config.rule_embed_dim), name='vocab_embedding_W', scale=0.1)\n        self.vocab_embedding_b = shared_zeros(config.target_vocab_size, name='vocab_embedding_b')\n\n        # decoder_hidden_dim -> action embed\n        self.decoder_hidden_state_W_rule = Dense(config.decoder_hidden_dim, config.rule_embed_dim, name='decoder_hidden_state_W_rule')\n\n        # decoder_hidden_dim -> action embed\n        self.decoder_hidden_state_W_token= Dense(config.decoder_hidden_dim + config.encoder_hidden_dim, config.rule_embed_dim,\n                                                 name='decoder_hidden_state_W_token')\n\n        # self.rule_encoder_lstm.params\n        self.params = self.query_embedding.params + self.query_encoder_lstm.params + \\\n                      self.decoder_lstm.params + self.src_ptr_net.params + self.terminal_gen_softmax.params + \\\n                      [self.rule_embedding_W, self.rule_embedding_b, self.node_embedding, self.vocab_embedding_W, self.vocab_embedding_b] + \\\n                      self.decoder_hidden_state_W_rule.params + self.decoder_hidden_state_W_token.params\n\n        self.srng = RandomStreams()\n\n    def build(self):\n        # (batch_size, max_example_action_num, action_type)\n        tgt_action_seq = ndim_itensor(3, 'tgt_action_seq')\n\n        # (batch_size, max_example_action_num, action_type)\n        tgt_action_seq_type = ndim_itensor(3, 'tgt_action_seq_type')\n\n        # (batch_size, max_example_action_num)\n        tgt_node_seq = ndim_itensor(2, 'tgt_node_seq')\n\n        # (batch_size, max_example_action_num)\n        tgt_par_rule_seq = ndim_itensor(2, 'tgt_par_rule_seq')\n\n        # (batch_size, max_example_action_num)\n        tgt_par_t_seq = ndim_itensor(2, 'tgt_par_t_seq')\n\n        # (batch_size, max_example_action_num, symbol_embed_dim)\n        # tgt_node_embed = self.node_embedding(tgt_node_seq, mask_zero=False)\n        tgt_node_embed = self.node_embedding[tgt_node_seq]\n\n        # (batch_size, max_query_length)\n        query_tokens = ndim_itensor(2, 'query_tokens')\n\n        # (batch_size, max_query_length, query_token_embed_dim)\n        # (batch_size, max_query_length)\n        query_token_embed, query_token_embed_mask = self.query_embedding(query_tokens, mask_zero=True)\n\n        # if WORD_DROPOUT > 0:\n        #     logging.info('used word dropout for source, p = %f', WORD_DROPOUT)\n        #     query_token_embed, query_token_embed_intact = WordDropout(WORD_DROPOUT, self.srng)(query_token_embed, False)\n\n        batch_size = tgt_action_seq.shape[0]\n        max_example_action_num = tgt_action_seq.shape[1]\n\n        # previous action embeddings\n        # (batch_size, max_example_action_num, action_embed_dim)\n        tgt_action_seq_embed = T.switch(T.shape_padright(tgt_action_seq[:, :, 0] > 0),\n                                        self.rule_embedding_W[tgt_action_seq[:, :, 0]],\n                                        self.vocab_embedding_W[tgt_action_seq[:, :, 1]])\n\n        tgt_action_seq_embed_tm1 = tensor_right_shift(tgt_action_seq_embed)\n\n        # parent rule application embeddings\n        tgt_par_rule_embed = T.switch(tgt_par_rule_seq[:, :, None] < 0,\n                                      T.alloc(0., 1, config.rule_embed_dim),\n                                      self.rule_embedding_W[tgt_par_rule_seq])\n\n        if not config.frontier_node_type_feed:\n            tgt_node_embed *= 0.\n\n        if not config.parent_action_feed:\n            tgt_par_rule_embed *= 0.\n\n        # (batch_size, max_example_action_num, action_embed_dim + symbol_embed_dim + action_embed_dim)\n        decoder_input = T.concatenate([tgt_action_seq_embed_tm1, tgt_node_embed, tgt_par_rule_embed], axis=-1)\n\n        # (batch_size, max_query_length, query_embed_dim)\n        query_embed = self.query_encoder_lstm(query_token_embed, mask=query_token_embed_mask,\n                                              dropout=config.dropout, srng=self.srng)\n\n        # (batch_size, max_example_action_num)\n        tgt_action_seq_mask = T.any(tgt_action_seq_type, axis=-1)\n        \n        # decoder_hidden_states: (batch_size, max_example_action_num, lstm_hidden_state)\n        # ctx_vectors: (batch_size, max_example_action_num, encoder_hidden_dim)\n        decoder_hidden_states, _, ctx_vectors = self.decoder_lstm(decoder_input,\n                                                                  context=query_embed,\n                                                                  context_mask=query_token_embed_mask,\n                                                                  mask=tgt_action_seq_mask,\n                                                                  parent_t_seq=tgt_par_t_seq,\n                                                                  dropout=config.dropout,\n                                                                  srng=self.srng)\n\n        # if DECODER_DROPOUT > 0:\n        #     logging.info('used dropout for decoder output, p = %f', DECODER_DROPOUT)\n        #     decoder_hidden_states = Dropout(DECODER_DROPOUT, self.srng)(decoder_hidden_states)\n\n        # ====================================================\n        # apply additional non-linearity transformation before\n        # predicting actions\n        # ====================================================\n\n        decoder_hidden_state_trans_rule = self.decoder_hidden_state_W_rule(decoder_hidden_states)\n        decoder_hidden_state_trans_token = self.decoder_hidden_state_W_token(T.concatenate([decoder_hidden_states, ctx_vectors], axis=-1))\n\n        # (batch_size, max_example_action_num, rule_num)\n        rule_predict = softmax(T.dot(decoder_hidden_state_trans_rule, T.transpose(self.rule_embedding_W)) + self.rule_embedding_b)\n\n        # (batch_size, max_example_action_num, 2)\n        terminal_gen_action_prob = self.terminal_gen_softmax(decoder_hidden_states)\n\n        # (batch_size, max_example_action_num, target_vocab_size)\n        vocab_predict = softmax(T.dot(decoder_hidden_state_trans_token, T.transpose(self.vocab_embedding_W)) + self.vocab_embedding_b)\n\n        # (batch_size, max_example_action_num, lstm_hidden_state + encoder_hidden_dim)\n        ptr_net_decoder_state = T.concatenate([decoder_hidden_states, ctx_vectors], axis=-1)\n\n        # (batch_size, max_example_action_num, max_query_length)\n        copy_prob = self.src_ptr_net(query_embed, query_token_embed_mask, ptr_net_decoder_state)\n\n        # (batch_size, max_example_action_num)\n        rule_tgt_prob = rule_predict[T.shape_padright(T.arange(batch_size)),\n                                     T.shape_padleft(T.arange(max_example_action_num)),\n                                     tgt_action_seq[:, :, 0]]\n\n        # (batch_size, max_example_action_num)\n        vocab_tgt_prob = vocab_predict[T.shape_padright(T.arange(batch_size)),\n                                       T.shape_padleft(T.arange(max_example_action_num)),\n                                       tgt_action_seq[:, :, 1]]\n\n        # (batch_size, max_example_action_num)\n        copy_tgt_prob = copy_prob[T.shape_padright(T.arange(batch_size)),\n                                  T.shape_padleft(T.arange(max_example_action_num)),\n                                  tgt_action_seq[:, :, 2]]\n\n\n        # (batch_size, max_example_action_num)\n        tgt_prob = tgt_action_seq_type[:, :, 0] * rule_tgt_prob + \\\n                   tgt_action_seq_type[:, :, 1] * terminal_gen_action_prob[:, :, 0] * vocab_tgt_prob + \\\n                   tgt_action_seq_type[:, :, 2] * terminal_gen_action_prob[:, :, 1] * copy_tgt_prob\n\n        likelihood = T.log(tgt_prob + 1.e-7 * (1 - tgt_action_seq_mask))\n        loss = - (likelihood * tgt_action_seq_mask).sum(axis=-1) # / tgt_action_seq_mask.sum(axis=-1)\n        loss = T.mean(loss)\n\n        # let's build the function!\n        train_inputs = [query_tokens, tgt_action_seq, tgt_action_seq_type,\n                        tgt_node_seq, tgt_par_rule_seq, tgt_par_t_seq]\n        optimizer = optimizers.get(config.optimizer)\n        optimizer.clip_grad = config.clip_grad\n        updates, grads = optimizer.get_updates(self.params, loss)\n        self.train_func = theano.function(train_inputs, [loss],\n                                          # [loss, tgt_action_seq_type, tgt_action_seq,\n                                          #  rule_tgt_prob, vocab_tgt_prob, copy_tgt_prob,\n                                          #  copy_prob, terminal_gen_action_prob],\n                                          updates=updates)\n\n        # if WORD_DROPOUT > 0:\n        #     self.build_decoder(query_tokens, query_token_embed_intact, query_token_embed_mask)\n        # else:\n        #     self.build_decoder(query_tokens, query_token_embed, query_token_embed_mask)\n\n        self.build_decoder(query_tokens, query_token_embed, query_token_embed_mask)\n\n    def build_decoder(self, query_tokens, query_token_embed, query_token_embed_mask):\n        logging.info('building decoder ...')\n\n        # (batch_size, decoder_state_dim)\n        decoder_prev_state = ndim_tensor(2, name='decoder_prev_state')\n\n        # (batch_size, decoder_state_dim)\n        decoder_prev_cell = ndim_tensor(2, name='decoder_prev_cell')\n\n        # (batch_size, n_timestep, decoder_state_dim)\n        hist_h = ndim_tensor(3, name='hist_h')\n\n        # (batch_size, decoder_state_dim)\n        prev_action_embed = ndim_tensor(2, name='prev_action_embed')\n\n        # (batch_size)\n        node_id = T.ivector(name='node_id')\n\n        # (batch_size, node_embed_dim)\n        node_embed = self.node_embedding[node_id]\n\n        # (batch_size)\n        par_rule_id = T.ivector(name='par_rule_id')\n\n        # (batch_size, decoder_state_dim)\n        par_rule_embed = T.switch(par_rule_id[:, None] < 0,\n                                  T.alloc(0., 1, config.rule_embed_dim),\n                                  self.rule_embedding_W[par_rule_id])\n\n        # ([time_step])\n        time_steps = T.ivector(name='time_steps')\n\n        # (batch_size)\n        parent_t = T.ivector(name='parent_t')\n\n        # (batch_size, 1)\n        parent_t_reshaped = T.shape_padright(parent_t)\n\n        query_embed = self.query_encoder_lstm(query_token_embed, mask=query_token_embed_mask,\n                                              dropout=config.dropout, train=False)\n\n        # (batch_size, 1, decoder_state_dim)\n        prev_action_embed_reshaped = prev_action_embed.dimshuffle((0, 'x', 1))\n\n        # (batch_size, 1, node_embed_dim)\n        node_embed_reshaped = node_embed.dimshuffle((0, 'x', 1))\n\n        # (batch_size, 1, node_embed_dim)\n        par_rule_embed_reshaped = par_rule_embed.dimshuffle((0, 'x', 1))\n\n        if not config.frontier_node_type_feed:\n            node_embed_reshaped *= 0.\n\n        if not config.parent_action_feed:\n            par_rule_embed_reshaped *= 0.\n\n        decoder_input = T.concatenate([prev_action_embed_reshaped, node_embed_reshaped, par_rule_embed_reshaped], axis=-1)\n\n        # (batch_size, 1, decoder_state_dim)\n        # (batch_size, 1, decoder_state_dim)\n        # (batch_size, 1, field_token_encode_dim)\n        decoder_next_state_dim3, decoder_next_cell_dim3, ctx_vectors = self.decoder_lstm(decoder_input,\n                                                                                         init_state=decoder_prev_state,\n                                                                                         init_cell=decoder_prev_cell,\n                                                                                         hist_h=hist_h,\n                                                                                         context=query_embed,\n                                                                                         context_mask=query_token_embed_mask,\n                                                                                         parent_t_seq=parent_t_reshaped,\n                                                                                         dropout=config.dropout,\n                                                                                         train=False,\n                                                                                         time_steps=time_steps)\n\n        decoder_next_state = decoder_next_state_dim3.flatten(2)\n        # decoder_output = decoder_next_state * (1 - DECODER_DROPOUT)\n\n        decoder_next_cell = decoder_next_cell_dim3.flatten(2)\n\n        decoder_next_state_trans_rule = self.decoder_hidden_state_W_rule(decoder_next_state)\n        decoder_next_state_trans_token = self.decoder_hidden_state_W_token(T.concatenate([decoder_next_state, ctx_vectors.flatten(2)], axis=-1))\n\n        rule_prob = softmax(T.dot(decoder_next_state_trans_rule, T.transpose(self.rule_embedding_W)) + self.rule_embedding_b)\n\n        gen_action_prob = self.terminal_gen_softmax(decoder_next_state)\n\n        vocab_prob = softmax(T.dot(decoder_next_state_trans_token, T.transpose(self.vocab_embedding_W)) + self.vocab_embedding_b)\n\n        ptr_net_decoder_state = T.concatenate([decoder_next_state_dim3, ctx_vectors], axis=-1)\n\n        copy_prob = self.src_ptr_net(query_embed, query_token_embed_mask, ptr_net_decoder_state)\n\n        copy_prob = copy_prob.flatten(2)\n\n        inputs = [query_tokens]\n        outputs = [query_embed, query_token_embed_mask]\n\n        self.decoder_func_init = theano.function(inputs, outputs)\n\n        inputs = [time_steps, decoder_prev_state, decoder_prev_cell, hist_h, prev_action_embed,\n                  node_id, par_rule_id, parent_t,\n                  query_embed, query_token_embed_mask]\n\n        outputs = [decoder_next_state, decoder_next_cell,\n                   rule_prob, gen_action_prob, vocab_prob, copy_prob]\n\n        self.decoder_func_next_step = theano.function(inputs, outputs)\n\n    def decode(self, example, grammar, terminal_vocab, beam_size, max_time_step, log=False):\n        # beam search decoding\n\n        eos = 1\n        unk = terminal_vocab.unk\n        vocab_embedding = self.vocab_embedding_W.get_value(borrow=True)\n        rule_embedding = self.rule_embedding_W.get_value(borrow=True)\n\n        query_tokens = example.data[0]\n\n        query_embed, query_token_embed_mask = self.decoder_func_init(query_tokens)\n\n        completed_hyps = []\n        completed_hyp_num = 0\n        live_hyp_num = 1\n\n        root_hyp = Hyp(grammar)\n        root_hyp.state = np.zeros(config.decoder_hidden_dim).astype('float32')\n        root_hyp.cell = np.zeros(config.decoder_hidden_dim).astype('float32')\n        root_hyp.action_embed = np.zeros(config.rule_embed_dim).astype('float32')\n        root_hyp.node_id = grammar.get_node_type_id(root_hyp.tree.type)\n        root_hyp.parent_rule_id = -1\n\n        hyp_samples = [root_hyp]  # [list() for i in range(live_hyp_num)]\n\n        # source word id in the terminal vocab\n        src_token_id = [terminal_vocab[t] for t in example.query][:config.max_query_length]\n        unk_pos_list = [x for x, t in enumerate(src_token_id) if t == unk]\n\n        # sometimes a word may appear multi-times in the source, in this case,\n        # we just copy its first appearing position. Therefore we mask the words\n        # appearing second and onwards to -1\n        token_set = set()\n        for i, tid in enumerate(src_token_id):\n            if tid in token_set:\n                src_token_id[i] = -1\n            else: token_set.add(tid)\n\n        for t in xrange(max_time_step):\n            hyp_num = len(hyp_samples)\n            # print 'time step [%d]' % t\n            decoder_prev_state = np.array([hyp.state for hyp in hyp_samples]).astype('float32')\n            decoder_prev_cell = np.array([hyp.cell for hyp in hyp_samples]).astype('float32')\n\n            hist_h = np.zeros((hyp_num, max_time_step, config.decoder_hidden_dim)).astype('float32')\n\n            if t > 0:\n                for i, hyp in enumerate(hyp_samples):\n                    hist_h[i, :len(hyp.hist_h), :] = hyp.hist_h\n                    # for j, h in enumerate(hyp.hist_h):\n                    #    hist_h[i, j] = h\n\n            prev_action_embed = np.array([hyp.action_embed for hyp in hyp_samples]).astype('float32')\n            node_id = np.array([hyp.node_id for hyp in hyp_samples], dtype='int32')\n            parent_rule_id = np.array([hyp.parent_rule_id for hyp in hyp_samples], dtype='int32')\n            parent_t = np.array([hyp.get_action_parent_t() for hyp in hyp_samples], dtype='int32')\n            query_embed_tiled = np.tile(query_embed, [live_hyp_num, 1, 1])\n            query_token_embed_mask_tiled = np.tile(query_token_embed_mask, [live_hyp_num, 1])\n\n            inputs = [np.array([t], dtype='int32'), decoder_prev_state, decoder_prev_cell, hist_h, prev_action_embed,\n                      node_id, parent_rule_id, parent_t,\n                      query_embed_tiled, query_token_embed_mask_tiled]\n\n            decoder_next_state, decoder_next_cell, \\\n            rule_prob, gen_action_prob, vocab_prob, copy_prob  = self.decoder_func_next_step(*inputs)\n\n            new_hyp_samples = []\n\n            cut_off_k = beam_size\n            score_heap = []\n\n            # iterating over items in the beam\n            # print 'time step: %d, hyp num: %d' % (t, live_hyp_num)\n\n            word_prob = gen_action_prob[:, 0:1] * vocab_prob\n            word_prob[:, unk] = 0\n\n            hyp_scores = np.array([hyp.score for hyp in hyp_samples])\n\n            # word_prob[:, src_token_id] += gen_action_prob[:, 1:2] * copy_prob[:, :len(src_token_id)]\n            # word_prob[:, unk] = 0\n\n            rule_apply_cand_hyp_ids = []\n            rule_apply_cand_scores = []\n            rule_apply_cand_rules = []\n            rule_apply_cand_rule_ids = []\n\n            hyp_frontier_nts = []\n            word_gen_hyp_ids = []\n            cand_copy_probs = []\n            unk_words = []\n\n            for k in xrange(live_hyp_num):\n                hyp = hyp_samples[k]\n\n                # if k == 0:\n                #     print 'Top Hyp: %s' % hyp.tree.__repr__()\n\n                frontier_nt = hyp.frontier_nt()\n                hyp_frontier_nts.append(frontier_nt)\n\n                assert hyp, 'none hyp!'\n\n                # if it's not a leaf\n                if not grammar.is_value_node(frontier_nt):\n                    # iterate over all the possible rules\n                    rules = grammar[frontier_nt.as_type_node] if config.head_nt_constraint else grammar\n                    assert len(rules) > 0, 'fail to expand nt node %s' % frontier_nt\n                    for rule in rules:\n                        rule_id = grammar.rule_to_id[rule]\n\n                        cur_rule_score = np.log(rule_prob[k, rule_id])\n                        new_hyp_score = hyp.score + cur_rule_score\n\n                        rule_apply_cand_hyp_ids.append(k)\n                        rule_apply_cand_scores.append(new_hyp_score)\n                        rule_apply_cand_rules.append(rule)\n                        rule_apply_cand_rule_ids.append(rule_id)\n\n                else:  # it's a leaf that holds values\n                    cand_copy_prob = 0.0\n                    for i, tid in enumerate(src_token_id):\n                        if tid != -1:\n                            word_prob[k, tid] += gen_action_prob[k, 1] * copy_prob[k, i]\n                            cand_copy_prob = gen_action_prob[k, 1]\n\n                    # and unk copy probability\n                    if len(unk_pos_list) > 0:\n                        unk_pos = copy_prob[k, unk_pos_list].argmax()\n                        unk_pos = unk_pos_list[unk_pos]\n\n                        unk_copy_score = gen_action_prob[k, 1] * copy_prob[k, unk_pos]\n                        word_prob[k, unk] = unk_copy_score\n\n                        unk_word = example.query[unk_pos]\n                        unk_words.append(unk_word)\n\n                        cand_copy_prob = gen_action_prob[k, 1]\n\n                    word_gen_hyp_ids.append(k)\n                    cand_copy_probs.append(cand_copy_prob)\n\n            # prune the hyp space\n            if completed_hyp_num >= beam_size:\n                break\n\n            word_prob = np.log(word_prob)\n\n            word_gen_hyp_num = len(word_gen_hyp_ids)\n            rule_apply_cand_num = len(rule_apply_cand_scores)\n\n            if word_gen_hyp_num > 0:\n                word_gen_cand_scores = hyp_scores[word_gen_hyp_ids, None] + word_prob[word_gen_hyp_ids, :]\n                word_gen_cand_scores_flat = word_gen_cand_scores.flatten()\n\n                cand_scores = np.concatenate([rule_apply_cand_scores, word_gen_cand_scores_flat])\n            else:\n                cand_scores = np.array(rule_apply_cand_scores)\n\n            top_cand_ids = (-cand_scores).argsort()[:beam_size - completed_hyp_num]\n\n            # expand_cand_num = 0\n            for cand_id in top_cand_ids:\n                # cand is rule application\n                new_hyp = None\n                if cand_id < rule_apply_cand_num:\n                    hyp_id = rule_apply_cand_hyp_ids[cand_id]\n                    hyp = hyp_samples[hyp_id]\n                    rule_id = rule_apply_cand_rule_ids[cand_id]\n                    rule = rule_apply_cand_rules[cand_id]\n                    new_hyp_score = rule_apply_cand_scores[cand_id]\n\n                    new_hyp = Hyp(hyp)\n                    new_hyp.apply_rule(rule)\n\n                    new_hyp.score = new_hyp_score\n                    new_hyp.state = copy.copy(decoder_next_state[hyp_id])\n                    new_hyp.hist_h.append(copy.copy(new_hyp.state))\n                    new_hyp.cell = copy.copy(decoder_next_cell[hyp_id])\n                    new_hyp.action_embed = rule_embedding[rule_id]\n                else:\n                    tid = (cand_id - rule_apply_cand_num) % word_prob.shape[1]\n                    word_gen_hyp_id = (cand_id - rule_apply_cand_num) / word_prob.shape[1]\n                    hyp_id = word_gen_hyp_ids[word_gen_hyp_id]\n\n                    if tid == unk:\n                        token = unk_words[word_gen_hyp_id]\n                    else:\n                        token = terminal_vocab.id_token_map[tid]\n\n                    frontier_nt = hyp_frontier_nts[hyp_id]\n                    # if frontier_nt.type == int and (not (is_numeric(token) or token == '<eos>')):\n                    #     continue\n\n                    hyp = hyp_samples[hyp_id]\n                    new_hyp_score = word_gen_cand_scores[word_gen_hyp_id, tid]\n\n                    new_hyp = Hyp(hyp)\n                    new_hyp.append_token(token)\n\n                    if log:\n                        cand_copy_prob = cand_copy_probs[word_gen_hyp_id]\n                        if cand_copy_prob > 0.5:\n                            new_hyp.log += ' || ' + str(new_hyp.frontier_nt()) + '{copy[%s][p=%f]}' % (token ,cand_copy_prob)\n\n                    new_hyp.score = new_hyp_score\n                    new_hyp.state = copy.copy(decoder_next_state[hyp_id])\n                    new_hyp.hist_h.append(copy.copy(new_hyp.state))\n                    new_hyp.cell = copy.copy(decoder_next_cell[hyp_id])\n                    new_hyp.action_embed = vocab_embedding[tid]\n                    new_hyp.node_id = grammar.get_node_type_id(frontier_nt)\n\n\n                # get the new frontier nt after rule application\n                new_frontier_nt = new_hyp.frontier_nt()\n\n                # if new_frontier_nt is None, then we have a new completed hyp!\n                if new_frontier_nt is None:\n                    # if t <= 1:\n                    #     continue\n\n                    new_hyp.n_timestep = t + 1\n                    completed_hyps.append(new_hyp)\n                    completed_hyp_num += 1\n\n                else:\n                    new_hyp.node_id = grammar.get_node_type_id(new_frontier_nt.type)\n                    # new_hyp.parent_rule_id = grammar.rule_to_id[\n                    #     new_frontier_nt.parent.to_rule(include_value=False)]\n                    new_hyp.parent_rule_id = grammar.rule_to_id[new_frontier_nt.parent.applied_rule]\n\n                    new_hyp_samples.append(new_hyp)\n\n                # expand_cand_num += 1\n                # if expand_cand_num >= beam_size - completed_hyp_num:\n                #     break\n\n                # cand is word generation\n\n            live_hyp_num = min(len(new_hyp_samples), beam_size - completed_hyp_num)\n            if live_hyp_num < 1:\n                break\n\n            hyp_samples = new_hyp_samples\n            # hyp_samples = sorted(new_hyp_samples, key=lambda x: x.score, reverse=True)[:live_hyp_num]\n\n        completed_hyps = sorted(completed_hyps, key=lambda x: x.score, reverse=True)\n\n        return completed_hyps\n\n    @property\n    def params_name_to_id(self):\n        name_to_id = dict()\n        for i, p in enumerate(self.params):\n            assert p.name is not None\n            # print 'parameter [%s]' % p.name\n\n            name_to_id[p.name] = i\n\n        return name_to_id\n\n    @property\n    def params_dict(self):\n        assert len(set(p.name for p in self.params)) == len(self.params), 'param name clashes!'\n        return OrderedDict((p.name, p) for p in self.params)\n\n    def pull_params(self):\n        return OrderedDict([(p_name, p.get_value(borrow=False)) for (p_name, p) in self.params_dict.iteritems()])\n\n    def save(self, model_file, **kwargs):\n        logging.info('save model to [%s]', model_file)\n\n        weights_dict = self.pull_params()\n        for k, v in kwargs.iteritems():\n            weights_dict[k] = v\n\n        np.savez(model_file, **weights_dict)\n\n    def load(self, model_file):\n        logging.info('load model from [%s]', model_file)\n        weights_dict = np.load(model_file)\n\n        # assert len(weights_dict.files) == len(self.params_dict)\n\n        for p_name, p in self.params_dict.iteritems():\n            if p_name not in weights_dict:\n                raise RuntimeError('parameter [%s] not in saved weights file', p_name)\n            else:\n                logging.info('loading parameter [%s]', p_name)\n                assert np.array_equal(p.shape.eval(), weights_dict[p_name].shape), \\\n                    'shape mis-match for [%s]!, %s != %s' % (p_name, p.shape.eval(), weights_dict[p_name].shape)\n\n                p.set_value(weights_dict[p_name])\n"
  },
  {
    "path": "nn/__init__.py",
    "content": "__author__ = 'yinpengcheng'\n"
  },
  {
    "path": "nn/activations.py",
    "content": "import theano.tensor as T\n\n\ndef softmax(x):\n    return T.nnet.softmax(x.reshape((-1, x.shape[-1]))).reshape(x.shape)\n\n\ndef time_distributed_softmax(x):\n    import warnings\n    warnings.warn(\"time_distributed_softmax is deprecated. Just use softmax!\", DeprecationWarning)\n    return softmax(x)\n\n\ndef softplus(x):\n    return T.nnet.softplus(x)\n\n\ndef relu(x):\n    return T.nnet.relu(x)\n\n\ndef tanh(x):\n    return T.tanh(x)\n\n\ndef sigmoid(x):\n    return T.nnet.sigmoid(x)\n\n\ndef hard_sigmoid(x):\n    return T.nnet.hard_sigmoid(x)\n\n\ndef linear(x):\n    '''\n    The function returns the variable that is passed in, so all types work\n    '''\n    return x\n\n\nfrom .utils.generic_utils import get_from_module\ndef get(identifier):\n    return get_from_module(identifier, globals(), 'activation function')\n"
  },
  {
    "path": "nn/initializations.py",
    "content": "import theano\nimport theano.tensor as T\nimport numpy as np\n\nfrom .utils.theano_utils import sharedX, shared_zeros, shared_ones\n\n\ndef get_fans(shape):\n    fan_in = shape[0] if len(shape) == 2 else np.prod(shape[1:])\n    fan_out = shape[1] if len(shape) == 2 else shape[0]\n    return fan_in, fan_out\n\n\ndef uniform(shape, scale=0.01, name=None):\n    return sharedX(np.random.uniform(low=-scale, high=scale, size=shape), name=name)\n\n\ndef normal(shape, scale=0.01, name=None):\n    return sharedX(np.random.randn(*shape) * scale, name=name)\n\n\ndef lecun_uniform(shape):\n    ''' Reference: LeCun 98, Efficient Backprop\n        http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf\n    '''\n    fan_in, fan_out = get_fans(shape)\n    scale = np.sqrt(3. / fan_in)\n    return uniform(shape, scale)\n\n\ndef glorot_normal(shape):\n    ''' Reference: Glorot & Bengio, AISTATS 2010\n    '''\n    fan_in, fan_out = get_fans(shape)\n    s = np.sqrt(2. / (fan_in + fan_out))\n    return normal(shape, s)\n\n\ndef glorot_uniform(shape, name=None):\n    fan_in, fan_out = get_fans(shape)\n    s = np.sqrt(6. / (fan_in + fan_out))\n    return uniform(shape, s, name=name)\n\n\ndef he_normal(shape):\n    ''' Reference:  He et al., http://arxiv.org/abs/1502.01852\n    '''\n    fan_in, fan_out = get_fans(shape)\n    s = np.sqrt(2. / fan_in)\n    return normal(shape, s)\n\n\ndef he_uniform(shape):\n    fan_in, fan_out = get_fans(shape)\n    s = np.sqrt(6. / fan_in)\n    return uniform(shape, s)\n\n\ndef orthogonal(shape, scale=1.1):\n    ''' From Lasagne\n    '''\n    flat_shape = (shape[0], np.prod(shape[1:]))\n    a = np.random.normal(0.0, 1.0, flat_shape)\n    u, _, v = np.linalg.svd(a, full_matrices=False)\n    # pick the one with the correct shape\n    q = u if u.shape == flat_shape else v\n    q = q.reshape(shape)\n    return sharedX(scale * q[:shape[0], :shape[1]])\n\n\ndef identity(shape, scale=1):\n    if len(shape) != 2 or shape[0] != shape[1]:\n        raise Exception(\"Identity matrix initialization can only be used for 2D square matrices\")\n    else:\n        return sharedX(scale * np.identity(shape[0]))\n\n\ndef zero(shape):\n    return shared_zeros(shape)\n\n\ndef one(shape):\n    return shared_ones(shape)\n\n\nfrom .utils.generic_utils import get_from_module\ndef get(identifier):\n    return get_from_module(identifier, globals(), 'initialization')\n"
  },
  {
    "path": "nn/layers/__init__.py",
    "content": "__author__ = 'yinpengcheng'\n"
  },
  {
    "path": "nn/layers/convolution.py",
    "content": "# -*- coding: utf-8 -*-\n\nfrom .core import Layer\nfrom nn.utils.theano_utils import *\nimport nn.initializations as initializations\nimport nn.activations as activations\nfrom theano.tensor.nnet import conv\nfrom theano.tensor.signal import pool\n\n\nclass Convolution2d(Layer):\n    \"\"\"a convolutional layer with max pooling\"\"\"\n\n    def __init__(self, max_sent_len, word_embed_dim, filter_num, filter_window_size,\n                 border_mode='valid', activation='relu', name='Convolution2d'):\n        super(Convolution2d, self).__init__()\n\n        self.init = initializations.get('uniform')\n        self.activation = activations.get(activation)\n        self.border_mode = border_mode\n\n        self.W = self.init((filter_num, 1, filter_window_size, word_embed_dim), scale=0.01, name='W')\n        self.b = shared_zeros((filter_num), name='b')\n\n        self.params = [self.W, self.b]\n\n        if self.border_mode == 'valid':\n            self.ds = (max_sent_len - filter_window_size + 1, 1)\n        elif self.border_mode == 'full':\n            self.ds = (max_sent_len + filter_window_size - 1, 1)\n\n        if name is not None:\n            self.set_name(name)\n\n    def __call__(self, X):\n        # X: (batch_size, max_sent_len, word_embed_dim)\n\n        # valid: (batch_size, nb_filters, max_sent_len - filter_window_size + 1, 1)\n        # full: (batch_size, nb_filters, max_sent_len + filter_window_size - 1, 1)\n        conv_output = conv.conv2d(X.reshape((X.shape[0], 1, X.shape[1], X.shape[2])),\n                                  filters=self.W,\n                                  filter_shape=self.W.shape.eval(),\n                                  border_mode=self.border_mode)\n\n        output = self.activation(conv_output + self.b.dimshuffle(('x', 0, 'x', 'x')))\n\n        # (batch_size, nb_filters, 1, 1)\n        output = pool.pool_2d(output, ds=self.ds, ignore_border=True, mode='max')\n        # (batch_size, nb_filters)\n        output = output.flatten(2)\n        return output\n"
  },
  {
    "path": "nn/layers/core.py",
    "content": "# -*- coding: utf-8 -*-\n\nimport theano\nimport theano.tensor as T\nimport numpy as np\n\nfrom nn.utils.theano_utils import *\nimport nn.initializations as initializations\nimport nn.activations as activations\n\nfrom theano.tensor.shared_randomstreams import RandomStreams\nfrom theano.sandbox.rng_mrg import MRG_RandomStreams\n\n\nclass Layer(object):\n    def __init__(self):\n        self.params = []\n\n    def init_updates(self):\n        self.updates = []\n\n    def __call__(self, X):\n        return X\n\n    def supports_masked_input(self):\n        ''' Whether or not this layer respects the output mask of its previous layer in its calculations. If you try\n        to attach a layer that does *not* support masked_input to a layer that gives a non-None output_mask() that is\n        an error'''\n        return False\n\n    def get_output_mask(self, train=None):\n        '''\n        For some models (such as RNNs) you want a way of being able to mark some output data-points as\n        \"masked\", so they are not used in future calculations. In such a model, get_output_mask() should return a mask\n        of one less dimension than get_output() (so if get_output is (nb_samples, nb_timesteps, nb_dimensions), then the mask\n        is (nb_samples, nb_timesteps), with a one for every unmasked datapoint, and a zero for every masked one.\n\n        If there is *no* masking then it shall return None. For instance if you attach an Activation layer (they support masking)\n        to a layer with an output_mask, then that Activation shall also have an output_mask. If you attach it to a layer with no\n        such mask, then the Activation's get_output_mask shall return None.\n\n        Some layers have an output_mask even if their input is unmasked, notably Embedding which can turn the entry \"0\" into\n        a mask.\n        '''\n        return None\n\n    def set_weights(self, weights):\n        for p, w in zip(self.params, weights):\n            if p.eval().shape != w.shape:\n                raise Exception(\"Layer shape %s not compatible with weight shape %s.\" % (p.eval().shape, w.shape))\n            p.set_value(floatX(w))\n\n    def get_weights(self):\n        weights = []\n        for p in self.params:\n            weights.append(p.get_value())\n        return weights\n\n    def get_params(self):\n        return self.params\n\n    def set_name(self, name):\n        if name:\n            for i in range(len(self.params)):\n                if self.params[i].name is None:\n                    self.params[i].name = '%s_p%d' % (name, i)\n                else:\n                    self.params[i].name = name + '_' + self.params[i].name\n\n        self.name = name\n\n\nclass MaskedLayer(Layer):\n    '''\n    If your layer trivially supports masking (by simply copying the input mask to the output), then subclass MaskedLayer\n    instead of Layer, and make sure that you incorporate the input mask into your calculation of get_output()\n    '''\n    def supports_masked_input(self):\n        return True\n\n\nclass Dense(Layer):\n    def __init__(self, input_dim, output_dim, init='glorot_uniform', activation='tanh', name='Dense'):\n\n        super(Dense, self).__init__()\n        self.init = initializations.get(init)\n        self.activation = activations.get(activation)\n        self.input_dim = input_dim\n        self.output_dim = output_dim\n\n        self.input = T.matrix()\n        self.W = self.init((self.input_dim, self.output_dim))\n        self.b = shared_zeros((self.output_dim))\n\n        self.params = [self.W, self.b]\n\n        if name is not None:\n            self.set_name(name)\n\n    def set_name(self, name):\n        self.W.name = '%s_W' % name\n        self.b.name = '%s_b' % name\n\n    def __call__(self, X):\n        output = self.activation(T.dot(X, self.W) + self.b)\n        return output\n\n\nclass Dropout(Layer):\n    def __init__(self, p, srng, name='dropout'):\n        super(Dropout, self).__init__()\n\n        assert 0. < p < 1.\n\n        self.p = p\n        self.srng = srng\n\n        if name is not None:\n            self.set_name(name)\n\n    def __call__(self, X, train_only=True):\n        retain_prob = 1. - self.p\n\n        X_train = X * self.srng.binomial(X.shape, p=retain_prob, dtype=theano.config.floatX)\n        X_test = X * retain_prob\n\n        if train_only:\n            return X_train\n        else:\n            return X_train, X_test\n\nclass WordDropout(Layer):\n    def __init__(self, p, srng, name='WordDropout'):\n        super(WordDropout, self).__init__()\n\n        self.p = p\n        self.srng = srng\n\n    def __call__(self, X, train_only=True):\n        retain_prob = 1. - self.p\n\n        mask = self.srng.binomial(X.shape[:-1], p=retain_prob, dtype=theano.config.floatX)\n        X_train = X * T.shape_padright(mask)\n\n        if train_only:\n            return X_train\n        else:\n            return X_train, X\n\n\n\n"
  },
  {
    "path": "nn/layers/embeddings.py",
    "content": "# -*- coding: utf-8 -*-\n\nfrom .core import Layer\nfrom nn.utils.theano_utils import *\nimport nn.initializations as initializations\n\nimport nn.activations as activations\nfrom theano.ifelse import ifelse\n\n\ndef get_embed_iter(file_path):\n    for line in open(file_path):\n        line = line.strip()\n        data = line.split(' ')\n\n        word = data[0]\n        embed = np.asarray([float(e) for e in data[1:]], dtype='float32')\n\n        yield word, embed\n\n\nclass Embedding(Layer):\n    '''\n        Turn positive integers (indexes) into denses vectors of fixed size.\n        eg. [[4], [20]] -> [[0.25, 0.1], [0.6, -0.2]]\n\n        @input_dim: size of vocabulary (highest input integer + 1)\n        @out_dim: size of dense representation\n    '''\n    def __init__(self, input_dim, output_dim, init='uniform', name=None):\n\n        super(Embedding, self).__init__()\n        self.init = initializations.get(init)\n        self.input_dim = input_dim\n        self.output_dim = output_dim\n\n        self.W = self.init((self.input_dim, self.output_dim), scale=0.1)\n        self.params = [self.W]\n\n        if name is not None:\n            self.set_name(name)\n\n    def get_output_mask(self, X):\n        return (T.ones_like(X) * (1 - T.eq(X, 0))).astype('int8')\n\n    def init_pretrained(self, file_path, vocab):\n        W = self.W.get_value(borrow=True)\n        inited_words = set()\n\n        for word, embed in get_embed_iter(file_path):\n            if word in vocab:\n                idx = vocab[word]\n                W[idx] = embed\n\n                inited_words.add(word)\n\n        return inited_words\n\n    def __call__(self, X, mask_zero=False):\n        out = self.W[X]\n        if mask_zero:\n            return out, self.get_output_mask(X)\n        else:\n            return out\n\n\nclass HybridEmbedding(Layer):\n    '''\n        Turn positive integers (indexes) into denses vectors of fixed size.\n        eg. [[4], [20]] -> [[0.25, 0.1], [0.6, -0.2]]\n\n        @input_dim: size of vocabulary (highest input integer + 1)\n        @out_dim: size of dense representation\n    '''\n    def __init__(self, embed_size, unfixed_embed_size, embed_dim, init='uniform', name='HybridEmbedding'):\n\n        super(HybridEmbedding, self).__init__()\n        self.init = initializations.get(init)\n\n        self.unfixed_embed_size = unfixed_embed_size\n\n        self.W_unfixed = self.init((embed_size, embed_dim))\n        self.W_fixed = self.init((embed_size, embed_dim))\n        self.W_fixed.name = 'HybridEmbedding_fiexed_embed_matrix'\n\n        # print W_fixed\n        # for id, row in enumerate(self.W_fixed.get_value()):\n        #     if id >= 400: print '[word %d]' % id, row\n\n        self.params = [self.W_unfixed]\n\n        if name is not None:\n            self.set_name(name)\n\n    def get_output_mask(self, X):\n        return T.ones_like(X) * (1 - T.eq(X, 0))\n\n    def __call__(self, X, mask_zero=False):\n        cond = T.lt(X, self.unfixed_embed_size)\n        out = T.switch(T.shape_padright(cond), self.W_unfixed[X], self.W_fixed[X])\n\n        if mask_zero:\n            return out, self.get_output_mask(X)\n        else:\n            return out"
  },
  {
    "path": "nn/layers/recurrent.py",
    "content": "# -*- coding: utf-8 -*-\n\nimport logging\nimport theano\nimport theano.tensor as T\nimport numpy as np\n\nfrom .core import *\n\n\nclass GRU(Layer):\n    '''\n        Gated Recurrent Unit - Cho et al. 2014\n\n        Acts as a spatiotemporal projection,\n        turning a sequence of vectors into a single vector.\n\n        Eats inputs with shape:\n        (nb_samples, max_sample_length (samples shorter than this are padded with zeros at the end), input_dim)\n\n        and returns outputs with shape:\n        if not return_sequences:\n            (nb_samples, output_dim)\n        if return_sequences:\n            (nb_samples, max_sample_length, output_dim)\n\n        References:\n            On the Properties of Neural Machine Translation: Encoder–Decoder Approaches\n                http://www.aclweb.org/anthology/W14-4012\n            Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling\n                http://arxiv.org/pdf/1412.3555v1.pdf\n    '''\n    def __init__(self, input_dim, output_dim=128,\n                 init='glorot_uniform', inner_init='orthogonal',\n                 activation='tanh', inner_activation='sigmoid',\n                 return_sequences=False, name='GRU'):\n\n        super(GRU, self).__init__()\n        self.input_dim = input_dim\n        self.output_dim = output_dim\n        self.return_sequences = return_sequences\n\n        self.init = initializations.get(init)\n        self.inner_init = initializations.get(inner_init)\n        self.activation = activations.get(activation)\n        self.inner_activation = activations.get(inner_activation)\n\n        self.W_z = self.init((self.input_dim, self.output_dim))\n        self.U_z = self.inner_init((self.output_dim, self.output_dim))\n        self.b_z = shared_zeros((self.output_dim))\n\n        self.W_r = self.init((self.input_dim, self.output_dim))\n        self.U_r = self.inner_init((self.output_dim, self.output_dim))\n        self.b_r = shared_zeros((self.output_dim))\n\n        self.W_h = self.init((self.input_dim, self.output_dim))\n        self.U_h = self.inner_init((self.output_dim, self.output_dim))\n        self.b_h = shared_zeros((self.output_dim))\n\n        self.params = [\n            self.W_z, self.U_z, self.b_z,\n            self.W_r, self.U_r, self.b_r,\n            self.W_h, self.U_h, self.b_h,\n        ]\n\n        if name is not None:\n            self.set_name(name)\n\n    def _step(self,\n              xz_t, xr_t, xh_t, mask_tm1,\n              h_tm1,\n              u_z, u_r, u_h):\n        # h_tm1 = theano.printing.Print(self.name + 'h_tm1::')(h_tm1)\n        h_mask_tm1 = mask_tm1 * h_tm1\n        # h_mask_tm1 = theano.printing.Print(self.name + 'h_mask_tm1::')(h_mask_tm1)\n        z = self.inner_activation(xz_t + T.dot(h_mask_tm1, u_z))\n        r = self.inner_activation(xr_t + T.dot(h_mask_tm1, u_r))\n        hh_t = self.activation(xh_t + T.dot(r * h_mask_tm1, u_h))\n        h_t = z * h_mask_tm1 + (1 - z) * hh_t\n        return h_t\n\n    def __call__(self, X, mask=None, init_state=None):\n        padded_mask = self.get_padded_shuffled_mask(mask, X, pad=1)\n        X = X.dimshuffle((1, 0, 2))\n\n        x_z = T.dot(X, self.W_z) + self.b_z\n        x_r = T.dot(X, self.W_r) + self.b_r\n        x_h = T.dot(X, self.W_h) + self.b_h\n\n        if init_state:\n            # (batch_size, output_dim)\n            outputs_info = T.unbroadcast(init_state, 1)\n        else:\n            outputs_info = T.unbroadcast(alloc_zeros_matrix(X.shape[1], self.output_dim), 1)\n\n        outputs, updates = theano.scan(\n            self._step,\n            sequences=[x_z, x_r, x_h, padded_mask],\n            outputs_info=outputs_info,\n            non_sequences=[self.U_z, self.U_r, self.U_h])\n\n        if self.return_sequences:\n            return outputs.dimshuffle((1, 0, 2))\n        return outputs[-1]\n\n    def get_padded_shuffled_mask(self, mask, X, pad=0):\n        # mask is (nb_samples, time)\n        if mask is None:\n            mask = T.ones((X.shape[0], X.shape[1]))\n\n        mask = T.shape_padright(mask)  # (nb_samples, time, 1)\n        mask = T.addbroadcast(mask, -1)  # (time, nb_samples, 1) matrix.\n        mask = mask.dimshuffle(1, 0, 2)  # (time, nb_samples, 1)\n\n        if pad > 0:\n            # left-pad in time with 0\n            padding = alloc_zeros_matrix(pad, mask.shape[1], 1)\n            mask = T.concatenate([padding, mask], axis=0)\n        return mask.astype('int8')\n\n\nclass GRU_4BiRNN(Layer):\n    '''\n        Gated Recurrent Unit - Cho et al. 2014\n\n        Acts as a spatiotemporal projection,\n        turning a sequence of vectors into a single vector.\n\n        Eats inputs with shape:\n        (nb_samples, max_sample_length (samples shorter than this are padded with zeros at the end), input_dim)\n\n        and returns outputs with shape:\n        if not return_sequences:\n            (nb_samples, output_dim)\n        if return_sequences:\n            (nb_samples, max_sample_length, output_dim)\n\n        References:\n            On the Properties of Neural Machine Translation: Encoder–Decoder Approaches\n                http://www.aclweb.org/anthology/W14-4012\n            Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling\n                http://arxiv.org/pdf/1412.3555v1.pdf\n    '''\n    def __init__(self, input_dim, output_dim=128,\n                 init='glorot_uniform', inner_init='orthogonal',\n                 activation='tanh', inner_activation='sigmoid',\n                 return_sequences=False, name=None):\n\n        super(GRU_4BiRNN, self).__init__()\n        self.input_dim = input_dim\n        self.output_dim = output_dim\n        self.return_sequences = return_sequences\n\n        self.init = initializations.get(init)\n        self.inner_init = initializations.get(inner_init)\n        self.activation = activations.get(activation)\n        self.inner_activation = activations.get(inner_activation)\n\n        self.W_z = self.init((self.input_dim, self.output_dim))\n        self.U_z = self.inner_init((self.output_dim, self.output_dim))\n        self.b_z = shared_zeros((self.output_dim))\n\n        self.W_r = self.init((self.input_dim, self.output_dim))\n        self.U_r = self.inner_init((self.output_dim, self.output_dim))\n        self.b_r = shared_zeros((self.output_dim))\n\n        self.W_h = self.init((self.input_dim, self.output_dim))\n        self.U_h = self.inner_init((self.output_dim, self.output_dim))\n        self.b_h = shared_zeros((self.output_dim))\n\n        self.params = [\n            self.W_z, self.U_z, self.b_z,\n            self.W_r, self.U_r, self.b_r,\n            self.W_h, self.U_h, self.b_h,\n        ]\n\n        if name is not None:\n            self.set_name(name)\n\n    def _step(self,\n              # xz_t, xr_t, xh_t, mask_tm1, mask,\n              xz_t, xr_t, xh_t, mask,\n              h_tm1,\n              u_z, u_r, u_h):\n        # h_mask_tm1 = mask_tm1 * h_tm1\n        # h_tm1 = theano.printing.Print(self.name + '::h_tm1::')(h_tm1)\n        # mask = theano.printing.Print(self.name + '::mask::')(mask)\n\n        z = self.inner_activation(xz_t + T.dot(h_tm1, u_z))\n        r = self.inner_activation(xr_t + T.dot(h_tm1, u_r))\n        hh_t = self.activation(xh_t + T.dot(r * h_tm1, u_h))\n        h_t = z * h_tm1 + (1 - z) * hh_t\n\n        # mask\n        h_t = (1 - mask) * h_tm1 + mask * h_t\n        # h_t = theano.printing.Print(self.name + '::h_t::')(h_t)\n\n        return h_t\n\n    def __call__(self, X, mask=None, init_state=None):\n        if mask is None:\n            mask = T.ones((X.shape[0], X.shape[1]))\n\n        mask = T.shape_padright(mask)  # (nb_samples, time, 1)\n        mask = T.addbroadcast(mask, -1)  # (time, nb_samples, 1) matrix.\n        mask = mask.dimshuffle(1, 0, 2)  # (time, nb_samples, 1)\n        mask = mask.astype('int8')\n        # mask, padded_mask = self.get_padded_shuffled_mask(mask, pad=1)\n        X = X.dimshuffle((1, 0, 2))\n\n        x_z = T.dot(X, self.W_z) + self.b_z\n        x_r = T.dot(X, self.W_r) + self.b_r\n        x_h = T.dot(X, self.W_h) + self.b_h\n\n        if init_state:\n            # (batch_size, output_dim)\n            outputs_info = T.unbroadcast(init_state, 1)\n        else:\n            outputs_info = T.unbroadcast(alloc_zeros_matrix(X.shape[1], self.output_dim), 1)\n\n        outputs, updates = theano.scan(\n            self._step,\n            # sequences=[x_z, x_r, x_h, padded_mask, mask],\n            sequences=[x_z, x_r, x_h, mask],\n            outputs_info=outputs_info,\n            non_sequences=[self.U_z, self.U_r, self.U_h])\n\n        if self.return_sequences:\n            return outputs.dimshuffle((1, 0, 2))\n        return outputs[-1]\n\n    def get_padded_shuffled_mask(self, mask, pad=0):\n        assert mask, 'mask cannot be None'\n        # mask is (nb_samples, time)\n        mask = T.shape_padright(mask)  # (nb_samples, time, 1)\n        mask = T.addbroadcast(mask, -1)  # (time, nb_samples, 1) matrix.\n        mask = mask.dimshuffle(1, 0, 2)  # (time, nb_samples, 1)\n\n        if pad > 0:\n            # left-pad in time with 0\n            padding = alloc_zeros_matrix(pad, mask.shape[1], 1)\n            padded_mask = T.concatenate([padding, mask], axis=0)\n        return mask.astype('int8'), padded_mask.astype('int8')\n\n\nclass LSTM(Layer):\n    def __init__(self, input_dim, output_dim,\n                 init='glorot_uniform', inner_init='orthogonal', forget_bias_init='one',\n                 activation='tanh', inner_activation='sigmoid', return_sequences=False, name='LSTM'):\n\n        super(LSTM, self).__init__()\n\n        self.output_dim = output_dim\n        self.init = initializations.get(init)\n        self.inner_init = initializations.get(inner_init)\n        self.forget_bias_init = initializations.get(forget_bias_init)\n        self.activation = activations.get(activation)\n        self.inner_activation = activations.get(inner_activation)\n        self.return_sequences = return_sequences\n\n        self.input_dim = input_dim\n\n        self.W_i = self.init((input_dim, self.output_dim))\n        self.U_i = self.inner_init((self.output_dim, self.output_dim))\n        self.b_i = shared_zeros((self.output_dim))\n\n        self.W_f = self.init((input_dim, self.output_dim))\n        self.U_f = self.inner_init((self.output_dim, self.output_dim))\n        self.b_f = self.forget_bias_init((self.output_dim))\n\n        self.W_c = self.init((input_dim, self.output_dim))\n        self.U_c = self.inner_init((self.output_dim, self.output_dim))\n        self.b_c = shared_zeros((self.output_dim))\n\n        self.W_o = self.init((input_dim, self.output_dim))\n        self.U_o = self.inner_init((self.output_dim, self.output_dim))\n        self.b_o = shared_zeros((self.output_dim))\n\n        self.params = [\n            self.W_i, self.U_i, self.b_i,\n            self.W_c, self.U_c, self.b_c,\n            self.W_f, self.U_f, self.b_f,\n            self.W_o, self.U_o, self.b_o,\n        ]\n\n        self.set_name(name)\n\n    def _step(self,\n              xi_t, xf_t, xo_t, xc_t, mask_t,\n              h_tm1, c_tm1,\n              u_i, u_f, u_o, u_c, b_u):\n\n        i_t = self.inner_activation(xi_t + T.dot(h_tm1 * b_u[0], u_i))\n        f_t = self.inner_activation(xf_t + T.dot(h_tm1 * b_u[1], u_f))\n        c_t = f_t * c_tm1 + i_t * self.activation(xc_t + T.dot(h_tm1 * b_u[2], u_c))\n        o_t = self.inner_activation(xo_t + T.dot(h_tm1 * b_u[3], u_o))\n        h_t = o_t * self.activation(c_t)\n\n        h_t = (1 - mask_t) * h_tm1 + mask_t * h_t\n        c_t = (1 - mask_t) * c_tm1 + mask_t * c_t\n\n        return h_t, c_t\n\n    def __call__(self, X, mask=None, init_state=None, dropout=0, train=True, srng=None):\n        mask = self.get_mask(mask, X)\n        X = X.dimshuffle((1, 0, 2))\n\n        retain_prob = 1. - dropout\n        B_w = np.ones((4,), dtype=theano.config.floatX)\n        B_u = np.ones((4,), dtype=theano.config.floatX)\n        if dropout > 0:\n            logging.info('applying dropout with p = %f', dropout)\n            if train:\n                B_w = srng.binomial((4, X.shape[1], self.input_dim), p=retain_prob,\n                    dtype=theano.config.floatX)\n                B_u = srng.binomial((4, X.shape[1], self.output_dim), p=retain_prob,\n                    dtype=theano.config.floatX)\n            else:\n                B_w *= retain_prob\n                B_u *= retain_prob\n\n        xi = T.dot(X * B_w[0], self.W_i) + self.b_i\n        xf = T.dot(X * B_w[1], self.W_f) + self.b_f\n        xc = T.dot(X * B_w[2], self.W_c) + self.b_c\n        xo = T.dot(X * B_w[3], self.W_o) + self.b_o\n\n        if init_state:\n            # (batch_size, output_dim)\n            first_state = T.unbroadcast(init_state, 1)\n        else:\n            first_state = T.unbroadcast(alloc_zeros_matrix(X.shape[1], self.output_dim), 1)\n\n        [outputs, memories], updates = theano.scan(\n            self._step,\n            sequences=[xi, xf, xo, xc, mask],\n            outputs_info=[\n                first_state,\n                T.unbroadcast(alloc_zeros_matrix(X.shape[1], self.output_dim), 1)\n            ],\n            non_sequences=[self.U_i, self.U_f, self.U_o, self.U_c, B_u])\n\n        if self.return_sequences:\n            return outputs.dimshuffle((1, 0, 2))\n        return outputs[-1]\n\n    def get_mask(self, mask, X):\n        if mask is None:\n            mask = T.ones((X.shape[0], X.shape[1]))\n\n        mask = T.shape_padright(mask)  # (nb_samples, time, 1)\n        mask = T.addbroadcast(mask, -1)  # (time, nb_samples, 1) matrix.\n        mask = mask.dimshuffle(1, 0, 2)  # (time, nb_samples, 1)\n        mask = mask.astype('int8')\n\n        return mask\n\n\nclass BiLSTM(Layer):\n    def __init__(self, input_dim, output_dim,\n                 init='glorot_uniform', inner_init='orthogonal', forget_bias_init='one',\n                 activation='tanh', inner_activation='sigmoid', return_sequences=False, name='BiLSTM'):\n        super(BiLSTM, self).__init__()\n\n        self.input_dim = input_dim\n        self.output_dim = output_dim\n        self.return_sequences = return_sequences\n\n        params = dict(locals())\n        del params['self']\n\n        params['name'] = 'foward_lstm'\n        self.forward_lstm = LSTM(**params)\n        params['name'] = 'backward_lstm'\n        self.backward_lstm = LSTM(**params)\n\n        self.params = self.forward_lstm.params + self.backward_lstm.params\n\n        self.set_name(name)\n\n    def __call__(self, X, mask=None, init_state=None, dropout=0, train=True, srng=None):\n        # X: (nb_samples, nb_time_steps, embed_dim)\n        # mask: (nb_samples, nb_time_steps)\n        if mask is None:\n            mask = T.ones((X.shape[0], X.shape[1]))\n\n        hidden_states_forward = self.forward_lstm(X, mask, init_state, dropout, train, srng)\n        hidden_states_backward = self.backward_lstm(X[:, ::-1, :], mask[:, ::-1], init_state, dropout, train, srng)\n\n        if self.return_sequences:\n            hidden_states = T.concatenate([hidden_states_forward, hidden_states_backward[:, ::-1, :]], axis=-1)\n        else:\n            raise NotImplementedError()\n\n        return hidden_states\n\n\nclass CondAttLSTM(Layer):\n    \"\"\"\n    Conditional LSTM with Attention\n    \"\"\"\n    def __init__(self, input_dim, output_dim,\n                 context_dim, att_hidden_dim,\n                 init='glorot_uniform', inner_init='orthogonal', forget_bias_init='one',\n                 activation='tanh', inner_activation='sigmoid', name='CondAttLSTM'):\n\n        super(CondAttLSTM, self).__init__()\n\n        self.output_dim = output_dim\n        self.init = initializations.get(init)\n        self.inner_init = initializations.get(inner_init)\n        self.forget_bias_init = initializations.get(forget_bias_init)\n        self.activation = activations.get(activation)\n        self.inner_activation = activations.get(inner_activation)\n        self.context_dim = context_dim\n        self.input_dim = input_dim\n\n        # regular LSTM layer\n\n        self.W_i = self.init((input_dim, self.output_dim))\n        self.U_i = self.inner_init((self.output_dim, self.output_dim))\n        self.C_i = self.inner_init((self.context_dim, self.output_dim))\n        self.b_i = shared_zeros((self.output_dim))\n\n        self.W_f = self.init((input_dim, self.output_dim))\n        self.U_f = self.inner_init((self.output_dim, self.output_dim))\n        self.C_f = self.inner_init((self.context_dim, self.output_dim))\n        self.b_f = self.forget_bias_init((self.output_dim))\n\n        self.W_c = self.init((input_dim, self.output_dim))\n        self.U_c = self.inner_init((self.output_dim, self.output_dim))\n        self.C_c = self.inner_init((self.context_dim, self.output_dim))\n        self.b_c = shared_zeros((self.output_dim))\n\n        self.W_o = self.init((input_dim, self.output_dim))\n        self.U_o = self.inner_init((self.output_dim, self.output_dim))\n        self.C_o = self.inner_init((self.context_dim, self.output_dim))\n        self.b_o = shared_zeros((self.output_dim))\n\n        self.params = [\n            self.W_i, self.U_i, self.b_i, self.C_i,\n            self.W_c, self.U_c, self.b_c, self.C_c,\n            self.W_f, self.U_f, self.b_f, self.C_f,\n            self.W_o, self.U_o, self.b_o, self.C_o,\n        ]\n\n        # attention layer\n        self.att_ctx_W1 = self.init((context_dim, att_hidden_dim))\n        self.att_h_W1 = self.init((output_dim, att_hidden_dim))\n        self.att_b1 = shared_zeros((att_hidden_dim))\n\n        self.att_W2 = self.init((att_hidden_dim, 1))\n        self.att_b2 = shared_zeros((1))\n\n        self.params += [\n            self.att_ctx_W1, self.att_h_W1, self.att_b1,\n            self.att_W2, self.att_b2\n        ]\n\n        self.set_name(name)\n\n    def _step(self,\n              xi_t, xf_t, xo_t, xc_t, mask_t,\n              h_tm1, c_tm1, ctx_vec_tm1,\n              u_i, u_f, u_o, u_c, c_i, c_f, c_o, c_c,\n              att_h_w1, att_w2, att_b2,\n              context, context_mask, context_att_trans,\n              b_u):\n\n        # context: (batch_size, context_size, context_dim)\n\n        # (batch_size, att_layer1_dim)\n        h_tm1_att_trans = T.dot(h_tm1, att_h_w1)\n\n        # h_tm1_att_trans = theano.printing.Print('h_tm1_att_trans')(h_tm1_att_trans)\n\n        # (batch_size, context_size, att_layer1_dim)\n        att_hidden = T.tanh(context_att_trans + h_tm1_att_trans[:, None, :])\n        # (batch_size, context_size, 1)\n        att_raw = T.dot(att_hidden, att_w2) + att_b2\n\n        # (batch_size, context_size)\n        ctx_att = T.exp(att_raw).reshape((att_raw.shape[0], att_raw.shape[1]))\n\n        if context_mask:\n            ctx_att = ctx_att * context_mask\n\n        ctx_att = ctx_att / T.sum(ctx_att, axis=-1, keepdims=True)\n        # (batch_size, context_dim)\n        ctx_vec = T.sum(context * ctx_att[:, :, None], axis=1)\n\n        i_t = self.inner_activation(xi_t + T.dot(h_tm1 * b_u[0], u_i) + T.dot(ctx_vec, c_i))\n        f_t = self.inner_activation(xf_t + T.dot(h_tm1 * b_u[1], u_f) + T.dot(ctx_vec, c_f))\n        c_t = f_t * c_tm1 + i_t * self.activation(xc_t + T.dot(h_tm1 * b_u[2], u_c) + T.dot(ctx_vec, c_c))\n        o_t = self.inner_activation(xo_t + T.dot(h_tm1 * b_u[3], u_o) + T.dot(ctx_vec, c_o))\n        h_t = o_t * self.activation(c_t)\n\n        h_t = (1 - mask_t) * h_tm1 + mask_t * h_t\n        c_t = (1 - mask_t) * c_tm1 + mask_t * c_t\n\n        # ctx_vec = theano.printing.Print('ctx_vec')(ctx_vec)\n\n        return h_t, c_t, ctx_vec\n\n    def __call__(self, X, context, init_state=None, init_cell=None, mask=None, context_mask=None,\n                 dropout=0, train=True, srng=None):\n        assert context_mask.dtype == 'int8', 'context_mask is not int8, got %s' % context_mask.dtype\n\n        mask = self.get_mask(mask, X)\n        X = X.dimshuffle((1, 0, 2))\n\n        retain_prob = 1. - dropout\n        B_w = np.ones((4,), dtype=theano.config.floatX)\n        B_u = np.ones((4,), dtype=theano.config.floatX)\n        if dropout > 0:\n            logging.info('applying dropout with p = %f', dropout)\n            if train:\n                B_w = srng.binomial((4, X.shape[1], self.input_dim), p=retain_prob,\n                                    dtype=theano.config.floatX)\n                B_u = srng.binomial((4, X.shape[1], self.output_dim), p=retain_prob,\n                                    dtype=theano.config.floatX)\n            else:\n                B_w *= retain_prob\n                B_u *= retain_prob\n\n        xi = T.dot(X * B_w[0], self.W_i) + self.b_i\n        xf = T.dot(X * B_w[1], self.W_f) + self.b_f\n        xc = T.dot(X * B_w[2], self.W_c) + self.b_c\n        xo = T.dot(X * B_w[3], self.W_o) + self.b_o\n\n        # (batch_size, context_size, att_layer1_dim)\n        context_att_trans = T.dot(context, self.att_ctx_W1) + self.att_b1\n\n        if init_state:\n            # (batch_size, output_dim)\n            first_state = T.unbroadcast(init_state, 1)\n        else:\n            first_state = T.unbroadcast(alloc_zeros_matrix(X.shape[1], self.output_dim), 1)\n\n        if init_cell:\n            # (batch_size, output_dim)\n            first_cell = T.unbroadcast(init_cell, 1)\n        else:\n            first_cell = T.unbroadcast(alloc_zeros_matrix(X.shape[1], self.output_dim), 1)\n\n        [outputs, cells, ctx_vectors], updates = theano.scan(\n            self._step,\n            sequences=[xi, xf, xo, xc, mask],\n            outputs_info=[\n                first_state,  # for h\n                first_cell,  # for cell   T.unbroadcast(alloc_zeros_matrix(X.shape[1], self.output_dim), 1)\n                T.unbroadcast(alloc_zeros_matrix(X.shape[1], self.context_dim), 1)  # for ctx vector\n            ],\n            non_sequences=[\n                self.U_i, self.U_f, self.U_o, self.U_c,\n                self.C_i, self.C_f, self.C_o, self.C_c,\n                self.att_h_W1, self.att_W2, self.att_b2,\n                context, context_mask, context_att_trans,\n                B_u\n            ])\n\n        outputs = outputs.dimshuffle((1, 0, 2))\n        ctx_vectors = ctx_vectors.dimshuffle((1, 0, 2))\n        cells = cells.dimshuffle((1, 0, 2))\n\n        return outputs, cells, ctx_vectors\n        # return outputs[-1]\n\n    def get_mask(self, mask, X):\n        if mask is None:\n            mask = T.ones((X.shape[0], X.shape[1]))\n\n        mask = T.shape_padright(mask)  # (nb_samples, time, 1)\n        mask = T.addbroadcast(mask, -1)  # (time, nb_samples, 1) matrix.\n        mask = mask.dimshuffle(1, 0, 2)  # (time, nb_samples, 1)\n        mask = mask.astype('int8')\n\n        return mask\n\n\nclass GRUDecoder(Layer):\n    '''\n        GRU Decoder\n    '''\n    def __init__(self, input_dim, context_dim, hidden_dim, vocab_num,\n                 init='glorot_uniform', inner_init='orthogonal',\n                 activation='tanh', inner_activation='sigmoid',\n                 name='GRUDecoder'):\n\n        super(GRUDecoder, self).__init__()\n        self.input_dim = input_dim\n        self.context_dim = context_dim\n        self.hidden_dim = hidden_dim\n        self.vocab_num = vocab_num\n\n        self.init = initializations.get(init)\n        self.inner_init = initializations.get(inner_init)\n        self.activation = activations.get(activation)\n        self.inner_activation = activations.get(inner_activation)\n\n        self.W_z = self.init((self.input_dim, self.hidden_dim))\n        self.U_z = self.inner_init((self.hidden_dim, self.hidden_dim))\n        self.C_z = self.init((self.context_dim, self.hidden_dim))\n        self.b_z = shared_zeros((self.hidden_dim))\n\n        self.W_r = self.init((self.input_dim, self.hidden_dim))\n        self.U_r = self.inner_init((self.hidden_dim, self.hidden_dim))\n        self.C_r = self.init((self.context_dim, self.hidden_dim))\n        self.b_r = shared_zeros((self.hidden_dim))\n\n        self.W_h = self.init((self.input_dim, self.hidden_dim))\n        self.U_h = self.inner_init((self.hidden_dim, self.hidden_dim))\n        self.C_h = self.init((self.context_dim, self.hidden_dim))\n        self.b_h = shared_zeros((self.hidden_dim))\n\n        # self.W_y = self.init((self.input_dim, self.vocab_num))\n        self.U_y = self.init((self.hidden_dim, self.vocab_num))\n        self.C_y = self.init((self.context_dim, self.vocab_num))\n        self.b_y = shared_zeros((self.vocab_num))\n\n        self.params = [\n            self.W_z, self.U_z, self.b_z,\n            self.W_r, self.U_r, self.b_r,\n            self.W_h, self.U_h, self.b_h,\n            self.C_z, self.C_r, self.C_h,\n            self.U_y, self.C_y, self.b_y, #self.W_y\n        ]\n\n        if name is not None:\n            self.set_name(name)\n\n    def _step(self,\n              xz_t, xr_t, xh_t, mask_tm1,\n              h_tm1,\n              u_z, u_r, u_h):\n        h_mask_tm1 = mask_tm1 * h_tm1\n        z = self.inner_activation(xz_t + T.dot(h_mask_tm1, u_z))\n        r = self.inner_activation(xr_t + T.dot(h_mask_tm1, u_r))\n        hh_t = self.activation(xh_t + T.dot(r * h_mask_tm1, u_h))\n        h_t = z * h_mask_tm1 + (1 - z) * hh_t\n        return h_t\n\n    def __call__(self, target, context, mask=None):\n        target = target * T.cast(T.shape_padright(mask), 'float32')\n        padded_mask = self.get_padded_shuffled_mask(mask, pad=1)\n        # target = theano.printing.Print('X::' + self.name)(target)\n        X_shifted = T.concatenate([alloc_zeros_matrix(target.shape[0], 1, self.input_dim), target[:, 0:-1, :]], axis=-2)\n\n        # X = theano.printing.Print('X::' + self.name)(X)\n        # X = T.zeros_like(target)\n        # T.set_subtensor(X[:, 1:, :], target[:, 0:-1, :])\n\n        X = X_shifted.dimshuffle((1, 0, 2))\n\n        ctx_step = context.dimshuffle(('x', 0, 1))\n        x_z = T.dot(X, self.W_z) + T.dot(ctx_step, self.C_z) + self.b_z\n        x_r = T.dot(X, self.W_r) + T.dot(ctx_step, self.C_r) + self.b_r\n        x_h = T.dot(X, self.W_h) + T.dot(ctx_step, self.C_h) + self.b_h\n\n        h, updates = theano.scan(\n            self._step,\n            sequences=[x_z, x_r, x_h, padded_mask],\n            outputs_info=T.unbroadcast(alloc_zeros_matrix(X.shape[1], self.hidden_dim), 1),\n            non_sequences=[self.U_z, self.U_r, self.U_h])\n\n        # (batch_size, max_token_len, hidden_dim)\n        h = h.dimshuffle((1, 0, 2))\n\n        # (batch_size, max_token_len, vocab_size)\n        predicts = T.dot(h, self.U_y) + T.dot(context.dimshuffle((0, 'x', 1)), self.C_y) + self.b_y # + T.dot(X_shifted, self.W_y)\n\n        predicts_flatten = predicts.reshape((-1, predicts.shape[2]))\n        return T.nnet.softmax(predicts_flatten).reshape((predicts.shape[0], predicts.shape[1], predicts.shape[2]))\n\n    def get_padded_shuffled_mask(self, mask, pad=0):\n        assert mask, 'mask cannot be None'\n        # mask is (nb_samples, time)\n        mask = T.shape_padright(mask)  # (nb_samples, time, 1)\n        mask = T.addbroadcast(mask, -1)  # (time, nb_samples, 1) matrix.\n        mask = mask.dimshuffle(1, 0, 2)  # (time, nb_samples, 1)\n\n        if pad > 0:\n            # left-pad in time with 0\n            padding = alloc_zeros_matrix(pad, mask.shape[1], 1)\n            mask = T.concatenate([padding, mask], axis=0)\n        return mask.astype('int8')\n"
  },
  {
    "path": "nn/objectives.py",
    "content": "from __future__ import absolute_import\nimport theano\nimport theano.tensor as T\nimport numpy as np\nfrom six.moves import range\n\nif theano.config.floatX == 'float64':\n    epsilon = 1.0e-9\nelse:\n    epsilon = 1.0e-7\n\n\ndef mean_squared_error(y_true, y_pred):\n    return T.sqr(y_pred - y_true).mean(axis=-1)\n\n\ndef mean_absolute_error(y_true, y_pred):\n    return T.abs_(y_pred - y_true).mean(axis=-1)\n\n\ndef mean_absolute_percentage_error(y_true, y_pred):\n    return T.abs_((y_true - y_pred) / T.clip(T.abs_(y_true), epsilon, np.inf)).mean(axis=-1) * 100.\n\n\ndef mean_squared_logarithmic_error(y_true, y_pred):\n    return T.sqr(T.log(T.clip(y_pred, epsilon, np.inf) + 1.) - T.log(T.clip(y_true, epsilon, np.inf) + 1.)).mean(axis=-1)\n\n\ndef squared_hinge(y_true, y_pred):\n    return T.sqr(T.maximum(1. - y_true * y_pred, 0.)).mean(axis=-1)\n\n\ndef hinge(y_true, y_pred):\n    return T.maximum(1. - y_true * y_pred, 0.).mean(axis=-1)\n\n\ndef categorical_crossentropy(y_true, y_pred):\n    '''Expects a binary class matrix instead of a vector of scalar classes\n    '''\n    y_pred = T.clip(y_pred, epsilon, 1.0 - epsilon)\n    # scale preds so that the class probas of each sample sum to 1\n    y_pred /= y_pred.sum(axis=-1, keepdims=True)\n    cce = T.nnet.categorical_crossentropy(y_pred, y_true)\n    return cce\n\n\ndef binary_crossentropy(y_true, y_pred):\n    y_pred = T.clip(y_pred, epsilon, 1.0 - epsilon)\n    bce = T.nnet.binary_crossentropy(y_pred, y_true).mean(axis=-1)\n    return bce\n\n\ndef poisson_loss(y_true, y_pred):\n    return T.mean(y_pred - y_true * T.log(y_pred + epsilon), axis=-1)\n\n# aliases\nmse = MSE = mean_squared_error\nmae = MAE = mean_absolute_error\nmape = MAPE = mean_absolute_percentage_error\nmsle = MSLE = mean_squared_logarithmic_error\n\nfrom .utils.generic_utils import get_from_module\ndef get(identifier):\n    return get_from_module(identifier, globals(), 'objective')\n"
  },
  {
    "path": "nn/optimizers.py",
    "content": "from __future__ import absolute_import\nimport theano\nimport theano.tensor as T\n\nfrom .utils.theano_utils import shared_zeros, shared_scalar, floatX\nfrom .utils.generic_utils import get_from_module\nfrom six.moves import zip\nfrom theano.sandbox.rng_mrg import MRG_RandomStreams\nfrom theano.tensor.shared_randomstreams import RandomStreams\nimport math\nfrom nn.utils.config_factory import config\n\n\ndef clip_norm(g, c, n):\n    if c > 0:\n        g = T.switch(T.ge(n, c), g * c / n, g)\n    return g\n\n\ndef kl_divergence(p, p_hat):\n    return p_hat - p + p * T.log(p / p_hat)\n\n\nclass Optimizer(object):\n    def __init__(self, **kwargs):\n        self.__dict__.update(kwargs)\n        self.updates = []\n\n    def get_state(self):\n        return [u[0].get_value() for u in self.updates]\n\n    def set_state(self, value_list):\n        assert len(self.updates) == len(value_list)\n        for u, v in zip(self.updates, value_list):\n            u[0].set_value(floatX(v))\n\n    def get_updates(self, params, constraints, loss, **kwargs):\n        raise NotImplementedError\n\n    def get_gradients(self, loss, params, **kwargs):\n\n        grads = T.grad(loss, params, disconnected_inputs='warn', **kwargs)\n\n        if hasattr(self, 'clip_grad') and self.clip_grad > 0:\n            norm = T.sqrt(sum([T.sum(g ** 2) for g in grads]))\n            # norm = theano.printing.Print('gradient norm::')(norm)\n            grads = [clip_norm(g, self.clip_grad, norm) for g in grads]\n\n        return grads\n\n    def get_config(self):\n        return {\"name\": self.__class__.__name__}\n\n\nclass SGD(Optimizer):\n\n    def __init__(self, lr=0.01, momentum=0., decay=0., nesterov=False, *args, **kwargs):\n        super(SGD, self).__init__(**kwargs)\n        self.__dict__.update(locals())\n        self.iterations = shared_scalar(0)\n        self.lr = shared_scalar(lr)\n        self.momentum = shared_scalar(momentum)\n\n    def get_updates(self, params, loss):\n        grads = self.get_gradients(loss, params)\n        lr = self.lr * (1.0 / (1.0 + self.decay * self.iterations))\n        self.updates = [(self.iterations, self.iterations + 1.)]\n\n        for p, g in zip(params, grads):\n            m = shared_zeros(p.get_value().shape)  # momentum\n            v = self.momentum * m - lr * g  # velocity\n            self.updates.append((m, v))\n\n            if self.nesterov:\n                new_p = p + self.momentum * v - lr * g\n            else:\n                new_p = p + v\n\n            self.updates.append((p, new_p))\n        return self.updates\n\n    def get_config(self):\n        return {\"name\": self.__class__.__name__,\n                \"lr\": float(self.lr.get_value()),\n                \"momentum\": float(self.momentum.get_value()),\n                \"decay\": float(self.decay.get_value()),\n                \"nesterov\": self.nesterov}\n\n\nclass RMSprop(Optimizer):\n    def __init__(self, lr=0.001, rho=0.9, epsilon=1e-6, *args, **kwargs):\n        super(RMSprop, self).__init__(**kwargs)\n        self.__dict__.update(locals())\n        self.lr = shared_scalar(lr)\n        self.rho = shared_scalar(rho)\n\n    def get_updates(self, params, constraints, loss):\n        grads = self.get_gradients(loss, params)\n        accumulators = [shared_zeros(p.get_value().shape) for p in params]\n        self.updates = []\n\n        for p, g, a, c in zip(params, grads, accumulators, constraints):\n            new_a = self.rho * a + (1 - self.rho) * g ** 2  # update accumulator\n            self.updates.append((a, new_a))\n\n            new_p = p - self.lr * g / T.sqrt(new_a + self.epsilon)\n            self.updates.append((p, c(new_p)))  # apply constraints\n        return self.updates\n\n    def get_config(self):\n        return {\"name\": self.__class__.__name__,\n                \"lr\": float(self.lr.get_value()),\n                \"rho\": float(self.rho.get_value()),\n                \"epsilon\": self.epsilon}\n\n\nclass Adagrad(Optimizer):\n    def __init__(self, lr=0.01, epsilon=1e-6, *args, **kwargs):\n        super(Adagrad, self).__init__(**kwargs)\n        self.__dict__.update(locals())\n        self.lr = shared_scalar(lr)\n\n    def get_updates(self, params, constraints, loss):\n        grads = self.get_gradients(loss, params)\n        accumulators = [shared_zeros(p.get_value().shape) for p in params]\n        self.updates = []\n\n        for p, g, a, c in zip(params, grads, accumulators, constraints):\n            new_a = a + g ** 2  # update accumulator\n            self.updates.append((a, new_a))\n            new_p = p - self.lr * g / T.sqrt(new_a + self.epsilon)\n            self.updates.append((p, c(new_p)))  # apply constraints\n        return self.updates\n\n    def get_config(self):\n        return {\"name\": self.__class__.__name__,\n                \"lr\": float(self.lr.get_value()),\n                \"epsilon\": self.epsilon}\n\n\nclass Adadelta(Optimizer):\n    '''\n        Reference: http://arxiv.org/abs/1212.5701\n    '''\n    def __init__(self, lr=1.0, rho=0.95, epsilon=1e-6, *args, **kwargs):\n        super(Adadelta, self).__init__(**kwargs)\n        self.__dict__.update(locals())\n        self.lr = shared_scalar(lr)\n\n    def get_updates(self, params, loss):\n        grads = self.get_gradients(loss, params)\n        accumulators = [shared_zeros(p.get_value().shape) for p in params]\n        delta_accumulators = [shared_zeros(p.get_value().shape) for p in params]\n        self.updates = []\n\n        for p, g, a, d_a in zip(params, grads, accumulators, delta_accumulators):\n            new_a = self.rho * a + (1 - self.rho) * g ** 2  # update accumulator\n            self.updates.append((a, new_a))\n\n            # use the new accumulator and the *old* delta_accumulator\n            update = g * T.sqrt(d_a + self.epsilon) / T.sqrt(new_a +\n                                                             self.epsilon)\n\n            new_p = p - self.lr * update\n            self.updates.append((p, new_p))\n\n            # update delta_accumulator\n            new_d_a = self.rho * d_a + (1 - self.rho) * update ** 2\n            self.updates.append((d_a, new_d_a))\n        return self.updates, grads\n\n    def get_config(self):\n        return {\"name\": self.__class__.__name__,\n                \"lr\": float(self.lr.get_value()),\n                \"rho\": self.rho,\n                \"epsilon\": self.epsilon}\n\n\nclass Adadelta_GaussianNoise(Optimizer):\n    '''\n        Reference: http://arxiv.org/abs/1212.5701\n    '''\n    def __init__(self, lr=1.0, rho=0.95, epsilon=1e-6, *args, **kwargs):\n        super(Adadelta_GaussianNoise, self).__init__(**kwargs)\n        self.__dict__.update(locals())\n        self.lr = shared_scalar(lr)\n        self.rng = MRG_RandomStreams(use_cuda=config.get('run.use_cuda')) #RandomStreams() #(use_cuda=False)\n\n    def get_updates(self, params, loss):\n        grads = self.get_gradients(loss, params)\n        accumulators = [shared_zeros(p.get_value().shape) for p in params]\n        delta_accumulators = [shared_zeros(p.get_value().shape) for p in params]\n        self.updates = []\n        n_step = theano.shared(1.0)\n        self.updates.append((n_step, n_step + 1))\n\n        for p, g, a, d_a in zip(params, grads, accumulators, delta_accumulators):\n            g_noise = self.rng.normal(p.shape, 0, T.sqrt(n_step ** - 0.55), dtype='float32')\n            g_deviated = g + g_noise\n\n            new_a = self.rho * a + (1 - self.rho) * g_deviated ** 2  # update accumulator\n            self.updates.append((a, new_a))\n\n            # use the new accumulator and the *old* delta_accumulator\n            update = g_deviated * T.sqrt(d_a + self.epsilon) / T.sqrt(new_a +\n                                                             self.epsilon)\n\n            new_p = p - self.lr * update\n            self.updates.append((p, new_p))\n\n            # update delta_accumulator\n            new_d_a = self.rho * d_a + (1 - self.rho) * update ** 2\n            self.updates.append((d_a, new_d_a))\n        return self.updates\n\n    def get_config(self):\n        return {\"name\": self.__class__.__name__,\n                \"lr\": float(self.lr.get_value()),\n                \"rho\": self.rho,\n                \"epsilon\": self.epsilon}\n\n\nclass Adam(Optimizer):\n    '''\n        Reference: http://arxiv.org/abs/1412.6980v8\n\n        Default parameters follow those provided in the original paper.\n    '''\n    def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8, *args, **kwargs):\n        super(Adam, self).__init__(**kwargs)\n        self.__dict__.update(locals())\n        self.iterations = shared_scalar(0)\n        self.lr = shared_scalar(lr)\n        # self.rng = MRG_RandomStreams(use_cuda=config['use_gpu']) #RandomStreams() #(use_cuda=False)\n\n    def get_updates(self, params, loss, **kwargs):\n        grads = self.get_gradients(loss, params, **kwargs)\n        self.updates = [(self.iterations, self.iterations+1.)]\n\n        t = self.iterations + 1\n        lr_t = self.lr * T.sqrt(1-self.beta_2**t)/(1-self.beta_1**t)\n\n        # n_step = theano.shared(1.0)\n        # self.updates.append((n_step, n_step + 1))\n\n        gradients = []\n\n        for p, g in zip(params, grads):\n            m = theano.shared(p.get_value() * 0.)  # zero init of moment\n            v = theano.shared(p.get_value() * 0.)  # zero init of velocity\n\n            # g_noise = self.rng.normal(g.shape, 0, T.sqrt(0.5 * n_step ** - 0.55), dtype='float32')\n            # g_deviated = g + g_noise\n            g_deviated = g\n\n            # for debug purposes\n            gradients.append(g)\n\n            m_t = (self.beta_1 * m) + (1 - self.beta_1) * g_deviated\n            v_t = (self.beta_2 * v) + (1 - self.beta_2) * (g_deviated**2)\n            p_t = p - lr_t * m_t / (T.sqrt(v_t) + self.epsilon)\n\n            self.updates.append((m, m_t))\n            self.updates.append((v, v_t))\n            self.updates.append((p, p_t))  # apply constraints\n        return self.updates, gradients\n\n    def get_config(self):\n        return {\"name\": self.__class__.__name__,\n                \"lr\": float(self.lr.get_value()),\n                \"beta_1\": self.beta_1,\n                \"beta_2\": self.beta_2,\n                \"epsilon\": self.epsilon}\n\n# aliases\nsgd = SGD\nrmsprop = RMSprop\nadagrad = Adagrad\nadadelta = Adadelta\nadam = Adam\nadadelta_noise = Adadelta_GaussianNoise\n\n\ndef get(identifier, kwargs=None):\n    return get_from_module(identifier, globals(), 'optimizer', instantiate=True,\n                           kwargs=kwargs)\n"
  },
  {
    "path": "nn/utils/__init__.py",
    "content": "__author__ = 'yinpengcheng'\n"
  },
  {
    "path": "nn/utils/config_factory.py",
    "content": "import logging\n\n\nclass MetaConfig(type):\n    def __getitem__(self, key):\n        return config._config[key]\n\n    def __setitem__(self, key, value):\n        config._config[key] = value\n\n\nclass config(object):\n    _config = {}\n    __metaclass__ = MetaConfig\n\n    @staticmethod\n    def set(key, val):\n        config._config[key] = val\n\n    @staticmethod\n    def init_config(file='config.py'):\n        if len(config._config) > 0:\n            return\n\n        logging.info('use configuration: %s', file)\n        data = {}\n        execfile(file, data)\n        config._config = data['config']"
  },
  {
    "path": "nn/utils/generic_utils.py",
    "content": "from __future__ import absolute_import\nimport numpy as np\nimport time\nimport sys\nimport six\nimport logging\n\n\ndef get_from_module(identifier, module_params, module_name, instantiate=False, kwargs=None):\n    if isinstance(identifier, six.string_types):\n        res = module_params.get(identifier)\n        if not res:\n            raise Exception('Invalid ' + str(module_name) + ': ' + str(identifier))\n        if instantiate and not kwargs:\n            return res()\n        elif instantiate and kwargs:\n            return res(**kwargs)\n        else:\n            return res\n    return identifier\n\n\ndef make_tuple(*args):\n    return args\n\n\ndef printv(v, prefix=''):\n    if type(v) == dict:\n        if 'name' in v:\n            print(prefix + '#' + v['name'])\n            del v['name']\n        prefix += '...'\n        for nk, nv in v.items():\n            if type(nv) in [dict, list]:\n                print(prefix + nk + ':')\n                printv(nv, prefix)\n            else:\n                print(prefix + nk + ':' + str(nv))\n    elif type(v) == list:\n        prefix += '...'\n        for i, nv in enumerate(v):\n            print(prefix + '#' + str(i))\n            printv(nv, prefix)\n    else:\n        prefix += '...'\n        print(prefix + str(v))\n\n\ndef make_batches(size, batch_size):\n    nb_batch = int(np.ceil(size/float(batch_size)))\n    return [(i*batch_size, min(size, (i+1)*batch_size)) for i in range(0, nb_batch)]\n\n\ndef slice_X(X, start=None, stop=None):\n    if type(X) == list:\n        if hasattr(start, '__len__'):\n            return [x[start] for x in X]\n        else:\n            return [x[start:stop] for x in X]\n    else:\n        if hasattr(start, '__len__'):\n            return X[start]\n        else:\n            return X[start:stop]\n\n\ndef init_logging(file_name, level=logging.INFO):\n    formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(module)s: %(message)s', datefmt='%m/%d/%Y %H:%M:%S')\n    fh = logging.FileHandler(file_name)\n    ch = logging.StreamHandler()\n\n    fh.setFormatter(formatter)\n    ch.setFormatter(formatter)\n\n    logging.getLogger().handlers = []\n    logging.getLogger().addHandler(ch)\n    logging.getLogger().addHandler(fh)\n    logging.getLogger().setLevel(level)\n\n    logging.info('init logging file [%s]' % file_name)\n\n\ndef pad_sequences(sequences, maxlen=None, dtype='int32',\n                  padding='pre', truncating='pre', value=0.):\n    '''Pads each sequence to the same length:\n    the length of the longest sequence.\n\n    If maxlen is provided, any sequence longer\n    than maxlen is truncated to maxlen.\n    Truncation happens off either the beginning (default) or\n    the end of the sequence.\n\n    Supports post-padding and pre-padding (default).\n\n    # Arguments\n        sequences: list of lists where each element is a sequence\n        maxlen: int, maximum length\n        dtype: type to cast the resulting sequence.\n        padding: 'pre' or 'post', pad either before or after each sequence.\n        truncating: 'pre' or 'post', remove values from sequences larger than\n            maxlen either in the beginning or in the end of the sequence\n        value: float, value to pad the sequences to the desired value.\n\n    # Returns\n        x: numpy array with dimensions (number_of_sequences, maxlen)\n    '''\n    lengths = [len(s) for s in sequences]\n\n    nb_samples = len(sequences)\n    if maxlen is None:\n        maxlen = np.max(lengths)\n\n    # take the sample shape from the first non empty sequence\n    # checking for consistency in the main loop below.\n    sample_shape = tuple()\n    for s in sequences:\n        if len(s) > 0:\n            sample_shape = np.asarray(s).shape[1:]\n            break\n\n    x = (np.ones((nb_samples, maxlen) + sample_shape) * value).astype(dtype)\n    for idx, s in enumerate(sequences):\n        if len(s) == 0:\n            continue  # empty list was found\n        if truncating == 'pre':\n            trunc = s[-maxlen:]\n        elif truncating == 'post':\n            trunc = s[:maxlen]\n        else:\n            raise ValueError('Truncating type \"%s\" not understood' % truncating)\n\n        # check `trunc` has expected shape\n        trunc = np.asarray(trunc, dtype=dtype)\n        if trunc.shape[1:] != sample_shape:\n            raise ValueError('Shape of sample %s of sequence at position %s is different from expected shape %s' %\n                             (trunc.shape[1:], idx, sample_shape))\n\n        if padding == 'post':\n            x[idx, :len(trunc)] = trunc\n        elif padding == 'pre':\n            x[idx, -len(trunc):] = trunc\n        else:\n            raise ValueError('Padding type \"%s\" not understood' % padding)\n    return x\n\n\nclass Progbar(object):\n    def __init__(self, target, width=30, verbose=1):\n        '''\n            @param target: total number of steps expected\n        '''\n        self.width = width\n        self.target = target\n        self.sum_values = {}\n        self.unique_values = []\n        self.start = time.time()\n        self.total_width = 0\n        self.seen_so_far = 0\n        self.verbose = verbose\n\n    def update(self, current, values=[]):\n        '''\n            @param current: index of current step\n            @param values: list of tuples (name, value_for_last_step).\n            The progress bar will display averages for these values.\n        '''\n        for k, v in values:\n            if k not in self.sum_values:\n                self.sum_values[k] = [v * (current - self.seen_so_far), current - self.seen_so_far]\n                self.unique_values.append(k)\n            else:\n                self.sum_values[k][0] += v * (current - self.seen_so_far)\n                self.sum_values[k][1] += (current - self.seen_so_far)\n        self.seen_so_far = current\n\n        now = time.time()\n        if self.verbose == 1:\n            prev_total_width = self.total_width\n            sys.stdout.write(\"\\b\" * prev_total_width)\n            sys.stdout.write(\"\\r\")\n\n            numdigits = int(np.floor(np.log10(self.target))) + 1\n            barstr = '%%%dd/%%%dd [' % (numdigits, numdigits)\n            bar = barstr % (current, self.target)\n            prog = float(current)/self.target\n            prog_width = int(self.width*prog)\n            if prog_width > 0:\n                bar += ('='*(prog_width-1))\n                if current < self.target:\n                    bar += '>'\n                else:\n                    bar += '='\n            bar += ('.'*(self.width-prog_width))\n            bar += ']'\n            sys.stdout.write(bar)\n            self.total_width = len(bar)\n\n            if current:\n                time_per_unit = (now - self.start) / current\n            else:\n                time_per_unit = 0\n            eta = time_per_unit*(self.target - current)\n            info = ''\n            if current < self.target:\n                info += ' - ETA: %ds' % eta\n            else:\n                info += ' - %ds' % (now - self.start)\n            for k in self.unique_values:\n                info += ' - %s: %.4f' % (k, self.sum_values[k][0] / max(1, self.sum_values[k][1]))\n\n            self.total_width += len(info)\n            if prev_total_width > self.total_width:\n                info += ((prev_total_width-self.total_width) * \" \")\n\n            sys.stdout.write(info)\n            sys.stdout.flush()\n\n            if current >= self.target:\n                sys.stdout.write(\"\\n\")\n\n        if self.verbose == 2:\n            if current >= self.target:\n                info = '%ds' % (now - self.start)\n                for k in self.unique_values:\n                    info += ' - %s: %.4f' % (k, self.sum_values[k][0] / max(1, self.sum_values[k][1]))\n                sys.stdout.write(info + \"\\n\")\n\n    def add(self, n, values=[]):\n        self.update(self.seen_so_far+n, values)\n\n\n"
  },
  {
    "path": "nn/utils/io_utils.py",
    "content": "from __future__ import absolute_import\n\nimport cPickle\nimport h5py\nimport numpy as np\nfrom collections import defaultdict\n\n\nclass HDF5Matrix():\n    refs = defaultdict(int)\n\n    def __init__(self, datapath, dataset, start, end, normalizer=None):\n        if datapath not in list(self.refs.keys()):\n            f = h5py.File(datapath)\n            self.refs[datapath] = f\n        else:\n            f = self.refs[datapath]\n        self.start = start\n        self.end = end\n        self.data = f[dataset]\n        self.normalizer = normalizer\n\n    def __len__(self):\n        return self.end - self.start\n\n    def __getitem__(self, key):\n        if isinstance(key, slice):\n            if key.stop + self.start <= self.end:\n                idx = slice(key.start+self.start, key.stop + self.start)\n            else:\n                raise IndexError\n        elif isinstance(key, int):\n            if key + self.start < self.end:\n                idx = key+self.start\n            else:\n                raise IndexError\n        elif isinstance(key, np.ndarray):\n            if np.max(key) + self.start < self.end:\n                idx = (self.start + key).tolist()\n            else:\n                raise IndexError\n        elif isinstance(key, list):\n            if max(key) + self.start < self.end:\n                idx = [x + self.start for x in key]\n            else:\n                raise IndexError\n        if self.normalizer is not None:\n            return self.normalizer(self.data[idx])\n        else:\n            return self.data[idx]\n\n    @property\n    def shape(self):\n        return tuple([self.end - self.start, self.data.shape[1]])\n\n\ndef save_array(array, name):\n    import tables\n    f = tables.open_file(name, 'w')\n    atom = tables.Atom.from_dtype(array.dtype)\n    ds = f.createCArray(f.root, 'data', atom, array.shape)\n    ds[:] = array\n    f.close()\n\n\ndef load_array(name):\n    import tables\n    f = tables.open_file(name)\n    array = f.root.data\n    a = np.empty(shape=array.shape, dtype=array.dtype)\n    a[:] = array[:]\n    f.close()\n    return a\n\n\ndef serialize_to_file(obj, path, protocol=cPickle.HIGHEST_PROTOCOL):\n    f = open(path, 'wb')\n    cPickle.dump(obj, f, protocol=protocol)\n    f.close()\n\n\ndef deserialize_from_file(path):\n    f = open(path, 'rb')\n    obj = cPickle.load(f)\n    f.close()\n    return obj"
  },
  {
    "path": "nn/utils/np_utils.py",
    "content": "from __future__ import absolute_import\nimport numpy as np\nimport scipy as sp\nfrom six.moves import range\nfrom six.moves import zip\n\n\ndef to_categorical(y, nb_classes=None):\n    '''Convert class vector (integers from 0 to nb_classes)\n    to binary class matrix, for use with categorical_crossentropy\n    '''\n    y = np.asarray(y, dtype='int32')\n    if not nb_classes:\n        nb_classes = np.max(y)+1\n    Y = np.zeros((len(y), nb_classes))\n    for i in range(len(y)):\n        Y[i, y[i]] = 1.\n    return Y\n\n\ndef normalize(a, axis=-1, order=2):\n    l2 = np.atleast_1d(np.linalg.norm(a, order, axis))\n    l2[l2 == 0] = 1\n    return a / np.expand_dims(l2, axis)\n\n\ndef binary_logloss(p, y):\n    epsilon = 1e-15\n    p = sp.maximum(epsilon, p)\n    p = sp.minimum(1-epsilon, p)\n    res = sum(y * sp.log(p) + sp.subtract(1, y) * sp.log(sp.subtract(1, p)))\n    res *= -1.0/len(y)\n    return res\n\n\ndef multiclass_logloss(P, Y):\n    score = 0.\n    npreds = [P[i][Y[i]-1] for i in range(len(Y))]\n    score = -(1. / len(Y)) * np.sum(np.log(npreds))\n    return score\n\n\ndef accuracy(p, y):\n    return np.mean([a == b for a, b in zip(p, y)])\n\n\ndef probas_to_classes(y_pred):\n    if len(y_pred.shape) > 1 and y_pred.shape[1] > 1:\n        return categorical_probas_to_classes(y_pred)\n    return np.array([1 if p > 0.5 else 0 for p in y_pred])\n\n\ndef categorical_probas_to_classes(p):\n    return np.argmax(p, axis=1)\n"
  },
  {
    "path": "nn/utils/test_utils.py",
    "content": "import numpy as np\n\n\ndef get_test_data(nb_train=1000, nb_test=500, input_shape=(10,), output_shape=(2,),\n                  classification=True, nb_class=2):\n    '''\n        classification=True overrides output_shape\n        (i.e. output_shape is set to (1,)) and the output\n        consists in integers in [0, nb_class-1].\n\n        Otherwise: float output with shape output_shape.\n    '''\n    nb_sample = nb_train + nb_test\n    if classification:\n        y = np.random.randint(0, nb_class, size=(nb_sample, 1))\n        X = np.zeros((nb_sample,) + input_shape)\n        for i in range(nb_sample):\n            X[i] = np.random.normal(loc=y[i], scale=1.0, size=input_shape)\n    else:\n        y_loc = np.random.random((nb_sample,))\n        X = np.zeros((nb_sample,) + input_shape)\n        y = np.zeros((nb_sample,) + output_shape)\n        for i in range(nb_sample):\n            X[i] = np.random.normal(loc=y_loc[i], scale=1.0, size=input_shape)\n            y[i] = np.random.normal(loc=y_loc[i], scale=1.0, size=output_shape)\n\n    return (X[:nb_train], y[:nb_train]), (X[nb_train:], y[nb_train:])\n"
  },
  {
    "path": "nn/utils/theano_utils.py",
    "content": "from __future__ import absolute_import\nimport numpy as np\nimport theano\nimport theano.tensor as T\n\n\ndef floatX(X):\n    return np.asarray(X, dtype=theano.config.floatX)\n\n\ndef sharedX(X, dtype=theano.config.floatX, name=None):\n    return theano.shared(np.asarray(X, dtype=dtype), name=name)\n\n\ndef shared_zeros(shape, dtype=theano.config.floatX, name=None):\n    return sharedX(np.zeros(shape), dtype=dtype, name=name)\n\n\ndef shared_scalar(val=0., dtype=theano.config.floatX, name=None):\n    return theano.shared(np.cast[dtype](val))\n\n\ndef shared_ones(shape, dtype=theano.config.floatX, name=None):\n    return sharedX(np.ones(shape), dtype=dtype, name=name)\n\n\ndef alloc_zeros_matrix(*dims):\n    return T.alloc(np.cast[theano.config.floatX](0.), *dims)\n\n\ndef tensor_right_shift(tensor):\n    temp = T.zeros_like(tensor)\n    temp = T.set_subtensor(temp[:, 1:, :], tensor[:, :-1, :])\n\n    return temp\n\n\ndef ndim_tensor(ndim, name=None):\n    if ndim == 1:\n        return T.vector()\n    elif ndim == 2:\n        return T.matrix()\n    elif ndim == 3:\n        return T.tensor3()\n    elif ndim == 4:\n        return T.tensor4()\n    return T.matrix(name=name)\n\n\n# get int32 tensor\ndef ndim_itensor(ndim, name=None):\n    if ndim == 2:\n        return T.imatrix(name)\n    elif ndim == 3:\n        return T.itensor3(name)\n    elif ndim == 4:\n        return T.itensor4(name)\n    return T.imatrix(name=name)\n\n\n# get int8 tensor\ndef ndim_btensor(ndim, name=None):\n    if ndim == 2:\n        return T.bmatrix(name)\n    elif ndim == 3:\n        return T.btensor3(name)\n    elif ndim == 4:\n        return T.btensor4(name)\n    return T.imatrix(name)"
  },
  {
    "path": "parse.py",
    "content": "import ast\nimport re\nimport sys, inspect\nfrom StringIO import StringIO\n\nimport astor\nfrom collections import OrderedDict\nfrom tokenize import generate_tokens, tokenize\nimport token as tk\n\nfrom nn.utils.io_utils import deserialize_from_file, serialize_to_file\n\nfrom astnode import *\n\n\nif __name__ == '__main__':\n    #     node = ast.parse('''\n    # # for i in range(1, 100):\n    # #  sum = sum + i\n    # #\n    # # sorted(arr, reverse=True)\n    # # sorted(my_dict, key=lambda x: my_dict[x], reverse=True)\n    # # m = dict ( zip ( new_keys , keys ) )\n    # # for f in sorted ( os . listdir ( self . path ) ) :\n    # #     pass\n    # for f in sorted ( os . listdir ( self . path ) ) : pass\n    # ''')\n    # print ast.dump(node, annotate_fields=False)\n    # print get_tree_str_repr(node)\n    # print parse('sorted(my_dict, key=lambda x: my_dict[x], reverse=True)')\n    # print parse('global _standard_context_processors')\n\n    # parse_django('/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code')\n\n    # code = 'sum = True'\n    # print parse_tree\n    # ast_tree = tree_to_ast(parse_tree)\n    # # # # #\n    # import astor\n    # print astor.to_source(ast_tree)\n\n    from dataset import DataSet, Vocab, DataEntry, Action\n    # train_data, dev_data, test_data = deserialize_from_file('django.cleaned.dataset.bin')\n    # cand_list = deserialize_from_file('cand_hyps.18771.bin')\n    # hyp_tree = cand_list[3].tree\n    #\n    # ast_tree = decode_tree_to_ast(hyp_tree)\n    # print astor.to_source(ast_tree)\n\n    pass"
  },
  {
    "path": "parse_hiro.py",
    "content": "import ast\nimport sys\nimport re\nimport inspect\n\ndef typename(x):\n    return type(x).__name__\n\ndef escape(text):\n    text = text \\\n        .replace('\"', '`') \\\n        .replace('\\'', '`') \\\n        .replace(' ', '-SP-') \\\n        .replace('\\t', '-TAB-') \\\n        .replace('\\n', '-NL-') \\\n        .replace('(', '-LRB-') \\\n        .replace(')', '-RRB-') \\\n        .replace('|', '-BAR-')\n    return repr(text)[1:-1] if text else '-NONE-'\n\ndef makestr(node):\n\n    #if node is None or isinstance(node, ast.Pass):\n    #    return ''\n\n    if isinstance(node, ast.AST):\n        n = 0\n        nodename = typename(node)\n        s = '(' + nodename\n        for chname, chval in ast.iter_fields(node):\n            chstr = makestr(chval)\n            if chstr:\n                s += ' (' + chname + ' ' + chstr + ')'\n                n += 1\n        if not n:\n            s += ' -' + nodename + '-' # (Foo) -> (Foo -Foo-)\n        s += ')'\n        return s\n\n    elif isinstance(node, list):\n        n = 0\n        s = '(list'\n        for ch in node:\n            chstr = makestr(ch)\n            if chstr:\n                s += ' ' + chstr\n                n += 1\n        s += ')'\n        return s if n else ''\n\n    elif isinstance(node, str):\n        return '(str ' + escape(node) + ')'\n\n    elif isinstance(node, bytes):\n        return '(bytes ' + escape(str(node)) + ')'\n\n    else:\n        return '(' + typename(node) + ' ' + str(node) + ')'\n\n\ndef main():\n    p_elif = re.compile(r'^elif\\s?')\n    p_else = re.compile(r'^else\\s?')\n    p_try = re.compile(r'^try\\s?')\n    p_except = re.compile(r'^except\\s?')\n    p_finally = re.compile(r'^finally\\s?')\n    p_decorator = re.compile(r'^@.*')\n\n    for l in [\"\"\"val = Header ( val , encoding ) . encode ( )\"\"\"]:  # val = ', ' . join ( sanitize_address ( addr , encoding )  for addr in getaddresses ( ( val , ) ) )\n        l = l.strip()\n        if not l:\n            print()\n            sys.stdout.flush()\n            continue\n\n        if p_elif.match(l): l = 'if True: pass\\n' + l\n        if p_else.match(l): l = 'if True: pass\\n' + l\n\n        if p_try.match(l): l = l + 'pass\\nexcept: pass'\n        elif p_except.match(l): l = 'try: pass\\n' + l\n        elif p_finally.match(l): l = 'try: pass\\n' + l\n\n        if p_decorator.match(l): l = l + '\\ndef dummy(): pass'\n        if l[-1] == ':': l = l + 'pass'\n\n        parse = ast.parse(l)\n        parse = parse.body[0]\n        dump = makestr(parse)\n        print(dump)\n        sys.stdout.flush()\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "run_interactive.sh",
    "content": "output=\"runs\"\ndevice=\"cpu\"\n\nif [ \"$1\" == \"hs\" ]; then\n\t# hs dataset\n\techo \"run trained model for hs\"\n\tdataset=\"data/hs.freq3.pre_suf.unary_closure.bin\"\n\tmodel=\"model.hs_unary_closure_top20_word128_encoder256_rule128_node64.beam15.adadelta.simple_trans.8e39832.iter5600.npz\"\n\tcommandline=\"-decode_max_time_step 350 -rule_embed_dim 128 -node_embed_dim 64\"\n\tdatatype=\"hs\"\nelse\n\t# django dataset\n\techo \"run trained model for django\"\n\tdataset=\"data/django.cleaned.dataset.freq5.par_info.refact.space_only.bin\"\n\tmodel=\"model.django_word128_encoder256_rule128_node64.beam15.adam.simple_trans.no_unary_closure.8e39832.run3.best_acc.npz\"\n\tcommandline=\"-rule_embed_dim 128 -node_embed_dim 64\"\n\tdatatype=\"django\"\nfi\n\n# run interactive mode on trained models\nTHEANO_FLAGS=\"mode=FAST_RUN,device=${device},floatX=float32\" python code_gen.py \\\n\t-data_type ${datatype} \\\n\t-data ${dataset} \\\n\t-output_dir ${output} \\\n\t-model models/${model} \\\n\t${commandline} \\\n\tinteractive \\\n\t-mode new"
  },
  {
    "path": "run_interactive_singlefile.sh",
    "content": "device=\"cpu\"\n\nif [ \"$1\" == \"hs\" ]; then\n\t# hs dataset\n\techo \"run trained model for hs\"\n\tdataset=\"data/hs.freq3.pre_suf.unary_closure.bin\"\n\tmodel=\"models/model.hs_unary_closure_top20_word128_encoder256_rule128_node64.beam15.adadelta.simple_trans.8e39832.iter5600.npz\"\n\tdatatype=\"hs\"\nelse\n\t# django dataset\n\techo \"run trained model for django\"\n\tdataset=\"data/django.cleaned.dataset.freq5.par_info.refact.space_only.bin\"\n\tmodel=\"models/model.django_word128_encoder256_rule128_node64.beam15.adam.simple_trans.no_unary_closure.8e39832.run3.best_acc.npz\"\n\tdatatype=\"django\"\nfi\n\n# run interactive mode on trained models\n# run interactive mode on trained models\nTHEANO_FLAGS=\"mode=FAST_RUN,device=${device},floatX=float32\" python interactive_mode.py \\\n\t-data_type ${datatype} \\\n\t-data ${dataset} \\\n\t-model ${model}"
  },
  {
    "path": "run_trained_model.sh",
    "content": "output=\"runs\"\ndevice=\"cpu\"\n\nif [ \"$1\" == \"hs\" ]; then\n\t# hs dataset\n\techo \"run trained model for hs\"\n\tdataset=\"data/hs.freq3.pre_suf.unary_closure.bin\"\n\tmodel=\"model.hs_unary_closure_top20_word128_encoder256_rule128_node64.beam15.adadelta.simple_trans.8e39832.iter5600.npz\"\n\tcommandline=\"-decode_max_time_step 350 -rule_embed_dim 128 -node_embed_dim 64\"\n\tdatatype=\"hs\"\nelse\n\t# django dataset\n\techo \"run trained model for django\"\n\tdataset=\"data/django.cleaned.dataset.freq5.par_info.refact.space_only.bin\"\n\tmodel=\"model.django_word128_encoder256_rule128_node64.beam15.adam.simple_trans.no_unary_closure.8e39832.run3.best_acc.npz\"\n\tcommandline=\"-rule_embed_dim 128 -node_embed_dim 64\"\n\tdatatype=\"django\"\nfi\n\n# decode the test set and save the nbest decoding results\nTHEANO_FLAGS=\"mode=FAST_RUN,device=${device},floatX=float32\" python code_gen.py \\\n-data_type ${datatype} \\\n-data ${dataset} \\\n-output_dir ${output} \\\n-model models/${model} \\\n${commandline} \\\ndecode \\\n-saveto ${output}/${model}.decode_results.test.bin\n\n# evaluate the decoding result\npython code_gen.py \\\n\t-data_type ${datatype} \\\n\t-data ${dataset} \\\n\t-output_dir ${output} \\\n\tevaluate \\\n\t-input ${output}/${model}.decode_results.test.bin | tee ${output}/${model}.decode_results.test.log\n"
  },
  {
    "path": "train.sh",
    "content": "output=\"runs\"\ndevice=\"cpu\"\n\nif [ \"$1\" == \"hs\" ]; then\n\t# hs dataset\n\techo \"training hs dataset\"\n\tdataset=\"hs.freq3.pre_suf.unary_closure.bin\"\n\tcommandline=\"-batch_size 10 -max_epoch 200 -valid_per_batch 280 -save_per_batch 280 -decode_max_time_step 350 -optimizer adadelta -rule_embed_dim 128 -node_embed_dim 64 -valid_metric bleu\"\n\tdatatype=\"hs\"\nelse\n\t# django dataset\n\techo \"training django dataset\"\n\tdataset=\"django.cleaned.dataset.freq5.par_info.refact.space_only.bin\"\n\tcommandline=\"-batch_size 10 -max_epoch 50 -valid_per_batch 4000 -save_per_batch 4000 -decode_max_time_step 100 -optimizer adam -rule_embed_dim 128 -node_embed_dim 64 -valid_metric bleu\"\n\tdatatype=\"django\"\nfi\n\n# train the model\nTHEANO_FLAGS=\"mode=FAST_RUN,device=${device},floatX=float32\" python -u code_gen.py \\\n\t-data_type ${datatype} \\\n\t-data data/${dataset} \\\n\t-output_dir ${output} \\\n\t${commandline} \\\n\ttrain\n\n# decode testing set, and evaluate the model which achieves the best bleu and accuracy, resp.\nfor model in \"model.best_bleu.npz\" \"model.best_acc.npz\"; do\n\tTHEANO_FLAGS=\"mode=FAST_RUN,device=${device},floatX=float32\" python code_gen.py \\\n\t-data_type ${datatype} \\\n\t-data data/${dataset} \\\n\t-output_dir ${output} \\\n\t-model ${output}/${model} \\\n\t${commandline} \\\n\tdecode \\\n\t-saveto ${output}/${model}.decode_results.test.bin\n\n\tpython code_gen.py \\\n\t\t-data_type ${datatype} \\\n\t\t-data data/${dataset} \\\n\t\t-output_dir ${output} \\\n\t\tevaluate \\\n\t\t-input ${output}/${model}.decode_results.test.bin\ndone"
  },
  {
    "path": "util.py",
    "content": "def is_numeric(s):\n    if s[0] in ('-', '+'):\n        return s[1:].isdigit()\n    return s.isdigit()"
  }
]