Make test_cli_find_threshold() more robust. (#12148)

This commit is contained in:
Raphael Mitsch 2023-01-23 14:42:33 +01:00 committed by GitHub
parent f9e020dd67
commit 950fceceb6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1074,7 +1074,7 @@ def test_cli_find_threshold(capsys):
) )
with make_tempdir() as nlp_dir: with make_tempdir() as nlp_dir:
nlp.to_disk(nlp_dir) nlp.to_disk(nlp_dir)
res = find_threshold( best_threshold, best_score, res = find_threshold(
model=nlp_dir, model=nlp_dir,
data_path=docs_dir / "docs.spacy", data_path=docs_dir / "docs.spacy",
pipe_name="tc_multi", pipe_name="tc_multi",
@ -1082,10 +1082,10 @@ def test_cli_find_threshold(capsys):
scores_key="cats_macro_f", scores_key="cats_macro_f",
silent=True, silent=True,
) )
assert res[0] != thresholds[0] assert best_threshold != thresholds[0]
assert thresholds[0] < res[0] < thresholds[9] assert thresholds[0] < best_threshold < thresholds[9]
assert res[1] == 1.0 assert best_score == max(res.values())
assert res[2][1.0] == 0.0 assert res[1.0] == 0.0
# Test with spancat. # Test with spancat.
nlp, _ = init_nlp((("spancat", {}),)) nlp, _ = init_nlp((("spancat", {}),))