diff --git a/spacy/cli/find_threshold.py b/spacy/cli/find_threshold.py index eb072817b..1ffc04bbd 100644 --- a/spacy/cli/find_threshold.py +++ b/spacy/cli/find_threshold.py @@ -1,6 +1,6 @@ from pathlib import Path import logging -from typing import Optional +from typing import Optional, Tuple, Union import numpy import wasabi.tables @@ -57,23 +57,26 @@ def find_threshold_cli( def find_threshold( - model_path: Path, - doc_path: Path, + model_path: Union[str, Path], + doc_path: Union[str, Path], *, average: str = _DEFAULTS["average"], pipe_name: Optional[str] = _DEFAULTS["pipe_name"], n_trials: int = _DEFAULTS["n_trials"], beta: float = _DEFAULTS["beta"], -) -> None: + verbose: bool = True, +) -> Tuple[float, float]: """ Runs prediction trials for `textcat` models with varying tresholds to maximize the specified metric. - model_path (Path): Path to file with trained model. - doc_path (Path): Path to file with DocBin with docs to use for threshold search. + model_path (Union[str, Path]): Path to file with trained model. + doc_path (Union[str, Path]): Path to file with DocBin with docs to use for threshold search. average (str): How to average F-scores across labels. One of ('micro', 'macro'). pipe_name (Optional[str]): Name of pipe to examine thresholds for. If None, pipe of type MultiLabel_TextCategorizer is seleted. If there are multiple, an error is raised. n_trials (int): Number of trials to determine optimal thresholds beta (float): Beta for F1 calculation. Ignored if different metric is used. + verbose (bool): Whether to print non-error-related output to stdout. + RETURNS (Tuple[float, float]): Best found threshold with corresponding F-score. """ nlp = util.load_model(model_path) @@ -90,10 +93,13 @@ def find_threshold( if pipe_name and _pipe_name == pipe_name: if not isinstance(_pipe, MultiLabel_TextCategorizer): wasabi.msg.fail( - "Specified component {component} is not of type `MultiLabel_TextCategorizer`.", + "Specified component '{component}' is not of type `MultiLabel_TextCategorizer`.".format( + component=pipe_name + ), exits=1, ) pipe = _pipe + print(pipe_name, _pipe_name, pipe.labels) break elif pipe_name is None: if isinstance(_pipe, MultiLabel_TextCategorizer): @@ -116,10 +122,11 @@ def find_threshold( exits=1, ) - print( - f"Searching threshold with the best {average} F-score for pipe '{selected_pipe_name}' with {n_trials} trials" - f" and beta = {beta}." - ) + if verbose: + print( + f"Searching threshold with the best {average} F-score for component '{selected_pipe_name}' with {n_trials} " + f"trials and beta = {beta}." + ) thresholds = numpy.linspace(0, 1, n_trials) ref_pos_counts = {label: 0 for label in pipe.labels} @@ -146,13 +153,15 @@ def find_threshold( # Collect count stats per threshold value and label. for threshold in thresholds: for label, score in pred_doc.cats.items(): + if label not in pipe.labels: + continue label_value = int(score >= threshold) if label_value == ref_doc.cats[label] == 1: pred_pos_counts[threshold][True][label] += 1 elif label_value == 1 and ref_doc.cats[label] == 0: pred_pos_counts[threshold][False][label] += 1 - # Compute f_scores. + # Compute F-scores. for threshold in thresholds: for label in ref_pos_counts: n_pos_preds = ( @@ -188,20 +197,21 @@ def find_threshold( ) / len(ref_pos_counts) best_threshold = max(f_scores, key=f_scores.get) - print( - f"Best threshold: {round(best_threshold, ndigits=4)} with F-score of {f_scores[best_threshold]}." - ) - print( - wasabi.tables.table( - data=[ - (threshold, label, f_score) - for threshold, label_f_scores in f_scores_per_label.items() - for label, f_score in label_f_scores.items() - ], - header=["Threshold", "Label", "F-Score"], - ), - wasabi.tables.table( - data=[(threshold, f_score) for threshold, f_score in f_scores.items()], - header=["Threshold", f"F-Score ({average})"], - ), - ) + if verbose: + print( + f"Best threshold: {round(best_threshold, ndigits=4)} with F-score of {f_scores[best_threshold]}.", + wasabi.tables.table( + data=[ + (threshold, label, f_score) + for threshold, label_f_scores in f_scores_per_label.items() + for label, f_score in label_f_scores.items() + ], + header=["Threshold", "Label", "F-Score"], + ), + wasabi.tables.table( + data=[(threshold, f_score) for threshold, f_score in f_scores.items()], + header=["Threshold", f"F-Score ({average})"], + ), + ) + + return best_threshold, f_scores[best_threshold] diff --git a/spacy/tests/test_cli.py b/spacy/tests/test_cli.py index 838e00369..264549a69 100644 --- a/spacy/tests/test_cli.py +++ b/spacy/tests/test_cli.py @@ -1,8 +1,8 @@ import os import math -from random import sample -from typing import Counter +from typing import Counter, Iterable, Tuple, List +import numpy import pytest import srsly from click import NoSuchOption @@ -26,19 +26,23 @@ from spacy.cli.init_config import RECOMMENDATIONS, init_config, fill_config from spacy.cli.package import get_third_party_dependencies from spacy.cli.package import _is_permitted_package_name from spacy.cli.validate import get_model_pkgs +from spacy.cli.find_threshold import find_threshold from spacy.lang.en import English from spacy.lang.nl import Dutch from spacy.language import Language from spacy.schemas import ProjectConfigSchema, RecommendationSchema, validate -from spacy.tokens import Doc +from spacy.tokens import Doc, DocBin from spacy.tokens.span import Span from spacy.training import Example, docs_to_json, offsets_to_biluo_tags from spacy.training.converters import conll_ner_to_docs, conllu_to_docs from spacy.training.converters import iob_to_docs +from spacy.pipeline import TextCategorizer, Pipe from spacy.util import ENV_VARS, get_minor_version, load_model_from_config, load_config -from ..cli.init_pipeline import _init_labels -from .util import make_tempdir +# from ..cli.init_pipeline import _init_labels +# from .util import make_tempdir +from spacy.cli.init_pipeline import _init_labels +from spacy.tests.util import make_tempdir @pytest.mark.issue(4665) @@ -855,3 +859,123 @@ def test_span_length_freq_dist_output_must_be_correct(): span_freqs = _get_spans_length_freq_dist(sample_span_lengths, threshold) assert sum(span_freqs.values()) >= threshold assert list(span_freqs.keys()) == [3, 1, 4, 5, 2] + + +def test_cli_find_threshold(capsys): + def make_get_examples_multi_label(_nlp: Language) -> List[Example]: + return [ + Example.from_dict(_nlp.make_doc(t[0]), t[1]) + for t in [ + ( + "I'm angry and confused", + {"cats": {"ANGRY": 1.0, "CONFUSED": 1.0, "HAPPY": 0.0}}, + ), + ( + "I'm confused but happy", + {"cats": {"ANGRY": 0.0, "CONFUSED": 1.0, "HAPPY": 1.0}}, + ), + ] + ] + + def init_nlp( + component_factory_names: Tuple[str] = (), + ) -> Tuple[Language, List[Example]]: + _nlp = English() + + textcat: TextCategorizer = _nlp.add_pipe(factory_name="textcat_multilabel", name="tc_multi") # type: ignore + textcat.add_label("ANGRY") + textcat.add_label("CONFUSED") + textcat.add_label("HAPPY") + for cfn in component_factory_names: + comp = _nlp.add_pipe(cfn) + if isinstance(comp, TextCategorizer): + comp.add_label("dummy") + + _nlp.initialize() + + _examples = make_get_examples_multi_label(_nlp) + for i in range(5): + _nlp.update(_examples) + + return _nlp, _examples + + with make_tempdir() as docs_dir: + # Check whether find_threshold() identifies lowest threshold above 0 as (first) ideal threshold, as this matches + # the current model behavior with the examples above. This can break once the model behavior changes and serves + # mostly as a smoke test. + nlp, examples = init_nlp() + DocBin(docs=[example.reference for example in examples]).to_disk( + docs_dir / "docs" + ) + with make_tempdir() as nlp_dir: + nlp.to_disk(nlp_dir) + assert ( + find_threshold(nlp_dir, docs_dir / "docs", verbose=False)[0] + == numpy.linspace(0, 1, 10)[1] + ) + + # Specifying name of non-MultiLabel_TextCategorizer component should fail. + nlp, _ = init_nlp(("sentencizer",)) + with make_tempdir() as nlp_dir: + nlp.to_disk(nlp_dir) + with pytest.raises(SystemExit) as error: + find_threshold(nlp_dir, docs_dir / "docs", pipe_name="sentencizer") + assert error.value.code == 1 + + # Having multiple textcat_multilabel components without specifying the name should fail. + nlp, _ = init_nlp(("textcat_multilabel",)) + with make_tempdir() as nlp_dir: + nlp.to_disk(nlp_dir) + with pytest.raises(SystemExit) as error: + find_threshold(nlp_dir, docs_dir / "docs") + assert error.value.code == 1 + + # Having multiple textcat_multilabel components should work when specifying the name. + nlp, _ = init_nlp(("textcat_multilabel",)) + with make_tempdir() as nlp_dir: + nlp.to_disk(nlp_dir) + assert ( + find_threshold( + nlp_dir, docs_dir / "docs", pipe_name="tc_multi", verbose=False + )[0] + == numpy.linspace(0, 1, 10)[1] + ) + + # Specifying the name of an non-existing pipe should fail. + nlp, _ = init_nlp() + with make_tempdir() as nlp_dir: + nlp.to_disk(nlp_dir) + with pytest.raises(SystemExit) as error: + find_threshold(nlp_dir, docs_dir / "docs", pipe_name="_") + assert error.value.code == 1 + + # Using a pipe with no textcat components should fail. + nlp = English() + with make_tempdir() as nlp_dir: + nlp.to_disk(nlp_dir) + with pytest.raises(SystemExit) as error: + find_threshold(nlp_dir, docs_dir / "docs") + assert error.value.code == 1 + + # Specifying scores not in range 0 <= x <= 1 should fail. + nlp, _ = init_nlp() + DocBin( + docs=[ + Example.from_dict(nlp.make_doc(t[0]), t[1]).reference + for t in [ + ( + "I'm angry and confused", + {"cats": {"ANGRY": 1.0, "CONFUSED": 2.0, "HAPPY": 0.0}}, + ), + ( + "I'm confused but happy", + {"cats": {"ANGRY": 0.0, "CONFUSED": 1.0, "HAPPY": 1.0}}, + ), + ] + ] + ).to_disk(docs_dir / "docs") + with make_tempdir() as nlp_dir: + nlp.to_disk(nlp_dir) + with pytest.raises(SystemExit) as error: + find_threshold(nlp_dir, docs_dir / "docs") + assert error.value.code == 1