mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Improve NER per type scoring (#4052)
* Improve NER per type scoring * include all gold labels in per type scoring, not only when recall > 0 * improve efficiency of per type scoring * Create Scorer tests, initially with NER tests * move regression test #3968 (per type NER scoring) to Scorer tests * add new test for per type NER scoring with imperfect P/R/F and per type P/R/F including a case where R == 0.0
This commit is contained in:
parent
f7d950de6d
commit
925a852bb6
|
@ -159,12 +159,19 @@ class Scorer(object):
|
|||
else:
|
||||
cand_deps.add((gold_i, gold_head, token.dep_.lower()))
|
||||
if "-" not in [token[-1] for token in gold.orig_annot]:
|
||||
# Find all NER labels in gold and doc
|
||||
ent_labels = set([x[0] for x in gold_ents]
|
||||
+ [k.label_ for k in doc.ents])
|
||||
# Set up all labels for per type scoring and prepare gold per type
|
||||
gold_per_ents = {ent_label: set() for ent_label in ent_labels}
|
||||
for ent_label in ent_labels:
|
||||
if ent_label not in self.ner_per_ents:
|
||||
self.ner_per_ents[ent_label] = PRFScore()
|
||||
gold_per_ents[ent_label].update([x for x in gold_ents if x[0] == ent_label])
|
||||
# Find all candidate labels, for all and per type
|
||||
cand_ents = set()
|
||||
current_ent = {k.label_: set() for k in doc.ents}
|
||||
current_gold = {k.label_: set() for k in doc.ents}
|
||||
cand_per_ents = {ent_label: set() for ent_label in ent_labels}
|
||||
for ent in doc.ents:
|
||||
if ent.label_ not in self.ner_per_ents:
|
||||
self.ner_per_ents[ent.label_] = PRFScore()
|
||||
first = gold.cand_to_gold[ent.start]
|
||||
last = gold.cand_to_gold[ent.end - 1]
|
||||
if first is None or last is None:
|
||||
|
@ -172,14 +179,11 @@ class Scorer(object):
|
|||
self.ner_per_ents[ent.label_].fp += 1
|
||||
else:
|
||||
cand_ents.add((ent.label_, first, last))
|
||||
current_ent[ent.label_].update([x for x in cand_ents if x[0] == ent.label_])
|
||||
current_gold[ent.label_].update([x for x in gold_ents if x[0] == ent.label_])
|
||||
cand_per_ents[ent.label_].add((ent.label_, first, last))
|
||||
# Scores per ent
|
||||
[
|
||||
v.score_set(current_ent[k], current_gold[k])
|
||||
for k, v in self.ner_per_ents.items()
|
||||
if k in current_ent
|
||||
]
|
||||
for k, v in self.ner_per_ents.items():
|
||||
if k in cand_per_ents:
|
||||
v.score_set(cand_per_ents[k], gold_per_ents[k])
|
||||
# Score for all ents
|
||||
self.ner.score_set(cand_ents, gold_ents)
|
||||
self.tags.score_set(cand_tags, gold_tags)
|
||||
|
|
|
@ -1,34 +0,0 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from spacy.gold import GoldParse
|
||||
from spacy.scorer import Scorer
|
||||
from ..util import get_doc
|
||||
|
||||
test_samples = [
|
||||
[
|
||||
"100 - 200",
|
||||
{
|
||||
"entities": [
|
||||
[0, 3, "CARDINAL"],
|
||||
[6, 9, "CARDINAL"]
|
||||
]
|
||||
}
|
||||
]
|
||||
]
|
||||
|
||||
def test_issue3625(en_vocab):
|
||||
scorer = Scorer()
|
||||
for input_, annot in test_samples:
|
||||
doc = get_doc(en_vocab, words = input_.split(' '), ents = [[0,1,'CARDINAL'], [2,3,'CARDINAL']]);
|
||||
gold = GoldParse(doc, entities = annot['entities'])
|
||||
scorer.score(doc, gold)
|
||||
results = scorer.scores
|
||||
|
||||
# Expects total accuracy and accuracy for each each entity to be 100%
|
||||
assert results['ents_p'] == 100
|
||||
assert results['ents_f'] == 100
|
||||
assert results['ents_r'] == 100
|
||||
assert results['ents_per_type']['CARDINAL']['p'] == 100
|
||||
assert results['ents_per_type']['CARDINAL']['f'] == 100
|
||||
assert results['ents_per_type']['CARDINAL']['r'] == 100
|
73
spacy/tests/test_scorer.py
Normal file
73
spacy/tests/test_scorer.py
Normal file
|
@ -0,0 +1,73 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from pytest import approx
|
||||
from spacy.gold import GoldParse
|
||||
from spacy.scorer import Scorer
|
||||
from .util import get_doc
|
||||
|
||||
test_ner_cardinal = [
|
||||
[
|
||||
"100 - 200",
|
||||
{
|
||||
"entities": [
|
||||
[0, 3, "CARDINAL"],
|
||||
[6, 9, "CARDINAL"]
|
||||
]
|
||||
}
|
||||
]
|
||||
]
|
||||
|
||||
test_ner_apple = [
|
||||
[
|
||||
"Apple is looking at buying U.K. startup for $1 billion",
|
||||
{
|
||||
"entities": [
|
||||
(0, 5, "ORG"),
|
||||
(27, 31, "GPE"),
|
||||
(44, 54, "MONEY"),
|
||||
]
|
||||
}
|
||||
]
|
||||
]
|
||||
|
||||
def test_ner_per_type(en_vocab):
|
||||
# Gold and Doc are identical
|
||||
scorer = Scorer()
|
||||
for input_, annot in test_ner_cardinal:
|
||||
doc = get_doc(en_vocab, words = input_.split(' '), ents = [[0, 1, 'CARDINAL'], [2, 3, 'CARDINAL']])
|
||||
gold = GoldParse(doc, entities = annot['entities'])
|
||||
scorer.score(doc, gold)
|
||||
results = scorer.scores
|
||||
|
||||
assert results['ents_p'] == 100
|
||||
assert results['ents_f'] == 100
|
||||
assert results['ents_r'] == 100
|
||||
assert results['ents_per_type']['CARDINAL']['p'] == 100
|
||||
assert results['ents_per_type']['CARDINAL']['f'] == 100
|
||||
assert results['ents_per_type']['CARDINAL']['r'] == 100
|
||||
|
||||
# Doc has one missing and one extra entity
|
||||
# Entity type MONEY is not present in Doc
|
||||
scorer = Scorer()
|
||||
for input_, annot in test_ner_apple:
|
||||
doc = get_doc(en_vocab, words = input_.split(' '), ents = [[0, 1, 'ORG'], [5, 6, 'GPE'], [6, 7, 'ORG']])
|
||||
gold = GoldParse(doc, entities = annot['entities'])
|
||||
scorer.score(doc, gold)
|
||||
results = scorer.scores
|
||||
|
||||
assert results['ents_p'] == approx(66.66666)
|
||||
assert results['ents_r'] == approx(66.66666)
|
||||
assert results['ents_f'] == approx(66.66666)
|
||||
assert 'GPE' in results['ents_per_type']
|
||||
assert 'MONEY' in results['ents_per_type']
|
||||
assert 'ORG' in results['ents_per_type']
|
||||
assert results['ents_per_type']['GPE']['p'] == 100
|
||||
assert results['ents_per_type']['GPE']['r'] == 100
|
||||
assert results['ents_per_type']['GPE']['f'] == 100
|
||||
assert results['ents_per_type']['MONEY']['p'] == 0
|
||||
assert results['ents_per_type']['MONEY']['r'] == 0
|
||||
assert results['ents_per_type']['MONEY']['f'] == 0
|
||||
assert results['ents_per_type']['ORG']['p'] == 50
|
||||
assert results['ents_per_type']['ORG']['r'] == 100
|
||||
assert results['ents_per_type']['ORG']['f'] == approx(66.66666)
|
Loading…
Reference in New Issue
Block a user