[
  {
    "path": "README.md",
    "content": "# Representing Schema Structure with Graph Neural Networks for Text-to-SQL Parsing\n\nAuthor implementation of this [ACL 2019 paper](https://arxiv.org/abs/1905.06241).\n\nPlease also see the [follow-up repository](https://github.com/benbogin/spider-schema-gnn-global) with improved results, for this [EMNLP paper](https://www.aclweb.org/anthology/D19-1378.pdf).\n\n## Install & Configure\n\n1. Install pytorch version 1.0.1.post2 that fits your CUDA version \n   \n   (this repository should probably work with the latest pytorch version, but wasn't tested for it. If you use another version, you'll need to also update the versions of packages in `requirements.txt`)\n    ```\n    pip install https://download.pytorch.org/whl/cu100/torch-1.0.1.post2-cp37-cp37m-linux_x86_64.whl # CUDA 10.0 build\n    ```\n    \n2. Install the rest of required packages\n    ```\n    pip install -r requirements.txt\n    ```\n    \n3. Run this command to install NLTK punkt.\n```\npython -c \"import nltk; nltk.download('punkt')\"\n```\n\n4. Download the dataset from the [official Spider dataset website](https://yale-lily.github.io/spider)\n\n5. Edit the config file `train_configs/defaults.jsonnet` to update the location of the dataset:\n```\nlocal dataset_path = \"dataset/\";\n```\n\n## Training\n\n1. Use the following AllenNLP command to train:\n```\nallennlp train train_configs/defaults.jsonnet -s experiments/name_of_experiment \\\n--include-package dataset_readers.spider \\ \n--include-package models.semantic_parsing.spider_parser\n``` \n\nFirst time loading of the dataset might take a while (a few hours) since the model first loads values from tables and calculates similarity features with the relevant question. It will then be cached for subsequent runs.\n\nYou should get results similar to the following (the `sql_match` is the one measured in the official evaluation test):\n```\n  \"best_validation__match/exact_match\": 0.3715686274509804,\n  \"best_validation_sql_match\": 0.47549019607843135,\n  \"best_validation__others/action_similarity\": 0.5731271471206189,\n  \"best_validation__match/match_single\": 0.6254612546125461,\n  \"best_validation__match/match_hard\": 0.3054393305439331,\n  \"best_validation_beam_hit\": 0.6070588235294118,\n  \"best_validation_loss\": 7.383035182952881\n  \"best_epoch\": 32\n```\n\nNote that the hyper-parameters used in `defaults.jsonnet` are different than those mentioned in the paper\n(most importantly, 3 timesteps are used instead of 2), thanks to the [following contribution from @wlhgtc](https://github.com/benbogin/spider-schema-gnn/pull/13).\nThe original training config file is still available in `train_configs/paper_Defaults.jsonnet`.\n\n## Inference\n\nUse the following AllenNLP command to output a file with the predicted queries:\n\n```\nallennlp predict experiments/name_of_experiment dataset/dev.json \\\n--predictor spider \\\n--use-dataset-reader \\\n--cuda-device=0 \\\n--output-file experiments/name_of_experiment/prediction.sql \\\n--silent \\\n--include-package models.semantic_parsing.spider_parser \\\n--include-package dataset_readers.spider \\\n--include-package predictors.spider_predictor \\\n--weights-file experiments/name_of_experiment/best.th \\\n-o \"{\\\"dataset_reader\\\":{\\\"keep_if_unparsable\\\":true}}\"\n```\n"
  },
  {
    "path": "dataset_readers/__init__.py",
    "content": ""
  },
  {
    "path": "dataset_readers/dataset_util/spider_utils.py",
    "content": "\"\"\"\nUtility functions for reading the standardised text2sql datasets presented in\n`\"Improving Text to SQL Evaluation Methodology\" <https://arxiv.org/abs/1806.09029>`_\n\"\"\"\nimport json\nimport os\nimport sqlite3\nfrom collections import defaultdict\nfrom typing import List, Dict, Optional\n\nfrom allennlp.common import JsonDict\n\nfrom spider_evaluation.process_sql import get_tables_with_alias, parse_sql\n\n\nclass TableColumn:\n    def __init__(self,\n                 name: str,\n                 text: str,\n                 column_type: str,\n                 is_primary_key: bool,\n                 foreign_key: Optional[str]):\n        self.name = name\n        self.text = text\n        self.column_type = column_type\n        self.is_primary_key = is_primary_key\n        self.foreign_key = foreign_key\n\n\nclass Table:\n    def __init__(self,\n                 name: str,\n                 text: str,\n                 columns: List[TableColumn]):\n        self.name = name\n        self.text = text\n        self.columns = columns\n\n\ndef read_dataset_schema(schema_path: str) -> Dict[str, List[Table]]:\n    schemas: Dict[str, Dict[str, Table]] = defaultdict(dict)\n    dbs_json_blob = json.load(open(schema_path, \"r\"))\n    for db in dbs_json_blob:\n        db_id = db['db_id']\n\n        column_id_to_table = {}\n        column_id_to_column = {}\n\n        for i, (column, text, column_type) in enumerate(zip(db['column_names_original'], db['column_names'], db['column_types'])):\n            table_id, column_name = column\n            _, column_text = text\n\n            table_name = db['table_names_original'][table_id]\n\n            if table_name not in schemas[db_id]:\n                table_text = db['table_names'][table_id]\n                schemas[db_id][table_name] = Table(table_name, table_text, [])\n\n            if column_name == \"*\":\n                continue\n\n            is_primary_key = i in db['primary_keys']\n            table_column = TableColumn(column_name.lower(), column_text, column_type, is_primary_key, None)\n            schemas[db_id][table_name].columns.append(table_column)\n            column_id_to_table[i] = table_name\n            column_id_to_column[i] = table_column\n\n        for (c1, c2) in db['foreign_keys']:\n            foreign_key = column_id_to_table[c2] + ':' + column_id_to_column[c2].name\n            column_id_to_column[c1].foreign_key = foreign_key\n\n    return {**schemas}\n\n\ndef read_dataset_values(db_id: str, dataset_path: str, tables: List[str]):\n    db = os.path.join(dataset_path, db_id, db_id + \".sqlite\")\n    try:\n        conn = sqlite3.connect(db)\n    except Exception as e:\n        raise Exception(f\"Can't connect to SQL: {e} in path {db}\")\n    conn.text_factory = str\n    cursor = conn.cursor()\n\n    values = {}\n\n    for table in tables:\n        try:\n            cursor.execute(f\"SELECT * FROM {table.name} LIMIT 5000\")\n            values[table] = cursor.fetchall()\n        except:\n            conn.text_factory = lambda x: str(x, 'latin1')\n            cursor = conn.cursor()\n            cursor.execute(f\"SELECT * FROM {table.name} LIMIT 5000\")\n            values[table] = cursor.fetchall()\n\n    return values\n\n\ndef ent_key_to_name(key):\n    parts = key.split(':')\n    if parts[0] == 'table':\n        return parts[1]\n    elif parts[0] == 'column':\n        _, _, table_name, column_name = parts\n        return f'{table_name}@{column_name}'\n    else:\n        return parts[1]\n\n\ndef fix_number_value(ex: JsonDict):\n    \"\"\"\n    There is something weird in the dataset files - the `query_toks_no_value` field anonymizes all values,\n    which is good since the evaluator doesn't check for the values. But it also anonymizes numbers that\n    should not be anonymized: e.g. LIMIT 3 becomes LIMIT 'value', while the evaluator fails if it is not a number.\n    \"\"\"\n\n    def split_and_keep(s, sep):\n        if not s: return ['']  # consistent with string.split()\n\n        # Find replacement character that is not used in string\n        # i.e. just use the highest available character plus one\n        # Note: This fails if ord(max(s)) = 0x10FFFF (ValueError)\n        p = chr(ord(max(s)) + 1)\n\n        return s.replace(sep, p + sep + p).split(p)\n\n    # input is tokenized in different ways... so first try to make splits equal\n    query_toks = ex['query_toks']\n    ex['query_toks'] = []\n    for q in query_toks:\n        ex['query_toks'] += split_and_keep(q, '.')\n\n    i_val, i_no_val = 0, 0\n    while i_val < len(ex['query_toks']) and i_no_val < len(ex['query_toks_no_value']):\n        if ex['query_toks_no_value'][i_no_val] != 'value':\n            i_val += 1\n            i_no_val += 1\n            continue\n\n        i_val_end = i_val\n        while i_val + 1 < len(ex['query_toks']) and \\\n                i_no_val + 1 < len(ex['query_toks_no_value']) and \\\n                ex['query_toks'][i_val_end + 1].lower() != ex['query_toks_no_value'][i_no_val + 1].lower():\n            i_val_end += 1\n\n        if i_val == i_val_end and ex['query_toks'][i_val] in [\"1\", \"2\", \"3\"] and ex['query_toks'][i_val - 1].lower() == \"limit\":\n            ex['query_toks_no_value'][i_no_val] = ex['query_toks'][i_val]\n        i_val = i_val_end\n\n        i_val += 1\n        i_no_val += 1\n\n    return ex\n\n\n_schemas_cache = None\n\n\ndef disambiguate_items(db_id: str, query_toks: List[str], tables_file: str, allow_aliases: bool) -> List[str]:\n    \"\"\"\n    we want the query tokens to be non-ambiguous - so we can change each column name to explicitly\n    tell which table it belongs to\n\n    parsed sql to sql clause is based on supermodel.gensql from syntaxsql\n    \"\"\"\n\n    class Schema:\n        \"\"\"\n        Simple schema which maps table&column to a unique identifier\n        \"\"\"\n\n        def __init__(self, schema, table):\n            self._schema = schema\n            self._table = table\n            self._idMap = self._map(self._schema, self._table)\n\n        @property\n        def schema(self):\n            return self._schema\n\n        @property\n        def idMap(self):\n            return self._idMap\n\n        def _map(self, schema, table):\n            column_names_original = table['column_names_original']\n            table_names_original = table['table_names_original']\n            # print 'column_names_original: ', column_names_original\n            # print 'table_names_original: ', table_names_original\n            for i, (tab_id, col) in enumerate(column_names_original):\n                if tab_id == -1:\n                    idMap = {'*': i}\n                else:\n                    key = table_names_original[tab_id].lower()\n                    val = col.lower()\n                    idMap[key + \".\" + val] = i\n\n            for i, tab in enumerate(table_names_original):\n                key = tab.lower()\n                idMap[key] = i\n\n            return idMap\n\n    def get_schemas_from_json(fpath):\n        global _schemas_cache\n\n        if _schemas_cache is not None:\n            return _schemas_cache\n\n        with open(fpath) as f:\n            data = json.load(f)\n        db_names = [db['db_id'] for db in data]\n\n        tables = {}\n        schemas = {}\n        for db in data:\n            db_id = db['db_id']\n            schema = {}  # {'table': [col.lower, ..., ]} * -> __all__\n            column_names_original = db['column_names_original']\n            table_names_original = db['table_names_original']\n            tables[db_id] = {'column_names_original': column_names_original,\n                             'table_names_original': table_names_original}\n            for i, tabn in enumerate(table_names_original):\n                table = str(tabn.lower())\n                cols = [str(col.lower()) for td, col in column_names_original if td == i]\n                schema[table] = cols\n            schemas[db_id] = schema\n\n        _schemas_cache = schemas, db_names, tables\n        return _schemas_cache\n\n    schemas, db_names, tables = get_schemas_from_json(tables_file)\n    schema = Schema(schemas[db_id], tables[db_id])\n\n    fixed_toks = []\n    i = 0\n    while i < len(query_toks):\n        tok = query_toks[i]\n        if tok == 'value' or tok == \"'value'\":\n            # TODO: value should alawys be between '/\" (remove first if clause)\n            new_tok = f'\"{tok}\"'\n        elif tok in ['!','<','>'] and query_toks[i+1] == '=':\n            new_tok = tok + '='\n            i += 1\n        elif i+1 < len(query_toks) and query_toks[i+1] == '.':\n            new_tok = ''.join(query_toks[i:i+3])\n            i += 2\n        else:\n            new_tok = tok\n        fixed_toks.append(new_tok)\n        i += 1\n\n    toks = fixed_toks\n\n    tables_with_alias = get_tables_with_alias(schema.schema, toks)\n    _, sql, mapped_entities = parse_sql(toks, 0, tables_with_alias, schema, mapped_entities_fn=lambda: [])\n\n    for i, new_name in mapped_entities:\n        curr_tok = toks[i]\n        if '.' in curr_tok and allow_aliases:\n            parts = curr_tok.split('.')\n            assert(len(parts) == 2)\n            toks[i] = parts[0] + '.' + new_name\n        else:\n            toks[i] = new_name\n\n    if not allow_aliases:\n        toks = [tok for tok in toks if tok not in ['as', 't1', 't2', 't3', 't4']]\n\n    toks = [f'\\'value\\'' if tok == '\"value\"' else tok for tok in toks]\n\n    return toks"
  },
  {
    "path": "dataset_readers/fields/knowledge_graph_field.py",
    "content": "\"\"\"\n``KnowledgeGraphField`` is a ``Field`` which stores a knowledge graph representation.\n\"\"\"\nfrom typing import List, Dict\n\nfrom allennlp.data import TokenIndexer, Tokenizer\nfrom allennlp.data.fields.knowledge_graph_field import KnowledgeGraphField\nfrom allennlp.data.tokenizers.token import Token\nfrom allennlp.semparse.contexts.knowledge_graph import KnowledgeGraph\n\n\nclass SpiderKnowledgeGraphField(KnowledgeGraphField):\n    \"\"\"\n    This implementation calculates all non-graph-related features (i.e. no related_column),\n    then takes each one of the features to calculate related column features, by taking the max score of all neighbours\n    \"\"\"\n    def __init__(self,\n                 knowledge_graph: KnowledgeGraph,\n                 utterance_tokens: List[Token],\n                 token_indexers: Dict[str, TokenIndexer],\n                 tokenizer: Tokenizer = None,\n                 feature_extractors: List[str] = None,\n                 entity_tokens: List[List[Token]] = None,\n                 linking_features: List[List[List[float]]] = None,\n                 include_in_vocab: bool = True,\n                 max_table_tokens: int = None) -> None:\n        feature_extractors = feature_extractors if feature_extractors is not None else [\n                'number_token_match',\n                'exact_token_match',\n                'contains_exact_token_match',\n                'lemma_match',\n                'contains_lemma_match',\n                'edit_distance',\n                'span_overlap_fraction',\n                'span_lemma_overlap_fraction',\n                ]\n\n        super().__init__(knowledge_graph, utterance_tokens, token_indexers,\n                         tokenizer=tokenizer, feature_extractors=feature_extractors, entity_tokens=entity_tokens,\n                         linking_features=linking_features, include_in_vocab=include_in_vocab,\n                         max_table_tokens=max_table_tokens)\n\n        self.linking_features = self._compute_related_linking_features(self.linking_features)\n\n        # hack needed to fix calculation of feature extractors in the inherited as_tensor method\n        self._feature_extractors = feature_extractors * 2\n\n    def _compute_related_linking_features(self,\n                                          non_related_features: List[List[List[float]]]) -> List[List[List[float]]]:\n        linking_features = non_related_features\n        entity_to_index_map = {}\n        for entity_id, entity in enumerate(self.knowledge_graph.entities):\n            entity_to_index_map[entity] = entity_id\n        for entity_id, (entity, entity_text) in enumerate(zip(self.knowledge_graph.entities, self.entity_texts)):\n            for token_index, token in enumerate(self.utterance_tokens):\n                entity_token_features = linking_features[entity_id][token_index]\n                for feature_index, feature_extractor in enumerate(self._feature_extractors):\n                    neighbour_features = []\n                    for neighbor in self.knowledge_graph.neighbors[entity]:\n                        # we only care about table/columns relations here, not foreign-primary\n                        if entity.startswith('column') and neighbor.startswith('column'):\n                            continue\n                        neighbor_index = entity_to_index_map[neighbor]\n                        neighbour_features.append(non_related_features[neighbor_index][token_index][feature_index])\n\n                    entity_token_features.append(max(neighbour_features))\n        return linking_features\n"
  },
  {
    "path": "dataset_readers/spider.py",
    "content": "import json\nimport logging\nimport os\nfrom typing import List, Dict\n\nimport dill\nfrom allennlp.common.checks import ConfigurationError\nfrom allennlp.data import DatasetReader, Tokenizer, TokenIndexer, Field, Instance\nfrom allennlp.data.fields import TextField, ProductionRuleField, ListField, IndexField, MetadataField\nfrom allennlp.data.token_indexers import SingleIdTokenIndexer\nfrom allennlp.data.tokenizers import WordTokenizer\nfrom allennlp.data.tokenizers.word_splitter import SpacyWordSplitter\nfrom overrides import overrides\nfrom spacy.symbols import ORTH, LEMMA\n\nfrom dataset_readers.dataset_util.spider_utils import fix_number_value, disambiguate_items\nfrom dataset_readers.fields.knowledge_graph_field import SpiderKnowledgeGraphField\nfrom semparse.contexts.spider_db_context import SpiderDBContext\nfrom semparse.worlds.spider_world import SpiderWorld\n\nlogger = logging.getLogger(__name__)\n\n\n@DatasetReader.register(\"spider\")\nclass SpiderDatasetReader(DatasetReader):\n    def __init__(self,\n                 lazy: bool = False,\n                 question_token_indexers: Dict[str, TokenIndexer] = None,\n                 keep_if_unparsable: bool = True,\n                 tables_file: str = None,\n                 dataset_path: str = 'dataset/database',\n                 load_cache: bool = True,\n                 save_cache: bool = True,\n                 loading_limit = -1):\n        super().__init__(lazy=lazy)\n\n        # default spacy tokenizer splits the common token 'id' to ['i', 'd'], we here write a manual fix for that\n        spacy_tokenizer = SpacyWordSplitter(pos_tags=True)\n        spacy_tokenizer.spacy.tokenizer.add_special_case(u'id', [{ORTH: u'id', LEMMA: u'id'}])\n        self._tokenizer = WordTokenizer(spacy_tokenizer)\n\n        self._utterance_token_indexers = question_token_indexers or {'tokens': SingleIdTokenIndexer()}\n        self._keep_if_unparsable = keep_if_unparsable\n\n        self._tables_file = tables_file\n        self._dataset_path = dataset_path\n\n        self._load_cache = load_cache\n        self._save_cache = save_cache\n        self._loading_limit = loading_limit\n\n    @overrides\n    def _read(self, file_path: str):\n        if not file_path.endswith('.json'):\n            raise ConfigurationError(f\"Don't know how to read filetype of {file_path}\")\n\n        cache_dir = os.path.join('cache', file_path.split(\"/\")[-1])\n\n        if self._load_cache:\n            logger.info(f'Trying to load cache from {cache_dir}')\n        if self._save_cache:\n            os.makedirs(cache_dir, exist_ok=True)\n\n        cnt = 0\n        with open(file_path, \"r\") as data_file:\n            json_obj = json.load(data_file)\n            for total_cnt, ex in enumerate(json_obj):\n                cache_filename = f'instance-{total_cnt}.pt'\n                cache_filepath = os.path.join(cache_dir, cache_filename)\n                if self._loading_limit == cnt:\n                    break\n\n                if self._load_cache:\n                    try:\n                        ins = dill.load(open(cache_filepath, 'rb'))\n                        if ins is None and not self._keep_if_unparsable:\n                            # skip unparsed examples\n                            continue\n                        yield ins\n                        cnt += 1\n                        continue\n                    except Exception as e:\n                        # could not load from cache - keep loading without cache\n                        pass\n\n                query_tokens = None\n                if 'query_toks' in ex:\n                    # we only have 'query_toks' in example for training/dev sets\n\n                    # fix for examples: we want to use the 'query_toks_no_value' field of the example which anonymizes\n                    # values. However, it also anonymizes numbers (e.g. LIMIT 3 -> LIMIT 'value', which is not good\n                    # since the official evaluator does expect a number and not a value\n                    ex = fix_number_value(ex)\n\n                    # we want the query tokens to be non-ambiguous (i.e. know for each column the table it belongs to,\n                    # and for each table alias its explicit name)\n                    # we thus remove all aliases and make changes such as:\n                    # 'name' -> 'singer@name',\n                    # 'singer AS T1' -> 'singer',\n                    # 'T1.name' -> 'singer@name'\n                    try:\n                        query_tokens = disambiguate_items(ex['db_id'], ex['query_toks_no_value'],\n                                                                self._tables_file, allow_aliases=False)\n                    except Exception as e:\n                        # there are two examples in the train set that are wrongly formatted, skip them\n                        print(f\"error with {ex['query']}\")\n                        print(e)\n\n                ins = self.text_to_instance(\n                    utterance=ex['question'],\n                    db_id=ex['db_id'],\n                    sql=query_tokens)\n                if ins is not None:\n                    cnt += 1\n                if self._save_cache:\n                    dill.dump(ins, open(cache_filepath, 'wb'))\n\n                if ins is not None:\n                    yield ins\n\n    def text_to_instance(self,\n                         utterance: str,\n                         db_id: str,\n                         sql: List[str] = None):\n        fields: Dict[str, Field] = {}\n\n        db_context = SpiderDBContext(db_id, utterance, tokenizer=self._tokenizer,\n                                     tables_file=self._tables_file, dataset_path=self._dataset_path)\n        table_field = SpiderKnowledgeGraphField(db_context.knowledge_graph,\n                                                db_context.tokenized_utterance,\n                                                self._utterance_token_indexers,\n                                                entity_tokens=db_context.entity_tokens,\n                                                include_in_vocab=False,  # TODO: self._use_table_for_vocab,\n                                                max_table_tokens=None)  # self._max_table_tokens)\n\n        world = SpiderWorld(db_context, query=sql)\n        fields[\"utterance\"] = TextField(db_context.tokenized_utterance, self._utterance_token_indexers)\n\n        action_sequence, all_actions = world.get_action_sequence_and_all_actions()\n\n        if action_sequence is None and self._keep_if_unparsable:\n            # print(\"Parse error\")\n            action_sequence = []\n        elif action_sequence is None:\n            return None\n\n        index_fields: List[Field] = []\n        production_rule_fields: List[Field] = []\n\n        for production_rule in all_actions:\n            nonterminal, rhs = production_rule.split(' -> ')\n            production_rule = ' '.join(production_rule.split(' '))\n            field = ProductionRuleField(production_rule,\n                                        world.is_global_rule(rhs),\n                                        nonterminal=nonterminal)\n            production_rule_fields.append(field)\n\n        valid_actions_field = ListField(production_rule_fields)\n        fields[\"valid_actions\"] = valid_actions_field\n\n        action_map = {action.rule: i  # type: ignore\n                      for i, action in enumerate(valid_actions_field.field_list)}\n\n        for production_rule in action_sequence:\n            index_fields.append(IndexField(action_map[production_rule], valid_actions_field))\n        if not action_sequence:\n            index_fields = [IndexField(-1, valid_actions_field)]\n\n        action_sequence_field = ListField(index_fields)\n        fields[\"action_sequence\"] = action_sequence_field\n        fields[\"world\"] = MetadataField(world)\n        fields[\"schema\"] = table_field\n        return Instance(fields)\n"
  },
  {
    "path": "models/__init__.py",
    "content": ""
  },
  {
    "path": "models/semantic_parsing/__init__.py",
    "content": "from models.semantic_parsing.spider_parser import SpiderParser"
  },
  {
    "path": "models/semantic_parsing/spider_parser.py",
    "content": "import difflib\nimport os\nfrom functools import partial\nfrom typing import Dict, List, Tuple, Any, Mapping, Sequence\n\nimport sqlparse\nimport torch\nfrom allennlp.common.util import pad_sequence_to_length\nfrom allennlp.data import Vocabulary\nfrom allennlp.data.fields.production_rule_field import ProductionRule, ProductionRuleArray\nfrom allennlp.models import Model\nfrom allennlp.modules import TextFieldEmbedder, Seq2SeqEncoder, Seq2VecEncoder, Embedding, Attention, FeedForward, \\\n    TimeDistributed\nfrom allennlp.modules.seq2vec_encoders import BagOfEmbeddingsEncoder\nfrom allennlp.nn import util, Activation\nfrom allennlp.state_machines import BeamSearch\nfrom allennlp.state_machines.states import GrammarStatelet\nfrom torch_geometric.data import Data, Batch\n\nfrom modules.gated_graph_conv import GatedGraphConv\nfrom semparse.worlds.evaluate_spider import evaluate\nfrom state_machines.states.rnn_statelet import RnnStatelet\nfrom allennlp.state_machines.trainers import MaximumMarginalLikelihood\nfrom allennlp.training.metrics import Average\nfrom overrides import overrides\n\nfrom semparse.contexts.spider_context_utils import action_sequence_to_sql\nfrom semparse.worlds.spider_world import SpiderWorld\nfrom state_machines.states.grammar_based_state import GrammarBasedState\nfrom state_machines.states.sql_state import SqlState\nfrom state_machines.transition_functions.attend_past_schema_items_transition import \\\n    AttendPastSchemaItemsTransitionFunction\nfrom state_machines.transition_functions.linking_transition_function import LinkingTransitionFunction\n\n\n@Model.register(\"spider\")\nclass SpiderParser(Model):\n    def __init__(self,\n                 vocab: Vocabulary,\n                 encoder: Seq2SeqEncoder,\n                 entity_encoder: Seq2VecEncoder,\n                 decoder_beam_search: BeamSearch,\n                 question_embedder: TextFieldEmbedder,\n                 input_attention: Attention,\n                 past_attention: Attention,\n                 max_decoding_steps: int,\n                 action_embedding_dim: int,\n                 gnn: bool = True,\n                 decoder_use_graph_entities: bool = True,\n                 decoder_self_attend: bool = True,\n                 gnn_timesteps: int = 2,\n                 parse_sql_on_decoding: bool = True,\n                 add_action_bias: bool = True,\n                 use_neighbor_similarity_for_linking: bool = True,\n                 dataset_path: str = 'dataset',\n                 training_beam_size: int = None,\n                 decoder_num_layers: int = 1,\n                 dropout: float = 0.0,\n                 rule_namespace: str = 'rule_labels',\n                 scoring_dev_params: dict = None,\n                 debug_parsing: bool = False) -> None:\n        super().__init__(vocab)\n        self.vocab = vocab\n        self._encoder = encoder\n        self._max_decoding_steps = max_decoding_steps\n        if dropout > 0:\n            self._dropout = torch.nn.Dropout(p=dropout)\n        else:\n            self._dropout = lambda x: x\n        self._rule_namespace = rule_namespace\n        self._question_embedder = question_embedder\n        self._add_action_bias = add_action_bias\n        self._scoring_dev_params = scoring_dev_params or {}\n        self.parse_sql_on_decoding = parse_sql_on_decoding\n        self._entity_encoder = TimeDistributed(entity_encoder)\n        self._use_neighbor_similarity_for_linking = use_neighbor_similarity_for_linking\n        self._self_attend = decoder_self_attend\n        self._decoder_use_graph_entities = decoder_use_graph_entities\n\n        self._action_padding_index = -1  # the padding value used by IndexField\n\n        self._exact_match = Average()\n        self._sql_evaluator_match = Average()\n        self._action_similarity = Average()\n        self._acc_single = Average()\n        self._acc_multi = Average()\n        self._beam_hit = Average()\n\n        self._action_embedding_dim = action_embedding_dim\n\n        num_actions = vocab.get_vocab_size(self._rule_namespace)\n        if self._add_action_bias:\n            input_action_dim = action_embedding_dim + 1\n        else:\n            input_action_dim = action_embedding_dim\n        self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=input_action_dim)\n        self._output_action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim)\n\n        encoder_output_dim = encoder.get_output_dim()\n        if gnn:\n            encoder_output_dim += action_embedding_dim\n\n        self._first_action_embedding = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim))\n        self._first_attended_utterance = torch.nn.Parameter(torch.FloatTensor(encoder_output_dim))\n        self._first_attended_output = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim))\n        torch.nn.init.normal_(self._first_action_embedding)\n        torch.nn.init.normal_(self._first_attended_utterance)\n        torch.nn.init.normal_(self._first_attended_output)\n\n        self._num_entity_types = 9\n        self._embedding_dim = question_embedder.get_output_dim()\n\n        self._entity_type_encoder_embedding = Embedding(self._num_entity_types, self._embedding_dim)\n        self._entity_type_decoder_embedding = Embedding(self._num_entity_types, action_embedding_dim)\n\n        self._linking_params = torch.nn.Linear(16, 1)\n        torch.nn.init.uniform_(self._linking_params.weight, 0, 1)\n\n        num_edge_types = 3\n        self._gnn = GatedGraphConv(self._embedding_dim, gnn_timesteps, num_edge_types=num_edge_types, dropout=dropout)\n\n        self._decoder_num_layers = decoder_num_layers\n\n        self._beam_search = decoder_beam_search\n        self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size)\n\n        if decoder_self_attend:\n            self._transition_function = AttendPastSchemaItemsTransitionFunction(encoder_output_dim=encoder_output_dim,\n                                                                                action_embedding_dim=action_embedding_dim,\n                                                                                input_attention=input_attention,\n                                                                                past_attention=past_attention,\n                                                                                predict_start_type_separately=False,\n                                                                                add_action_bias=self._add_action_bias,\n                                                                                dropout=dropout,\n                                                                                num_layers=self._decoder_num_layers)\n        else:\n            self._transition_function = LinkingTransitionFunction(encoder_output_dim=encoder_output_dim,\n                                                                  action_embedding_dim=action_embedding_dim,\n                                                                  input_attention=input_attention,\n                                                                  predict_start_type_separately=False,\n                                                                  add_action_bias=self._add_action_bias,\n                                                                  dropout=dropout,\n                                                                  num_layers=self._decoder_num_layers)\n\n        self._ent2ent_ff = FeedForward(action_embedding_dim, 1, action_embedding_dim, Activation.by_name('relu')())\n\n        self._neighbor_params = torch.nn.Linear(self._embedding_dim, self._embedding_dim)\n\n        # TODO: Remove hard-coded dirs\n        self._evaluate_func = partial(evaluate,\n                                      db_dir=os.path.join(dataset_path, 'database'),\n                                      table=os.path.join(dataset_path, 'tables.json'),\n                                      check_valid=False)\n\n        self.debug_parsing = debug_parsing\n\n    @overrides\n    def forward(self,  # type: ignore\n                utterance: Dict[str, torch.LongTensor],\n                valid_actions: List[List[ProductionRule]],\n                world: List[SpiderWorld],\n                schema: Dict[str, torch.LongTensor],\n                action_sequence: torch.LongTensor = None) -> Dict[str, torch.Tensor]:\n\n        batch_size = len(world)\n        device = utterance['tokens'].device\n\n        initial_state = self._get_initial_state(utterance, world, schema, valid_actions)\n\n        if action_sequence is not None:\n            # Remove the trailing dimension (from ListField[ListField[IndexField]]).\n            action_sequence = action_sequence.squeeze(-1)\n            action_mask = action_sequence != self._action_padding_index\n        else:\n            action_mask = None\n\n        if self.training:\n            decode_output = self._decoder_trainer.decode(initial_state,\n                                                         self._transition_function,\n                                                         (action_sequence.unsqueeze(1), action_mask.unsqueeze(1)))\n\n            return {'loss': decode_output['loss']}\n        else:\n            loss = torch.tensor([0]).float().to(device)\n            if action_sequence is not None and action_sequence.size(1) > 1:\n                try:\n                    loss = self._decoder_trainer.decode(initial_state,\n                                                        self._transition_function,\n                                                        (action_sequence.unsqueeze(1),\n                                                         action_mask.unsqueeze(1)))['loss']\n                except ZeroDivisionError:\n                    # reached a dead-end during beam search\n                    pass\n\n            outputs: Dict[str, Any] = {\n                'loss': loss\n            }\n\n            num_steps = self._max_decoding_steps\n            # This tells the state to start keeping track of debug info, which we'll pass along in\n            # our output dictionary.\n            initial_state.debug_info = [[] for _ in range(batch_size)]\n\n            best_final_states = self._beam_search.search(num_steps,\n                                                         initial_state,\n                                                         self._transition_function,\n                                                         keep_final_unfinished_states=False)\n\n            self._compute_validation_outputs(valid_actions,\n                                             best_final_states,\n                                             world,\n                                             action_sequence,\n                                             outputs)\n            return outputs\n\n    def _get_initial_state(self,\n                           utterance: Dict[str, torch.LongTensor],\n                           worlds: List[SpiderWorld],\n                           schema: Dict[str, torch.LongTensor],\n                           actions: List[List[ProductionRule]]) -> GrammarBasedState:\n        schema_text = schema['text']\n        embedded_schema = self._question_embedder(schema_text, num_wrapping_dims=1)\n        schema_mask = util.get_text_field_mask(schema_text, num_wrapping_dims=1).float()\n\n        embedded_utterance = self._question_embedder(utterance)\n        utterance_mask = util.get_text_field_mask(utterance).float()\n\n        batch_size, num_entities, num_entity_tokens, _ = embedded_schema.size()\n        num_entities = max([len(world.db_context.knowledge_graph.entities) for world in worlds])\n        num_question_tokens = utterance['tokens'].size(1)\n\n        # entity_types: tensor with shape (batch_size, num_entities), where each entry is the\n        # entity's type id.\n        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index\n        # These encode the same information, but for efficiency reasons later it's nice\n        # to have one version as a tensor and one that's accessible on the cpu.\n        entity_types, entity_type_dict = self._get_type_vector(worlds, num_entities, embedded_schema.device)\n\n        entity_type_embeddings = self._entity_type_encoder_embedding(entity_types)\n\n        # Compute entity and question word similarity.  We tried using cosine distance here, but\n        # because this similarity is the main mechanism that the model can use to push apart logit\n        # scores for certain actions (like \"n -> 1\" and \"n -> -1\"), this needs to have a larger\n        # output range than [-1, 1].\n        question_entity_similarity = torch.bmm(embedded_schema.view(batch_size,\n                                                                    num_entities * num_entity_tokens,\n                                                                    self._embedding_dim),\n                                               torch.transpose(embedded_utterance, 1, 2))\n\n        question_entity_similarity = question_entity_similarity.view(batch_size,\n                                                                     num_entities,\n                                                                     num_entity_tokens,\n                                                                     num_question_tokens)\n        # (batch_size, num_entities, num_question_tokens)\n        question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2)\n\n        # (batch_size, num_entities, num_question_tokens, num_features)\n        linking_features = schema['linking']\n\n        linking_scores = question_entity_similarity_max_score\n\n        feature_scores = self._linking_params(linking_features).squeeze(3)\n\n        linking_scores = linking_scores + feature_scores\n\n        # (batch_size, num_question_tokens, num_entities)\n        linking_probabilities = self._get_linking_probabilities(worlds, linking_scores.transpose(1, 2),\n                                                                utterance_mask, entity_type_dict)\n\n        # (batch_size, num_entities, num_neighbors) or None\n        neighbor_indices = self._get_neighbor_indices(worlds, num_entities, linking_scores.device)\n\n        if self._use_neighbor_similarity_for_linking and neighbor_indices is not None:\n            # (batch_size, num_entities, embedding_dim)\n            encoded_table = self._entity_encoder(embedded_schema, schema_mask)\n\n            # Neighbor_indices is padded with -1 since 0 is a potential neighbor index.\n            # Thus, the absolute value needs to be taken in the index_select, and 1 needs to\n            # be added for the mask since that method expects 0 for padding.\n            # (batch_size, num_entities, num_neighbors, embedding_dim)\n            embedded_neighbors = util.batched_index_select(encoded_table, torch.abs(neighbor_indices))\n\n            neighbor_mask = util.get_text_field_mask({'ignored': neighbor_indices + 1},\n                                                     num_wrapping_dims=1).float()\n\n            # Encoder initialized to easily obtain a masked average.\n            neighbor_encoder = TimeDistributed(BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True))\n            # (batch_size, num_entities, embedding_dim)\n            embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask)\n            projected_neighbor_embeddings = self._neighbor_params(embedded_neighbors.float())\n\n            # (batch_size, num_entities, embedding_dim)\n            entity_embeddings = torch.tanh(entity_type_embeddings + projected_neighbor_embeddings)\n        else:\n            # (batch_size, num_entities, embedding_dim)\n            entity_embeddings = torch.tanh(entity_type_embeddings)\n\n        link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities)\n        encoder_input = torch.cat([link_embedding, embedded_utterance], 2)\n\n        # (batch_size, utterance_length, encoder_output_dim)\n        encoder_outputs = self._dropout(self._encoder(encoder_input, utterance_mask))\n\n        max_entities_relevance = linking_probabilities.max(dim=1)[0]\n        entities_relevance = max_entities_relevance.unsqueeze(-1).detach()\n\n        graph_initial_embedding = entity_type_embeddings * entities_relevance\n\n        encoder_output_dim = self._encoder.get_output_dim()\n        if self._gnn:\n            entities_graph_encoding = self._get_schema_graph_encoding(worlds,\n                                                                      graph_initial_embedding)\n            graph_link_embedding = util.weighted_sum(entities_graph_encoding, linking_probabilities)\n            encoder_outputs = torch.cat((\n                encoder_outputs,\n                graph_link_embedding\n            ), dim=-1)\n            encoder_output_dim = self._action_embedding_dim + self._encoder.get_output_dim()\n        else:\n            entities_graph_encoding = None\n\n        if self._self_attend:\n            # linked_actions_linking_scores = self._get_linked_actions_linking_scores(actions, entities_graph_encoding)\n            entities_ff = self._ent2ent_ff(entities_graph_encoding)\n            linked_actions_linking_scores = torch.bmm(entities_ff, entities_ff.transpose(1, 2))\n        else:\n            linked_actions_linking_scores = [None] * batch_size\n\n        # This will be our initial hidden state and memory cell for the decoder LSTM.\n        final_encoder_output = util.get_final_encoder_states(encoder_outputs,\n                                                             utterance_mask,\n                                                             self._encoder.is_bidirectional())\n        memory_cell = encoder_outputs.new_zeros(batch_size, encoder_output_dim)\n        initial_score = embedded_utterance.data.new_zeros(batch_size)\n\n        # To make grouping states together in the decoder easier, we convert the batch dimension in\n        # all of our tensors into an outer list.  For instance, the encoder outputs have shape\n        # `(batch_size, utterance_length, encoder_output_dim)`.  We need to convert this into a list\n        # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`.  Then we\n        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.\n        initial_score_list = [initial_score[i] for i in range(batch_size)]\n        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]\n        utterance_mask_list = [utterance_mask[i] for i in range(batch_size)]\n        initial_rnn_state = []\n        for i in range(batch_size):\n            initial_rnn_state.append(RnnStatelet(final_encoder_output[i],\n                                                 memory_cell[i],\n                                                 self._first_action_embedding,\n                                                 self._first_attended_utterance,\n                                                 encoder_output_list,\n                                                 utterance_mask_list))\n\n        initial_grammar_state = [self._create_grammar_state(worlds[i],\n                                                            actions[i],\n                                                            linking_scores[i],\n                                                            linked_actions_linking_scores[i],\n                                                            entity_types[i],\n                                                            entities_graph_encoding[\n                                                                i] if entities_graph_encoding is not None else None)\n                                 for i in range(batch_size)]\n\n        initial_sql_state = [SqlState(actions[i], self.parse_sql_on_decoding) for i in range(batch_size)]\n\n        initial_state = GrammarBasedState(batch_indices=list(range(batch_size)),\n                                          action_history=[[] for _ in range(batch_size)],\n                                          score=initial_score_list,\n                                          rnn_state=initial_rnn_state,\n                                          grammar_state=initial_grammar_state,\n                                          sql_state=initial_sql_state,\n                                          possible_actions=actions,\n                                          action_entity_mapping=[w.get_action_entity_mapping() for w in worlds])\n\n        return initial_state\n\n    @staticmethod\n    def _get_neighbor_indices(worlds: List[SpiderWorld],\n                              num_entities: int,\n                              device: torch.device) -> torch.LongTensor:\n        \"\"\"\n        This method returns the indices of each entity's neighbors. A tensor\n        is accepted as a parameter for copying purposes.\n\n        Parameters\n        ----------\n        worlds : ``List[SpiderWorld]``\n        num_entities : ``int``\n        tensor : ``torch.Tensor``\n            Used for copying the constructed list onto the right device.\n\n        Returns\n        -------\n        A ``torch.LongTensor`` with shape ``(batch_size, num_entities, num_neighbors)``. It is padded\n        with -1 instead of 0, since 0 is a valid neighbor index. If all the entities in the batch\n        have no neighbors, None will be returned.\n        \"\"\"\n\n        num_neighbors = 0\n        for world in worlds:\n            for entity in world.db_context.knowledge_graph.entities:\n                if len(world.db_context.knowledge_graph.neighbors[entity]) > num_neighbors:\n                    num_neighbors = len(world.db_context.knowledge_graph.neighbors[entity])\n\n        batch_neighbors = []\n        no_entities_have_neighbors = True\n        for world in worlds:\n            # Each batch instance has its own world, which has a corresponding table.\n            entities = world.db_context.knowledge_graph.entities\n            entity2index = {entity: i for i, entity in enumerate(entities)}\n            entity2neighbors = world.db_context.knowledge_graph.neighbors\n            neighbor_indexes = []\n            for entity in entities:\n                entity_neighbors = [entity2index[n] for n in entity2neighbors[entity]]\n                if entity_neighbors:\n                    no_entities_have_neighbors = False\n                # Pad with -1 instead of 0, since 0 represents a neighbor index.\n                padded = pad_sequence_to_length(entity_neighbors, num_neighbors, lambda: -1)\n                neighbor_indexes.append(padded)\n            neighbor_indexes = pad_sequence_to_length(neighbor_indexes,\n                                                      num_entities,\n                                                      lambda: [-1] * num_neighbors)\n            batch_neighbors.append(neighbor_indexes)\n        # It is possible that none of the entities has any neighbors, since our definition of the\n        # knowledge graph allows it when no entities or numbers were extracted from the question.\n        if no_entities_have_neighbors:\n            return None\n        return torch.tensor(batch_neighbors, device=device, dtype=torch.long)\n\n    def _get_schema_graph_encoding(self,\n                                   worlds: List[SpiderWorld],\n                                   initial_graph_embeddings: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n        max_num_entities = max([len(world.db_context.knowledge_graph.entities) for world in worlds])\n        batch_size = initial_graph_embeddings.size(0)\n\n        graph_data_list = []\n\n        for batch_index, world in enumerate(worlds):\n            x = initial_graph_embeddings[batch_index]\n\n            adj_list = self._get_graph_adj_lists(initial_graph_embeddings.device,\n                                                 world, initial_graph_embeddings.size(1) - 1)\n            graph_data = Data(x)\n            for i, l in enumerate(adj_list):\n                graph_data[f'edge_index_{i}'] = l\n            graph_data_list.append(graph_data)\n\n        batch = Batch.from_data_list(graph_data_list)\n\n        gnn_output = self._gnn(batch.x, [batch[f'edge_index_{i}'] for i in range(self._gnn.num_edge_types)])\n\n        num_nodes = max_num_entities\n        gnn_output = gnn_output.view(batch_size, num_nodes, -1)\n        # entities_encodings = gnn_output\n        entities_encodings = gnn_output[:, :max_num_entities]\n        # global_node_encodings = gnn_output[:, max_num_entities]\n\n        return entities_encodings\n\n    @staticmethod\n    def _get_graph_adj_lists(device, world, global_entity_id, global_node=False):\n        entity_mapping = {}\n        for i, entity in enumerate(world.db_context.knowledge_graph.entities):\n            entity_mapping[entity] = i\n        entity_mapping['_global_'] = global_entity_id\n        adj_list_own = []  # column--table\n        adj_list_link = []  # table->table / foreign->primary\n        adj_list_linked = []  # table<-table / foreign<-primary\n        adj_list_global = []  # node->global\n\n        # TODO: Prepare in advance?\n        for key, neighbors in world.db_context.knowledge_graph.neighbors.items():\n            idx_source = entity_mapping[key]\n            for n_key in neighbors:\n                idx_target = entity_mapping[n_key]\n                if n_key.startswith(\"table\") or key.startswith(\"table\"):\n                    adj_list_own.append((idx_source, idx_target))\n                elif n_key.startswith(\"string\") or key.startswith(\"string\"):\n                    adj_list_own.append((idx_source, idx_target))\n                elif key.startswith(\"column:foreign\"):\n                    adj_list_link.append((idx_source, idx_target))\n                    src_table_key = f\"table:{key.split(':')[2]}\"\n                    tgt_table_key = f\"table:{n_key.split(':')[2]}\"\n                    idx_source_table = entity_mapping[src_table_key]\n                    idx_target_table = entity_mapping[tgt_table_key]\n                    adj_list_link.append((idx_source_table, idx_target_table))\n                elif n_key.startswith(\"column:foreign\"):\n                    adj_list_linked.append((idx_source, idx_target))\n                    src_table_key = f\"table:{key.split(':')[2]}\"\n                    tgt_table_key = f\"table:{n_key.split(':')[2]}\"\n                    idx_source_table = entity_mapping[src_table_key]\n                    idx_target_table = entity_mapping[tgt_table_key]\n                    adj_list_linked.append((idx_source_table, idx_target_table))\n                else:\n                    assert False\n\n            adj_list_global.append((idx_source, entity_mapping['_global_']))\n\n        all_adj_types = [adj_list_own, adj_list_link, adj_list_linked]\n\n        if global_node:\n            all_adj_types.append(adj_list_global)\n\n        return [torch.tensor(l, device=device, dtype=torch.long).transpose(0, 1) if l\n                else torch.tensor(l, device=device, dtype=torch.long)\n                for l in all_adj_types]\n\n    def _create_grammar_state(self,\n                              world: SpiderWorld,\n                              possible_actions: List[ProductionRule],\n                              linking_scores: torch.Tensor,\n                              linked_actions_linking_scores: torch.Tensor,\n                              entity_types: torch.Tensor,\n                              entity_graph_encoding: torch.Tensor) -> GrammarStatelet:\n        action_map = {}\n        for action_index, action in enumerate(possible_actions):\n            action_string = action[0]\n            action_map[action_string] = action_index\n\n        valid_actions = world.valid_actions\n        entity_map = {}\n        entities = world.entities_names\n\n        for entity_index, entity in enumerate(entities):\n            entity_map[entity] = entity_index\n\n        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {}\n        for key, action_strings in valid_actions.items():\n            translated_valid_actions[key] = {}\n            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid\n            # productions of that non-terminal.  We'll first split those productions by global vs.\n            # linked action.\n\n            action_indices = [action_map[action_string] for action_string in action_strings]\n            production_rule_arrays = [(possible_actions[index], index) for index in action_indices]\n            global_actions = []\n            linked_actions = []\n            for production_rule_array, action_index in production_rule_arrays:\n                if production_rule_array[1]:\n                    global_actions.append((production_rule_array[2], action_index))\n                else:\n                    linked_actions.append((production_rule_array[0], action_index))\n\n            if global_actions:\n                global_action_tensors, global_action_ids = zip(*global_actions)\n                global_action_tensor = torch.cat(global_action_tensors, dim=0).to(\n                    global_action_tensors[0].device).long()\n                global_input_embeddings = self._action_embedder(global_action_tensor)\n                global_output_embeddings = self._output_action_embedder(global_action_tensor)\n                translated_valid_actions[key]['global'] = (global_input_embeddings,\n                                                           global_output_embeddings,\n                                                           list(global_action_ids))\n            if linked_actions:\n                linked_rules, linked_action_ids = zip(*linked_actions)\n                entities = [rule.split(' -> ')[1].strip('[]\\\"') for rule in linked_rules]\n\n                entity_ids = [entity_map[entity] for entity in entities]\n\n                entity_linking_scores = linking_scores[entity_ids]\n\n                if linked_actions_linking_scores is not None:\n                    entity_action_linking_scores = linked_actions_linking_scores[entity_ids]\n\n                if not self._decoder_use_graph_entities:\n                    entity_type_tensor = entity_types[entity_ids]\n                    entity_type_embeddings = (self._entity_type_decoder_embedding(entity_type_tensor)\n                                              .to(entity_types.device)\n                                              .float())\n                else:\n                    entity_type_embeddings = entity_graph_encoding.index_select(\n                        dim=0,\n                        index=torch.tensor(entity_ids, device=entity_graph_encoding.device)\n                    )\n\n                if self._self_attend:\n                    translated_valid_actions[key]['linked'] = (entity_linking_scores,\n                                                               entity_type_embeddings,\n                                                               list(linked_action_ids),\n                                                               entity_action_linking_scores)\n                else:\n                    translated_valid_actions[key]['linked'] = (entity_linking_scores,\n                                                               entity_type_embeddings,\n                                                               list(linked_action_ids))\n\n        return GrammarStatelet(['statement'],\n                               translated_valid_actions,\n                               self.is_nonterminal)\n\n    @staticmethod\n    def is_nonterminal(token: str):\n        if token[0] == '\"' and token[-1] == '\"':\n            return False\n        return True\n\n    def _get_linking_probabilities(self,\n                                   worlds: List[SpiderWorld],\n                                   linking_scores: torch.FloatTensor,\n                                   question_mask: torch.LongTensor,\n                                   entity_type_dict: Dict[int, int]) -> torch.FloatTensor:\n        \"\"\"\n        Produces the probability of an entity given a question word and type. The logic below\n        separates the entities by type since the softmax normalization term sums over entities\n        of a single type.\n\n        Parameters\n        ----------\n        worlds : ``List[WikiTablesWorld]``\n        linking_scores : ``torch.FloatTensor``\n            Has shape (batch_size, num_question_tokens, num_entities).\n        question_mask: ``torch.LongTensor``\n            Has shape (batch_size, num_question_tokens).\n        entity_type_dict : ``Dict[int, int]``\n            This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.\n\n        Returns\n        -------\n        batch_probabilities : ``torch.FloatTensor``\n            Has shape ``(batch_size, num_question_tokens, num_entities)``.\n            Contains all the probabilities for an entity given a question word.\n        \"\"\"\n        _, num_question_tokens, num_entities = linking_scores.size()\n        batch_probabilities = []\n\n        for batch_index, world in enumerate(worlds):\n            all_probabilities = []\n            num_entities_in_instance = 0\n\n            # NOTE: The way that we're doing this here relies on the fact that entities are\n            # implicitly sorted by their types when we sort them by name, and that numbers come\n            # before \"date_column:\", followed by \"number_column:\", \"string:\", and \"string_column:\".\n            # This is not a great assumption, and could easily break later, but it should work for now.\n            for type_index in range(self._num_entity_types):\n                # This index of 0 is for the null entity for each type, representing the case where a\n                # word doesn't link to any entity.\n                entity_indices = [0]\n                entities = world.db_context.knowledge_graph.entities\n                for entity_index, _ in enumerate(entities):\n                    if entity_type_dict[batch_index * num_entities + entity_index] == type_index:\n                        entity_indices.append(entity_index)\n\n                if len(entity_indices) == 1:\n                    # No entities of this type; move along...\n                    continue\n\n                # We're subtracting one here because of the null entity we added above.\n                num_entities_in_instance += len(entity_indices) - 1\n\n                # We separate the scores by type, since normalization is done per type.  There's an\n                # extra \"null\" entity per type, also, so we have `num_entities_per_type + 1`.  We're\n                # selecting from a (num_question_tokens, num_entities) linking tensor on _dimension 1_,\n                # so we get back something of shape (num_question_tokens,) for each index we're\n                # selecting.  All of the selected indices together then make a tensor of shape\n                # (num_question_tokens, num_entities_per_type + 1).\n                indices = linking_scores.new_tensor(entity_indices, dtype=torch.long)\n                entity_scores = linking_scores[batch_index].index_select(1, indices)\n\n                # We used index 0 for the null entity, so this will actually have some values in it.\n                # But we want the null entity's score to be 0, so we set that here.\n                entity_scores[:, 0] = 0\n\n                # No need for a mask here, as this is done per batch instance, with no padding.\n                type_probabilities = torch.nn.functional.softmax(entity_scores, dim=1)\n                all_probabilities.append(type_probabilities[:, 1:])\n\n            # We need to add padding here if we don't have the right number of entities.\n            if num_entities_in_instance != num_entities:\n                zeros = linking_scores.new_zeros(num_question_tokens,\n                                                 num_entities - num_entities_in_instance)\n                all_probabilities.append(zeros)\n\n            # (num_question_tokens, num_entities)\n            probabilities = torch.cat(all_probabilities, dim=1)\n            batch_probabilities.append(probabilities)\n        batch_probabilities = torch.stack(batch_probabilities, dim=0)\n        return batch_probabilities * question_mask.unsqueeze(-1).float()\n\n    @staticmethod\n    def _action_history_match(predicted: List[int], targets: torch.LongTensor) -> int:\n        # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something.\n        # Check if target is big enough to cover prediction (including start/end symbols)\n        if len(predicted) > targets.size(0):\n            return 0\n        predicted_tensor = targets.new_tensor(predicted)\n        targets_trimmed = targets[:len(predicted)]\n        # Return 1 if the predicted sequence is anywhere in the list of targets.\n        return torch.max(torch.min(targets_trimmed.eq(predicted_tensor), dim=0)[0]).item()\n\n    @staticmethod\n    def _query_difficulty(targets: torch.LongTensor, action_mapping, batch_index):\n        number_tables = len([action_mapping[(batch_index, int(a))] for a in targets if\n                             a >= 0 and action_mapping[(batch_index, int(a))].startswith('table_name')])\n        return number_tables > 1\n\n    @overrides\n    def get_metrics(self, reset: bool = False) -> Dict[str, float]:\n        return {\n            '_match/exact_match': self._exact_match.get_metric(reset),\n            'sql_match': self._sql_evaluator_match.get_metric(reset),\n            '_others/action_similarity': self._action_similarity.get_metric(reset),\n            '_match/match_single': self._acc_single.get_metric(reset),\n            '_match/match_hard': self._acc_multi.get_metric(reset),\n            'beam_hit': self._beam_hit.get_metric(reset)\n        }\n\n    @staticmethod\n    def _get_type_vector(worlds: List[SpiderWorld],\n                         num_entities: int,\n                         device) -> Tuple[torch.LongTensor, Dict[int, int]]:\n        \"\"\"\n        Produces the encoding for each entity's type. In addition, a map from a flattened entity\n        index to type is returned to combine entity type operations into one method.\n\n        Parameters\n        ----------\n        worlds : ``List[AtisWorld]``\n        num_entities : ``int``\n        tensor : ``torch.Tensor``\n            Used for copying the constructed list onto the right device.\n\n        Returns\n        -------\n        A ``torch.LongTensor`` with shape ``(batch_size, num_entities, num_types)``.\n        entity_types : ``Dict[int, int]``\n            This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.\n        \"\"\"\n        entity_types = {}\n        batch_types = []\n\n        column_type_ids = ['boolean', 'foreign', 'number', 'others', 'primary', 'text', 'time']\n\n        for batch_index, world in enumerate(worlds):\n            types = []\n\n            for entity_index, entity in enumerate(world.db_context.knowledge_graph.entities):\n                parts = entity.split(':')\n                entity_main_type = parts[0]\n                if entity_main_type == 'column':\n                    column_type = parts[1]\n                    entity_type = column_type_ids.index(column_type)\n                elif entity_main_type == 'string':\n                    # cell value\n                    entity_type = len(column_type_ids)\n                elif entity_main_type == 'table':\n                    entity_type = len(column_type_ids) + 1\n                else:\n                    raise (Exception(\"Unkown entity\"))\n                types.append(entity_type)\n\n                # For easier lookups later, we're actually using a _flattened_ version\n                # of (batch_index, entity_index) for the key, because this is how the\n                # linking scores are stored.\n                flattened_entity_index = batch_index * num_entities + entity_index\n                entity_types[flattened_entity_index] = entity_type\n            padded = pad_sequence_to_length(types, num_entities, lambda: 0)\n            batch_types.append(padded)\n\n        return torch.tensor(batch_types, dtype=torch.long, device=device), entity_types\n\n    def _compute_validation_outputs(self,\n                                    actions: List[List[ProductionRuleArray]],\n                                    best_final_states: Mapping[int, Sequence[GrammarBasedState]],\n                                    world: List[SpiderWorld],\n                                    target_list: List[List[str]],\n                                    outputs: Dict[str, Any]) -> None:\n        batch_size = len(actions)\n\n        outputs['predicted_sql_query'] = []\n\n        action_mapping = {}\n        for batch_index, batch_actions in enumerate(actions):\n            for action_index, action in enumerate(batch_actions):\n                action_mapping[(batch_index, action_index)] = action[0]\n\n        for i in range(batch_size):\n            # gold sql exactly as given\n            original_gold_sql_query = ' '.join(world[i].get_query_without_table_hints())\n\n            if i not in best_final_states:\n                self._exact_match(0)\n                self._action_similarity(0)\n                self._sql_evaluator_match(0)\n                self._acc_multi(0)\n                self._acc_single(0)\n                outputs['predicted_sql_query'].append('')\n                continue\n\n            best_action_indices = best_final_states[i][0].action_history[0]\n\n            action_strings = [action_mapping[(i, action_index)]\n                              for action_index in best_action_indices]\n            predicted_sql_query = action_sequence_to_sql(action_strings, add_table_names=True)\n            outputs['predicted_sql_query'].append(sqlparse.format(predicted_sql_query, reindent=False))\n\n            if target_list is not None:\n                targets = target_list[i].data\n            target_available = target_list is not None and targets[0] > -1\n\n            if target_available:\n                sequence_in_targets = self._action_history_match(best_action_indices, targets)\n                self._exact_match(sequence_in_targets)\n\n                sql_evaluator_match = self._evaluate_func(original_gold_sql_query, predicted_sql_query, world[i].db_id)\n                self._sql_evaluator_match(sql_evaluator_match)\n\n                similarity = difflib.SequenceMatcher(None, best_action_indices, targets)\n                self._action_similarity(similarity.ratio())\n\n                difficulty = self._query_difficulty(targets, action_mapping, i)\n                if difficulty:\n                    self._acc_multi(sql_evaluator_match)\n                else:\n                    self._acc_single(sql_evaluator_match)\n\n            beam_hit = False\n            for pos, final_state in enumerate(best_final_states[i]):\n                action_indices = final_state.action_history[0]\n                action_strings = [action_mapping[(i, action_index)]\n                                  for action_index in action_indices]\n                candidate_sql_query = action_sequence_to_sql(action_strings, add_table_names=True)\n\n                if target_available:\n                    correct = self._evaluate_func(original_gold_sql_query, candidate_sql_query, world[i].db_id)\n                    if correct:\n                        beam_hit = True\n                    self._beam_hit(beam_hit)\n"
  },
  {
    "path": "modules/gated_graph_conv.py",
    "content": "import math\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Parameter as Param, init\nfrom torch_geometric.data import Data\nfrom torch_geometric.nn.conv import MessagePassing\n\n\nclass GatedGraphConv(MessagePassing):\n    r\"\"\"The gated graph convolution operator from the `\"Gated Graph Sequence\n    Neural Networks\" <https://arxiv.org/abs/1511.05493>`_ paper\n    .. math::\n        \\mathbf{h}_i^{(0)} &= \\mathbf{x}_i \\, \\Vert \\, \\mathbf{0}\n        \\mathbf{m}_i^{(l+1)} &= \\sum_{j \\in \\mathcal{N}(i)} \\mathbf{\\Theta}\n        \\cdot \\mathbf{h}_j^{(l)}\n        \\mathbf{h}_i^{(l+1)} &= \\textrm{GRU} (\\mathbf{m}_i^{(l+1)},\n        \\mathbf{h}_i^{(l)})\n    up to representation :math:`\\mathbf{h}_i^{(L)}`.\n    The number of input channels of :math:`\\mathbf{x}_i` needs to be less or\n    equal than :obj:`out_channels`.\n    Args:\n        out_channels (int): Size of each input sample.\n        num_layers (int): The sequence length :math:`L`.\n        aggr (string): The aggregation scheme to use\n            (:obj:`\"add\"`, :obj:`\"mean\"`, :obj:`\"max\"`).\n            (default: :obj:`\"add\"`)\n        bias (bool, optional): If set to :obj:`False`, the layer will not learn\n            an additive bias. (default: :obj:`True`)\n    \"\"\"\n\n    def __init__(self, input_dim, num_timesteps, num_edge_types, aggr='add', bias=True, dropout=0):\n        super(GatedGraphConv, self).__init__(aggr)\n\n        self._input_dim = input_dim\n        self.num_timesteps = num_timesteps\n        self.num_edge_types = num_edge_types\n\n        self.weight = Param(Tensor(num_timesteps, num_edge_types, input_dim, input_dim))\n        self.bias = Param(Tensor(num_timesteps, num_edge_types, input_dim))\n        self.rnn = torch.nn.GRUCell(input_dim, input_dim, bias=bias)\n        self.dropout = torch.nn.Dropout(dropout)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        for t in range(self.num_timesteps):\n            for e in range(self.num_edge_types):\n                torch.nn.init.xavier_uniform_(self.weight[t, e])\n        init.uniform_(self.bias, -0.01, 0.01)\n        self.rnn.reset_parameters()\n\n    def forward(self, x, edge_indices):\n        \"\"\"\"\"\"\n        if len(edge_indices) != self.num_edge_types:\n            raise ValueError(f'GatedGraphConv constructed with {self.num_edge_types} edge types, '\n                             f'but {len(edge_indices)} were passed')\n\n        h = x\n\n        if h.size(1) > self._input_dim:\n            raise ValueError('The number of input channels is not allowed to '\n                             'be larger than the number of output channels')\n\n        if h.size(1) < self._input_dim:\n            zero = h.new_zeros(h.size(0), self._input_dim - h.size(1))\n            h = torch.cat([h, zero], dim=1)\n\n        for t in range(self.num_timesteps):\n            new_h = []\n            for e in range(self.num_edge_types):\n                if len(edge_indices[e]) == 0:\n                    continue\n                m = self.dropout(torch.matmul(h, self.weight[t, e]) + self.bias[t, e])\n                new_h.append(self.propagate(edge_indices[e], size=(x.size(0), x.size(0)), x=m))\n            m_sum = torch.sum(torch.stack(new_h), dim=0)\n            h = self.rnn(m_sum, h)\n\n        return h\n\n    def __repr__(self):\n        return '{}({}, num_layers={})'.format(\n            self.__class__.__name__, self._input_dim, self.num_timesteps)\n\n\nif __name__ == '__main__':\n    gcn = GatedGraphConv(input_dim=10, num_timesteps=3, num_edge_types=3)\n    data = Data(torch.zeros((5, 10)), edge_index=[\n        torch.tensor([[1,2],[2,3]]),\n        torch.tensor([[1,3],[0,1]]),\n        torch.tensor([[1,4],[2,3]]),\n    ])\n    output = gcn(data.x, data.edge_index)\n\n    print(output)\n"
  },
  {
    "path": "predictors/spider_predictor.py",
    "content": "from overrides import overrides\n\nfrom allennlp.common.util import JsonDict, sanitize\nfrom allennlp.data import DatasetReader, Instance\nfrom allennlp.models import Model\nfrom allennlp.predictors.predictor import Predictor\n\n\n@Predictor.register(\"spider\")\nclass WikiTablesParserPredictor(Predictor):\n    def __init__(self, model: Model, dataset_reader: DatasetReader) -> None:\n        super().__init__(model, dataset_reader)\n\n    @overrides\n    def predict_instance(self, instance: Instance) -> JsonDict:\n        json_output = {}\n        outputs = self._model.forward_on_instance(instance)\n        predicted_sql_query = outputs['predicted_sql_query'].replace('\\n', ' ')\n        if predicted_sql_query == '':\n            # line must not be empty for the evaluator to consider it\n            predicted_sql_query = 'NO PREDICTION'\n        json_output['predicted_sql_query'] = predicted_sql_query\n        return sanitize(json_output)\n\n    @overrides\n    def dump_line(self, outputs: JsonDict) -> str:  # pylint: disable=no-self-use\n        \"\"\"\n        If you don't want your outputs in JSON-lines format\n        you can override this function to output them differently.\n        \"\"\"\n        return outputs['predicted_sql_query'] + \"\\n\"\n"
  },
  {
    "path": "requirements.txt",
    "content": "torch==1.0.1.post2\nspacy==2.0.18\nallennlp==0.8.2\ndill\ntorch-scatter==1.1.2\ntorch-sparse==0.2.4\ntorch-cluster==1.2.4\ntorch-geometric==1.1.2\nordered_set\nhttps://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.0.0/en_core_web_sm-2.0.0.tar.gz"
  },
  {
    "path": "semparse/contexts/spider_context_utils.py",
    "content": "import re\nfrom collections import defaultdict\nfrom sys import exc_info\nfrom typing import List, Dict, Set\n\nfrom overrides import overrides\nfrom parsimonious.exceptions import VisitationError, UndefinedLabel\nfrom parsimonious.expressions import Literal, OneOf, Sequence\nfrom parsimonious.grammar import Grammar\nfrom parsimonious.nodes import Node, NodeVisitor\nfrom six import reraise\n\nWHITESPACE_REGEX = re.compile(\" wsp |wsp | wsp| ws |ws | ws\")\n\n\ndef format_grammar_string(grammar_dictionary: Dict[str, List[str]]) -> str:\n    \"\"\"\n    Formats a dictionary of production rules into the string format expected\n    by the Parsimonious Grammar class.\n    \"\"\"\n    return '\\n'.join([f\"{nonterminal} = {' / '.join(right_hand_side)}\"\n                      for nonterminal, right_hand_side in grammar_dictionary.items()])\n\n\ndef initialize_valid_actions(grammar: Grammar,\n                             keywords_to_uppercase: List[str] = None) -> Dict[str, List[str]]:\n    \"\"\"\n    We initialize the valid actions with the global actions. These include the\n    valid actions that result from the grammar and also those that result from\n    the tables provided. The keys represent the nonterminals in the grammar\n    and the values are lists of the valid actions of that nonterminal.\n    \"\"\"\n    valid_actions: Dict[str, Set[str]] = defaultdict(set)\n\n    for key in grammar:\n        rhs = grammar[key]\n\n        # Sequence represents a series of expressions that match pieces of the text in order.\n        # Eg. A -> B C\n        if isinstance(rhs, Sequence):\n            valid_actions[key].add(\n                format_action(key, \" \".join(rhs._unicode_members()),  # pylint: disable=protected-access\n                              keywords_to_uppercase=keywords_to_uppercase))\n\n        # OneOf represents a series of expressions, one of which matches the text.\n        # Eg. A -> B / C\n        elif isinstance(rhs, OneOf):\n            for option in rhs._unicode_members():  # pylint: disable=protected-access\n                valid_actions[key].add(format_action(key, option,\n                                                     keywords_to_uppercase=keywords_to_uppercase))\n\n        # A string literal, eg. \"A\"\n        elif isinstance(rhs, Literal):\n            if rhs.literal != \"\":\n                valid_actions[key].add(format_action(key, repr(rhs.literal),\n                                                     keywords_to_uppercase=keywords_to_uppercase))\n            else:\n                valid_actions[key] = set()\n\n    valid_action_strings = {key: sorted(value) for key, value in valid_actions.items()}\n    return valid_action_strings\n\n\ndef format_action(nonterminal: str,\n                  right_hand_side: str,\n                  is_string: bool = False,\n                  is_number: bool = False,\n                  keywords_to_uppercase: List[str] = None) -> str:\n    \"\"\"\n    This function formats an action as it appears in models. It\n    splits productions based on the special `ws` and `wsp` rules,\n    which are used in grammars to denote whitespace, and then\n    rejoins these tokens a formatted, comma separated list.\n    Importantly, note that it `does not` split on spaces in\n    the grammar string, because these might not correspond\n    to spaces in the language the grammar recognises.\n\n    Parameters\n    ----------\n    nonterminal : ``str``, required.\n        The nonterminal in the action.\n    right_hand_side : ``str``, required.\n        The right hand side of the action\n        (i.e the thing which is produced).\n    is_string : ``bool``, optional (default = False).\n        Whether the production produces a string.\n        If it does, it is formatted as ``nonterminal -> ['string']``\n    is_number : ``bool``, optional, (default = False).\n        Whether the production produces a string.\n        If it does, it is formatted as ``nonterminal -> ['number']``\n    keywords_to_uppercase: ``List[str]``, optional, (default = None)\n        Keywords in the grammar to uppercase. In the case of sql,\n        this might be SELECT, MAX etc.\n    \"\"\"\n    keywords_to_uppercase = keywords_to_uppercase or []\n    if right_hand_side.upper() in keywords_to_uppercase:\n        right_hand_side = right_hand_side.upper()\n\n    if is_string:\n        return f'{nonterminal} -> [\"\\'{right_hand_side}\\'\"]'\n\n    elif is_number:\n        return f'{nonterminal} -> [\"{right_hand_side}\"]'\n\n    else:\n        right_hand_side = right_hand_side.lstrip(\"(\").rstrip(\")\")\n        child_strings = [token for token in WHITESPACE_REGEX.split(right_hand_side) if token]\n        child_strings = [tok.upper() if tok.upper() in keywords_to_uppercase else tok for tok in child_strings]\n        return f\"{nonterminal} -> [{', '.join(child_strings)}]\"\n\n\ndef action_sequence_to_sql(action_sequences: List[str], add_table_names: bool=False) -> str:\n    # Convert an action sequence like ['statement -> [query, \";\"]', ...] to the\n    # SQL string.\n    query = []\n    for action in action_sequences:\n        nonterminal, right_hand_side = action.split(' -> ')\n        right_hand_side_tokens = right_hand_side[1:-1].split(', ')\n        if nonterminal == 'statement':\n            query.extend(right_hand_side_tokens)\n        else:\n            for query_index, token in list(enumerate(query)):\n                if token == nonterminal:\n                    if nonterminal == 'column_name' and '@' in right_hand_side_tokens[0] and len(right_hand_side_tokens) == 1:\n                        if add_table_names:\n                            table_name, column_name = right_hand_side_tokens[0].split('@')\n                            if '.' in table_name:\n                                table_name = table_name.split('.')[0]\n                            right_hand_side_tokens = [table_name + '.' + column_name]\n                        else:\n                            right_hand_side_tokens = [right_hand_side_tokens[0].split('@')[-1]]\n                    query = query[:query_index] + \\\n                            right_hand_side_tokens + \\\n                            query[query_index + 1:]\n                    break\n    return ' '.join([token.strip('\"') for token in query])\n\n\nclass SqlVisitor(NodeVisitor):\n    \"\"\"\n    ``SqlVisitor`` performs a depth-first traversal of the the AST. It takes the parse tree\n    and gives us an action sequence that resulted in that parse. Since the visitor has mutable\n    state, we define a new ``SqlVisitor`` for each query. To get the action sequence, we create\n    a ``SqlVisitor`` and call parse on it, which returns a list of actions. Ex.\n\n        sql_visitor = SqlVisitor(grammar_string)\n        action_sequence = sql_visitor.parse(query)\n\n    Importantly, this ``SqlVisitor`` skips over ``ws`` and ``wsp`` nodes,\n    because they do not hold any meaning, and make an action sequence\n    much longer than it needs to be.\n\n    Parameters\n    ----------\n    grammar : ``Grammar``\n        A Grammar object that we use to parse the text.\n    keywords_to_uppercase: ``List[str]``, optional, (default = None)\n        Keywords in the grammar to uppercase. In the case of sql,\n        this might be SELECT, MAX etc.\n    \"\"\"\n\n    def __init__(self, grammar: Grammar, keywords_to_uppercase: List[str] = None) -> None:\n        self.action_sequence: List[str] = []\n        self.grammar: Grammar = grammar\n        self.keywords_to_uppercase = keywords_to_uppercase or []\n\n    @overrides\n    def generic_visit(self, node: Node, visited_children: List[None]) -> List[str]:\n        self.add_action(node)\n        if node.expr.name == 'statement':\n            return self.action_sequence\n        return []\n\n    def add_action(self, node: Node) -> None:\n        \"\"\"\n        For each node, we accumulate the rules that generated its children in a list.\n        \"\"\"\n        if node.expr.name and node.expr.name not in ['ws', 'wsp']:\n            nonterminal = f'{node.expr.name} -> '\n\n            if isinstance(node.expr, Literal):\n                right_hand_side = f'[\"{node.text}\"]'\n\n            else:\n                child_strings = []\n                for child in node.__iter__():\n                    if child.expr.name in ['ws', 'wsp']:\n                        continue\n                    if child.expr.name != '':\n                        child_strings.append(child.expr.name)\n                    else:\n                        child_right_side_string = child.expr._as_rhs().lstrip(\"(\").rstrip(\n                            \")\")  # pylint: disable=protected-access\n                        child_right_side_list = [tok for tok in\n                                                 WHITESPACE_REGEX.split(child_right_side_string) if tok]\n                        child_right_side_list = [tok.upper() if tok.upper() in\n                                                                self.keywords_to_uppercase else tok\n                                                 for tok in child_right_side_list]\n                        child_strings.extend(child_right_side_list)\n                right_hand_side = \"[\" + \", \".join(child_strings) + \"]\"\n            rule = nonterminal + right_hand_side\n            self.action_sequence = [rule] + self.action_sequence\n\n    @overrides\n    def visit(self, node):\n        \"\"\"\n        See the ``NodeVisitor`` visit method. This just changes the order in which\n        we visit nonterminals from right to left to left to right.\n        \"\"\"\n        method = getattr(self, 'visit_' + node.expr_name, self.generic_visit)\n\n        # Call that method, and show where in the tree it failed if it blows\n        # up.\n        try:\n            # Changing this to reverse here!\n            return method(node, [self.visit(child) for child in reversed(list(node))])\n        except (VisitationError, UndefinedLabel):\n            # Don't catch and re-wrap already-wrapped exceptions.\n            raise\n        except self.unwrapped_exceptions:\n            raise\n        except Exception:  # pylint: disable=broad-except\n            # Catch any exception, and tack on a parse tree so it's easier to\n            # see where it went wrong.\n            exc_class, exc, traceback = exc_info()\n            reraise(VisitationError, VisitationError(exc, exc_class, node), traceback)\n"
  },
  {
    "path": "semparse/contexts/spider_db_context.py",
    "content": "import re\nfrom collections import Set, defaultdict\nfrom typing import Dict, Tuple, List\n\nfrom allennlp.data import Tokenizer, Token\nfrom ordered_set import OrderedSet\nfrom unidecode import unidecode\n\nfrom dataset_readers.dataset_util.spider_utils import TableColumn, read_dataset_schema, read_dataset_values\nfrom allennlp.semparse.contexts.knowledge_graph import KnowledgeGraph\n\n\n# == stop words that will be omitted by ContextGenerator\nSTOP_WORDS = {\"\", \"\", \"all\", \"being\", \"-\", \"over\", \"through\", \"yourselves\", \"its\", \"before\",\n              \"hadn\", \"with\", \"had\", \",\", \"should\", \"to\", \"only\", \"under\", \"ours\", \"has\", \"ought\", \"do\",\n              \"them\", \"his\", \"than\", \"very\", \"cannot\", \"they\", \"not\", \"during\", \"yourself\", \"him\",\n              \"nor\", \"did\", \"didn\", \"'ve\", \"this\", \"she\", \"each\", \"where\", \"because\", \"doing\", \"some\", \"we\", \"are\",\n              \"further\", \"ourselves\", \"out\", \"what\", \"for\", \"weren\", \"does\", \"above\", \"between\", \"mustn\", \"?\",\n              \"be\", \"hasn\", \"who\", \"were\", \"here\", \"shouldn\", \"let\", \"hers\", \"by\", \"both\", \"about\", \"couldn\",\n              \"of\", \"could\", \"against\", \"isn\", \"or\", \"own\", \"into\", \"while\", \"whom\", \"down\", \"wasn\", \"your\",\n              \"from\", \"her\", \"their\", \"aren\", \"there\", \"been\", \".\", \"few\", \"too\", \"wouldn\", \"themselves\",\n              \":\", \"was\", \"until\", \"more\", \"himself\", \"on\", \"but\", \"don\", \"herself\", \"haven\", \"those\", \"he\",\n              \"me\", \"myself\", \"these\", \"up\", \";\", \"below\", \"'re\", \"can\", \"theirs\", \"my\", \"and\", \"would\", \"then\",\n              \"is\", \"am\", \"it\", \"doesn\", \"an\", \"as\", \"itself\", \"at\", \"have\", \"in\", \"any\", \"if\", \"!\",\n              \"again\", \"'ll\", \"no\", \"that\", \"when\", \"same\", \"how\", \"other\", \"which\", \"you\", \"many\", \"shan\",\n              \"'t\", \"'s\", \"our\", \"after\", \"most\", \"'d\", \"such\", \"'m\", \"why\", \"a\", \"off\", \"i\", \"yours\", \"so\",\n              \"the\", \"having\", \"once\"}\n\n\nclass SpiderDBContext:\n    schemas = {}\n    db_knowledge_graph = {}\n    db_tables_data = {}\n\n    def __init__(self, db_id: str, utterance: str, tokenizer: Tokenizer, tables_file: str, dataset_path: str):\n        self.dataset_path = dataset_path\n        self.tables_file = tables_file\n        self.db_id = db_id\n        self.utterance = utterance\n\n        tokenized_utterance = tokenizer.tokenize(utterance.lower())\n        self.tokenized_utterance = [Token(text=t.text, lemma=t.lemma_) for t in tokenized_utterance]\n\n        if db_id not in SpiderDBContext.schemas:\n            SpiderDBContext.schemas = read_dataset_schema(self.tables_file)\n        self.schema = SpiderDBContext.schemas[db_id]\n\n        self.knowledge_graph = self.get_db_knowledge_graph(db_id)\n\n        entity_texts = [self.knowledge_graph.entity_text[entity].lower()\n                        for entity in self.knowledge_graph.entities]\n        entity_tokens = tokenizer.batch_tokenize(entity_texts)\n        self.entity_tokens = [[Token(text=t.text, lemma=t.lemma_) for t in et] for et in entity_tokens]\n\n    @staticmethod\n    def entity_key_for_column(table_name: str, column: TableColumn) -> str:\n        if column.foreign_key is not None:\n            column_type = \"foreign\"\n        elif column.is_primary_key:\n            column_type = \"primary\"\n        else:\n            column_type = column.column_type\n        return f\"column:{column_type.lower()}:{table_name.lower()}:{column.name.lower()}\"\n\n    def get_db_knowledge_graph(self, db_id: str) -> KnowledgeGraph:\n        entities: Set[str] = set()\n        neighbors: Dict[str, OrderedSet[str]] = defaultdict(OrderedSet)\n        entity_text: Dict[str, str] = {}\n        foreign_keys_to_column: Dict[str, str] = {}\n\n        db_schema = self.schema\n        tables = db_schema.values()\n\n        if db_id not in self.db_tables_data:\n            self.db_tables_data[db_id] = read_dataset_values(db_id, self.dataset_path, tables)\n\n        tables_data = self.db_tables_data[db_id]\n\n        string_column_mapping: Dict[str, set] = defaultdict(set)\n\n        for table, table_data in tables_data.items():\n            for table_row in table_data:\n                for column, cell_value in zip(db_schema[table.name].columns, table_row):\n                    if column.column_type == 'text' and type(cell_value) is str:\n                        cell_value_normalized = self.normalize_string(cell_value)\n                        column_key = self.entity_key_for_column(table.name, column)\n                        string_column_mapping[cell_value_normalized].add(column_key)\n\n        string_entities = self.get_entities_from_question(string_column_mapping)\n\n        for table in tables:\n            table_key = f\"table:{table.name.lower()}\"\n            entities.add(table_key)\n            entity_text[table_key] = table.text\n\n            for column in db_schema[table.name].columns:\n                entity_key = self.entity_key_for_column(table.name, column)\n                entities.add(entity_key)\n                neighbors[entity_key].add(table_key)\n                neighbors[table_key].add(entity_key)\n                entity_text[entity_key] = column.text\n\n        for string_entity, column_keys in string_entities:\n            entities.add(string_entity)\n            for column_key in column_keys:\n                neighbors[string_entity].add(column_key)\n                neighbors[column_key].add(string_entity)\n            entity_text[string_entity] = string_entity.replace(\"string:\", \"\").replace(\"_\", \" \")\n\n        # loop again after we have gone through all columns to link foreign keys columns\n        for table_name in db_schema.keys():\n            for column in db_schema[table_name].columns:\n                if column.foreign_key is None:\n                    continue\n\n                other_column_table, other_column_name = column.foreign_key.split(':')\n\n                # must have exactly one by design\n                other_column = [col for col in db_schema[other_column_table].columns if col.name == other_column_name][0]\n\n                entity_key = self.entity_key_for_column(table_name, column)\n                other_entity_key = self.entity_key_for_column(other_column_table, other_column)\n\n                neighbors[entity_key].add(other_entity_key)\n                neighbors[other_entity_key].add(entity_key)\n\n                foreign_keys_to_column[entity_key] = other_entity_key\n\n        kg = KnowledgeGraph(entities, dict(neighbors), entity_text)\n        kg.foreign_keys_to_column = foreign_keys_to_column\n\n        return kg\n\n    def _string_in_table(self, candidate: str,\n                         string_column_mapping: Dict[str, set]) -> List[str]:\n        \"\"\"\n        Checks if the string occurs in the table, and if it does, returns the names of the columns\n        under which it occurs. If it does not, returns an empty list.\n        \"\"\"\n        candidate_column_names: List[str] = []\n        # First check if the entire candidate occurs as a cell.\n        if candidate in string_column_mapping:\n            candidate_column_names = string_column_mapping[candidate]\n        # If not, check if it is a substring pf any cell value.\n        if not candidate_column_names:\n            for cell_value, column_names in string_column_mapping.items():\n                if candidate in cell_value:\n                    candidate_column_names.extend(column_names)\n        candidate_column_names = list(set(candidate_column_names))\n        return candidate_column_names\n\n    def get_entities_from_question(self,\n                                   string_column_mapping: Dict[str, set]) -> List[Tuple[str, str]]:\n        entity_data = []\n        for i, token in enumerate(self.tokenized_utterance):\n            token_text = token.text\n            if token_text in STOP_WORDS:\n                continue\n            normalized_token_text = self.normalize_string(token_text)\n            if not normalized_token_text:\n                continue\n            token_columns = self._string_in_table(normalized_token_text, string_column_mapping)\n            if token_columns:\n                token_type = token_columns[0].split(\":\")[1]\n                entity_data.append({'value': normalized_token_text,\n                                    'token_start': i,\n                                    'token_end': i+1,\n                                    'token_type': token_type,\n                                    'token_in_columns': token_columns})\n\n        # extracted_numbers = self._get_numbers_from_tokens(self.question_tokens)\n        # filter out number entities to avoid repetition\n        expanded_entities = []\n        for entity in self._expand_entities(self.tokenized_utterance, entity_data, string_column_mapping):\n            if entity[\"token_type\"] == \"text\":\n                expanded_entities.append((f\"string:{entity['value']}\", entity['token_in_columns']))\n        # return expanded_entities, extracted_numbers  #TODO(shikhar) Handle conjunctions\n\n        return expanded_entities\n\n    @staticmethod\n    def normalize_string(string: str) -> str:\n        \"\"\"\n        These are the transformation rules used to normalize cell in column names in Sempre.  See\n        ``edu.stanford.nlp.sempre.tables.StringNormalizationUtils.characterNormalize`` and\n        ``edu.stanford.nlp.sempre.tables.TableTypeSystem.canonicalizeName``.  We reproduce those\n        rules here to normalize and canonicalize cells and columns in the same way so that we can\n        match them against constants in logical forms appropriately.\n        \"\"\"\n        # Normalization rules from Sempre\n        # \\u201A -> ,\n        string = re.sub(\"‚\", \",\", string)\n        string = re.sub(\"„\", \",,\", string)\n        string = re.sub(\"[·・]\", \".\", string)\n        string = re.sub(\"…\", \"...\", string)\n        string = re.sub(\"ˆ\", \"^\", string)\n        string = re.sub(\"˜\", \"~\", string)\n        string = re.sub(\"‹\", \"<\", string)\n        string = re.sub(\"›\", \">\", string)\n        string = re.sub(\"[‘’´`]\", \"'\", string)\n        string = re.sub(\"[“”«»]\", \"\\\"\", string)\n        string = re.sub(\"[•†‡²³]\", \"\", string)\n        string = re.sub(\"[‐‑–—−]\", \"-\", string)\n        # Oddly, some unicode characters get converted to _ instead of being stripped.  Not really\n        # sure how sempre decides what to do with these...  TODO(mattg): can we just get rid of the\n        # need for this function somehow?  It's causing a whole lot of headaches.\n        string = re.sub(\"[ðø′″€⁄ªΣ]\", \"_\", string)\n        # This is such a mess.  There isn't just a block of unicode that we can strip out, because\n        # sometimes sempre just strips diacritics...  We'll try stripping out a few separate\n        # blocks, skipping the ones that sempre skips...\n        string = re.sub(\"[\\\\u0180-\\\\u0210]\", \"\", string).strip()\n        string = re.sub(\"[\\\\u0220-\\\\uFFFF]\", \"\", string).strip()\n        string = string.replace(\"\\\\n\", \"_\")\n        string = re.sub(\"\\\\s+\", \" \", string)\n        # Canonicalization rules from Sempre.\n        string = re.sub(\"[^\\\\w]\", \"_\", string)\n        string = re.sub(\"_+\", \"_\", string)\n        string = re.sub(\"_$\", \"\", string)\n        return unidecode(string.lower())\n\n    def _expand_entities(self, question, entity_data, string_column_mapping: Dict[str, set]):\n        new_entities = []\n        for entity in entity_data:\n            # to ensure the same strings are not used over and over\n            if new_entities and entity['token_end'] <= new_entities[-1]['token_end']:\n                continue\n            current_start = entity['token_start']\n            current_end = entity['token_end']\n            current_token = entity['value']\n            current_token_type = entity['token_type']\n            current_token_columns = entity['token_in_columns']\n\n            while current_end < len(question):\n                next_token = question[current_end].text\n                next_token_normalized = self.normalize_string(next_token)\n                if next_token_normalized == \"\":\n                    current_end += 1\n                    continue\n                candidate = \"%s_%s\" %(current_token, next_token_normalized)\n                candidate_columns = self._string_in_table(candidate, string_column_mapping)\n                candidate_columns = list(set(candidate_columns).intersection(current_token_columns))\n                if not candidate_columns:\n                    break\n                candidate_type = candidate_columns[0].split(\":\")[1]\n                if candidate_type != current_token_type:\n                    break\n                current_end += 1\n                current_token = candidate\n                current_token_columns = candidate_columns\n\n            new_entities.append({'token_start': current_start,\n                                 'token_end': current_end,\n                                 'value': current_token,\n                                 'token_type': current_token_type,\n                                 'token_in_columns': current_token_columns})\n        return new_entities\n"
  },
  {
    "path": "semparse/contexts/spider_db_grammar.py",
    "content": "# pylint: disable=anomalous-backslash-in-string\n\"\"\"\nA ``Text2SqlTableContext`` represents the SQL context in which an utterance appears\nfor the any of the text2sql datasets, with the grammar and the valid actions.\n\"\"\"\nfrom typing import List, Dict\n\nfrom dataset_readers.dataset_util.spider_utils import Table\n\n\nGRAMMAR_DICTIONARY = {}\nGRAMMAR_DICTIONARY[\"statement\"] = ['(query ws iue ws query)', '(query ws)']\nGRAMMAR_DICTIONARY[\"iue\"] = ['\"intersect\"', '\"except\"', '\"union\"']\nGRAMMAR_DICTIONARY[\"query\"] = ['(ws select_core ws groupby_clause ws orderby_clause ws limit)',\n                               '(ws select_core ws groupby_clause ws orderby_clause)',\n                               '(ws select_core ws groupby_clause ws limit)',\n                               '(ws select_core ws orderby_clause ws limit)',\n                               '(ws select_core ws groupby_clause)',\n                               '(ws select_core ws orderby_clause)',\n                               '(ws select_core)']\n\nGRAMMAR_DICTIONARY[\"select_core\"] = ['(select_with_distinct ws select_results ws from_clause ws where_clause)',\n                                     '(select_with_distinct ws select_results ws from_clause)',\n                                     '(select_with_distinct ws select_results ws where_clause)',\n                                     '(select_with_distinct ws select_results)']\nGRAMMAR_DICTIONARY[\"select_with_distinct\"] = ['(ws \"select\" ws \"distinct\")', '(ws \"select\")']\nGRAMMAR_DICTIONARY[\"select_results\"] = ['(ws select_result ws \",\" ws select_results)', '(ws select_result)']\nGRAMMAR_DICTIONARY[\"select_result\"] = ['\"*\"', '(table_source ws \".*\")',\n                                       'expr', 'col_ref']\n\nGRAMMAR_DICTIONARY[\"from_clause\"] = ['(ws \"from\" ws table_source ws join_clauses)',\n                                     '(ws \"from\" ws source)']\nGRAMMAR_DICTIONARY[\"join_clauses\"] = ['(join_clause ws join_clauses)', 'join_clause']\nGRAMMAR_DICTIONARY[\"join_clause\"] = ['\"join\" ws table_source ws \"on\" ws join_condition_clause']\nGRAMMAR_DICTIONARY[\"join_condition_clause\"] = ['(join_condition ws \"and\" ws join_condition_clause)', 'join_condition']\nGRAMMAR_DICTIONARY[\"join_condition\"] = ['ws col_ref ws \"=\" ws col_ref']\nGRAMMAR_DICTIONARY[\"source\"] = ['(ws single_source ws \",\" ws source)', '(ws single_source)']\nGRAMMAR_DICTIONARY[\"single_source\"] = ['table_source', 'source_subq']\nGRAMMAR_DICTIONARY[\"source_subq\"] = ['(\"(\" ws query ws \")\")']\n# GRAMMAR_DICTIONARY[\"source_subq\"] = ['(\"(\" ws query ws \")\" ws \"as\" ws name)', '(\"(\" ws query ws \")\")']\nGRAMMAR_DICTIONARY[\"limit\"] = ['(\"limit\" ws non_literal_number)']\n\nGRAMMAR_DICTIONARY[\"where_clause\"] = ['(ws \"where\" wsp expr ws where_conj)', '(ws \"where\" wsp expr)']\nGRAMMAR_DICTIONARY[\"where_conj\"] = ['(ws \"and\" wsp expr ws where_conj)', '(ws \"and\" wsp expr)']\n\nGRAMMAR_DICTIONARY[\"groupby_clause\"] = ['(ws \"group\" ws \"by\" ws group_clause ws \"having\" ws expr)',\n                                        '(ws \"group\" ws \"by\" ws group_clause)']\nGRAMMAR_DICTIONARY[\"group_clause\"] = ['(ws expr ws \",\" ws group_clause)', '(ws expr)']\n\nGRAMMAR_DICTIONARY[\"orderby_clause\"] = ['ws \"order\" ws \"by\" ws order_clause']\nGRAMMAR_DICTIONARY[\"order_clause\"] = ['(ordering_term ws \",\" ws order_clause)', 'ordering_term']\nGRAMMAR_DICTIONARY[\"ordering_term\"] = ['(ws expr ws ordering)', '(ws expr)']\nGRAMMAR_DICTIONARY[\"ordering\"] = ['(ws \"asc\")', '(ws \"desc\")']\n\nGRAMMAR_DICTIONARY[\"col_ref\"] = ['(table_name ws \".\" ws column_name)', 'column_name']\nGRAMMAR_DICTIONARY[\"table_source\"] = ['(table_name ws \"as\" ws table_alias)', 'table_name']\nGRAMMAR_DICTIONARY[\"table_name\"] = [\"table_alias\"]\nGRAMMAR_DICTIONARY[\"table_alias\"] = ['\"t1\"', '\"t2\"', '\"t3\"', '\"t4\"']\nGRAMMAR_DICTIONARY[\"column_name\"] = []\n\nGRAMMAR_DICTIONARY[\"ws\"] = ['~\"\\s*\"i']\nGRAMMAR_DICTIONARY['wsp'] = ['~\"\\s+\"i']\n\nGRAMMAR_DICTIONARY[\"expr\"] = ['in_expr',\n                              # Like expressions.\n                              '(value wsp \"like\" wsp string)',\n                              # Between expressions.\n                              '(value ws \"between\" wsp value ws \"and\" wsp value)',\n                              # Binary expressions.\n                              '(value ws binaryop wsp expr)',\n                              # Unary expressions.\n                              '(unaryop ws expr)',\n                              'source_subq',\n                              'value']\nGRAMMAR_DICTIONARY[\"in_expr\"] = ['(value wsp \"not\" wsp \"in\" wsp string_set)',\n                                 '(value wsp \"in\" wsp string_set)',\n                                 '(value wsp \"not\" wsp \"in\" wsp expr)',\n                                 '(value wsp \"in\" wsp expr)']\n\nGRAMMAR_DICTIONARY[\"value\"] = ['parenval', '\"YEAR(CURDATE())\"', 'number', 'boolean',\n                               'function', 'col_ref', 'string']\nGRAMMAR_DICTIONARY[\"parenval\"] = ['\"(\" ws expr ws \")\"']\nGRAMMAR_DICTIONARY[\"function\"] = ['(fname ws \"(\" ws \"distinct\" ws arg_list_or_star ws \")\")',\n                                  '(fname ws \"(\" ws arg_list_or_star ws \")\")']\n\nGRAMMAR_DICTIONARY[\"arg_list_or_star\"] = ['arg_list', '\"*\"']\nGRAMMAR_DICTIONARY[\"arg_list\"] = ['(expr ws \",\" ws arg_list)', 'expr']\n # TODO(MARK): Massive hack, remove and modify the grammar accordingly\n# GRAMMAR_DICTIONARY[\"number\"] = ['~\"\\d*\\.?\\d+\"i', \"'3'\", \"'4'\"]\nGRAMMAR_DICTIONARY[\"non_literal_number\"] = ['\"1\"', '\"2\"', '\"3\"', '\"4\"']\nGRAMMAR_DICTIONARY[\"number\"] = ['ws \"value\" ws']\nGRAMMAR_DICTIONARY[\"string_set\"] = ['ws \"(\" ws string_set_vals ws \")\"']\nGRAMMAR_DICTIONARY[\"string_set_vals\"] = ['(string ws \",\" ws string_set_vals)', 'string']\n# GRAMMAR_DICTIONARY[\"string\"] = ['~\"\\'.*?\\'\"i']\nGRAMMAR_DICTIONARY[\"string\"] = ['\"\\'\" ws \"value\" ws \"\\'\"']\nGRAMMAR_DICTIONARY[\"fname\"] = ['\"count\"', '\"sum\"', '\"max\"', '\"min\"', '\"avg\"', '\"all\"']\nGRAMMAR_DICTIONARY[\"boolean\"] = ['\"true\"', '\"false\"']\n\n# TODO(MARK): This is not tight enough. AND/OR are strictly boolean value operators.\nGRAMMAR_DICTIONARY[\"binaryop\"] = ['\"+\"', '\"-\"', '\"*\"', '\"/\"', '\"=\"', '\"!=\"', '\"<>\"',\n                                  '\">=\"', '\"<=\"', '\">\"', '\"<\"', '\"and\"', '\"or\"', '\"like\"']\nGRAMMAR_DICTIONARY[\"unaryop\"] = ['\"+\"', '\"-\"', '\"not\"', '\"not\"']\n\n\ndef update_grammar_with_tables(grammar_dictionary: Dict[str, List[str]],\n                               schema: Dict[str, Table]) -> None:\n    table_names = sorted([f'\"{table.lower()}\"' for table in\n                          list(schema.keys())], reverse=True)\n    grammar_dictionary['table_name'] += table_names\n\n    all_columns = set()\n    for table in schema.values():\n        all_columns.update([f'\"{table.name.lower()}@{column.name.lower()}\"' for column in table.columns if column.name != '*'])\n    sorted_columns = sorted([column for column in all_columns], reverse=True)\n    grammar_dictionary['column_name'] += sorted_columns\n\n\ndef update_grammar_to_be_table_names_free(grammar_dictionary: Dict[str, List[str]]):\n    \"\"\"\n    Remove table names from column names, remove aliases\n    \"\"\"\n\n    grammar_dictionary[\"column_name\"] = []\n    grammar_dictionary[\"table_name\"] = []\n    grammar_dictionary[\"col_ref\"] = ['column_name']\n    grammar_dictionary[\"table_source\"] = ['table_name']\n\n    del grammar_dictionary[\"table_alias\"]\n\n\ndef update_grammar_flip_joins(grammar_dictionary: Dict[str, List[str]]):\n    \"\"\"\n    Remove table names from column names, remove aliases\n    \"\"\"\n\n    # using a simple rule such as join_clauses-> [(join_clauses ws join_clause), join_clause]\n    # resulted in a max recursion error, so for now just using a predefined max\n    # number of joins\n    grammar_dictionary[\"join_clauses\"] = ['(join_clauses_1 ws join_clause)', 'join_clause']\n    grammar_dictionary[\"join_clauses_1\"] = ['(join_clauses_2 ws join_clause)', 'join_clause']\n    grammar_dictionary[\"join_clauses_2\"] = ['(join_clause ws join_clause)', 'join_clause']"
  },
  {
    "path": "semparse/worlds/evaluate.py",
    "content": "################################\n# val: number(float)/string(str)/sql(dict)\n# col_unit: (agg_id, col_id, isDistinct(bool))\n# val_unit: (unit_op, col_unit1, col_unit2)\n# table_unit: (table_type, col_unit/sql)\n# cond_unit: (not_op, op_id, val_unit, val1, val2)\n# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]\n# sql {\n#   'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])\n#   'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}\n#   'where': condition\n#   'groupBy': [col_unit1, col_unit2, ...]\n#   'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])\n#   'having': condition\n#   'limit': None/limit value\n#   'intersect': None/sql\n#   'except': None/sql\n#   'union': None/sql\n# }\n################################\nimport copy\nimport os\nimport json\nimport sqlite3\nimport argparse\n\nfrom spider_evaluation.process_sql import get_schema, Schema, get_sql\n\n# Flag to disable value evaluation\nDISABLE_VALUE = True\n# Flag to disable distinct in select evaluation\nDISABLE_DISTINCT = True\n\n\nCLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')\nJOIN_KEYWORDS = ('join', 'on', 'as')\n\nWHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')\nUNIT_OPS = ('none', '-', '+', \"*\", '/')\nAGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')\nTABLE_TYPE = {\n    'sql': \"sql\",\n    'table_unit': \"table_unit\",\n}\n\nCOND_OPS = ('and', 'or')\nSQL_OPS = ('intersect', 'union', 'except')\nORDER_OPS = ('desc', 'asc')\n\n\nHARDNESS = {\n    \"component1\": ('where', 'group', 'order', 'limit', 'join', 'or', 'like'),\n    \"component2\": ('except', 'union', 'intersect')\n}\n\n\ndef condition_has_or(conds):\n    return 'or' in conds[1::2]\n\n\ndef condition_has_like(conds):\n    return WHERE_OPS.index('like') in [cond_unit[1] for cond_unit in conds[::2]]\n\n\ndef condition_has_sql(conds):\n    for cond_unit in conds[::2]:\n        val1, val2 = cond_unit[3], cond_unit[4]\n        if val1 is not None and type(val1) is dict:\n            return True\n        if val2 is not None and type(val2) is dict:\n            return True\n    return False\n\n\ndef val_has_op(val_unit):\n    return val_unit[0] != UNIT_OPS.index('none')\n\n\ndef has_agg(unit):\n    return unit[0] != AGG_OPS.index('none')\n\n\ndef accuracy(count, total):\n    if count == total:\n        return 1\n    return 0\n\n\ndef recall(count, total):\n    if count == total:\n        return 1\n    return 0\n\n\ndef F1(acc, rec):\n    if (acc + rec) == 0:\n        return 0\n    return (2. * acc * rec) / (acc + rec)\n\n\ndef get_scores(count, pred_total, label_total):\n    if pred_total != label_total:\n        return 0,0,0\n    elif count == pred_total:\n        return 1,1,1\n    return 0,0,0\n\n\ndef eval_sel(pred, label):\n    pred_sel = pred['select'][1]\n    label_sel = label['select'][1]\n    label_wo_agg = [unit[1] for unit in label_sel]\n    pred_total = len(pred_sel)\n    label_total = len(label_sel)\n    cnt = 0\n    cnt_wo_agg = 0\n\n    for unit in pred_sel:\n        if unit in label_sel:\n            cnt += 1\n            label_sel.remove(unit)\n        if unit[1] in label_wo_agg:\n            cnt_wo_agg += 1\n            label_wo_agg.remove(unit[1])\n\n    return label_total, pred_total, cnt, cnt_wo_agg\n\n\ndef eval_where(pred, label):\n    pred_conds = [unit for unit in pred['where'][::2]]\n    label_conds = [unit for unit in label['where'][::2]]\n    label_wo_agg = [unit[2] for unit in label_conds]\n    pred_total = len(pred_conds)\n    label_total = len(label_conds)\n    cnt = 0\n    cnt_wo_agg = 0\n\n    for unit in pred_conds:\n        if unit in label_conds:\n            cnt += 1\n            label_conds.remove(unit)\n        if unit[2] in label_wo_agg:\n            cnt_wo_agg += 1\n            label_wo_agg.remove(unit[2])\n\n    return label_total, pred_total, cnt, cnt_wo_agg\n\n\ndef eval_group(pred, label):\n    pred_cols = [unit[1] for unit in pred['groupBy']]\n    label_cols = [unit[1] for unit in label['groupBy']]\n    pred_total = len(pred_cols)\n    label_total = len(label_cols)\n    cnt = 0\n    pred_cols = [pred.split(\".\")[1] if \".\" in pred else pred for pred in pred_cols]\n    label_cols = [label.split(\".\")[1] if \".\" in label else label for label in label_cols]\n    for col in pred_cols:\n        if col in label_cols:\n            cnt += 1\n            label_cols.remove(col)\n    return label_total, pred_total, cnt\n\n\ndef eval_having(pred, label):\n    pred_total = label_total = cnt = 0\n    if len(pred['groupBy']) > 0:\n        pred_total = 1\n    if len(label['groupBy']) > 0:\n        label_total = 1\n\n    pred_cols = [unit[1] for unit in pred['groupBy']]\n    label_cols = [unit[1] for unit in label['groupBy']]\n    if pred_total == label_total == 1 \\\n            and pred_cols == label_cols \\\n            and pred['having'] == label['having']:\n        cnt = 1\n\n    return label_total, pred_total, cnt\n\n\ndef eval_order(pred, label):\n    pred_total = label_total = cnt = 0\n    if len(pred['orderBy']) > 0:\n        pred_total = 1\n    if len(label['orderBy']) > 0:\n        label_total = 1\n    if len(label['orderBy']) > 0 and pred['orderBy'] == label['orderBy'] and \\\n            ((pred['limit'] is None and label['limit'] is None) or (pred['limit'] is not None and label['limit'] is not None)):\n        cnt = 1\n    return label_total, pred_total, cnt\n\n\ndef eval_and_or(pred, label):\n    pred_ao = pred['where'][1::2]\n    label_ao = label['where'][1::2]\n    pred_ao = set(pred_ao)\n    label_ao = set(label_ao)\n\n    if pred_ao == label_ao:\n        return 1,1,1\n    return len(pred_ao),len(label_ao),0\n\n\ndef get_nestedSQL(sql):\n    nested = []\n    for cond_unit in sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]:\n        if type(cond_unit[3]) is dict:\n            nested.append(cond_unit[3])\n        if type(cond_unit[4]) is dict:\n            nested.append(cond_unit[4])\n    if sql['intersect'] is not None:\n        nested.append(sql['intersect'])\n    if sql['except'] is not None:\n        nested.append(sql['except'])\n    if sql['union'] is not None:\n        nested.append(sql['union'])\n    return nested\n\n\ndef eval_nested(pred, label):\n    label_total = 0\n    pred_total = 0\n    cnt = 0\n    if pred is not None:\n        pred_total += 1\n    if label is not None:\n        label_total += 1\n    if pred is not None and label is not None:\n        cnt += Evaluator().eval_exact_match(pred, label)\n    return label_total, pred_total, cnt\n\n\ndef eval_IUEN(pred, label):\n    lt1, pt1, cnt1 = eval_nested(pred['intersect'], label['intersect'])\n    lt2, pt2, cnt2 = eval_nested(pred['except'], label['except'])\n    lt3, pt3, cnt3 = eval_nested(pred['union'], label['union'])\n    label_total = lt1 + lt2 + lt3\n    pred_total = pt1 + pt2 + pt3\n    cnt = cnt1 + cnt2 + cnt3\n    return label_total, pred_total, cnt\n\n\ndef get_keywords(sql):\n    res = set()\n    if len(sql['where']) > 0:\n        res.add('where')\n    if len(sql['groupBy']) > 0:\n        res.add('group')\n    if len(sql['having']) > 0:\n        res.add('having')\n    if len(sql['orderBy']) > 0:\n        res.add(sql['orderBy'][0])\n        res.add('order')\n    if sql['limit'] is not None:\n        res.add('limit')\n    if sql['except'] is not None:\n        res.add('except')\n    if sql['union'] is not None:\n        res.add('union')\n    if sql['intersect'] is not None:\n        res.add('intersect')\n\n    # or keyword\n    ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]\n    if len([token for token in ao if token == 'or']) > 0:\n        res.add('or')\n\n    cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]\n    # not keyword\n    if len([cond_unit for cond_unit in cond_units if cond_unit[0]]) > 0:\n        res.add('not')\n\n    # in keyword\n    if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('in')]) > 0:\n        res.add('in')\n\n    # like keyword\n    if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) > 0:\n        res.add('like')\n\n    return res\n\n\ndef eval_keywords(pred, label):\n    pred_keywords = get_keywords(pred)\n    label_keywords = get_keywords(label)\n    pred_total = len(pred_keywords)\n    label_total = len(label_keywords)\n    cnt = 0\n\n    for k in pred_keywords:\n        if k in label_keywords:\n            cnt += 1\n    return label_total, pred_total, cnt\n\n\ndef count_agg(units):\n    return len([unit for unit in units if has_agg(unit)])\n\n\ndef count_component1(sql):\n    count = 0\n    if len(sql['where']) > 0:\n        count += 1\n    if len(sql['groupBy']) > 0:\n        count += 1\n    if len(sql['orderBy']) > 0:\n        count += 1\n    if sql['limit'] is not None:\n        count += 1\n    if len(sql['from']['table_units']) > 0:  # JOIN\n        count += len(sql['from']['table_units']) - 1\n\n    ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]\n    count += len([token for token in ao if token == 'or'])\n    cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]\n    count += len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')])\n\n    return count\n\n\ndef count_component2(sql):\n    nested = get_nestedSQL(sql)\n    return len(nested)\n\n\ndef count_others(sql):\n    count = 0\n    # number of aggregation\n    agg_count = count_agg(sql['select'][1])\n    agg_count += count_agg(sql['where'][::2])\n    agg_count += count_agg(sql['groupBy'])\n    if len(sql['orderBy']) > 0:\n        agg_count += count_agg([unit[1] for unit in sql['orderBy'][1] if unit[1]] +\n                            [unit[2] for unit in sql['orderBy'][1] if unit[2]])\n    agg_count += count_agg(sql['having'])\n    if agg_count > 1:\n        count += 1\n\n    # number of select columns\n    if len(sql['select'][1]) > 1:\n        count += 1\n\n    # number of where conditions\n    if len(sql['where']) > 1:\n        count += 1\n\n    # number of group by clauses\n    if len(sql['groupBy']) > 1:\n        count += 1\n\n    return count\n\n\nclass Evaluator:\n    \"\"\"A simple evaluator\"\"\"\n    def __init__(self):\n        self.partial_scores = None\n\n    def eval_hardness(self, sql):\n        count_comp1_ = count_component1(sql)\n        count_comp2_ = count_component2(sql)\n        count_others_ = count_others(sql)\n\n        if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0:\n            return \"easy\"\n        elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or \\\n                (count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0):\n            return \"medium\"\n        elif (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or \\\n                (2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or \\\n                (count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1):\n            return \"hard\"\n        else:\n            return \"extra\"\n\n    def eval_exact_match(self, pred, label):\n        partial_scores = self.eval_partial_match(pred, label)\n        self.partial_scores = partial_scores\n\n        for _, score in partial_scores.items():\n            if score['f1'] != 1:\n                return 0\n        if len(label['from']['table_units']) > 0:\n            label_tables = sorted(label['from']['table_units'])\n            pred_tables = sorted(pred['from']['table_units'])\n            if label_tables != pred_tables:\n                return False\n        if len(label['from']['conds']) > 0:\n            label_joins = sorted(label['from']['conds'], key=lambda x: str(x))\n            pred_joins = sorted(pred['from']['conds'], key=lambda x: str(x))\n            if label_joins != pred_joins:\n                return False\n        return 1\n\n    def eval_partial_match(self, pred, label):\n        res = {}\n\n        label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label)\n        acc, rec, f1 = get_scores(cnt, pred_total, label_total)\n        res['select'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}\n        acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)\n        res['select(no AGG)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}\n\n        label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label)\n        acc, rec, f1 = get_scores(cnt, pred_total, label_total)\n        res['where'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}\n        acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)\n        res['where(no OP)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}\n\n        label_total, pred_total, cnt = eval_group(pred, label)\n        acc, rec, f1 = get_scores(cnt, pred_total, label_total)\n        res['group(no Having)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}\n\n        label_total, pred_total, cnt = eval_having(pred, label)\n        acc, rec, f1 = get_scores(cnt, pred_total, label_total)\n        res['group'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}\n\n        label_total, pred_total, cnt = eval_order(pred, label)\n        acc, rec, f1 = get_scores(cnt, pred_total, label_total)\n        res['order'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}\n\n        label_total, pred_total, cnt = eval_and_or(pred, label)\n        acc, rec, f1 = get_scores(cnt, pred_total, label_total)\n        res['and/or'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}\n\n        label_total, pred_total, cnt = eval_IUEN(pred, label)\n        acc, rec, f1 = get_scores(cnt, pred_total, label_total)\n        res['IUEN'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}\n\n        label_total, pred_total, cnt = eval_keywords(pred, label)\n        acc, rec, f1 = get_scores(cnt, pred_total, label_total)\n        res['keywords'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}\n\n        return res\n\n\ndef isValidSQL(sql, db):\n    conn = sqlite3.connect(db)\n    cursor = conn.cursor()\n    try:\n        cursor.execute(sql, [])\n    except Exception as e:\n        return False\n    return True\n\n\ndef print_scores(scores, etype):\n    levels = ['easy', 'medium', 'hard', 'extra', 'all']\n    partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',\n                     'group', 'order', 'and/or', 'IUEN', 'keywords']\n\n    print(\"{:20} {:20} {:20} {:20} {:20} {:20}\".format(\"\", *levels))\n    counts = [scores[level]['count'] for level in levels]\n    print(\"{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}\".format(\"count\", *counts))\n\n    if etype in [\"all\", \"exec\"]:\n        print('=====================   EXECUTION ACCURACY     =====================')\n        this_scores = [scores[level]['exec'] for level in levels]\n        print(\"{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}\".format(\"execution\", *this_scores))\n\n    if etype in [\"all\", \"match\"]:\n        print('\\n====================== EXACT MATCHING ACCURACY =====================')\n        exact_scores = [scores[level]['exact'] for level in levels]\n        print(\"{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}\".format(\"exact match\", *exact_scores))\n        print('\\n---------------------PARTIAL MATCHING ACCURACY----------------------')\n        for type_ in partial_types:\n            this_scores = [scores[level]['partial'][type_]['acc'] for level in levels]\n            print(\"{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}\".format(type_, *this_scores))\n\n        print('---------------------- PARTIAL MATCHING RECALL ----------------------')\n        for type_ in partial_types:\n            this_scores = [scores[level]['partial'][type_]['rec'] for level in levels]\n            print(\"{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}\".format(type_, *this_scores))\n\n        print('---------------------- PARTIAL MATCHING F1 --------------------------')\n        for type_ in partial_types:\n            this_scores = [scores[level]['partial'][type_]['f1'] for level in levels]\n            print(\"{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}\".format(type_, *this_scores))\n\n\ndef evaluate(gold, predict, db_dir, etype, kmaps):\n    with open(gold) as f:\n        glist = [l.strip().split('\\t') for l in f.readlines() if len(l.strip()) > 0]\n\n    with open(predict) as f:\n        plist = [l.strip().split('\\t') for l in f.readlines() if len(l.strip()) > 0]\n    # plist = [(\"select max(Share),min(Share) from performance where Type != 'terminal'\", \"orchestra\")]\n    # glist = [(\"SELECT max(SHARE) ,  min(SHARE) FROM performance WHERE TYPE != 'Live final'\", \"orchestra\")]\n    evaluator = Evaluator()\n\n    levels = ['easy', 'medium', 'hard', 'extra', 'all']\n    partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',\n                     'group', 'order', 'and/or', 'IUEN', 'keywords']\n    entries = []\n    scores = {}\n\n    for level in levels:\n        scores[level] = {'count': 0, 'partial': {}, 'exact': 0.}\n        scores[level]['exec'] = 0\n        for type_ in partial_types:\n            scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0.,'acc_count':0,'rec_count':0}\n\n    eval_err_num = 0\n    for p, g in zip(plist, glist):\n        p_str = p[0]\n        g_str, db = g\n        db_name = db\n        db = os.path.join(db_dir, db, db + \".sqlite\")\n        schema = Schema(get_schema(db))\n        g_sql = get_sql(schema, g_str)\n        hardness = evaluator.eval_hardness(g_sql)\n        scores[hardness]['count'] += 1\n        scores['all']['count'] += 1\n\n        try:\n            p_sql = get_sql(schema, p_str)\n        except:\n            # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql\n            p_sql = {\n            \"except\": None,\n            \"from\": {\n                \"conds\": [],\n                \"table_units\": []\n            },\n            \"groupBy\": [],\n            \"having\": [],\n            \"intersect\": None,\n            \"limit\": None,\n            \"orderBy\": [],\n            \"select\": [\n                False,\n                []\n            ],\n            \"union\": None,\n            \"where\": []\n            }\n            eval_err_num += 1\n            print(\"eval_err_num:{}\".format(eval_err_num))\n            print(p_str)\n            print()\n\n        # rebuild sql for value evaluation\n        kmap = kmaps[db_name]\n        g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema)\n        g_sql = rebuild_sql_val(g_sql)\n        g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap)\n        p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema)\n        p_sql = rebuild_sql_val(p_sql)\n        p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)\n\n        # p_sql_copy = copy.deepcopy(p_sql)\n        # g_sql_copy = copy.deepcopy(g_sql)\n        # if not eval_exec_match(db, p_str, g_str, p_sql_copy, g_sql_copy) and evaluator.eval_exact_match(p_sql_copy, g_sql_copy):\n        #     a = 1\n\n        if etype in [\"all\", \"exec\"]:\n            exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql)\n            if exec_score:\n                scores[hardness]['exec'] += 1\n            scores['all']['exec'] += exec_score\n\n        if etype in [\"all\", \"match\"]:\n            exact_score = evaluator.eval_exact_match(p_sql, g_sql)\n            partial_scores = evaluator.partial_scores\n            if exact_score == 0:\n                print(\"{} pred: {}\".format(hardness,p_str))\n                print(\"{} gold: {}\".format(hardness,g_str))\n                print(\"\")\n            scores[hardness]['exact'] += exact_score\n            scores['all']['exact'] += exact_score\n            for type_ in partial_types:\n                if partial_scores[type_]['pred_total'] > 0:\n                    scores[hardness]['partial'][type_]['acc'] += partial_scores[type_]['acc']\n                    scores[hardness]['partial'][type_]['acc_count'] += 1\n                if partial_scores[type_]['label_total'] > 0:\n                    scores[hardness]['partial'][type_]['rec'] += partial_scores[type_]['rec']\n                    scores[hardness]['partial'][type_]['rec_count'] += 1\n                scores[hardness]['partial'][type_]['f1'] += partial_scores[type_]['f1']\n                if partial_scores[type_]['pred_total'] > 0:\n                    scores['all']['partial'][type_]['acc'] += partial_scores[type_]['acc']\n                    scores['all']['partial'][type_]['acc_count'] += 1\n                if partial_scores[type_]['label_total'] > 0:\n                    scores['all']['partial'][type_]['rec'] += partial_scores[type_]['rec']\n                    scores['all']['partial'][type_]['rec_count'] += 1\n                scores['all']['partial'][type_]['f1'] += partial_scores[type_]['f1']\n\n            entries.append({\n                'predictSQL': p_str,\n                'goldSQL': g_str,\n                'hardness': hardness,\n                'exact': exact_score,\n                'partial': partial_scores\n            })\n\n    for level in levels:\n        if scores[level]['count'] == 0:\n            continue\n        if etype in [\"all\", \"exec\"]:\n            scores[level]['exec'] /= scores[level]['count']\n\n        if etype in [\"all\", \"match\"]:\n            scores[level]['exact'] /= scores[level]['count']\n            for type_ in partial_types:\n                if scores[level]['partial'][type_]['acc_count'] == 0:\n                    scores[level]['partial'][type_]['acc'] = 0\n                else:\n                    scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \\\n                                                             scores[level]['partial'][type_]['acc_count'] * 1.0\n                if scores[level]['partial'][type_]['rec_count'] == 0:\n                    scores[level]['partial'][type_]['rec'] = 0\n                else:\n                    scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \\\n                                                             scores[level]['partial'][type_]['rec_count'] * 1.0\n                if scores[level]['partial'][type_]['acc'] == 0 and scores[level]['partial'][type_]['rec'] == 0:\n                    scores[level]['partial'][type_]['f1'] = 1\n                else:\n                    scores[level]['partial'][type_]['f1'] = \\\n                        2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / (\n                        scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc'])\n\n    print_scores(scores, etype)\n\n\ndef eval_exec_match(db, p_str, g_str, pred, gold):\n    \"\"\"\n    return 1 if the values between prediction and gold are matching\n    in the corresponding index. Currently not support multiple col_unit(pairs).\n    \"\"\"\n    conn = sqlite3.connect(db)\n    cursor = conn.cursor()\n    conn.text_factory = bytes\n    try:\n        cursor.execute(p_str)\n        p_res = cursor.fetchall()\n    except:\n        return False\n\n    cursor.execute(g_str)\n    q_res = cursor.fetchall()\n\n    def res_map(res, val_units):\n        rmap = {}\n        for idx, val_unit in enumerate(val_units):\n            key = tuple(val_unit[1]) if not val_unit[2] else (val_unit[0], tuple(val_unit[1]), tuple(val_unit[2]))\n            rmap[key] = [r[idx] for r in res]\n        return rmap\n\n    p_val_units = [unit[1] for unit in pred['select'][1]]\n    q_val_units = [unit[1] for unit in gold['select'][1]]\n    return res_map(p_res, p_val_units) == res_map(q_res, q_val_units)\n\n\n# Rebuild SQL functions for value evaluation\ndef rebuild_cond_unit_val(cond_unit):\n    if cond_unit is None or not DISABLE_VALUE:\n        return cond_unit\n\n    not_op, op_id, val_unit, val1, val2 = cond_unit\n    if type(val1) is not dict:\n        val1 = None\n    else:\n        val1 = rebuild_sql_val(val1)\n    if type(val2) is not dict:\n        val2 = None\n    else:\n        val2 = rebuild_sql_val(val2)\n    return not_op, op_id, val_unit, val1, val2\n\n\ndef rebuild_condition_val(condition):\n    if condition is None or not DISABLE_VALUE:\n        return condition\n\n    res = []\n    for idx, it in enumerate(condition):\n        if idx % 2 == 0:\n            res.append(rebuild_cond_unit_val(it))\n        else:\n            res.append(it)\n    return res\n\n\ndef rebuild_sql_val(sql):\n    if sql is None or not DISABLE_VALUE:\n        return sql\n\n    sql['from']['conds'] = rebuild_condition_val(sql['from']['conds'])\n    sql['having'] = rebuild_condition_val(sql['having'])\n    sql['where'] = rebuild_condition_val(sql['where'])\n    sql['intersect'] = rebuild_sql_val(sql['intersect'])\n    sql['except'] = rebuild_sql_val(sql['except'])\n    sql['union'] = rebuild_sql_val(sql['union'])\n\n    return sql\n\n\n# Rebuild SQL functions for foreign key evaluation\ndef build_valid_col_units(table_units, schema):\n    col_ids = [table_unit[1] for table_unit in table_units if table_unit[0] == TABLE_TYPE['table_unit']]\n    prefixs = [col_id[:-2] for col_id in col_ids]\n    valid_col_units= []\n    for value in schema.idMap.values():\n        if '.' in value and value[:value.index('.')] in prefixs:\n            valid_col_units.append(value)\n    return valid_col_units\n\n\ndef rebuild_col_unit_col(valid_col_units, col_unit, kmap):\n    if col_unit is None:\n        return col_unit\n\n    agg_id, col_id, distinct = col_unit\n    if col_id in kmap and col_id in valid_col_units:\n        col_id = kmap[col_id]\n    if DISABLE_DISTINCT:\n        distinct = None\n    return agg_id, col_id, distinct\n\n\ndef rebuild_val_unit_col(valid_col_units, val_unit, kmap):\n    if val_unit is None:\n        return val_unit\n\n    unit_op, col_unit1, col_unit2 = val_unit\n    col_unit1 = rebuild_col_unit_col(valid_col_units, col_unit1, kmap)\n    col_unit2 = rebuild_col_unit_col(valid_col_units, col_unit2, kmap)\n    return unit_op, col_unit1, col_unit2\n\n\ndef rebuild_table_unit_col(valid_col_units, table_unit, kmap):\n    if table_unit is None:\n        return table_unit\n\n    table_type, col_unit_or_sql = table_unit\n    if isinstance(col_unit_or_sql, tuple):\n        col_unit_or_sql = rebuild_col_unit_col(valid_col_units, col_unit_or_sql, kmap)\n    return table_type, col_unit_or_sql\n\n\ndef rebuild_cond_unit_col(valid_col_units, cond_unit, kmap):\n    if cond_unit is None:\n        return cond_unit\n\n    not_op, op_id, val_unit, val1, val2 = cond_unit\n    val_unit = rebuild_val_unit_col(valid_col_units, val_unit, kmap)\n    return not_op, op_id, val_unit, val1, val2\n\n\ndef rebuild_condition_col(valid_col_units, condition, kmap):\n    for idx in range(len(condition)):\n        if idx % 2 == 0:\n            condition[idx] = rebuild_cond_unit_col(valid_col_units, condition[idx], kmap)\n    return condition\n\n\ndef rebuild_select_col(valid_col_units, sel, kmap):\n    if sel is None:\n        return sel\n    distinct, _list = sel\n    new_list = []\n    for it in _list:\n        agg_id, val_unit = it\n        new_list.append((agg_id, rebuild_val_unit_col(valid_col_units, val_unit, kmap)))\n    if DISABLE_DISTINCT:\n        distinct = None\n    return distinct, new_list\n\n\ndef rebuild_from_col(valid_col_units, from_, kmap):\n    if from_ is None:\n        return from_\n\n    from_['table_units'] = [rebuild_table_unit_col(valid_col_units, table_unit, kmap) for table_unit in from_['table_units']]\n    from_['conds'] = rebuild_condition_col(valid_col_units, from_['conds'], kmap)\n    return from_\n\n\ndef rebuild_group_by_col(valid_col_units, group_by, kmap):\n    if group_by is None:\n        return group_by\n\n    return [rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by]\n\n\ndef rebuild_order_by_col(valid_col_units, order_by, kmap):\n    if order_by is None or len(order_by) == 0:\n        return order_by\n\n    direction, val_units = order_by\n    new_val_units = [rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units]\n    return direction, new_val_units\n\n\ndef rebuild_sql_col(valid_col_units, sql, kmap):\n    if sql is None:\n        return sql\n\n    sql['select'] = rebuild_select_col(valid_col_units, sql['select'], kmap)\n    sql['from'] = rebuild_from_col(valid_col_units, sql['from'], kmap)\n    sql['where'] = rebuild_condition_col(valid_col_units, sql['where'], kmap)\n    sql['groupBy'] = rebuild_group_by_col(valid_col_units, sql['groupBy'], kmap)\n    sql['orderBy'] = rebuild_order_by_col(valid_col_units, sql['orderBy'], kmap)\n    sql['having'] = rebuild_condition_col(valid_col_units, sql['having'], kmap)\n    sql['intersect'] = rebuild_sql_col(valid_col_units, sql['intersect'], kmap)\n    sql['except'] = rebuild_sql_col(valid_col_units, sql['except'], kmap)\n    sql['union'] = rebuild_sql_col(valid_col_units, sql['union'], kmap)\n\n    return sql\n\n\ndef build_foreign_key_map(entry):\n    cols_orig = entry[\"column_names_original\"]\n    tables_orig = entry[\"table_names_original\"]\n\n    # rebuild cols corresponding to idmap in Schema\n    cols = []\n    for col_orig in cols_orig:\n        if col_orig[0] >= 0:\n            t = tables_orig[col_orig[0]]\n            c = col_orig[1]\n            cols.append(\"__\" + t.lower() + \".\" + c.lower() + \"__\")\n        else:\n            cols.append(\"__all__\")\n\n    def keyset_in_list(k1, k2, k_list):\n        for k_set in k_list:\n            if k1 in k_set or k2 in k_set:\n                return k_set\n        new_k_set = set()\n        k_list.append(new_k_set)\n        return new_k_set\n\n    foreign_key_list = []\n    foreign_keys = entry[\"foreign_keys\"]\n    for fkey in foreign_keys:\n        key1, key2 = fkey\n        key_set = keyset_in_list(key1, key2, foreign_key_list)\n        key_set.add(key1)\n        key_set.add(key2)\n\n    foreign_key_map = {}\n    for key_set in foreign_key_list:\n        sorted_list = sorted(list(key_set))\n        midx = sorted_list[0]\n        for idx in sorted_list:\n            foreign_key_map[cols[idx]] = cols[midx]\n\n    return foreign_key_map\n\n\ndef build_foreign_key_map_from_json(table):\n    with open(table) as f:\n        data = json.load(f)\n    tables = {}\n    for entry in data:\n        tables[entry['db_id']] = build_foreign_key_map(entry)\n    return tables\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--gold', dest='gold', type=str)\n    parser.add_argument('--pred', dest='pred', type=str)\n    parser.add_argument('--db', dest='db', type=str)\n    parser.add_argument('--table', dest='table', type=str)\n    parser.add_argument('--etype', dest='etype', type=str)\n    args = parser.parse_args()\n\n    gold = args.gold\n    pred = args.pred\n    db_dir = args.db\n    table = args.table\n    etype = args.etype\n\n    assert etype in [\"all\", \"exec\", \"match\"], \"Unknown evaluation method\"\n\n    kmaps = build_foreign_key_map_from_json(table)\n\n    evaluate(gold, pred, db_dir, etype, kmaps)"
  },
  {
    "path": "semparse/worlds/evaluate_spider.py",
    "content": "import os\nimport sqlite3\n\nfrom semparse.worlds.evaluate import Evaluator, build_valid_col_units, rebuild_sql_val, rebuild_sql_col, \\\n    build_foreign_key_map_from_json\nfrom spider_evaluation.process_sql import Schema, get_schema, get_sql\n\n_schemas = {}\nkmaps = None\n\n\ndef evaluate(gold, predict, db_name, db_dir, table, check_valid: bool=True) -> bool:\n    global kmaps\n\n    # try:\n    evaluator = Evaluator()\n\n    if kmaps is None:\n        kmaps = build_foreign_key_map_from_json(table)\n\n    if db_name in _schemas:\n        schema = _schemas[db_name]\n    else:\n        db = os.path.join(db_dir, db_name, db_name + \".sqlite\")\n        schema = _schemas[db_name] = Schema(get_schema(db))\n\n    g_sql = get_sql(schema, gold)\n\n    try:\n        p_sql = get_sql(schema, predict)\n    except Exception as e:\n        return False\n\n    # rebuild sql for value evaluation\n    kmap = kmaps[db_name]\n    g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema)\n    g_sql = rebuild_sql_val(g_sql)\n    g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap)\n    p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema)\n    p_sql = rebuild_sql_val(p_sql)\n    p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)\n\n    exact_score = evaluator.eval_exact_match(p_sql, g_sql)\n\n    if not check_valid:\n        return exact_score\n    else:\n        return exact_score and check_valid_sql(predict, db_name, db_dir)\n    # except Exception as e:\n    #     return 0\n\n\n_conns = {}\n\n\ndef check_valid_sql(sql, db_name, db_dir, return_error=False):\n    db = os.path.join(db_dir, db_name, db_name + \".sqlite\")\n\n    if db_name == 'wta_1':\n        # TODO: seems like there is a problem with this dataset - slow response - add limit 1\n        return True if not return_error else (True, None)\n\n    if db_name not in _conns:\n        _conns[db_name] = sqlite3.connect(db)\n\n        # fixes an encoding bug\n        _conns[db_name].text_factory = bytes\n\n    conn = _conns[db_name]\n    cursor = conn.cursor()\n    try:\n        cursor.execute(sql)\n        cursor.fetchall()\n        return True if not return_error else (True, None)\n    except Exception as e:\n        return False if not return_error else (False, e.args[0])\n"
  },
  {
    "path": "semparse/worlds/spider_world.py",
    "content": "from typing import List, Tuple, Dict, Set, Optional\nfrom copy import deepcopy\n\nfrom parsimonious import Grammar\nfrom parsimonious.exceptions import ParseError\n\nfrom semparse.contexts.spider_context_utils import format_grammar_string, initialize_valid_actions, SqlVisitor\nfrom semparse.contexts.spider_db_context import SpiderDBContext\nfrom semparse.contexts.spider_db_grammar import GRAMMAR_DICTIONARY, update_grammar_with_tables, \\\n    update_grammar_to_be_table_names_free, update_grammar_flip_joins\n\n\nclass SpiderWorld:\n    \"\"\"\n    World representation for spider dataset.\n    \"\"\"\n\n    def __init__(self, db_context: SpiderDBContext, query: Optional[List[str]], allow_alias: bool = False) -> None:\n        self.db_id = db_context.db_id\n        self.allow_alias = allow_alias\n\n        # NOTE: This base dictionary should not be modified.\n        self.base_grammar_dictionary = deepcopy(GRAMMAR_DICTIONARY)\n        self.query = query\n        self.db_context = db_context\n\n        # keep a list of entities names as they are given in sql queries\n        self.entities_names = {}\n        for i, entity in enumerate(self.db_context.knowledge_graph.entities):\n            parts = entity.split(':')\n            if parts[0] in ['table', 'string']:\n                self.entities_names[parts[1]] = i\n            else:\n                _, _, table_name, column_name = parts\n                self.entities_names[f'{table_name}@{column_name}'] = i\n        self.valid_actions = []\n        self.valid_actions_flat = []\n\n    def get_action_sequence_and_all_actions(self,\n                                            allow_aliases: bool = False) -> Tuple[List[str], List[str]]:\n        grammar_with_context = deepcopy(self.base_grammar_dictionary)\n        if not allow_aliases:\n            update_grammar_to_be_table_names_free(grammar_with_context)\n\n        schema = self.db_context.schema\n\n        update_grammar_with_tables(grammar_with_context, schema)\n        grammar = Grammar(format_grammar_string(grammar_with_context))\n\n        valid_actions = initialize_valid_actions(grammar)\n        all_actions = set()\n        for action_list in valid_actions.values():\n            all_actions.update(action_list)\n        sorted_actions = sorted(all_actions)\n        self.valid_actions = valid_actions\n        self.valid_actions_flat = sorted_actions\n\n        action_sequence = None\n        if self.query is not None:\n            sql_visitor = SqlVisitor(grammar)\n            query = \" \".join(self.query).lower().replace(\"``\", \"'\").replace(\"''\", \"'\")\n            try:\n                action_sequence = sql_visitor.parse(query) if query else []\n            except ParseError as e:\n                pass\n\n        return action_sequence, sorted_actions\n\n    def get_all_actions(self, schema,\n                        flip_joins: bool,\n                        allow_aliases: bool) -> Tuple[List[str], List[str]]:\n        grammar_with_context = deepcopy(self.base_grammar_dictionary)\n        if not allow_aliases:\n            update_grammar_to_be_table_names_free(grammar_with_context)\n\n        if flip_joins:\n            update_grammar_flip_joins(grammar_with_context)\n\n        update_grammar_with_tables(grammar_with_context, schema, self.db_id)\n        grammar = Grammar(format_grammar_string(grammar_with_context))\n\n        valid_actions = initialize_valid_actions(grammar)\n        all_actions = set()\n        for action_list in valid_actions.values():\n            all_actions.update(action_list)\n        sorted_actions = sorted(all_actions)\n        self.valid_actions = valid_actions\n        self.valid_actions_flat = sorted_actions\n\n        return sorted_actions\n\n    def is_global_rule(self, rhs: str) -> bool:\n        rhs = rhs.strip('[] ')\n        if rhs[0] != '\"':\n            return True\n        return rhs.strip('\"') not in self.entities_names\n\n    def get_oracle_relevance_score(self, oracle_entities: set):\n        \"\"\"\n        return 0/1 for each schema item if it should be in the graph,\n        given the used entities in the gold answer\n        \"\"\"\n        scores = [0 for _ in range(len(self.db_context.knowledge_graph.entities))]\n\n        for i, entity in enumerate(self.db_context.knowledge_graph.entities):\n            parts = entity.split(':')\n            if parts[0] == 'column':\n                name = parts[2] + '@' + parts[3]\n            else:\n                name = parts[-1]\n            if name in oracle_entities:\n                scores[i] = 1\n\n        return scores\n\n    def get_action_entity_mapping(self) -> Dict[int, int]:\n        mapping = {}\n\n        for action_index, action in enumerate(self.valid_actions_flat):\n            # default is padding\n            mapping[action_index] = -1\n\n            action = action.split(\" -> \")[1].strip('[]')\n            action_stripped = action.strip('\\\"')\n            if action[0] != '\"' or action_stripped not in self.entities_names:\n                continue\n\n            mapping[action_index] = self.entities_names[action_stripped]\n\n        return mapping\n\n    def get_query_without_table_hints(self):\n        if not self.query:\n            return ''\n        toks = []\n        for tok in self.query:\n            if '@' in tok:\n                parts = tok.split('@')\n                if '.' in parts[0]:\n                    toks.append(parts[0].split('.')[0] + '.' + parts[1])\n                else:\n                    toks.append(parts[1])\n            else:\n                toks.append(tok)\n        return toks\n\n    # def is_ambiguous_column(self, action_rhs: str, actions_sequence: List[str], action_index: int):\n    #     \"\"\"\n    #     a column would be ambiguous if another table is used in the query and it has a column with the same name\n    #     currently, return true only for join clauses\n    #     \"\"\"\n    #     if actions_sequence[action_index-1].startswith('join_condition -> ') or \\\n    #             actions_sequence[action_index-2].startswith('join_condition -> '):\n    #         return True\n    #\n    #     return False\n    #     # column_table, column_name = action_rhs.strip('\"').split('.')\n    #     # tables_used = [a.split(' -> ')[1].strip('[]\\\"') for a in actions_sequence if a.startswith('table_name -> ')]\n    #     # columns_with_same_name = set([a.split(' -> ')[1].strip('[]\\\"') for a in actions_sequence\n    #     #                 if a.startswith('column_name -> ') and a.strip('[]\\\"').endswith(column_name)])\n    #     # table_of_columns_with_same_name = set([c.split('.')[0] for c in columns_with_same_name])\n    #     # if len(table_of_columns_with_same_name) > 1:\n    #     #     return True\n    #     # if column_table not in tables_used:\n    #     #     return False\n    #     # other_tables = [t for t in tables_used if t != column_table]\n    #     # other_tables_columns = [[c.split(':')[-1] for c in self.knowledge_graph.neighbors[f'table:{t}']] for t in other_tables]\n    #     # other_tables_columns_set = set([item for sublist in other_tables_columns for item in sublist])\n    #     # return column_name in other_tables_columns_set\n"
  },
  {
    "path": "spider_evaluation/evaluate.py",
    "content": "################################\n# val: number(float)/string(str)/sql(dict)\n# col_unit: (agg_id, col_id, isDistinct(bool))\n# val_unit: (unit_op, col_unit1, col_unit2)\n# table_unit: (table_type, col_unit/sql)\n# cond_unit: (not_op, op_id, val_unit, val1, val2)\n# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]\n# sql {\n#   'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])\n#   'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}\n#   'where': condition\n#   'groupBy': [col_unit1, col_unit2, ...]\n#   'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])\n#   'having': condition\n#   'limit': None/limit value\n#   'intersect': None/sql\n#   'except': None/sql\n#   'union': None/sql\n# }\n################################\nimport copy\nimport os\nimport json\nimport sqlite3\nimport argparse\n\nfrom spider_evaluation import get_schema, Schema, get_sql\n\n# Flag to disable value evaluation\nDISABLE_VALUE = True\n# Flag to disable distinct in select evaluation\nDISABLE_DISTINCT = True\n\n\nCLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')\nJOIN_KEYWORDS = ('join', 'on', 'as')\n\nWHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')\nUNIT_OPS = ('none', '-', '+', \"*\", '/')\nAGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')\nTABLE_TYPE = {\n    'sql': \"sql\",\n    'table_unit': \"table_unit\",\n}\n\nCOND_OPS = ('and', 'or')\nSQL_OPS = ('intersect', 'union', 'except')\nORDER_OPS = ('desc', 'asc')\n\n\nHARDNESS = {\n    \"component1\": ('where', 'group', 'order', 'limit', 'join', 'or', 'like'),\n    \"component2\": ('except', 'union', 'intersect')\n}\n\n\ndef condition_has_or(conds):\n    return 'or' in conds[1::2]\n\n\ndef condition_has_like(conds):\n    return WHERE_OPS.index('like') in [cond_unit[1] for cond_unit in conds[::2]]\n\n\ndef condition_has_sql(conds):\n    for cond_unit in conds[::2]:\n        val1, val2 = cond_unit[3], cond_unit[4]\n        if val1 is not None and type(val1) is dict:\n            return True\n        if val2 is not None and type(val2) is dict:\n            return True\n    return False\n\n\ndef val_has_op(val_unit):\n    return val_unit[0] != UNIT_OPS.index('none')\n\n\ndef has_agg(unit):\n    return unit[0] != AGG_OPS.index('none')\n\n\ndef accuracy(count, total):\n    if count == total:\n        return 1\n    return 0\n\n\ndef recall(count, total):\n    if count == total:\n        return 1\n    return 0\n\n\ndef F1(acc, rec):\n    if (acc + rec) == 0:\n        return 0\n    return (2. * acc * rec) / (acc + rec)\n\n\ndef get_scores(count, pred_total, label_total):\n    if pred_total != label_total:\n        return 0,0,0\n    elif count == pred_total:\n        return 1,1,1\n    return 0,0,0\n\n\ndef eval_sel(pred, label):\n    pred_sel = pred['select'][1]\n    label_sel = label['select'][1]\n    label_wo_agg = [unit[1] for unit in label_sel]\n    pred_total = len(pred_sel)\n    label_total = len(label_sel)\n    cnt = 0\n    cnt_wo_agg = 0\n\n    for unit in pred_sel:\n        if unit in label_sel:\n            cnt += 1\n            label_sel.remove(unit)\n        if unit[1] in label_wo_agg:\n            cnt_wo_agg += 1\n            label_wo_agg.remove(unit[1])\n\n    return label_total, pred_total, cnt, cnt_wo_agg\n\n\ndef eval_where(pred, label):\n    pred_conds = [unit for unit in pred['where'][::2]]\n    label_conds = [unit for unit in label['where'][::2]]\n    label_wo_agg = [unit[2] for unit in label_conds]\n    pred_total = len(pred_conds)\n    label_total = len(label_conds)\n    cnt = 0\n    cnt_wo_agg = 0\n\n    for unit in pred_conds:\n        if unit in label_conds:\n            cnt += 1\n            label_conds.remove(unit)\n        if unit[2] in label_wo_agg:\n            cnt_wo_agg += 1\n            label_wo_agg.remove(unit[2])\n\n    return label_total, pred_total, cnt, cnt_wo_agg\n\n\ndef eval_group(pred, label):\n    pred_cols = [unit[1] for unit in pred['groupBy']]\n    label_cols = [unit[1] for unit in label['groupBy']]\n    pred_total = len(pred_cols)\n    label_total = len(label_cols)\n    cnt = 0\n    pred_cols = [pred.split(\".\")[1] if \".\" in pred else pred for pred in pred_cols]\n    label_cols = [label.split(\".\")[1] if \".\" in label else label for label in label_cols]\n    for col in pred_cols:\n        if col in label_cols:\n            cnt += 1\n            label_cols.remove(col)\n    return label_total, pred_total, cnt\n\n\ndef eval_having(pred, label):\n    pred_total = label_total = cnt = 0\n    if len(pred['groupBy']) > 0:\n        pred_total = 1\n    if len(label['groupBy']) > 0:\n        label_total = 1\n\n    pred_cols = [unit[1] for unit in pred['groupBy']]\n    label_cols = [unit[1] for unit in label['groupBy']]\n    if pred_total == label_total == 1 \\\n            and pred_cols == label_cols \\\n            and pred['having'] == label['having']:\n        cnt = 1\n\n    return label_total, pred_total, cnt\n\n\ndef eval_order(pred, label):\n    pred_total = label_total = cnt = 0\n    if len(pred['orderBy']) > 0:\n        pred_total = 1\n    if len(label['orderBy']) > 0:\n        label_total = 1\n    if len(label['orderBy']) > 0 and pred['orderBy'] == label['orderBy'] and \\\n            ((pred['limit'] is None and label['limit'] is None) or (pred['limit'] is not None and label['limit'] is not None)):\n        cnt = 1\n    return label_total, pred_total, cnt\n\n\ndef eval_and_or(pred, label):\n    pred_ao = pred['where'][1::2]\n    label_ao = label['where'][1::2]\n    pred_ao = set(pred_ao)\n    label_ao = set(label_ao)\n\n    if pred_ao == label_ao:\n        return 1,1,1\n    return len(pred_ao),len(label_ao),0\n\n\ndef get_nestedSQL(sql):\n    nested = []\n    for cond_unit in sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]:\n        if type(cond_unit[3]) is dict:\n            nested.append(cond_unit[3])\n        if type(cond_unit[4]) is dict:\n            nested.append(cond_unit[4])\n    if sql['intersect'] is not None:\n        nested.append(sql['intersect'])\n    if sql['except'] is not None:\n        nested.append(sql['except'])\n    if sql['union'] is not None:\n        nested.append(sql['union'])\n    return nested\n\n\ndef eval_nested(pred, label):\n    label_total = 0\n    pred_total = 0\n    cnt = 0\n    if pred is not None:\n        pred_total += 1\n    if label is not None:\n        label_total += 1\n    if pred is not None and label is not None:\n        cnt += Evaluator().eval_exact_match(pred, label)\n    return label_total, pred_total, cnt\n\n\ndef eval_IUEN(pred, label):\n    lt1, pt1, cnt1 = eval_nested(pred['intersect'], label['intersect'])\n    lt2, pt2, cnt2 = eval_nested(pred['except'], label['except'])\n    lt3, pt3, cnt3 = eval_nested(pred['union'], label['union'])\n    label_total = lt1 + lt2 + lt3\n    pred_total = pt1 + pt2 + pt3\n    cnt = cnt1 + cnt2 + cnt3\n    return label_total, pred_total, cnt\n\n\ndef get_keywords(sql):\n    res = set()\n    if len(sql['where']) > 0:\n        res.add('where')\n    if len(sql['groupBy']) > 0:\n        res.add('group')\n    if len(sql['having']) > 0:\n        res.add('having')\n    if len(sql['orderBy']) > 0:\n        res.add(sql['orderBy'][0])\n        res.add('order')\n    if sql['limit'] is not None:\n        res.add('limit')\n    if sql['except'] is not None:\n        res.add('except')\n    if sql['union'] is not None:\n        res.add('union')\n    if sql['intersect'] is not None:\n        res.add('intersect')\n\n    # or keyword\n    ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]\n    if len([token for token in ao if token == 'or']) > 0:\n        res.add('or')\n\n    cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]\n    # not keyword\n    if len([cond_unit for cond_unit in cond_units if cond_unit[0]]) > 0:\n        res.add('not')\n\n    # in keyword\n    if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('in')]) > 0:\n        res.add('in')\n\n    # like keyword\n    if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) > 0:\n        res.add('like')\n\n    return res\n\n\ndef eval_keywords(pred, label):\n    pred_keywords = get_keywords(pred)\n    label_keywords = get_keywords(label)\n    pred_total = len(pred_keywords)\n    label_total = len(label_keywords)\n    cnt = 0\n\n    for k in pred_keywords:\n        if k in label_keywords:\n            cnt += 1\n    return label_total, pred_total, cnt\n\n\ndef count_agg(units):\n    return len([unit for unit in units if has_agg(unit)])\n\n\ndef count_component1(sql):\n    count = 0\n    if len(sql['where']) > 0:\n        count += 1\n    if len(sql['groupBy']) > 0:\n        count += 1\n    if len(sql['orderBy']) > 0:\n        count += 1\n    if sql['limit'] is not None:\n        count += 1\n    if len(sql['from']['table_units']) > 0:  # JOIN\n        count += len(sql['from']['table_units']) - 1\n\n    ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]\n    count += len([token for token in ao if token == 'or'])\n    cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]\n    count += len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')])\n\n    return count\n\n\ndef count_component2(sql):\n    nested = get_nestedSQL(sql)\n    return len(nested)\n\n\ndef count_others(sql):\n    count = 0\n    # number of aggregation\n    agg_count = count_agg(sql['select'][1])\n    agg_count += count_agg(sql['where'][::2])\n    agg_count += count_agg(sql['groupBy'])\n    if len(sql['orderBy']) > 0:\n        agg_count += count_agg([unit[1] for unit in sql['orderBy'][1] if unit[1]] +\n                            [unit[2] for unit in sql['orderBy'][1] if unit[2]])\n    agg_count += count_agg(sql['having'])\n    if agg_count > 1:\n        count += 1\n\n    # number of select columns\n    if len(sql['select'][1]) > 1:\n        count += 1\n\n    # number of where conditions\n    if len(sql['where']) > 1:\n        count += 1\n\n    # number of group by clauses\n    if len(sql['groupBy']) > 1:\n        count += 1\n\n    return count\n\n\nclass Evaluator:\n    \"\"\"A simple evaluator\"\"\"\n    def __init__(self):\n        self.partial_scores = None\n\n    def eval_hardness(self, sql):\n        count_comp1_ = count_component1(sql)\n        count_comp2_ = count_component2(sql)\n        count_others_ = count_others(sql)\n\n        if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0:\n            return \"easy\"\n        elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or \\\n                (count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0):\n            return \"medium\"\n        elif (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or \\\n                (2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or \\\n                (count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1):\n            return \"hard\"\n        else:\n            return \"extra\"\n\n    def eval_exact_match(self, pred, label):\n        partial_scores = self.eval_partial_match(pred, label)\n        self.partial_scores = partial_scores\n\n        for _, score in partial_scores.items():\n            if score['f1'] != 1:\n                return 0\n        if len(label['from']['table_units']) > 0:\n            label_tables = sorted(label['from']['table_units'])\n            pred_tables = sorted(pred['from']['table_units'])\n            if label_tables != pred_tables:\n                return False\n        if len(label['from']['conds']) > 0:\n            label_joins = sorted(label['from']['conds'], key=lambda x: str(x))\n            pred_joins = sorted(pred['from']['conds'], key=lambda x: str(x))\n            if label_joins != pred_joins:\n                return False\n        return 1\n\n    def eval_partial_match(self, pred, label):\n        res = {}\n\n        label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label)\n        acc, rec, f1 = get_scores(cnt, pred_total, label_total)\n        res['select'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}\n        acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)\n        res['select(no AGG)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}\n\n        label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label)\n        acc, rec, f1 = get_scores(cnt, pred_total, label_total)\n        res['where'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}\n        acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)\n        res['where(no OP)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}\n\n        label_total, pred_total, cnt = eval_group(pred, label)\n        acc, rec, f1 = get_scores(cnt, pred_total, label_total)\n        res['group(no Having)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}\n\n        label_total, pred_total, cnt = eval_having(pred, label)\n        acc, rec, f1 = get_scores(cnt, pred_total, label_total)\n        res['group'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}\n\n        label_total, pred_total, cnt = eval_order(pred, label)\n        acc, rec, f1 = get_scores(cnt, pred_total, label_total)\n        res['order'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}\n\n        label_total, pred_total, cnt = eval_and_or(pred, label)\n        acc, rec, f1 = get_scores(cnt, pred_total, label_total)\n        res['and/or'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}\n\n        label_total, pred_total, cnt = eval_IUEN(pred, label)\n        acc, rec, f1 = get_scores(cnt, pred_total, label_total)\n        res['IUEN'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}\n\n        label_total, pred_total, cnt = eval_keywords(pred, label)\n        acc, rec, f1 = get_scores(cnt, pred_total, label_total)\n        res['keywords'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}\n\n        return res\n\n\ndef isValidSQL(sql, db):\n    conn = sqlite3.connect(db)\n    cursor = conn.cursor()\n    try:\n        cursor.execute(sql, [])\n    except Exception as e:\n        return False\n    return True\n\n\ndef print_scores(scores, etype):\n    levels = ['easy', 'medium', 'hard', 'extra', 'all']\n    partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',\n                     'group', 'order', 'and/or', 'IUEN', 'keywords']\n\n    print(\"{:20} {:20} {:20} {:20} {:20} {:20}\".format(\"\", *levels))\n    counts = [scores[level]['count'] for level in levels]\n    print(\"{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}\".format(\"count\", *counts))\n\n    if etype in [\"all\", \"exec\"]:\n        print('=====================   EXECUTION ACCURACY     =====================')\n        this_scores = [scores[level]['exec'] for level in levels]\n        print(\"{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}\".format(\"execution\", *this_scores))\n\n    if etype in [\"all\", \"match\"]:\n        print('\\n====================== EXACT MATCHING ACCURACY =====================')\n        exact_scores = [scores[level]['exact'] for level in levels]\n        print(\"{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}\".format(\"exact match\", *exact_scores))\n        print('\\n---------------------PARTIAL MATCHING ACCURACY----------------------')\n        for type_ in partial_types:\n            this_scores = [scores[level]['partial'][type_]['acc'] for level in levels]\n            print(\"{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}\".format(type_, *this_scores))\n\n        print('---------------------- PARTIAL MATCHING RECALL ----------------------')\n        for type_ in partial_types:\n            this_scores = [scores[level]['partial'][type_]['rec'] for level in levels]\n            print(\"{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}\".format(type_, *this_scores))\n\n        print('---------------------- PARTIAL MATCHING F1 --------------------------')\n        for type_ in partial_types:\n            this_scores = [scores[level]['partial'][type_]['f1'] for level in levels]\n            print(\"{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}\".format(type_, *this_scores))\n\n\ndef evaluate(gold, predict, db_dir, etype, kmaps):\n    with open(gold) as f:\n        glist = [l.strip().split('\\t') for l in f.readlines() if len(l.strip()) > 0]\n\n    with open(predict) as f:\n        plist = [l.strip().split('\\t') for l in f.readlines() if len(l.strip()) > 0]\n    # plist = [(\"select max(Share),min(Share) from performance where Type != 'terminal'\", \"orchestra\")]\n    # glist = [(\"SELECT max(SHARE) ,  min(SHARE) FROM performance WHERE TYPE != 'Live final'\", \"orchestra\")]\n    evaluator = Evaluator()\n\n    levels = ['easy', 'medium', 'hard', 'extra', 'all']\n    partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',\n                     'group', 'order', 'and/or', 'IUEN', 'keywords']\n    entries = []\n    scores = {}\n\n    for level in levels:\n        scores[level] = {'count': 0, 'partial': {}, 'exact': 0.}\n        scores[level]['exec'] = 0\n        for type_ in partial_types:\n            scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0.,'acc_count':0,'rec_count':0}\n\n    eval_err_num = 0\n    for p, g in zip(plist, glist):\n        p_str = p[0]\n        g_str, db = g\n        db_name = db\n        db = os.path.join(db_dir, db, db + \".sqlite\")\n        schema = Schema(get_schema(db))\n        g_sql = get_sql(schema, g_str)\n        hardness = evaluator.eval_hardness(g_sql)\n        scores[hardness]['count'] += 1\n        scores['all']['count'] += 1\n\n        try:\n            p_sql = get_sql(schema, p_str)\n        except:\n            # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql\n            p_sql = {\n            \"except\": None,\n            \"from\": {\n                \"conds\": [],\n                \"table_units\": []\n            },\n            \"groupBy\": [],\n            \"having\": [],\n            \"intersect\": None,\n            \"limit\": None,\n            \"orderBy\": [],\n            \"select\": [\n                False,\n                []\n            ],\n            \"union\": None,\n            \"where\": []\n            }\n            eval_err_num += 1\n            print(\"eval_err_num:{}\".format(eval_err_num))\n            print(p_str)\n            print()\n\n        # rebuild sql for value evaluation\n        kmap = kmaps[db_name]\n        g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema)\n        g_sql = rebuild_sql_val(g_sql)\n        g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap)\n        p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema)\n        p_sql = rebuild_sql_val(p_sql)\n        p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)\n\n        # p_sql_copy = copy.deepcopy(p_sql)\n        # g_sql_copy = copy.deepcopy(g_sql)\n        # if not eval_exec_match(db, p_str, g_str, p_sql_copy, g_sql_copy) and evaluator.eval_exact_match(p_sql_copy, g_sql_copy):\n        #     a = 1\n\n        if etype in [\"all\", \"exec\"]:\n            exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql)\n            if exec_score:\n                scores[hardness]['exec'] += 1\n            scores['all']['exec'] += exec_score\n\n        if etype in [\"all\", \"match\"]:\n            exact_score = evaluator.eval_exact_match(p_sql, g_sql)\n            partial_scores = evaluator.partial_scores\n            if exact_score == 0:\n                print(\"{} pred: {}\".format(hardness,p_str))\n                print(\"{} gold: {}\".format(hardness,g_str))\n                print(\"\")\n            scores[hardness]['exact'] += exact_score\n            scores['all']['exact'] += exact_score\n            for type_ in partial_types:\n                if partial_scores[type_]['pred_total'] > 0:\n                    scores[hardness]['partial'][type_]['acc'] += partial_scores[type_]['acc']\n                    scores[hardness]['partial'][type_]['acc_count'] += 1\n                if partial_scores[type_]['label_total'] > 0:\n                    scores[hardness]['partial'][type_]['rec'] += partial_scores[type_]['rec']\n                    scores[hardness]['partial'][type_]['rec_count'] += 1\n                scores[hardness]['partial'][type_]['f1'] += partial_scores[type_]['f1']\n                if partial_scores[type_]['pred_total'] > 0:\n                    scores['all']['partial'][type_]['acc'] += partial_scores[type_]['acc']\n                    scores['all']['partial'][type_]['acc_count'] += 1\n                if partial_scores[type_]['label_total'] > 0:\n                    scores['all']['partial'][type_]['rec'] += partial_scores[type_]['rec']\n                    scores['all']['partial'][type_]['rec_count'] += 1\n                scores['all']['partial'][type_]['f1'] += partial_scores[type_]['f1']\n\n            entries.append({\n                'predictSQL': p_str,\n                'goldSQL': g_str,\n                'hardness': hardness,\n                'exact': exact_score,\n                'partial': partial_scores\n            })\n\n    for level in levels:\n        if scores[level]['count'] == 0:\n            continue\n        if etype in [\"all\", \"exec\"]:\n            scores[level]['exec'] /= scores[level]['count']\n\n        if etype in [\"all\", \"match\"]:\n            scores[level]['exact'] /= scores[level]['count']\n            for type_ in partial_types:\n                if scores[level]['partial'][type_]['acc_count'] == 0:\n                    scores[level]['partial'][type_]['acc'] = 0\n                else:\n                    scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \\\n                                                             scores[level]['partial'][type_]['acc_count'] * 1.0\n                if scores[level]['partial'][type_]['rec_count'] == 0:\n                    scores[level]['partial'][type_]['rec'] = 0\n                else:\n                    scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \\\n                                                             scores[level]['partial'][type_]['rec_count'] * 1.0\n                if scores[level]['partial'][type_]['acc'] == 0 and scores[level]['partial'][type_]['rec'] == 0:\n                    scores[level]['partial'][type_]['f1'] = 1\n                else:\n                    scores[level]['partial'][type_]['f1'] = \\\n                        2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / (\n                        scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc'])\n\n    print_scores(scores, etype)\n\n\ndef eval_exec_match(db, p_str, g_str, pred, gold):\n    \"\"\"\n    return 1 if the values between prediction and gold are matching\n    in the corresponding index. Currently not support multiple col_unit(pairs).\n    \"\"\"\n    conn = sqlite3.connect(db)\n    cursor = conn.cursor()\n    conn.text_factory = bytes\n    try:\n        cursor.execute(p_str)\n        p_res = cursor.fetchall()\n    except:\n        return False\n\n    cursor.execute(g_str)\n    q_res = cursor.fetchall()\n\n    def res_map(res, val_units):\n        rmap = {}\n        for idx, val_unit in enumerate(val_units):\n            key = tuple(val_unit[1]) if not val_unit[2] else (val_unit[0], tuple(val_unit[1]), tuple(val_unit[2]))\n            rmap[key] = [r[idx] for r in res]\n        return rmap\n\n    p_val_units = [unit[1] for unit in pred['select'][1]]\n    q_val_units = [unit[1] for unit in gold['select'][1]]\n    return res_map(p_res, p_val_units) == res_map(q_res, q_val_units)\n\n\n# Rebuild SQL functions for value evaluation\ndef rebuild_cond_unit_val(cond_unit):\n    if cond_unit is None or not DISABLE_VALUE:\n        return cond_unit\n\n    not_op, op_id, val_unit, val1, val2 = cond_unit\n    if type(val1) is not dict:\n        val1 = None\n    else:\n        val1 = rebuild_sql_val(val1)\n    if type(val2) is not dict:\n        val2 = None\n    else:\n        val2 = rebuild_sql_val(val2)\n    return not_op, op_id, val_unit, val1, val2\n\n\ndef rebuild_condition_val(condition):\n    if condition is None or not DISABLE_VALUE:\n        return condition\n\n    res = []\n    for idx, it in enumerate(condition):\n        if idx % 2 == 0:\n            res.append(rebuild_cond_unit_val(it))\n        else:\n            res.append(it)\n    return res\n\n\ndef rebuild_sql_val(sql):\n    if sql is None or not DISABLE_VALUE:\n        return sql\n\n    sql['from']['conds'] = rebuild_condition_val(sql['from']['conds'])\n    sql['having'] = rebuild_condition_val(sql['having'])\n    sql['where'] = rebuild_condition_val(sql['where'])\n    sql['intersect'] = rebuild_sql_val(sql['intersect'])\n    sql['except'] = rebuild_sql_val(sql['except'])\n    sql['union'] = rebuild_sql_val(sql['union'])\n\n    return sql\n\n\n# Rebuild SQL functions for foreign key evaluation\ndef build_valid_col_units(table_units, schema):\n    col_ids = [table_unit[1] for table_unit in table_units if table_unit[0] == TABLE_TYPE['table_unit']]\n    prefixs = [col_id[:-2] for col_id in col_ids]\n    valid_col_units= []\n    for value in schema.idMap.values():\n        if '.' in value and value[:value.index('.')] in prefixs:\n            valid_col_units.append(value)\n    return valid_col_units\n\n\ndef rebuild_col_unit_col(valid_col_units, col_unit, kmap):\n    if col_unit is None:\n        return col_unit\n\n    agg_id, col_id, distinct = col_unit\n    if col_id in kmap and col_id in valid_col_units:\n        col_id = kmap[col_id]\n    if DISABLE_DISTINCT:\n        distinct = None\n    return agg_id, col_id, distinct\n\n\ndef rebuild_val_unit_col(valid_col_units, val_unit, kmap):\n    if val_unit is None:\n        return val_unit\n\n    unit_op, col_unit1, col_unit2 = val_unit\n    col_unit1 = rebuild_col_unit_col(valid_col_units, col_unit1, kmap)\n    col_unit2 = rebuild_col_unit_col(valid_col_units, col_unit2, kmap)\n    return unit_op, col_unit1, col_unit2\n\n\ndef rebuild_table_unit_col(valid_col_units, table_unit, kmap):\n    if table_unit is None:\n        return table_unit\n\n    table_type, col_unit_or_sql = table_unit\n    if isinstance(col_unit_or_sql, tuple):\n        col_unit_or_sql = rebuild_col_unit_col(valid_col_units, col_unit_or_sql, kmap)\n    return table_type, col_unit_or_sql\n\n\ndef rebuild_cond_unit_col(valid_col_units, cond_unit, kmap):\n    if cond_unit is None:\n        return cond_unit\n\n    not_op, op_id, val_unit, val1, val2 = cond_unit\n    val_unit = rebuild_val_unit_col(valid_col_units, val_unit, kmap)\n    return not_op, op_id, val_unit, val1, val2\n\n\ndef rebuild_condition_col(valid_col_units, condition, kmap):\n    for idx in range(len(condition)):\n        if idx % 2 == 0:\n            condition[idx] = rebuild_cond_unit_col(valid_col_units, condition[idx], kmap)\n    return condition\n\n\ndef rebuild_select_col(valid_col_units, sel, kmap):\n    if sel is None:\n        return sel\n    distinct, _list = sel\n    new_list = []\n    for it in _list:\n        agg_id, val_unit = it\n        new_list.append((agg_id, rebuild_val_unit_col(valid_col_units, val_unit, kmap)))\n    if DISABLE_DISTINCT:\n        distinct = None\n    return distinct, new_list\n\n\ndef rebuild_from_col(valid_col_units, from_, kmap):\n    if from_ is None:\n        return from_\n\n    from_['table_units'] = [rebuild_table_unit_col(valid_col_units, table_unit, kmap) for table_unit in from_['table_units']]\n    from_['conds'] = rebuild_condition_col(valid_col_units, from_['conds'], kmap)\n    return from_\n\n\ndef rebuild_group_by_col(valid_col_units, group_by, kmap):\n    if group_by is None:\n        return group_by\n\n    return [rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by]\n\n\ndef rebuild_order_by_col(valid_col_units, order_by, kmap):\n    if order_by is None or len(order_by) == 0:\n        return order_by\n\n    direction, val_units = order_by\n    new_val_units = [rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units]\n    return direction, new_val_units\n\n\ndef rebuild_sql_col(valid_col_units, sql, kmap):\n    if sql is None:\n        return sql\n\n    sql['select'] = rebuild_select_col(valid_col_units, sql['select'], kmap)\n    sql['from'] = rebuild_from_col(valid_col_units, sql['from'], kmap)\n    sql['where'] = rebuild_condition_col(valid_col_units, sql['where'], kmap)\n    sql['groupBy'] = rebuild_group_by_col(valid_col_units, sql['groupBy'], kmap)\n    sql['orderBy'] = rebuild_order_by_col(valid_col_units, sql['orderBy'], kmap)\n    sql['having'] = rebuild_condition_col(valid_col_units, sql['having'], kmap)\n    sql['intersect'] = rebuild_sql_col(valid_col_units, sql['intersect'], kmap)\n    sql['except'] = rebuild_sql_col(valid_col_units, sql['except'], kmap)\n    sql['union'] = rebuild_sql_col(valid_col_units, sql['union'], kmap)\n\n    return sql\n\n\ndef build_foreign_key_map(entry):\n    cols_orig = entry[\"column_names_original\"]\n    tables_orig = entry[\"table_names_original\"]\n\n    # rebuild cols corresponding to idmap in Schema\n    cols = []\n    for col_orig in cols_orig:\n        if col_orig[0] >= 0:\n            t = tables_orig[col_orig[0]]\n            c = col_orig[1]\n            cols.append(\"__\" + t.lower() + \".\" + c.lower() + \"__\")\n        else:\n            cols.append(\"__all__\")\n\n    def keyset_in_list(k1, k2, k_list):\n        for k_set in k_list:\n            if k1 in k_set or k2 in k_set:\n                return k_set\n        new_k_set = set()\n        k_list.append(new_k_set)\n        return new_k_set\n\n    foreign_key_list = []\n    foreign_keys = entry[\"foreign_keys\"]\n    for fkey in foreign_keys:\n        key1, key2 = fkey\n        key_set = keyset_in_list(key1, key2, foreign_key_list)\n        key_set.add(key1)\n        key_set.add(key2)\n\n    foreign_key_map = {}\n    for key_set in foreign_key_list:\n        sorted_list = sorted(list(key_set))\n        midx = sorted_list[0]\n        for idx in sorted_list:\n            foreign_key_map[cols[idx]] = cols[midx]\n\n    return foreign_key_map\n\n\ndef build_foreign_key_map_from_json(table):\n    with open(table) as f:\n        data = json.load(f)\n    tables = {}\n    for entry in data:\n        tables[entry['db_id']] = build_foreign_key_map(entry)\n    return tables\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--gold', dest='gold', type=str)\n    parser.add_argument('--pred', dest='pred', type=str)\n    parser.add_argument('--db', dest='db', type=str)\n    parser.add_argument('--table', dest='table', type=str)\n    parser.add_argument('--etype', dest='etype', type=str)\n    args = parser.parse_args()\n\n    gold = args.gold\n    pred = args.pred\n    db_dir = args.db\n    table = args.table\n    etype = args.etype\n\n    assert etype in [\"all\", \"exec\", \"match\"], \"Unknown evaluation method\"\n\n    kmaps = build_foreign_key_map_from_json(table)\n\n    evaluate(gold, pred, db_dir, etype, kmaps)"
  },
  {
    "path": "spider_evaluation/process_sql.py",
    "content": "################################\n# Assumptions:\n#   1. sql is correct\n#   2. only table name has alias\n#   3. only one intersect/union/except\n#\n# val: number(float)/string(str)/sql(dict)\n# col_unit: (agg_id, col_id, isDistinct(bool))\n# val_unit: (unit_op, col_unit1, col_unit2)\n# table_unit: (table_type, col_unit/sql)\n# cond_unit: (not_op, op_id, val_unit, val1, val2)\n# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]\n# sql {\n#   'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])\n#   'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}\n#   'where': condition\n#   'groupBy': [col_unit1, col_unit2, ...]\n#   'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])\n#   'having': condition\n#   'limit': None/limit value\n#   'intersect': None/sql\n#   'except': None/sql\n#   'union': None/sql\n# }\n################################\n\nimport json\nimport sqlite3\nfrom nltk import word_tokenize\n\nCLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')\nJOIN_KEYWORDS = ('join', 'on', 'as')\n\nWHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')\nUNIT_OPS = ('none', '-', '+', \"*\", '/')\nAGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')\nTABLE_TYPE = {\n    'sql': \"sql\",\n    'table_unit': \"table_unit\",\n}\n\nCOND_OPS = ('and', 'or')\nSQL_OPS = ('intersect', 'union', 'except')\nORDER_OPS = ('desc', 'asc')\nmapped_entities = []\n\n\nclass Schema:\n    \"\"\"\n    Simple schema which maps table&column to a unique identifier\n    \"\"\"\n    def __init__(self, schema):\n        self._schema = schema\n        self._idMap = self._map(self._schema)\n\n    @property\n    def schema(self):\n        return self._schema\n\n    @property\n    def idMap(self):\n        return self._idMap\n\n    def _map(self, schema):\n        idMap = {'*': \"__all__\"}\n        id = 1\n        for key, vals in schema.items():\n            for val in vals:\n                idMap[key.lower() + \".\" + val.lower()] = \"__\" + key.lower() + \".\" + val.lower() + \"__\"\n                id += 1\n\n        for key in schema:\n            idMap[key.lower()] = \"__\" + key.lower() + \"__\"\n            id += 1\n\n        return idMap\n\n\ndef get_schema(db):\n    \"\"\"\n    Get database's schema, which is a dict with table name as key\n    and list of column names as value\n    :param db: database path\n    :return: schema dict\n    \"\"\"\n\n    schema = {}\n    conn = sqlite3.connect(db)\n    cursor = conn.cursor()\n\n    # fetch table names\n    cursor.execute(\"SELECT name FROM sqlite_master WHERE type='table';\")\n    tables = [str(table[0].lower()) for table in cursor.fetchall()]\n\n    # fetch table info\n    for table in tables:\n        cursor.execute(\"PRAGMA table_info({})\".format(table))\n        schema[table] = [str(col[1].lower()) for col in cursor.fetchall()]\n\n    return schema\n\n\ndef get_schema_from_json(fpath):\n    with open(fpath) as f:\n        data = json.load(f)\n\n    schema = {}\n    for entry in data:\n        table = str(entry['table'].lower())\n        cols = [str(col['column_name'].lower()) for col in entry['col_data']]\n        schema[table] = cols\n\n    return schema\n\n\ndef tokenize(string):\n    string = str(string)\n    string = string.replace(\"\\'\", \"\\\"\")  # ensures all string values wrapped by \"\" problem??\n    quote_idxs = [idx for idx, char in enumerate(string) if char == '\"']\n    assert len(quote_idxs) % 2 == 0, \"Unexpected quote\"\n\n    # keep string value as token\n    vals = {}\n    for i in range(len(quote_idxs)-1, -1, -2):\n        qidx1 = quote_idxs[i-1]\n        qidx2 = quote_idxs[i]\n        val = string[qidx1: qidx2+1]\n        key = \"__val_{}_{}__\".format(qidx1, qidx2)\n        string = string[:qidx1] + key + string[qidx2+1:]\n        vals[key] = val\n\n    toks = [word.lower() for word in word_tokenize(string)]\n    # replace with string value token\n    for i in range(len(toks)):\n        if toks[i] in vals:\n            toks[i] = vals[toks[i]]\n\n    # find if there exists !=, >=, <=\n    eq_idxs = [idx for idx, tok in enumerate(toks) if tok == \"=\"]\n    eq_idxs.reverse()\n    prefix = ('!', '>', '<')\n    for eq_idx in eq_idxs:\n        pre_tok = toks[eq_idx-1]\n        if pre_tok in prefix:\n            toks = toks[:eq_idx-1] + [pre_tok + \"=\"] + toks[eq_idx+1: ]\n\n    return toks\n\n\ndef scan_alias(toks):\n    \"\"\"Scan the index of 'as' and build the map for all alias\"\"\"\n    as_idxs = [idx for idx, tok in enumerate(toks) if tok == 'as']\n    alias = {}\n    for idx in as_idxs:\n        alias[toks[idx+1]] = toks[idx-1]\n    return alias\n\n\ndef get_tables_with_alias(schema, toks):\n    tables = scan_alias(toks)\n    for key in schema:\n        assert key not in tables, \"Alias {} has the same name in table\".format(key)\n        tables[key] = key\n    return tables\n\n\ndef parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None):\n    \"\"\"\n        :returns next idx, column id\n    \"\"\"\n    global mapped_entities\n    tok = toks[start_idx]\n    if tok == \"*\":\n        return start_idx + 1, schema.idMap[tok]\n\n    if '.' in tok:  # if token is a composite\n        alias, col = tok.split('.')\n        key = tables_with_alias[alias] + \".\" + col\n        mapped_entities.append((start_idx, tables_with_alias[alias] + \"@\" + col))\n        return start_idx+1, schema.idMap[key]\n\n    assert default_tables is not None and len(default_tables) > 0, \"Default tables should not be None or empty\"\n\n    for alias in default_tables:\n        table = tables_with_alias[alias]\n        if tok in schema.schema[table]:\n            key = table + \".\" + tok\n            mapped_entities.append((start_idx, table + \"@\" + tok))\n            return start_idx+1, schema.idMap[key]\n\n    assert False, \"Error col: {}\".format(tok)\n\n\ndef parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=None):\n    \"\"\"\n        :returns next idx, (agg_op id, col_id)\n    \"\"\"\n    idx = start_idx\n    len_ = len(toks)\n    isBlock = False\n    isDistinct = False\n    if toks[idx] == '(':\n        isBlock = True\n        idx += 1\n\n    if toks[idx] in AGG_OPS:\n        agg_id = AGG_OPS.index(toks[idx])\n        idx += 1\n        assert idx < len_ and toks[idx] == '('\n        idx += 1\n        if toks[idx] == \"distinct\":\n            idx += 1\n            isDistinct = True\n        idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables)\n        assert idx < len_ and toks[idx] == ')'\n        idx += 1\n        return idx, (agg_id, col_id, isDistinct)\n\n    if toks[idx] == \"distinct\":\n        idx += 1\n        isDistinct = True\n    agg_id = AGG_OPS.index(\"none\")\n    idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables)\n\n    if isBlock:\n        assert toks[idx] == ')'\n        idx += 1  # skip ')'\n\n    return idx, (agg_id, col_id, isDistinct)\n\n\ndef parse_val_unit(toks, start_idx, tables_with_alias, schema, default_tables=None):\n    idx = start_idx\n    len_ = len(toks)\n    isBlock = False\n    if toks[idx] == '(':\n        isBlock = True\n        idx += 1\n\n    col_unit1 = None\n    col_unit2 = None\n    unit_op = UNIT_OPS.index('none')\n\n    idx, col_unit1 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)\n    if idx < len_ and toks[idx] in UNIT_OPS:\n        unit_op = UNIT_OPS.index(toks[idx])\n        idx += 1\n        idx, col_unit2 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)\n\n    if isBlock:\n        assert toks[idx] == ')'\n        idx += 1  # skip ')'\n\n    return idx, (unit_op, col_unit1, col_unit2)\n\n\ndef parse_table_unit(toks, start_idx, tables_with_alias, schema):\n    \"\"\"\n        :returns next idx, table id, table name\n    \"\"\"\n    idx = start_idx\n    len_ = len(toks)\n    key = tables_with_alias[toks[idx]]\n\n    if idx + 1 < len_ and toks[idx+1] == \"as\":\n        idx += 3\n    else:\n        idx += 1\n\n    return idx, schema.idMap[key], key\n\n\ndef parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None):\n    idx = start_idx\n    len_ = len(toks)\n\n    isBlock = False\n    if toks[idx] == '(':\n        isBlock = True\n        idx += 1\n\n    if toks[idx] == 'select':\n        idx, val = parse_sql(toks, idx, tables_with_alias, schema)\n    elif \"\\\"\" in toks[idx]:  # token is a string value\n        val = toks[idx]\n        idx += 1\n    else:\n        try:\n            val = float(toks[idx])\n            idx += 1\n        except:\n            end_idx = idx\n            while end_idx < len_ and toks[end_idx] != ',' and toks[end_idx] != ')'\\\n                and toks[end_idx] != 'and' and toks[end_idx] not in CLAUSE_KEYWORDS and toks[end_idx] not in JOIN_KEYWORDS:\n                    end_idx += 1\n\n            idx, val = parse_col_unit(toks[: end_idx], start_idx, tables_with_alias, schema, default_tables)\n            idx = end_idx\n\n    if isBlock:\n        assert toks[idx] == ')'\n        idx += 1\n\n    return idx, val\n\n\ndef parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=None):\n    idx = start_idx\n    len_ = len(toks)\n    conds = []\n\n    while idx < len_:\n        idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)\n        not_op = False\n        if toks[idx] == 'not':\n            not_op = True\n            idx += 1\n\n        assert idx < len_ and toks[idx] in WHERE_OPS, \"Error condition: idx: {}, tok: {}\".format(idx, toks[idx])\n        op_id = WHERE_OPS.index(toks[idx])\n        idx += 1\n        val1 = val2 = None\n        if op_id == WHERE_OPS.index('between'):  # between..and... special case: dual values\n            idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables)\n            assert toks[idx] == 'and'\n            idx += 1\n            idx, val2 = parse_value(toks, idx, tables_with_alias, schema, default_tables)\n        else:  # normal case: single value\n            idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables)\n            val2 = None\n\n        conds.append((not_op, op_id, val_unit, val1, val2))\n\n        if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (\")\", \";\") or toks[idx] in JOIN_KEYWORDS):\n            break\n\n        if idx < len_ and toks[idx] in COND_OPS:\n            conds.append(toks[idx])\n            idx += 1  # skip and/or\n\n    return idx, conds\n\n\ndef parse_select(toks, start_idx, tables_with_alias, schema, default_tables=None):\n    idx = start_idx\n    len_ = len(toks)\n\n    assert toks[idx] == 'select', \"'select' not found\"\n    idx += 1\n    isDistinct = False\n    if idx < len_ and toks[idx] == 'distinct':\n        idx += 1\n        isDistinct = True\n    val_units = []\n\n    while idx < len_ and toks[idx] not in CLAUSE_KEYWORDS:\n        agg_id = AGG_OPS.index(\"none\")\n        if toks[idx] in AGG_OPS:\n            agg_id = AGG_OPS.index(toks[idx])\n            idx += 1\n        idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)\n        val_units.append((agg_id, val_unit))\n        if idx < len_ and toks[idx] == ',':\n            idx += 1  # skip ','\n\n    return idx, (isDistinct, val_units)\n\n\ndef parse_from(toks, start_idx, tables_with_alias, schema):\n    \"\"\"\n    Assume in the from clause, all table units are combined with join\n    \"\"\"\n    assert 'from' in toks[start_idx:], \"'from' not found\"\n\n    len_ = len(toks)\n    idx = toks.index('from', start_idx) + 1\n    default_tables = []\n    table_units = []\n    conds = []\n\n    while idx < len_:\n        isBlock = False\n        if toks[idx] == '(':\n            isBlock = True\n            idx += 1\n\n        if toks[idx] == 'select':\n            idx, sql = parse_sql(toks, idx, tables_with_alias, schema)\n            table_units.append((TABLE_TYPE['sql'], sql))\n        else:\n            if idx < len_ and toks[idx] == 'join':\n                idx += 1  # skip join\n            idx, table_unit, table_name = parse_table_unit(toks, idx, tables_with_alias, schema)\n            table_units.append((TABLE_TYPE['table_unit'],table_unit))\n            default_tables.append(table_name)\n        if idx < len_ and toks[idx] == \"on\":\n            idx += 1  # skip on\n            idx, this_conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)\n            if len(conds) > 0:\n                conds.append('and')\n            conds.extend(this_conds)\n\n        if isBlock:\n            assert toks[idx] == ')'\n            idx += 1\n        if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (\")\", \";\")):\n            break\n\n    return idx, table_units, conds, default_tables\n\n\ndef parse_where(toks, start_idx, tables_with_alias, schema, default_tables):\n    idx = start_idx\n    len_ = len(toks)\n\n    if idx >= len_ or toks[idx] != 'where':\n        return idx, []\n\n    idx += 1\n    idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)\n    return idx, conds\n\n\ndef parse_group_by(toks, start_idx, tables_with_alias, schema, default_tables):\n    idx = start_idx\n    len_ = len(toks)\n    col_units = []\n\n    if idx >= len_ or toks[idx] != 'group':\n        return idx, col_units\n\n    idx += 1\n    assert toks[idx] == 'by'\n    idx += 1\n\n    while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (\")\", \";\")):\n        idx, col_unit = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)\n        col_units.append(col_unit)\n        if idx < len_ and toks[idx] == ',':\n            idx += 1  # skip ','\n        else:\n            break\n\n    return idx, col_units\n\n\ndef parse_order_by(toks, start_idx, tables_with_alias, schema, default_tables):\n    idx = start_idx\n    len_ = len(toks)\n    val_units = []\n    order_type = 'asc' # default type is 'asc'\n\n    if idx >= len_ or toks[idx] != 'order':\n        return idx, val_units\n\n    idx += 1\n    assert toks[idx] == 'by'\n    idx += 1\n\n    while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (\")\", \";\")):\n        idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)\n        val_units.append(val_unit)\n        if idx < len_ and toks[idx] in ORDER_OPS:\n            order_type = toks[idx]\n            idx += 1\n        if idx < len_ and toks[idx] == ',':\n            idx += 1  # skip ','\n        else:\n            break\n\n    return idx, (order_type, val_units)\n\n\ndef parse_having(toks, start_idx, tables_with_alias, schema, default_tables):\n    idx = start_idx\n    len_ = len(toks)\n\n    if idx >= len_ or toks[idx] != 'having':\n        return idx, []\n\n    idx += 1\n    idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)\n    return idx, conds\n\n\ndef parse_limit(toks, start_idx):\n    idx = start_idx\n    len_ = len(toks)\n\n    if idx < len_ and toks[idx] == 'limit':\n        idx += 2\n        try:\n            limit_val = int(toks[idx-1])\n        except Exception:\n            limit_val = '\"value\"'\n        return idx, limit_val\n\n    return idx, None\n\n\ndef parse_sql(toks, start_idx, tables_with_alias, schema, mapped_entities_fn=None):\n    global mapped_entities\n\n    if mapped_entities_fn is not None:\n        mapped_entities = mapped_entities_fn()\n    isBlock = False # indicate whether this is a block of sql/sub-sql\n    len_ = len(toks)\n    idx = start_idx\n\n    sql = {}\n    if toks[idx] == '(':\n        isBlock = True\n        idx += 1\n\n    # parse from clause in order to get default tables\n    from_end_idx, table_units, conds, default_tables = parse_from(toks, start_idx, tables_with_alias, schema)\n    sql['from'] = {'table_units': table_units, 'conds': conds}\n    # select clause\n    _, select_col_units = parse_select(toks, idx, tables_with_alias, schema, default_tables)\n    idx = from_end_idx\n    sql['select'] = select_col_units\n    # where clause\n    idx, where_conds = parse_where(toks, idx, tables_with_alias, schema, default_tables)\n    sql['where'] = where_conds\n    # group by clause\n    idx, group_col_units = parse_group_by(toks, idx, tables_with_alias, schema, default_tables)\n    sql['groupBy'] = group_col_units\n    # having clause\n    idx, having_conds = parse_having(toks, idx, tables_with_alias, schema, default_tables)\n    sql['having'] = having_conds\n    # order by clause\n    idx, order_col_units = parse_order_by(toks, idx, tables_with_alias, schema, default_tables)\n    sql['orderBy'] = order_col_units\n    # limit clause\n    idx, limit_val = parse_limit(toks, idx)\n    sql['limit'] = limit_val\n\n    idx = skip_semicolon(toks, idx)\n    if isBlock:\n        assert toks[idx] == ')'\n        idx += 1  # skip ')'\n    idx = skip_semicolon(toks, idx)\n\n    # intersect/union/except clause\n    for op in SQL_OPS:  # initialize IUE\n        sql[op] = None\n    if idx < len_ and toks[idx] in SQL_OPS:\n        sql_op = toks[idx]\n        idx += 1\n        idx, IUE_sql = parse_sql(toks, idx, tables_with_alias, schema)\n        sql[sql_op] = IUE_sql\n\n    if mapped_entities_fn is not None:\n        return idx, sql, mapped_entities\n    else:\n        return idx, sql\n\n\ndef load_data(fpath):\n    with open(fpath) as f:\n        data = json.load(f)\n    return data\n\n\ndef get_sql(schema, query):\n    toks = tokenize(query)\n    tables_with_alias = get_tables_with_alias(schema.schema, toks)\n    _, sql = parse_sql(toks, 0, tables_with_alias, schema)\n\n    return sql\n\n\ndef skip_semicolon(toks, start_idx):\n    idx = start_idx\n    while idx < len(toks) and toks[idx] == \";\":\n        idx += 1\n    return idx"
  },
  {
    "path": "state_machines/states/grammar_based_state.py",
    "content": "from typing import Any, Dict, List, Sequence, Tuple\n\nimport torch\n\nfrom allennlp.data.fields.production_rule_field import ProductionRule\nfrom allennlp.state_machines.states.grammar_statelet import GrammarStatelet\nfrom allennlp.state_machines.states.rnn_statelet import RnnStatelet\nfrom allennlp.state_machines.states.state import State\n\n# This syntax is pretty weird and ugly, but it's necessary to make mypy happy with the API that\n# we've defined.  We're using generics to make the type of `combine_states` come out right.  See\n# the note in `state_machines.state.py` for a little more detail.\nfrom state_machines.states.sql_state import SqlState\n\n\nclass GrammarBasedState(State['GrammarBasedState']):\n    \"\"\"\n    A generic State that's suitable for most models that do grammar-based decoding.  We keep around\n    a `group` of states, and each element in the group has a few things: a batch index, an action\n    history, a score, an ``RnnStatelet``, and a ``GrammarStatelet``.  We additionally have some\n    information that's independent of any particular group element: a list of all possible actions\n    for all batch instances passed to ``model.forward()``, and a ``extras`` field that you can use\n    if you really need some extra information about each batch instance (like a string description,\n    or other metadata).\n\n    Finally, we also have a specially-treated, optional ``debug_info`` field.  If this is given, it\n    should be an empty list for each group instance when the initial state is created.  In that\n    case, we will keep around information about the actions considered at each timestep of decoding\n    and other things that you might want to visualize in a demo.  This probably isn't necessary for\n    training, and to get it right we need to copy a bunch of data structures for each new state, so\n    it's best used only at evaluation / demo time.\n\n    Parameters\n    ----------\n    batch_indices : ``List[int]``\n        Passed to super class; see docs there.\n    action_history : ``List[List[int]]``\n        Passed to super class; see docs there.\n    score : ``List[torch.Tensor]``\n        Passed to super class; see docs there.\n    rnn_state : ``List[RnnStatelet]``\n        An ``RnnStatelet`` for every group element.  This keeps track of the current decoder hidden\n        state, the previous decoder output, the output from the encoder (for computing attentions),\n        and other things that are typical seq2seq decoder state things.\n    grammar_state : ``List[GrammarStatelet]``\n        This hold the current grammar state for each element of the group.  The ``GrammarStatelet``\n        keeps track of which actions are currently valid.\n    possible_actions : ``List[List[ProductionRule]]``\n        The list of all possible actions that was passed to ``model.forward()``.  We need this so\n        we can recover production strings, which we need to update grammar states.\n    extras : ``List[Any]``, optional (default=None)\n        If you need to keep around some extra data for each instance in the batch, you can put that\n        in here, without adding another field.  This should be used `very sparingly`, as there is\n        no type checking or anything done on the contents of this field, and it will just be passed\n        around between ``States`` as-is, without copying.\n    debug_info : ``List[Any]``, optional (default=None).\n    \"\"\"\n    def __init__(self,\n                 batch_indices: List[int],\n                 action_history: List[List[int]],\n                 score: List[torch.Tensor],\n                 rnn_state: List[RnnStatelet],\n                 grammar_state: List[GrammarStatelet],\n                 sql_state: List[SqlState],\n                 possible_actions: List[List[ProductionRule]],\n                 action_entity_mapping: List[Dict[int, int]],\n                 extras: List[Any] = None,\n                 debug_info: List = None) -> None:\n        super().__init__(batch_indices, action_history, score)\n        self.rnn_state = rnn_state\n        self.grammar_state = grammar_state\n        self.sql_state = sql_state\n        self.possible_actions = possible_actions\n        self.action_entity_mapping = action_entity_mapping\n        self.extras = extras\n        self.debug_info = debug_info\n\n    def new_state_from_group_index(self,\n                                   group_index: int,\n                                   action: int,\n                                   new_score: torch.Tensor,\n                                   new_rnn_state: RnnStatelet,\n                                   considered_actions: List[int] = None,\n                                   action_probabilities: List[float] = None,\n                                   attention_weights: torch.Tensor = None,\n                                   linking_scores_qst: torch.Tensor = None,\n                                   linking_scores_past: torch.Tensor = None) -> 'GrammarBasedState':\n        batch_index = self.batch_indices[group_index]\n        new_action_history = self.action_history[group_index] + [action]\n        production_rule = self.possible_actions[batch_index][action][0]\n        new_grammar_state = self.grammar_state[group_index].take_action(production_rule)\n        new_sql_state = self.sql_state[group_index].take_action(production_rule)\n        if self.debug_info is not None:\n            attention = attention_weights[group_index] if attention_weights is not None else None\n            # output_attention = output_attention_weights[group_index] if output_attention_weights is not None else None\n            debug_info = {\n                    'considered_actions': considered_actions,\n                    'question_attention': attention,\n                    'probabilities': action_probabilities,\n                    'linking_scores_qst': linking_scores_qst,\n                    'linking_scores_past': linking_scores_past,\n                    }\n            new_debug_info = [self.debug_info[group_index] + [debug_info]]\n        else:\n            new_debug_info = None\n        return GrammarBasedState(batch_indices=[batch_index],\n                                 action_history=[new_action_history],\n                                 score=[new_score],\n                                 rnn_state=[new_rnn_state],\n                                 grammar_state=[new_grammar_state],\n                                 sql_state=[new_sql_state],\n                                 possible_actions=self.possible_actions,\n                                 action_entity_mapping=self.action_entity_mapping,\n                                 extras=self.extras,\n                                 debug_info=new_debug_info)\n\n    def print_action_history(self, group_index: int = None) -> None:\n        scores = self.score if group_index is None else [self.score[group_index]]\n        batch_indices = self.batch_indices if group_index is None else [self.batch_indices[group_index]]\n        histories = self.action_history if group_index is None else [self.action_history[group_index]]\n        for score, batch_index, action_history in zip(scores, batch_indices, histories):\n            print('  ', score.detach().cpu().numpy()[0],\n                  [self.possible_actions[batch_index][action][0] for action in action_history])\n\n    def get_valid_actions(self) -> List[Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]]:\n        \"\"\"\n        Returns a list of valid actions for each element of the group.\n        \"\"\"\n        # return [state.get_valid_actions() for state in self.grammar_state]\n        return [sql_state.get_valid_actions(grammar_state.get_valid_actions())\n                for sql_state, grammar_state in zip(self.sql_state, self.grammar_state)]\n\n    def is_finished(self) -> bool:\n        if len(self.batch_indices) != 1:\n            raise RuntimeError(\"is_finished() is only defined with a group_size of 1\")\n        return self.grammar_state[0].is_finished()\n\n    @classmethod\n    def combine_states(cls, states: Sequence['GrammarBasedState']) -> 'GrammarBasedState':\n        batch_indices = [batch_index for state in states for batch_index in state.batch_indices]\n        action_histories = [action_history for state in states for action_history in state.action_history]\n        scores = [score for state in states for score in state.score]\n        rnn_states = [rnn_state for state in states for rnn_state in state.rnn_state]\n        grammar_states = [grammar_state for state in states for grammar_state in state.grammar_state]\n        sql_states = [sql_state for state in states for sql_state in state.sql_state]\n        if states[0].debug_info is not None:\n            debug_info = [debug_info for state in states for debug_info in state.debug_info]\n        else:\n            debug_info = None\n        return GrammarBasedState(batch_indices=batch_indices,\n                                 action_history=action_histories,\n                                 score=scores,\n                                 rnn_state=rnn_states,\n                                 grammar_state=grammar_states,\n                                 sql_state=sql_states,\n                                 possible_actions=states[0].possible_actions,\n                                 action_entity_mapping=states[0].action_entity_mapping,\n                                 extras=states[0].extras,\n                                 debug_info=debug_info)\n"
  },
  {
    "path": "state_machines/states/rnn_statelet.py",
    "content": "from typing import List, Optional\n\nimport torch\n\nfrom allennlp.nn import util\n\n\nclass RnnStatelet:\n    \"\"\"\n    This class keeps track of all of decoder-RNN-related variables that you need during decoding.\n    This includes things like the current decoder hidden state, the memory cell (for LSTM\n    decoders), the encoder output that you need for computing attentions, and so on.\n\n    This is intended to be used `inside` a ``State``, which likely has other things it has to keep\n    track of for doing constrained decoding.\n\n    Parameters\n    ----------\n    hidden_state : ``torch.Tensor``\n        This holds the LSTM hidden state, with shape ``(decoder_output_dim,)`` if the decoder\n        has 1 layer and ``(num_layers, decoder_output_dim)`` otherwise.\n    memory_cell : ``torch.Tensor``\n        This holds the LSTM memory cell, with shape ``(decoder_output_dim,)`` if the decoder has\n        1 layer and ``(num_layers, decoder_output_dim)`` otherwise.\n    previous_action_embedding : ``torch.Tensor``\n        This holds the embedding for the action we took at the last timestep (which gets input to\n        the decoder).  Has shape ``(action_embedding_dim,)``.\n    attended_input : ``torch.Tensor``\n        This holds the attention-weighted sum over the input representations that we computed in\n        the previous timestep.  We keep this as part of the state because we use the previous\n        attention as part of our decoder cell update.  Has shape ``(encoder_output_dim,)``.\n    encoder_outputs : ``List[torch.Tensor]``\n        A list of variables, each of shape ``(input_sequence_length, encoder_output_dim)``,\n        containing the encoder outputs at each timestep.  The list is over batch elements, and we\n        do the input this way so we can easily do a ``torch.cat`` on a list of indices into this\n        batched list.\n\n        Note that all of the above parameters are single tensors, while the encoder outputs and\n        mask are lists of length ``batch_size``.  We always pass around the encoder outputs and\n        mask unmodified, regardless of what's in the grouping for this state.  We'll use the\n        ``batch_indices`` for the group to pull pieces out of these lists when we're ready to\n        actually do some computation.\n    encoder_output_mask : ``List[torch.Tensor]``\n        A list of variables, each of shape ``(input_sequence_length,)``, containing a mask over\n        question tokens for each batch instance.  This is a list over batch elements, for the same\n        reasons as above.\n    \"\"\"\n    def __init__(self,\n                 hidden_state: torch.Tensor,\n                 memory_cell: torch.Tensor,\n                 previous_action_embedding: torch.Tensor,\n                 attended_input: torch.Tensor,\n                 encoder_outputs: List[torch.Tensor],\n                 encoder_output_mask: List[torch.Tensor],\n                 decoder_outputs: Optional[torch.Tensor] = None) -> None:\n        self.hidden_state = hidden_state\n        self.memory_cell = memory_cell\n        self.previous_action_embedding = previous_action_embedding\n        self.attended_input = attended_input\n        self.encoder_outputs = encoder_outputs\n        self.encoder_output_mask = encoder_output_mask\n        self.decoder_outputs = decoder_outputs\n\n    def __eq__(self, other):\n        if isinstance(self, other.__class__):\n            return all([\n                    util.tensors_equal(self.hidden_state, other.hidden_state, tolerance=1e-5),\n                    util.tensors_equal(self.memory_cell, other.memory_cell, tolerance=1e-5),\n                    util.tensors_equal(self.previous_action_embedding, other.previous_action_embedding,\n                                       tolerance=1e-5),\n                    util.tensors_equal(self.attended_input, other.attended_input, tolerance=1e-5),\n                    ])\n        return NotImplemented\n"
  },
  {
    "path": "state_machines/states/sql_state.py",
    "content": "import copy\nimport logging\n\nlogger = logging.getLogger(__name__)  # pylint: disable=invalid-name\n\n\nclass SqlState:\n    def __init__(self,\n                 possible_actions,\n                 enabled: bool=True):\n        self.possible_actions = [a[0] for a in possible_actions]\n        self.action_history = []\n        self.tables_used = set()\n        self.tables_used_by_columns = set()\n        self.current_stack = []\n        self.subqueries_stack = []\n        self.enabled = enabled\n\n    def take_action(self, production_rule: str) -> 'SqlState':\n        if not self.enabled:\n            return self\n\n        new_sql_state = copy.deepcopy(self)\n\n        lhs, rhs = production_rule.split(' -> ')\n        rhs_tokens = rhs.strip('[]').split(', ')\n        if lhs == 'table_name':\n            new_sql_state.tables_used.add(rhs_tokens[0].strip('\"'))\n        elif lhs == 'column_name':\n            new_sql_state.tables_used_by_columns.add(rhs_tokens[0].strip('\"').split('@')[0])\n        elif lhs == 'iue':\n            new_sql_state.tables_used_by_columns = set()\n            new_sql_state.tables_used = set()\n        elif lhs == \"source_subq\":\n            new_sql_state.subqueries_stack.append(copy.deepcopy(new_sql_state))\n            new_sql_state.tables_used = set()\n            new_sql_state.tables_used_by_columns = set()\n\n        new_sql_state.action_history.append(production_rule)\n\n        new_sql_state.current_stack.append([lhs, []])\n\n        for token in rhs_tokens:\n            is_terminal = token[0] == '\"' and token[-1] == '\"'\n            if not is_terminal:\n                new_sql_state.current_stack[-1][1].append(token)\n\n        while len(new_sql_state.current_stack[-1][1]) == 0:\n            finished_item = new_sql_state.current_stack[-1][0]\n            del new_sql_state.current_stack[-1]\n            if finished_item == 'statement':\n                break\n            if new_sql_state.current_stack[-1][1][0] == finished_item:\n                new_sql_state.current_stack[-1][1] = new_sql_state.current_stack[-1][1][1:]\n\n            if finished_item == 'source_subq':\n                new_sql_state.tables_used = new_sql_state.subqueries_stack[-1].tables_used\n                new_sql_state.tables_used_by_columns = new_sql_state.subqueries_stack[-1].tables_used_by_columns\n                del new_sql_state.subqueries_stack[-1]\n\n        return new_sql_state\n\n    def get_valid_actions(self, valid_actions: dict):\n        if not self.enabled:\n            return valid_actions\n\n        valid_actions_ids = []\n        for key, items in valid_actions.items():\n            valid_actions_ids += [(key, rule_id) for rule_id in valid_actions[key][2]]\n        valid_actions_rules = [self.possible_actions[rule_id] for rule_type, rule_id in valid_actions_ids]\n\n        actions_to_remove = {k: set() for k in valid_actions.keys()}\n\n        current_clause = self._get_current_open_clause()\n\n        if current_clause in ['where_clause', 'orderby_clause', 'join_condition', 'groupby_clause']:\n            for rule_id, rule in zip(valid_actions_ids, valid_actions_rules):\n                rule_type, rule_id = rule_id\n                lhs, rhs = rule.split(' -> ')\n                rhs_values = rhs.strip('[]').split(', ')\n                if lhs == 'column_name':\n                    rule_table = rhs_values[0].strip('\"').split('@')[0]\n                    if rule_table not in self.tables_used:\n                        actions_to_remove[rule_type].add(rule_id)\n\n                # if len(self.current_stack[-1][1]) < 2:\n                #     # disable condition clause when same tables\n                #     rule_table = rhs_values[0].strip('\"').split('@')[0]\n                #     last_table = self.action_history[-1].split(' -> ')[1].strip('[]\"').split('@')[0]\n                #     if rule_table == last_table:\n                #         actions_to_remove[rule_type].add(rule_id)\n\n        if current_clause in ['join_clause']:\n            for rule_id, rule in zip(valid_actions_ids, valid_actions_rules):\n                rule_type, rule_id = rule_id\n                lhs, rhs = rule.split(' -> ')\n                rhs_values = rhs.strip('[]').split(', ')\n                if lhs == 'table_name':\n                    candidate_table = rhs_values[0].strip('\"')\n\n                    if current_clause == 'join_clause' and len(self.current_stack[-1][1]) == 2:\n                        if candidate_table in self.tables_used:\n                            # trying to join an already joined table\n                            actions_to_remove[rule_type].add(rule_id)\n\n                    if 'join_clauses' not in self.current_stack[-2][1] and not self.current_stack[-2][0].startswith('join_clauses'):\n                        # decided not to join any more tables\n                        remaining_joins = self.tables_used_by_columns - self.tables_used\n                        if len(remaining_joins) > 0 and candidate_table not in self.tables_used_by_columns:\n                            # trying to select a single table but used other table(s) in columns\n                            actions_to_remove[rule_type].add(rule_id)\n\n        if current_clause in ['select_core']:\n            for rule_id, rule in zip(valid_actions_ids, valid_actions_rules):\n                rule_type, rule_id = rule_id\n                lhs, rhs = rule.split(' -> ')\n                rhs_values = rhs.strip('[]').split(', ')\n                if self.current_stack[-1][1][0] == 'from_clause' or self.current_stack[-1][1][0] == 'join_clauses':\n                    all_tables = set([a.split(' -> ')[1].strip('[]\\\"') for a in self.possible_actions if\n                                      a.startswith('table_name ->')])\n                    if len(self.tables_used_by_columns - self.tables_used) > 1:\n                        # selected columns from more tables than selected, must join\n                        if 'join_clauses' not in rhs:\n                            actions_to_remove[rule_type].add(rule_id)\n                    if len(all_tables - self.tables_used) <= 1:\n                        # don't join 2 tables because otherwise there will be no more tables to join\n                        # (assuming no joining twice and no sub-queries)\n                        if 'join_clauses' in rhs:\n                            actions_to_remove[rule_type].add(rule_id)\n                if lhs == \"table_name\" and self.current_stack[-1][0] == \"single_source\":\n                    candidate_table = rhs_values[0].strip('\"')\n                    if len(self.tables_used_by_columns) > 0 and candidate_table not in self.tables_used_by_columns:\n                        # trying to select a single table but used other table(s) in columns\n                        actions_to_remove[rule_type].add(rule_id)\n\n                if lhs == 'single_source' and len(self.tables_used_by_columns) == 0 and rhs.strip('[]') == 'source_subq':\n                    # prevent cases such as \"select count ( * ) from ( select city.district from city ) where city.district = ' value '\"\n                    search_stack_pos = -1\n                    while self.current_stack[search_stack_pos][0] != 'select_core':\n                        # note - should look for other \"gateaways\" here (i.e. maybe this is not a dead end, if there is\n                        # another source_subq. This is ignored here\n                        search_stack_pos -= 1\n                    if self.current_stack[search_stack_pos][1][-1] == 'where_clause':\n                        # planning to add where/group/order later, but no columns were ever selected\n                        actions_to_remove[rule_type].add(rule_id)\n\n                    while self.current_stack[search_stack_pos][0] != 'query':\n                        search_stack_pos -= 1\n                    if 'orderby_clause' in self.current_stack[search_stack_pos][1]:\n                        actions_to_remove[rule_type].add(rule_id)\n                    if 'groupby_clause' in self.current_stack[search_stack_pos][1]:\n                        actions_to_remove[rule_type].add(rule_id)\n\n        new_valid_actions = {}\n        new_global_actions = self._remove_actions(valid_actions, 'global',\n                                                  actions_to_remove['global']) if 'global' in valid_actions else None\n        new_linked_actions = self._remove_actions(valid_actions, 'linked',\n                                                  actions_to_remove['linked']) if 'linked' in valid_actions else None\n\n        if new_linked_actions is not None:\n            new_valid_actions['linked'] = new_linked_actions\n        if new_global_actions is not None:\n            new_valid_actions['global'] = new_global_actions\n\n        # if len(new_valid_actions) == 0 and valid_actions:\n        #     # should not get here! implies that a rule should have been disabled in past (bug in this parser)\n        #     # log and do not remove rules (otherwise crashes)\n        #     # logger.warning(\"No valid action remains, error in sql decoding parser!\")\n        #     # logger.warning(\"Action history: \" + str(self.action_history))\n        #     # logger.warning(\"Tables in db: \" + ', '.join([a.split(' -> ')[1].strip('[]\\\"') for a in self.possible_actions if a.startswith('table_name ->')]))\n        #\n        #     return valid_actions\n\n        return new_valid_actions\n\n    @staticmethod\n    def _remove_actions(valid_actions, key, ids_to_remove):\n        if len(ids_to_remove) == 0:\n            return valid_actions[key]\n\n        if len(ids_to_remove) == len(valid_actions[key][2]):\n            return None\n\n        current_ids = valid_actions[key][2]\n        keep_ids = []\n        keep_ids_loc = []\n\n        for loc, rule_id in enumerate(current_ids):\n            if rule_id not in ids_to_remove:\n                keep_ids.append(rule_id)\n                keep_ids_loc.append(loc)\n\n        items = list(valid_actions[key])\n        items[0] = items[0][keep_ids_loc]\n        items[1] = items[1][keep_ids_loc]\n        items[2] = keep_ids\n\n        if len(items) >= 4:\n            items[3] = items[3][keep_ids_loc]\n        return tuple(items)\n\n    def _get_current_open_clause(self):\n        relevant_clauses = [\n            'where_clause',\n            'orderby_clause',\n            'join_clause',\n            'join_condition',\n            'select_core',\n            'groupby_clause',\n            'source_subq'\n        ]\n        for rule in self.current_stack[::-1]:\n            if rule[0] in relevant_clauses:\n                return rule[0]\n\n        return None\n"
  },
  {
    "path": "state_machines/transition_functions/attend_past_schema_items_transition.py",
    "content": "from collections import defaultdict\nfrom typing import Dict, Tuple, List, Set, Any, Callable, Optional\n\nimport torch\nfrom allennlp.modules import Attention, FeedForward\nfrom allennlp.nn import Activation, util\n\nfrom allennlp.state_machines.states.grammar_based_state import GrammarBasedState\nfrom allennlp.state_machines.transition_functions import BasicTransitionFunction\nfrom allennlp.state_machines.transition_functions.linking_transition_function import LinkingTransitionFunction\nfrom overrides import overrides\nfrom torch.nn import Linear\n\nfrom state_machines.states.rnn_statelet import RnnStatelet\n\n\nclass AttendPastSchemaItemsTransitionFunction(BasicTransitionFunction):\n    def __init__(self,\n                 encoder_output_dim: int,\n                 action_embedding_dim: int,\n                 input_attention: Attention,\n                 past_attention: Attention,\n                 activation: Activation = Activation.by_name('relu')(),\n                 predict_start_type_separately: bool = True,\n                 num_start_types: int = None,\n                 add_action_bias: bool = True,\n                 dropout: float = 0.0,\n                 num_layers: int = 1) -> None:\n        super().__init__(encoder_output_dim=encoder_output_dim,\n                         action_embedding_dim=action_embedding_dim,\n                         input_attention=input_attention,\n                         num_start_types=num_start_types,\n                         activation=activation,\n                         predict_start_type_separately=predict_start_type_separately,\n                         add_action_bias=add_action_bias,\n                         dropout=dropout,\n                         num_layers=num_layers)\n\n        self._past_attention = past_attention\n        self._ent2ent_ff = FeedForward(1, 1, 1, Activation.by_name('linear')())\n\n    @overrides\n    def take_step(self,\n                  state: GrammarBasedState,\n                  max_actions: int = None,\n                  allowed_actions: List[Set[int]] = None) -> List[GrammarBasedState]:\n        if self._predict_start_type_separately and not state.action_history[0]:\n            # The wikitables parser did something different when predicting the start type, which\n            # is our first action.  So in this case we break out into a different function.  We'll\n            # ignore max_actions on our first step, assuming there aren't that many start types.\n            return self._take_first_step(state, allowed_actions)\n\n        # Taking a step in the decoder consists of three main parts.  First, we'll construct the\n        # input to the decoder and update the decoder's hidden state.  Second, we'll use this new\n        # hidden state (and maybe other information) to predict an action.  Finally, we will\n        # construct new states for the next step.  Each new state corresponds to one valid action\n        # that can be taken from the current state, and they are ordered by their probability of\n        # being selected.\n\n        updated_state = self._update_decoder_state(state)\n        batch_results = self._compute_action_probabilities(state,\n                                                           updated_state['hidden_state'],\n                                                           updated_state['attention_weights'],\n                                                           updated_state['past_schema_items_attention_weights'],\n                                                           updated_state['predicted_action_embeddings'])\n        new_states = self._construct_next_states(state,\n                                                 updated_state,\n                                                 batch_results,\n                                                 max_actions,\n                                                 allowed_actions)\n\n        return new_states\n\n    def _update_decoder_state(self, state: GrammarBasedState) -> Dict[str, torch.Tensor]:\n        # For updating the decoder, we're doing a bunch of tensor operations that can be batched\n        # without much difficulty.  So, we take all group elements and batch their tensors together\n        # before doing these decoder operations.\n\n        group_size = len(state.batch_indices)\n        attended_question = torch.stack([rnn_state.attended_input for rnn_state in state.rnn_state])\n        if self._num_layers > 1:\n            hidden_state = torch.stack([rnn_state.hidden_state for rnn_state in state.rnn_state], 1)\n            memory_cell = torch.stack([rnn_state.memory_cell for rnn_state in state.rnn_state], 1)\n        else:\n            hidden_state = torch.stack([rnn_state.hidden_state for rnn_state in state.rnn_state])\n            memory_cell = torch.stack([rnn_state.memory_cell for rnn_state in state.rnn_state])\n\n        previous_action_embedding = torch.stack([rnn_state.previous_action_embedding\n                                                 for rnn_state in state.rnn_state])\n\n        # (group_size, decoder_input_dim)\n        projected_input = self._input_projection_layer(torch.cat([attended_question,\n                                                                  previous_action_embedding], -1))\n        decoder_input = self._activation(projected_input)\n        if self._num_layers > 1:\n            _, (hidden_state, memory_cell) = self._decoder_cell(decoder_input.unsqueeze(0),\n                                                                (hidden_state, memory_cell))\n        else:\n            hidden_state, memory_cell = self._decoder_cell(decoder_input, (hidden_state, memory_cell))\n        hidden_state = self._dropout(hidden_state)\n\n        # (group_size, encoder_output_dim)\n        encoder_outputs = torch.stack([state.rnn_state[0].encoder_outputs[i] for i in state.batch_indices])\n        encoder_output_mask = torch.stack([state.rnn_state[0].encoder_output_mask[i] for i in state.batch_indices])\n\n        if self._num_layers > 1:\n            attended_question, attention_weights = self.attend_on_question(hidden_state[-1],\n                                                                           encoder_outputs,\n                                                                           encoder_output_mask)\n            action_query = torch.cat([hidden_state[-1], attended_question], dim=-1)\n        else:\n            attended_question, attention_weights = self.attend_on_question(hidden_state,\n                                                                           encoder_outputs,\n                                                                           encoder_output_mask)\n            action_query = torch.cat([hidden_state, attended_question], dim=-1)\n\n        # TODO: Can batch this (need to save ids of states with saved outputs)\n        past_schema_items_attention_weights = []\n        for i, rnn_state in enumerate(state.rnn_state):\n            if rnn_state.decoder_outputs is not None:\n                decoder_outputs_states, decoder_outputs_ids = rnn_state.decoder_outputs\n                attn_weights = self.attend(self._past_attention, hidden_state[i].unsqueeze(0),\n                                           decoder_outputs_states.unsqueeze(0), None).squeeze(0)\n                past_schema_items_attention_weights.append((attn_weights, decoder_outputs_ids))\n            else:\n                past_schema_items_attention_weights.append(None)\n\n        # past_schema_items_attention_weights = torch.stack(past_schema_items_attention_weights)\n\n        # (group_size, action_embedding_dim)\n        projected_query = self._activation(self._output_projection_layer(action_query))\n        predicted_action_embeddings = self._dropout(projected_query)\n        if self._add_action_bias:\n            # NOTE: It's important that this happens right before the dot product with the action\n            # embeddings.  Otherwise this isn't a proper bias.  We do it here instead of right next\n            # to the `.mm` below just so we only do it once for the whole group.\n            ones = predicted_action_embeddings.new([[1] for _ in range(group_size)])\n            predicted_action_embeddings = torch.cat([predicted_action_embeddings, ones], dim=-1)\n        return {\n                'hidden_state': hidden_state,\n                'memory_cell': memory_cell,\n                'attended_question': attended_question,\n                'attention_weights': attention_weights,\n                'past_schema_items_attention_weights': past_schema_items_attention_weights,\n                'predicted_action_embeddings': predicted_action_embeddings,\n                }\n\n    @overrides\n    def _compute_action_probabilities(self,\n                                      state: GrammarBasedState,\n                                      hidden_state: torch.Tensor,\n                                      attention_weights: torch.Tensor,\n                                      past_schema_items_attention_weights: torch.Tensor,\n                                      predicted_action_embeddings: torch.Tensor\n                                      ) -> Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]]:\n        # In this section we take our predicted action embedding and compare it to the available\n        # actions in our current state (which might be different for each group element).  For\n        # computing action scores, we'll forget about doing batched / grouped computation, as it\n        # adds too much complexity and doesn't speed things up, anyway, with the operations we're\n        # doing here.  This means we don't need any action masks, as we'll only get the right\n        # lengths for what we're computing.\n\n        group_size = len(state.batch_indices)\n        actions = state.get_valid_actions()\n\n        batch_results: Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]] = defaultdict(list)\n        for group_index in range(group_size):\n            batch_id = state.batch_indices[group_index]\n            instance_actions = actions[group_index]\n            predicted_action_embedding = predicted_action_embeddings[group_index]\n            embedded_actions: List[int] = []\n\n            output_action_embeddings = None\n            embedded_action_logits = None\n            current_log_probs = None\n            linked_action_logits_encoder = None\n            linked_action_ent2ent_logits = None\n\n            if 'global' in instance_actions:\n                action_embeddings, output_action_embeddings, embedded_actions = instance_actions['global']\n                # This is just a matrix product between a (num_actions, embedding_dim) matrix and an\n                # (embedding_dim, 1) matrix.\n                embedded_action_logits = action_embeddings.mm(predicted_action_embedding.unsqueeze(-1)).squeeze(-1)\n                action_ids = embedded_actions\n\n            if 'linked' in instance_actions:\n                linking_scores, type_embeddings, linked_actions, entity_action_linking_scores = instance_actions['linked']\n                action_ids = embedded_actions + linked_actions\n                # (num_question_tokens, 1)\n\n                # for linked actions, in addition to the linking score with the attended question word, we add\n                # a linking score with an attended previously decoded linked action\n                # num_decoded_entities = 3\n                if past_schema_items_attention_weights[group_index] is not None:\n                    past_items_attention_weights, past_items_action_ids = past_schema_items_attention_weights[group_index]\n                    past_schema_items_ids = [state.action_entity_mapping[batch_id][a]\n                                             for a in past_items_action_ids]\n\n                    # we are only interested about the scores of the entities the decoder has already output\n                    past_entity_linking_scores = entity_action_linking_scores[:, past_schema_items_ids]\n\n                    linked_action_ent2ent_logits = past_entity_linking_scores.mm(\n                        past_items_attention_weights.unsqueeze(-1)).squeeze(-1)\n                    linked_action_ent2ent_logits = self._ent2ent_ff(linked_action_ent2ent_logits.unsqueeze(-1)).squeeze(1)\n                else:\n                    linked_action_ent2ent_logits = 0\n\n                linked_action_logits_encoder = linking_scores.mm(attention_weights[group_index].unsqueeze(-1)).squeeze(-1)\n                linked_action_logits = linked_action_logits_encoder + linked_action_ent2ent_logits\n\n                # The `output_action_embeddings` tensor gets used later as the input to the next\n                # decoder step.  For linked actions, we don't have any action embedding, so we use\n                # the entity type instead.\n                if output_action_embeddings is not None:\n                    output_action_embeddings = torch.cat([output_action_embeddings, type_embeddings], dim=0)\n                else:\n                    output_action_embeddings = type_embeddings\n\n                if embedded_action_logits is not None:\n                    action_logits = torch.cat([embedded_action_logits, linked_action_logits], dim=-1)\n                else:\n                    action_logits = linked_action_logits\n                current_log_probs = torch.nn.functional.log_softmax(action_logits, dim=-1)\n            elif not instance_actions:\n                action_ids = None\n                current_log_probs = float('inf')\n            else:\n                action_logits = embedded_action_logits\n                current_log_probs = torch.nn.functional.log_softmax(action_logits, dim=-1)\n\n            # This is now the total score for each state after taking each action.  We're going to\n            # sort by this later, so it's important that this is the total score, not just the\n            # score for the current action.\n            log_probs = state.score[group_index] + current_log_probs\n            batch_results[state.batch_indices[group_index]].append((group_index,\n                                                                    log_probs,\n                                                                    current_log_probs,\n                                                                    output_action_embeddings,\n                                                                    action_ids,\n                                                                    linked_action_logits_encoder,\n                                                                    linked_action_ent2ent_logits))\n        return batch_results\n\n    def _construct_next_states(self,\n                               state: GrammarBasedState,\n                               updated_rnn_state: Dict[str, torch.Tensor],\n                               batch_action_probs: Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]],\n                               max_actions: int,\n                               allowed_actions: List[Set[int]]):\n        # pylint: disable=no-self-use\n\n        # We'll yield a bunch of states here that all have a `group_size` of 1, so that the\n        # learning algorithm can decide how many of these it wants to keep, and it can just regroup\n        # them later, as that's a really easy operation.\n        #\n        # We first define a `make_state` method, as in the logic that follows we want to create\n        # states in a couple of different branches, and we don't want to duplicate the\n        # state-creation logic.  This method creates a closure using variables from the method, so\n        # it doesn't make sense to pull it out of here.\n\n        # Each group index here might get accessed multiple times, and doing the slicing operation\n        # each time is more expensive than doing it once upfront.  These three lines give about a\n        # 10% speedup in training time.\n        group_size = len(state.batch_indices)\n\n        chunk_index = 1 if self._num_layers > 1 else 0\n        hidden_state = [x.squeeze(chunk_index)\n                        for x in updated_rnn_state['hidden_state'].chunk(group_size, chunk_index)]\n        memory_cell = [x.squeeze(chunk_index)\n                       for x in updated_rnn_state['memory_cell'].chunk(group_size, chunk_index)]\n\n        attended_question = [x.squeeze(0) for x in updated_rnn_state['attended_question'].chunk(group_size, 0)]\n\n        def make_state(group_index: int,\n                       action: int,\n                       new_score: torch.Tensor,\n                       action_embedding: torch.Tensor) -> GrammarBasedState:\n            batch_index = state.batch_indices[group_index]\n\n            decoder_outputs = state.rnn_state[group_index].decoder_outputs\n            is_linked_action = not state.possible_actions[batch_index][action][1]\n            if is_linked_action:\n                if decoder_outputs is None:\n                    decoder_outputs = hidden_state[group_index].unsqueeze(0), [action]\n                else:\n                    decoder_outputs_states, decoder_outputs_ids = decoder_outputs\n                    decoder_outputs = torch.cat((\n                        decoder_outputs_states,\n                        hidden_state[group_index].unsqueeze(0)\n                    ), dim=0), decoder_outputs_ids + [action]\n\n            new_rnn_state = RnnStatelet(hidden_state[group_index],\n                                        memory_cell[group_index],\n                                        action_embedding,\n                                        attended_question[group_index],\n                                        state.rnn_state[group_index].encoder_outputs,\n                                        state.rnn_state[group_index].encoder_output_mask,\n                                        decoder_outputs)\n            for i, _, current_log_probs, _, actions, lsq, lsp in batch_action_probs[batch_index]:\n                if i == group_index:\n                    considered_actions = actions\n                    probabilities = current_log_probs.exp().cpu()\n                    considered_lsq = lsq\n                    considered_lsp = lsp\n                    break\n            return state.new_state_from_group_index(group_index,\n                                                    action,\n                                                    new_score,\n                                                    new_rnn_state,\n                                                    considered_actions,\n                                                    probabilities,\n                                                    updated_rnn_state['attention_weights'],\n                                                    considered_lsq,\n                                                    considered_lsp\n                                                    )\n\n        new_states = []\n        for _, results in batch_action_probs.items():\n            if allowed_actions and not max_actions:\n                # If we're given a set of allowed actions, and we're not just keeping the top k of\n                # them, we don't need to do any sorting, so we can speed things up quite a bit.\n                for group_index, log_probs, _, action_embeddings, actions in results:\n                    for log_prob, action_embedding, action in zip(log_probs, action_embeddings, actions):\n                        if action in allowed_actions[group_index]:\n                            new_states.append(make_state(group_index, action, log_prob, action_embedding))\n            else:\n                # In this case, we need to sort the actions.  We'll do that on CPU, as it's easier,\n                # and our action list is on the CPU, anyway.\n                group_indices = []\n                group_log_probs: List[torch.Tensor] = []\n                group_action_embeddings = []\n                group_actions = []\n                for group_index, log_probs, _, action_embeddings, actions, _, _ in results:\n                    if not actions:\n                        continue\n\n                    group_indices.extend([group_index] * len(actions))\n                    group_log_probs.append(log_probs)\n                    group_action_embeddings.append(action_embeddings)\n                    group_actions.extend(actions)\n                    \n                if len(group_log_probs) == 0:\n                    continue\n\n                log_probs = torch.cat(group_log_probs, dim=0)\n                action_embeddings = torch.cat(group_action_embeddings, dim=0)\n                log_probs_cpu = log_probs.data.cpu().numpy().tolist()\n                batch_states = [(log_probs_cpu[i],\n                                 group_indices[i],\n                                 log_probs[i],\n                                 action_embeddings[i],\n                                 group_actions[i])\n                                for i in range(len(group_actions))\n                                if (not allowed_actions or\n                                    group_actions[i] in allowed_actions[group_indices[i]])]\n                # We use a key here to make sure we're not trying to compare anything on the GPU.\n                batch_states.sort(key=lambda x: x[0], reverse=True)\n                if max_actions:\n                    batch_states = batch_states[:max_actions]\n                for _, group_index, log_prob, action_embedding, action in batch_states:\n                    new_states.append(make_state(group_index, action, log_prob, action_embedding))\n        return new_states\n\n    def attend(self,\n               attention: Attention,\n               query: torch.Tensor,\n               key: torch.Tensor,\n               value: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Given a query (which is typically the decoder hidden state), compute an attention over the\n        output of the question encoder, and return a weighted sum of the question representations\n        given this attention.  We also return the attention weights themselves.\n\n        This is a simple computation, but we have it as a separate method so that the ``forward``\n        method on the main parser module can call it on the initial hidden state, to simplify the\n        logic in ``take_step``.\n        \"\"\"\n        # (group_size, question_length)\n        attention_weights = attention(query, key, None)\n\n        if value is None:\n            return attention_weights\n\n        # (group_size, encoder_output_dim)\n        attended_question = util.weighted_sum(value, attention_weights)\n        return attended_question, attention_weights\n"
  },
  {
    "path": "state_machines/transition_functions/basic_transition_function.py",
    "content": "from collections import defaultdict\nfrom typing import Any, Dict, List, Set, Tuple\n\nfrom overrides import overrides\n\nimport torch\nfrom torch.nn.modules.rnn import LSTM, LSTMCell\nfrom torch.nn.modules.linear import Linear\n\nfrom allennlp.modules import Attention\nfrom allennlp.nn import util, Activation\nfrom allennlp.state_machines.states import RnnStatelet, GrammarBasedState\nfrom allennlp.state_machines.transition_functions.transition_function import TransitionFunction\n\n\nclass BasicTransitionFunction(TransitionFunction[GrammarBasedState]):\n    \"\"\"\n    This is a typical transition function for a state-based decoder.  We use an LSTM to track\n    decoder state, and at every timestep we compute an attention over the input question/utterance\n    to help in selecting the action.  All actions have an embedding, and we use a dot product\n    between a predicted action embedding and the allowed actions to compute a distribution over\n    actions at each timestep.\n\n    We allow the first action to be predicted separately from everything else.  This is optional,\n    and is because that's how the original WikiTableQuestions semantic parser was written.  The\n    intuition is that maybe you want to predict the type of your output program outside of the\n    typical LSTM decoder (or maybe Jayant just didn't realize this could be treated as another\n    action...).\n\n    Parameters\n    ----------\n    encoder_output_dim : ``int``\n    action_embedding_dim : ``int``\n    input_attention : ``Attention``\n    activation : ``Activation``, optional (default=relu)\n        The activation that gets applied to the decoder LSTM input and to the action query.\n    predict_start_type_separately : ``bool``, optional (default=True)\n        If ``True``, we will predict the initial action (which is typically the base type of the\n        logical form) using a different mechanism than our typical action decoder.  We basically\n        just do a projection of the hidden state, and don't update the decoder RNN.\n    num_start_types : ``int``, optional (default=None)\n        If ``predict_start_type_separately`` is ``True``, this is the number of start types that\n        are in the grammar.  We need this so we can construct parameters with the right shape.\n        This is unused if ``predict_start_type_separately`` is ``False``.\n    add_action_bias : ``bool``, optional (default=True)\n        If ``True``, there has been a bias dimension added to the embedding of each action, which\n        gets used when predicting the next action.  We add a dimension of ones to our predicted\n        action vector in this case to account for that.\n    dropout : ``float`` (optional, default=0.0)\n    num_layers: ``int``, (optional, default=1)\n        The number of layers in the decoder LSTM.\n    \"\"\"\n    def __init__(self,\n                 encoder_output_dim: int,\n                 action_embedding_dim: int,\n                 input_attention: Attention,\n                 activation: Activation = Activation.by_name('relu')(),\n                 predict_start_type_separately: bool = True,\n                 num_start_types: int = None,\n                 add_action_bias: bool = True,\n                 dropout: float = 0.0,\n                 num_layers: int = 1) -> None:\n        super().__init__()\n        self._input_attention = input_attention\n        self._add_action_bias = add_action_bias\n        self._activation = activation\n        self._num_layers = num_layers\n\n        self._predict_start_type_separately = predict_start_type_separately\n        if predict_start_type_separately:\n            self._start_type_predictor = Linear(encoder_output_dim, num_start_types)\n            self._num_start_types = num_start_types\n        else:\n            self._start_type_predictor = None\n            self._num_start_types = None\n\n        # Decoder output dim needs to be the same as the encoder output dim since we initialize the\n        # hidden state of the decoder with the final hidden state of the encoder.\n        output_dim = encoder_output_dim\n        input_dim = output_dim\n        # Our decoder input will be the concatenation of the decoder hidden state and the previous\n        # action embedding, and we'll project that down to the decoder's `input_dim`, which we\n        # arbitrarily set to be the same as `output_dim`.\n        self._input_projection_layer = Linear(output_dim + action_embedding_dim, input_dim)\n        # Before making a prediction, we'll compute an attention over the input given our updated\n        # hidden state. Then we concatenate those with the decoder state and project to\n        # `action_embedding_dim` to make a prediction.\n        self._output_projection_layer = Linear(output_dim + encoder_output_dim, action_embedding_dim)\n        if self._num_layers > 1:\n            self._decoder_cell = LSTM(input_dim, output_dim, self._num_layers)\n        else:\n            # We use a ``LSTMCell`` if we just have one layer because it is slightly faster since we are\n            # just running the LSTM for one step each time.\n            self._decoder_cell = LSTMCell(input_dim, output_dim)\n\n        if dropout > 0:\n            self._dropout = torch.nn.Dropout(p=dropout)\n        else:\n            self._dropout = lambda x: x\n\n    @overrides\n    def take_step(self,\n                  state: GrammarBasedState,\n                  max_actions: int = None,\n                  allowed_actions: List[Set[int]] = None) -> List[GrammarBasedState]:\n        if self._predict_start_type_separately and not state.action_history[0]:\n            # The wikitables parser did something different when predicting the start type, which\n            # is our first action.  So in this case we break out into a different function.  We'll\n            # ignore max_actions on our first step, assuming there aren't that many start types.\n            return self._take_first_step(state, allowed_actions)\n\n        # Taking a step in the decoder consists of three main parts.  First, we'll construct the\n        # input to the decoder and update the decoder's hidden state.  Second, we'll use this new\n        # hidden state (and maybe other information) to predict an action.  Finally, we will\n        # construct new states for the next step.  Each new state corresponds to one valid action\n        # that can be taken from the current state, and they are ordered by their probability of\n        # being selected.\n\n        updated_state = self._update_decoder_state(state)\n        batch_results = self._compute_action_probabilities(state,\n                                                           updated_state['hidden_state'],\n                                                           updated_state['attention_weights'],\n                                                           updated_state['predicted_action_embeddings'])\n        new_states = self._construct_next_states(state,\n                                                 updated_state,\n                                                 batch_results,\n                                                 max_actions,\n                                                 allowed_actions)\n\n        return new_states\n\n    def _update_decoder_state(self, state: GrammarBasedState) -> Dict[str, torch.Tensor]:\n        # For updating the decoder, we're doing a bunch of tensor operations that can be batched\n        # without much difficulty.  So, we take all group elements and batch their tensors together\n        # before doing these decoder operations.\n\n        group_size = len(state.batch_indices)\n        attended_question = torch.stack([rnn_state.attended_input for rnn_state in state.rnn_state])\n        if self._num_layers > 1:\n            hidden_state = torch.stack([rnn_state.hidden_state for rnn_state in state.rnn_state], 1)\n            memory_cell = torch.stack([rnn_state.memory_cell for rnn_state in state.rnn_state], 1)\n        else:\n            hidden_state = torch.stack([rnn_state.hidden_state for rnn_state in state.rnn_state])\n            memory_cell = torch.stack([rnn_state.memory_cell for rnn_state in state.rnn_state])\n\n        previous_action_embedding = torch.stack([rnn_state.previous_action_embedding\n                                                 for rnn_state in state.rnn_state])\n\n        # (group_size, decoder_input_dim)\n        projected_input = self._input_projection_layer(torch.cat([attended_question,\n                                                                  previous_action_embedding], -1))\n        decoder_input = self._activation(projected_input)\n        if self._num_layers > 1:\n            _, (hidden_state, memory_cell) = self._decoder_cell(decoder_input.unsqueeze(0),\n                                                                (hidden_state, memory_cell))\n        else:\n            hidden_state, memory_cell = self._decoder_cell(decoder_input, (hidden_state, memory_cell))\n        hidden_state = self._dropout(hidden_state)\n\n        # (group_size, encoder_output_dim)\n        encoder_outputs = torch.stack([state.rnn_state[0].encoder_outputs[i] for i in state.batch_indices])\n        encoder_output_mask = torch.stack([state.rnn_state[0].encoder_output_mask[i] for i in state.batch_indices])\n\n        if self._num_layers > 1:\n            attended_question, attention_weights = self.attend_on_question(hidden_state[-1],\n                                                                           encoder_outputs,\n                                                                           encoder_output_mask)\n            action_query = torch.cat([hidden_state[-1], attended_question], dim=-1)\n        else:\n            attended_question, attention_weights = self.attend_on_question(hidden_state,\n                                                                           encoder_outputs,\n                                                                           encoder_output_mask)\n            action_query = torch.cat([hidden_state, attended_question], dim=-1)\n\n        # (group_size, action_embedding_dim)\n        projected_query = self._activation(self._output_projection_layer(action_query))\n        predicted_action_embeddings = self._dropout(projected_query)\n        if self._add_action_bias:\n            # NOTE: It's important that this happens right before the dot product with the action\n            # embeddings.  Otherwise this isn't a proper bias.  We do it here instead of right next\n            # to the `.mm` below just so we only do it once for the whole group.\n            ones = predicted_action_embeddings.new([[1] for _ in range(group_size)])\n            predicted_action_embeddings = torch.cat([predicted_action_embeddings, ones], dim=-1)\n        return {\n                'hidden_state': hidden_state,\n                'memory_cell': memory_cell,\n                'attended_question': attended_question,\n                'attention_weights': attention_weights,\n                'predicted_action_embeddings': predicted_action_embeddings,\n                }\n\n    def _compute_action_probabilities(self,\n                                      state: GrammarBasedState,\n                                      hidden_state: torch.Tensor,\n                                      attention_weights: torch.Tensor,\n                                      predicted_action_embeddings: torch.Tensor\n                                     ) -> Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]]:\n        # We take a couple of extra arguments here because subclasses might use them.\n        # pylint: disable=unused-argument,no-self-use\n\n        # In this section we take our predicted action embedding and compare it to the available\n        # actions in our current state (which might be different for each group element).  For\n        # computing action scores, we'll forget about doing batched / grouped computation, as it\n        # adds too much complexity and doesn't speed things up, anyway, with the operations we're\n        # doing here.  This means we don't need any action masks, as we'll only get the right\n        # lengths for what we're computing.\n\n        group_size = len(state.batch_indices)\n        actions = state.get_valid_actions()\n\n        batch_results: Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]] = defaultdict(list)\n        for group_index in range(group_size):\n            instance_actions = actions[group_index]\n            predicted_action_embedding = predicted_action_embeddings[group_index]\n            action_embeddings, output_action_embeddings, action_ids = instance_actions['global']\n            # This is just a matrix product between a (num_actions, embedding_dim) matrix and an\n            # (embedding_dim, 1) matrix.\n            action_logits = action_embeddings.mm(predicted_action_embedding.unsqueeze(-1)).squeeze(-1)\n            current_log_probs = torch.nn.functional.log_softmax(action_logits, dim=-1)\n\n            # This is now the total score for each state after taking each action.  We're going to\n            # sort by this later, so it's important that this is the total score, not just the\n            # score for the current action.\n            log_probs = state.score[group_index] + current_log_probs\n            batch_results[state.batch_indices[group_index]].append((group_index,\n                                                                    log_probs,\n                                                                    current_log_probs,\n                                                                    output_action_embeddings,\n                                                                    action_ids))\n        return batch_results\n\n    def _construct_next_states(self,\n                               state: GrammarBasedState,\n                               updated_rnn_state: Dict[str, torch.Tensor],\n                               batch_action_probs: Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]],\n                               max_actions: int,\n                               allowed_actions: List[Set[int]]):\n        # pylint: disable=no-self-use\n\n        # We'll yield a bunch of states here that all have a `group_size` of 1, so that the\n        # learning algorithm can decide how many of these it wants to keep, and it can just regroup\n        # them later, as that's a really easy operation.\n        #\n        # We first define a `make_state` method, as in the logic that follows we want to create\n        # states in a couple of different branches, and we don't want to duplicate the\n        # state-creation logic.  This method creates a closure using variables from the method, so\n        # it doesn't make sense to pull it out of here.\n\n        # Each group index here might get accessed multiple times, and doing the slicing operation\n        # each time is more expensive than doing it once upfront.  These three lines give about a\n        # 10% speedup in training time.\n        group_size = len(state.batch_indices)\n\n        chunk_index = 1 if self._num_layers > 1 else 0\n        hidden_state = [x.squeeze(chunk_index)\n                        for x in updated_rnn_state['hidden_state'].chunk(group_size, chunk_index)]\n        memory_cell = [x.squeeze(chunk_index)\n                       for x in updated_rnn_state['memory_cell'].chunk(group_size, chunk_index)]\n\n        attended_question = [x.squeeze(0) for x in updated_rnn_state['attended_question'].chunk(group_size, 0)]\n\n        def make_state(group_index: int,\n                       action: int,\n                       new_score: torch.Tensor,\n                       action_embedding: torch.Tensor) -> GrammarBasedState:\n            new_rnn_state = RnnStatelet(hidden_state[group_index],\n                                        memory_cell[group_index],\n                                        action_embedding,\n                                        attended_question[group_index],\n                                        state.rnn_state[group_index].encoder_outputs,\n                                        state.rnn_state[group_index].encoder_output_mask)\n            batch_index = state.batch_indices[group_index]\n            for i, _, current_log_probs, _, actions in batch_action_probs[batch_index]:\n                if i == group_index:\n                    considered_actions = actions\n                    probabilities = current_log_probs.exp().cpu()\n                    break\n            return state.new_state_from_group_index(group_index,\n                                                    action,\n                                                    new_score,\n                                                    new_rnn_state,\n                                                    considered_actions,\n                                                    probabilities,\n                                                    updated_rnn_state['attention_weights'])\n\n        new_states = []\n        for _, results in batch_action_probs.items():\n            if allowed_actions and not max_actions:\n                # If we're given a set of allowed actions, and we're not just keeping the top k of\n                # them, we don't need to do any sorting, so we can speed things up quite a bit.\n                for group_index, log_probs, _, action_embeddings, actions in results:\n                    for log_prob, action_embedding, action in zip(log_probs, action_embeddings, actions):\n                        if action in allowed_actions[group_index]:\n                            new_states.append(make_state(group_index, action, log_prob, action_embedding))\n            else:\n                # In this case, we need to sort the actions.  We'll do that on CPU, as it's easier,\n                # and our action list is on the CPU, anyway.\n                group_indices = []\n                group_log_probs: List[torch.Tensor] = []\n                group_action_embeddings = []\n                group_actions = []\n\n                for group_index, log_probs, _, action_embeddings, actions in results:\n                    if not actions:\n                        continue\n\n                    group_indices.extend([group_index] * len(actions))\n                    group_log_probs.append(log_probs)\n                    group_action_embeddings.append(action_embeddings)\n                    group_actions.extend(actions)\n\n                if len(group_log_probs) == 0:\n                    continue\n                log_probs = torch.cat(group_log_probs, dim=0)\n                action_embeddings = torch.cat(group_action_embeddings, dim=0)\n                log_probs_cpu = log_probs.data.cpu().numpy().tolist()\n                batch_states = [(log_probs_cpu[i],\n                                 group_indices[i],\n                                 log_probs[i],\n                                 action_embeddings[i],\n                                 group_actions[i])\n                                for i in range(len(group_actions))\n                                if (not allowed_actions or\n                                    group_actions[i] in allowed_actions[group_indices[i]])]\n                # We use a key here to make sure we're not trying to compare anything on the GPU.\n                batch_states.sort(key=lambda x: x[0], reverse=True)\n                if max_actions:\n                    batch_states = batch_states[:max_actions]\n                for _, group_index, log_prob, action_embedding, action in batch_states:\n                    new_states.append(make_state(group_index, action, log_prob, action_embedding))\n        return new_states\n\n    def _take_first_step(self,\n                         state: GrammarBasedState,\n                         allowed_actions: List[Set[int]] = None) -> List[GrammarBasedState]:\n        # We'll just do a projection from the current hidden state (which was initialized with the\n        # final encoder output) to the number of start actions that we have, normalize those\n        # logits, and use that as our score.  We end up duplicating some of the logic from\n        # `_compute_new_states` here, but we do things slightly differently, and it's easier to\n        # just copy the parts we need than to try to re-use that code.\n\n        # (group_size, hidden_dim)\n        hidden_state = torch.stack([rnn_state.hidden_state for rnn_state in state.rnn_state])\n        # (group_size, num_start_type)\n        start_action_logits = self._start_type_predictor(hidden_state)\n        log_probs = torch.nn.functional.log_softmax(start_action_logits, dim=-1)\n        sorted_log_probs, sorted_actions = log_probs.sort(dim=-1, descending=True)\n\n        sorted_actions = sorted_actions.detach().cpu().numpy().tolist()\n        if state.debug_info is not None:\n            probs_cpu = log_probs.exp().detach().cpu().numpy().tolist()\n        else:\n            probs_cpu = [None] * len(state.batch_indices)\n\n        # state.get_valid_actions() will return a list that is consistently sorted, so as along as\n        # the set of valid start actions never changes, we can just match up the log prob indices\n        # above with the position of each considered action, and we're good.\n        valid_actions = state.get_valid_actions()\n        considered_actions = [actions['global'][2] for actions in valid_actions]\n        if len(considered_actions[0]) != self._num_start_types:\n            raise RuntimeError(\"Calculated wrong number of initial actions.  Expected \"\n                               f\"{self._num_start_types}, found {len(considered_actions[0])}.\")\n\n        best_next_states: Dict[int, List[Tuple[int, int, int]]] = defaultdict(list)\n        for group_index, (batch_index, group_actions) in enumerate(zip(state.batch_indices, sorted_actions)):\n            for action_index, action in enumerate(group_actions):\n                # `action` is currently the index in `log_probs`, not the actual action ID.  To get\n                # the action ID, we need to go through `considered_actions`.\n                action = considered_actions[group_index][action]\n                if allowed_actions is not None and action not in allowed_actions[group_index]:\n                    # This happens when our _decoder trainer_ wants us to only evaluate certain\n                    # actions, likely because they are the gold actions in this state.  We just skip\n                    # emitting any state that isn't allowed by the trainer, because constructing the\n                    # new state can be expensive.\n                    continue\n                best_next_states[batch_index].append((group_index, action_index, action))\n\n        new_states = []\n        for batch_index, best_states in sorted(best_next_states.items()):\n            for group_index, action_index, action in best_states:\n                # We'll yield a bunch of states here that all have a `group_size` of 1, so that the\n                # learning algorithm can decide how many of these it wants to keep, and it can just\n                # regroup them later, as that's a really easy operation.\n                new_score = state.score[group_index] + sorted_log_probs[group_index, action_index]\n\n                # This part is different from `_compute_new_states` - we're just passing through\n                # the previous RNN state, as predicting the start type wasn't included in the\n                # decoder RNN in the original model.\n                new_rnn_state = state.rnn_state[group_index]\n                new_state = state.new_state_from_group_index(group_index,\n                                                             action,\n                                                             new_score,\n                                                             new_rnn_state,\n                                                             considered_actions[group_index],\n                                                             probs_cpu[group_index],\n                                                             None)\n                new_states.append(new_state)\n        return new_states\n\n    def attend_on_question(self,\n                           query: torch.Tensor,\n                           encoder_outputs: torch.Tensor,\n                           encoder_output_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Given a query (which is typically the decoder hidden state), compute an attention over the\n        output of the question encoder, and return a weighted sum of the question representations\n        given this attention.  We also return the attention weights themselves.\n\n        This is a simple computation, but we have it as a separate method so that the ``forward``\n        method on the main parser module can call it on the initial hidden state, to simplify the\n        logic in ``take_step``.\n        \"\"\"\n        # (group_size, question_length)\n        question_attention_weights = self._input_attention(query,\n                                                           encoder_outputs,\n                                                           encoder_output_mask)\n        # (group_size, encoder_output_dim)\n        attended_question = util.weighted_sum(encoder_outputs, question_attention_weights)\n        return attended_question, question_attention_weights\n"
  },
  {
    "path": "state_machines/transition_functions/linking_transition_function.py",
    "content": "from collections import defaultdict\nfrom typing import Any, Dict, List, Tuple\n\nfrom overrides import overrides\n\nimport torch\n\nfrom allennlp.common.checks import check_dimensions_match\nfrom allennlp.modules import Attention, FeedForward\nfrom allennlp.nn import Activation\nfrom allennlp.state_machines.states import GrammarBasedState\n\nfrom state_machines.transition_functions.basic_transition_function import BasicTransitionFunction\n\n\nclass LinkingTransitionFunction(BasicTransitionFunction):\n    \"\"\"\n    This transition function adds the ability to consider `linked` actions to the\n    ``BasicTransitionFunction`` (which is just an LSTM decoder with attention).  These actions are\n    potentially unseen at training time, so we need to handle them without requiring the action to\n    have an embedding.  Instead, we rely on a `linking score` between each action and the words in\n    the question/utterance, and use these scores, along with the attention, to do something similar\n    to a copy mechanism when producing these actions.\n\n    When both linked and global (embedded) actions are available, we need some way to compare the\n    scores for these two sets of actions.  The original WikiTableQuestion semantic parser just\n    concatenated the logits together before doing a joint softmax, but this is quite brittle,\n    because the logits might have quite different scales.  So we have the option here of predicting\n    a mixture probability between two independently normalized distributions.\n\n    Parameters\n    ----------\n    encoder_output_dim : ``int``\n    action_embedding_dim : ``int``\n    input_attention : ``Attention``\n    activation : ``Activation``, optional (default=relu)\n        The activation that gets applied to the decoder LSTM input and to the action query.\n    predict_start_type_separately : ``bool``, optional (default=True)\n        If ``True``, we will predict the initial action (which is typically the base type of the\n        logical form) using a different mechanism than our typical action decoder.  We basically\n        just do a projection of the hidden state, and don't update the decoder RNN.\n    num_start_types : ``int``, optional (default=None)\n        If ``predict_start_type_separately`` is ``True``, this is the number of start types that\n        are in the grammar.  We need this so we can construct parameters with the right shape.\n        This is unused if ``predict_start_type_separately`` is ``False``.\n    add_action_bias : ``bool``, optional (default=True)\n        If ``True``, there has been a bias dimension added to the embedding of each action, which\n        gets used when predicting the next action.  We add a dimension of ones to our predicted\n        action vector in this case to account for that.\n    mixture_feedforward : ``FeedForward`` optional (default=None)\n        If given, we'll use this to compute a mixture probability between global actions and linked\n        actions given the hidden state at every timestep of decoding, instead of concatenating the\n        logits for both (where the logits may not be compatible with each other).\n    dropout : ``float`` (optional, default=0.0)\n    num_layers: ``int`` (optional, default=1)\n        The number of layers in the decoder LSTM.\n    \"\"\"\n    def __init__(self,\n                 encoder_output_dim: int,\n                 action_embedding_dim: int,\n                 input_attention: Attention,\n                 activation: Activation = Activation.by_name('relu')(),\n                 predict_start_type_separately: bool = True,\n                 num_start_types: int = None,\n                 add_action_bias: bool = True,\n                 mixture_feedforward: FeedForward = None,\n                 dropout: float = 0.0,\n                 num_layers: int = 1) -> None:\n        super().__init__(encoder_output_dim=encoder_output_dim,\n                         action_embedding_dim=action_embedding_dim,\n                         input_attention=input_attention,\n                         num_start_types=num_start_types,\n                         activation=activation,\n                         predict_start_type_separately=predict_start_type_separately,\n                         add_action_bias=add_action_bias,\n                         dropout=dropout,\n                         num_layers=num_layers)\n        self._mixture_feedforward = mixture_feedforward\n\n        if mixture_feedforward is not None:\n            check_dimensions_match(encoder_output_dim, mixture_feedforward.get_input_dim(),\n                                   \"hidden state embedding dim\", \"mixture feedforward input dim\")\n            check_dimensions_match(mixture_feedforward.get_output_dim(), 1,\n                                   \"mixture feedforward output dim\", \"dimension for scalar value\")\n\n    @overrides\n    def _compute_action_probabilities(self,\n                                      state: GrammarBasedState,\n                                      hidden_state: torch.Tensor,\n                                      attention_weights: torch.Tensor,\n                                      predicted_action_embeddings: torch.Tensor\n                                     ) -> Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]]:\n        # In this section we take our predicted action embedding and compare it to the available\n        # actions in our current state (which might be different for each group element).  For\n        # computing action scores, we'll forget about doing batched / grouped computation, as it\n        # adds too much complexity and doesn't speed things up, anyway, with the operations we're\n        # doing here.  This means we don't need any action masks, as we'll only get the right\n        # lengths for what we're computing.\n\n        group_size = len(state.batch_indices)\n        actions = state.get_valid_actions()\n\n        batch_results: Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]] = defaultdict(list)\n        for group_index in range(group_size):\n            instance_actions = actions[group_index]\n            predicted_action_embedding = predicted_action_embeddings[group_index]\n            embedded_actions: List[int] = []\n\n            output_action_embeddings = None\n            embedded_action_logits = None\n            current_log_probs = None\n\n            if 'global' in instance_actions:\n                action_embeddings, output_action_embeddings, embedded_actions = instance_actions['global']\n                # This is just a matrix product between a (num_actions, embedding_dim) matrix and an\n                # (embedding_dim, 1) matrix.\n                embedded_action_logits = action_embeddings.mm(predicted_action_embedding.unsqueeze(-1)).squeeze(-1)\n                action_ids = embedded_actions\n\n            if 'linked' in instance_actions:\n                linking_scores, type_embeddings, linked_actions = instance_actions['linked']\n                action_ids = embedded_actions + linked_actions\n                # (num_question_tokens, 1)\n                linked_action_logits = linking_scores.mm(attention_weights[group_index].unsqueeze(-1)).squeeze(-1)\n\n                # The `output_action_embeddings` tensor gets used later as the input to the next\n                # decoder step.  For linked actions, we don't have any action embedding, so we use\n                # the entity type instead.\n                if output_action_embeddings is not None:\n                    output_action_embeddings = torch.cat([output_action_embeddings, type_embeddings], dim=0)\n                else:\n                    output_action_embeddings = type_embeddings\n\n                if self._mixture_feedforward is not None:\n                    # The linked and global logits are combined with a mixture weight to prevent the\n                    # linked_action_logits from dominating the embedded_action_logits if a softmax\n                    # was applied on both together.\n                    mixture_weight = self._mixture_feedforward(hidden_state[group_index])\n                    mix1 = torch.log(mixture_weight)\n                    mix2 = torch.log(1 - mixture_weight)\n\n                    entity_action_probs = torch.nn.functional.log_softmax(linked_action_logits, dim=-1) + mix1\n                    if embedded_action_logits is not None:\n                        embedded_action_probs = torch.nn.functional.log_softmax(embedded_action_logits,\n                                                                                dim=-1) + mix2\n                        current_log_probs = torch.cat([embedded_action_probs, entity_action_probs], dim=-1)\n                    else:\n                        current_log_probs = entity_action_probs\n                else:\n                    if embedded_action_logits is not None:\n                        action_logits = torch.cat([embedded_action_logits, linked_action_logits], dim=-1)\n                    else:\n                        action_logits = linked_action_logits\n                    current_log_probs = torch.nn.functional.log_softmax(action_logits, dim=-1)\n            elif not instance_actions:\n                action_ids = None\n                current_log_probs = float('inf')\n            else:\n                action_logits = embedded_action_logits\n                current_log_probs = torch.nn.functional.log_softmax(action_logits, dim=-1)\n\n            # This is now the total score for each state after taking each action.  We're going to\n            # sort by this later, so it's important that this is the total score, not just the\n            # score for the current action.\n            log_probs = state.score[group_index] + current_log_probs\n            batch_results[state.batch_indices[group_index]].append((group_index,\n                                                                    log_probs,\n                                                                    current_log_probs,\n                                                                    output_action_embeddings,\n                                                                    action_ids))\n        return batch_results\n"
  },
  {
    "path": "state_machines/transition_functions/prefix_attend_transition.py",
    "content": "from typing import Dict, Tuple, List, Set, Any, Callable\n\nimport torch\nfrom allennlp.modules import Attention, FeedForward\nfrom allennlp.nn import Activation, util\n\nfrom allennlp.state_machines.states.grammar_based_state import GrammarBasedState\nfrom allennlp.state_machines.transition_functions.linking_transition_function import LinkingTransitionFunction\nfrom overrides import overrides\nfrom torch.nn import Linear\n\nfrom state_machines.states.rnn_statelet import RnnStatelet\n\n\nclass PrefixAttendTransitionFunction(LinkingTransitionFunction):\n    def __init__(self,\n                 encoder_output_dim: int,\n                 action_embedding_dim: int,\n                 input_attention: Attention,\n                 output_attention: Attention,\n                 activation: Activation = Activation.by_name('relu')(),\n                 predict_start_type_separately: bool = True,\n                 num_start_types: int = None,\n                 add_action_bias: bool = True,\n                 mixture_feedforward: FeedForward = None,\n                 dropout: float = 0.0,\n                 num_layers: int = 1) -> None:\n        super().__init__(encoder_output_dim=encoder_output_dim,\n                         action_embedding_dim=action_embedding_dim,\n                         input_attention=input_attention,\n                         num_start_types=num_start_types,\n                         activation=activation,\n                         predict_start_type_separately=predict_start_type_separately,\n                         add_action_bias=add_action_bias,\n                         dropout=dropout,\n                         num_layers=num_layers,\n                         mixture_feedforward=mixture_feedforward)\n        self._output_attention = output_attention\n\n        # override\n        self._input_projection_layer = Linear(encoder_output_dim + action_embedding_dim, encoder_output_dim)\n\n        self._attend_output_projection_layer = Linear(encoder_output_dim*2, encoder_output_dim)\n\n        self._first_attended_output = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim))\n        torch.nn.init.normal_(self._first_attended_output)\n\n    @overrides\n    def take_step(self,\n                  state: GrammarBasedState,\n                  max_actions: int = None,\n                  allowed_actions: List[Set[int]] = None) -> List[GrammarBasedState]:\n        if self._predict_start_type_separately and not state.action_history[0]:\n            # The wikitables parser did something different when predicting the start type, which\n            # is our first action.  So in this case we break out into a different function.  We'll\n            # ignore max_actions on our first step, assuming there aren't that many start types.\n            return self._take_first_step(state, allowed_actions)\n\n        # Taking a step in the decoder consists of three main parts.  First, we'll construct the\n        # input to the decoder and update the decoder's hidden state.  Second, we'll use this new\n        # hidden state (and maybe other information) to predict an action.  Finally, we will\n        # construct new states for the next step.  Each new state corresponds to one valid action\n        # that can be taken from the current state, and they are ordered by their probability of\n        # being selected.\n\n        updated_state = self._update_decoder_state(state)\n        batch_results = self._compute_action_probabilities(state,\n                                                           updated_state['hidden_state'],\n                                                           updated_state['attention_weights'],\n                                                           updated_state['predicted_action_embeddings'])\n        new_states = self._construct_next_states(state,\n                                                 updated_state,\n                                                 batch_results,\n                                                 max_actions,\n                                                 allowed_actions)\n\n        return new_states\n\n    @overrides\n    def _update_decoder_state(self, state: GrammarBasedState) -> Dict[str, torch.Tensor]:\n        # For updating the decoder, we're doing a bunch of tensor operations that can be batched\n        # without much difficulty.  So, we take all group elements and batch their tensors together\n        # before doing these decoder operations.\n\n        group_size = len(state.batch_indices)\n        attended_question = torch.stack([rnn_state.attended_input for rnn_state in state.rnn_state])\n\n        hidden_state = torch.stack([rnn_state.hidden_state for rnn_state in state.rnn_state])\n        memory_cell = torch.stack([rnn_state.memory_cell for rnn_state in state.rnn_state])\n\n        previous_action_embedding = torch.stack([rnn_state.previous_action_embedding\n                                                 for rnn_state in state.rnn_state])\n\n        if not state.action_history[0]:\n            attended_output, output_attention_weights = self._first_attended_output.unsqueeze(0).repeat(group_size, 1),\\\n                                                        None\n        else:\n            decoder_outputs = torch.stack([rnn_state.decoder_outputs for rnn_state in state.rnn_state])\n            decoded_item_embeddings = torch.stack([rnn_state.decoded_item_embeddings for rnn_state in state.rnn_state])\n\n            action_query = torch.cat([hidden_state, attended_question], dim=-1)\n\n            # (group_size, action_embedding_dim)\n            projected_query = self._activation(self._attend_output_projection_layer(action_query))\n            query = self._dropout(projected_query)\n\n            attended_output, output_attention_weights = self.attend(query,\n                                                                    decoder_outputs,\n                                                                    decoded_item_embeddings)\n\n        # (group_size, decoder_input_dim)\n        projected_input = self._input_projection_layer(torch.cat([attended_question,\n                                                                  previous_action_embedding + attended_output], -1))\n        decoder_input = self._activation(projected_input)\n        hidden_state, memory_cell = self._decoder_cell(decoder_input, (hidden_state, memory_cell))\n        hidden_state = self._dropout(hidden_state)\n\n        # (group_size, encoder_output_dim)\n        encoder_outputs = torch.stack([state.rnn_state[0].encoder_outputs[i] for i in state.batch_indices])\n        encoder_output_mask = torch.stack([state.rnn_state[0].encoder_output_mask[i] for i in state.batch_indices])\n\n        attended_question, attention_weights = self.attend_on_question(hidden_state,\n                                                                       encoder_outputs,\n                                                                       encoder_output_mask)\n        action_query = torch.cat([hidden_state, attended_question], dim=-1)\n\n        # (group_size, action_embedding_dim)\n        projected_query = self._activation(self._output_projection_layer(action_query))\n        predicted_action_embeddings = self._dropout(projected_query)\n        if self._add_action_bias:\n            # NOTE: It's important that this happens right before the dot product with the action\n            # embeddings.  Otherwise this isn't a proper bias.  We do it here instead of right next\n            # to the `.mm` below just so we only do it once for the whole group.\n            ones = predicted_action_embeddings.new([[1] for _ in range(group_size)])\n            predicted_action_embeddings = torch.cat([predicted_action_embeddings, ones], dim=-1)\n        return {\n            'hidden_state': hidden_state,\n            'memory_cell': memory_cell,\n            'attended_question': attended_question,\n            'attended_output': attended_output,\n            'attention_weights': attention_weights,\n            'output_attention_weights': output_attention_weights,\n            'predicted_action_embeddings': predicted_action_embeddings\n        }\n\n    def attend(self,\n               query: torch.Tensor,\n               key: torch.Tensor,\n               value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Given a query (which is typically the decoder hidden state), compute an attention over the\n        output of the question encoder, and return a weighted sum of the question representations\n        given this attention.  We also return the attention weights themselves.\n\n        This is a simple computation, but we have it as a separate method so that the ``forward``\n        method on the main parser module can call it on the initial hidden state, to simplify the\n        logic in ``take_step``.\n        \"\"\"\n        # (group_size, question_length)\n        question_attention_weights = self._output_attention(query, key, None)\n\n        # (group_size, encoder_output_dim)\n        attended_question = util.weighted_sum(value, question_attention_weights)\n        return attended_question, question_attention_weights\n\n    def _construct_next_states(self,\n                               state: GrammarBasedState,\n                               updated_rnn_state: Dict[str, torch.Tensor],\n                               batch_action_probs: Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]],\n                               max_actions: int,\n                               allowed_actions: List[Set[int]]):\n        # pylint: disable=no-self-use\n\n        # We'll yield a bunch of states here that all have a `group_size` of 1, so that the\n        # learning algorithm can decide how many of these it wants to keep, and it can just regroup\n        # them later, as that's a really easy operation.\n        #\n        # We first define a `make_state` method, as in the logic that follows we want to create\n        # states in a couple of different branches, and we don't want to duplicate the\n        # state-creation logic.  This method creates a closure using variables from the method, so\n        # it doesn't make sense to pull it out of here.\n\n        # Each group index here might get accessed multiple times, and doing the slicing operation\n        # each time is more expensive than doing it once upfront.  These three lines give about a\n        # 10% speedup in training time.\n        group_size = len(state.batch_indices)\n\n        chunk_index = 1 if self._num_layers > 1 else 0\n        hidden_state = [x.squeeze(chunk_index)\n                        for x in updated_rnn_state['hidden_state'].chunk(group_size, chunk_index)]\n        memory_cell = [x.squeeze(chunk_index)\n                       for x in updated_rnn_state['memory_cell'].chunk(group_size, chunk_index)]\n\n        if not state.action_history[0]:\n            decoder_outputs = updated_rnn_state['hidden_state'].unsqueeze(1)\n        else:\n            decoder_outputs = torch.cat((\n                torch.stack([x.decoder_outputs for x in state.rnn_state]),\n                updated_rnn_state['hidden_state'].unsqueeze(1)\n            ), dim=1)\n\n        attended_question = [x.squeeze(0) for x in updated_rnn_state['attended_question'].chunk(group_size, 0)]\n\n        def make_state(group_index: int,\n                       action: int,\n                       new_score: torch.Tensor,\n                       action_embedding: torch.Tensor) -> GrammarBasedState:\n            batch_index = state.batch_indices[group_index]\n            action_entity_id = state.action_entity_mapping[batch_index][action] + 1  # add 1 so that -1 becomes 0 (pad)\n            if not state.action_history[0]:\n                decoded_item_embeddings = state.rnn_state[group_index].item_embeddings[action_entity_id].unsqueeze(0)\n            else:\n                decoded_item_embeddings = torch.cat((\n                    state.rnn_state[group_index].decoded_item_embeddings,\n                    state.rnn_state[group_index].item_embeddings[action_entity_id].unsqueeze(0)\n                ), dim=0)\n\n            new_rnn_state = RnnStatelet(hidden_state[group_index],\n                                        memory_cell[group_index],\n                                        action_embedding,\n                                        attended_question[group_index],\n                                        state.rnn_state[group_index].encoder_outputs,\n                                        state.rnn_state[group_index].encoder_output_mask,\n                                        state.rnn_state[group_index].item_embeddings,\n                                        decoder_outputs[group_index],\n                                        decoded_item_embeddings)\n            for i, _, current_log_probs, _, actions in batch_action_probs[batch_index]:\n                if i == group_index:\n                    considered_actions = actions\n                    probabilities = current_log_probs.exp().cpu()\n                    break\n            return state.new_state_from_group_index(group_index,\n                                                    action,\n                                                    new_score,\n                                                    new_rnn_state,\n                                                    considered_actions,\n                                                    probabilities,\n                                                    updated_rnn_state['attention_weights'],\n                                                    updated_rnn_state['output_attention_weights'])\n\n        new_states = []\n        for _, results in batch_action_probs.items():\n            if allowed_actions and not max_actions:\n                # If we're given a set of allowed actions, and we're not just keeping the top k of\n                # them, we don't need to do any sorting, so we can speed things up quite a bit.\n                for group_index, log_probs, _, action_embeddings, actions in results:\n                    for log_prob, action_embedding, action in zip(log_probs, action_embeddings, actions):\n                        if action in allowed_actions[group_index]:\n                            new_states.append(make_state(group_index, action, log_prob, action_embedding))\n            else:\n                # In this case, we need to sort the actions.  We'll do that on CPU, as it's easier,\n                # and our action list is on the CPU, anyway.\n                group_indices = []\n                group_log_probs: List[torch.Tensor] = []\n                group_action_embeddings = []\n                group_actions = []\n                for group_index, log_probs, _, action_embeddings, actions in results:\n                    group_indices.extend([group_index] * len(actions))\n                    group_log_probs.append(log_probs)\n                    group_action_embeddings.append(action_embeddings)\n                    group_actions.extend(actions)\n                log_probs = torch.cat(group_log_probs, dim=0)\n                action_embeddings = torch.cat(group_action_embeddings, dim=0)\n                log_probs_cpu = log_probs.data.cpu().numpy().tolist()\n                batch_states = [(log_probs_cpu[i],\n                                 group_indices[i],\n                                 log_probs[i],\n                                 action_embeddings[i],\n                                 group_actions[i])\n                                for i in range(len(group_actions))\n                                if (not allowed_actions or\n                                    group_actions[i] in allowed_actions[group_indices[i]])]\n                # We use a key here to make sure we're not trying to compare anything on the GPU.\n                batch_states.sort(key=lambda x: x[0], reverse=True)\n                if max_actions:\n                    batch_states = batch_states[:max_actions]\n                for _, group_index, log_prob, action_embedding, action in batch_states:\n                    new_states.append(make_state(group_index, action, log_prob, action_embedding))\n        return new_states\n"
  },
  {
    "path": "train_configs/defaults.jsonnet",
    "content": "local dataset_path = \"dataset/\";\n\n{\n  \"random_seed\": 5,\n  \"numpy_seed\": 5,\n  \"pytorch_seed\": 5,\n  \"dataset_reader\": {\n    \"type\": \"spider\",\n    \"tables_file\": dataset_path + \"tables.json\",\n    \"dataset_path\": dataset_path + \"database\",\n    \"lazy\": false,\n    \"keep_if_unparsable\": false,\n    \"loading_limit\": -1\n  },\n  \"validation_dataset_reader\": {\n    \"type\": \"spider\",\n    \"tables_file\": dataset_path + \"tables.json\",\n    \"dataset_path\": dataset_path + \"database\",\n    \"lazy\": false,\n    \"keep_if_unparsable\": true,\n    \"loading_limit\": -1\n  },\n  \"train_data_path\": dataset_path + \"train_spider.json\",\n  \"validation_data_path\": dataset_path + \"dev.json\",\n  \"model\": {\n    \"type\": \"spider\",\n    \"dataset_path\": dataset_path,\n    \"parse_sql_on_decoding\": true,\n    \"gnn\": true,\n    \"gnn_timesteps\": 3,\n    \"decoder_self_attend\": true,\n    \"decoder_use_graph_entities\": true,\n    \"use_neighbor_similarity_for_linking\": true,\n    \"question_embedder\": {\n      \"tokens\": {\n        \"type\": \"embedding\",\n        \"embedding_dim\": 200,\n        \"trainable\": true\n      }\n    },\n    \"action_embedding_dim\": 200,\n    \"encoder\": {\n      \"type\": \"lstm\",\n      \"input_size\": 400,\n      \"hidden_size\": 400,\n      \"bidirectional\": true,\n      \"num_layers\": 1\n    },\n    \"entity_encoder\": {\n      \"type\": \"boe\",\n      \"embedding_dim\": 200,\n      \"averaged\": true\n    },\n    \"decoder_beam_search\": {\n      \"beam_size\": 10\n    },\n    \"training_beam_size\": 1,\n    \"max_decoding_steps\": 100,\n    \"input_attention\": {\"type\": \"dot_product\"},\n    \"past_attention\": {\"type\": \"dot_product\"},\n    \"dropout\": 0.5\n  },\n  \"iterator\": {\n    \"type\": \"basic\",\n    \"batch_size\" : 15\n  },\n  \"validation_iterator\": {\n    \"type\": \"basic\",\n    \"batch_size\" : 1\n  },\n  \"trainer\": {\n    \"num_epochs\": 100,\n    \"cuda_device\": 0,\n    \"patience\": 20,\n    \"validation_metric\": \"+sql_match\",\n    \"optimizer\": {\n      \"type\": \"adam\",\n      \"lr\": 0.001,\n      \"weight_decay\": 5e-4\n    },\n    \"num_serialized_models_to_keep\": 2\n  }\n}\n"
  },
  {
    "path": "train_configs/paper_defaults.jsonnet",
    "content": "local dataset_path = \"dataset/\";\n\n{\n  \"random_seed\": 5,\n  \"numpy_seed\": 5,\n  \"pytorch_seed\": 5,\n  \"dataset_reader\": {\n    \"type\": \"spider\",\n    \"tables_file\": dataset_path + \"tables.json\",\n    \"dataset_path\": dataset_path + \"database\",\n    \"lazy\": false,\n    \"keep_if_unparsable\": false,\n    \"loading_limit\": -1\n  },\n  \"validation_dataset_reader\": {\n    \"type\": \"spider\",\n    \"tables_file\": dataset_path + \"tables.json\",\n    \"dataset_path\": dataset_path + \"database\",\n    \"lazy\": false,\n    \"keep_if_unparsable\": true,\n    \"loading_limit\": -1\n  },\n  \"train_data_path\": dataset_path + \"train_spider.json\",\n  \"validation_data_path\": dataset_path + \"dev.json\",\n  \"model\": {\n    \"type\": \"spider\",\n    \"dataset_path\": dataset_path,\n    \"parse_sql_on_decoding\": true,\n    \"gnn\": true,\n    \"gnn_timesteps\": 2,\n    \"decoder_self_attend\": true,\n    \"decoder_use_graph_entities\": true,\n    \"use_neighbor_similarity_for_linking\": true,\n    \"question_embedder\": {\n      \"tokens\": {\n        \"type\": \"embedding\",\n        \"embedding_dim\": 200,\n        \"trainable\": true\n      }\n    },\n    \"action_embedding_dim\": 200,\n    \"encoder\": {\n      \"type\": \"lstm\",\n      \"input_size\": 400,\n      \"hidden_size\": 200,\n      \"bidirectional\": true,\n      \"num_layers\": 1\n    },\n    \"entity_encoder\": {\n      \"type\": \"boe\",\n      \"embedding_dim\": 200,\n      \"averaged\": true\n    },\n    \"decoder_beam_search\": {\n      \"beam_size\": 10\n    },\n    \"training_beam_size\": 1,\n    \"max_decoding_steps\": 100,\n    \"input_attention\": {\"type\": \"dot_product\"},\n    \"past_attention\": {\"type\": \"dot_product\"},\n    \"dropout\": 0.5\n  },\n  \"iterator\": {\n    \"type\": \"basic\",\n    \"batch_size\" : 15\n  },\n  \"validation_iterator\": {\n    \"type\": \"basic\",\n    \"batch_size\" : 1\n  },\n  \"trainer\": {\n    \"num_epochs\": 100,\n    \"cuda_device\": 0,\n    \"patience\": 20,\n    \"validation_metric\": \"+sql_match\",\n    \"optimizer\": {\n      \"type\": \"adam\",\n      \"lr\": 0.001\n    },\n    \"num_serialized_models_to_keep\": 2\n  }\n}"
  }
]