mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-08 06:04:57 +03:00
Reload pipeline on threshold change. Adjust tests. Remove confection reference.
This commit is contained in:
parent
73432c6bfb
commit
20c4a0d613
|
@ -6,7 +6,6 @@ from typing import Optional, Tuple, Any, Dict, List
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
import wasabi.tables
|
import wasabi.tables
|
||||||
from confection import Config
|
|
||||||
|
|
||||||
from ..pipeline import TrainablePipe, Pipe
|
from ..pipeline import TrainablePipe, Pipe
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
|
@ -87,7 +86,7 @@ def find_threshold(
|
||||||
use_gpu: int = _DEFAULTS["use_gpu"], # type: ignore
|
use_gpu: int = _DEFAULTS["use_gpu"], # type: ignore
|
||||||
gold_preproc: bool = _DEFAULTS["gold_preproc"], # type: ignore
|
gold_preproc: bool = _DEFAULTS["gold_preproc"], # type: ignore
|
||||||
silent: bool = True,
|
silent: bool = True,
|
||||||
) -> Tuple[float, float]:
|
) -> Tuple[float, float, Dict[float, float]]:
|
||||||
"""
|
"""
|
||||||
Runs prediction trials for `textcat` models with varying tresholds to maximize the specified metric.
|
Runs prediction trials for `textcat` models with varying tresholds to maximize the specified metric.
|
||||||
model (Union[str, Path]): Path to file with trained model.
|
model (Union[str, Path]): Path to file with trained model.
|
||||||
|
@ -102,7 +101,8 @@ def find_threshold(
|
||||||
tokenization, and may result in sequences of more consistent length. However, it may reduce runtime accuracy due
|
tokenization, and may result in sequences of more consistent length. However, it may reduce runtime accuracy due
|
||||||
to train/test skew.
|
to train/test skew.
|
||||||
silent (bool): Whether to print non-error-related output to stdout.
|
silent (bool): Whether to print non-error-related output to stdout.
|
||||||
RETURNS (Tuple[float, float]): Best found threshold with corresponding F-score.
|
RETURNS (Tuple[float, float, Dict[float, float]]): Best found threshold, the corresponding score, scores for all
|
||||||
|
evaluated thresholds.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
setup_gpu(use_gpu, silent=silent)
|
setup_gpu(use_gpu, silent=silent)
|
||||||
|
@ -138,18 +138,39 @@ def find_threshold(
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Set item in nested dictionary. Adapated from https://stackoverflow.com/a/54138200.
|
"""Set item in nested dictionary. Adapated from https://stackoverflow.com/a/54138200.
|
||||||
config (Dict[str, Any]): Configuration dictionary.
|
config (Dict[str, Any]): Configuration dictionary.
|
||||||
keys (List[Any]):
|
keys (List[Any]): Path to value to set.
|
||||||
value (float): Value to set.
|
value (float): Value to set.
|
||||||
RETURNS (Dict[str, Any]): Updated dictionary.
|
RETURNS (Dict[str, Any]): Updated dictionary.
|
||||||
"""
|
"""
|
||||||
functools.reduce(operator.getitem, keys[:-1], config)[keys[-1]] = value
|
functools.reduce(operator.getitem, keys[:-1], config)[keys[-1]] = value
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
def filter_config(config: Dict[str, Any], keys: List[str]) -> Dict[str, Any]:
|
||||||
|
"""Filters provided config dictionary so that only the specified keys path remains.
|
||||||
|
config (Dict[str, Any]): Configuration dictionary.
|
||||||
|
keys (List[Any]): Path to value to set.
|
||||||
|
RETURNS (Dict[str, Any]): Filtered dictionary.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
keys[0]: filter_config(config[keys[0]], keys[1:])
|
||||||
|
if len(keys) > 1
|
||||||
|
else config[keys[0]]
|
||||||
|
}
|
||||||
|
|
||||||
# Evaluate with varying threshold values.
|
# Evaluate with varying threshold values.
|
||||||
scores: Dict[float, float] = {}
|
scores: Dict[float, float] = {}
|
||||||
|
config_keys_full = ["components", pipe_name, *config_keys]
|
||||||
for threshold in numpy.linspace(0, 1, n_trials):
|
for threshold in numpy.linspace(0, 1, n_trials):
|
||||||
pipe.cfg = set_nested_item(pipe.cfg, config_keys, threshold)
|
# Reload pipeline with overrides specifying the new threshold.
|
||||||
nlp._pipe_configs[pipe_name] = Config(pipe.cfg)
|
nlp = util.load_model(
|
||||||
|
model,
|
||||||
|
config=set_nested_item(
|
||||||
|
filter_config(nlp.config, config_keys_full).copy(),
|
||||||
|
config_keys_full,
|
||||||
|
threshold,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
nlp.get_pipe(pipe_name).cfg = set_nested_item(pipe.cfg, config_keys, threshold)
|
||||||
scores[threshold] = nlp.evaluate(dev_dataset)[scores_key]
|
scores[threshold] = nlp.evaluate(dev_dataset)[scores_key]
|
||||||
if not (
|
if not (
|
||||||
isinstance(scores[threshold], float) or isinstance(scores[threshold], int)
|
isinstance(scores[threshold], float) or isinstance(scores[threshold], int)
|
||||||
|
@ -170,4 +191,4 @@ def find_threshold(
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return best_threshold, scores[best_threshold]
|
return best_threshold, scores[best_threshold], scores
|
||||||
|
|
|
@ -36,7 +36,7 @@ from spacy.tokens.span import Span
|
||||||
from spacy.training import Example, docs_to_json, offsets_to_biluo_tags
|
from spacy.training import Example, docs_to_json, offsets_to_biluo_tags
|
||||||
from spacy.training.converters import conll_ner_to_docs, conllu_to_docs
|
from spacy.training.converters import conll_ner_to_docs, conllu_to_docs
|
||||||
from spacy.training.converters import iob_to_docs
|
from spacy.training.converters import iob_to_docs
|
||||||
from spacy.pipeline import TextCategorizer, Pipe, SpanCategorizer
|
from spacy.pipeline import TextCategorizer
|
||||||
from spacy.util import ENV_VARS, get_minor_version, load_model_from_config, load_config
|
from spacy.util import ENV_VARS, get_minor_version, load_model_from_config, load_config
|
||||||
|
|
||||||
from ..cli.init_pipeline import _init_labels
|
from ..cli.init_pipeline import _init_labels
|
||||||
|
@ -860,6 +860,8 @@ def test_span_length_freq_dist_output_must_be_correct():
|
||||||
|
|
||||||
|
|
||||||
def test_cli_find_threshold(capsys):
|
def test_cli_find_threshold(capsys):
|
||||||
|
thresholds = numpy.linspace(0, 1, 10)
|
||||||
|
|
||||||
def make_examples(_nlp: Language) -> List[Example]:
|
def make_examples(_nlp: Language) -> List[Example]:
|
||||||
docs: List[Example] = []
|
docs: List[Example] = []
|
||||||
|
|
||||||
|
@ -921,33 +923,35 @@ def test_cli_find_threshold(capsys):
|
||||||
)
|
)
|
||||||
with make_tempdir() as nlp_dir:
|
with make_tempdir() as nlp_dir:
|
||||||
nlp.to_disk(nlp_dir)
|
nlp.to_disk(nlp_dir)
|
||||||
assert (
|
res = find_threshold(
|
||||||
find_threshold(
|
model=nlp_dir,
|
||||||
model=nlp_dir,
|
data_path=docs_dir / "docs.spacy",
|
||||||
data_path=docs_dir / "docs.spacy",
|
pipe_name="tc_multi",
|
||||||
pipe_name="tc_multi",
|
threshold_key="threshold",
|
||||||
threshold_key="threshold",
|
scores_key="cats_macro_f",
|
||||||
scores_key="cats_macro_f",
|
silent=False,
|
||||||
silent=False,
|
|
||||||
)[0]
|
|
||||||
== numpy.linspace(0, 1, 10)[1]
|
|
||||||
)
|
)
|
||||||
|
assert res[0] != thresholds[0]
|
||||||
|
assert thresholds[0] < res[0] < thresholds[9]
|
||||||
|
assert res[1] == 1.0
|
||||||
|
assert res[2][1.0] == 0.0
|
||||||
|
|
||||||
# Test with spancat.
|
# Test with spancat.
|
||||||
nlp, _ = init_nlp((("spancat", {}),))
|
nlp, _ = init_nlp((("spancat", {}),))
|
||||||
with make_tempdir() as nlp_dir:
|
with make_tempdir() as nlp_dir:
|
||||||
nlp.to_disk(nlp_dir)
|
nlp.to_disk(nlp_dir)
|
||||||
assert (
|
res = find_threshold(
|
||||||
find_threshold(
|
model=nlp_dir,
|
||||||
model=nlp_dir,
|
data_path=docs_dir / "docs.spacy",
|
||||||
data_path=docs_dir / "docs.spacy",
|
pipe_name="spancat",
|
||||||
pipe_name="spancat",
|
threshold_key="threshold",
|
||||||
threshold_key="threshold",
|
scores_key="spans_sc_f",
|
||||||
scores_key="spans_sc_f",
|
silent=True,
|
||||||
silent=True,
|
|
||||||
)[0]
|
|
||||||
== numpy.linspace(0, 1, 10)[1]
|
|
||||||
)
|
)
|
||||||
|
assert res[0] != thresholds[0]
|
||||||
|
assert thresholds[0] < res[0] < thresholds[8]
|
||||||
|
assert res[1] == 1.0
|
||||||
|
assert res[2][1.0] == 0.0
|
||||||
|
|
||||||
# Having multiple textcat_multilabel components should work, since the name has to be specified.
|
# Having multiple textcat_multilabel components should work, since the name has to be specified.
|
||||||
nlp, _ = init_nlp((("textcat_multilabel", {}),))
|
nlp, _ = init_nlp((("textcat_multilabel", {}),))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user