diff --git a/spacy/tests/test_cli.py b/spacy/tests/test_cli.py index d00f66c60..249c44672 100644 --- a/spacy/tests/test_cli.py +++ b/spacy/tests/test_cli.py @@ -1074,7 +1074,7 @@ def test_cli_find_threshold(capsys): ) with make_tempdir() as nlp_dir: nlp.to_disk(nlp_dir) - res = find_threshold( + best_threshold, best_score, res = find_threshold( model=nlp_dir, data_path=docs_dir / "docs.spacy", pipe_name="tc_multi", @@ -1082,10 +1082,10 @@ def test_cli_find_threshold(capsys): scores_key="cats_macro_f", silent=True, ) - assert res[0] != thresholds[0] - assert thresholds[0] < res[0] < thresholds[9] - assert res[1] == 1.0 - assert res[2][1.0] == 0.0 + assert best_threshold != thresholds[0] + assert thresholds[0] < best_threshold < thresholds[9] + assert best_score == max(res.values()) + assert res[1.0] == 0.0 # Test with spancat. nlp, _ = init_nlp((("spancat", {}),))