mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Make test_cli_find_threshold() more robust. (#12148)
This commit is contained in:
parent
f9e020dd67
commit
950fceceb6
|
@ -1074,7 +1074,7 @@ def test_cli_find_threshold(capsys):
|
||||||
)
|
)
|
||||||
with make_tempdir() as nlp_dir:
|
with make_tempdir() as nlp_dir:
|
||||||
nlp.to_disk(nlp_dir)
|
nlp.to_disk(nlp_dir)
|
||||||
res = find_threshold(
|
best_threshold, best_score, res = find_threshold(
|
||||||
model=nlp_dir,
|
model=nlp_dir,
|
||||||
data_path=docs_dir / "docs.spacy",
|
data_path=docs_dir / "docs.spacy",
|
||||||
pipe_name="tc_multi",
|
pipe_name="tc_multi",
|
||||||
|
@ -1082,10 +1082,10 @@ def test_cli_find_threshold(capsys):
|
||||||
scores_key="cats_macro_f",
|
scores_key="cats_macro_f",
|
||||||
silent=True,
|
silent=True,
|
||||||
)
|
)
|
||||||
assert res[0] != thresholds[0]
|
assert best_threshold != thresholds[0]
|
||||||
assert thresholds[0] < res[0] < thresholds[9]
|
assert thresholds[0] < best_threshold < thresholds[9]
|
||||||
assert res[1] == 1.0
|
assert best_score == max(res.values())
|
||||||
assert res[2][1.0] == 0.0
|
assert res[1.0] == 0.0
|
||||||
|
|
||||||
# Test with spancat.
|
# Test with spancat.
|
||||||
nlp, _ = init_nlp((("spancat", {}),))
|
nlp, _ = init_nlp((("spancat", {}),))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user