From 16475528f735114370d2db48b576106b1a6451e5 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Thu, 24 Sep 2020 20:38:57 +0200 Subject: [PATCH] Fix skipped documents in entity scorer (#6137) * Fix skipped documents in entity scorer * Add back the skipping of unannotated entities * Update spacy/scorer.py * Use more specific NER scorer * Fix import * Fix get_ner_prf * Add scorer * Fix scorer Co-authored-by: Ines Montani --- spacy/pipeline/ner.pyx | 15 ++++++++-- spacy/scorer.py | 64 ++++++++++++++++++++++++++++++++++++------ 2 files changed, 67 insertions(+), 12 deletions(-) diff --git a/spacy/pipeline/ner.pyx b/spacy/pipeline/ner.pyx index c9b0a5031..fc0dda40d 100644 --- a/spacy/pipeline/ner.pyx +++ b/spacy/pipeline/ner.pyx @@ -6,7 +6,7 @@ from .transition_parser cimport Parser from ._parser_internals.ner cimport BiluoPushDown from ..language import Language -from ..scorer import Scorer +from ..scorer import get_ner_prf, PRFScore from ..training import validate_examples @@ -117,9 +117,18 @@ cdef class EntityRecognizer(Parser): """Score a batch of examples. examples (Iterable[Example]): The examples to score. - RETURNS (Dict[str, Any]): The scores, produced by Scorer.score_spans. + RETURNS (Dict[str, Any]): The NER precision, recall and f-scores. DOCS: https://nightly.spacy.io/api/entityrecognizer#score """ validate_examples(examples, "EntityRecognizer.score") - return Scorer.score_spans(examples, "ents", **kwargs) + score_per_type = get_ner_prf(examples) + totals = PRFScore() + for prf in score_per_type.values(): + totals += prf + return { + "ents_p": totals.precision, + "ents_r": totals.recall, + "ents_f": totals.fscore, + "ents_per_type": {k: v.to_dict() for k, v in score_per_type.items()}, + } diff --git a/spacy/scorer.py b/spacy/scorer.py index cd3b013cd..c1795847d 100644 --- a/spacy/scorer.py +++ b/spacy/scorer.py @@ -1,5 +1,6 @@ from typing import Optional, Iterable, Dict, Any, Callable, TYPE_CHECKING import numpy as np +from collections import defaultdict from .training import Example from .tokens import Token, Doc, Span @@ -23,6 +24,19 @@ class PRFScore: self.fp = 0 self.fn = 0 + def __iadd__(self, other): + self.tp += other.tp + self.fp += other.fp + self.fn += other.fn + return self + + def __add__(self, other): + return PRFScore( + tp=self.tp+other.tp, + fp=self.fp+other.fp, + fn=self.fn+other.fn + ) + def score_set(self, cand: set, gold: set) -> None: self.tp += len(cand.intersection(gold)) self.fp += len(cand - gold) @@ -295,20 +309,19 @@ class Scorer: # Find all predidate labels, for all and per type gold_spans = set() pred_spans = set() - # Special case for ents: - # If we have missing values in the gold, we can't easily tell - # whether our NER predictions are true. - # It seems bad but it's what we've always done. - if attr == "ents" and not all(token.ent_iob != 0 for token in gold_doc): - continue for span in getter(gold_doc, attr): gold_span = (span.label_, span.start, span.end - 1) gold_spans.add(gold_span) gold_per_type[span.label_].add((span.label_, span.start, span.end - 1)) pred_per_type = {label: set() for label in labels} - for span in example.get_aligned_spans_x2y(getter(pred_doc, attr)): - pred_spans.add((span.label_, span.start, span.end - 1)) - pred_per_type[span.label_].add((span.label_, span.start, span.end - 1)) + align_x2y = example.alignment.x2y + for pred_span in getter(pred_doc, attr): + indices = align_x2y[pred_span.start : pred_span.end].dataXd.ravel() + if len(indices): + g_span = gold_doc[indices[0] : indices[-1]] + span = (pred_span.label_, indices[0], indices[-1]) + pred_spans.add(span) + pred_per_type[pred_span.label_].add(span) # Scores per label for k, v in score_per_type.items(): if k in pred_per_type: @@ -613,6 +626,39 @@ class Scorer: } +def get_ner_prf(examples: Iterable[Example]) -> Dict[str, PRFScore]: + """Compute per-entity PRFScore objects for a sequence of examples. The + results are returned as a dictionary keyed by the entity type. You can + add the PRFScore objects to get micro-averaged total. + """ + scores = defaultdict(PRFScore) + for eg in examples: + if not eg.y.has_annotation("ENT_IOB"): + continue + golds = {(e.label_, e.start, e.end) for e in eg.y.ents} + align_x2y = eg.alignment.x2y + preds = set() + for pred_ent in eg.x.ents: + if pred_ent.label_ not in scores: + scores[pred_ent.label_] = PRFScore() + indices = align_x2y[pred_ent.start : pred_ent.end].dataXd.ravel() + if len(indices): + g_span = eg.y[indices[0] : indices[-1] + 1] + # Check we aren't missing annotation on this span. If so, + # our prediction is neither right nor wrong, we just + # ignore it. + if all(token.ent_iob != 0 for token in g_span): + key = (pred_ent.label_, indices[0], indices[-1] + 1) + if key in golds: + scores[pred_ent.label_].tp += 1 + golds.remove(key) + else: + scores[pred_ent.label_].fp += 1 + for label, start, end in golds: + scores[label].fn += 1 + return scores + + ############################################################################# # # The following implementation of roc_auc_score() is adapted from