Make beta a component scorer setting.

This commit is contained in:
Raphael Mitsch 2022-09-02 12:35:46 +02:00
parent ea9737a664
commit 110850f095
6 changed files with 21 additions and 13 deletions

View File

@ -1,4 +1,5 @@
import functools import functools
from functools import partial
import operator import operator
from pathlib import Path from pathlib import Path
import logging import logging
@ -115,8 +116,12 @@ def find_threshold(
pipe = nlp.get_pipe(pipe_name) pipe = nlp.get_pipe(pipe_name)
except KeyError as err: except KeyError as err:
wasabi.msg.fail(title=str(err), exits=1) wasabi.msg.fail(title=str(err), exits=1)
if not isinstance(pipe, TrainablePipe): if not isinstance(pipe, TrainablePipe):
raise TypeError(Errors.E1044) raise TypeError(Errors.E1044)
if not hasattr(pipe, "scorer"):
raise AttributeError(Errors.E1045)
setattr(pipe, "scorer", partial(pipe.scorer.func, beta=beta))
if not silent: if not silent:
wasabi.msg.info( wasabi.msg.info(
@ -145,9 +150,7 @@ def find_threshold(
scores: Dict[float, float] = {} scores: Dict[float, float] = {}
for threshold in numpy.linspace(0, 1, n_trials): for threshold in numpy.linspace(0, 1, n_trials):
pipe.cfg = set_nested_item(pipe.cfg, config_keys, threshold) pipe.cfg = set_nested_item(pipe.cfg, config_keys, threshold)
scores[threshold] = nlp.evaluate(dev_dataset, scorer_cfg={"beta": beta})[ scores[threshold] = nlp.evaluate(dev_dataset)[scores_key]
scores_key
]
if not ( if not (
isinstance(scores[threshold], float) or isinstance(scores[threshold], int) isinstance(scores[threshold], float) or isinstance(scores[threshold], int)
): ):

View File

@ -939,7 +939,8 @@ class Errors(metaclass=ErrorsWithCodes):
"`{arg2}`={arg2_values} but these arguments are conflicting.") "`{arg2}`={arg2_values} but these arguments are conflicting.")
E1043 = ("Expected None or a value in range [{range_start}, {range_end}] for entity linker threshold, but got " E1043 = ("Expected None or a value in range [{range_start}, {range_end}] for entity linker threshold, but got "
"{value}.") "{value}.")
E1044 = ("Only components of type `TrainablePipe` are supported by `find_threshold()`.") E1044 = ("`find_threshold()` only supports components of type `TrainablePipe`.")
E1045 = ("`find_threshold()` only supports components with a `scorer` attribute.")
# Deprecated model shortcuts, only used in errors and warnings # Deprecated model shortcuts, only used in errors and warnings

View File

@ -1,3 +1,4 @@
from functools import partial
from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast
from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops
from thinc.api import Optimizer from thinc.api import Optimizer
@ -165,8 +166,8 @@ def spancat_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
@registry.scorers("spacy.spancat_scorer.v1") @registry.scorers("spacy.spancat_scorer.v1")
def make_spancat_scorer(): def make_spancat_scorer(beta: float = 1.0):
return spancat_score return partial(spancat_score, beta=beta)
class SpanCategorizer(TrainablePipe): class SpanCategorizer(TrainablePipe):

View File

@ -1,3 +1,4 @@
from functools import partial
from typing import Iterable, Optional, Dict, List, Callable, Any from typing import Iterable, Optional, Dict, List, Callable, Any
from thinc.types import Floats2d from thinc.types import Floats2d
from thinc.api import Model, Config from thinc.api import Model, Config
@ -121,8 +122,8 @@ def textcat_multilabel_score(examples: Iterable[Example], **kwargs) -> Dict[str,
@registry.scorers("spacy.textcat_multilabel_scorer.v1") @registry.scorers("spacy.textcat_multilabel_scorer.v1")
def make_textcat_multilabel_scorer(): def make_textcat_multilabel_scorer(beta: float = 1.0):
return textcat_multilabel_score return partial(textcat_multilabel_score, beta=beta)
class MultiLabel_TextCategorizer(TextCategorizer): class MultiLabel_TextCategorizer(TextCategorizer):

View File

@ -102,7 +102,7 @@ class ROCAUCScore:
class Scorer: class Scorer:
"""Compute evaluation scores.""" """Compute evaluation scores."""
BETA = 1 BETA = 1.0
def __init__( def __init__(
self, self,
@ -336,6 +336,7 @@ class Scorer:
has_annotation: Optional[Callable[[Doc], bool]] = None, has_annotation: Optional[Callable[[Doc], bool]] = None,
labeled: bool = True, labeled: bool = True,
allow_overlap: bool = False, allow_overlap: bool = False,
beta: float = 1.0,
**cfg, **cfg,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Returns PRF scores for labeled spans. """Returns PRF scores for labeled spans.
@ -353,12 +354,12 @@ class Scorer:
equal if their start and end match, irrespective of their label. equal if their start and end match, irrespective of their label.
allow_overlap (bool): Whether or not to allow overlapping spans. allow_overlap (bool): Whether or not to allow overlapping spans.
If set to 'False', the alignment will automatically resolve conflicts. If set to 'False', the alignment will automatically resolve conflicts.
beta (float): Beta coefficient for F-score calculation. Defaults to 1.0.
RETURNS (Dict[str, Any]): A dictionary containing the PRF scores under RETURNS (Dict[str, Any]): A dictionary containing the PRF scores under
the keys attr_p/r/f and the per-type PRF scores under attr_per_type. the keys attr_p/r/f and the per-type PRF scores under attr_per_type.
DOCS: https://spacy.io/api/scorer#score_spans DOCS: https://spacy.io/api/scorer#score_spans
""" """
beta = cfg.get("beta", Scorer.BETA)
score = PRFScore(beta=beta) score = PRFScore(beta=beta)
score_per_type = dict() score_per_type = dict()
for example in examples: for example in examples:
@ -439,6 +440,7 @@ class Scorer:
multi_label: bool = True, multi_label: bool = True,
positive_label: Optional[str] = None, positive_label: Optional[str] = None,
threshold: Optional[float] = None, threshold: Optional[float] = None,
beta: float = 1.0,
**cfg, **cfg,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Returns PRF and ROC AUC scores for a doc-level attribute with a """Returns PRF and ROC AUC scores for a doc-level attribute with a
@ -458,6 +460,7 @@ class Scorer:
threshold (float): Cutoff to consider a prediction "positive". Defaults threshold (float): Cutoff to consider a prediction "positive". Defaults
to 0.5 for multi-label, and 0.0 (i.e. whatever's highest scoring) to 0.5 for multi-label, and 0.0 (i.e. whatever's highest scoring)
otherwise. otherwise.
beta (float): Beta coefficient for F-score calculation.
RETURNS (Dict[str, Any]): A dictionary containing the scores, with RETURNS (Dict[str, Any]): A dictionary containing the scores, with
inapplicable scores as None: inapplicable scores as None:
for all: for all:
@ -475,7 +478,6 @@ class Scorer:
DOCS: https://spacy.io/api/scorer#score_cats DOCS: https://spacy.io/api/scorer#score_cats
""" """
beta = cfg.get("beta", Scorer.BETA)
if threshold is None: if threshold is None:
threshold = 0.5 if multi_label else 0.0 threshold = 0.5 if multi_label else 0.0
f_per_type = {label: PRFScore(beta=beta) for label in labels} f_per_type = {label: PRFScore(beta=beta) for label in labels}

View File

@ -928,12 +928,12 @@ def test_cli_find_threshold(capsys):
pipe_name="tc_multi", pipe_name="tc_multi",
threshold_key="threshold", threshold_key="threshold",
scores_key="cats_macro_f", scores_key="cats_macro_f",
silent=True, silent=False,
)[0] )[0]
== numpy.linspace(0, 1, 10)[1] == numpy.linspace(0, 1, 10)[1]
) )
# Specifying name of non-MultiLabel_TextCategorizer component should fail. # Test with spancat.
nlp, _ = init_nlp((("spancat", {}),)) nlp, _ = init_nlp((("spancat", {}),))
with make_tempdir() as nlp_dir: with make_tempdir() as nlp_dir:
nlp.to_disk(nlp_dir) nlp.to_disk(nlp_dir)