Repository: benbogin/spider-schema-gnn Branch: master Commit: 02f4ae43b891 Files: 28 Total size: 285.7 KB Directory structure: gitextract_4_pkjgz2/ ├── README.md ├── dataset_readers/ │ ├── __init__.py │ ├── dataset_util/ │ │ └── spider_utils.py │ ├── fields/ │ │ └── knowledge_graph_field.py │ └── spider.py ├── models/ │ ├── __init__.py │ └── semantic_parsing/ │ ├── __init__.py │ └── spider_parser.py ├── modules/ │ └── gated_graph_conv.py ├── predictors/ │ └── spider_predictor.py ├── requirements.txt ├── semparse/ │ ├── contexts/ │ │ ├── spider_context_utils.py │ │ ├── spider_db_context.py │ │ └── spider_db_grammar.py │ └── worlds/ │ ├── evaluate.py │ ├── evaluate_spider.py │ └── spider_world.py ├── spider_evaluation/ │ ├── evaluate.py │ └── process_sql.py ├── state_machines/ │ ├── states/ │ │ ├── grammar_based_state.py │ │ ├── rnn_statelet.py │ │ └── sql_state.py │ └── transition_functions/ │ ├── attend_past_schema_items_transition.py │ ├── basic_transition_function.py │ ├── linking_transition_function.py │ └── prefix_attend_transition.py └── train_configs/ ├── defaults.jsonnet └── paper_defaults.jsonnet ================================================ FILE CONTENTS ================================================ ================================================ FILE: README.md ================================================ # Representing Schema Structure with Graph Neural Networks for Text-to-SQL Parsing Author implementation of this [ACL 2019 paper](https://arxiv.org/abs/1905.06241). Please also see the [follow-up repository](https://github.com/benbogin/spider-schema-gnn-global) with improved results, for this [EMNLP paper](https://www.aclweb.org/anthology/D19-1378.pdf). ## Install & Configure 1. Install pytorch version 1.0.1.post2 that fits your CUDA version (this repository should probably work with the latest pytorch version, but wasn't tested for it. If you use another version, you'll need to also update the versions of packages in `requirements.txt`) ``` pip install https://download.pytorch.org/whl/cu100/torch-1.0.1.post2-cp37-cp37m-linux_x86_64.whl # CUDA 10.0 build ``` 2. Install the rest of required packages ``` pip install -r requirements.txt ``` 3. Run this command to install NLTK punkt. ``` python -c "import nltk; nltk.download('punkt')" ``` 4. Download the dataset from the [official Spider dataset website](https://yale-lily.github.io/spider) 5. Edit the config file `train_configs/defaults.jsonnet` to update the location of the dataset: ``` local dataset_path = "dataset/"; ``` ## Training 1. Use the following AllenNLP command to train: ``` allennlp train train_configs/defaults.jsonnet -s experiments/name_of_experiment \ --include-package dataset_readers.spider \ --include-package models.semantic_parsing.spider_parser ``` First time loading of the dataset might take a while (a few hours) since the model first loads values from tables and calculates similarity features with the relevant question. It will then be cached for subsequent runs. You should get results similar to the following (the `sql_match` is the one measured in the official evaluation test): ``` "best_validation__match/exact_match": 0.3715686274509804, "best_validation_sql_match": 0.47549019607843135, "best_validation__others/action_similarity": 0.5731271471206189, "best_validation__match/match_single": 0.6254612546125461, "best_validation__match/match_hard": 0.3054393305439331, "best_validation_beam_hit": 0.6070588235294118, "best_validation_loss": 7.383035182952881 "best_epoch": 32 ``` Note that the hyper-parameters used in `defaults.jsonnet` are different than those mentioned in the paper (most importantly, 3 timesteps are used instead of 2), thanks to the [following contribution from @wlhgtc](https://github.com/benbogin/spider-schema-gnn/pull/13). The original training config file is still available in `train_configs/paper_Defaults.jsonnet`. ## Inference Use the following AllenNLP command to output a file with the predicted queries: ``` allennlp predict experiments/name_of_experiment dataset/dev.json \ --predictor spider \ --use-dataset-reader \ --cuda-device=0 \ --output-file experiments/name_of_experiment/prediction.sql \ --silent \ --include-package models.semantic_parsing.spider_parser \ --include-package dataset_readers.spider \ --include-package predictors.spider_predictor \ --weights-file experiments/name_of_experiment/best.th \ -o "{\"dataset_reader\":{\"keep_if_unparsable\":true}}" ``` ================================================ FILE: dataset_readers/__init__.py ================================================ ================================================ FILE: dataset_readers/dataset_util/spider_utils.py ================================================ """ Utility functions for reading the standardised text2sql datasets presented in `"Improving Text to SQL Evaluation Methodology" `_ """ import json import os import sqlite3 from collections import defaultdict from typing import List, Dict, Optional from allennlp.common import JsonDict from spider_evaluation.process_sql import get_tables_with_alias, parse_sql class TableColumn: def __init__(self, name: str, text: str, column_type: str, is_primary_key: bool, foreign_key: Optional[str]): self.name = name self.text = text self.column_type = column_type self.is_primary_key = is_primary_key self.foreign_key = foreign_key class Table: def __init__(self, name: str, text: str, columns: List[TableColumn]): self.name = name self.text = text self.columns = columns def read_dataset_schema(schema_path: str) -> Dict[str, List[Table]]: schemas: Dict[str, Dict[str, Table]] = defaultdict(dict) dbs_json_blob = json.load(open(schema_path, "r")) for db in dbs_json_blob: db_id = db['db_id'] column_id_to_table = {} column_id_to_column = {} for i, (column, text, column_type) in enumerate(zip(db['column_names_original'], db['column_names'], db['column_types'])): table_id, column_name = column _, column_text = text table_name = db['table_names_original'][table_id] if table_name not in schemas[db_id]: table_text = db['table_names'][table_id] schemas[db_id][table_name] = Table(table_name, table_text, []) if column_name == "*": continue is_primary_key = i in db['primary_keys'] table_column = TableColumn(column_name.lower(), column_text, column_type, is_primary_key, None) schemas[db_id][table_name].columns.append(table_column) column_id_to_table[i] = table_name column_id_to_column[i] = table_column for (c1, c2) in db['foreign_keys']: foreign_key = column_id_to_table[c2] + ':' + column_id_to_column[c2].name column_id_to_column[c1].foreign_key = foreign_key return {**schemas} def read_dataset_values(db_id: str, dataset_path: str, tables: List[str]): db = os.path.join(dataset_path, db_id, db_id + ".sqlite") try: conn = sqlite3.connect(db) except Exception as e: raise Exception(f"Can't connect to SQL: {e} in path {db}") conn.text_factory = str cursor = conn.cursor() values = {} for table in tables: try: cursor.execute(f"SELECT * FROM {table.name} LIMIT 5000") values[table] = cursor.fetchall() except: conn.text_factory = lambda x: str(x, 'latin1') cursor = conn.cursor() cursor.execute(f"SELECT * FROM {table.name} LIMIT 5000") values[table] = cursor.fetchall() return values def ent_key_to_name(key): parts = key.split(':') if parts[0] == 'table': return parts[1] elif parts[0] == 'column': _, _, table_name, column_name = parts return f'{table_name}@{column_name}' else: return parts[1] def fix_number_value(ex: JsonDict): """ There is something weird in the dataset files - the `query_toks_no_value` field anonymizes all values, which is good since the evaluator doesn't check for the values. But it also anonymizes numbers that should not be anonymized: e.g. LIMIT 3 becomes LIMIT 'value', while the evaluator fails if it is not a number. """ def split_and_keep(s, sep): if not s: return [''] # consistent with string.split() # Find replacement character that is not used in string # i.e. just use the highest available character plus one # Note: This fails if ord(max(s)) = 0x10FFFF (ValueError) p = chr(ord(max(s)) + 1) return s.replace(sep, p + sep + p).split(p) # input is tokenized in different ways... so first try to make splits equal query_toks = ex['query_toks'] ex['query_toks'] = [] for q in query_toks: ex['query_toks'] += split_and_keep(q, '.') i_val, i_no_val = 0, 0 while i_val < len(ex['query_toks']) and i_no_val < len(ex['query_toks_no_value']): if ex['query_toks_no_value'][i_no_val] != 'value': i_val += 1 i_no_val += 1 continue i_val_end = i_val while i_val + 1 < len(ex['query_toks']) and \ i_no_val + 1 < len(ex['query_toks_no_value']) and \ ex['query_toks'][i_val_end + 1].lower() != ex['query_toks_no_value'][i_no_val + 1].lower(): i_val_end += 1 if i_val == i_val_end and ex['query_toks'][i_val] in ["1", "2", "3"] and ex['query_toks'][i_val - 1].lower() == "limit": ex['query_toks_no_value'][i_no_val] = ex['query_toks'][i_val] i_val = i_val_end i_val += 1 i_no_val += 1 return ex _schemas_cache = None def disambiguate_items(db_id: str, query_toks: List[str], tables_file: str, allow_aliases: bool) -> List[str]: """ we want the query tokens to be non-ambiguous - so we can change each column name to explicitly tell which table it belongs to parsed sql to sql clause is based on supermodel.gensql from syntaxsql """ class Schema: """ Simple schema which maps table&column to a unique identifier """ def __init__(self, schema, table): self._schema = schema self._table = table self._idMap = self._map(self._schema, self._table) @property def schema(self): return self._schema @property def idMap(self): return self._idMap def _map(self, schema, table): column_names_original = table['column_names_original'] table_names_original = table['table_names_original'] # print 'column_names_original: ', column_names_original # print 'table_names_original: ', table_names_original for i, (tab_id, col) in enumerate(column_names_original): if tab_id == -1: idMap = {'*': i} else: key = table_names_original[tab_id].lower() val = col.lower() idMap[key + "." + val] = i for i, tab in enumerate(table_names_original): key = tab.lower() idMap[key] = i return idMap def get_schemas_from_json(fpath): global _schemas_cache if _schemas_cache is not None: return _schemas_cache with open(fpath) as f: data = json.load(f) db_names = [db['db_id'] for db in data] tables = {} schemas = {} for db in data: db_id = db['db_id'] schema = {} # {'table': [col.lower, ..., ]} * -> __all__ column_names_original = db['column_names_original'] table_names_original = db['table_names_original'] tables[db_id] = {'column_names_original': column_names_original, 'table_names_original': table_names_original} for i, tabn in enumerate(table_names_original): table = str(tabn.lower()) cols = [str(col.lower()) for td, col in column_names_original if td == i] schema[table] = cols schemas[db_id] = schema _schemas_cache = schemas, db_names, tables return _schemas_cache schemas, db_names, tables = get_schemas_from_json(tables_file) schema = Schema(schemas[db_id], tables[db_id]) fixed_toks = [] i = 0 while i < len(query_toks): tok = query_toks[i] if tok == 'value' or tok == "'value'": # TODO: value should alawys be between '/" (remove first if clause) new_tok = f'"{tok}"' elif tok in ['!','<','>'] and query_toks[i+1] == '=': new_tok = tok + '=' i += 1 elif i+1 < len(query_toks) and query_toks[i+1] == '.': new_tok = ''.join(query_toks[i:i+3]) i += 2 else: new_tok = tok fixed_toks.append(new_tok) i += 1 toks = fixed_toks tables_with_alias = get_tables_with_alias(schema.schema, toks) _, sql, mapped_entities = parse_sql(toks, 0, tables_with_alias, schema, mapped_entities_fn=lambda: []) for i, new_name in mapped_entities: curr_tok = toks[i] if '.' in curr_tok and allow_aliases: parts = curr_tok.split('.') assert(len(parts) == 2) toks[i] = parts[0] + '.' + new_name else: toks[i] = new_name if not allow_aliases: toks = [tok for tok in toks if tok not in ['as', 't1', 't2', 't3', 't4']] toks = [f'\'value\'' if tok == '"value"' else tok for tok in toks] return toks ================================================ FILE: dataset_readers/fields/knowledge_graph_field.py ================================================ """ ``KnowledgeGraphField`` is a ``Field`` which stores a knowledge graph representation. """ from typing import List, Dict from allennlp.data import TokenIndexer, Tokenizer from allennlp.data.fields.knowledge_graph_field import KnowledgeGraphField from allennlp.data.tokenizers.token import Token from allennlp.semparse.contexts.knowledge_graph import KnowledgeGraph class SpiderKnowledgeGraphField(KnowledgeGraphField): """ This implementation calculates all non-graph-related features (i.e. no related_column), then takes each one of the features to calculate related column features, by taking the max score of all neighbours """ def __init__(self, knowledge_graph: KnowledgeGraph, utterance_tokens: List[Token], token_indexers: Dict[str, TokenIndexer], tokenizer: Tokenizer = None, feature_extractors: List[str] = None, entity_tokens: List[List[Token]] = None, linking_features: List[List[List[float]]] = None, include_in_vocab: bool = True, max_table_tokens: int = None) -> None: feature_extractors = feature_extractors if feature_extractors is not None else [ 'number_token_match', 'exact_token_match', 'contains_exact_token_match', 'lemma_match', 'contains_lemma_match', 'edit_distance', 'span_overlap_fraction', 'span_lemma_overlap_fraction', ] super().__init__(knowledge_graph, utterance_tokens, token_indexers, tokenizer=tokenizer, feature_extractors=feature_extractors, entity_tokens=entity_tokens, linking_features=linking_features, include_in_vocab=include_in_vocab, max_table_tokens=max_table_tokens) self.linking_features = self._compute_related_linking_features(self.linking_features) # hack needed to fix calculation of feature extractors in the inherited as_tensor method self._feature_extractors = feature_extractors * 2 def _compute_related_linking_features(self, non_related_features: List[List[List[float]]]) -> List[List[List[float]]]: linking_features = non_related_features entity_to_index_map = {} for entity_id, entity in enumerate(self.knowledge_graph.entities): entity_to_index_map[entity] = entity_id for entity_id, (entity, entity_text) in enumerate(zip(self.knowledge_graph.entities, self.entity_texts)): for token_index, token in enumerate(self.utterance_tokens): entity_token_features = linking_features[entity_id][token_index] for feature_index, feature_extractor in enumerate(self._feature_extractors): neighbour_features = [] for neighbor in self.knowledge_graph.neighbors[entity]: # we only care about table/columns relations here, not foreign-primary if entity.startswith('column') and neighbor.startswith('column'): continue neighbor_index = entity_to_index_map[neighbor] neighbour_features.append(non_related_features[neighbor_index][token_index][feature_index]) entity_token_features.append(max(neighbour_features)) return linking_features ================================================ FILE: dataset_readers/spider.py ================================================ import json import logging import os from typing import List, Dict import dill from allennlp.common.checks import ConfigurationError from allennlp.data import DatasetReader, Tokenizer, TokenIndexer, Field, Instance from allennlp.data.fields import TextField, ProductionRuleField, ListField, IndexField, MetadataField from allennlp.data.token_indexers import SingleIdTokenIndexer from allennlp.data.tokenizers import WordTokenizer from allennlp.data.tokenizers.word_splitter import SpacyWordSplitter from overrides import overrides from spacy.symbols import ORTH, LEMMA from dataset_readers.dataset_util.spider_utils import fix_number_value, disambiguate_items from dataset_readers.fields.knowledge_graph_field import SpiderKnowledgeGraphField from semparse.contexts.spider_db_context import SpiderDBContext from semparse.worlds.spider_world import SpiderWorld logger = logging.getLogger(__name__) @DatasetReader.register("spider") class SpiderDatasetReader(DatasetReader): def __init__(self, lazy: bool = False, question_token_indexers: Dict[str, TokenIndexer] = None, keep_if_unparsable: bool = True, tables_file: str = None, dataset_path: str = 'dataset/database', load_cache: bool = True, save_cache: bool = True, loading_limit = -1): super().__init__(lazy=lazy) # default spacy tokenizer splits the common token 'id' to ['i', 'd'], we here write a manual fix for that spacy_tokenizer = SpacyWordSplitter(pos_tags=True) spacy_tokenizer.spacy.tokenizer.add_special_case(u'id', [{ORTH: u'id', LEMMA: u'id'}]) self._tokenizer = WordTokenizer(spacy_tokenizer) self._utterance_token_indexers = question_token_indexers or {'tokens': SingleIdTokenIndexer()} self._keep_if_unparsable = keep_if_unparsable self._tables_file = tables_file self._dataset_path = dataset_path self._load_cache = load_cache self._save_cache = save_cache self._loading_limit = loading_limit @overrides def _read(self, file_path: str): if not file_path.endswith('.json'): raise ConfigurationError(f"Don't know how to read filetype of {file_path}") cache_dir = os.path.join('cache', file_path.split("/")[-1]) if self._load_cache: logger.info(f'Trying to load cache from {cache_dir}') if self._save_cache: os.makedirs(cache_dir, exist_ok=True) cnt = 0 with open(file_path, "r") as data_file: json_obj = json.load(data_file) for total_cnt, ex in enumerate(json_obj): cache_filename = f'instance-{total_cnt}.pt' cache_filepath = os.path.join(cache_dir, cache_filename) if self._loading_limit == cnt: break if self._load_cache: try: ins = dill.load(open(cache_filepath, 'rb')) if ins is None and not self._keep_if_unparsable: # skip unparsed examples continue yield ins cnt += 1 continue except Exception as e: # could not load from cache - keep loading without cache pass query_tokens = None if 'query_toks' in ex: # we only have 'query_toks' in example for training/dev sets # fix for examples: we want to use the 'query_toks_no_value' field of the example which anonymizes # values. However, it also anonymizes numbers (e.g. LIMIT 3 -> LIMIT 'value', which is not good # since the official evaluator does expect a number and not a value ex = fix_number_value(ex) # we want the query tokens to be non-ambiguous (i.e. know for each column the table it belongs to, # and for each table alias its explicit name) # we thus remove all aliases and make changes such as: # 'name' -> 'singer@name', # 'singer AS T1' -> 'singer', # 'T1.name' -> 'singer@name' try: query_tokens = disambiguate_items(ex['db_id'], ex['query_toks_no_value'], self._tables_file, allow_aliases=False) except Exception as e: # there are two examples in the train set that are wrongly formatted, skip them print(f"error with {ex['query']}") print(e) ins = self.text_to_instance( utterance=ex['question'], db_id=ex['db_id'], sql=query_tokens) if ins is not None: cnt += 1 if self._save_cache: dill.dump(ins, open(cache_filepath, 'wb')) if ins is not None: yield ins def text_to_instance(self, utterance: str, db_id: str, sql: List[str] = None): fields: Dict[str, Field] = {} db_context = SpiderDBContext(db_id, utterance, tokenizer=self._tokenizer, tables_file=self._tables_file, dataset_path=self._dataset_path) table_field = SpiderKnowledgeGraphField(db_context.knowledge_graph, db_context.tokenized_utterance, self._utterance_token_indexers, entity_tokens=db_context.entity_tokens, include_in_vocab=False, # TODO: self._use_table_for_vocab, max_table_tokens=None) # self._max_table_tokens) world = SpiderWorld(db_context, query=sql) fields["utterance"] = TextField(db_context.tokenized_utterance, self._utterance_token_indexers) action_sequence, all_actions = world.get_action_sequence_and_all_actions() if action_sequence is None and self._keep_if_unparsable: # print("Parse error") action_sequence = [] elif action_sequence is None: return None index_fields: List[Field] = [] production_rule_fields: List[Field] = [] for production_rule in all_actions: nonterminal, rhs = production_rule.split(' -> ') production_rule = ' '.join(production_rule.split(' ')) field = ProductionRuleField(production_rule, world.is_global_rule(rhs), nonterminal=nonterminal) production_rule_fields.append(field) valid_actions_field = ListField(production_rule_fields) fields["valid_actions"] = valid_actions_field action_map = {action.rule: i # type: ignore for i, action in enumerate(valid_actions_field.field_list)} for production_rule in action_sequence: index_fields.append(IndexField(action_map[production_rule], valid_actions_field)) if not action_sequence: index_fields = [IndexField(-1, valid_actions_field)] action_sequence_field = ListField(index_fields) fields["action_sequence"] = action_sequence_field fields["world"] = MetadataField(world) fields["schema"] = table_field return Instance(fields) ================================================ FILE: models/__init__.py ================================================ ================================================ FILE: models/semantic_parsing/__init__.py ================================================ from models.semantic_parsing.spider_parser import SpiderParser ================================================ FILE: models/semantic_parsing/spider_parser.py ================================================ import difflib import os from functools import partial from typing import Dict, List, Tuple, Any, Mapping, Sequence import sqlparse import torch from allennlp.common.util import pad_sequence_to_length from allennlp.data import Vocabulary from allennlp.data.fields.production_rule_field import ProductionRule, ProductionRuleArray from allennlp.models import Model from allennlp.modules import TextFieldEmbedder, Seq2SeqEncoder, Seq2VecEncoder, Embedding, Attention, FeedForward, \ TimeDistributed from allennlp.modules.seq2vec_encoders import BagOfEmbeddingsEncoder from allennlp.nn import util, Activation from allennlp.state_machines import BeamSearch from allennlp.state_machines.states import GrammarStatelet from torch_geometric.data import Data, Batch from modules.gated_graph_conv import GatedGraphConv from semparse.worlds.evaluate_spider import evaluate from state_machines.states.rnn_statelet import RnnStatelet from allennlp.state_machines.trainers import MaximumMarginalLikelihood from allennlp.training.metrics import Average from overrides import overrides from semparse.contexts.spider_context_utils import action_sequence_to_sql from semparse.worlds.spider_world import SpiderWorld from state_machines.states.grammar_based_state import GrammarBasedState from state_machines.states.sql_state import SqlState from state_machines.transition_functions.attend_past_schema_items_transition import \ AttendPastSchemaItemsTransitionFunction from state_machines.transition_functions.linking_transition_function import LinkingTransitionFunction @Model.register("spider") class SpiderParser(Model): def __init__(self, vocab: Vocabulary, encoder: Seq2SeqEncoder, entity_encoder: Seq2VecEncoder, decoder_beam_search: BeamSearch, question_embedder: TextFieldEmbedder, input_attention: Attention, past_attention: Attention, max_decoding_steps: int, action_embedding_dim: int, gnn: bool = True, decoder_use_graph_entities: bool = True, decoder_self_attend: bool = True, gnn_timesteps: int = 2, parse_sql_on_decoding: bool = True, add_action_bias: bool = True, use_neighbor_similarity_for_linking: bool = True, dataset_path: str = 'dataset', training_beam_size: int = None, decoder_num_layers: int = 1, dropout: float = 0.0, rule_namespace: str = 'rule_labels', scoring_dev_params: dict = None, debug_parsing: bool = False) -> None: super().__init__(vocab) self.vocab = vocab self._encoder = encoder self._max_decoding_steps = max_decoding_steps if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x self._rule_namespace = rule_namespace self._question_embedder = question_embedder self._add_action_bias = add_action_bias self._scoring_dev_params = scoring_dev_params or {} self.parse_sql_on_decoding = parse_sql_on_decoding self._entity_encoder = TimeDistributed(entity_encoder) self._use_neighbor_similarity_for_linking = use_neighbor_similarity_for_linking self._self_attend = decoder_self_attend self._decoder_use_graph_entities = decoder_use_graph_entities self._action_padding_index = -1 # the padding value used by IndexField self._exact_match = Average() self._sql_evaluator_match = Average() self._action_similarity = Average() self._acc_single = Average() self._acc_multi = Average() self._beam_hit = Average() self._action_embedding_dim = action_embedding_dim num_actions = vocab.get_vocab_size(self._rule_namespace) if self._add_action_bias: input_action_dim = action_embedding_dim + 1 else: input_action_dim = action_embedding_dim self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=input_action_dim) self._output_action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim) encoder_output_dim = encoder.get_output_dim() if gnn: encoder_output_dim += action_embedding_dim self._first_action_embedding = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim)) self._first_attended_utterance = torch.nn.Parameter(torch.FloatTensor(encoder_output_dim)) self._first_attended_output = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim)) torch.nn.init.normal_(self._first_action_embedding) torch.nn.init.normal_(self._first_attended_utterance) torch.nn.init.normal_(self._first_attended_output) self._num_entity_types = 9 self._embedding_dim = question_embedder.get_output_dim() self._entity_type_encoder_embedding = Embedding(self._num_entity_types, self._embedding_dim) self._entity_type_decoder_embedding = Embedding(self._num_entity_types, action_embedding_dim) self._linking_params = torch.nn.Linear(16, 1) torch.nn.init.uniform_(self._linking_params.weight, 0, 1) num_edge_types = 3 self._gnn = GatedGraphConv(self._embedding_dim, gnn_timesteps, num_edge_types=num_edge_types, dropout=dropout) self._decoder_num_layers = decoder_num_layers self._beam_search = decoder_beam_search self._decoder_trainer = MaximumMarginalLikelihood(training_beam_size) if decoder_self_attend: self._transition_function = AttendPastSchemaItemsTransitionFunction(encoder_output_dim=encoder_output_dim, action_embedding_dim=action_embedding_dim, input_attention=input_attention, past_attention=past_attention, predict_start_type_separately=False, add_action_bias=self._add_action_bias, dropout=dropout, num_layers=self._decoder_num_layers) else: self._transition_function = LinkingTransitionFunction(encoder_output_dim=encoder_output_dim, action_embedding_dim=action_embedding_dim, input_attention=input_attention, predict_start_type_separately=False, add_action_bias=self._add_action_bias, dropout=dropout, num_layers=self._decoder_num_layers) self._ent2ent_ff = FeedForward(action_embedding_dim, 1, action_embedding_dim, Activation.by_name('relu')()) self._neighbor_params = torch.nn.Linear(self._embedding_dim, self._embedding_dim) # TODO: Remove hard-coded dirs self._evaluate_func = partial(evaluate, db_dir=os.path.join(dataset_path, 'database'), table=os.path.join(dataset_path, 'tables.json'), check_valid=False) self.debug_parsing = debug_parsing @overrides def forward(self, # type: ignore utterance: Dict[str, torch.LongTensor], valid_actions: List[List[ProductionRule]], world: List[SpiderWorld], schema: Dict[str, torch.LongTensor], action_sequence: torch.LongTensor = None) -> Dict[str, torch.Tensor]: batch_size = len(world) device = utterance['tokens'].device initial_state = self._get_initial_state(utterance, world, schema, valid_actions) if action_sequence is not None: # Remove the trailing dimension (from ListField[ListField[IndexField]]). action_sequence = action_sequence.squeeze(-1) action_mask = action_sequence != self._action_padding_index else: action_mask = None if self.training: decode_output = self._decoder_trainer.decode(initial_state, self._transition_function, (action_sequence.unsqueeze(1), action_mask.unsqueeze(1))) return {'loss': decode_output['loss']} else: loss = torch.tensor([0]).float().to(device) if action_sequence is not None and action_sequence.size(1) > 1: try: loss = self._decoder_trainer.decode(initial_state, self._transition_function, (action_sequence.unsqueeze(1), action_mask.unsqueeze(1)))['loss'] except ZeroDivisionError: # reached a dead-end during beam search pass outputs: Dict[str, Any] = { 'loss': loss } num_steps = self._max_decoding_steps # This tells the state to start keeping track of debug info, which we'll pass along in # our output dictionary. initial_state.debug_info = [[] for _ in range(batch_size)] best_final_states = self._beam_search.search(num_steps, initial_state, self._transition_function, keep_final_unfinished_states=False) self._compute_validation_outputs(valid_actions, best_final_states, world, action_sequence, outputs) return outputs def _get_initial_state(self, utterance: Dict[str, torch.LongTensor], worlds: List[SpiderWorld], schema: Dict[str, torch.LongTensor], actions: List[List[ProductionRule]]) -> GrammarBasedState: schema_text = schema['text'] embedded_schema = self._question_embedder(schema_text, num_wrapping_dims=1) schema_mask = util.get_text_field_mask(schema_text, num_wrapping_dims=1).float() embedded_utterance = self._question_embedder(utterance) utterance_mask = util.get_text_field_mask(utterance).float() batch_size, num_entities, num_entity_tokens, _ = embedded_schema.size() num_entities = max([len(world.db_context.knowledge_graph.entities) for world in worlds]) num_question_tokens = utterance['tokens'].size(1) # entity_types: tensor with shape (batch_size, num_entities), where each entry is the # entity's type id. # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index # These encode the same information, but for efficiency reasons later it's nice # to have one version as a tensor and one that's accessible on the cpu. entity_types, entity_type_dict = self._get_type_vector(worlds, num_entities, embedded_schema.device) entity_type_embeddings = self._entity_type_encoder_embedding(entity_types) # Compute entity and question word similarity. We tried using cosine distance here, but # because this similarity is the main mechanism that the model can use to push apart logit # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger # output range than [-1, 1]. question_entity_similarity = torch.bmm(embedded_schema.view(batch_size, num_entities * num_entity_tokens, self._embedding_dim), torch.transpose(embedded_utterance, 1, 2)) question_entity_similarity = question_entity_similarity.view(batch_size, num_entities, num_entity_tokens, num_question_tokens) # (batch_size, num_entities, num_question_tokens) question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2) # (batch_size, num_entities, num_question_tokens, num_features) linking_features = schema['linking'] linking_scores = question_entity_similarity_max_score feature_scores = self._linking_params(linking_features).squeeze(3) linking_scores = linking_scores + feature_scores # (batch_size, num_question_tokens, num_entities) linking_probabilities = self._get_linking_probabilities(worlds, linking_scores.transpose(1, 2), utterance_mask, entity_type_dict) # (batch_size, num_entities, num_neighbors) or None neighbor_indices = self._get_neighbor_indices(worlds, num_entities, linking_scores.device) if self._use_neighbor_similarity_for_linking and neighbor_indices is not None: # (batch_size, num_entities, embedding_dim) encoded_table = self._entity_encoder(embedded_schema, schema_mask) # Neighbor_indices is padded with -1 since 0 is a potential neighbor index. # Thus, the absolute value needs to be taken in the index_select, and 1 needs to # be added for the mask since that method expects 0 for padding. # (batch_size, num_entities, num_neighbors, embedding_dim) embedded_neighbors = util.batched_index_select(encoded_table, torch.abs(neighbor_indices)) neighbor_mask = util.get_text_field_mask({'ignored': neighbor_indices + 1}, num_wrapping_dims=1).float() # Encoder initialized to easily obtain a masked average. neighbor_encoder = TimeDistributed(BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True)) # (batch_size, num_entities, embedding_dim) embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask) projected_neighbor_embeddings = self._neighbor_params(embedded_neighbors.float()) # (batch_size, num_entities, embedding_dim) entity_embeddings = torch.tanh(entity_type_embeddings + projected_neighbor_embeddings) else: # (batch_size, num_entities, embedding_dim) entity_embeddings = torch.tanh(entity_type_embeddings) link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities) encoder_input = torch.cat([link_embedding, embedded_utterance], 2) # (batch_size, utterance_length, encoder_output_dim) encoder_outputs = self._dropout(self._encoder(encoder_input, utterance_mask)) max_entities_relevance = linking_probabilities.max(dim=1)[0] entities_relevance = max_entities_relevance.unsqueeze(-1).detach() graph_initial_embedding = entity_type_embeddings * entities_relevance encoder_output_dim = self._encoder.get_output_dim() if self._gnn: entities_graph_encoding = self._get_schema_graph_encoding(worlds, graph_initial_embedding) graph_link_embedding = util.weighted_sum(entities_graph_encoding, linking_probabilities) encoder_outputs = torch.cat(( encoder_outputs, graph_link_embedding ), dim=-1) encoder_output_dim = self._action_embedding_dim + self._encoder.get_output_dim() else: entities_graph_encoding = None if self._self_attend: # linked_actions_linking_scores = self._get_linked_actions_linking_scores(actions, entities_graph_encoding) entities_ff = self._ent2ent_ff(entities_graph_encoding) linked_actions_linking_scores = torch.bmm(entities_ff, entities_ff.transpose(1, 2)) else: linked_actions_linking_scores = [None] * batch_size # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states(encoder_outputs, utterance_mask, self._encoder.is_bidirectional()) memory_cell = encoder_outputs.new_zeros(batch_size, encoder_output_dim) initial_score = embedded_utterance.data.new_zeros(batch_size) # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, utterance_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. initial_score_list = [initial_score[i] for i in range(batch_size)] encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] utterance_mask_list = [utterance_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append(RnnStatelet(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_utterance, encoder_output_list, utterance_mask_list)) initial_grammar_state = [self._create_grammar_state(worlds[i], actions[i], linking_scores[i], linked_actions_linking_scores[i], entity_types[i], entities_graph_encoding[ i] if entities_graph_encoding is not None else None) for i in range(batch_size)] initial_sql_state = [SqlState(actions[i], self.parse_sql_on_decoding) for i in range(batch_size)] initial_state = GrammarBasedState(batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, sql_state=initial_sql_state, possible_actions=actions, action_entity_mapping=[w.get_action_entity_mapping() for w in worlds]) return initial_state @staticmethod def _get_neighbor_indices(worlds: List[SpiderWorld], num_entities: int, device: torch.device) -> torch.LongTensor: """ This method returns the indices of each entity's neighbors. A tensor is accepted as a parameter for copying purposes. Parameters ---------- worlds : ``List[SpiderWorld]`` num_entities : ``int`` tensor : ``torch.Tensor`` Used for copying the constructed list onto the right device. Returns ------- A ``torch.LongTensor`` with shape ``(batch_size, num_entities, num_neighbors)``. It is padded with -1 instead of 0, since 0 is a valid neighbor index. If all the entities in the batch have no neighbors, None will be returned. """ num_neighbors = 0 for world in worlds: for entity in world.db_context.knowledge_graph.entities: if len(world.db_context.knowledge_graph.neighbors[entity]) > num_neighbors: num_neighbors = len(world.db_context.knowledge_graph.neighbors[entity]) batch_neighbors = [] no_entities_have_neighbors = True for world in worlds: # Each batch instance has its own world, which has a corresponding table. entities = world.db_context.knowledge_graph.entities entity2index = {entity: i for i, entity in enumerate(entities)} entity2neighbors = world.db_context.knowledge_graph.neighbors neighbor_indexes = [] for entity in entities: entity_neighbors = [entity2index[n] for n in entity2neighbors[entity]] if entity_neighbors: no_entities_have_neighbors = False # Pad with -1 instead of 0, since 0 represents a neighbor index. padded = pad_sequence_to_length(entity_neighbors, num_neighbors, lambda: -1) neighbor_indexes.append(padded) neighbor_indexes = pad_sequence_to_length(neighbor_indexes, num_entities, lambda: [-1] * num_neighbors) batch_neighbors.append(neighbor_indexes) # It is possible that none of the entities has any neighbors, since our definition of the # knowledge graph allows it when no entities or numbers were extracted from the question. if no_entities_have_neighbors: return None return torch.tensor(batch_neighbors, device=device, dtype=torch.long) def _get_schema_graph_encoding(self, worlds: List[SpiderWorld], initial_graph_embeddings: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: max_num_entities = max([len(world.db_context.knowledge_graph.entities) for world in worlds]) batch_size = initial_graph_embeddings.size(0) graph_data_list = [] for batch_index, world in enumerate(worlds): x = initial_graph_embeddings[batch_index] adj_list = self._get_graph_adj_lists(initial_graph_embeddings.device, world, initial_graph_embeddings.size(1) - 1) graph_data = Data(x) for i, l in enumerate(adj_list): graph_data[f'edge_index_{i}'] = l graph_data_list.append(graph_data) batch = Batch.from_data_list(graph_data_list) gnn_output = self._gnn(batch.x, [batch[f'edge_index_{i}'] for i in range(self._gnn.num_edge_types)]) num_nodes = max_num_entities gnn_output = gnn_output.view(batch_size, num_nodes, -1) # entities_encodings = gnn_output entities_encodings = gnn_output[:, :max_num_entities] # global_node_encodings = gnn_output[:, max_num_entities] return entities_encodings @staticmethod def _get_graph_adj_lists(device, world, global_entity_id, global_node=False): entity_mapping = {} for i, entity in enumerate(world.db_context.knowledge_graph.entities): entity_mapping[entity] = i entity_mapping['_global_'] = global_entity_id adj_list_own = [] # column--table adj_list_link = [] # table->table / foreign->primary adj_list_linked = [] # table<-table / foreign<-primary adj_list_global = [] # node->global # TODO: Prepare in advance? for key, neighbors in world.db_context.knowledge_graph.neighbors.items(): idx_source = entity_mapping[key] for n_key in neighbors: idx_target = entity_mapping[n_key] if n_key.startswith("table") or key.startswith("table"): adj_list_own.append((idx_source, idx_target)) elif n_key.startswith("string") or key.startswith("string"): adj_list_own.append((idx_source, idx_target)) elif key.startswith("column:foreign"): adj_list_link.append((idx_source, idx_target)) src_table_key = f"table:{key.split(':')[2]}" tgt_table_key = f"table:{n_key.split(':')[2]}" idx_source_table = entity_mapping[src_table_key] idx_target_table = entity_mapping[tgt_table_key] adj_list_link.append((idx_source_table, idx_target_table)) elif n_key.startswith("column:foreign"): adj_list_linked.append((idx_source, idx_target)) src_table_key = f"table:{key.split(':')[2]}" tgt_table_key = f"table:{n_key.split(':')[2]}" idx_source_table = entity_mapping[src_table_key] idx_target_table = entity_mapping[tgt_table_key] adj_list_linked.append((idx_source_table, idx_target_table)) else: assert False adj_list_global.append((idx_source, entity_mapping['_global_'])) all_adj_types = [adj_list_own, adj_list_link, adj_list_linked] if global_node: all_adj_types.append(adj_list_global) return [torch.tensor(l, device=device, dtype=torch.long).transpose(0, 1) if l else torch.tensor(l, device=device, dtype=torch.long) for l in all_adj_types] def _create_grammar_state(self, world: SpiderWorld, possible_actions: List[ProductionRule], linking_scores: torch.Tensor, linked_actions_linking_scores: torch.Tensor, entity_types: torch.Tensor, entity_graph_encoding: torch.Tensor) -> GrammarStatelet: action_map = {} for action_index, action in enumerate(possible_actions): action_string = action[0] action_map[action_string] = action_index valid_actions = world.valid_actions entity_map = {} entities = world.entities_names for entity_index, entity in enumerate(entities): entity_map[entity] = entity_index translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {} for key, action_strings in valid_actions.items(): translated_valid_actions[key] = {} # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid # productions of that non-terminal. We'll first split those productions by global vs. # linked action. action_indices = [action_map[action_string] for action_string in action_strings] production_rule_arrays = [(possible_actions[index], index) for index in action_indices] global_actions = [] linked_actions = [] for production_rule_array, action_index in production_rule_arrays: if production_rule_array[1]: global_actions.append((production_rule_array[2], action_index)) else: linked_actions.append((production_rule_array[0], action_index)) if global_actions: global_action_tensors, global_action_ids = zip(*global_actions) global_action_tensor = torch.cat(global_action_tensors, dim=0).to( global_action_tensors[0].device).long() global_input_embeddings = self._action_embedder(global_action_tensor) global_output_embeddings = self._output_action_embedder(global_action_tensor) translated_valid_actions[key]['global'] = (global_input_embeddings, global_output_embeddings, list(global_action_ids)) if linked_actions: linked_rules, linked_action_ids = zip(*linked_actions) entities = [rule.split(' -> ')[1].strip('[]\"') for rule in linked_rules] entity_ids = [entity_map[entity] for entity in entities] entity_linking_scores = linking_scores[entity_ids] if linked_actions_linking_scores is not None: entity_action_linking_scores = linked_actions_linking_scores[entity_ids] if not self._decoder_use_graph_entities: entity_type_tensor = entity_types[entity_ids] entity_type_embeddings = (self._entity_type_decoder_embedding(entity_type_tensor) .to(entity_types.device) .float()) else: entity_type_embeddings = entity_graph_encoding.index_select( dim=0, index=torch.tensor(entity_ids, device=entity_graph_encoding.device) ) if self._self_attend: translated_valid_actions[key]['linked'] = (entity_linking_scores, entity_type_embeddings, list(linked_action_ids), entity_action_linking_scores) else: translated_valid_actions[key]['linked'] = (entity_linking_scores, entity_type_embeddings, list(linked_action_ids)) return GrammarStatelet(['statement'], translated_valid_actions, self.is_nonterminal) @staticmethod def is_nonterminal(token: str): if token[0] == '"' and token[-1] == '"': return False return True def _get_linking_probabilities(self, worlds: List[SpiderWorld], linking_scores: torch.FloatTensor, question_mask: torch.LongTensor, entity_type_dict: Dict[int, int]) -> torch.FloatTensor: """ Produces the probability of an entity given a question word and type. The logic below separates the entities by type since the softmax normalization term sums over entities of a single type. Parameters ---------- worlds : ``List[WikiTablesWorld]`` linking_scores : ``torch.FloatTensor`` Has shape (batch_size, num_question_tokens, num_entities). question_mask: ``torch.LongTensor`` Has shape (batch_size, num_question_tokens). entity_type_dict : ``Dict[int, int]`` This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id. Returns ------- batch_probabilities : ``torch.FloatTensor`` Has shape ``(batch_size, num_question_tokens, num_entities)``. Contains all the probabilities for an entity given a question word. """ _, num_question_tokens, num_entities = linking_scores.size() batch_probabilities = [] for batch_index, world in enumerate(worlds): all_probabilities = [] num_entities_in_instance = 0 # NOTE: The way that we're doing this here relies on the fact that entities are # implicitly sorted by their types when we sort them by name, and that numbers come # before "date_column:", followed by "number_column:", "string:", and "string_column:". # This is not a great assumption, and could easily break later, but it should work for now. for type_index in range(self._num_entity_types): # This index of 0 is for the null entity for each type, representing the case where a # word doesn't link to any entity. entity_indices = [0] entities = world.db_context.knowledge_graph.entities for entity_index, _ in enumerate(entities): if entity_type_dict[batch_index * num_entities + entity_index] == type_index: entity_indices.append(entity_index) if len(entity_indices) == 1: # No entities of this type; move along... continue # We're subtracting one here because of the null entity we added above. num_entities_in_instance += len(entity_indices) - 1 # We separate the scores by type, since normalization is done per type. There's an # extra "null" entity per type, also, so we have `num_entities_per_type + 1`. We're # selecting from a (num_question_tokens, num_entities) linking tensor on _dimension 1_, # so we get back something of shape (num_question_tokens,) for each index we're # selecting. All of the selected indices together then make a tensor of shape # (num_question_tokens, num_entities_per_type + 1). indices = linking_scores.new_tensor(entity_indices, dtype=torch.long) entity_scores = linking_scores[batch_index].index_select(1, indices) # We used index 0 for the null entity, so this will actually have some values in it. # But we want the null entity's score to be 0, so we set that here. entity_scores[:, 0] = 0 # No need for a mask here, as this is done per batch instance, with no padding. type_probabilities = torch.nn.functional.softmax(entity_scores, dim=1) all_probabilities.append(type_probabilities[:, 1:]) # We need to add padding here if we don't have the right number of entities. if num_entities_in_instance != num_entities: zeros = linking_scores.new_zeros(num_question_tokens, num_entities - num_entities_in_instance) all_probabilities.append(zeros) # (num_question_tokens, num_entities) probabilities = torch.cat(all_probabilities, dim=1) batch_probabilities.append(probabilities) batch_probabilities = torch.stack(batch_probabilities, dim=0) return batch_probabilities * question_mask.unsqueeze(-1).float() @staticmethod def _action_history_match(predicted: List[int], targets: torch.LongTensor) -> int: # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something. # Check if target is big enough to cover prediction (including start/end symbols) if len(predicted) > targets.size(0): return 0 predicted_tensor = targets.new_tensor(predicted) targets_trimmed = targets[:len(predicted)] # Return 1 if the predicted sequence is anywhere in the list of targets. return torch.max(torch.min(targets_trimmed.eq(predicted_tensor), dim=0)[0]).item() @staticmethod def _query_difficulty(targets: torch.LongTensor, action_mapping, batch_index): number_tables = len([action_mapping[(batch_index, int(a))] for a in targets if a >= 0 and action_mapping[(batch_index, int(a))].startswith('table_name')]) return number_tables > 1 @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: return { '_match/exact_match': self._exact_match.get_metric(reset), 'sql_match': self._sql_evaluator_match.get_metric(reset), '_others/action_similarity': self._action_similarity.get_metric(reset), '_match/match_single': self._acc_single.get_metric(reset), '_match/match_hard': self._acc_multi.get_metric(reset), 'beam_hit': self._beam_hit.get_metric(reset) } @staticmethod def _get_type_vector(worlds: List[SpiderWorld], num_entities: int, device) -> Tuple[torch.LongTensor, Dict[int, int]]: """ Produces the encoding for each entity's type. In addition, a map from a flattened entity index to type is returned to combine entity type operations into one method. Parameters ---------- worlds : ``List[AtisWorld]`` num_entities : ``int`` tensor : ``torch.Tensor`` Used for copying the constructed list onto the right device. Returns ------- A ``torch.LongTensor`` with shape ``(batch_size, num_entities, num_types)``. entity_types : ``Dict[int, int]`` This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id. """ entity_types = {} batch_types = [] column_type_ids = ['boolean', 'foreign', 'number', 'others', 'primary', 'text', 'time'] for batch_index, world in enumerate(worlds): types = [] for entity_index, entity in enumerate(world.db_context.knowledge_graph.entities): parts = entity.split(':') entity_main_type = parts[0] if entity_main_type == 'column': column_type = parts[1] entity_type = column_type_ids.index(column_type) elif entity_main_type == 'string': # cell value entity_type = len(column_type_ids) elif entity_main_type == 'table': entity_type = len(column_type_ids) + 1 else: raise (Exception("Unkown entity")) types.append(entity_type) # For easier lookups later, we're actually using a _flattened_ version # of (batch_index, entity_index) for the key, because this is how the # linking scores are stored. flattened_entity_index = batch_index * num_entities + entity_index entity_types[flattened_entity_index] = entity_type padded = pad_sequence_to_length(types, num_entities, lambda: 0) batch_types.append(padded) return torch.tensor(batch_types, dtype=torch.long, device=device), entity_types def _compute_validation_outputs(self, actions: List[List[ProductionRuleArray]], best_final_states: Mapping[int, Sequence[GrammarBasedState]], world: List[SpiderWorld], target_list: List[List[str]], outputs: Dict[str, Any]) -> None: batch_size = len(actions) outputs['predicted_sql_query'] = [] action_mapping = {} for batch_index, batch_actions in enumerate(actions): for action_index, action in enumerate(batch_actions): action_mapping[(batch_index, action_index)] = action[0] for i in range(batch_size): # gold sql exactly as given original_gold_sql_query = ' '.join(world[i].get_query_without_table_hints()) if i not in best_final_states: self._exact_match(0) self._action_similarity(0) self._sql_evaluator_match(0) self._acc_multi(0) self._acc_single(0) outputs['predicted_sql_query'].append('') continue best_action_indices = best_final_states[i][0].action_history[0] action_strings = [action_mapping[(i, action_index)] for action_index in best_action_indices] predicted_sql_query = action_sequence_to_sql(action_strings, add_table_names=True) outputs['predicted_sql_query'].append(sqlparse.format(predicted_sql_query, reindent=False)) if target_list is not None: targets = target_list[i].data target_available = target_list is not None and targets[0] > -1 if target_available: sequence_in_targets = self._action_history_match(best_action_indices, targets) self._exact_match(sequence_in_targets) sql_evaluator_match = self._evaluate_func(original_gold_sql_query, predicted_sql_query, world[i].db_id) self._sql_evaluator_match(sql_evaluator_match) similarity = difflib.SequenceMatcher(None, best_action_indices, targets) self._action_similarity(similarity.ratio()) difficulty = self._query_difficulty(targets, action_mapping, i) if difficulty: self._acc_multi(sql_evaluator_match) else: self._acc_single(sql_evaluator_match) beam_hit = False for pos, final_state in enumerate(best_final_states[i]): action_indices = final_state.action_history[0] action_strings = [action_mapping[(i, action_index)] for action_index in action_indices] candidate_sql_query = action_sequence_to_sql(action_strings, add_table_names=True) if target_available: correct = self._evaluate_func(original_gold_sql_query, candidate_sql_query, world[i].db_id) if correct: beam_hit = True self._beam_hit(beam_hit) ================================================ FILE: modules/gated_graph_conv.py ================================================ import math import torch from torch import Tensor from torch.nn import Parameter as Param, init from torch_geometric.data import Data from torch_geometric.nn.conv import MessagePassing class GatedGraphConv(MessagePassing): r"""The gated graph convolution operator from the `"Gated Graph Sequence Neural Networks" `_ paper .. math:: \mathbf{h}_i^{(0)} &= \mathbf{x}_i \, \Vert \, \mathbf{0} \mathbf{m}_i^{(l+1)} &= \sum_{j \in \mathcal{N}(i)} \mathbf{\Theta} \cdot \mathbf{h}_j^{(l)} \mathbf{h}_i^{(l+1)} &= \textrm{GRU} (\mathbf{m}_i^{(l+1)}, \mathbf{h}_i^{(l)}) up to representation :math:`\mathbf{h}_i^{(L)}`. The number of input channels of :math:`\mathbf{x}_i` needs to be less or equal than :obj:`out_channels`. Args: out_channels (int): Size of each input sample. num_layers (int): The sequence length :math:`L`. aggr (string): The aggregation scheme to use (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`). (default: :obj:`"add"`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) """ def __init__(self, input_dim, num_timesteps, num_edge_types, aggr='add', bias=True, dropout=0): super(GatedGraphConv, self).__init__(aggr) self._input_dim = input_dim self.num_timesteps = num_timesteps self.num_edge_types = num_edge_types self.weight = Param(Tensor(num_timesteps, num_edge_types, input_dim, input_dim)) self.bias = Param(Tensor(num_timesteps, num_edge_types, input_dim)) self.rnn = torch.nn.GRUCell(input_dim, input_dim, bias=bias) self.dropout = torch.nn.Dropout(dropout) self.reset_parameters() def reset_parameters(self): for t in range(self.num_timesteps): for e in range(self.num_edge_types): torch.nn.init.xavier_uniform_(self.weight[t, e]) init.uniform_(self.bias, -0.01, 0.01) self.rnn.reset_parameters() def forward(self, x, edge_indices): """""" if len(edge_indices) != self.num_edge_types: raise ValueError(f'GatedGraphConv constructed with {self.num_edge_types} edge types, ' f'but {len(edge_indices)} were passed') h = x if h.size(1) > self._input_dim: raise ValueError('The number of input channels is not allowed to ' 'be larger than the number of output channels') if h.size(1) < self._input_dim: zero = h.new_zeros(h.size(0), self._input_dim - h.size(1)) h = torch.cat([h, zero], dim=1) for t in range(self.num_timesteps): new_h = [] for e in range(self.num_edge_types): if len(edge_indices[e]) == 0: continue m = self.dropout(torch.matmul(h, self.weight[t, e]) + self.bias[t, e]) new_h.append(self.propagate(edge_indices[e], size=(x.size(0), x.size(0)), x=m)) m_sum = torch.sum(torch.stack(new_h), dim=0) h = self.rnn(m_sum, h) return h def __repr__(self): return '{}({}, num_layers={})'.format( self.__class__.__name__, self._input_dim, self.num_timesteps) if __name__ == '__main__': gcn = GatedGraphConv(input_dim=10, num_timesteps=3, num_edge_types=3) data = Data(torch.zeros((5, 10)), edge_index=[ torch.tensor([[1,2],[2,3]]), torch.tensor([[1,3],[0,1]]), torch.tensor([[1,4],[2,3]]), ]) output = gcn(data.x, data.edge_index) print(output) ================================================ FILE: predictors/spider_predictor.py ================================================ from overrides import overrides from allennlp.common.util import JsonDict, sanitize from allennlp.data import DatasetReader, Instance from allennlp.models import Model from allennlp.predictors.predictor import Predictor @Predictor.register("spider") class WikiTablesParserPredictor(Predictor): def __init__(self, model: Model, dataset_reader: DatasetReader) -> None: super().__init__(model, dataset_reader) @overrides def predict_instance(self, instance: Instance) -> JsonDict: json_output = {} outputs = self._model.forward_on_instance(instance) predicted_sql_query = outputs['predicted_sql_query'].replace('\n', ' ') if predicted_sql_query == '': # line must not be empty for the evaluator to consider it predicted_sql_query = 'NO PREDICTION' json_output['predicted_sql_query'] = predicted_sql_query return sanitize(json_output) @overrides def dump_line(self, outputs: JsonDict) -> str: # pylint: disable=no-self-use """ If you don't want your outputs in JSON-lines format you can override this function to output them differently. """ return outputs['predicted_sql_query'] + "\n" ================================================ FILE: requirements.txt ================================================ torch==1.0.1.post2 spacy==2.0.18 allennlp==0.8.2 dill torch-scatter==1.1.2 torch-sparse==0.2.4 torch-cluster==1.2.4 torch-geometric==1.1.2 ordered_set https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.0.0/en_core_web_sm-2.0.0.tar.gz ================================================ FILE: semparse/contexts/spider_context_utils.py ================================================ import re from collections import defaultdict from sys import exc_info from typing import List, Dict, Set from overrides import overrides from parsimonious.exceptions import VisitationError, UndefinedLabel from parsimonious.expressions import Literal, OneOf, Sequence from parsimonious.grammar import Grammar from parsimonious.nodes import Node, NodeVisitor from six import reraise WHITESPACE_REGEX = re.compile(" wsp |wsp | wsp| ws |ws | ws") def format_grammar_string(grammar_dictionary: Dict[str, List[str]]) -> str: """ Formats a dictionary of production rules into the string format expected by the Parsimonious Grammar class. """ return '\n'.join([f"{nonterminal} = {' / '.join(right_hand_side)}" for nonterminal, right_hand_side in grammar_dictionary.items()]) def initialize_valid_actions(grammar: Grammar, keywords_to_uppercase: List[str] = None) -> Dict[str, List[str]]: """ We initialize the valid actions with the global actions. These include the valid actions that result from the grammar and also those that result from the tables provided. The keys represent the nonterminals in the grammar and the values are lists of the valid actions of that nonterminal. """ valid_actions: Dict[str, Set[str]] = defaultdict(set) for key in grammar: rhs = grammar[key] # Sequence represents a series of expressions that match pieces of the text in order. # Eg. A -> B C if isinstance(rhs, Sequence): valid_actions[key].add( format_action(key, " ".join(rhs._unicode_members()), # pylint: disable=protected-access keywords_to_uppercase=keywords_to_uppercase)) # OneOf represents a series of expressions, one of which matches the text. # Eg. A -> B / C elif isinstance(rhs, OneOf): for option in rhs._unicode_members(): # pylint: disable=protected-access valid_actions[key].add(format_action(key, option, keywords_to_uppercase=keywords_to_uppercase)) # A string literal, eg. "A" elif isinstance(rhs, Literal): if rhs.literal != "": valid_actions[key].add(format_action(key, repr(rhs.literal), keywords_to_uppercase=keywords_to_uppercase)) else: valid_actions[key] = set() valid_action_strings = {key: sorted(value) for key, value in valid_actions.items()} return valid_action_strings def format_action(nonterminal: str, right_hand_side: str, is_string: bool = False, is_number: bool = False, keywords_to_uppercase: List[str] = None) -> str: """ This function formats an action as it appears in models. It splits productions based on the special `ws` and `wsp` rules, which are used in grammars to denote whitespace, and then rejoins these tokens a formatted, comma separated list. Importantly, note that it `does not` split on spaces in the grammar string, because these might not correspond to spaces in the language the grammar recognises. Parameters ---------- nonterminal : ``str``, required. The nonterminal in the action. right_hand_side : ``str``, required. The right hand side of the action (i.e the thing which is produced). is_string : ``bool``, optional (default = False). Whether the production produces a string. If it does, it is formatted as ``nonterminal -> ['string']`` is_number : ``bool``, optional, (default = False). Whether the production produces a string. If it does, it is formatted as ``nonterminal -> ['number']`` keywords_to_uppercase: ``List[str]``, optional, (default = None) Keywords in the grammar to uppercase. In the case of sql, this might be SELECT, MAX etc. """ keywords_to_uppercase = keywords_to_uppercase or [] if right_hand_side.upper() in keywords_to_uppercase: right_hand_side = right_hand_side.upper() if is_string: return f'{nonterminal} -> ["\'{right_hand_side}\'"]' elif is_number: return f'{nonterminal} -> ["{right_hand_side}"]' else: right_hand_side = right_hand_side.lstrip("(").rstrip(")") child_strings = [token for token in WHITESPACE_REGEX.split(right_hand_side) if token] child_strings = [tok.upper() if tok.upper() in keywords_to_uppercase else tok for tok in child_strings] return f"{nonterminal} -> [{', '.join(child_strings)}]" def action_sequence_to_sql(action_sequences: List[str], add_table_names: bool=False) -> str: # Convert an action sequence like ['statement -> [query, ";"]', ...] to the # SQL string. query = [] for action in action_sequences: nonterminal, right_hand_side = action.split(' -> ') right_hand_side_tokens = right_hand_side[1:-1].split(', ') if nonterminal == 'statement': query.extend(right_hand_side_tokens) else: for query_index, token in list(enumerate(query)): if token == nonterminal: if nonterminal == 'column_name' and '@' in right_hand_side_tokens[0] and len(right_hand_side_tokens) == 1: if add_table_names: table_name, column_name = right_hand_side_tokens[0].split('@') if '.' in table_name: table_name = table_name.split('.')[0] right_hand_side_tokens = [table_name + '.' + column_name] else: right_hand_side_tokens = [right_hand_side_tokens[0].split('@')[-1]] query = query[:query_index] + \ right_hand_side_tokens + \ query[query_index + 1:] break return ' '.join([token.strip('"') for token in query]) class SqlVisitor(NodeVisitor): """ ``SqlVisitor`` performs a depth-first traversal of the the AST. It takes the parse tree and gives us an action sequence that resulted in that parse. Since the visitor has mutable state, we define a new ``SqlVisitor`` for each query. To get the action sequence, we create a ``SqlVisitor`` and call parse on it, which returns a list of actions. Ex. sql_visitor = SqlVisitor(grammar_string) action_sequence = sql_visitor.parse(query) Importantly, this ``SqlVisitor`` skips over ``ws`` and ``wsp`` nodes, because they do not hold any meaning, and make an action sequence much longer than it needs to be. Parameters ---------- grammar : ``Grammar`` A Grammar object that we use to parse the text. keywords_to_uppercase: ``List[str]``, optional, (default = None) Keywords in the grammar to uppercase. In the case of sql, this might be SELECT, MAX etc. """ def __init__(self, grammar: Grammar, keywords_to_uppercase: List[str] = None) -> None: self.action_sequence: List[str] = [] self.grammar: Grammar = grammar self.keywords_to_uppercase = keywords_to_uppercase or [] @overrides def generic_visit(self, node: Node, visited_children: List[None]) -> List[str]: self.add_action(node) if node.expr.name == 'statement': return self.action_sequence return [] def add_action(self, node: Node) -> None: """ For each node, we accumulate the rules that generated its children in a list. """ if node.expr.name and node.expr.name not in ['ws', 'wsp']: nonterminal = f'{node.expr.name} -> ' if isinstance(node.expr, Literal): right_hand_side = f'["{node.text}"]' else: child_strings = [] for child in node.__iter__(): if child.expr.name in ['ws', 'wsp']: continue if child.expr.name != '': child_strings.append(child.expr.name) else: child_right_side_string = child.expr._as_rhs().lstrip("(").rstrip( ")") # pylint: disable=protected-access child_right_side_list = [tok for tok in WHITESPACE_REGEX.split(child_right_side_string) if tok] child_right_side_list = [tok.upper() if tok.upper() in self.keywords_to_uppercase else tok for tok in child_right_side_list] child_strings.extend(child_right_side_list) right_hand_side = "[" + ", ".join(child_strings) + "]" rule = nonterminal + right_hand_side self.action_sequence = [rule] + self.action_sequence @overrides def visit(self, node): """ See the ``NodeVisitor`` visit method. This just changes the order in which we visit nonterminals from right to left to left to right. """ method = getattr(self, 'visit_' + node.expr_name, self.generic_visit) # Call that method, and show where in the tree it failed if it blows # up. try: # Changing this to reverse here! return method(node, [self.visit(child) for child in reversed(list(node))]) except (VisitationError, UndefinedLabel): # Don't catch and re-wrap already-wrapped exceptions. raise except self.unwrapped_exceptions: raise except Exception: # pylint: disable=broad-except # Catch any exception, and tack on a parse tree so it's easier to # see where it went wrong. exc_class, exc, traceback = exc_info() reraise(VisitationError, VisitationError(exc, exc_class, node), traceback) ================================================ FILE: semparse/contexts/spider_db_context.py ================================================ import re from collections import Set, defaultdict from typing import Dict, Tuple, List from allennlp.data import Tokenizer, Token from ordered_set import OrderedSet from unidecode import unidecode from dataset_readers.dataset_util.spider_utils import TableColumn, read_dataset_schema, read_dataset_values from allennlp.semparse.contexts.knowledge_graph import KnowledgeGraph # == stop words that will be omitted by ContextGenerator STOP_WORDS = {"", "", "all", "being", "-", "over", "through", "yourselves", "its", "before", "hadn", "with", "had", ",", "should", "to", "only", "under", "ours", "has", "ought", "do", "them", "his", "than", "very", "cannot", "they", "not", "during", "yourself", "him", "nor", "did", "didn", "'ve", "this", "she", "each", "where", "because", "doing", "some", "we", "are", "further", "ourselves", "out", "what", "for", "weren", "does", "above", "between", "mustn", "?", "be", "hasn", "who", "were", "here", "shouldn", "let", "hers", "by", "both", "about", "couldn", "of", "could", "against", "isn", "or", "own", "into", "while", "whom", "down", "wasn", "your", "from", "her", "their", "aren", "there", "been", ".", "few", "too", "wouldn", "themselves", ":", "was", "until", "more", "himself", "on", "but", "don", "herself", "haven", "those", "he", "me", "myself", "these", "up", ";", "below", "'re", "can", "theirs", "my", "and", "would", "then", "is", "am", "it", "doesn", "an", "as", "itself", "at", "have", "in", "any", "if", "!", "again", "'ll", "no", "that", "when", "same", "how", "other", "which", "you", "many", "shan", "'t", "'s", "our", "after", "most", "'d", "such", "'m", "why", "a", "off", "i", "yours", "so", "the", "having", "once"} class SpiderDBContext: schemas = {} db_knowledge_graph = {} db_tables_data = {} def __init__(self, db_id: str, utterance: str, tokenizer: Tokenizer, tables_file: str, dataset_path: str): self.dataset_path = dataset_path self.tables_file = tables_file self.db_id = db_id self.utterance = utterance tokenized_utterance = tokenizer.tokenize(utterance.lower()) self.tokenized_utterance = [Token(text=t.text, lemma=t.lemma_) for t in tokenized_utterance] if db_id not in SpiderDBContext.schemas: SpiderDBContext.schemas = read_dataset_schema(self.tables_file) self.schema = SpiderDBContext.schemas[db_id] self.knowledge_graph = self.get_db_knowledge_graph(db_id) entity_texts = [self.knowledge_graph.entity_text[entity].lower() for entity in self.knowledge_graph.entities] entity_tokens = tokenizer.batch_tokenize(entity_texts) self.entity_tokens = [[Token(text=t.text, lemma=t.lemma_) for t in et] for et in entity_tokens] @staticmethod def entity_key_for_column(table_name: str, column: TableColumn) -> str: if column.foreign_key is not None: column_type = "foreign" elif column.is_primary_key: column_type = "primary" else: column_type = column.column_type return f"column:{column_type.lower()}:{table_name.lower()}:{column.name.lower()}" def get_db_knowledge_graph(self, db_id: str) -> KnowledgeGraph: entities: Set[str] = set() neighbors: Dict[str, OrderedSet[str]] = defaultdict(OrderedSet) entity_text: Dict[str, str] = {} foreign_keys_to_column: Dict[str, str] = {} db_schema = self.schema tables = db_schema.values() if db_id not in self.db_tables_data: self.db_tables_data[db_id] = read_dataset_values(db_id, self.dataset_path, tables) tables_data = self.db_tables_data[db_id] string_column_mapping: Dict[str, set] = defaultdict(set) for table, table_data in tables_data.items(): for table_row in table_data: for column, cell_value in zip(db_schema[table.name].columns, table_row): if column.column_type == 'text' and type(cell_value) is str: cell_value_normalized = self.normalize_string(cell_value) column_key = self.entity_key_for_column(table.name, column) string_column_mapping[cell_value_normalized].add(column_key) string_entities = self.get_entities_from_question(string_column_mapping) for table in tables: table_key = f"table:{table.name.lower()}" entities.add(table_key) entity_text[table_key] = table.text for column in db_schema[table.name].columns: entity_key = self.entity_key_for_column(table.name, column) entities.add(entity_key) neighbors[entity_key].add(table_key) neighbors[table_key].add(entity_key) entity_text[entity_key] = column.text for string_entity, column_keys in string_entities: entities.add(string_entity) for column_key in column_keys: neighbors[string_entity].add(column_key) neighbors[column_key].add(string_entity) entity_text[string_entity] = string_entity.replace("string:", "").replace("_", " ") # loop again after we have gone through all columns to link foreign keys columns for table_name in db_schema.keys(): for column in db_schema[table_name].columns: if column.foreign_key is None: continue other_column_table, other_column_name = column.foreign_key.split(':') # must have exactly one by design other_column = [col for col in db_schema[other_column_table].columns if col.name == other_column_name][0] entity_key = self.entity_key_for_column(table_name, column) other_entity_key = self.entity_key_for_column(other_column_table, other_column) neighbors[entity_key].add(other_entity_key) neighbors[other_entity_key].add(entity_key) foreign_keys_to_column[entity_key] = other_entity_key kg = KnowledgeGraph(entities, dict(neighbors), entity_text) kg.foreign_keys_to_column = foreign_keys_to_column return kg def _string_in_table(self, candidate: str, string_column_mapping: Dict[str, set]) -> List[str]: """ Checks if the string occurs in the table, and if it does, returns the names of the columns under which it occurs. If it does not, returns an empty list. """ candidate_column_names: List[str] = [] # First check if the entire candidate occurs as a cell. if candidate in string_column_mapping: candidate_column_names = string_column_mapping[candidate] # If not, check if it is a substring pf any cell value. if not candidate_column_names: for cell_value, column_names in string_column_mapping.items(): if candidate in cell_value: candidate_column_names.extend(column_names) candidate_column_names = list(set(candidate_column_names)) return candidate_column_names def get_entities_from_question(self, string_column_mapping: Dict[str, set]) -> List[Tuple[str, str]]: entity_data = [] for i, token in enumerate(self.tokenized_utterance): token_text = token.text if token_text in STOP_WORDS: continue normalized_token_text = self.normalize_string(token_text) if not normalized_token_text: continue token_columns = self._string_in_table(normalized_token_text, string_column_mapping) if token_columns: token_type = token_columns[0].split(":")[1] entity_data.append({'value': normalized_token_text, 'token_start': i, 'token_end': i+1, 'token_type': token_type, 'token_in_columns': token_columns}) # extracted_numbers = self._get_numbers_from_tokens(self.question_tokens) # filter out number entities to avoid repetition expanded_entities = [] for entity in self._expand_entities(self.tokenized_utterance, entity_data, string_column_mapping): if entity["token_type"] == "text": expanded_entities.append((f"string:{entity['value']}", entity['token_in_columns'])) # return expanded_entities, extracted_numbers #TODO(shikhar) Handle conjunctions return expanded_entities @staticmethod def normalize_string(string: str) -> str: """ These are the transformation rules used to normalize cell in column names in Sempre. See ``edu.stanford.nlp.sempre.tables.StringNormalizationUtils.characterNormalize`` and ``edu.stanford.nlp.sempre.tables.TableTypeSystem.canonicalizeName``. We reproduce those rules here to normalize and canonicalize cells and columns in the same way so that we can match them against constants in logical forms appropriately. """ # Normalization rules from Sempre # \u201A -> , string = re.sub("‚", ",", string) string = re.sub("„", ",,", string) string = re.sub("[·・]", ".", string) string = re.sub("…", "...", string) string = re.sub("ˆ", "^", string) string = re.sub("˜", "~", string) string = re.sub("‹", "<", string) string = re.sub("›", ">", string) string = re.sub("[‘’´`]", "'", string) string = re.sub("[“”«»]", "\"", string) string = re.sub("[•†‡²³]", "", string) string = re.sub("[‐‑–—−]", "-", string) # Oddly, some unicode characters get converted to _ instead of being stripped. Not really # sure how sempre decides what to do with these... TODO(mattg): can we just get rid of the # need for this function somehow? It's causing a whole lot of headaches. string = re.sub("[ðø′″€⁄ªΣ]", "_", string) # This is such a mess. There isn't just a block of unicode that we can strip out, because # sometimes sempre just strips diacritics... We'll try stripping out a few separate # blocks, skipping the ones that sempre skips... string = re.sub("[\\u0180-\\u0210]", "", string).strip() string = re.sub("[\\u0220-\\uFFFF]", "", string).strip() string = string.replace("\\n", "_") string = re.sub("\\s+", " ", string) # Canonicalization rules from Sempre. string = re.sub("[^\\w]", "_", string) string = re.sub("_+", "_", string) string = re.sub("_$", "", string) return unidecode(string.lower()) def _expand_entities(self, question, entity_data, string_column_mapping: Dict[str, set]): new_entities = [] for entity in entity_data: # to ensure the same strings are not used over and over if new_entities and entity['token_end'] <= new_entities[-1]['token_end']: continue current_start = entity['token_start'] current_end = entity['token_end'] current_token = entity['value'] current_token_type = entity['token_type'] current_token_columns = entity['token_in_columns'] while current_end < len(question): next_token = question[current_end].text next_token_normalized = self.normalize_string(next_token) if next_token_normalized == "": current_end += 1 continue candidate = "%s_%s" %(current_token, next_token_normalized) candidate_columns = self._string_in_table(candidate, string_column_mapping) candidate_columns = list(set(candidate_columns).intersection(current_token_columns)) if not candidate_columns: break candidate_type = candidate_columns[0].split(":")[1] if candidate_type != current_token_type: break current_end += 1 current_token = candidate current_token_columns = candidate_columns new_entities.append({'token_start': current_start, 'token_end': current_end, 'value': current_token, 'token_type': current_token_type, 'token_in_columns': current_token_columns}) return new_entities ================================================ FILE: semparse/contexts/spider_db_grammar.py ================================================ # pylint: disable=anomalous-backslash-in-string """ A ``Text2SqlTableContext`` represents the SQL context in which an utterance appears for the any of the text2sql datasets, with the grammar and the valid actions. """ from typing import List, Dict from dataset_readers.dataset_util.spider_utils import Table GRAMMAR_DICTIONARY = {} GRAMMAR_DICTIONARY["statement"] = ['(query ws iue ws query)', '(query ws)'] GRAMMAR_DICTIONARY["iue"] = ['"intersect"', '"except"', '"union"'] GRAMMAR_DICTIONARY["query"] = ['(ws select_core ws groupby_clause ws orderby_clause ws limit)', '(ws select_core ws groupby_clause ws orderby_clause)', '(ws select_core ws groupby_clause ws limit)', '(ws select_core ws orderby_clause ws limit)', '(ws select_core ws groupby_clause)', '(ws select_core ws orderby_clause)', '(ws select_core)'] GRAMMAR_DICTIONARY["select_core"] = ['(select_with_distinct ws select_results ws from_clause ws where_clause)', '(select_with_distinct ws select_results ws from_clause)', '(select_with_distinct ws select_results ws where_clause)', '(select_with_distinct ws select_results)'] GRAMMAR_DICTIONARY["select_with_distinct"] = ['(ws "select" ws "distinct")', '(ws "select")'] GRAMMAR_DICTIONARY["select_results"] = ['(ws select_result ws "," ws select_results)', '(ws select_result)'] GRAMMAR_DICTIONARY["select_result"] = ['"*"', '(table_source ws ".*")', 'expr', 'col_ref'] GRAMMAR_DICTIONARY["from_clause"] = ['(ws "from" ws table_source ws join_clauses)', '(ws "from" ws source)'] GRAMMAR_DICTIONARY["join_clauses"] = ['(join_clause ws join_clauses)', 'join_clause'] GRAMMAR_DICTIONARY["join_clause"] = ['"join" ws table_source ws "on" ws join_condition_clause'] GRAMMAR_DICTIONARY["join_condition_clause"] = ['(join_condition ws "and" ws join_condition_clause)', 'join_condition'] GRAMMAR_DICTIONARY["join_condition"] = ['ws col_ref ws "=" ws col_ref'] GRAMMAR_DICTIONARY["source"] = ['(ws single_source ws "," ws source)', '(ws single_source)'] GRAMMAR_DICTIONARY["single_source"] = ['table_source', 'source_subq'] GRAMMAR_DICTIONARY["source_subq"] = ['("(" ws query ws ")")'] # GRAMMAR_DICTIONARY["source_subq"] = ['("(" ws query ws ")" ws "as" ws name)', '("(" ws query ws ")")'] GRAMMAR_DICTIONARY["limit"] = ['("limit" ws non_literal_number)'] GRAMMAR_DICTIONARY["where_clause"] = ['(ws "where" wsp expr ws where_conj)', '(ws "where" wsp expr)'] GRAMMAR_DICTIONARY["where_conj"] = ['(ws "and" wsp expr ws where_conj)', '(ws "and" wsp expr)'] GRAMMAR_DICTIONARY["groupby_clause"] = ['(ws "group" ws "by" ws group_clause ws "having" ws expr)', '(ws "group" ws "by" ws group_clause)'] GRAMMAR_DICTIONARY["group_clause"] = ['(ws expr ws "," ws group_clause)', '(ws expr)'] GRAMMAR_DICTIONARY["orderby_clause"] = ['ws "order" ws "by" ws order_clause'] GRAMMAR_DICTIONARY["order_clause"] = ['(ordering_term ws "," ws order_clause)', 'ordering_term'] GRAMMAR_DICTIONARY["ordering_term"] = ['(ws expr ws ordering)', '(ws expr)'] GRAMMAR_DICTIONARY["ordering"] = ['(ws "asc")', '(ws "desc")'] GRAMMAR_DICTIONARY["col_ref"] = ['(table_name ws "." ws column_name)', 'column_name'] GRAMMAR_DICTIONARY["table_source"] = ['(table_name ws "as" ws table_alias)', 'table_name'] GRAMMAR_DICTIONARY["table_name"] = ["table_alias"] GRAMMAR_DICTIONARY["table_alias"] = ['"t1"', '"t2"', '"t3"', '"t4"'] GRAMMAR_DICTIONARY["column_name"] = [] GRAMMAR_DICTIONARY["ws"] = ['~"\s*"i'] GRAMMAR_DICTIONARY['wsp'] = ['~"\s+"i'] GRAMMAR_DICTIONARY["expr"] = ['in_expr', # Like expressions. '(value wsp "like" wsp string)', # Between expressions. '(value ws "between" wsp value ws "and" wsp value)', # Binary expressions. '(value ws binaryop wsp expr)', # Unary expressions. '(unaryop ws expr)', 'source_subq', 'value'] GRAMMAR_DICTIONARY["in_expr"] = ['(value wsp "not" wsp "in" wsp string_set)', '(value wsp "in" wsp string_set)', '(value wsp "not" wsp "in" wsp expr)', '(value wsp "in" wsp expr)'] GRAMMAR_DICTIONARY["value"] = ['parenval', '"YEAR(CURDATE())"', 'number', 'boolean', 'function', 'col_ref', 'string'] GRAMMAR_DICTIONARY["parenval"] = ['"(" ws expr ws ")"'] GRAMMAR_DICTIONARY["function"] = ['(fname ws "(" ws "distinct" ws arg_list_or_star ws ")")', '(fname ws "(" ws arg_list_or_star ws ")")'] GRAMMAR_DICTIONARY["arg_list_or_star"] = ['arg_list', '"*"'] GRAMMAR_DICTIONARY["arg_list"] = ['(expr ws "," ws arg_list)', 'expr'] # TODO(MARK): Massive hack, remove and modify the grammar accordingly # GRAMMAR_DICTIONARY["number"] = ['~"\d*\.?\d+"i', "'3'", "'4'"] GRAMMAR_DICTIONARY["non_literal_number"] = ['"1"', '"2"', '"3"', '"4"'] GRAMMAR_DICTIONARY["number"] = ['ws "value" ws'] GRAMMAR_DICTIONARY["string_set"] = ['ws "(" ws string_set_vals ws ")"'] GRAMMAR_DICTIONARY["string_set_vals"] = ['(string ws "," ws string_set_vals)', 'string'] # GRAMMAR_DICTIONARY["string"] = ['~"\'.*?\'"i'] GRAMMAR_DICTIONARY["string"] = ['"\'" ws "value" ws "\'"'] GRAMMAR_DICTIONARY["fname"] = ['"count"', '"sum"', '"max"', '"min"', '"avg"', '"all"'] GRAMMAR_DICTIONARY["boolean"] = ['"true"', '"false"'] # TODO(MARK): This is not tight enough. AND/OR are strictly boolean value operators. GRAMMAR_DICTIONARY["binaryop"] = ['"+"', '"-"', '"*"', '"/"', '"="', '"!="', '"<>"', '">="', '"<="', '">"', '"<"', '"and"', '"or"', '"like"'] GRAMMAR_DICTIONARY["unaryop"] = ['"+"', '"-"', '"not"', '"not"'] def update_grammar_with_tables(grammar_dictionary: Dict[str, List[str]], schema: Dict[str, Table]) -> None: table_names = sorted([f'"{table.lower()}"' for table in list(schema.keys())], reverse=True) grammar_dictionary['table_name'] += table_names all_columns = set() for table in schema.values(): all_columns.update([f'"{table.name.lower()}@{column.name.lower()}"' for column in table.columns if column.name != '*']) sorted_columns = sorted([column for column in all_columns], reverse=True) grammar_dictionary['column_name'] += sorted_columns def update_grammar_to_be_table_names_free(grammar_dictionary: Dict[str, List[str]]): """ Remove table names from column names, remove aliases """ grammar_dictionary["column_name"] = [] grammar_dictionary["table_name"] = [] grammar_dictionary["col_ref"] = ['column_name'] grammar_dictionary["table_source"] = ['table_name'] del grammar_dictionary["table_alias"] def update_grammar_flip_joins(grammar_dictionary: Dict[str, List[str]]): """ Remove table names from column names, remove aliases """ # using a simple rule such as join_clauses-> [(join_clauses ws join_clause), join_clause] # resulted in a max recursion error, so for now just using a predefined max # number of joins grammar_dictionary["join_clauses"] = ['(join_clauses_1 ws join_clause)', 'join_clause'] grammar_dictionary["join_clauses_1"] = ['(join_clauses_2 ws join_clause)', 'join_clause'] grammar_dictionary["join_clauses_2"] = ['(join_clause ws join_clause)', 'join_clause'] ================================================ FILE: semparse/worlds/evaluate.py ================================================ ################################ # val: number(float)/string(str)/sql(dict) # col_unit: (agg_id, col_id, isDistinct(bool)) # val_unit: (unit_op, col_unit1, col_unit2) # table_unit: (table_type, col_unit/sql) # cond_unit: (not_op, op_id, val_unit, val1, val2) # condition: [cond_unit1, 'and'/'or', cond_unit2, ...] # sql { # 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...]) # 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition} # 'where': condition # 'groupBy': [col_unit1, col_unit2, ...] # 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...]) # 'having': condition # 'limit': None/limit value # 'intersect': None/sql # 'except': None/sql # 'union': None/sql # } ################################ import copy import os import json import sqlite3 import argparse from spider_evaluation.process_sql import get_schema, Schema, get_sql # Flag to disable value evaluation DISABLE_VALUE = True # Flag to disable distinct in select evaluation DISABLE_DISTINCT = True CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except') JOIN_KEYWORDS = ('join', 'on', 'as') WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists') UNIT_OPS = ('none', '-', '+', "*", '/') AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg') TABLE_TYPE = { 'sql': "sql", 'table_unit': "table_unit", } COND_OPS = ('and', 'or') SQL_OPS = ('intersect', 'union', 'except') ORDER_OPS = ('desc', 'asc') HARDNESS = { "component1": ('where', 'group', 'order', 'limit', 'join', 'or', 'like'), "component2": ('except', 'union', 'intersect') } def condition_has_or(conds): return 'or' in conds[1::2] def condition_has_like(conds): return WHERE_OPS.index('like') in [cond_unit[1] for cond_unit in conds[::2]] def condition_has_sql(conds): for cond_unit in conds[::2]: val1, val2 = cond_unit[3], cond_unit[4] if val1 is not None and type(val1) is dict: return True if val2 is not None and type(val2) is dict: return True return False def val_has_op(val_unit): return val_unit[0] != UNIT_OPS.index('none') def has_agg(unit): return unit[0] != AGG_OPS.index('none') def accuracy(count, total): if count == total: return 1 return 0 def recall(count, total): if count == total: return 1 return 0 def F1(acc, rec): if (acc + rec) == 0: return 0 return (2. * acc * rec) / (acc + rec) def get_scores(count, pred_total, label_total): if pred_total != label_total: return 0,0,0 elif count == pred_total: return 1,1,1 return 0,0,0 def eval_sel(pred, label): pred_sel = pred['select'][1] label_sel = label['select'][1] label_wo_agg = [unit[1] for unit in label_sel] pred_total = len(pred_sel) label_total = len(label_sel) cnt = 0 cnt_wo_agg = 0 for unit in pred_sel: if unit in label_sel: cnt += 1 label_sel.remove(unit) if unit[1] in label_wo_agg: cnt_wo_agg += 1 label_wo_agg.remove(unit[1]) return label_total, pred_total, cnt, cnt_wo_agg def eval_where(pred, label): pred_conds = [unit for unit in pred['where'][::2]] label_conds = [unit for unit in label['where'][::2]] label_wo_agg = [unit[2] for unit in label_conds] pred_total = len(pred_conds) label_total = len(label_conds) cnt = 0 cnt_wo_agg = 0 for unit in pred_conds: if unit in label_conds: cnt += 1 label_conds.remove(unit) if unit[2] in label_wo_agg: cnt_wo_agg += 1 label_wo_agg.remove(unit[2]) return label_total, pred_total, cnt, cnt_wo_agg def eval_group(pred, label): pred_cols = [unit[1] for unit in pred['groupBy']] label_cols = [unit[1] for unit in label['groupBy']] pred_total = len(pred_cols) label_total = len(label_cols) cnt = 0 pred_cols = [pred.split(".")[1] if "." in pred else pred for pred in pred_cols] label_cols = [label.split(".")[1] if "." in label else label for label in label_cols] for col in pred_cols: if col in label_cols: cnt += 1 label_cols.remove(col) return label_total, pred_total, cnt def eval_having(pred, label): pred_total = label_total = cnt = 0 if len(pred['groupBy']) > 0: pred_total = 1 if len(label['groupBy']) > 0: label_total = 1 pred_cols = [unit[1] for unit in pred['groupBy']] label_cols = [unit[1] for unit in label['groupBy']] if pred_total == label_total == 1 \ and pred_cols == label_cols \ and pred['having'] == label['having']: cnt = 1 return label_total, pred_total, cnt def eval_order(pred, label): pred_total = label_total = cnt = 0 if len(pred['orderBy']) > 0: pred_total = 1 if len(label['orderBy']) > 0: label_total = 1 if len(label['orderBy']) > 0 and pred['orderBy'] == label['orderBy'] and \ ((pred['limit'] is None and label['limit'] is None) or (pred['limit'] is not None and label['limit'] is not None)): cnt = 1 return label_total, pred_total, cnt def eval_and_or(pred, label): pred_ao = pred['where'][1::2] label_ao = label['where'][1::2] pred_ao = set(pred_ao) label_ao = set(label_ao) if pred_ao == label_ao: return 1,1,1 return len(pred_ao),len(label_ao),0 def get_nestedSQL(sql): nested = [] for cond_unit in sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]: if type(cond_unit[3]) is dict: nested.append(cond_unit[3]) if type(cond_unit[4]) is dict: nested.append(cond_unit[4]) if sql['intersect'] is not None: nested.append(sql['intersect']) if sql['except'] is not None: nested.append(sql['except']) if sql['union'] is not None: nested.append(sql['union']) return nested def eval_nested(pred, label): label_total = 0 pred_total = 0 cnt = 0 if pred is not None: pred_total += 1 if label is not None: label_total += 1 if pred is not None and label is not None: cnt += Evaluator().eval_exact_match(pred, label) return label_total, pred_total, cnt def eval_IUEN(pred, label): lt1, pt1, cnt1 = eval_nested(pred['intersect'], label['intersect']) lt2, pt2, cnt2 = eval_nested(pred['except'], label['except']) lt3, pt3, cnt3 = eval_nested(pred['union'], label['union']) label_total = lt1 + lt2 + lt3 pred_total = pt1 + pt2 + pt3 cnt = cnt1 + cnt2 + cnt3 return label_total, pred_total, cnt def get_keywords(sql): res = set() if len(sql['where']) > 0: res.add('where') if len(sql['groupBy']) > 0: res.add('group') if len(sql['having']) > 0: res.add('having') if len(sql['orderBy']) > 0: res.add(sql['orderBy'][0]) res.add('order') if sql['limit'] is not None: res.add('limit') if sql['except'] is not None: res.add('except') if sql['union'] is not None: res.add('union') if sql['intersect'] is not None: res.add('intersect') # or keyword ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2] if len([token for token in ao if token == 'or']) > 0: res.add('or') cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2] # not keyword if len([cond_unit for cond_unit in cond_units if cond_unit[0]]) > 0: res.add('not') # in keyword if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('in')]) > 0: res.add('in') # like keyword if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) > 0: res.add('like') return res def eval_keywords(pred, label): pred_keywords = get_keywords(pred) label_keywords = get_keywords(label) pred_total = len(pred_keywords) label_total = len(label_keywords) cnt = 0 for k in pred_keywords: if k in label_keywords: cnt += 1 return label_total, pred_total, cnt def count_agg(units): return len([unit for unit in units if has_agg(unit)]) def count_component1(sql): count = 0 if len(sql['where']) > 0: count += 1 if len(sql['groupBy']) > 0: count += 1 if len(sql['orderBy']) > 0: count += 1 if sql['limit'] is not None: count += 1 if len(sql['from']['table_units']) > 0: # JOIN count += len(sql['from']['table_units']) - 1 ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2] count += len([token for token in ao if token == 'or']) cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2] count += len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) return count def count_component2(sql): nested = get_nestedSQL(sql) return len(nested) def count_others(sql): count = 0 # number of aggregation agg_count = count_agg(sql['select'][1]) agg_count += count_agg(sql['where'][::2]) agg_count += count_agg(sql['groupBy']) if len(sql['orderBy']) > 0: agg_count += count_agg([unit[1] for unit in sql['orderBy'][1] if unit[1]] + [unit[2] for unit in sql['orderBy'][1] if unit[2]]) agg_count += count_agg(sql['having']) if agg_count > 1: count += 1 # number of select columns if len(sql['select'][1]) > 1: count += 1 # number of where conditions if len(sql['where']) > 1: count += 1 # number of group by clauses if len(sql['groupBy']) > 1: count += 1 return count class Evaluator: """A simple evaluator""" def __init__(self): self.partial_scores = None def eval_hardness(self, sql): count_comp1_ = count_component1(sql) count_comp2_ = count_component2(sql) count_others_ = count_others(sql) if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0: return "easy" elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or \ (count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0): return "medium" elif (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or \ (2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or \ (count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1): return "hard" else: return "extra" def eval_exact_match(self, pred, label): partial_scores = self.eval_partial_match(pred, label) self.partial_scores = partial_scores for _, score in partial_scores.items(): if score['f1'] != 1: return 0 if len(label['from']['table_units']) > 0: label_tables = sorted(label['from']['table_units']) pred_tables = sorted(pred['from']['table_units']) if label_tables != pred_tables: return False if len(label['from']['conds']) > 0: label_joins = sorted(label['from']['conds'], key=lambda x: str(x)) pred_joins = sorted(pred['from']['conds'], key=lambda x: str(x)) if label_joins != pred_joins: return False return 1 def eval_partial_match(self, pred, label): res = {} label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) res['select'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total) res['select(no AGG)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) res['where'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total) res['where(no OP)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} label_total, pred_total, cnt = eval_group(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) res['group(no Having)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} label_total, pred_total, cnt = eval_having(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) res['group'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} label_total, pred_total, cnt = eval_order(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) res['order'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} label_total, pred_total, cnt = eval_and_or(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) res['and/or'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} label_total, pred_total, cnt = eval_IUEN(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) res['IUEN'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} label_total, pred_total, cnt = eval_keywords(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) res['keywords'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} return res def isValidSQL(sql, db): conn = sqlite3.connect(db) cursor = conn.cursor() try: cursor.execute(sql, []) except Exception as e: return False return True def print_scores(scores, etype): levels = ['easy', 'medium', 'hard', 'extra', 'all'] partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)', 'group', 'order', 'and/or', 'IUEN', 'keywords'] print("{:20} {:20} {:20} {:20} {:20} {:20}".format("", *levels)) counts = [scores[level]['count'] for level in levels] print("{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}".format("count", *counts)) if etype in ["all", "exec"]: print('===================== EXECUTION ACCURACY =====================') this_scores = [scores[level]['exec'] for level in levels] print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("execution", *this_scores)) if etype in ["all", "match"]: print('\n====================== EXACT MATCHING ACCURACY =====================') exact_scores = [scores[level]['exact'] for level in levels] print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("exact match", *exact_scores)) print('\n---------------------PARTIAL MATCHING ACCURACY----------------------') for type_ in partial_types: this_scores = [scores[level]['partial'][type_]['acc'] for level in levels] print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores)) print('---------------------- PARTIAL MATCHING RECALL ----------------------') for type_ in partial_types: this_scores = [scores[level]['partial'][type_]['rec'] for level in levels] print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores)) print('---------------------- PARTIAL MATCHING F1 --------------------------') for type_ in partial_types: this_scores = [scores[level]['partial'][type_]['f1'] for level in levels] print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores)) def evaluate(gold, predict, db_dir, etype, kmaps): with open(gold) as f: glist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0] with open(predict) as f: plist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0] # plist = [("select max(Share),min(Share) from performance where Type != 'terminal'", "orchestra")] # glist = [("SELECT max(SHARE) , min(SHARE) FROM performance WHERE TYPE != 'Live final'", "orchestra")] evaluator = Evaluator() levels = ['easy', 'medium', 'hard', 'extra', 'all'] partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)', 'group', 'order', 'and/or', 'IUEN', 'keywords'] entries = [] scores = {} for level in levels: scores[level] = {'count': 0, 'partial': {}, 'exact': 0.} scores[level]['exec'] = 0 for type_ in partial_types: scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0.,'acc_count':0,'rec_count':0} eval_err_num = 0 for p, g in zip(plist, glist): p_str = p[0] g_str, db = g db_name = db db = os.path.join(db_dir, db, db + ".sqlite") schema = Schema(get_schema(db)) g_sql = get_sql(schema, g_str) hardness = evaluator.eval_hardness(g_sql) scores[hardness]['count'] += 1 scores['all']['count'] += 1 try: p_sql = get_sql(schema, p_str) except: # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql p_sql = { "except": None, "from": { "conds": [], "table_units": [] }, "groupBy": [], "having": [], "intersect": None, "limit": None, "orderBy": [], "select": [ False, [] ], "union": None, "where": [] } eval_err_num += 1 print("eval_err_num:{}".format(eval_err_num)) print(p_str) print() # rebuild sql for value evaluation kmap = kmaps[db_name] g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema) g_sql = rebuild_sql_val(g_sql) g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap) p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema) p_sql = rebuild_sql_val(p_sql) p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap) # p_sql_copy = copy.deepcopy(p_sql) # g_sql_copy = copy.deepcopy(g_sql) # if not eval_exec_match(db, p_str, g_str, p_sql_copy, g_sql_copy) and evaluator.eval_exact_match(p_sql_copy, g_sql_copy): # a = 1 if etype in ["all", "exec"]: exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql) if exec_score: scores[hardness]['exec'] += 1 scores['all']['exec'] += exec_score if etype in ["all", "match"]: exact_score = evaluator.eval_exact_match(p_sql, g_sql) partial_scores = evaluator.partial_scores if exact_score == 0: print("{} pred: {}".format(hardness,p_str)) print("{} gold: {}".format(hardness,g_str)) print("") scores[hardness]['exact'] += exact_score scores['all']['exact'] += exact_score for type_ in partial_types: if partial_scores[type_]['pred_total'] > 0: scores[hardness]['partial'][type_]['acc'] += partial_scores[type_]['acc'] scores[hardness]['partial'][type_]['acc_count'] += 1 if partial_scores[type_]['label_total'] > 0: scores[hardness]['partial'][type_]['rec'] += partial_scores[type_]['rec'] scores[hardness]['partial'][type_]['rec_count'] += 1 scores[hardness]['partial'][type_]['f1'] += partial_scores[type_]['f1'] if partial_scores[type_]['pred_total'] > 0: scores['all']['partial'][type_]['acc'] += partial_scores[type_]['acc'] scores['all']['partial'][type_]['acc_count'] += 1 if partial_scores[type_]['label_total'] > 0: scores['all']['partial'][type_]['rec'] += partial_scores[type_]['rec'] scores['all']['partial'][type_]['rec_count'] += 1 scores['all']['partial'][type_]['f1'] += partial_scores[type_]['f1'] entries.append({ 'predictSQL': p_str, 'goldSQL': g_str, 'hardness': hardness, 'exact': exact_score, 'partial': partial_scores }) for level in levels: if scores[level]['count'] == 0: continue if etype in ["all", "exec"]: scores[level]['exec'] /= scores[level]['count'] if etype in ["all", "match"]: scores[level]['exact'] /= scores[level]['count'] for type_ in partial_types: if scores[level]['partial'][type_]['acc_count'] == 0: scores[level]['partial'][type_]['acc'] = 0 else: scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \ scores[level]['partial'][type_]['acc_count'] * 1.0 if scores[level]['partial'][type_]['rec_count'] == 0: scores[level]['partial'][type_]['rec'] = 0 else: scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \ scores[level]['partial'][type_]['rec_count'] * 1.0 if scores[level]['partial'][type_]['acc'] == 0 and scores[level]['partial'][type_]['rec'] == 0: scores[level]['partial'][type_]['f1'] = 1 else: scores[level]['partial'][type_]['f1'] = \ 2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / ( scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc']) print_scores(scores, etype) def eval_exec_match(db, p_str, g_str, pred, gold): """ return 1 if the values between prediction and gold are matching in the corresponding index. Currently not support multiple col_unit(pairs). """ conn = sqlite3.connect(db) cursor = conn.cursor() conn.text_factory = bytes try: cursor.execute(p_str) p_res = cursor.fetchall() except: return False cursor.execute(g_str) q_res = cursor.fetchall() def res_map(res, val_units): rmap = {} for idx, val_unit in enumerate(val_units): key = tuple(val_unit[1]) if not val_unit[2] else (val_unit[0], tuple(val_unit[1]), tuple(val_unit[2])) rmap[key] = [r[idx] for r in res] return rmap p_val_units = [unit[1] for unit in pred['select'][1]] q_val_units = [unit[1] for unit in gold['select'][1]] return res_map(p_res, p_val_units) == res_map(q_res, q_val_units) # Rebuild SQL functions for value evaluation def rebuild_cond_unit_val(cond_unit): if cond_unit is None or not DISABLE_VALUE: return cond_unit not_op, op_id, val_unit, val1, val2 = cond_unit if type(val1) is not dict: val1 = None else: val1 = rebuild_sql_val(val1) if type(val2) is not dict: val2 = None else: val2 = rebuild_sql_val(val2) return not_op, op_id, val_unit, val1, val2 def rebuild_condition_val(condition): if condition is None or not DISABLE_VALUE: return condition res = [] for idx, it in enumerate(condition): if idx % 2 == 0: res.append(rebuild_cond_unit_val(it)) else: res.append(it) return res def rebuild_sql_val(sql): if sql is None or not DISABLE_VALUE: return sql sql['from']['conds'] = rebuild_condition_val(sql['from']['conds']) sql['having'] = rebuild_condition_val(sql['having']) sql['where'] = rebuild_condition_val(sql['where']) sql['intersect'] = rebuild_sql_val(sql['intersect']) sql['except'] = rebuild_sql_val(sql['except']) sql['union'] = rebuild_sql_val(sql['union']) return sql # Rebuild SQL functions for foreign key evaluation def build_valid_col_units(table_units, schema): col_ids = [table_unit[1] for table_unit in table_units if table_unit[0] == TABLE_TYPE['table_unit']] prefixs = [col_id[:-2] for col_id in col_ids] valid_col_units= [] for value in schema.idMap.values(): if '.' in value and value[:value.index('.')] in prefixs: valid_col_units.append(value) return valid_col_units def rebuild_col_unit_col(valid_col_units, col_unit, kmap): if col_unit is None: return col_unit agg_id, col_id, distinct = col_unit if col_id in kmap and col_id in valid_col_units: col_id = kmap[col_id] if DISABLE_DISTINCT: distinct = None return agg_id, col_id, distinct def rebuild_val_unit_col(valid_col_units, val_unit, kmap): if val_unit is None: return val_unit unit_op, col_unit1, col_unit2 = val_unit col_unit1 = rebuild_col_unit_col(valid_col_units, col_unit1, kmap) col_unit2 = rebuild_col_unit_col(valid_col_units, col_unit2, kmap) return unit_op, col_unit1, col_unit2 def rebuild_table_unit_col(valid_col_units, table_unit, kmap): if table_unit is None: return table_unit table_type, col_unit_or_sql = table_unit if isinstance(col_unit_or_sql, tuple): col_unit_or_sql = rebuild_col_unit_col(valid_col_units, col_unit_or_sql, kmap) return table_type, col_unit_or_sql def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap): if cond_unit is None: return cond_unit not_op, op_id, val_unit, val1, val2 = cond_unit val_unit = rebuild_val_unit_col(valid_col_units, val_unit, kmap) return not_op, op_id, val_unit, val1, val2 def rebuild_condition_col(valid_col_units, condition, kmap): for idx in range(len(condition)): if idx % 2 == 0: condition[idx] = rebuild_cond_unit_col(valid_col_units, condition[idx], kmap) return condition def rebuild_select_col(valid_col_units, sel, kmap): if sel is None: return sel distinct, _list = sel new_list = [] for it in _list: agg_id, val_unit = it new_list.append((agg_id, rebuild_val_unit_col(valid_col_units, val_unit, kmap))) if DISABLE_DISTINCT: distinct = None return distinct, new_list def rebuild_from_col(valid_col_units, from_, kmap): if from_ is None: return from_ from_['table_units'] = [rebuild_table_unit_col(valid_col_units, table_unit, kmap) for table_unit in from_['table_units']] from_['conds'] = rebuild_condition_col(valid_col_units, from_['conds'], kmap) return from_ def rebuild_group_by_col(valid_col_units, group_by, kmap): if group_by is None: return group_by return [rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by] def rebuild_order_by_col(valid_col_units, order_by, kmap): if order_by is None or len(order_by) == 0: return order_by direction, val_units = order_by new_val_units = [rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units] return direction, new_val_units def rebuild_sql_col(valid_col_units, sql, kmap): if sql is None: return sql sql['select'] = rebuild_select_col(valid_col_units, sql['select'], kmap) sql['from'] = rebuild_from_col(valid_col_units, sql['from'], kmap) sql['where'] = rebuild_condition_col(valid_col_units, sql['where'], kmap) sql['groupBy'] = rebuild_group_by_col(valid_col_units, sql['groupBy'], kmap) sql['orderBy'] = rebuild_order_by_col(valid_col_units, sql['orderBy'], kmap) sql['having'] = rebuild_condition_col(valid_col_units, sql['having'], kmap) sql['intersect'] = rebuild_sql_col(valid_col_units, sql['intersect'], kmap) sql['except'] = rebuild_sql_col(valid_col_units, sql['except'], kmap) sql['union'] = rebuild_sql_col(valid_col_units, sql['union'], kmap) return sql def build_foreign_key_map(entry): cols_orig = entry["column_names_original"] tables_orig = entry["table_names_original"] # rebuild cols corresponding to idmap in Schema cols = [] for col_orig in cols_orig: if col_orig[0] >= 0: t = tables_orig[col_orig[0]] c = col_orig[1] cols.append("__" + t.lower() + "." + c.lower() + "__") else: cols.append("__all__") def keyset_in_list(k1, k2, k_list): for k_set in k_list: if k1 in k_set or k2 in k_set: return k_set new_k_set = set() k_list.append(new_k_set) return new_k_set foreign_key_list = [] foreign_keys = entry["foreign_keys"] for fkey in foreign_keys: key1, key2 = fkey key_set = keyset_in_list(key1, key2, foreign_key_list) key_set.add(key1) key_set.add(key2) foreign_key_map = {} for key_set in foreign_key_list: sorted_list = sorted(list(key_set)) midx = sorted_list[0] for idx in sorted_list: foreign_key_map[cols[idx]] = cols[midx] return foreign_key_map def build_foreign_key_map_from_json(table): with open(table) as f: data = json.load(f) tables = {} for entry in data: tables[entry['db_id']] = build_foreign_key_map(entry) return tables if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--gold', dest='gold', type=str) parser.add_argument('--pred', dest='pred', type=str) parser.add_argument('--db', dest='db', type=str) parser.add_argument('--table', dest='table', type=str) parser.add_argument('--etype', dest='etype', type=str) args = parser.parse_args() gold = args.gold pred = args.pred db_dir = args.db table = args.table etype = args.etype assert etype in ["all", "exec", "match"], "Unknown evaluation method" kmaps = build_foreign_key_map_from_json(table) evaluate(gold, pred, db_dir, etype, kmaps) ================================================ FILE: semparse/worlds/evaluate_spider.py ================================================ import os import sqlite3 from semparse.worlds.evaluate import Evaluator, build_valid_col_units, rebuild_sql_val, rebuild_sql_col, \ build_foreign_key_map_from_json from spider_evaluation.process_sql import Schema, get_schema, get_sql _schemas = {} kmaps = None def evaluate(gold, predict, db_name, db_dir, table, check_valid: bool=True) -> bool: global kmaps # try: evaluator = Evaluator() if kmaps is None: kmaps = build_foreign_key_map_from_json(table) if db_name in _schemas: schema = _schemas[db_name] else: db = os.path.join(db_dir, db_name, db_name + ".sqlite") schema = _schemas[db_name] = Schema(get_schema(db)) g_sql = get_sql(schema, gold) try: p_sql = get_sql(schema, predict) except Exception as e: return False # rebuild sql for value evaluation kmap = kmaps[db_name] g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema) g_sql = rebuild_sql_val(g_sql) g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap) p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema) p_sql = rebuild_sql_val(p_sql) p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap) exact_score = evaluator.eval_exact_match(p_sql, g_sql) if not check_valid: return exact_score else: return exact_score and check_valid_sql(predict, db_name, db_dir) # except Exception as e: # return 0 _conns = {} def check_valid_sql(sql, db_name, db_dir, return_error=False): db = os.path.join(db_dir, db_name, db_name + ".sqlite") if db_name == 'wta_1': # TODO: seems like there is a problem with this dataset - slow response - add limit 1 return True if not return_error else (True, None) if db_name not in _conns: _conns[db_name] = sqlite3.connect(db) # fixes an encoding bug _conns[db_name].text_factory = bytes conn = _conns[db_name] cursor = conn.cursor() try: cursor.execute(sql) cursor.fetchall() return True if not return_error else (True, None) except Exception as e: return False if not return_error else (False, e.args[0]) ================================================ FILE: semparse/worlds/spider_world.py ================================================ from typing import List, Tuple, Dict, Set, Optional from copy import deepcopy from parsimonious import Grammar from parsimonious.exceptions import ParseError from semparse.contexts.spider_context_utils import format_grammar_string, initialize_valid_actions, SqlVisitor from semparse.contexts.spider_db_context import SpiderDBContext from semparse.contexts.spider_db_grammar import GRAMMAR_DICTIONARY, update_grammar_with_tables, \ update_grammar_to_be_table_names_free, update_grammar_flip_joins class SpiderWorld: """ World representation for spider dataset. """ def __init__(self, db_context: SpiderDBContext, query: Optional[List[str]], allow_alias: bool = False) -> None: self.db_id = db_context.db_id self.allow_alias = allow_alias # NOTE: This base dictionary should not be modified. self.base_grammar_dictionary = deepcopy(GRAMMAR_DICTIONARY) self.query = query self.db_context = db_context # keep a list of entities names as they are given in sql queries self.entities_names = {} for i, entity in enumerate(self.db_context.knowledge_graph.entities): parts = entity.split(':') if parts[0] in ['table', 'string']: self.entities_names[parts[1]] = i else: _, _, table_name, column_name = parts self.entities_names[f'{table_name}@{column_name}'] = i self.valid_actions = [] self.valid_actions_flat = [] def get_action_sequence_and_all_actions(self, allow_aliases: bool = False) -> Tuple[List[str], List[str]]: grammar_with_context = deepcopy(self.base_grammar_dictionary) if not allow_aliases: update_grammar_to_be_table_names_free(grammar_with_context) schema = self.db_context.schema update_grammar_with_tables(grammar_with_context, schema) grammar = Grammar(format_grammar_string(grammar_with_context)) valid_actions = initialize_valid_actions(grammar) all_actions = set() for action_list in valid_actions.values(): all_actions.update(action_list) sorted_actions = sorted(all_actions) self.valid_actions = valid_actions self.valid_actions_flat = sorted_actions action_sequence = None if self.query is not None: sql_visitor = SqlVisitor(grammar) query = " ".join(self.query).lower().replace("``", "'").replace("''", "'") try: action_sequence = sql_visitor.parse(query) if query else [] except ParseError as e: pass return action_sequence, sorted_actions def get_all_actions(self, schema, flip_joins: bool, allow_aliases: bool) -> Tuple[List[str], List[str]]: grammar_with_context = deepcopy(self.base_grammar_dictionary) if not allow_aliases: update_grammar_to_be_table_names_free(grammar_with_context) if flip_joins: update_grammar_flip_joins(grammar_with_context) update_grammar_with_tables(grammar_with_context, schema, self.db_id) grammar = Grammar(format_grammar_string(grammar_with_context)) valid_actions = initialize_valid_actions(grammar) all_actions = set() for action_list in valid_actions.values(): all_actions.update(action_list) sorted_actions = sorted(all_actions) self.valid_actions = valid_actions self.valid_actions_flat = sorted_actions return sorted_actions def is_global_rule(self, rhs: str) -> bool: rhs = rhs.strip('[] ') if rhs[0] != '"': return True return rhs.strip('"') not in self.entities_names def get_oracle_relevance_score(self, oracle_entities: set): """ return 0/1 for each schema item if it should be in the graph, given the used entities in the gold answer """ scores = [0 for _ in range(len(self.db_context.knowledge_graph.entities))] for i, entity in enumerate(self.db_context.knowledge_graph.entities): parts = entity.split(':') if parts[0] == 'column': name = parts[2] + '@' + parts[3] else: name = parts[-1] if name in oracle_entities: scores[i] = 1 return scores def get_action_entity_mapping(self) -> Dict[int, int]: mapping = {} for action_index, action in enumerate(self.valid_actions_flat): # default is padding mapping[action_index] = -1 action = action.split(" -> ")[1].strip('[]') action_stripped = action.strip('\"') if action[0] != '"' or action_stripped not in self.entities_names: continue mapping[action_index] = self.entities_names[action_stripped] return mapping def get_query_without_table_hints(self): if not self.query: return '' toks = [] for tok in self.query: if '@' in tok: parts = tok.split('@') if '.' in parts[0]: toks.append(parts[0].split('.')[0] + '.' + parts[1]) else: toks.append(parts[1]) else: toks.append(tok) return toks # def is_ambiguous_column(self, action_rhs: str, actions_sequence: List[str], action_index: int): # """ # a column would be ambiguous if another table is used in the query and it has a column with the same name # currently, return true only for join clauses # """ # if actions_sequence[action_index-1].startswith('join_condition -> ') or \ # actions_sequence[action_index-2].startswith('join_condition -> '): # return True # # return False # # column_table, column_name = action_rhs.strip('"').split('.') # # tables_used = [a.split(' -> ')[1].strip('[]\"') for a in actions_sequence if a.startswith('table_name -> ')] # # columns_with_same_name = set([a.split(' -> ')[1].strip('[]\"') for a in actions_sequence # # if a.startswith('column_name -> ') and a.strip('[]\"').endswith(column_name)]) # # table_of_columns_with_same_name = set([c.split('.')[0] for c in columns_with_same_name]) # # if len(table_of_columns_with_same_name) > 1: # # return True # # if column_table not in tables_used: # # return False # # other_tables = [t for t in tables_used if t != column_table] # # other_tables_columns = [[c.split(':')[-1] for c in self.knowledge_graph.neighbors[f'table:{t}']] for t in other_tables] # # other_tables_columns_set = set([item for sublist in other_tables_columns for item in sublist]) # # return column_name in other_tables_columns_set ================================================ FILE: spider_evaluation/evaluate.py ================================================ ################################ # val: number(float)/string(str)/sql(dict) # col_unit: (agg_id, col_id, isDistinct(bool)) # val_unit: (unit_op, col_unit1, col_unit2) # table_unit: (table_type, col_unit/sql) # cond_unit: (not_op, op_id, val_unit, val1, val2) # condition: [cond_unit1, 'and'/'or', cond_unit2, ...] # sql { # 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...]) # 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition} # 'where': condition # 'groupBy': [col_unit1, col_unit2, ...] # 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...]) # 'having': condition # 'limit': None/limit value # 'intersect': None/sql # 'except': None/sql # 'union': None/sql # } ################################ import copy import os import json import sqlite3 import argparse from spider_evaluation import get_schema, Schema, get_sql # Flag to disable value evaluation DISABLE_VALUE = True # Flag to disable distinct in select evaluation DISABLE_DISTINCT = True CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except') JOIN_KEYWORDS = ('join', 'on', 'as') WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists') UNIT_OPS = ('none', '-', '+', "*", '/') AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg') TABLE_TYPE = { 'sql': "sql", 'table_unit': "table_unit", } COND_OPS = ('and', 'or') SQL_OPS = ('intersect', 'union', 'except') ORDER_OPS = ('desc', 'asc') HARDNESS = { "component1": ('where', 'group', 'order', 'limit', 'join', 'or', 'like'), "component2": ('except', 'union', 'intersect') } def condition_has_or(conds): return 'or' in conds[1::2] def condition_has_like(conds): return WHERE_OPS.index('like') in [cond_unit[1] for cond_unit in conds[::2]] def condition_has_sql(conds): for cond_unit in conds[::2]: val1, val2 = cond_unit[3], cond_unit[4] if val1 is not None and type(val1) is dict: return True if val2 is not None and type(val2) is dict: return True return False def val_has_op(val_unit): return val_unit[0] != UNIT_OPS.index('none') def has_agg(unit): return unit[0] != AGG_OPS.index('none') def accuracy(count, total): if count == total: return 1 return 0 def recall(count, total): if count == total: return 1 return 0 def F1(acc, rec): if (acc + rec) == 0: return 0 return (2. * acc * rec) / (acc + rec) def get_scores(count, pred_total, label_total): if pred_total != label_total: return 0,0,0 elif count == pred_total: return 1,1,1 return 0,0,0 def eval_sel(pred, label): pred_sel = pred['select'][1] label_sel = label['select'][1] label_wo_agg = [unit[1] for unit in label_sel] pred_total = len(pred_sel) label_total = len(label_sel) cnt = 0 cnt_wo_agg = 0 for unit in pred_sel: if unit in label_sel: cnt += 1 label_sel.remove(unit) if unit[1] in label_wo_agg: cnt_wo_agg += 1 label_wo_agg.remove(unit[1]) return label_total, pred_total, cnt, cnt_wo_agg def eval_where(pred, label): pred_conds = [unit for unit in pred['where'][::2]] label_conds = [unit for unit in label['where'][::2]] label_wo_agg = [unit[2] for unit in label_conds] pred_total = len(pred_conds) label_total = len(label_conds) cnt = 0 cnt_wo_agg = 0 for unit in pred_conds: if unit in label_conds: cnt += 1 label_conds.remove(unit) if unit[2] in label_wo_agg: cnt_wo_agg += 1 label_wo_agg.remove(unit[2]) return label_total, pred_total, cnt, cnt_wo_agg def eval_group(pred, label): pred_cols = [unit[1] for unit in pred['groupBy']] label_cols = [unit[1] for unit in label['groupBy']] pred_total = len(pred_cols) label_total = len(label_cols) cnt = 0 pred_cols = [pred.split(".")[1] if "." in pred else pred for pred in pred_cols] label_cols = [label.split(".")[1] if "." in label else label for label in label_cols] for col in pred_cols: if col in label_cols: cnt += 1 label_cols.remove(col) return label_total, pred_total, cnt def eval_having(pred, label): pred_total = label_total = cnt = 0 if len(pred['groupBy']) > 0: pred_total = 1 if len(label['groupBy']) > 0: label_total = 1 pred_cols = [unit[1] for unit in pred['groupBy']] label_cols = [unit[1] for unit in label['groupBy']] if pred_total == label_total == 1 \ and pred_cols == label_cols \ and pred['having'] == label['having']: cnt = 1 return label_total, pred_total, cnt def eval_order(pred, label): pred_total = label_total = cnt = 0 if len(pred['orderBy']) > 0: pred_total = 1 if len(label['orderBy']) > 0: label_total = 1 if len(label['orderBy']) > 0 and pred['orderBy'] == label['orderBy'] and \ ((pred['limit'] is None and label['limit'] is None) or (pred['limit'] is not None and label['limit'] is not None)): cnt = 1 return label_total, pred_total, cnt def eval_and_or(pred, label): pred_ao = pred['where'][1::2] label_ao = label['where'][1::2] pred_ao = set(pred_ao) label_ao = set(label_ao) if pred_ao == label_ao: return 1,1,1 return len(pred_ao),len(label_ao),0 def get_nestedSQL(sql): nested = [] for cond_unit in sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]: if type(cond_unit[3]) is dict: nested.append(cond_unit[3]) if type(cond_unit[4]) is dict: nested.append(cond_unit[4]) if sql['intersect'] is not None: nested.append(sql['intersect']) if sql['except'] is not None: nested.append(sql['except']) if sql['union'] is not None: nested.append(sql['union']) return nested def eval_nested(pred, label): label_total = 0 pred_total = 0 cnt = 0 if pred is not None: pred_total += 1 if label is not None: label_total += 1 if pred is not None and label is not None: cnt += Evaluator().eval_exact_match(pred, label) return label_total, pred_total, cnt def eval_IUEN(pred, label): lt1, pt1, cnt1 = eval_nested(pred['intersect'], label['intersect']) lt2, pt2, cnt2 = eval_nested(pred['except'], label['except']) lt3, pt3, cnt3 = eval_nested(pred['union'], label['union']) label_total = lt1 + lt2 + lt3 pred_total = pt1 + pt2 + pt3 cnt = cnt1 + cnt2 + cnt3 return label_total, pred_total, cnt def get_keywords(sql): res = set() if len(sql['where']) > 0: res.add('where') if len(sql['groupBy']) > 0: res.add('group') if len(sql['having']) > 0: res.add('having') if len(sql['orderBy']) > 0: res.add(sql['orderBy'][0]) res.add('order') if sql['limit'] is not None: res.add('limit') if sql['except'] is not None: res.add('except') if sql['union'] is not None: res.add('union') if sql['intersect'] is not None: res.add('intersect') # or keyword ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2] if len([token for token in ao if token == 'or']) > 0: res.add('or') cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2] # not keyword if len([cond_unit for cond_unit in cond_units if cond_unit[0]]) > 0: res.add('not') # in keyword if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('in')]) > 0: res.add('in') # like keyword if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) > 0: res.add('like') return res def eval_keywords(pred, label): pred_keywords = get_keywords(pred) label_keywords = get_keywords(label) pred_total = len(pred_keywords) label_total = len(label_keywords) cnt = 0 for k in pred_keywords: if k in label_keywords: cnt += 1 return label_total, pred_total, cnt def count_agg(units): return len([unit for unit in units if has_agg(unit)]) def count_component1(sql): count = 0 if len(sql['where']) > 0: count += 1 if len(sql['groupBy']) > 0: count += 1 if len(sql['orderBy']) > 0: count += 1 if sql['limit'] is not None: count += 1 if len(sql['from']['table_units']) > 0: # JOIN count += len(sql['from']['table_units']) - 1 ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2] count += len([token for token in ao if token == 'or']) cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2] count += len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) return count def count_component2(sql): nested = get_nestedSQL(sql) return len(nested) def count_others(sql): count = 0 # number of aggregation agg_count = count_agg(sql['select'][1]) agg_count += count_agg(sql['where'][::2]) agg_count += count_agg(sql['groupBy']) if len(sql['orderBy']) > 0: agg_count += count_agg([unit[1] for unit in sql['orderBy'][1] if unit[1]] + [unit[2] for unit in sql['orderBy'][1] if unit[2]]) agg_count += count_agg(sql['having']) if agg_count > 1: count += 1 # number of select columns if len(sql['select'][1]) > 1: count += 1 # number of where conditions if len(sql['where']) > 1: count += 1 # number of group by clauses if len(sql['groupBy']) > 1: count += 1 return count class Evaluator: """A simple evaluator""" def __init__(self): self.partial_scores = None def eval_hardness(self, sql): count_comp1_ = count_component1(sql) count_comp2_ = count_component2(sql) count_others_ = count_others(sql) if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0: return "easy" elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or \ (count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0): return "medium" elif (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or \ (2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or \ (count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1): return "hard" else: return "extra" def eval_exact_match(self, pred, label): partial_scores = self.eval_partial_match(pred, label) self.partial_scores = partial_scores for _, score in partial_scores.items(): if score['f1'] != 1: return 0 if len(label['from']['table_units']) > 0: label_tables = sorted(label['from']['table_units']) pred_tables = sorted(pred['from']['table_units']) if label_tables != pred_tables: return False if len(label['from']['conds']) > 0: label_joins = sorted(label['from']['conds'], key=lambda x: str(x)) pred_joins = sorted(pred['from']['conds'], key=lambda x: str(x)) if label_joins != pred_joins: return False return 1 def eval_partial_match(self, pred, label): res = {} label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) res['select'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total) res['select(no AGG)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) res['where'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total) res['where(no OP)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} label_total, pred_total, cnt = eval_group(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) res['group(no Having)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} label_total, pred_total, cnt = eval_having(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) res['group'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} label_total, pred_total, cnt = eval_order(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) res['order'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} label_total, pred_total, cnt = eval_and_or(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) res['and/or'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} label_total, pred_total, cnt = eval_IUEN(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) res['IUEN'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} label_total, pred_total, cnt = eval_keywords(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) res['keywords'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} return res def isValidSQL(sql, db): conn = sqlite3.connect(db) cursor = conn.cursor() try: cursor.execute(sql, []) except Exception as e: return False return True def print_scores(scores, etype): levels = ['easy', 'medium', 'hard', 'extra', 'all'] partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)', 'group', 'order', 'and/or', 'IUEN', 'keywords'] print("{:20} {:20} {:20} {:20} {:20} {:20}".format("", *levels)) counts = [scores[level]['count'] for level in levels] print("{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}".format("count", *counts)) if etype in ["all", "exec"]: print('===================== EXECUTION ACCURACY =====================') this_scores = [scores[level]['exec'] for level in levels] print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("execution", *this_scores)) if etype in ["all", "match"]: print('\n====================== EXACT MATCHING ACCURACY =====================') exact_scores = [scores[level]['exact'] for level in levels] print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("exact match", *exact_scores)) print('\n---------------------PARTIAL MATCHING ACCURACY----------------------') for type_ in partial_types: this_scores = [scores[level]['partial'][type_]['acc'] for level in levels] print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores)) print('---------------------- PARTIAL MATCHING RECALL ----------------------') for type_ in partial_types: this_scores = [scores[level]['partial'][type_]['rec'] for level in levels] print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores)) print('---------------------- PARTIAL MATCHING F1 --------------------------') for type_ in partial_types: this_scores = [scores[level]['partial'][type_]['f1'] for level in levels] print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores)) def evaluate(gold, predict, db_dir, etype, kmaps): with open(gold) as f: glist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0] with open(predict) as f: plist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0] # plist = [("select max(Share),min(Share) from performance where Type != 'terminal'", "orchestra")] # glist = [("SELECT max(SHARE) , min(SHARE) FROM performance WHERE TYPE != 'Live final'", "orchestra")] evaluator = Evaluator() levels = ['easy', 'medium', 'hard', 'extra', 'all'] partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)', 'group', 'order', 'and/or', 'IUEN', 'keywords'] entries = [] scores = {} for level in levels: scores[level] = {'count': 0, 'partial': {}, 'exact': 0.} scores[level]['exec'] = 0 for type_ in partial_types: scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0.,'acc_count':0,'rec_count':0} eval_err_num = 0 for p, g in zip(plist, glist): p_str = p[0] g_str, db = g db_name = db db = os.path.join(db_dir, db, db + ".sqlite") schema = Schema(get_schema(db)) g_sql = get_sql(schema, g_str) hardness = evaluator.eval_hardness(g_sql) scores[hardness]['count'] += 1 scores['all']['count'] += 1 try: p_sql = get_sql(schema, p_str) except: # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql p_sql = { "except": None, "from": { "conds": [], "table_units": [] }, "groupBy": [], "having": [], "intersect": None, "limit": None, "orderBy": [], "select": [ False, [] ], "union": None, "where": [] } eval_err_num += 1 print("eval_err_num:{}".format(eval_err_num)) print(p_str) print() # rebuild sql for value evaluation kmap = kmaps[db_name] g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema) g_sql = rebuild_sql_val(g_sql) g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap) p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema) p_sql = rebuild_sql_val(p_sql) p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap) # p_sql_copy = copy.deepcopy(p_sql) # g_sql_copy = copy.deepcopy(g_sql) # if not eval_exec_match(db, p_str, g_str, p_sql_copy, g_sql_copy) and evaluator.eval_exact_match(p_sql_copy, g_sql_copy): # a = 1 if etype in ["all", "exec"]: exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql) if exec_score: scores[hardness]['exec'] += 1 scores['all']['exec'] += exec_score if etype in ["all", "match"]: exact_score = evaluator.eval_exact_match(p_sql, g_sql) partial_scores = evaluator.partial_scores if exact_score == 0: print("{} pred: {}".format(hardness,p_str)) print("{} gold: {}".format(hardness,g_str)) print("") scores[hardness]['exact'] += exact_score scores['all']['exact'] += exact_score for type_ in partial_types: if partial_scores[type_]['pred_total'] > 0: scores[hardness]['partial'][type_]['acc'] += partial_scores[type_]['acc'] scores[hardness]['partial'][type_]['acc_count'] += 1 if partial_scores[type_]['label_total'] > 0: scores[hardness]['partial'][type_]['rec'] += partial_scores[type_]['rec'] scores[hardness]['partial'][type_]['rec_count'] += 1 scores[hardness]['partial'][type_]['f1'] += partial_scores[type_]['f1'] if partial_scores[type_]['pred_total'] > 0: scores['all']['partial'][type_]['acc'] += partial_scores[type_]['acc'] scores['all']['partial'][type_]['acc_count'] += 1 if partial_scores[type_]['label_total'] > 0: scores['all']['partial'][type_]['rec'] += partial_scores[type_]['rec'] scores['all']['partial'][type_]['rec_count'] += 1 scores['all']['partial'][type_]['f1'] += partial_scores[type_]['f1'] entries.append({ 'predictSQL': p_str, 'goldSQL': g_str, 'hardness': hardness, 'exact': exact_score, 'partial': partial_scores }) for level in levels: if scores[level]['count'] == 0: continue if etype in ["all", "exec"]: scores[level]['exec'] /= scores[level]['count'] if etype in ["all", "match"]: scores[level]['exact'] /= scores[level]['count'] for type_ in partial_types: if scores[level]['partial'][type_]['acc_count'] == 0: scores[level]['partial'][type_]['acc'] = 0 else: scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \ scores[level]['partial'][type_]['acc_count'] * 1.0 if scores[level]['partial'][type_]['rec_count'] == 0: scores[level]['partial'][type_]['rec'] = 0 else: scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \ scores[level]['partial'][type_]['rec_count'] * 1.0 if scores[level]['partial'][type_]['acc'] == 0 and scores[level]['partial'][type_]['rec'] == 0: scores[level]['partial'][type_]['f1'] = 1 else: scores[level]['partial'][type_]['f1'] = \ 2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / ( scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc']) print_scores(scores, etype) def eval_exec_match(db, p_str, g_str, pred, gold): """ return 1 if the values between prediction and gold are matching in the corresponding index. Currently not support multiple col_unit(pairs). """ conn = sqlite3.connect(db) cursor = conn.cursor() conn.text_factory = bytes try: cursor.execute(p_str) p_res = cursor.fetchall() except: return False cursor.execute(g_str) q_res = cursor.fetchall() def res_map(res, val_units): rmap = {} for idx, val_unit in enumerate(val_units): key = tuple(val_unit[1]) if not val_unit[2] else (val_unit[0], tuple(val_unit[1]), tuple(val_unit[2])) rmap[key] = [r[idx] for r in res] return rmap p_val_units = [unit[1] for unit in pred['select'][1]] q_val_units = [unit[1] for unit in gold['select'][1]] return res_map(p_res, p_val_units) == res_map(q_res, q_val_units) # Rebuild SQL functions for value evaluation def rebuild_cond_unit_val(cond_unit): if cond_unit is None or not DISABLE_VALUE: return cond_unit not_op, op_id, val_unit, val1, val2 = cond_unit if type(val1) is not dict: val1 = None else: val1 = rebuild_sql_val(val1) if type(val2) is not dict: val2 = None else: val2 = rebuild_sql_val(val2) return not_op, op_id, val_unit, val1, val2 def rebuild_condition_val(condition): if condition is None or not DISABLE_VALUE: return condition res = [] for idx, it in enumerate(condition): if idx % 2 == 0: res.append(rebuild_cond_unit_val(it)) else: res.append(it) return res def rebuild_sql_val(sql): if sql is None or not DISABLE_VALUE: return sql sql['from']['conds'] = rebuild_condition_val(sql['from']['conds']) sql['having'] = rebuild_condition_val(sql['having']) sql['where'] = rebuild_condition_val(sql['where']) sql['intersect'] = rebuild_sql_val(sql['intersect']) sql['except'] = rebuild_sql_val(sql['except']) sql['union'] = rebuild_sql_val(sql['union']) return sql # Rebuild SQL functions for foreign key evaluation def build_valid_col_units(table_units, schema): col_ids = [table_unit[1] for table_unit in table_units if table_unit[0] == TABLE_TYPE['table_unit']] prefixs = [col_id[:-2] for col_id in col_ids] valid_col_units= [] for value in schema.idMap.values(): if '.' in value and value[:value.index('.')] in prefixs: valid_col_units.append(value) return valid_col_units def rebuild_col_unit_col(valid_col_units, col_unit, kmap): if col_unit is None: return col_unit agg_id, col_id, distinct = col_unit if col_id in kmap and col_id in valid_col_units: col_id = kmap[col_id] if DISABLE_DISTINCT: distinct = None return agg_id, col_id, distinct def rebuild_val_unit_col(valid_col_units, val_unit, kmap): if val_unit is None: return val_unit unit_op, col_unit1, col_unit2 = val_unit col_unit1 = rebuild_col_unit_col(valid_col_units, col_unit1, kmap) col_unit2 = rebuild_col_unit_col(valid_col_units, col_unit2, kmap) return unit_op, col_unit1, col_unit2 def rebuild_table_unit_col(valid_col_units, table_unit, kmap): if table_unit is None: return table_unit table_type, col_unit_or_sql = table_unit if isinstance(col_unit_or_sql, tuple): col_unit_or_sql = rebuild_col_unit_col(valid_col_units, col_unit_or_sql, kmap) return table_type, col_unit_or_sql def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap): if cond_unit is None: return cond_unit not_op, op_id, val_unit, val1, val2 = cond_unit val_unit = rebuild_val_unit_col(valid_col_units, val_unit, kmap) return not_op, op_id, val_unit, val1, val2 def rebuild_condition_col(valid_col_units, condition, kmap): for idx in range(len(condition)): if idx % 2 == 0: condition[idx] = rebuild_cond_unit_col(valid_col_units, condition[idx], kmap) return condition def rebuild_select_col(valid_col_units, sel, kmap): if sel is None: return sel distinct, _list = sel new_list = [] for it in _list: agg_id, val_unit = it new_list.append((agg_id, rebuild_val_unit_col(valid_col_units, val_unit, kmap))) if DISABLE_DISTINCT: distinct = None return distinct, new_list def rebuild_from_col(valid_col_units, from_, kmap): if from_ is None: return from_ from_['table_units'] = [rebuild_table_unit_col(valid_col_units, table_unit, kmap) for table_unit in from_['table_units']] from_['conds'] = rebuild_condition_col(valid_col_units, from_['conds'], kmap) return from_ def rebuild_group_by_col(valid_col_units, group_by, kmap): if group_by is None: return group_by return [rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by] def rebuild_order_by_col(valid_col_units, order_by, kmap): if order_by is None or len(order_by) == 0: return order_by direction, val_units = order_by new_val_units = [rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units] return direction, new_val_units def rebuild_sql_col(valid_col_units, sql, kmap): if sql is None: return sql sql['select'] = rebuild_select_col(valid_col_units, sql['select'], kmap) sql['from'] = rebuild_from_col(valid_col_units, sql['from'], kmap) sql['where'] = rebuild_condition_col(valid_col_units, sql['where'], kmap) sql['groupBy'] = rebuild_group_by_col(valid_col_units, sql['groupBy'], kmap) sql['orderBy'] = rebuild_order_by_col(valid_col_units, sql['orderBy'], kmap) sql['having'] = rebuild_condition_col(valid_col_units, sql['having'], kmap) sql['intersect'] = rebuild_sql_col(valid_col_units, sql['intersect'], kmap) sql['except'] = rebuild_sql_col(valid_col_units, sql['except'], kmap) sql['union'] = rebuild_sql_col(valid_col_units, sql['union'], kmap) return sql def build_foreign_key_map(entry): cols_orig = entry["column_names_original"] tables_orig = entry["table_names_original"] # rebuild cols corresponding to idmap in Schema cols = [] for col_orig in cols_orig: if col_orig[0] >= 0: t = tables_orig[col_orig[0]] c = col_orig[1] cols.append("__" + t.lower() + "." + c.lower() + "__") else: cols.append("__all__") def keyset_in_list(k1, k2, k_list): for k_set in k_list: if k1 in k_set or k2 in k_set: return k_set new_k_set = set() k_list.append(new_k_set) return new_k_set foreign_key_list = [] foreign_keys = entry["foreign_keys"] for fkey in foreign_keys: key1, key2 = fkey key_set = keyset_in_list(key1, key2, foreign_key_list) key_set.add(key1) key_set.add(key2) foreign_key_map = {} for key_set in foreign_key_list: sorted_list = sorted(list(key_set)) midx = sorted_list[0] for idx in sorted_list: foreign_key_map[cols[idx]] = cols[midx] return foreign_key_map def build_foreign_key_map_from_json(table): with open(table) as f: data = json.load(f) tables = {} for entry in data: tables[entry['db_id']] = build_foreign_key_map(entry) return tables if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--gold', dest='gold', type=str) parser.add_argument('--pred', dest='pred', type=str) parser.add_argument('--db', dest='db', type=str) parser.add_argument('--table', dest='table', type=str) parser.add_argument('--etype', dest='etype', type=str) args = parser.parse_args() gold = args.gold pred = args.pred db_dir = args.db table = args.table etype = args.etype assert etype in ["all", "exec", "match"], "Unknown evaluation method" kmaps = build_foreign_key_map_from_json(table) evaluate(gold, pred, db_dir, etype, kmaps) ================================================ FILE: spider_evaluation/process_sql.py ================================================ ################################ # Assumptions: # 1. sql is correct # 2. only table name has alias # 3. only one intersect/union/except # # val: number(float)/string(str)/sql(dict) # col_unit: (agg_id, col_id, isDistinct(bool)) # val_unit: (unit_op, col_unit1, col_unit2) # table_unit: (table_type, col_unit/sql) # cond_unit: (not_op, op_id, val_unit, val1, val2) # condition: [cond_unit1, 'and'/'or', cond_unit2, ...] # sql { # 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...]) # 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition} # 'where': condition # 'groupBy': [col_unit1, col_unit2, ...] # 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...]) # 'having': condition # 'limit': None/limit value # 'intersect': None/sql # 'except': None/sql # 'union': None/sql # } ################################ import json import sqlite3 from nltk import word_tokenize CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except') JOIN_KEYWORDS = ('join', 'on', 'as') WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists') UNIT_OPS = ('none', '-', '+', "*", '/') AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg') TABLE_TYPE = { 'sql': "sql", 'table_unit': "table_unit", } COND_OPS = ('and', 'or') SQL_OPS = ('intersect', 'union', 'except') ORDER_OPS = ('desc', 'asc') mapped_entities = [] class Schema: """ Simple schema which maps table&column to a unique identifier """ def __init__(self, schema): self._schema = schema self._idMap = self._map(self._schema) @property def schema(self): return self._schema @property def idMap(self): return self._idMap def _map(self, schema): idMap = {'*': "__all__"} id = 1 for key, vals in schema.items(): for val in vals: idMap[key.lower() + "." + val.lower()] = "__" + key.lower() + "." + val.lower() + "__" id += 1 for key in schema: idMap[key.lower()] = "__" + key.lower() + "__" id += 1 return idMap def get_schema(db): """ Get database's schema, which is a dict with table name as key and list of column names as value :param db: database path :return: schema dict """ schema = {} conn = sqlite3.connect(db) cursor = conn.cursor() # fetch table names cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") tables = [str(table[0].lower()) for table in cursor.fetchall()] # fetch table info for table in tables: cursor.execute("PRAGMA table_info({})".format(table)) schema[table] = [str(col[1].lower()) for col in cursor.fetchall()] return schema def get_schema_from_json(fpath): with open(fpath) as f: data = json.load(f) schema = {} for entry in data: table = str(entry['table'].lower()) cols = [str(col['column_name'].lower()) for col in entry['col_data']] schema[table] = cols return schema def tokenize(string): string = str(string) string = string.replace("\'", "\"") # ensures all string values wrapped by "" problem?? quote_idxs = [idx for idx, char in enumerate(string) if char == '"'] assert len(quote_idxs) % 2 == 0, "Unexpected quote" # keep string value as token vals = {} for i in range(len(quote_idxs)-1, -1, -2): qidx1 = quote_idxs[i-1] qidx2 = quote_idxs[i] val = string[qidx1: qidx2+1] key = "__val_{}_{}__".format(qidx1, qidx2) string = string[:qidx1] + key + string[qidx2+1:] vals[key] = val toks = [word.lower() for word in word_tokenize(string)] # replace with string value token for i in range(len(toks)): if toks[i] in vals: toks[i] = vals[toks[i]] # find if there exists !=, >=, <= eq_idxs = [idx for idx, tok in enumerate(toks) if tok == "="] eq_idxs.reverse() prefix = ('!', '>', '<') for eq_idx in eq_idxs: pre_tok = toks[eq_idx-1] if pre_tok in prefix: toks = toks[:eq_idx-1] + [pre_tok + "="] + toks[eq_idx+1: ] return toks def scan_alias(toks): """Scan the index of 'as' and build the map for all alias""" as_idxs = [idx for idx, tok in enumerate(toks) if tok == 'as'] alias = {} for idx in as_idxs: alias[toks[idx+1]] = toks[idx-1] return alias def get_tables_with_alias(schema, toks): tables = scan_alias(toks) for key in schema: assert key not in tables, "Alias {} has the same name in table".format(key) tables[key] = key return tables def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None): """ :returns next idx, column id """ global mapped_entities tok = toks[start_idx] if tok == "*": return start_idx + 1, schema.idMap[tok] if '.' in tok: # if token is a composite alias, col = tok.split('.') key = tables_with_alias[alias] + "." + col mapped_entities.append((start_idx, tables_with_alias[alias] + "@" + col)) return start_idx+1, schema.idMap[key] assert default_tables is not None and len(default_tables) > 0, "Default tables should not be None or empty" for alias in default_tables: table = tables_with_alias[alias] if tok in schema.schema[table]: key = table + "." + tok mapped_entities.append((start_idx, table + "@" + tok)) return start_idx+1, schema.idMap[key] assert False, "Error col: {}".format(tok) def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=None): """ :returns next idx, (agg_op id, col_id) """ idx = start_idx len_ = len(toks) isBlock = False isDistinct = False if toks[idx] == '(': isBlock = True idx += 1 if toks[idx] in AGG_OPS: agg_id = AGG_OPS.index(toks[idx]) idx += 1 assert idx < len_ and toks[idx] == '(' idx += 1 if toks[idx] == "distinct": idx += 1 isDistinct = True idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables) assert idx < len_ and toks[idx] == ')' idx += 1 return idx, (agg_id, col_id, isDistinct) if toks[idx] == "distinct": idx += 1 isDistinct = True agg_id = AGG_OPS.index("none") idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables) if isBlock: assert toks[idx] == ')' idx += 1 # skip ')' return idx, (agg_id, col_id, isDistinct) def parse_val_unit(toks, start_idx, tables_with_alias, schema, default_tables=None): idx = start_idx len_ = len(toks) isBlock = False if toks[idx] == '(': isBlock = True idx += 1 col_unit1 = None col_unit2 = None unit_op = UNIT_OPS.index('none') idx, col_unit1 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) if idx < len_ and toks[idx] in UNIT_OPS: unit_op = UNIT_OPS.index(toks[idx]) idx += 1 idx, col_unit2 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) if isBlock: assert toks[idx] == ')' idx += 1 # skip ')' return idx, (unit_op, col_unit1, col_unit2) def parse_table_unit(toks, start_idx, tables_with_alias, schema): """ :returns next idx, table id, table name """ idx = start_idx len_ = len(toks) key = tables_with_alias[toks[idx]] if idx + 1 < len_ and toks[idx+1] == "as": idx += 3 else: idx += 1 return idx, schema.idMap[key], key def parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None): idx = start_idx len_ = len(toks) isBlock = False if toks[idx] == '(': isBlock = True idx += 1 if toks[idx] == 'select': idx, val = parse_sql(toks, idx, tables_with_alias, schema) elif "\"" in toks[idx]: # token is a string value val = toks[idx] idx += 1 else: try: val = float(toks[idx]) idx += 1 except: end_idx = idx while end_idx < len_ and toks[end_idx] != ',' and toks[end_idx] != ')'\ and toks[end_idx] != 'and' and toks[end_idx] not in CLAUSE_KEYWORDS and toks[end_idx] not in JOIN_KEYWORDS: end_idx += 1 idx, val = parse_col_unit(toks[: end_idx], start_idx, tables_with_alias, schema, default_tables) idx = end_idx if isBlock: assert toks[idx] == ')' idx += 1 return idx, val def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=None): idx = start_idx len_ = len(toks) conds = [] while idx < len_: idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) not_op = False if toks[idx] == 'not': not_op = True idx += 1 assert idx < len_ and toks[idx] in WHERE_OPS, "Error condition: idx: {}, tok: {}".format(idx, toks[idx]) op_id = WHERE_OPS.index(toks[idx]) idx += 1 val1 = val2 = None if op_id == WHERE_OPS.index('between'): # between..and... special case: dual values idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables) assert toks[idx] == 'and' idx += 1 idx, val2 = parse_value(toks, idx, tables_with_alias, schema, default_tables) else: # normal case: single value idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables) val2 = None conds.append((not_op, op_id, val_unit, val1, val2)) if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";") or toks[idx] in JOIN_KEYWORDS): break if idx < len_ and toks[idx] in COND_OPS: conds.append(toks[idx]) idx += 1 # skip and/or return idx, conds def parse_select(toks, start_idx, tables_with_alias, schema, default_tables=None): idx = start_idx len_ = len(toks) assert toks[idx] == 'select', "'select' not found" idx += 1 isDistinct = False if idx < len_ and toks[idx] == 'distinct': idx += 1 isDistinct = True val_units = [] while idx < len_ and toks[idx] not in CLAUSE_KEYWORDS: agg_id = AGG_OPS.index("none") if toks[idx] in AGG_OPS: agg_id = AGG_OPS.index(toks[idx]) idx += 1 idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) val_units.append((agg_id, val_unit)) if idx < len_ and toks[idx] == ',': idx += 1 # skip ',' return idx, (isDistinct, val_units) def parse_from(toks, start_idx, tables_with_alias, schema): """ Assume in the from clause, all table units are combined with join """ assert 'from' in toks[start_idx:], "'from' not found" len_ = len(toks) idx = toks.index('from', start_idx) + 1 default_tables = [] table_units = [] conds = [] while idx < len_: isBlock = False if toks[idx] == '(': isBlock = True idx += 1 if toks[idx] == 'select': idx, sql = parse_sql(toks, idx, tables_with_alias, schema) table_units.append((TABLE_TYPE['sql'], sql)) else: if idx < len_ and toks[idx] == 'join': idx += 1 # skip join idx, table_unit, table_name = parse_table_unit(toks, idx, tables_with_alias, schema) table_units.append((TABLE_TYPE['table_unit'],table_unit)) default_tables.append(table_name) if idx < len_ and toks[idx] == "on": idx += 1 # skip on idx, this_conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) if len(conds) > 0: conds.append('and') conds.extend(this_conds) if isBlock: assert toks[idx] == ')' idx += 1 if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): break return idx, table_units, conds, default_tables def parse_where(toks, start_idx, tables_with_alias, schema, default_tables): idx = start_idx len_ = len(toks) if idx >= len_ or toks[idx] != 'where': return idx, [] idx += 1 idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) return idx, conds def parse_group_by(toks, start_idx, tables_with_alias, schema, default_tables): idx = start_idx len_ = len(toks) col_units = [] if idx >= len_ or toks[idx] != 'group': return idx, col_units idx += 1 assert toks[idx] == 'by' idx += 1 while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): idx, col_unit = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) col_units.append(col_unit) if idx < len_ and toks[idx] == ',': idx += 1 # skip ',' else: break return idx, col_units def parse_order_by(toks, start_idx, tables_with_alias, schema, default_tables): idx = start_idx len_ = len(toks) val_units = [] order_type = 'asc' # default type is 'asc' if idx >= len_ or toks[idx] != 'order': return idx, val_units idx += 1 assert toks[idx] == 'by' idx += 1 while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) val_units.append(val_unit) if idx < len_ and toks[idx] in ORDER_OPS: order_type = toks[idx] idx += 1 if idx < len_ and toks[idx] == ',': idx += 1 # skip ',' else: break return idx, (order_type, val_units) def parse_having(toks, start_idx, tables_with_alias, schema, default_tables): idx = start_idx len_ = len(toks) if idx >= len_ or toks[idx] != 'having': return idx, [] idx += 1 idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) return idx, conds def parse_limit(toks, start_idx): idx = start_idx len_ = len(toks) if idx < len_ and toks[idx] == 'limit': idx += 2 try: limit_val = int(toks[idx-1]) except Exception: limit_val = '"value"' return idx, limit_val return idx, None def parse_sql(toks, start_idx, tables_with_alias, schema, mapped_entities_fn=None): global mapped_entities if mapped_entities_fn is not None: mapped_entities = mapped_entities_fn() isBlock = False # indicate whether this is a block of sql/sub-sql len_ = len(toks) idx = start_idx sql = {} if toks[idx] == '(': isBlock = True idx += 1 # parse from clause in order to get default tables from_end_idx, table_units, conds, default_tables = parse_from(toks, start_idx, tables_with_alias, schema) sql['from'] = {'table_units': table_units, 'conds': conds} # select clause _, select_col_units = parse_select(toks, idx, tables_with_alias, schema, default_tables) idx = from_end_idx sql['select'] = select_col_units # where clause idx, where_conds = parse_where(toks, idx, tables_with_alias, schema, default_tables) sql['where'] = where_conds # group by clause idx, group_col_units = parse_group_by(toks, idx, tables_with_alias, schema, default_tables) sql['groupBy'] = group_col_units # having clause idx, having_conds = parse_having(toks, idx, tables_with_alias, schema, default_tables) sql['having'] = having_conds # order by clause idx, order_col_units = parse_order_by(toks, idx, tables_with_alias, schema, default_tables) sql['orderBy'] = order_col_units # limit clause idx, limit_val = parse_limit(toks, idx) sql['limit'] = limit_val idx = skip_semicolon(toks, idx) if isBlock: assert toks[idx] == ')' idx += 1 # skip ')' idx = skip_semicolon(toks, idx) # intersect/union/except clause for op in SQL_OPS: # initialize IUE sql[op] = None if idx < len_ and toks[idx] in SQL_OPS: sql_op = toks[idx] idx += 1 idx, IUE_sql = parse_sql(toks, idx, tables_with_alias, schema) sql[sql_op] = IUE_sql if mapped_entities_fn is not None: return idx, sql, mapped_entities else: return idx, sql def load_data(fpath): with open(fpath) as f: data = json.load(f) return data def get_sql(schema, query): toks = tokenize(query) tables_with_alias = get_tables_with_alias(schema.schema, toks) _, sql = parse_sql(toks, 0, tables_with_alias, schema) return sql def skip_semicolon(toks, start_idx): idx = start_idx while idx < len(toks) and toks[idx] == ";": idx += 1 return idx ================================================ FILE: state_machines/states/grammar_based_state.py ================================================ from typing import Any, Dict, List, Sequence, Tuple import torch from allennlp.data.fields.production_rule_field import ProductionRule from allennlp.state_machines.states.grammar_statelet import GrammarStatelet from allennlp.state_machines.states.rnn_statelet import RnnStatelet from allennlp.state_machines.states.state import State # This syntax is pretty weird and ugly, but it's necessary to make mypy happy with the API that # we've defined. We're using generics to make the type of `combine_states` come out right. See # the note in `state_machines.state.py` for a little more detail. from state_machines.states.sql_state import SqlState class GrammarBasedState(State['GrammarBasedState']): """ A generic State that's suitable for most models that do grammar-based decoding. We keep around a `group` of states, and each element in the group has a few things: a batch index, an action history, a score, an ``RnnStatelet``, and a ``GrammarStatelet``. We additionally have some information that's independent of any particular group element: a list of all possible actions for all batch instances passed to ``model.forward()``, and a ``extras`` field that you can use if you really need some extra information about each batch instance (like a string description, or other metadata). Finally, we also have a specially-treated, optional ``debug_info`` field. If this is given, it should be an empty list for each group instance when the initial state is created. In that case, we will keep around information about the actions considered at each timestep of decoding and other things that you might want to visualize in a demo. This probably isn't necessary for training, and to get it right we need to copy a bunch of data structures for each new state, so it's best used only at evaluation / demo time. Parameters ---------- batch_indices : ``List[int]`` Passed to super class; see docs there. action_history : ``List[List[int]]`` Passed to super class; see docs there. score : ``List[torch.Tensor]`` Passed to super class; see docs there. rnn_state : ``List[RnnStatelet]`` An ``RnnStatelet`` for every group element. This keeps track of the current decoder hidden state, the previous decoder output, the output from the encoder (for computing attentions), and other things that are typical seq2seq decoder state things. grammar_state : ``List[GrammarStatelet]`` This hold the current grammar state for each element of the group. The ``GrammarStatelet`` keeps track of which actions are currently valid. possible_actions : ``List[List[ProductionRule]]`` The list of all possible actions that was passed to ``model.forward()``. We need this so we can recover production strings, which we need to update grammar states. extras : ``List[Any]``, optional (default=None) If you need to keep around some extra data for each instance in the batch, you can put that in here, without adding another field. This should be used `very sparingly`, as there is no type checking or anything done on the contents of this field, and it will just be passed around between ``States`` as-is, without copying. debug_info : ``List[Any]``, optional (default=None). """ def __init__(self, batch_indices: List[int], action_history: List[List[int]], score: List[torch.Tensor], rnn_state: List[RnnStatelet], grammar_state: List[GrammarStatelet], sql_state: List[SqlState], possible_actions: List[List[ProductionRule]], action_entity_mapping: List[Dict[int, int]], extras: List[Any] = None, debug_info: List = None) -> None: super().__init__(batch_indices, action_history, score) self.rnn_state = rnn_state self.grammar_state = grammar_state self.sql_state = sql_state self.possible_actions = possible_actions self.action_entity_mapping = action_entity_mapping self.extras = extras self.debug_info = debug_info def new_state_from_group_index(self, group_index: int, action: int, new_score: torch.Tensor, new_rnn_state: RnnStatelet, considered_actions: List[int] = None, action_probabilities: List[float] = None, attention_weights: torch.Tensor = None, linking_scores_qst: torch.Tensor = None, linking_scores_past: torch.Tensor = None) -> 'GrammarBasedState': batch_index = self.batch_indices[group_index] new_action_history = self.action_history[group_index] + [action] production_rule = self.possible_actions[batch_index][action][0] new_grammar_state = self.grammar_state[group_index].take_action(production_rule) new_sql_state = self.sql_state[group_index].take_action(production_rule) if self.debug_info is not None: attention = attention_weights[group_index] if attention_weights is not None else None # output_attention = output_attention_weights[group_index] if output_attention_weights is not None else None debug_info = { 'considered_actions': considered_actions, 'question_attention': attention, 'probabilities': action_probabilities, 'linking_scores_qst': linking_scores_qst, 'linking_scores_past': linking_scores_past, } new_debug_info = [self.debug_info[group_index] + [debug_info]] else: new_debug_info = None return GrammarBasedState(batch_indices=[batch_index], action_history=[new_action_history], score=[new_score], rnn_state=[new_rnn_state], grammar_state=[new_grammar_state], sql_state=[new_sql_state], possible_actions=self.possible_actions, action_entity_mapping=self.action_entity_mapping, extras=self.extras, debug_info=new_debug_info) def print_action_history(self, group_index: int = None) -> None: scores = self.score if group_index is None else [self.score[group_index]] batch_indices = self.batch_indices if group_index is None else [self.batch_indices[group_index]] histories = self.action_history if group_index is None else [self.action_history[group_index]] for score, batch_index, action_history in zip(scores, batch_indices, histories): print(' ', score.detach().cpu().numpy()[0], [self.possible_actions[batch_index][action][0] for action in action_history]) def get_valid_actions(self) -> List[Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]]: """ Returns a list of valid actions for each element of the group. """ # return [state.get_valid_actions() for state in self.grammar_state] return [sql_state.get_valid_actions(grammar_state.get_valid_actions()) for sql_state, grammar_state in zip(self.sql_state, self.grammar_state)] def is_finished(self) -> bool: if len(self.batch_indices) != 1: raise RuntimeError("is_finished() is only defined with a group_size of 1") return self.grammar_state[0].is_finished() @classmethod def combine_states(cls, states: Sequence['GrammarBasedState']) -> 'GrammarBasedState': batch_indices = [batch_index for state in states for batch_index in state.batch_indices] action_histories = [action_history for state in states for action_history in state.action_history] scores = [score for state in states for score in state.score] rnn_states = [rnn_state for state in states for rnn_state in state.rnn_state] grammar_states = [grammar_state for state in states for grammar_state in state.grammar_state] sql_states = [sql_state for state in states for sql_state in state.sql_state] if states[0].debug_info is not None: debug_info = [debug_info for state in states for debug_info in state.debug_info] else: debug_info = None return GrammarBasedState(batch_indices=batch_indices, action_history=action_histories, score=scores, rnn_state=rnn_states, grammar_state=grammar_states, sql_state=sql_states, possible_actions=states[0].possible_actions, action_entity_mapping=states[0].action_entity_mapping, extras=states[0].extras, debug_info=debug_info) ================================================ FILE: state_machines/states/rnn_statelet.py ================================================ from typing import List, Optional import torch from allennlp.nn import util class RnnStatelet: """ This class keeps track of all of decoder-RNN-related variables that you need during decoding. This includes things like the current decoder hidden state, the memory cell (for LSTM decoders), the encoder output that you need for computing attentions, and so on. This is intended to be used `inside` a ``State``, which likely has other things it has to keep track of for doing constrained decoding. Parameters ---------- hidden_state : ``torch.Tensor`` This holds the LSTM hidden state, with shape ``(decoder_output_dim,)`` if the decoder has 1 layer and ``(num_layers, decoder_output_dim)`` otherwise. memory_cell : ``torch.Tensor`` This holds the LSTM memory cell, with shape ``(decoder_output_dim,)`` if the decoder has 1 layer and ``(num_layers, decoder_output_dim)`` otherwise. previous_action_embedding : ``torch.Tensor`` This holds the embedding for the action we took at the last timestep (which gets input to the decoder). Has shape ``(action_embedding_dim,)``. attended_input : ``torch.Tensor`` This holds the attention-weighted sum over the input representations that we computed in the previous timestep. We keep this as part of the state because we use the previous attention as part of our decoder cell update. Has shape ``(encoder_output_dim,)``. encoder_outputs : ``List[torch.Tensor]`` A list of variables, each of shape ``(input_sequence_length, encoder_output_dim)``, containing the encoder outputs at each timestep. The list is over batch elements, and we do the input this way so we can easily do a ``torch.cat`` on a list of indices into this batched list. Note that all of the above parameters are single tensors, while the encoder outputs and mask are lists of length ``batch_size``. We always pass around the encoder outputs and mask unmodified, regardless of what's in the grouping for this state. We'll use the ``batch_indices`` for the group to pull pieces out of these lists when we're ready to actually do some computation. encoder_output_mask : ``List[torch.Tensor]`` A list of variables, each of shape ``(input_sequence_length,)``, containing a mask over question tokens for each batch instance. This is a list over batch elements, for the same reasons as above. """ def __init__(self, hidden_state: torch.Tensor, memory_cell: torch.Tensor, previous_action_embedding: torch.Tensor, attended_input: torch.Tensor, encoder_outputs: List[torch.Tensor], encoder_output_mask: List[torch.Tensor], decoder_outputs: Optional[torch.Tensor] = None) -> None: self.hidden_state = hidden_state self.memory_cell = memory_cell self.previous_action_embedding = previous_action_embedding self.attended_input = attended_input self.encoder_outputs = encoder_outputs self.encoder_output_mask = encoder_output_mask self.decoder_outputs = decoder_outputs def __eq__(self, other): if isinstance(self, other.__class__): return all([ util.tensors_equal(self.hidden_state, other.hidden_state, tolerance=1e-5), util.tensors_equal(self.memory_cell, other.memory_cell, tolerance=1e-5), util.tensors_equal(self.previous_action_embedding, other.previous_action_embedding, tolerance=1e-5), util.tensors_equal(self.attended_input, other.attended_input, tolerance=1e-5), ]) return NotImplemented ================================================ FILE: state_machines/states/sql_state.py ================================================ import copy import logging logger = logging.getLogger(__name__) # pylint: disable=invalid-name class SqlState: def __init__(self, possible_actions, enabled: bool=True): self.possible_actions = [a[0] for a in possible_actions] self.action_history = [] self.tables_used = set() self.tables_used_by_columns = set() self.current_stack = [] self.subqueries_stack = [] self.enabled = enabled def take_action(self, production_rule: str) -> 'SqlState': if not self.enabled: return self new_sql_state = copy.deepcopy(self) lhs, rhs = production_rule.split(' -> ') rhs_tokens = rhs.strip('[]').split(', ') if lhs == 'table_name': new_sql_state.tables_used.add(rhs_tokens[0].strip('"')) elif lhs == 'column_name': new_sql_state.tables_used_by_columns.add(rhs_tokens[0].strip('"').split('@')[0]) elif lhs == 'iue': new_sql_state.tables_used_by_columns = set() new_sql_state.tables_used = set() elif lhs == "source_subq": new_sql_state.subqueries_stack.append(copy.deepcopy(new_sql_state)) new_sql_state.tables_used = set() new_sql_state.tables_used_by_columns = set() new_sql_state.action_history.append(production_rule) new_sql_state.current_stack.append([lhs, []]) for token in rhs_tokens: is_terminal = token[0] == '"' and token[-1] == '"' if not is_terminal: new_sql_state.current_stack[-1][1].append(token) while len(new_sql_state.current_stack[-1][1]) == 0: finished_item = new_sql_state.current_stack[-1][0] del new_sql_state.current_stack[-1] if finished_item == 'statement': break if new_sql_state.current_stack[-1][1][0] == finished_item: new_sql_state.current_stack[-1][1] = new_sql_state.current_stack[-1][1][1:] if finished_item == 'source_subq': new_sql_state.tables_used = new_sql_state.subqueries_stack[-1].tables_used new_sql_state.tables_used_by_columns = new_sql_state.subqueries_stack[-1].tables_used_by_columns del new_sql_state.subqueries_stack[-1] return new_sql_state def get_valid_actions(self, valid_actions: dict): if not self.enabled: return valid_actions valid_actions_ids = [] for key, items in valid_actions.items(): valid_actions_ids += [(key, rule_id) for rule_id in valid_actions[key][2]] valid_actions_rules = [self.possible_actions[rule_id] for rule_type, rule_id in valid_actions_ids] actions_to_remove = {k: set() for k in valid_actions.keys()} current_clause = self._get_current_open_clause() if current_clause in ['where_clause', 'orderby_clause', 'join_condition', 'groupby_clause']: for rule_id, rule in zip(valid_actions_ids, valid_actions_rules): rule_type, rule_id = rule_id lhs, rhs = rule.split(' -> ') rhs_values = rhs.strip('[]').split(', ') if lhs == 'column_name': rule_table = rhs_values[0].strip('"').split('@')[0] if rule_table not in self.tables_used: actions_to_remove[rule_type].add(rule_id) # if len(self.current_stack[-1][1]) < 2: # # disable condition clause when same tables # rule_table = rhs_values[0].strip('"').split('@')[0] # last_table = self.action_history[-1].split(' -> ')[1].strip('[]"').split('@')[0] # if rule_table == last_table: # actions_to_remove[rule_type].add(rule_id) if current_clause in ['join_clause']: for rule_id, rule in zip(valid_actions_ids, valid_actions_rules): rule_type, rule_id = rule_id lhs, rhs = rule.split(' -> ') rhs_values = rhs.strip('[]').split(', ') if lhs == 'table_name': candidate_table = rhs_values[0].strip('"') if current_clause == 'join_clause' and len(self.current_stack[-1][1]) == 2: if candidate_table in self.tables_used: # trying to join an already joined table actions_to_remove[rule_type].add(rule_id) if 'join_clauses' not in self.current_stack[-2][1] and not self.current_stack[-2][0].startswith('join_clauses'): # decided not to join any more tables remaining_joins = self.tables_used_by_columns - self.tables_used if len(remaining_joins) > 0 and candidate_table not in self.tables_used_by_columns: # trying to select a single table but used other table(s) in columns actions_to_remove[rule_type].add(rule_id) if current_clause in ['select_core']: for rule_id, rule in zip(valid_actions_ids, valid_actions_rules): rule_type, rule_id = rule_id lhs, rhs = rule.split(' -> ') rhs_values = rhs.strip('[]').split(', ') if self.current_stack[-1][1][0] == 'from_clause' or self.current_stack[-1][1][0] == 'join_clauses': all_tables = set([a.split(' -> ')[1].strip('[]\"') for a in self.possible_actions if a.startswith('table_name ->')]) if len(self.tables_used_by_columns - self.tables_used) > 1: # selected columns from more tables than selected, must join if 'join_clauses' not in rhs: actions_to_remove[rule_type].add(rule_id) if len(all_tables - self.tables_used) <= 1: # don't join 2 tables because otherwise there will be no more tables to join # (assuming no joining twice and no sub-queries) if 'join_clauses' in rhs: actions_to_remove[rule_type].add(rule_id) if lhs == "table_name" and self.current_stack[-1][0] == "single_source": candidate_table = rhs_values[0].strip('"') if len(self.tables_used_by_columns) > 0 and candidate_table not in self.tables_used_by_columns: # trying to select a single table but used other table(s) in columns actions_to_remove[rule_type].add(rule_id) if lhs == 'single_source' and len(self.tables_used_by_columns) == 0 and rhs.strip('[]') == 'source_subq': # prevent cases such as "select count ( * ) from ( select city.district from city ) where city.district = ' value '" search_stack_pos = -1 while self.current_stack[search_stack_pos][0] != 'select_core': # note - should look for other "gateaways" here (i.e. maybe this is not a dead end, if there is # another source_subq. This is ignored here search_stack_pos -= 1 if self.current_stack[search_stack_pos][1][-1] == 'where_clause': # planning to add where/group/order later, but no columns were ever selected actions_to_remove[rule_type].add(rule_id) while self.current_stack[search_stack_pos][0] != 'query': search_stack_pos -= 1 if 'orderby_clause' in self.current_stack[search_stack_pos][1]: actions_to_remove[rule_type].add(rule_id) if 'groupby_clause' in self.current_stack[search_stack_pos][1]: actions_to_remove[rule_type].add(rule_id) new_valid_actions = {} new_global_actions = self._remove_actions(valid_actions, 'global', actions_to_remove['global']) if 'global' in valid_actions else None new_linked_actions = self._remove_actions(valid_actions, 'linked', actions_to_remove['linked']) if 'linked' in valid_actions else None if new_linked_actions is not None: new_valid_actions['linked'] = new_linked_actions if new_global_actions is not None: new_valid_actions['global'] = new_global_actions # if len(new_valid_actions) == 0 and valid_actions: # # should not get here! implies that a rule should have been disabled in past (bug in this parser) # # log and do not remove rules (otherwise crashes) # # logger.warning("No valid action remains, error in sql decoding parser!") # # logger.warning("Action history: " + str(self.action_history)) # # logger.warning("Tables in db: " + ', '.join([a.split(' -> ')[1].strip('[]\"') for a in self.possible_actions if a.startswith('table_name ->')])) # # return valid_actions return new_valid_actions @staticmethod def _remove_actions(valid_actions, key, ids_to_remove): if len(ids_to_remove) == 0: return valid_actions[key] if len(ids_to_remove) == len(valid_actions[key][2]): return None current_ids = valid_actions[key][2] keep_ids = [] keep_ids_loc = [] for loc, rule_id in enumerate(current_ids): if rule_id not in ids_to_remove: keep_ids.append(rule_id) keep_ids_loc.append(loc) items = list(valid_actions[key]) items[0] = items[0][keep_ids_loc] items[1] = items[1][keep_ids_loc] items[2] = keep_ids if len(items) >= 4: items[3] = items[3][keep_ids_loc] return tuple(items) def _get_current_open_clause(self): relevant_clauses = [ 'where_clause', 'orderby_clause', 'join_clause', 'join_condition', 'select_core', 'groupby_clause', 'source_subq' ] for rule in self.current_stack[::-1]: if rule[0] in relevant_clauses: return rule[0] return None ================================================ FILE: state_machines/transition_functions/attend_past_schema_items_transition.py ================================================ from collections import defaultdict from typing import Dict, Tuple, List, Set, Any, Callable, Optional import torch from allennlp.modules import Attention, FeedForward from allennlp.nn import Activation, util from allennlp.state_machines.states.grammar_based_state import GrammarBasedState from allennlp.state_machines.transition_functions import BasicTransitionFunction from allennlp.state_machines.transition_functions.linking_transition_function import LinkingTransitionFunction from overrides import overrides from torch.nn import Linear from state_machines.states.rnn_statelet import RnnStatelet class AttendPastSchemaItemsTransitionFunction(BasicTransitionFunction): def __init__(self, encoder_output_dim: int, action_embedding_dim: int, input_attention: Attention, past_attention: Attention, activation: Activation = Activation.by_name('relu')(), predict_start_type_separately: bool = True, num_start_types: int = None, add_action_bias: bool = True, dropout: float = 0.0, num_layers: int = 1) -> None: super().__init__(encoder_output_dim=encoder_output_dim, action_embedding_dim=action_embedding_dim, input_attention=input_attention, num_start_types=num_start_types, activation=activation, predict_start_type_separately=predict_start_type_separately, add_action_bias=add_action_bias, dropout=dropout, num_layers=num_layers) self._past_attention = past_attention self._ent2ent_ff = FeedForward(1, 1, 1, Activation.by_name('linear')()) @overrides def take_step(self, state: GrammarBasedState, max_actions: int = None, allowed_actions: List[Set[int]] = None) -> List[GrammarBasedState]: if self._predict_start_type_separately and not state.action_history[0]: # The wikitables parser did something different when predicting the start type, which # is our first action. So in this case we break out into a different function. We'll # ignore max_actions on our first step, assuming there aren't that many start types. return self._take_first_step(state, allowed_actions) # Taking a step in the decoder consists of three main parts. First, we'll construct the # input to the decoder and update the decoder's hidden state. Second, we'll use this new # hidden state (and maybe other information) to predict an action. Finally, we will # construct new states for the next step. Each new state corresponds to one valid action # that can be taken from the current state, and they are ordered by their probability of # being selected. updated_state = self._update_decoder_state(state) batch_results = self._compute_action_probabilities(state, updated_state['hidden_state'], updated_state['attention_weights'], updated_state['past_schema_items_attention_weights'], updated_state['predicted_action_embeddings']) new_states = self._construct_next_states(state, updated_state, batch_results, max_actions, allowed_actions) return new_states def _update_decoder_state(self, state: GrammarBasedState) -> Dict[str, torch.Tensor]: # For updating the decoder, we're doing a bunch of tensor operations that can be batched # without much difficulty. So, we take all group elements and batch their tensors together # before doing these decoder operations. group_size = len(state.batch_indices) attended_question = torch.stack([rnn_state.attended_input for rnn_state in state.rnn_state]) if self._num_layers > 1: hidden_state = torch.stack([rnn_state.hidden_state for rnn_state in state.rnn_state], 1) memory_cell = torch.stack([rnn_state.memory_cell for rnn_state in state.rnn_state], 1) else: hidden_state = torch.stack([rnn_state.hidden_state for rnn_state in state.rnn_state]) memory_cell = torch.stack([rnn_state.memory_cell for rnn_state in state.rnn_state]) previous_action_embedding = torch.stack([rnn_state.previous_action_embedding for rnn_state in state.rnn_state]) # (group_size, decoder_input_dim) projected_input = self._input_projection_layer(torch.cat([attended_question, previous_action_embedding], -1)) decoder_input = self._activation(projected_input) if self._num_layers > 1: _, (hidden_state, memory_cell) = self._decoder_cell(decoder_input.unsqueeze(0), (hidden_state, memory_cell)) else: hidden_state, memory_cell = self._decoder_cell(decoder_input, (hidden_state, memory_cell)) hidden_state = self._dropout(hidden_state) # (group_size, encoder_output_dim) encoder_outputs = torch.stack([state.rnn_state[0].encoder_outputs[i] for i in state.batch_indices]) encoder_output_mask = torch.stack([state.rnn_state[0].encoder_output_mask[i] for i in state.batch_indices]) if self._num_layers > 1: attended_question, attention_weights = self.attend_on_question(hidden_state[-1], encoder_outputs, encoder_output_mask) action_query = torch.cat([hidden_state[-1], attended_question], dim=-1) else: attended_question, attention_weights = self.attend_on_question(hidden_state, encoder_outputs, encoder_output_mask) action_query = torch.cat([hidden_state, attended_question], dim=-1) # TODO: Can batch this (need to save ids of states with saved outputs) past_schema_items_attention_weights = [] for i, rnn_state in enumerate(state.rnn_state): if rnn_state.decoder_outputs is not None: decoder_outputs_states, decoder_outputs_ids = rnn_state.decoder_outputs attn_weights = self.attend(self._past_attention, hidden_state[i].unsqueeze(0), decoder_outputs_states.unsqueeze(0), None).squeeze(0) past_schema_items_attention_weights.append((attn_weights, decoder_outputs_ids)) else: past_schema_items_attention_weights.append(None) # past_schema_items_attention_weights = torch.stack(past_schema_items_attention_weights) # (group_size, action_embedding_dim) projected_query = self._activation(self._output_projection_layer(action_query)) predicted_action_embeddings = self._dropout(projected_query) if self._add_action_bias: # NOTE: It's important that this happens right before the dot product with the action # embeddings. Otherwise this isn't a proper bias. We do it here instead of right next # to the `.mm` below just so we only do it once for the whole group. ones = predicted_action_embeddings.new([[1] for _ in range(group_size)]) predicted_action_embeddings = torch.cat([predicted_action_embeddings, ones], dim=-1) return { 'hidden_state': hidden_state, 'memory_cell': memory_cell, 'attended_question': attended_question, 'attention_weights': attention_weights, 'past_schema_items_attention_weights': past_schema_items_attention_weights, 'predicted_action_embeddings': predicted_action_embeddings, } @overrides def _compute_action_probabilities(self, state: GrammarBasedState, hidden_state: torch.Tensor, attention_weights: torch.Tensor, past_schema_items_attention_weights: torch.Tensor, predicted_action_embeddings: torch.Tensor ) -> Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]]: # In this section we take our predicted action embedding and compare it to the available # actions in our current state (which might be different for each group element). For # computing action scores, we'll forget about doing batched / grouped computation, as it # adds too much complexity and doesn't speed things up, anyway, with the operations we're # doing here. This means we don't need any action masks, as we'll only get the right # lengths for what we're computing. group_size = len(state.batch_indices) actions = state.get_valid_actions() batch_results: Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]] = defaultdict(list) for group_index in range(group_size): batch_id = state.batch_indices[group_index] instance_actions = actions[group_index] predicted_action_embedding = predicted_action_embeddings[group_index] embedded_actions: List[int] = [] output_action_embeddings = None embedded_action_logits = None current_log_probs = None linked_action_logits_encoder = None linked_action_ent2ent_logits = None if 'global' in instance_actions: action_embeddings, output_action_embeddings, embedded_actions = instance_actions['global'] # This is just a matrix product between a (num_actions, embedding_dim) matrix and an # (embedding_dim, 1) matrix. embedded_action_logits = action_embeddings.mm(predicted_action_embedding.unsqueeze(-1)).squeeze(-1) action_ids = embedded_actions if 'linked' in instance_actions: linking_scores, type_embeddings, linked_actions, entity_action_linking_scores = instance_actions['linked'] action_ids = embedded_actions + linked_actions # (num_question_tokens, 1) # for linked actions, in addition to the linking score with the attended question word, we add # a linking score with an attended previously decoded linked action # num_decoded_entities = 3 if past_schema_items_attention_weights[group_index] is not None: past_items_attention_weights, past_items_action_ids = past_schema_items_attention_weights[group_index] past_schema_items_ids = [state.action_entity_mapping[batch_id][a] for a in past_items_action_ids] # we are only interested about the scores of the entities the decoder has already output past_entity_linking_scores = entity_action_linking_scores[:, past_schema_items_ids] linked_action_ent2ent_logits = past_entity_linking_scores.mm( past_items_attention_weights.unsqueeze(-1)).squeeze(-1) linked_action_ent2ent_logits = self._ent2ent_ff(linked_action_ent2ent_logits.unsqueeze(-1)).squeeze(1) else: linked_action_ent2ent_logits = 0 linked_action_logits_encoder = linking_scores.mm(attention_weights[group_index].unsqueeze(-1)).squeeze(-1) linked_action_logits = linked_action_logits_encoder + linked_action_ent2ent_logits # The `output_action_embeddings` tensor gets used later as the input to the next # decoder step. For linked actions, we don't have any action embedding, so we use # the entity type instead. if output_action_embeddings is not None: output_action_embeddings = torch.cat([output_action_embeddings, type_embeddings], dim=0) else: output_action_embeddings = type_embeddings if embedded_action_logits is not None: action_logits = torch.cat([embedded_action_logits, linked_action_logits], dim=-1) else: action_logits = linked_action_logits current_log_probs = torch.nn.functional.log_softmax(action_logits, dim=-1) elif not instance_actions: action_ids = None current_log_probs = float('inf') else: action_logits = embedded_action_logits current_log_probs = torch.nn.functional.log_softmax(action_logits, dim=-1) # This is now the total score for each state after taking each action. We're going to # sort by this later, so it's important that this is the total score, not just the # score for the current action. log_probs = state.score[group_index] + current_log_probs batch_results[state.batch_indices[group_index]].append((group_index, log_probs, current_log_probs, output_action_embeddings, action_ids, linked_action_logits_encoder, linked_action_ent2ent_logits)) return batch_results def _construct_next_states(self, state: GrammarBasedState, updated_rnn_state: Dict[str, torch.Tensor], batch_action_probs: Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]], max_actions: int, allowed_actions: List[Set[int]]): # pylint: disable=no-self-use # We'll yield a bunch of states here that all have a `group_size` of 1, so that the # learning algorithm can decide how many of these it wants to keep, and it can just regroup # them later, as that's a really easy operation. # # We first define a `make_state` method, as in the logic that follows we want to create # states in a couple of different branches, and we don't want to duplicate the # state-creation logic. This method creates a closure using variables from the method, so # it doesn't make sense to pull it out of here. # Each group index here might get accessed multiple times, and doing the slicing operation # each time is more expensive than doing it once upfront. These three lines give about a # 10% speedup in training time. group_size = len(state.batch_indices) chunk_index = 1 if self._num_layers > 1 else 0 hidden_state = [x.squeeze(chunk_index) for x in updated_rnn_state['hidden_state'].chunk(group_size, chunk_index)] memory_cell = [x.squeeze(chunk_index) for x in updated_rnn_state['memory_cell'].chunk(group_size, chunk_index)] attended_question = [x.squeeze(0) for x in updated_rnn_state['attended_question'].chunk(group_size, 0)] def make_state(group_index: int, action: int, new_score: torch.Tensor, action_embedding: torch.Tensor) -> GrammarBasedState: batch_index = state.batch_indices[group_index] decoder_outputs = state.rnn_state[group_index].decoder_outputs is_linked_action = not state.possible_actions[batch_index][action][1] if is_linked_action: if decoder_outputs is None: decoder_outputs = hidden_state[group_index].unsqueeze(0), [action] else: decoder_outputs_states, decoder_outputs_ids = decoder_outputs decoder_outputs = torch.cat(( decoder_outputs_states, hidden_state[group_index].unsqueeze(0) ), dim=0), decoder_outputs_ids + [action] new_rnn_state = RnnStatelet(hidden_state[group_index], memory_cell[group_index], action_embedding, attended_question[group_index], state.rnn_state[group_index].encoder_outputs, state.rnn_state[group_index].encoder_output_mask, decoder_outputs) for i, _, current_log_probs, _, actions, lsq, lsp in batch_action_probs[batch_index]: if i == group_index: considered_actions = actions probabilities = current_log_probs.exp().cpu() considered_lsq = lsq considered_lsp = lsp break return state.new_state_from_group_index(group_index, action, new_score, new_rnn_state, considered_actions, probabilities, updated_rnn_state['attention_weights'], considered_lsq, considered_lsp ) new_states = [] for _, results in batch_action_probs.items(): if allowed_actions and not max_actions: # If we're given a set of allowed actions, and we're not just keeping the top k of # them, we don't need to do any sorting, so we can speed things up quite a bit. for group_index, log_probs, _, action_embeddings, actions in results: for log_prob, action_embedding, action in zip(log_probs, action_embeddings, actions): if action in allowed_actions[group_index]: new_states.append(make_state(group_index, action, log_prob, action_embedding)) else: # In this case, we need to sort the actions. We'll do that on CPU, as it's easier, # and our action list is on the CPU, anyway. group_indices = [] group_log_probs: List[torch.Tensor] = [] group_action_embeddings = [] group_actions = [] for group_index, log_probs, _, action_embeddings, actions, _, _ in results: if not actions: continue group_indices.extend([group_index] * len(actions)) group_log_probs.append(log_probs) group_action_embeddings.append(action_embeddings) group_actions.extend(actions) if len(group_log_probs) == 0: continue log_probs = torch.cat(group_log_probs, dim=0) action_embeddings = torch.cat(group_action_embeddings, dim=0) log_probs_cpu = log_probs.data.cpu().numpy().tolist() batch_states = [(log_probs_cpu[i], group_indices[i], log_probs[i], action_embeddings[i], group_actions[i]) for i in range(len(group_actions)) if (not allowed_actions or group_actions[i] in allowed_actions[group_indices[i]])] # We use a key here to make sure we're not trying to compare anything on the GPU. batch_states.sort(key=lambda x: x[0], reverse=True) if max_actions: batch_states = batch_states[:max_actions] for _, group_index, log_prob, action_embedding, action in batch_states: new_states.append(make_state(group_index, action, log_prob, action_embedding)) return new_states def attend(self, attention: Attention, query: torch.Tensor, key: torch.Tensor, value: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: """ Given a query (which is typically the decoder hidden state), compute an attention over the output of the question encoder, and return a weighted sum of the question representations given this attention. We also return the attention weights themselves. This is a simple computation, but we have it as a separate method so that the ``forward`` method on the main parser module can call it on the initial hidden state, to simplify the logic in ``take_step``. """ # (group_size, question_length) attention_weights = attention(query, key, None) if value is None: return attention_weights # (group_size, encoder_output_dim) attended_question = util.weighted_sum(value, attention_weights) return attended_question, attention_weights ================================================ FILE: state_machines/transition_functions/basic_transition_function.py ================================================ from collections import defaultdict from typing import Any, Dict, List, Set, Tuple from overrides import overrides import torch from torch.nn.modules.rnn import LSTM, LSTMCell from torch.nn.modules.linear import Linear from allennlp.modules import Attention from allennlp.nn import util, Activation from allennlp.state_machines.states import RnnStatelet, GrammarBasedState from allennlp.state_machines.transition_functions.transition_function import TransitionFunction class BasicTransitionFunction(TransitionFunction[GrammarBasedState]): """ This is a typical transition function for a state-based decoder. We use an LSTM to track decoder state, and at every timestep we compute an attention over the input question/utterance to help in selecting the action. All actions have an embedding, and we use a dot product between a predicted action embedding and the allowed actions to compute a distribution over actions at each timestep. We allow the first action to be predicted separately from everything else. This is optional, and is because that's how the original WikiTableQuestions semantic parser was written. The intuition is that maybe you want to predict the type of your output program outside of the typical LSTM decoder (or maybe Jayant just didn't realize this could be treated as another action...). Parameters ---------- encoder_output_dim : ``int`` action_embedding_dim : ``int`` input_attention : ``Attention`` activation : ``Activation``, optional (default=relu) The activation that gets applied to the decoder LSTM input and to the action query. predict_start_type_separately : ``bool``, optional (default=True) If ``True``, we will predict the initial action (which is typically the base type of the logical form) using a different mechanism than our typical action decoder. We basically just do a projection of the hidden state, and don't update the decoder RNN. num_start_types : ``int``, optional (default=None) If ``predict_start_type_separately`` is ``True``, this is the number of start types that are in the grammar. We need this so we can construct parameters with the right shape. This is unused if ``predict_start_type_separately`` is ``False``. add_action_bias : ``bool``, optional (default=True) If ``True``, there has been a bias dimension added to the embedding of each action, which gets used when predicting the next action. We add a dimension of ones to our predicted action vector in this case to account for that. dropout : ``float`` (optional, default=0.0) num_layers: ``int``, (optional, default=1) The number of layers in the decoder LSTM. """ def __init__(self, encoder_output_dim: int, action_embedding_dim: int, input_attention: Attention, activation: Activation = Activation.by_name('relu')(), predict_start_type_separately: bool = True, num_start_types: int = None, add_action_bias: bool = True, dropout: float = 0.0, num_layers: int = 1) -> None: super().__init__() self._input_attention = input_attention self._add_action_bias = add_action_bias self._activation = activation self._num_layers = num_layers self._predict_start_type_separately = predict_start_type_separately if predict_start_type_separately: self._start_type_predictor = Linear(encoder_output_dim, num_start_types) self._num_start_types = num_start_types else: self._start_type_predictor = None self._num_start_types = None # Decoder output dim needs to be the same as the encoder output dim since we initialize the # hidden state of the decoder with the final hidden state of the encoder. output_dim = encoder_output_dim input_dim = output_dim # Our decoder input will be the concatenation of the decoder hidden state and the previous # action embedding, and we'll project that down to the decoder's `input_dim`, which we # arbitrarily set to be the same as `output_dim`. self._input_projection_layer = Linear(output_dim + action_embedding_dim, input_dim) # Before making a prediction, we'll compute an attention over the input given our updated # hidden state. Then we concatenate those with the decoder state and project to # `action_embedding_dim` to make a prediction. self._output_projection_layer = Linear(output_dim + encoder_output_dim, action_embedding_dim) if self._num_layers > 1: self._decoder_cell = LSTM(input_dim, output_dim, self._num_layers) else: # We use a ``LSTMCell`` if we just have one layer because it is slightly faster since we are # just running the LSTM for one step each time. self._decoder_cell = LSTMCell(input_dim, output_dim) if dropout > 0: self._dropout = torch.nn.Dropout(p=dropout) else: self._dropout = lambda x: x @overrides def take_step(self, state: GrammarBasedState, max_actions: int = None, allowed_actions: List[Set[int]] = None) -> List[GrammarBasedState]: if self._predict_start_type_separately and not state.action_history[0]: # The wikitables parser did something different when predicting the start type, which # is our first action. So in this case we break out into a different function. We'll # ignore max_actions on our first step, assuming there aren't that many start types. return self._take_first_step(state, allowed_actions) # Taking a step in the decoder consists of three main parts. First, we'll construct the # input to the decoder and update the decoder's hidden state. Second, we'll use this new # hidden state (and maybe other information) to predict an action. Finally, we will # construct new states for the next step. Each new state corresponds to one valid action # that can be taken from the current state, and they are ordered by their probability of # being selected. updated_state = self._update_decoder_state(state) batch_results = self._compute_action_probabilities(state, updated_state['hidden_state'], updated_state['attention_weights'], updated_state['predicted_action_embeddings']) new_states = self._construct_next_states(state, updated_state, batch_results, max_actions, allowed_actions) return new_states def _update_decoder_state(self, state: GrammarBasedState) -> Dict[str, torch.Tensor]: # For updating the decoder, we're doing a bunch of tensor operations that can be batched # without much difficulty. So, we take all group elements and batch their tensors together # before doing these decoder operations. group_size = len(state.batch_indices) attended_question = torch.stack([rnn_state.attended_input for rnn_state in state.rnn_state]) if self._num_layers > 1: hidden_state = torch.stack([rnn_state.hidden_state for rnn_state in state.rnn_state], 1) memory_cell = torch.stack([rnn_state.memory_cell for rnn_state in state.rnn_state], 1) else: hidden_state = torch.stack([rnn_state.hidden_state for rnn_state in state.rnn_state]) memory_cell = torch.stack([rnn_state.memory_cell for rnn_state in state.rnn_state]) previous_action_embedding = torch.stack([rnn_state.previous_action_embedding for rnn_state in state.rnn_state]) # (group_size, decoder_input_dim) projected_input = self._input_projection_layer(torch.cat([attended_question, previous_action_embedding], -1)) decoder_input = self._activation(projected_input) if self._num_layers > 1: _, (hidden_state, memory_cell) = self._decoder_cell(decoder_input.unsqueeze(0), (hidden_state, memory_cell)) else: hidden_state, memory_cell = self._decoder_cell(decoder_input, (hidden_state, memory_cell)) hidden_state = self._dropout(hidden_state) # (group_size, encoder_output_dim) encoder_outputs = torch.stack([state.rnn_state[0].encoder_outputs[i] for i in state.batch_indices]) encoder_output_mask = torch.stack([state.rnn_state[0].encoder_output_mask[i] for i in state.batch_indices]) if self._num_layers > 1: attended_question, attention_weights = self.attend_on_question(hidden_state[-1], encoder_outputs, encoder_output_mask) action_query = torch.cat([hidden_state[-1], attended_question], dim=-1) else: attended_question, attention_weights = self.attend_on_question(hidden_state, encoder_outputs, encoder_output_mask) action_query = torch.cat([hidden_state, attended_question], dim=-1) # (group_size, action_embedding_dim) projected_query = self._activation(self._output_projection_layer(action_query)) predicted_action_embeddings = self._dropout(projected_query) if self._add_action_bias: # NOTE: It's important that this happens right before the dot product with the action # embeddings. Otherwise this isn't a proper bias. We do it here instead of right next # to the `.mm` below just so we only do it once for the whole group. ones = predicted_action_embeddings.new([[1] for _ in range(group_size)]) predicted_action_embeddings = torch.cat([predicted_action_embeddings, ones], dim=-1) return { 'hidden_state': hidden_state, 'memory_cell': memory_cell, 'attended_question': attended_question, 'attention_weights': attention_weights, 'predicted_action_embeddings': predicted_action_embeddings, } def _compute_action_probabilities(self, state: GrammarBasedState, hidden_state: torch.Tensor, attention_weights: torch.Tensor, predicted_action_embeddings: torch.Tensor ) -> Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]]: # We take a couple of extra arguments here because subclasses might use them. # pylint: disable=unused-argument,no-self-use # In this section we take our predicted action embedding and compare it to the available # actions in our current state (which might be different for each group element). For # computing action scores, we'll forget about doing batched / grouped computation, as it # adds too much complexity and doesn't speed things up, anyway, with the operations we're # doing here. This means we don't need any action masks, as we'll only get the right # lengths for what we're computing. group_size = len(state.batch_indices) actions = state.get_valid_actions() batch_results: Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]] = defaultdict(list) for group_index in range(group_size): instance_actions = actions[group_index] predicted_action_embedding = predicted_action_embeddings[group_index] action_embeddings, output_action_embeddings, action_ids = instance_actions['global'] # This is just a matrix product between a (num_actions, embedding_dim) matrix and an # (embedding_dim, 1) matrix. action_logits = action_embeddings.mm(predicted_action_embedding.unsqueeze(-1)).squeeze(-1) current_log_probs = torch.nn.functional.log_softmax(action_logits, dim=-1) # This is now the total score for each state after taking each action. We're going to # sort by this later, so it's important that this is the total score, not just the # score for the current action. log_probs = state.score[group_index] + current_log_probs batch_results[state.batch_indices[group_index]].append((group_index, log_probs, current_log_probs, output_action_embeddings, action_ids)) return batch_results def _construct_next_states(self, state: GrammarBasedState, updated_rnn_state: Dict[str, torch.Tensor], batch_action_probs: Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]], max_actions: int, allowed_actions: List[Set[int]]): # pylint: disable=no-self-use # We'll yield a bunch of states here that all have a `group_size` of 1, so that the # learning algorithm can decide how many of these it wants to keep, and it can just regroup # them later, as that's a really easy operation. # # We first define a `make_state` method, as in the logic that follows we want to create # states in a couple of different branches, and we don't want to duplicate the # state-creation logic. This method creates a closure using variables from the method, so # it doesn't make sense to pull it out of here. # Each group index here might get accessed multiple times, and doing the slicing operation # each time is more expensive than doing it once upfront. These three lines give about a # 10% speedup in training time. group_size = len(state.batch_indices) chunk_index = 1 if self._num_layers > 1 else 0 hidden_state = [x.squeeze(chunk_index) for x in updated_rnn_state['hidden_state'].chunk(group_size, chunk_index)] memory_cell = [x.squeeze(chunk_index) for x in updated_rnn_state['memory_cell'].chunk(group_size, chunk_index)] attended_question = [x.squeeze(0) for x in updated_rnn_state['attended_question'].chunk(group_size, 0)] def make_state(group_index: int, action: int, new_score: torch.Tensor, action_embedding: torch.Tensor) -> GrammarBasedState: new_rnn_state = RnnStatelet(hidden_state[group_index], memory_cell[group_index], action_embedding, attended_question[group_index], state.rnn_state[group_index].encoder_outputs, state.rnn_state[group_index].encoder_output_mask) batch_index = state.batch_indices[group_index] for i, _, current_log_probs, _, actions in batch_action_probs[batch_index]: if i == group_index: considered_actions = actions probabilities = current_log_probs.exp().cpu() break return state.new_state_from_group_index(group_index, action, new_score, new_rnn_state, considered_actions, probabilities, updated_rnn_state['attention_weights']) new_states = [] for _, results in batch_action_probs.items(): if allowed_actions and not max_actions: # If we're given a set of allowed actions, and we're not just keeping the top k of # them, we don't need to do any sorting, so we can speed things up quite a bit. for group_index, log_probs, _, action_embeddings, actions in results: for log_prob, action_embedding, action in zip(log_probs, action_embeddings, actions): if action in allowed_actions[group_index]: new_states.append(make_state(group_index, action, log_prob, action_embedding)) else: # In this case, we need to sort the actions. We'll do that on CPU, as it's easier, # and our action list is on the CPU, anyway. group_indices = [] group_log_probs: List[torch.Tensor] = [] group_action_embeddings = [] group_actions = [] for group_index, log_probs, _, action_embeddings, actions in results: if not actions: continue group_indices.extend([group_index] * len(actions)) group_log_probs.append(log_probs) group_action_embeddings.append(action_embeddings) group_actions.extend(actions) if len(group_log_probs) == 0: continue log_probs = torch.cat(group_log_probs, dim=0) action_embeddings = torch.cat(group_action_embeddings, dim=0) log_probs_cpu = log_probs.data.cpu().numpy().tolist() batch_states = [(log_probs_cpu[i], group_indices[i], log_probs[i], action_embeddings[i], group_actions[i]) for i in range(len(group_actions)) if (not allowed_actions or group_actions[i] in allowed_actions[group_indices[i]])] # We use a key here to make sure we're not trying to compare anything on the GPU. batch_states.sort(key=lambda x: x[0], reverse=True) if max_actions: batch_states = batch_states[:max_actions] for _, group_index, log_prob, action_embedding, action in batch_states: new_states.append(make_state(group_index, action, log_prob, action_embedding)) return new_states def _take_first_step(self, state: GrammarBasedState, allowed_actions: List[Set[int]] = None) -> List[GrammarBasedState]: # We'll just do a projection from the current hidden state (which was initialized with the # final encoder output) to the number of start actions that we have, normalize those # logits, and use that as our score. We end up duplicating some of the logic from # `_compute_new_states` here, but we do things slightly differently, and it's easier to # just copy the parts we need than to try to re-use that code. # (group_size, hidden_dim) hidden_state = torch.stack([rnn_state.hidden_state for rnn_state in state.rnn_state]) # (group_size, num_start_type) start_action_logits = self._start_type_predictor(hidden_state) log_probs = torch.nn.functional.log_softmax(start_action_logits, dim=-1) sorted_log_probs, sorted_actions = log_probs.sort(dim=-1, descending=True) sorted_actions = sorted_actions.detach().cpu().numpy().tolist() if state.debug_info is not None: probs_cpu = log_probs.exp().detach().cpu().numpy().tolist() else: probs_cpu = [None] * len(state.batch_indices) # state.get_valid_actions() will return a list that is consistently sorted, so as along as # the set of valid start actions never changes, we can just match up the log prob indices # above with the position of each considered action, and we're good. valid_actions = state.get_valid_actions() considered_actions = [actions['global'][2] for actions in valid_actions] if len(considered_actions[0]) != self._num_start_types: raise RuntimeError("Calculated wrong number of initial actions. Expected " f"{self._num_start_types}, found {len(considered_actions[0])}.") best_next_states: Dict[int, List[Tuple[int, int, int]]] = defaultdict(list) for group_index, (batch_index, group_actions) in enumerate(zip(state.batch_indices, sorted_actions)): for action_index, action in enumerate(group_actions): # `action` is currently the index in `log_probs`, not the actual action ID. To get # the action ID, we need to go through `considered_actions`. action = considered_actions[group_index][action] if allowed_actions is not None and action not in allowed_actions[group_index]: # This happens when our _decoder trainer_ wants us to only evaluate certain # actions, likely because they are the gold actions in this state. We just skip # emitting any state that isn't allowed by the trainer, because constructing the # new state can be expensive. continue best_next_states[batch_index].append((group_index, action_index, action)) new_states = [] for batch_index, best_states in sorted(best_next_states.items()): for group_index, action_index, action in best_states: # We'll yield a bunch of states here that all have a `group_size` of 1, so that the # learning algorithm can decide how many of these it wants to keep, and it can just # regroup them later, as that's a really easy operation. new_score = state.score[group_index] + sorted_log_probs[group_index, action_index] # This part is different from `_compute_new_states` - we're just passing through # the previous RNN state, as predicting the start type wasn't included in the # decoder RNN in the original model. new_rnn_state = state.rnn_state[group_index] new_state = state.new_state_from_group_index(group_index, action, new_score, new_rnn_state, considered_actions[group_index], probs_cpu[group_index], None) new_states.append(new_state) return new_states def attend_on_question(self, query: torch.Tensor, encoder_outputs: torch.Tensor, encoder_output_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Given a query (which is typically the decoder hidden state), compute an attention over the output of the question encoder, and return a weighted sum of the question representations given this attention. We also return the attention weights themselves. This is a simple computation, but we have it as a separate method so that the ``forward`` method on the main parser module can call it on the initial hidden state, to simplify the logic in ``take_step``. """ # (group_size, question_length) question_attention_weights = self._input_attention(query, encoder_outputs, encoder_output_mask) # (group_size, encoder_output_dim) attended_question = util.weighted_sum(encoder_outputs, question_attention_weights) return attended_question, question_attention_weights ================================================ FILE: state_machines/transition_functions/linking_transition_function.py ================================================ from collections import defaultdict from typing import Any, Dict, List, Tuple from overrides import overrides import torch from allennlp.common.checks import check_dimensions_match from allennlp.modules import Attention, FeedForward from allennlp.nn import Activation from allennlp.state_machines.states import GrammarBasedState from state_machines.transition_functions.basic_transition_function import BasicTransitionFunction class LinkingTransitionFunction(BasicTransitionFunction): """ This transition function adds the ability to consider `linked` actions to the ``BasicTransitionFunction`` (which is just an LSTM decoder with attention). These actions are potentially unseen at training time, so we need to handle them without requiring the action to have an embedding. Instead, we rely on a `linking score` between each action and the words in the question/utterance, and use these scores, along with the attention, to do something similar to a copy mechanism when producing these actions. When both linked and global (embedded) actions are available, we need some way to compare the scores for these two sets of actions. The original WikiTableQuestion semantic parser just concatenated the logits together before doing a joint softmax, but this is quite brittle, because the logits might have quite different scales. So we have the option here of predicting a mixture probability between two independently normalized distributions. Parameters ---------- encoder_output_dim : ``int`` action_embedding_dim : ``int`` input_attention : ``Attention`` activation : ``Activation``, optional (default=relu) The activation that gets applied to the decoder LSTM input and to the action query. predict_start_type_separately : ``bool``, optional (default=True) If ``True``, we will predict the initial action (which is typically the base type of the logical form) using a different mechanism than our typical action decoder. We basically just do a projection of the hidden state, and don't update the decoder RNN. num_start_types : ``int``, optional (default=None) If ``predict_start_type_separately`` is ``True``, this is the number of start types that are in the grammar. We need this so we can construct parameters with the right shape. This is unused if ``predict_start_type_separately`` is ``False``. add_action_bias : ``bool``, optional (default=True) If ``True``, there has been a bias dimension added to the embedding of each action, which gets used when predicting the next action. We add a dimension of ones to our predicted action vector in this case to account for that. mixture_feedforward : ``FeedForward`` optional (default=None) If given, we'll use this to compute a mixture probability between global actions and linked actions given the hidden state at every timestep of decoding, instead of concatenating the logits for both (where the logits may not be compatible with each other). dropout : ``float`` (optional, default=0.0) num_layers: ``int`` (optional, default=1) The number of layers in the decoder LSTM. """ def __init__(self, encoder_output_dim: int, action_embedding_dim: int, input_attention: Attention, activation: Activation = Activation.by_name('relu')(), predict_start_type_separately: bool = True, num_start_types: int = None, add_action_bias: bool = True, mixture_feedforward: FeedForward = None, dropout: float = 0.0, num_layers: int = 1) -> None: super().__init__(encoder_output_dim=encoder_output_dim, action_embedding_dim=action_embedding_dim, input_attention=input_attention, num_start_types=num_start_types, activation=activation, predict_start_type_separately=predict_start_type_separately, add_action_bias=add_action_bias, dropout=dropout, num_layers=num_layers) self._mixture_feedforward = mixture_feedforward if mixture_feedforward is not None: check_dimensions_match(encoder_output_dim, mixture_feedforward.get_input_dim(), "hidden state embedding dim", "mixture feedforward input dim") check_dimensions_match(mixture_feedforward.get_output_dim(), 1, "mixture feedforward output dim", "dimension for scalar value") @overrides def _compute_action_probabilities(self, state: GrammarBasedState, hidden_state: torch.Tensor, attention_weights: torch.Tensor, predicted_action_embeddings: torch.Tensor ) -> Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]]: # In this section we take our predicted action embedding and compare it to the available # actions in our current state (which might be different for each group element). For # computing action scores, we'll forget about doing batched / grouped computation, as it # adds too much complexity and doesn't speed things up, anyway, with the operations we're # doing here. This means we don't need any action masks, as we'll only get the right # lengths for what we're computing. group_size = len(state.batch_indices) actions = state.get_valid_actions() batch_results: Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]] = defaultdict(list) for group_index in range(group_size): instance_actions = actions[group_index] predicted_action_embedding = predicted_action_embeddings[group_index] embedded_actions: List[int] = [] output_action_embeddings = None embedded_action_logits = None current_log_probs = None if 'global' in instance_actions: action_embeddings, output_action_embeddings, embedded_actions = instance_actions['global'] # This is just a matrix product between a (num_actions, embedding_dim) matrix and an # (embedding_dim, 1) matrix. embedded_action_logits = action_embeddings.mm(predicted_action_embedding.unsqueeze(-1)).squeeze(-1) action_ids = embedded_actions if 'linked' in instance_actions: linking_scores, type_embeddings, linked_actions = instance_actions['linked'] action_ids = embedded_actions + linked_actions # (num_question_tokens, 1) linked_action_logits = linking_scores.mm(attention_weights[group_index].unsqueeze(-1)).squeeze(-1) # The `output_action_embeddings` tensor gets used later as the input to the next # decoder step. For linked actions, we don't have any action embedding, so we use # the entity type instead. if output_action_embeddings is not None: output_action_embeddings = torch.cat([output_action_embeddings, type_embeddings], dim=0) else: output_action_embeddings = type_embeddings if self._mixture_feedforward is not None: # The linked and global logits are combined with a mixture weight to prevent the # linked_action_logits from dominating the embedded_action_logits if a softmax # was applied on both together. mixture_weight = self._mixture_feedforward(hidden_state[group_index]) mix1 = torch.log(mixture_weight) mix2 = torch.log(1 - mixture_weight) entity_action_probs = torch.nn.functional.log_softmax(linked_action_logits, dim=-1) + mix1 if embedded_action_logits is not None: embedded_action_probs = torch.nn.functional.log_softmax(embedded_action_logits, dim=-1) + mix2 current_log_probs = torch.cat([embedded_action_probs, entity_action_probs], dim=-1) else: current_log_probs = entity_action_probs else: if embedded_action_logits is not None: action_logits = torch.cat([embedded_action_logits, linked_action_logits], dim=-1) else: action_logits = linked_action_logits current_log_probs = torch.nn.functional.log_softmax(action_logits, dim=-1) elif not instance_actions: action_ids = None current_log_probs = float('inf') else: action_logits = embedded_action_logits current_log_probs = torch.nn.functional.log_softmax(action_logits, dim=-1) # This is now the total score for each state after taking each action. We're going to # sort by this later, so it's important that this is the total score, not just the # score for the current action. log_probs = state.score[group_index] + current_log_probs batch_results[state.batch_indices[group_index]].append((group_index, log_probs, current_log_probs, output_action_embeddings, action_ids)) return batch_results ================================================ FILE: state_machines/transition_functions/prefix_attend_transition.py ================================================ from typing import Dict, Tuple, List, Set, Any, Callable import torch from allennlp.modules import Attention, FeedForward from allennlp.nn import Activation, util from allennlp.state_machines.states.grammar_based_state import GrammarBasedState from allennlp.state_machines.transition_functions.linking_transition_function import LinkingTransitionFunction from overrides import overrides from torch.nn import Linear from state_machines.states.rnn_statelet import RnnStatelet class PrefixAttendTransitionFunction(LinkingTransitionFunction): def __init__(self, encoder_output_dim: int, action_embedding_dim: int, input_attention: Attention, output_attention: Attention, activation: Activation = Activation.by_name('relu')(), predict_start_type_separately: bool = True, num_start_types: int = None, add_action_bias: bool = True, mixture_feedforward: FeedForward = None, dropout: float = 0.0, num_layers: int = 1) -> None: super().__init__(encoder_output_dim=encoder_output_dim, action_embedding_dim=action_embedding_dim, input_attention=input_attention, num_start_types=num_start_types, activation=activation, predict_start_type_separately=predict_start_type_separately, add_action_bias=add_action_bias, dropout=dropout, num_layers=num_layers, mixture_feedforward=mixture_feedforward) self._output_attention = output_attention # override self._input_projection_layer = Linear(encoder_output_dim + action_embedding_dim, encoder_output_dim) self._attend_output_projection_layer = Linear(encoder_output_dim*2, encoder_output_dim) self._first_attended_output = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim)) torch.nn.init.normal_(self._first_attended_output) @overrides def take_step(self, state: GrammarBasedState, max_actions: int = None, allowed_actions: List[Set[int]] = None) -> List[GrammarBasedState]: if self._predict_start_type_separately and not state.action_history[0]: # The wikitables parser did something different when predicting the start type, which # is our first action. So in this case we break out into a different function. We'll # ignore max_actions on our first step, assuming there aren't that many start types. return self._take_first_step(state, allowed_actions) # Taking a step in the decoder consists of three main parts. First, we'll construct the # input to the decoder and update the decoder's hidden state. Second, we'll use this new # hidden state (and maybe other information) to predict an action. Finally, we will # construct new states for the next step. Each new state corresponds to one valid action # that can be taken from the current state, and they are ordered by their probability of # being selected. updated_state = self._update_decoder_state(state) batch_results = self._compute_action_probabilities(state, updated_state['hidden_state'], updated_state['attention_weights'], updated_state['predicted_action_embeddings']) new_states = self._construct_next_states(state, updated_state, batch_results, max_actions, allowed_actions) return new_states @overrides def _update_decoder_state(self, state: GrammarBasedState) -> Dict[str, torch.Tensor]: # For updating the decoder, we're doing a bunch of tensor operations that can be batched # without much difficulty. So, we take all group elements and batch their tensors together # before doing these decoder operations. group_size = len(state.batch_indices) attended_question = torch.stack([rnn_state.attended_input for rnn_state in state.rnn_state]) hidden_state = torch.stack([rnn_state.hidden_state for rnn_state in state.rnn_state]) memory_cell = torch.stack([rnn_state.memory_cell for rnn_state in state.rnn_state]) previous_action_embedding = torch.stack([rnn_state.previous_action_embedding for rnn_state in state.rnn_state]) if not state.action_history[0]: attended_output, output_attention_weights = self._first_attended_output.unsqueeze(0).repeat(group_size, 1),\ None else: decoder_outputs = torch.stack([rnn_state.decoder_outputs for rnn_state in state.rnn_state]) decoded_item_embeddings = torch.stack([rnn_state.decoded_item_embeddings for rnn_state in state.rnn_state]) action_query = torch.cat([hidden_state, attended_question], dim=-1) # (group_size, action_embedding_dim) projected_query = self._activation(self._attend_output_projection_layer(action_query)) query = self._dropout(projected_query) attended_output, output_attention_weights = self.attend(query, decoder_outputs, decoded_item_embeddings) # (group_size, decoder_input_dim) projected_input = self._input_projection_layer(torch.cat([attended_question, previous_action_embedding + attended_output], -1)) decoder_input = self._activation(projected_input) hidden_state, memory_cell = self._decoder_cell(decoder_input, (hidden_state, memory_cell)) hidden_state = self._dropout(hidden_state) # (group_size, encoder_output_dim) encoder_outputs = torch.stack([state.rnn_state[0].encoder_outputs[i] for i in state.batch_indices]) encoder_output_mask = torch.stack([state.rnn_state[0].encoder_output_mask[i] for i in state.batch_indices]) attended_question, attention_weights = self.attend_on_question(hidden_state, encoder_outputs, encoder_output_mask) action_query = torch.cat([hidden_state, attended_question], dim=-1) # (group_size, action_embedding_dim) projected_query = self._activation(self._output_projection_layer(action_query)) predicted_action_embeddings = self._dropout(projected_query) if self._add_action_bias: # NOTE: It's important that this happens right before the dot product with the action # embeddings. Otherwise this isn't a proper bias. We do it here instead of right next # to the `.mm` below just so we only do it once for the whole group. ones = predicted_action_embeddings.new([[1] for _ in range(group_size)]) predicted_action_embeddings = torch.cat([predicted_action_embeddings, ones], dim=-1) return { 'hidden_state': hidden_state, 'memory_cell': memory_cell, 'attended_question': attended_question, 'attended_output': attended_output, 'attention_weights': attention_weights, 'output_attention_weights': output_attention_weights, 'predicted_action_embeddings': predicted_action_embeddings } def attend(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Given a query (which is typically the decoder hidden state), compute an attention over the output of the question encoder, and return a weighted sum of the question representations given this attention. We also return the attention weights themselves. This is a simple computation, but we have it as a separate method so that the ``forward`` method on the main parser module can call it on the initial hidden state, to simplify the logic in ``take_step``. """ # (group_size, question_length) question_attention_weights = self._output_attention(query, key, None) # (group_size, encoder_output_dim) attended_question = util.weighted_sum(value, question_attention_weights) return attended_question, question_attention_weights def _construct_next_states(self, state: GrammarBasedState, updated_rnn_state: Dict[str, torch.Tensor], batch_action_probs: Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]], max_actions: int, allowed_actions: List[Set[int]]): # pylint: disable=no-self-use # We'll yield a bunch of states here that all have a `group_size` of 1, so that the # learning algorithm can decide how many of these it wants to keep, and it can just regroup # them later, as that's a really easy operation. # # We first define a `make_state` method, as in the logic that follows we want to create # states in a couple of different branches, and we don't want to duplicate the # state-creation logic. This method creates a closure using variables from the method, so # it doesn't make sense to pull it out of here. # Each group index here might get accessed multiple times, and doing the slicing operation # each time is more expensive than doing it once upfront. These three lines give about a # 10% speedup in training time. group_size = len(state.batch_indices) chunk_index = 1 if self._num_layers > 1 else 0 hidden_state = [x.squeeze(chunk_index) for x in updated_rnn_state['hidden_state'].chunk(group_size, chunk_index)] memory_cell = [x.squeeze(chunk_index) for x in updated_rnn_state['memory_cell'].chunk(group_size, chunk_index)] if not state.action_history[0]: decoder_outputs = updated_rnn_state['hidden_state'].unsqueeze(1) else: decoder_outputs = torch.cat(( torch.stack([x.decoder_outputs for x in state.rnn_state]), updated_rnn_state['hidden_state'].unsqueeze(1) ), dim=1) attended_question = [x.squeeze(0) for x in updated_rnn_state['attended_question'].chunk(group_size, 0)] def make_state(group_index: int, action: int, new_score: torch.Tensor, action_embedding: torch.Tensor) -> GrammarBasedState: batch_index = state.batch_indices[group_index] action_entity_id = state.action_entity_mapping[batch_index][action] + 1 # add 1 so that -1 becomes 0 (pad) if not state.action_history[0]: decoded_item_embeddings = state.rnn_state[group_index].item_embeddings[action_entity_id].unsqueeze(0) else: decoded_item_embeddings = torch.cat(( state.rnn_state[group_index].decoded_item_embeddings, state.rnn_state[group_index].item_embeddings[action_entity_id].unsqueeze(0) ), dim=0) new_rnn_state = RnnStatelet(hidden_state[group_index], memory_cell[group_index], action_embedding, attended_question[group_index], state.rnn_state[group_index].encoder_outputs, state.rnn_state[group_index].encoder_output_mask, state.rnn_state[group_index].item_embeddings, decoder_outputs[group_index], decoded_item_embeddings) for i, _, current_log_probs, _, actions in batch_action_probs[batch_index]: if i == group_index: considered_actions = actions probabilities = current_log_probs.exp().cpu() break return state.new_state_from_group_index(group_index, action, new_score, new_rnn_state, considered_actions, probabilities, updated_rnn_state['attention_weights'], updated_rnn_state['output_attention_weights']) new_states = [] for _, results in batch_action_probs.items(): if allowed_actions and not max_actions: # If we're given a set of allowed actions, and we're not just keeping the top k of # them, we don't need to do any sorting, so we can speed things up quite a bit. for group_index, log_probs, _, action_embeddings, actions in results: for log_prob, action_embedding, action in zip(log_probs, action_embeddings, actions): if action in allowed_actions[group_index]: new_states.append(make_state(group_index, action, log_prob, action_embedding)) else: # In this case, we need to sort the actions. We'll do that on CPU, as it's easier, # and our action list is on the CPU, anyway. group_indices = [] group_log_probs: List[torch.Tensor] = [] group_action_embeddings = [] group_actions = [] for group_index, log_probs, _, action_embeddings, actions in results: group_indices.extend([group_index] * len(actions)) group_log_probs.append(log_probs) group_action_embeddings.append(action_embeddings) group_actions.extend(actions) log_probs = torch.cat(group_log_probs, dim=0) action_embeddings = torch.cat(group_action_embeddings, dim=0) log_probs_cpu = log_probs.data.cpu().numpy().tolist() batch_states = [(log_probs_cpu[i], group_indices[i], log_probs[i], action_embeddings[i], group_actions[i]) for i in range(len(group_actions)) if (not allowed_actions or group_actions[i] in allowed_actions[group_indices[i]])] # We use a key here to make sure we're not trying to compare anything on the GPU. batch_states.sort(key=lambda x: x[0], reverse=True) if max_actions: batch_states = batch_states[:max_actions] for _, group_index, log_prob, action_embedding, action in batch_states: new_states.append(make_state(group_index, action, log_prob, action_embedding)) return new_states ================================================ FILE: train_configs/defaults.jsonnet ================================================ local dataset_path = "dataset/"; { "random_seed": 5, "numpy_seed": 5, "pytorch_seed": 5, "dataset_reader": { "type": "spider", "tables_file": dataset_path + "tables.json", "dataset_path": dataset_path + "database", "lazy": false, "keep_if_unparsable": false, "loading_limit": -1 }, "validation_dataset_reader": { "type": "spider", "tables_file": dataset_path + "tables.json", "dataset_path": dataset_path + "database", "lazy": false, "keep_if_unparsable": true, "loading_limit": -1 }, "train_data_path": dataset_path + "train_spider.json", "validation_data_path": dataset_path + "dev.json", "model": { "type": "spider", "dataset_path": dataset_path, "parse_sql_on_decoding": true, "gnn": true, "gnn_timesteps": 3, "decoder_self_attend": true, "decoder_use_graph_entities": true, "use_neighbor_similarity_for_linking": true, "question_embedder": { "tokens": { "type": "embedding", "embedding_dim": 200, "trainable": true } }, "action_embedding_dim": 200, "encoder": { "type": "lstm", "input_size": 400, "hidden_size": 400, "bidirectional": true, "num_layers": 1 }, "entity_encoder": { "type": "boe", "embedding_dim": 200, "averaged": true }, "decoder_beam_search": { "beam_size": 10 }, "training_beam_size": 1, "max_decoding_steps": 100, "input_attention": {"type": "dot_product"}, "past_attention": {"type": "dot_product"}, "dropout": 0.5 }, "iterator": { "type": "basic", "batch_size" : 15 }, "validation_iterator": { "type": "basic", "batch_size" : 1 }, "trainer": { "num_epochs": 100, "cuda_device": 0, "patience": 20, "validation_metric": "+sql_match", "optimizer": { "type": "adam", "lr": 0.001, "weight_decay": 5e-4 }, "num_serialized_models_to_keep": 2 } } ================================================ FILE: train_configs/paper_defaults.jsonnet ================================================ local dataset_path = "dataset/"; { "random_seed": 5, "numpy_seed": 5, "pytorch_seed": 5, "dataset_reader": { "type": "spider", "tables_file": dataset_path + "tables.json", "dataset_path": dataset_path + "database", "lazy": false, "keep_if_unparsable": false, "loading_limit": -1 }, "validation_dataset_reader": { "type": "spider", "tables_file": dataset_path + "tables.json", "dataset_path": dataset_path + "database", "lazy": false, "keep_if_unparsable": true, "loading_limit": -1 }, "train_data_path": dataset_path + "train_spider.json", "validation_data_path": dataset_path + "dev.json", "model": { "type": "spider", "dataset_path": dataset_path, "parse_sql_on_decoding": true, "gnn": true, "gnn_timesteps": 2, "decoder_self_attend": true, "decoder_use_graph_entities": true, "use_neighbor_similarity_for_linking": true, "question_embedder": { "tokens": { "type": "embedding", "embedding_dim": 200, "trainable": true } }, "action_embedding_dim": 200, "encoder": { "type": "lstm", "input_size": 400, "hidden_size": 200, "bidirectional": true, "num_layers": 1 }, "entity_encoder": { "type": "boe", "embedding_dim": 200, "averaged": true }, "decoder_beam_search": { "beam_size": 10 }, "training_beam_size": 1, "max_decoding_steps": 100, "input_attention": {"type": "dot_product"}, "past_attention": {"type": "dot_product"}, "dropout": 0.5 }, "iterator": { "type": "basic", "batch_size" : 15 }, "validation_iterator": { "type": "basic", "batch_size" : 1 }, "trainer": { "num_epochs": 100, "cuda_device": 0, "patience": 20, "validation_metric": "+sql_match", "optimizer": { "type": "adam", "lr": 0.001 }, "num_serialized_models_to_keep": 2 } }