From 925a852bb6450e16a23346e97a1813fc0fcb22a0 Mon Sep 17 00:00:00 2001 From: adrianeboyd Date: Thu, 1 Aug 2019 17:15:36 +0200 Subject: [PATCH] Improve NER per type scoring (#4052) * Improve NER per type scoring * include all gold labels in per type scoring, not only when recall > 0 * improve efficiency of per type scoring * Create Scorer tests, initially with NER tests * move regression test #3968 (per type NER scoring) to Scorer tests * add new test for per type NER scoring with imperfect P/R/F and per type P/R/F including a case where R == 0.0 --- spacy/scorer.py | 26 +++++---- spacy/tests/regression/test_issue3968.py | 34 ----------- spacy/tests/test_scorer.py | 73 ++++++++++++++++++++++++ 3 files changed, 88 insertions(+), 45 deletions(-) delete mode 100644 spacy/tests/regression/test_issue3968.py create mode 100644 spacy/tests/test_scorer.py diff --git a/spacy/scorer.py b/spacy/scorer.py index 34a9b7620..1362e9b4d 100644 --- a/spacy/scorer.py +++ b/spacy/scorer.py @@ -159,12 +159,19 @@ class Scorer(object): else: cand_deps.add((gold_i, gold_head, token.dep_.lower())) if "-" not in [token[-1] for token in gold.orig_annot]: + # Find all NER labels in gold and doc + ent_labels = set([x[0] for x in gold_ents] + + [k.label_ for k in doc.ents]) + # Set up all labels for per type scoring and prepare gold per type + gold_per_ents = {ent_label: set() for ent_label in ent_labels} + for ent_label in ent_labels: + if ent_label not in self.ner_per_ents: + self.ner_per_ents[ent_label] = PRFScore() + gold_per_ents[ent_label].update([x for x in gold_ents if x[0] == ent_label]) + # Find all candidate labels, for all and per type cand_ents = set() - current_ent = {k.label_: set() for k in doc.ents} - current_gold = {k.label_: set() for k in doc.ents} + cand_per_ents = {ent_label: set() for ent_label in ent_labels} for ent in doc.ents: - if ent.label_ not in self.ner_per_ents: - self.ner_per_ents[ent.label_] = PRFScore() first = gold.cand_to_gold[ent.start] last = gold.cand_to_gold[ent.end - 1] if first is None or last is None: @@ -172,14 +179,11 @@ class Scorer(object): self.ner_per_ents[ent.label_].fp += 1 else: cand_ents.add((ent.label_, first, last)) - current_ent[ent.label_].update([x for x in cand_ents if x[0] == ent.label_]) - current_gold[ent.label_].update([x for x in gold_ents if x[0] == ent.label_]) + cand_per_ents[ent.label_].add((ent.label_, first, last)) # Scores per ent - [ - v.score_set(current_ent[k], current_gold[k]) - for k, v in self.ner_per_ents.items() - if k in current_ent - ] + for k, v in self.ner_per_ents.items(): + if k in cand_per_ents: + v.score_set(cand_per_ents[k], gold_per_ents[k]) # Score for all ents self.ner.score_set(cand_ents, gold_ents) self.tags.score_set(cand_tags, gold_tags) diff --git a/spacy/tests/regression/test_issue3968.py b/spacy/tests/regression/test_issue3968.py deleted file mode 100644 index 7e970a3a9..000000000 --- a/spacy/tests/regression/test_issue3968.py +++ /dev/null @@ -1,34 +0,0 @@ -# coding: utf-8 -from __future__ import unicode_literals - -from spacy.gold import GoldParse -from spacy.scorer import Scorer -from ..util import get_doc - -test_samples = [ - [ - "100 - 200", - { - "entities": [ - [0, 3, "CARDINAL"], - [6, 9, "CARDINAL"] - ] - } - ] -] - -def test_issue3625(en_vocab): - scorer = Scorer() - for input_, annot in test_samples: - doc = get_doc(en_vocab, words = input_.split(' '), ents = [[0,1,'CARDINAL'], [2,3,'CARDINAL']]); - gold = GoldParse(doc, entities = annot['entities']) - scorer.score(doc, gold) - results = scorer.scores - - # Expects total accuracy and accuracy for each each entity to be 100% - assert results['ents_p'] == 100 - assert results['ents_f'] == 100 - assert results['ents_r'] == 100 - assert results['ents_per_type']['CARDINAL']['p'] == 100 - assert results['ents_per_type']['CARDINAL']['f'] == 100 - assert results['ents_per_type']['CARDINAL']['r'] == 100 diff --git a/spacy/tests/test_scorer.py b/spacy/tests/test_scorer.py new file mode 100644 index 000000000..a88aef368 --- /dev/null +++ b/spacy/tests/test_scorer.py @@ -0,0 +1,73 @@ +# coding: utf-8 +from __future__ import unicode_literals + +from pytest import approx +from spacy.gold import GoldParse +from spacy.scorer import Scorer +from .util import get_doc + +test_ner_cardinal = [ + [ + "100 - 200", + { + "entities": [ + [0, 3, "CARDINAL"], + [6, 9, "CARDINAL"] + ] + } + ] +] + +test_ner_apple = [ + [ + "Apple is looking at buying U.K. startup for $1 billion", + { + "entities": [ + (0, 5, "ORG"), + (27, 31, "GPE"), + (44, 54, "MONEY"), + ] + } + ] +] + +def test_ner_per_type(en_vocab): + # Gold and Doc are identical + scorer = Scorer() + for input_, annot in test_ner_cardinal: + doc = get_doc(en_vocab, words = input_.split(' '), ents = [[0, 1, 'CARDINAL'], [2, 3, 'CARDINAL']]) + gold = GoldParse(doc, entities = annot['entities']) + scorer.score(doc, gold) + results = scorer.scores + + assert results['ents_p'] == 100 + assert results['ents_f'] == 100 + assert results['ents_r'] == 100 + assert results['ents_per_type']['CARDINAL']['p'] == 100 + assert results['ents_per_type']['CARDINAL']['f'] == 100 + assert results['ents_per_type']['CARDINAL']['r'] == 100 + + # Doc has one missing and one extra entity + # Entity type MONEY is not present in Doc + scorer = Scorer() + for input_, annot in test_ner_apple: + doc = get_doc(en_vocab, words = input_.split(' '), ents = [[0, 1, 'ORG'], [5, 6, 'GPE'], [6, 7, 'ORG']]) + gold = GoldParse(doc, entities = annot['entities']) + scorer.score(doc, gold) + results = scorer.scores + + assert results['ents_p'] == approx(66.66666) + assert results['ents_r'] == approx(66.66666) + assert results['ents_f'] == approx(66.66666) + assert 'GPE' in results['ents_per_type'] + assert 'MONEY' in results['ents_per_type'] + assert 'ORG' in results['ents_per_type'] + assert results['ents_per_type']['GPE']['p'] == 100 + assert results['ents_per_type']['GPE']['r'] == 100 + assert results['ents_per_type']['GPE']['f'] == 100 + assert results['ents_per_type']['MONEY']['p'] == 0 + assert results['ents_per_type']['MONEY']['r'] == 0 + assert results['ents_per_type']['MONEY']['f'] == 0 + assert results['ents_per_type']['ORG']['p'] == 50 + assert results['ents_per_type']['ORG']['r'] == 100 + assert results['ents_per_type']['ORG']['f'] == approx(66.66666)