mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-11 17:56:30 +03:00
fix micro PRF for textcat (#6130)
* fix micro PRF for textcat * small fix
This commit is contained in:
parent
17a6b0a173
commit
c645c4e7ce
|
@ -240,7 +240,7 @@ class Scorer:
|
||||||
pred_per_feat[field].add((gold_i, feat))
|
pred_per_feat[field].add((gold_i, feat))
|
||||||
for field in per_feat:
|
for field in per_feat:
|
||||||
per_feat[field].score_set(
|
per_feat[field].score_set(
|
||||||
pred_per_feat.get(field, set()), gold_per_feat.get(field, set()),
|
pred_per_feat.get(field, set()), gold_per_feat.get(field, set())
|
||||||
)
|
)
|
||||||
result = {k: v.to_dict() for k, v in per_feat.items()}
|
result = {k: v.to_dict() for k, v in per_feat.items()}
|
||||||
return {f"{attr}_per_feat": result}
|
return {f"{attr}_per_feat": result}
|
||||||
|
@ -418,9 +418,9 @@ class Scorer:
|
||||||
f_per_type[pred_label].fp += 1
|
f_per_type[pred_label].fp += 1
|
||||||
micro_prf = PRFScore()
|
micro_prf = PRFScore()
|
||||||
for label_prf in f_per_type.values():
|
for label_prf in f_per_type.values():
|
||||||
micro_prf.tp = label_prf.tp
|
micro_prf.tp += label_prf.tp
|
||||||
micro_prf.fn = label_prf.fn
|
micro_prf.fn += label_prf.fn
|
||||||
micro_prf.fp = label_prf.fp
|
micro_prf.fp += label_prf.fp
|
||||||
n_cats = len(f_per_type) + 1e-100
|
n_cats = len(f_per_type) + 1e-100
|
||||||
macro_p = sum(prf.precision for prf in f_per_type.values()) / n_cats
|
macro_p = sum(prf.precision for prf in f_per_type.values()) / n_cats
|
||||||
macro_r = sum(prf.recall for prf in f_per_type.values()) / n_cats
|
macro_r = sum(prf.recall for prf in f_per_type.values()) / n_cats
|
||||||
|
|
|
@ -8,6 +8,7 @@ from spacy.language import Language
|
||||||
from spacy.pipeline import TextCategorizer
|
from spacy.pipeline import TextCategorizer
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Doc
|
||||||
from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL
|
from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL
|
||||||
|
from spacy.scorer import Scorer
|
||||||
|
|
||||||
from ..util import make_tempdir
|
from ..util import make_tempdir
|
||||||
from ...cli.train import verify_textcat_config
|
from ...cli.train import verify_textcat_config
|
||||||
|
@ -224,3 +225,31 @@ def test_positive_class_not_binary():
|
||||||
assert textcat.labels == ("SOME", "THING", "POS")
|
assert textcat.labels == ("SOME", "THING", "POS")
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
verify_textcat_config(nlp, pipe_config)
|
verify_textcat_config(nlp, pipe_config)
|
||||||
|
|
||||||
|
def test_textcat_evaluation():
|
||||||
|
train_examples = []
|
||||||
|
nlp = English()
|
||||||
|
ref1 = nlp("one")
|
||||||
|
ref1.cats = {"winter": 1.0, "summer": 1.0, "spring": 1.0, "autumn": 1.0}
|
||||||
|
pred1 = nlp("one")
|
||||||
|
pred1.cats = {"winter": 1.0, "summer": 0.0, "spring": 1.0, "autumn": 1.0}
|
||||||
|
train_examples.append(Example(pred1, ref1))
|
||||||
|
|
||||||
|
ref2 = nlp("two")
|
||||||
|
ref2.cats = {"winter": 0.0, "summer": 0.0, "spring": 1.0, "autumn": 1.0}
|
||||||
|
pred2 = nlp("two")
|
||||||
|
pred2.cats = {"winter": 1.0, "summer": 0.0, "spring": 0.0, "autumn": 1.0}
|
||||||
|
train_examples.append(Example(pred2, ref2))
|
||||||
|
|
||||||
|
scores = Scorer().score_cats(train_examples, "cats", labels=["winter", "summer", "spring", "autumn"])
|
||||||
|
assert scores["cats_f_per_type"]["winter"]["p"] == 1/2
|
||||||
|
assert scores["cats_f_per_type"]["winter"]["r"] == 1/1
|
||||||
|
assert scores["cats_f_per_type"]["summer"]["p"] == 0
|
||||||
|
assert scores["cats_f_per_type"]["summer"]["r"] == 0/1
|
||||||
|
assert scores["cats_f_per_type"]["spring"]["p"] == 1/1
|
||||||
|
assert scores["cats_f_per_type"]["spring"]["r"] == 1/2
|
||||||
|
assert scores["cats_f_per_type"]["autumn"]["p"] == 2/2
|
||||||
|
assert scores["cats_f_per_type"]["autumn"]["r"] == 2/2
|
||||||
|
|
||||||
|
assert scores["cats_micro_p"] == 4/5
|
||||||
|
assert scores["cats_micro_r"] == 4/6
|
||||||
|
|
Loading…
Reference in New Issue
Block a user