diff --git a/spacy/cli/find_threshold.py b/spacy/cli/find_threshold.py index 082fa3b82..0b8e6fbdb 100644 --- a/spacy/cli/find_threshold.py +++ b/spacy/cli/find_threshold.py @@ -7,7 +7,6 @@ from typing import Optional, Tuple, Any, Dict, List import numpy import wasabi.tables -from ..pipeline import TrainablePipe, Pipe from ..errors import Errors from ..training import Corpus from ._util import app, Arg, Opt, import_code, setup_gpu @@ -111,14 +110,11 @@ def find_threshold( wasabi.msg.fail("Evaluation data not found", data_path, exits=1) nlp = util.load_model(model) - pipe: Optional[Pipe] = None + pipe: Optional[Any] = None try: pipe = nlp.get_pipe(pipe_name) except KeyError as err: wasabi.msg.fail(title=str(err), exits=1) - - if not isinstance(pipe, TrainablePipe): - raise TypeError(Errors.E1044) if not hasattr(pipe, "scorer"): raise AttributeError(Errors.E1045) @@ -170,10 +166,16 @@ def find_threshold( threshold, ), ) - nlp.get_pipe(pipe_name).cfg = set_nested_item(pipe.cfg, config_keys, threshold) + if hasattr(pipe, "cfg"): + setattr( + nlp.get_pipe(pipe_name), + "cfg", + set_nested_item(getattr(pipe, "cfg"), config_keys, threshold), + ) + scores[threshold] = nlp.evaluate(dev_dataset)[scores_key] - if not ( - isinstance(scores[threshold], float) or isinstance(scores[threshold], int) + if not isinstance(scores[threshold], float) and not isinstance( + scores[threshold], int ): wasabi.msg.fail( f"Returned score for key '{scores_key}' is not numeric. Threshold optimization only works for numeric "