"""
Elementary Dependency Matching
"""
__all__ = ["compute"]
import logging
from collections import Counter
from collections.abc import Iterable
from itertools import zip_longest
from typing import Any, NamedTuple, TypeVar
# Default modules need to import the PyDelphin version
from delphin.__about__ import __version__ # noqa: F401
from delphin.dmrs import DMRS
from delphin.eds import EDS
from delphin.sembase import Predication
logging.basicConfig(level=logging.ERROR)
logger = logging.getLogger(__name__)
SR = TypeVar("SR", EDS, DMRS)
_Span = tuple[int, int]
_Triple = tuple[_Span, str, Any]
class _Count(NamedTuple):
gold: int
test: int
both: int
def add(self, other: "_Count") -> "_Count":
return _Count(
self.gold + other.gold,
self.test + other.test,
self.both + other.both,
)
class _Match(NamedTuple):
name: _Count
argument: _Count
property: _Count
constant: _Count
top: _Count
def add(self, other: "_Match") -> "_Match":
return _Match(
self.name.add(other.name),
self.argument.add(other.argument),
self.property.add(other.property),
self.constant.add(other.constant),
self.top.add(other.top),
)
class _Score(NamedTuple):
precision: float
recall: float
fscore: float
def _span(node: Predication) -> _Span:
"""Return the Lnk span of a Node as a (cfrom, cto) tuple."""
return (node.cfrom, node.cto)
def _names(sr: SR) -> list[_Triple]:
"""Return the list of name (predicate) triples for *sr*."""
triples = []
for node in sr.nodes:
# the None is just a placeholder for type checking
triples.append((_span(node), node.predicate, None))
return triples
def _arguments(sr: SR) -> list[_Triple]:
"""Return the list of argument triples for *sr*."""
triples = []
args = sr.arguments()
for node in sr.nodes:
source_span = _span(node)
for role, target in args[node.id]:
if target in sr:
triples.append((source_span, role, _span(sr[target])))
return triples
def _properties(sr: SR) -> list[_Triple]:
"""Return the list of property triples for *sr*."""
triples = []
for node in sr.nodes:
node_span = _span(node)
for feature, value in node.properties.items():
triples.append((node_span, feature, value))
return triples
def _constants(sr: SR) -> list[_Triple]:
"""Return the list of constant (CARG) triples for *sr*."""
triples = []
for node in sr.nodes:
if node.carg:
triples.append((_span(node), "carg", node.carg))
return triples
def _match(gold: SR, test: SR) -> _Match:
"""
Return the counts of *gold* and *test* triples for all categories.
The counts are a list of lists of counts as follows::
# gold test both
[
[gn, tn, bn], # name counts
[ga, ta, ba], # argument counts
[gp, tp, bp], # property counts
[gc, tc, bc], # constant counts
[gt, tt, bt], # top counts
]
"""
gold_top = 1 if gold.top in gold else 0
test_top = 1 if test.top in test else 0
if gold_top and test_top and _span(gold[gold.top]) == _span(test[test.top]):
both_top = 1
else:
both_top = 0
top_count = _Count(gold_top, test_top, both_top)
return _Match(
_count(_names, gold, test),
_count(_arguments, gold, test),
_count(_properties, gold, test),
_count(_constants, gold, test),
top_count,
)
def _count(func, gold, test) -> _Count:
"""
Return the counts of *gold* and *test* triples from *func*.
"""
gold_triples = func(gold)
test_triples = func(test)
c1 = Counter(gold_triples)
c2 = Counter(test_triples)
both = sum(min(c1[t], c2[t]) for t in c1 if t in c2)
return _Count(len(gold_triples), len(test_triples), both)
def _accumulate(
golds: Iterable[SR | None],
tests: Iterable[SR | None],
ignore_missing_gold: bool,
ignore_missing_test: bool,
) -> _Match:
"""
Sum the matches for all *golds* and *tests*.
"""
info = logger.isEnabledFor(logging.INFO)
totals = _Match(
_Count(0, 0, 0),
_Count(0, 0, 0),
_Count(0, 0, 0),
_Count(0, 0, 0),
_Count(0, 0, 0),
)
for i, (gold, test) in enumerate(zip_longest(golds, tests), 1):
logger.info("pair %d", i)
if gold is None and test is None:
logger.info("no gold or test representation; skipping")
continue
elif gold is None:
assert test is not None
if ignore_missing_gold:
logger.info("no gold representation; skipping")
continue
else:
logger.debug("missing gold representation")
gold = type(test)()
elif test is None:
assert gold is not None
if ignore_missing_test:
logger.info("no test representation; skipping")
continue
else:
logger.debug("missing test representation")
test = type(gold)()
result = _match(gold, test)
if info:
logger.info(" gold\ttest\tboth\tPrec.\tRec.\tF-Score")
fmt = "%11s: %4d\t%4d\t%4d\t%5.3f\t%5.3f\t%5.3f"
logger.info(fmt, "Names", *result.name, *_prf(*result.name))
logger.info(fmt, "Arguments", *result.argument, *_prf(*result.argument))
logger.info(fmt, "Properties", *result.property, *_prf(*result.property))
logger.info(fmt, "Constants", *result.constant, *_prf(*result.constant))
logger.info(fmt, "Tops", *result.top, *_prf(*result.top))
totals = totals.add(result)
return totals
[docs]
def compute(
golds: Iterable[SR | None],
tests: Iterable[SR | None],
name_weight: float = 1.0,
argument_weight: float = 1.0,
property_weight: float = 1.0,
constant_weight: float = 1.0,
top_weight: float = 1.0,
ignore_missing_gold: bool = False,
ignore_missing_test: bool = False,
) -> _Score:
"""
Compute the precision, recall, and f-score for all pairs.
The *golds* and *tests* arguments are iterables of PyDelphin
dependency representations, such as EDS or DMRS. The precision and
recall are computed as follows:
- Precision = *matching_triples* / *test_triples*
- Recall = *matching_triples* / *gold_triples*
- F-score = 2 * (Precision * Recall) / (Precision + Recall)
Arguments:
golds: gold semantic representations
tests: test semantic representations
name_weight: weight applied to the name score
argument_weight: weight applied to the argument score
property_weight: weight applied to the property score
constant_weight: weight applied to the constant score
top_weight: weight applied to the top score
ignore_missing_gold: if ``True``, don't count missing gold
items as mismatches
ignore_missing_test: if ``True``, don't count missing test
items as mismatches
Returns:
A tuple of (precision, recall, f-score)
"""
logger.info(
"Computing EDM (N=%g, A=%g, P=%g, T=%g)",
name_weight,
argument_weight,
property_weight,
top_weight,
)
totals: _Match = _accumulate(golds, tests, ignore_missing_gold, ignore_missing_test)
gold_total = (
totals.name.gold * name_weight
+ totals.argument.gold * argument_weight
+ totals.property.gold * property_weight
+ totals.constant.gold * constant_weight
+ totals.top.gold * top_weight
)
test_total = (
totals.name.test * name_weight
+ totals.argument.test * argument_weight
+ totals.property.test * property_weight
+ totals.constant.test * constant_weight
+ totals.top.test * top_weight
)
both_total = (
totals.name.both * name_weight
+ totals.argument.both * argument_weight
+ totals.property.both * property_weight
+ totals.constant.both * constant_weight
+ totals.top.both * top_weight
)
return _prf(gold_total, test_total, both_total)
def _prf(g, t, b) -> _Score:
if t == 0 or g == 0 or b == 0:
return _Score(0.0, 0.0, 0.0)
else:
p = b / t
r = b / g
f = 2 * (p * r) / (p + r)
return _Score(p, r, f)