Source code for delphin.edm

"""
Elementary Dependency Matching
"""

__all__ = ["compute"]

import logging
from collections import Counter
from collections.abc import Iterable
from itertools import zip_longest
from typing import Any, NamedTuple, TypeVar

# Default modules need to import the PyDelphin version
from delphin.__about__ import __version__  # noqa: F401
from delphin.dmrs import DMRS
from delphin.eds import EDS
from delphin.sembase import Predication

logging.basicConfig(level=logging.ERROR)
logger = logging.getLogger(__name__)

SR = TypeVar("SR", EDS, DMRS)
_Span = tuple[int, int]
_Triple = tuple[_Span, str, Any]


class _Count(NamedTuple):
    gold: int
    test: int
    both: int

    def add(self, other: "_Count") -> "_Count":
        return _Count(
            self.gold + other.gold,
            self.test + other.test,
            self.both + other.both,
        )


class _Match(NamedTuple):
    name: _Count
    argument: _Count
    property: _Count
    constant: _Count
    top: _Count

    def add(self, other: "_Match") -> "_Match":
        return _Match(
            self.name.add(other.name),
            self.argument.add(other.argument),
            self.property.add(other.property),
            self.constant.add(other.constant),
            self.top.add(other.top),
        )


class _Score(NamedTuple):
    precision: float
    recall: float
    fscore: float


def _span(node: Predication) -> _Span:
    """Return the Lnk span of a Node as a (cfrom, cto) tuple."""
    return (node.cfrom, node.cto)


def _names(sr: SR) -> list[_Triple]:
    """Return the list of name (predicate) triples for *sr*."""
    triples = []
    for node in sr.nodes:
        # the None is just a placeholder for type checking
        triples.append((_span(node), node.predicate, None))
    return triples


def _arguments(sr: SR) -> list[_Triple]:
    """Return the list of argument triples for *sr*."""
    triples = []
    args = sr.arguments()
    for node in sr.nodes:
        source_span = _span(node)
        for role, target in args[node.id]:
            if target in sr:
                triples.append((source_span, role, _span(sr[target])))
    return triples


def _properties(sr: SR) -> list[_Triple]:
    """Return the list of property triples for *sr*."""
    triples = []
    for node in sr.nodes:
        node_span = _span(node)
        for feature, value in node.properties.items():
            triples.append((node_span, feature, value))
    return triples


def _constants(sr: SR) -> list[_Triple]:
    """Return the list of constant (CARG) triples for *sr*."""
    triples = []
    for node in sr.nodes:
        if node.carg:
            triples.append((_span(node), "carg", node.carg))
    return triples


def _match(gold: SR, test: SR) -> _Match:
    """
    Return the counts of *gold* and *test* triples for all categories.

    The counts are a list of lists of counts as follows::

        # gold test both
        [
            [gn, tn, bn],  # name counts
            [ga, ta, ba],  # argument counts
            [gp, tp, bp],  # property counts
            [gc, tc, bc],  # constant counts
            [gt, tt, bt],  # top counts
        ]
    """
    gold_top = 1 if gold.top in gold else 0
    test_top = 1 if test.top in test else 0
    if gold_top and test_top and _span(gold[gold.top]) == _span(test[test.top]):
        both_top = 1
    else:
        both_top = 0
    top_count = _Count(gold_top, test_top, both_top)

    return _Match(
        _count(_names, gold, test),
        _count(_arguments, gold, test),
        _count(_properties, gold, test),
        _count(_constants, gold, test),
        top_count,
    )


def _count(func, gold, test) -> _Count:
    """
    Return the counts of *gold* and *test* triples from *func*.
    """
    gold_triples = func(gold)
    test_triples = func(test)
    c1 = Counter(gold_triples)
    c2 = Counter(test_triples)
    both = sum(min(c1[t], c2[t]) for t in c1 if t in c2)
    return _Count(len(gold_triples), len(test_triples), both)


def _accumulate(
    golds: Iterable[SR | None],
    tests: Iterable[SR | None],
    ignore_missing_gold: bool,
    ignore_missing_test: bool,
) -> _Match:
    """
    Sum the matches for all *golds* and *tests*.
    """
    info = logger.isEnabledFor(logging.INFO)
    totals = _Match(
        _Count(0, 0, 0),
        _Count(0, 0, 0),
        _Count(0, 0, 0),
        _Count(0, 0, 0),
        _Count(0, 0, 0),
    )

    for i, (gold, test) in enumerate(zip_longest(golds, tests), 1):
        logger.info("pair %d", i)

        if gold is None and test is None:
            logger.info("no gold or test representation; skipping")
            continue
        elif gold is None:
            assert test is not None
            if ignore_missing_gold:
                logger.info("no gold representation; skipping")
                continue
            else:
                logger.debug("missing gold representation")
                gold = type(test)()
        elif test is None:
            assert gold is not None
            if ignore_missing_test:
                logger.info("no test representation; skipping")
                continue
            else:
                logger.debug("missing test representation")
                test = type(gold)()

        result = _match(gold, test)

        if info:
            logger.info("             gold\ttest\tboth\tPrec.\tRec.\tF-Score")
            fmt = "%11s: %4d\t%4d\t%4d\t%5.3f\t%5.3f\t%5.3f"
            logger.info(fmt, "Names", *result.name, *_prf(*result.name))
            logger.info(fmt, "Arguments", *result.argument, *_prf(*result.argument))
            logger.info(fmt, "Properties", *result.property, *_prf(*result.property))
            logger.info(fmt, "Constants", *result.constant, *_prf(*result.constant))
            logger.info(fmt, "Tops", *result.top, *_prf(*result.top))

        totals = totals.add(result)

    return totals


[docs] def compute( golds: Iterable[SR | None], tests: Iterable[SR | None], name_weight: float = 1.0, argument_weight: float = 1.0, property_weight: float = 1.0, constant_weight: float = 1.0, top_weight: float = 1.0, ignore_missing_gold: bool = False, ignore_missing_test: bool = False, ) -> _Score: """ Compute the precision, recall, and f-score for all pairs. The *golds* and *tests* arguments are iterables of PyDelphin dependency representations, such as EDS or DMRS. The precision and recall are computed as follows: - Precision = *matching_triples* / *test_triples* - Recall = *matching_triples* / *gold_triples* - F-score = 2 * (Precision * Recall) / (Precision + Recall) Arguments: golds: gold semantic representations tests: test semantic representations name_weight: weight applied to the name score argument_weight: weight applied to the argument score property_weight: weight applied to the property score constant_weight: weight applied to the constant score top_weight: weight applied to the top score ignore_missing_gold: if ``True``, don't count missing gold items as mismatches ignore_missing_test: if ``True``, don't count missing test items as mismatches Returns: A tuple of (precision, recall, f-score) """ logger.info( "Computing EDM (N=%g, A=%g, P=%g, T=%g)", name_weight, argument_weight, property_weight, top_weight, ) totals: _Match = _accumulate(golds, tests, ignore_missing_gold, ignore_missing_test) gold_total = ( totals.name.gold * name_weight + totals.argument.gold * argument_weight + totals.property.gold * property_weight + totals.constant.gold * constant_weight + totals.top.gold * top_weight ) test_total = ( totals.name.test * name_weight + totals.argument.test * argument_weight + totals.property.test * property_weight + totals.constant.test * constant_weight + totals.top.test * top_weight ) both_total = ( totals.name.both * name_weight + totals.argument.both * argument_weight + totals.property.both * property_weight + totals.constant.both * constant_weight + totals.top.both * top_weight ) return _prf(gold_total, test_total, both_total)
def _prf(g, t, b) -> _Score: if t == 0 or g == 0 or b == 0: return _Score(0.0, 0.0, 0.0) else: p = b / t r = b / g f = 2 * (p * r) / (p + r) return _Score(p, r, f)