Reload pipeline on threshold change. Adjust tests. Remove confection reference.

This commit is contained in:
Raphael Mitsch 2022-09-05 11:39:29 +02:00
parent 73432c6bfb
commit 20c4a0d613
2 changed files with 53 additions and 28 deletions

View File

@ -6,7 +6,6 @@ from typing import Optional, Tuple, Any, Dict, List
import numpy
import wasabi.tables
from confection import Config
from ..pipeline import TrainablePipe, Pipe
from ..errors import Errors
@ -87,7 +86,7 @@ def find_threshold(
use_gpu: int = _DEFAULTS["use_gpu"], # type: ignore
gold_preproc: bool = _DEFAULTS["gold_preproc"], # type: ignore
silent: bool = True,
) -> Tuple[float, float]:
) -> Tuple[float, float, Dict[float, float]]:
"""
Runs prediction trials for `textcat` models with varying tresholds to maximize the specified metric.
model (Union[str, Path]): Path to file with trained model.
@ -102,7 +101,8 @@ def find_threshold(
tokenization, and may result in sequences of more consistent length. However, it may reduce runtime accuracy due
to train/test skew.
silent (bool): Whether to print non-error-related output to stdout.
RETURNS (Tuple[float, float]): Best found threshold with corresponding F-score.
RETURNS (Tuple[float, float, Dict[float, float]]): Best found threshold, the corresponding score, scores for all
evaluated thresholds.
"""
setup_gpu(use_gpu, silent=silent)
@ -138,18 +138,39 @@ def find_threshold(
) -> Dict[str, Any]:
"""Set item in nested dictionary. Adapated from https://stackoverflow.com/a/54138200.
config (Dict[str, Any]): Configuration dictionary.
keys (List[Any]):
keys (List[Any]): Path to value to set.
value (float): Value to set.
RETURNS (Dict[str, Any]): Updated dictionary.
"""
functools.reduce(operator.getitem, keys[:-1], config)[keys[-1]] = value
return config
def filter_config(config: Dict[str, Any], keys: List[str]) -> Dict[str, Any]:
"""Filters provided config dictionary so that only the specified keys path remains.
config (Dict[str, Any]): Configuration dictionary.
keys (List[Any]): Path to value to set.
RETURNS (Dict[str, Any]): Filtered dictionary.
"""
return {
keys[0]: filter_config(config[keys[0]], keys[1:])
if len(keys) > 1
else config[keys[0]]
}
# Evaluate with varying threshold values.
scores: Dict[float, float] = {}
config_keys_full = ["components", pipe_name, *config_keys]
for threshold in numpy.linspace(0, 1, n_trials):
pipe.cfg = set_nested_item(pipe.cfg, config_keys, threshold)
nlp._pipe_configs[pipe_name] = Config(pipe.cfg)
# Reload pipeline with overrides specifying the new threshold.
nlp = util.load_model(
model,
config=set_nested_item(
filter_config(nlp.config, config_keys_full).copy(),
config_keys_full,
threshold,
),
)
nlp.get_pipe(pipe_name).cfg = set_nested_item(pipe.cfg, config_keys, threshold)
scores[threshold] = nlp.evaluate(dev_dataset)[scores_key]
if not (
isinstance(scores[threshold], float) or isinstance(scores[threshold], int)
@ -170,4 +191,4 @@ def find_threshold(
),
)
return best_threshold, scores[best_threshold]
return best_threshold, scores[best_threshold], scores

View File

@ -36,7 +36,7 @@ from spacy.tokens.span import Span
from spacy.training import Example, docs_to_json, offsets_to_biluo_tags
from spacy.training.converters import conll_ner_to_docs, conllu_to_docs
from spacy.training.converters import iob_to_docs
from spacy.pipeline import TextCategorizer, Pipe, SpanCategorizer
from spacy.pipeline import TextCategorizer
from spacy.util import ENV_VARS, get_minor_version, load_model_from_config, load_config
from ..cli.init_pipeline import _init_labels
@ -860,6 +860,8 @@ def test_span_length_freq_dist_output_must_be_correct():
def test_cli_find_threshold(capsys):
thresholds = numpy.linspace(0, 1, 10)
def make_examples(_nlp: Language) -> List[Example]:
docs: List[Example] = []
@ -921,33 +923,35 @@ def test_cli_find_threshold(capsys):
)
with make_tempdir() as nlp_dir:
nlp.to_disk(nlp_dir)
assert (
find_threshold(
res = find_threshold(
model=nlp_dir,
data_path=docs_dir / "docs.spacy",
pipe_name="tc_multi",
threshold_key="threshold",
scores_key="cats_macro_f",
silent=False,
)[0]
== numpy.linspace(0, 1, 10)[1]
)
assert res[0] != thresholds[0]
assert thresholds[0] < res[0] < thresholds[9]
assert res[1] == 1.0
assert res[2][1.0] == 0.0
# Test with spancat.
nlp, _ = init_nlp((("spancat", {}),))
with make_tempdir() as nlp_dir:
nlp.to_disk(nlp_dir)
assert (
find_threshold(
res = find_threshold(
model=nlp_dir,
data_path=docs_dir / "docs.spacy",
pipe_name="spancat",
threshold_key="threshold",
scores_key="spans_sc_f",
silent=True,
)[0]
== numpy.linspace(0, 1, 10)[1]
)
assert res[0] != thresholds[0]
assert thresholds[0] < res[0] < thresholds[8]
assert res[1] == 1.0
assert res[2][1.0] == 0.0
# Having multiple textcat_multilabel components should work, since the name has to be specified.
nlp, _ = init_nlp((("textcat_multilabel", {}),))