Source code for delphin.edm


"""
Elementary Dependency Matching
"""

import logging
from collections import Counter
from itertools import zip_longest
from typing import Any, Iterable, List, NamedTuple, Optional, Tuple, Union

# Default modules need to import the PyDelphin version
from delphin.__about__ import __version__  # noqa: F401
from delphin.dmrs import DMRS
from delphin.eds import EDS
from delphin.lnk import LnkMixin

logging.basicConfig(level=logging.ERROR)
logger = logging.getLogger(__name__)

_SemanticRepresentation = Union[EDS, DMRS]
_Span = Tuple[int, int]
_Triple = Tuple[_Span, str, Any]


class _Count(NamedTuple):
    gold: int
    test: int
    both: int

    def add(self, other: '_Count') -> '_Count':
        return _Count(self.gold + other.gold,
                      self.test + other.test,
                      self.both + other.both)


class _Match(NamedTuple):
    name: _Count
    argument: _Count
    property: _Count
    constant: _Count
    top: _Count

    def add(self, other: '_Match') -> '_Match':
        return _Match(self.name.add(other.name),
                      self.argument.add(other.argument),
                      self.property.add(other.property),
                      self.constant.add(other.constant),
                      self.top.add(other.top))


class _Score(NamedTuple):
    precision: float
    recall: float
    fscore: float


def _span(node: LnkMixin) -> _Span:
    """Return the Lnk span of a Node as a (cfrom, cto) tuple."""
    return (node.cfrom, node.cto)


def _names(sr: _SemanticRepresentation) -> List[_Triple]:
    """Return the list of name (predicate) triples for *sr*."""
    triples = []
    for node in sr.nodes:
        # the None is just a placeholder for type checking
        triples.append((_span(node), node.predicate, None))
    return triples


def _arguments(sr: _SemanticRepresentation) -> List[_Triple]:
    """Return the list of argument triples for *sr*."""
    triples = []
    args = sr.arguments()
    for node in sr.nodes:
        source_span = _span(node)
        for role, target in args[node.id]:
            if target in sr:
                triples.append((source_span, role, _span(sr[target])))
    return triples


def _properties(sr: _SemanticRepresentation) -> List[_Triple]:
    """Return the list of property triples for *sr*."""
    triples = []
    for node in sr.nodes:
        node_span = _span(node)
        for feature, value in node.properties.items():
            triples.append((node_span, feature, value))
    return triples


def _constants(sr: _SemanticRepresentation) -> List[_Triple]:
    """Return the list of constant (CARG) triples for *sr*."""
    triples = []
    for node in sr.nodes:
        if node.carg:
            triples.append((_span(node), 'carg', node.carg))
    return triples


def _match(gold: _SemanticRepresentation,
           test: _SemanticRepresentation) -> _Match:
    """
    Return the counts of *gold* and *test* triples for all categories.

    The counts are a list of lists of counts as follows::

        # gold test both
        [[gn,  tn,  bn],  # name counts
         [ga,  ta,  ba],  # argument counts
         [gp,  tp,  bp],  # property counts
         [gc,  tc,  bc],  # constant counts
         [gt,  tt,  bt]]  # top counts
    """
    gold_top = 1 if gold.top in gold else 0
    test_top = 1 if test.top in test else 0
    if (
        gold_top
        and test_top
        and _span(gold[gold.top]) == _span(test[test.top])
    ):
        both_top = 1
    else:
        both_top = 0
    top_count = _Count(gold_top, test_top, both_top)

    return _Match(_count(_names, gold, test),
                  _count(_arguments, gold, test),
                  _count(_properties, gold, test),
                  _count(_constants, gold, test),
                  top_count)


def _count(func, gold, test) -> _Count:
    """
    Return the counts of *gold* and *test* triples from *func*.
    """
    gold_triples = func(gold)
    test_triples = func(test)
    c1 = Counter(gold_triples)
    c2 = Counter(test_triples)
    both = sum(min(c1[t], c2[t]) for t in c1 if t in c2)
    return _Count(len(gold_triples), len(test_triples), both)


def _accumulate(
    golds: Iterable[Optional[_SemanticRepresentation]],
    tests: Iterable[Optional[_SemanticRepresentation]],
    ignore_missing_gold: bool,
    ignore_missing_test: bool,
) -> _Match:
    """
    Sum the matches for all *golds* and *tests*.
    """
    info = logger.isEnabledFor(logging.INFO)
    totals = _Match(_Count(0, 0, 0),
                    _Count(0, 0, 0),
                    _Count(0, 0, 0),
                    _Count(0, 0, 0),
                    _Count(0, 0, 0))

    for i, (gold, test) in enumerate(zip_longest(golds, tests), 1):
        logger.info('pair %d', i)

        if gold is None and test is None:
            logger.info('no gold or test representation; skipping')
            continue
        elif gold is None:
            if ignore_missing_gold:
                logger.info('no gold representation; skipping')
                continue
            else:
                logger.debug('missing gold representation')
                gold = EDS()
        elif test is None:
            if ignore_missing_test:
                logger.info('no test representation; skipping')
                continue
            else:
                logger.debug('missing test representation')
                test = EDS()

        assert isinstance(gold, (EDS, DMRS)) and isinstance(test, (EDS, DMRS))
        result = _match(gold, test)

        if info:
            logger.info(
                '             gold\ttest\tboth\tPrec.\tRec.\tF-Score')
            fmt = '%11s: %4d\t%4d\t%4d\t%5.3f\t%5.3f\t%5.3f'
            logger.info(
                fmt, 'Names', *result.name, *_prf(*result.name))
            logger.info(
                fmt, 'Arguments', *result.argument, *_prf(*result.argument))
            logger.info(
                fmt, 'Properties', *result.property, *_prf(*result.property))
            logger.info(
                fmt, 'Constants', *result.constant, *_prf(*result.constant))
            logger.info(
                fmt, 'Tops', *result.top, *_prf(*result.top))

        totals = totals.add(result)

    return totals


[docs] def compute(golds: Iterable[Optional[_SemanticRepresentation]], tests: Iterable[Optional[_SemanticRepresentation]], name_weight: float = 1.0, argument_weight: float = 1.0, property_weight: float = 1.0, constant_weight: float = 1.0, top_weight: float = 1.0, ignore_missing_gold: bool = False, ignore_missing_test: bool = False) -> _Score: """ Compute the precision, recall, and f-score for all pairs. The *golds* and *tests* arguments are iterables of PyDelphin dependency representations, such as EDS or DMRS. The precision and recall are computed as follows: - Precision = *matching_triples* / *test_triples* - Recall = *matching_triples* / *gold_triples* - F-score = 2 * (Precision * Recall) / (Precision + Recall) Arguments: golds: gold semantic representations tests: test semantic representations name_weight: weight applied to the name score argument_weight: weight applied to the argument score property_weight: weight applied to the property score constant_weight: weight applied to the constant score top_weight: weight applied to the top score ignore_missing_gold: if ``True``, don't count missing gold items as mismatches ignore_missing_test: if ``True``, don't count missing test items as mismatches Returns: A tuple of (precision, recall, f-score) """ logger.info('Computing EDM (N=%g, A=%g, P=%g, T=%g)', name_weight, argument_weight, property_weight, top_weight) totals: _Match = _accumulate( golds, tests, ignore_missing_gold, ignore_missing_test ) gold_total = (totals.name.gold * name_weight + totals.argument.gold * argument_weight + totals.property.gold * property_weight + totals.constant.gold * constant_weight + totals.top.gold * top_weight) test_total = (totals.name.test * name_weight + totals.argument.test * argument_weight + totals.property.test * property_weight + totals.constant.test * constant_weight + totals.top.test * top_weight) both_total = (totals.name.both * name_weight + totals.argument.both * argument_weight + totals.property.both * property_weight + totals.constant.both * constant_weight + totals.top.both * top_weight) return _prf(gold_total, test_total, both_total)
def _prf(g, t, b) -> _Score: if t == 0 or g == 0 or b == 0: return _Score(0.0, 0.0, 0.0) else: p = b / t r = b / g f = 2 * (p * r) / (p + r) return _Score(p, r, f)