Fix test issue. Update docstring.

This commit is contained in:
Raphael Mitsch 2022-11-17 11:50:42 +01:00
parent d080808bf7
commit 7b4da3f36d
2 changed files with 16 additions and 17 deletions

View File

@ -39,19 +39,17 @@ def find_threshold_cli(
# fmt: on
):
"""
Runs prediction trials for models with varying tresholds to maximize the specified metric from CLI.
model (Path): Path to file with trained model.
data_path (Path): Path to file with DocBin with docs to use for threshold search.
pipe_name (str): Name of pipe to examine thresholds for.
threshold_key (str): Key of threshold attribute in component's configuration.
scores_key (str): Metric to optimize.
n_trials (int): Number of trials to determine optimal thresholds
code_path (Optional[Path]): Path to Python file with additional code (registered functions) to be imported.
use_gpu (int): GPU ID or -1 for CPU.
gold_preproc (bool): Whether to use gold preprocessing. Gold preprocessing helps the annotations align to the
tokenization, and may result in sequences of more consistent length. However, it may reduce runtime accuracy due
to train/test skew.
silent (bool): Display more information for debugging purposes
Runs prediction trials for models with varying tresholds to maximize the
specified metric. The search space for the threshold is traversed
linearly from 0 to 1 in n_trials steps.
This is applicable only for components whose predictions are influenced
by thresholds (e.g. textcat_multilabel and spancat, but not textcat).
Note that the full path to the corresponding threshold attribute in the
config has to be provided.
DOCS: https://spacy.io/api/cli#find-threshold
"""
util.logger.setLevel(logging.DEBUG if verbose else logging.INFO)
@ -105,7 +103,9 @@ def find_threshold(
nlp = util.load_model(model)
if pipe_name not in nlp.component_names:
raise AttributeError(Errors.E001.format(name=pipe_name))
raise AttributeError(
Errors.E001.format(name=pipe_name, opts=nlp.component_names)
)
pipe = nlp.get_pipe(pipe_name)
if not hasattr(pipe, "scorer"):
raise AttributeError(Errors.E1045)
@ -190,7 +190,7 @@ def find_threshold(
if scores_key not in eval_scores:
wasabi.msg.fail(
title=f"Failed to look up score `{scores_key}` in evaluation results.",
text=f"Make sure you specified the correct value for `scores_key` correctly. The following scores are "
text=f"Make sure you specified the correct value for `scores_key`. The following scores are "
f"available: {eval_scores.keys()}",
exits=1,
)

View File

@ -966,7 +966,7 @@ def test_cli_find_threshold(capsys):
nlp, _ = init_nlp()
with make_tempdir() as nlp_dir:
nlp.to_disk(nlp_dir)
with pytest.raises(SystemExit) as error:
with pytest.raises(AttributeError):
find_threshold(
model=nlp_dir,
data_path=docs_dir / "docs.spacy",
@ -975,7 +975,6 @@ def test_cli_find_threshold(capsys):
scores_key="cats_macro_f",
silent=True,
)
assert error.value.code == 1
@pytest.mark.parametrize(