mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Add LAS per dependency to Scorer (#4560)
This commit is contained in:
parent
de98d66f87
commit
56ad3a3988
|
@ -82,6 +82,7 @@ class Scorer(object):
|
|||
self.sbd = PRFScore()
|
||||
self.unlabelled = PRFScore()
|
||||
self.labelled = PRFScore()
|
||||
self.labelled_per_dep = dict()
|
||||
self.tags = PRFScore()
|
||||
self.ner = PRFScore()
|
||||
self.ner_per_ents = dict()
|
||||
|
@ -124,9 +125,18 @@ class Scorer(object):
|
|||
|
||||
@property
|
||||
def las(self):
|
||||
"""RETURNS (float): Labelled depdendency score."""
|
||||
"""RETURNS (float): Labelled dependency score."""
|
||||
return self.labelled.fscore * 100
|
||||
|
||||
@property
|
||||
def las_per_type(self):
|
||||
"""RETURNS (dict): Scores per dependency label.
|
||||
"""
|
||||
return {
|
||||
k: {"p": v.precision * 100, "r": v.recall * 100, "f": v.fscore * 100}
|
||||
for k, v in self.labelled_per_dep.items()
|
||||
}
|
||||
|
||||
@property
|
||||
def ents_p(self):
|
||||
"""RETURNS (float): Named entity accuracy (precision)."""
|
||||
|
@ -196,6 +206,7 @@ class Scorer(object):
|
|||
return {
|
||||
"uas": self.uas,
|
||||
"las": self.las,
|
||||
"las_per_type": self.las_per_type,
|
||||
"ents_p": self.ents_p,
|
||||
"ents_r": self.ents_r,
|
||||
"ents_f": self.ents_f,
|
||||
|
@ -223,13 +234,20 @@ class Scorer(object):
|
|||
doc, tuple(zip(*gold.orig_annot)) + (gold.cats,)
|
||||
)
|
||||
gold_deps = set()
|
||||
gold_deps_per_dep = {}
|
||||
gold_tags = set()
|
||||
gold_ents = set(tags_to_entities([annot[-1] for annot in gold.orig_annot]))
|
||||
for id_, word, tag, head, dep, ner in gold.orig_annot:
|
||||
gold_tags.add((id_, tag))
|
||||
if dep not in (None, "") and dep.lower() not in punct_labels:
|
||||
gold_deps.add((id_, head, dep.lower()))
|
||||
if dep.lower() not in self.labelled_per_dep:
|
||||
self.labelled_per_dep[dep.lower()] = PRFScore()
|
||||
if dep.lower() not in gold_deps_per_dep:
|
||||
gold_deps_per_dep[dep.lower()] = set()
|
||||
gold_deps_per_dep[dep.lower()].add((id_, head, dep.lower()))
|
||||
cand_deps = set()
|
||||
cand_deps_per_dep = {}
|
||||
cand_tags = set()
|
||||
for token in doc:
|
||||
if token.orth_.isspace():
|
||||
|
@ -249,6 +267,11 @@ class Scorer(object):
|
|||
self.labelled.fp += 1
|
||||
else:
|
||||
cand_deps.add((gold_i, gold_head, token.dep_.lower()))
|
||||
if token.dep_.lower() not in self.labelled_per_dep:
|
||||
self.labelled_per_dep[token.dep_.lower()] = PRFScore()
|
||||
if token.dep_.lower() not in cand_deps_per_dep:
|
||||
cand_deps_per_dep[token.dep_.lower()] = set()
|
||||
cand_deps_per_dep[token.dep_.lower()].add((gold_i, gold_head, token.dep_.lower()))
|
||||
if "-" not in [token[-1] for token in gold.orig_annot]:
|
||||
# Find all NER labels in gold and doc
|
||||
ent_labels = set([x[0] for x in gold_ents] + [k.label_ for k in doc.ents])
|
||||
|
@ -280,6 +303,8 @@ class Scorer(object):
|
|||
self.ner.score_set(cand_ents, gold_ents)
|
||||
self.tags.score_set(cand_tags, gold_tags)
|
||||
self.labelled.score_set(cand_deps, gold_deps)
|
||||
for dep in self.labelled_per_dep:
|
||||
self.labelled_per_dep[dep].score_set(cand_deps_per_dep.get(dep, set()), gold_deps_per_dep.get(dep, set()))
|
||||
self.unlabelled.score_set(
|
||||
set(item[:2] for item in cand_deps), set(item[:2] for item in gold_deps)
|
||||
)
|
||||
|
|
|
@ -9,6 +9,14 @@ from spacy.scorer import Scorer, ROCAUCScore
|
|||
from spacy.scorer import _roc_auc_score, _roc_curve
|
||||
from .util import get_doc
|
||||
|
||||
test_las_apple = [
|
||||
[
|
||||
"Apple is looking at buying U.K. startup for $ 1 billion",
|
||||
{"heads": [2, 2, 2, 2, 3, 6, 4, 4, 10, 10, 7],
|
||||
"deps": ['nsubj', 'aux', 'ROOT', 'prep', 'pcomp', 'compound', 'dobj', 'prep', 'quantmod', 'compound', 'pobj']},
|
||||
]
|
||||
]
|
||||
|
||||
test_ner_cardinal = [
|
||||
["100 - 200", {"entities": [[0, 3, "CARDINAL"], [6, 9, "CARDINAL"]]}]
|
||||
]
|
||||
|
@ -21,6 +29,53 @@ test_ner_apple = [
|
|||
]
|
||||
|
||||
|
||||
def test_las_per_type(en_vocab):
|
||||
# Gold and Doc are identical
|
||||
scorer = Scorer()
|
||||
for input_, annot in test_las_apple:
|
||||
doc = get_doc(
|
||||
en_vocab,
|
||||
words=input_.split(" "),
|
||||
heads=([h - i for i, h in enumerate(annot["heads"])]),
|
||||
deps=annot["deps"],
|
||||
)
|
||||
gold = GoldParse(doc, heads=annot["heads"], deps=annot["deps"])
|
||||
scorer.score(doc, gold)
|
||||
results = scorer.scores
|
||||
|
||||
assert results["uas"] == 100
|
||||
assert results["las"] == 100
|
||||
assert results["las_per_type"]["nsubj"]["p"] == 100
|
||||
assert results["las_per_type"]["nsubj"]["r"] == 100
|
||||
assert results["las_per_type"]["nsubj"]["f"] == 100
|
||||
assert results["las_per_type"]["compound"]["p"] == 100
|
||||
assert results["las_per_type"]["compound"]["r"] == 100
|
||||
assert results["las_per_type"]["compound"]["f"] == 100
|
||||
|
||||
# One dep is incorrect in Doc
|
||||
scorer = Scorer()
|
||||
for input_, annot in test_las_apple:
|
||||
doc = get_doc(
|
||||
en_vocab,
|
||||
words=input_.split(" "),
|
||||
heads=([h - i for i, h in enumerate(annot["heads"])]),
|
||||
deps=annot["deps"]
|
||||
)
|
||||
gold = GoldParse(doc, heads=annot["heads"], deps=annot["deps"])
|
||||
doc[0].dep_ = "compound"
|
||||
scorer.score(doc, gold)
|
||||
results = scorer.scores
|
||||
|
||||
assert results["uas"] == 100
|
||||
assert_almost_equal(results["las"], 90.9090909)
|
||||
assert results["las_per_type"]["nsubj"]["p"] == 0
|
||||
assert results["las_per_type"]["nsubj"]["r"] == 0
|
||||
assert results["las_per_type"]["nsubj"]["f"] == 0
|
||||
assert_almost_equal(results["las_per_type"]["compound"]["p"], 66.6666666)
|
||||
assert results["las_per_type"]["compound"]["r"] == 100
|
||||
assert results["las_per_type"]["compound"]["f"] == 80
|
||||
|
||||
|
||||
def test_ner_per_type(en_vocab):
|
||||
# Gold and Doc are identical
|
||||
scorer = Scorer()
|
||||
|
|
Loading…
Reference in New Issue
Block a user