fix micro PRF for textcat (#6130)

* fix micro PRF for textcat

* small fix
This commit is contained in:
Sofie Van Landeghem 2020-09-24 10:31:17 +02:00 committed by GitHub
parent 17a6b0a173
commit c645c4e7ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 4 deletions

View File

@ -240,7 +240,7 @@ class Scorer:
pred_per_feat[field].add((gold_i, feat)) pred_per_feat[field].add((gold_i, feat))
for field in per_feat: for field in per_feat:
per_feat[field].score_set( per_feat[field].score_set(
pred_per_feat.get(field, set()), gold_per_feat.get(field, set()), pred_per_feat.get(field, set()), gold_per_feat.get(field, set())
) )
result = {k: v.to_dict() for k, v in per_feat.items()} result = {k: v.to_dict() for k, v in per_feat.items()}
return {f"{attr}_per_feat": result} return {f"{attr}_per_feat": result}
@ -418,9 +418,9 @@ class Scorer:
f_per_type[pred_label].fp += 1 f_per_type[pred_label].fp += 1
micro_prf = PRFScore() micro_prf = PRFScore()
for label_prf in f_per_type.values(): for label_prf in f_per_type.values():
micro_prf.tp = label_prf.tp micro_prf.tp += label_prf.tp
micro_prf.fn = label_prf.fn micro_prf.fn += label_prf.fn
micro_prf.fp = label_prf.fp micro_prf.fp += label_prf.fp
n_cats = len(f_per_type) + 1e-100 n_cats = len(f_per_type) + 1e-100
macro_p = sum(prf.precision for prf in f_per_type.values()) / n_cats macro_p = sum(prf.precision for prf in f_per_type.values()) / n_cats
macro_r = sum(prf.recall for prf in f_per_type.values()) / n_cats macro_r = sum(prf.recall for prf in f_per_type.values()) / n_cats

View File

@ -8,6 +8,7 @@ from spacy.language import Language
from spacy.pipeline import TextCategorizer from spacy.pipeline import TextCategorizer
from spacy.tokens import Doc from spacy.tokens import Doc
from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL
from spacy.scorer import Scorer
from ..util import make_tempdir from ..util import make_tempdir
from ...cli.train import verify_textcat_config from ...cli.train import verify_textcat_config
@ -224,3 +225,31 @@ def test_positive_class_not_binary():
assert textcat.labels == ("SOME", "THING", "POS") assert textcat.labels == ("SOME", "THING", "POS")
with pytest.raises(ValueError): with pytest.raises(ValueError):
verify_textcat_config(nlp, pipe_config) verify_textcat_config(nlp, pipe_config)
def test_textcat_evaluation():
train_examples = []
nlp = English()
ref1 = nlp("one")
ref1.cats = {"winter": 1.0, "summer": 1.0, "spring": 1.0, "autumn": 1.0}
pred1 = nlp("one")
pred1.cats = {"winter": 1.0, "summer": 0.0, "spring": 1.0, "autumn": 1.0}
train_examples.append(Example(pred1, ref1))
ref2 = nlp("two")
ref2.cats = {"winter": 0.0, "summer": 0.0, "spring": 1.0, "autumn": 1.0}
pred2 = nlp("two")
pred2.cats = {"winter": 1.0, "summer": 0.0, "spring": 0.0, "autumn": 1.0}
train_examples.append(Example(pred2, ref2))
scores = Scorer().score_cats(train_examples, "cats", labels=["winter", "summer", "spring", "autumn"])
assert scores["cats_f_per_type"]["winter"]["p"] == 1/2
assert scores["cats_f_per_type"]["winter"]["r"] == 1/1
assert scores["cats_f_per_type"]["summer"]["p"] == 0
assert scores["cats_f_per_type"]["summer"]["r"] == 0/1
assert scores["cats_f_per_type"]["spring"]["p"] == 1/1
assert scores["cats_f_per_type"]["spring"]["r"] == 1/2
assert scores["cats_f_per_type"]["autumn"]["p"] == 2/2
assert scores["cats_f_per_type"]["autumn"]["r"] == 2/2
assert scores["cats_micro_p"] == 4/5
assert scores["cats_micro_r"] == 4/6