"""
TSQL -- Test Suite Query Language
"""
from typing import (
List, Tuple, Dict, Set, Optional, Union, Any, Type,
Iterator, Callable, cast as typing_cast)
import operator
import re
from datetime import datetime
from delphin.exceptions import PyDelphinException, PyDelphinSyntaxError
from delphin import util
from delphin import tsdb
# Default modules need to import the PyDelphin version
from delphin.__about__ import __version__ # noqa: F401
# CUSTOM EXCEPTIONS ###########################################################
[docs]class TSQLError(PyDelphinException):
"""Raised on invalid TSQL operations."""
[docs]class TSQLSyntaxError(PyDelphinSyntaxError):
"""Raised when encountering an invalid TSQL query."""
# LOCAL TYPES #################################################################
_Names = List[str]
_Comparison = Tuple[str, Tuple[str, tsdb.Value]]
# the following should be recursive:
# _Boolean = Tuple[str, Union[_Boolean, List[_Boolean, ...]]]
# but use Any until the type checker supports recursive types.
# see: https://github.com/python/mypy/issues/731
_Boolean = Tuple[str, Any]
_Condition = Union[_Comparison, _Boolean]
_FilterFunction = Callable[[tsdb.Record], bool]
_QNameResolver = Callable[[str], Tuple[str, tsdb.Field]]
class _Record(tsdb.Record):
"""Dummy Record type to mimic the call signature of itsdb.Row."""
def __new__(cls,
fields: tsdb.Fields,
data: tsdb.Record,
field_index: tsdb.FieldIndex = None):
return tuple(data)
class Selection(tsdb.Records):
def __init__(self,
record_class: Optional[Type[_Record]] = None) -> None:
"""
The results of a 'select' query.
"""
self.fields: List[tsdb.Field] = []
self._field_index: tsdb.FieldIndex = {}
self.data: tsdb.Records = []
self.projection = None
if record_class is None:
record_class = _Record
self.record_class = record_class
self.joined: Set[str] = set()
def __iter__(self) -> Iterator[tsdb.Record]:
if self.projection is None:
return self.select()
else:
return self.select(*self.projection)
def select(self,
*names: str,
cast: bool = False) -> Iterator[tsdb.Record]:
if not names:
indices = list(range(len(self.fields)))
else:
indices = [self._field_index[name] for name in names]
fields = [self.fields[idx] for idx in indices]
index = tsdb.make_field_index(fields)
cls = self.record_class
for record in self.data:
record = getattr(record, 'data', record) # in case it's a Row
data = tuple(record[idx] for idx in indices)
if cast and all(value is None or isinstance(value, str)
for value in data):
data = typing_cast(Tuple[Optional[str], ...], data)
data = tuple(tsdb.cast(field.datatype, value)
for field, value in zip(fields, data))
yield cls(fields, data, field_index=index)
# QUERY INSPECTION ############################################################
[docs]def inspect_query(querystring: str) -> dict:
"""
Parse *querystring* and return the interpreted query dictionary.
Example:
>>> from delphin import tsql
>>> from pprint import pprint
>>> pprint(tsql.inspect_query(
... 'select i-input from item where i-id < 100'))
{'type': 'select',
'projection': ['i-input'],
'relations': ['item'],
'condition': ('<', ('i-id', 100))}
"""
return _parse_query(querystring)
# QUERY PROCESSING ############################################################
[docs]def query(querystring: str,
db: tsdb.Database,
**kwargs):
"""
Perform query *querystring* on the testsuite *ts*.
Note: currently only 'select' queries are supported.
Args:
querystring (str): TSQL query string
ts (:class:`delphin.itsdb.TestSuite`): testsuite to query over
kwargs: keyword arguments passed to the more specific query
function (e.g., :func:`select`)
Example:
>>> list(tsql.query('select i-id where i-length < 4', ts))
[[142], [1061]]
"""
queryobj = _parse_query(querystring)
if queryobj['type'] in ('select', 'retrieve'):
return _select(
queryobj['projection'],
queryobj['relations'],
queryobj['condition'],
db,
record_class=kwargs.get('record_class', None))
else:
# not really a syntax error; replace with TSQLError or something
# when the proper exception class exists
raise TSQLSyntaxError(queryobj['type'] + ' queries are not supported',
text=querystring)
[docs]def select(querystring: str,
db: tsdb.Database,
record_class: Optional[Type[_Record]] = None) -> Selection:
"""
Perform the TSQL selection query *querystring* on testsuite *ts*.
Note: The `select`/`retrieve` part of the query is not included.
Args:
querystring: TSQL select query
db: TSDB database to query over
Example:
>>> list(tsql.select('i-id where i-length < 4', ts))
[[142], [1061]]
"""
queryobj = _parse_select(querystring)
return _select(
queryobj['projection'],
queryobj['relations'],
queryobj['condition'],
db,
record_class=record_class)
def _select(projection: _Names,
relations: List[str],
condition: Optional[_Condition],
db: tsdb.Database,
record_class: Optional[Type[_Record]]) -> Selection:
proj, joins, condition = _make_execution_plan(
projection, relations, condition, db)
selection = Selection(record_class=record_class)
for name, columns in joins:
_join(selection, db, name, columns, 'inner')
if condition:
cond = _process_condition_function(condition, selection)
selection.data = list(filter(cond, selection.data))
selection.projection = proj
return selection
def _make_execution_plan(
projection: _Names,
relations: List[str],
condition: Optional[_Condition],
db: tsdb.Database) -> Tuple:
"""Make a plan for all relations to join and columns to keep."""
resolve_qname = _make_qname_resolver(db, relations)
if projection == ['*']:
projection = _project_all(relations, db)
else:
projection = [resolve_qname(name)[0] for name in projection]
cond_resolved: Optional[_Condition] = None
cond_fields: _Names = []
if condition:
cond_resolved, cond_fields = _process_condition_fields(
condition, resolve_qname)
joins = _plan_joins(projection, cond_fields, relations, db)
return projection, joins, cond_resolved
def _project_all(relations: List[str], db: tsdb.Database) -> List[str]:
projection = []
keys_added: Set[str] = set()
for name in relations:
for field in db.schema[name]:
qname = f'{name}.{field.name}'
# only include same keys once
if not field.is_key:
projection.append(qname)
elif field.name not in keys_added:
projection.append(qname)
keys_added.add(field.name)
return projection
def _make_qname_resolver(
db: tsdb.Database,
relations: List[str]) -> _QNameResolver:
"""
Return a function that turns column names into qualified names.
For example, `i-input` becomes `item.i-input`.
"""
index = {rel: tsdb.make_field_index(db.schema[rel]) for rel in db.schema}
schema_map: Dict[str, List[str]] = {}
for relname, fields in db.schema.items():
for field in fields:
schema_map.setdefault(field.name, []).append(relname)
# prefer those appearing in specified relations
for colname in schema_map:
schema_map[colname] = sorted(schema_map[colname],
key=relations.__contains__,
reverse=True)
def resolve(colname: str) -> Tuple[str, tsdb.Field]:
rel, _, col = colname.rpartition('.')
if rel:
qname = colname
elif col in schema_map:
rel = schema_map[col][0]
qname = f'{rel}.{col}'
else:
raise TSQLError(f'undefined column: {colname}')
return qname, db.schema[rel][index[rel][col]]
return resolve
def _plan_joins(projection, condition_fields, relations, db):
"""
Calculate the relations and columns needed for the query.
"""
joinmap = {}
added = set()
relset = set(relations)
for qname in projection + list(condition_fields):
if qname not in added:
rel, _, col = qname.rpartition('.')
relset.add(rel)
joinmap.setdefault(rel, []).append(col)
added.add(qname)
# add necessary relations to span all requested relations
keymap = _make_keymap(db)
relset.update(_pivot_relations(relset, keymap, db))
# always add keys
for relation in relset:
for field in db.schema[relation]:
if field.is_key:
qname = f'{relation}.{field.name}'
if qname not in added:
joinmap.setdefault(relation, []).append(field.name)
# finally ensure joins occur in a valid order
joined_keys = set()
joins = []
while joinmap:
changed = False
for rel in list(joinmap):
if not joins or joined_keys.intersection(joinmap[rel]):
joins.append((rel, joinmap.pop(rel)))
joined_keys.update(keymap[rel])
changed = True
break
if not changed:
raise TSQLError('infinite loop detected!')
return joins
def _make_keymap(db):
keymap = {}
for rel, fields in db.schema.items():
keys = [field.name for field in db.schema[rel] if field.is_key]
keymap[rel] = keys
return keymap
def _pivot_relations(relset, keymap, db):
"""
Search to find a relation that can join two disjoint relations.
Note: If disjoint relation sets cannot be conjoined with a single
other relation, a TSQLError is raised.
"""
edges = []
nodes = set()
def add_edges(keys):
for i in range(len(keys) - 1):
for j in range(i + 1, len(keys)):
edges.append((keys[i], keys[j]))
for rel in relset:
keys = keymap[rel]
nodes.update(keys)
add_edges(keys)
pivots = set()
components = util._connected_components(nodes, edges)
while len(components) > 1:
improved = False
for rel, keys in keymap.items():
if rel not in relset.union(pivots) and len(keys) > 1:
if sum(1 if c.intersection(keys) else 0
for c in components) > 1:
nodes.update(keys)
add_edges(keys)
pivots.add(rel)
improved = True
break
if not improved:
raise TSQLError('could not find relation to join: {}'
.format(', '.join(sorted(relset))))
components = util._connected_components(nodes, edges)
return pivots
_operator_functions = {'==': operator.eq,
'!=': operator.ne,
'<': operator.lt,
'<=': operator.le,
'>': operator.gt,
'>=': operator.ge}
def _process_condition_fields(
condition: _Condition,
resolve_qname: _QNameResolver) -> Tuple[_Condition, _Names]:
# conditions are something like:
# ('==', ('i-id', 11))
op, body = condition
if op in ('and', 'or'):
body = typing_cast(List[_Condition], body)
fieldset = set()
conditions = []
for cond in body:
_cond, _fields = _process_condition_fields(
cond, resolve_qname)
fieldset.update(_fields)
conditions.append(_cond)
return (op, conditions), sorted(fieldset)
elif op == 'not':
ncond, fields = _process_condition_fields(
body, resolve_qname)
return ('not', ncond), fields
else:
qname, field = resolve_qname(body[0])
# check if the condition's type matches the column
typ = _expected_type(field.datatype)
if not isinstance(body[1], typ):
raise TSQLError(
'type mismatch in condition on {}: {} {} {}'
.format(qname, typ.__name__, op, type(body[1]).__name__))
return (op, (qname, body[1])), [qname]
def _expected_type(datatype):
if datatype == ':string':
return str
elif datatype in ':integer':
return int
elif datatype in ':float':
return (int, float)
elif datatype in ':date':
return datetime
def _process_condition_function(
condition: _Condition,
selection: Selection) -> _FilterFunction:
field_index = selection._field_index
fields = selection.fields
# conditions are something like:
# ('==', ('i-id', 11))
op, body = condition
if op in ('and', 'or'):
body = typing_cast(List[_Condition], body)
conditions = []
for cond in body:
_func = _process_condition_function(cond, selection)
conditions.append(_func)
_func = all if op == 'and' else any
def func(row):
return _func(cond(row) for cond in conditions)
elif op == 'not':
nfunc = _process_condition_function(body, selection)
def func(row):
return not nfunc(row)
elif op == '~':
def func(row):
index = field_index[body[0]]
field = fields[index]
value = tsdb.cast(field.datatype, row[index])
return value is not None and re.search(body[1], value)
elif op == '!~':
def func(row):
index = field_index[body[0]]
field = fields[index]
value = tsdb.cast(field.datatype, row[index])
return value is None or not re.search(body[1], value)
else:
compare = _operator_functions[op]
def func(row):
index = field_index[body[0]]
field = fields[index]
value = tsdb.cast(field.datatype, row[index])
return value is not None and compare(value, body[1])
return func
# RELATION JOINS ##############################################################
def _join(selection: Selection,
db: tsdb.Database,
name: str,
columns: _Names,
how: str = 'inner') -> None:
"""
Join *fields* from *relation* into *selection*.
If *how* is `"inner"`, then only matched rows persist after
the join; if *how* is `"left"`, all existing rows are kept and
those without a match are padded with `None` values.
"""
if how not in ('inner', 'left'):
raise TSQLError("only 'inner' and 'left' join methods are allowed")
if name in selection.joined:
raise TSQLError('cannot join the same relation twice')
all_fields = db.schema[name]
field_index = tsdb.make_field_index(all_fields)
indices = [field_index[col] for col in columns]
fields = [all_fields[idx] for idx in indices]
data: List[tsdb.Record] = []
if not selection.joined:
_merge_fields(selection, name, [], fields)
data.extend(db._select_raw(name, columns))
else:
on: List[str] = []
if selection is not None:
on = [f.name for f in fields
if f.is_key and f.name in selection._field_index]
fields = [f for f in fields if f.name not in on]
cols = [f.name for f in fields]
if not on:
raise TSQLError('no shared keys for joining')
right: Dict[Tuple[tsdb.Value, ...], List[tsdb.Record]] = {}
for keys, row in zip(db.select_from(name, on, cast=True),
db._select_raw(name, cols)):
right.setdefault(tuple(keys), []).append(tuple(row))
rfill = tuple([None] * len(fields))
for keys, lrow in zip(selection.select(*on, cast=True), selection):
keys = tuple(keys)
if how == 'left' or keys in right:
for rrow in right.get(keys, [rfill]):
data.append(tuple(lrow) + tuple(rrow))
_merge_fields(selection, name, on, fields)
selection.data = data
def _merge_fields(selection: Selection,
relationname: str,
on: _Names,
fields: tsdb.Fields) -> None:
offset = len(selection.fields)
for i, field in enumerate(fields, offset):
selection.fields.append(field)
if field.name not in selection._field_index:
selection._field_index[field.name] = i
selection._field_index[relationname + '.' + field.name] = i
# also add qualified names for 'on' fields in case the joins
# happen in a strange order
for name in on:
i = selection._field_index[name]
selection._field_index[relationname + '.' + name] = i
selection.joined.add(relationname)
# QUERY PARSING ###############################################################
_year = r'[0-9]{4}'
_yr = r'(?:[0-9]{2})?[0-9]{2}'
_month = r'(?:[0-9][0-9]?|jan|feb|mar|apr|may|jun|jul|aug|sep|oct|nov|dec)'
_day = r'[0-9]{1,2}'
_time = (r'\s*\({t}:{t}(?::{t})?\)'
r'|\s+{t}:{t}(?::{t})').format(t=r'[0-9]{2}')
_yyyy_mm_dd = r'{year}-{month}(?:-{day})?(?:{time})?'.format(
year=_year, month=_month, day=_day, time=_time)
_dd_mm_yy = r'(?:{day}-)?{month}-{year}(?:{time})?'.format(
year=_yr, month=_month, day=_day, time=_time)
_id = r'[a-zA-Z][-_a-zA-Z0-9]*'
_qid = r'{id}\.{id}'.format(id=_id) # qualified id: "table.column"
_TSQLLexer = util.Lexer(
tokens=[
(r'from', 'FROM:from'),
(r'where', 'WHERE:where'),
(r'report', 'REPORT:report'),
(r'\*', 'STAR:*'),
(r'\.', 'DOT:.'),
(r'==|=|!=|~|!~|<=|<|>=|>', 'OP:a comparison operator'),
(r'&&|&|and', "AND:'&&', '&', or 'and'"),
(r'\|\||\||or', "OR:'||', '|', or 'or'"),
(r'!|not', "NOT:'!' or 'not'"),
(r'\(', 'LPAREN:('),
(r'\)', 'RPAREN:)'),
(r'"([^"\\]*(?:\\.[^"\\]*)*)"', 'DQSTRING:a double-quoted string'),
(r"'([^'\\]*(?:\\.[^'\\]*)*)'", 'SQSTRING:a single-quoted string'),
(_yyyy_mm_dd, 'YYYYMMDD:a YYYY-MM-DD date'),
(_dd_mm_yy, 'DDMMYY: a DD-MM-YY date'),
(r':today|now', "KWDATE:'now' or ':today'"),
(r'[+-]?\d+', 'INT:an integer'),
(_qid, 'QID:a qualified identifier'),
(_id, 'ID:a simple identifier'),
(r'[^\s]', 'UNEXPECTED')
],
error_class=TSQLSyntaxError)
_FROM = _TSQLLexer.tokentypes.FROM
_WHERE = _TSQLLexer.tokentypes.WHERE
_REPORT = _TSQLLexer.tokentypes.REPORT
_STAR = _TSQLLexer.tokentypes.STAR
_DOT = _TSQLLexer.tokentypes.DOT
_OP = _TSQLLexer.tokentypes.OP
_AND = _TSQLLexer.tokentypes.AND
_OR = _TSQLLexer.tokentypes.OR
_NOT = _TSQLLexer.tokentypes.NOT
_LPAREN = _TSQLLexer.tokentypes.LPAREN
_RPAREN = _TSQLLexer.tokentypes.RPAREN
_DQSTRING = _TSQLLexer.tokentypes.DQSTRING
_SQSTRING = _TSQLLexer.tokentypes.SQSTRING
_YYYYMMDD = _TSQLLexer.tokentypes.YYYYMMDD
_DDMMYY = _TSQLLexer.tokentypes.DDMMYY
_KWDATE = _TSQLLexer.tokentypes.KWDATE
_INT = _TSQLLexer.tokentypes.INT
_QID = _TSQLLexer.tokentypes.QID
_ID = _TSQLLexer.tokentypes.ID
_UNEXPECTED = _TSQLLexer.tokentypes.UNEXPECTED
def _parse_query(querystring: str) -> dict:
querytype, _, querybody = querystring.lstrip().partition(' ')
querytype = querytype.lower()
if querytype in ('select', 'retrieve'):
result = _parse_select(querybody)
else:
raise TSQLSyntaxError(f"'{querytype}' queries are not supported",
lineno=1)
return result
def _parse_select(querystring: str) -> dict:
querystring += '.' # just a sentinel to indicate the end of the query
lexer = _TSQLLexer.lex(querystring.splitlines())
projection = _parse_select_projection(lexer)
relations = _parse_select_from(lexer)
condition = _parse_select_where(lexer)
lexer.expect_type(_DOT)
if projection == ['*'] and not relations:
raise TSQLSyntaxError(
"'select *' requires a 'from' clause",
text=querystring)
return {'type': 'select',
'projection': projection,
'relations': relations,
'condition': condition}
def _parse_select_projection(lexer: util.LookaheadLexer) -> List[str]:
typ, col_id = lexer.choice_type(_STAR, _QID, _ID)
projection = []
if typ in (_QID, _ID):
while col_id:
projection.append(col_id)
col_id = lexer.accept_type(_QID) or lexer.accept_type(_ID)
else:
projection.append(col_id)
return projection
def _parse_select_from(lexer: util.LookaheadLexer) -> List[str]:
relations = []
if lexer.accept_type(_FROM):
relation = lexer.expect_type(_ID)
while relation:
relations.append(relation)
relation = lexer.accept_type(_ID)
return relations
def _parse_select_where(
lexer: util.LookaheadLexer) -> Optional[_Condition]:
conditions: List[_Condition] = []
while lexer.accept_type(_WHERE):
conditions.append(_parse_condition_disjunction(lexer))
condition: Optional[_Condition] = None
if len(conditions) == 1:
condition = conditions[0]
elif len(conditions) > 1:
condition = ('and', list(conditions))
return condition
def _parse_condition_disjunction(
lexer: util.LookaheadLexer) -> _Condition:
conds = []
while True:
conds.append(_parse_condition_conjunction(lexer))
if not lexer.accept_type(_OR):
break
if len(conds) == 0:
raise TSQLSyntaxError('invalid query')
elif len(conds) == 1:
return conds[0]
else:
return ('or', list(conds))
def _parse_condition_conjunction(
lexer: util.LookaheadLexer) -> _Condition:
conds: List[_Condition] = []
while True:
typ, token = lexer.choice_type(_NOT, _LPAREN, _QID, _ID)
if typ == _NOT:
conds.append(('not', _parse_condition_disjunction(lexer)))
elif typ == _LPAREN:
conds.append(_parse_condition_disjunction(lexer))
lexer.expect_type(_RPAREN)
elif typ in (_QID, _ID):
conds.append(_parse_condition_statement(token, lexer))
if not lexer.accept_type(_AND):
break
if len(conds) == 0:
raise TSQLSyntaxError('invalid query')
elif len(conds) == 1:
return conds[0]
else:
return ('and', list(conds))
def _parse_condition_statement(
column: str,
lexer: util.LookaheadLexer) -> _Comparison:
op = lexer.expect_type(_OP)
if op == '=':
op = '==' # normalize = to == (I think these are equivalent)
if op in ('~', '!~'):
typ, value = lexer.choice_type(_DQSTRING, _SQSTRING)
elif op in ('<', '<=', '>', '>='):
typ, value = lexer.choice_type(_INT, _YYYYMMDD, _DDMMYY, _KWDATE)
else: # must be == or !=
typ, value = lexer.choice_type(_INT, _DQSTRING, _SQSTRING,
_YYYYMMDD, _DDMMYY, _KWDATE)
if typ == _INT:
value = int(value)
elif typ in (_YYYYMMDD, _DDMMYY, _KWDATE):
value = tsdb.cast(':date', value)
return (op, (column, value))