diff --git a/spacy/cli/find_threshold.py b/spacy/cli/find_threshold.py index bfb15f39a..a37f276b4 100644 --- a/spacy/cli/find_threshold.py +++ b/spacy/cli/find_threshold.py @@ -30,7 +30,7 @@ def find_threshold_cli( data_path: Path = Arg(..., help="Location of binary evaluation data in .spacy format", exists=True), pipe_name: str = Arg(..., help="Name of pipe to examine thresholds for"), threshold_key: str = Arg(..., help="Key of threshold attribute in component's configuration"), - scores_key: str = Arg(..., help="Name of score to metric to optimize"), + scores_key: str = Arg(..., help="Metric to optimize"), n_trials: int = Opt(_DEFAULTS["n_trials"], "--n_trials", "-n", help="Number of trials to determine optimal thresholds"), code_path: Optional[Path] = Opt(None, "--code", "-c", help="Path to Python file with additional code (registered functions) to be imported"), use_gpu: int = Opt(_DEFAULTS["use_gpu"], "--gpu-id", "-g", help="GPU ID or -1 for CPU"), @@ -39,12 +39,12 @@ def find_threshold_cli( # fmt: on ): """ - Runs prediction trials for `textcat` models with varying tresholds to maximize the specified metric from CLI. + Runs prediction trials for models with varying tresholds to maximize the specified metric from CLI. model (Path): Path to file with trained model. data_path (Path): Path to file with DocBin with docs to use for threshold search. pipe_name (str): Name of pipe to examine thresholds for. threshold_key (str): Key of threshold attribute in component's configuration. - scores_key (str): Name of score to metric to optimize. + scores_key (str): Metric to optimize. n_trials (int): Number of trials to determine optimal thresholds code_path (Optional[Path]): Path to Python file with additional code (registered functions) to be imported. use_gpu (int): GPU ID or -1 for CPU. @@ -82,7 +82,7 @@ def find_threshold( silent: bool = True, ) -> Tuple[float, float, Dict[float, float]]: """ - Runs prediction trials for `textcat` models with varying tresholds to maximize the specified metric. + Runs prediction trials for models with varying tresholds to maximize the specified metric. model (Union[str, Path]): Path to file with trained model. data_path (Union[str, Path]): Path to file with DocBin with docs to use for threshold search. pipe_name (str): Name of pipe to examine thresholds for. @@ -104,11 +104,9 @@ def find_threshold( wasabi.msg.fail("Evaluation data not found", data_path, exits=1) nlp = util.load_model(model) - pipe: Optional[Any] = None - try: - pipe = nlp.get_pipe(pipe_name) - except KeyError as err: - wasabi.msg.fail(title=str(err), exits=1) + if pipe_name not in nlp.component_names: + raise AttributeError(Errors.E001.format(name=pipe_name)) + pipe = nlp.get_pipe(pipe_name) if not hasattr(pipe, "scorer"): raise AttributeError(Errors.E1045) @@ -141,14 +139,24 @@ def find_threshold( functools.reduce(operator.getitem, keys[:-1], config)[keys[-1]] = value return config - def filter_config(config: Dict[str, Any], keys: List[str]) -> Dict[str, Any]: + def filter_config( + config: Dict[str, Any], keys: List[str], full_key: str + ) -> Dict[str, Any]: """Filters provided config dictionary so that only the specified keys path remains. config (Dict[str, Any]): Configuration dictionary. keys (List[Any]): Path to value to set. + full_key (str): Full user-specified key. RETURNS (Dict[str, Any]): Filtered dictionary. """ + if keys[0] not in config: + wasabi.msg.fail( + title=f"Failed to look up `{full_key}` in config: sub-key {[keys[0]]} not found.", + text=f"Make sure you specified {[keys[0]]} correctly. The following sub-keys are available instead: " + f"{config.keys()}", + exits=1, + ) return { - keys[0]: filter_config(config[keys[0]], keys[1:]) + keys[0]: filter_config(config[keys[0]], keys[1:], full_key) if len(keys) > 1 else config[keys[0]] } @@ -164,7 +172,9 @@ def find_threshold( nlp = util.load_model( model, config=set_nested_item( - filter_config(nlp.config, config_keys_full).copy(), + filter_config( + nlp.config, config_keys_full, ".".join(config_keys_full) + ).copy(), config_keys_full, threshold, ), @@ -176,7 +186,16 @@ def find_threshold( set_nested_item(getattr(pipe, "cfg"), config_keys, threshold), ) - scores[threshold] = nlp.evaluate(dev_dataset)[scores_key] + eval_scores = nlp.evaluate(dev_dataset) + if scores_key not in eval_scores: + wasabi.msg.fail( + title=f"Failed to look up score `{scores_key}` in evaluation results.", + text=f"Make sure you specified the correct value for `scores_key` correctly. The following scores are " + f"available: {eval_scores.keys()}", + exits=1, + ) + scores[threshold] = eval_scores[scores_key] + if not isinstance(scores[threshold], (float, int)): wasabi.msg.fail( f"Returned score for key '{scores_key}' is not numeric. Threshold optimization only works for numeric "