mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-23 15:54:13 +03:00
fix micro PRF for textcat (#6130)
* fix micro PRF for textcat * small fix
This commit is contained in:
parent
17a6b0a173
commit
c645c4e7ce
|
@ -240,7 +240,7 @@ class Scorer:
|
|||
pred_per_feat[field].add((gold_i, feat))
|
||||
for field in per_feat:
|
||||
per_feat[field].score_set(
|
||||
pred_per_feat.get(field, set()), gold_per_feat.get(field, set()),
|
||||
pred_per_feat.get(field, set()), gold_per_feat.get(field, set())
|
||||
)
|
||||
result = {k: v.to_dict() for k, v in per_feat.items()}
|
||||
return {f"{attr}_per_feat": result}
|
||||
|
@ -418,9 +418,9 @@ class Scorer:
|
|||
f_per_type[pred_label].fp += 1
|
||||
micro_prf = PRFScore()
|
||||
for label_prf in f_per_type.values():
|
||||
micro_prf.tp = label_prf.tp
|
||||
micro_prf.fn = label_prf.fn
|
||||
micro_prf.fp = label_prf.fp
|
||||
micro_prf.tp += label_prf.tp
|
||||
micro_prf.fn += label_prf.fn
|
||||
micro_prf.fp += label_prf.fp
|
||||
n_cats = len(f_per_type) + 1e-100
|
||||
macro_p = sum(prf.precision for prf in f_per_type.values()) / n_cats
|
||||
macro_r = sum(prf.recall for prf in f_per_type.values()) / n_cats
|
||||
|
|
|
@ -8,6 +8,7 @@ from spacy.language import Language
|
|||
from spacy.pipeline import TextCategorizer
|
||||
from spacy.tokens import Doc
|
||||
from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL
|
||||
from spacy.scorer import Scorer
|
||||
|
||||
from ..util import make_tempdir
|
||||
from ...cli.train import verify_textcat_config
|
||||
|
@ -224,3 +225,31 @@ def test_positive_class_not_binary():
|
|||
assert textcat.labels == ("SOME", "THING", "POS")
|
||||
with pytest.raises(ValueError):
|
||||
verify_textcat_config(nlp, pipe_config)
|
||||
|
||||
def test_textcat_evaluation():
|
||||
train_examples = []
|
||||
nlp = English()
|
||||
ref1 = nlp("one")
|
||||
ref1.cats = {"winter": 1.0, "summer": 1.0, "spring": 1.0, "autumn": 1.0}
|
||||
pred1 = nlp("one")
|
||||
pred1.cats = {"winter": 1.0, "summer": 0.0, "spring": 1.0, "autumn": 1.0}
|
||||
train_examples.append(Example(pred1, ref1))
|
||||
|
||||
ref2 = nlp("two")
|
||||
ref2.cats = {"winter": 0.0, "summer": 0.0, "spring": 1.0, "autumn": 1.0}
|
||||
pred2 = nlp("two")
|
||||
pred2.cats = {"winter": 1.0, "summer": 0.0, "spring": 0.0, "autumn": 1.0}
|
||||
train_examples.append(Example(pred2, ref2))
|
||||
|
||||
scores = Scorer().score_cats(train_examples, "cats", labels=["winter", "summer", "spring", "autumn"])
|
||||
assert scores["cats_f_per_type"]["winter"]["p"] == 1/2
|
||||
assert scores["cats_f_per_type"]["winter"]["r"] == 1/1
|
||||
assert scores["cats_f_per_type"]["summer"]["p"] == 0
|
||||
assert scores["cats_f_per_type"]["summer"]["r"] == 0/1
|
||||
assert scores["cats_f_per_type"]["spring"]["p"] == 1/1
|
||||
assert scores["cats_f_per_type"]["spring"]["r"] == 1/2
|
||||
assert scores["cats_f_per_type"]["autumn"]["p"] == 2/2
|
||||
assert scores["cats_f_per_type"]["autumn"]["r"] == 2/2
|
||||
|
||||
assert scores["cats_micro_p"] == 4/5
|
||||
assert scores["cats_micro_r"] == 4/6
|
||||
|
|
Loading…
Reference in New Issue
Block a user