diff --git a/spacy/scorer.py b/spacy/scorer.py index f28cb5639..8061aa329 100644 --- a/spacy/scorer.py +++ b/spacy/scorer.py @@ -20,10 +20,16 @@ MISSING_VALUES = frozenset([None, 0, ""]) class PRFScore: """A precision / recall / F score.""" - def __init__(self) -> None: - self.tp = 0 - self.fp = 0 - self.fn = 0 + def __init__( + self, + *, + tp: int = 0, + fp: int = 0, + fn: int = 0, + ) -> None: + self.tp = tp + self.fp = fp + self.fn = fn def __len__(self) -> int: return self.tp + self.fp + self.fn diff --git a/spacy/tests/test_scorer.py b/spacy/tests/test_scorer.py index 4dddca404..ecdaee768 100644 --- a/spacy/tests/test_scorer.py +++ b/spacy/tests/test_scorer.py @@ -3,7 +3,7 @@ import pytest from pytest import approx from spacy.training import Example from spacy.training.iob_utils import offsets_to_biluo_tags -from spacy.scorer import Scorer, ROCAUCScore +from spacy.scorer import Scorer, ROCAUCScore, PRFScore from spacy.scorer import _roc_auc_score, _roc_curve from spacy.lang.en import English from spacy.tokens import Doc @@ -403,3 +403,23 @@ def test_roc_auc_score(): score.score_set(0.75, 1) with pytest.raises(ValueError): _ = score.score # noqa: F841 + + +def test_prf_score(): + cand = {"hi", "ho"} + gold1 = {"yo", "hi"} + gold2 = set() + + a = PRFScore() + a.score_set(cand=cand, gold=gold1) + assert (a.precision, a.recall, a.fscore) == approx((0.5, 0.5, 0.5)) + + b = PRFScore() + b.score_set(cand=cand, gold=gold2) + assert (b.precision, b.recall, b.fscore) == approx((0.0, 0.0, 0.0)) + + c = a + b + assert (c.precision, c.recall, c.fscore) == approx((0.25, 0.5, 0.33333333)) + + a += b + assert (a.precision, a.recall, a.fscore) == approx((c.precision, c.recall, c.fscore))