mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-11 17:56:30 +03:00
Fix skipped documents in entity scorer (#6137)
* Fix skipped documents in entity scorer * Add back the skipping of unannotated entities * Update spacy/scorer.py * Use more specific NER scorer * Fix import * Fix get_ner_prf * Add scorer * Fix scorer Co-authored-by: Ines Montani <ines@ines.io>
This commit is contained in:
parent
2abb4ba9db
commit
16475528f7
|
@ -6,7 +6,7 @@ from .transition_parser cimport Parser
|
|||
from ._parser_internals.ner cimport BiluoPushDown
|
||||
|
||||
from ..language import Language
|
||||
from ..scorer import Scorer
|
||||
from ..scorer import get_ner_prf, PRFScore
|
||||
from ..training import validate_examples
|
||||
|
||||
|
||||
|
@ -117,9 +117,18 @@ cdef class EntityRecognizer(Parser):
|
|||
"""Score a batch of examples.
|
||||
|
||||
examples (Iterable[Example]): The examples to score.
|
||||
RETURNS (Dict[str, Any]): The scores, produced by Scorer.score_spans.
|
||||
RETURNS (Dict[str, Any]): The NER precision, recall and f-scores.
|
||||
|
||||
DOCS: https://nightly.spacy.io/api/entityrecognizer#score
|
||||
"""
|
||||
validate_examples(examples, "EntityRecognizer.score")
|
||||
return Scorer.score_spans(examples, "ents", **kwargs)
|
||||
score_per_type = get_ner_prf(examples)
|
||||
totals = PRFScore()
|
||||
for prf in score_per_type.values():
|
||||
totals += prf
|
||||
return {
|
||||
"ents_p": totals.precision,
|
||||
"ents_r": totals.recall,
|
||||
"ents_f": totals.fscore,
|
||||
"ents_per_type": {k: v.to_dict() for k, v in score_per_type.items()},
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from typing import Optional, Iterable, Dict, Any, Callable, TYPE_CHECKING
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
|
||||
from .training import Example
|
||||
from .tokens import Token, Doc, Span
|
||||
|
@ -23,6 +24,19 @@ class PRFScore:
|
|||
self.fp = 0
|
||||
self.fn = 0
|
||||
|
||||
def __iadd__(self, other):
|
||||
self.tp += other.tp
|
||||
self.fp += other.fp
|
||||
self.fn += other.fn
|
||||
return self
|
||||
|
||||
def __add__(self, other):
|
||||
return PRFScore(
|
||||
tp=self.tp+other.tp,
|
||||
fp=self.fp+other.fp,
|
||||
fn=self.fn+other.fn
|
||||
)
|
||||
|
||||
def score_set(self, cand: set, gold: set) -> None:
|
||||
self.tp += len(cand.intersection(gold))
|
||||
self.fp += len(cand - gold)
|
||||
|
@ -295,20 +309,19 @@ class Scorer:
|
|||
# Find all predidate labels, for all and per type
|
||||
gold_spans = set()
|
||||
pred_spans = set()
|
||||
# Special case for ents:
|
||||
# If we have missing values in the gold, we can't easily tell
|
||||
# whether our NER predictions are true.
|
||||
# It seems bad but it's what we've always done.
|
||||
if attr == "ents" and not all(token.ent_iob != 0 for token in gold_doc):
|
||||
continue
|
||||
for span in getter(gold_doc, attr):
|
||||
gold_span = (span.label_, span.start, span.end - 1)
|
||||
gold_spans.add(gold_span)
|
||||
gold_per_type[span.label_].add((span.label_, span.start, span.end - 1))
|
||||
pred_per_type = {label: set() for label in labels}
|
||||
for span in example.get_aligned_spans_x2y(getter(pred_doc, attr)):
|
||||
pred_spans.add((span.label_, span.start, span.end - 1))
|
||||
pred_per_type[span.label_].add((span.label_, span.start, span.end - 1))
|
||||
align_x2y = example.alignment.x2y
|
||||
for pred_span in getter(pred_doc, attr):
|
||||
indices = align_x2y[pred_span.start : pred_span.end].dataXd.ravel()
|
||||
if len(indices):
|
||||
g_span = gold_doc[indices[0] : indices[-1]]
|
||||
span = (pred_span.label_, indices[0], indices[-1])
|
||||
pred_spans.add(span)
|
||||
pred_per_type[pred_span.label_].add(span)
|
||||
# Scores per label
|
||||
for k, v in score_per_type.items():
|
||||
if k in pred_per_type:
|
||||
|
@ -613,6 +626,39 @@ class Scorer:
|
|||
}
|
||||
|
||||
|
||||
def get_ner_prf(examples: Iterable[Example]) -> Dict[str, PRFScore]:
|
||||
"""Compute per-entity PRFScore objects for a sequence of examples. The
|
||||
results are returned as a dictionary keyed by the entity type. You can
|
||||
add the PRFScore objects to get micro-averaged total.
|
||||
"""
|
||||
scores = defaultdict(PRFScore)
|
||||
for eg in examples:
|
||||
if not eg.y.has_annotation("ENT_IOB"):
|
||||
continue
|
||||
golds = {(e.label_, e.start, e.end) for e in eg.y.ents}
|
||||
align_x2y = eg.alignment.x2y
|
||||
preds = set()
|
||||
for pred_ent in eg.x.ents:
|
||||
if pred_ent.label_ not in scores:
|
||||
scores[pred_ent.label_] = PRFScore()
|
||||
indices = align_x2y[pred_ent.start : pred_ent.end].dataXd.ravel()
|
||||
if len(indices):
|
||||
g_span = eg.y[indices[0] : indices[-1] + 1]
|
||||
# Check we aren't missing annotation on this span. If so,
|
||||
# our prediction is neither right nor wrong, we just
|
||||
# ignore it.
|
||||
if all(token.ent_iob != 0 for token in g_span):
|
||||
key = (pred_ent.label_, indices[0], indices[-1] + 1)
|
||||
if key in golds:
|
||||
scores[pred_ent.label_].tp += 1
|
||||
golds.remove(key)
|
||||
else:
|
||||
scores[pred_ent.label_].fp += 1
|
||||
for label, start, end in golds:
|
||||
scores[label].fn += 1
|
||||
return scores
|
||||
|
||||
|
||||
#############################################################################
|
||||
#
|
||||
# The following implementation of roc_auc_score() is adapted from
|
||||
|
|
Loading…
Reference in New Issue
Block a user