From c645c4e7ceddbd819b7a56e56f013bb8447dea4b Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Thu, 24 Sep 2020 10:31:17 +0200 Subject: [PATCH] fix micro PRF for textcat (#6130) * fix micro PRF for textcat * small fix --- spacy/scorer.py | 8 ++++---- spacy/tests/pipeline/test_textcat.py | 29 ++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/spacy/scorer.py b/spacy/scorer.py index da22d59d4..c50de3d43 100644 --- a/spacy/scorer.py +++ b/spacy/scorer.py @@ -240,7 +240,7 @@ class Scorer: pred_per_feat[field].add((gold_i, feat)) for field in per_feat: per_feat[field].score_set( - pred_per_feat.get(field, set()), gold_per_feat.get(field, set()), + pred_per_feat.get(field, set()), gold_per_feat.get(field, set()) ) result = {k: v.to_dict() for k, v in per_feat.items()} return {f"{attr}_per_feat": result} @@ -418,9 +418,9 @@ class Scorer: f_per_type[pred_label].fp += 1 micro_prf = PRFScore() for label_prf in f_per_type.values(): - micro_prf.tp = label_prf.tp - micro_prf.fn = label_prf.fn - micro_prf.fp = label_prf.fp + micro_prf.tp += label_prf.tp + micro_prf.fn += label_prf.fn + micro_prf.fp += label_prf.fp n_cats = len(f_per_type) + 1e-100 macro_p = sum(prf.precision for prf in f_per_type.values()) / n_cats macro_r = sum(prf.recall for prf in f_per_type.values()) / n_cats diff --git a/spacy/tests/pipeline/test_textcat.py b/spacy/tests/pipeline/test_textcat.py index 99b5132ca..232b53e1d 100644 --- a/spacy/tests/pipeline/test_textcat.py +++ b/spacy/tests/pipeline/test_textcat.py @@ -8,6 +8,7 @@ from spacy.language import Language from spacy.pipeline import TextCategorizer from spacy.tokens import Doc from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL +from spacy.scorer import Scorer from ..util import make_tempdir from ...cli.train import verify_textcat_config @@ -224,3 +225,31 @@ def test_positive_class_not_binary(): assert textcat.labels == ("SOME", "THING", "POS") with pytest.raises(ValueError): verify_textcat_config(nlp, pipe_config) + +def test_textcat_evaluation(): + train_examples = [] + nlp = English() + ref1 = nlp("one") + ref1.cats = {"winter": 1.0, "summer": 1.0, "spring": 1.0, "autumn": 1.0} + pred1 = nlp("one") + pred1.cats = {"winter": 1.0, "summer": 0.0, "spring": 1.0, "autumn": 1.0} + train_examples.append(Example(pred1, ref1)) + + ref2 = nlp("two") + ref2.cats = {"winter": 0.0, "summer": 0.0, "spring": 1.0, "autumn": 1.0} + pred2 = nlp("two") + pred2.cats = {"winter": 1.0, "summer": 0.0, "spring": 0.0, "autumn": 1.0} + train_examples.append(Example(pred2, ref2)) + + scores = Scorer().score_cats(train_examples, "cats", labels=["winter", "summer", "spring", "autumn"]) + assert scores["cats_f_per_type"]["winter"]["p"] == 1/2 + assert scores["cats_f_per_type"]["winter"]["r"] == 1/1 + assert scores["cats_f_per_type"]["summer"]["p"] == 0 + assert scores["cats_f_per_type"]["summer"]["r"] == 0/1 + assert scores["cats_f_per_type"]["spring"]["p"] == 1/1 + assert scores["cats_f_per_type"]["spring"]["r"] == 1/2 + assert scores["cats_f_per_type"]["autumn"]["p"] == 2/2 + assert scores["cats_f_per_type"]["autumn"]["r"] == 2/2 + + assert scores["cats_micro_p"] == 4/5 + assert scores["cats_micro_r"] == 4/6