Fix __add__ method of PRFScore (#7557)

* Add failing test for PRFScore

* Fix erroneous implementation of __add__

* Simplify constructor

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
graue70 2021-04-08 09:34:14 +02:00 committed by GitHub
parent de4f4c9b8a
commit 81fd595223
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 31 additions and 5 deletions

View File

@ -20,10 +20,16 @@ MISSING_VALUES = frozenset([None, 0, ""])
class PRFScore:
"""A precision / recall / F score."""
def __init__(self) -> None:
self.tp = 0
self.fp = 0
self.fn = 0
def __init__(
self,
*,
tp: int = 0,
fp: int = 0,
fn: int = 0,
) -> None:
self.tp = tp
self.fp = fp
self.fn = fn
def __len__(self) -> int:
return self.tp + self.fp + self.fn

View File

@ -3,7 +3,7 @@ import pytest
from pytest import approx
from spacy.training import Example
from spacy.training.iob_utils import offsets_to_biluo_tags
from spacy.scorer import Scorer, ROCAUCScore
from spacy.scorer import Scorer, ROCAUCScore, PRFScore
from spacy.scorer import _roc_auc_score, _roc_curve
from spacy.lang.en import English
from spacy.tokens import Doc
@ -403,3 +403,23 @@ def test_roc_auc_score():
score.score_set(0.75, 1)
with pytest.raises(ValueError):
_ = score.score # noqa: F841
def test_prf_score():
cand = {"hi", "ho"}
gold1 = {"yo", "hi"}
gold2 = set()
a = PRFScore()
a.score_set(cand=cand, gold=gold1)
assert (a.precision, a.recall, a.fscore) == approx((0.5, 0.5, 0.5))
b = PRFScore()
b.score_set(cand=cand, gold=gold2)
assert (b.precision, b.recall, b.fscore) == approx((0.0, 0.0, 0.0))
c = a + b
assert (c.precision, c.recall, c.fscore) == approx((0.25, 0.5, 0.33333333))
a += b
assert (a.precision, a.recall, a.fscore) == approx((c.precision, c.recall, c.fscore))