Showing preview only (298K chars total). Download the full file or copy to clipboard to get everything.
Repository: benbogin/spider-schema-gnn
Branch: master
Commit: 02f4ae43b891
Files: 28
Total size: 285.7 KB
Directory structure:
gitextract_4_pkjgz2/
├── README.md
├── dataset_readers/
│ ├── __init__.py
│ ├── dataset_util/
│ │ └── spider_utils.py
│ ├── fields/
│ │ └── knowledge_graph_field.py
│ └── spider.py
├── models/
│ ├── __init__.py
│ └── semantic_parsing/
│ ├── __init__.py
│ └── spider_parser.py
├── modules/
│ └── gated_graph_conv.py
├── predictors/
│ └── spider_predictor.py
├── requirements.txt
├── semparse/
│ ├── contexts/
│ │ ├── spider_context_utils.py
│ │ ├── spider_db_context.py
│ │ └── spider_db_grammar.py
│ └── worlds/
│ ├── evaluate.py
│ ├── evaluate_spider.py
│ └── spider_world.py
├── spider_evaluation/
│ ├── evaluate.py
│ └── process_sql.py
├── state_machines/
│ ├── states/
│ │ ├── grammar_based_state.py
│ │ ├── rnn_statelet.py
│ │ └── sql_state.py
│ └── transition_functions/
│ ├── attend_past_schema_items_transition.py
│ ├── basic_transition_function.py
│ ├── linking_transition_function.py
│ └── prefix_attend_transition.py
└── train_configs/
├── defaults.jsonnet
└── paper_defaults.jsonnet
================================================
FILE CONTENTS
================================================
================================================
FILE: README.md
================================================
# Representing Schema Structure with Graph Neural Networks for Text-to-SQL Parsing
Author implementation of this [ACL 2019 paper](https://arxiv.org/abs/1905.06241).
Please also see the [follow-up repository](https://github.com/benbogin/spider-schema-gnn-global) with improved results, for this [EMNLP paper](https://www.aclweb.org/anthology/D19-1378.pdf).
## Install & Configure
1. Install pytorch version 1.0.1.post2 that fits your CUDA version
(this repository should probably work with the latest pytorch version, but wasn't tested for it. If you use another version, you'll need to also update the versions of packages in `requirements.txt`)
```
pip install https://download.pytorch.org/whl/cu100/torch-1.0.1.post2-cp37-cp37m-linux_x86_64.whl # CUDA 10.0 build
```
2. Install the rest of required packages
```
pip install -r requirements.txt
```
3. Run this command to install NLTK punkt.
```
python -c "import nltk; nltk.download('punkt')"
```
4. Download the dataset from the [official Spider dataset website](https://yale-lily.github.io/spider)
5. Edit the config file `train_configs/defaults.jsonnet` to update the location of the dataset:
```
local dataset_path = "dataset/";
```
## Training
1. Use the following AllenNLP command to train:
```
allennlp train train_configs/defaults.jsonnet -s experiments/name_of_experiment \
--include-package dataset_readers.spider \
--include-package models.semantic_parsing.spider_parser
```
First time loading of the dataset might take a while (a few hours) since the model first loads values from tables and calculates similarity features with the relevant question. It will then be cached for subsequent runs.
You should get results similar to the following (the `sql_match` is the one measured in the official evaluation test):
```
"best_validation__match/exact_match": 0.3715686274509804,
"best_validation_sql_match": 0.47549019607843135,
"best_validation__others/action_similarity": 0.5731271471206189,
"best_validation__match/match_single": 0.6254612546125461,
"best_validation__match/match_hard": 0.3054393305439331,
"best_validation_beam_hit": 0.6070588235294118,
"best_validation_loss": 7.383035182952881
"best_epoch": 32
```
Note that the hyper-parameters used in `defaults.jsonnet` are different than those mentioned in the paper
(most importantly, 3 timesteps are used instead of 2), thanks to the [following contribution from @wlhgtc](https://github.com/benbogin/spider-schema-gnn/pull/13).
The original training config file is still available in `train_configs/paper_Defaults.jsonnet`.
## Inference
Use the following AllenNLP command to output a file with the predicted queries:
```
allennlp predict experiments/name_of_experiment dataset/dev.json \
--predictor spider \
--use-dataset-reader \
--cuda-device=0 \
--output-file experiments/name_of_experiment/prediction.sql \
--silent \
--include-package models.semantic_parsing.spider_parser \
--include-package dataset_readers.spider \
--include-package predictors.spider_predictor \
--weights-file experiments/name_of_experiment/best.th \
-o "{\"dataset_reader\":{\"keep_if_unparsable\":true}}"
```
================================================
FILE: dataset_readers/__init__.py
================================================
================================================
FILE: dataset_readers/dataset_util/spider_utils.py
================================================
"""
Utility functions for reading the standardised text2sql datasets presented in
`"Improving Text to SQL Evaluation Methodology" <https://arxiv.org/abs/1806.09029>`_
"""
import json
import os
import sqlite3
from collections import defaultdict
from typing import List, Dict, Optional
from allennlp.common import JsonDict
from spider_evaluation.process_sql import get_tables_with_alias, parse_sql
class TableColumn:
def __init__(self,
name: str,
text: str,
column_type: str,
is_primary_key: bool,
foreign_key: Optional[str]):
self.name = name
self.text = text
self.column_type = column_type
self.is_primary_key = is_primary_key
self.foreign_key = foreign_key
class Table:
def __init__(self,
name: str,
text: str,
columns: List[TableColumn]):
self.name = name
self.text = text
self.columns = columns
def read_dataset_schema(schema_path: str) -> Dict[str, List[Table]]:
schemas: Dict[str, Dict[str, Table]] = defaultdict(dict)
dbs_json_blob = json.load(open(schema_path, "r"))
for db in dbs_json_blob:
db_id = db['db_id']
column_id_to_table = {}
column_id_to_column = {}
for i, (column, text, column_type) in enumerate(zip(db['column_names_original'], db['column_names'], db['column_types'])):
table_id, column_name = column
_, column_text = text
table_name = db['table_names_original'][table_id]
if table_name not in schemas[db_id]:
table_text = db['table_names'][table_id]
schemas[db_id][table_name] = Table(table_name, table_text, [])
if column_name == "*":
continue
is_primary_key = i in db['primary_keys']
table_column = TableColumn(column_name.lower(), column_text, column_type, is_primary_key, None)
schemas[db_id][table_name].columns.append(table_column)
column_id_to_table[i] = table_name
column_id_to_column[i] = table_column
for (c1, c2) in db['foreign_keys']:
foreign_key = column_id_to_table[c2] + ':' + column_id_to_column[c2].name
column_id_to_column[c1].foreign_key = foreign_key
return {**schemas}
def read_dataset_values(db_id: str, dataset_path: str, tables: List[str]):
db = os.path.join(dataset_path, db_id, db_id + ".sqlite")
try:
conn = sqlite3.connect(db)
except Exception as e:
raise Exception(f"Can't connect to SQL: {e} in path {db}")
conn.text_factory = str
cursor = conn.cursor()
values = {}
for table in tables:
try:
cursor.execute(f"SELECT * FROM {table.name} LIMIT 5000")
values[table] = cursor.fetchall()
except:
conn.text_factory = lambda x: str(x, 'latin1')
cursor = conn.cursor()
cursor.execute(f"SELECT * FROM {table.name} LIMIT 5000")
values[table] = cursor.fetchall()
return values
def ent_key_to_name(key):
parts = key.split(':')
if parts[0] == 'table':
return parts[1]
elif parts[0] == 'column':
_, _, table_name, column_name = parts
return f'{table_name}@{column_name}'
else:
return parts[1]
def fix_number_value(ex: JsonDict):
"""
There is something weird in the dataset files - the `query_toks_no_value` field anonymizes all values,
which is good since the evaluator doesn't check for the values. But it also anonymizes numbers that
should not be anonymized: e.g. LIMIT 3 becomes LIMIT 'value', while the evaluator fails if it is not a number.
"""
def split_and_keep(s, sep):
if not s: return [''] # consistent with string.split()
# Find replacement character that is not used in string
# i.e. just use the highest available character plus one
# Note: This fails if ord(max(s)) = 0x10FFFF (ValueError)
p = chr(ord(max(s)) + 1)
return s.replace(sep, p + sep + p).split(p)
# input is tokenized in different ways... so first try to make splits equal
query_toks = ex['query_toks']
ex['query_toks'] = []
for q in query_toks:
ex['query_toks'] += split_and_keep(q, '.')
i_val, i_no_val = 0, 0
while i_val < len(ex['query_toks']) and i_no_val < len(ex['query_toks_no_value']):
if ex['query_toks_no_value'][i_no_val] != 'value':
i_val += 1
i_no_val += 1
continue
i_val_end = i_val
while i_val + 1 < len(ex['query_toks']) and \
i_no_val + 1 < len(ex['query_toks_no_value']) and \
ex['query_toks'][i_val_end + 1].lower() != ex['query_toks_no_value'][i_no_val + 1].lower():
i_val_end += 1
if i_val == i_val_end and ex['query_toks'][i_val] in ["1", "2", "3"] and ex['query_toks'][i_val - 1].lower() == "limit":
ex['query_toks_no_value'][i_no_val] = ex['query_toks'][i_val]
i_val = i_val_end
i_val += 1
i_no_val += 1
return ex
_schemas_cache = None
def disambiguate_items(db_id: str, query_toks: List[str], tables_file: str, allow_aliases: bool) -> List[str]:
"""
we want the query tokens to be non-ambiguous - so we can change each column name to explicitly
tell which table it belongs to
parsed sql to sql clause is based on supermodel.gensql from syntaxsql
"""
class Schema:
"""
Simple schema which maps table&column to a unique identifier
"""
def __init__(self, schema, table):
self._schema = schema
self._table = table
self._idMap = self._map(self._schema, self._table)
@property
def schema(self):
return self._schema
@property
def idMap(self):
return self._idMap
def _map(self, schema, table):
column_names_original = table['column_names_original']
table_names_original = table['table_names_original']
# print 'column_names_original: ', column_names_original
# print 'table_names_original: ', table_names_original
for i, (tab_id, col) in enumerate(column_names_original):
if tab_id == -1:
idMap = {'*': i}
else:
key = table_names_original[tab_id].lower()
val = col.lower()
idMap[key + "." + val] = i
for i, tab in enumerate(table_names_original):
key = tab.lower()
idMap[key] = i
return idMap
def get_schemas_from_json(fpath):
global _schemas_cache
if _schemas_cache is not None:
return _schemas_cache
with open(fpath) as f:
data = json.load(f)
db_names = [db['db_id'] for db in data]
tables = {}
schemas = {}
for db in data:
db_id = db['db_id']
schema = {} # {'table': [col.lower, ..., ]} * -> __all__
column_names_original = db['column_names_original']
table_names_original = db['table_names_original']
tables[db_id] = {'column_names_original': column_names_original,
'table_names_original': table_names_original}
for i, tabn in enumerate(table_names_original):
table = str(tabn.lower())
cols = [str(col.lower()) for td, col in column_names_original if td == i]
schema[table] = cols
schemas[db_id] = schema
_schemas_cache = schemas, db_names, tables
return _schemas_cache
schemas, db_names, tables = get_schemas_from_json(tables_file)
schema = Schema(schemas[db_id], tables[db_id])
fixed_toks = []
i = 0
while i < len(query_toks):
tok = query_toks[i]
if tok == 'value' or tok == "'value'":
# TODO: value should alawys be between '/" (remove first if clause)
new_tok = f'"{tok}"'
elif tok in ['!','<','>'] and query_toks[i+1] == '=':
new_tok = tok + '='
i += 1
elif i+1 < len(query_toks) and query_toks[i+1] == '.':
new_tok = ''.join(query_toks[i:i+3])
i += 2
else:
new_tok = tok
fixed_toks.append(new_tok)
i += 1
toks = fixed_toks
tables_with_alias = get_tables_with_alias(schema.schema, toks)
_, sql, mapped_entities = parse_sql(toks, 0, tables_with_alias, schema, mapped_entities_fn=lambda: [])
for i, new_name in mapped_entities:
curr_tok = toks[i]
if '.' in curr_tok and allow_aliases:
parts = curr_tok.split('.')
assert(len(parts) == 2)
toks[i] = parts[0] + '.' + new_name
else:
toks[i] = new_name
if not allow_aliases:
toks = [tok for tok in toks if tok not in ['as', 't1', 't2', 't3', 't4']]
toks = [f'\'value\'' if tok == '"value"' else tok for tok in toks]
return toks
================================================
FILE: dataset_readers/fields/knowledge_graph_field.py
================================================
"""
``KnowledgeGraphField`` is a ``Field`` which stores a knowledge graph representation.
"""
from typing import List, Dict
from allennlp.data import TokenIndexer, Tokenizer
from allennlp.data.fields.knowledge_graph_field import KnowledgeGraphField
from allennlp.data.tokenizers.token import Token
from allennlp.semparse.contexts.knowledge_graph import KnowledgeGraph
class SpiderKnowledgeGraphField(KnowledgeGraphField):
"""
This implementation calculates all non-graph-related features (i.e. no related_column),
then takes each one of the features to calculate related column features, by taking the max score of all neighbours
"""
def __init__(self,
knowledge_graph: KnowledgeGraph,
utterance_tokens: List[Token],
token_indexers: Dict[str, TokenIndexer],
tokenizer: Tokenizer = None,
feature_extractors: List[str] = None,
entity_tokens: List[List[Token]] = None,
linking_features: List[List[List[float]]] = None,
include_in_vocab: bool = True,
max_table_tokens: int = None) -> None:
feature_extractors = feature_extractors if feature_extractors is not None else [
'number_token_match',
'exact_token_match',
'contains_exact_token_match',
'lemma_match',
'contains_lemma_match',
'edit_distance',
'span_overlap_fraction',
'span_lemma_overlap_fraction',
]
super().__init__(knowledge_graph, utterance_tokens, token_indexers,
tokenizer=tokenizer, feature_extractors=feature_extractors, entity_tokens=entity_tokens,
linking_features=linking_features, include_in_vocab=include_in_vocab,
max_table_tokens=max_table_tokens)
self.linking_features = self._compute_related_linking_features(self.linking_features)
# hack needed to fix calculation of feature extractors in the inherited as_tensor method
self._feature_extractors = feature_extractors * 2
def _compute_related_linking_features(self,
non_related_features: List[List[List[float]]]) -> List[List[List[float]]]:
linking_features = non_related_features
entity_to_index_map = {}
for entity_id, entity in enumerate(self.knowledge_graph.entities):
entity_to_index_map[entity] = entity_id
for entity_id, (entity, entity_text) in enumerate(zip(self.knowledge_graph.entities, self.entity_texts)):
for token_index, token in enumerate(self.utterance_tokens):
entity_token_features = linking_features[entity_id][token_index]
for feature_index, feature_extractor in enumerate(self._feature_extractors):
neighbour_features = []
for neighbor in self.knowledge_graph.neighbors[entity]:
# we only care about table/columns relations here, not foreign-primary
if entity.startswith('column') and neighbor.startswith('column'):
continue
neighbor_index = entity_to_index_map[neighbor]
neighbour_features.append(non_related_features[neighbor_index][token_index][feature_index])
entity_token_features.append(max(neighbour_features))
return linking_features
================================================
FILE: dataset_readers/spider.py
================================================
import json
import logging
import os
from typing import List, Dict
import dill
from allennlp.common.checks import ConfigurationError
from allennlp.data import DatasetReader, Tokenizer, TokenIndexer, Field, Instance
from allennlp.data.fields import TextField, ProductionRuleField, ListField, IndexField, MetadataField
from allennlp.data.token_indexers import SingleIdTokenIndexer
from allennlp.data.tokenizers import WordTokenizer
from allennlp.data.tokenizers.word_splitter import SpacyWordSplitter
from overrides import overrides
from spacy.symbols import ORTH, LEMMA
from dataset_readers.dataset_util.spider_utils import fix_number_value, disambiguate_items
from dataset_readers.fields.knowledge_graph_field import SpiderKnowledgeGraphField
from semparse.contexts.spider_db_context import SpiderDBContext
from semparse.worlds.spider_world import SpiderWorld
logger = logging.getLogger(__name__)
@DatasetReader.register("spider")
class SpiderDatasetReader(DatasetReader):
def __init__(self,
lazy: bool = False,
question_token_indexers: Dict[str, TokenIndexer] = None,
keep_if_unparsable: bool = True,
tables_file: str = None,
dataset_path: str = 'dataset/database',
load_cache: bool = True,
save_cache: bool = True,
loading_limit = -1):
super().__init__(lazy=lazy)
# default spacy tokenizer splits the common token 'id' to ['i', 'd'], we here write a manual fix for that
spacy_tokenizer = SpacyWordSplitter(pos_tags=True)
spacy_tokenizer.spacy.tokenizer.add_special_case(u'id', [{ORTH: u'id', LEMMA: u'id'}])
self._tokenizer = WordTokenizer(spacy_tokenizer)
self._utterance_token_indexers = question_token_indexers or {'tokens': SingleIdTokenIndexer()}
self._keep_if_unparsable = keep_if_unparsable
self._tables_file = tables_file
self._dataset_path = dataset_path
self._load_cache = load_cache
self._save_cache = save_cache
self._loading_limit = loading_limit
@overrides
def _read(self, file_path: str):
if not file_path.endswith('.json'):
raise ConfigurationError(f"Don't know how to read filetype of {file_path}")
cache_dir = os.path.join('cache', file_path.split("/")[-1])
if self._load_cache:
logger.info(f'Trying to load cache from {cache_dir}')
if self._save_cache:
os.makedirs(cache_dir, exist_ok=True)
cnt = 0
with open(file_path, "r") as data_file:
json_obj = json.load(data_file)
for total_cnt, ex in enumerate(json_obj):
cache_filename = f'instance-{total_cnt}.pt'
cache_filepath = os.path.join(cache_dir, cache_filename)
if self._loading_limit == cnt:
break
if self._load_cache:
try:
ins = dill.load(open(cache_filepath, 'rb'))
if ins is None and not self._keep_if_unparsable:
# skip unparsed examples
continue
yield ins
cnt += 1
continue
except Exception as e:
# could not load from cache - keep loading without cache
pass
query_tokens = None
if 'query_toks' in ex:
# we only have 'query_toks' in example for training/dev sets
# fix for examples: we want to use the 'query_toks_no_value' field of the example which anonymizes
# values. However, it also anonymizes numbers (e.g. LIMIT 3 -> LIMIT 'value', which is not good
# since the official evaluator does expect a number and not a value
ex = fix_number_value(ex)
# we want the query tokens to be non-ambiguous (i.e. know for each column the table it belongs to,
# and for each table alias its explicit name)
# we thus remove all aliases and make changes such as:
# 'name' -> 'singer@name',
# 'singer AS T1' -> 'singer',
# 'T1.name' -> 'singer@name'
try:
query_tokens = disambiguate_items(ex['db_id'], ex['query_toks_no_value'],
self._tables_file, allow_aliases=False)
except Exception as e:
# there are two examples in the train set that are wrongly formatted, skip them
print(f"error with {ex['query']}")
print(e)
ins = self.text_to_instance(
utterance=ex['question'],
db_id=ex['db_id'],
sql=query_tokens)
if ins is not None:
cnt += 1
if self._save_cache:
dill.dump(ins, open(cache_filepath, 'wb'))
if ins is not None:
yield ins
def text_to_instance(self,
utterance: str,
db_id: str,
sql: List[str] = None):
fields: Dict[str, Field] = {}
db_context = SpiderDBContext(db_id, utterance, tokenizer=self._tokenizer,
tables_file=self._tables_file, dataset_path=self._dataset_path)
table_field = SpiderKnowledgeGraphField(db_context.knowledge_graph,
db_context.tokenized_utterance,
self._utterance_token_indexers,
entity_tokens=db_context.entity_tokens,
include_in_vocab=False, # TODO: self._use_table_for_vocab,
max_table_tokens=None) # self._max_table_tokens)
world = SpiderWorld(db_context, query=sql)
fields["utterance"] = TextField(db_context.tokenized_utterance, self._utterance_token_indexers)
action_sequence, all_actions = world.get_action_sequence_and_all_actions()
if action_sequence is None and self._keep_if_unparsable:
# print("Parse error")
action_sequence = []
elif action_sequence is None:
return None
index_fields: List[Field] = []
production_rule_fields: List[Field] = []
for production_rule in all_actions:
nonterminal, rhs = production_rule.split(' -> ')
production_rule = ' '.join(production_rule.split(' '))
field = ProductionRuleField(production_rule,
world.is_global_rule(rhs),
nonterminal=nonterminal)
production_rule_fields.append(field)
valid_actions_field = ListField(production_rule_fields)
fields["valid_actions"] = valid_actions_field
action_map = {action.rule: i # type: ignore
for i, action in enumerate(valid_actions_field.field_list)}
for production_rule in action_sequence:
index_fields.append(IndexField(action_map[production_rule], valid_actions_field))
if not action_sequence:
index_fields = [IndexField(-1, valid_actions_field)]
action_sequence_field = ListField(index_fields)
fields["action_sequence"] = action_sequence_field
fields["world"] = MetadataField(world)
fields["schema"] = table_field
return Instance(fields)
================================================
FILE: models/__init__.py
================================================
================================================
FILE: models/semantic_parsing/__init__.py
================================================
from models.semantic_parsing.spider_parser import SpiderParser
================================================
FILE: models/semantic_parsing/spider_parser.py
================================================
import difflib
import os
from functools import partial
from typing import Dict, List, Tuple, Any, Mapping, Sequence
import sqlparse
import torch
from allennlp.common.util import pad_sequence_to_length
from allennlp.data import Vocabulary
from allennlp.data.fields.production_rule_field import ProductionRule, ProductionRuleArray
from allennlp.models import Model
from allennlp.modules import TextFieldEmbedder, Seq2SeqEncoder, Seq2VecEncoder, Embedding, Attention, FeedForward, \
TimeDistributed
from allennlp.modules.seq2vec_encoders import BagOfEmbeddingsEncoder
from allennlp.nn import util, Activation
from allennlp.state_machines import BeamSearch
from allennlp.state_machines.states import GrammarStatelet
from torch_geometric.data import Data, Batch
from modules.gated_graph_conv import GatedGraphConv
from semparse.worlds.evaluate_spider import evaluate
from state_machines.states.rnn_statelet import RnnStatelet
from allennlp.state_machines.trainers import MaximumMarginalLikelihood
from allennlp.training.metrics import Average
from overrides import overrides
from semparse.contexts.spider_context_utils import action_sequence_to_sql
from semparse.worlds.spider_world import SpiderWorld
from state_machines.states.grammar_based_state import GrammarBasedState
from state_machines.states.sql_state import SqlState
from state_machines.transition_functions.attend_past_schema_items_transition import \
AttendPastSchemaItemsTransitionFunction
from state_machines.transition_functions.linking_transition_function import LinkingTransitionFunction
@Model.register("spider")
class SpiderParser(Model):
def __init__(self,
vocab: Vocabulary,
encoder: Seq2SeqEncoder,
entity_encoder: Seq2VecEncoder,
decoder_beam_search: BeamSearch,
question_embedder: TextFieldEmbedder,
input_attention: Attention,
past_attention: Attention,
max_decoding_steps: int,
action_embedding_dim: int,
gnn: bool = True,
decoder_use_graph_entities: bool = True,
decoder_self_attend: bool = True,
gnn_timesteps: int = 2,
parse_sql_on_decoding: bool = True,
add_action_bias: bool = True,
use_neighbor_similarity_for_linking: bool = True,
dataset_path: str = 'dataset',
training_beam_size: int = None,
decoder_num_layers: int = 1,
dropout: float = 0.0,
rule_namespace: str = 'rule_labels',
scoring_dev_params: dict = None,
debug_parsing: bool = False) -> None:
super().__init__(vocab)
self.vocab = vocab
self._encoder = encoder
self._max_decoding_steps = max_decoding_steps
if dropout > 0:
self._dropout = torch.nn.Dropout(p=dropout)
else:
self._dropout = lambda x: x
self._rule_namespace = rule_namespace
self._question_embedder = question_embedder
self._add_action_bias = add_action_bias
self._scoring_dev_params = scoring_dev_params or {}
self.parse_sql_on_decoding = parse_sql_on_decoding
self._entity_encoder = TimeDistributed(entity_encoder)
self._use_neighbor_similarity_for_linking = use_neighbor_similarity_for_linking
self._self_attend = decoder_self_attend
self._decoder_use_graph_entities = decoder_use_graph_entities
self._action_padding_index = -1 # the padding value used by IndexField
self._exact_match = Average()
self._sql_evaluator_match = Average()
self._action_similarity = Average()
self._acc_single = Average()
self._acc_multi = Average()
self._beam_hit = Average()
self._action_embedding_dim = action_embedding_dim
num_actions = vocab.get_vocab_size(self._rule_namespace)
if self._add_action_bias:
input_action_dim = action_embedding_dim + 1
else:
input_action_dim = action_embedding_dim
self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=input_action_dim)
self._output_action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim)
encoder_output_dim = encoder.get_output_dim()
if gnn:
encoder_output_dim += action_embedding_dim
self._first_action_embedding = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim))
self._first_attended_utterance = torch.nn.Parameter(torch.FloatTensor(encoder_output_dim))
self._first_attended_output = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim))
torch.nn.init.normal_(self._first_action_embedding)
torch.nn.init.normal_(self._first_attended_utterance)
torch.nn.init.normal_(self._first_attended_output)
self._num_entity_types = 9
self._embedding_dim = question_embedder.get_output_dim()
self._entity_type_encoder_embedding = Embedding(self._num_entity_types, self._embedding_dim)
self._entity_type_decoder_embedding = Embedding(self._num_entity_types, action_embedding_dim)
self._linking_params = torch.nn.Linear(16, 1)
torch.nn.init.uniform_(self._linking_params.weight, 0, 1)
num_edge_types = 3
self._gnn = GatedGraphConv(self._embedding_dim, gnn_timesteps, num_edge_types=num_edge_types, dropout=dropout)
self._decoder_num_layers = decoder_num_layers
self._beam_search = decoder_beam_search
self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size)
if decoder_self_attend:
self._transition_function = AttendPastSchemaItemsTransitionFunction(encoder_output_dim=encoder_output_dim,
action_embedding_dim=action_embedding_dim,
input_attention=input_attention,
past_attention=past_attention,
predict_start_type_separately=False,
add_action_bias=self._add_action_bias,
dropout=dropout,
num_layers=self._decoder_num_layers)
else:
self._transition_function = LinkingTransitionFunction(encoder_output_dim=encoder_output_dim,
action_embedding_dim=action_embedding_dim,
input_attention=input_attention,
predict_start_type_separately=False,
add_action_bias=self._add_action_bias,
dropout=dropout,
num_layers=self._decoder_num_layers)
self._ent2ent_ff = FeedForward(action_embedding_dim, 1, action_embedding_dim, Activation.by_name('relu')())
self._neighbor_params = torch.nn.Linear(self._embedding_dim, self._embedding_dim)
# TODO: Remove hard-coded dirs
self._evaluate_func = partial(evaluate,
db_dir=os.path.join(dataset_path, 'database'),
table=os.path.join(dataset_path, 'tables.json'),
check_valid=False)
self.debug_parsing = debug_parsing
@overrides
def forward(self, # type: ignore
utterance: Dict[str, torch.LongTensor],
valid_actions: List[List[ProductionRule]],
world: List[SpiderWorld],
schema: Dict[str, torch.LongTensor],
action_sequence: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
batch_size = len(world)
device = utterance['tokens'].device
initial_state = self._get_initial_state(utterance, world, schema, valid_actions)
if action_sequence is not None:
# Remove the trailing dimension (from ListField[ListField[IndexField]]).
action_sequence = action_sequence.squeeze(-1)
action_mask = action_sequence != self._action_padding_index
else:
action_mask = None
if self.training:
decode_output = self._decoder_trainer.decode(initial_state,
self._transition_function,
(action_sequence.unsqueeze(1), action_mask.unsqueeze(1)))
return {'loss': decode_output['loss']}
else:
loss = torch.tensor([0]).float().to(device)
if action_sequence is not None and action_sequence.size(1) > 1:
try:
loss = self._decoder_trainer.decode(initial_state,
self._transition_function,
(action_sequence.unsqueeze(1),
action_mask.unsqueeze(1)))['loss']
except ZeroDivisionError:
# reached a dead-end during beam search
pass
outputs: Dict[str, Any] = {
'loss': loss
}
num_steps = self._max_decoding_steps
# This tells the state to start keeping track of debug info, which we'll pass along in
# our output dictionary.
initial_state.debug_info = [[] for _ in range(batch_size)]
best_final_states = self._beam_search.search(num_steps,
initial_state,
self._transition_function,
keep_final_unfinished_states=False)
self._compute_validation_outputs(valid_actions,
best_final_states,
world,
action_sequence,
outputs)
return outputs
def _get_initial_state(self,
utterance: Dict[str, torch.LongTensor],
worlds: List[SpiderWorld],
schema: Dict[str, torch.LongTensor],
actions: List[List[ProductionRule]]) -> GrammarBasedState:
schema_text = schema['text']
embedded_schema = self._question_embedder(schema_text, num_wrapping_dims=1)
schema_mask = util.get_text_field_mask(schema_text, num_wrapping_dims=1).float()
embedded_utterance = self._question_embedder(utterance)
utterance_mask = util.get_text_field_mask(utterance).float()
batch_size, num_entities, num_entity_tokens, _ = embedded_schema.size()
num_entities = max([len(world.db_context.knowledge_graph.entities) for world in worlds])
num_question_tokens = utterance['tokens'].size(1)
# entity_types: tensor with shape (batch_size, num_entities), where each entry is the
# entity's type id.
# entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
# These encode the same information, but for efficiency reasons later it's nice
# to have one version as a tensor and one that's accessible on the cpu.
entity_types, entity_type_dict = self._get_type_vector(worlds, num_entities, embedded_schema.device)
entity_type_embeddings = self._entity_type_encoder_embedding(entity_types)
# Compute entity and question word similarity. We tried using cosine distance here, but
# because this similarity is the main mechanism that the model can use to push apart logit
# scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger
# output range than [-1, 1].
question_entity_similarity = torch.bmm(embedded_schema.view(batch_size,
num_entities * num_entity_tokens,
self._embedding_dim),
torch.transpose(embedded_utterance, 1, 2))
question_entity_similarity = question_entity_similarity.view(batch_size,
num_entities,
num_entity_tokens,
num_question_tokens)
# (batch_size, num_entities, num_question_tokens)
question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2)
# (batch_size, num_entities, num_question_tokens, num_features)
linking_features = schema['linking']
linking_scores = question_entity_similarity_max_score
feature_scores = self._linking_params(linking_features).squeeze(3)
linking_scores = linking_scores + feature_scores
# (batch_size, num_question_tokens, num_entities)
linking_probabilities = self._get_linking_probabilities(worlds, linking_scores.transpose(1, 2),
utterance_mask, entity_type_dict)
# (batch_size, num_entities, num_neighbors) or None
neighbor_indices = self._get_neighbor_indices(worlds, num_entities, linking_scores.device)
if self._use_neighbor_similarity_for_linking and neighbor_indices is not None:
# (batch_size, num_entities, embedding_dim)
encoded_table = self._entity_encoder(embedded_schema, schema_mask)
# Neighbor_indices is padded with -1 since 0 is a potential neighbor index.
# Thus, the absolute value needs to be taken in the index_select, and 1 needs to
# be added for the mask since that method expects 0 for padding.
# (batch_size, num_entities, num_neighbors, embedding_dim)
embedded_neighbors = util.batched_index_select(encoded_table, torch.abs(neighbor_indices))
neighbor_mask = util.get_text_field_mask({'ignored': neighbor_indices + 1},
num_wrapping_dims=1).float()
# Encoder initialized to easily obtain a masked average.
neighbor_encoder = TimeDistributed(BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True))
# (batch_size, num_entities, embedding_dim)
embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask)
projected_neighbor_embeddings = self._neighbor_params(embedded_neighbors.float())
# (batch_size, num_entities, embedding_dim)
entity_embeddings = torch.tanh(entity_type_embeddings + projected_neighbor_embeddings)
else:
# (batch_size, num_entities, embedding_dim)
entity_embeddings = torch.tanh(entity_type_embeddings)
link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities)
encoder_input = torch.cat([link_embedding, embedded_utterance], 2)
# (batch_size, utterance_length, encoder_output_dim)
encoder_outputs = self._dropout(self._encoder(encoder_input, utterance_mask))
max_entities_relevance = linking_probabilities.max(dim=1)[0]
entities_relevance = max_entities_relevance.unsqueeze(-1).detach()
graph_initial_embedding = entity_type_embeddings * entities_relevance
encoder_output_dim = self._encoder.get_output_dim()
if self._gnn:
entities_graph_encoding = self._get_schema_graph_encoding(worlds,
graph_initial_embedding)
graph_link_embedding = util.weighted_sum(entities_graph_encoding, linking_probabilities)
encoder_outputs = torch.cat((
encoder_outputs,
graph_link_embedding
), dim=-1)
encoder_output_dim = self._action_embedding_dim + self._encoder.get_output_dim()
else:
entities_graph_encoding = None
if self._self_attend:
# linked_actions_linking_scores = self._get_linked_actions_linking_scores(actions, entities_graph_encoding)
entities_ff = self._ent2ent_ff(entities_graph_encoding)
linked_actions_linking_scores = torch.bmm(entities_ff, entities_ff.transpose(1, 2))
else:
linked_actions_linking_scores = [None] * batch_size
# This will be our initial hidden state and memory cell for the decoder LSTM.
final_encoder_output = util.get_final_encoder_states(encoder_outputs,
utterance_mask,
self._encoder.is_bidirectional())
memory_cell = encoder_outputs.new_zeros(batch_size, encoder_output_dim)
initial_score = embedded_utterance.data.new_zeros(batch_size)
# To make grouping states together in the decoder easier, we convert the batch dimension in
# all of our tensors into an outer list. For instance, the encoder outputs have shape
# `(batch_size, utterance_length, encoder_output_dim)`. We need to convert this into a list
# of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`. Then we
# won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
initial_score_list = [initial_score[i] for i in range(batch_size)]
encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
utterance_mask_list = [utterance_mask[i] for i in range(batch_size)]
initial_rnn_state = []
for i in range(batch_size):
initial_rnn_state.append(RnnStatelet(final_encoder_output[i],
memory_cell[i],
self._first_action_embedding,
self._first_attended_utterance,
encoder_output_list,
utterance_mask_list))
initial_grammar_state = [self._create_grammar_state(worlds[i],
actions[i],
linking_scores[i],
linked_actions_linking_scores[i],
entity_types[i],
entities_graph_encoding[
i] if entities_graph_encoding is not None else None)
for i in range(batch_size)]
initial_sql_state = [SqlState(actions[i], self.parse_sql_on_decoding) for i in range(batch_size)]
initial_state = GrammarBasedState(batch_indices=list(range(batch_size)),
action_history=[[] for _ in range(batch_size)],
score=initial_score_list,
rnn_state=initial_rnn_state,
grammar_state=initial_grammar_state,
sql_state=initial_sql_state,
possible_actions=actions,
action_entity_mapping=[w.get_action_entity_mapping() for w in worlds])
return initial_state
@staticmethod
def _get_neighbor_indices(worlds: List[SpiderWorld],
num_entities: int,
device: torch.device) -> torch.LongTensor:
"""
This method returns the indices of each entity's neighbors. A tensor
is accepted as a parameter for copying purposes.
Parameters
----------
worlds : ``List[SpiderWorld]``
num_entities : ``int``
tensor : ``torch.Tensor``
Used for copying the constructed list onto the right device.
Returns
-------
A ``torch.LongTensor`` with shape ``(batch_size, num_entities, num_neighbors)``. It is padded
with -1 instead of 0, since 0 is a valid neighbor index. If all the entities in the batch
have no neighbors, None will be returned.
"""
num_neighbors = 0
for world in worlds:
for entity in world.db_context.knowledge_graph.entities:
if len(world.db_context.knowledge_graph.neighbors[entity]) > num_neighbors:
num_neighbors = len(world.db_context.knowledge_graph.neighbors[entity])
batch_neighbors = []
no_entities_have_neighbors = True
for world in worlds:
# Each batch instance has its own world, which has a corresponding table.
entities = world.db_context.knowledge_graph.entities
entity2index = {entity: i for i, entity in enumerate(entities)}
entity2neighbors = world.db_context.knowledge_graph.neighbors
neighbor_indexes = []
for entity in entities:
entity_neighbors = [entity2index[n] for n in entity2neighbors[entity]]
if entity_neighbors:
no_entities_have_neighbors = False
# Pad with -1 instead of 0, since 0 represents a neighbor index.
padded = pad_sequence_to_length(entity_neighbors, num_neighbors, lambda: -1)
neighbor_indexes.append(padded)
neighbor_indexes = pad_sequence_to_length(neighbor_indexes,
num_entities,
lambda: [-1] * num_neighbors)
batch_neighbors.append(neighbor_indexes)
# It is possible that none of the entities has any neighbors, since our definition of the
# knowledge graph allows it when no entities or numbers were extracted from the question.
if no_entities_have_neighbors:
return None
return torch.tensor(batch_neighbors, device=device, dtype=torch.long)
def _get_schema_graph_encoding(self,
worlds: List[SpiderWorld],
initial_graph_embeddings: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
max_num_entities = max([len(world.db_context.knowledge_graph.entities) for world in worlds])
batch_size = initial_graph_embeddings.size(0)
graph_data_list = []
for batch_index, world in enumerate(worlds):
x = initial_graph_embeddings[batch_index]
adj_list = self._get_graph_adj_lists(initial_graph_embeddings.device,
world, initial_graph_embeddings.size(1) - 1)
graph_data = Data(x)
for i, l in enumerate(adj_list):
graph_data[f'edge_index_{i}'] = l
graph_data_list.append(graph_data)
batch = Batch.from_data_list(graph_data_list)
gnn_output = self._gnn(batch.x, [batch[f'edge_index_{i}'] for i in range(self._gnn.num_edge_types)])
num_nodes = max_num_entities
gnn_output = gnn_output.view(batch_size, num_nodes, -1)
# entities_encodings = gnn_output
entities_encodings = gnn_output[:, :max_num_entities]
# global_node_encodings = gnn_output[:, max_num_entities]
return entities_encodings
@staticmethod
def _get_graph_adj_lists(device, world, global_entity_id, global_node=False):
entity_mapping = {}
for i, entity in enumerate(world.db_context.knowledge_graph.entities):
entity_mapping[entity] = i
entity_mapping['_global_'] = global_entity_id
adj_list_own = [] # column--table
adj_list_link = [] # table->table / foreign->primary
adj_list_linked = [] # table<-table / foreign<-primary
adj_list_global = [] # node->global
# TODO: Prepare in advance?
for key, neighbors in world.db_context.knowledge_graph.neighbors.items():
idx_source = entity_mapping[key]
for n_key in neighbors:
idx_target = entity_mapping[n_key]
if n_key.startswith("table") or key.startswith("table"):
adj_list_own.append((idx_source, idx_target))
elif n_key.startswith("string") or key.startswith("string"):
adj_list_own.append((idx_source, idx_target))
elif key.startswith("column:foreign"):
adj_list_link.append((idx_source, idx_target))
src_table_key = f"table:{key.split(':')[2]}"
tgt_table_key = f"table:{n_key.split(':')[2]}"
idx_source_table = entity_mapping[src_table_key]
idx_target_table = entity_mapping[tgt_table_key]
adj_list_link.append((idx_source_table, idx_target_table))
elif n_key.startswith("column:foreign"):
adj_list_linked.append((idx_source, idx_target))
src_table_key = f"table:{key.split(':')[2]}"
tgt_table_key = f"table:{n_key.split(':')[2]}"
idx_source_table = entity_mapping[src_table_key]
idx_target_table = entity_mapping[tgt_table_key]
adj_list_linked.append((idx_source_table, idx_target_table))
else:
assert False
adj_list_global.append((idx_source, entity_mapping['_global_']))
all_adj_types = [adj_list_own, adj_list_link, adj_list_linked]
if global_node:
all_adj_types.append(adj_list_global)
return [torch.tensor(l, device=device, dtype=torch.long).transpose(0, 1) if l
else torch.tensor(l, device=device, dtype=torch.long)
for l in all_adj_types]
def _create_grammar_state(self,
world: SpiderWorld,
possible_actions: List[ProductionRule],
linking_scores: torch.Tensor,
linked_actions_linking_scores: torch.Tensor,
entity_types: torch.Tensor,
entity_graph_encoding: torch.Tensor) -> GrammarStatelet:
action_map = {}
for action_index, action in enumerate(possible_actions):
action_string = action[0]
action_map[action_string] = action_index
valid_actions = world.valid_actions
entity_map = {}
entities = world.entities_names
for entity_index, entity in enumerate(entities):
entity_map[entity] = entity_index
translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {}
for key, action_strings in valid_actions.items():
translated_valid_actions[key] = {}
# `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
# productions of that non-terminal. We'll first split those productions by global vs.
# linked action.
action_indices = [action_map[action_string] for action_string in action_strings]
production_rule_arrays = [(possible_actions[index], index) for index in action_indices]
global_actions = []
linked_actions = []
for production_rule_array, action_index in production_rule_arrays:
if production_rule_array[1]:
global_actions.append((production_rule_array[2], action_index))
else:
linked_actions.append((production_rule_array[0], action_index))
if global_actions:
global_action_tensors, global_action_ids = zip(*global_actions)
global_action_tensor = torch.cat(global_action_tensors, dim=0).to(
global_action_tensors[0].device).long()
global_input_embeddings = self._action_embedder(global_action_tensor)
global_output_embeddings = self._output_action_embedder(global_action_tensor)
translated_valid_actions[key]['global'] = (global_input_embeddings,
global_output_embeddings,
list(global_action_ids))
if linked_actions:
linked_rules, linked_action_ids = zip(*linked_actions)
entities = [rule.split(' -> ')[1].strip('[]\"') for rule in linked_rules]
entity_ids = [entity_map[entity] for entity in entities]
entity_linking_scores = linking_scores[entity_ids]
if linked_actions_linking_scores is not None:
entity_action_linking_scores = linked_actions_linking_scores[entity_ids]
if not self._decoder_use_graph_entities:
entity_type_tensor = entity_types[entity_ids]
entity_type_embeddings = (self._entity_type_decoder_embedding(entity_type_tensor)
.to(entity_types.device)
.float())
else:
entity_type_embeddings = entity_graph_encoding.index_select(
dim=0,
index=torch.tensor(entity_ids, device=entity_graph_encoding.device)
)
if self._self_attend:
translated_valid_actions[key]['linked'] = (entity_linking_scores,
entity_type_embeddings,
list(linked_action_ids),
entity_action_linking_scores)
else:
translated_valid_actions[key]['linked'] = (entity_linking_scores,
entity_type_embeddings,
list(linked_action_ids))
return GrammarStatelet(['statement'],
translated_valid_actions,
self.is_nonterminal)
@staticmethod
def is_nonterminal(token: str):
if token[0] == '"' and token[-1] == '"':
return False
return True
def _get_linking_probabilities(self,
worlds: List[SpiderWorld],
linking_scores: torch.FloatTensor,
question_mask: torch.LongTensor,
entity_type_dict: Dict[int, int]) -> torch.FloatTensor:
"""
Produces the probability of an entity given a question word and type. The logic below
separates the entities by type since the softmax normalization term sums over entities
of a single type.
Parameters
----------
worlds : ``List[WikiTablesWorld]``
linking_scores : ``torch.FloatTensor``
Has shape (batch_size, num_question_tokens, num_entities).
question_mask: ``torch.LongTensor``
Has shape (batch_size, num_question_tokens).
entity_type_dict : ``Dict[int, int]``
This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.
Returns
-------
batch_probabilities : ``torch.FloatTensor``
Has shape ``(batch_size, num_question_tokens, num_entities)``.
Contains all the probabilities for an entity given a question word.
"""
_, num_question_tokens, num_entities = linking_scores.size()
batch_probabilities = []
for batch_index, world in enumerate(worlds):
all_probabilities = []
num_entities_in_instance = 0
# NOTE: The way that we're doing this here relies on the fact that entities are
# implicitly sorted by their types when we sort them by name, and that numbers come
# before "date_column:", followed by "number_column:", "string:", and "string_column:".
# This is not a great assumption, and could easily break later, but it should work for now.
for type_index in range(self._num_entity_types):
# This index of 0 is for the null entity for each type, representing the case where a
# word doesn't link to any entity.
entity_indices = [0]
entities = world.db_context.knowledge_graph.entities
for entity_index, _ in enumerate(entities):
if entity_type_dict[batch_index * num_entities + entity_index] == type_index:
entity_indices.append(entity_index)
if len(entity_indices) == 1:
# No entities of this type; move along...
continue
# We're subtracting one here because of the null entity we added above.
num_entities_in_instance += len(entity_indices) - 1
# We separate the scores by type, since normalization is done per type. There's an
# extra "null" entity per type, also, so we have `num_entities_per_type + 1`. We're
# selecting from a (num_question_tokens, num_entities) linking tensor on _dimension 1_,
# so we get back something of shape (num_question_tokens,) for each index we're
# selecting. All of the selected indices together then make a tensor of shape
# (num_question_tokens, num_entities_per_type + 1).
indices = linking_scores.new_tensor(entity_indices, dtype=torch.long)
entity_scores = linking_scores[batch_index].index_select(1, indices)
# We used index 0 for the null entity, so this will actually have some values in it.
# But we want the null entity's score to be 0, so we set that here.
entity_scores[:, 0] = 0
# No need for a mask here, as this is done per batch instance, with no padding.
type_probabilities = torch.nn.functional.softmax(entity_scores, dim=1)
all_probabilities.append(type_probabilities[:, 1:])
# We need to add padding here if we don't have the right number of entities.
if num_entities_in_instance != num_entities:
zeros = linking_scores.new_zeros(num_question_tokens,
num_entities - num_entities_in_instance)
all_probabilities.append(zeros)
# (num_question_tokens, num_entities)
probabilities = torch.cat(all_probabilities, dim=1)
batch_probabilities.append(probabilities)
batch_probabilities = torch.stack(batch_probabilities, dim=0)
return batch_probabilities * question_mask.unsqueeze(-1).float()
@staticmethod
def _action_history_match(predicted: List[int], targets: torch.LongTensor) -> int:
# TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something.
# Check if target is big enough to cover prediction (including start/end symbols)
if len(predicted) > targets.size(0):
return 0
predicted_tensor = targets.new_tensor(predicted)
targets_trimmed = targets[:len(predicted)]
# Return 1 if the predicted sequence is anywhere in the list of targets.
return torch.max(torch.min(targets_trimmed.eq(predicted_tensor), dim=0)[0]).item()
@staticmethod
def _query_difficulty(targets: torch.LongTensor, action_mapping, batch_index):
number_tables = len([action_mapping[(batch_index, int(a))] for a in targets if
a >= 0 and action_mapping[(batch_index, int(a))].startswith('table_name')])
return number_tables > 1
@overrides
def get_metrics(self, reset: bool = False) -> Dict[str, float]:
return {
'_match/exact_match': self._exact_match.get_metric(reset),
'sql_match': self._sql_evaluator_match.get_metric(reset),
'_others/action_similarity': self._action_similarity.get_metric(reset),
'_match/match_single': self._acc_single.get_metric(reset),
'_match/match_hard': self._acc_multi.get_metric(reset),
'beam_hit': self._beam_hit.get_metric(reset)
}
@staticmethod
def _get_type_vector(worlds: List[SpiderWorld],
num_entities: int,
device) -> Tuple[torch.LongTensor, Dict[int, int]]:
"""
Produces the encoding for each entity's type. In addition, a map from a flattened entity
index to type is returned to combine entity type operations into one method.
Parameters
----------
worlds : ``List[AtisWorld]``
num_entities : ``int``
tensor : ``torch.Tensor``
Used for copying the constructed list onto the right device.
Returns
-------
A ``torch.LongTensor`` with shape ``(batch_size, num_entities, num_types)``.
entity_types : ``Dict[int, int]``
This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.
"""
entity_types = {}
batch_types = []
column_type_ids = ['boolean', 'foreign', 'number', 'others', 'primary', 'text', 'time']
for batch_index, world in enumerate(worlds):
types = []
for entity_index, entity in enumerate(world.db_context.knowledge_graph.entities):
parts = entity.split(':')
entity_main_type = parts[0]
if entity_main_type == 'column':
column_type = parts[1]
entity_type = column_type_ids.index(column_type)
elif entity_main_type == 'string':
# cell value
entity_type = len(column_type_ids)
elif entity_main_type == 'table':
entity_type = len(column_type_ids) + 1
else:
raise (Exception("Unkown entity"))
types.append(entity_type)
# For easier lookups later, we're actually using a _flattened_ version
# of (batch_index, entity_index) for the key, because this is how the
# linking scores are stored.
flattened_entity_index = batch_index * num_entities + entity_index
entity_types[flattened_entity_index] = entity_type
padded = pad_sequence_to_length(types, num_entities, lambda: 0)
batch_types.append(padded)
return torch.tensor(batch_types, dtype=torch.long, device=device), entity_types
def _compute_validation_outputs(self,
actions: List[List[ProductionRuleArray]],
best_final_states: Mapping[int, Sequence[GrammarBasedState]],
world: List[SpiderWorld],
target_list: List[List[str]],
outputs: Dict[str, Any]) -> None:
batch_size = len(actions)
outputs['predicted_sql_query'] = []
action_mapping = {}
for batch_index, batch_actions in enumerate(actions):
for action_index, action in enumerate(batch_actions):
action_mapping[(batch_index, action_index)] = action[0]
for i in range(batch_size):
# gold sql exactly as given
original_gold_sql_query = ' '.join(world[i].get_query_without_table_hints())
if i not in best_final_states:
self._exact_match(0)
self._action_similarity(0)
self._sql_evaluator_match(0)
self._acc_multi(0)
self._acc_single(0)
outputs['predicted_sql_query'].append('')
continue
best_action_indices = best_final_states[i][0].action_history[0]
action_strings = [action_mapping[(i, action_index)]
for action_index in best_action_indices]
predicted_sql_query = action_sequence_to_sql(action_strings, add_table_names=True)
outputs['predicted_sql_query'].append(sqlparse.format(predicted_sql_query, reindent=False))
if target_list is not None:
targets = target_list[i].data
target_available = target_list is not None and targets[0] > -1
if target_available:
sequence_in_targets = self._action_history_match(best_action_indices, targets)
self._exact_match(sequence_in_targets)
sql_evaluator_match = self._evaluate_func(original_gold_sql_query, predicted_sql_query, world[i].db_id)
self._sql_evaluator_match(sql_evaluator_match)
similarity = difflib.SequenceMatcher(None, best_action_indices, targets)
self._action_similarity(similarity.ratio())
difficulty = self._query_difficulty(targets, action_mapping, i)
if difficulty:
self._acc_multi(sql_evaluator_match)
else:
self._acc_single(sql_evaluator_match)
beam_hit = False
for pos, final_state in enumerate(best_final_states[i]):
action_indices = final_state.action_history[0]
action_strings = [action_mapping[(i, action_index)]
for action_index in action_indices]
candidate_sql_query = action_sequence_to_sql(action_strings, add_table_names=True)
if target_available:
correct = self._evaluate_func(original_gold_sql_query, candidate_sql_query, world[i].db_id)
if correct:
beam_hit = True
self._beam_hit(beam_hit)
================================================
FILE: modules/gated_graph_conv.py
================================================
import math
import torch
from torch import Tensor
from torch.nn import Parameter as Param, init
from torch_geometric.data import Data
from torch_geometric.nn.conv import MessagePassing
class GatedGraphConv(MessagePassing):
r"""The gated graph convolution operator from the `"Gated Graph Sequence
Neural Networks" <https://arxiv.org/abs/1511.05493>`_ paper
.. math::
\mathbf{h}_i^{(0)} &= \mathbf{x}_i \, \Vert \, \mathbf{0}
\mathbf{m}_i^{(l+1)} &= \sum_{j \in \mathcal{N}(i)} \mathbf{\Theta}
\cdot \mathbf{h}_j^{(l)}
\mathbf{h}_i^{(l+1)} &= \textrm{GRU} (\mathbf{m}_i^{(l+1)},
\mathbf{h}_i^{(l)})
up to representation :math:`\mathbf{h}_i^{(L)}`.
The number of input channels of :math:`\mathbf{x}_i` needs to be less or
equal than :obj:`out_channels`.
Args:
out_channels (int): Size of each input sample.
num_layers (int): The sequence length :math:`L`.
aggr (string): The aggregation scheme to use
(:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`).
(default: :obj:`"add"`)
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
"""
def __init__(self, input_dim, num_timesteps, num_edge_types, aggr='add', bias=True, dropout=0):
super(GatedGraphConv, self).__init__(aggr)
self._input_dim = input_dim
self.num_timesteps = num_timesteps
self.num_edge_types = num_edge_types
self.weight = Param(Tensor(num_timesteps, num_edge_types, input_dim, input_dim))
self.bias = Param(Tensor(num_timesteps, num_edge_types, input_dim))
self.rnn = torch.nn.GRUCell(input_dim, input_dim, bias=bias)
self.dropout = torch.nn.Dropout(dropout)
self.reset_parameters()
def reset_parameters(self):
for t in range(self.num_timesteps):
for e in range(self.num_edge_types):
torch.nn.init.xavier_uniform_(self.weight[t, e])
init.uniform_(self.bias, -0.01, 0.01)
self.rnn.reset_parameters()
def forward(self, x, edge_indices):
""""""
if len(edge_indices) != self.num_edge_types:
raise ValueError(f'GatedGraphConv constructed with {self.num_edge_types} edge types, '
f'but {len(edge_indices)} were passed')
h = x
if h.size(1) > self._input_dim:
raise ValueError('The number of input channels is not allowed to '
'be larger than the number of output channels')
if h.size(1) < self._input_dim:
zero = h.new_zeros(h.size(0), self._input_dim - h.size(1))
h = torch.cat([h, zero], dim=1)
for t in range(self.num_timesteps):
new_h = []
for e in range(self.num_edge_types):
if len(edge_indices[e]) == 0:
continue
m = self.dropout(torch.matmul(h, self.weight[t, e]) + self.bias[t, e])
new_h.append(self.propagate(edge_indices[e], size=(x.size(0), x.size(0)), x=m))
m_sum = torch.sum(torch.stack(new_h), dim=0)
h = self.rnn(m_sum, h)
return h
def __repr__(self):
return '{}({}, num_layers={})'.format(
self.__class__.__name__, self._input_dim, self.num_timesteps)
if __name__ == '__main__':
gcn = GatedGraphConv(input_dim=10, num_timesteps=3, num_edge_types=3)
data = Data(torch.zeros((5, 10)), edge_index=[
torch.tensor([[1,2],[2,3]]),
torch.tensor([[1,3],[0,1]]),
torch.tensor([[1,4],[2,3]]),
])
output = gcn(data.x, data.edge_index)
print(output)
================================================
FILE: predictors/spider_predictor.py
================================================
from overrides import overrides
from allennlp.common.util import JsonDict, sanitize
from allennlp.data import DatasetReader, Instance
from allennlp.models import Model
from allennlp.predictors.predictor import Predictor
@Predictor.register("spider")
class WikiTablesParserPredictor(Predictor):
def __init__(self, model: Model, dataset_reader: DatasetReader) -> None:
super().__init__(model, dataset_reader)
@overrides
def predict_instance(self, instance: Instance) -> JsonDict:
json_output = {}
outputs = self._model.forward_on_instance(instance)
predicted_sql_query = outputs['predicted_sql_query'].replace('\n', ' ')
if predicted_sql_query == '':
# line must not be empty for the evaluator to consider it
predicted_sql_query = 'NO PREDICTION'
json_output['predicted_sql_query'] = predicted_sql_query
return sanitize(json_output)
@overrides
def dump_line(self, outputs: JsonDict) -> str: # pylint: disable=no-self-use
"""
If you don't want your outputs in JSON-lines format
you can override this function to output them differently.
"""
return outputs['predicted_sql_query'] + "\n"
================================================
FILE: requirements.txt
================================================
torch==1.0.1.post2
spacy==2.0.18
allennlp==0.8.2
dill
torch-scatter==1.1.2
torch-sparse==0.2.4
torch-cluster==1.2.4
torch-geometric==1.1.2
ordered_set
https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.0.0/en_core_web_sm-2.0.0.tar.gz
================================================
FILE: semparse/contexts/spider_context_utils.py
================================================
import re
from collections import defaultdict
from sys import exc_info
from typing import List, Dict, Set
from overrides import overrides
from parsimonious.exceptions import VisitationError, UndefinedLabel
from parsimonious.expressions import Literal, OneOf, Sequence
from parsimonious.grammar import Grammar
from parsimonious.nodes import Node, NodeVisitor
from six import reraise
WHITESPACE_REGEX = re.compile(" wsp |wsp | wsp| ws |ws | ws")
def format_grammar_string(grammar_dictionary: Dict[str, List[str]]) -> str:
"""
Formats a dictionary of production rules into the string format expected
by the Parsimonious Grammar class.
"""
return '\n'.join([f"{nonterminal} = {' / '.join(right_hand_side)}"
for nonterminal, right_hand_side in grammar_dictionary.items()])
def initialize_valid_actions(grammar: Grammar,
keywords_to_uppercase: List[str] = None) -> Dict[str, List[str]]:
"""
We initialize the valid actions with the global actions. These include the
valid actions that result from the grammar and also those that result from
the tables provided. The keys represent the nonterminals in the grammar
and the values are lists of the valid actions of that nonterminal.
"""
valid_actions: Dict[str, Set[str]] = defaultdict(set)
for key in grammar:
rhs = grammar[key]
# Sequence represents a series of expressions that match pieces of the text in order.
# Eg. A -> B C
if isinstance(rhs, Sequence):
valid_actions[key].add(
format_action(key, " ".join(rhs._unicode_members()), # pylint: disable=protected-access
keywords_to_uppercase=keywords_to_uppercase))
# OneOf represents a series of expressions, one of which matches the text.
# Eg. A -> B / C
elif isinstance(rhs, OneOf):
for option in rhs._unicode_members(): # pylint: disable=protected-access
valid_actions[key].add(format_action(key, option,
keywords_to_uppercase=keywords_to_uppercase))
# A string literal, eg. "A"
elif isinstance(rhs, Literal):
if rhs.literal != "":
valid_actions[key].add(format_action(key, repr(rhs.literal),
keywords_to_uppercase=keywords_to_uppercase))
else:
valid_actions[key] = set()
valid_action_strings = {key: sorted(value) for key, value in valid_actions.items()}
return valid_action_strings
def format_action(nonterminal: str,
right_hand_side: str,
is_string: bool = False,
is_number: bool = False,
keywords_to_uppercase: List[str] = None) -> str:
"""
This function formats an action as it appears in models. It
splits productions based on the special `ws` and `wsp` rules,
which are used in grammars to denote whitespace, and then
rejoins these tokens a formatted, comma separated list.
Importantly, note that it `does not` split on spaces in
the grammar string, because these might not correspond
to spaces in the language the grammar recognises.
Parameters
----------
nonterminal : ``str``, required.
The nonterminal in the action.
right_hand_side : ``str``, required.
The right hand side of the action
(i.e the thing which is produced).
is_string : ``bool``, optional (default = False).
Whether the production produces a string.
If it does, it is formatted as ``nonterminal -> ['string']``
is_number : ``bool``, optional, (default = False).
Whether the production produces a string.
If it does, it is formatted as ``nonterminal -> ['number']``
keywords_to_uppercase: ``List[str]``, optional, (default = None)
Keywords in the grammar to uppercase. In the case of sql,
this might be SELECT, MAX etc.
"""
keywords_to_uppercase = keywords_to_uppercase or []
if right_hand_side.upper() in keywords_to_uppercase:
right_hand_side = right_hand_side.upper()
if is_string:
return f'{nonterminal} -> ["\'{right_hand_side}\'"]'
elif is_number:
return f'{nonterminal} -> ["{right_hand_side}"]'
else:
right_hand_side = right_hand_side.lstrip("(").rstrip(")")
child_strings = [token for token in WHITESPACE_REGEX.split(right_hand_side) if token]
child_strings = [tok.upper() if tok.upper() in keywords_to_uppercase else tok for tok in child_strings]
return f"{nonterminal} -> [{', '.join(child_strings)}]"
def action_sequence_to_sql(action_sequences: List[str], add_table_names: bool=False) -> str:
# Convert an action sequence like ['statement -> [query, ";"]', ...] to the
# SQL string.
query = []
for action in action_sequences:
nonterminal, right_hand_side = action.split(' -> ')
right_hand_side_tokens = right_hand_side[1:-1].split(', ')
if nonterminal == 'statement':
query.extend(right_hand_side_tokens)
else:
for query_index, token in list(enumerate(query)):
if token == nonterminal:
if nonterminal == 'column_name' and '@' in right_hand_side_tokens[0] and len(right_hand_side_tokens) == 1:
if add_table_names:
table_name, column_name = right_hand_side_tokens[0].split('@')
if '.' in table_name:
table_name = table_name.split('.')[0]
right_hand_side_tokens = [table_name + '.' + column_name]
else:
right_hand_side_tokens = [right_hand_side_tokens[0].split('@')[-1]]
query = query[:query_index] + \
right_hand_side_tokens + \
query[query_index + 1:]
break
return ' '.join([token.strip('"') for token in query])
class SqlVisitor(NodeVisitor):
"""
``SqlVisitor`` performs a depth-first traversal of the the AST. It takes the parse tree
and gives us an action sequence that resulted in that parse. Since the visitor has mutable
state, we define a new ``SqlVisitor`` for each query. To get the action sequence, we create
a ``SqlVisitor`` and call parse on it, which returns a list of actions. Ex.
sql_visitor = SqlVisitor(grammar_string)
action_sequence = sql_visitor.parse(query)
Importantly, this ``SqlVisitor`` skips over ``ws`` and ``wsp`` nodes,
because they do not hold any meaning, and make an action sequence
much longer than it needs to be.
Parameters
----------
grammar : ``Grammar``
A Grammar object that we use to parse the text.
keywords_to_uppercase: ``List[str]``, optional, (default = None)
Keywords in the grammar to uppercase. In the case of sql,
this might be SELECT, MAX etc.
"""
def __init__(self, grammar: Grammar, keywords_to_uppercase: List[str] = None) -> None:
self.action_sequence: List[str] = []
self.grammar: Grammar = grammar
self.keywords_to_uppercase = keywords_to_uppercase or []
@overrides
def generic_visit(self, node: Node, visited_children: List[None]) -> List[str]:
self.add_action(node)
if node.expr.name == 'statement':
return self.action_sequence
return []
def add_action(self, node: Node) -> None:
"""
For each node, we accumulate the rules that generated its children in a list.
"""
if node.expr.name and node.expr.name not in ['ws', 'wsp']:
nonterminal = f'{node.expr.name} -> '
if isinstance(node.expr, Literal):
right_hand_side = f'["{node.text}"]'
else:
child_strings = []
for child in node.__iter__():
if child.expr.name in ['ws', 'wsp']:
continue
if child.expr.name != '':
child_strings.append(child.expr.name)
else:
child_right_side_string = child.expr._as_rhs().lstrip("(").rstrip(
")") # pylint: disable=protected-access
child_right_side_list = [tok for tok in
WHITESPACE_REGEX.split(child_right_side_string) if tok]
child_right_side_list = [tok.upper() if tok.upper() in
self.keywords_to_uppercase else tok
for tok in child_right_side_list]
child_strings.extend(child_right_side_list)
right_hand_side = "[" + ", ".join(child_strings) + "]"
rule = nonterminal + right_hand_side
self.action_sequence = [rule] + self.action_sequence
@overrides
def visit(self, node):
"""
See the ``NodeVisitor`` visit method. This just changes the order in which
we visit nonterminals from right to left to left to right.
"""
method = getattr(self, 'visit_' + node.expr_name, self.generic_visit)
# Call that method, and show where in the tree it failed if it blows
# up.
try:
# Changing this to reverse here!
return method(node, [self.visit(child) for child in reversed(list(node))])
except (VisitationError, UndefinedLabel):
# Don't catch and re-wrap already-wrapped exceptions.
raise
except self.unwrapped_exceptions:
raise
except Exception: # pylint: disable=broad-except
# Catch any exception, and tack on a parse tree so it's easier to
# see where it went wrong.
exc_class, exc, traceback = exc_info()
reraise(VisitationError, VisitationError(exc, exc_class, node), traceback)
================================================
FILE: semparse/contexts/spider_db_context.py
================================================
import re
from collections import Set, defaultdict
from typing import Dict, Tuple, List
from allennlp.data import Tokenizer, Token
from ordered_set import OrderedSet
from unidecode import unidecode
from dataset_readers.dataset_util.spider_utils import TableColumn, read_dataset_schema, read_dataset_values
from allennlp.semparse.contexts.knowledge_graph import KnowledgeGraph
# == stop words that will be omitted by ContextGenerator
STOP_WORDS = {"", "", "all", "being", "-", "over", "through", "yourselves", "its", "before",
"hadn", "with", "had", ",", "should", "to", "only", "under", "ours", "has", "ought", "do",
"them", "his", "than", "very", "cannot", "they", "not", "during", "yourself", "him",
"nor", "did", "didn", "'ve", "this", "she", "each", "where", "because", "doing", "some", "we", "are",
"further", "ourselves", "out", "what", "for", "weren", "does", "above", "between", "mustn", "?",
"be", "hasn", "who", "were", "here", "shouldn", "let", "hers", "by", "both", "about", "couldn",
"of", "could", "against", "isn", "or", "own", "into", "while", "whom", "down", "wasn", "your",
"from", "her", "their", "aren", "there", "been", ".", "few", "too", "wouldn", "themselves",
":", "was", "until", "more", "himself", "on", "but", "don", "herself", "haven", "those", "he",
"me", "myself", "these", "up", ";", "below", "'re", "can", "theirs", "my", "and", "would", "then",
"is", "am", "it", "doesn", "an", "as", "itself", "at", "have", "in", "any", "if", "!",
"again", "'ll", "no", "that", "when", "same", "how", "other", "which", "you", "many", "shan",
"'t", "'s", "our", "after", "most", "'d", "such", "'m", "why", "a", "off", "i", "yours", "so",
"the", "having", "once"}
class SpiderDBContext:
schemas = {}
db_knowledge_graph = {}
db_tables_data = {}
def __init__(self, db_id: str, utterance: str, tokenizer: Tokenizer, tables_file: str, dataset_path: str):
self.dataset_path = dataset_path
self.tables_file = tables_file
self.db_id = db_id
self.utterance = utterance
tokenized_utterance = tokenizer.tokenize(utterance.lower())
self.tokenized_utterance = [Token(text=t.text, lemma=t.lemma_) for t in tokenized_utterance]
if db_id not in SpiderDBContext.schemas:
SpiderDBContext.schemas = read_dataset_schema(self.tables_file)
self.schema = SpiderDBContext.schemas[db_id]
self.knowledge_graph = self.get_db_knowledge_graph(db_id)
entity_texts = [self.knowledge_graph.entity_text[entity].lower()
for entity in self.knowledge_graph.entities]
entity_tokens = tokenizer.batch_tokenize(entity_texts)
self.entity_tokens = [[Token(text=t.text, lemma=t.lemma_) for t in et] for et in entity_tokens]
@staticmethod
def entity_key_for_column(table_name: str, column: TableColumn) -> str:
if column.foreign_key is not None:
column_type = "foreign"
elif column.is_primary_key:
column_type = "primary"
else:
column_type = column.column_type
return f"column:{column_type.lower()}:{table_name.lower()}:{column.name.lower()}"
def get_db_knowledge_graph(self, db_id: str) -> KnowledgeGraph:
entities: Set[str] = set()
neighbors: Dict[str, OrderedSet[str]] = defaultdict(OrderedSet)
entity_text: Dict[str, str] = {}
foreign_keys_to_column: Dict[str, str] = {}
db_schema = self.schema
tables = db_schema.values()
if db_id not in self.db_tables_data:
self.db_tables_data[db_id] = read_dataset_values(db_id, self.dataset_path, tables)
tables_data = self.db_tables_data[db_id]
string_column_mapping: Dict[str, set] = defaultdict(set)
for table, table_data in tables_data.items():
for table_row in table_data:
for column, cell_value in zip(db_schema[table.name].columns, table_row):
if column.column_type == 'text' and type(cell_value) is str:
cell_value_normalized = self.normalize_string(cell_value)
column_key = self.entity_key_for_column(table.name, column)
string_column_mapping[cell_value_normalized].add(column_key)
string_entities = self.get_entities_from_question(string_column_mapping)
for table in tables:
table_key = f"table:{table.name.lower()}"
entities.add(table_key)
entity_text[table_key] = table.text
for column in db_schema[table.name].columns:
entity_key = self.entity_key_for_column(table.name, column)
entities.add(entity_key)
neighbors[entity_key].add(table_key)
neighbors[table_key].add(entity_key)
entity_text[entity_key] = column.text
for string_entity, column_keys in string_entities:
entities.add(string_entity)
for column_key in column_keys:
neighbors[string_entity].add(column_key)
neighbors[column_key].add(string_entity)
entity_text[string_entity] = string_entity.replace("string:", "").replace("_", " ")
# loop again after we have gone through all columns to link foreign keys columns
for table_name in db_schema.keys():
for column in db_schema[table_name].columns:
if column.foreign_key is None:
continue
other_column_table, other_column_name = column.foreign_key.split(':')
# must have exactly one by design
other_column = [col for col in db_schema[other_column_table].columns if col.name == other_column_name][0]
entity_key = self.entity_key_for_column(table_name, column)
other_entity_key = self.entity_key_for_column(other_column_table, other_column)
neighbors[entity_key].add(other_entity_key)
neighbors[other_entity_key].add(entity_key)
foreign_keys_to_column[entity_key] = other_entity_key
kg = KnowledgeGraph(entities, dict(neighbors), entity_text)
kg.foreign_keys_to_column = foreign_keys_to_column
return kg
def _string_in_table(self, candidate: str,
string_column_mapping: Dict[str, set]) -> List[str]:
"""
Checks if the string occurs in the table, and if it does, returns the names of the columns
under which it occurs. If it does not, returns an empty list.
"""
candidate_column_names: List[str] = []
# First check if the entire candidate occurs as a cell.
if candidate in string_column_mapping:
candidate_column_names = string_column_mapping[candidate]
# If not, check if it is a substring pf any cell value.
if not candidate_column_names:
for cell_value, column_names in string_column_mapping.items():
if candidate in cell_value:
candidate_column_names.extend(column_names)
candidate_column_names = list(set(candidate_column_names))
return candidate_column_names
def get_entities_from_question(self,
string_column_mapping: Dict[str, set]) -> List[Tuple[str, str]]:
entity_data = []
for i, token in enumerate(self.tokenized_utterance):
token_text = token.text
if token_text in STOP_WORDS:
continue
normalized_token_text = self.normalize_string(token_text)
if not normalized_token_text:
continue
token_columns = self._string_in_table(normalized_token_text, string_column_mapping)
if token_columns:
token_type = token_columns[0].split(":")[1]
entity_data.append({'value': normalized_token_text,
'token_start': i,
'token_end': i+1,
'token_type': token_type,
'token_in_columns': token_columns})
# extracted_numbers = self._get_numbers_from_tokens(self.question_tokens)
# filter out number entities to avoid repetition
expanded_entities = []
for entity in self._expand_entities(self.tokenized_utterance, entity_data, string_column_mapping):
if entity["token_type"] == "text":
expanded_entities.append((f"string:{entity['value']}", entity['token_in_columns']))
# return expanded_entities, extracted_numbers #TODO(shikhar) Handle conjunctions
return expanded_entities
@staticmethod
def normalize_string(string: str) -> str:
"""
These are the transformation rules used to normalize cell in column names in Sempre. See
``edu.stanford.nlp.sempre.tables.StringNormalizationUtils.characterNormalize`` and
``edu.stanford.nlp.sempre.tables.TableTypeSystem.canonicalizeName``. We reproduce those
rules here to normalize and canonicalize cells and columns in the same way so that we can
match them against constants in logical forms appropriately.
"""
# Normalization rules from Sempre
# \u201A -> ,
string = re.sub("‚", ",", string)
string = re.sub("„", ",,", string)
string = re.sub("[·・]", ".", string)
string = re.sub("…", "...", string)
string = re.sub("ˆ", "^", string)
string = re.sub("˜", "~", string)
string = re.sub("‹", "<", string)
string = re.sub("›", ">", string)
string = re.sub("[‘’´`]", "'", string)
string = re.sub("[“”«»]", "\"", string)
string = re.sub("[•†‡²³]", "", string)
string = re.sub("[‐‑–—−]", "-", string)
# Oddly, some unicode characters get converted to _ instead of being stripped. Not really
# sure how sempre decides what to do with these... TODO(mattg): can we just get rid of the
# need for this function somehow? It's causing a whole lot of headaches.
string = re.sub("[ðø′″€⁄ªΣ]", "_", string)
# This is such a mess. There isn't just a block of unicode that we can strip out, because
# sometimes sempre just strips diacritics... We'll try stripping out a few separate
# blocks, skipping the ones that sempre skips...
string = re.sub("[\\u0180-\\u0210]", "", string).strip()
string = re.sub("[\\u0220-\\uFFFF]", "", string).strip()
string = string.replace("\\n", "_")
string = re.sub("\\s+", " ", string)
# Canonicalization rules from Sempre.
string = re.sub("[^\\w]", "_", string)
string = re.sub("_+", "_", string)
string = re.sub("_$", "", string)
return unidecode(string.lower())
def _expand_entities(self, question, entity_data, string_column_mapping: Dict[str, set]):
new_entities = []
for entity in entity_data:
# to ensure the same strings are not used over and over
if new_entities and entity['token_end'] <= new_entities[-1]['token_end']:
continue
current_start = entity['token_start']
current_end = entity['token_end']
current_token = entity['value']
current_token_type = entity['token_type']
current_token_columns = entity['token_in_columns']
while current_end < len(question):
next_token = question[current_end].text
next_token_normalized = self.normalize_string(next_token)
if next_token_normalized == "":
current_end += 1
continue
candidate = "%s_%s" %(current_token, next_token_normalized)
candidate_columns = self._string_in_table(candidate, string_column_mapping)
candidate_columns = list(set(candidate_columns).intersection(current_token_columns))
if not candidate_columns:
break
candidate_type = candidate_columns[0].split(":")[1]
if candidate_type != current_token_type:
break
current_end += 1
current_token = candidate
current_token_columns = candidate_columns
new_entities.append({'token_start': current_start,
'token_end': current_end,
'value': current_token,
'token_type': current_token_type,
'token_in_columns': current_token_columns})
return new_entities
================================================
FILE: semparse/contexts/spider_db_grammar.py
================================================
# pylint: disable=anomalous-backslash-in-string
"""
A ``Text2SqlTableContext`` represents the SQL context in which an utterance appears
for the any of the text2sql datasets, with the grammar and the valid actions.
"""
from typing import List, Dict
from dataset_readers.dataset_util.spider_utils import Table
GRAMMAR_DICTIONARY = {}
GRAMMAR_DICTIONARY["statement"] = ['(query ws iue ws query)', '(query ws)']
GRAMMAR_DICTIONARY["iue"] = ['"intersect"', '"except"', '"union"']
GRAMMAR_DICTIONARY["query"] = ['(ws select_core ws groupby_clause ws orderby_clause ws limit)',
'(ws select_core ws groupby_clause ws orderby_clause)',
'(ws select_core ws groupby_clause ws limit)',
'(ws select_core ws orderby_clause ws limit)',
'(ws select_core ws groupby_clause)',
'(ws select_core ws orderby_clause)',
'(ws select_core)']
GRAMMAR_DICTIONARY["select_core"] = ['(select_with_distinct ws select_results ws from_clause ws where_clause)',
'(select_with_distinct ws select_results ws from_clause)',
'(select_with_distinct ws select_results ws where_clause)',
'(select_with_distinct ws select_results)']
GRAMMAR_DICTIONARY["select_with_distinct"] = ['(ws "select" ws "distinct")', '(ws "select")']
GRAMMAR_DICTIONARY["select_results"] = ['(ws select_result ws "," ws select_results)', '(ws select_result)']
GRAMMAR_DICTIONARY["select_result"] = ['"*"', '(table_source ws ".*")',
'expr', 'col_ref']
GRAMMAR_DICTIONARY["from_clause"] = ['(ws "from" ws table_source ws join_clauses)',
'(ws "from" ws source)']
GRAMMAR_DICTIONARY["join_clauses"] = ['(join_clause ws join_clauses)', 'join_clause']
GRAMMAR_DICTIONARY["join_clause"] = ['"join" ws table_source ws "on" ws join_condition_clause']
GRAMMAR_DICTIONARY["join_condition_clause"] = ['(join_condition ws "and" ws join_condition_clause)', 'join_condition']
GRAMMAR_DICTIONARY["join_condition"] = ['ws col_ref ws "=" ws col_ref']
GRAMMAR_DICTIONARY["source"] = ['(ws single_source ws "," ws source)', '(ws single_source)']
GRAMMAR_DICTIONARY["single_source"] = ['table_source', 'source_subq']
GRAMMAR_DICTIONARY["source_subq"] = ['("(" ws query ws ")")']
# GRAMMAR_DICTIONARY["source_subq"] = ['("(" ws query ws ")" ws "as" ws name)', '("(" ws query ws ")")']
GRAMMAR_DICTIONARY["limit"] = ['("limit" ws non_literal_number)']
GRAMMAR_DICTIONARY["where_clause"] = ['(ws "where" wsp expr ws where_conj)', '(ws "where" wsp expr)']
GRAMMAR_DICTIONARY["where_conj"] = ['(ws "and" wsp expr ws where_conj)', '(ws "and" wsp expr)']
GRAMMAR_DICTIONARY["groupby_clause"] = ['(ws "group" ws "by" ws group_clause ws "having" ws expr)',
'(ws "group" ws "by" ws group_clause)']
GRAMMAR_DICTIONARY["group_clause"] = ['(ws expr ws "," ws group_clause)', '(ws expr)']
GRAMMAR_DICTIONARY["orderby_clause"] = ['ws "order" ws "by" ws order_clause']
GRAMMAR_DICTIONARY["order_clause"] = ['(ordering_term ws "," ws order_clause)', 'ordering_term']
GRAMMAR_DICTIONARY["ordering_term"] = ['(ws expr ws ordering)', '(ws expr)']
GRAMMAR_DICTIONARY["ordering"] = ['(ws "asc")', '(ws "desc")']
GRAMMAR_DICTIONARY["col_ref"] = ['(table_name ws "." ws column_name)', 'column_name']
GRAMMAR_DICTIONARY["table_source"] = ['(table_name ws "as" ws table_alias)', 'table_name']
GRAMMAR_DICTIONARY["table_name"] = ["table_alias"]
GRAMMAR_DICTIONARY["table_alias"] = ['"t1"', '"t2"', '"t3"', '"t4"']
GRAMMAR_DICTIONARY["column_name"] = []
GRAMMAR_DICTIONARY["ws"] = ['~"\s*"i']
GRAMMAR_DICTIONARY['wsp'] = ['~"\s+"i']
GRAMMAR_DICTIONARY["expr"] = ['in_expr',
# Like expressions.
'(value wsp "like" wsp string)',
# Between expressions.
'(value ws "between" wsp value ws "and" wsp value)',
# Binary expressions.
'(value ws binaryop wsp expr)',
# Unary expressions.
'(unaryop ws expr)',
'source_subq',
'value']
GRAMMAR_DICTIONARY["in_expr"] = ['(value wsp "not" wsp "in" wsp string_set)',
'(value wsp "in" wsp string_set)',
'(value wsp "not" wsp "in" wsp expr)',
'(value wsp "in" wsp expr)']
GRAMMAR_DICTIONARY["value"] = ['parenval', '"YEAR(CURDATE())"', 'number', 'boolean',
'function', 'col_ref', 'string']
GRAMMAR_DICTIONARY["parenval"] = ['"(" ws expr ws ")"']
GRAMMAR_DICTIONARY["function"] = ['(fname ws "(" ws "distinct" ws arg_list_or_star ws ")")',
'(fname ws "(" ws arg_list_or_star ws ")")']
GRAMMAR_DICTIONARY["arg_list_or_star"] = ['arg_list', '"*"']
GRAMMAR_DICTIONARY["arg_list"] = ['(expr ws "," ws arg_list)', 'expr']
# TODO(MARK): Massive hack, remove and modify the grammar accordingly
# GRAMMAR_DICTIONARY["number"] = ['~"\d*\.?\d+"i', "'3'", "'4'"]
GRAMMAR_DICTIONARY["non_literal_number"] = ['"1"', '"2"', '"3"', '"4"']
GRAMMAR_DICTIONARY["number"] = ['ws "value" ws']
GRAMMAR_DICTIONARY["string_set"] = ['ws "(" ws string_set_vals ws ")"']
GRAMMAR_DICTIONARY["string_set_vals"] = ['(string ws "," ws string_set_vals)', 'string']
# GRAMMAR_DICTIONARY["string"] = ['~"\'.*?\'"i']
GRAMMAR_DICTIONARY["string"] = ['"\'" ws "value" ws "\'"']
GRAMMAR_DICTIONARY["fname"] = ['"count"', '"sum"', '"max"', '"min"', '"avg"', '"all"']
GRAMMAR_DICTIONARY["boolean"] = ['"true"', '"false"']
# TODO(MARK): This is not tight enough. AND/OR are strictly boolean value operators.
GRAMMAR_DICTIONARY["binaryop"] = ['"+"', '"-"', '"*"', '"/"', '"="', '"!="', '"<>"',
'">="', '"<="', '">"', '"<"', '"and"', '"or"', '"like"']
GRAMMAR_DICTIONARY["unaryop"] = ['"+"', '"-"', '"not"', '"not"']
def update_grammar_with_tables(grammar_dictionary: Dict[str, List[str]],
schema: Dict[str, Table]) -> None:
table_names = sorted([f'"{table.lower()}"' for table in
list(schema.keys())], reverse=True)
grammar_dictionary['table_name'] += table_names
all_columns = set()
for table in schema.values():
all_columns.update([f'"{table.name.lower()}@{column.name.lower()}"' for column in table.columns if column.name != '*'])
sorted_columns = sorted([column for column in all_columns], reverse=True)
grammar_dictionary['column_name'] += sorted_columns
def update_grammar_to_be_table_names_free(grammar_dictionary: Dict[str, List[str]]):
"""
Remove table names from column names, remove aliases
"""
grammar_dictionary["column_name"] = []
grammar_dictionary["table_name"] = []
grammar_dictionary["col_ref"] = ['column_name']
grammar_dictionary["table_source"] = ['table_name']
del grammar_dictionary["table_alias"]
def update_grammar_flip_joins(grammar_dictionary: Dict[str, List[str]]):
"""
Remove table names from column names, remove aliases
"""
# using a simple rule such as join_clauses-> [(join_clauses ws join_clause), join_clause]
# resulted in a max recursion error, so for now just using a predefined max
# number of joins
grammar_dictionary["join_clauses"] = ['(join_clauses_1 ws join_clause)', 'join_clause']
grammar_dictionary["join_clauses_1"] = ['(join_clauses_2 ws join_clause)', 'join_clause']
grammar_dictionary["join_clauses_2"] = ['(join_clause ws join_clause)', 'join_clause']
================================================
FILE: semparse/worlds/evaluate.py
================================================
################################
# val: number(float)/string(str)/sql(dict)
# col_unit: (agg_id, col_id, isDistinct(bool))
# val_unit: (unit_op, col_unit1, col_unit2)
# table_unit: (table_type, col_unit/sql)
# cond_unit: (not_op, op_id, val_unit, val1, val2)
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
# sql {
# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
# 'where': condition
# 'groupBy': [col_unit1, col_unit2, ...]
# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
# 'having': condition
# 'limit': None/limit value
# 'intersect': None/sql
# 'except': None/sql
# 'union': None/sql
# }
################################
import copy
import os
import json
import sqlite3
import argparse
from spider_evaluation.process_sql import get_schema, Schema, get_sql
# Flag to disable value evaluation
DISABLE_VALUE = True
# Flag to disable distinct in select evaluation
DISABLE_DISTINCT = True
CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
JOIN_KEYWORDS = ('join', 'on', 'as')
WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
UNIT_OPS = ('none', '-', '+', "*", '/')
AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
TABLE_TYPE = {
'sql': "sql",
'table_unit': "table_unit",
}
COND_OPS = ('and', 'or')
SQL_OPS = ('intersect', 'union', 'except')
ORDER_OPS = ('desc', 'asc')
HARDNESS = {
"component1": ('where', 'group', 'order', 'limit', 'join', 'or', 'like'),
"component2": ('except', 'union', 'intersect')
}
def condition_has_or(conds):
return 'or' in conds[1::2]
def condition_has_like(conds):
return WHERE_OPS.index('like') in [cond_unit[1] for cond_unit in conds[::2]]
def condition_has_sql(conds):
for cond_unit in conds[::2]:
val1, val2 = cond_unit[3], cond_unit[4]
if val1 is not None and type(val1) is dict:
return True
if val2 is not None and type(val2) is dict:
return True
return False
def val_has_op(val_unit):
return val_unit[0] != UNIT_OPS.index('none')
def has_agg(unit):
return unit[0] != AGG_OPS.index('none')
def accuracy(count, total):
if count == total:
return 1
return 0
def recall(count, total):
if count == total:
return 1
return 0
def F1(acc, rec):
if (acc + rec) == 0:
return 0
return (2. * acc * rec) / (acc + rec)
def get_scores(count, pred_total, label_total):
if pred_total != label_total:
return 0,0,0
elif count == pred_total:
return 1,1,1
return 0,0,0
def eval_sel(pred, label):
pred_sel = pred['select'][1]
label_sel = label['select'][1]
label_wo_agg = [unit[1] for unit in label_sel]
pred_total = len(pred_sel)
label_total = len(label_sel)
cnt = 0
cnt_wo_agg = 0
for unit in pred_sel:
if unit in label_sel:
cnt += 1
label_sel.remove(unit)
if unit[1] in label_wo_agg:
cnt_wo_agg += 1
label_wo_agg.remove(unit[1])
return label_total, pred_total, cnt, cnt_wo_agg
def eval_where(pred, label):
pred_conds = [unit for unit in pred['where'][::2]]
label_conds = [unit for unit in label['where'][::2]]
label_wo_agg = [unit[2] for unit in label_conds]
pred_total = len(pred_conds)
label_total = len(label_conds)
cnt = 0
cnt_wo_agg = 0
for unit in pred_conds:
if unit in label_conds:
cnt += 1
label_conds.remove(unit)
if unit[2] in label_wo_agg:
cnt_wo_agg += 1
label_wo_agg.remove(unit[2])
return label_total, pred_total, cnt, cnt_wo_agg
def eval_group(pred, label):
pred_cols = [unit[1] for unit in pred['groupBy']]
label_cols = [unit[1] for unit in label['groupBy']]
pred_total = len(pred_cols)
label_total = len(label_cols)
cnt = 0
pred_cols = [pred.split(".")[1] if "." in pred else pred for pred in pred_cols]
label_cols = [label.split(".")[1] if "." in label else label for label in label_cols]
for col in pred_cols:
if col in label_cols:
cnt += 1
label_cols.remove(col)
return label_total, pred_total, cnt
def eval_having(pred, label):
pred_total = label_total = cnt = 0
if len(pred['groupBy']) > 0:
pred_total = 1
if len(label['groupBy']) > 0:
label_total = 1
pred_cols = [unit[1] for unit in pred['groupBy']]
label_cols = [unit[1] for unit in label['groupBy']]
if pred_total == label_total == 1 \
and pred_cols == label_cols \
and pred['having'] == label['having']:
cnt = 1
return label_total, pred_total, cnt
def eval_order(pred, label):
pred_total = label_total = cnt = 0
if len(pred['orderBy']) > 0:
pred_total = 1
if len(label['orderBy']) > 0:
label_total = 1
if len(label['orderBy']) > 0 and pred['orderBy'] == label['orderBy'] and \
((pred['limit'] is None and label['limit'] is None) or (pred['limit'] is not None and label['limit'] is not None)):
cnt = 1
return label_total, pred_total, cnt
def eval_and_or(pred, label):
pred_ao = pred['where'][1::2]
label_ao = label['where'][1::2]
pred_ao = set(pred_ao)
label_ao = set(label_ao)
if pred_ao == label_ao:
return 1,1,1
return len(pred_ao),len(label_ao),0
def get_nestedSQL(sql):
nested = []
for cond_unit in sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]:
if type(cond_unit[3]) is dict:
nested.append(cond_unit[3])
if type(cond_unit[4]) is dict:
nested.append(cond_unit[4])
if sql['intersect'] is not None:
nested.append(sql['intersect'])
if sql['except'] is not None:
nested.append(sql['except'])
if sql['union'] is not None:
nested.append(sql['union'])
return nested
def eval_nested(pred, label):
label_total = 0
pred_total = 0
cnt = 0
if pred is not None:
pred_total += 1
if label is not None:
label_total += 1
if pred is not None and label is not None:
cnt += Evaluator().eval_exact_match(pred, label)
return label_total, pred_total, cnt
def eval_IUEN(pred, label):
lt1, pt1, cnt1 = eval_nested(pred['intersect'], label['intersect'])
lt2, pt2, cnt2 = eval_nested(pred['except'], label['except'])
lt3, pt3, cnt3 = eval_nested(pred['union'], label['union'])
label_total = lt1 + lt2 + lt3
pred_total = pt1 + pt2 + pt3
cnt = cnt1 + cnt2 + cnt3
return label_total, pred_total, cnt
def get_keywords(sql):
res = set()
if len(sql['where']) > 0:
res.add('where')
if len(sql['groupBy']) > 0:
res.add('group')
if len(sql['having']) > 0:
res.add('having')
if len(sql['orderBy']) > 0:
res.add(sql['orderBy'][0])
res.add('order')
if sql['limit'] is not None:
res.add('limit')
if sql['except'] is not None:
res.add('except')
if sql['union'] is not None:
res.add('union')
if sql['intersect'] is not None:
res.add('intersect')
# or keyword
ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
if len([token for token in ao if token == 'or']) > 0:
res.add('or')
cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
# not keyword
if len([cond_unit for cond_unit in cond_units if cond_unit[0]]) > 0:
res.add('not')
# in keyword
if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('in')]) > 0:
res.add('in')
# like keyword
if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) > 0:
res.add('like')
return res
def eval_keywords(pred, label):
pred_keywords = get_keywords(pred)
label_keywords = get_keywords(label)
pred_total = len(pred_keywords)
label_total = len(label_keywords)
cnt = 0
for k in pred_keywords:
if k in label_keywords:
cnt += 1
return label_total, pred_total, cnt
def count_agg(units):
return len([unit for unit in units if has_agg(unit)])
def count_component1(sql):
count = 0
if len(sql['where']) > 0:
count += 1
if len(sql['groupBy']) > 0:
count += 1
if len(sql['orderBy']) > 0:
count += 1
if sql['limit'] is not None:
count += 1
if len(sql['from']['table_units']) > 0: # JOIN
count += len(sql['from']['table_units']) - 1
ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
count += len([token for token in ao if token == 'or'])
cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
count += len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')])
return count
def count_component2(sql):
nested = get_nestedSQL(sql)
return len(nested)
def count_others(sql):
count = 0
# number of aggregation
agg_count = count_agg(sql['select'][1])
agg_count += count_agg(sql['where'][::2])
agg_count += count_agg(sql['groupBy'])
if len(sql['orderBy']) > 0:
agg_count += count_agg([unit[1] for unit in sql['orderBy'][1] if unit[1]] +
[unit[2] for unit in sql['orderBy'][1] if unit[2]])
agg_count += count_agg(sql['having'])
if agg_count > 1:
count += 1
# number of select columns
if len(sql['select'][1]) > 1:
count += 1
# number of where conditions
if len(sql['where']) > 1:
count += 1
# number of group by clauses
if len(sql['groupBy']) > 1:
count += 1
return count
class Evaluator:
"""A simple evaluator"""
def __init__(self):
self.partial_scores = None
def eval_hardness(self, sql):
count_comp1_ = count_component1(sql)
count_comp2_ = count_component2(sql)
count_others_ = count_others(sql)
if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0:
return "easy"
elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or \
(count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0):
return "medium"
elif (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or \
(2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or \
(count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1):
return "hard"
else:
return "extra"
def eval_exact_match(self, pred, label):
partial_scores = self.eval_partial_match(pred, label)
self.partial_scores = partial_scores
for _, score in partial_scores.items():
if score['f1'] != 1:
return 0
if len(label['from']['table_units']) > 0:
label_tables = sorted(label['from']['table_units'])
pred_tables = sorted(pred['from']['table_units'])
if label_tables != pred_tables:
return False
if len(label['from']['conds']) > 0:
label_joins = sorted(label['from']['conds'], key=lambda x: str(x))
pred_joins = sorted(pred['from']['conds'], key=lambda x: str(x))
if label_joins != pred_joins:
return False
return 1
def eval_partial_match(self, pred, label):
res = {}
label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['select'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
res['select(no AGG)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['where'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
res['where(no OP)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_group(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['group(no Having)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_having(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['group'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_order(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['order'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_and_or(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['and/or'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_IUEN(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['IUEN'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_keywords(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['keywords'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
return res
def isValidSQL(sql, db):
conn = sqlite3.connect(db)
cursor = conn.cursor()
try:
cursor.execute(sql, [])
except Exception as e:
return False
return True
def print_scores(scores, etype):
levels = ['easy', 'medium', 'hard', 'extra', 'all']
partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
'group', 'order', 'and/or', 'IUEN', 'keywords']
print("{:20} {:20} {:20} {:20} {:20} {:20}".format("", *levels))
counts = [scores[level]['count'] for level in levels]
print("{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}".format("count", *counts))
if etype in ["all", "exec"]:
print('===================== EXECUTION ACCURACY =====================')
this_scores = [scores[level]['exec'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("execution", *this_scores))
if etype in ["all", "match"]:
print('\n====================== EXACT MATCHING ACCURACY =====================')
exact_scores = [scores[level]['exact'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("exact match", *exact_scores))
print('\n---------------------PARTIAL MATCHING ACCURACY----------------------')
for type_ in partial_types:
this_scores = [scores[level]['partial'][type_]['acc'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
print('---------------------- PARTIAL MATCHING RECALL ----------------------')
for type_ in partial_types:
this_scores = [scores[level]['partial'][type_]['rec'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
print('---------------------- PARTIAL MATCHING F1 --------------------------')
for type_ in partial_types:
this_scores = [scores[level]['partial'][type_]['f1'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
def evaluate(gold, predict, db_dir, etype, kmaps):
with open(gold) as f:
glist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
with open(predict) as f:
plist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
# plist = [("select max(Share),min(Share) from performance where Type != 'terminal'", "orchestra")]
# glist = [("SELECT max(SHARE) , min(SHARE) FROM performance WHERE TYPE != 'Live final'", "orchestra")]
evaluator = Evaluator()
levels = ['easy', 'medium', 'hard', 'extra', 'all']
partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
'group', 'order', 'and/or', 'IUEN', 'keywords']
entries = []
scores = {}
for level in levels:
scores[level] = {'count': 0, 'partial': {}, 'exact': 0.}
scores[level]['exec'] = 0
for type_ in partial_types:
scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0.,'acc_count':0,'rec_count':0}
eval_err_num = 0
for p, g in zip(plist, glist):
p_str = p[0]
g_str, db = g
db_name = db
db = os.path.join(db_dir, db, db + ".sqlite")
schema = Schema(get_schema(db))
g_sql = get_sql(schema, g_str)
hardness = evaluator.eval_hardness(g_sql)
scores[hardness]['count'] += 1
scores['all']['count'] += 1
try:
p_sql = get_sql(schema, p_str)
except:
# If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql
p_sql = {
"except": None,
"from": {
"conds": [],
"table_units": []
},
"groupBy": [],
"having": [],
"intersect": None,
"limit": None,
"orderBy": [],
"select": [
False,
[]
],
"union": None,
"where": []
}
eval_err_num += 1
print("eval_err_num:{}".format(eval_err_num))
print(p_str)
print()
# rebuild sql for value evaluation
kmap = kmaps[db_name]
g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema)
g_sql = rebuild_sql_val(g_sql)
g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap)
p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema)
p_sql = rebuild_sql_val(p_sql)
p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)
# p_sql_copy = copy.deepcopy(p_sql)
# g_sql_copy = copy.deepcopy(g_sql)
# if not eval_exec_match(db, p_str, g_str, p_sql_copy, g_sql_copy) and evaluator.eval_exact_match(p_sql_copy, g_sql_copy):
# a = 1
if etype in ["all", "exec"]:
exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql)
if exec_score:
scores[hardness]['exec'] += 1
scores['all']['exec'] += exec_score
if etype in ["all", "match"]:
exact_score = evaluator.eval_exact_match(p_sql, g_sql)
partial_scores = evaluator.partial_scores
if exact_score == 0:
print("{} pred: {}".format(hardness,p_str))
print("{} gold: {}".format(hardness,g_str))
print("")
scores[hardness]['exact'] += exact_score
scores['all']['exact'] += exact_score
for type_ in partial_types:
if partial_scores[type_]['pred_total'] > 0:
scores[hardness]['partial'][type_]['acc'] += partial_scores[type_]['acc']
scores[hardness]['partial'][type_]['acc_count'] += 1
if partial_scores[type_]['label_total'] > 0:
scores[hardness]['partial'][type_]['rec'] += partial_scores[type_]['rec']
scores[hardness]['partial'][type_]['rec_count'] += 1
scores[hardness]['partial'][type_]['f1'] += partial_scores[type_]['f1']
if partial_scores[type_]['pred_total'] > 0:
scores['all']['partial'][type_]['acc'] += partial_scores[type_]['acc']
scores['all']['partial'][type_]['acc_count'] += 1
if partial_scores[type_]['label_total'] > 0:
scores['all']['partial'][type_]['rec'] += partial_scores[type_]['rec']
scores['all']['partial'][type_]['rec_count'] += 1
scores['all']['partial'][type_]['f1'] += partial_scores[type_]['f1']
entries.append({
'predictSQL': p_str,
'goldSQL': g_str,
'hardness': hardness,
'exact': exact_score,
'partial': partial_scores
})
for level in levels:
if scores[level]['count'] == 0:
continue
if etype in ["all", "exec"]:
scores[level]['exec'] /= scores[level]['count']
if etype in ["all", "match"]:
scores[level]['exact'] /= scores[level]['count']
for type_ in partial_types:
if scores[level]['partial'][type_]['acc_count'] == 0:
scores[level]['partial'][type_]['acc'] = 0
else:
scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \
scores[level]['partial'][type_]['acc_count'] * 1.0
if scores[level]['partial'][type_]['rec_count'] == 0:
scores[level]['partial'][type_]['rec'] = 0
else:
scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \
scores[level]['partial'][type_]['rec_count'] * 1.0
if scores[level]['partial'][type_]['acc'] == 0 and scores[level]['partial'][type_]['rec'] == 0:
scores[level]['partial'][type_]['f1'] = 1
else:
scores[level]['partial'][type_]['f1'] = \
2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / (
scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc'])
print_scores(scores, etype)
def eval_exec_match(db, p_str, g_str, pred, gold):
"""
return 1 if the values between prediction and gold are matching
in the corresponding index. Currently not support multiple col_unit(pairs).
"""
conn = sqlite3.connect(db)
cursor = conn.cursor()
conn.text_factory = bytes
try:
cursor.execute(p_str)
p_res = cursor.fetchall()
except:
return False
cursor.execute(g_str)
q_res = cursor.fetchall()
def res_map(res, val_units):
rmap = {}
for idx, val_unit in enumerate(val_units):
key = tuple(val_unit[1]) if not val_unit[2] else (val_unit[0], tuple(val_unit[1]), tuple(val_unit[2]))
rmap[key] = [r[idx] for r in res]
return rmap
p_val_units = [unit[1] for unit in pred['select'][1]]
q_val_units = [unit[1] for unit in gold['select'][1]]
return res_map(p_res, p_val_units) == res_map(q_res, q_val_units)
# Rebuild SQL functions for value evaluation
def rebuild_cond_unit_val(cond_unit):
if cond_unit is None or not DISABLE_VALUE:
return cond_unit
not_op, op_id, val_unit, val1, val2 = cond_unit
if type(val1) is not dict:
val1 = None
else:
val1 = rebuild_sql_val(val1)
if type(val2) is not dict:
val2 = None
else:
val2 = rebuild_sql_val(val2)
return not_op, op_id, val_unit, val1, val2
def rebuild_condition_val(condition):
if condition is None or not DISABLE_VALUE:
return condition
res = []
for idx, it in enumerate(condition):
if idx % 2 == 0:
res.append(rebuild_cond_unit_val(it))
else:
res.append(it)
return res
def rebuild_sql_val(sql):
if sql is None or not DISABLE_VALUE:
return sql
sql['from']['conds'] = rebuild_condition_val(sql['from']['conds'])
sql['having'] = rebuild_condition_val(sql['having'])
sql['where'] = rebuild_condition_val(sql['where'])
sql['intersect'] = rebuild_sql_val(sql['intersect'])
sql['except'] = rebuild_sql_val(sql['except'])
sql['union'] = rebuild_sql_val(sql['union'])
return sql
# Rebuild SQL functions for foreign key evaluation
def build_valid_col_units(table_units, schema):
col_ids = [table_unit[1] for table_unit in table_units if table_unit[0] == TABLE_TYPE['table_unit']]
prefixs = [col_id[:-2] for col_id in col_ids]
valid_col_units= []
for value in schema.idMap.values():
if '.' in value and value[:value.index('.')] in prefixs:
valid_col_units.append(value)
return valid_col_units
def rebuild_col_unit_col(valid_col_units, col_unit, kmap):
if col_unit is None:
return col_unit
agg_id, col_id, distinct = col_unit
if col_id in kmap and col_id in valid_col_units:
col_id = kmap[col_id]
if DISABLE_DISTINCT:
distinct = None
return agg_id, col_id, distinct
def rebuild_val_unit_col(valid_col_units, val_unit, kmap):
if val_unit is None:
return val_unit
unit_op, col_unit1, col_unit2 = val_unit
col_unit1 = rebuild_col_unit_col(valid_col_units, col_unit1, kmap)
col_unit2 = rebuild_col_unit_col(valid_col_units, col_unit2, kmap)
return unit_op, col_unit1, col_unit2
def rebuild_table_unit_col(valid_col_units, table_unit, kmap):
if table_unit is None:
return table_unit
table_type, col_unit_or_sql = table_unit
if isinstance(col_unit_or_sql, tuple):
col_unit_or_sql = rebuild_col_unit_col(valid_col_units, col_unit_or_sql, kmap)
return table_type, col_unit_or_sql
def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap):
if cond_unit is None:
return cond_unit
not_op, op_id, val_unit, val1, val2 = cond_unit
val_unit = rebuild_val_unit_col(valid_col_units, val_unit, kmap)
return not_op, op_id, val_unit, val1, val2
def rebuild_condition_col(valid_col_units, condition, kmap):
for idx in range(len(condition)):
if idx % 2 == 0:
condition[idx] = rebuild_cond_unit_col(valid_col_units, condition[idx], kmap)
return condition
def rebuild_select_col(valid_col_units, sel, kmap):
if sel is None:
return sel
distinct, _list = sel
new_list = []
for it in _list:
agg_id, val_unit = it
new_list.append((agg_id, rebuild_val_unit_col(valid_col_units, val_unit, kmap)))
if DISABLE_DISTINCT:
distinct = None
return distinct, new_list
def rebuild_from_col(valid_col_units, from_, kmap):
if from_ is None:
return from_
from_['table_units'] = [rebuild_table_unit_col(valid_col_units, table_unit, kmap) for table_unit in from_['table_units']]
from_['conds'] = rebuild_condition_col(valid_col_units, from_['conds'], kmap)
return from_
def rebuild_group_by_col(valid_col_units, group_by, kmap):
if group_by is None:
return group_by
return [rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by]
def rebuild_order_by_col(valid_col_units, order_by, kmap):
if order_by is None or len(order_by) == 0:
return order_by
direction, val_units = order_by
new_val_units = [rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units]
return direction, new_val_units
def rebuild_sql_col(valid_col_units, sql, kmap):
if sql is None:
return sql
sql['select'] = rebuild_select_col(valid_col_units, sql['select'], kmap)
sql['from'] = rebuild_from_col(valid_col_units, sql['from'], kmap)
sql['where'] = rebuild_condition_col(valid_col_units, sql['where'], kmap)
sql['groupBy'] = rebuild_group_by_col(valid_col_units, sql['groupBy'], kmap)
sql['orderBy'] = rebuild_order_by_col(valid_col_units, sql['orderBy'], kmap)
sql['having'] = rebuild_condition_col(valid_col_units, sql['having'], kmap)
sql['intersect'] = rebuild_sql_col(valid_col_units, sql['intersect'], kmap)
sql['except'] = rebuild_sql_col(valid_col_units, sql['except'], kmap)
sql['union'] = rebuild_sql_col(valid_col_units, sql['union'], kmap)
return sql
def build_foreign_key_map(entry):
cols_orig = entry["column_names_original"]
tables_orig = entry["table_names_original"]
# rebuild cols corresponding to idmap in Schema
cols = []
for col_orig in cols_orig:
if col_orig[0] >= 0:
t = tables_orig[col_orig[0]]
c = col_orig[1]
cols.append("__" + t.lower() + "." + c.lower() + "__")
else:
cols.append("__all__")
def keyset_in_list(k1, k2, k_list):
for k_set in k_list:
if k1 in k_set or k2 in k_set:
return k_set
new_k_set = set()
k_list.append(new_k_set)
return new_k_set
foreign_key_list = []
foreign_keys = entry["foreign_keys"]
for fkey in foreign_keys:
key1, key2 = fkey
key_set = keyset_in_list(key1, key2, foreign_key_list)
key_set.add(key1)
key_set.add(key2)
foreign_key_map = {}
for key_set in foreign_key_list:
sorted_list = sorted(list(key_set))
midx = sorted_list[0]
for idx in sorted_list:
foreign_key_map[cols[idx]] = cols[midx]
return foreign_key_map
def build_foreign_key_map_from_json(table):
with open(table) as f:
data = json.load(f)
tables = {}
for entry in data:
tables[entry['db_id']] = build_foreign_key_map(entry)
return tables
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--gold', dest='gold', type=str)
parser.add_argument('--pred', dest='pred', type=str)
parser.add_argument('--db', dest='db', type=str)
parser.add_argument('--table', dest='table', type=str)
parser.add_argument('--etype', dest='etype', type=str)
args = parser.parse_args()
gold = args.gold
pred = args.pred
db_dir = args.db
table = args.table
etype = args.etype
assert etype in ["all", "exec", "match"], "Unknown evaluation method"
kmaps = build_foreign_key_map_from_json(table)
evaluate(gold, pred, db_dir, etype, kmaps)
================================================
FILE: semparse/worlds/evaluate_spider.py
================================================
import os
import sqlite3
from semparse.worlds.evaluate import Evaluator, build_valid_col_units, rebuild_sql_val, rebuild_sql_col, \
build_foreign_key_map_from_json
from spider_evaluation.process_sql import Schema, get_schema, get_sql
_schemas = {}
kmaps = None
def evaluate(gold, predict, db_name, db_dir, table, check_valid: bool=True) -> bool:
global kmaps
# try:
evaluator = Evaluator()
if kmaps is None:
kmaps = build_foreign_key_map_from_json(table)
if db_name in _schemas:
schema = _schemas[db_name]
else:
db = os.path.join(db_dir, db_name, db_name + ".sqlite")
schema = _schemas[db_name] = Schema(get_schema(db))
g_sql = get_sql(schema, gold)
try:
p_sql = get_sql(schema, predict)
except Exception as e:
return False
# rebuild sql for value evaluation
kmap = kmaps[db_name]
g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema)
g_sql = rebuild_sql_val(g_sql)
g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap)
p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema)
p_sql = rebuild_sql_val(p_sql)
p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)
exact_score = evaluator.eval_exact_match(p_sql, g_sql)
if not check_valid:
return exact_score
else:
return exact_score and check_valid_sql(predict, db_name, db_dir)
# except Exception as e:
# return 0
_conns = {}
def check_valid_sql(sql, db_name, db_dir, return_error=False):
db = os.path.join(db_dir, db_name, db_name + ".sqlite")
if db_name == 'wta_1':
# TODO: seems like there is a problem with this dataset - slow response - add limit 1
return True if not return_error else (True, None)
if db_name not in _conns:
_conns[db_name] = sqlite3.connect(db)
# fixes an encoding bug
_conns[db_name].text_factory = bytes
conn = _conns[db_name]
cursor = conn.cursor()
try:
cursor.execute(sql)
cursor.fetchall()
return True if not return_error else (True, None)
except Exception as e:
return False if not return_error else (False, e.args[0])
================================================
FILE: semparse/worlds/spider_world.py
================================================
from typing import List, Tuple, Dict, Set, Optional
from copy import deepcopy
from parsimonious import Grammar
from parsimonious.exceptions import ParseError
from semparse.contexts.spider_context_utils import format_grammar_string, initialize_valid_actions, SqlVisitor
from semparse.contexts.spider_db_context import SpiderDBContext
from semparse.contexts.spider_db_grammar import GRAMMAR_DICTIONARY, update_grammar_with_tables, \
update_grammar_to_be_table_names_free, update_grammar_flip_joins
class SpiderWorld:
"""
World representation for spider dataset.
"""
def __init__(self, db_context: SpiderDBContext, query: Optional[List[str]], allow_alias: bool = False) -> None:
self.db_id = db_context.db_id
self.allow_alias = allow_alias
# NOTE: This base dictionary should not be modified.
self.base_grammar_dictionary = deepcopy(GRAMMAR_DICTIONARY)
self.query = query
self.db_context = db_context
# keep a list of entities names as they are given in sql queries
self.entities_names = {}
for i, entity in enumerate(self.db_context.knowledge_graph.entities):
parts = entity.split(':')
if parts[0] in ['table', 'string']:
self.entities_names[parts[1]] = i
else:
_, _, table_name, column_name = parts
self.entities_names[f'{table_name}@{column_name}'] = i
self.valid_actions = []
self.valid_actions_flat = []
def get_action_sequence_and_all_actions(self,
allow_aliases: bool = False) -> Tuple[List[str], List[str]]:
grammar_with_context = deepcopy(self.base_grammar_dictionary)
if not allow_aliases:
update_grammar_to_be_table_names_free(grammar_with_context)
schema = self.db_context.schema
update_grammar_with_tables(grammar_with_context, schema)
grammar = Grammar(format_grammar_string(grammar_with_context))
valid_actions = initialize_valid_actions(grammar)
all_actions = set()
for action_list in valid_actions.values():
all_actions.update(action_list)
sorted_actions = sorted(all_actions)
self.valid_actions = valid_actions
self.valid_actions_flat = sorted_actions
action_sequence = None
if self.query is not None:
sql_visitor = SqlVisitor(grammar)
query = " ".join(self.query).lower().replace("``", "'").replace("''", "'")
try:
action_sequence = sql_visitor.parse(query) if query else []
except ParseError as e:
pass
return action_sequence, sorted_actions
def get_all_actions(self, schema,
flip_joins: bool,
allow_aliases: bool) -> Tuple[List[str], List[str]]:
grammar_with_context = deepcopy(self.base_grammar_dictionary)
if not allow_aliases:
update_grammar_to_be_table_names_free(grammar_with_context)
if flip_joins:
update_grammar_flip_joins(grammar_with_context)
update_grammar_with_tables(grammar_with_context, schema, self.db_id)
grammar = Grammar(format_grammar_string(grammar_with_context))
valid_actions = initialize_valid_actions(grammar)
all_actions = set()
for action_list in valid_actions.values():
all_actions.update(action_list)
sorted_actions = sorted(all_actions)
self.valid_actions = valid_actions
self.valid_actions_flat = sorted_actions
return sorted_actions
def is_global_rule(self, rhs: str) -> bool:
rhs = rhs.strip('[] ')
if rhs[0] != '"':
return True
return rhs.strip('"') not in self.entities_names
def get_oracle_relevance_score(self, oracle_entities: set):
"""
return 0/1 for each schema item if it should be in the graph,
given the used entities in the gold answer
"""
scores = [0 for _ in range(len(self.db_context.knowledge_graph.entities))]
for i, entity in enumerate(self.db_context.knowledge_graph.entities):
parts = entity.split(':')
if parts[0] == 'column':
name = parts[2] + '@' + parts[3]
else:
name = parts[-1]
if name in oracle_entities:
scores[i] = 1
return scores
def get_action_entity_mapping(self) -> Dict[int, int]:
mapping = {}
for action_index, action in enumerate(self.valid_actions_flat):
# default is padding
mapping[action_index] = -1
action = action.split(" -> ")[1].strip('[]')
action_stripped = action.strip('\"')
if action[0] != '"' or action_stripped not in self.entities_names:
continue
mapping[action_index] = self.entities_names[action_stripped]
return mapping
def get_query_without_table_hints(self):
if not self.query:
return ''
toks = []
for tok in self.query:
if '@' in tok:
parts = tok.split('@')
if '.' in parts[0]:
toks.append(parts[0].split('.')[0] + '.' + parts[1])
else:
toks.append(parts[1])
else:
toks.append(tok)
return toks
# def is_ambiguous_column(self, action_rhs: str, actions_sequence: List[str], action_index: int):
# """
# a column would be ambiguous if another table is used in the query and it has a column with the same name
# currently, return true only for join clauses
# """
# if actions_sequence[action_index-1].startswith('join_condition -> ') or \
# actions_sequence[action_index-2].startswith('join_condition -> '):
# return True
#
# return False
# # column_table, column_name = action_rhs.strip('"').split('.')
# # tables_used = [a.split(' -> ')[1].strip('[]\"') for a in actions_sequence if a.startswith('table_name -> ')]
# # columns_with_same_name = set([a.split(' -> ')[1].strip('[]\"') for a in actions_sequence
# # if a.startswith('column_name -> ') and a.strip('[]\"').endswith(column_name)])
# # table_of_columns_with_same_name = set([c.split('.')[0] for c in columns_with_same_name])
# # if len(table_of_columns_with_same_name) > 1:
# # return True
# # if column_table not in tables_used:
# # return False
# # other_tables = [t for t in tables_used if t != column_table]
# # other_tables_columns = [[c.split(':')[-1] for c in self.knowledge_graph.neighbors[f'table:{t}']] for t in other_tables]
# # other_tables_columns_set = set([item for sublist in other_tables_columns for item in sublist])
# # return column_name in other_tables_columns_set
================================================
FILE: spider_evaluation/evaluate.py
================================================
################################
# val: number(float)/string(str)/sql(dict)
# col_unit: (agg_id, col_id, isDistinct(bool))
# val_unit: (unit_op, col_unit1, col_unit2)
# table_unit: (table_type, col_unit/sql)
# cond_unit: (not_op, op_id, val_unit, val1, val2)
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
# sql {
# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
# 'where': condition
# 'groupBy': [col_unit1, col_unit2, ...]
# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
# 'having': condition
# 'limit': None/limit value
# 'intersect': None/sql
# 'except': None/sql
# 'union': None/sql
# }
################################
import copy
import os
import json
import sqlite3
import argparse
from spider_evaluation import get_schema, Schema, get_sql
# Flag to disable value evaluation
DISABLE_VALUE = True
# Flag to disable distinct in select evaluation
DISABLE_DISTINCT = True
CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
JOIN_KEYWORDS = ('join', 'on', 'as')
WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
UNIT_OPS = ('none', '-', '+', "*", '/')
AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
TABLE_TYPE = {
'sql': "sql",
'table_unit': "table_unit",
}
COND_OPS = ('and', 'or')
SQL_OPS = ('intersect', 'union', 'except')
ORDER_OPS = ('desc', 'asc')
HARDNESS = {
"component1": ('where', 'group', 'order', 'limit', 'join', 'or', 'like'),
"component2": ('except', 'union', 'intersect')
}
def condition_has_or(conds):
return 'or' in conds[1::2]
def condition_has_like(conds):
return WHERE_OPS.index('like') in [cond_unit[1] for cond_unit in conds[::2]]
def condition_has_sql(conds):
for cond_unit in conds[::2]:
val1, val2 = cond_unit[3], cond_unit[4]
if val1 is not None and type(val1) is dict:
return True
if val2 is not None and type(val2) is dict:
return True
return False
def val_has_op(val_unit):
return val_unit[0] != UNIT_OPS.index('none')
def has_agg(unit):
return unit[0] != AGG_OPS.index('none')
def accuracy(count, total):
if count == total:
return 1
return 0
def recall(count, total):
if count == total:
return 1
return 0
def F1(acc, rec):
if (acc + rec) == 0:
return 0
return (2. * acc * rec) / (acc + rec)
def get_scores(count, pred_total, label_total):
if pred_total != label_total:
return 0,0,0
elif count == pred_total:
return 1,1,1
return 0,0,0
def eval_sel(pred, label):
pred_sel = pred['select'][1]
label_sel = label['select'][1]
label_wo_agg = [unit[1] for unit in label_sel]
pred_total = len(pred_sel)
label_total = len(label_sel)
cnt = 0
cnt_wo_agg = 0
for unit in pred_sel:
if unit in label_sel:
cnt += 1
label_sel.remove(unit)
if unit[1] in label_wo_agg:
cnt_wo_agg += 1
label_wo_agg.remove(unit[1])
return label_total, pred_total, cnt, cnt_wo_agg
def eval_where(pred, label):
pred_conds = [unit for unit in pred['where'][::2]]
label_conds = [unit for unit in label['where'][::2]]
label_wo_agg = [unit[2] for unit in label_conds]
pred_total = len(pred_conds)
label_total = len(label_conds)
cnt = 0
cnt_wo_agg = 0
for unit in pred_conds:
if unit in label_conds:
cnt += 1
label_conds.remove(unit)
if unit[2] in label_wo_agg:
cnt_wo_agg += 1
label_wo_agg.remove(unit[2])
return label_total, pred_total, cnt, cnt_wo_agg
def eval_group(pred, label):
pred_cols = [unit[1] for unit in pred['groupBy']]
label_cols = [unit[1] for unit in label['groupBy']]
pred_total = len(pred_cols)
label_total = len(label_cols)
cnt = 0
pred_cols = [pred.split(".")[1] if "." in pred else pred for pred in pred_cols]
label_cols = [label.split(".")[1] if "." in label else label for label in label_cols]
for col in pred_cols:
if col in label_cols:
cnt += 1
label_cols.remove(col)
return label_total, pred_total, cnt
def eval_having(pred, label):
pred_total = label_total = cnt = 0
if len(pred['groupBy']) > 0:
pred_total = 1
if len(label['groupBy']) > 0:
label_total = 1
pred_cols = [unit[1] for unit in pred['groupBy']]
label_cols = [unit[1] for unit in label['groupBy']]
if pred_total == label_total == 1 \
and pred_cols == label_cols \
and pred['having'] == label['having']:
cnt = 1
return label_total, pred_total, cnt
def eval_order(pred, label):
pred_total = label_total = cnt = 0
if len(pred['orderBy']) > 0:
pred_total = 1
if len(label['orderBy']) > 0:
label_total = 1
if len(label['orderBy']) > 0 and pred['orderBy'] == label['orderBy'] and \
((pred['limit'] is None and label['limit'] is None) or (pred['limit'] is not None and label['limit'] is not None)):
cnt = 1
return label_total, pred_total, cnt
def eval_and_or(pred, label):
pred_ao = pred['where'][1::2]
label_ao = label['where'][1::2]
pred_ao = set(pred_ao)
label_ao = set(label_ao)
if pred_ao == label_ao:
return 1,1,1
return len(pred_ao),len(label_ao),0
def get_nestedSQL(sql):
nested = []
for cond_unit in sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]:
if type(cond_unit[3]) is dict:
nested.append(cond_unit[3])
if type(cond_unit[4]) is dict:
nested.append(cond_unit[4])
if sql['intersect'] is not None:
nested.append(sql['intersect'])
if sql['except'] is not None:
nested.append(sql['except'])
if sql['union'] is not None:
nested.append(sql['union'])
return nested
def eval_nested(pred, label):
label_total = 0
pred_total = 0
cnt = 0
if pred is not None:
pred_total += 1
if label is not None:
label_total += 1
if pred is not None and label is not None:
cnt += Evaluator().eval_exact_match(pred, label)
return label_total, pred_total, cnt
def eval_IUEN(pred, label):
lt1, pt1, cnt1 = eval_nested(pred['intersect'], label['intersect'])
lt2, pt2, cnt2 = eval_nested(pred['except'], label['except'])
lt3, pt3, cnt3 = eval_nested(pred['union'], label['union'])
label_total = lt1 + lt2 + lt3
pred_total = pt1 + pt2 + pt3
cnt = cnt1 + cnt2 + cnt3
return label_total, pred_total, cnt
def get_keywords(sql):
res = set()
if len(sql['where']) > 0:
res.add('where')
if len(sql['groupBy']) > 0:
res.add('group')
if len(sql['having']) > 0:
res.add('having')
if len(sql['orderBy']) > 0:
res.add(sql['orderBy'][0])
res.add('order')
if sql['limit'] is not None:
res.add('limit')
if sql['except'] is not None:
res.add('except')
if sql['union'] is not None:
res.add('union')
if sql['intersect'] is not None:
res.add('intersect')
# or keyword
ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
if len([token for token in ao if token == 'or']) > 0:
res.add('or')
cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
# not keyword
if len([cond_unit for cond_unit in cond_units if cond_unit[0]]) > 0:
res.add('not')
# in keyword
if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('in')]) > 0:
res.add('in')
# like keyword
if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) > 0:
res.add('like')
return res
def eval_keywords(pred, label):
pred_keywords = get_keywords(pred)
label_keywords = get_keywords(label)
pred_total = len(pred_keywords)
label_total = len(label_keywords)
cnt = 0
for k in pred_keywords:
if k in label_keywords:
cnt += 1
return label_total, pred_total, cnt
def count_agg(units):
return len([unit for unit in units if has_agg(unit)])
def count_component1(sql):
count = 0
if len(sql['where']) > 0:
count += 1
if len(sql['groupBy']) > 0:
count += 1
if len(sql['orderBy']) > 0:
count += 1
if sql['limit'] is not None:
count += 1
if len(sql['from']['table_units']) > 0: # JOIN
count += len(sql['from']['table_units']) - 1
ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
count += len([token for token in ao if token == 'or'])
cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
count += len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')])
return count
def count_component2(sql):
nested = get_nestedSQL(sql)
return len(nested)
def count_others(sql):
count = 0
# number of aggregation
agg_count = count_agg(sql['select'][1])
agg_count += count_agg(sql['where'][::2])
agg_count += count_agg(sql['groupBy'])
if len(sql['orderBy']) > 0:
agg_count += count_agg([unit[1] for unit in sql['orderBy'][1] if unit[1]] +
[unit[2] for unit in sql['orderBy'][1] if unit[2]])
agg_count += count_agg(sql['having'])
if agg_count > 1:
count += 1
# number of select columns
if len(sql['select'][1]) > 1:
count += 1
# number of where conditions
if len(sql['where']) > 1:
count += 1
# number of group by clauses
if len(sql['groupBy']) > 1:
count += 1
return count
class Evaluator:
"""A simple evaluator"""
def __init__(self):
self.partial_scores = None
def eval_hardness(self, sql):
count_comp1_ = count_component1(sql)
count_comp2_ = count_component2(sql)
count_others_ = count_others(sql)
if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0:
return "easy"
elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or \
(count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0):
return "medium"
elif (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or \
(2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or \
(count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1):
return "hard"
else:
return "extra"
def eval_exact_match(self, pred, label):
partial_scores = self.eval_partial_match(pred, label)
self.partial_scores = partial_scores
for _, score in partial_scores.items():
if score['f1'] != 1:
return 0
if len(label['from']['table_units']) > 0:
label_tables = sorted(label['from']['table_units'])
pred_tables = sorted(pred['from']['table_units'])
if label_tables != pred_tables:
return False
if len(label['from']['conds']) > 0:
label_joins = sorted(label['from']['conds'], key=lambda x: str(x))
pred_joins = sorted(pred['from']['conds'], key=lambda x: str(x))
if label_joins != pred_joins:
return False
return 1
def eval_partial_match(self, pred, label):
res = {}
label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['select'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
res['select(no AGG)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['where'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
res['where(no OP)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_group(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['group(no Having)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_having(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['group'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_order(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['order'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_and_or(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['and/or'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_IUEN(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['IUEN'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_keywords(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['keywords'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
return res
def isValidSQL(sql, db):
conn = sqlite3.connect(db)
cursor = conn.cursor()
try:
cursor.execute(sql, [])
except Exception as e:
return False
return True
def print_scores(scores, etype):
levels = ['easy', 'medium', 'hard', 'extra', 'all']
partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
'group', 'order', 'and/or', 'IUEN', 'keywords']
print("{:20} {:20} {:20} {:20} {:20} {:20}".format("", *levels))
counts = [scores[level]['count'] for level in levels]
print("{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}".format("count", *counts))
if etype in ["all", "exec"]:
print('===================== EXECUTION ACCURACY =====================')
this_scores = [scores[level]['exec'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("execution", *this_scores))
if etype in ["all", "match"]:
print('\n====================== EXACT MATCHING ACCURACY =====================')
exact_scores = [scores[level]['exact'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("exact match", *exact_scores))
print('\n---------------------PARTIAL MATCHING ACCURACY----------------------')
for type_ in partial_types:
this_scores = [scores[level]['partial'][type_]['acc'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
print('---------------------- PARTIAL MATCHING RECALL ----------------------')
for type_ in partial_types:
this_scores = [scores[level]['partial'][type_]['rec'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
print('---------------------- PARTIAL MATCHING F1 --------------------------')
for type_ in partial_types:
this_scores = [scores[level]['partial'][type_]['f1'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
def evaluate(gold, predict, db_dir, etype, kmaps):
with open(gold) as f:
glist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
with open(predict) as f:
plist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
# plist = [("select max(Share),min(Share) from performance where Type != 'terminal'", "orchestra")]
# glist = [("SELECT max(SHARE) , min(SHARE) FROM performance WHERE TYPE != 'Live final'", "orchestra")]
evaluator = Evaluator()
levels = ['easy', 'medium', 'hard', 'extra', 'all']
partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
'group', 'order', 'and/or', 'IUEN', 'keywords']
entries = []
scores = {}
for level in levels:
scores[level] = {'count': 0, 'partial': {}, 'exact': 0.}
scores[level]['exec'] = 0
for type_ in partial_types:
scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0.,'acc_count':0,'rec_count':0}
eval_err_num = 0
for p, g in zip(plist, glist):
p_str = p[0]
g_str, db = g
db_name = db
db = os.path.join(db_dir, db, db + ".sqlite")
schema = Schema(get_schema(db))
g_sql = get_sql(schema, g_str)
hardness = evaluator.eval_hardness(g_sql)
scores[hardness]['count'] += 1
scores['all']['count'] += 1
try:
p_sql = get_sql(schema, p_str)
except:
# If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql
p_sql = {
"except": None,
"from": {
"conds": [],
"table_units": []
},
"groupBy": [],
"having": [],
"intersect": None,
"limit": None,
"orderBy": [],
"select": [
False,
[]
],
"union": None,
"where": []
}
eval_err_num += 1
print("eval_err_num:{}".format(eval_err_num))
print(p_str)
print()
# rebuild sql for value evaluation
kmap = kmaps[db_name]
g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema)
g_sql = rebuild_sql_val(g_sql)
g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap)
p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema)
p_sql = rebuild_sql_val(p_sql)
p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)
# p_sql_copy = copy.deepcopy(p_sql)
# g_sql_copy = copy.deepcopy(g_sql)
# if not eval_exec_match(db, p_str, g_str, p_sql_copy, g_sql_copy) and evaluator.eval_exact_match(p_sql_copy, g_sql_copy):
# a = 1
if etype in ["all", "exec"]:
exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql)
if exec_score:
scores[hardness]['exec'] += 1
scores['all']['exec'] += exec_score
if etype in ["all", "match"]:
exact_score = evaluator.eval_exact_match(p_sql, g_sql)
partial_scores = evaluator.partial_scores
if exact_score == 0:
print("{} pred: {}".format(hardness,p_str))
print("{} gold: {}".format(hardness,g_str))
print("")
scores[hardness]['exact'] += exact_score
scores['all']['exact'] += exact_score
for type_ in partial_types:
if partial_scores[type_]['pred_total'] > 0:
scores[hardness]['partial'][type_]['acc'] += partial_scores[type_]['acc']
scores[hardness]['partial'][type_]['acc_count'] += 1
if partial_scores[type_]['label_total'] > 0:
scores[hardness]['partial'][type_]['rec'] += partial_scores[type_]['rec']
scores[hardness]['partial'][type_]['rec_count'] += 1
scores[hardness]['partial'][type_]['f1'] += partial_scores[type_]['f1']
if partial_scores[type_]['pred_total'] > 0:
scores['all']['partial'][type_]['acc'] += partial_scores[type_]['acc']
scores['all']['partial'][type_]['acc_count'] += 1
if partial_scores[type_]['label_total'] > 0:
scores['all']['partial'][type_]['rec'] += partial_scores[type_]['rec']
scores['all']['partial'][type_]['rec_count'] += 1
scores['all']['partial'][type_]['f1'] += partial_scores[type_]['f1']
entries.append({
'predictSQL': p_str,
'goldSQL': g_str,
'hardness': hardness,
'exact': exact_score,
'partial': partial_scores
})
for level in levels:
if scores[level]['count'] == 0:
continue
if etype in ["all", "exec"]:
scores[level]['exec'] /= scores[level]['count']
if etype in ["all", "match"]:
scores[level]['exact'] /= scores[level]['count']
for type_ in partial_types:
if scores[level]['partial'][type_]['acc_count'] == 0:
scores[level]['partial'][type_]['acc'] = 0
else:
scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \
scores[level]['partial'][type_]['acc_count'] * 1.0
if scores[level]['partial'][type_]['rec_count'] == 0:
scores[level]['partial'][type_]['rec'] = 0
else:
scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \
scores[level]['partial'][type_]['rec_count'] * 1.0
if scores[level]['partial'][type_]['acc'] == 0 and scores[level]['partial'][type_]['rec'] == 0:
scores[level]['partial'][type_]['f1'] = 1
else:
scores[level]['partial'][type_]['f1'] = \
2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / (
scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc'])
print_scores(scores, etype)
def eval_exec_match(db, p_str, g_str, pred, gold):
"""
return 1 if the values between prediction and gold are matching
in the corresponding index. Currently not support multiple col_unit(pairs).
"""
conn = sqlite3.connect(db)
cursor = conn.cursor()
conn.text_factory = bytes
try:
cursor.execute(p_str)
p_res = cursor.fetchall()
except:
return False
cursor.execute(g_str)
q_res = cursor.fetchall()
def res_map(res, val_units):
rmap = {}
for idx, val_unit in enumerate(val_units):
key = tuple(val_unit[1]) if not val_unit[2] else (val_unit[0], tuple(val_unit[1]), tuple(val_unit[2]))
rmap[key] = [r[idx] for r in res]
return rmap
p_val_units = [unit[1] for unit in pred['select'][1]]
q_val_units = [unit[1] for unit in gold['select'][1]]
return res_map(p_res, p_val_units) == res_map(q_res, q_val_units)
# Rebuild SQL functions for value evaluation
def rebuild_cond_unit_val(cond_unit):
if cond_unit is None or not DISABLE_VALUE:
return cond_unit
not_op, op_id, val_unit, val1, val2 = cond_unit
if type(val1) is not dict:
val1 = None
else:
val1 = rebuild_sql_val(val1)
if type(val2) is not dict:
val2 = None
else:
val2 = rebuild_sql_val(val2)
return not_op, op_id, val_unit, val1, val2
def rebuild_condition_val(condition):
if condition is None or not DISABLE_VALUE:
return condition
res = []
for idx, it in enumerate(condition):
if idx % 2 == 0:
res.append(rebuild_cond_unit_val(it))
else:
res.append(it)
return res
def rebuild_sql_val(sql):
if sql is None or not DISABLE_VALUE:
return sql
sql['from']['conds'] = rebuild_condition_val(sql['from']['conds'])
sql['having'] = rebuild_condition_val(sql['having'])
sql['where'] = rebuild_condition_val(sql['where'])
sql['intersect'] = rebuild_sql_val(sql['intersect'])
sql['except'] = rebuild_sql_val(sql['except'])
sql['union'] = rebuild_sql_val(sql['union'])
return sql
# Rebuild SQL functions for foreign key evaluation
def build_valid_col_units(table_units, schema):
col_ids = [table_unit[1] for table_unit in table_units if table_unit[0] == TABLE_TYPE['table_unit']]
prefixs = [col_id[:-2] for col_id in col_ids]
valid_col_units= []
for value in schema.idMap.values():
if '.' in value and value[:value.index('.')] in prefixs:
valid_col_units.append(value)
return valid_col_units
def rebuild_col_unit_col(valid_col_units, col_unit, kmap):
if col_unit is None:
return col_unit
agg_id, col_id, distinct = col_unit
if col_id in kmap and col_id in valid_col_units:
col_id = kmap[col_id]
if DISABLE_DISTINCT:
distinct = None
return agg_id, col_id, distinct
def rebuild_val_unit_col(valid_col_units, val_unit, kmap):
if val_unit is None:
return val_unit
unit_op, col_unit1, col_unit2 = val_unit
col_unit1 = rebuild_col_unit_col(valid_col_units, col_unit1, kmap)
col_unit2 = rebuild_col_unit_col(valid_col_units, col_unit2, kmap)
return unit_op, col_unit1, col_unit2
def rebuild_table_unit_col(valid_col_units, table_unit, kmap):
if table_unit is None:
return table_unit
table_type, col_unit_or_sql = table_unit
if isinstance(col_unit_or_sql, tuple):
col_unit_or_sql = rebuild_col_unit_col(valid_col_units, col_unit_or_sql, kmap)
return table_type, col_unit_or_sql
def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap):
if cond_unit is None:
return cond_unit
not_op, op_id, val_unit, val1, val2 = cond_unit
val_unit = rebuild_val_unit_col(valid_col_units, val_unit, kmap)
return not_op, op_id, val_unit, val1, val2
def rebuild_condition_col(valid_col_units, condition, kmap):
for idx in range(len(condition)):
if idx % 2 == 0:
condition[idx] = rebuild_cond_unit_col(valid_col_units, condition[idx], kmap)
return condition
def rebuild_select_col(valid_col_units, sel, kmap):
if sel is None:
return sel
distinct, _list = sel
new_list = []
for it in _list:
agg_id, val_unit = it
new_list.append((agg_id, rebuild_val_unit_col(valid_col_units, val_unit, kmap)))
if DISABLE_DISTINCT:
distinct = None
return distinct, new_list
def rebuild_from_col(valid_col_units, from_, kmap):
if from_ is None:
return from_
from_['table_units'] = [rebuild_table_unit_col(valid_col_units, table_unit, kmap) for table_unit in from_['table_units']]
from_['conds'] = rebuild_condition_col(valid_col_units, from_['conds'], kmap)
return from_
def rebuild_group_by_col(valid_col_units, group_by, kmap):
if group_by is None:
return group_by
return [rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by]
def rebuild_order_by_col(valid_col_units, order_by, kmap):
if order_by is None or len(order_by) == 0:
return order_by
direction, val_units = order_by
new_val_units = [rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units]
return direction, new_val_units
def rebuild_sql_col(valid_col_units, sql, kmap):
if sql is None:
return sql
sql['select'] = rebuild_select_col(valid_col_units, sql['select'], kmap)
sql['from'] = rebuild_from_col(valid_col_units, sql['from'], kmap)
sql['where'] = rebuild_condition_col(valid_col_units, sql['where'], kmap)
sql['groupBy'] = rebuild_group_by_col(valid_col_units, sql['groupBy'], kmap)
sql['orderBy'] = rebuild_order_by_col(valid_col_units, sql['orderBy'], kmap)
sql['having'] = rebuild_condition_col(valid_col_units, sql['having'], kmap)
sql['intersect'] = rebuild_sql_col(valid_col_units, sql['intersect'], kmap)
sql['except'] = rebuild_sql_col(valid_col_units, sql['except'], kmap)
sql['union'] = rebuild_sql_col(valid_col_units, sql['union'], kmap)
return sql
def build_foreign_key_map(entry):
cols_orig = entry["column_names_original"]
tables_orig = entry["table_names_original"]
# rebuild cols corresponding to idmap in Schema
cols = []
for col_orig in cols_orig:
if col_orig[0] >= 0:
t = tables_orig[col_orig[0]]
c = col_orig[1]
cols.append("__" + t.lower() + "." + c.lower() + "__")
else:
cols.append("__all__")
def keyset_in_list(k1, k2, k_list):
for k_set in k_list:
if k1 in k_set or k2 in k_set:
return k_set
new_k_set = set()
k_list.append(new_k_set)
return new_k_set
foreign_key_list = []
foreign_keys = entry["foreign_keys"]
for fkey in foreign_keys:
key1, key2 = fkey
key_set = keyset_in_list(key1, key2, foreign_key_list)
key_set.add(key1)
key_set.add(key2)
foreign_key_map = {}
for key_set in foreign_key_list:
sorted_list = sorted(list(key_set))
midx = sorted_list[0]
for idx in sorted_list:
foreign_key_map[cols[idx]] = cols[midx]
return foreign_key_map
def build_foreign_key_map_from_json(table):
with open(table) as f:
data = json.load(f)
tables = {}
for entry in data:
tables[entry['db_id']] = build_foreign_key_map(entry)
return tables
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--gold', dest='gold', type=str)
parser.add_argument('--pred', dest='pred', type=str)
parser.add_argument('--db', dest='db', type=str)
parser.add_argument('--table', dest='table', type=str)
parser.add_argument('--etype', dest='etype', type=str)
args = parser.parse_args()
gold = args.gold
pred = args.pred
db_dir = args.db
table = args.table
etype = args.etype
assert etype in ["all", "exec", "match"], "Unknown evaluation method"
kmaps = build_foreign_key_map_from_json(table)
evaluate(gold, pred, db_dir, etype, kmaps)
================================================
FILE: spider_evaluation/process_sql.py
================================================
################################
# Assumptions:
# 1. sql is correct
# 2. only table name has alias
# 3. only one intersect/union/except
#
# val: number(float)/string(str)/sql(dict)
# col_unit: (agg_id, col_id, isDistinct(bool))
# val_unit: (unit_op, col_unit1, col_unit2)
# table_unit: (table_type, col_unit/sql)
# cond_unit: (not_op, op_id, val_unit, val1, val2)
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
# sql {
# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
# 'where': condition
# 'groupBy': [col_unit1, col_unit2, ...]
# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
# 'having': condition
# 'limit': None/limit value
# 'intersect': None/sql
# 'except': None/sql
# 'union': None/sql
# }
################################
import json
import sqlite3
from nltk import word_tokenize
CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
JOIN_KEYWORDS = ('join', 'on', 'as')
WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
UNIT_OPS = ('none', '-', '+', "*", '/')
AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
TABLE_TYPE = {
'sql': "sql",
'table_unit': "table_unit",
}
COND_OPS = ('and', 'or')
SQL_OPS = ('intersect', 'union', 'except')
ORDER_OPS = ('desc', 'asc')
mapped_entities = []
class Schema:
"""
Simple schema which maps table&column to a unique identifier
"""
def __init__(self, schema):
self._schema = schema
self._idMap = self._map(self._schema)
@property
def schema(self):
return self._schema
@property
def idMap(self):
return self._idMap
def _map(self, schema):
idMap = {'*': "__all__"}
id = 1
for key, vals in schema.items():
for val in vals:
idMap[key.lower() + "." + val.lower()] = "__" + key.lower() + "." + val.lower() + "__"
id += 1
for key in schema:
idMap[key.lower()] = "__" + key.lower() + "__"
id += 1
return idMap
def get_schema(db):
"""
Get database's schema, which is a dict with table name as key
and list of column names as value
:param db: database path
:return: schema dict
"""
schema = {}
conn = sqlite3.connect(db)
cursor = conn.cursor()
# fetch table names
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = [str(table[0].lower()) for table in cursor.fetchall()]
# fetch table info
for table in tables:
cursor.execute("PRAGMA table_info({})".format(table))
schema[table] = [str(col[1].lower()) for col in cursor.fetchall()]
return schema
def get_schema_from_json(fpath):
with open(fpath) as f:
data = json.load(f)
schema = {}
for entry in data:
table = str(entry['table'].lower())
cols = [str(col['column_name'].lower()) for col in entry['col_data']]
schema[table] = cols
return schema
def tokenize(string):
string = str(string)
string = string.replace("\'", "\"") # ensures all string values wrapped by "" problem??
quote_idxs = [idx for idx, char in enumerate(string) if char == '"']
assert len(quote_idxs) % 2 == 0, "Unexpected quote"
# keep string value as token
vals = {}
for i in range(len(quote_idxs)-1, -1, -2):
qidx1 = quote_idxs[i-1]
qidx2 = quote_idxs[i]
val = string[qidx1: qidx2+1]
key = "__val_{}_{}__".format(qidx1, qidx2)
string = string[:qidx1] + key + string[qidx2+1:]
vals[key] = val
toks = [word.lower() for word in word_tokenize(string)]
# replace with string value token
for i in range(len(toks)):
if toks[i] in vals:
toks[i] = vals[toks[i]]
# find if there exists !=, >=, <=
eq_idxs = [idx for idx, tok in enumerate(toks) if tok == "="]
eq_idxs.reverse()
prefix = ('!', '>', '<')
for eq_idx in eq_idxs:
pre_tok = toks[eq_idx-1]
if pre_tok in prefix:
toks = toks[:eq_idx-1] + [pre_tok + "="] + toks[eq_idx+1: ]
return toks
def scan_alias(toks):
"""Scan the index of 'as' and build the map for all alias"""
as_idxs = [idx for idx, tok in enumerate(toks) if tok == 'as']
alias = {}
for idx in as_idxs:
alias[toks[idx+1]] = toks[idx-1]
return alias
def get_tables_with_alias(schema, toks):
tables = scan_alias(toks)
for key in schema:
assert key not in tables, "Alias {} has the same name in table".format(key)
tables[key] = key
return tables
def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None):
"""
:returns next idx, column id
"""
global mapped_entities
tok = toks[start_idx]
if tok == "*":
return start_idx + 1, schema.idMap[tok]
if '.' in tok: # if token is a composite
alias, col = tok.split('.')
key = tables_with_alias[alias] + "." + col
mapped_entities.append((start_idx, tables_with_alias[alias] + "@" + col))
return start_idx+1, schema.idMap[key]
assert default_tables is not None and len(default_tables) > 0, "Default tables should not be None or empty"
for alias in default_tables:
table = tables_with_alias[alias]
if tok in schema.schema[table]:
key = table + "." + tok
mapped_entities.append((start_idx, table + "@" + tok))
return start_idx+1, schema.idMap[key]
assert False, "Error col: {}".format(tok)
def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=None):
"""
:returns next idx, (agg_op id, col_id)
"""
idx = start_idx
len_ = len(toks)
isBlock = False
isDistinct = False
if toks[idx] == '(':
isBlock = True
idx += 1
if toks[idx] in AGG_OPS:
agg_id = AGG_OPS.index(toks[idx])
idx += 1
assert idx < len_ and toks[idx] == '('
idx += 1
if toks[idx] == "distinct":
idx += 1
isDistinct = True
idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables)
assert idx < len_ and toks[idx] == ')'
idx += 1
return idx, (agg_id, col_id, isDistinct)
if toks[idx] == "distinct":
idx += 1
isDistinct = True
agg_id = AGG_OPS.index("none")
idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables)
if isBlock:
assert toks[idx] == ')'
idx += 1 # skip ')'
return idx, (agg_id, col_id, isDistinct)
def parse_val_unit(toks, start_idx, tables_with_alias, schema, default_tables=None):
idx = start_idx
len_ = len(toks)
isBlock = False
if toks[idx] == '(':
isBlock = True
idx += 1
col_unit1 = None
col_unit2 = None
unit_op = UNIT_OPS.index('none')
idx, col_unit1 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
if idx < len_ and toks[idx] in UNIT_OPS:
unit_op = UNIT_OPS.index(toks[idx])
idx += 1
idx, col_unit2 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
if isBlock:
assert toks[idx] == ')'
idx += 1 # skip ')'
return idx, (unit_op, col_unit1, col_unit2)
def parse_table_unit(toks, start_idx, tables_with_alias, schema):
"""
:returns next idx, table id, table name
"""
idx = start_idx
len_ = len(toks)
key = tables_with_alias[toks[idx]]
if idx + 1 < len_ and toks[idx+1] == "as":
idx += 3
else:
idx += 1
return idx, schema.idMap[key], key
def parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None):
idx = start_idx
len_ = len(toks)
isBlock = False
if toks[idx] == '(':
isBlock = True
idx += 1
if toks[idx] == 'select':
idx, val = parse_sql(toks, idx, tables_with_alias, schema)
elif "\"" in toks[idx]: # token is a string value
val = toks[idx]
idx += 1
else:
try:
val = float(toks[idx])
idx += 1
except:
end_idx = idx
while end_idx < len_ and toks[end_idx] != ',' and toks[end_idx] != ')'\
and toks[end_idx] != 'and' and toks[end_idx] not in CLAUSE_KEYWORDS and toks[end_idx] not in JOIN_KEYWORDS:
end_idx += 1
idx, val = parse_col_unit(toks[: end_idx], start_idx, tables_with_alias, schema, default_tables)
idx = end_idx
if isBlock:
assert toks[idx] == ')'
idx += 1
return idx, val
def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=None):
idx = start_idx
len_ = len(toks)
conds = []
while idx < len_:
idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
not_op = False
if toks[idx] == 'not':
not_op = True
idx += 1
assert idx < len_ and toks[idx] in WHERE_OPS, "Error condition: idx: {}, tok: {}".format(idx, toks[idx])
op_id = WHERE_OPS.index(toks[idx])
idx += 1
val1 = val2 = None
if op_id == WHERE_OPS.index('between'): # between..and... special case: dual values
idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
assert toks[idx] == 'and'
idx += 1
idx, val2 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
else: # normal case: single value
idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
val2 = None
conds.append((not_op, op_id, val_unit, val1, val2))
if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";") or toks[idx] in JOIN_KEYWORDS):
break
if idx < len_ and toks[idx] in COND_OPS:
conds.append(toks[idx])
idx += 1 # skip and/or
return idx, conds
def parse_select(toks, start_idx, tables_with_alias, schema, default_tables=None):
idx = start_idx
len_ = len(toks)
assert toks[idx] == 'select', "'select' not found"
idx += 1
isDistinct = False
if idx < len_ and toks[idx] == 'distinct':
idx += 1
isDistinct = True
val_units = []
while idx < len_ and toks[idx] not in CLAUSE_KEYWORDS:
agg_id = AGG_OPS.index("none")
if toks[idx] in AGG_OPS:
agg_id = AGG_OPS.index(toks[idx])
idx += 1
idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
val_units.append((agg_id, val_unit))
if idx < len_ and toks[idx] == ',':
idx += 1 # skip ','
return idx, (isDistinct, val_units)
def parse_from(toks, start_idx, tables_with_alias, schema):
"""
Assume in the from clause, all table units are combined with join
"""
assert 'from' in toks[start_idx:], "'from' not found"
len_ = len(toks)
idx = toks.index('from', start_idx) + 1
default_tables = []
table_units = []
conds = []
while idx < len_:
isBlock = False
if toks[idx] == '(':
isBlock = True
idx += 1
if toks[idx] == 'select':
idx, sql = parse_sql(toks, idx, tables_with_alias, schema)
table_units.append((TABLE_TYPE['sql'], sql))
else:
if idx < len_ and toks[idx] == 'join':
idx += 1 # skip join
idx, table_unit, table_name = parse_table_unit(toks, idx, tables_with_alias, schema)
table_units.append((TABLE_TYPE['table_unit'],table_unit))
default_tables.append(table_name)
if idx < len_ and toks[idx] == "on":
idx += 1 # skip on
idx, this_conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
if len(conds) > 0:
conds.append('and')
conds.extend(this_conds)
if isBlock:
assert toks[idx] == ')'
idx += 1
if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
break
return idx, table_units, conds, default_tables
def parse_where(toks, start_idx, tables_with_alias, schema, default_tables):
idx = start_idx
len_ = len(toks)
if idx >= len_ or toks[idx] != 'where':
return idx, []
idx += 1
idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
return idx, conds
def parse_group_by(toks, start_idx, tables_with_alias, schema, default_tables):
idx = start_idx
len_ = len(toks)
col_units = []
if idx >= len_ or toks[idx] != 'group':
return idx, col_units
idx += 1
assert toks[idx] == 'by'
idx += 1
while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
idx, col_unit = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
col_units.append(col_unit)
if idx < len_ and toks[idx] == ',':
idx += 1 # skip ','
else:
break
return idx, col_units
def parse_order_by(toks, start_idx, tables_with_alias, schema, default_tables):
idx = start_idx
len_ = len(toks)
val_units = []
order_type = 'asc' # default type is 'asc'
if idx >= len_ or toks[idx] != 'order':
return idx, val_units
idx += 1
assert toks[idx] == 'by'
idx += 1
while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
val_units.append(val_unit)
if idx < len_ and toks[idx] in ORDER_OPS:
order_type = toks[idx]
idx += 1
if idx < len_ and toks[idx] == ',':
idx += 1 # skip ','
else:
break
return idx, (order_type, val_units)
def parse_having(toks, start_idx, tables_with_alias, schema, default_tables):
idx = start_idx
len_ = len(toks)
if idx >= len_ or toks[idx] != 'having':
return idx, []
idx += 1
idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
return idx, conds
def parse_limit(toks, start_idx):
idx = start_idx
len_ = len(toks)
if idx < len_ and toks[idx] == 'limit':
idx += 2
try:
limit_val = int(toks[idx-1])
except Exception:
limit_val = '"value"'
return idx, limit_val
return idx, None
def parse_sql(toks, start_idx, tables_with_alias, schema, mapped_entities_fn=None):
global mapped_entities
if mapped_entities_fn is not None:
mapped_entities = mapped_entities_fn()
isBlock = False # indicate whether this is a block of sql/sub-sql
len_ = len(toks)
idx = start_idx
sql = {}
if toks[idx] == '(':
isBlock = True
idx += 1
# parse from clause in order to get default tables
from_end_idx, table_units, conds, default_tables = parse_from(toks, start_idx, tables_with_alias, schema)
sql['from'] = {'table_units': table_units, 'conds': conds}
# select clause
_, select_col_units = parse_select(toks, idx, tables_with_alias, schema, default_tables)
idx = from_end_idx
sql['select'] = select_col_units
# where clause
idx, where_conds = parse_where(toks, idx, tables_with_alias, schema, default_tables)
sql['where'] = where_conds
# group by clause
idx, group_col_units = parse_group_by(toks, idx, tables_with_alias, schema, default_tables)
sql['groupBy'] = group_col_units
# having clause
idx, having_conds = parse_having(toks, idx, tables_with_alias, schema, default_tables)
sql['having'] = having_conds
# order by clause
idx, order_col_units = parse_order_by(toks, idx, tables_with_alias, schema, default_tables)
sql['orderBy'] = order_col_units
# limit clause
idx, limit_val = parse_limit(toks, idx)
sql['limit'] = limit_val
idx = skip_semicolon(toks, idx)
if isBlock:
assert toks[idx] == ')'
idx += 1 # skip ')'
idx = skip_semicolon(toks, idx)
# intersect/union/except clause
for op in SQL_OPS: # initialize IUE
sql[op] = None
if idx < len_ and toks[idx] in SQL_OPS:
sql_op = toks[idx]
idx += 1
idx, IUE_sql = parse_sql(toks, idx, tables_with_alias, schema)
sql[sql_op] = IUE_sql
if mapped_entities_fn is not None:
return idx, sql, mapped_entities
else:
return idx, sql
def load_data(fpath):
with open(fpath) as f:
data = json.load(f)
return data
def get_sql(schema, query):
toks = tokenize(query)
tables_with_alias = get_tables_with_alias(schema.schema, toks)
_, sql = parse_sql(toks, 0, tables_with_alias, schema)
return sql
def skip_semicolon(toks, start_idx):
idx = start_idx
while idx < len(toks) and toks[idx] == ";":
idx += 1
return idx
================================================
FILE: state_machines/states/grammar_based_state.py
================================================
from typing import Any, Dict, List, Sequence, Tuple
import torch
from allennlp.data.fields.production_rule_field import ProductionRule
from allennlp.state_machines.states.grammar_statelet import GrammarStatelet
from allennlp.state_machines.states.rnn_statelet import RnnStatelet
from allennlp.state_machines.states.state import State
# This syntax is pretty weird and ugly, but it's necessary to make mypy happy with the API that
# we've defined. We're using generics to make the type of `combine_states` come out right. See
# the note in `state_machines.state.py` for a little more detail.
from state_machines.states.sql_state import SqlState
class GrammarBasedState(State['GrammarBasedState']):
"""
A generic State that's suitable for most models that do grammar-based decoding. We keep around
a `group` of states, and each element in the group has a few things: a batch index, an action
history, a score, an ``RnnStatelet``, and a ``GrammarStatelet``. We additionally have some
information that's independent of any particular group element: a list of all possible actions
for all batch instances passed to ``model.forward()``, and a ``extras`` field that you can use
if you really need some extra information about each batch instance (like a string description,
or other metadata).
Finally, we also have a specially-treated, optional ``debug_info`` field. If this is given, it
should be an empty list for each group instance when the initial state is created. In that
case, we will keep around information about the actions considered at each timestep of decoding
and other things that you might want to visualize in a demo. This probably isn't necessary for
training, and to get it right we need to copy a bunch of data structures for each new state, so
it's best used only at evaluation / demo time.
Parameters
----------
batch_indices : ``List[int]``
Passed to super class; see docs there.
action_history : ``List[List[int]]``
Passed to super class; see docs there.
score : ``List[torch.Tensor]``
Passed to super class; see docs there.
rnn_state : ``List[RnnStatelet]``
An ``RnnStatelet`` for every group element. This keeps track of the current decoder hidden
state, the previous decoder output, the output from the encoder (for computing attentions),
and other things that are typical seq2seq decoder state things.
grammar_state : ``List[GrammarStatelet]``
This hold the current grammar state for each element of the group. The ``GrammarStatelet``
keeps track of which actions are currently valid.
possible_actions : ``List[List[ProductionRule]]``
The list of all possible actions that was passed to ``model.forward()``. We need this so
we can recover production strings, which we need to update grammar states.
extras : ``List[Any]``, optional (default=None)
If you need to keep around some extra data for each instance in the batch, you can put that
in here, without adding another field. This should be used `very sparingly`, as there is
no type checking or anything done on the contents of this field, and it will just be passed
around between ``States`` as-is, without copying.
debug_info : ``List[Any]``, optional (default=None).
"""
def __init__(self,
batch_indices: List[int],
action_history: List[List[int]],
score: List[torch.Tensor],
rnn_state: List[RnnStatelet],
grammar_state: List[GrammarStatelet],
sql_state: List[SqlState],
possible_actions: List[List[ProductionRule]],
action_entity_mapping: List[Dict[int, int]],
extras: List[Any] = None,
debug_info: List = None) -> None:
super().__init__(batch_indices, action_history, score)
self.rnn_state = rnn_state
self.grammar_state = grammar_state
self.sql_state = sql_state
self.possible_actions = possible_actions
self.action_entity_mapping = action_entity_mapping
self.extras = extras
self.debug_info = debug_info
def new_state_from_group_index(self,
group_index: int,
action: int,
new_score: torch.Tensor,
new_rnn_state: RnnStatelet,
considered_actions: List[int] = None,
action_probabilities: List[float] = None,
attention_weights: torch.Tensor = None,
linking_scores_qst: torch.Tensor = None,
linking_scores_past: torch.Tensor = None) -> 'GrammarBasedState':
batch_index = self.batch_indices[group_index]
new_action_history = self.action_history[group_index] + [action]
production_rule = self.possible_actions[batch_index][action][0]
new_grammar_state = self.grammar_state[group_index].take_action(production_rule)
new_sql_state = self.sql_state[group_index].take_action(production_rule)
if self.debug_info is not None:
attention = attention_weights[group_index] if attention_weights is not None else None
# output_attention = output_attention_weights[group_index] if output_attention_weights is not None else None
debug_info = {
'considered_actions': considered_actions,
'question_attention': attention,
'probabilities': action_probabilities,
'linking_scores_qst': linking_scores_qst,
'linking_scores_past': linking_scores_past,
}
new_debug_info = [self.debug_info[group_index] + [debug_info]]
else:
new_debug_info = None
return GrammarBasedState(batch_indices=[batch_index],
action_history=[new_action_history],
score=[new_score],
gitextract_4_pkjgz2/
├── README.md
├── dataset_readers/
│ ├── __init__.py
│ ├── dataset_util/
│ │ └── spider_utils.py
│ ├── fields/
│ │ └── knowledge_graph_field.py
│ └── spider.py
├── models/
│ ├── __init__.py
│ └── semantic_parsing/
│ ├── __init__.py
│ └── spider_parser.py
├── modules/
│ └── gated_graph_conv.py
├── predictors/
│ └── spider_predictor.py
├── requirements.txt
├── semparse/
│ ├── contexts/
│ │ ├── spider_context_utils.py
│ │ ├── spider_db_context.py
│ │ └── spider_db_grammar.py
│ └── worlds/
│ ├── evaluate.py
│ ├── evaluate_spider.py
│ └── spider_world.py
├── spider_evaluation/
│ ├── evaluate.py
│ └── process_sql.py
├── state_machines/
│ ├── states/
│ │ ├── grammar_based_state.py
│ │ ├── rnn_statelet.py
│ │ └── sql_state.py
│ └── transition_functions/
│ ├── attend_past_schema_items_transition.py
│ ├── basic_transition_function.py
│ ├── linking_transition_function.py
│ └── prefix_attend_transition.py
└── train_configs/
├── defaults.jsonnet
└── paper_defaults.jsonnet
SYMBOL INDEX (235 symbols across 21 files)
FILE: dataset_readers/dataset_util/spider_utils.py
class TableColumn (line 16) | class TableColumn:
method __init__ (line 17) | def __init__(self,
class Table (line 30) | class Table:
method __init__ (line 31) | def __init__(self,
function read_dataset_schema (line 40) | def read_dataset_schema(schema_path: str) -> Dict[str, List[Table]]:
function read_dataset_values (line 75) | def read_dataset_values(db_id: str, dataset_path: str, tables: List[str]):
function ent_key_to_name (line 99) | def ent_key_to_name(key):
function fix_number_value (line 110) | def fix_number_value(ex: JsonDict):
function disambiguate_items (line 159) | def disambiguate_items(db_id: str, query_toks: List[str], tables_file: s...
FILE: dataset_readers/fields/knowledge_graph_field.py
class SpiderKnowledgeGraphField (line 12) | class SpiderKnowledgeGraphField(KnowledgeGraphField):
method __init__ (line 17) | def __init__(self,
method _compute_related_linking_features (line 48) | def _compute_related_linking_features(self,
FILE: dataset_readers/spider.py
class SpiderDatasetReader (line 25) | class SpiderDatasetReader(DatasetReader):
method __init__ (line 26) | def __init__(self,
method _read (line 53) | def _read(self, file_path: str):
method text_to_instance (line 121) | def text_to_instance(self,
FILE: models/semantic_parsing/spider_parser.py
class SpiderParser (line 37) | class SpiderParser(Model):
method __init__ (line 38) | def __init__(self,
method forward (line 158) | def forward(self, # type: ignore
method _get_initial_state (line 216) | def _get_initial_state(self,
method _get_neighbor_indices (line 376) | def _get_neighbor_indices(worlds: List[SpiderWorld],
method _get_schema_graph_encoding (line 428) | def _get_schema_graph_encoding(self,
method _get_graph_adj_lists (line 459) | def _get_graph_adj_lists(device, world, global_entity_id, global_node=...
method _create_grammar_state (line 506) | def _create_grammar_state(self,
method is_nonterminal (line 588) | def is_nonterminal(token: str):
method _get_linking_probabilities (line 593) | def _get_linking_probabilities(self,
method _action_history_match (line 676) | def _action_history_match(predicted: List[int], targets: torch.LongTen...
method _query_difficulty (line 687) | def _query_difficulty(targets: torch.LongTensor, action_mapping, batch...
method get_metrics (line 693) | def get_metrics(self, reset: bool = False) -> Dict[str, float]:
method _get_type_vector (line 704) | def _get_type_vector(worlds: List[SpiderWorld],
method _compute_validation_outputs (line 757) | def _compute_validation_outputs(self,
FILE: modules/gated_graph_conv.py
class GatedGraphConv (line 10) | class GatedGraphConv(MessagePassing):
method __init__ (line 32) | def __init__(self, input_dim, num_timesteps, num_edge_types, aggr='add...
method reset_parameters (line 46) | def reset_parameters(self):
method forward (line 53) | def forward(self, x, edge_indices):
method __repr__ (line 81) | def __repr__(self):
FILE: predictors/spider_predictor.py
class WikiTablesParserPredictor (line 10) | class WikiTablesParserPredictor(Predictor):
method __init__ (line 11) | def __init__(self, model: Model, dataset_reader: DatasetReader) -> None:
method predict_instance (line 15) | def predict_instance(self, instance: Instance) -> JsonDict:
method dump_line (line 26) | def dump_line(self, outputs: JsonDict) -> str: # pylint: disable=no-s...
FILE: semparse/contexts/spider_context_utils.py
function format_grammar_string (line 16) | def format_grammar_string(grammar_dictionary: Dict[str, List[str]]) -> str:
function initialize_valid_actions (line 25) | def initialize_valid_actions(grammar: Grammar,
function format_action (line 64) | def format_action(nonterminal: str,
function action_sequence_to_sql (line 112) | def action_sequence_to_sql(action_sequences: List[str], add_table_names:...
class SqlVisitor (line 139) | class SqlVisitor(NodeVisitor):
method __init__ (line 162) | def __init__(self, grammar: Grammar, keywords_to_uppercase: List[str] ...
method generic_visit (line 168) | def generic_visit(self, node: Node, visited_children: List[None]) -> L...
method add_action (line 174) | def add_action(self, node: Node) -> None:
method visit (line 205) | def visit(self, node):
FILE: semparse/contexts/spider_db_context.py
class SpiderDBContext (line 30) | class SpiderDBContext:
method __init__ (line 35) | def __init__(self, db_id: str, utterance: str, tokenizer: Tokenizer, t...
method entity_key_for_column (line 56) | def entity_key_for_column(table_name: str, column: TableColumn) -> str:
method get_db_knowledge_graph (line 65) | def get_db_knowledge_graph(self, db_id: str) -> KnowledgeGraph:
method _string_in_table (line 134) | def _string_in_table(self, candidate: str,
method get_entities_from_question (line 152) | def get_entities_from_question(self,
method normalize_string (line 182) | def normalize_string(string: str) -> str:
method _expand_entities (line 221) | def _expand_entities(self, question, entity_data, string_column_mappin...
FILE: semparse/contexts/spider_db_grammar.py
function update_grammar_with_tables (line 105) | def update_grammar_with_tables(grammar_dictionary: Dict[str, List[str]],
function update_grammar_to_be_table_names_free (line 118) | def update_grammar_to_be_table_names_free(grammar_dictionary: Dict[str, ...
function update_grammar_flip_joins (line 131) | def update_grammar_flip_joins(grammar_dictionary: Dict[str, List[str]]):
FILE: semparse/worlds/evaluate.py
function condition_has_or (line 57) | def condition_has_or(conds):
function condition_has_like (line 61) | def condition_has_like(conds):
function condition_has_sql (line 65) | def condition_has_sql(conds):
function val_has_op (line 75) | def val_has_op(val_unit):
function has_agg (line 79) | def has_agg(unit):
function accuracy (line 83) | def accuracy(count, total):
function recall (line 89) | def recall(count, total):
function F1 (line 95) | def F1(acc, rec):
function get_scores (line 101) | def get_scores(count, pred_total, label_total):
function eval_sel (line 109) | def eval_sel(pred, label):
function eval_where (line 129) | def eval_where(pred, label):
function eval_group (line 149) | def eval_group(pred, label):
function eval_having (line 164) | def eval_having(pred, label):
function eval_order (line 181) | def eval_order(pred, label):
function eval_and_or (line 193) | def eval_and_or(pred, label):
function get_nestedSQL (line 204) | def get_nestedSQL(sql):
function eval_nested (line 220) | def eval_nested(pred, label):
function eval_IUEN (line 233) | def eval_IUEN(pred, label):
function get_keywords (line 243) | def get_keywords(sql):
function eval_keywords (line 284) | def eval_keywords(pred, label):
function count_agg (line 297) | def count_agg(units):
function count_component1 (line 301) | def count_component1(sql):
function count_component2 (line 322) | def count_component2(sql):
function count_others (line 327) | def count_others(sql):
class Evaluator (line 355) | class Evaluator:
method __init__ (line 357) | def __init__(self):
method eval_hardness (line 360) | def eval_hardness(self, sql):
method eval_exact_match (line 377) | def eval_exact_match(self, pred, label):
method eval_partial_match (line 396) | def eval_partial_match(self, pred, label):
function isValidSQL (line 438) | def isValidSQL(sql, db):
function print_scores (line 448) | def print_scores(scores, etype):
function evaluate (line 482) | def evaluate(gold, predict, db_dir, etype, kmaps):
function eval_exec_match (line 625) | def eval_exec_match(db, p_str, g_str, pred, gold):
function rebuild_cond_unit_val (line 655) | def rebuild_cond_unit_val(cond_unit):
function rebuild_condition_val (line 671) | def rebuild_condition_val(condition):
function rebuild_sql_val (line 684) | def rebuild_sql_val(sql):
function build_valid_col_units (line 699) | def build_valid_col_units(table_units, schema):
function rebuild_col_unit_col (line 709) | def rebuild_col_unit_col(valid_col_units, col_unit, kmap):
function rebuild_val_unit_col (line 721) | def rebuild_val_unit_col(valid_col_units, val_unit, kmap):
function rebuild_table_unit_col (line 731) | def rebuild_table_unit_col(valid_col_units, table_unit, kmap):
function rebuild_cond_unit_col (line 741) | def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap):
function rebuild_condition_col (line 750) | def rebuild_condition_col(valid_col_units, condition, kmap):
function rebuild_select_col (line 757) | def rebuild_select_col(valid_col_units, sel, kmap):
function rebuild_from_col (line 770) | def rebuild_from_col(valid_col_units, from_, kmap):
function rebuild_group_by_col (line 779) | def rebuild_group_by_col(valid_col_units, group_by, kmap):
function rebuild_order_by_col (line 786) | def rebuild_order_by_col(valid_col_units, order_by, kmap):
function rebuild_sql_col (line 795) | def rebuild_sql_col(valid_col_units, sql, kmap):
function build_foreign_key_map (line 812) | def build_foreign_key_map(entry):
function build_foreign_key_map_from_json (line 852) | def build_foreign_key_map_from_json(table):
FILE: semparse/worlds/evaluate_spider.py
function evaluate (line 12) | def evaluate(gold, predict, db_name, db_dir, table, check_valid: bool=Tr...
function check_valid_sql (line 56) | def check_valid_sql(sql, db_name, db_dir, return_error=False):
FILE: semparse/worlds/spider_world.py
class SpiderWorld (line 13) | class SpiderWorld:
method __init__ (line 18) | def __init__(self, db_context: SpiderDBContext, query: Optional[List[s...
method get_action_sequence_and_all_actions (line 39) | def get_action_sequence_and_all_actions(self,
method get_all_actions (line 69) | def get_all_actions(self, schema,
method is_global_rule (line 92) | def is_global_rule(self, rhs: str) -> bool:
method get_oracle_relevance_score (line 98) | def get_oracle_relevance_score(self, oracle_entities: set):
method get_action_entity_mapping (line 116) | def get_action_entity_mapping(self) -> Dict[int, int]:
method get_query_without_table_hints (line 132) | def get_query_without_table_hints(self):
FILE: spider_evaluation/evaluate.py
function condition_has_or (line 57) | def condition_has_or(conds):
function condition_has_like (line 61) | def condition_has_like(conds):
function condition_has_sql (line 65) | def condition_has_sql(conds):
function val_has_op (line 75) | def val_has_op(val_unit):
function has_agg (line 79) | def has_agg(unit):
function accuracy (line 83) | def accuracy(count, total):
function recall (line 89) | def recall(count, total):
function F1 (line 95) | def F1(acc, rec):
function get_scores (line 101) | def get_scores(count, pred_total, label_total):
function eval_sel (line 109) | def eval_sel(pred, label):
function eval_where (line 129) | def eval_where(pred, label):
function eval_group (line 149) | def eval_group(pred, label):
function eval_having (line 164) | def eval_having(pred, label):
function eval_order (line 181) | def eval_order(pred, label):
function eval_and_or (line 193) | def eval_and_or(pred, label):
function get_nestedSQL (line 204) | def get_nestedSQL(sql):
function eval_nested (line 220) | def eval_nested(pred, label):
function eval_IUEN (line 233) | def eval_IUEN(pred, label):
function get_keywords (line 243) | def get_keywords(sql):
function eval_keywords (line 284) | def eval_keywords(pred, label):
function count_agg (line 297) | def count_agg(units):
function count_component1 (line 301) | def count_component1(sql):
function count_component2 (line 322) | def count_component2(sql):
function count_others (line 327) | def count_others(sql):
class Evaluator (line 355) | class Evaluator:
method __init__ (line 357) | def __init__(self):
method eval_hardness (line 360) | def eval_hardness(self, sql):
method eval_exact_match (line 377) | def eval_exact_match(self, pred, label):
method eval_partial_match (line 396) | def eval_partial_match(self, pred, label):
function isValidSQL (line 438) | def isValidSQL(sql, db):
function print_scores (line 448) | def print_scores(scores, etype):
function evaluate (line 482) | def evaluate(gold, predict, db_dir, etype, kmaps):
function eval_exec_match (line 625) | def eval_exec_match(db, p_str, g_str, pred, gold):
function rebuild_cond_unit_val (line 655) | def rebuild_cond_unit_val(cond_unit):
function rebuild_condition_val (line 671) | def rebuild_condition_val(condition):
function rebuild_sql_val (line 684) | def rebuild_sql_val(sql):
function build_valid_col_units (line 699) | def build_valid_col_units(table_units, schema):
function rebuild_col_unit_col (line 709) | def rebuild_col_unit_col(valid_col_units, col_unit, kmap):
function rebuild_val_unit_col (line 721) | def rebuild_val_unit_col(valid_col_units, val_unit, kmap):
function rebuild_table_unit_col (line 731) | def rebuild_table_unit_col(valid_col_units, table_unit, kmap):
function rebuild_cond_unit_col (line 741) | def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap):
function rebuild_condition_col (line 750) | def rebuild_condition_col(valid_col_units, condition, kmap):
function rebuild_select_col (line 757) | def rebuild_select_col(valid_col_units, sel, kmap):
function rebuild_from_col (line 770) | def rebuild_from_col(valid_col_units, from_, kmap):
function rebuild_group_by_col (line 779) | def rebuild_group_by_col(valid_col_units, group_by, kmap):
function rebuild_order_by_col (line 786) | def rebuild_order_by_col(valid_col_units, order_by, kmap):
function rebuild_sql_col (line 795) | def rebuild_sql_col(valid_col_units, sql, kmap):
function build_foreign_key_map (line 812) | def build_foreign_key_map(entry):
function build_foreign_key_map_from_json (line 852) | def build_foreign_key_map_from_json(table):
FILE: spider_evaluation/process_sql.py
class Schema (line 48) | class Schema:
method __init__ (line 52) | def __init__(self, schema):
method schema (line 57) | def schema(self):
method idMap (line 61) | def idMap(self):
method _map (line 64) | def _map(self, schema):
function get_schema (line 79) | def get_schema(db):
function get_schema_from_json (line 103) | def get_schema_from_json(fpath):
function tokenize (line 116) | def tokenize(string):
function scan_alias (line 150) | def scan_alias(toks):
function get_tables_with_alias (line 159) | def get_tables_with_alias(schema, toks):
function parse_col (line 167) | def parse_col(toks, start_idx, tables_with_alias, schema, default_tables...
function parse_col_unit (line 194) | def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_t...
function parse_val_unit (line 232) | def parse_val_unit(toks, start_idx, tables_with_alias, schema, default_t...
function parse_table_unit (line 257) | def parse_table_unit(toks, start_idx, tables_with_alias, schema):
function parse_value (line 273) | def parse_value(toks, start_idx, tables_with_alias, schema, default_tabl...
function parse_condition (line 307) | def parse_condition(toks, start_idx, tables_with_alias, schema, default_...
function parse_select (line 344) | def parse_select(toks, start_idx, tables_with_alias, schema, default_tab...
function parse_from (line 369) | def parse_from(toks, start_idx, tables_with_alias, schema):
function parse_where (line 412) | def parse_where(toks, start_idx, tables_with_alias, schema, default_tabl...
function parse_group_by (line 424) | def parse_group_by(toks, start_idx, tables_with_alias, schema, default_t...
function parse_order_by (line 447) | def parse_order_by(toks, start_idx, tables_with_alias, schema, default_t...
function parse_having (line 474) | def parse_having(toks, start_idx, tables_with_alias, schema, default_tab...
function parse_limit (line 486) | def parse_limit(toks, start_idx):
function parse_sql (line 501) | def parse_sql(toks, start_idx, tables_with_alias, schema, mapped_entitie...
function load_data (line 559) | def load_data(fpath):
function get_sql (line 565) | def get_sql(schema, query):
function skip_semicolon (line 573) | def skip_semicolon(toks, start_idx):
FILE: state_machines/states/grammar_based_state.py
class GrammarBasedState (line 16) | class GrammarBasedState(State['GrammarBasedState']):
method __init__ (line 58) | def __init__(self,
method new_state_from_group_index (line 78) | def new_state_from_group_index(self,
method print_action_history (line 117) | def print_action_history(self, group_index: int = None) -> None:
method get_valid_actions (line 125) | def get_valid_actions(self) -> List[Dict[str, Tuple[torch.Tensor, torc...
method is_finished (line 133) | def is_finished(self) -> bool:
method combine_states (line 139) | def combine_states(cls, states: Sequence['GrammarBasedState']) -> 'Gra...
FILE: state_machines/states/rnn_statelet.py
class RnnStatelet (line 8) | class RnnStatelet:
method __init__ (line 48) | def __init__(self,
method __eq__ (line 64) | def __eq__(self, other):
FILE: state_machines/states/sql_state.py
class SqlState (line 7) | class SqlState:
method __init__ (line 8) | def __init__(self,
method take_action (line 19) | def take_action(self, production_rule: str) -> 'SqlState':
method get_valid_actions (line 63) | def get_valid_actions(self, valid_actions: dict):
method _remove_actions (line 177) | def _remove_actions(valid_actions, key, ids_to_remove):
method _get_current_open_clause (line 202) | def _get_current_open_clause(self):
FILE: state_machines/transition_functions/attend_past_schema_items_transition.py
class AttendPastSchemaItemsTransitionFunction (line 17) | class AttendPastSchemaItemsTransitionFunction(BasicTransitionFunction):
method __init__ (line 18) | def __init__(self,
method take_step (line 43) | def take_step(self,
method _update_decoder_state (line 74) | def _update_decoder_state(self, state: GrammarBasedState) -> Dict[str,...
method _compute_action_probabilities (line 149) | def _compute_action_probabilities(self,
method _construct_next_states (line 244) | def _construct_next_states(self,
method attend (line 364) | def attend(self,
FILE: state_machines/transition_functions/basic_transition_function.py
class BasicTransitionFunction (line 16) | class BasicTransitionFunction(TransitionFunction[GrammarBasedState]):
method __init__ (line 53) | def __init__(self,
method take_step (line 102) | def take_step(self,
method _update_decoder_state (line 132) | def _update_decoder_state(self, state: GrammarBasedState) -> Dict[str,...
method _compute_action_probabilities (line 192) | def _compute_action_probabilities(self,
method _construct_next_states (line 232) | def _construct_next_states(self,
method _take_first_step (line 333) | def _take_first_step(self,
method attend_on_question (line 400) | def attend_on_question(self,
FILE: state_machines/transition_functions/linking_transition_function.py
class LinkingTransitionFunction (line 16) | class LinkingTransitionFunction(BasicTransitionFunction):
method __init__ (line 58) | def __init__(self,
method _compute_action_probabilities (line 87) | def _compute_action_probabilities(self,
FILE: state_machines/transition_functions/prefix_attend_transition.py
class PrefixAttendTransitionFunction (line 15) | class PrefixAttendTransitionFunction(LinkingTransitionFunction):
method __init__ (line 16) | def __init__(self,
method take_step (line 49) | def take_step(self,
method _update_decoder_state (line 80) | def _update_decoder_state(self, state: GrammarBasedState) -> Dict[str,...
method attend (line 146) | def attend(self,
method _construct_next_states (line 166) | def _construct_next_states(self,
Condensed preview — 28 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (303K chars).
[
{
"path": "README.md",
"chars": 3185,
"preview": "# Representing Schema Structure with Graph Neural Networks for Text-to-SQL Parsing\n\nAuthor implementation of this [ACL 2"
},
{
"path": "dataset_readers/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "dataset_readers/dataset_util/spider_utils.py",
"chars": 9182,
"preview": "\"\"\"\nUtility functions for reading the standardised text2sql datasets presented in\n`\"Improving Text to SQL Evaluation Met"
},
{
"path": "dataset_readers/fields/knowledge_graph_field.py",
"chars": 3549,
"preview": "\"\"\"\n``KnowledgeGraphField`` is a ``Field`` which stores a knowledge graph representation.\n\"\"\"\nfrom typing import List, D"
},
{
"path": "dataset_readers/spider.py",
"chars": 7828,
"preview": "import json\nimport logging\nimport os\nfrom typing import List, Dict\n\nimport dill\nfrom allennlp.common.checks import Confi"
},
{
"path": "models/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "models/semantic_parsing/__init__.py",
"chars": 62,
"preview": "from models.semantic_parsing.spider_parser import SpiderParser"
},
{
"path": "models/semantic_parsing/spider_parser.py",
"chars": 43124,
"preview": "import difflib\nimport os\nfrom functools import partial\nfrom typing import Dict, List, Tuple, Any, Mapping, Sequence\n\nimp"
},
{
"path": "modules/gated_graph_conv.py",
"chars": 3703,
"preview": "import math\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Parameter as Param, init\nfrom torch_geometric.da"
},
{
"path": "predictors/spider_predictor.py",
"chars": 1229,
"preview": "from overrides import overrides\n\nfrom allennlp.common.util import JsonDict, sanitize\nfrom allennlp.data import DatasetRe"
},
{
"path": "requirements.txt",
"chars": 259,
"preview": "torch==1.0.1.post2\nspacy==2.0.18\nallennlp==0.8.2\ndill\ntorch-scatter==1.1.2\ntorch-sparse==0.2.4\ntorch-cluster==1.2.4\ntorc"
},
{
"path": "semparse/contexts/spider_context_utils.py",
"chars": 10171,
"preview": "import re\nfrom collections import defaultdict\nfrom sys import exc_info\nfrom typing import List, Dict, Set\n\nfrom override"
},
{
"path": "semparse/contexts/spider_db_context.py",
"chars": 12887,
"preview": "import re\nfrom collections import Set, defaultdict\nfrom typing import Dict, Tuple, List\n\nfrom allennlp.data import Token"
},
{
"path": "semparse/contexts/spider_db_grammar.py",
"chars": 7872,
"preview": "# pylint: disable=anomalous-backslash-in-string\n\"\"\"\nA ``Text2SqlTableContext`` represents the SQL context in which an ut"
},
{
"path": "semparse/worlds/evaluate.py",
"chars": 30758,
"preview": "################################\n# val: number(float)/string(str)/sql(dict)\n# col_unit: (agg_id, col_id, isDistinct(bool"
},
{
"path": "semparse/worlds/evaluate_spider.py",
"chars": 2230,
"preview": "import os\nimport sqlite3\n\nfrom semparse.worlds.evaluate import Evaluator, build_valid_col_units, rebuild_sql_val, rebuil"
},
{
"path": "semparse/worlds/spider_world.py",
"chars": 7018,
"preview": "from typing import List, Tuple, Dict, Set, Optional\nfrom copy import deepcopy\n\nfrom parsimonious import Grammar\nfrom par"
},
{
"path": "spider_evaluation/evaluate.py",
"chars": 30746,
"preview": "################################\n# val: number(float)/string(str)/sql(dict)\n# col_unit: (agg_id, col_id, isDistinct(bool"
},
{
"path": "spider_evaluation/process_sql.py",
"chars": 17219,
"preview": "################################\n# Assumptions:\n# 1. sql is correct\n# 2. only table name has alias\n# 3. only one i"
},
{
"path": "state_machines/states/grammar_based_state.py",
"chars": 9387,
"preview": "from typing import Any, Dict, List, Sequence, Tuple\n\nimport torch\n\nfrom allennlp.data.fields.production_rule_field impor"
},
{
"path": "state_machines/states/rnn_statelet.py",
"chars": 3883,
"preview": "from typing import List, Optional\n\nimport torch\n\nfrom allennlp.nn import util\n\n\nclass RnnStatelet:\n \"\"\"\n This clas"
},
{
"path": "state_machines/states/sql_state.py",
"chars": 10509,
"preview": "import copy\nimport logging\n\nlogger = logging.getLogger(__name__) # pylint: disable=invalid-name\n\n\nclass SqlState:\n d"
},
{
"path": "state_machines/transition_functions/attend_past_schema_items_transition.py",
"chars": 22504,
"preview": "from collections import defaultdict\nfrom typing import Dict, Tuple, List, Set, Any, Callable, Optional\n\nimport torch\nfro"
},
{
"path": "state_machines/transition_functions/basic_transition_function.py",
"chars": 25189,
"preview": "from collections import defaultdict\nfrom typing import Any, Dict, List, Set, Tuple\n\nfrom overrides import overrides\n\nimp"
},
{
"path": "state_machines/transition_functions/linking_transition_function.py",
"chars": 10082,
"preview": "from collections import defaultdict\nfrom typing import Any, Dict, List, Tuple\n\nfrom overrides import overrides\n\nimport t"
},
{
"path": "state_machines/transition_functions/prefix_attend_transition.py",
"chars": 16028,
"preview": "from typing import Dict, Tuple, List, Set, Any, Callable\n\nimport torch\nfrom allennlp.modules import Attention, FeedForwa"
},
{
"path": "train_configs/defaults.jsonnet",
"chars": 1987,
"preview": "local dataset_path = \"dataset/\";\n\n{\n \"random_seed\": 5,\n \"numpy_seed\": 5,\n \"pytorch_seed\": 5,\n \"dataset_reader\": {\n "
},
{
"path": "train_configs/paper_defaults.jsonnet",
"chars": 1958,
"preview": "local dataset_path = \"dataset/\";\n\n{\n \"random_seed\": 5,\n \"numpy_seed\": 5,\n \"pytorch_seed\": 5,\n \"dataset_reader\": {\n "
}
]
About this extraction
This page contains the full source code of the benbogin/spider-schema-gnn GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 28 files (285.7 KB), approximately 65.2k tokens, and a symbol index with 235 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.