"""
TSQL -- Test Suite Query Language
"""
import operator
import re
from collections.abc import Callable, Iterator
from datetime import datetime
from typing import (
Any,
cast as typing_cast,
)
from delphin import tsdb, util
# Default modules need to import the PyDelphin version
from delphin.__about__ import __version__ # noqa: F401
from delphin.exceptions import PyDelphinException, PyDelphinSyntaxError
# CUSTOM EXCEPTIONS ###########################################################
[docs]
class TSQLError(PyDelphinException):
"""Raised on invalid TSQL operations."""
[docs]
class TSQLSyntaxError(PyDelphinSyntaxError):
"""Raised when encountering an invalid TSQL query."""
# LOCAL TYPES #################################################################
_Names = list[str]
_Comparison = tuple[str, tuple[str, tsdb.Value]]
# _Boolean should be recursive but use Any until the type checker
# supports recursive types.
# see: https://github.com/python/mypy/issues/731
_Boolean = tuple[str, Any]
_Condition = _Comparison | _Boolean
_FilterFunction = Callable[[tsdb.Record], bool]
_QNameResolver = Callable[[str], tuple[str, tsdb.Field]]
class _Record(tsdb.Record):
"""Dummy Record type to mimic the call signature of itsdb.Row."""
def __new__(
cls,
fields: tsdb.Fields,
data: tsdb.Record,
field_index: tsdb.FieldIndex | None = None,
):
return tuple(data)
class Selection(tsdb.Records):
def __init__(self, record_class: type[_Record] | None = None) -> None:
"""
The results of a 'select' query.
"""
self.fields: list[tsdb.Field] = []
self._field_index: tsdb.FieldIndex = {}
self.data: tsdb.Records = []
self.projection = None
if record_class is None:
record_class = _Record
self.record_class = record_class
self.joined: set[str] = set()
def __iter__(self) -> Iterator[tsdb.Record]:
if self.projection is None:
return self.select()
else:
return self.select(*self.projection)
def select(self, *names: str, cast: bool = False) -> Iterator[tsdb.Record]:
if not names:
indices = list(range(len(self.fields)))
else:
indices = [self._field_index[name] for name in names]
fields = [self.fields[idx] for idx in indices]
index = tsdb.make_field_index(fields)
cls = self.record_class
for record in self.data:
record = getattr(record, "data", record) # in case it's a Row
data = tuple(record[idx] for idx in indices)
if cast and all(value is None or isinstance(value, str) for value in data):
data = typing_cast(tuple[str | None, ...], data)
data = tuple(
tsdb.cast(field.datatype, value)
for field, value in zip(fields, data, strict=True)
)
yield cls(fields, data, field_index=index)
# QUERY INSPECTION ############################################################
[docs]
def inspect_query(querystring: str) -> dict:
"""
Parse *querystring* and return the interpreted query dictionary.
Example:
>>> from delphin import tsql
>>> from pprint import pprint
>>> pprint(tsql.inspect_query("select i-input from item where i-id < 100"))
{'type': 'select',
'projection': ['i-input'],
'relations': ['item'],
'condition': ('<', ('i-id', 100))}
"""
return _parse_query(querystring)
# QUERY PROCESSING ############################################################
[docs]
def query(querystring: str, db: tsdb.Database, **kwargs):
"""
Perform query *querystring* on the testsuite *ts*.
Note: currently only 'select' queries are supported.
Args:
querystring (str): TSQL query string
ts (:class:`delphin.itsdb.TestSuite`): testsuite to query over
kwargs: keyword arguments passed to the more specific query
function (e.g., :func:`select`)
Example:
>>> list(tsql.query("select i-id where i-length < 4", ts))
[[142], [1061]]
"""
queryobj = _parse_query(querystring)
if queryobj["type"] in ("select", "retrieve"):
return _select(
queryobj["projection"],
queryobj["relations"],
queryobj["condition"],
db,
record_class=kwargs.get("record_class", None),
)
else:
# not really a syntax error; replace with TSQLError or something
# when the proper exception class exists
raise TSQLSyntaxError(
queryobj["type"] + " queries are not supported", text=querystring
)
[docs]
def select(
querystring: str, db: tsdb.Database, record_class: type[_Record] | None = None
) -> Selection:
"""
Perform the TSQL selection query *querystring* on testsuite *ts*.
Note: The `select`/`retrieve` part of the query is not included.
Args:
querystring: TSQL select query
db: TSDB database to query over
Example:
>>> list(tsql.select("i-id where i-length < 4", ts))
[[142], [1061]]
"""
queryobj = _parse_select(querystring)
return _select(
queryobj["projection"],
queryobj["relations"],
queryobj["condition"],
db,
record_class=record_class,
)
def _select(
projection: _Names,
relations: list[str],
condition: _Condition | None,
db: tsdb.Database,
record_class: type[_Record] | None,
) -> Selection:
proj, joins, condition = _make_execution_plan(projection, relations, condition, db)
selection = Selection(record_class=record_class)
for name, columns in joins:
_join(selection, db, name, columns, "inner")
if condition:
cond = _process_condition_function(condition, selection)
selection.data = list(filter(cond, selection.data))
selection.projection = proj
return selection
def _make_execution_plan(
projection: _Names,
relations: list[str],
condition: _Condition | None,
db: tsdb.Database,
) -> tuple:
"""Make a plan for all relations to join and columns to keep."""
resolve_qname = _make_qname_resolver(db, relations)
if projection == ["*"]:
projection = _project_all(relations, db)
else:
projection = [resolve_qname(name)[0] for name in projection]
cond_resolved: _Condition | None = None
cond_fields: _Names = []
if condition:
cond_resolved, cond_fields = _process_condition_fields(condition, resolve_qname)
joins = _plan_joins(projection, cond_fields, relations, db)
return projection, joins, cond_resolved
def _project_all(relations: list[str], db: tsdb.Database) -> list[str]:
projection = []
keys_added: set[str] = set()
for name in relations:
for field in db.schema[name]:
qname = f"{name}.{field.name}"
# only include same keys once
if not field.is_key:
projection.append(qname)
elif field.name not in keys_added:
projection.append(qname)
keys_added.add(field.name)
return projection
def _make_qname_resolver(db: tsdb.Database, relations: list[str]) -> _QNameResolver:
"""
Return a function that turns column names into qualified names.
For example, `i-input` becomes `item.i-input`.
"""
index = {rel: tsdb.make_field_index(db.schema[rel]) for rel in db.schema}
schema_map: dict[str, list[str]] = {}
for relname, fields in db.schema.items():
for field in fields:
schema_map.setdefault(field.name, []).append(relname)
# prefer those appearing in specified relations
for colname in schema_map:
schema_map[colname] = sorted(
schema_map[colname], key=relations.__contains__, reverse=True
)
def resolve(colname: str) -> tuple[str, tsdb.Field]:
rel, _, col = colname.rpartition(".")
if rel:
qname = colname
elif col in schema_map:
rel = schema_map[col][0]
qname = f"{rel}.{col}"
else:
raise TSQLError(f"undefined column: {colname}")
return qname, db.schema[rel][index[rel][col]]
return resolve
def _plan_joins(projection, condition_fields, relations, db):
"""
Calculate the relations and columns needed for the query.
"""
joinmap = {}
added = set()
relset = set(relations)
for qname in projection + list(condition_fields):
if qname not in added:
rel, _, col = qname.rpartition(".")
relset.add(rel)
joinmap.setdefault(rel, []).append(col)
added.add(qname)
# add necessary relations to span all requested relations
keymap = _make_keymap(db)
relset.update(_pivot_relations(relset, keymap, db))
# always add keys
for relation in relset:
for field in db.schema[relation]:
if field.is_key:
qname = f"{relation}.{field.name}"
if qname not in added:
joinmap.setdefault(relation, []).append(field.name)
# finally ensure joins occur in a valid order
joined_keys = set()
joins = []
while joinmap:
changed = False
for rel in list(joinmap):
if not joins or joined_keys.intersection(joinmap[rel]):
joins.append((rel, joinmap.pop(rel)))
joined_keys.update(keymap[rel])
changed = True
break
if not changed:
raise TSQLError("infinite loop detected!")
return joins
def _make_keymap(db):
keymap = {}
for rel, _fields in db.schema.items():
keys = [field.name for field in db.schema[rel] if field.is_key]
keymap[rel] = keys
return keymap
def _pivot_relations(relset, keymap, db):
"""
Search to find a relation that can join two disjoint relations.
Note: If disjoint relation sets cannot be conjoined with a single
other relation, a TSQLError is raised.
"""
edges = []
nodes = set()
def add_edges(keys):
for i in range(len(keys) - 1):
for j in range(i + 1, len(keys)):
edges.append((keys[i], keys[j]))
for rel in relset:
keys = keymap[rel]
nodes.update(keys)
add_edges(keys)
pivots = set()
components = util._connected_components(nodes, edges)
while len(components) > 1:
improved = False
for rel, keys in keymap.items():
if rel not in relset.union(pivots) and len(keys) > 1:
if sum(1 if c.intersection(keys) else 0 for c in components) > 1:
nodes.update(keys)
add_edges(keys)
pivots.add(rel)
improved = True
break
if not improved:
raise TSQLError(
"could not find relation to join: {}".format(", ".join(sorted(relset)))
)
components = util._connected_components(nodes, edges)
return pivots
_operator_functions = {
"==": operator.eq,
"!=": operator.ne,
"<": operator.lt,
"<=": operator.le,
">": operator.gt,
">=": operator.ge,
}
def _process_condition_fields(
condition: _Condition, resolve_qname: _QNameResolver
) -> tuple[_Condition, _Names]:
# conditions are something like: ('==', ('i-id', 11))
op, body = condition
if op in ("and", "or"):
body = typing_cast(list[_Condition], body)
fieldset = set()
conditions = []
for cond in body:
_cond, _fields = _process_condition_fields(cond, resolve_qname)
fieldset.update(_fields)
conditions.append(_cond)
return (op, conditions), sorted(fieldset)
elif op == "not":
ncond, fields = _process_condition_fields(body, resolve_qname)
return ("not", ncond), fields
else:
qname, field = resolve_qname(body[0])
# check if the condition's type matches the column
typ = _expected_type(field.datatype)
if not isinstance(body[1], typ):
raise TSQLError(
f"type mismatch in condition on {qname}: "
f"{typ.__name__} {op} {type(body[1]).__name__}"
)
return (op, (qname, body[1])), [qname]
def _expected_type(datatype):
if datatype == ":string":
return str
elif datatype in ":integer":
return int
elif datatype in ":float":
return (int, float)
elif datatype in ":date":
return datetime
def _process_condition_function(
condition: _Condition, selection: Selection
) -> _FilterFunction:
field_index = selection._field_index
fields = selection.fields
# conditions are something like: ('==', ('i-id', 11))
op, body = condition
if op in ("and", "or"):
body = typing_cast(list[_Condition], body)
conditions = []
for cond in body:
_func = _process_condition_function(cond, selection)
conditions.append(_func)
_func = all if op == "and" else any
def func(row):
return _func(cond(row) for cond in conditions)
elif op == "not":
nfunc = _process_condition_function(body, selection)
def func(row):
return not nfunc(row)
elif op == "~":
def func(row):
index = field_index[body[0]]
field = fields[index]
value = tsdb.cast(field.datatype, row[index])
return value is not None and re.search(body[1], value)
elif op == "!~":
def func(row):
index = field_index[body[0]]
field = fields[index]
value = tsdb.cast(field.datatype, row[index])
return value is None or not re.search(body[1], value)
else:
compare = _operator_functions[op]
def func(row):
index = field_index[body[0]]
field = fields[index]
value = tsdb.cast(field.datatype, row[index])
return value is not None and compare(value, body[1])
return func
# RELATION JOINS ##############################################################
def _join(
selection: Selection,
db: tsdb.Database,
name: str,
columns: _Names,
how: str = "inner",
) -> None:
"""
Join *fields* from *relation* into *selection*.
If *how* is `"inner"`, then only matched rows persist after
the join; if *how* is `"left"`, all existing rows are kept and
those without a match are padded with `None` values.
"""
if how not in ("inner", "left"):
raise TSQLError("only 'inner' and 'left' join methods are allowed")
if name in selection.joined:
raise TSQLError("cannot join the same relation twice")
all_fields = db.schema[name]
field_index = tsdb.make_field_index(all_fields)
indices = [field_index[col] for col in columns]
fields = [all_fields[idx] for idx in indices]
data: list[tsdb.Record] = []
if not selection.joined:
_merge_fields(selection, name, [], fields)
data.extend(db._select_raw(name, columns))
else:
on: list[str] = []
if selection is not None:
on = [
f.name for f in fields if f.is_key and f.name in selection._field_index
]
fields = [f for f in fields if f.name not in on]
cols = [f.name for f in fields]
if not on:
raise TSQLError("no shared keys for joining")
right: dict[tuple[tsdb.Value, ...], list[tsdb.Record]] = {}
for keys, row in zip(
db.select_from(name, on, cast=True), db._select_raw(name, cols), strict=True
):
right.setdefault(tuple(keys), []).append(tuple(row))
rfill = tuple([None] * len(fields))
for keys, lrow in zip(selection.select(*on, cast=True), selection, strict=True):
keys = tuple(keys)
if how == "left" or keys in right:
for rrow in right.get(keys, [rfill]):
data.append(tuple(lrow) + tuple(rrow))
_merge_fields(selection, name, on, fields)
selection.data = data
def _merge_fields(
selection: Selection, relationname: str, on: _Names, fields: tsdb.Fields
) -> None:
offset = len(selection.fields)
for i, field in enumerate(fields, offset):
selection.fields.append(field)
if field.name not in selection._field_index:
selection._field_index[field.name] = i
selection._field_index[relationname + "." + field.name] = i
# also add qualified names for 'on' fields in case the joins
# happen in a strange order
for name in on:
i = selection._field_index[name]
selection._field_index[relationname + "." + name] = i
selection.joined.add(relationname)
# QUERY PARSING ###############################################################
_year = r"[0-9]{4}"
_yr = r"(?:[0-9]{2})?[0-9]{2}"
_month = r"(?:[0-9][0-9]?|jan|feb|mar|apr|may|jun|jul|aug|sep|oct|nov|dec)"
_day = r"[0-9]{1,2}"
_time = (
r"\s*\({t}:{t}(?::{t})?\)"
r"|\s+{t}:{t}(?::{t})"
).format(t=r"[0-9]{2}")
_yyyy_mm_dd = rf"{_year}-{_month}(?:-{_day})?(?:{_time})?"
_dd_mm_yy = rf"(?:{_day}-)?{_month}-{_yr}(?:{_time})?"
_id = r"[a-zA-Z][-_a-zA-Z0-9]*"
_qid = rf"{_id}\.{_id}" # qualified id: "table.column"
_TSQLLexer = util.Lexer(
tokens=[
(r"from", "FROM:from"),
(r"where", "WHERE:where"),
(r"report", "REPORT:report"),
(r"\*", "STAR:*"),
(r"\.", "DOT:."),
(r"==|=|!=|~|!~|<=|<|>=|>", "OP:a comparison operator"),
(r"&&|&|and", "AND:'&&', '&', or 'and'"),
(r"\|\||\||or", "OR:'||', '|', or 'or'"),
(r"!|not", "NOT:'!' or 'not'"),
(r"\(", "LPAREN:("),
(r"\)", "RPAREN:)"),
(r'"([^"\\]*(?:\\.[^"\\]*)*)"', "DQSTRING:a double-quoted string"),
(r"'([^'\\]*(?:\\.[^'\\]*)*)'", "SQSTRING:a single-quoted string"),
(_yyyy_mm_dd, "YYYYMMDD:a YYYY-MM-DD date"),
(_dd_mm_yy, "DDMMYY: a DD-MM-YY date"),
(r":today|now", "KWDATE:'now' or ':today'"),
(r"[+-]?\d+", "INT:an integer"),
(_qid, "QID:a qualified identifier"),
(_id, "ID:a simple identifier"),
(r"[^\s]", "UNEXPECTED"),
],
error_class=TSQLSyntaxError,
)
_FROM = _TSQLLexer.tokentypes.FROM
_WHERE = _TSQLLexer.tokentypes.WHERE
_REPORT = _TSQLLexer.tokentypes.REPORT
_STAR = _TSQLLexer.tokentypes.STAR
_DOT = _TSQLLexer.tokentypes.DOT
_OP = _TSQLLexer.tokentypes.OP
_AND = _TSQLLexer.tokentypes.AND
_OR = _TSQLLexer.tokentypes.OR
_NOT = _TSQLLexer.tokentypes.NOT
_LPAREN = _TSQLLexer.tokentypes.LPAREN
_RPAREN = _TSQLLexer.tokentypes.RPAREN
_DQSTRING = _TSQLLexer.tokentypes.DQSTRING
_SQSTRING = _TSQLLexer.tokentypes.SQSTRING
_YYYYMMDD = _TSQLLexer.tokentypes.YYYYMMDD
_DDMMYY = _TSQLLexer.tokentypes.DDMMYY
_KWDATE = _TSQLLexer.tokentypes.KWDATE
_INT = _TSQLLexer.tokentypes.INT
_QID = _TSQLLexer.tokentypes.QID
_ID = _TSQLLexer.tokentypes.ID
_UNEXPECTED = _TSQLLexer.tokentypes.UNEXPECTED
def _parse_query(querystring: str) -> dict:
querytype, _, querybody = querystring.lstrip().partition(" ")
querytype = querytype.lower()
if querytype in ("select", "retrieve"):
result = _parse_select(querybody)
else:
raise TSQLSyntaxError(f"'{querytype}' queries are not supported", lineno=1)
return result
def _parse_select(querystring: str) -> dict:
querystring += "." # just a sentinel to indicate the end of the query
lexer = _TSQLLexer.lex(querystring.splitlines())
projection = _parse_select_projection(lexer)
relations = _parse_select_from(lexer)
condition = _parse_select_where(lexer)
lexer.expect_type(_DOT)
if projection == ["*"] and not relations:
raise TSQLSyntaxError("'select *' requires a 'from' clause", text=querystring)
return {
"type": "select",
"projection": projection,
"relations": relations,
"condition": condition,
}
def _parse_select_projection(lexer: util.LookaheadLexer) -> list[str]:
typ, col_id = lexer.choice_type(_STAR, _QID, _ID)
projection = []
if typ in (_QID, _ID):
while col_id:
projection.append(col_id)
col_id = lexer.accept_type(_QID) or lexer.accept_type(_ID)
else:
projection.append(col_id)
return projection
def _parse_select_from(lexer: util.LookaheadLexer) -> list[str]:
relations = []
if lexer.accept_type(_FROM):
relation = lexer.expect_type(_ID)
while relation:
relations.append(relation)
relation = lexer.accept_type(_ID)
return relations
def _parse_select_where(lexer: util.LookaheadLexer) -> _Condition | None:
conditions: list[_Condition] = []
while lexer.accept_type(_WHERE):
conditions.append(_parse_condition_disjunction(lexer))
condition: _Condition | None = None
if len(conditions) == 1:
condition = conditions[0]
elif len(conditions) > 1:
condition = ("and", list(conditions))
return condition
def _parse_condition_disjunction(lexer: util.LookaheadLexer) -> _Condition:
conds = []
while True:
conds.append(_parse_condition_conjunction(lexer))
if not lexer.accept_type(_OR):
break
if len(conds) == 0:
raise TSQLSyntaxError("invalid query")
elif len(conds) == 1:
return conds[0]
else:
return ("or", list(conds))
def _parse_condition_conjunction(lexer: util.LookaheadLexer) -> _Condition:
conds: list[_Condition] = []
while True:
typ, token = lexer.choice_type(_NOT, _LPAREN, _QID, _ID)
if typ == _NOT:
conds.append(("not", _parse_condition_disjunction(lexer)))
elif typ == _LPAREN:
conds.append(_parse_condition_disjunction(lexer))
lexer.expect_type(_RPAREN)
elif typ in (_QID, _ID):
conds.append(_parse_condition_statement(token, lexer))
if not lexer.accept_type(_AND):
break
if len(conds) == 0:
raise TSQLSyntaxError("invalid query")
elif len(conds) == 1:
return conds[0]
else:
return ("and", list(conds))
def _parse_condition_statement(column: str, lexer: util.LookaheadLexer) -> _Comparison:
op = lexer.expect_type(_OP)
if op == "=":
op = "==" # normalize = to == (I think these are equivalent)
if op in ("~", "!~"):
typ, value = lexer.choice_type(_DQSTRING, _SQSTRING)
elif op in ("<", "<=", ">", ">="):
typ, value = lexer.choice_type(_INT, _YYYYMMDD, _DDMMYY, _KWDATE)
else: # must be == or !=
typ, value = lexer.choice_type(
_INT, _DQSTRING, _SQSTRING, _YYYYMMDD, _DDMMYY, _KWDATE
)
if typ == _INT:
value = int(value)
elif typ in (_YYYYMMDD, _DDMMYY, _KWDATE):
value = tsdb.cast(":date", value)
return (op, (column, value))