diff --git a/spacy/tests/test_cli.py b/spacy/tests/test_cli.py index 42ffae22d..dc7ce46fe 100644 --- a/spacy/tests/test_cli.py +++ b/spacy/tests/test_cli.py @@ -1017,8 +1017,6 @@ def test_local_remote_storage_pull_missing(): def test_cli_find_threshold(capsys): - thresholds = numpy.linspace(0, 1, 10) - def make_examples(nlp: Language) -> List[Example]: docs: List[Example] = [] @@ -1082,8 +1080,6 @@ def test_cli_find_threshold(capsys): scores_key="cats_macro_f", silent=True, ) - assert best_threshold != thresholds[0] - assert thresholds[0] < best_threshold < thresholds[9] assert best_score == max(res.values()) assert res[1.0] == 0.0 @@ -1091,7 +1087,7 @@ def test_cli_find_threshold(capsys): nlp, _ = init_nlp((("spancat", {}),)) with make_tempdir() as nlp_dir: nlp.to_disk(nlp_dir) - res = find_threshold( + best_threshold, best_score, res = find_threshold( model=nlp_dir, data_path=docs_dir / "docs.spacy", pipe_name="spancat", @@ -1099,10 +1095,8 @@ def test_cli_find_threshold(capsys): scores_key="spans_sc_f", silent=True, ) - assert res[0] != thresholds[0] - assert thresholds[0] < res[0] < thresholds[8] - assert res[1] >= 0.6 - assert res[2][1.0] == 0.0 + assert best_score == max(res.values()) + assert res[1.0] == 0.0 # Having multiple textcat_multilabel components should work, since the name has to be specified. nlp, _ = init_nlp((("textcat_multilabel", {}),))