Make beta a component scorer setting.

This commit is contained in:
Raphael Mitsch 2022-09-02 12:35:46 +02:00
parent ea9737a664
commit 110850f095
6 changed files with 21 additions and 13 deletions

View File

@ -1,4 +1,5 @@
import functools
from functools import partial
import operator
from pathlib import Path
import logging
@ -115,8 +116,12 @@ def find_threshold(
pipe = nlp.get_pipe(pipe_name)
except KeyError as err:
wasabi.msg.fail(title=str(err), exits=1)
if not isinstance(pipe, TrainablePipe):
raise TypeError(Errors.E1044)
if not hasattr(pipe, "scorer"):
raise AttributeError(Errors.E1045)
setattr(pipe, "scorer", partial(pipe.scorer.func, beta=beta))
if not silent:
wasabi.msg.info(
@ -145,9 +150,7 @@ def find_threshold(
scores: Dict[float, float] = {}
for threshold in numpy.linspace(0, 1, n_trials):
pipe.cfg = set_nested_item(pipe.cfg, config_keys, threshold)
scores[threshold] = nlp.evaluate(dev_dataset, scorer_cfg={"beta": beta})[
scores_key
]
scores[threshold] = nlp.evaluate(dev_dataset)[scores_key]
if not (
isinstance(scores[threshold], float) or isinstance(scores[threshold], int)
):

View File

@ -939,7 +939,8 @@ class Errors(metaclass=ErrorsWithCodes):
"`{arg2}`={arg2_values} but these arguments are conflicting.")
E1043 = ("Expected None or a value in range [{range_start}, {range_end}] for entity linker threshold, but got "
"{value}.")
E1044 = ("Only components of type `TrainablePipe` are supported by `find_threshold()`.")
E1044 = ("`find_threshold()` only supports components of type `TrainablePipe`.")
E1045 = ("`find_threshold()` only supports components with a `scorer` attribute.")
# Deprecated model shortcuts, only used in errors and warnings

View File

@ -1,3 +1,4 @@
from functools import partial
from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast
from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops
from thinc.api import Optimizer
@ -165,8 +166,8 @@ def spancat_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
@registry.scorers("spacy.spancat_scorer.v1")
def make_spancat_scorer():
return spancat_score
def make_spancat_scorer(beta: float = 1.0):
return partial(spancat_score, beta=beta)
class SpanCategorizer(TrainablePipe):

View File

@ -1,3 +1,4 @@
from functools import partial
from typing import Iterable, Optional, Dict, List, Callable, Any
from thinc.types import Floats2d
from thinc.api import Model, Config
@ -121,8 +122,8 @@ def textcat_multilabel_score(examples: Iterable[Example], **kwargs) -> Dict[str,
@registry.scorers("spacy.textcat_multilabel_scorer.v1")
def make_textcat_multilabel_scorer():
return textcat_multilabel_score
def make_textcat_multilabel_scorer(beta: float = 1.0):
return partial(textcat_multilabel_score, beta=beta)
class MultiLabel_TextCategorizer(TextCategorizer):

View File

@ -102,7 +102,7 @@ class ROCAUCScore:
class Scorer:
"""Compute evaluation scores."""
BETA = 1
BETA = 1.0
def __init__(
self,
@ -336,6 +336,7 @@ class Scorer:
has_annotation: Optional[Callable[[Doc], bool]] = None,
labeled: bool = True,
allow_overlap: bool = False,
beta: float = 1.0,
**cfg,
) -> Dict[str, Any]:
"""Returns PRF scores for labeled spans.
@ -353,12 +354,12 @@ class Scorer:
equal if their start and end match, irrespective of their label.
allow_overlap (bool): Whether or not to allow overlapping spans.
If set to 'False', the alignment will automatically resolve conflicts.
beta (float): Beta coefficient for F-score calculation. Defaults to 1.0.
RETURNS (Dict[str, Any]): A dictionary containing the PRF scores under
the keys attr_p/r/f and the per-type PRF scores under attr_per_type.
DOCS: https://spacy.io/api/scorer#score_spans
"""
beta = cfg.get("beta", Scorer.BETA)
score = PRFScore(beta=beta)
score_per_type = dict()
for example in examples:
@ -439,6 +440,7 @@ class Scorer:
multi_label: bool = True,
positive_label: Optional[str] = None,
threshold: Optional[float] = None,
beta: float = 1.0,
**cfg,
) -> Dict[str, Any]:
"""Returns PRF and ROC AUC scores for a doc-level attribute with a
@ -458,6 +460,7 @@ class Scorer:
threshold (float): Cutoff to consider a prediction "positive". Defaults
to 0.5 for multi-label, and 0.0 (i.e. whatever's highest scoring)
otherwise.
beta (float): Beta coefficient for F-score calculation.
RETURNS (Dict[str, Any]): A dictionary containing the scores, with
inapplicable scores as None:
for all:
@ -475,7 +478,6 @@ class Scorer:
DOCS: https://spacy.io/api/scorer#score_cats
"""
beta = cfg.get("beta", Scorer.BETA)
if threshold is None:
threshold = 0.5 if multi_label else 0.0
f_per_type = {label: PRFScore(beta=beta) for label in labels}

View File

@ -928,12 +928,12 @@ def test_cli_find_threshold(capsys):
pipe_name="tc_multi",
threshold_key="threshold",
scores_key="cats_macro_f",
silent=True,
silent=False,
)[0]
== numpy.linspace(0, 1, 10)[1]
)
# Specifying name of non-MultiLabel_TextCategorizer component should fail.
# Test with spancat.
nlp, _ = init_nlp((("spancat", {}),))
with make_tempdir() as nlp_dir:
nlp.to_disk(nlp_dir)