Reload pipeline on threshold change. Adjust tests. Remove confection reference.

This commit is contained in:
Raphael Mitsch 2022-09-05 11:39:29 +02:00
parent 73432c6bfb
commit 20c4a0d613
2 changed files with 53 additions and 28 deletions

View File

@ -6,7 +6,6 @@ from typing import Optional, Tuple, Any, Dict, List
import numpy import numpy
import wasabi.tables import wasabi.tables
from confection import Config
from ..pipeline import TrainablePipe, Pipe from ..pipeline import TrainablePipe, Pipe
from ..errors import Errors from ..errors import Errors
@ -87,7 +86,7 @@ def find_threshold(
use_gpu: int = _DEFAULTS["use_gpu"], # type: ignore use_gpu: int = _DEFAULTS["use_gpu"], # type: ignore
gold_preproc: bool = _DEFAULTS["gold_preproc"], # type: ignore gold_preproc: bool = _DEFAULTS["gold_preproc"], # type: ignore
silent: bool = True, silent: bool = True,
) -> Tuple[float, float]: ) -> Tuple[float, float, Dict[float, float]]:
""" """
Runs prediction trials for `textcat` models with varying tresholds to maximize the specified metric. Runs prediction trials for `textcat` models with varying tresholds to maximize the specified metric.
model (Union[str, Path]): Path to file with trained model. model (Union[str, Path]): Path to file with trained model.
@ -102,7 +101,8 @@ def find_threshold(
tokenization, and may result in sequences of more consistent length. However, it may reduce runtime accuracy due tokenization, and may result in sequences of more consistent length. However, it may reduce runtime accuracy due
to train/test skew. to train/test skew.
silent (bool): Whether to print non-error-related output to stdout. silent (bool): Whether to print non-error-related output to stdout.
RETURNS (Tuple[float, float]): Best found threshold with corresponding F-score. RETURNS (Tuple[float, float, Dict[float, float]]): Best found threshold, the corresponding score, scores for all
evaluated thresholds.
""" """
setup_gpu(use_gpu, silent=silent) setup_gpu(use_gpu, silent=silent)
@ -138,18 +138,39 @@ def find_threshold(
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Set item in nested dictionary. Adapated from https://stackoverflow.com/a/54138200. """Set item in nested dictionary. Adapated from https://stackoverflow.com/a/54138200.
config (Dict[str, Any]): Configuration dictionary. config (Dict[str, Any]): Configuration dictionary.
keys (List[Any]): keys (List[Any]): Path to value to set.
value (float): Value to set. value (float): Value to set.
RETURNS (Dict[str, Any]): Updated dictionary. RETURNS (Dict[str, Any]): Updated dictionary.
""" """
functools.reduce(operator.getitem, keys[:-1], config)[keys[-1]] = value functools.reduce(operator.getitem, keys[:-1], config)[keys[-1]] = value
return config return config
def filter_config(config: Dict[str, Any], keys: List[str]) -> Dict[str, Any]:
"""Filters provided config dictionary so that only the specified keys path remains.
config (Dict[str, Any]): Configuration dictionary.
keys (List[Any]): Path to value to set.
RETURNS (Dict[str, Any]): Filtered dictionary.
"""
return {
keys[0]: filter_config(config[keys[0]], keys[1:])
if len(keys) > 1
else config[keys[0]]
}
# Evaluate with varying threshold values. # Evaluate with varying threshold values.
scores: Dict[float, float] = {} scores: Dict[float, float] = {}
config_keys_full = ["components", pipe_name, *config_keys]
for threshold in numpy.linspace(0, 1, n_trials): for threshold in numpy.linspace(0, 1, n_trials):
pipe.cfg = set_nested_item(pipe.cfg, config_keys, threshold) # Reload pipeline with overrides specifying the new threshold.
nlp._pipe_configs[pipe_name] = Config(pipe.cfg) nlp = util.load_model(
model,
config=set_nested_item(
filter_config(nlp.config, config_keys_full).copy(),
config_keys_full,
threshold,
),
)
nlp.get_pipe(pipe_name).cfg = set_nested_item(pipe.cfg, config_keys, threshold)
scores[threshold] = nlp.evaluate(dev_dataset)[scores_key] scores[threshold] = nlp.evaluate(dev_dataset)[scores_key]
if not ( if not (
isinstance(scores[threshold], float) or isinstance(scores[threshold], int) isinstance(scores[threshold], float) or isinstance(scores[threshold], int)
@ -170,4 +191,4 @@ def find_threshold(
), ),
) )
return best_threshold, scores[best_threshold] return best_threshold, scores[best_threshold], scores

View File

@ -36,7 +36,7 @@ from spacy.tokens.span import Span
from spacy.training import Example, docs_to_json, offsets_to_biluo_tags from spacy.training import Example, docs_to_json, offsets_to_biluo_tags
from spacy.training.converters import conll_ner_to_docs, conllu_to_docs from spacy.training.converters import conll_ner_to_docs, conllu_to_docs
from spacy.training.converters import iob_to_docs from spacy.training.converters import iob_to_docs
from spacy.pipeline import TextCategorizer, Pipe, SpanCategorizer from spacy.pipeline import TextCategorizer
from spacy.util import ENV_VARS, get_minor_version, load_model_from_config, load_config from spacy.util import ENV_VARS, get_minor_version, load_model_from_config, load_config
from ..cli.init_pipeline import _init_labels from ..cli.init_pipeline import _init_labels
@ -860,6 +860,8 @@ def test_span_length_freq_dist_output_must_be_correct():
def test_cli_find_threshold(capsys): def test_cli_find_threshold(capsys):
thresholds = numpy.linspace(0, 1, 10)
def make_examples(_nlp: Language) -> List[Example]: def make_examples(_nlp: Language) -> List[Example]:
docs: List[Example] = [] docs: List[Example] = []
@ -921,33 +923,35 @@ def test_cli_find_threshold(capsys):
) )
with make_tempdir() as nlp_dir: with make_tempdir() as nlp_dir:
nlp.to_disk(nlp_dir) nlp.to_disk(nlp_dir)
assert ( res = find_threshold(
find_threshold( model=nlp_dir,
model=nlp_dir, data_path=docs_dir / "docs.spacy",
data_path=docs_dir / "docs.spacy", pipe_name="tc_multi",
pipe_name="tc_multi", threshold_key="threshold",
threshold_key="threshold", scores_key="cats_macro_f",
scores_key="cats_macro_f", silent=False,
silent=False,
)[0]
== numpy.linspace(0, 1, 10)[1]
) )
assert res[0] != thresholds[0]
assert thresholds[0] < res[0] < thresholds[9]
assert res[1] == 1.0
assert res[2][1.0] == 0.0
# Test with spancat. # Test with spancat.
nlp, _ = init_nlp((("spancat", {}),)) nlp, _ = init_nlp((("spancat", {}),))
with make_tempdir() as nlp_dir: with make_tempdir() as nlp_dir:
nlp.to_disk(nlp_dir) nlp.to_disk(nlp_dir)
assert ( res = find_threshold(
find_threshold( model=nlp_dir,
model=nlp_dir, data_path=docs_dir / "docs.spacy",
data_path=docs_dir / "docs.spacy", pipe_name="spancat",
pipe_name="spancat", threshold_key="threshold",
threshold_key="threshold", scores_key="spans_sc_f",
scores_key="spans_sc_f", silent=True,
silent=True,
)[0]
== numpy.linspace(0, 1, 10)[1]
) )
assert res[0] != thresholds[0]
assert thresholds[0] < res[0] < thresholds[8]
assert res[1] == 1.0
assert res[2][1.0] == 0.0
# Having multiple textcat_multilabel components should work, since the name has to be specified. # Having multiple textcat_multilabel components should work, since the name has to be specified.
nlp, _ = init_nlp((("textcat_multilabel", {}),)) nlp, _ = init_nlp((("textcat_multilabel", {}),))