mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Fix __add__ method of PRFScore (#7557)
* Add failing test for PRFScore * Fix erroneous implementation of __add__ * Simplify constructor Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
parent
de4f4c9b8a
commit
81fd595223
|
@ -20,10 +20,16 @@ MISSING_VALUES = frozenset([None, 0, ""])
|
||||||
class PRFScore:
|
class PRFScore:
|
||||||
"""A precision / recall / F score."""
|
"""A precision / recall / F score."""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(
|
||||||
self.tp = 0
|
self,
|
||||||
self.fp = 0
|
*,
|
||||||
self.fn = 0
|
tp: int = 0,
|
||||||
|
fp: int = 0,
|
||||||
|
fn: int = 0,
|
||||||
|
) -> None:
|
||||||
|
self.tp = tp
|
||||||
|
self.fp = fp
|
||||||
|
self.fn = fn
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return self.tp + self.fp + self.fn
|
return self.tp + self.fp + self.fn
|
||||||
|
|
|
@ -3,7 +3,7 @@ import pytest
|
||||||
from pytest import approx
|
from pytest import approx
|
||||||
from spacy.training import Example
|
from spacy.training import Example
|
||||||
from spacy.training.iob_utils import offsets_to_biluo_tags
|
from spacy.training.iob_utils import offsets_to_biluo_tags
|
||||||
from spacy.scorer import Scorer, ROCAUCScore
|
from spacy.scorer import Scorer, ROCAUCScore, PRFScore
|
||||||
from spacy.scorer import _roc_auc_score, _roc_curve
|
from spacy.scorer import _roc_auc_score, _roc_curve
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Doc
|
||||||
|
@ -403,3 +403,23 @@ def test_roc_auc_score():
|
||||||
score.score_set(0.75, 1)
|
score.score_set(0.75, 1)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
_ = score.score # noqa: F841
|
_ = score.score # noqa: F841
|
||||||
|
|
||||||
|
|
||||||
|
def test_prf_score():
|
||||||
|
cand = {"hi", "ho"}
|
||||||
|
gold1 = {"yo", "hi"}
|
||||||
|
gold2 = set()
|
||||||
|
|
||||||
|
a = PRFScore()
|
||||||
|
a.score_set(cand=cand, gold=gold1)
|
||||||
|
assert (a.precision, a.recall, a.fscore) == approx((0.5, 0.5, 0.5))
|
||||||
|
|
||||||
|
b = PRFScore()
|
||||||
|
b.score_set(cand=cand, gold=gold2)
|
||||||
|
assert (b.precision, b.recall, b.fscore) == approx((0.0, 0.0, 0.0))
|
||||||
|
|
||||||
|
c = a + b
|
||||||
|
assert (c.precision, c.recall, c.fscore) == approx((0.25, 0.5, 0.33333333))
|
||||||
|
|
||||||
|
a += b
|
||||||
|
assert (a.precision, a.recall, a.fscore) == approx((c.precision, c.recall, c.fscore))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user