Fix test issue. Update docstring.

This commit is contained in:
Raphael Mitsch 2022-11-17 11:50:42 +01:00
parent d080808bf7
commit 7b4da3f36d
2 changed files with 16 additions and 17 deletions

View File

@ -39,19 +39,17 @@ def find_threshold_cli(
# fmt: on # fmt: on
): ):
""" """
Runs prediction trials for models with varying tresholds to maximize the specified metric from CLI. Runs prediction trials for models with varying tresholds to maximize the
model (Path): Path to file with trained model. specified metric. The search space for the threshold is traversed
data_path (Path): Path to file with DocBin with docs to use for threshold search. linearly from 0 to 1 in n_trials steps.
pipe_name (str): Name of pipe to examine thresholds for.
threshold_key (str): Key of threshold attribute in component's configuration. This is applicable only for components whose predictions are influenced
scores_key (str): Metric to optimize. by thresholds (e.g. textcat_multilabel and spancat, but not textcat).
n_trials (int): Number of trials to determine optimal thresholds
code_path (Optional[Path]): Path to Python file with additional code (registered functions) to be imported. Note that the full path to the corresponding threshold attribute in the
use_gpu (int): GPU ID or -1 for CPU. config has to be provided.
gold_preproc (bool): Whether to use gold preprocessing. Gold preprocessing helps the annotations align to the
tokenization, and may result in sequences of more consistent length. However, it may reduce runtime accuracy due DOCS: https://spacy.io/api/cli#find-threshold
to train/test skew.
silent (bool): Display more information for debugging purposes
""" """
util.logger.setLevel(logging.DEBUG if verbose else logging.INFO) util.logger.setLevel(logging.DEBUG if verbose else logging.INFO)
@ -105,7 +103,9 @@ def find_threshold(
nlp = util.load_model(model) nlp = util.load_model(model)
if pipe_name not in nlp.component_names: if pipe_name not in nlp.component_names:
raise AttributeError(Errors.E001.format(name=pipe_name)) raise AttributeError(
Errors.E001.format(name=pipe_name, opts=nlp.component_names)
)
pipe = nlp.get_pipe(pipe_name) pipe = nlp.get_pipe(pipe_name)
if not hasattr(pipe, "scorer"): if not hasattr(pipe, "scorer"):
raise AttributeError(Errors.E1045) raise AttributeError(Errors.E1045)
@ -190,7 +190,7 @@ def find_threshold(
if scores_key not in eval_scores: if scores_key not in eval_scores:
wasabi.msg.fail( wasabi.msg.fail(
title=f"Failed to look up score `{scores_key}` in evaluation results.", title=f"Failed to look up score `{scores_key}` in evaluation results.",
text=f"Make sure you specified the correct value for `scores_key` correctly. The following scores are " text=f"Make sure you specified the correct value for `scores_key`. The following scores are "
f"available: {eval_scores.keys()}", f"available: {eval_scores.keys()}",
exits=1, exits=1,
) )

View File

@ -966,7 +966,7 @@ def test_cli_find_threshold(capsys):
nlp, _ = init_nlp() nlp, _ = init_nlp()
with make_tempdir() as nlp_dir: with make_tempdir() as nlp_dir:
nlp.to_disk(nlp_dir) nlp.to_disk(nlp_dir)
with pytest.raises(SystemExit) as error: with pytest.raises(AttributeError):
find_threshold( find_threshold(
model=nlp_dir, model=nlp_dir,
data_path=docs_dir / "docs.spacy", data_path=docs_dir / "docs.spacy",
@ -975,7 +975,6 @@ def test_cli_find_threshold(capsys):
scores_key="cats_macro_f", scores_key="cats_macro_f",
silent=True, silent=True,
) )
assert error.value.code == 1
@pytest.mark.parametrize( @pytest.mark.parametrize(