raise error when setting overlapping entities as doc.ents (#2880)

This commit is contained in:
Grivaz 2018-10-26 17:29:16 -04:00 committed by Matthew Honnibal
parent 071789467e
commit 57f274b693
4 changed files with 43 additions and 11 deletions

View File

@ -250,7 +250,9 @@ class Errors(object):
E096 = ("Invalid object passed to displaCy: Can only visualize Doc or "
"Span objects, or dicts if set to manual=True.")
E097 = ("Can't merge non-disjoint spans. '{token}' is already part of tokens to merge")
E098 = ("Trying to set conflicting doc.ents: '{span1}' and '{span2}'. A token"
" can only be part of one entity, so make sure the entities you're "
"setting don't overlap.")
@add_codes
class TempErrors(object):

View File

@ -3,6 +3,7 @@ from __future__ import unicode_literals
from ...pipeline import EntityRecognizer
from ..util import get_doc
from ...tokens import Span
import pytest
@ -22,3 +23,13 @@ def test_doc_add_entities_set_ents_iob(en_vocab):
doc.ents = [(doc.vocab.strings['WORD'], 0, 2)]
assert [w.ent_iob_ for w in doc] == ['B', 'I', '', '']
def test_add_overlapping_entities(en_vocab):
text = ["Louisiana", "Office", "of", "Conservation"]
doc = get_doc(en_vocab, text)
entity = Span(doc, 0, 4, label=391)
doc.ents = [entity]
new_entity = Span(doc, 0, 1, label=392)
with pytest.raises(ValueError):
doc.ents = list(doc.ents) + [new_entity]

View File

@ -17,9 +17,10 @@ def test_issue242(en_tokenizer):
matcher.add('FOOD', None, *patterns)
matches = [(ent_type, start, end) for ent_type, start, end in matcher(doc)]
doc.ents += tuple(matches)
match1, match2 = matches
assert match1[1] == 3
assert match1[2] == 5
assert match2[1] == 4
assert match2[2] == 6
doc.ents += tuple([match2])

View File

@ -458,6 +458,21 @@ cdef class Doc:
# prediction
# 3. Test basic data-driven ORTH gazetteer
# 4. Test more nuanced date and currency regex
tokens_in_ents = {}
cdef attr_t entity_type
cdef int ent_start, ent_end
for ent_info in ents:
entity_type, ent_start, ent_end = get_entity_info(ent_info)
for token_index in range(ent_start, ent_end):
if token_index in tokens_in_ents.keys():
raise ValueError(Errors.E098.format(
span1=(tokens_in_ents[token_index][0],
tokens_in_ents[token_index][1],
self.vocab.strings[tokens_in_ents[token_index][2]]),
span2=(ent_start, ent_end, self.vocab.strings[entity_type])))
tokens_in_ents[token_index] = (ent_start, ent_end, entity_type)
cdef int i
for i in range(self.length):
self.c[i].ent_type = 0
@ -465,15 +480,7 @@ cdef class Doc:
cdef attr_t ent_type
cdef int start, end
for ent_info in ents:
if isinstance(ent_info, Span):
ent_id = ent_info.ent_id
ent_type = ent_info.label
start = ent_info.start
end = ent_info.end
elif len(ent_info) == 3:
ent_type, start, end = ent_info
else:
ent_id, ent_type, start, end = ent_info
ent_type, start, end = get_entity_info(ent_info)
if ent_type is None or ent_type < 0:
# Mark as O
for i in range(start, end):
@ -1062,3 +1069,14 @@ def fix_attributes(doc, attributes):
attributes[ENT_TYPE] = doc.vocab.strings[attributes['label']]
if 'ent_type' in attributes:
attributes[ENT_TYPE] = attributes['ent_type']
def get_entity_info(ent_info):
if isinstance(ent_info, Span):
ent_type = ent_info.label
start = ent_info.start
end = ent_info.end
elif len(ent_info) == 3:
ent_type, start, end = ent_info
else:
ent_id, ent_type, start, end = ent_info
return ent_type, start, end