mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-11 17:56:30 +03:00
Fix skipped documents in entity scorer (#6137)
* Fix skipped documents in entity scorer * Add back the skipping of unannotated entities * Update spacy/scorer.py * Use more specific NER scorer * Fix import * Fix get_ner_prf * Add scorer * Fix scorer Co-authored-by: Ines Montani <ines@ines.io>
This commit is contained in:
parent
2abb4ba9db
commit
16475528f7
|
@ -6,7 +6,7 @@ from .transition_parser cimport Parser
|
||||||
from ._parser_internals.ner cimport BiluoPushDown
|
from ._parser_internals.ner cimport BiluoPushDown
|
||||||
|
|
||||||
from ..language import Language
|
from ..language import Language
|
||||||
from ..scorer import Scorer
|
from ..scorer import get_ner_prf, PRFScore
|
||||||
from ..training import validate_examples
|
from ..training import validate_examples
|
||||||
|
|
||||||
|
|
||||||
|
@ -117,9 +117,18 @@ cdef class EntityRecognizer(Parser):
|
||||||
"""Score a batch of examples.
|
"""Score a batch of examples.
|
||||||
|
|
||||||
examples (Iterable[Example]): The examples to score.
|
examples (Iterable[Example]): The examples to score.
|
||||||
RETURNS (Dict[str, Any]): The scores, produced by Scorer.score_spans.
|
RETURNS (Dict[str, Any]): The NER precision, recall and f-scores.
|
||||||
|
|
||||||
DOCS: https://nightly.spacy.io/api/entityrecognizer#score
|
DOCS: https://nightly.spacy.io/api/entityrecognizer#score
|
||||||
"""
|
"""
|
||||||
validate_examples(examples, "EntityRecognizer.score")
|
validate_examples(examples, "EntityRecognizer.score")
|
||||||
return Scorer.score_spans(examples, "ents", **kwargs)
|
score_per_type = get_ner_prf(examples)
|
||||||
|
totals = PRFScore()
|
||||||
|
for prf in score_per_type.values():
|
||||||
|
totals += prf
|
||||||
|
return {
|
||||||
|
"ents_p": totals.precision,
|
||||||
|
"ents_r": totals.recall,
|
||||||
|
"ents_f": totals.fscore,
|
||||||
|
"ents_per_type": {k: v.to_dict() for k, v in score_per_type.items()},
|
||||||
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from typing import Optional, Iterable, Dict, Any, Callable, TYPE_CHECKING
|
from typing import Optional, Iterable, Dict, Any, Callable, TYPE_CHECKING
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
from .training import Example
|
from .training import Example
|
||||||
from .tokens import Token, Doc, Span
|
from .tokens import Token, Doc, Span
|
||||||
|
@ -23,6 +24,19 @@ class PRFScore:
|
||||||
self.fp = 0
|
self.fp = 0
|
||||||
self.fn = 0
|
self.fn = 0
|
||||||
|
|
||||||
|
def __iadd__(self, other):
|
||||||
|
self.tp += other.tp
|
||||||
|
self.fp += other.fp
|
||||||
|
self.fn += other.fn
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __add__(self, other):
|
||||||
|
return PRFScore(
|
||||||
|
tp=self.tp+other.tp,
|
||||||
|
fp=self.fp+other.fp,
|
||||||
|
fn=self.fn+other.fn
|
||||||
|
)
|
||||||
|
|
||||||
def score_set(self, cand: set, gold: set) -> None:
|
def score_set(self, cand: set, gold: set) -> None:
|
||||||
self.tp += len(cand.intersection(gold))
|
self.tp += len(cand.intersection(gold))
|
||||||
self.fp += len(cand - gold)
|
self.fp += len(cand - gold)
|
||||||
|
@ -295,20 +309,19 @@ class Scorer:
|
||||||
# Find all predidate labels, for all and per type
|
# Find all predidate labels, for all and per type
|
||||||
gold_spans = set()
|
gold_spans = set()
|
||||||
pred_spans = set()
|
pred_spans = set()
|
||||||
# Special case for ents:
|
|
||||||
# If we have missing values in the gold, we can't easily tell
|
|
||||||
# whether our NER predictions are true.
|
|
||||||
# It seems bad but it's what we've always done.
|
|
||||||
if attr == "ents" and not all(token.ent_iob != 0 for token in gold_doc):
|
|
||||||
continue
|
|
||||||
for span in getter(gold_doc, attr):
|
for span in getter(gold_doc, attr):
|
||||||
gold_span = (span.label_, span.start, span.end - 1)
|
gold_span = (span.label_, span.start, span.end - 1)
|
||||||
gold_spans.add(gold_span)
|
gold_spans.add(gold_span)
|
||||||
gold_per_type[span.label_].add((span.label_, span.start, span.end - 1))
|
gold_per_type[span.label_].add((span.label_, span.start, span.end - 1))
|
||||||
pred_per_type = {label: set() for label in labels}
|
pred_per_type = {label: set() for label in labels}
|
||||||
for span in example.get_aligned_spans_x2y(getter(pred_doc, attr)):
|
align_x2y = example.alignment.x2y
|
||||||
pred_spans.add((span.label_, span.start, span.end - 1))
|
for pred_span in getter(pred_doc, attr):
|
||||||
pred_per_type[span.label_].add((span.label_, span.start, span.end - 1))
|
indices = align_x2y[pred_span.start : pred_span.end].dataXd.ravel()
|
||||||
|
if len(indices):
|
||||||
|
g_span = gold_doc[indices[0] : indices[-1]]
|
||||||
|
span = (pred_span.label_, indices[0], indices[-1])
|
||||||
|
pred_spans.add(span)
|
||||||
|
pred_per_type[pred_span.label_].add(span)
|
||||||
# Scores per label
|
# Scores per label
|
||||||
for k, v in score_per_type.items():
|
for k, v in score_per_type.items():
|
||||||
if k in pred_per_type:
|
if k in pred_per_type:
|
||||||
|
@ -613,6 +626,39 @@ class Scorer:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_ner_prf(examples: Iterable[Example]) -> Dict[str, PRFScore]:
|
||||||
|
"""Compute per-entity PRFScore objects for a sequence of examples. The
|
||||||
|
results are returned as a dictionary keyed by the entity type. You can
|
||||||
|
add the PRFScore objects to get micro-averaged total.
|
||||||
|
"""
|
||||||
|
scores = defaultdict(PRFScore)
|
||||||
|
for eg in examples:
|
||||||
|
if not eg.y.has_annotation("ENT_IOB"):
|
||||||
|
continue
|
||||||
|
golds = {(e.label_, e.start, e.end) for e in eg.y.ents}
|
||||||
|
align_x2y = eg.alignment.x2y
|
||||||
|
preds = set()
|
||||||
|
for pred_ent in eg.x.ents:
|
||||||
|
if pred_ent.label_ not in scores:
|
||||||
|
scores[pred_ent.label_] = PRFScore()
|
||||||
|
indices = align_x2y[pred_ent.start : pred_ent.end].dataXd.ravel()
|
||||||
|
if len(indices):
|
||||||
|
g_span = eg.y[indices[0] : indices[-1] + 1]
|
||||||
|
# Check we aren't missing annotation on this span. If so,
|
||||||
|
# our prediction is neither right nor wrong, we just
|
||||||
|
# ignore it.
|
||||||
|
if all(token.ent_iob != 0 for token in g_span):
|
||||||
|
key = (pred_ent.label_, indices[0], indices[-1] + 1)
|
||||||
|
if key in golds:
|
||||||
|
scores[pred_ent.label_].tp += 1
|
||||||
|
golds.remove(key)
|
||||||
|
else:
|
||||||
|
scores[pred_ent.label_].fp += 1
|
||||||
|
for label, start, end in golds:
|
||||||
|
scores[label].fn += 1
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
#############################################################################
|
#############################################################################
|
||||||
#
|
#
|
||||||
# The following implementation of roc_auc_score() is adapted from
|
# The following implementation of roc_auc_score() is adapted from
|
||||||
|
|
Loading…
Reference in New Issue
Block a user