diff --git a/spacy/cli/find_threshold.py b/spacy/cli/find_threshold.py index 1ffc04bbd..6f8bb68b8 100644 --- a/spacy/cli/find_threshold.py +++ b/spacy/cli/find_threshold.py @@ -1,10 +1,11 @@ from pathlib import Path import logging -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, Dict, cast import numpy import wasabi.tables +from pipeline import Pipe from ._util import app, Arg, Opt from .. import util from ..pipeline import MultiLabel_TextCategorizer @@ -60,10 +61,10 @@ def find_threshold( model_path: Union[str, Path], doc_path: Union[str, Path], *, - average: str = _DEFAULTS["average"], - pipe_name: Optional[str] = _DEFAULTS["pipe_name"], - n_trials: int = _DEFAULTS["n_trials"], - beta: float = _DEFAULTS["beta"], + average: str = _DEFAULTS["average"], # type: ignore + pipe_name: Optional[str] = _DEFAULTS["pipe_name"], # type: ignore + n_trials: int = _DEFAULTS["n_trials"], # type: ignore + beta: float = _DEFAULTS["beta"], # type: ignore verbose: bool = True, ) -> Tuple[float, float]: """ @@ -80,7 +81,7 @@ def find_threshold( """ nlp = util.load_model(model_path) - pipe: Optional[MultiLabel_TextCategorizer] = None + pipe: Optional[Pipe] = None selected_pipe_name: Optional[str] = pipe_name if average not in ("micro", "macro"): @@ -99,7 +100,6 @@ def find_threshold( exits=1, ) pipe = _pipe - print(pipe_name, _pipe_name, pipe.labels) break elif pipe_name is None: if isinstance(_pipe, MultiLabel_TextCategorizer): @@ -121,6 +121,8 @@ def find_threshold( "No component of type `MultiLabel_TextCategorizer` found in pipeline.", exits=1, ) + # This is purely for MyPy. Type checking is done in loop above already. + assert isinstance(pipe, MultiLabel_TextCategorizer) if verbose: print( @@ -134,8 +136,8 @@ def find_threshold( t: {True: ref_pos_counts.copy(), False: ref_pos_counts.copy()} for t in thresholds } - f_scores_per_label = {t: ref_pos_counts.copy() for t in thresholds} - f_scores = {t: 0 for t in thresholds} + f_scores_per_label = {t: {label: 0.0 for label in pipe.labels} for t in thresholds} + f_scores = {t: 0.0 for t in thresholds} # Count true/false positives for provided docs. doc_bin = DocBin() @@ -196,7 +198,7 @@ def find_threshold( [f_scores_per_label[threshold][label] for label in ref_pos_counts] ) / len(ref_pos_counts) - best_threshold = max(f_scores, key=f_scores.get) + best_threshold = max(f_scores.keys(), key=(lambda key: f_scores[key])) if verbose: print( f"Best threshold: {round(best_threshold, ndigits=4)} with F-score of {f_scores[best_threshold]}.", diff --git a/spacy/tests/test_cli.py b/spacy/tests/test_cli.py index 8b0bab5b6..48cc364f0 100644 --- a/spacy/tests/test_cli.py +++ b/spacy/tests/test_cli.py @@ -876,7 +876,7 @@ def test_cli_find_threshold(capsys): ] def init_nlp( - component_factory_names: Tuple[str] = (), + component_factory_names: Tuple[str, ...] = (), ) -> Tuple[Language, List[Example]]: _nlp = English()