mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
Handle missing reference values in scorer (#6286)
* Handle missing reference values in scorer Handle missing values in reference doc during scoring where it is possible to detect an unset state for the attribute. If no reference docs contain annotation, `None` is returned instead of a score. `spacy evaluate` displays `-` for missing scores and the missing scores are saved as `None`/`null` in the metrics. Attributes without unset states: * `token.head`: relies on `token.dep` to recognize unset values * `doc.cats`: unable to handle missing annotation Additional changes: * add optional `has_annotation` check to `score_scans` to replace `doc.sents` hack * update `score_token_attr_per_feat` to handle missing and empty morph representations * fix bug in `Doc.has_annotation` for normalization of `IS_SENT_START` vs. `SENT_START` * Fix import * Update return types
This commit is contained in:
parent
5d2cb86c34
commit
a4b32b9552
|
@ -98,10 +98,13 @@ def evaluate(
|
|||
if key in scores:
|
||||
if key == "cats_score":
|
||||
metric = metric + " (" + scores.get("cats_score_desc", "unk") + ")"
|
||||
if key == "speed":
|
||||
results[metric] = f"{scores[key]:.0f}"
|
||||
if isinstance(scores[key], (int, float)):
|
||||
if key == "speed":
|
||||
results[metric] = f"{scores[key]:.0f}"
|
||||
else:
|
||||
results[metric] = f"{scores[key]*100:.2f}"
|
||||
else:
|
||||
results[metric] = f"{scores[key]*100:.2f}"
|
||||
results[metric] = "-"
|
||||
data[re.sub(r"[\s/]", "_", key.lower())] = scores[key]
|
||||
|
||||
msg.table(results, title="Results")
|
||||
|
|
|
@ -226,6 +226,9 @@ class AttributeRuler(Pipe):
|
|||
|
||||
DOCS: https://nightly.spacy.io/api/tagger#score
|
||||
"""
|
||||
def morph_key_getter(token, attr):
|
||||
return getattr(token, attr).key
|
||||
|
||||
validate_examples(examples, "AttributeRuler.score")
|
||||
results = {}
|
||||
attrs = set()
|
||||
|
@ -237,7 +240,8 @@ class AttributeRuler(Pipe):
|
|||
elif attr == POS:
|
||||
results.update(Scorer.score_token_attr(examples, "pos", **kwargs))
|
||||
elif attr == MORPH:
|
||||
results.update(Scorer.score_token_attr(examples, "morph", **kwargs))
|
||||
results.update(Scorer.score_token_attr(examples, "morph", getter=morph_key_getter, **kwargs))
|
||||
results.update(Scorer.score_token_attr_per_feat(examples, "morph", getter=morph_key_getter, **kwargs))
|
||||
elif attr == LEMMA:
|
||||
results.update(Scorer.score_token_attr(examples, "lemma", **kwargs))
|
||||
return results
|
||||
|
|
|
@ -155,13 +155,16 @@ cdef class DependencyParser(Parser):
|
|||
|
||||
DOCS: https://nightly.spacy.io/api/dependencyparser#score
|
||||
"""
|
||||
def has_sents(doc):
|
||||
return doc.has_annotation("SENT_START")
|
||||
|
||||
validate_examples(examples, "DependencyParser.score")
|
||||
def dep_getter(token, attr):
|
||||
dep = getattr(token, attr)
|
||||
dep = token.vocab.strings.as_string(dep).lower()
|
||||
return dep
|
||||
results = {}
|
||||
results.update(Scorer.score_spans(examples, "sents", **kwargs))
|
||||
results.update(Scorer.score_spans(examples, "sents", has_annotation=has_sents, **kwargs))
|
||||
kwargs.setdefault("getter", dep_getter)
|
||||
kwargs.setdefault("ignore_labels", ("p", "punct"))
|
||||
results.update(Scorer.score_deps(examples, "dep", **kwargs))
|
||||
|
|
|
@ -10,7 +10,7 @@ from ..errors import Errors
|
|||
from ..util import ensure_path, to_disk, from_disk, SimpleFrozenList
|
||||
from ..tokens import Doc, Span
|
||||
from ..matcher import Matcher, PhraseMatcher
|
||||
from ..scorer import Scorer
|
||||
from ..scorer import get_ner_prf
|
||||
from ..training import validate_examples
|
||||
|
||||
|
||||
|
@ -340,7 +340,7 @@ class EntityRuler(Pipe):
|
|||
|
||||
def score(self, examples, **kwargs):
|
||||
validate_examples(examples, "EntityRuler.score")
|
||||
return Scorer.score_spans(examples, "ents", **kwargs)
|
||||
return get_ner_prf(examples)
|
||||
|
||||
def from_bytes(
|
||||
self, patterns_bytes: bytes, *, exclude: Iterable[str] = SimpleFrozenList()
|
||||
|
|
|
@ -251,10 +251,13 @@ class Morphologizer(Tagger):
|
|||
|
||||
DOCS: https://nightly.spacy.io/api/morphologizer#score
|
||||
"""
|
||||
def morph_key_getter(token, attr):
|
||||
return getattr(token, attr).key
|
||||
|
||||
validate_examples(examples, "Morphologizer.score")
|
||||
results = {}
|
||||
results.update(Scorer.score_token_attr(examples, "pos", **kwargs))
|
||||
results.update(Scorer.score_token_attr(examples, "morph", **kwargs))
|
||||
results.update(Scorer.score_token_attr(examples, "morph", getter=morph_key_getter, **kwargs))
|
||||
results.update(Scorer.score_token_attr_per_feat(examples,
|
||||
"morph", **kwargs))
|
||||
"morph", getter=morph_key_getter, **kwargs))
|
||||
return results
|
||||
|
|
|
@ -122,13 +122,4 @@ cdef class EntityRecognizer(Parser):
|
|||
DOCS: https://nightly.spacy.io/api/entityrecognizer#score
|
||||
"""
|
||||
validate_examples(examples, "EntityRecognizer.score")
|
||||
score_per_type = get_ner_prf(examples)
|
||||
totals = PRFScore()
|
||||
for prf in score_per_type.values():
|
||||
totals += prf
|
||||
return {
|
||||
"ents_p": totals.precision,
|
||||
"ents_r": totals.recall,
|
||||
"ents_f": totals.fscore,
|
||||
"ents_per_type": {k: v.to_dict() for k, v in score_per_type.items()},
|
||||
}
|
||||
return get_ner_prf(examples)
|
||||
|
|
|
@ -155,8 +155,11 @@ class Sentencizer(Pipe):
|
|||
|
||||
DOCS: https://nightly.spacy.io/api/sentencizer#score
|
||||
"""
|
||||
def has_sents(doc):
|
||||
return doc.has_annotation("SENT_START")
|
||||
|
||||
validate_examples(examples, "Sentencizer.score")
|
||||
results = Scorer.score_spans(examples, "sents", **kwargs)
|
||||
results = Scorer.score_spans(examples, "sents", has_annotation=has_sents, **kwargs)
|
||||
del results["sents_per_type"]
|
||||
return results
|
||||
|
||||
|
|
|
@ -160,7 +160,10 @@ class SentenceRecognizer(Tagger):
|
|||
RETURNS (Dict[str, Any]): The scores, produced by Scorer.score_spans.
|
||||
DOCS: https://nightly.spacy.io/api/sentencerecognizer#score
|
||||
"""
|
||||
def has_sents(doc):
|
||||
return doc.has_annotation("SENT_START")
|
||||
|
||||
validate_examples(examples, "SentenceRecognizer.score")
|
||||
results = Scorer.score_spans(examples, "sents", **kwargs)
|
||||
results = Scorer.score_spans(examples, "sents", has_annotation=has_sents, **kwargs)
|
||||
del results["sents_per_type"]
|
||||
return results
|
||||
|
|
257
spacy/scorer.py
257
spacy/scorer.py
|
@ -1,9 +1,9 @@
|
|||
from typing import Optional, Iterable, Dict, Any, Callable, TYPE_CHECKING
|
||||
from typing import Optional, Iterable, Dict, Set, Any, Callable, TYPE_CHECKING
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
|
||||
from .training import Example
|
||||
from .tokens import Token, Doc, Span
|
||||
from .tokens import Token, Doc, Span, MorphAnalysis
|
||||
from .errors import Errors
|
||||
from .util import get_lang_class, SimpleFrozenList
|
||||
from .morphology import Morphology
|
||||
|
@ -13,7 +13,8 @@ if TYPE_CHECKING:
|
|||
from .language import Language # noqa: F401
|
||||
|
||||
|
||||
DEFAULT_PIPELINE = ["senter", "tagger", "morphologizer", "parser", "ner", "textcat"]
|
||||
DEFAULT_PIPELINE = ("senter", "tagger", "morphologizer", "parser", "ner", "textcat")
|
||||
MISSING_VALUES = frozenset([None, 0, ""])
|
||||
|
||||
|
||||
class PRFScore:
|
||||
|
@ -24,6 +25,9 @@ class PRFScore:
|
|||
self.fp = 0
|
||||
self.fn = 0
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.tp + self.fp + self.fn
|
||||
|
||||
def __iadd__(self, other):
|
||||
self.tp += other.tp
|
||||
self.fp += other.fp
|
||||
|
@ -94,7 +98,7 @@ class Scorer:
|
|||
self,
|
||||
nlp: Optional["Language"] = None,
|
||||
default_lang: str = "xx",
|
||||
default_pipeline=DEFAULT_PIPELINE,
|
||||
default_pipeline: Iterable[str] = DEFAULT_PIPELINE,
|
||||
**cfg,
|
||||
) -> None:
|
||||
"""Initialize the Scorer.
|
||||
|
@ -126,13 +130,13 @@ class Scorer:
|
|||
return scores
|
||||
|
||||
@staticmethod
|
||||
def score_tokenization(examples: Iterable[Example], **cfg) -> Dict[str, float]:
|
||||
def score_tokenization(examples: Iterable[Example], **cfg) -> Dict[str, Any]:
|
||||
"""Returns accuracy and PRF scores for tokenization.
|
||||
* token_acc: # correct tokens / # gold tokens
|
||||
* token_p/r/f: PRF for token character spans
|
||||
|
||||
examples (Iterable[Example]): Examples to score
|
||||
RETURNS (Dict[str, float]): A dictionary containing the scores
|
||||
RETURNS (Dict[str, Any]): A dictionary containing the scores
|
||||
token_acc/p/r/f.
|
||||
|
||||
DOCS: https://nightly.spacy.io/api/scorer#score_tokenization
|
||||
|
@ -142,6 +146,8 @@ class Scorer:
|
|||
for example in examples:
|
||||
gold_doc = example.reference
|
||||
pred_doc = example.predicted
|
||||
if gold_doc.has_unknown_spaces:
|
||||
continue
|
||||
align = example.alignment
|
||||
gold_spans = set()
|
||||
pred_spans = set()
|
||||
|
@ -158,12 +164,20 @@ class Scorer:
|
|||
else:
|
||||
acc_score.tp += 1
|
||||
prf_score.score_set(pred_spans, gold_spans)
|
||||
return {
|
||||
"token_acc": acc_score.fscore,
|
||||
"token_p": prf_score.precision,
|
||||
"token_r": prf_score.recall,
|
||||
"token_f": prf_score.fscore,
|
||||
}
|
||||
if len(acc_score) > 0:
|
||||
return {
|
||||
"token_acc": acc_score.fscore,
|
||||
"token_p": prf_score.precision,
|
||||
"token_r": prf_score.recall,
|
||||
"token_f": prf_score.fscore,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"token_acc": None,
|
||||
"token_p": None,
|
||||
"token_r": None,
|
||||
"token_f": None
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def score_token_attr(
|
||||
|
@ -171,8 +185,9 @@ class Scorer:
|
|||
attr: str,
|
||||
*,
|
||||
getter: Callable[[Token, str], Any] = getattr,
|
||||
missing_values: Set[Any] = MISSING_VALUES,
|
||||
**cfg,
|
||||
) -> Dict[str, float]:
|
||||
) -> Dict[str, Any]:
|
||||
"""Returns an accuracy score for a token-level attribute.
|
||||
|
||||
examples (Iterable[Example]): Examples to score
|
||||
|
@ -180,7 +195,7 @@ class Scorer:
|
|||
getter (Callable[[Token, str], Any]): Defaults to getattr. If provided,
|
||||
getter(token, attr) should return the value of the attribute for an
|
||||
individual token.
|
||||
RETURNS (Dict[str, float]): A dictionary containing the accuracy score
|
||||
RETURNS (Dict[str, Any]): A dictionary containing the accuracy score
|
||||
under the key attr_acc.
|
||||
|
||||
DOCS: https://nightly.spacy.io/api/scorer#score_token_attr
|
||||
|
@ -191,17 +206,27 @@ class Scorer:
|
|||
pred_doc = example.predicted
|
||||
align = example.alignment
|
||||
gold_tags = set()
|
||||
missing_indices = set()
|
||||
for gold_i, token in enumerate(gold_doc):
|
||||
gold_tags.add((gold_i, getter(token, attr)))
|
||||
value = getter(token, attr)
|
||||
if value not in missing_values:
|
||||
gold_tags.add((gold_i, getter(token, attr)))
|
||||
else:
|
||||
missing_indices.add(gold_i)
|
||||
pred_tags = set()
|
||||
for token in pred_doc:
|
||||
if token.orth_.isspace():
|
||||
continue
|
||||
if align.x2y.lengths[token.i] == 1:
|
||||
gold_i = align.x2y[token.i].dataXd[0, 0]
|
||||
pred_tags.add((gold_i, getter(token, attr)))
|
||||
if gold_i not in missing_indices:
|
||||
pred_tags.add((gold_i, getter(token, attr)))
|
||||
tag_score.score_set(pred_tags, gold_tags)
|
||||
return {f"{attr}_acc": tag_score.fscore}
|
||||
score_key = f"{attr}_acc"
|
||||
if len(tag_score) == 0:
|
||||
return {score_key: None}
|
||||
else:
|
||||
return {score_key: tag_score.fscore}
|
||||
|
||||
@staticmethod
|
||||
def score_token_attr_per_feat(
|
||||
|
@ -209,8 +234,9 @@ class Scorer:
|
|||
attr: str,
|
||||
*,
|
||||
getter: Callable[[Token, str], Any] = getattr,
|
||||
missing_values: Set[Any] = MISSING_VALUES,
|
||||
**cfg,
|
||||
):
|
||||
) -> Dict[str, Any]:
|
||||
"""Return PRF scores per feat for a token attribute in UFEATS format.
|
||||
|
||||
examples (Iterable[Example]): Examples to score
|
||||
|
@ -218,7 +244,7 @@ class Scorer:
|
|||
getter (Callable[[Token, str], Any]): Defaults to getattr. If provided,
|
||||
getter(token, attr) should return the value of the attribute for an
|
||||
individual token.
|
||||
RETURNS (dict): A dictionary containing the per-feat PRF scores unders
|
||||
RETURNS (dict): A dictionary containing the per-feat PRF scores under
|
||||
the key attr_per_feat.
|
||||
"""
|
||||
per_feat = {}
|
||||
|
@ -227,9 +253,11 @@ class Scorer:
|
|||
gold_doc = example.reference
|
||||
align = example.alignment
|
||||
gold_per_feat = {}
|
||||
missing_indices = set()
|
||||
for gold_i, token in enumerate(gold_doc):
|
||||
morph = str(getter(token, attr))
|
||||
if morph:
|
||||
value = getter(token, attr)
|
||||
morph = gold_doc.vocab.strings[value]
|
||||
if value not in missing_values and morph != Morphology.EMPTY_MORPH:
|
||||
for feat in morph.split(Morphology.FEATURE_SEP):
|
||||
field, values = feat.split(Morphology.FIELD_SEP)
|
||||
if field not in per_feat:
|
||||
|
@ -237,27 +265,35 @@ class Scorer:
|
|||
if field not in gold_per_feat:
|
||||
gold_per_feat[field] = set()
|
||||
gold_per_feat[field].add((gold_i, feat))
|
||||
else:
|
||||
missing_indices.add(gold_i)
|
||||
pred_per_feat = {}
|
||||
for token in pred_doc:
|
||||
if token.orth_.isspace():
|
||||
continue
|
||||
if align.x2y.lengths[token.i] == 1:
|
||||
gold_i = align.x2y[token.i].dataXd[0, 0]
|
||||
morph = str(getter(token, attr))
|
||||
if morph:
|
||||
for feat in morph.split("|"):
|
||||
field, values = feat.split("=")
|
||||
if field not in per_feat:
|
||||
per_feat[field] = PRFScore()
|
||||
if field not in pred_per_feat:
|
||||
pred_per_feat[field] = set()
|
||||
pred_per_feat[field].add((gold_i, feat))
|
||||
if gold_i not in missing_indices:
|
||||
value = getter(token, attr)
|
||||
morph = gold_doc.vocab.strings[value]
|
||||
if value not in missing_values and morph != Morphology.EMPTY_MORPH:
|
||||
for feat in morph.split(Morphology.FEATURE_SEP):
|
||||
field, values = feat.split(Morphology.FIELD_SEP)
|
||||
if field not in per_feat:
|
||||
per_feat[field] = PRFScore()
|
||||
if field not in pred_per_feat:
|
||||
pred_per_feat[field] = set()
|
||||
pred_per_feat[field].add((gold_i, feat))
|
||||
for field in per_feat:
|
||||
per_feat[field].score_set(
|
||||
pred_per_feat.get(field, set()), gold_per_feat.get(field, set())
|
||||
)
|
||||
result = {k: v.to_dict() for k, v in per_feat.items()}
|
||||
return {f"{attr}_per_feat": result}
|
||||
score_key = f"{attr}_per_feat"
|
||||
if any([len(v) for v in per_feat.values()]):
|
||||
result = {k: v.to_dict() for k, v in per_feat.items()}
|
||||
return {score_key: result}
|
||||
else:
|
||||
return {score_key: None}
|
||||
|
||||
@staticmethod
|
||||
def score_spans(
|
||||
|
@ -265,6 +301,7 @@ class Scorer:
|
|||
attr: str,
|
||||
*,
|
||||
getter: Callable[[Doc, str], Iterable[Span]] = getattr,
|
||||
has_annotation: Optional[Callable[[Doc], bool]] = None,
|
||||
**cfg,
|
||||
) -> Dict[str, Any]:
|
||||
"""Returns PRF scores for labeled spans.
|
||||
|
@ -284,18 +321,10 @@ class Scorer:
|
|||
for example in examples:
|
||||
pred_doc = example.predicted
|
||||
gold_doc = example.reference
|
||||
# TODO
|
||||
# This is a temporary hack to work around the problem that the scorer
|
||||
# fails if you have examples that are not fully annotated for all
|
||||
# the tasks in your pipeline. For instance, you might have a corpus
|
||||
# of NER annotations that does not set sentence boundaries, but the
|
||||
# pipeline includes a parser or senter, and then the score_weights
|
||||
# are used to evaluate that component. When the scorer attempts
|
||||
# to read the sentences from the gold document, it fails.
|
||||
try:
|
||||
list(getter(gold_doc, attr))
|
||||
except ValueError:
|
||||
continue
|
||||
# Option to handle docs without sents
|
||||
if has_annotation is not None:
|
||||
if not has_annotation(gold_doc):
|
||||
continue
|
||||
# Find all labels in gold and doc
|
||||
labels = set(
|
||||
[k.label_ for k in getter(gold_doc, attr)]
|
||||
|
@ -323,13 +352,21 @@ class Scorer:
|
|||
v.score_set(pred_per_type[k], gold_per_type[k])
|
||||
# Score for all labels
|
||||
score.score_set(pred_spans, gold_spans)
|
||||
results = {
|
||||
f"{attr}_p": score.precision,
|
||||
f"{attr}_r": score.recall,
|
||||
f"{attr}_f": score.fscore,
|
||||
f"{attr}_per_type": {k: v.to_dict() for k, v in score_per_type.items()},
|
||||
}
|
||||
return results
|
||||
if len(score) > 0:
|
||||
return {
|
||||
f"{attr}_p": score.precision,
|
||||
f"{attr}_r": score.recall,
|
||||
f"{attr}_f": score.fscore,
|
||||
f"{attr}_per_type": {k: v.to_dict() for k, v in score_per_type.items()},
|
||||
}
|
||||
else:
|
||||
return {
|
||||
f"{attr}_p": None,
|
||||
f"{attr}_r": None,
|
||||
f"{attr}_f": None,
|
||||
f"{attr}_per_type": None,
|
||||
}
|
||||
|
||||
|
||||
@staticmethod
|
||||
def score_cats(
|
||||
|
@ -390,9 +427,6 @@ class Scorer:
|
|||
pred_cats = getter(example.predicted, attr)
|
||||
gold_cats = getter(example.reference, attr)
|
||||
|
||||
# I think the AUC metric is applicable regardless of whether we're
|
||||
# doing multi-label classification? Unsure. If not, move this into
|
||||
# the elif pred_cats and gold_cats block below.
|
||||
for label in labels:
|
||||
pred_score = pred_cats.get(label, 0.0)
|
||||
gold_score = gold_cats.get(label, 0.0)
|
||||
|
@ -542,6 +576,7 @@ class Scorer:
|
|||
head_attr: str = "head",
|
||||
head_getter: Callable[[Token, str], Token] = getattr,
|
||||
ignore_labels: Iterable[str] = SimpleFrozenList(),
|
||||
missing_values: Set[Any] = MISSING_VALUES,
|
||||
**cfg,
|
||||
) -> Dict[str, Any]:
|
||||
"""Returns the UAS, LAS, and LAS per type scores for dependency
|
||||
|
@ -566,6 +601,7 @@ class Scorer:
|
|||
unlabelled = PRFScore()
|
||||
labelled = PRFScore()
|
||||
labelled_per_dep = dict()
|
||||
missing_indices = set()
|
||||
for example in examples:
|
||||
gold_doc = example.reference
|
||||
pred_doc = example.predicted
|
||||
|
@ -575,13 +611,16 @@ class Scorer:
|
|||
for gold_i, token in enumerate(gold_doc):
|
||||
dep = getter(token, attr)
|
||||
head = head_getter(token, head_attr)
|
||||
if dep not in ignore_labels:
|
||||
gold_deps.add((gold_i, head.i, dep))
|
||||
if dep not in labelled_per_dep:
|
||||
labelled_per_dep[dep] = PRFScore()
|
||||
if dep not in gold_deps_per_dep:
|
||||
gold_deps_per_dep[dep] = set()
|
||||
gold_deps_per_dep[dep].add((gold_i, head.i, dep))
|
||||
if dep not in missing_values:
|
||||
if dep not in ignore_labels:
|
||||
gold_deps.add((gold_i, head.i, dep))
|
||||
if dep not in labelled_per_dep:
|
||||
labelled_per_dep[dep] = PRFScore()
|
||||
if dep not in gold_deps_per_dep:
|
||||
gold_deps_per_dep[dep] = set()
|
||||
gold_deps_per_dep[dep].add((gold_i, head.i, dep))
|
||||
else:
|
||||
missing_indices.add(gold_i)
|
||||
pred_deps = set()
|
||||
pred_deps_per_dep = {}
|
||||
for token in pred_doc:
|
||||
|
@ -591,25 +630,26 @@ class Scorer:
|
|||
gold_i = None
|
||||
else:
|
||||
gold_i = align.x2y[token.i].dataXd[0, 0]
|
||||
dep = getter(token, attr)
|
||||
head = head_getter(token, head_attr)
|
||||
if dep not in ignore_labels and token.orth_.strip():
|
||||
if align.x2y.lengths[head.i] == 1:
|
||||
gold_head = align.x2y[head.i].dataXd[0, 0]
|
||||
else:
|
||||
gold_head = None
|
||||
# None is indistinct, so we can't just add it to the set
|
||||
# Multiple (None, None) deps are possible
|
||||
if gold_i is None or gold_head is None:
|
||||
unlabelled.fp += 1
|
||||
labelled.fp += 1
|
||||
else:
|
||||
pred_deps.add((gold_i, gold_head, dep))
|
||||
if dep not in labelled_per_dep:
|
||||
labelled_per_dep[dep] = PRFScore()
|
||||
if dep not in pred_deps_per_dep:
|
||||
pred_deps_per_dep[dep] = set()
|
||||
pred_deps_per_dep[dep].add((gold_i, gold_head, dep))
|
||||
if gold_i not in missing_indices:
|
||||
dep = getter(token, attr)
|
||||
head = head_getter(token, head_attr)
|
||||
if dep not in ignore_labels and token.orth_.strip():
|
||||
if align.x2y.lengths[head.i] == 1:
|
||||
gold_head = align.x2y[head.i].dataXd[0, 0]
|
||||
else:
|
||||
gold_head = None
|
||||
# None is indistinct, so we can't just add it to the set
|
||||
# Multiple (None, None) deps are possible
|
||||
if gold_i is None or gold_head is None:
|
||||
unlabelled.fp += 1
|
||||
labelled.fp += 1
|
||||
else:
|
||||
pred_deps.add((gold_i, gold_head, dep))
|
||||
if dep not in labelled_per_dep:
|
||||
labelled_per_dep[dep] = PRFScore()
|
||||
if dep not in pred_deps_per_dep:
|
||||
pred_deps_per_dep[dep] = set()
|
||||
pred_deps_per_dep[dep].add((gold_i, gold_head, dep))
|
||||
labelled.score_set(pred_deps, gold_deps)
|
||||
for dep in labelled_per_dep:
|
||||
labelled_per_dep[dep].score_set(
|
||||
|
@ -618,29 +658,34 @@ class Scorer:
|
|||
unlabelled.score_set(
|
||||
set(item[:2] for item in pred_deps), set(item[:2] for item in gold_deps)
|
||||
)
|
||||
return {
|
||||
f"{attr}_uas": unlabelled.fscore,
|
||||
f"{attr}_las": labelled.fscore,
|
||||
f"{attr}_las_per_type": {
|
||||
k: v.to_dict() for k, v in labelled_per_dep.items()
|
||||
},
|
||||
}
|
||||
if len(unlabelled) > 0:
|
||||
return {
|
||||
f"{attr}_uas": unlabelled.fscore,
|
||||
f"{attr}_las": labelled.fscore,
|
||||
f"{attr}_las_per_type": {
|
||||
k: v.to_dict() for k, v in labelled_per_dep.items()
|
||||
},
|
||||
}
|
||||
else:
|
||||
return {
|
||||
f"{attr}_uas": None,
|
||||
f"{attr}_las": None,
|
||||
f"{attr}_las_per_type": None,
|
||||
}
|
||||
|
||||
|
||||
def get_ner_prf(examples: Iterable[Example]) -> Dict[str, PRFScore]:
|
||||
"""Compute per-entity PRFScore objects for a sequence of examples. The
|
||||
results are returned as a dictionary keyed by the entity type. You can
|
||||
add the PRFScore objects to get micro-averaged total.
|
||||
def get_ner_prf(examples: Iterable[Example]) -> Dict[str, Any]:
|
||||
"""Compute micro-PRF and per-entity PRF scores for a sequence of examples.
|
||||
"""
|
||||
scores = defaultdict(PRFScore)
|
||||
score_per_type = defaultdict(PRFScore)
|
||||
for eg in examples:
|
||||
if not eg.y.has_annotation("ENT_IOB"):
|
||||
continue
|
||||
golds = {(e.label_, e.start, e.end) for e in eg.y.ents}
|
||||
align_x2y = eg.alignment.x2y
|
||||
for pred_ent in eg.x.ents:
|
||||
if pred_ent.label_ not in scores:
|
||||
scores[pred_ent.label_] = PRFScore()
|
||||
if pred_ent.label_ not in score_per_type:
|
||||
score_per_type[pred_ent.label_] = PRFScore()
|
||||
indices = align_x2y[pred_ent.start : pred_ent.end].dataXd.ravel()
|
||||
if len(indices):
|
||||
g_span = eg.y[indices[0] : indices[-1] + 1]
|
||||
|
@ -650,13 +695,29 @@ def get_ner_prf(examples: Iterable[Example]) -> Dict[str, PRFScore]:
|
|||
if all(token.ent_iob != 0 for token in g_span):
|
||||
key = (pred_ent.label_, indices[0], indices[-1] + 1)
|
||||
if key in golds:
|
||||
scores[pred_ent.label_].tp += 1
|
||||
score_per_type[pred_ent.label_].tp += 1
|
||||
golds.remove(key)
|
||||
else:
|
||||
scores[pred_ent.label_].fp += 1
|
||||
score_per_type[pred_ent.label_].fp += 1
|
||||
for label, start, end in golds:
|
||||
scores[label].fn += 1
|
||||
return scores
|
||||
score_per_type[label].fn += 1
|
||||
totals = PRFScore()
|
||||
for prf in score_per_type.values():
|
||||
totals += prf
|
||||
if len(totals) > 0:
|
||||
return {
|
||||
"ents_p": totals.precision,
|
||||
"ents_r": totals.recall,
|
||||
"ents_f": totals.fscore,
|
||||
"ents_per_type": {k: v.to_dict() for k, v in score_per_type.items()},
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"ents_p": None,
|
||||
"ents_r": None,
|
||||
"ents_f": None,
|
||||
"ents_per_type": None,
|
||||
}
|
||||
|
||||
|
||||
#############################################################################
|
||||
|
|
|
@ -160,8 +160,8 @@ def test_attributeruler_score(nlp, pattern_dicts):
|
|||
scores = nlp.evaluate(dev_examples)
|
||||
# "cat" is the only correct lemma
|
||||
assert scores["lemma_acc"] == pytest.approx(0.2)
|
||||
# the empty morphs are correct
|
||||
assert scores["morph_acc"] == pytest.approx(0.6)
|
||||
# no morphs are set
|
||||
assert scores["morph_acc"] == None
|
||||
|
||||
|
||||
def test_attributeruler_rule_order(nlp):
|
||||
|
|
|
@ -277,6 +277,62 @@ def test_tag_score(tagged_doc):
|
|||
assert results["morph_per_feat"]["Number"]["f"] == approx(0.72727272)
|
||||
|
||||
|
||||
def test_partial_annotation(en_tokenizer):
|
||||
pred_doc = en_tokenizer("a b c d e")
|
||||
pred_doc[0].tag_ = "A"
|
||||
pred_doc[0].pos_ = "X"
|
||||
pred_doc[0].set_morph("Feat=Val")
|
||||
pred_doc[0].dep_ = "dep"
|
||||
|
||||
# unannotated reference
|
||||
ref_doc = en_tokenizer("a b c d e")
|
||||
ref_doc.has_unknown_spaces = True
|
||||
example = Example(pred_doc, ref_doc)
|
||||
scorer = Scorer()
|
||||
scores = scorer.score([example])
|
||||
for key in scores:
|
||||
# cats doesn't have an unset state
|
||||
if key.startswith("cats"):
|
||||
continue
|
||||
assert scores[key] == None
|
||||
|
||||
# partially annotated reference, not overlapping with predicted annotation
|
||||
ref_doc = en_tokenizer("a b c d e")
|
||||
ref_doc.has_unknown_spaces = True
|
||||
ref_doc[1].tag_ = "A"
|
||||
ref_doc[1].pos_ = "X"
|
||||
ref_doc[1].set_morph("Feat=Val")
|
||||
ref_doc[1].dep_ = "dep"
|
||||
example = Example(pred_doc, ref_doc)
|
||||
scorer = Scorer()
|
||||
scores = scorer.score([example])
|
||||
assert scores["token_acc"] == None
|
||||
assert scores["tag_acc"] == 0.0
|
||||
assert scores["pos_acc"] == 0.0
|
||||
assert scores["morph_acc"] == 0.0
|
||||
assert scores["dep_uas"] == 1.0
|
||||
assert scores["dep_las"] == 0.0
|
||||
assert scores["sents_f"] == None
|
||||
|
||||
# partially annotated reference, overlapping with predicted annotation
|
||||
ref_doc = en_tokenizer("a b c d e")
|
||||
ref_doc.has_unknown_spaces = True
|
||||
ref_doc[0].tag_ = "A"
|
||||
ref_doc[0].pos_ = "X"
|
||||
ref_doc[1].set_morph("Feat=Val")
|
||||
ref_doc[1].dep_ = "dep"
|
||||
example = Example(pred_doc, ref_doc)
|
||||
scorer = Scorer()
|
||||
scores = scorer.score([example])
|
||||
assert scores["token_acc"] == None
|
||||
assert scores["tag_acc"] == 1.0
|
||||
assert scores["pos_acc"] == 1.0
|
||||
assert scores["morph_acc"] == 0.0
|
||||
assert scores["dep_uas"] == 1.0
|
||||
assert scores["dep_las"] == 0.0
|
||||
assert scores["sents_f"] == None
|
||||
|
||||
|
||||
def test_roc_auc_score():
|
||||
# Binary classification, toy tests from scikit-learn test suite
|
||||
y_true = [0, 1]
|
||||
|
|
|
@ -399,14 +399,13 @@ cdef class Doc:
|
|||
return True
|
||||
cdef int i
|
||||
cdef int range_start = 0
|
||||
if attr == "IS_SENT_START" or attr == self.vocab.strings["IS_SENT_START"]:
|
||||
attr = SENT_START
|
||||
attr = intify_attr(attr)
|
||||
# adjust attributes
|
||||
if attr == HEAD:
|
||||
# HEAD does not have an unset state, so rely on DEP
|
||||
attr = DEP
|
||||
elif attr == self.vocab.strings["IS_SENT_START"]:
|
||||
# as in Matcher, allow IS_SENT_START as an alias of SENT_START
|
||||
attr = SENT_START
|
||||
# special cases for sentence boundaries
|
||||
if attr == SENT_START:
|
||||
if "sents" in self.user_hooks:
|
||||
|
|
|
@ -683,6 +683,7 @@ The L2 norm of the document's vector representation.
|
|||
| `user_hooks` | A dictionary that allows customization of the `Doc`'s properties. ~~Dict[str, Callable]~~ |
|
||||
| `user_token_hooks` | A dictionary that allows customization of properties of `Token` children. ~~Dict[str, Callable]~~ |
|
||||
| `user_span_hooks` | A dictionary that allows customization of properties of `Span` children. ~~Dict[str, Callable]~~ |
|
||||
| `has_unknown_spaces` | Whether the document was constructed without known spacing between tokens (typically when created from gold tokenization). ~~bool~~ |
|
||||
| `_` | User space for adding custom [attribute extensions](/usage/processing-pipelines#custom-components-attributes). ~~Underscore~~ |
|
||||
|
||||
## Serialization fields {#serialization-fields}
|
||||
|
|
|
@ -68,6 +68,8 @@ Scores the tokenization:
|
|||
- `token_p`, `token_r`, `token_f`: precision, recall and F-score for token
|
||||
character spans
|
||||
|
||||
Docs with `has_unknown_spaces` are skipped during scoring.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
|
@ -81,7 +83,8 @@ Scores the tokenization:
|
|||
|
||||
## Scorer.score_token_attr {#score_token_attr tag="staticmethod" new="3"}
|
||||
|
||||
Scores a single token attribute.
|
||||
Scores a single token attribute. Tokens with missing values in the reference doc
|
||||
are skipped during scoring.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
|
@ -90,20 +93,22 @@ Scores a single token attribute.
|
|||
> print(scores["pos_acc"])
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `examples` | The `Example` objects holding both the predictions and the correct gold-standard annotations. ~~Iterable[Example]~~ |
|
||||
| `attr` | The attribute to score. ~~str~~ |
|
||||
| _keyword-only_ | |
|
||||
| `getter` | Defaults to `getattr`. If provided, `getter(token, attr)` should return the value of the attribute for an individual `Token`. ~~Callable[[Token, str], Any]~~ |
|
||||
| **RETURNS** | A dictionary containing the score `{attr}_acc`. ~~Dict[str, float]~~ |
|
||||
| Name | Description |
|
||||
| ---------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `examples` | The `Example` objects holding both the predictions and the correct gold-standard annotations. ~~Iterable[Example]~~ |
|
||||
| `attr` | The attribute to score. ~~str~~ |
|
||||
| _keyword-only_ | |
|
||||
| `getter` | Defaults to `getattr`. If provided, `getter(token, attr)` should return the value of the attribute for an individual `Token`. ~~Callable[[Token, str], Any]~~ |
|
||||
| `missing_values` | Attribute values to treat as missing annotation in the reference annotation. Defaults to `{0, None, ""}`. ~~Set[Any]~~ |
|
||||
| **RETURNS** | A dictionary containing the score `{attr}_acc`. ~~Dict[str, float]~~ |
|
||||
|
||||
## Scorer.score_token_attr_per_feat {#score_token_attr_per_feat tag="staticmethod" new="3"}
|
||||
|
||||
Scores a single token attribute per feature for a token attribute in the
|
||||
Universal Dependencies
|
||||
[FEATS](https://universaldependencies.org/format.html#morphological-annotation)
|
||||
format.
|
||||
format. Tokens with missing values in the reference doc are skipped during
|
||||
scoring.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
|
@ -112,13 +117,14 @@ format.
|
|||
> print(scores["morph_per_feat"])
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `examples` | The `Example` objects holding both the predictions and the correct gold-standard annotations. ~~Iterable[Example]~~ |
|
||||
| `attr` | The attribute to score. ~~str~~ |
|
||||
| _keyword-only_ | |
|
||||
| `getter` | Defaults to `getattr`. If provided, `getter(token, attr)` should return the value of the attribute for an individual `Token`. ~~Callable[[Token, str], Any]~~ |
|
||||
| **RETURNS** | A dictionary containing the per-feature PRF scores under the key `{attr}_per_feat`. ~~Dict[str, Dict[str, float]]~~ |
|
||||
| Name | Description |
|
||||
| ---------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `examples` | The `Example` objects holding both the predictions and the correct gold-standard annotations. ~~Iterable[Example]~~ |
|
||||
| `attr` | The attribute to score. ~~str~~ |
|
||||
| _keyword-only_ | |
|
||||
| `getter` | Defaults to `getattr`. If provided, `getter(token, attr)` should return the value of the attribute for an individual `Token`. ~~Callable[[Token, str], Any]~~ |
|
||||
| `missing_values` | Attribute values to treat as missing annotation in the reference annotation. Defaults to `{0, None, ""}`. ~~Set[Any]~~ |
|
||||
| **RETURNS** | A dictionary containing the per-feature PRF scores under the key `{attr}_per_feat`. ~~Dict[str, Dict[str, float]]~~ |
|
||||
|
||||
## Scorer.score_spans {#score_spans tag="staticmethod" new="3"}
|
||||
|
||||
|
@ -131,17 +137,19 @@ Returns PRF scores for labeled or unlabeled spans.
|
|||
> print(scores["ents_f"])
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `examples` | The `Example` objects holding both the predictions and the correct gold-standard annotations. ~~Iterable[Example]~~ |
|
||||
| `attr` | The attribute to score. ~~str~~ |
|
||||
| _keyword-only_ | |
|
||||
| `getter` | Defaults to `getattr`. If provided, `getter(doc, attr)` should return the `Span` objects for an individual `Doc`. ~~Callable[[Doc, str], Iterable[Span]]~~ |
|
||||
| **RETURNS** | A dictionary containing the PRF scores under the keys `{attr}_p`, `{attr}_r`, `{attr}_f` and the per-type PRF scores under `{attr}_per_type`. ~~Dict[str, Union[float, Dict[str, float]]]~~ |
|
||||
| Name | Description |
|
||||
| ---------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `examples` | The `Example` objects holding both the predictions and the correct gold-standard annotations. ~~Iterable[Example]~~ |
|
||||
| `attr` | The attribute to score. ~~str~~ |
|
||||
| _keyword-only_ | |
|
||||
| `getter` | Defaults to `getattr`. If provided, `getter(doc, attr)` should return the `Span` objects for an individual `Doc`. ~~Callable[[Doc, str], Iterable[Span]]~~ |
|
||||
| `has_annotation` | Defaults to `None`. If provided, `has_annotation(doc)` should return whether a `Doc` has annotation for this `attr`. Docs without annotation are skipped for scoring purposes. ~~str~~ |
|
||||
| **RETURNS** | A dictionary containing the PRF scores under the keys `{attr}_p`, `{attr}_r`, `{attr}_f` and the per-type PRF scores under `{attr}_per_type`. ~~Dict[str, Union[float, Dict[str, float]]]~~ |
|
||||
|
||||
## Scorer.score_deps {#score_deps tag="staticmethod" new="3"}
|
||||
|
||||
Calculate the UAS, LAS, and LAS per type scores for dependency parses.
|
||||
Calculate the UAS, LAS, and LAS per type scores for dependency parses. Tokens
|
||||
with missing values for the `attr` (typically `dep`) are skipped during scoring.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
|
@ -160,16 +168,17 @@ Calculate the UAS, LAS, and LAS per type scores for dependency parses.
|
|||
> print(scores["dep_uas"], scores["dep_las"])
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| --------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `examples` | The `Example` objects holding both the predictions and the correct gold-standard annotations. ~~Iterable[Example]~~ |
|
||||
| `attr` | The attribute to score. ~~str~~ |
|
||||
| _keyword-only_ | |
|
||||
| `getter` | Defaults to `getattr`. If provided, `getter(token, attr)` should return the value of the attribute for an individual `Token`. ~~Callable[[Token, str], Any]~~ |
|
||||
| `head_attr` | The attribute containing the head token. ~~str~~ |
|
||||
| `head_getter` | Defaults to `getattr`. If provided, `head_getter(token, attr)` should return the head for an individual `Token`. ~~Callable[[Doc, str], Token]~~ |
|
||||
| `ignore_labels` | Labels to ignore while scoring (e.g. `"punct"`). ~~Iterable[str]~~ |
|
||||
| **RETURNS** | A dictionary containing the scores: `{attr}_uas`, `{attr}_las`, and `{attr}_las_per_type`. ~~Dict[str, Union[float, Dict[str, float]]]~~ |
|
||||
| Name | Description |
|
||||
| ---------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `examples` | The `Example` objects holding both the predictions and the correct gold-standard annotations. ~~Iterable[Example]~~ |
|
||||
| `attr` | The attribute to score. ~~str~~ |
|
||||
| _keyword-only_ | |
|
||||
| `getter` | Defaults to `getattr`. If provided, `getter(token, attr)` should return the value of the attribute for an individual `Token`. ~~Callable[[Token, str], Any]~~ |
|
||||
| `head_attr` | The attribute containing the head token. ~~str~~ |
|
||||
| `head_getter` | Defaults to `getattr`. If provided, `head_getter(token, attr)` should return the head for an individual `Token`. ~~Callable[[Doc, str], Token]~~ |
|
||||
| `ignore_labels` | Labels to ignore while scoring (e.g. `"punct"`). ~~Iterable[str]~~ |
|
||||
| `missing_values` | Attribute values to treat as missing annotation in the reference annotation. Defaults to `{0, None, ""}`. ~~Set[Any]~~ |
|
||||
| **RETURNS** | A dictionary containing the scores: `{attr}_uas`, `{attr}_las`, and `{attr}_las_per_type`. ~~Dict[str, Union[float, Dict[str, float]]]~~ |
|
||||
|
||||
## Scorer.score_cats {#score_cats tag="staticmethod" new="3"}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user