diff --git a/spacy/cli/find_threshold.py b/spacy/cli/find_threshold.py index a37f276b4..32cb7a555 100644 --- a/spacy/cli/find_threshold.py +++ b/spacy/cli/find_threshold.py @@ -39,19 +39,17 @@ def find_threshold_cli( # fmt: on ): """ - Runs prediction trials for models with varying tresholds to maximize the specified metric from CLI. - model (Path): Path to file with trained model. - data_path (Path): Path to file with DocBin with docs to use for threshold search. - pipe_name (str): Name of pipe to examine thresholds for. - threshold_key (str): Key of threshold attribute in component's configuration. - scores_key (str): Metric to optimize. - n_trials (int): Number of trials to determine optimal thresholds - code_path (Optional[Path]): Path to Python file with additional code (registered functions) to be imported. - use_gpu (int): GPU ID or -1 for CPU. - gold_preproc (bool): Whether to use gold preprocessing. Gold preprocessing helps the annotations align to the - tokenization, and may result in sequences of more consistent length. However, it may reduce runtime accuracy due - to train/test skew. - silent (bool): Display more information for debugging purposes + Runs prediction trials for models with varying tresholds to maximize the + specified metric. The search space for the threshold is traversed + linearly from 0 to 1 in n_trials steps. + + This is applicable only for components whose predictions are influenced + by thresholds (e.g. textcat_multilabel and spancat, but not textcat). + + Note that the full path to the corresponding threshold attribute in the + config has to be provided. + + DOCS: https://spacy.io/api/cli#find-threshold """ util.logger.setLevel(logging.DEBUG if verbose else logging.INFO) @@ -105,7 +103,9 @@ def find_threshold( nlp = util.load_model(model) if pipe_name not in nlp.component_names: - raise AttributeError(Errors.E001.format(name=pipe_name)) + raise AttributeError( + Errors.E001.format(name=pipe_name, opts=nlp.component_names) + ) pipe = nlp.get_pipe(pipe_name) if not hasattr(pipe, "scorer"): raise AttributeError(Errors.E1045) @@ -190,7 +190,7 @@ def find_threshold( if scores_key not in eval_scores: wasabi.msg.fail( title=f"Failed to look up score `{scores_key}` in evaluation results.", - text=f"Make sure you specified the correct value for `scores_key` correctly. The following scores are " + text=f"Make sure you specified the correct value for `scores_key`. The following scores are " f"available: {eval_scores.keys()}", exits=1, ) diff --git a/spacy/tests/test_cli.py b/spacy/tests/test_cli.py index 6d45ba53b..f29568bab 100644 --- a/spacy/tests/test_cli.py +++ b/spacy/tests/test_cli.py @@ -966,7 +966,7 @@ def test_cli_find_threshold(capsys): nlp, _ = init_nlp() with make_tempdir() as nlp_dir: nlp.to_disk(nlp_dir) - with pytest.raises(SystemExit) as error: + with pytest.raises(AttributeError): find_threshold( model=nlp_dir, data_path=docs_dir / "docs.spacy", @@ -975,7 +975,6 @@ def test_cli_find_threshold(capsys): scores_key="cats_macro_f", silent=True, ) - assert error.value.code == 1 @pytest.mark.parametrize(