Fix skipped documents in entity scorer (#6137)

* Fix skipped documents in entity scorer

* Add back the skipping of unannotated entities

* Update spacy/scorer.py

* Use more specific NER scorer

* Fix import

* Fix get_ner_prf

* Add scorer

* Fix scorer

Co-authored-by: Ines Montani <ines@ines.io>
This commit is contained in:
Matthew Honnibal 2020-09-24 20:38:57 +02:00 committed by GitHub
parent 2abb4ba9db
commit 16475528f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 67 additions and 12 deletions

View File

@ -6,7 +6,7 @@ from .transition_parser cimport Parser
from ._parser_internals.ner cimport BiluoPushDown
from ..language import Language
from ..scorer import Scorer
from ..scorer import get_ner_prf, PRFScore
from ..training import validate_examples
@ -117,9 +117,18 @@ cdef class EntityRecognizer(Parser):
"""Score a batch of examples.
examples (Iterable[Example]): The examples to score.
RETURNS (Dict[str, Any]): The scores, produced by Scorer.score_spans.
RETURNS (Dict[str, Any]): The NER precision, recall and f-scores.
DOCS: https://nightly.spacy.io/api/entityrecognizer#score
"""
validate_examples(examples, "EntityRecognizer.score")
return Scorer.score_spans(examples, "ents", **kwargs)
score_per_type = get_ner_prf(examples)
totals = PRFScore()
for prf in score_per_type.values():
totals += prf
return {
"ents_p": totals.precision,
"ents_r": totals.recall,
"ents_f": totals.fscore,
"ents_per_type": {k: v.to_dict() for k, v in score_per_type.items()},
}

View File

@ -1,5 +1,6 @@
from typing import Optional, Iterable, Dict, Any, Callable, TYPE_CHECKING
import numpy as np
from collections import defaultdict
from .training import Example
from .tokens import Token, Doc, Span
@ -23,6 +24,19 @@ class PRFScore:
self.fp = 0
self.fn = 0
def __iadd__(self, other):
self.tp += other.tp
self.fp += other.fp
self.fn += other.fn
return self
def __add__(self, other):
return PRFScore(
tp=self.tp+other.tp,
fp=self.fp+other.fp,
fn=self.fn+other.fn
)
def score_set(self, cand: set, gold: set) -> None:
self.tp += len(cand.intersection(gold))
self.fp += len(cand - gold)
@ -295,20 +309,19 @@ class Scorer:
# Find all predidate labels, for all and per type
gold_spans = set()
pred_spans = set()
# Special case for ents:
# If we have missing values in the gold, we can't easily tell
# whether our NER predictions are true.
# It seems bad but it's what we've always done.
if attr == "ents" and not all(token.ent_iob != 0 for token in gold_doc):
continue
for span in getter(gold_doc, attr):
gold_span = (span.label_, span.start, span.end - 1)
gold_spans.add(gold_span)
gold_per_type[span.label_].add((span.label_, span.start, span.end - 1))
pred_per_type = {label: set() for label in labels}
for span in example.get_aligned_spans_x2y(getter(pred_doc, attr)):
pred_spans.add((span.label_, span.start, span.end - 1))
pred_per_type[span.label_].add((span.label_, span.start, span.end - 1))
align_x2y = example.alignment.x2y
for pred_span in getter(pred_doc, attr):
indices = align_x2y[pred_span.start : pred_span.end].dataXd.ravel()
if len(indices):
g_span = gold_doc[indices[0] : indices[-1]]
span = (pred_span.label_, indices[0], indices[-1])
pred_spans.add(span)
pred_per_type[pred_span.label_].add(span)
# Scores per label
for k, v in score_per_type.items():
if k in pred_per_type:
@ -613,6 +626,39 @@ class Scorer:
}
def get_ner_prf(examples: Iterable[Example]) -> Dict[str, PRFScore]:
"""Compute per-entity PRFScore objects for a sequence of examples. The
results are returned as a dictionary keyed by the entity type. You can
add the PRFScore objects to get micro-averaged total.
"""
scores = defaultdict(PRFScore)
for eg in examples:
if not eg.y.has_annotation("ENT_IOB"):
continue
golds = {(e.label_, e.start, e.end) for e in eg.y.ents}
align_x2y = eg.alignment.x2y
preds = set()
for pred_ent in eg.x.ents:
if pred_ent.label_ not in scores:
scores[pred_ent.label_] = PRFScore()
indices = align_x2y[pred_ent.start : pred_ent.end].dataXd.ravel()
if len(indices):
g_span = eg.y[indices[0] : indices[-1] + 1]
# Check we aren't missing annotation on this span. If so,
# our prediction is neither right nor wrong, we just
# ignore it.
if all(token.ent_iob != 0 for token in g_span):
key = (pred_ent.label_, indices[0], indices[-1] + 1)
if key in golds:
scores[pred_ent.label_].tp += 1
golds.remove(key)
else:
scores[pred_ent.label_].fp += 1
for label, start, end in golds:
scores[label].fn += 1
return scores
#############################################################################
#
# The following implementation of roc_auc_score() is adapted from