Merge pull request #5819 from explosion/feature/component-scores

This commit is contained in:
Ines Montani 2020-07-28 10:40:56 +02:00 committed by GitHub
commit 9b704c3db3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 229 additions and 83 deletions

View File

@ -82,8 +82,7 @@ def evaluate(
"NER P": "ents_p", "NER P": "ents_p",
"NER R": "ents_r", "NER R": "ents_r",
"NER F": "ents_f", "NER F": "ents_f",
"Textcat AUC": "textcat_macro_auc", "Textcat": "cats_score",
"Textcat F": "textcat_macro_f",
"Sent P": "sents_p", "Sent P": "sents_p",
"Sent R": "sents_r", "Sent R": "sents_r",
"Sent F": "sents_f", "Sent F": "sents_f",
@ -91,6 +90,8 @@ def evaluate(
results = {} results = {}
for metric, key in metrics.items(): for metric, key in metrics.items():
if key in scores: if key in scores:
if key == "cats_score":
metric = metric + " (" + scores.get("cats_score_desc", "unk") + ")"
results[metric] = f"{scores[key]*100:.2f}" results[metric] = f"{scores[key]*100:.2f}"
data = {re.sub(r"[\s/]", "_", k.lower()): v for k, v in results.items()} data = {re.sub(r"[\s/]", "_", k.lower()): v for k, v in results.items()}
@ -99,12 +100,12 @@ def evaluate(
if "ents_per_type" in scores: if "ents_per_type" in scores:
if scores["ents_per_type"]: if scores["ents_per_type"]:
print_ents_per_type(msg, scores["ents_per_type"]) print_ents_per_type(msg, scores["ents_per_type"])
if "textcat_f_per_cat" in scores: if "cats_f_per_type" in scores:
if scores["textcat_f_per_cat"]: if scores["cats_f_per_type"]:
print_textcats_f_per_cat(msg, scores["textcat_f_per_cat"]) print_textcats_f_per_cat(msg, scores["cats_f_per_type"])
if "textcat_auc_per_cat" in scores: if "cats_auc_per_type" in scores:
if scores["textcat_auc_per_cat"]: if scores["cats_auc_per_type"]:
print_textcats_auc_per_cat(msg, scores["textcat_auc_per_cat"]) print_textcats_auc_per_cat(msg, scores["cats_auc_per_type"])
if displacy_path: if displacy_path:
factory_names = [nlp.get_pipe_meta(pipe).factory for pipe in nlp.pipe_names] factory_names = [nlp.get_pipe_meta(pipe).factory for pipe in nlp.pipe_names]
@ -170,7 +171,7 @@ def print_textcats_f_per_cat(msg: Printer, scores: Dict[str, Dict[str, float]])
data, data,
header=("", "P", "R", "F"), header=("", "P", "R", "F"),
aligns=("l", "r", "r", "r"), aligns=("l", "r", "r", "r"),
title="Textcat F (per type)", title="Textcat F (per label)",
) )

View File

@ -395,7 +395,7 @@ def subdivide_batch(batch, accumulate_gradient):
def setup_printer( def setup_printer(
training: Union[Dict[str, Any], Config], nlp: Language training: Union[Dict[str, Any], Config], nlp: Language
) -> Callable[[Dict[str, Any]], None]: ) -> Callable[[Dict[str, Any]], None]:
score_cols = training["scores"] score_cols = list(training["score_weights"])
score_widths = [max(len(col), 6) for col in score_cols] score_widths = [max(len(col), 6) for col in score_cols]
loss_cols = [f"Loss {pipe}" for pipe in nlp.pipe_names] loss_cols = [f"Loss {pipe}" for pipe in nlp.pipe_names]
loss_widths = [max(len(col), 8) for col in loss_cols] loss_widths = [max(len(col), 8) for col in loss_cols]

View File

@ -34,8 +34,8 @@ seed = 0
accumulate_gradient = 1 accumulate_gradient = 1
use_pytorch_for_gpu_memory = false use_pytorch_for_gpu_memory = false
# Control how scores are printed and checkpoints are evaluated. # Control how scores are printed and checkpoints are evaluated.
scores = ["speed", "tag_acc", "dep_uas", "dep_las", "ents_f"] scores = ["token_acc", "speed"]
score_weights = {"tag_acc": 0.2, "dep_las": 0.4, "ents_f": 0.4} score_weights = {}
# These settings are invalid for the transformer models. # These settings are invalid for the transformer models.
init_tok2vec = null init_tok2vec = null
discard_oversize = false discard_oversize = false

View File

@ -21,7 +21,7 @@ from .pipe_analysis import analyze_pipes, analyze_all_pipes, validate_attrs
from .gold import Example from .gold import Example
from .scorer import Scorer from .scorer import Scorer
from .util import link_vectors_to_models, create_default_optimizer, registry from .util import link_vectors_to_models, create_default_optimizer, registry
from .util import SimpleFrozenDict from .util import SimpleFrozenDict, combine_score_weights
from .lang.tokenizer_exceptions import URL_MATCH, BASE_EXCEPTIONS from .lang.tokenizer_exceptions import URL_MATCH, BASE_EXCEPTIONS
from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES
from .lang.punctuation import TOKENIZER_INFIXES from .lang.punctuation import TOKENIZER_INFIXES
@ -219,16 +219,25 @@ class Language:
@property @property
def config(self) -> Config: def config(self) -> Config:
self._config.setdefault("nlp", {}) self._config.setdefault("nlp", {})
self._config.setdefault("training", {})
self._config["nlp"]["lang"] = self.lang self._config["nlp"]["lang"] = self.lang
# We're storing the filled config for each pipeline component and so # We're storing the filled config for each pipeline component and so
# we can populate the config again later # we can populate the config again later
pipeline = {} pipeline = {}
scores = self._config["training"].get("scores", [])
score_weights = []
for pipe_name in self.pipe_names: for pipe_name in self.pipe_names:
pipe_meta = self.get_pipe_meta(pipe_name) pipe_meta = self.get_pipe_meta(pipe_name)
pipe_config = self.get_pipe_config(pipe_name) pipe_config = self.get_pipe_config(pipe_name)
pipeline[pipe_name] = {"factory": pipe_meta.factory, **pipe_config} pipeline[pipe_name] = {"factory": pipe_meta.factory, **pipe_config}
scores.extend(pipe_meta.scores)
if pipe_meta.default_score_weights:
score_weights.append(pipe_meta.default_score_weights)
self._config["nlp"]["pipeline"] = self.pipe_names self._config["nlp"]["pipeline"] = self.pipe_names
self._config["components"] = pipeline self._config["components"] = pipeline
self._config["training"]["scores"] = sorted(set(scores))
combined_score_weights = combine_score_weights(score_weights)
self._config["training"]["score_weights"] = combine_score_weights(score_weights)
if not srsly.is_json_serializable(self._config): if not srsly.is_json_serializable(self._config):
raise ValueError(Errors.E961.format(config=self._config)) raise ValueError(Errors.E961.format(config=self._config))
return self._config return self._config
@ -349,6 +358,8 @@ class Language:
assigns: Iterable[str] = tuple(), assigns: Iterable[str] = tuple(),
requires: Iterable[str] = tuple(), requires: Iterable[str] = tuple(),
retokenizes: bool = False, retokenizes: bool = False,
scores: Iterable[str] = tuple(),
default_score_weights: Dict[str, float] = SimpleFrozenDict(),
func: Optional[Callable] = None, func: Optional[Callable] = None,
) -> Callable: ) -> Callable:
"""Register a new pipeline component factory. Can be used as a decorator """Register a new pipeline component factory. Can be used as a decorator
@ -394,6 +405,8 @@ class Language:
default_config=default_config, default_config=default_config,
assigns=validate_attrs(assigns), assigns=validate_attrs(assigns),
requires=validate_attrs(requires), requires=validate_attrs(requires),
scores=scores,
default_score_weights=default_score_weights,
retokenizes=retokenizes, retokenizes=retokenizes,
) )
cls.set_factory_meta(name, factory_meta) cls.set_factory_meta(name, factory_meta)
@ -418,6 +431,8 @@ class Language:
assigns: Iterable[str] = tuple(), assigns: Iterable[str] = tuple(),
requires: Iterable[str] = tuple(), requires: Iterable[str] = tuple(),
retokenizes: bool = False, retokenizes: bool = False,
scores: Iterable[str] = tuple(),
default_score_weights: Dict[str, float] = SimpleFrozenDict(),
func: Optional[Callable[[Doc], Doc]] = None, func: Optional[Callable[[Doc], Doc]] = None,
) -> Callable: ) -> Callable:
"""Register a new pipeline component. Can be used for stateless function """Register a new pipeline component. Can be used for stateless function
@ -451,6 +466,8 @@ class Language:
assigns=assigns, assigns=assigns,
requires=requires, requires=requires,
retokenizes=retokenizes, retokenizes=retokenizes,
scores=scores,
default_score_weights=default_score_weights,
func=factory_func, func=factory_func,
) )
return component_func return component_func
@ -1487,6 +1504,8 @@ class FactoryMeta:
assigns: Iterable[str] = tuple() assigns: Iterable[str] = tuple()
requires: Iterable[str] = tuple() requires: Iterable[str] = tuple()
retokenizes: bool = False retokenizes: bool = False
scores: Iterable[str] = tuple()
default_score_weights: Dict[str, float] = None
def _get_config_overrides( def _get_config_overrides(

View File

@ -42,7 +42,9 @@ DEFAULT_PARSER_MODEL = Config().from_str(default_model_config)["model"]
"learn_tokens": False, "learn_tokens": False,
"min_action_freq": 30, "min_action_freq": 30,
"model": DEFAULT_PARSER_MODEL, "model": DEFAULT_PARSER_MODEL,
} },
scores=["dep_uas", "dep_las", "dep_las_per_type", "sents_p", "sents_r", "sents_f"],
default_score_weights={"dep_uas": 0.5, "dep_las": 0.5, "sents_f": 0.0},
) )
def make_parser( def make_parser(
nlp: Language, nlp: Language,
@ -120,4 +122,5 @@ cdef class DependencyParser(Parser):
results.update(Scorer.score_spans(examples, "sents", **kwargs)) results.update(Scorer.score_spans(examples, "sents", **kwargs))
results.update(Scorer.score_deps(examples, "dep", getter=dep_getter, results.update(Scorer.score_deps(examples, "dep", getter=dep_getter,
ignore_labels=("p", "punct"), **kwargs)) ignore_labels=("p", "punct"), **kwargs))
del results["sents_per_type"]
return results return results

View File

@ -8,6 +8,7 @@ from ..errors import Errors
from ..util import ensure_path, to_disk, from_disk from ..util import ensure_path, to_disk, from_disk
from ..tokens import Doc, Span from ..tokens import Doc, Span
from ..matcher import Matcher, PhraseMatcher from ..matcher import Matcher, PhraseMatcher
from ..scorer import Scorer
DEFAULT_ENT_ID_SEP = "||" DEFAULT_ENT_ID_SEP = "||"
@ -23,6 +24,8 @@ PatternType = Dict[str, Union[str, List[Dict[str, Any]]]]
"overwrite_ents": False, "overwrite_ents": False,
"ent_id_sep": DEFAULT_ENT_ID_SEP, "ent_id_sep": DEFAULT_ENT_ID_SEP,
}, },
scores=["ents_p", "ents_r", "ents_f", "ents_per_type"],
default_score_weights={"ents_f": 1.0, "ents_p": 0.0, "ents_r": 0.0},
) )
def make_entity_ruler( def make_entity_ruler(
nlp: Language, nlp: Language,
@ -309,6 +312,9 @@ class EntityRuler:
label = f"{label}{self.ent_id_sep}{ent_id}" label = f"{label}{self.ent_id_sep}{ent_id}"
return label return label
def score(self, examples, **kwargs):
return Scorer.score_spans(examples, "ents", **kwargs)
def from_bytes( def from_bytes(
self, patterns_bytes: bytes, exclude: Iterable[str] = tuple() self, patterns_bytes: bytes, exclude: Iterable[str] = tuple()
) -> "EntityRuler": ) -> "EntityRuler":

View File

@ -39,7 +39,9 @@ DEFAULT_MORPH_MODEL = Config().from_str(default_model_config)["model"]
@Language.factory( @Language.factory(
"morphologizer", "morphologizer",
assigns=["token.morph", "token.pos"], assigns=["token.morph", "token.pos"],
default_config={"model": DEFAULT_MORPH_MODEL} default_config={"model": DEFAULT_MORPH_MODEL},
scores=["pos_acc", "morph_acc", "morph_per_feat"],
default_score_weights={"pos_acc": 0.5, "morph_acc": 0.5},
) )
def make_morphologizer( def make_morphologizer(
nlp: Language, nlp: Language,

View File

@ -40,7 +40,10 @@ DEFAULT_NER_MODEL = Config().from_str(default_model_config)["model"]
"learn_tokens": False, "learn_tokens": False,
"min_action_freq": 30, "min_action_freq": 30,
"model": DEFAULT_NER_MODEL, "model": DEFAULT_NER_MODEL,
} },
scores=["ents_p", "ents_r", "ents_f", "ents_per_type"],
default_score_weights={"ents_f": 1.0, "ents_p": 0.0, "ents_r": 0.0},
) )
def make_ner( def make_ner(
nlp: Language, nlp: Language,

View File

@ -13,7 +13,9 @@ from .. import util
@Language.factory( @Language.factory(
"sentencizer", "sentencizer",
assigns=["token.is_sent_start", "doc.sents"], assigns=["token.is_sent_start", "doc.sents"],
default_config={"punct_chars": None} default_config={"punct_chars": None},
scores=["sents_p", "sents_r", "sents_f"],
default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0},
) )
def make_sentencizer( def make_sentencizer(
nlp: Language, nlp: Language,
@ -156,7 +158,9 @@ class Sentencizer(Pipe):
DOCS: https://spacy.io/api/sentencizer#score DOCS: https://spacy.io/api/sentencizer#score
""" """
return Scorer.score_spans(examples, "sents", **kwargs) results = Scorer.score_spans(examples, "sents", **kwargs)
del results["sents_per_type"]
return results
def to_bytes(self, exclude=tuple()): def to_bytes(self, exclude=tuple()):
"""Serialize the sentencizer to a bytestring. """Serialize the sentencizer to a bytestring.

View File

@ -33,7 +33,9 @@ DEFAULT_SENTER_MODEL = Config().from_str(default_model_config)["model"]
@Language.factory( @Language.factory(
"senter", "senter",
assigns=["token.is_sent_start"], assigns=["token.is_sent_start"],
default_config={"model": DEFAULT_SENTER_MODEL} default_config={"model": DEFAULT_SENTER_MODEL},
scores=["sents_p", "sents_r", "sents_f"],
default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0},
) )
def make_senter(nlp: Language, name: str, model: Model): def make_senter(nlp: Language, name: str, model: Model):
return SentenceRecognizer(nlp.vocab, model, name) return SentenceRecognizer(nlp.vocab, model, name)
@ -151,7 +153,9 @@ class SentenceRecognizer(Tagger):
RETURNS (Dict[str, Any]): The scores, produced by Scorer.score_spans. RETURNS (Dict[str, Any]): The scores, produced by Scorer.score_spans.
DOCS: https://spacy.io/api/sentencerecognizer#score DOCS: https://spacy.io/api/sentencerecognizer#score
""" """
return Scorer.score_spans(examples, "sents", **kwargs) results = Scorer.score_spans(examples, "sents", **kwargs)
del results["sents_per_type"]
return results
def to_bytes(self, exclude=tuple()): def to_bytes(self, exclude=tuple()):
"""Serialize the pipe to a bytestring. """Serialize the pipe to a bytestring.

View File

@ -8,6 +8,7 @@ from ..gold import Example, spans_from_biluo_tags, iob_to_biluo, biluo_to_iob
from ..tokens import Doc from ..tokens import Doc
from ..language import Language from ..language import Language
from ..vocab import Vocab from ..vocab import Vocab
from ..scorer import Scorer
from .. import util from .. import util
from .pipe import Pipe from .pipe import Pipe
@ -34,6 +35,9 @@ DEFAULT_SIMPLE_NER_MODEL = Config().from_str(default_model_config)["model"]
"simple_ner", "simple_ner",
assigns=["doc.ents"], assigns=["doc.ents"],
default_config={"labels": [], "model": DEFAULT_SIMPLE_NER_MODEL}, default_config={"labels": [], "model": DEFAULT_SIMPLE_NER_MODEL},
scores=["ents_p", "ents_r", "ents_f", "ents_per_type"],
default_score_weights={"ents_f": 1.0, "ents_p": 0.0, "ents_r": 0.0},
) )
def make_simple_ner( def make_simple_ner(
nlp: Language, name: str, model: Model, labels: Iterable[str] nlp: Language, name: str, model: Model, labels: Iterable[str]
@ -173,6 +177,9 @@ class SimpleNER(Pipe):
def init_multitask_objectives(self, *args, **kwargs): def init_multitask_objectives(self, *args, **kwargs):
pass pass
def score(self, examples, **kwargs):
return Scorer.score_spans(examples, "ents", **kwargs)
def _has_ner(example: Example) -> bool: def _has_ner(example: Example) -> bool:
for ner_tag in example.get_aligned_ner(): for ner_tag in example.get_aligned_ner():

View File

@ -39,7 +39,9 @@ DEFAULT_TAGGER_MODEL = Config().from_str(default_model_config)["model"]
@Language.factory( @Language.factory(
"tagger", "tagger",
assigns=["token.tag"], assigns=["token.tag"],
default_config={"model": DEFAULT_TAGGER_MODEL, "set_morphology": False} default_config={"model": DEFAULT_TAGGER_MODEL, "set_morphology": False},
scores=["tag_acc", "pos_acc", "lemma_acc"],
default_score_weights={"tag_acc": 1.0},
) )
def make_tagger(nlp: Language, name: str, model: Model, set_morphology: bool): def make_tagger(nlp: Language, name: str, model: Model, set_morphology: bool):
return Tagger(nlp.vocab, model, name, set_morphology=set_morphology) return Tagger(nlp.vocab, model, name, set_morphology=set_morphology)

View File

@ -56,6 +56,8 @@ dropout = null
"textcat", "textcat",
assigns=["doc.cats"], assigns=["doc.cats"],
default_config={"labels": [], "model": DEFAULT_TEXTCAT_MODEL}, default_config={"labels": [], "model": DEFAULT_TEXTCAT_MODEL},
scores=["cats_score", "cats_score_desc", "cats_p", "cats_r", "cats_f", "cats_macro_f", "cats_macro_auc", "cats_f_per_type", "cats_macro_auc_per_type"],
default_score_weights={"cats_score": 1.0},
) )
def make_textcat( def make_textcat(
nlp: Language, name: str, model: Model, labels: Iterable[str] nlp: Language, name: str, model: Model, labels: Iterable[str]

View File

@ -298,7 +298,8 @@ class Scorer:
**cfg **cfg
): ):
"""Returns PRF and ROC AUC scores for a doc-level attribute with a """Returns PRF and ROC AUC scores for a doc-level attribute with a
dict with scores for each label like Doc.cats. dict with scores for each label like Doc.cats. The reported overall
score depends on the scorer settings.
examples (Iterable[Example]): Examples to score examples (Iterable[Example]): Examples to score
attr (str): The attribute to score. attr (str): The attribute to score.
@ -309,11 +310,16 @@ class Scorer:
Defaults to True. Defaults to True.
positive_label (str): The positive label for a binary task with positive_label (str): The positive label for a binary task with
exclusive classes. Defaults to None. exclusive classes. Defaults to None.
RETURNS (dict): A dictionary containing the scores: RETURNS (dict): A dictionary containing the scores, with inapplicable
for binary exclusive with positive label: attr_p/r/f, scores as None:
for 3+ exclusive classes, macro-averaged fscore: attr_macro_f, for all:
for multilabel, macro-averaged AUC: attr_macro_auc, attr_score (one of attr_f / attr_macro_f / attr_macro_auc),
for all: attr_f_per_type, attr_auc_per_type attr_score_desc (text description of the overall score),
attr_f_per_type,
attr_auc_per_type
for binary exclusive with positive label: attr_p/r/f
for 3+ exclusive classes, macro-averaged fscore: attr_macro_f
for multilabel, macro-averaged AUC: attr_macro_auc
""" """
score = PRFScore() score = PRFScore()
f_per_type = dict() f_per_type = dict()
@ -362,6 +368,13 @@ class Scorer:
) )
) )
results = { results = {
attr + "_score": None,
attr + "_score_desc": None,
attr + "_p": None,
attr + "_r": None,
attr + "_f": None,
attr + "_macro_f": None,
attr + "_macro_auc": None,
attr + "_f_per_type": {k: v.to_dict() for k, v in f_per_type.items()}, attr + "_f_per_type": {k: v.to_dict() for k, v in f_per_type.items()},
attr + "_auc_per_type": {k: v.score for k, v in auc_per_type.items()}, attr + "_auc_per_type": {k: v.score for k, v in auc_per_type.items()},
} }
@ -369,16 +382,22 @@ class Scorer:
results[attr + "_p"] = score.precision results[attr + "_p"] = score.precision
results[attr + "_r"] = score.recall results[attr + "_r"] = score.recall
results[attr + "_f"] = score.fscore results[attr + "_f"] = score.fscore
results[attr + "_score"] = results[attr + "_f"]
results[attr + "_score_desc"] = "F (" + positive_label + ")"
elif not multi_label: elif not multi_label:
results[attr + "_macro_f"] = sum( results[attr + "_macro_f"] = sum(
[score.fscore for label, score in f_per_type.items()] [score.fscore for label, score in f_per_type.items()]
) / (len(f_per_type) + 1e-100) ) / (len(f_per_type) + 1e-100)
results[attr + "_score"] = results[attr + "_macro_f"]
results[attr + "_score_desc"] = "macro F"
else: else:
results[attr + "_macro_auc"] = max( results[attr + "_macro_auc"] = max(
sum([score.score for label, score in auc_per_type.items()]) sum([score.score for label, score in auc_per_type.items()])
/ (len(auc_per_type) + 1e-100), / (len(auc_per_type) + 1e-100),
-1, -1,
) )
results[attr + "_score"] = results[attr + "_macro_auc"]
results[attr + "_score_desc"] = "macro AUC"
return results return results
@staticmethod @staticmethod

View File

@ -3,7 +3,7 @@ from spacy.language import Language
from spacy.lang.en import English from spacy.lang.en import English
from spacy.lang.de import German from spacy.lang.de import German
from spacy.tokens import Doc from spacy.tokens import Doc
from spacy.util import registry, SimpleFrozenDict from spacy.util import registry, SimpleFrozenDict, combine_score_weights
from thinc.api import Model, Linear from thinc.api import Model, Linear
from thinc.config import ConfigValidationError from thinc.config import ConfigValidationError
from pydantic import StrictInt, StrictStr from pydantic import StrictInt, StrictStr
@ -328,3 +328,54 @@ def test_language_factories_invalid():
assert len(nlp.factories) assert len(nlp.factories)
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
nlp.factories["foo"] = "bar" nlp.factories["foo"] = "bar"
@pytest.mark.parametrize(
"weights,expected",
[
([{"a": 1.0}, {"b": 1.0}, {"c": 1.0}], {"a": 0.33, "b": 0.33, "c": 0.33}),
([{"a": 1.0}, {"b": 50}, {"c": 123}], {"a": 0.33, "b": 0.33, "c": 0.33}),
(
[{"a": 0.7, "b": 0.3}, {"c": 1.0}, {"d": 0.5, "e": 0.5}],
{"a": 0.23, "b": 0.1, "c": 0.33, "d": 0.17, "e": 0.17},
),
(
[{"a": 100, "b": 400}, {"c": 0.5, "d": 0.5}],
{"a": 0.1, "b": 0.4, "c": 0.25, "d": 0.25},
),
(
[{"a": 0.5, "b": 0.5}, {"b": 1.0}],
{"a": 0.25, "b": 0.75},
),
],
)
def test_language_factories_combine_score_weights(weights, expected):
result = combine_score_weights(weights)
assert sum(result.values()) in (0.99, 1.0)
assert result == expected
def test_language_factories_scores():
name = "test_language_factories_scores"
func = lambda doc: doc
weights1 = {"a1": 0.5, "a2": 0.5}
weights2 = {"b1": 0.2, "b2": 0.7, "b3": 0.1}
Language.component(
f"{name}1", scores=list(weights1), default_score_weights=weights1, func=func,
)
Language.component(
f"{name}2", scores=list(weights2), default_score_weights=weights2, func=func,
)
meta1 = Language.get_factory_meta(f"{name}1")
assert meta1.default_score_weights == weights1
meta2 = Language.get_factory_meta(f"{name}2")
assert meta2.default_score_weights == weights2
nlp = Language()
nlp._config["training"]["scores"] = ["speed"]
nlp._config["training"]["score_weights"] = {}
nlp.add_pipe(f"{name}1")
nlp.add_pipe(f"{name}2")
cfg = nlp.config["training"]
assert cfg["scores"] == sorted(["speed", *list(weights1.keys()), *list(weights2.keys())])
expected_weights = {"a1": 0.25, "a2": 0.25, "b1": 0.1, "b2": 0.35, "b3": 0.05}
assert cfg["score_weights"] == expected_weights

View File

@ -121,6 +121,8 @@ def test_overfitting_IO():
train_examples, component_cfg={"scorer": {"positive_label": "POSITIVE"}} train_examples, component_cfg={"scorer": {"positive_label": "POSITIVE"}}
) )
assert scores["cats_f"] == 1.0 assert scores["cats_f"] == 1.0
assert scores["cats_score"] == 1.0
assert "cats_score_desc" in scores
# fmt: off # fmt: off

View File

@ -1130,6 +1130,25 @@ def get_arg_names(func: Callable) -> List[str]:
return list(set([*argspec.args, *argspec.kwonlyargs])) return list(set([*argspec.args, *argspec.kwonlyargs]))
def combine_score_weights(weights: List[Dict[str, float]]) -> Dict[str, float]:
"""Combine and normalize score weights defined by components, e.g.
{"ents_r": 0.2, "ents_p": 0.3, "ents_f": 0.5} and {"some_other_score": 1.0}.
weights (List[dict]): The weights defined by the components.
RETURNS (Dict[str, float]): The combined and normalized weights.
"""
result = {}
for w_dict in weights:
# We need to account for weights that don't sum to 1.0 and normalize
# the score weights accordingly, then divide score by the number of
# components.
total = sum(w_dict.values())
for key, value in w_dict.items():
weight = round(value / total / len(weights), 2)
result[key] = result.get(key, 0.0) + weight
return result
class DummyTokenizer: class DummyTokenizer:
# add dummy methods for to_bytes, from_bytes, to_disk and from_disk to # add dummy methods for to_bytes, from_bytes, to_disk and from_disk to
# allow serialization (see #1557) # allow serialization (see #1557)

View File

@ -8,8 +8,8 @@ source: spacy/scorer.py
The `Scorer` computes evaluation scores. It's typically created by The `Scorer` computes evaluation scores. It's typically created by
[`Language.evaluate`](/api/language#evaluate). [`Language.evaluate`](/api/language#evaluate).
In addition, the `Scorer` provides a number of evaluation methods for In addition, the `Scorer` provides a number of evaluation methods for evaluating
evaluating `Token` and `Doc` attributes. `Token` and `Doc` attributes.
## Scorer.\_\_init\_\_ {#init tag="method"} ## Scorer.\_\_init\_\_ {#init tag="method"}
@ -28,10 +28,10 @@ Create a new `Scorer`.
> scorer = Scorer(nlp) > scorer = Scorer(nlp)
> ``` > ```
| Name | Type | Description | | Name | Type | Description |
| ------------ | -------- | ------------------------------------------------------------ | | ----------- | -------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `nlp` | Language | The pipeline to use for scoring, where each pipeline component may provide a scoring method. If none is provided, then a default pipeline for the multi-language code `xx` is constructed containing: `senter`, `tagger`, `morphologizer`, `parser`, `ner`, `textcat`. | | `nlp` | Language | The pipeline to use for scoring, where each pipeline component may provide a scoring method. If none is provided, then a default pipeline for the multi-language code `xx` is constructed containing: `senter`, `tagger`, `morphologizer`, `parser`, `ner`, `textcat`. |
| **RETURNS** | `Scorer` | The newly created object. | | **RETURNS** | `Scorer` | The newly created object. |
## Scorer.score {#score tag="method"} ## Scorer.score {#score tag="method"}
@ -39,13 +39,13 @@ Calculate the scores for a list of [`Example`](/api/example) objects using the
scoring methods provided by the components in the pipeline. scoring methods provided by the components in the pipeline.
The returned `Dict` contains the scores provided by the individual pipeline The returned `Dict` contains the scores provided by the individual pipeline
components. For the scoring methods provided by the `Scorer` and use by the components. For the scoring methods provided by the `Scorer` and use by the core
core pipeline components, the individual score names start with the `Token` or pipeline components, the individual score names start with the `Token` or `Doc`
`Doc` attribute being scored: `token_acc`, `token_p/r/f`, `sents_p/r/f`, attribute being scored: `token_acc`, `token_p/r/f`, `sents_p/r/f`, `tag_acc`,
`tag_acc`, `pos_acc`, `morph_acc`, `morph_per_feat`, `lemma_acc`, `dep_uas`, `pos_acc`, `morph_acc`, `morph_per_feat`, `lemma_acc`, `dep_uas`, `dep_las`,
`dep_las`, `dep_las_per_type`, `ents_p/r/f`, `ents_per_type`, `dep_las_per_type`, `ents_p/r/f`, `ents_per_type`, `textcat_macro_auc`,
`textcat_macro_auc`, `textcat_macro_f`. `textcat_macro_f`.
> #### Example > #### Example
> >
> ```python > ```python
@ -53,19 +53,20 @@ core pipeline components, the individual score names start with the `Token` or
> scorer.score(examples) > scorer.score(examples)
> ``` > ```
| Name | Type | Description | | Name | Type | Description |
| ----------- | --------- | --------------------------------------------------------------------------------------------------------| | ----------- | ------------------- | --------------------------------------------------------------------------------------------- |
| `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. | | `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. |
| **RETURNS** | `Dict` | A dictionary of scores. | | **RETURNS** | `Dict` | A dictionary of scores. |
## Scorer.score_tokenization {#score_tokenization tag="staticmethod"} ## Scorer.score_tokenization {#score_tokenization tag="staticmethod"}
Scores the tokenization: Scores the tokenization:
* `token_acc`: # correct tokens / # gold tokens - `token_acc`: # correct tokens / # gold tokens
* `token_p/r/f`: PRF for token character spans - `token_p/r/f`: PRF for token character spans
| Name | Type | Description | | Name | Type | Description |
| ----------- | --------- | --------------------------------------------------------------------------------------------------------| | ----------- | ------------------- | --------------------------------------------------------------------------------------------- |
| `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. | | `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. |
| **RETURNS** | `Dict` | A dictionary containing the scores `token_acc/p/r/f`. | | **RETURNS** | `Dict` | A dictionary containing the scores `token_acc/p/r/f`. |
@ -73,61 +74,62 @@ Scores the tokenization:
Scores a single token attribute. Scores a single token attribute.
| Name | Type | Description | | Name | Type | Description |
| ----------- | --------- | --------------------------------------------------------------------------------------------------------| | ----------- | ------------------- | ----------------------------------------------------------------------------------------------------------------------------- |
| `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. | | `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. |
| `attr` | `str` | The attribute to score. | | `attr` | `str` | The attribute to score. |
| `getter` | `callable` | Defaults to `getattr`. If provided, `getter(token, attr)` should return the value of the attribute for an individual `Token`. | | `getter` | `callable` | Defaults to `getattr`. If provided, `getter(token, attr)` should return the value of the attribute for an individual `Token`. |
| **RETURNS** | `Dict` | A dictionary containing the score `attr_acc`. | | **RETURNS** | `Dict` | A dictionary containing the score `attr_acc`. |
## Scorer.score_token_attr_per_feat {#score_token_attr_per_feat tag="staticmethod"} ## Scorer.score_token_attr_per_feat {#score_token_attr_per_feat tag="staticmethod"}
Scores a single token attribute per feature for a token attribute in UFEATS format. Scores a single token attribute per feature for a token attribute in UFEATS
format.
| Name | Type | Description | | Name | Type | Description |
| ----------- | --------- | --------------------------------------------------------------------------------------------------------| | ----------- | ------------------- | ----------------------------------------------------------------------------------------------------------------------------- |
| `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. | | `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. |
| `attr` | `str` | The attribute to score. | | `attr` | `str` | The attribute to score. |
| `getter` | `callable` | Defaults to `getattr`. If provided, `getter(token, attr)` should return the value of the attribute for an individual `Token`. | | `getter` | `callable` | Defaults to `getattr`. If provided, `getter(token, attr)` should return the value of the attribute for an individual `Token`. |
| **RETURNS** | `Dict` | A dictionary containing the per-feature PRF scores unders the key `attr_per_feat`. | | **RETURNS** | `Dict` | A dictionary containing the per-feature PRF scores unders the key `attr_per_feat`. |
## Scorer.score_spans {#score_spans tag="staticmethod"} ## Scorer.score_spans {#score_spans tag="staticmethod"}
Returns PRF scores for labeled or unlabeled spans. Returns PRF scores for labeled or unlabeled spans.
| Name | Type | Description | | Name | Type | Description |
| ----------- | --------- | --------------------------------------------------------------------------------------------------------| | ----------- | ------------------- | --------------------------------------------------------------------------------------------------------------------- |
| `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. | | `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. |
| `attr` | `str` | The attribute to score. | | `attr` | `str` | The attribute to score. |
| `getter` | `callable` | Defaults to `getattr`. If provided, `getter(doc, attr)` should return the `Span` objects for an individual `Doc`. | | `getter` | `callable` | Defaults to `getattr`. If provided, `getter(doc, attr)` should return the `Span` objects for an individual `Doc`. |
| **RETURNS** | `Dict` | A dictionary containing the PRF scores under the keys `attr_p/r/f` and the per-type PRF scores under `attr_per_type`. | | **RETURNS** | `Dict` | A dictionary containing the PRF scores under the keys `attr_p/r/f` and the per-type PRF scores under `attr_per_type`. |
## Scorer.score_deps {#score_deps tag="staticmethod"} ## Scorer.score_deps {#score_deps tag="staticmethod"}
Calculate the UAS, LAS, and LAS per type scores for dependency parses. Calculate the UAS, LAS, and LAS per type scores for dependency parses.
| Name | Type | Description | | Name | Type | Description |
| ----------- | --------- | --------------------------------------------------------------------------------------------------------| | --------------- | ------------------- | ----------------------------------------------------------------------------------------------------------------------------- |
| `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. | | `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. |
| `attr` | `str` | The attribute containing the dependency label. | | `attr` | `str` | The attribute containing the dependency label. |
| `getter` | `callable` | Defaults to `getattr`. If provided, `getter(token, attr)` should return the value of the attribute for an individual `Token`. | | `getter` | `callable` | Defaults to `getattr`. If provided, `getter(token, attr)` should return the value of the attribute for an individual `Token`. |
| `head_attr` | `str` | The attribute containing the head token. | | `head_attr` | `str` | The attribute containing the head token. |
| `head_getter` | `callable` | Defaults to `getattr`. If provided, `head_getter(token, attr)` should return the head for an individual `Token`. | | `head_getter` | `callable` | Defaults to `getattr`. If provided, `head_getter(token, attr)` should return the head for an individual `Token`. |
| `ignore_labels` | `Tuple` | Labels to ignore while scoring (e.g., `punct`). | `ignore_labels` | `Tuple` | Labels to ignore while scoring (e.g., `punct`). |
| **RETURNS** | `Dict` | A dictionary containing the scores: `attr_uas`, `attr_las`, and `attr_las_per_type`. | | **RETURNS** | `Dict` | A dictionary containing the scores: `attr_uas`, `attr_las`, and `attr_las_per_type`. |
## Scorer.score_cats {#score_cats tag="staticmethod"} ## Scorer.score_cats {#score_cats tag="staticmethod"}
Calculate PRF and ROC AUC scores for a doc-level attribute that is a dict Calculate PRF and ROC AUC scores for a doc-level attribute that is a dict
containing scores for each label like `Doc.cats`. containing scores for each label like `Doc.cats`. The reported overall score
depends on the scorer settings.
| Name | Type | Description |
| ----------- | --------- | --------------------------------------------------------------------------------------------------------|
| `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. |
| `attr` | `str` | The attribute to score. |
| `getter` | `callable` | Defaults to `getattr`. If provided, `getter(doc, attr)` should return the cats for an individual `Doc`. |
| labels | `Iterable[str]` | The set of possible labels. Defaults to `[]`. |
| multi_label | `bool` | Whether the attribute allows multiple labels. Defaults to `True`. |
| positive_label | `str` | The positive label for a binary task with exclusive classes. Defaults to `None`. |
| **RETURNS** | `Dict` | A dictionary containing the scores: 1) for binary exclusive with positive label: `attr_p/r/f`; 2) for 3+ exclusive classes, macro-averaged fscore: `attr_macro_f`; 3) for multilabel, macro-averaged AUC: `attr_macro_auc`; 4) for all: `attr_f_per_type`, `attr_auc_per_type` |
| Name | Type | Description |
| ---------------- | ------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. |
| `attr` | `str` | The attribute to score. |
| `getter` | `callable` | Defaults to `getattr`. If provided, `getter(doc, attr)` should return the cats for an individual `Doc`. |
| labels | `Iterable[str]` | The set of possible labels. Defaults to `[]`. |
| `multi_label` | `bool` | Whether the attribute allows multiple labels. Defaults to `True`. |
| `positive_label` | `str` | The positive label for a binary task with exclusive classes. Defaults to `None`. |
| **RETURNS** | `Dict` | A dictionary containing the scores, with inapplicable scores as `None`: 1) for all: `attr_score` (one of `attr_f` / `attr_macro_f` / `attr_macro_auc`), `attr_score_desc` (text description of the overall score), `attr_f_per_type`, `attr_auc_per_type`; 2) for binary exclusive with positive label: `attr_p/r/f`; 3) for 3+ exclusive classes, macro-averaged fscore: `attr_macro_f`; 4) for multilabel, macro-averaged AUC: `attr_macro_auc` |