2019-09-15 23:31:31 +03:00
|
|
|
from numpy.testing import assert_almost_equal, assert_array_almost_equal
|
|
|
|
import pytest
|
2019-08-01 18:15:36 +03:00
|
|
|
from pytest import approx
|
2019-11-11 19:35:27 +03:00
|
|
|
from spacy.gold import Example, GoldParse
|
2019-09-15 23:31:31 +03:00
|
|
|
from spacy.scorer import Scorer, ROCAUCScore
|
|
|
|
from spacy.scorer import _roc_auc_score, _roc_curve
|
2019-08-01 18:15:36 +03:00
|
|
|
from .util import get_doc
|
2020-04-02 15:46:32 +03:00
|
|
|
from spacy.lang.en import English
|
2019-08-01 18:15:36 +03:00
|
|
|
|
2019-10-31 23:18:16 +03:00
|
|
|
test_las_apple = [
|
|
|
|
[
|
|
|
|
"Apple is looking at buying U.K. startup for $ 1 billion",
|
2019-11-20 15:15:24 +03:00
|
|
|
{
|
|
|
|
"heads": [2, 2, 2, 2, 3, 6, 4, 4, 10, 10, 7],
|
|
|
|
"deps": [
|
|
|
|
"nsubj",
|
|
|
|
"aux",
|
|
|
|
"ROOT",
|
|
|
|
"prep",
|
|
|
|
"pcomp",
|
|
|
|
"compound",
|
|
|
|
"dobj",
|
|
|
|
"prep",
|
|
|
|
"quantmod",
|
|
|
|
"compound",
|
|
|
|
"pobj",
|
|
|
|
],
|
|
|
|
},
|
2019-10-31 23:18:16 +03:00
|
|
|
]
|
|
|
|
]
|
|
|
|
|
2019-08-01 18:15:36 +03:00
|
|
|
test_ner_cardinal = [
|
2019-08-18 16:09:16 +03:00
|
|
|
["100 - 200", {"entities": [[0, 3, "CARDINAL"], [6, 9, "CARDINAL"]]}]
|
2019-08-01 18:15:36 +03:00
|
|
|
]
|
|
|
|
|
|
|
|
test_ner_apple = [
|
|
|
|
[
|
|
|
|
"Apple is looking at buying U.K. startup for $1 billion",
|
2019-08-18 16:09:16 +03:00
|
|
|
{"entities": [(0, 5, "ORG"), (27, 31, "GPE"), (44, 54, "MONEY")]},
|
2019-08-01 18:15:36 +03:00
|
|
|
]
|
|
|
|
]
|
|
|
|
|
2020-04-02 15:46:32 +03:00
|
|
|
@pytest.fixture
|
|
|
|
def tagged_doc():
|
|
|
|
text = "Sarah's sister flew to Silicon Valley via London."
|
|
|
|
tags = ["NNP", "POS", "NN", "VBD", "IN", "NNP", "NNP", "IN", "NNP", "."]
|
|
|
|
pos = [
|
|
|
|
"PROPN",
|
|
|
|
"PART",
|
|
|
|
"NOUN",
|
|
|
|
"VERB",
|
|
|
|
"ADP",
|
|
|
|
"PROPN",
|
|
|
|
"PROPN",
|
|
|
|
"ADP",
|
|
|
|
"PROPN",
|
|
|
|
"PUNCT",
|
|
|
|
]
|
|
|
|
morphs = [
|
|
|
|
"NounType=prop|Number=sing",
|
|
|
|
"Poss=yes",
|
|
|
|
"Number=sing",
|
|
|
|
"Tense=past|VerbForm=fin",
|
|
|
|
"",
|
|
|
|
"NounType=prop|Number=sing",
|
|
|
|
"NounType=prop|Number=sing",
|
|
|
|
"",
|
|
|
|
"NounType=prop|Number=sing",
|
|
|
|
"PunctType=peri",
|
|
|
|
]
|
|
|
|
nlp = English()
|
|
|
|
doc = nlp(text)
|
|
|
|
for i in range(len(tags)):
|
|
|
|
doc[i].tag_ = tags[i]
|
|
|
|
doc[i].pos_ = pos[i]
|
|
|
|
doc[i].morph_ = morphs[i]
|
|
|
|
doc.is_tagged = True
|
|
|
|
return doc
|
|
|
|
|
2019-08-18 16:09:16 +03:00
|
|
|
|
2019-10-31 23:18:16 +03:00
|
|
|
def test_las_per_type(en_vocab):
|
|
|
|
# Gold and Doc are identical
|
|
|
|
scorer = Scorer()
|
|
|
|
for input_, annot in test_las_apple:
|
|
|
|
doc = get_doc(
|
|
|
|
en_vocab,
|
|
|
|
words=input_.split(" "),
|
|
|
|
heads=([h - i for i, h in enumerate(annot["heads"])]),
|
|
|
|
deps=annot["deps"],
|
|
|
|
)
|
|
|
|
gold = GoldParse(doc, heads=annot["heads"], deps=annot["deps"])
|
2019-11-11 19:35:27 +03:00
|
|
|
scorer.score((doc, gold))
|
2019-10-31 23:18:16 +03:00
|
|
|
results = scorer.scores
|
|
|
|
|
|
|
|
assert results["uas"] == 100
|
|
|
|
assert results["las"] == 100
|
|
|
|
assert results["las_per_type"]["nsubj"]["p"] == 100
|
|
|
|
assert results["las_per_type"]["nsubj"]["r"] == 100
|
|
|
|
assert results["las_per_type"]["nsubj"]["f"] == 100
|
|
|
|
assert results["las_per_type"]["compound"]["p"] == 100
|
|
|
|
assert results["las_per_type"]["compound"]["r"] == 100
|
|
|
|
assert results["las_per_type"]["compound"]["f"] == 100
|
|
|
|
|
|
|
|
# One dep is incorrect in Doc
|
|
|
|
scorer = Scorer()
|
|
|
|
for input_, annot in test_las_apple:
|
|
|
|
doc = get_doc(
|
|
|
|
en_vocab,
|
|
|
|
words=input_.split(" "),
|
|
|
|
heads=([h - i for i, h in enumerate(annot["heads"])]),
|
2019-11-20 15:15:24 +03:00
|
|
|
deps=annot["deps"],
|
2019-10-31 23:18:16 +03:00
|
|
|
)
|
|
|
|
gold = GoldParse(doc, heads=annot["heads"], deps=annot["deps"])
|
|
|
|
doc[0].dep_ = "compound"
|
2019-11-11 19:35:27 +03:00
|
|
|
scorer.score((doc, gold))
|
2019-10-31 23:18:16 +03:00
|
|
|
results = scorer.scores
|
|
|
|
|
|
|
|
assert results["uas"] == 100
|
|
|
|
assert_almost_equal(results["las"], 90.9090909)
|
|
|
|
assert results["las_per_type"]["nsubj"]["p"] == 0
|
|
|
|
assert results["las_per_type"]["nsubj"]["r"] == 0
|
|
|
|
assert results["las_per_type"]["nsubj"]["f"] == 0
|
|
|
|
assert_almost_equal(results["las_per_type"]["compound"]["p"], 66.6666666)
|
|
|
|
assert results["las_per_type"]["compound"]["r"] == 100
|
|
|
|
assert results["las_per_type"]["compound"]["f"] == 80
|
|
|
|
|
|
|
|
|
2019-08-01 18:15:36 +03:00
|
|
|
def test_ner_per_type(en_vocab):
|
|
|
|
# Gold and Doc are identical
|
|
|
|
scorer = Scorer()
|
|
|
|
for input_, annot in test_ner_cardinal:
|
2019-08-18 16:09:16 +03:00
|
|
|
doc = get_doc(
|
|
|
|
en_vocab,
|
|
|
|
words=input_.split(" "),
|
|
|
|
ents=[[0, 1, "CARDINAL"], [2, 3, "CARDINAL"]],
|
|
|
|
)
|
2019-11-11 19:35:27 +03:00
|
|
|
ex = Example(doc=doc)
|
2019-11-25 18:03:28 +03:00
|
|
|
ex.set_token_annotation(entities=annot["entities"])
|
2019-11-11 19:35:27 +03:00
|
|
|
scorer.score(ex)
|
2019-08-01 18:15:36 +03:00
|
|
|
results = scorer.scores
|
|
|
|
|
2019-08-18 16:09:16 +03:00
|
|
|
assert results["ents_p"] == 100
|
|
|
|
assert results["ents_f"] == 100
|
|
|
|
assert results["ents_r"] == 100
|
|
|
|
assert results["ents_per_type"]["CARDINAL"]["p"] == 100
|
|
|
|
assert results["ents_per_type"]["CARDINAL"]["f"] == 100
|
|
|
|
assert results["ents_per_type"]["CARDINAL"]["r"] == 100
|
2019-08-01 18:15:36 +03:00
|
|
|
|
|
|
|
# Doc has one missing and one extra entity
|
|
|
|
# Entity type MONEY is not present in Doc
|
|
|
|
scorer = Scorer()
|
|
|
|
for input_, annot in test_ner_apple:
|
2019-08-18 16:09:16 +03:00
|
|
|
doc = get_doc(
|
|
|
|
en_vocab,
|
|
|
|
words=input_.split(" "),
|
|
|
|
ents=[[0, 1, "ORG"], [5, 6, "GPE"], [6, 7, "ORG"]],
|
|
|
|
)
|
2019-11-11 19:35:27 +03:00
|
|
|
ex = Example(doc=doc)
|
2019-11-25 18:03:28 +03:00
|
|
|
ex.set_token_annotation(entities=annot["entities"])
|
2019-11-11 19:35:27 +03:00
|
|
|
scorer.score(ex)
|
2019-08-01 18:15:36 +03:00
|
|
|
results = scorer.scores
|
|
|
|
|
2019-08-18 16:09:16 +03:00
|
|
|
assert results["ents_p"] == approx(66.66666)
|
|
|
|
assert results["ents_r"] == approx(66.66666)
|
|
|
|
assert results["ents_f"] == approx(66.66666)
|
|
|
|
assert "GPE" in results["ents_per_type"]
|
|
|
|
assert "MONEY" in results["ents_per_type"]
|
|
|
|
assert "ORG" in results["ents_per_type"]
|
|
|
|
assert results["ents_per_type"]["GPE"]["p"] == 100
|
|
|
|
assert results["ents_per_type"]["GPE"]["r"] == 100
|
|
|
|
assert results["ents_per_type"]["GPE"]["f"] == 100
|
|
|
|
assert results["ents_per_type"]["MONEY"]["p"] == 0
|
|
|
|
assert results["ents_per_type"]["MONEY"]["r"] == 0
|
|
|
|
assert results["ents_per_type"]["MONEY"]["f"] == 0
|
|
|
|
assert results["ents_per_type"]["ORG"]["p"] == 50
|
|
|
|
assert results["ents_per_type"]["ORG"]["r"] == 100
|
|
|
|
assert results["ents_per_type"]["ORG"]["f"] == approx(66.66666)
|
2019-09-15 23:31:31 +03:00
|
|
|
|
|
|
|
|
2020-04-02 15:46:32 +03:00
|
|
|
def test_tag_score(tagged_doc):
|
|
|
|
# Gold and Doc are identical
|
|
|
|
scorer = Scorer()
|
|
|
|
gold = GoldParse(
|
|
|
|
tagged_doc,
|
|
|
|
tags=[t.tag_ for t in tagged_doc],
|
|
|
|
pos=[t.pos_ for t in tagged_doc],
|
|
|
|
morphs=[t.morph_ for t in tagged_doc]
|
|
|
|
)
|
|
|
|
scorer.score((tagged_doc, gold))
|
|
|
|
results = scorer.scores
|
|
|
|
|
|
|
|
assert results["tags_acc"] == 100
|
|
|
|
assert results["pos_acc"] == 100
|
|
|
|
assert results["morphs_acc"] == 100
|
|
|
|
assert results["morphs_per_type"]["NounType"]["f"] == 100
|
|
|
|
|
|
|
|
# Gold and Doc are identical
|
|
|
|
scorer = Scorer()
|
|
|
|
tags = [t.tag_ for t in tagged_doc]
|
|
|
|
tags[0] = "NN"
|
|
|
|
pos = [t.pos_ for t in tagged_doc]
|
|
|
|
pos[1] = "X"
|
|
|
|
morphs = [t.morph_ for t in tagged_doc]
|
|
|
|
morphs[1] = "Number=sing"
|
|
|
|
morphs[2] = "Number=plur"
|
|
|
|
gold = GoldParse(tagged_doc, tags=tags, pos=pos, morphs=morphs)
|
|
|
|
scorer.score((tagged_doc, gold))
|
|
|
|
results = scorer.scores
|
|
|
|
|
|
|
|
assert results["tags_acc"] == 90
|
|
|
|
assert results["pos_acc"] == 90
|
|
|
|
assert results["morphs_acc"] == approx(80)
|
|
|
|
assert results["morphs_per_type"]["Poss"]["f"] == 0.0
|
|
|
|
assert results["morphs_per_type"]["Number"]["f"] == approx(72.727272)
|
|
|
|
|
|
|
|
|
2019-09-15 23:31:31 +03:00
|
|
|
def test_roc_auc_score():
|
|
|
|
# Binary classification, toy tests from scikit-learn test suite
|
|
|
|
y_true = [0, 1]
|
|
|
|
y_score = [0, 1]
|
|
|
|
tpr, fpr, _ = _roc_curve(y_true, y_score)
|
|
|
|
roc_auc = _roc_auc_score(y_true, y_score)
|
|
|
|
assert_array_almost_equal(tpr, [0, 0, 1])
|
|
|
|
assert_array_almost_equal(fpr, [0, 1, 1])
|
2019-09-18 21:27:03 +03:00
|
|
|
assert_almost_equal(roc_auc, 1.0)
|
2019-09-15 23:31:31 +03:00
|
|
|
|
|
|
|
y_true = [0, 1]
|
|
|
|
y_score = [1, 0]
|
|
|
|
tpr, fpr, _ = _roc_curve(y_true, y_score)
|
|
|
|
roc_auc = _roc_auc_score(y_true, y_score)
|
|
|
|
assert_array_almost_equal(tpr, [0, 1, 1])
|
|
|
|
assert_array_almost_equal(fpr, [0, 0, 1])
|
2019-09-18 21:27:03 +03:00
|
|
|
assert_almost_equal(roc_auc, 0.0)
|
2019-09-15 23:31:31 +03:00
|
|
|
|
|
|
|
y_true = [1, 0]
|
|
|
|
y_score = [1, 1]
|
|
|
|
tpr, fpr, _ = _roc_curve(y_true, y_score)
|
|
|
|
roc_auc = _roc_auc_score(y_true, y_score)
|
|
|
|
assert_array_almost_equal(tpr, [0, 1])
|
|
|
|
assert_array_almost_equal(fpr, [0, 1])
|
|
|
|
assert_almost_equal(roc_auc, 0.5)
|
|
|
|
|
|
|
|
y_true = [1, 0]
|
|
|
|
y_score = [1, 0]
|
|
|
|
tpr, fpr, _ = _roc_curve(y_true, y_score)
|
|
|
|
roc_auc = _roc_auc_score(y_true, y_score)
|
|
|
|
assert_array_almost_equal(tpr, [0, 0, 1])
|
|
|
|
assert_array_almost_equal(fpr, [0, 1, 1])
|
2019-09-18 21:27:03 +03:00
|
|
|
assert_almost_equal(roc_auc, 1.0)
|
2019-09-15 23:31:31 +03:00
|
|
|
|
|
|
|
y_true = [1, 0]
|
|
|
|
y_score = [0.5, 0.5]
|
|
|
|
tpr, fpr, _ = _roc_curve(y_true, y_score)
|
|
|
|
roc_auc = _roc_auc_score(y_true, y_score)
|
|
|
|
assert_array_almost_equal(tpr, [0, 1])
|
|
|
|
assert_array_almost_equal(fpr, [0, 1])
|
2019-09-18 21:27:03 +03:00
|
|
|
assert_almost_equal(roc_auc, 0.5)
|
2019-09-15 23:31:31 +03:00
|
|
|
|
|
|
|
# same result as above with ROCAUCScore wrapper
|
|
|
|
score = ROCAUCScore()
|
|
|
|
score.score_set(0.5, 1)
|
|
|
|
score.score_set(0.5, 0)
|
2019-09-18 21:27:03 +03:00
|
|
|
assert_almost_equal(score.score, 0.5)
|
2019-09-15 23:31:31 +03:00
|
|
|
|
|
|
|
# check that errors are raised in undefined cases and score is -inf
|
|
|
|
y_true = [0, 0]
|
|
|
|
y_score = [0.25, 0.75]
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
_roc_auc_score(y_true, y_score)
|
|
|
|
|
|
|
|
score = ROCAUCScore()
|
|
|
|
score.score_set(0.25, 0)
|
|
|
|
score.score_set(0.75, 0)
|
|
|
|
assert score.score == -float("inf")
|
|
|
|
|
|
|
|
y_true = [1, 1]
|
|
|
|
y_score = [0.25, 0.75]
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
_roc_auc_score(y_true, y_score)
|
|
|
|
|
|
|
|
score = ROCAUCScore()
|
|
|
|
score.score_set(0.25, 1)
|
|
|
|
score.score_set(0.75, 1)
|
|
|
|
assert score.score == -float("inf")
|