mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Add LAS per dependency to Scorer (#4560)
This commit is contained in:
parent
de98d66f87
commit
56ad3a3988
|
@ -82,6 +82,7 @@ class Scorer(object):
|
||||||
self.sbd = PRFScore()
|
self.sbd = PRFScore()
|
||||||
self.unlabelled = PRFScore()
|
self.unlabelled = PRFScore()
|
||||||
self.labelled = PRFScore()
|
self.labelled = PRFScore()
|
||||||
|
self.labelled_per_dep = dict()
|
||||||
self.tags = PRFScore()
|
self.tags = PRFScore()
|
||||||
self.ner = PRFScore()
|
self.ner = PRFScore()
|
||||||
self.ner_per_ents = dict()
|
self.ner_per_ents = dict()
|
||||||
|
@ -124,9 +125,18 @@ class Scorer(object):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def las(self):
|
def las(self):
|
||||||
"""RETURNS (float): Labelled depdendency score."""
|
"""RETURNS (float): Labelled dependency score."""
|
||||||
return self.labelled.fscore * 100
|
return self.labelled.fscore * 100
|
||||||
|
|
||||||
|
@property
|
||||||
|
def las_per_type(self):
|
||||||
|
"""RETURNS (dict): Scores per dependency label.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
k: {"p": v.precision * 100, "r": v.recall * 100, "f": v.fscore * 100}
|
||||||
|
for k, v in self.labelled_per_dep.items()
|
||||||
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ents_p(self):
|
def ents_p(self):
|
||||||
"""RETURNS (float): Named entity accuracy (precision)."""
|
"""RETURNS (float): Named entity accuracy (precision)."""
|
||||||
|
@ -196,6 +206,7 @@ class Scorer(object):
|
||||||
return {
|
return {
|
||||||
"uas": self.uas,
|
"uas": self.uas,
|
||||||
"las": self.las,
|
"las": self.las,
|
||||||
|
"las_per_type": self.las_per_type,
|
||||||
"ents_p": self.ents_p,
|
"ents_p": self.ents_p,
|
||||||
"ents_r": self.ents_r,
|
"ents_r": self.ents_r,
|
||||||
"ents_f": self.ents_f,
|
"ents_f": self.ents_f,
|
||||||
|
@ -223,13 +234,20 @@ class Scorer(object):
|
||||||
doc, tuple(zip(*gold.orig_annot)) + (gold.cats,)
|
doc, tuple(zip(*gold.orig_annot)) + (gold.cats,)
|
||||||
)
|
)
|
||||||
gold_deps = set()
|
gold_deps = set()
|
||||||
|
gold_deps_per_dep = {}
|
||||||
gold_tags = set()
|
gold_tags = set()
|
||||||
gold_ents = set(tags_to_entities([annot[-1] for annot in gold.orig_annot]))
|
gold_ents = set(tags_to_entities([annot[-1] for annot in gold.orig_annot]))
|
||||||
for id_, word, tag, head, dep, ner in gold.orig_annot:
|
for id_, word, tag, head, dep, ner in gold.orig_annot:
|
||||||
gold_tags.add((id_, tag))
|
gold_tags.add((id_, tag))
|
||||||
if dep not in (None, "") and dep.lower() not in punct_labels:
|
if dep not in (None, "") and dep.lower() not in punct_labels:
|
||||||
gold_deps.add((id_, head, dep.lower()))
|
gold_deps.add((id_, head, dep.lower()))
|
||||||
|
if dep.lower() not in self.labelled_per_dep:
|
||||||
|
self.labelled_per_dep[dep.lower()] = PRFScore()
|
||||||
|
if dep.lower() not in gold_deps_per_dep:
|
||||||
|
gold_deps_per_dep[dep.lower()] = set()
|
||||||
|
gold_deps_per_dep[dep.lower()].add((id_, head, dep.lower()))
|
||||||
cand_deps = set()
|
cand_deps = set()
|
||||||
|
cand_deps_per_dep = {}
|
||||||
cand_tags = set()
|
cand_tags = set()
|
||||||
for token in doc:
|
for token in doc:
|
||||||
if token.orth_.isspace():
|
if token.orth_.isspace():
|
||||||
|
@ -249,6 +267,11 @@ class Scorer(object):
|
||||||
self.labelled.fp += 1
|
self.labelled.fp += 1
|
||||||
else:
|
else:
|
||||||
cand_deps.add((gold_i, gold_head, token.dep_.lower()))
|
cand_deps.add((gold_i, gold_head, token.dep_.lower()))
|
||||||
|
if token.dep_.lower() not in self.labelled_per_dep:
|
||||||
|
self.labelled_per_dep[token.dep_.lower()] = PRFScore()
|
||||||
|
if token.dep_.lower() not in cand_deps_per_dep:
|
||||||
|
cand_deps_per_dep[token.dep_.lower()] = set()
|
||||||
|
cand_deps_per_dep[token.dep_.lower()].add((gold_i, gold_head, token.dep_.lower()))
|
||||||
if "-" not in [token[-1] for token in gold.orig_annot]:
|
if "-" not in [token[-1] for token in gold.orig_annot]:
|
||||||
# Find all NER labels in gold and doc
|
# Find all NER labels in gold and doc
|
||||||
ent_labels = set([x[0] for x in gold_ents] + [k.label_ for k in doc.ents])
|
ent_labels = set([x[0] for x in gold_ents] + [k.label_ for k in doc.ents])
|
||||||
|
@ -280,6 +303,8 @@ class Scorer(object):
|
||||||
self.ner.score_set(cand_ents, gold_ents)
|
self.ner.score_set(cand_ents, gold_ents)
|
||||||
self.tags.score_set(cand_tags, gold_tags)
|
self.tags.score_set(cand_tags, gold_tags)
|
||||||
self.labelled.score_set(cand_deps, gold_deps)
|
self.labelled.score_set(cand_deps, gold_deps)
|
||||||
|
for dep in self.labelled_per_dep:
|
||||||
|
self.labelled_per_dep[dep].score_set(cand_deps_per_dep.get(dep, set()), gold_deps_per_dep.get(dep, set()))
|
||||||
self.unlabelled.score_set(
|
self.unlabelled.score_set(
|
||||||
set(item[:2] for item in cand_deps), set(item[:2] for item in gold_deps)
|
set(item[:2] for item in cand_deps), set(item[:2] for item in gold_deps)
|
||||||
)
|
)
|
||||||
|
|
|
@ -9,6 +9,14 @@ from spacy.scorer import Scorer, ROCAUCScore
|
||||||
from spacy.scorer import _roc_auc_score, _roc_curve
|
from spacy.scorer import _roc_auc_score, _roc_curve
|
||||||
from .util import get_doc
|
from .util import get_doc
|
||||||
|
|
||||||
|
test_las_apple = [
|
||||||
|
[
|
||||||
|
"Apple is looking at buying U.K. startup for $ 1 billion",
|
||||||
|
{"heads": [2, 2, 2, 2, 3, 6, 4, 4, 10, 10, 7],
|
||||||
|
"deps": ['nsubj', 'aux', 'ROOT', 'prep', 'pcomp', 'compound', 'dobj', 'prep', 'quantmod', 'compound', 'pobj']},
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
test_ner_cardinal = [
|
test_ner_cardinal = [
|
||||||
["100 - 200", {"entities": [[0, 3, "CARDINAL"], [6, 9, "CARDINAL"]]}]
|
["100 - 200", {"entities": [[0, 3, "CARDINAL"], [6, 9, "CARDINAL"]]}]
|
||||||
]
|
]
|
||||||
|
@ -21,6 +29,53 @@ test_ner_apple = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_las_per_type(en_vocab):
|
||||||
|
# Gold and Doc are identical
|
||||||
|
scorer = Scorer()
|
||||||
|
for input_, annot in test_las_apple:
|
||||||
|
doc = get_doc(
|
||||||
|
en_vocab,
|
||||||
|
words=input_.split(" "),
|
||||||
|
heads=([h - i for i, h in enumerate(annot["heads"])]),
|
||||||
|
deps=annot["deps"],
|
||||||
|
)
|
||||||
|
gold = GoldParse(doc, heads=annot["heads"], deps=annot["deps"])
|
||||||
|
scorer.score(doc, gold)
|
||||||
|
results = scorer.scores
|
||||||
|
|
||||||
|
assert results["uas"] == 100
|
||||||
|
assert results["las"] == 100
|
||||||
|
assert results["las_per_type"]["nsubj"]["p"] == 100
|
||||||
|
assert results["las_per_type"]["nsubj"]["r"] == 100
|
||||||
|
assert results["las_per_type"]["nsubj"]["f"] == 100
|
||||||
|
assert results["las_per_type"]["compound"]["p"] == 100
|
||||||
|
assert results["las_per_type"]["compound"]["r"] == 100
|
||||||
|
assert results["las_per_type"]["compound"]["f"] == 100
|
||||||
|
|
||||||
|
# One dep is incorrect in Doc
|
||||||
|
scorer = Scorer()
|
||||||
|
for input_, annot in test_las_apple:
|
||||||
|
doc = get_doc(
|
||||||
|
en_vocab,
|
||||||
|
words=input_.split(" "),
|
||||||
|
heads=([h - i for i, h in enumerate(annot["heads"])]),
|
||||||
|
deps=annot["deps"]
|
||||||
|
)
|
||||||
|
gold = GoldParse(doc, heads=annot["heads"], deps=annot["deps"])
|
||||||
|
doc[0].dep_ = "compound"
|
||||||
|
scorer.score(doc, gold)
|
||||||
|
results = scorer.scores
|
||||||
|
|
||||||
|
assert results["uas"] == 100
|
||||||
|
assert_almost_equal(results["las"], 90.9090909)
|
||||||
|
assert results["las_per_type"]["nsubj"]["p"] == 0
|
||||||
|
assert results["las_per_type"]["nsubj"]["r"] == 0
|
||||||
|
assert results["las_per_type"]["nsubj"]["f"] == 0
|
||||||
|
assert_almost_equal(results["las_per_type"]["compound"]["p"], 66.6666666)
|
||||||
|
assert results["las_per_type"]["compound"]["r"] == 100
|
||||||
|
assert results["las_per_type"]["compound"]["f"] == 80
|
||||||
|
|
||||||
|
|
||||||
def test_ner_per_type(en_vocab):
|
def test_ner_per_type(en_vocab):
|
||||||
# Gold and Doc are identical
|
# Gold and Doc are identical
|
||||||
scorer = Scorer()
|
scorer = Scorer()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user