mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-06 21:30:22 +03:00
Add tests for score_cats with thresholds
This commit is contained in:
parent
ff4ea7cbee
commit
7b139d1bcd
|
@ -474,3 +474,50 @@ def test_prf_score():
|
|||
assert (a.precision, a.recall, a.fscore) == approx(
|
||||
(c.precision, c.recall, c.fscore)
|
||||
)
|
||||
|
||||
|
||||
def test_score_cats(en_tokenizer):
|
||||
text = "some text"
|
||||
gold_doc = en_tokenizer(text)
|
||||
gold_doc.cats = {"POSITIVE": 1.0, "NEGATIVE": 0.0}
|
||||
pred_doc = en_tokenizer(text)
|
||||
pred_doc.cats = {"POSITIVE": 0.75, "NEGATIVE": 0.25}
|
||||
example = Example(pred_doc, gold_doc)
|
||||
# threshold is ignored for multi_label=False
|
||||
scores1 = Scorer.score_cats(
|
||||
[example],
|
||||
"cats",
|
||||
labels=list(gold_doc.cats.keys()),
|
||||
multi_label=False,
|
||||
positive_label="POSITIVE",
|
||||
threshold=0.1,
|
||||
)
|
||||
scores2 = Scorer.score_cats(
|
||||
[example],
|
||||
"cats",
|
||||
labels=list(gold_doc.cats.keys()),
|
||||
multi_label=False,
|
||||
positive_label="POSITIVE",
|
||||
threshold=0.9,
|
||||
)
|
||||
assert scores1["cats_score"] == 1.0
|
||||
assert scores2["cats_score"] == 1.0
|
||||
assert scores1 == scores2
|
||||
# threshold is relevant for multi_label=True
|
||||
scores = Scorer.score_cats(
|
||||
[example],
|
||||
"cats",
|
||||
labels=list(gold_doc.cats.keys()),
|
||||
multi_label=True,
|
||||
threshold=0.9,
|
||||
)
|
||||
assert scores["cats_macro_f"] == 0.0
|
||||
# threshold is relevant for multi_label=True
|
||||
scores = Scorer.score_cats(
|
||||
[example],
|
||||
"cats",
|
||||
labels=list(gold_doc.cats.keys()),
|
||||
multi_label=True,
|
||||
threshold=0.1,
|
||||
)
|
||||
assert scores["cats_macro_f"] == 0.5
|
||||
|
|
Loading…
Reference in New Issue
Block a user