From 81fd595223ced0df15318cca692fdf0d8e8f79fd Mon Sep 17 00:00:00 2001 From: graue70 <23035329+graue70@users.noreply.github.com> Date: Thu, 8 Apr 2021 09:34:14 +0200 Subject: [PATCH] Fix __add__ method of PRFScore (#7557) * Add failing test for PRFScore * Fix erroneous implementation of __add__ * Simplify constructor Co-authored-by: Sofie Van Landeghem Co-authored-by: Sofie Van Landeghem --- spacy/scorer.py | 14 ++++++++++---- spacy/tests/test_scorer.py | 22 +++++++++++++++++++++- 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/spacy/scorer.py b/spacy/scorer.py index f28cb5639..8061aa329 100644 --- a/spacy/scorer.py +++ b/spacy/scorer.py @@ -20,10 +20,16 @@ MISSING_VALUES = frozenset([None, 0, ""]) class PRFScore: """A precision / recall / F score.""" - def __init__(self) -> None: - self.tp = 0 - self.fp = 0 - self.fn = 0 + def __init__( + self, + *, + tp: int = 0, + fp: int = 0, + fn: int = 0, + ) -> None: + self.tp = tp + self.fp = fp + self.fn = fn def __len__(self) -> int: return self.tp + self.fp + self.fn diff --git a/spacy/tests/test_scorer.py b/spacy/tests/test_scorer.py index 4dddca404..ecdaee768 100644 --- a/spacy/tests/test_scorer.py +++ b/spacy/tests/test_scorer.py @@ -3,7 +3,7 @@ import pytest from pytest import approx from spacy.training import Example from spacy.training.iob_utils import offsets_to_biluo_tags -from spacy.scorer import Scorer, ROCAUCScore +from spacy.scorer import Scorer, ROCAUCScore, PRFScore from spacy.scorer import _roc_auc_score, _roc_curve from spacy.lang.en import English from spacy.tokens import Doc @@ -403,3 +403,23 @@ def test_roc_auc_score(): score.score_set(0.75, 1) with pytest.raises(ValueError): _ = score.score # noqa: F841 + + +def test_prf_score(): + cand = {"hi", "ho"} + gold1 = {"yo", "hi"} + gold2 = set() + + a = PRFScore() + a.score_set(cand=cand, gold=gold1) + assert (a.precision, a.recall, a.fscore) == approx((0.5, 0.5, 0.5)) + + b = PRFScore() + b.score_set(cand=cand, gold=gold2) + assert (b.precision, b.recall, b.fscore) == approx((0.0, 0.0, 0.0)) + + c = a + b + assert (c.precision, c.recall, c.fscore) == approx((0.25, 0.5, 0.33333333)) + + a += b + assert (a.precision, a.recall, a.fscore) == approx((c.precision, c.recall, c.fscore))