diff --git a/spacy/cli/find_threshold.py b/spacy/cli/find_threshold.py index a451b7d4a..082fa3b82 100644 --- a/spacy/cli/find_threshold.py +++ b/spacy/cli/find_threshold.py @@ -6,7 +6,6 @@ from typing import Optional, Tuple, Any, Dict, List import numpy import wasabi.tables -from confection import Config from ..pipeline import TrainablePipe, Pipe from ..errors import Errors @@ -87,7 +86,7 @@ def find_threshold( use_gpu: int = _DEFAULTS["use_gpu"], # type: ignore gold_preproc: bool = _DEFAULTS["gold_preproc"], # type: ignore silent: bool = True, -) -> Tuple[float, float]: +) -> Tuple[float, float, Dict[float, float]]: """ Runs prediction trials for `textcat` models with varying tresholds to maximize the specified metric. model (Union[str, Path]): Path to file with trained model. @@ -102,7 +101,8 @@ def find_threshold( tokenization, and may result in sequences of more consistent length. However, it may reduce runtime accuracy due to train/test skew. silent (bool): Whether to print non-error-related output to stdout. - RETURNS (Tuple[float, float]): Best found threshold with corresponding F-score. + RETURNS (Tuple[float, float, Dict[float, float]]): Best found threshold, the corresponding score, scores for all + evaluated thresholds. """ setup_gpu(use_gpu, silent=silent) @@ -138,18 +138,39 @@ def find_threshold( ) -> Dict[str, Any]: """Set item in nested dictionary. Adapated from https://stackoverflow.com/a/54138200. config (Dict[str, Any]): Configuration dictionary. - keys (List[Any]): + keys (List[Any]): Path to value to set. value (float): Value to set. RETURNS (Dict[str, Any]): Updated dictionary. """ functools.reduce(operator.getitem, keys[:-1], config)[keys[-1]] = value return config + def filter_config(config: Dict[str, Any], keys: List[str]) -> Dict[str, Any]: + """Filters provided config dictionary so that only the specified keys path remains. + config (Dict[str, Any]): Configuration dictionary. + keys (List[Any]): Path to value to set. + RETURNS (Dict[str, Any]): Filtered dictionary. + """ + return { + keys[0]: filter_config(config[keys[0]], keys[1:]) + if len(keys) > 1 + else config[keys[0]] + } + # Evaluate with varying threshold values. scores: Dict[float, float] = {} + config_keys_full = ["components", pipe_name, *config_keys] for threshold in numpy.linspace(0, 1, n_trials): - pipe.cfg = set_nested_item(pipe.cfg, config_keys, threshold) - nlp._pipe_configs[pipe_name] = Config(pipe.cfg) + # Reload pipeline with overrides specifying the new threshold. + nlp = util.load_model( + model, + config=set_nested_item( + filter_config(nlp.config, config_keys_full).copy(), + config_keys_full, + threshold, + ), + ) + nlp.get_pipe(pipe_name).cfg = set_nested_item(pipe.cfg, config_keys, threshold) scores[threshold] = nlp.evaluate(dev_dataset)[scores_key] if not ( isinstance(scores[threshold], float) or isinstance(scores[threshold], int) @@ -170,4 +191,4 @@ def find_threshold( ), ) - return best_threshold, scores[best_threshold] + return best_threshold, scores[best_threshold], scores diff --git a/spacy/tests/test_cli.py b/spacy/tests/test_cli.py index 733c7c876..5aa4fe43b 100644 --- a/spacy/tests/test_cli.py +++ b/spacy/tests/test_cli.py @@ -36,7 +36,7 @@ from spacy.tokens.span import Span from spacy.training import Example, docs_to_json, offsets_to_biluo_tags from spacy.training.converters import conll_ner_to_docs, conllu_to_docs from spacy.training.converters import iob_to_docs -from spacy.pipeline import TextCategorizer, Pipe, SpanCategorizer +from spacy.pipeline import TextCategorizer from spacy.util import ENV_VARS, get_minor_version, load_model_from_config, load_config from ..cli.init_pipeline import _init_labels @@ -860,6 +860,8 @@ def test_span_length_freq_dist_output_must_be_correct(): def test_cli_find_threshold(capsys): + thresholds = numpy.linspace(0, 1, 10) + def make_examples(_nlp: Language) -> List[Example]: docs: List[Example] = [] @@ -921,33 +923,35 @@ def test_cli_find_threshold(capsys): ) with make_tempdir() as nlp_dir: nlp.to_disk(nlp_dir) - assert ( - find_threshold( - model=nlp_dir, - data_path=docs_dir / "docs.spacy", - pipe_name="tc_multi", - threshold_key="threshold", - scores_key="cats_macro_f", - silent=False, - )[0] - == numpy.linspace(0, 1, 10)[1] + res = find_threshold( + model=nlp_dir, + data_path=docs_dir / "docs.spacy", + pipe_name="tc_multi", + threshold_key="threshold", + scores_key="cats_macro_f", + silent=False, ) + assert res[0] != thresholds[0] + assert thresholds[0] < res[0] < thresholds[9] + assert res[1] == 1.0 + assert res[2][1.0] == 0.0 # Test with spancat. nlp, _ = init_nlp((("spancat", {}),)) with make_tempdir() as nlp_dir: nlp.to_disk(nlp_dir) - assert ( - find_threshold( - model=nlp_dir, - data_path=docs_dir / "docs.spacy", - pipe_name="spancat", - threshold_key="threshold", - scores_key="spans_sc_f", - silent=True, - )[0] - == numpy.linspace(0, 1, 10)[1] + res = find_threshold( + model=nlp_dir, + data_path=docs_dir / "docs.spacy", + pipe_name="spancat", + threshold_key="threshold", + scores_key="spans_sc_f", + silent=True, ) + assert res[0] != thresholds[0] + assert thresholds[0] < res[0] < thresholds[8] + assert res[1] == 1.0 + assert res[2][1.0] == 0.0 # Having multiple textcat_multilabel components should work, since the name has to be specified. nlp, _ = init_nlp((("textcat_multilabel", {}),))