mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-08 14:14:57 +03:00
Make beta a component scorer setting.
This commit is contained in:
parent
ea9737a664
commit
110850f095
|
@ -1,4 +1,5 @@
|
|||
import functools
|
||||
from functools import partial
|
||||
import operator
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
@ -115,8 +116,12 @@ def find_threshold(
|
|||
pipe = nlp.get_pipe(pipe_name)
|
||||
except KeyError as err:
|
||||
wasabi.msg.fail(title=str(err), exits=1)
|
||||
|
||||
if not isinstance(pipe, TrainablePipe):
|
||||
raise TypeError(Errors.E1044)
|
||||
if not hasattr(pipe, "scorer"):
|
||||
raise AttributeError(Errors.E1045)
|
||||
setattr(pipe, "scorer", partial(pipe.scorer.func, beta=beta))
|
||||
|
||||
if not silent:
|
||||
wasabi.msg.info(
|
||||
|
@ -145,9 +150,7 @@ def find_threshold(
|
|||
scores: Dict[float, float] = {}
|
||||
for threshold in numpy.linspace(0, 1, n_trials):
|
||||
pipe.cfg = set_nested_item(pipe.cfg, config_keys, threshold)
|
||||
scores[threshold] = nlp.evaluate(dev_dataset, scorer_cfg={"beta": beta})[
|
||||
scores_key
|
||||
]
|
||||
scores[threshold] = nlp.evaluate(dev_dataset)[scores_key]
|
||||
if not (
|
||||
isinstance(scores[threshold], float) or isinstance(scores[threshold], int)
|
||||
):
|
||||
|
|
|
@ -939,7 +939,8 @@ class Errors(metaclass=ErrorsWithCodes):
|
|||
"`{arg2}`={arg2_values} but these arguments are conflicting.")
|
||||
E1043 = ("Expected None or a value in range [{range_start}, {range_end}] for entity linker threshold, but got "
|
||||
"{value}.")
|
||||
E1044 = ("Only components of type `TrainablePipe` are supported by `find_threshold()`.")
|
||||
E1044 = ("`find_threshold()` only supports components of type `TrainablePipe`.")
|
||||
E1045 = ("`find_threshold()` only supports components with a `scorer` attribute.")
|
||||
|
||||
|
||||
# Deprecated model shortcuts, only used in errors and warnings
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from functools import partial
|
||||
from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast
|
||||
from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops
|
||||
from thinc.api import Optimizer
|
||||
|
@ -165,8 +166,8 @@ def spancat_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
|
|||
|
||||
|
||||
@registry.scorers("spacy.spancat_scorer.v1")
|
||||
def make_spancat_scorer():
|
||||
return spancat_score
|
||||
def make_spancat_scorer(beta: float = 1.0):
|
||||
return partial(spancat_score, beta=beta)
|
||||
|
||||
|
||||
class SpanCategorizer(TrainablePipe):
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from functools import partial
|
||||
from typing import Iterable, Optional, Dict, List, Callable, Any
|
||||
from thinc.types import Floats2d
|
||||
from thinc.api import Model, Config
|
||||
|
@ -121,8 +122,8 @@ def textcat_multilabel_score(examples: Iterable[Example], **kwargs) -> Dict[str,
|
|||
|
||||
|
||||
@registry.scorers("spacy.textcat_multilabel_scorer.v1")
|
||||
def make_textcat_multilabel_scorer():
|
||||
return textcat_multilabel_score
|
||||
def make_textcat_multilabel_scorer(beta: float = 1.0):
|
||||
return partial(textcat_multilabel_score, beta=beta)
|
||||
|
||||
|
||||
class MultiLabel_TextCategorizer(TextCategorizer):
|
||||
|
|
|
@ -102,7 +102,7 @@ class ROCAUCScore:
|
|||
class Scorer:
|
||||
"""Compute evaluation scores."""
|
||||
|
||||
BETA = 1
|
||||
BETA = 1.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -336,6 +336,7 @@ class Scorer:
|
|||
has_annotation: Optional[Callable[[Doc], bool]] = None,
|
||||
labeled: bool = True,
|
||||
allow_overlap: bool = False,
|
||||
beta: float = 1.0,
|
||||
**cfg,
|
||||
) -> Dict[str, Any]:
|
||||
"""Returns PRF scores for labeled spans.
|
||||
|
@ -353,12 +354,12 @@ class Scorer:
|
|||
equal if their start and end match, irrespective of their label.
|
||||
allow_overlap (bool): Whether or not to allow overlapping spans.
|
||||
If set to 'False', the alignment will automatically resolve conflicts.
|
||||
beta (float): Beta coefficient for F-score calculation. Defaults to 1.0.
|
||||
RETURNS (Dict[str, Any]): A dictionary containing the PRF scores under
|
||||
the keys attr_p/r/f and the per-type PRF scores under attr_per_type.
|
||||
|
||||
DOCS: https://spacy.io/api/scorer#score_spans
|
||||
"""
|
||||
beta = cfg.get("beta", Scorer.BETA)
|
||||
score = PRFScore(beta=beta)
|
||||
score_per_type = dict()
|
||||
for example in examples:
|
||||
|
@ -439,6 +440,7 @@ class Scorer:
|
|||
multi_label: bool = True,
|
||||
positive_label: Optional[str] = None,
|
||||
threshold: Optional[float] = None,
|
||||
beta: float = 1.0,
|
||||
**cfg,
|
||||
) -> Dict[str, Any]:
|
||||
"""Returns PRF and ROC AUC scores for a doc-level attribute with a
|
||||
|
@ -458,6 +460,7 @@ class Scorer:
|
|||
threshold (float): Cutoff to consider a prediction "positive". Defaults
|
||||
to 0.5 for multi-label, and 0.0 (i.e. whatever's highest scoring)
|
||||
otherwise.
|
||||
beta (float): Beta coefficient for F-score calculation.
|
||||
RETURNS (Dict[str, Any]): A dictionary containing the scores, with
|
||||
inapplicable scores as None:
|
||||
for all:
|
||||
|
@ -475,7 +478,6 @@ class Scorer:
|
|||
|
||||
DOCS: https://spacy.io/api/scorer#score_cats
|
||||
"""
|
||||
beta = cfg.get("beta", Scorer.BETA)
|
||||
if threshold is None:
|
||||
threshold = 0.5 if multi_label else 0.0
|
||||
f_per_type = {label: PRFScore(beta=beta) for label in labels}
|
||||
|
|
|
@ -928,12 +928,12 @@ def test_cli_find_threshold(capsys):
|
|||
pipe_name="tc_multi",
|
||||
threshold_key="threshold",
|
||||
scores_key="cats_macro_f",
|
||||
silent=True,
|
||||
silent=False,
|
||||
)[0]
|
||||
== numpy.linspace(0, 1, 10)[1]
|
||||
)
|
||||
|
||||
# Specifying name of non-MultiLabel_TextCategorizer component should fail.
|
||||
# Test with spancat.
|
||||
nlp, _ = init_nlp((("spancat", {}),))
|
||||
with make_tempdir() as nlp_dir:
|
||||
nlp.to_disk(nlp_dir)
|
||||
|
|
Loading…
Reference in New Issue
Block a user