Merge pull request #5819 from explosion/feature/component-scores

This commit is contained in:
Ines Montani 2020-07-28 10:40:56 +02:00 committed by GitHub
commit 9b704c3db3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 229 additions and 83 deletions

View File

@ -82,8 +82,7 @@ def evaluate(
"NER P": "ents_p",
"NER R": "ents_r",
"NER F": "ents_f",
"Textcat AUC": "textcat_macro_auc",
"Textcat F": "textcat_macro_f",
"Textcat": "cats_score",
"Sent P": "sents_p",
"Sent R": "sents_r",
"Sent F": "sents_f",
@ -91,6 +90,8 @@ def evaluate(
results = {}
for metric, key in metrics.items():
if key in scores:
if key == "cats_score":
metric = metric + " (" + scores.get("cats_score_desc", "unk") + ")"
results[metric] = f"{scores[key]*100:.2f}"
data = {re.sub(r"[\s/]", "_", k.lower()): v for k, v in results.items()}
@ -99,12 +100,12 @@ def evaluate(
if "ents_per_type" in scores:
if scores["ents_per_type"]:
print_ents_per_type(msg, scores["ents_per_type"])
if "textcat_f_per_cat" in scores:
if scores["textcat_f_per_cat"]:
print_textcats_f_per_cat(msg, scores["textcat_f_per_cat"])
if "textcat_auc_per_cat" in scores:
if scores["textcat_auc_per_cat"]:
print_textcats_auc_per_cat(msg, scores["textcat_auc_per_cat"])
if "cats_f_per_type" in scores:
if scores["cats_f_per_type"]:
print_textcats_f_per_cat(msg, scores["cats_f_per_type"])
if "cats_auc_per_type" in scores:
if scores["cats_auc_per_type"]:
print_textcats_auc_per_cat(msg, scores["cats_auc_per_type"])
if displacy_path:
factory_names = [nlp.get_pipe_meta(pipe).factory for pipe in nlp.pipe_names]
@ -170,7 +171,7 @@ def print_textcats_f_per_cat(msg: Printer, scores: Dict[str, Dict[str, float]])
data,
header=("", "P", "R", "F"),
aligns=("l", "r", "r", "r"),
title="Textcat F (per type)",
title="Textcat F (per label)",
)

View File

@ -395,7 +395,7 @@ def subdivide_batch(batch, accumulate_gradient):
def setup_printer(
training: Union[Dict[str, Any], Config], nlp: Language
) -> Callable[[Dict[str, Any]], None]:
score_cols = training["scores"]
score_cols = list(training["score_weights"])
score_widths = [max(len(col), 6) for col in score_cols]
loss_cols = [f"Loss {pipe}" for pipe in nlp.pipe_names]
loss_widths = [max(len(col), 8) for col in loss_cols]

View File

@ -34,8 +34,8 @@ seed = 0
accumulate_gradient = 1
use_pytorch_for_gpu_memory = false
# Control how scores are printed and checkpoints are evaluated.
scores = ["speed", "tag_acc", "dep_uas", "dep_las", "ents_f"]
score_weights = {"tag_acc": 0.2, "dep_las": 0.4, "ents_f": 0.4}
scores = ["token_acc", "speed"]
score_weights = {}
# These settings are invalid for the transformer models.
init_tok2vec = null
discard_oversize = false

View File

@ -21,7 +21,7 @@ from .pipe_analysis import analyze_pipes, analyze_all_pipes, validate_attrs
from .gold import Example
from .scorer import Scorer
from .util import link_vectors_to_models, create_default_optimizer, registry
from .util import SimpleFrozenDict
from .util import SimpleFrozenDict, combine_score_weights
from .lang.tokenizer_exceptions import URL_MATCH, BASE_EXCEPTIONS
from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES
from .lang.punctuation import TOKENIZER_INFIXES
@ -219,16 +219,25 @@ class Language:
@property
def config(self) -> Config:
self._config.setdefault("nlp", {})
self._config.setdefault("training", {})
self._config["nlp"]["lang"] = self.lang
# We're storing the filled config for each pipeline component and so
# we can populate the config again later
pipeline = {}
scores = self._config["training"].get("scores", [])
score_weights = []
for pipe_name in self.pipe_names:
pipe_meta = self.get_pipe_meta(pipe_name)
pipe_config = self.get_pipe_config(pipe_name)
pipeline[pipe_name] = {"factory": pipe_meta.factory, **pipe_config}
scores.extend(pipe_meta.scores)
if pipe_meta.default_score_weights:
score_weights.append(pipe_meta.default_score_weights)
self._config["nlp"]["pipeline"] = self.pipe_names
self._config["components"] = pipeline
self._config["training"]["scores"] = sorted(set(scores))
combined_score_weights = combine_score_weights(score_weights)
self._config["training"]["score_weights"] = combine_score_weights(score_weights)
if not srsly.is_json_serializable(self._config):
raise ValueError(Errors.E961.format(config=self._config))
return self._config
@ -349,6 +358,8 @@ class Language:
assigns: Iterable[str] = tuple(),
requires: Iterable[str] = tuple(),
retokenizes: bool = False,
scores: Iterable[str] = tuple(),
default_score_weights: Dict[str, float] = SimpleFrozenDict(),
func: Optional[Callable] = None,
) -> Callable:
"""Register a new pipeline component factory. Can be used as a decorator
@ -394,6 +405,8 @@ class Language:
default_config=default_config,
assigns=validate_attrs(assigns),
requires=validate_attrs(requires),
scores=scores,
default_score_weights=default_score_weights,
retokenizes=retokenizes,
)
cls.set_factory_meta(name, factory_meta)
@ -418,6 +431,8 @@ class Language:
assigns: Iterable[str] = tuple(),
requires: Iterable[str] = tuple(),
retokenizes: bool = False,
scores: Iterable[str] = tuple(),
default_score_weights: Dict[str, float] = SimpleFrozenDict(),
func: Optional[Callable[[Doc], Doc]] = None,
) -> Callable:
"""Register a new pipeline component. Can be used for stateless function
@ -451,6 +466,8 @@ class Language:
assigns=assigns,
requires=requires,
retokenizes=retokenizes,
scores=scores,
default_score_weights=default_score_weights,
func=factory_func,
)
return component_func
@ -1487,6 +1504,8 @@ class FactoryMeta:
assigns: Iterable[str] = tuple()
requires: Iterable[str] = tuple()
retokenizes: bool = False
scores: Iterable[str] = tuple()
default_score_weights: Dict[str, float] = None
def _get_config_overrides(

View File

@ -42,7 +42,9 @@ DEFAULT_PARSER_MODEL = Config().from_str(default_model_config)["model"]
"learn_tokens": False,
"min_action_freq": 30,
"model": DEFAULT_PARSER_MODEL,
}
},
scores=["dep_uas", "dep_las", "dep_las_per_type", "sents_p", "sents_r", "sents_f"],
default_score_weights={"dep_uas": 0.5, "dep_las": 0.5, "sents_f": 0.0},
)
def make_parser(
nlp: Language,
@ -120,4 +122,5 @@ cdef class DependencyParser(Parser):
results.update(Scorer.score_spans(examples, "sents", **kwargs))
results.update(Scorer.score_deps(examples, "dep", getter=dep_getter,
ignore_labels=("p", "punct"), **kwargs))
del results["sents_per_type"]
return results

View File

@ -8,6 +8,7 @@ from ..errors import Errors
from ..util import ensure_path, to_disk, from_disk
from ..tokens import Doc, Span
from ..matcher import Matcher, PhraseMatcher
from ..scorer import Scorer
DEFAULT_ENT_ID_SEP = "||"
@ -23,6 +24,8 @@ PatternType = Dict[str, Union[str, List[Dict[str, Any]]]]
"overwrite_ents": False,
"ent_id_sep": DEFAULT_ENT_ID_SEP,
},
scores=["ents_p", "ents_r", "ents_f", "ents_per_type"],
default_score_weights={"ents_f": 1.0, "ents_p": 0.0, "ents_r": 0.0},
)
def make_entity_ruler(
nlp: Language,
@ -309,6 +312,9 @@ class EntityRuler:
label = f"{label}{self.ent_id_sep}{ent_id}"
return label
def score(self, examples, **kwargs):
return Scorer.score_spans(examples, "ents", **kwargs)
def from_bytes(
self, patterns_bytes: bytes, exclude: Iterable[str] = tuple()
) -> "EntityRuler":

View File

@ -39,7 +39,9 @@ DEFAULT_MORPH_MODEL = Config().from_str(default_model_config)["model"]
@Language.factory(
"morphologizer",
assigns=["token.morph", "token.pos"],
default_config={"model": DEFAULT_MORPH_MODEL}
default_config={"model": DEFAULT_MORPH_MODEL},
scores=["pos_acc", "morph_acc", "morph_per_feat"],
default_score_weights={"pos_acc": 0.5, "morph_acc": 0.5},
)
def make_morphologizer(
nlp: Language,

View File

@ -40,7 +40,10 @@ DEFAULT_NER_MODEL = Config().from_str(default_model_config)["model"]
"learn_tokens": False,
"min_action_freq": 30,
"model": DEFAULT_NER_MODEL,
}
},
scores=["ents_p", "ents_r", "ents_f", "ents_per_type"],
default_score_weights={"ents_f": 1.0, "ents_p": 0.0, "ents_r": 0.0},
)
def make_ner(
nlp: Language,

View File

@ -13,7 +13,9 @@ from .. import util
@Language.factory(
"sentencizer",
assigns=["token.is_sent_start", "doc.sents"],
default_config={"punct_chars": None}
default_config={"punct_chars": None},
scores=["sents_p", "sents_r", "sents_f"],
default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0},
)
def make_sentencizer(
nlp: Language,
@ -156,7 +158,9 @@ class Sentencizer(Pipe):
DOCS: https://spacy.io/api/sentencizer#score
"""
return Scorer.score_spans(examples, "sents", **kwargs)
results = Scorer.score_spans(examples, "sents", **kwargs)
del results["sents_per_type"]
return results
def to_bytes(self, exclude=tuple()):
"""Serialize the sentencizer to a bytestring.

View File

@ -33,7 +33,9 @@ DEFAULT_SENTER_MODEL = Config().from_str(default_model_config)["model"]
@Language.factory(
"senter",
assigns=["token.is_sent_start"],
default_config={"model": DEFAULT_SENTER_MODEL}
default_config={"model": DEFAULT_SENTER_MODEL},
scores=["sents_p", "sents_r", "sents_f"],
default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0},
)
def make_senter(nlp: Language, name: str, model: Model):
return SentenceRecognizer(nlp.vocab, model, name)
@ -151,7 +153,9 @@ class SentenceRecognizer(Tagger):
RETURNS (Dict[str, Any]): The scores, produced by Scorer.score_spans.
DOCS: https://spacy.io/api/sentencerecognizer#score
"""
return Scorer.score_spans(examples, "sents", **kwargs)
results = Scorer.score_spans(examples, "sents", **kwargs)
del results["sents_per_type"]
return results
def to_bytes(self, exclude=tuple()):
"""Serialize the pipe to a bytestring.

View File

@ -8,6 +8,7 @@ from ..gold import Example, spans_from_biluo_tags, iob_to_biluo, biluo_to_iob
from ..tokens import Doc
from ..language import Language
from ..vocab import Vocab
from ..scorer import Scorer
from .. import util
from .pipe import Pipe
@ -34,6 +35,9 @@ DEFAULT_SIMPLE_NER_MODEL = Config().from_str(default_model_config)["model"]
"simple_ner",
assigns=["doc.ents"],
default_config={"labels": [], "model": DEFAULT_SIMPLE_NER_MODEL},
scores=["ents_p", "ents_r", "ents_f", "ents_per_type"],
default_score_weights={"ents_f": 1.0, "ents_p": 0.0, "ents_r": 0.0},
)
def make_simple_ner(
nlp: Language, name: str, model: Model, labels: Iterable[str]
@ -173,6 +177,9 @@ class SimpleNER(Pipe):
def init_multitask_objectives(self, *args, **kwargs):
pass
def score(self, examples, **kwargs):
return Scorer.score_spans(examples, "ents", **kwargs)
def _has_ner(example: Example) -> bool:
for ner_tag in example.get_aligned_ner():

View File

@ -39,7 +39,9 @@ DEFAULT_TAGGER_MODEL = Config().from_str(default_model_config)["model"]
@Language.factory(
"tagger",
assigns=["token.tag"],
default_config={"model": DEFAULT_TAGGER_MODEL, "set_morphology": False}
default_config={"model": DEFAULT_TAGGER_MODEL, "set_morphology": False},
scores=["tag_acc", "pos_acc", "lemma_acc"],
default_score_weights={"tag_acc": 1.0},
)
def make_tagger(nlp: Language, name: str, model: Model, set_morphology: bool):
return Tagger(nlp.vocab, model, name, set_morphology=set_morphology)

View File

@ -56,6 +56,8 @@ dropout = null
"textcat",
assigns=["doc.cats"],
default_config={"labels": [], "model": DEFAULT_TEXTCAT_MODEL},
scores=["cats_score", "cats_score_desc", "cats_p", "cats_r", "cats_f", "cats_macro_f", "cats_macro_auc", "cats_f_per_type", "cats_macro_auc_per_type"],
default_score_weights={"cats_score": 1.0},
)
def make_textcat(
nlp: Language, name: str, model: Model, labels: Iterable[str]

View File

@ -298,7 +298,8 @@ class Scorer:
**cfg
):
"""Returns PRF and ROC AUC scores for a doc-level attribute with a
dict with scores for each label like Doc.cats.
dict with scores for each label like Doc.cats. The reported overall
score depends on the scorer settings.
examples (Iterable[Example]): Examples to score
attr (str): The attribute to score.
@ -309,11 +310,16 @@ class Scorer:
Defaults to True.
positive_label (str): The positive label for a binary task with
exclusive classes. Defaults to None.
RETURNS (dict): A dictionary containing the scores:
for binary exclusive with positive label: attr_p/r/f,
for 3+ exclusive classes, macro-averaged fscore: attr_macro_f,
for multilabel, macro-averaged AUC: attr_macro_auc,
for all: attr_f_per_type, attr_auc_per_type
RETURNS (dict): A dictionary containing the scores, with inapplicable
scores as None:
for all:
attr_score (one of attr_f / attr_macro_f / attr_macro_auc),
attr_score_desc (text description of the overall score),
attr_f_per_type,
attr_auc_per_type
for binary exclusive with positive label: attr_p/r/f
for 3+ exclusive classes, macro-averaged fscore: attr_macro_f
for multilabel, macro-averaged AUC: attr_macro_auc
"""
score = PRFScore()
f_per_type = dict()
@ -362,6 +368,13 @@ class Scorer:
)
)
results = {
attr + "_score": None,
attr + "_score_desc": None,
attr + "_p": None,
attr + "_r": None,
attr + "_f": None,
attr + "_macro_f": None,
attr + "_macro_auc": None,
attr + "_f_per_type": {k: v.to_dict() for k, v in f_per_type.items()},
attr + "_auc_per_type": {k: v.score for k, v in auc_per_type.items()},
}
@ -369,16 +382,22 @@ class Scorer:
results[attr + "_p"] = score.precision
results[attr + "_r"] = score.recall
results[attr + "_f"] = score.fscore
results[attr + "_score"] = results[attr + "_f"]
results[attr + "_score_desc"] = "F (" + positive_label + ")"
elif not multi_label:
results[attr + "_macro_f"] = sum(
[score.fscore for label, score in f_per_type.items()]
) / (len(f_per_type) + 1e-100)
results[attr + "_score"] = results[attr + "_macro_f"]
results[attr + "_score_desc"] = "macro F"
else:
results[attr + "_macro_auc"] = max(
sum([score.score for label, score in auc_per_type.items()])
/ (len(auc_per_type) + 1e-100),
-1,
)
results[attr + "_score"] = results[attr + "_macro_auc"]
results[attr + "_score_desc"] = "macro AUC"
return results
@staticmethod

View File

@ -3,7 +3,7 @@ from spacy.language import Language
from spacy.lang.en import English
from spacy.lang.de import German
from spacy.tokens import Doc
from spacy.util import registry, SimpleFrozenDict
from spacy.util import registry, SimpleFrozenDict, combine_score_weights
from thinc.api import Model, Linear
from thinc.config import ConfigValidationError
from pydantic import StrictInt, StrictStr
@ -328,3 +328,54 @@ def test_language_factories_invalid():
assert len(nlp.factories)
with pytest.raises(NotImplementedError):
nlp.factories["foo"] = "bar"
@pytest.mark.parametrize(
"weights,expected",
[
([{"a": 1.0}, {"b": 1.0}, {"c": 1.0}], {"a": 0.33, "b": 0.33, "c": 0.33}),
([{"a": 1.0}, {"b": 50}, {"c": 123}], {"a": 0.33, "b": 0.33, "c": 0.33}),
(
[{"a": 0.7, "b": 0.3}, {"c": 1.0}, {"d": 0.5, "e": 0.5}],
{"a": 0.23, "b": 0.1, "c": 0.33, "d": 0.17, "e": 0.17},
),
(
[{"a": 100, "b": 400}, {"c": 0.5, "d": 0.5}],
{"a": 0.1, "b": 0.4, "c": 0.25, "d": 0.25},
),
(
[{"a": 0.5, "b": 0.5}, {"b": 1.0}],
{"a": 0.25, "b": 0.75},
),
],
)
def test_language_factories_combine_score_weights(weights, expected):
result = combine_score_weights(weights)
assert sum(result.values()) in (0.99, 1.0)
assert result == expected
def test_language_factories_scores():
name = "test_language_factories_scores"
func = lambda doc: doc
weights1 = {"a1": 0.5, "a2": 0.5}
weights2 = {"b1": 0.2, "b2": 0.7, "b3": 0.1}
Language.component(
f"{name}1", scores=list(weights1), default_score_weights=weights1, func=func,
)
Language.component(
f"{name}2", scores=list(weights2), default_score_weights=weights2, func=func,
)
meta1 = Language.get_factory_meta(f"{name}1")
assert meta1.default_score_weights == weights1
meta2 = Language.get_factory_meta(f"{name}2")
assert meta2.default_score_weights == weights2
nlp = Language()
nlp._config["training"]["scores"] = ["speed"]
nlp._config["training"]["score_weights"] = {}
nlp.add_pipe(f"{name}1")
nlp.add_pipe(f"{name}2")
cfg = nlp.config["training"]
assert cfg["scores"] == sorted(["speed", *list(weights1.keys()), *list(weights2.keys())])
expected_weights = {"a1": 0.25, "a2": 0.25, "b1": 0.1, "b2": 0.35, "b3": 0.05}
assert cfg["score_weights"] == expected_weights

View File

@ -121,6 +121,8 @@ def test_overfitting_IO():
train_examples, component_cfg={"scorer": {"positive_label": "POSITIVE"}}
)
assert scores["cats_f"] == 1.0
assert scores["cats_score"] == 1.0
assert "cats_score_desc" in scores
# fmt: off

View File

@ -1130,6 +1130,25 @@ def get_arg_names(func: Callable) -> List[str]:
return list(set([*argspec.args, *argspec.kwonlyargs]))
def combine_score_weights(weights: List[Dict[str, float]]) -> Dict[str, float]:
"""Combine and normalize score weights defined by components, e.g.
{"ents_r": 0.2, "ents_p": 0.3, "ents_f": 0.5} and {"some_other_score": 1.0}.
weights (List[dict]): The weights defined by the components.
RETURNS (Dict[str, float]): The combined and normalized weights.
"""
result = {}
for w_dict in weights:
# We need to account for weights that don't sum to 1.0 and normalize
# the score weights accordingly, then divide score by the number of
# components.
total = sum(w_dict.values())
for key, value in w_dict.items():
weight = round(value / total / len(weights), 2)
result[key] = result.get(key, 0.0) + weight
return result
class DummyTokenizer:
# add dummy methods for to_bytes, from_bytes, to_disk and from_disk to
# allow serialization (see #1557)

View File

@ -8,8 +8,8 @@ source: spacy/scorer.py
The `Scorer` computes evaluation scores. It's typically created by
[`Language.evaluate`](/api/language#evaluate).
In addition, the `Scorer` provides a number of evaluation methods for
evaluating `Token` and `Doc` attributes.
In addition, the `Scorer` provides a number of evaluation methods for evaluating
`Token` and `Doc` attributes.
## Scorer.\_\_init\_\_ {#init tag="method"}
@ -28,10 +28,10 @@ Create a new `Scorer`.
> scorer = Scorer(nlp)
> ```
| Name | Type | Description |
| ------------ | -------- | ------------------------------------------------------------ |
| `nlp` | Language | The pipeline to use for scoring, where each pipeline component may provide a scoring method. If none is provided, then a default pipeline for the multi-language code `xx` is constructed containing: `senter`, `tagger`, `morphologizer`, `parser`, `ner`, `textcat`. |
| **RETURNS** | `Scorer` | The newly created object. |
| Name | Type | Description |
| ----------- | -------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `nlp` | Language | The pipeline to use for scoring, where each pipeline component may provide a scoring method. If none is provided, then a default pipeline for the multi-language code `xx` is constructed containing: `senter`, `tagger`, `morphologizer`, `parser`, `ner`, `textcat`. |
| **RETURNS** | `Scorer` | The newly created object. |
## Scorer.score {#score tag="method"}
@ -39,13 +39,13 @@ Calculate the scores for a list of [`Example`](/api/example) objects using the
scoring methods provided by the components in the pipeline.
The returned `Dict` contains the scores provided by the individual pipeline
components. For the scoring methods provided by the `Scorer` and use by the
core pipeline components, the individual score names start with the `Token` or
`Doc` attribute being scored: `token_acc`, `token_p/r/f`, `sents_p/r/f`,
`tag_acc`, `pos_acc`, `morph_acc`, `morph_per_feat`, `lemma_acc`, `dep_uas`,
`dep_las`, `dep_las_per_type`, `ents_p/r/f`, `ents_per_type`,
`textcat_macro_auc`, `textcat_macro_f`.
components. For the scoring methods provided by the `Scorer` and use by the core
pipeline components, the individual score names start with the `Token` or `Doc`
attribute being scored: `token_acc`, `token_p/r/f`, `sents_p/r/f`, `tag_acc`,
`pos_acc`, `morph_acc`, `morph_per_feat`, `lemma_acc`, `dep_uas`, `dep_las`,
`dep_las_per_type`, `ents_p/r/f`, `ents_per_type`, `textcat_macro_auc`,
`textcat_macro_f`.
> #### Example
>
> ```python
@ -53,19 +53,20 @@ core pipeline components, the individual score names start with the `Token` or
> scorer.score(examples)
> ```
| Name | Type | Description |
| ----------- | --------- | --------------------------------------------------------------------------------------------------------|
| Name | Type | Description |
| ----------- | ------------------- | --------------------------------------------------------------------------------------------- |
| `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. |
| **RETURNS** | `Dict` | A dictionary of scores. |
## Scorer.score_tokenization {#score_tokenization tag="staticmethod"}
Scores the tokenization:
* `token_acc`: # correct tokens / # gold tokens
* `token_p/r/f`: PRF for token character spans
- `token_acc`: # correct tokens / # gold tokens
- `token_p/r/f`: PRF for token character spans
| Name | Type | Description |
| ----------- | --------- | --------------------------------------------------------------------------------------------------------|
| Name | Type | Description |
| ----------- | ------------------- | --------------------------------------------------------------------------------------------- |
| `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. |
| **RETURNS** | `Dict` | A dictionary containing the scores `token_acc/p/r/f`. |
@ -73,61 +74,62 @@ Scores the tokenization:
Scores a single token attribute.
| Name | Type | Description |
| ----------- | --------- | --------------------------------------------------------------------------------------------------------|
| `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. |
| `attr` | `str` | The attribute to score. |
| Name | Type | Description |
| ----------- | ------------------- | ----------------------------------------------------------------------------------------------------------------------------- |
| `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. |
| `attr` | `str` | The attribute to score. |
| `getter` | `callable` | Defaults to `getattr`. If provided, `getter(token, attr)` should return the value of the attribute for an individual `Token`. |
| **RETURNS** | `Dict` | A dictionary containing the score `attr_acc`. |
| **RETURNS** | `Dict` | A dictionary containing the score `attr_acc`. |
## Scorer.score_token_attr_per_feat {#score_token_attr_per_feat tag="staticmethod"}
Scores a single token attribute per feature for a token attribute in UFEATS format.
Scores a single token attribute per feature for a token attribute in UFEATS
format.
| Name | Type | Description |
| ----------- | --------- | --------------------------------------------------------------------------------------------------------|
| `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. |
| `attr` | `str` | The attribute to score. |
| Name | Type | Description |
| ----------- | ------------------- | ----------------------------------------------------------------------------------------------------------------------------- |
| `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. |
| `attr` | `str` | The attribute to score. |
| `getter` | `callable` | Defaults to `getattr`. If provided, `getter(token, attr)` should return the value of the attribute for an individual `Token`. |
| **RETURNS** | `Dict` | A dictionary containing the per-feature PRF scores unders the key `attr_per_feat`. |
| **RETURNS** | `Dict` | A dictionary containing the per-feature PRF scores unders the key `attr_per_feat`. |
## Scorer.score_spans {#score_spans tag="staticmethod"}
Returns PRF scores for labeled or unlabeled spans.
| Name | Type | Description |
| ----------- | --------- | --------------------------------------------------------------------------------------------------------|
| `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. |
| `attr` | `str` | The attribute to score. |
| `getter` | `callable` | Defaults to `getattr`. If provided, `getter(doc, attr)` should return the `Span` objects for an individual `Doc`. |
| Name | Type | Description |
| ----------- | ------------------- | --------------------------------------------------------------------------------------------------------------------- |
| `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. |
| `attr` | `str` | The attribute to score. |
| `getter` | `callable` | Defaults to `getattr`. If provided, `getter(doc, attr)` should return the `Span` objects for an individual `Doc`. |
| **RETURNS** | `Dict` | A dictionary containing the PRF scores under the keys `attr_p/r/f` and the per-type PRF scores under `attr_per_type`. |
## Scorer.score_deps {#score_deps tag="staticmethod"}
Calculate the UAS, LAS, and LAS per type scores for dependency parses.
| Name | Type | Description |
| ----------- | --------- | --------------------------------------------------------------------------------------------------------|
| `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. |
| `attr` | `str` | The attribute containing the dependency label. |
| `getter` | `callable` | Defaults to `getattr`. If provided, `getter(token, attr)` should return the value of the attribute for an individual `Token`. |
| `head_attr` | `str` | The attribute containing the head token. |
| `head_getter` | `callable` | Defaults to `getattr`. If provided, `head_getter(token, attr)` should return the head for an individual `Token`. |
| `ignore_labels` | `Tuple` | Labels to ignore while scoring (e.g., `punct`).
| **RETURNS** | `Dict` | A dictionary containing the scores: `attr_uas`, `attr_las`, and `attr_las_per_type`. |
| Name | Type | Description |
| --------------- | ------------------- | ----------------------------------------------------------------------------------------------------------------------------- |
| `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. |
| `attr` | `str` | The attribute containing the dependency label. |
| `getter` | `callable` | Defaults to `getattr`. If provided, `getter(token, attr)` should return the value of the attribute for an individual `Token`. |
| `head_attr` | `str` | The attribute containing the head token. |
| `head_getter` | `callable` | Defaults to `getattr`. If provided, `head_getter(token, attr)` should return the head for an individual `Token`. |
| `ignore_labels` | `Tuple` | Labels to ignore while scoring (e.g., `punct`). |
| **RETURNS** | `Dict` | A dictionary containing the scores: `attr_uas`, `attr_las`, and `attr_las_per_type`. |
## Scorer.score_cats {#score_cats tag="staticmethod"}
Calculate PRF and ROC AUC scores for a doc-level attribute that is a dict
containing scores for each label like `Doc.cats`.
| Name | Type | Description |
| ----------- | --------- | --------------------------------------------------------------------------------------------------------|
| `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. |
| `attr` | `str` | The attribute to score. |
| `getter` | `callable` | Defaults to `getattr`. If provided, `getter(doc, attr)` should return the cats for an individual `Doc`. |
| labels | `Iterable[str]` | The set of possible labels. Defaults to `[]`. |
| multi_label | `bool` | Whether the attribute allows multiple labels. Defaults to `True`. |
| positive_label | `str` | The positive label for a binary task with exclusive classes. Defaults to `None`. |
| **RETURNS** | `Dict` | A dictionary containing the scores: 1) for binary exclusive with positive label: `attr_p/r/f`; 2) for 3+ exclusive classes, macro-averaged fscore: `attr_macro_f`; 3) for multilabel, macro-averaged AUC: `attr_macro_auc`; 4) for all: `attr_f_per_type`, `attr_auc_per_type` |
containing scores for each label like `Doc.cats`. The reported overall score
depends on the scorer settings.
| Name | Type | Description |
| ---------------- | ------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. |
| `attr` | `str` | The attribute to score. |
| `getter` | `callable` | Defaults to `getattr`. If provided, `getter(doc, attr)` should return the cats for an individual `Doc`. |
| labels | `Iterable[str]` | The set of possible labels. Defaults to `[]`. |
| `multi_label` | `bool` | Whether the attribute allows multiple labels. Defaults to `True`. |
| `positive_label` | `str` | The positive label for a binary task with exclusive classes. Defaults to `None`. |
| **RETURNS** | `Dict` | A dictionary containing the scores, with inapplicable scores as `None`: 1) for all: `attr_score` (one of `attr_f` / `attr_macro_f` / `attr_macro_auc`), `attr_score_desc` (text description of the overall score), `attr_f_per_type`, `attr_auc_per_type`; 2) for binary exclusive with positive label: `attr_p/r/f`; 3) for 3+ exclusive classes, macro-averaged fscore: `attr_macro_f`; 4) for multilabel, macro-averaged AUC: `attr_macro_auc` |