mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-07 21:54:54 +03:00
Fix test issue. Update docstring.
This commit is contained in:
parent
d080808bf7
commit
7b4da3f36d
|
@ -39,19 +39,17 @@ def find_threshold_cli(
|
|||
# fmt: on
|
||||
):
|
||||
"""
|
||||
Runs prediction trials for models with varying tresholds to maximize the specified metric from CLI.
|
||||
model (Path): Path to file with trained model.
|
||||
data_path (Path): Path to file with DocBin with docs to use for threshold search.
|
||||
pipe_name (str): Name of pipe to examine thresholds for.
|
||||
threshold_key (str): Key of threshold attribute in component's configuration.
|
||||
scores_key (str): Metric to optimize.
|
||||
n_trials (int): Number of trials to determine optimal thresholds
|
||||
code_path (Optional[Path]): Path to Python file with additional code (registered functions) to be imported.
|
||||
use_gpu (int): GPU ID or -1 for CPU.
|
||||
gold_preproc (bool): Whether to use gold preprocessing. Gold preprocessing helps the annotations align to the
|
||||
tokenization, and may result in sequences of more consistent length. However, it may reduce runtime accuracy due
|
||||
to train/test skew.
|
||||
silent (bool): Display more information for debugging purposes
|
||||
Runs prediction trials for models with varying tresholds to maximize the
|
||||
specified metric. The search space for the threshold is traversed
|
||||
linearly from 0 to 1 in n_trials steps.
|
||||
|
||||
This is applicable only for components whose predictions are influenced
|
||||
by thresholds (e.g. textcat_multilabel and spancat, but not textcat).
|
||||
|
||||
Note that the full path to the corresponding threshold attribute in the
|
||||
config has to be provided.
|
||||
|
||||
DOCS: https://spacy.io/api/cli#find-threshold
|
||||
"""
|
||||
|
||||
util.logger.setLevel(logging.DEBUG if verbose else logging.INFO)
|
||||
|
@ -105,7 +103,9 @@ def find_threshold(
|
|||
nlp = util.load_model(model)
|
||||
|
||||
if pipe_name not in nlp.component_names:
|
||||
raise AttributeError(Errors.E001.format(name=pipe_name))
|
||||
raise AttributeError(
|
||||
Errors.E001.format(name=pipe_name, opts=nlp.component_names)
|
||||
)
|
||||
pipe = nlp.get_pipe(pipe_name)
|
||||
if not hasattr(pipe, "scorer"):
|
||||
raise AttributeError(Errors.E1045)
|
||||
|
@ -190,7 +190,7 @@ def find_threshold(
|
|||
if scores_key not in eval_scores:
|
||||
wasabi.msg.fail(
|
||||
title=f"Failed to look up score `{scores_key}` in evaluation results.",
|
||||
text=f"Make sure you specified the correct value for `scores_key` correctly. The following scores are "
|
||||
text=f"Make sure you specified the correct value for `scores_key`. The following scores are "
|
||||
f"available: {eval_scores.keys()}",
|
||||
exits=1,
|
||||
)
|
||||
|
|
|
@ -966,7 +966,7 @@ def test_cli_find_threshold(capsys):
|
|||
nlp, _ = init_nlp()
|
||||
with make_tempdir() as nlp_dir:
|
||||
nlp.to_disk(nlp_dir)
|
||||
with pytest.raises(SystemExit) as error:
|
||||
with pytest.raises(AttributeError):
|
||||
find_threshold(
|
||||
model=nlp_dir,
|
||||
data_path=docs_dir / "docs.spacy",
|
||||
|
@ -975,7 +975,6 @@ def test_cli_find_threshold(capsys):
|
|||
scores_key="cats_macro_f",
|
||||
silent=True,
|
||||
)
|
||||
assert error.value.code == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
Loading…
Reference in New Issue
Block a user