Full Code of benbogin/spider-schema-gnn for AI

master 02f4ae43b891 cached
28 files
285.7 KB
65.2k tokens
235 symbols
1 requests
Download .txt
Showing preview only (298K chars total). Download the full file or copy to clipboard to get everything.
Repository: benbogin/spider-schema-gnn
Branch: master
Commit: 02f4ae43b891
Files: 28
Total size: 285.7 KB

Directory structure:
gitextract_4_pkjgz2/

├── README.md
├── dataset_readers/
│   ├── __init__.py
│   ├── dataset_util/
│   │   └── spider_utils.py
│   ├── fields/
│   │   └── knowledge_graph_field.py
│   └── spider.py
├── models/
│   ├── __init__.py
│   └── semantic_parsing/
│       ├── __init__.py
│       └── spider_parser.py
├── modules/
│   └── gated_graph_conv.py
├── predictors/
│   └── spider_predictor.py
├── requirements.txt
├── semparse/
│   ├── contexts/
│   │   ├── spider_context_utils.py
│   │   ├── spider_db_context.py
│   │   └── spider_db_grammar.py
│   └── worlds/
│       ├── evaluate.py
│       ├── evaluate_spider.py
│       └── spider_world.py
├── spider_evaluation/
│   ├── evaluate.py
│   └── process_sql.py
├── state_machines/
│   ├── states/
│   │   ├── grammar_based_state.py
│   │   ├── rnn_statelet.py
│   │   └── sql_state.py
│   └── transition_functions/
│       ├── attend_past_schema_items_transition.py
│       ├── basic_transition_function.py
│       ├── linking_transition_function.py
│       └── prefix_attend_transition.py
└── train_configs/
    ├── defaults.jsonnet
    └── paper_defaults.jsonnet

================================================
FILE CONTENTS
================================================

================================================
FILE: README.md
================================================
# Representing Schema Structure with Graph Neural Networks for Text-to-SQL Parsing

Author implementation of this [ACL 2019 paper](https://arxiv.org/abs/1905.06241).

Please also see the [follow-up repository](https://github.com/benbogin/spider-schema-gnn-global) with improved results, for this [EMNLP paper](https://www.aclweb.org/anthology/D19-1378.pdf).

## Install & Configure

1. Install pytorch version 1.0.1.post2 that fits your CUDA version 
   
   (this repository should probably work with the latest pytorch version, but wasn't tested for it. If you use another version, you'll need to also update the versions of packages in `requirements.txt`)
    ```
    pip install https://download.pytorch.org/whl/cu100/torch-1.0.1.post2-cp37-cp37m-linux_x86_64.whl # CUDA 10.0 build
    ```
    
2. Install the rest of required packages
    ```
    pip install -r requirements.txt
    ```
    
3. Run this command to install NLTK punkt.
```
python -c "import nltk; nltk.download('punkt')"
```

4. Download the dataset from the [official Spider dataset website](https://yale-lily.github.io/spider)

5. Edit the config file `train_configs/defaults.jsonnet` to update the location of the dataset:
```
local dataset_path = "dataset/";
```

## Training

1. Use the following AllenNLP command to train:
```
allennlp train train_configs/defaults.jsonnet -s experiments/name_of_experiment \
--include-package dataset_readers.spider \ 
--include-package models.semantic_parsing.spider_parser
``` 

First time loading of the dataset might take a while (a few hours) since the model first loads values from tables and calculates similarity features with the relevant question. It will then be cached for subsequent runs.

You should get results similar to the following (the `sql_match` is the one measured in the official evaluation test):
```
  "best_validation__match/exact_match": 0.3715686274509804,
  "best_validation_sql_match": 0.47549019607843135,
  "best_validation__others/action_similarity": 0.5731271471206189,
  "best_validation__match/match_single": 0.6254612546125461,
  "best_validation__match/match_hard": 0.3054393305439331,
  "best_validation_beam_hit": 0.6070588235294118,
  "best_validation_loss": 7.383035182952881
  "best_epoch": 32
```

Note that the hyper-parameters used in `defaults.jsonnet` are different than those mentioned in the paper
(most importantly, 3 timesteps are used instead of 2), thanks to the [following contribution from @wlhgtc](https://github.com/benbogin/spider-schema-gnn/pull/13).
The original training config file is still available in `train_configs/paper_Defaults.jsonnet`.

## Inference

Use the following AllenNLP command to output a file with the predicted queries:

```
allennlp predict experiments/name_of_experiment dataset/dev.json \
--predictor spider \
--use-dataset-reader \
--cuda-device=0 \
--output-file experiments/name_of_experiment/prediction.sql \
--silent \
--include-package models.semantic_parsing.spider_parser \
--include-package dataset_readers.spider \
--include-package predictors.spider_predictor \
--weights-file experiments/name_of_experiment/best.th \
-o "{\"dataset_reader\":{\"keep_if_unparsable\":true}}"
```


================================================
FILE: dataset_readers/__init__.py
================================================


================================================
FILE: dataset_readers/dataset_util/spider_utils.py
================================================
"""
Utility functions for reading the standardised text2sql datasets presented in
`"Improving Text to SQL Evaluation Methodology" <https://arxiv.org/abs/1806.09029>`_
"""
import json
import os
import sqlite3
from collections import defaultdict
from typing import List, Dict, Optional

from allennlp.common import JsonDict

from spider_evaluation.process_sql import get_tables_with_alias, parse_sql


class TableColumn:
    def __init__(self,
                 name: str,
                 text: str,
                 column_type: str,
                 is_primary_key: bool,
                 foreign_key: Optional[str]):
        self.name = name
        self.text = text
        self.column_type = column_type
        self.is_primary_key = is_primary_key
        self.foreign_key = foreign_key


class Table:
    def __init__(self,
                 name: str,
                 text: str,
                 columns: List[TableColumn]):
        self.name = name
        self.text = text
        self.columns = columns


def read_dataset_schema(schema_path: str) -> Dict[str, List[Table]]:
    schemas: Dict[str, Dict[str, Table]] = defaultdict(dict)
    dbs_json_blob = json.load(open(schema_path, "r"))
    for db in dbs_json_blob:
        db_id = db['db_id']

        column_id_to_table = {}
        column_id_to_column = {}

        for i, (column, text, column_type) in enumerate(zip(db['column_names_original'], db['column_names'], db['column_types'])):
            table_id, column_name = column
            _, column_text = text

            table_name = db['table_names_original'][table_id]

            if table_name not in schemas[db_id]:
                table_text = db['table_names'][table_id]
                schemas[db_id][table_name] = Table(table_name, table_text, [])

            if column_name == "*":
                continue

            is_primary_key = i in db['primary_keys']
            table_column = TableColumn(column_name.lower(), column_text, column_type, is_primary_key, None)
            schemas[db_id][table_name].columns.append(table_column)
            column_id_to_table[i] = table_name
            column_id_to_column[i] = table_column

        for (c1, c2) in db['foreign_keys']:
            foreign_key = column_id_to_table[c2] + ':' + column_id_to_column[c2].name
            column_id_to_column[c1].foreign_key = foreign_key

    return {**schemas}


def read_dataset_values(db_id: str, dataset_path: str, tables: List[str]):
    db = os.path.join(dataset_path, db_id, db_id + ".sqlite")
    try:
        conn = sqlite3.connect(db)
    except Exception as e:
        raise Exception(f"Can't connect to SQL: {e} in path {db}")
    conn.text_factory = str
    cursor = conn.cursor()

    values = {}

    for table in tables:
        try:
            cursor.execute(f"SELECT * FROM {table.name} LIMIT 5000")
            values[table] = cursor.fetchall()
        except:
            conn.text_factory = lambda x: str(x, 'latin1')
            cursor = conn.cursor()
            cursor.execute(f"SELECT * FROM {table.name} LIMIT 5000")
            values[table] = cursor.fetchall()

    return values


def ent_key_to_name(key):
    parts = key.split(':')
    if parts[0] == 'table':
        return parts[1]
    elif parts[0] == 'column':
        _, _, table_name, column_name = parts
        return f'{table_name}@{column_name}'
    else:
        return parts[1]


def fix_number_value(ex: JsonDict):
    """
    There is something weird in the dataset files - the `query_toks_no_value` field anonymizes all values,
    which is good since the evaluator doesn't check for the values. But it also anonymizes numbers that
    should not be anonymized: e.g. LIMIT 3 becomes LIMIT 'value', while the evaluator fails if it is not a number.
    """

    def split_and_keep(s, sep):
        if not s: return ['']  # consistent with string.split()

        # Find replacement character that is not used in string
        # i.e. just use the highest available character plus one
        # Note: This fails if ord(max(s)) = 0x10FFFF (ValueError)
        p = chr(ord(max(s)) + 1)

        return s.replace(sep, p + sep + p).split(p)

    # input is tokenized in different ways... so first try to make splits equal
    query_toks = ex['query_toks']
    ex['query_toks'] = []
    for q in query_toks:
        ex['query_toks'] += split_and_keep(q, '.')

    i_val, i_no_val = 0, 0
    while i_val < len(ex['query_toks']) and i_no_val < len(ex['query_toks_no_value']):
        if ex['query_toks_no_value'][i_no_val] != 'value':
            i_val += 1
            i_no_val += 1
            continue

        i_val_end = i_val
        while i_val + 1 < len(ex['query_toks']) and \
                i_no_val + 1 < len(ex['query_toks_no_value']) and \
                ex['query_toks'][i_val_end + 1].lower() != ex['query_toks_no_value'][i_no_val + 1].lower():
            i_val_end += 1

        if i_val == i_val_end and ex['query_toks'][i_val] in ["1", "2", "3"] and ex['query_toks'][i_val - 1].lower() == "limit":
            ex['query_toks_no_value'][i_no_val] = ex['query_toks'][i_val]
        i_val = i_val_end

        i_val += 1
        i_no_val += 1

    return ex


_schemas_cache = None


def disambiguate_items(db_id: str, query_toks: List[str], tables_file: str, allow_aliases: bool) -> List[str]:
    """
    we want the query tokens to be non-ambiguous - so we can change each column name to explicitly
    tell which table it belongs to

    parsed sql to sql clause is based on supermodel.gensql from syntaxsql
    """

    class Schema:
        """
        Simple schema which maps table&column to a unique identifier
        """

        def __init__(self, schema, table):
            self._schema = schema
            self._table = table
            self._idMap = self._map(self._schema, self._table)

        @property
        def schema(self):
            return self._schema

        @property
        def idMap(self):
            return self._idMap

        def _map(self, schema, table):
            column_names_original = table['column_names_original']
            table_names_original = table['table_names_original']
            # print 'column_names_original: ', column_names_original
            # print 'table_names_original: ', table_names_original
            for i, (tab_id, col) in enumerate(column_names_original):
                if tab_id == -1:
                    idMap = {'*': i}
                else:
                    key = table_names_original[tab_id].lower()
                    val = col.lower()
                    idMap[key + "." + val] = i

            for i, tab in enumerate(table_names_original):
                key = tab.lower()
                idMap[key] = i

            return idMap

    def get_schemas_from_json(fpath):
        global _schemas_cache

        if _schemas_cache is not None:
            return _schemas_cache

        with open(fpath) as f:
            data = json.load(f)
        db_names = [db['db_id'] for db in data]

        tables = {}
        schemas = {}
        for db in data:
            db_id = db['db_id']
            schema = {}  # {'table': [col.lower, ..., ]} * -> __all__
            column_names_original = db['column_names_original']
            table_names_original = db['table_names_original']
            tables[db_id] = {'column_names_original': column_names_original,
                             'table_names_original': table_names_original}
            for i, tabn in enumerate(table_names_original):
                table = str(tabn.lower())
                cols = [str(col.lower()) for td, col in column_names_original if td == i]
                schema[table] = cols
            schemas[db_id] = schema

        _schemas_cache = schemas, db_names, tables
        return _schemas_cache

    schemas, db_names, tables = get_schemas_from_json(tables_file)
    schema = Schema(schemas[db_id], tables[db_id])

    fixed_toks = []
    i = 0
    while i < len(query_toks):
        tok = query_toks[i]
        if tok == 'value' or tok == "'value'":
            # TODO: value should alawys be between '/" (remove first if clause)
            new_tok = f'"{tok}"'
        elif tok in ['!','<','>'] and query_toks[i+1] == '=':
            new_tok = tok + '='
            i += 1
        elif i+1 < len(query_toks) and query_toks[i+1] == '.':
            new_tok = ''.join(query_toks[i:i+3])
            i += 2
        else:
            new_tok = tok
        fixed_toks.append(new_tok)
        i += 1

    toks = fixed_toks

    tables_with_alias = get_tables_with_alias(schema.schema, toks)
    _, sql, mapped_entities = parse_sql(toks, 0, tables_with_alias, schema, mapped_entities_fn=lambda: [])

    for i, new_name in mapped_entities:
        curr_tok = toks[i]
        if '.' in curr_tok and allow_aliases:
            parts = curr_tok.split('.')
            assert(len(parts) == 2)
            toks[i] = parts[0] + '.' + new_name
        else:
            toks[i] = new_name

    if not allow_aliases:
        toks = [tok for tok in toks if tok not in ['as', 't1', 't2', 't3', 't4']]

    toks = [f'\'value\'' if tok == '"value"' else tok for tok in toks]

    return toks

================================================
FILE: dataset_readers/fields/knowledge_graph_field.py
================================================
"""
``KnowledgeGraphField`` is a ``Field`` which stores a knowledge graph representation.
"""
from typing import List, Dict

from allennlp.data import TokenIndexer, Tokenizer
from allennlp.data.fields.knowledge_graph_field import KnowledgeGraphField
from allennlp.data.tokenizers.token import Token
from allennlp.semparse.contexts.knowledge_graph import KnowledgeGraph


class SpiderKnowledgeGraphField(KnowledgeGraphField):
    """
    This implementation calculates all non-graph-related features (i.e. no related_column),
    then takes each one of the features to calculate related column features, by taking the max score of all neighbours
    """
    def __init__(self,
                 knowledge_graph: KnowledgeGraph,
                 utterance_tokens: List[Token],
                 token_indexers: Dict[str, TokenIndexer],
                 tokenizer: Tokenizer = None,
                 feature_extractors: List[str] = None,
                 entity_tokens: List[List[Token]] = None,
                 linking_features: List[List[List[float]]] = None,
                 include_in_vocab: bool = True,
                 max_table_tokens: int = None) -> None:
        feature_extractors = feature_extractors if feature_extractors is not None else [
                'number_token_match',
                'exact_token_match',
                'contains_exact_token_match',
                'lemma_match',
                'contains_lemma_match',
                'edit_distance',
                'span_overlap_fraction',
                'span_lemma_overlap_fraction',
                ]

        super().__init__(knowledge_graph, utterance_tokens, token_indexers,
                         tokenizer=tokenizer, feature_extractors=feature_extractors, entity_tokens=entity_tokens,
                         linking_features=linking_features, include_in_vocab=include_in_vocab,
                         max_table_tokens=max_table_tokens)

        self.linking_features = self._compute_related_linking_features(self.linking_features)

        # hack needed to fix calculation of feature extractors in the inherited as_tensor method
        self._feature_extractors = feature_extractors * 2

    def _compute_related_linking_features(self,
                                          non_related_features: List[List[List[float]]]) -> List[List[List[float]]]:
        linking_features = non_related_features
        entity_to_index_map = {}
        for entity_id, entity in enumerate(self.knowledge_graph.entities):
            entity_to_index_map[entity] = entity_id
        for entity_id, (entity, entity_text) in enumerate(zip(self.knowledge_graph.entities, self.entity_texts)):
            for token_index, token in enumerate(self.utterance_tokens):
                entity_token_features = linking_features[entity_id][token_index]
                for feature_index, feature_extractor in enumerate(self._feature_extractors):
                    neighbour_features = []
                    for neighbor in self.knowledge_graph.neighbors[entity]:
                        # we only care about table/columns relations here, not foreign-primary
                        if entity.startswith('column') and neighbor.startswith('column'):
                            continue
                        neighbor_index = entity_to_index_map[neighbor]
                        neighbour_features.append(non_related_features[neighbor_index][token_index][feature_index])

                    entity_token_features.append(max(neighbour_features))
        return linking_features


================================================
FILE: dataset_readers/spider.py
================================================
import json
import logging
import os
from typing import List, Dict

import dill
from allennlp.common.checks import ConfigurationError
from allennlp.data import DatasetReader, Tokenizer, TokenIndexer, Field, Instance
from allennlp.data.fields import TextField, ProductionRuleField, ListField, IndexField, MetadataField
from allennlp.data.token_indexers import SingleIdTokenIndexer
from allennlp.data.tokenizers import WordTokenizer
from allennlp.data.tokenizers.word_splitter import SpacyWordSplitter
from overrides import overrides
from spacy.symbols import ORTH, LEMMA

from dataset_readers.dataset_util.spider_utils import fix_number_value, disambiguate_items
from dataset_readers.fields.knowledge_graph_field import SpiderKnowledgeGraphField
from semparse.contexts.spider_db_context import SpiderDBContext
from semparse.worlds.spider_world import SpiderWorld

logger = logging.getLogger(__name__)


@DatasetReader.register("spider")
class SpiderDatasetReader(DatasetReader):
    def __init__(self,
                 lazy: bool = False,
                 question_token_indexers: Dict[str, TokenIndexer] = None,
                 keep_if_unparsable: bool = True,
                 tables_file: str = None,
                 dataset_path: str = 'dataset/database',
                 load_cache: bool = True,
                 save_cache: bool = True,
                 loading_limit = -1):
        super().__init__(lazy=lazy)

        # default spacy tokenizer splits the common token 'id' to ['i', 'd'], we here write a manual fix for that
        spacy_tokenizer = SpacyWordSplitter(pos_tags=True)
        spacy_tokenizer.spacy.tokenizer.add_special_case(u'id', [{ORTH: u'id', LEMMA: u'id'}])
        self._tokenizer = WordTokenizer(spacy_tokenizer)

        self._utterance_token_indexers = question_token_indexers or {'tokens': SingleIdTokenIndexer()}
        self._keep_if_unparsable = keep_if_unparsable

        self._tables_file = tables_file
        self._dataset_path = dataset_path

        self._load_cache = load_cache
        self._save_cache = save_cache
        self._loading_limit = loading_limit

    @overrides
    def _read(self, file_path: str):
        if not file_path.endswith('.json'):
            raise ConfigurationError(f"Don't know how to read filetype of {file_path}")

        cache_dir = os.path.join('cache', file_path.split("/")[-1])

        if self._load_cache:
            logger.info(f'Trying to load cache from {cache_dir}')
        if self._save_cache:
            os.makedirs(cache_dir, exist_ok=True)

        cnt = 0
        with open(file_path, "r") as data_file:
            json_obj = json.load(data_file)
            for total_cnt, ex in enumerate(json_obj):
                cache_filename = f'instance-{total_cnt}.pt'
                cache_filepath = os.path.join(cache_dir, cache_filename)
                if self._loading_limit == cnt:
                    break

                if self._load_cache:
                    try:
                        ins = dill.load(open(cache_filepath, 'rb'))
                        if ins is None and not self._keep_if_unparsable:
                            # skip unparsed examples
                            continue
                        yield ins
                        cnt += 1
                        continue
                    except Exception as e:
                        # could not load from cache - keep loading without cache
                        pass

                query_tokens = None
                if 'query_toks' in ex:
                    # we only have 'query_toks' in example for training/dev sets

                    # fix for examples: we want to use the 'query_toks_no_value' field of the example which anonymizes
                    # values. However, it also anonymizes numbers (e.g. LIMIT 3 -> LIMIT 'value', which is not good
                    # since the official evaluator does expect a number and not a value
                    ex = fix_number_value(ex)

                    # we want the query tokens to be non-ambiguous (i.e. know for each column the table it belongs to,
                    # and for each table alias its explicit name)
                    # we thus remove all aliases and make changes such as:
                    # 'name' -> 'singer@name',
                    # 'singer AS T1' -> 'singer',
                    # 'T1.name' -> 'singer@name'
                    try:
                        query_tokens = disambiguate_items(ex['db_id'], ex['query_toks_no_value'],
                                                                self._tables_file, allow_aliases=False)
                    except Exception as e:
                        # there are two examples in the train set that are wrongly formatted, skip them
                        print(f"error with {ex['query']}")
                        print(e)

                ins = self.text_to_instance(
                    utterance=ex['question'],
                    db_id=ex['db_id'],
                    sql=query_tokens)
                if ins is not None:
                    cnt += 1
                if self._save_cache:
                    dill.dump(ins, open(cache_filepath, 'wb'))

                if ins is not None:
                    yield ins

    def text_to_instance(self,
                         utterance: str,
                         db_id: str,
                         sql: List[str] = None):
        fields: Dict[str, Field] = {}

        db_context = SpiderDBContext(db_id, utterance, tokenizer=self._tokenizer,
                                     tables_file=self._tables_file, dataset_path=self._dataset_path)
        table_field = SpiderKnowledgeGraphField(db_context.knowledge_graph,
                                                db_context.tokenized_utterance,
                                                self._utterance_token_indexers,
                                                entity_tokens=db_context.entity_tokens,
                                                include_in_vocab=False,  # TODO: self._use_table_for_vocab,
                                                max_table_tokens=None)  # self._max_table_tokens)

        world = SpiderWorld(db_context, query=sql)
        fields["utterance"] = TextField(db_context.tokenized_utterance, self._utterance_token_indexers)

        action_sequence, all_actions = world.get_action_sequence_and_all_actions()

        if action_sequence is None and self._keep_if_unparsable:
            # print("Parse error")
            action_sequence = []
        elif action_sequence is None:
            return None

        index_fields: List[Field] = []
        production_rule_fields: List[Field] = []

        for production_rule in all_actions:
            nonterminal, rhs = production_rule.split(' -> ')
            production_rule = ' '.join(production_rule.split(' '))
            field = ProductionRuleField(production_rule,
                                        world.is_global_rule(rhs),
                                        nonterminal=nonterminal)
            production_rule_fields.append(field)

        valid_actions_field = ListField(production_rule_fields)
        fields["valid_actions"] = valid_actions_field

        action_map = {action.rule: i  # type: ignore
                      for i, action in enumerate(valid_actions_field.field_list)}

        for production_rule in action_sequence:
            index_fields.append(IndexField(action_map[production_rule], valid_actions_field))
        if not action_sequence:
            index_fields = [IndexField(-1, valid_actions_field)]

        action_sequence_field = ListField(index_fields)
        fields["action_sequence"] = action_sequence_field
        fields["world"] = MetadataField(world)
        fields["schema"] = table_field
        return Instance(fields)


================================================
FILE: models/__init__.py
================================================


================================================
FILE: models/semantic_parsing/__init__.py
================================================
from models.semantic_parsing.spider_parser import SpiderParser

================================================
FILE: models/semantic_parsing/spider_parser.py
================================================
import difflib
import os
from functools import partial
from typing import Dict, List, Tuple, Any, Mapping, Sequence

import sqlparse
import torch
from allennlp.common.util import pad_sequence_to_length
from allennlp.data import Vocabulary
from allennlp.data.fields.production_rule_field import ProductionRule, ProductionRuleArray
from allennlp.models import Model
from allennlp.modules import TextFieldEmbedder, Seq2SeqEncoder, Seq2VecEncoder, Embedding, Attention, FeedForward, \
    TimeDistributed
from allennlp.modules.seq2vec_encoders import BagOfEmbeddingsEncoder
from allennlp.nn import util, Activation
from allennlp.state_machines import BeamSearch
from allennlp.state_machines.states import GrammarStatelet
from torch_geometric.data import Data, Batch

from modules.gated_graph_conv import GatedGraphConv
from semparse.worlds.evaluate_spider import evaluate
from state_machines.states.rnn_statelet import RnnStatelet
from allennlp.state_machines.trainers import MaximumMarginalLikelihood
from allennlp.training.metrics import Average
from overrides import overrides

from semparse.contexts.spider_context_utils import action_sequence_to_sql
from semparse.worlds.spider_world import SpiderWorld
from state_machines.states.grammar_based_state import GrammarBasedState
from state_machines.states.sql_state import SqlState
from state_machines.transition_functions.attend_past_schema_items_transition import \
    AttendPastSchemaItemsTransitionFunction
from state_machines.transition_functions.linking_transition_function import LinkingTransitionFunction


@Model.register("spider")
class SpiderParser(Model):
    def __init__(self,
                 vocab: Vocabulary,
                 encoder: Seq2SeqEncoder,
                 entity_encoder: Seq2VecEncoder,
                 decoder_beam_search: BeamSearch,
                 question_embedder: TextFieldEmbedder,
                 input_attention: Attention,
                 past_attention: Attention,
                 max_decoding_steps: int,
                 action_embedding_dim: int,
                 gnn: bool = True,
                 decoder_use_graph_entities: bool = True,
                 decoder_self_attend: bool = True,
                 gnn_timesteps: int = 2,
                 parse_sql_on_decoding: bool = True,
                 add_action_bias: bool = True,
                 use_neighbor_similarity_for_linking: bool = True,
                 dataset_path: str = 'dataset',
                 training_beam_size: int = None,
                 decoder_num_layers: int = 1,
                 dropout: float = 0.0,
                 rule_namespace: str = 'rule_labels',
                 scoring_dev_params: dict = None,
                 debug_parsing: bool = False) -> None:
        super().__init__(vocab)
        self.vocab = vocab
        self._encoder = encoder
        self._max_decoding_steps = max_decoding_steps
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._rule_namespace = rule_namespace
        self._question_embedder = question_embedder
        self._add_action_bias = add_action_bias
        self._scoring_dev_params = scoring_dev_params or {}
        self.parse_sql_on_decoding = parse_sql_on_decoding
        self._entity_encoder = TimeDistributed(entity_encoder)
        self._use_neighbor_similarity_for_linking = use_neighbor_similarity_for_linking
        self._self_attend = decoder_self_attend
        self._decoder_use_graph_entities = decoder_use_graph_entities

        self._action_padding_index = -1  # the padding value used by IndexField

        self._exact_match = Average()
        self._sql_evaluator_match = Average()
        self._action_similarity = Average()
        self._acc_single = Average()
        self._acc_multi = Average()
        self._beam_hit = Average()

        self._action_embedding_dim = action_embedding_dim

        num_actions = vocab.get_vocab_size(self._rule_namespace)
        if self._add_action_bias:
            input_action_dim = action_embedding_dim + 1
        else:
            input_action_dim = action_embedding_dim
        self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=input_action_dim)
        self._output_action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim)

        encoder_output_dim = encoder.get_output_dim()
        if gnn:
            encoder_output_dim += action_embedding_dim

        self._first_action_embedding = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim))
        self._first_attended_utterance = torch.nn.Parameter(torch.FloatTensor(encoder_output_dim))
        self._first_attended_output = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_utterance)
        torch.nn.init.normal_(self._first_attended_output)

        self._num_entity_types = 9
        self._embedding_dim = question_embedder.get_output_dim()

        self._entity_type_encoder_embedding = Embedding(self._num_entity_types, self._embedding_dim)
        self._entity_type_decoder_embedding = Embedding(self._num_entity_types, action_embedding_dim)

        self._linking_params = torch.nn.Linear(16, 1)
        torch.nn.init.uniform_(self._linking_params.weight, 0, 1)

        num_edge_types = 3
        self._gnn = GatedGraphConv(self._embedding_dim, gnn_timesteps, num_edge_types=num_edge_types, dropout=dropout)

        self._decoder_num_layers = decoder_num_layers

        self._beam_search = decoder_beam_search
        self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size)

        if decoder_self_attend:
            self._transition_function = AttendPastSchemaItemsTransitionFunction(encoder_output_dim=encoder_output_dim,
                                                                                action_embedding_dim=action_embedding_dim,
                                                                                input_attention=input_attention,
                                                                                past_attention=past_attention,
                                                                                predict_start_type_separately=False,
                                                                                add_action_bias=self._add_action_bias,
                                                                                dropout=dropout,
                                                                                num_layers=self._decoder_num_layers)
        else:
            self._transition_function = LinkingTransitionFunction(encoder_output_dim=encoder_output_dim,
                                                                  action_embedding_dim=action_embedding_dim,
                                                                  input_attention=input_attention,
                                                                  predict_start_type_separately=False,
                                                                  add_action_bias=self._add_action_bias,
                                                                  dropout=dropout,
                                                                  num_layers=self._decoder_num_layers)

        self._ent2ent_ff = FeedForward(action_embedding_dim, 1, action_embedding_dim, Activation.by_name('relu')())

        self._neighbor_params = torch.nn.Linear(self._embedding_dim, self._embedding_dim)

        # TODO: Remove hard-coded dirs
        self._evaluate_func = partial(evaluate,
                                      db_dir=os.path.join(dataset_path, 'database'),
                                      table=os.path.join(dataset_path, 'tables.json'),
                                      check_valid=False)

        self.debug_parsing = debug_parsing

    @overrides
    def forward(self,  # type: ignore
                utterance: Dict[str, torch.LongTensor],
                valid_actions: List[List[ProductionRule]],
                world: List[SpiderWorld],
                schema: Dict[str, torch.LongTensor],
                action_sequence: torch.LongTensor = None) -> Dict[str, torch.Tensor]:

        batch_size = len(world)
        device = utterance['tokens'].device

        initial_state = self._get_initial_state(utterance, world, schema, valid_actions)

        if action_sequence is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            action_sequence = action_sequence.squeeze(-1)
            action_mask = action_sequence != self._action_padding_index
        else:
            action_mask = None

        if self.training:
            decode_output = self._decoder_trainer.decode(initial_state,
                                                         self._transition_function,
                                                         (action_sequence.unsqueeze(1), action_mask.unsqueeze(1)))

            return {'loss': decode_output['loss']}
        else:
            loss = torch.tensor([0]).float().to(device)
            if action_sequence is not None and action_sequence.size(1) > 1:
                try:
                    loss = self._decoder_trainer.decode(initial_state,
                                                        self._transition_function,
                                                        (action_sequence.unsqueeze(1),
                                                         action_mask.unsqueeze(1)))['loss']
                except ZeroDivisionError:
                    # reached a dead-end during beam search
                    pass

            outputs: Dict[str, Any] = {
                'loss': loss
            }

            num_steps = self._max_decoding_steps
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]

            best_final_states = self._beam_search.search(num_steps,
                                                         initial_state,
                                                         self._transition_function,
                                                         keep_final_unfinished_states=False)

            self._compute_validation_outputs(valid_actions,
                                             best_final_states,
                                             world,
                                             action_sequence,
                                             outputs)
            return outputs

    def _get_initial_state(self,
                           utterance: Dict[str, torch.LongTensor],
                           worlds: List[SpiderWorld],
                           schema: Dict[str, torch.LongTensor],
                           actions: List[List[ProductionRule]]) -> GrammarBasedState:
        schema_text = schema['text']
        embedded_schema = self._question_embedder(schema_text, num_wrapping_dims=1)
        schema_mask = util.get_text_field_mask(schema_text, num_wrapping_dims=1).float()

        embedded_utterance = self._question_embedder(utterance)
        utterance_mask = util.get_text_field_mask(utterance).float()

        batch_size, num_entities, num_entity_tokens, _ = embedded_schema.size()
        num_entities = max([len(world.db_context.knowledge_graph.entities) for world in worlds])
        num_question_tokens = utterance['tokens'].size(1)

        # entity_types: tensor with shape (batch_size, num_entities), where each entry is the
        # entity's type id.
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(worlds, num_entities, embedded_schema.device)

        entity_type_embeddings = self._entity_type_encoder_embedding(entity_types)

        # Compute entity and question word similarity.  We tried using cosine distance here, but
        # because this similarity is the main mechanism that the model can use to push apart logit
        # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger
        # output range than [-1, 1].
        question_entity_similarity = torch.bmm(embedded_schema.view(batch_size,
                                                                    num_entities * num_entity_tokens,
                                                                    self._embedding_dim),
                                               torch.transpose(embedded_utterance, 1, 2))

        question_entity_similarity = question_entity_similarity.view(batch_size,
                                                                     num_entities,
                                                                     num_entity_tokens,
                                                                     num_question_tokens)
        # (batch_size, num_entities, num_question_tokens)
        question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2)

        # (batch_size, num_entities, num_question_tokens, num_features)
        linking_features = schema['linking']

        linking_scores = question_entity_similarity_max_score

        feature_scores = self._linking_params(linking_features).squeeze(3)

        linking_scores = linking_scores + feature_scores

        # (batch_size, num_question_tokens, num_entities)
        linking_probabilities = self._get_linking_probabilities(worlds, linking_scores.transpose(1, 2),
                                                                utterance_mask, entity_type_dict)

        # (batch_size, num_entities, num_neighbors) or None
        neighbor_indices = self._get_neighbor_indices(worlds, num_entities, linking_scores.device)

        if self._use_neighbor_similarity_for_linking and neighbor_indices is not None:
            # (batch_size, num_entities, embedding_dim)
            encoded_table = self._entity_encoder(embedded_schema, schema_mask)

            # Neighbor_indices is padded with -1 since 0 is a potential neighbor index.
            # Thus, the absolute value needs to be taken in the index_select, and 1 needs to
            # be added for the mask since that method expects 0 for padding.
            # (batch_size, num_entities, num_neighbors, embedding_dim)
            embedded_neighbors = util.batched_index_select(encoded_table, torch.abs(neighbor_indices))

            neighbor_mask = util.get_text_field_mask({'ignored': neighbor_indices + 1},
                                                     num_wrapping_dims=1).float()

            # Encoder initialized to easily obtain a masked average.
            neighbor_encoder = TimeDistributed(BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True))
            # (batch_size, num_entities, embedding_dim)
            embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask)
            projected_neighbor_embeddings = self._neighbor_params(embedded_neighbors.float())

            # (batch_size, num_entities, embedding_dim)
            entity_embeddings = torch.tanh(entity_type_embeddings + projected_neighbor_embeddings)
        else:
            # (batch_size, num_entities, embedding_dim)
            entity_embeddings = torch.tanh(entity_type_embeddings)

        link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities)
        encoder_input = torch.cat([link_embedding, embedded_utterance], 2)

        # (batch_size, utterance_length, encoder_output_dim)
        encoder_outputs = self._dropout(self._encoder(encoder_input, utterance_mask))

        max_entities_relevance = linking_probabilities.max(dim=1)[0]
        entities_relevance = max_entities_relevance.unsqueeze(-1).detach()

        graph_initial_embedding = entity_type_embeddings * entities_relevance

        encoder_output_dim = self._encoder.get_output_dim()
        if self._gnn:
            entities_graph_encoding = self._get_schema_graph_encoding(worlds,
                                                                      graph_initial_embedding)
            graph_link_embedding = util.weighted_sum(entities_graph_encoding, linking_probabilities)
            encoder_outputs = torch.cat((
                encoder_outputs,
                graph_link_embedding
            ), dim=-1)
            encoder_output_dim = self._action_embedding_dim + self._encoder.get_output_dim()
        else:
            entities_graph_encoding = None

        if self._self_attend:
            # linked_actions_linking_scores = self._get_linked_actions_linking_scores(actions, entities_graph_encoding)
            entities_ff = self._ent2ent_ff(entities_graph_encoding)
            linked_actions_linking_scores = torch.bmm(entities_ff, entities_ff.transpose(1, 2))
        else:
            linked_actions_linking_scores = [None] * batch_size

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(encoder_outputs,
                                                             utterance_mask,
                                                             self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size, encoder_output_dim)
        initial_score = embedded_utterance.data.new_zeros(batch_size)

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, utterance_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        utterance_mask_list = [utterance_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(RnnStatelet(final_encoder_output[i],
                                                 memory_cell[i],
                                                 self._first_action_embedding,
                                                 self._first_attended_utterance,
                                                 encoder_output_list,
                                                 utterance_mask_list))

        initial_grammar_state = [self._create_grammar_state(worlds[i],
                                                            actions[i],
                                                            linking_scores[i],
                                                            linked_actions_linking_scores[i],
                                                            entity_types[i],
                                                            entities_graph_encoding[
                                                                i] if entities_graph_encoding is not None else None)
                                 for i in range(batch_size)]

        initial_sql_state = [SqlState(actions[i], self.parse_sql_on_decoding) for i in range(batch_size)]

        initial_state = GrammarBasedState(batch_indices=list(range(batch_size)),
                                          action_history=[[] for _ in range(batch_size)],
                                          score=initial_score_list,
                                          rnn_state=initial_rnn_state,
                                          grammar_state=initial_grammar_state,
                                          sql_state=initial_sql_state,
                                          possible_actions=actions,
                                          action_entity_mapping=[w.get_action_entity_mapping() for w in worlds])

        return initial_state

    @staticmethod
    def _get_neighbor_indices(worlds: List[SpiderWorld],
                              num_entities: int,
                              device: torch.device) -> torch.LongTensor:
        """
        This method returns the indices of each entity's neighbors. A tensor
        is accepted as a parameter for copying purposes.

        Parameters
        ----------
        worlds : ``List[SpiderWorld]``
        num_entities : ``int``
        tensor : ``torch.Tensor``
            Used for copying the constructed list onto the right device.

        Returns
        -------
        A ``torch.LongTensor`` with shape ``(batch_size, num_entities, num_neighbors)``. It is padded
        with -1 instead of 0, since 0 is a valid neighbor index. If all the entities in the batch
        have no neighbors, None will be returned.
        """

        num_neighbors = 0
        for world in worlds:
            for entity in world.db_context.knowledge_graph.entities:
                if len(world.db_context.knowledge_graph.neighbors[entity]) > num_neighbors:
                    num_neighbors = len(world.db_context.knowledge_graph.neighbors[entity])

        batch_neighbors = []
        no_entities_have_neighbors = True
        for world in worlds:
            # Each batch instance has its own world, which has a corresponding table.
            entities = world.db_context.knowledge_graph.entities
            entity2index = {entity: i for i, entity in enumerate(entities)}
            entity2neighbors = world.db_context.knowledge_graph.neighbors
            neighbor_indexes = []
            for entity in entities:
                entity_neighbors = [entity2index[n] for n in entity2neighbors[entity]]
                if entity_neighbors:
                    no_entities_have_neighbors = False
                # Pad with -1 instead of 0, since 0 represents a neighbor index.
                padded = pad_sequence_to_length(entity_neighbors, num_neighbors, lambda: -1)
                neighbor_indexes.append(padded)
            neighbor_indexes = pad_sequence_to_length(neighbor_indexes,
                                                      num_entities,
                                                      lambda: [-1] * num_neighbors)
            batch_neighbors.append(neighbor_indexes)
        # It is possible that none of the entities has any neighbors, since our definition of the
        # knowledge graph allows it when no entities or numbers were extracted from the question.
        if no_entities_have_neighbors:
            return None
        return torch.tensor(batch_neighbors, device=device, dtype=torch.long)

    def _get_schema_graph_encoding(self,
                                   worlds: List[SpiderWorld],
                                   initial_graph_embeddings: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        max_num_entities = max([len(world.db_context.knowledge_graph.entities) for world in worlds])
        batch_size = initial_graph_embeddings.size(0)

        graph_data_list = []

        for batch_index, world in enumerate(worlds):
            x = initial_graph_embeddings[batch_index]

            adj_list = self._get_graph_adj_lists(initial_graph_embeddings.device,
                                                 world, initial_graph_embeddings.size(1) - 1)
            graph_data = Data(x)
            for i, l in enumerate(adj_list):
                graph_data[f'edge_index_{i}'] = l
            graph_data_list.append(graph_data)

        batch = Batch.from_data_list(graph_data_list)

        gnn_output = self._gnn(batch.x, [batch[f'edge_index_{i}'] for i in range(self._gnn.num_edge_types)])

        num_nodes = max_num_entities
        gnn_output = gnn_output.view(batch_size, num_nodes, -1)
        # entities_encodings = gnn_output
        entities_encodings = gnn_output[:, :max_num_entities]
        # global_node_encodings = gnn_output[:, max_num_entities]

        return entities_encodings

    @staticmethod
    def _get_graph_adj_lists(device, world, global_entity_id, global_node=False):
        entity_mapping = {}
        for i, entity in enumerate(world.db_context.knowledge_graph.entities):
            entity_mapping[entity] = i
        entity_mapping['_global_'] = global_entity_id
        adj_list_own = []  # column--table
        adj_list_link = []  # table->table / foreign->primary
        adj_list_linked = []  # table<-table / foreign<-primary
        adj_list_global = []  # node->global

        # TODO: Prepare in advance?
        for key, neighbors in world.db_context.knowledge_graph.neighbors.items():
            idx_source = entity_mapping[key]
            for n_key in neighbors:
                idx_target = entity_mapping[n_key]
                if n_key.startswith("table") or key.startswith("table"):
                    adj_list_own.append((idx_source, idx_target))
                elif n_key.startswith("string") or key.startswith("string"):
                    adj_list_own.append((idx_source, idx_target))
                elif key.startswith("column:foreign"):
                    adj_list_link.append((idx_source, idx_target))
                    src_table_key = f"table:{key.split(':')[2]}"
                    tgt_table_key = f"table:{n_key.split(':')[2]}"
                    idx_source_table = entity_mapping[src_table_key]
                    idx_target_table = entity_mapping[tgt_table_key]
                    adj_list_link.append((idx_source_table, idx_target_table))
                elif n_key.startswith("column:foreign"):
                    adj_list_linked.append((idx_source, idx_target))
                    src_table_key = f"table:{key.split(':')[2]}"
                    tgt_table_key = f"table:{n_key.split(':')[2]}"
                    idx_source_table = entity_mapping[src_table_key]
                    idx_target_table = entity_mapping[tgt_table_key]
                    adj_list_linked.append((idx_source_table, idx_target_table))
                else:
                    assert False

            adj_list_global.append((idx_source, entity_mapping['_global_']))

        all_adj_types = [adj_list_own, adj_list_link, adj_list_linked]

        if global_node:
            all_adj_types.append(adj_list_global)

        return [torch.tensor(l, device=device, dtype=torch.long).transpose(0, 1) if l
                else torch.tensor(l, device=device, dtype=torch.long)
                for l in all_adj_types]

    def _create_grammar_state(self,
                              world: SpiderWorld,
                              possible_actions: List[ProductionRule],
                              linking_scores: torch.Tensor,
                              linked_actions_linking_scores: torch.Tensor,
                              entity_types: torch.Tensor,
                              entity_graph_encoding: torch.Tensor) -> GrammarStatelet:
        action_map = {}
        for action_index, action in enumerate(possible_actions):
            action_string = action[0]
            action_map[action_string] = action_index

        valid_actions = world.valid_actions
        entity_map = {}
        entities = world.entities_names

        for entity_index, entity in enumerate(entities):
            entity_map[entity] = entity_index

        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {}
        for key, action_strings in valid_actions.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.  We'll first split those productions by global vs.
            # linked action.

            action_indices = [action_map[action_string] for action_string in action_strings]
            production_rule_arrays = [(possible_actions[index], index) for index in action_indices]
            global_actions = []
            linked_actions = []
            for production_rule_array, action_index in production_rule_arrays:
                if production_rule_array[1]:
                    global_actions.append((production_rule_array[2], action_index))
                else:
                    linked_actions.append((production_rule_array[0], action_index))

            if global_actions:
                global_action_tensors, global_action_ids = zip(*global_actions)
                global_action_tensor = torch.cat(global_action_tensors, dim=0).to(
                    global_action_tensors[0].device).long()
                global_input_embeddings = self._action_embedder(global_action_tensor)
                global_output_embeddings = self._output_action_embedder(global_action_tensor)
                translated_valid_actions[key]['global'] = (global_input_embeddings,
                                                           global_output_embeddings,
                                                           list(global_action_ids))
            if linked_actions:
                linked_rules, linked_action_ids = zip(*linked_actions)
                entities = [rule.split(' -> ')[1].strip('[]\"') for rule in linked_rules]

                entity_ids = [entity_map[entity] for entity in entities]

                entity_linking_scores = linking_scores[entity_ids]

                if linked_actions_linking_scores is not None:
                    entity_action_linking_scores = linked_actions_linking_scores[entity_ids]

                if not self._decoder_use_graph_entities:
                    entity_type_tensor = entity_types[entity_ids]
                    entity_type_embeddings = (self._entity_type_decoder_embedding(entity_type_tensor)
                                              .to(entity_types.device)
                                              .float())
                else:
                    entity_type_embeddings = entity_graph_encoding.index_select(
                        dim=0,
                        index=torch.tensor(entity_ids, device=entity_graph_encoding.device)
                    )

                if self._self_attend:
                    translated_valid_actions[key]['linked'] = (entity_linking_scores,
                                                               entity_type_embeddings,
                                                               list(linked_action_ids),
                                                               entity_action_linking_scores)
                else:
                    translated_valid_actions[key]['linked'] = (entity_linking_scores,
                                                               entity_type_embeddings,
                                                               list(linked_action_ids))

        return GrammarStatelet(['statement'],
                               translated_valid_actions,
                               self.is_nonterminal)

    @staticmethod
    def is_nonterminal(token: str):
        if token[0] == '"' and token[-1] == '"':
            return False
        return True

    def _get_linking_probabilities(self,
                                   worlds: List[SpiderWorld],
                                   linking_scores: torch.FloatTensor,
                                   question_mask: torch.LongTensor,
                                   entity_type_dict: Dict[int, int]) -> torch.FloatTensor:
        """
        Produces the probability of an entity given a question word and type. The logic below
        separates the entities by type since the softmax normalization term sums over entities
        of a single type.

        Parameters
        ----------
        worlds : ``List[WikiTablesWorld]``
        linking_scores : ``torch.FloatTensor``
            Has shape (batch_size, num_question_tokens, num_entities).
        question_mask: ``torch.LongTensor``
            Has shape (batch_size, num_question_tokens).
        entity_type_dict : ``Dict[int, int]``
            This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.

        Returns
        -------
        batch_probabilities : ``torch.FloatTensor``
            Has shape ``(batch_size, num_question_tokens, num_entities)``.
            Contains all the probabilities for an entity given a question word.
        """
        _, num_question_tokens, num_entities = linking_scores.size()
        batch_probabilities = []

        for batch_index, world in enumerate(worlds):
            all_probabilities = []
            num_entities_in_instance = 0

            # NOTE: The way that we're doing this here relies on the fact that entities are
            # implicitly sorted by their types when we sort them by name, and that numbers come
            # before "date_column:", followed by "number_column:", "string:", and "string_column:".
            # This is not a great assumption, and could easily break later, but it should work for now.
            for type_index in range(self._num_entity_types):
                # This index of 0 is for the null entity for each type, representing the case where a
                # word doesn't link to any entity.
                entity_indices = [0]
                entities = world.db_context.knowledge_graph.entities
                for entity_index, _ in enumerate(entities):
                    if entity_type_dict[batch_index * num_entities + entity_index] == type_index:
                        entity_indices.append(entity_index)

                if len(entity_indices) == 1:
                    # No entities of this type; move along...
                    continue

                # We're subtracting one here because of the null entity we added above.
                num_entities_in_instance += len(entity_indices) - 1

                # We separate the scores by type, since normalization is done per type.  There's an
                # extra "null" entity per type, also, so we have `num_entities_per_type + 1`.  We're
                # selecting from a (num_question_tokens, num_entities) linking tensor on _dimension 1_,
                # so we get back something of shape (num_question_tokens,) for each index we're
                # selecting.  All of the selected indices together then make a tensor of shape
                # (num_question_tokens, num_entities_per_type + 1).
                indices = linking_scores.new_tensor(entity_indices, dtype=torch.long)
                entity_scores = linking_scores[batch_index].index_select(1, indices)

                # We used index 0 for the null entity, so this will actually have some values in it.
                # But we want the null entity's score to be 0, so we set that here.
                entity_scores[:, 0] = 0

                # No need for a mask here, as this is done per batch instance, with no padding.
                type_probabilities = torch.nn.functional.softmax(entity_scores, dim=1)
                all_probabilities.append(type_probabilities[:, 1:])

            # We need to add padding here if we don't have the right number of entities.
            if num_entities_in_instance != num_entities:
                zeros = linking_scores.new_zeros(num_question_tokens,
                                                 num_entities - num_entities_in_instance)
                all_probabilities.append(zeros)

            # (num_question_tokens, num_entities)
            probabilities = torch.cat(all_probabilities, dim=1)
            batch_probabilities.append(probabilities)
        batch_probabilities = torch.stack(batch_probabilities, dim=0)
        return batch_probabilities * question_mask.unsqueeze(-1).float()

    @staticmethod
    def _action_history_match(predicted: List[int], targets: torch.LongTensor) -> int:
        # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something.
        # Check if target is big enough to cover prediction (including start/end symbols)
        if len(predicted) > targets.size(0):
            return 0
        predicted_tensor = targets.new_tensor(predicted)
        targets_trimmed = targets[:len(predicted)]
        # Return 1 if the predicted sequence is anywhere in the list of targets.
        return torch.max(torch.min(targets_trimmed.eq(predicted_tensor), dim=0)[0]).item()

    @staticmethod
    def _query_difficulty(targets: torch.LongTensor, action_mapping, batch_index):
        number_tables = len([action_mapping[(batch_index, int(a))] for a in targets if
                             a >= 0 and action_mapping[(batch_index, int(a))].startswith('table_name')])
        return number_tables > 1

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
            '_match/exact_match': self._exact_match.get_metric(reset),
            'sql_match': self._sql_evaluator_match.get_metric(reset),
            '_others/action_similarity': self._action_similarity.get_metric(reset),
            '_match/match_single': self._acc_single.get_metric(reset),
            '_match/match_hard': self._acc_multi.get_metric(reset),
            'beam_hit': self._beam_hit.get_metric(reset)
        }

    @staticmethod
    def _get_type_vector(worlds: List[SpiderWorld],
                         num_entities: int,
                         device) -> Tuple[torch.LongTensor, Dict[int, int]]:
        """
        Produces the encoding for each entity's type. In addition, a map from a flattened entity
        index to type is returned to combine entity type operations into one method.

        Parameters
        ----------
        worlds : ``List[AtisWorld]``
        num_entities : ``int``
        tensor : ``torch.Tensor``
            Used for copying the constructed list onto the right device.

        Returns
        -------
        A ``torch.LongTensor`` with shape ``(batch_size, num_entities, num_types)``.
        entity_types : ``Dict[int, int]``
            This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.
        """
        entity_types = {}
        batch_types = []

        column_type_ids = ['boolean', 'foreign', 'number', 'others', 'primary', 'text', 'time']

        for batch_index, world in enumerate(worlds):
            types = []

            for entity_index, entity in enumerate(world.db_context.knowledge_graph.entities):
                parts = entity.split(':')
                entity_main_type = parts[0]
                if entity_main_type == 'column':
                    column_type = parts[1]
                    entity_type = column_type_ids.index(column_type)
                elif entity_main_type == 'string':
                    # cell value
                    entity_type = len(column_type_ids)
                elif entity_main_type == 'table':
                    entity_type = len(column_type_ids) + 1
                else:
                    raise (Exception("Unkown entity"))
                types.append(entity_type)

                # For easier lookups later, we're actually using a _flattened_ version
                # of (batch_index, entity_index) for the key, because this is how the
                # linking scores are stored.
                flattened_entity_index = batch_index * num_entities + entity_index
                entity_types[flattened_entity_index] = entity_type
            padded = pad_sequence_to_length(types, num_entities, lambda: 0)
            batch_types.append(padded)

        return torch.tensor(batch_types, dtype=torch.long, device=device), entity_types

    def _compute_validation_outputs(self,
                                    actions: List[List[ProductionRuleArray]],
                                    best_final_states: Mapping[int, Sequence[GrammarBasedState]],
                                    world: List[SpiderWorld],
                                    target_list: List[List[str]],
                                    outputs: Dict[str, Any]) -> None:
        batch_size = len(actions)

        outputs['predicted_sql_query'] = []

        action_mapping = {}
        for batch_index, batch_actions in enumerate(actions):
            for action_index, action in enumerate(batch_actions):
                action_mapping[(batch_index, action_index)] = action[0]

        for i in range(batch_size):
            # gold sql exactly as given
            original_gold_sql_query = ' '.join(world[i].get_query_without_table_hints())

            if i not in best_final_states:
                self._exact_match(0)
                self._action_similarity(0)
                self._sql_evaluator_match(0)
                self._acc_multi(0)
                self._acc_single(0)
                outputs['predicted_sql_query'].append('')
                continue

            best_action_indices = best_final_states[i][0].action_history[0]

            action_strings = [action_mapping[(i, action_index)]
                              for action_index in best_action_indices]
            predicted_sql_query = action_sequence_to_sql(action_strings, add_table_names=True)
            outputs['predicted_sql_query'].append(sqlparse.format(predicted_sql_query, reindent=False))

            if target_list is not None:
                targets = target_list[i].data
            target_available = target_list is not None and targets[0] > -1

            if target_available:
                sequence_in_targets = self._action_history_match(best_action_indices, targets)
                self._exact_match(sequence_in_targets)

                sql_evaluator_match = self._evaluate_func(original_gold_sql_query, predicted_sql_query, world[i].db_id)
                self._sql_evaluator_match(sql_evaluator_match)

                similarity = difflib.SequenceMatcher(None, best_action_indices, targets)
                self._action_similarity(similarity.ratio())

                difficulty = self._query_difficulty(targets, action_mapping, i)
                if difficulty:
                    self._acc_multi(sql_evaluator_match)
                else:
                    self._acc_single(sql_evaluator_match)

            beam_hit = False
            for pos, final_state in enumerate(best_final_states[i]):
                action_indices = final_state.action_history[0]
                action_strings = [action_mapping[(i, action_index)]
                                  for action_index in action_indices]
                candidate_sql_query = action_sequence_to_sql(action_strings, add_table_names=True)

                if target_available:
                    correct = self._evaluate_func(original_gold_sql_query, candidate_sql_query, world[i].db_id)
                    if correct:
                        beam_hit = True
                    self._beam_hit(beam_hit)


================================================
FILE: modules/gated_graph_conv.py
================================================
import math

import torch
from torch import Tensor
from torch.nn import Parameter as Param, init
from torch_geometric.data import Data
from torch_geometric.nn.conv import MessagePassing


class GatedGraphConv(MessagePassing):
    r"""The gated graph convolution operator from the `"Gated Graph Sequence
    Neural Networks" <https://arxiv.org/abs/1511.05493>`_ paper
    .. math::
        \mathbf{h}_i^{(0)} &= \mathbf{x}_i \, \Vert \, \mathbf{0}
        \mathbf{m}_i^{(l+1)} &= \sum_{j \in \mathcal{N}(i)} \mathbf{\Theta}
        \cdot \mathbf{h}_j^{(l)}
        \mathbf{h}_i^{(l+1)} &= \textrm{GRU} (\mathbf{m}_i^{(l+1)},
        \mathbf{h}_i^{(l)})
    up to representation :math:`\mathbf{h}_i^{(L)}`.
    The number of input channels of :math:`\mathbf{x}_i` needs to be less or
    equal than :obj:`out_channels`.
    Args:
        out_channels (int): Size of each input sample.
        num_layers (int): The sequence length :math:`L`.
        aggr (string): The aggregation scheme to use
            (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`).
            (default: :obj:`"add"`)
        bias (bool, optional): If set to :obj:`False`, the layer will not learn
            an additive bias. (default: :obj:`True`)
    """

    def __init__(self, input_dim, num_timesteps, num_edge_types, aggr='add', bias=True, dropout=0):
        super(GatedGraphConv, self).__init__(aggr)

        self._input_dim = input_dim
        self.num_timesteps = num_timesteps
        self.num_edge_types = num_edge_types

        self.weight = Param(Tensor(num_timesteps, num_edge_types, input_dim, input_dim))
        self.bias = Param(Tensor(num_timesteps, num_edge_types, input_dim))
        self.rnn = torch.nn.GRUCell(input_dim, input_dim, bias=bias)
        self.dropout = torch.nn.Dropout(dropout)

        self.reset_parameters()

    def reset_parameters(self):
        for t in range(self.num_timesteps):
            for e in range(self.num_edge_types):
                torch.nn.init.xavier_uniform_(self.weight[t, e])
        init.uniform_(self.bias, -0.01, 0.01)
        self.rnn.reset_parameters()

    def forward(self, x, edge_indices):
        """"""
        if len(edge_indices) != self.num_edge_types:
            raise ValueError(f'GatedGraphConv constructed with {self.num_edge_types} edge types, '
                             f'but {len(edge_indices)} were passed')

        h = x

        if h.size(1) > self._input_dim:
            raise ValueError('The number of input channels is not allowed to '
                             'be larger than the number of output channels')

        if h.size(1) < self._input_dim:
            zero = h.new_zeros(h.size(0), self._input_dim - h.size(1))
            h = torch.cat([h, zero], dim=1)

        for t in range(self.num_timesteps):
            new_h = []
            for e in range(self.num_edge_types):
                if len(edge_indices[e]) == 0:
                    continue
                m = self.dropout(torch.matmul(h, self.weight[t, e]) + self.bias[t, e])
                new_h.append(self.propagate(edge_indices[e], size=(x.size(0), x.size(0)), x=m))
            m_sum = torch.sum(torch.stack(new_h), dim=0)
            h = self.rnn(m_sum, h)

        return h

    def __repr__(self):
        return '{}({}, num_layers={})'.format(
            self.__class__.__name__, self._input_dim, self.num_timesteps)


if __name__ == '__main__':
    gcn = GatedGraphConv(input_dim=10, num_timesteps=3, num_edge_types=3)
    data = Data(torch.zeros((5, 10)), edge_index=[
        torch.tensor([[1,2],[2,3]]),
        torch.tensor([[1,3],[0,1]]),
        torch.tensor([[1,4],[2,3]]),
    ])
    output = gcn(data.x, data.edge_index)

    print(output)


================================================
FILE: predictors/spider_predictor.py
================================================
from overrides import overrides

from allennlp.common.util import JsonDict, sanitize
from allennlp.data import DatasetReader, Instance
from allennlp.models import Model
from allennlp.predictors.predictor import Predictor


@Predictor.register("spider")
class WikiTablesParserPredictor(Predictor):
    def __init__(self, model: Model, dataset_reader: DatasetReader) -> None:
        super().__init__(model, dataset_reader)

    @overrides
    def predict_instance(self, instance: Instance) -> JsonDict:
        json_output = {}
        outputs = self._model.forward_on_instance(instance)
        predicted_sql_query = outputs['predicted_sql_query'].replace('\n', ' ')
        if predicted_sql_query == '':
            # line must not be empty for the evaluator to consider it
            predicted_sql_query = 'NO PREDICTION'
        json_output['predicted_sql_query'] = predicted_sql_query
        return sanitize(json_output)

    @overrides
    def dump_line(self, outputs: JsonDict) -> str:  # pylint: disable=no-self-use
        """
        If you don't want your outputs in JSON-lines format
        you can override this function to output them differently.
        """
        return outputs['predicted_sql_query'] + "\n"


================================================
FILE: requirements.txt
================================================
torch==1.0.1.post2
spacy==2.0.18
allennlp==0.8.2
dill
torch-scatter==1.1.2
torch-sparse==0.2.4
torch-cluster==1.2.4
torch-geometric==1.1.2
ordered_set
https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.0.0/en_core_web_sm-2.0.0.tar.gz

================================================
FILE: semparse/contexts/spider_context_utils.py
================================================
import re
from collections import defaultdict
from sys import exc_info
from typing import List, Dict, Set

from overrides import overrides
from parsimonious.exceptions import VisitationError, UndefinedLabel
from parsimonious.expressions import Literal, OneOf, Sequence
from parsimonious.grammar import Grammar
from parsimonious.nodes import Node, NodeVisitor
from six import reraise

WHITESPACE_REGEX = re.compile(" wsp |wsp | wsp| ws |ws | ws")


def format_grammar_string(grammar_dictionary: Dict[str, List[str]]) -> str:
    """
    Formats a dictionary of production rules into the string format expected
    by the Parsimonious Grammar class.
    """
    return '\n'.join([f"{nonterminal} = {' / '.join(right_hand_side)}"
                      for nonterminal, right_hand_side in grammar_dictionary.items()])


def initialize_valid_actions(grammar: Grammar,
                             keywords_to_uppercase: List[str] = None) -> Dict[str, List[str]]:
    """
    We initialize the valid actions with the global actions. These include the
    valid actions that result from the grammar and also those that result from
    the tables provided. The keys represent the nonterminals in the grammar
    and the values are lists of the valid actions of that nonterminal.
    """
    valid_actions: Dict[str, Set[str]] = defaultdict(set)

    for key in grammar:
        rhs = grammar[key]

        # Sequence represents a series of expressions that match pieces of the text in order.
        # Eg. A -> B C
        if isinstance(rhs, Sequence):
            valid_actions[key].add(
                format_action(key, " ".join(rhs._unicode_members()),  # pylint: disable=protected-access
                              keywords_to_uppercase=keywords_to_uppercase))

        # OneOf represents a series of expressions, one of which matches the text.
        # Eg. A -> B / C
        elif isinstance(rhs, OneOf):
            for option in rhs._unicode_members():  # pylint: disable=protected-access
                valid_actions[key].add(format_action(key, option,
                                                     keywords_to_uppercase=keywords_to_uppercase))

        # A string literal, eg. "A"
        elif isinstance(rhs, Literal):
            if rhs.literal != "":
                valid_actions[key].add(format_action(key, repr(rhs.literal),
                                                     keywords_to_uppercase=keywords_to_uppercase))
            else:
                valid_actions[key] = set()

    valid_action_strings = {key: sorted(value) for key, value in valid_actions.items()}
    return valid_action_strings


def format_action(nonterminal: str,
                  right_hand_side: str,
                  is_string: bool = False,
                  is_number: bool = False,
                  keywords_to_uppercase: List[str] = None) -> str:
    """
    This function formats an action as it appears in models. It
    splits productions based on the special `ws` and `wsp` rules,
    which are used in grammars to denote whitespace, and then
    rejoins these tokens a formatted, comma separated list.
    Importantly, note that it `does not` split on spaces in
    the grammar string, because these might not correspond
    to spaces in the language the grammar recognises.

    Parameters
    ----------
    nonterminal : ``str``, required.
        The nonterminal in the action.
    right_hand_side : ``str``, required.
        The right hand side of the action
        (i.e the thing which is produced).
    is_string : ``bool``, optional (default = False).
        Whether the production produces a string.
        If it does, it is formatted as ``nonterminal -> ['string']``
    is_number : ``bool``, optional, (default = False).
        Whether the production produces a string.
        If it does, it is formatted as ``nonterminal -> ['number']``
    keywords_to_uppercase: ``List[str]``, optional, (default = None)
        Keywords in the grammar to uppercase. In the case of sql,
        this might be SELECT, MAX etc.
    """
    keywords_to_uppercase = keywords_to_uppercase or []
    if right_hand_side.upper() in keywords_to_uppercase:
        right_hand_side = right_hand_side.upper()

    if is_string:
        return f'{nonterminal} -> ["\'{right_hand_side}\'"]'

    elif is_number:
        return f'{nonterminal} -> ["{right_hand_side}"]'

    else:
        right_hand_side = right_hand_side.lstrip("(").rstrip(")")
        child_strings = [token for token in WHITESPACE_REGEX.split(right_hand_side) if token]
        child_strings = [tok.upper() if tok.upper() in keywords_to_uppercase else tok for tok in child_strings]
        return f"{nonterminal} -> [{', '.join(child_strings)}]"


def action_sequence_to_sql(action_sequences: List[str], add_table_names: bool=False) -> str:
    # Convert an action sequence like ['statement -> [query, ";"]', ...] to the
    # SQL string.
    query = []
    for action in action_sequences:
        nonterminal, right_hand_side = action.split(' -> ')
        right_hand_side_tokens = right_hand_side[1:-1].split(', ')
        if nonterminal == 'statement':
            query.extend(right_hand_side_tokens)
        else:
            for query_index, token in list(enumerate(query)):
                if token == nonterminal:
                    if nonterminal == 'column_name' and '@' in right_hand_side_tokens[0] and len(right_hand_side_tokens) == 1:
                        if add_table_names:
                            table_name, column_name = right_hand_side_tokens[0].split('@')
                            if '.' in table_name:
                                table_name = table_name.split('.')[0]
                            right_hand_side_tokens = [table_name + '.' + column_name]
                        else:
                            right_hand_side_tokens = [right_hand_side_tokens[0].split('@')[-1]]
                    query = query[:query_index] + \
                            right_hand_side_tokens + \
                            query[query_index + 1:]
                    break
    return ' '.join([token.strip('"') for token in query])


class SqlVisitor(NodeVisitor):
    """
    ``SqlVisitor`` performs a depth-first traversal of the the AST. It takes the parse tree
    and gives us an action sequence that resulted in that parse. Since the visitor has mutable
    state, we define a new ``SqlVisitor`` for each query. To get the action sequence, we create
    a ``SqlVisitor`` and call parse on it, which returns a list of actions. Ex.

        sql_visitor = SqlVisitor(grammar_string)
        action_sequence = sql_visitor.parse(query)

    Importantly, this ``SqlVisitor`` skips over ``ws`` and ``wsp`` nodes,
    because they do not hold any meaning, and make an action sequence
    much longer than it needs to be.

    Parameters
    ----------
    grammar : ``Grammar``
        A Grammar object that we use to parse the text.
    keywords_to_uppercase: ``List[str]``, optional, (default = None)
        Keywords in the grammar to uppercase. In the case of sql,
        this might be SELECT, MAX etc.
    """

    def __init__(self, grammar: Grammar, keywords_to_uppercase: List[str] = None) -> None:
        self.action_sequence: List[str] = []
        self.grammar: Grammar = grammar
        self.keywords_to_uppercase = keywords_to_uppercase or []

    @overrides
    def generic_visit(self, node: Node, visited_children: List[None]) -> List[str]:
        self.add_action(node)
        if node.expr.name == 'statement':
            return self.action_sequence
        return []

    def add_action(self, node: Node) -> None:
        """
        For each node, we accumulate the rules that generated its children in a list.
        """
        if node.expr.name and node.expr.name not in ['ws', 'wsp']:
            nonterminal = f'{node.expr.name} -> '

            if isinstance(node.expr, Literal):
                right_hand_side = f'["{node.text}"]'

            else:
                child_strings = []
                for child in node.__iter__():
                    if child.expr.name in ['ws', 'wsp']:
                        continue
                    if child.expr.name != '':
                        child_strings.append(child.expr.name)
                    else:
                        child_right_side_string = child.expr._as_rhs().lstrip("(").rstrip(
                            ")")  # pylint: disable=protected-access
                        child_right_side_list = [tok for tok in
                                                 WHITESPACE_REGEX.split(child_right_side_string) if tok]
                        child_right_side_list = [tok.upper() if tok.upper() in
                                                                self.keywords_to_uppercase else tok
                                                 for tok in child_right_side_list]
                        child_strings.extend(child_right_side_list)
                right_hand_side = "[" + ", ".join(child_strings) + "]"
            rule = nonterminal + right_hand_side
            self.action_sequence = [rule] + self.action_sequence

    @overrides
    def visit(self, node):
        """
        See the ``NodeVisitor`` visit method. This just changes the order in which
        we visit nonterminals from right to left to left to right.
        """
        method = getattr(self, 'visit_' + node.expr_name, self.generic_visit)

        # Call that method, and show where in the tree it failed if it blows
        # up.
        try:
            # Changing this to reverse here!
            return method(node, [self.visit(child) for child in reversed(list(node))])
        except (VisitationError, UndefinedLabel):
            # Don't catch and re-wrap already-wrapped exceptions.
            raise
        except self.unwrapped_exceptions:
            raise
        except Exception:  # pylint: disable=broad-except
            # Catch any exception, and tack on a parse tree so it's easier to
            # see where it went wrong.
            exc_class, exc, traceback = exc_info()
            reraise(VisitationError, VisitationError(exc, exc_class, node), traceback)


================================================
FILE: semparse/contexts/spider_db_context.py
================================================
import re
from collections import Set, defaultdict
from typing import Dict, Tuple, List

from allennlp.data import Tokenizer, Token
from ordered_set import OrderedSet
from unidecode import unidecode

from dataset_readers.dataset_util.spider_utils import TableColumn, read_dataset_schema, read_dataset_values
from allennlp.semparse.contexts.knowledge_graph import KnowledgeGraph


# == stop words that will be omitted by ContextGenerator
STOP_WORDS = {"", "", "all", "being", "-", "over", "through", "yourselves", "its", "before",
              "hadn", "with", "had", ",", "should", "to", "only", "under", "ours", "has", "ought", "do",
              "them", "his", "than", "very", "cannot", "they", "not", "during", "yourself", "him",
              "nor", "did", "didn", "'ve", "this", "she", "each", "where", "because", "doing", "some", "we", "are",
              "further", "ourselves", "out", "what", "for", "weren", "does", "above", "between", "mustn", "?",
              "be", "hasn", "who", "were", "here", "shouldn", "let", "hers", "by", "both", "about", "couldn",
              "of", "could", "against", "isn", "or", "own", "into", "while", "whom", "down", "wasn", "your",
              "from", "her", "their", "aren", "there", "been", ".", "few", "too", "wouldn", "themselves",
              ":", "was", "until", "more", "himself", "on", "but", "don", "herself", "haven", "those", "he",
              "me", "myself", "these", "up", ";", "below", "'re", "can", "theirs", "my", "and", "would", "then",
              "is", "am", "it", "doesn", "an", "as", "itself", "at", "have", "in", "any", "if", "!",
              "again", "'ll", "no", "that", "when", "same", "how", "other", "which", "you", "many", "shan",
              "'t", "'s", "our", "after", "most", "'d", "such", "'m", "why", "a", "off", "i", "yours", "so",
              "the", "having", "once"}


class SpiderDBContext:
    schemas = {}
    db_knowledge_graph = {}
    db_tables_data = {}

    def __init__(self, db_id: str, utterance: str, tokenizer: Tokenizer, tables_file: str, dataset_path: str):
        self.dataset_path = dataset_path
        self.tables_file = tables_file
        self.db_id = db_id
        self.utterance = utterance

        tokenized_utterance = tokenizer.tokenize(utterance.lower())
        self.tokenized_utterance = [Token(text=t.text, lemma=t.lemma_) for t in tokenized_utterance]

        if db_id not in SpiderDBContext.schemas:
            SpiderDBContext.schemas = read_dataset_schema(self.tables_file)
        self.schema = SpiderDBContext.schemas[db_id]

        self.knowledge_graph = self.get_db_knowledge_graph(db_id)

        entity_texts = [self.knowledge_graph.entity_text[entity].lower()
                        for entity in self.knowledge_graph.entities]
        entity_tokens = tokenizer.batch_tokenize(entity_texts)
        self.entity_tokens = [[Token(text=t.text, lemma=t.lemma_) for t in et] for et in entity_tokens]

    @staticmethod
    def entity_key_for_column(table_name: str, column: TableColumn) -> str:
        if column.foreign_key is not None:
            column_type = "foreign"
        elif column.is_primary_key:
            column_type = "primary"
        else:
            column_type = column.column_type
        return f"column:{column_type.lower()}:{table_name.lower()}:{column.name.lower()}"

    def get_db_knowledge_graph(self, db_id: str) -> KnowledgeGraph:
        entities: Set[str] = set()
        neighbors: Dict[str, OrderedSet[str]] = defaultdict(OrderedSet)
        entity_text: Dict[str, str] = {}
        foreign_keys_to_column: Dict[str, str] = {}

        db_schema = self.schema
        tables = db_schema.values()

        if db_id not in self.db_tables_data:
            self.db_tables_data[db_id] = read_dataset_values(db_id, self.dataset_path, tables)

        tables_data = self.db_tables_data[db_id]

        string_column_mapping: Dict[str, set] = defaultdict(set)

        for table, table_data in tables_data.items():
            for table_row in table_data:
                for column, cell_value in zip(db_schema[table.name].columns, table_row):
                    if column.column_type == 'text' and type(cell_value) is str:
                        cell_value_normalized = self.normalize_string(cell_value)
                        column_key = self.entity_key_for_column(table.name, column)
                        string_column_mapping[cell_value_normalized].add(column_key)

        string_entities = self.get_entities_from_question(string_column_mapping)

        for table in tables:
            table_key = f"table:{table.name.lower()}"
            entities.add(table_key)
            entity_text[table_key] = table.text

            for column in db_schema[table.name].columns:
                entity_key = self.entity_key_for_column(table.name, column)
                entities.add(entity_key)
                neighbors[entity_key].add(table_key)
                neighbors[table_key].add(entity_key)
                entity_text[entity_key] = column.text

        for string_entity, column_keys in string_entities:
            entities.add(string_entity)
            for column_key in column_keys:
                neighbors[string_entity].add(column_key)
                neighbors[column_key].add(string_entity)
            entity_text[string_entity] = string_entity.replace("string:", "").replace("_", " ")

        # loop again after we have gone through all columns to link foreign keys columns
        for table_name in db_schema.keys():
            for column in db_schema[table_name].columns:
                if column.foreign_key is None:
                    continue

                other_column_table, other_column_name = column.foreign_key.split(':')

                # must have exactly one by design
                other_column = [col for col in db_schema[other_column_table].columns if col.name == other_column_name][0]

                entity_key = self.entity_key_for_column(table_name, column)
                other_entity_key = self.entity_key_for_column(other_column_table, other_column)

                neighbors[entity_key].add(other_entity_key)
                neighbors[other_entity_key].add(entity_key)

                foreign_keys_to_column[entity_key] = other_entity_key

        kg = KnowledgeGraph(entities, dict(neighbors), entity_text)
        kg.foreign_keys_to_column = foreign_keys_to_column

        return kg

    def _string_in_table(self, candidate: str,
                         string_column_mapping: Dict[str, set]) -> List[str]:
        """
        Checks if the string occurs in the table, and if it does, returns the names of the columns
        under which it occurs. If it does not, returns an empty list.
        """
        candidate_column_names: List[str] = []
        # First check if the entire candidate occurs as a cell.
        if candidate in string_column_mapping:
            candidate_column_names = string_column_mapping[candidate]
        # If not, check if it is a substring pf any cell value.
        if not candidate_column_names:
            for cell_value, column_names in string_column_mapping.items():
                if candidate in cell_value:
                    candidate_column_names.extend(column_names)
        candidate_column_names = list(set(candidate_column_names))
        return candidate_column_names

    def get_entities_from_question(self,
                                   string_column_mapping: Dict[str, set]) -> List[Tuple[str, str]]:
        entity_data = []
        for i, token in enumerate(self.tokenized_utterance):
            token_text = token.text
            if token_text in STOP_WORDS:
                continue
            normalized_token_text = self.normalize_string(token_text)
            if not normalized_token_text:
                continue
            token_columns = self._string_in_table(normalized_token_text, string_column_mapping)
            if token_columns:
                token_type = token_columns[0].split(":")[1]
                entity_data.append({'value': normalized_token_text,
                                    'token_start': i,
                                    'token_end': i+1,
                                    'token_type': token_type,
                                    'token_in_columns': token_columns})

        # extracted_numbers = self._get_numbers_from_tokens(self.question_tokens)
        # filter out number entities to avoid repetition
        expanded_entities = []
        for entity in self._expand_entities(self.tokenized_utterance, entity_data, string_column_mapping):
            if entity["token_type"] == "text":
                expanded_entities.append((f"string:{entity['value']}", entity['token_in_columns']))
        # return expanded_entities, extracted_numbers  #TODO(shikhar) Handle conjunctions

        return expanded_entities

    @staticmethod
    def normalize_string(string: str) -> str:
        """
        These are the transformation rules used to normalize cell in column names in Sempre.  See
        ``edu.stanford.nlp.sempre.tables.StringNormalizationUtils.characterNormalize`` and
        ``edu.stanford.nlp.sempre.tables.TableTypeSystem.canonicalizeName``.  We reproduce those
        rules here to normalize and canonicalize cells and columns in the same way so that we can
        match them against constants in logical forms appropriately.
        """
        # Normalization rules from Sempre
        # \u201A -> ,
        string = re.sub("‚", ",", string)
        string = re.sub("„", ",,", string)
        string = re.sub("[·・]", ".", string)
        string = re.sub("…", "...", string)
        string = re.sub("ˆ", "^", string)
        string = re.sub("˜", "~", string)
        string = re.sub("‹", "<", string)
        string = re.sub("›", ">", string)
        string = re.sub("[‘’´`]", "'", string)
        string = re.sub("[“”«»]", "\"", string)
        string = re.sub("[•†‡²³]", "", string)
        string = re.sub("[‐‑–—−]", "-", string)
        # Oddly, some unicode characters get converted to _ instead of being stripped.  Not really
        # sure how sempre decides what to do with these...  TODO(mattg): can we just get rid of the
        # need for this function somehow?  It's causing a whole lot of headaches.
        string = re.sub("[ðø′″€⁄ªΣ]", "_", string)
        # This is such a mess.  There isn't just a block of unicode that we can strip out, because
        # sometimes sempre just strips diacritics...  We'll try stripping out a few separate
        # blocks, skipping the ones that sempre skips...
        string = re.sub("[\\u0180-\\u0210]", "", string).strip()
        string = re.sub("[\\u0220-\\uFFFF]", "", string).strip()
        string = string.replace("\\n", "_")
        string = re.sub("\\s+", " ", string)
        # Canonicalization rules from Sempre.
        string = re.sub("[^\\w]", "_", string)
        string = re.sub("_+", "_", string)
        string = re.sub("_$", "", string)
        return unidecode(string.lower())

    def _expand_entities(self, question, entity_data, string_column_mapping: Dict[str, set]):
        new_entities = []
        for entity in entity_data:
            # to ensure the same strings are not used over and over
            if new_entities and entity['token_end'] <= new_entities[-1]['token_end']:
                continue
            current_start = entity['token_start']
            current_end = entity['token_end']
            current_token = entity['value']
            current_token_type = entity['token_type']
            current_token_columns = entity['token_in_columns']

            while current_end < len(question):
                next_token = question[current_end].text
                next_token_normalized = self.normalize_string(next_token)
                if next_token_normalized == "":
                    current_end += 1
                    continue
                candidate = "%s_%s" %(current_token, next_token_normalized)
                candidate_columns = self._string_in_table(candidate, string_column_mapping)
                candidate_columns = list(set(candidate_columns).intersection(current_token_columns))
                if not candidate_columns:
                    break
                candidate_type = candidate_columns[0].split(":")[1]
                if candidate_type != current_token_type:
                    break
                current_end += 1
                current_token = candidate
                current_token_columns = candidate_columns

            new_entities.append({'token_start': current_start,
                                 'token_end': current_end,
                                 'value': current_token,
                                 'token_type': current_token_type,
                                 'token_in_columns': current_token_columns})
        return new_entities


================================================
FILE: semparse/contexts/spider_db_grammar.py
================================================
# pylint: disable=anomalous-backslash-in-string
"""
A ``Text2SqlTableContext`` represents the SQL context in which an utterance appears
for the any of the text2sql datasets, with the grammar and the valid actions.
"""
from typing import List, Dict

from dataset_readers.dataset_util.spider_utils import Table


GRAMMAR_DICTIONARY = {}
GRAMMAR_DICTIONARY["statement"] = ['(query ws iue ws query)', '(query ws)']
GRAMMAR_DICTIONARY["iue"] = ['"intersect"', '"except"', '"union"']
GRAMMAR_DICTIONARY["query"] = ['(ws select_core ws groupby_clause ws orderby_clause ws limit)',
                               '(ws select_core ws groupby_clause ws orderby_clause)',
                               '(ws select_core ws groupby_clause ws limit)',
                               '(ws select_core ws orderby_clause ws limit)',
                               '(ws select_core ws groupby_clause)',
                               '(ws select_core ws orderby_clause)',
                               '(ws select_core)']

GRAMMAR_DICTIONARY["select_core"] = ['(select_with_distinct ws select_results ws from_clause ws where_clause)',
                                     '(select_with_distinct ws select_results ws from_clause)',
                                     '(select_with_distinct ws select_results ws where_clause)',
                                     '(select_with_distinct ws select_results)']
GRAMMAR_DICTIONARY["select_with_distinct"] = ['(ws "select" ws "distinct")', '(ws "select")']
GRAMMAR_DICTIONARY["select_results"] = ['(ws select_result ws "," ws select_results)', '(ws select_result)']
GRAMMAR_DICTIONARY["select_result"] = ['"*"', '(table_source ws ".*")',
                                       'expr', 'col_ref']

GRAMMAR_DICTIONARY["from_clause"] = ['(ws "from" ws table_source ws join_clauses)',
                                     '(ws "from" ws source)']
GRAMMAR_DICTIONARY["join_clauses"] = ['(join_clause ws join_clauses)', 'join_clause']
GRAMMAR_DICTIONARY["join_clause"] = ['"join" ws table_source ws "on" ws join_condition_clause']
GRAMMAR_DICTIONARY["join_condition_clause"] = ['(join_condition ws "and" ws join_condition_clause)', 'join_condition']
GRAMMAR_DICTIONARY["join_condition"] = ['ws col_ref ws "=" ws col_ref']
GRAMMAR_DICTIONARY["source"] = ['(ws single_source ws "," ws source)', '(ws single_source)']
GRAMMAR_DICTIONARY["single_source"] = ['table_source', 'source_subq']
GRAMMAR_DICTIONARY["source_subq"] = ['("(" ws query ws ")")']
# GRAMMAR_DICTIONARY["source_subq"] = ['("(" ws query ws ")" ws "as" ws name)', '("(" ws query ws ")")']
GRAMMAR_DICTIONARY["limit"] = ['("limit" ws non_literal_number)']

GRAMMAR_DICTIONARY["where_clause"] = ['(ws "where" wsp expr ws where_conj)', '(ws "where" wsp expr)']
GRAMMAR_DICTIONARY["where_conj"] = ['(ws "and" wsp expr ws where_conj)', '(ws "and" wsp expr)']

GRAMMAR_DICTIONARY["groupby_clause"] = ['(ws "group" ws "by" ws group_clause ws "having" ws expr)',
                                        '(ws "group" ws "by" ws group_clause)']
GRAMMAR_DICTIONARY["group_clause"] = ['(ws expr ws "," ws group_clause)', '(ws expr)']

GRAMMAR_DICTIONARY["orderby_clause"] = ['ws "order" ws "by" ws order_clause']
GRAMMAR_DICTIONARY["order_clause"] = ['(ordering_term ws "," ws order_clause)', 'ordering_term']
GRAMMAR_DICTIONARY["ordering_term"] = ['(ws expr ws ordering)', '(ws expr)']
GRAMMAR_DICTIONARY["ordering"] = ['(ws "asc")', '(ws "desc")']

GRAMMAR_DICTIONARY["col_ref"] = ['(table_name ws "." ws column_name)', 'column_name']
GRAMMAR_DICTIONARY["table_source"] = ['(table_name ws "as" ws table_alias)', 'table_name']
GRAMMAR_DICTIONARY["table_name"] = ["table_alias"]
GRAMMAR_DICTIONARY["table_alias"] = ['"t1"', '"t2"', '"t3"', '"t4"']
GRAMMAR_DICTIONARY["column_name"] = []

GRAMMAR_DICTIONARY["ws"] = ['~"\s*"i']
GRAMMAR_DICTIONARY['wsp'] = ['~"\s+"i']

GRAMMAR_DICTIONARY["expr"] = ['in_expr',
                              # Like expressions.
                              '(value wsp "like" wsp string)',
                              # Between expressions.
                              '(value ws "between" wsp value ws "and" wsp value)',
                              # Binary expressions.
                              '(value ws binaryop wsp expr)',
                              # Unary expressions.
                              '(unaryop ws expr)',
                              'source_subq',
                              'value']
GRAMMAR_DICTIONARY["in_expr"] = ['(value wsp "not" wsp "in" wsp string_set)',
                                 '(value wsp "in" wsp string_set)',
                                 '(value wsp "not" wsp "in" wsp expr)',
                                 '(value wsp "in" wsp expr)']

GRAMMAR_DICTIONARY["value"] = ['parenval', '"YEAR(CURDATE())"', 'number', 'boolean',
                               'function', 'col_ref', 'string']
GRAMMAR_DICTIONARY["parenval"] = ['"(" ws expr ws ")"']
GRAMMAR_DICTIONARY["function"] = ['(fname ws "(" ws "distinct" ws arg_list_or_star ws ")")',
                                  '(fname ws "(" ws arg_list_or_star ws ")")']

GRAMMAR_DICTIONARY["arg_list_or_star"] = ['arg_list', '"*"']
GRAMMAR_DICTIONARY["arg_list"] = ['(expr ws "," ws arg_list)', 'expr']
 # TODO(MARK): Massive hack, remove and modify the grammar accordingly
# GRAMMAR_DICTIONARY["number"] = ['~"\d*\.?\d+"i', "'3'", "'4'"]
GRAMMAR_DICTIONARY["non_literal_number"] = ['"1"', '"2"', '"3"', '"4"']
GRAMMAR_DICTIONARY["number"] = ['ws "value" ws']
GRAMMAR_DICTIONARY["string_set"] = ['ws "(" ws string_set_vals ws ")"']
GRAMMAR_DICTIONARY["string_set_vals"] = ['(string ws "," ws string_set_vals)', 'string']
# GRAMMAR_DICTIONARY["string"] = ['~"\'.*?\'"i']
GRAMMAR_DICTIONARY["string"] = ['"\'" ws "value" ws "\'"']
GRAMMAR_DICTIONARY["fname"] = ['"count"', '"sum"', '"max"', '"min"', '"avg"', '"all"']
GRAMMAR_DICTIONARY["boolean"] = ['"true"', '"false"']

# TODO(MARK): This is not tight enough. AND/OR are strictly boolean value operators.
GRAMMAR_DICTIONARY["binaryop"] = ['"+"', '"-"', '"*"', '"/"', '"="', '"!="', '"<>"',
                                  '">="', '"<="', '">"', '"<"', '"and"', '"or"', '"like"']
GRAMMAR_DICTIONARY["unaryop"] = ['"+"', '"-"', '"not"', '"not"']


def update_grammar_with_tables(grammar_dictionary: Dict[str, List[str]],
                               schema: Dict[str, Table]) -> None:
    table_names = sorted([f'"{table.lower()}"' for table in
                          list(schema.keys())], reverse=True)
    grammar_dictionary['table_name'] += table_names

    all_columns = set()
    for table in schema.values():
        all_columns.update([f'"{table.name.lower()}@{column.name.lower()}"' for column in table.columns if column.name != '*'])
    sorted_columns = sorted([column for column in all_columns], reverse=True)
    grammar_dictionary['column_name'] += sorted_columns


def update_grammar_to_be_table_names_free(grammar_dictionary: Dict[str, List[str]]):
    """
    Remove table names from column names, remove aliases
    """

    grammar_dictionary["column_name"] = []
    grammar_dictionary["table_name"] = []
    grammar_dictionary["col_ref"] = ['column_name']
    grammar_dictionary["table_source"] = ['table_name']

    del grammar_dictionary["table_alias"]


def update_grammar_flip_joins(grammar_dictionary: Dict[str, List[str]]):
    """
    Remove table names from column names, remove aliases
    """

    # using a simple rule such as join_clauses-> [(join_clauses ws join_clause), join_clause]
    # resulted in a max recursion error, so for now just using a predefined max
    # number of joins
    grammar_dictionary["join_clauses"] = ['(join_clauses_1 ws join_clause)', 'join_clause']
    grammar_dictionary["join_clauses_1"] = ['(join_clauses_2 ws join_clause)', 'join_clause']
    grammar_dictionary["join_clauses_2"] = ['(join_clause ws join_clause)', 'join_clause']

================================================
FILE: semparse/worlds/evaluate.py
================================================
################################
# val: number(float)/string(str)/sql(dict)
# col_unit: (agg_id, col_id, isDistinct(bool))
# val_unit: (unit_op, col_unit1, col_unit2)
# table_unit: (table_type, col_unit/sql)
# cond_unit: (not_op, op_id, val_unit, val1, val2)
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
# sql {
#   'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
#   'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
#   'where': condition
#   'groupBy': [col_unit1, col_unit2, ...]
#   'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
#   'having': condition
#   'limit': None/limit value
#   'intersect': None/sql
#   'except': None/sql
#   'union': None/sql
# }
################################
import copy
import os
import json
import sqlite3
import argparse

from spider_evaluation.process_sql import get_schema, Schema, get_sql

# Flag to disable value evaluation
DISABLE_VALUE = True
# Flag to disable distinct in select evaluation
DISABLE_DISTINCT = True


CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
JOIN_KEYWORDS = ('join', 'on', 'as')

WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
UNIT_OPS = ('none', '-', '+', "*", '/')
AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
TABLE_TYPE = {
    'sql': "sql",
    'table_unit': "table_unit",
}

COND_OPS = ('and', 'or')
SQL_OPS = ('intersect', 'union', 'except')
ORDER_OPS = ('desc', 'asc')


HARDNESS = {
    "component1": ('where', 'group', 'order', 'limit', 'join', 'or', 'like'),
    "component2": ('except', 'union', 'intersect')
}


def condition_has_or(conds):
    return 'or' in conds[1::2]


def condition_has_like(conds):
    return WHERE_OPS.index('like') in [cond_unit[1] for cond_unit in conds[::2]]


def condition_has_sql(conds):
    for cond_unit in conds[::2]:
        val1, val2 = cond_unit[3], cond_unit[4]
        if val1 is not None and type(val1) is dict:
            return True
        if val2 is not None and type(val2) is dict:
            return True
    return False


def val_has_op(val_unit):
    return val_unit[0] != UNIT_OPS.index('none')


def has_agg(unit):
    return unit[0] != AGG_OPS.index('none')


def accuracy(count, total):
    if count == total:
        return 1
    return 0


def recall(count, total):
    if count == total:
        return 1
    return 0


def F1(acc, rec):
    if (acc + rec) == 0:
        return 0
    return (2. * acc * rec) / (acc + rec)


def get_scores(count, pred_total, label_total):
    if pred_total != label_total:
        return 0,0,0
    elif count == pred_total:
        return 1,1,1
    return 0,0,0


def eval_sel(pred, label):
    pred_sel = pred['select'][1]
    label_sel = label['select'][1]
    label_wo_agg = [unit[1] for unit in label_sel]
    pred_total = len(pred_sel)
    label_total = len(label_sel)
    cnt = 0
    cnt_wo_agg = 0

    for unit in pred_sel:
        if unit in label_sel:
            cnt += 1
            label_sel.remove(unit)
        if unit[1] in label_wo_agg:
            cnt_wo_agg += 1
            label_wo_agg.remove(unit[1])

    return label_total, pred_total, cnt, cnt_wo_agg


def eval_where(pred, label):
    pred_conds = [unit for unit in pred['where'][::2]]
    label_conds = [unit for unit in label['where'][::2]]
    label_wo_agg = [unit[2] for unit in label_conds]
    pred_total = len(pred_conds)
    label_total = len(label_conds)
    cnt = 0
    cnt_wo_agg = 0

    for unit in pred_conds:
        if unit in label_conds:
            cnt += 1
            label_conds.remove(unit)
        if unit[2] in label_wo_agg:
            cnt_wo_agg += 1
            label_wo_agg.remove(unit[2])

    return label_total, pred_total, cnt, cnt_wo_agg


def eval_group(pred, label):
    pred_cols = [unit[1] for unit in pred['groupBy']]
    label_cols = [unit[1] for unit in label['groupBy']]
    pred_total = len(pred_cols)
    label_total = len(label_cols)
    cnt = 0
    pred_cols = [pred.split(".")[1] if "." in pred else pred for pred in pred_cols]
    label_cols = [label.split(".")[1] if "." in label else label for label in label_cols]
    for col in pred_cols:
        if col in label_cols:
            cnt += 1
            label_cols.remove(col)
    return label_total, pred_total, cnt


def eval_having(pred, label):
    pred_total = label_total = cnt = 0
    if len(pred['groupBy']) > 0:
        pred_total = 1
    if len(label['groupBy']) > 0:
        label_total = 1

    pred_cols = [unit[1] for unit in pred['groupBy']]
    label_cols = [unit[1] for unit in label['groupBy']]
    if pred_total == label_total == 1 \
            and pred_cols == label_cols \
            and pred['having'] == label['having']:
        cnt = 1

    return label_total, pred_total, cnt


def eval_order(pred, label):
    pred_total = label_total = cnt = 0
    if len(pred['orderBy']) > 0:
        pred_total = 1
    if len(label['orderBy']) > 0:
        label_total = 1
    if len(label['orderBy']) > 0 and pred['orderBy'] == label['orderBy'] and \
            ((pred['limit'] is None and label['limit'] is None) or (pred['limit'] is not None and label['limit'] is not None)):
        cnt = 1
    return label_total, pred_total, cnt


def eval_and_or(pred, label):
    pred_ao = pred['where'][1::2]
    label_ao = label['where'][1::2]
    pred_ao = set(pred_ao)
    label_ao = set(label_ao)

    if pred_ao == label_ao:
        return 1,1,1
    return len(pred_ao),len(label_ao),0


def get_nestedSQL(sql):
    nested = []
    for cond_unit in sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]:
        if type(cond_unit[3]) is dict:
            nested.append(cond_unit[3])
        if type(cond_unit[4]) is dict:
            nested.append(cond_unit[4])
    if sql['intersect'] is not None:
        nested.append(sql['intersect'])
    if sql['except'] is not None:
        nested.append(sql['except'])
    if sql['union'] is not None:
        nested.append(sql['union'])
    return nested


def eval_nested(pred, label):
    label_total = 0
    pred_total = 0
    cnt = 0
    if pred is not None:
        pred_total += 1
    if label is not None:
        label_total += 1
    if pred is not None and label is not None:
        cnt += Evaluator().eval_exact_match(pred, label)
    return label_total, pred_total, cnt


def eval_IUEN(pred, label):
    lt1, pt1, cnt1 = eval_nested(pred['intersect'], label['intersect'])
    lt2, pt2, cnt2 = eval_nested(pred['except'], label['except'])
    lt3, pt3, cnt3 = eval_nested(pred['union'], label['union'])
    label_total = lt1 + lt2 + lt3
    pred_total = pt1 + pt2 + pt3
    cnt = cnt1 + cnt2 + cnt3
    return label_total, pred_total, cnt


def get_keywords(sql):
    res = set()
    if len(sql['where']) > 0:
        res.add('where')
    if len(sql['groupBy']) > 0:
        res.add('group')
    if len(sql['having']) > 0:
        res.add('having')
    if len(sql['orderBy']) > 0:
        res.add(sql['orderBy'][0])
        res.add('order')
    if sql['limit'] is not None:
        res.add('limit')
    if sql['except'] is not None:
        res.add('except')
    if sql['union'] is not None:
        res.add('union')
    if sql['intersect'] is not None:
        res.add('intersect')

    # or keyword
    ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
    if len([token for token in ao if token == 'or']) > 0:
        res.add('or')

    cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
    # not keyword
    if len([cond_unit for cond_unit in cond_units if cond_unit[0]]) > 0:
        res.add('not')

    # in keyword
    if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('in')]) > 0:
        res.add('in')

    # like keyword
    if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) > 0:
        res.add('like')

    return res


def eval_keywords(pred, label):
    pred_keywords = get_keywords(pred)
    label_keywords = get_keywords(label)
    pred_total = len(pred_keywords)
    label_total = len(label_keywords)
    cnt = 0

    for k in pred_keywords:
        if k in label_keywords:
            cnt += 1
    return label_total, pred_total, cnt


def count_agg(units):
    return len([unit for unit in units if has_agg(unit)])


def count_component1(sql):
    count = 0
    if len(sql['where']) > 0:
        count += 1
    if len(sql['groupBy']) > 0:
        count += 1
    if len(sql['orderBy']) > 0:
        count += 1
    if sql['limit'] is not None:
        count += 1
    if len(sql['from']['table_units']) > 0:  # JOIN
        count += len(sql['from']['table_units']) - 1

    ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
    count += len([token for token in ao if token == 'or'])
    cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
    count += len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')])

    return count


def count_component2(sql):
    nested = get_nestedSQL(sql)
    return len(nested)


def count_others(sql):
    count = 0
    # number of aggregation
    agg_count = count_agg(sql['select'][1])
    agg_count += count_agg(sql['where'][::2])
    agg_count += count_agg(sql['groupBy'])
    if len(sql['orderBy']) > 0:
        agg_count += count_agg([unit[1] for unit in sql['orderBy'][1] if unit[1]] +
                            [unit[2] for unit in sql['orderBy'][1] if unit[2]])
    agg_count += count_agg(sql['having'])
    if agg_count > 1:
        count += 1

    # number of select columns
    if len(sql['select'][1]) > 1:
        count += 1

    # number of where conditions
    if len(sql['where']) > 1:
        count += 1

    # number of group by clauses
    if len(sql['groupBy']) > 1:
        count += 1

    return count


class Evaluator:
    """A simple evaluator"""
    def __init__(self):
        self.partial_scores = None

    def eval_hardness(self, sql):
        count_comp1_ = count_component1(sql)
        count_comp2_ = count_component2(sql)
        count_others_ = count_others(sql)

        if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0:
            return "easy"
        elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or \
                (count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0):
            return "medium"
        elif (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or \
                (2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or \
                (count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1):
            return "hard"
        else:
            return "extra"

    def eval_exact_match(self, pred, label):
        partial_scores = self.eval_partial_match(pred, label)
        self.partial_scores = partial_scores

        for _, score in partial_scores.items():
            if score['f1'] != 1:
                return 0
        if len(label['from']['table_units']) > 0:
            label_tables = sorted(label['from']['table_units'])
            pred_tables = sorted(pred['from']['table_units'])
            if label_tables != pred_tables:
                return False
        if len(label['from']['conds']) > 0:
            label_joins = sorted(label['from']['conds'], key=lambda x: str(x))
            pred_joins = sorted(pred['from']['conds'], key=lambda x: str(x))
            if label_joins != pred_joins:
                return False
        return 1

    def eval_partial_match(self, pred, label):
        res = {}

        label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label)
        acc, rec, f1 = get_scores(cnt, pred_total, label_total)
        res['select'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
        acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
        res['select(no AGG)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}

        label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label)
        acc, rec, f1 = get_scores(cnt, pred_total, label_total)
        res['where'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
        acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
        res['where(no OP)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}

        label_total, pred_total, cnt = eval_group(pred, label)
        acc, rec, f1 = get_scores(cnt, pred_total, label_total)
        res['group(no Having)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}

        label_total, pred_total, cnt = eval_having(pred, label)
        acc, rec, f1 = get_scores(cnt, pred_total, label_total)
        res['group'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}

        label_total, pred_total, cnt = eval_order(pred, label)
        acc, rec, f1 = get_scores(cnt, pred_total, label_total)
        res['order'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}

        label_total, pred_total, cnt = eval_and_or(pred, label)
        acc, rec, f1 = get_scores(cnt, pred_total, label_total)
        res['and/or'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}

        label_total, pred_total, cnt = eval_IUEN(pred, label)
        acc, rec, f1 = get_scores(cnt, pred_total, label_total)
        res['IUEN'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}

        label_total, pred_total, cnt = eval_keywords(pred, label)
        acc, rec, f1 = get_scores(cnt, pred_total, label_total)
        res['keywords'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}

        return res


def isValidSQL(sql, db):
    conn = sqlite3.connect(db)
    cursor = conn.cursor()
    try:
        cursor.execute(sql, [])
    except Exception as e:
        return False
    return True


def print_scores(scores, etype):
    levels = ['easy', 'medium', 'hard', 'extra', 'all']
    partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
                     'group', 'order', 'and/or', 'IUEN', 'keywords']

    print("{:20} {:20} {:20} {:20} {:20} {:20}".format("", *levels))
    counts = [scores[level]['count'] for level in levels]
    print("{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}".format("count", *counts))

    if etype in ["all", "exec"]:
        print('=====================   EXECUTION ACCURACY     =====================')
        this_scores = [scores[level]['exec'] for level in levels]
        print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("execution", *this_scores))

    if etype in ["all", "match"]:
        print('\n====================== EXACT MATCHING ACCURACY =====================')
        exact_scores = [scores[level]['exact'] for level in levels]
        print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("exact match", *exact_scores))
        print('\n---------------------PARTIAL MATCHING ACCURACY----------------------')
        for type_ in partial_types:
            this_scores = [scores[level]['partial'][type_]['acc'] for level in levels]
            print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))

        print('---------------------- PARTIAL MATCHING RECALL ----------------------')
        for type_ in partial_types:
            this_scores = [scores[level]['partial'][type_]['rec'] for level in levels]
            print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))

        print('---------------------- PARTIAL MATCHING F1 --------------------------')
        for type_ in partial_types:
            this_scores = [scores[level]['partial'][type_]['f1'] for level in levels]
            print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))


def evaluate(gold, predict, db_dir, etype, kmaps):
    with open(gold) as f:
        glist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]

    with open(predict) as f:
        plist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
    # plist = [("select max(Share),min(Share) from performance where Type != 'terminal'", "orchestra")]
    # glist = [("SELECT max(SHARE) ,  min(SHARE) FROM performance WHERE TYPE != 'Live final'", "orchestra")]
    evaluator = Evaluator()

    levels = ['easy', 'medium', 'hard', 'extra', 'all']
    partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
                     'group', 'order', 'and/or', 'IUEN', 'keywords']
    entries = []
    scores = {}

    for level in levels:
        scores[level] = {'count': 0, 'partial': {}, 'exact': 0.}
        scores[level]['exec'] = 0
        for type_ in partial_types:
            scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0.,'acc_count':0,'rec_count':0}

    eval_err_num = 0
    for p, g in zip(plist, glist):
        p_str = p[0]
        g_str, db = g
        db_name = db
        db = os.path.join(db_dir, db, db + ".sqlite")
        schema = Schema(get_schema(db))
        g_sql = get_sql(schema, g_str)
        hardness = evaluator.eval_hardness(g_sql)
        scores[hardness]['count'] += 1
        scores['all']['count'] += 1

        try:
            p_sql = get_sql(schema, p_str)
        except:
            # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql
            p_sql = {
            "except": None,
            "from": {
                "conds": [],
                "table_units": []
            },
            "groupBy": [],
            "having": [],
            "intersect": None,
            "limit": None,
            "orderBy": [],
            "select": [
                False,
                []
            ],
            "union": None,
            "where": []
            }
            eval_err_num += 1
            print("eval_err_num:{}".format(eval_err_num))
            print(p_str)
            print()

        # rebuild sql for value evaluation
        kmap = kmaps[db_name]
        g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema)
        g_sql = rebuild_sql_val(g_sql)
        g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap)
        p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema)
        p_sql = rebuild_sql_val(p_sql)
        p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)

        # p_sql_copy = copy.deepcopy(p_sql)
        # g_sql_copy = copy.deepcopy(g_sql)
        # if not eval_exec_match(db, p_str, g_str, p_sql_copy, g_sql_copy) and evaluator.eval_exact_match(p_sql_copy, g_sql_copy):
        #     a = 1

        if etype in ["all", "exec"]:
            exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql)
            if exec_score:
                scores[hardness]['exec'] += 1
            scores['all']['exec'] += exec_score

        if etype in ["all", "match"]:
            exact_score = evaluator.eval_exact_match(p_sql, g_sql)
            partial_scores = evaluator.partial_scores
            if exact_score == 0:
                print("{} pred: {}".format(hardness,p_str))
                print("{} gold: {}".format(hardness,g_str))
                print("")
            scores[hardness]['exact'] += exact_score
            scores['all']['exact'] += exact_score
            for type_ in partial_types:
                if partial_scores[type_]['pred_total'] > 0:
                    scores[hardness]['partial'][type_]['acc'] += partial_scores[type_]['acc']
                    scores[hardness]['partial'][type_]['acc_count'] += 1
                if partial_scores[type_]['label_total'] > 0:
                    scores[hardness]['partial'][type_]['rec'] += partial_scores[type_]['rec']
                    scores[hardness]['partial'][type_]['rec_count'] += 1
                scores[hardness]['partial'][type_]['f1'] += partial_scores[type_]['f1']
                if partial_scores[type_]['pred_total'] > 0:
                    scores['all']['partial'][type_]['acc'] += partial_scores[type_]['acc']
                    scores['all']['partial'][type_]['acc_count'] += 1
                if partial_scores[type_]['label_total'] > 0:
                    scores['all']['partial'][type_]['rec'] += partial_scores[type_]['rec']
                    scores['all']['partial'][type_]['rec_count'] += 1
                scores['all']['partial'][type_]['f1'] += partial_scores[type_]['f1']

            entries.append({
                'predictSQL': p_str,
                'goldSQL': g_str,
                'hardness': hardness,
                'exact': exact_score,
                'partial': partial_scores
            })

    for level in levels:
        if scores[level]['count'] == 0:
            continue
        if etype in ["all", "exec"]:
            scores[level]['exec'] /= scores[level]['count']

        if etype in ["all", "match"]:
            scores[level]['exact'] /= scores[level]['count']
            for type_ in partial_types:
                if scores[level]['partial'][type_]['acc_count'] == 0:
                    scores[level]['partial'][type_]['acc'] = 0
                else:
                    scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \
                                                             scores[level]['partial'][type_]['acc_count'] * 1.0
                if scores[level]['partial'][type_]['rec_count'] == 0:
                    scores[level]['partial'][type_]['rec'] = 0
                else:
                    scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \
                                                             scores[level]['partial'][type_]['rec_count'] * 1.0
                if scores[level]['partial'][type_]['acc'] == 0 and scores[level]['partial'][type_]['rec'] == 0:
                    scores[level]['partial'][type_]['f1'] = 1
                else:
                    scores[level]['partial'][type_]['f1'] = \
                        2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / (
                        scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc'])

    print_scores(scores, etype)


def eval_exec_match(db, p_str, g_str, pred, gold):
    """
    return 1 if the values between prediction and gold are matching
    in the corresponding index. Currently not support multiple col_unit(pairs).
    """
    conn = sqlite3.connect(db)
    cursor = conn.cursor()
    conn.text_factory = bytes
    try:
        cursor.execute(p_str)
        p_res = cursor.fetchall()
    except:
        return False

    cursor.execute(g_str)
    q_res = cursor.fetchall()

    def res_map(res, val_units):
        rmap = {}
        for idx, val_unit in enumerate(val_units):
            key = tuple(val_unit[1]) if not val_unit[2] else (val_unit[0], tuple(val_unit[1]), tuple(val_unit[2]))
            rmap[key] = [r[idx] for r in res]
        return rmap

    p_val_units = [unit[1] for unit in pred['select'][1]]
    q_val_units = [unit[1] for unit in gold['select'][1]]
    return res_map(p_res, p_val_units) == res_map(q_res, q_val_units)


# Rebuild SQL functions for value evaluation
def rebuild_cond_unit_val(cond_unit):
    if cond_unit is None or not DISABLE_VALUE:
        return cond_unit

    not_op, op_id, val_unit, val1, val2 = cond_unit
    if type(val1) is not dict:
        val1 = None
    else:
        val1 = rebuild_sql_val(val1)
    if type(val2) is not dict:
        val2 = None
    else:
        val2 = rebuild_sql_val(val2)
    return not_op, op_id, val_unit, val1, val2


def rebuild_condition_val(condition):
    if condition is None or not DISABLE_VALUE:
        return condition

    res = []
    for idx, it in enumerate(condition):
        if idx % 2 == 0:
            res.append(rebuild_cond_unit_val(it))
        else:
            res.append(it)
    return res


def rebuild_sql_val(sql):
    if sql is None or not DISABLE_VALUE:
        return sql

    sql['from']['conds'] = rebuild_condition_val(sql['from']['conds'])
    sql['having'] = rebuild_condition_val(sql['having'])
    sql['where'] = rebuild_condition_val(sql['where'])
    sql['intersect'] = rebuild_sql_val(sql['intersect'])
    sql['except'] = rebuild_sql_val(sql['except'])
    sql['union'] = rebuild_sql_val(sql['union'])

    return sql


# Rebuild SQL functions for foreign key evaluation
def build_valid_col_units(table_units, schema):
    col_ids = [table_unit[1] for table_unit in table_units if table_unit[0] == TABLE_TYPE['table_unit']]
    prefixs = [col_id[:-2] for col_id in col_ids]
    valid_col_units= []
    for value in schema.idMap.values():
        if '.' in value and value[:value.index('.')] in prefixs:
            valid_col_units.append(value)
    return valid_col_units


def rebuild_col_unit_col(valid_col_units, col_unit, kmap):
    if col_unit is None:
        return col_unit

    agg_id, col_id, distinct = col_unit
    if col_id in kmap and col_id in valid_col_units:
        col_id = kmap[col_id]
    if DISABLE_DISTINCT:
        distinct = None
    return agg_id, col_id, distinct


def rebuild_val_unit_col(valid_col_units, val_unit, kmap):
    if val_unit is None:
        return val_unit

    unit_op, col_unit1, col_unit2 = val_unit
    col_unit1 = rebuild_col_unit_col(valid_col_units, col_unit1, kmap)
    col_unit2 = rebuild_col_unit_col(valid_col_units, col_unit2, kmap)
    return unit_op, col_unit1, col_unit2


def rebuild_table_unit_col(valid_col_units, table_unit, kmap):
    if table_unit is None:
        return table_unit

    table_type, col_unit_or_sql = table_unit
    if isinstance(col_unit_or_sql, tuple):
        col_unit_or_sql = rebuild_col_unit_col(valid_col_units, col_unit_or_sql, kmap)
    return table_type, col_unit_or_sql


def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap):
    if cond_unit is None:
        return cond_unit

    not_op, op_id, val_unit, val1, val2 = cond_unit
    val_unit = rebuild_val_unit_col(valid_col_units, val_unit, kmap)
    return not_op, op_id, val_unit, val1, val2


def rebuild_condition_col(valid_col_units, condition, kmap):
    for idx in range(len(condition)):
        if idx % 2 == 0:
            condition[idx] = rebuild_cond_unit_col(valid_col_units, condition[idx], kmap)
    return condition


def rebuild_select_col(valid_col_units, sel, kmap):
    if sel is None:
        return sel
    distinct, _list = sel
    new_list = []
    for it in _list:
        agg_id, val_unit = it
        new_list.append((agg_id, rebuild_val_unit_col(valid_col_units, val_unit, kmap)))
    if DISABLE_DISTINCT:
        distinct = None
    return distinct, new_list


def rebuild_from_col(valid_col_units, from_, kmap):
    if from_ is None:
        return from_

    from_['table_units'] = [rebuild_table_unit_col(valid_col_units, table_unit, kmap) for table_unit in from_['table_units']]
    from_['conds'] = rebuild_condition_col(valid_col_units, from_['conds'], kmap)
    return from_


def rebuild_group_by_col(valid_col_units, group_by, kmap):
    if group_by is None:
        return group_by

    return [rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by]


def rebuild_order_by_col(valid_col_units, order_by, kmap):
    if order_by is None or len(order_by) == 0:
        return order_by

    direction, val_units = order_by
    new_val_units = [rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units]
    return direction, new_val_units


def rebuild_sql_col(valid_col_units, sql, kmap):
    if sql is None:
        return sql

    sql['select'] = rebuild_select_col(valid_col_units, sql['select'], kmap)
    sql['from'] = rebuild_from_col(valid_col_units, sql['from'], kmap)
    sql['where'] = rebuild_condition_col(valid_col_units, sql['where'], kmap)
    sql['groupBy'] = rebuild_group_by_col(valid_col_units, sql['groupBy'], kmap)
    sql['orderBy'] = rebuild_order_by_col(valid_col_units, sql['orderBy'], kmap)
    sql['having'] = rebuild_condition_col(valid_col_units, sql['having'], kmap)
    sql['intersect'] = rebuild_sql_col(valid_col_units, sql['intersect'], kmap)
    sql['except'] = rebuild_sql_col(valid_col_units, sql['except'], kmap)
    sql['union'] = rebuild_sql_col(valid_col_units, sql['union'], kmap)

    return sql


def build_foreign_key_map(entry):
    cols_orig = entry["column_names_original"]
    tables_orig = entry["table_names_original"]

    # rebuild cols corresponding to idmap in Schema
    cols = []
    for col_orig in cols_orig:
        if col_orig[0] >= 0:
            t = tables_orig[col_orig[0]]
            c = col_orig[1]
            cols.append("__" + t.lower() + "." + c.lower() + "__")
        else:
            cols.append("__all__")

    def keyset_in_list(k1, k2, k_list):
        for k_set in k_list:
            if k1 in k_set or k2 in k_set:
                return k_set
        new_k_set = set()
        k_list.append(new_k_set)
        return new_k_set

    foreign_key_list = []
    foreign_keys = entry["foreign_keys"]
    for fkey in foreign_keys:
        key1, key2 = fkey
        key_set = keyset_in_list(key1, key2, foreign_key_list)
        key_set.add(key1)
        key_set.add(key2)

    foreign_key_map = {}
    for key_set in foreign_key_list:
        sorted_list = sorted(list(key_set))
        midx = sorted_list[0]
        for idx in sorted_list:
            foreign_key_map[cols[idx]] = cols[midx]

    return foreign_key_map


def build_foreign_key_map_from_json(table):
    with open(table) as f:
        data = json.load(f)
    tables = {}
    for entry in data:
        tables[entry['db_id']] = build_foreign_key_map(entry)
    return tables


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--gold', dest='gold', type=str)
    parser.add_argument('--pred', dest='pred', type=str)
    parser.add_argument('--db', dest='db', type=str)
    parser.add_argument('--table', dest='table', type=str)
    parser.add_argument('--etype', dest='etype', type=str)
    args = parser.parse_args()

    gold = args.gold
    pred = args.pred
    db_dir = args.db
    table = args.table
    etype = args.etype

    assert etype in ["all", "exec", "match"], "Unknown evaluation method"

    kmaps = build_foreign_key_map_from_json(table)

    evaluate(gold, pred, db_dir, etype, kmaps)

================================================
FILE: semparse/worlds/evaluate_spider.py
================================================
import os
import sqlite3

from semparse.worlds.evaluate import Evaluator, build_valid_col_units, rebuild_sql_val, rebuild_sql_col, \
    build_foreign_key_map_from_json
from spider_evaluation.process_sql import Schema, get_schema, get_sql

_schemas = {}
kmaps = None


def evaluate(gold, predict, db_name, db_dir, table, check_valid: bool=True) -> bool:
    global kmaps

    # try:
    evaluator = Evaluator()

    if kmaps is None:
        kmaps = build_foreign_key_map_from_json(table)

    if db_name in _schemas:
        schema = _schemas[db_name]
    else:
        db = os.path.join(db_dir, db_name, db_name + ".sqlite")
        schema = _schemas[db_name] = Schema(get_schema(db))

    g_sql = get_sql(schema, gold)

    try:
        p_sql = get_sql(schema, predict)
    except Exception as e:
        return False

    # rebuild sql for value evaluation
    kmap = kmaps[db_name]
    g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema)
    g_sql = rebuild_sql_val(g_sql)
    g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap)
    p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema)
    p_sql = rebuild_sql_val(p_sql)
    p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)

    exact_score = evaluator.eval_exact_match(p_sql, g_sql)

    if not check_valid:
        return exact_score
    else:
        return exact_score and check_valid_sql(predict, db_name, db_dir)
    # except Exception as e:
    #     return 0


_conns = {}


def check_valid_sql(sql, db_name, db_dir, return_error=False):
    db = os.path.join(db_dir, db_name, db_name + ".sqlite")

    if db_name == 'wta_1':
        # TODO: seems like there is a problem with this dataset - slow response - add limit 1
        return True if not return_error else (True, None)

    if db_name not in _conns:
        _conns[db_name] = sqlite3.connect(db)

        # fixes an encoding bug
        _conns[db_name].text_factory = bytes

    conn = _conns[db_name]
    cursor = conn.cursor()
    try:
        cursor.execute(sql)
        cursor.fetchall()
        return True if not return_error else (True, None)
    except Exception as e:
        return False if not return_error else (False, e.args[0])


================================================
FILE: semparse/worlds/spider_world.py
================================================
from typing import List, Tuple, Dict, Set, Optional
from copy import deepcopy

from parsimonious import Grammar
from parsimonious.exceptions import ParseError

from semparse.contexts.spider_context_utils import format_grammar_string, initialize_valid_actions, SqlVisitor
from semparse.contexts.spider_db_context import SpiderDBContext
from semparse.contexts.spider_db_grammar import GRAMMAR_DICTIONARY, update_grammar_with_tables, \
    update_grammar_to_be_table_names_free, update_grammar_flip_joins


class SpiderWorld:
    """
    World representation for spider dataset.
    """

    def __init__(self, db_context: SpiderDBContext, query: Optional[List[str]], allow_alias: bool = False) -> None:
        self.db_id = db_context.db_id
        self.allow_alias = allow_alias

        # NOTE: This base dictionary should not be modified.
        self.base_grammar_dictionary = deepcopy(GRAMMAR_DICTIONARY)
        self.query = query
        self.db_context = db_context

        # keep a list of entities names as they are given in sql queries
        self.entities_names = {}
        for i, entity in enumerate(self.db_context.knowledge_graph.entities):
            parts = entity.split(':')
            if parts[0] in ['table', 'string']:
                self.entities_names[parts[1]] = i
            else:
                _, _, table_name, column_name = parts
                self.entities_names[f'{table_name}@{column_name}'] = i
        self.valid_actions = []
        self.valid_actions_flat = []

    def get_action_sequence_and_all_actions(self,
                                            allow_aliases: bool = False) -> Tuple[List[str], List[str]]:
        grammar_with_context = deepcopy(self.base_grammar_dictionary)
        if not allow_aliases:
            update_grammar_to_be_table_names_free(grammar_with_context)

        schema = self.db_context.schema

        update_grammar_with_tables(grammar_with_context, schema)
        grammar = Grammar(format_grammar_string(grammar_with_context))

        valid_actions = initialize_valid_actions(grammar)
        all_actions = set()
        for action_list in valid_actions.values():
            all_actions.update(action_list)
        sorted_actions = sorted(all_actions)
        self.valid_actions = valid_actions
        self.valid_actions_flat = sorted_actions

        action_sequence = None
        if self.query is not None:
            sql_visitor = SqlVisitor(grammar)
            query = " ".join(self.query).lower().replace("``", "'").replace("''", "'")
            try:
                action_sequence = sql_visitor.parse(query) if query else []
            except ParseError as e:
                pass

        return action_sequence, sorted_actions

    def get_all_actions(self, schema,
                        flip_joins: bool,
                        allow_aliases: bool) -> Tuple[List[str], List[str]]:
        grammar_with_context = deepcopy(self.base_grammar_dictionary)
        if not allow_aliases:
            update_grammar_to_be_table_names_free(grammar_with_context)

        if flip_joins:
            update_grammar_flip_joins(grammar_with_context)

        update_grammar_with_tables(grammar_with_context, schema, self.db_id)
        grammar = Grammar(format_grammar_string(grammar_with_context))

        valid_actions = initialize_valid_actions(grammar)
        all_actions = set()
        for action_list in valid_actions.values():
            all_actions.update(action_list)
        sorted_actions = sorted(all_actions)
        self.valid_actions = valid_actions
        self.valid_actions_flat = sorted_actions

        return sorted_actions

    def is_global_rule(self, rhs: str) -> bool:
        rhs = rhs.strip('[] ')
        if rhs[0] != '"':
            return True
        return rhs.strip('"') not in self.entities_names

    def get_oracle_relevance_score(self, oracle_entities: set):
        """
        return 0/1 for each schema item if it should be in the graph,
        given the used entities in the gold answer
        """
        scores = [0 for _ in range(len(self.db_context.knowledge_graph.entities))]

        for i, entity in enumerate(self.db_context.knowledge_graph.entities):
            parts = entity.split(':')
            if parts[0] == 'column':
                name = parts[2] + '@' + parts[3]
            else:
                name = parts[-1]
            if name in oracle_entities:
                scores[i] = 1

        return scores

    def get_action_entity_mapping(self) -> Dict[int, int]:
        mapping = {}

        for action_index, action in enumerate(self.valid_actions_flat):
            # default is padding
            mapping[action_index] = -1

            action = action.split(" -> ")[1].strip('[]')
            action_stripped = action.strip('\"')
            if action[0] != '"' or action_stripped not in self.entities_names:
                continue

            mapping[action_index] = self.entities_names[action_stripped]

        return mapping

    def get_query_without_table_hints(self):
        if not self.query:
            return ''
        toks = []
        for tok in self.query:
            if '@' in tok:
                parts = tok.split('@')
                if '.' in parts[0]:
                    toks.append(parts[0].split('.')[0] + '.' + parts[1])
                else:
                    toks.append(parts[1])
            else:
                toks.append(tok)
        return toks

    # def is_ambiguous_column(self, action_rhs: str, actions_sequence: List[str], action_index: int):
    #     """
    #     a column would be ambiguous if another table is used in the query and it has a column with the same name
    #     currently, return true only for join clauses
    #     """
    #     if actions_sequence[action_index-1].startswith('join_condition -> ') or \
    #             actions_sequence[action_index-2].startswith('join_condition -> '):
    #         return True
    #
    #     return False
    #     # column_table, column_name = action_rhs.strip('"').split('.')
    #     # tables_used = [a.split(' -> ')[1].strip('[]\"') for a in actions_sequence if a.startswith('table_name -> ')]
    #     # columns_with_same_name = set([a.split(' -> ')[1].strip('[]\"') for a in actions_sequence
    #     #                 if a.startswith('column_name -> ') and a.strip('[]\"').endswith(column_name)])
    #     # table_of_columns_with_same_name = set([c.split('.')[0] for c in columns_with_same_name])
    #     # if len(table_of_columns_with_same_name) > 1:
    #     #     return True
    #     # if column_table not in tables_used:
    #     #     return False
    #     # other_tables = [t for t in tables_used if t != column_table]
    #     # other_tables_columns = [[c.split(':')[-1] for c in self.knowledge_graph.neighbors[f'table:{t}']] for t in other_tables]
    #     # other_tables_columns_set = set([item for sublist in other_tables_columns for item in sublist])
    #     # return column_name in other_tables_columns_set


================================================
FILE: spider_evaluation/evaluate.py
================================================
################################
# val: number(float)/string(str)/sql(dict)
# col_unit: (agg_id, col_id, isDistinct(bool))
# val_unit: (unit_op, col_unit1, col_unit2)
# table_unit: (table_type, col_unit/sql)
# cond_unit: (not_op, op_id, val_unit, val1, val2)
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
# sql {
#   'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
#   'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
#   'where': condition
#   'groupBy': [col_unit1, col_unit2, ...]
#   'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
#   'having': condition
#   'limit': None/limit value
#   'intersect': None/sql
#   'except': None/sql
#   'union': None/sql
# }
################################
import copy
import os
import json
import sqlite3
import argparse

from spider_evaluation import get_schema, Schema, get_sql

# Flag to disable value evaluation
DISABLE_VALUE = True
# Flag to disable distinct in select evaluation
DISABLE_DISTINCT = True


CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
JOIN_KEYWORDS = ('join', 'on', 'as')

WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
UNIT_OPS = ('none', '-', '+', "*", '/')
AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
TABLE_TYPE = {
    'sql': "sql",
    'table_unit': "table_unit",
}

COND_OPS = ('and', 'or')
SQL_OPS = ('intersect', 'union', 'except')
ORDER_OPS = ('desc', 'asc')


HARDNESS = {
    "component1": ('where', 'group', 'order', 'limit', 'join', 'or', 'like'),
    "component2": ('except', 'union', 'intersect')
}


def condition_has_or(conds):
    return 'or' in conds[1::2]


def condition_has_like(conds):
    return WHERE_OPS.index('like') in [cond_unit[1] for cond_unit in conds[::2]]


def condition_has_sql(conds):
    for cond_unit in conds[::2]:
        val1, val2 = cond_unit[3], cond_unit[4]
        if val1 is not None and type(val1) is dict:
            return True
        if val2 is not None and type(val2) is dict:
            return True
    return False


def val_has_op(val_unit):
    return val_unit[0] != UNIT_OPS.index('none')


def has_agg(unit):
    return unit[0] != AGG_OPS.index('none')


def accuracy(count, total):
    if count == total:
        return 1
    return 0


def recall(count, total):
    if count == total:
        return 1
    return 0


def F1(acc, rec):
    if (acc + rec) == 0:
        return 0
    return (2. * acc * rec) / (acc + rec)


def get_scores(count, pred_total, label_total):
    if pred_total != label_total:
        return 0,0,0
    elif count == pred_total:
        return 1,1,1
    return 0,0,0


def eval_sel(pred, label):
    pred_sel = pred['select'][1]
    label_sel = label['select'][1]
    label_wo_agg = [unit[1] for unit in label_sel]
    pred_total = len(pred_sel)
    label_total = len(label_sel)
    cnt = 0
    cnt_wo_agg = 0

    for unit in pred_sel:
        if unit in label_sel:
            cnt += 1
            label_sel.remove(unit)
        if unit[1] in label_wo_agg:
            cnt_wo_agg += 1
            label_wo_agg.remove(unit[1])

    return label_total, pred_total, cnt, cnt_wo_agg


def eval_where(pred, label):
    pred_conds = [unit for unit in pred['where'][::2]]
    label_conds = [unit for unit in label['where'][::2]]
    label_wo_agg = [unit[2] for unit in label_conds]
    pred_total = len(pred_conds)
    label_total = len(label_conds)
    cnt = 0
    cnt_wo_agg = 0

    for unit in pred_conds:
        if unit in label_conds:
            cnt += 1
            label_conds.remove(unit)
        if unit[2] in label_wo_agg:
            cnt_wo_agg += 1
            label_wo_agg.remove(unit[2])

    return label_total, pred_total, cnt, cnt_wo_agg


def eval_group(pred, label):
    pred_cols = [unit[1] for unit in pred['groupBy']]
    label_cols = [unit[1] for unit in label['groupBy']]
    pred_total = len(pred_cols)
    label_total = len(label_cols)
    cnt = 0
    pred_cols = [pred.split(".")[1] if "." in pred else pred for pred in pred_cols]
    label_cols = [label.split(".")[1] if "." in label else label for label in label_cols]
    for col in pred_cols:
        if col in label_cols:
            cnt += 1
            label_cols.remove(col)
    return label_total, pred_total, cnt


def eval_having(pred, label):
    pred_total = label_total = cnt = 0
    if len(pred['groupBy']) > 0:
        pred_total = 1
    if len(label['groupBy']) > 0:
        label_total = 1

    pred_cols = [unit[1] for unit in pred['groupBy']]
    label_cols = [unit[1] for unit in label['groupBy']]
    if pred_total == label_total == 1 \
            and pred_cols == label_cols \
            and pred['having'] == label['having']:
        cnt = 1

    return label_total, pred_total, cnt


def eval_order(pred, label):
    pred_total = label_total = cnt = 0
    if len(pred['orderBy']) > 0:
        pred_total = 1
    if len(label['orderBy']) > 0:
        label_total = 1
    if len(label['orderBy']) > 0 and pred['orderBy'] == label['orderBy'] and \
            ((pred['limit'] is None and label['limit'] is None) or (pred['limit'] is not None and label['limit'] is not None)):
        cnt = 1
    return label_total, pred_total, cnt


def eval_and_or(pred, label):
    pred_ao = pred['where'][1::2]
    label_ao = label['where'][1::2]
    pred_ao = set(pred_ao)
    label_ao = set(label_ao)

    if pred_ao == label_ao:
        return 1,1,1
    return len(pred_ao),len(label_ao),0


def get_nestedSQL(sql):
    nested = []
    for cond_unit in sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]:
        if type(cond_unit[3]) is dict:
            nested.append(cond_unit[3])
        if type(cond_unit[4]) is dict:
            nested.append(cond_unit[4])
    if sql['intersect'] is not None:
        nested.append(sql['intersect'])
    if sql['except'] is not None:
        nested.append(sql['except'])
    if sql['union'] is not None:
        nested.append(sql['union'])
    return nested


def eval_nested(pred, label):
    label_total = 0
    pred_total = 0
    cnt = 0
    if pred is not None:
        pred_total += 1
    if label is not None:
        label_total += 1
    if pred is not None and label is not None:
        cnt += Evaluator().eval_exact_match(pred, label)
    return label_total, pred_total, cnt


def eval_IUEN(pred, label):
    lt1, pt1, cnt1 = eval_nested(pred['intersect'], label['intersect'])
    lt2, pt2, cnt2 = eval_nested(pred['except'], label['except'])
    lt3, pt3, cnt3 = eval_nested(pred['union'], label['union'])
    label_total = lt1 + lt2 + lt3
    pred_total = pt1 + pt2 + pt3
    cnt = cnt1 + cnt2 + cnt3
    return label_total, pred_total, cnt


def get_keywords(sql):
    res = set()
    if len(sql['where']) > 0:
        res.add('where')
    if len(sql['groupBy']) > 0:
        res.add('group')
    if len(sql['having']) > 0:
        res.add('having')
    if len(sql['orderBy']) > 0:
        res.add(sql['orderBy'][0])
        res.add('order')
    if sql['limit'] is not None:
        res.add('limit')
    if sql['except'] is not None:
        res.add('except')
    if sql['union'] is not None:
        res.add('union')
    if sql['intersect'] is not None:
        res.add('intersect')

    # or keyword
    ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
    if len([token for token in ao if token == 'or']) > 0:
        res.add('or')

    cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
    # not keyword
    if len([cond_unit for cond_unit in cond_units if cond_unit[0]]) > 0:
        res.add('not')

    # in keyword
    if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('in')]) > 0:
        res.add('in')

    # like keyword
    if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) > 0:
        res.add('like')

    return res


def eval_keywords(pred, label):
    pred_keywords = get_keywords(pred)
    label_keywords = get_keywords(label)
    pred_total = len(pred_keywords)
    label_total = len(label_keywords)
    cnt = 0

    for k in pred_keywords:
        if k in label_keywords:
            cnt += 1
    return label_total, pred_total, cnt


def count_agg(units):
    return len([unit for unit in units if has_agg(unit)])


def count_component1(sql):
    count = 0
    if len(sql['where']) > 0:
        count += 1
    if len(sql['groupBy']) > 0:
        count += 1
    if len(sql['orderBy']) > 0:
        count += 1
    if sql['limit'] is not None:
        count += 1
    if len(sql['from']['table_units']) > 0:  # JOIN
        count += len(sql['from']['table_units']) - 1

    ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
    count += len([token for token in ao if token == 'or'])
    cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
    count += len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')])

    return count


def count_component2(sql):
    nested = get_nestedSQL(sql)
    return len(nested)


def count_others(sql):
    count = 0
    # number of aggregation
    agg_count = count_agg(sql['select'][1])
    agg_count += count_agg(sql['where'][::2])
    agg_count += count_agg(sql['groupBy'])
    if len(sql['orderBy']) > 0:
        agg_count += count_agg([unit[1] for unit in sql['orderBy'][1] if unit[1]] +
                            [unit[2] for unit in sql['orderBy'][1] if unit[2]])
    agg_count += count_agg(sql['having'])
    if agg_count > 1:
        count += 1

    # number of select columns
    if len(sql['select'][1]) > 1:
        count += 1

    # number of where conditions
    if len(sql['where']) > 1:
        count += 1

    # number of group by clauses
    if len(sql['groupBy']) > 1:
        count += 1

    return count


class Evaluator:
    """A simple evaluator"""
    def __init__(self):
        self.partial_scores = None

    def eval_hardness(self, sql):
        count_comp1_ = count_component1(sql)
        count_comp2_ = count_component2(sql)
        count_others_ = count_others(sql)

        if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0:
            return "easy"
        elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or \
                (count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0):
            return "medium"
        elif (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or \
                (2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or \
                (count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1):
            return "hard"
        else:
            return "extra"

    def eval_exact_match(self, pred, label):
        partial_scores = self.eval_partial_match(pred, label)
        self.partial_scores = partial_scores

        for _, score in partial_scores.items():
            if score['f1'] != 1:
                return 0
        if len(label['from']['table_units']) > 0:
            label_tables = sorted(label['from']['table_units'])
            pred_tables = sorted(pred['from']['table_units'])
            if label_tables != pred_tables:
                return False
        if len(label['from']['conds']) > 0:
            label_joins = sorted(label['from']['conds'], key=lambda x: str(x))
            pred_joins = sorted(pred['from']['conds'], key=lambda x: str(x))
            if label_joins != pred_joins:
                return False
        return 1

    def eval_partial_match(self, pred, label):
        res = {}

        label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label)
        acc, rec, f1 = get_scores(cnt, pred_total, label_total)
        res['select'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
        acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
        res['select(no AGG)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}

        label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label)
        acc, rec, f1 = get_scores(cnt, pred_total, label_total)
        res['where'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
        acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
        res['where(no OP)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}

        label_total, pred_total, cnt = eval_group(pred, label)
        acc, rec, f1 = get_scores(cnt, pred_total, label_total)
        res['group(no Having)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}

        label_total, pred_total, cnt = eval_having(pred, label)
        acc, rec, f1 = get_scores(cnt, pred_total, label_total)
        res['group'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}

        label_total, pred_total, cnt = eval_order(pred, label)
        acc, rec, f1 = get_scores(cnt, pred_total, label_total)
        res['order'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}

        label_total, pred_total, cnt = eval_and_or(pred, label)
        acc, rec, f1 = get_scores(cnt, pred_total, label_total)
        res['and/or'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}

        label_total, pred_total, cnt = eval_IUEN(pred, label)
        acc, rec, f1 = get_scores(cnt, pred_total, label_total)
        res['IUEN'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}

        label_total, pred_total, cnt = eval_keywords(pred, label)
        acc, rec, f1 = get_scores(cnt, pred_total, label_total)
        res['keywords'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}

        return res


def isValidSQL(sql, db):
    conn = sqlite3.connect(db)
    cursor = conn.cursor()
    try:
        cursor.execute(sql, [])
    except Exception as e:
        return False
    return True


def print_scores(scores, etype):
    levels = ['easy', 'medium', 'hard', 'extra', 'all']
    partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
                     'group', 'order', 'and/or', 'IUEN', 'keywords']

    print("{:20} {:20} {:20} {:20} {:20} {:20}".format("", *levels))
    counts = [scores[level]['count'] for level in levels]
    print("{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}".format("count", *counts))

    if etype in ["all", "exec"]:
        print('=====================   EXECUTION ACCURACY     =====================')
        this_scores = [scores[level]['exec'] for level in levels]
        print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("execution", *this_scores))

    if etype in ["all", "match"]:
        print('\n====================== EXACT MATCHING ACCURACY =====================')
        exact_scores = [scores[level]['exact'] for level in levels]
        print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("exact match", *exact_scores))
        print('\n---------------------PARTIAL MATCHING ACCURACY----------------------')
        for type_ in partial_types:
            this_scores = [scores[level]['partial'][type_]['acc'] for level in levels]
            print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))

        print('---------------------- PARTIAL MATCHING RECALL ----------------------')
        for type_ in partial_types:
            this_scores = [scores[level]['partial'][type_]['rec'] for level in levels]
            print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))

        print('---------------------- PARTIAL MATCHING F1 --------------------------')
        for type_ in partial_types:
            this_scores = [scores[level]['partial'][type_]['f1'] for level in levels]
            print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))


def evaluate(gold, predict, db_dir, etype, kmaps):
    with open(gold) as f:
        glist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]

    with open(predict) as f:
        plist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
    # plist = [("select max(Share),min(Share) from performance where Type != 'terminal'", "orchestra")]
    # glist = [("SELECT max(SHARE) ,  min(SHARE) FROM performance WHERE TYPE != 'Live final'", "orchestra")]
    evaluator = Evaluator()

    levels = ['easy', 'medium', 'hard', 'extra', 'all']
    partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
                     'group', 'order', 'and/or', 'IUEN', 'keywords']
    entries = []
    scores = {}

    for level in levels:
        scores[level] = {'count': 0, 'partial': {}, 'exact': 0.}
        scores[level]['exec'] = 0
        for type_ in partial_types:
            scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0.,'acc_count':0,'rec_count':0}

    eval_err_num = 0
    for p, g in zip(plist, glist):
        p_str = p[0]
        g_str, db = g
        db_name = db
        db = os.path.join(db_dir, db, db + ".sqlite")
        schema = Schema(get_schema(db))
        g_sql = get_sql(schema, g_str)
        hardness = evaluator.eval_hardness(g_sql)
        scores[hardness]['count'] += 1
        scores['all']['count'] += 1

        try:
            p_sql = get_sql(schema, p_str)
        except:
            # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql
            p_sql = {
            "except": None,
            "from": {
                "conds": [],
                "table_units": []
            },
            "groupBy": [],
            "having": [],
            "intersect": None,
            "limit": None,
            "orderBy": [],
            "select": [
                False,
                []
            ],
            "union": None,
            "where": []
            }
            eval_err_num += 1
            print("eval_err_num:{}".format(eval_err_num))
            print(p_str)
            print()

        # rebuild sql for value evaluation
        kmap = kmaps[db_name]
        g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema)
        g_sql = rebuild_sql_val(g_sql)
        g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap)
        p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema)
        p_sql = rebuild_sql_val(p_sql)
        p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)

        # p_sql_copy = copy.deepcopy(p_sql)
        # g_sql_copy = copy.deepcopy(g_sql)
        # if not eval_exec_match(db, p_str, g_str, p_sql_copy, g_sql_copy) and evaluator.eval_exact_match(p_sql_copy, g_sql_copy):
        #     a = 1

        if etype in ["all", "exec"]:
            exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql)
            if exec_score:
                scores[hardness]['exec'] += 1
            scores['all']['exec'] += exec_score

        if etype in ["all", "match"]:
            exact_score = evaluator.eval_exact_match(p_sql, g_sql)
            partial_scores = evaluator.partial_scores
            if exact_score == 0:
                print("{} pred: {}".format(hardness,p_str))
                print("{} gold: {}".format(hardness,g_str))
                print("")
            scores[hardness]['exact'] += exact_score
            scores['all']['exact'] += exact_score
            for type_ in partial_types:
                if partial_scores[type_]['pred_total'] > 0:
                    scores[hardness]['partial'][type_]['acc'] += partial_scores[type_]['acc']
                    scores[hardness]['partial'][type_]['acc_count'] += 1
                if partial_scores[type_]['label_total'] > 0:
                    scores[hardness]['partial'][type_]['rec'] += partial_scores[type_]['rec']
                    scores[hardness]['partial'][type_]['rec_count'] += 1
                scores[hardness]['partial'][type_]['f1'] += partial_scores[type_]['f1']
                if partial_scores[type_]['pred_total'] > 0:
                    scores['all']['partial'][type_]['acc'] += partial_scores[type_]['acc']
                    scores['all']['partial'][type_]['acc_count'] += 1
                if partial_scores[type_]['label_total'] > 0:
                    scores['all']['partial'][type_]['rec'] += partial_scores[type_]['rec']
                    scores['all']['partial'][type_]['rec_count'] += 1
                scores['all']['partial'][type_]['f1'] += partial_scores[type_]['f1']

            entries.append({
                'predictSQL': p_str,
                'goldSQL': g_str,
                'hardness': hardness,
                'exact': exact_score,
                'partial': partial_scores
            })

    for level in levels:
        if scores[level]['count'] == 0:
            continue
        if etype in ["all", "exec"]:
            scores[level]['exec'] /= scores[level]['count']

        if etype in ["all", "match"]:
            scores[level]['exact'] /= scores[level]['count']
            for type_ in partial_types:
                if scores[level]['partial'][type_]['acc_count'] == 0:
                    scores[level]['partial'][type_]['acc'] = 0
                else:
                    scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \
                                                             scores[level]['partial'][type_]['acc_count'] * 1.0
                if scores[level]['partial'][type_]['rec_count'] == 0:
                    scores[level]['partial'][type_]['rec'] = 0
                else:
                    scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \
                                                             scores[level]['partial'][type_]['rec_count'] * 1.0
                if scores[level]['partial'][type_]['acc'] == 0 and scores[level]['partial'][type_]['rec'] == 0:
                    scores[level]['partial'][type_]['f1'] = 1
                else:
                    scores[level]['partial'][type_]['f1'] = \
                        2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / (
                        scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc'])

    print_scores(scores, etype)


def eval_exec_match(db, p_str, g_str, pred, gold):
    """
    return 1 if the values between prediction and gold are matching
    in the corresponding index. Currently not support multiple col_unit(pairs).
    """
    conn = sqlite3.connect(db)
    cursor = conn.cursor()
    conn.text_factory = bytes
    try:
        cursor.execute(p_str)
        p_res = cursor.fetchall()
    except:
        return False

    cursor.execute(g_str)
    q_res = cursor.fetchall()

    def res_map(res, val_units):
        rmap = {}
        for idx, val_unit in enumerate(val_units):
            key = tuple(val_unit[1]) if not val_unit[2] else (val_unit[0], tuple(val_unit[1]), tuple(val_unit[2]))
            rmap[key] = [r[idx] for r in res]
        return rmap

    p_val_units = [unit[1] for unit in pred['select'][1]]
    q_val_units = [unit[1] for unit in gold['select'][1]]
    return res_map(p_res, p_val_units) == res_map(q_res, q_val_units)


# Rebuild SQL functions for value evaluation
def rebuild_cond_unit_val(cond_unit):
    if cond_unit is None or not DISABLE_VALUE:
        return cond_unit

    not_op, op_id, val_unit, val1, val2 = cond_unit
    if type(val1) is not dict:
        val1 = None
    else:
        val1 = rebuild_sql_val(val1)
    if type(val2) is not dict:
        val2 = None
    else:
        val2 = rebuild_sql_val(val2)
    return not_op, op_id, val_unit, val1, val2


def rebuild_condition_val(condition):
    if condition is None or not DISABLE_VALUE:
        return condition

    res = []
    for idx, it in enumerate(condition):
        if idx % 2 == 0:
            res.append(rebuild_cond_unit_val(it))
        else:
            res.append(it)
    return res


def rebuild_sql_val(sql):
    if sql is None or not DISABLE_VALUE:
        return sql

    sql['from']['conds'] = rebuild_condition_val(sql['from']['conds'])
    sql['having'] = rebuild_condition_val(sql['having'])
    sql['where'] = rebuild_condition_val(sql['where'])
    sql['intersect'] = rebuild_sql_val(sql['intersect'])
    sql['except'] = rebuild_sql_val(sql['except'])
    sql['union'] = rebuild_sql_val(sql['union'])

    return sql


# Rebuild SQL functions for foreign key evaluation
def build_valid_col_units(table_units, schema):
    col_ids = [table_unit[1] for table_unit in table_units if table_unit[0] == TABLE_TYPE['table_unit']]
    prefixs = [col_id[:-2] for col_id in col_ids]
    valid_col_units= []
    for value in schema.idMap.values():
        if '.' in value and value[:value.index('.')] in prefixs:
            valid_col_units.append(value)
    return valid_col_units


def rebuild_col_unit_col(valid_col_units, col_unit, kmap):
    if col_unit is None:
        return col_unit

    agg_id, col_id, distinct = col_unit
    if col_id in kmap and col_id in valid_col_units:
        col_id = kmap[col_id]
    if DISABLE_DISTINCT:
        distinct = None
    return agg_id, col_id, distinct


def rebuild_val_unit_col(valid_col_units, val_unit, kmap):
    if val_unit is None:
        return val_unit

    unit_op, col_unit1, col_unit2 = val_unit
    col_unit1 = rebuild_col_unit_col(valid_col_units, col_unit1, kmap)
    col_unit2 = rebuild_col_unit_col(valid_col_units, col_unit2, kmap)
    return unit_op, col_unit1, col_unit2


def rebuild_table_unit_col(valid_col_units, table_unit, kmap):
    if table_unit is None:
        return table_unit

    table_type, col_unit_or_sql = table_unit
    if isinstance(col_unit_or_sql, tuple):
        col_unit_or_sql = rebuild_col_unit_col(valid_col_units, col_unit_or_sql, kmap)
    return table_type, col_unit_or_sql


def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap):
    if cond_unit is None:
        return cond_unit

    not_op, op_id, val_unit, val1, val2 = cond_unit
    val_unit = rebuild_val_unit_col(valid_col_units, val_unit, kmap)
    return not_op, op_id, val_unit, val1, val2


def rebuild_condition_col(valid_col_units, condition, kmap):
    for idx in range(len(condition)):
        if idx % 2 == 0:
            condition[idx] = rebuild_cond_unit_col(valid_col_units, condition[idx], kmap)
    return condition


def rebuild_select_col(valid_col_units, sel, kmap):
    if sel is None:
        return sel
    distinct, _list = sel
    new_list = []
    for it in _list:
        agg_id, val_unit = it
        new_list.append((agg_id, rebuild_val_unit_col(valid_col_units, val_unit, kmap)))
    if DISABLE_DISTINCT:
        distinct = None
    return distinct, new_list


def rebuild_from_col(valid_col_units, from_, kmap):
    if from_ is None:
        return from_

    from_['table_units'] = [rebuild_table_unit_col(valid_col_units, table_unit, kmap) for table_unit in from_['table_units']]
    from_['conds'] = rebuild_condition_col(valid_col_units, from_['conds'], kmap)
    return from_


def rebuild_group_by_col(valid_col_units, group_by, kmap):
    if group_by is None:
        return group_by

    return [rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by]


def rebuild_order_by_col(valid_col_units, order_by, kmap):
    if order_by is None or len(order_by) == 0:
        return order_by

    direction, val_units = order_by
    new_val_units = [rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units]
    return direction, new_val_units


def rebuild_sql_col(valid_col_units, sql, kmap):
    if sql is None:
        return sql

    sql['select'] = rebuild_select_col(valid_col_units, sql['select'], kmap)
    sql['from'] = rebuild_from_col(valid_col_units, sql['from'], kmap)
    sql['where'] = rebuild_condition_col(valid_col_units, sql['where'], kmap)
    sql['groupBy'] = rebuild_group_by_col(valid_col_units, sql['groupBy'], kmap)
    sql['orderBy'] = rebuild_order_by_col(valid_col_units, sql['orderBy'], kmap)
    sql['having'] = rebuild_condition_col(valid_col_units, sql['having'], kmap)
    sql['intersect'] = rebuild_sql_col(valid_col_units, sql['intersect'], kmap)
    sql['except'] = rebuild_sql_col(valid_col_units, sql['except'], kmap)
    sql['union'] = rebuild_sql_col(valid_col_units, sql['union'], kmap)

    return sql


def build_foreign_key_map(entry):
    cols_orig = entry["column_names_original"]
    tables_orig = entry["table_names_original"]

    # rebuild cols corresponding to idmap in Schema
    cols = []
    for col_orig in cols_orig:
        if col_orig[0] >= 0:
            t = tables_orig[col_orig[0]]
            c = col_orig[1]
            cols.append("__" + t.lower() + "." + c.lower() + "__")
        else:
            cols.append("__all__")

    def keyset_in_list(k1, k2, k_list):
        for k_set in k_list:
            if k1 in k_set or k2 in k_set:
                return k_set
        new_k_set = set()
        k_list.append(new_k_set)
        return new_k_set

    foreign_key_list = []
    foreign_keys = entry["foreign_keys"]
    for fkey in foreign_keys:
        key1, key2 = fkey
        key_set = keyset_in_list(key1, key2, foreign_key_list)
        key_set.add(key1)
        key_set.add(key2)

    foreign_key_map = {}
    for key_set in foreign_key_list:
        sorted_list = sorted(list(key_set))
        midx = sorted_list[0]
        for idx in sorted_list:
            foreign_key_map[cols[idx]] = cols[midx]

    return foreign_key_map


def build_foreign_key_map_from_json(table):
    with open(table) as f:
        data = json.load(f)
    tables = {}
    for entry in data:
        tables[entry['db_id']] = build_foreign_key_map(entry)
    return tables


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--gold', dest='gold', type=str)
    parser.add_argument('--pred', dest='pred', type=str)
    parser.add_argument('--db', dest='db', type=str)
    parser.add_argument('--table', dest='table', type=str)
    parser.add_argument('--etype', dest='etype', type=str)
    args = parser.parse_args()

    gold = args.gold
    pred = args.pred
    db_dir = args.db
    table = args.table
    etype = args.etype

    assert etype in ["all", "exec", "match"], "Unknown evaluation method"

    kmaps = build_foreign_key_map_from_json(table)

    evaluate(gold, pred, db_dir, etype, kmaps)

================================================
FILE: spider_evaluation/process_sql.py
================================================
################################
# Assumptions:
#   1. sql is correct
#   2. only table name has alias
#   3. only one intersect/union/except
#
# val: number(float)/string(str)/sql(dict)
# col_unit: (agg_id, col_id, isDistinct(bool))
# val_unit: (unit_op, col_unit1, col_unit2)
# table_unit: (table_type, col_unit/sql)
# cond_unit: (not_op, op_id, val_unit, val1, val2)
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
# sql {
#   'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
#   'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
#   'where': condition
#   'groupBy': [col_unit1, col_unit2, ...]
#   'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
#   'having': condition
#   'limit': None/limit value
#   'intersect': None/sql
#   'except': None/sql
#   'union': None/sql
# }
################################

import json
import sqlite3
from nltk import word_tokenize

CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
JOIN_KEYWORDS = ('join', 'on', 'as')

WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
UNIT_OPS = ('none', '-', '+', "*", '/')
AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
TABLE_TYPE = {
    'sql': "sql",
    'table_unit': "table_unit",
}

COND_OPS = ('and', 'or')
SQL_OPS = ('intersect', 'union', 'except')
ORDER_OPS = ('desc', 'asc')
mapped_entities = []


class Schema:
    """
    Simple schema which maps table&column to a unique identifier
    """
    def __init__(self, schema):
        self._schema = schema
        self._idMap = self._map(self._schema)

    @property
    def schema(self):
        return self._schema

    @property
    def idMap(self):
        return self._idMap

    def _map(self, schema):
        idMap = {'*': "__all__"}
        id = 1
        for key, vals in schema.items():
            for val in vals:
                idMap[key.lower() + "." + val.lower()] = "__" + key.lower() + "." + val.lower() + "__"
                id += 1

        for key in schema:
            idMap[key.lower()] = "__" + key.lower() + "__"
            id += 1

        return idMap


def get_schema(db):
    """
    Get database's schema, which is a dict with table name as key
    and list of column names as value
    :param db: database path
    :return: schema dict
    """

    schema = {}
    conn = sqlite3.connect(db)
    cursor = conn.cursor()

    # fetch table names
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = [str(table[0].lower()) for table in cursor.fetchall()]

    # fetch table info
    for table in tables:
        cursor.execute("PRAGMA table_info({})".format(table))
        schema[table] = [str(col[1].lower()) for col in cursor.fetchall()]

    return schema


def get_schema_from_json(fpath):
    with open(fpath) as f:
        data = json.load(f)

    schema = {}
    for entry in data:
        table = str(entry['table'].lower())
        cols = [str(col['column_name'].lower()) for col in entry['col_data']]
        schema[table] = cols

    return schema


def tokenize(string):
    string = str(string)
    string = string.replace("\'", "\"")  # ensures all string values wrapped by "" problem??
    quote_idxs = [idx for idx, char in enumerate(string) if char == '"']
    assert len(quote_idxs) % 2 == 0, "Unexpected quote"

    # keep string value as token
    vals = {}
    for i in range(len(quote_idxs)-1, -1, -2):
        qidx1 = quote_idxs[i-1]
        qidx2 = quote_idxs[i]
        val = string[qidx1: qidx2+1]
        key = "__val_{}_{}__".format(qidx1, qidx2)
        string = string[:qidx1] + key + string[qidx2+1:]
        vals[key] = val

    toks = [word.lower() for word in word_tokenize(string)]
    # replace with string value token
    for i in range(len(toks)):
        if toks[i] in vals:
            toks[i] = vals[toks[i]]

    # find if there exists !=, >=, <=
    eq_idxs = [idx for idx, tok in enumerate(toks) if tok == "="]
    eq_idxs.reverse()
    prefix = ('!', '>', '<')
    for eq_idx in eq_idxs:
        pre_tok = toks[eq_idx-1]
        if pre_tok in prefix:
            toks = toks[:eq_idx-1] + [pre_tok + "="] + toks[eq_idx+1: ]

    return toks


def scan_alias(toks):
    """Scan the index of 'as' and build the map for all alias"""
    as_idxs = [idx for idx, tok in enumerate(toks) if tok == 'as']
    alias = {}
    for idx in as_idxs:
        alias[toks[idx+1]] = toks[idx-1]
    return alias


def get_tables_with_alias(schema, toks):
    tables = scan_alias(toks)
    for key in schema:
        assert key not in tables, "Alias {} has the same name in table".format(key)
        tables[key] = key
    return tables


def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None):
    """
        :returns next idx, column id
    """
    global mapped_entities
    tok = toks[start_idx]
    if tok == "*":
        return start_idx + 1, schema.idMap[tok]

    if '.' in tok:  # if token is a composite
        alias, col = tok.split('.')
        key = tables_with_alias[alias] + "." + col
        mapped_entities.append((start_idx, tables_with_alias[alias] + "@" + col))
        return start_idx+1, schema.idMap[key]

    assert default_tables is not None and len(default_tables) > 0, "Default tables should not be None or empty"

    for alias in default_tables:
        table = tables_with_alias[alias]
        if tok in schema.schema[table]:
            key = table + "." + tok
            mapped_entities.append((start_idx, table + "@" + tok))
            return start_idx+1, schema.idMap[key]

    assert False, "Error col: {}".format(tok)


def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=None):
    """
        :returns next idx, (agg_op id, col_id)
    """
    idx = start_idx
    len_ = len(toks)
    isBlock = False
    isDistinct = False
    if toks[idx] == '(':
        isBlock = True
        idx += 1

    if toks[idx] in AGG_OPS:
        agg_id = AGG_OPS.index(toks[idx])
        idx += 1
        assert idx < len_ and toks[idx] == '('
        idx += 1
        if toks[idx] == "distinct":
            idx += 1
            isDistinct = True
        idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables)
        assert idx < len_ and toks[idx] == ')'
        idx += 1
        return idx, (agg_id, col_id, isDistinct)

    if toks[idx] == "distinct":
        idx += 1
        isDistinct = True
    agg_id = AGG_OPS.index("none")
    idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables)

    if isBlock:
        assert toks[idx] == ')'
        idx += 1  # skip ')'

    return idx, (agg_id, col_id, isDistinct)


def parse_val_unit(toks, start_idx, tables_with_alias, schema, default_tables=None):
    idx = start_idx
    len_ = len(toks)
    isBlock = False
    if toks[idx] == '(':
        isBlock = True
        idx += 1

    col_unit1 = None
    col_unit2 = None
    unit_op = UNIT_OPS.index('none')

    idx, col_unit1 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
    if idx < len_ and toks[idx] in UNIT_OPS:
        unit_op = UNIT_OPS.index(toks[idx])
        idx += 1
        idx, col_unit2 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)

    if isBlock:
        assert toks[idx] == ')'
        idx += 1  # skip ')'

    return idx, (unit_op, col_unit1, col_unit2)


def parse_table_unit(toks, start_idx, tables_with_alias, schema):
    """
        :returns next idx, table id, table name
    """
    idx = start_idx
    len_ = len(toks)
    key = tables_with_alias[toks[idx]]

    if idx + 1 < len_ and toks[idx+1] == "as":
        idx += 3
    else:
        idx += 1

    return idx, schema.idMap[key], key


def parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None):
    idx = start_idx
    len_ = len(toks)

    isBlock = False
    if toks[idx] == '(':
        isBlock = True
        idx += 1

    if toks[idx] == 'select':
        idx, val = parse_sql(toks, idx, tables_with_alias, schema)
    elif "\"" in toks[idx]:  # token is a string value
        val = toks[idx]
        idx += 1
    else:
        try:
            val = float(toks[idx])
            idx += 1
        except:
            end_idx = idx
            while end_idx < len_ and toks[end_idx] != ',' and toks[end_idx] != ')'\
                and toks[end_idx] != 'and' and toks[end_idx] not in CLAUSE_KEYWORDS and toks[end_idx] not in JOIN_KEYWORDS:
                    end_idx += 1

            idx, val = parse_col_unit(toks[: end_idx], start_idx, tables_with_alias, schema, default_tables)
            idx = end_idx

    if isBlock:
        assert toks[idx] == ')'
        idx += 1

    return idx, val


def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=None):
    idx = start_idx
    len_ = len(toks)
    conds = []

    while idx < len_:
        idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
        not_op = False
        if toks[idx] == 'not':
            not_op = True
            idx += 1

        assert idx < len_ and toks[idx] in WHERE_OPS, "Error condition: idx: {}, tok: {}".format(idx, toks[idx])
        op_id = WHERE_OPS.index(toks[idx])
        idx += 1
        val1 = val2 = None
        if op_id == WHERE_OPS.index('between'):  # between..and... special case: dual values
            idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
            assert toks[idx] == 'and'
            idx += 1
            idx, val2 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
        else:  # normal case: single value
            idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
            val2 = None

        conds.append((not_op, op_id, val_unit, val1, val2))

        if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";") or toks[idx] in JOIN_KEYWORDS):
            break

        if idx < len_ and toks[idx] in COND_OPS:
            conds.append(toks[idx])
            idx += 1  # skip and/or

    return idx, conds


def parse_select(toks, start_idx, tables_with_alias, schema, default_tables=None):
    idx = start_idx
    len_ = len(toks)

    assert toks[idx] == 'select', "'select' not found"
    idx += 1
    isDistinct = False
    if idx < len_ and toks[idx] == 'distinct':
        idx += 1
        isDistinct = True
    val_units = []

    while idx < len_ and toks[idx] not in CLAUSE_KEYWORDS:
        agg_id = AGG_OPS.index("none")
        if toks[idx] in AGG_OPS:
            agg_id = AGG_OPS.index(toks[idx])
            idx += 1
        idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
        val_units.append((agg_id, val_unit))
        if idx < len_ and toks[idx] == ',':
            idx += 1  # skip ','

    return idx, (isDistinct, val_units)


def parse_from(toks, start_idx, tables_with_alias, schema):
    """
    Assume in the from clause, all table units are combined with join
    """
    assert 'from' in toks[start_idx:], "'from' not found"

    len_ = len(toks)
    idx = toks.index('from', start_idx) + 1
    default_tables = []
    table_units = []
    conds = []

    while idx < len_:
        isBlock = False
        if toks[idx] == '(':
            isBlock = True
            idx += 1

        if toks[idx] == 'select':
            idx, sql = parse_sql(toks, idx, tables_with_alias, schema)
            table_units.append((TABLE_TYPE['sql'], sql))
        else:
            if idx < len_ and toks[idx] == 'join':
                idx += 1  # skip join
            idx, table_unit, table_name = parse_table_unit(toks, idx, tables_with_alias, schema)
            table_units.append((TABLE_TYPE['table_unit'],table_unit))
            default_tables.append(table_name)
        if idx < len_ and toks[idx] == "on":
            idx += 1  # skip on
            idx, this_conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
            if len(conds) > 0:
                conds.append('and')
            conds.extend(this_conds)

        if isBlock:
            assert toks[idx] == ')'
            idx += 1
        if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
            break

    return idx, table_units, conds, default_tables


def parse_where(toks, start_idx, tables_with_alias, schema, default_tables):
    idx = start_idx
    len_ = len(toks)

    if idx >= len_ or toks[idx] != 'where':
        return idx, []

    idx += 1
    idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
    return idx, conds


def parse_group_by(toks, start_idx, tables_with_alias, schema, default_tables):
    idx = start_idx
    len_ = len(toks)
    col_units = []

    if idx >= len_ or toks[idx] != 'group':
        return idx, col_units

    idx += 1
    assert toks[idx] == 'by'
    idx += 1

    while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
        idx, col_unit = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
        col_units.append(col_unit)
        if idx < len_ and toks[idx] == ',':
            idx += 1  # skip ','
        else:
            break

    return idx, col_units


def parse_order_by(toks, start_idx, tables_with_alias, schema, default_tables):
    idx = start_idx
    len_ = len(toks)
    val_units = []
    order_type = 'asc' # default type is 'asc'

    if idx >= len_ or toks[idx] != 'order':
        return idx, val_units

    idx += 1
    assert toks[idx] == 'by'
    idx += 1

    while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
        idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
        val_units.append(val_unit)
        if idx < len_ and toks[idx] in ORDER_OPS:
            order_type = toks[idx]
            idx += 1
        if idx < len_ and toks[idx] == ',':
            idx += 1  # skip ','
        else:
            break

    return idx, (order_type, val_units)


def parse_having(toks, start_idx, tables_with_alias, schema, default_tables):
    idx = start_idx
    len_ = len(toks)

    if idx >= len_ or toks[idx] != 'having':
        return idx, []

    idx += 1
    idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
    return idx, conds


def parse_limit(toks, start_idx):
    idx = start_idx
    len_ = len(toks)

    if idx < len_ and toks[idx] == 'limit':
        idx += 2
        try:
            limit_val = int(toks[idx-1])
        except Exception:
            limit_val = '"value"'
        return idx, limit_val

    return idx, None


def parse_sql(toks, start_idx, tables_with_alias, schema, mapped_entities_fn=None):
    global mapped_entities

    if mapped_entities_fn is not None:
        mapped_entities = mapped_entities_fn()
    isBlock = False # indicate whether this is a block of sql/sub-sql
    len_ = len(toks)
    idx = start_idx

    sql = {}
    if toks[idx] == '(':
        isBlock = True
        idx += 1

    # parse from clause in order to get default tables
    from_end_idx, table_units, conds, default_tables = parse_from(toks, start_idx, tables_with_alias, schema)
    sql['from'] = {'table_units': table_units, 'conds': conds}
    # select clause
    _, select_col_units = parse_select(toks, idx, tables_with_alias, schema, default_tables)
    idx = from_end_idx
    sql['select'] = select_col_units
    # where clause
    idx, where_conds = parse_where(toks, idx, tables_with_alias, schema, default_tables)
    sql['where'] = where_conds
    # group by clause
    idx, group_col_units = parse_group_by(toks, idx, tables_with_alias, schema, default_tables)
    sql['groupBy'] = group_col_units
    # having clause
    idx, having_conds = parse_having(toks, idx, tables_with_alias, schema, default_tables)
    sql['having'] = having_conds
    # order by clause
    idx, order_col_units = parse_order_by(toks, idx, tables_with_alias, schema, default_tables)
    sql['orderBy'] = order_col_units
    # limit clause
    idx, limit_val = parse_limit(toks, idx)
    sql['limit'] = limit_val

    idx = skip_semicolon(toks, idx)
    if isBlock:
        assert toks[idx] == ')'
        idx += 1  # skip ')'
    idx = skip_semicolon(toks, idx)

    # intersect/union/except clause
    for op in SQL_OPS:  # initialize IUE
        sql[op] = None
    if idx < len_ and toks[idx] in SQL_OPS:
        sql_op = toks[idx]
        idx += 1
        idx, IUE_sql = parse_sql(toks, idx, tables_with_alias, schema)
        sql[sql_op] = IUE_sql

    if mapped_entities_fn is not None:
        return idx, sql, mapped_entities
    else:
        return idx, sql


def load_data(fpath):
    with open(fpath) as f:
        data = json.load(f)
    return data


def get_sql(schema, query):
    toks = tokenize(query)
    tables_with_alias = get_tables_with_alias(schema.schema, toks)
    _, sql = parse_sql(toks, 0, tables_with_alias, schema)

    return sql


def skip_semicolon(toks, start_idx):
    idx = start_idx
    while idx < len(toks) and toks[idx] == ";":
        idx += 1
    return idx

================================================
FILE: state_machines/states/grammar_based_state.py
================================================
from typing import Any, Dict, List, Sequence, Tuple

import torch

from allennlp.data.fields.production_rule_field import ProductionRule
from allennlp.state_machines.states.grammar_statelet import GrammarStatelet
from allennlp.state_machines.states.rnn_statelet import RnnStatelet
from allennlp.state_machines.states.state import State

# This syntax is pretty weird and ugly, but it's necessary to make mypy happy with the API that
# we've defined.  We're using generics to make the type of `combine_states` come out right.  See
# the note in `state_machines.state.py` for a little more detail.
from state_machines.states.sql_state import SqlState


class GrammarBasedState(State['GrammarBasedState']):
    """
    A generic State that's suitable for most models that do grammar-based decoding.  We keep around
    a `group` of states, and each element in the group has a few things: a batch index, an action
    history, a score, an ``RnnStatelet``, and a ``GrammarStatelet``.  We additionally have some
    information that's independent of any particular group element: a list of all possible actions
    for all batch instances passed to ``model.forward()``, and a ``extras`` field that you can use
    if you really need some extra information about each batch instance (like a string description,
    or other metadata).

    Finally, we also have a specially-treated, optional ``debug_info`` field.  If this is given, it
    should be an empty list for each group instance when the initial state is created.  In that
    case, we will keep around information about the actions considered at each timestep of decoding
    and other things that you might want to visualize in a demo.  This probably isn't necessary for
    training, and to get it right we need to copy a bunch of data structures for each new state, so
    it's best used only at evaluation / demo time.

    Parameters
    ----------
    batch_indices : ``List[int]``
        Passed to super class; see docs there.
    action_history : ``List[List[int]]``
        Passed to super class; see docs there.
    score : ``List[torch.Tensor]``
        Passed to super class; see docs there.
    rnn_state : ``List[RnnStatelet]``
        An ``RnnStatelet`` for every group element.  This keeps track of the current decoder hidden
        state, the previous decoder output, the output from the encoder (for computing attentions),
        and other things that are typical seq2seq decoder state things.
    grammar_state : ``List[GrammarStatelet]``
        This hold the current grammar state for each element of the group.  The ``GrammarStatelet``
        keeps track of which actions are currently valid.
    possible_actions : ``List[List[ProductionRule]]``
        The list of all possible actions that was passed to ``model.forward()``.  We need this so
        we can recover production strings, which we need to update grammar states.
    extras : ``List[Any]``, optional (default=None)
        If you need to keep around some extra data for each instance in the batch, you can put that
        in here, without adding another field.  This should be used `very sparingly`, as there is
        no type checking or anything done on the contents of this field, and it will just be passed
        around between ``States`` as-is, without copying.
    debug_info : ``List[Any]``, optional (default=None).
    """
    def __init__(self,
                 batch_indices: List[int],
                 action_history: List[List[int]],
                 score: List[torch.Tensor],
                 rnn_state: List[RnnStatelet],
                 grammar_state: List[GrammarStatelet],
                 sql_state: List[SqlState],
                 possible_actions: List[List[ProductionRule]],
                 action_entity_mapping: List[Dict[int, int]],
                 extras: List[Any] = None,
                 debug_info: List = None) -> None:
        super().__init__(batch_indices, action_history, score)
        self.rnn_state = rnn_state
        self.grammar_state = grammar_state
        self.sql_state = sql_state
        self.possible_actions = possible_actions
        self.action_entity_mapping = action_entity_mapping
        self.extras = extras
        self.debug_info = debug_info

    def new_state_from_group_index(self,
                                   group_index: int,
                                   action: int,
                                   new_score: torch.Tensor,
                                   new_rnn_state: RnnStatelet,
                                   considered_actions: List[int] = None,
                                   action_probabilities: List[float] = None,
                                   attention_weights: torch.Tensor = None,
                                   linking_scores_qst: torch.Tensor = None,
                                   linking_scores_past: torch.Tensor = None) -> 'GrammarBasedState':
        batch_index = self.batch_indices[group_index]
        new_action_history = self.action_history[group_index] + [action]
        production_rule = self.possible_actions[batch_index][action][0]
        new_grammar_state = self.grammar_state[group_index].take_action(production_rule)
        new_sql_state = self.sql_state[group_index].take_action(production_rule)
        if self.debug_info is not None:
            attention = attention_weights[group_index] if attention_weights is not None else None
            # output_attention = output_attention_weights[group_index] if output_attention_weights is not None else None
            debug_info = {
                    'considered_actions': considered_actions,
                    'question_attention': attention,
                    'probabilities': action_probabilities,
                    'linking_scores_qst': linking_scores_qst,
                    'linking_scores_past': linking_scores_past,
                    }
            new_debug_info = [self.debug_info[group_index] + [debug_info]]
        else:
            new_debug_info = None
        return GrammarBasedState(batch_indices=[batch_index],
                                 action_history=[new_action_history],
                                 score=[new_score],
 
Download .txt
gitextract_4_pkjgz2/

├── README.md
├── dataset_readers/
│   ├── __init__.py
│   ├── dataset_util/
│   │   └── spider_utils.py
│   ├── fields/
│   │   └── knowledge_graph_field.py
│   └── spider.py
├── models/
│   ├── __init__.py
│   └── semantic_parsing/
│       ├── __init__.py
│       └── spider_parser.py
├── modules/
│   └── gated_graph_conv.py
├── predictors/
│   └── spider_predictor.py
├── requirements.txt
├── semparse/
│   ├── contexts/
│   │   ├── spider_context_utils.py
│   │   ├── spider_db_context.py
│   │   └── spider_db_grammar.py
│   └── worlds/
│       ├── evaluate.py
│       ├── evaluate_spider.py
│       └── spider_world.py
├── spider_evaluation/
│   ├── evaluate.py
│   └── process_sql.py
├── state_machines/
│   ├── states/
│   │   ├── grammar_based_state.py
│   │   ├── rnn_statelet.py
│   │   └── sql_state.py
│   └── transition_functions/
│       ├── attend_past_schema_items_transition.py
│       ├── basic_transition_function.py
│       ├── linking_transition_function.py
│       └── prefix_attend_transition.py
└── train_configs/
    ├── defaults.jsonnet
    └── paper_defaults.jsonnet
Download .txt
SYMBOL INDEX (235 symbols across 21 files)

FILE: dataset_readers/dataset_util/spider_utils.py
  class TableColumn (line 16) | class TableColumn:
    method __init__ (line 17) | def __init__(self,
  class Table (line 30) | class Table:
    method __init__ (line 31) | def __init__(self,
  function read_dataset_schema (line 40) | def read_dataset_schema(schema_path: str) -> Dict[str, List[Table]]:
  function read_dataset_values (line 75) | def read_dataset_values(db_id: str, dataset_path: str, tables: List[str]):
  function ent_key_to_name (line 99) | def ent_key_to_name(key):
  function fix_number_value (line 110) | def fix_number_value(ex: JsonDict):
  function disambiguate_items (line 159) | def disambiguate_items(db_id: str, query_toks: List[str], tables_file: s...

FILE: dataset_readers/fields/knowledge_graph_field.py
  class SpiderKnowledgeGraphField (line 12) | class SpiderKnowledgeGraphField(KnowledgeGraphField):
    method __init__ (line 17) | def __init__(self,
    method _compute_related_linking_features (line 48) | def _compute_related_linking_features(self,

FILE: dataset_readers/spider.py
  class SpiderDatasetReader (line 25) | class SpiderDatasetReader(DatasetReader):
    method __init__ (line 26) | def __init__(self,
    method _read (line 53) | def _read(self, file_path: str):
    method text_to_instance (line 121) | def text_to_instance(self,

FILE: models/semantic_parsing/spider_parser.py
  class SpiderParser (line 37) | class SpiderParser(Model):
    method __init__ (line 38) | def __init__(self,
    method forward (line 158) | def forward(self,  # type: ignore
    method _get_initial_state (line 216) | def _get_initial_state(self,
    method _get_neighbor_indices (line 376) | def _get_neighbor_indices(worlds: List[SpiderWorld],
    method _get_schema_graph_encoding (line 428) | def _get_schema_graph_encoding(self,
    method _get_graph_adj_lists (line 459) | def _get_graph_adj_lists(device, world, global_entity_id, global_node=...
    method _create_grammar_state (line 506) | def _create_grammar_state(self,
    method is_nonterminal (line 588) | def is_nonterminal(token: str):
    method _get_linking_probabilities (line 593) | def _get_linking_probabilities(self,
    method _action_history_match (line 676) | def _action_history_match(predicted: List[int], targets: torch.LongTen...
    method _query_difficulty (line 687) | def _query_difficulty(targets: torch.LongTensor, action_mapping, batch...
    method get_metrics (line 693) | def get_metrics(self, reset: bool = False) -> Dict[str, float]:
    method _get_type_vector (line 704) | def _get_type_vector(worlds: List[SpiderWorld],
    method _compute_validation_outputs (line 757) | def _compute_validation_outputs(self,

FILE: modules/gated_graph_conv.py
  class GatedGraphConv (line 10) | class GatedGraphConv(MessagePassing):
    method __init__ (line 32) | def __init__(self, input_dim, num_timesteps, num_edge_types, aggr='add...
    method reset_parameters (line 46) | def reset_parameters(self):
    method forward (line 53) | def forward(self, x, edge_indices):
    method __repr__ (line 81) | def __repr__(self):

FILE: predictors/spider_predictor.py
  class WikiTablesParserPredictor (line 10) | class WikiTablesParserPredictor(Predictor):
    method __init__ (line 11) | def __init__(self, model: Model, dataset_reader: DatasetReader) -> None:
    method predict_instance (line 15) | def predict_instance(self, instance: Instance) -> JsonDict:
    method dump_line (line 26) | def dump_line(self, outputs: JsonDict) -> str:  # pylint: disable=no-s...

FILE: semparse/contexts/spider_context_utils.py
  function format_grammar_string (line 16) | def format_grammar_string(grammar_dictionary: Dict[str, List[str]]) -> str:
  function initialize_valid_actions (line 25) | def initialize_valid_actions(grammar: Grammar,
  function format_action (line 64) | def format_action(nonterminal: str,
  function action_sequence_to_sql (line 112) | def action_sequence_to_sql(action_sequences: List[str], add_table_names:...
  class SqlVisitor (line 139) | class SqlVisitor(NodeVisitor):
    method __init__ (line 162) | def __init__(self, grammar: Grammar, keywords_to_uppercase: List[str] ...
    method generic_visit (line 168) | def generic_visit(self, node: Node, visited_children: List[None]) -> L...
    method add_action (line 174) | def add_action(self, node: Node) -> None:
    method visit (line 205) | def visit(self, node):

FILE: semparse/contexts/spider_db_context.py
  class SpiderDBContext (line 30) | class SpiderDBContext:
    method __init__ (line 35) | def __init__(self, db_id: str, utterance: str, tokenizer: Tokenizer, t...
    method entity_key_for_column (line 56) | def entity_key_for_column(table_name: str, column: TableColumn) -> str:
    method get_db_knowledge_graph (line 65) | def get_db_knowledge_graph(self, db_id: str) -> KnowledgeGraph:
    method _string_in_table (line 134) | def _string_in_table(self, candidate: str,
    method get_entities_from_question (line 152) | def get_entities_from_question(self,
    method normalize_string (line 182) | def normalize_string(string: str) -> str:
    method _expand_entities (line 221) | def _expand_entities(self, question, entity_data, string_column_mappin...

FILE: semparse/contexts/spider_db_grammar.py
  function update_grammar_with_tables (line 105) | def update_grammar_with_tables(grammar_dictionary: Dict[str, List[str]],
  function update_grammar_to_be_table_names_free (line 118) | def update_grammar_to_be_table_names_free(grammar_dictionary: Dict[str, ...
  function update_grammar_flip_joins (line 131) | def update_grammar_flip_joins(grammar_dictionary: Dict[str, List[str]]):

FILE: semparse/worlds/evaluate.py
  function condition_has_or (line 57) | def condition_has_or(conds):
  function condition_has_like (line 61) | def condition_has_like(conds):
  function condition_has_sql (line 65) | def condition_has_sql(conds):
  function val_has_op (line 75) | def val_has_op(val_unit):
  function has_agg (line 79) | def has_agg(unit):
  function accuracy (line 83) | def accuracy(count, total):
  function recall (line 89) | def recall(count, total):
  function F1 (line 95) | def F1(acc, rec):
  function get_scores (line 101) | def get_scores(count, pred_total, label_total):
  function eval_sel (line 109) | def eval_sel(pred, label):
  function eval_where (line 129) | def eval_where(pred, label):
  function eval_group (line 149) | def eval_group(pred, label):
  function eval_having (line 164) | def eval_having(pred, label):
  function eval_order (line 181) | def eval_order(pred, label):
  function eval_and_or (line 193) | def eval_and_or(pred, label):
  function get_nestedSQL (line 204) | def get_nestedSQL(sql):
  function eval_nested (line 220) | def eval_nested(pred, label):
  function eval_IUEN (line 233) | def eval_IUEN(pred, label):
  function get_keywords (line 243) | def get_keywords(sql):
  function eval_keywords (line 284) | def eval_keywords(pred, label):
  function count_agg (line 297) | def count_agg(units):
  function count_component1 (line 301) | def count_component1(sql):
  function count_component2 (line 322) | def count_component2(sql):
  function count_others (line 327) | def count_others(sql):
  class Evaluator (line 355) | class Evaluator:
    method __init__ (line 357) | def __init__(self):
    method eval_hardness (line 360) | def eval_hardness(self, sql):
    method eval_exact_match (line 377) | def eval_exact_match(self, pred, label):
    method eval_partial_match (line 396) | def eval_partial_match(self, pred, label):
  function isValidSQL (line 438) | def isValidSQL(sql, db):
  function print_scores (line 448) | def print_scores(scores, etype):
  function evaluate (line 482) | def evaluate(gold, predict, db_dir, etype, kmaps):
  function eval_exec_match (line 625) | def eval_exec_match(db, p_str, g_str, pred, gold):
  function rebuild_cond_unit_val (line 655) | def rebuild_cond_unit_val(cond_unit):
  function rebuild_condition_val (line 671) | def rebuild_condition_val(condition):
  function rebuild_sql_val (line 684) | def rebuild_sql_val(sql):
  function build_valid_col_units (line 699) | def build_valid_col_units(table_units, schema):
  function rebuild_col_unit_col (line 709) | def rebuild_col_unit_col(valid_col_units, col_unit, kmap):
  function rebuild_val_unit_col (line 721) | def rebuild_val_unit_col(valid_col_units, val_unit, kmap):
  function rebuild_table_unit_col (line 731) | def rebuild_table_unit_col(valid_col_units, table_unit, kmap):
  function rebuild_cond_unit_col (line 741) | def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap):
  function rebuild_condition_col (line 750) | def rebuild_condition_col(valid_col_units, condition, kmap):
  function rebuild_select_col (line 757) | def rebuild_select_col(valid_col_units, sel, kmap):
  function rebuild_from_col (line 770) | def rebuild_from_col(valid_col_units, from_, kmap):
  function rebuild_group_by_col (line 779) | def rebuild_group_by_col(valid_col_units, group_by, kmap):
  function rebuild_order_by_col (line 786) | def rebuild_order_by_col(valid_col_units, order_by, kmap):
  function rebuild_sql_col (line 795) | def rebuild_sql_col(valid_col_units, sql, kmap):
  function build_foreign_key_map (line 812) | def build_foreign_key_map(entry):
  function build_foreign_key_map_from_json (line 852) | def build_foreign_key_map_from_json(table):

FILE: semparse/worlds/evaluate_spider.py
  function evaluate (line 12) | def evaluate(gold, predict, db_name, db_dir, table, check_valid: bool=Tr...
  function check_valid_sql (line 56) | def check_valid_sql(sql, db_name, db_dir, return_error=False):

FILE: semparse/worlds/spider_world.py
  class SpiderWorld (line 13) | class SpiderWorld:
    method __init__ (line 18) | def __init__(self, db_context: SpiderDBContext, query: Optional[List[s...
    method get_action_sequence_and_all_actions (line 39) | def get_action_sequence_and_all_actions(self,
    method get_all_actions (line 69) | def get_all_actions(self, schema,
    method is_global_rule (line 92) | def is_global_rule(self, rhs: str) -> bool:
    method get_oracle_relevance_score (line 98) | def get_oracle_relevance_score(self, oracle_entities: set):
    method get_action_entity_mapping (line 116) | def get_action_entity_mapping(self) -> Dict[int, int]:
    method get_query_without_table_hints (line 132) | def get_query_without_table_hints(self):

FILE: spider_evaluation/evaluate.py
  function condition_has_or (line 57) | def condition_has_or(conds):
  function condition_has_like (line 61) | def condition_has_like(conds):
  function condition_has_sql (line 65) | def condition_has_sql(conds):
  function val_has_op (line 75) | def val_has_op(val_unit):
  function has_agg (line 79) | def has_agg(unit):
  function accuracy (line 83) | def accuracy(count, total):
  function recall (line 89) | def recall(count, total):
  function F1 (line 95) | def F1(acc, rec):
  function get_scores (line 101) | def get_scores(count, pred_total, label_total):
  function eval_sel (line 109) | def eval_sel(pred, label):
  function eval_where (line 129) | def eval_where(pred, label):
  function eval_group (line 149) | def eval_group(pred, label):
  function eval_having (line 164) | def eval_having(pred, label):
  function eval_order (line 181) | def eval_order(pred, label):
  function eval_and_or (line 193) | def eval_and_or(pred, label):
  function get_nestedSQL (line 204) | def get_nestedSQL(sql):
  function eval_nested (line 220) | def eval_nested(pred, label):
  function eval_IUEN (line 233) | def eval_IUEN(pred, label):
  function get_keywords (line 243) | def get_keywords(sql):
  function eval_keywords (line 284) | def eval_keywords(pred, label):
  function count_agg (line 297) | def count_agg(units):
  function count_component1 (line 301) | def count_component1(sql):
  function count_component2 (line 322) | def count_component2(sql):
  function count_others (line 327) | def count_others(sql):
  class Evaluator (line 355) | class Evaluator:
    method __init__ (line 357) | def __init__(self):
    method eval_hardness (line 360) | def eval_hardness(self, sql):
    method eval_exact_match (line 377) | def eval_exact_match(self, pred, label):
    method eval_partial_match (line 396) | def eval_partial_match(self, pred, label):
  function isValidSQL (line 438) | def isValidSQL(sql, db):
  function print_scores (line 448) | def print_scores(scores, etype):
  function evaluate (line 482) | def evaluate(gold, predict, db_dir, etype, kmaps):
  function eval_exec_match (line 625) | def eval_exec_match(db, p_str, g_str, pred, gold):
  function rebuild_cond_unit_val (line 655) | def rebuild_cond_unit_val(cond_unit):
  function rebuild_condition_val (line 671) | def rebuild_condition_val(condition):
  function rebuild_sql_val (line 684) | def rebuild_sql_val(sql):
  function build_valid_col_units (line 699) | def build_valid_col_units(table_units, schema):
  function rebuild_col_unit_col (line 709) | def rebuild_col_unit_col(valid_col_units, col_unit, kmap):
  function rebuild_val_unit_col (line 721) | def rebuild_val_unit_col(valid_col_units, val_unit, kmap):
  function rebuild_table_unit_col (line 731) | def rebuild_table_unit_col(valid_col_units, table_unit, kmap):
  function rebuild_cond_unit_col (line 741) | def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap):
  function rebuild_condition_col (line 750) | def rebuild_condition_col(valid_col_units, condition, kmap):
  function rebuild_select_col (line 757) | def rebuild_select_col(valid_col_units, sel, kmap):
  function rebuild_from_col (line 770) | def rebuild_from_col(valid_col_units, from_, kmap):
  function rebuild_group_by_col (line 779) | def rebuild_group_by_col(valid_col_units, group_by, kmap):
  function rebuild_order_by_col (line 786) | def rebuild_order_by_col(valid_col_units, order_by, kmap):
  function rebuild_sql_col (line 795) | def rebuild_sql_col(valid_col_units, sql, kmap):
  function build_foreign_key_map (line 812) | def build_foreign_key_map(entry):
  function build_foreign_key_map_from_json (line 852) | def build_foreign_key_map_from_json(table):

FILE: spider_evaluation/process_sql.py
  class Schema (line 48) | class Schema:
    method __init__ (line 52) | def __init__(self, schema):
    method schema (line 57) | def schema(self):
    method idMap (line 61) | def idMap(self):
    method _map (line 64) | def _map(self, schema):
  function get_schema (line 79) | def get_schema(db):
  function get_schema_from_json (line 103) | def get_schema_from_json(fpath):
  function tokenize (line 116) | def tokenize(string):
  function scan_alias (line 150) | def scan_alias(toks):
  function get_tables_with_alias (line 159) | def get_tables_with_alias(schema, toks):
  function parse_col (line 167) | def parse_col(toks, start_idx, tables_with_alias, schema, default_tables...
  function parse_col_unit (line 194) | def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_t...
  function parse_val_unit (line 232) | def parse_val_unit(toks, start_idx, tables_with_alias, schema, default_t...
  function parse_table_unit (line 257) | def parse_table_unit(toks, start_idx, tables_with_alias, schema):
  function parse_value (line 273) | def parse_value(toks, start_idx, tables_with_alias, schema, default_tabl...
  function parse_condition (line 307) | def parse_condition(toks, start_idx, tables_with_alias, schema, default_...
  function parse_select (line 344) | def parse_select(toks, start_idx, tables_with_alias, schema, default_tab...
  function parse_from (line 369) | def parse_from(toks, start_idx, tables_with_alias, schema):
  function parse_where (line 412) | def parse_where(toks, start_idx, tables_with_alias, schema, default_tabl...
  function parse_group_by (line 424) | def parse_group_by(toks, start_idx, tables_with_alias, schema, default_t...
  function parse_order_by (line 447) | def parse_order_by(toks, start_idx, tables_with_alias, schema, default_t...
  function parse_having (line 474) | def parse_having(toks, start_idx, tables_with_alias, schema, default_tab...
  function parse_limit (line 486) | def parse_limit(toks, start_idx):
  function parse_sql (line 501) | def parse_sql(toks, start_idx, tables_with_alias, schema, mapped_entitie...
  function load_data (line 559) | def load_data(fpath):
  function get_sql (line 565) | def get_sql(schema, query):
  function skip_semicolon (line 573) | def skip_semicolon(toks, start_idx):

FILE: state_machines/states/grammar_based_state.py
  class GrammarBasedState (line 16) | class GrammarBasedState(State['GrammarBasedState']):
    method __init__ (line 58) | def __init__(self,
    method new_state_from_group_index (line 78) | def new_state_from_group_index(self,
    method print_action_history (line 117) | def print_action_history(self, group_index: int = None) -> None:
    method get_valid_actions (line 125) | def get_valid_actions(self) -> List[Dict[str, Tuple[torch.Tensor, torc...
    method is_finished (line 133) | def is_finished(self) -> bool:
    method combine_states (line 139) | def combine_states(cls, states: Sequence['GrammarBasedState']) -> 'Gra...

FILE: state_machines/states/rnn_statelet.py
  class RnnStatelet (line 8) | class RnnStatelet:
    method __init__ (line 48) | def __init__(self,
    method __eq__ (line 64) | def __eq__(self, other):

FILE: state_machines/states/sql_state.py
  class SqlState (line 7) | class SqlState:
    method __init__ (line 8) | def __init__(self,
    method take_action (line 19) | def take_action(self, production_rule: str) -> 'SqlState':
    method get_valid_actions (line 63) | def get_valid_actions(self, valid_actions: dict):
    method _remove_actions (line 177) | def _remove_actions(valid_actions, key, ids_to_remove):
    method _get_current_open_clause (line 202) | def _get_current_open_clause(self):

FILE: state_machines/transition_functions/attend_past_schema_items_transition.py
  class AttendPastSchemaItemsTransitionFunction (line 17) | class AttendPastSchemaItemsTransitionFunction(BasicTransitionFunction):
    method __init__ (line 18) | def __init__(self,
    method take_step (line 43) | def take_step(self,
    method _update_decoder_state (line 74) | def _update_decoder_state(self, state: GrammarBasedState) -> Dict[str,...
    method _compute_action_probabilities (line 149) | def _compute_action_probabilities(self,
    method _construct_next_states (line 244) | def _construct_next_states(self,
    method attend (line 364) | def attend(self,

FILE: state_machines/transition_functions/basic_transition_function.py
  class BasicTransitionFunction (line 16) | class BasicTransitionFunction(TransitionFunction[GrammarBasedState]):
    method __init__ (line 53) | def __init__(self,
    method take_step (line 102) | def take_step(self,
    method _update_decoder_state (line 132) | def _update_decoder_state(self, state: GrammarBasedState) -> Dict[str,...
    method _compute_action_probabilities (line 192) | def _compute_action_probabilities(self,
    method _construct_next_states (line 232) | def _construct_next_states(self,
    method _take_first_step (line 333) | def _take_first_step(self,
    method attend_on_question (line 400) | def attend_on_question(self,

FILE: state_machines/transition_functions/linking_transition_function.py
  class LinkingTransitionFunction (line 16) | class LinkingTransitionFunction(BasicTransitionFunction):
    method __init__ (line 58) | def __init__(self,
    method _compute_action_probabilities (line 87) | def _compute_action_probabilities(self,

FILE: state_machines/transition_functions/prefix_attend_transition.py
  class PrefixAttendTransitionFunction (line 15) | class PrefixAttendTransitionFunction(LinkingTransitionFunction):
    method __init__ (line 16) | def __init__(self,
    method take_step (line 49) | def take_step(self,
    method _update_decoder_state (line 80) | def _update_decoder_state(self, state: GrammarBasedState) -> Dict[str,...
    method attend (line 146) | def attend(self,
    method _construct_next_states (line 166) | def _construct_next_states(self,
Condensed preview — 28 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (303K chars).
[
  {
    "path": "README.md",
    "chars": 3185,
    "preview": "# Representing Schema Structure with Graph Neural Networks for Text-to-SQL Parsing\n\nAuthor implementation of this [ACL 2"
  },
  {
    "path": "dataset_readers/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "dataset_readers/dataset_util/spider_utils.py",
    "chars": 9182,
    "preview": "\"\"\"\nUtility functions for reading the standardised text2sql datasets presented in\n`\"Improving Text to SQL Evaluation Met"
  },
  {
    "path": "dataset_readers/fields/knowledge_graph_field.py",
    "chars": 3549,
    "preview": "\"\"\"\n``KnowledgeGraphField`` is a ``Field`` which stores a knowledge graph representation.\n\"\"\"\nfrom typing import List, D"
  },
  {
    "path": "dataset_readers/spider.py",
    "chars": 7828,
    "preview": "import json\nimport logging\nimport os\nfrom typing import List, Dict\n\nimport dill\nfrom allennlp.common.checks import Confi"
  },
  {
    "path": "models/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "models/semantic_parsing/__init__.py",
    "chars": 62,
    "preview": "from models.semantic_parsing.spider_parser import SpiderParser"
  },
  {
    "path": "models/semantic_parsing/spider_parser.py",
    "chars": 43124,
    "preview": "import difflib\nimport os\nfrom functools import partial\nfrom typing import Dict, List, Tuple, Any, Mapping, Sequence\n\nimp"
  },
  {
    "path": "modules/gated_graph_conv.py",
    "chars": 3703,
    "preview": "import math\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Parameter as Param, init\nfrom torch_geometric.da"
  },
  {
    "path": "predictors/spider_predictor.py",
    "chars": 1229,
    "preview": "from overrides import overrides\n\nfrom allennlp.common.util import JsonDict, sanitize\nfrom allennlp.data import DatasetRe"
  },
  {
    "path": "requirements.txt",
    "chars": 259,
    "preview": "torch==1.0.1.post2\nspacy==2.0.18\nallennlp==0.8.2\ndill\ntorch-scatter==1.1.2\ntorch-sparse==0.2.4\ntorch-cluster==1.2.4\ntorc"
  },
  {
    "path": "semparse/contexts/spider_context_utils.py",
    "chars": 10171,
    "preview": "import re\nfrom collections import defaultdict\nfrom sys import exc_info\nfrom typing import List, Dict, Set\n\nfrom override"
  },
  {
    "path": "semparse/contexts/spider_db_context.py",
    "chars": 12887,
    "preview": "import re\nfrom collections import Set, defaultdict\nfrom typing import Dict, Tuple, List\n\nfrom allennlp.data import Token"
  },
  {
    "path": "semparse/contexts/spider_db_grammar.py",
    "chars": 7872,
    "preview": "# pylint: disable=anomalous-backslash-in-string\n\"\"\"\nA ``Text2SqlTableContext`` represents the SQL context in which an ut"
  },
  {
    "path": "semparse/worlds/evaluate.py",
    "chars": 30758,
    "preview": "################################\n# val: number(float)/string(str)/sql(dict)\n# col_unit: (agg_id, col_id, isDistinct(bool"
  },
  {
    "path": "semparse/worlds/evaluate_spider.py",
    "chars": 2230,
    "preview": "import os\nimport sqlite3\n\nfrom semparse.worlds.evaluate import Evaluator, build_valid_col_units, rebuild_sql_val, rebuil"
  },
  {
    "path": "semparse/worlds/spider_world.py",
    "chars": 7018,
    "preview": "from typing import List, Tuple, Dict, Set, Optional\nfrom copy import deepcopy\n\nfrom parsimonious import Grammar\nfrom par"
  },
  {
    "path": "spider_evaluation/evaluate.py",
    "chars": 30746,
    "preview": "################################\n# val: number(float)/string(str)/sql(dict)\n# col_unit: (agg_id, col_id, isDistinct(bool"
  },
  {
    "path": "spider_evaluation/process_sql.py",
    "chars": 17219,
    "preview": "################################\n# Assumptions:\n#   1. sql is correct\n#   2. only table name has alias\n#   3. only one i"
  },
  {
    "path": "state_machines/states/grammar_based_state.py",
    "chars": 9387,
    "preview": "from typing import Any, Dict, List, Sequence, Tuple\n\nimport torch\n\nfrom allennlp.data.fields.production_rule_field impor"
  },
  {
    "path": "state_machines/states/rnn_statelet.py",
    "chars": 3883,
    "preview": "from typing import List, Optional\n\nimport torch\n\nfrom allennlp.nn import util\n\n\nclass RnnStatelet:\n    \"\"\"\n    This clas"
  },
  {
    "path": "state_machines/states/sql_state.py",
    "chars": 10509,
    "preview": "import copy\nimport logging\n\nlogger = logging.getLogger(__name__)  # pylint: disable=invalid-name\n\n\nclass SqlState:\n    d"
  },
  {
    "path": "state_machines/transition_functions/attend_past_schema_items_transition.py",
    "chars": 22504,
    "preview": "from collections import defaultdict\nfrom typing import Dict, Tuple, List, Set, Any, Callable, Optional\n\nimport torch\nfro"
  },
  {
    "path": "state_machines/transition_functions/basic_transition_function.py",
    "chars": 25189,
    "preview": "from collections import defaultdict\nfrom typing import Any, Dict, List, Set, Tuple\n\nfrom overrides import overrides\n\nimp"
  },
  {
    "path": "state_machines/transition_functions/linking_transition_function.py",
    "chars": 10082,
    "preview": "from collections import defaultdict\nfrom typing import Any, Dict, List, Tuple\n\nfrom overrides import overrides\n\nimport t"
  },
  {
    "path": "state_machines/transition_functions/prefix_attend_transition.py",
    "chars": 16028,
    "preview": "from typing import Dict, Tuple, List, Set, Any, Callable\n\nimport torch\nfrom allennlp.modules import Attention, FeedForwa"
  },
  {
    "path": "train_configs/defaults.jsonnet",
    "chars": 1987,
    "preview": "local dataset_path = \"dataset/\";\n\n{\n  \"random_seed\": 5,\n  \"numpy_seed\": 5,\n  \"pytorch_seed\": 5,\n  \"dataset_reader\": {\n  "
  },
  {
    "path": "train_configs/paper_defaults.jsonnet",
    "chars": 1958,
    "preview": "local dataset_path = \"dataset/\";\n\n{\n  \"random_seed\": 5,\n  \"numpy_seed\": 5,\n  \"pytorch_seed\": 5,\n  \"dataset_reader\": {\n  "
  }
]

About this extraction

This page contains the full source code of the benbogin/spider-schema-gnn GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 28 files (285.7 KB), approximately 65.2k tokens, and a symbol index with 235 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!