mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 17:24:41 +03:00
raise error when setting overlapping entities as doc.ents (#2880)
This commit is contained in:
parent
071789467e
commit
57f274b693
|
@ -250,7 +250,9 @@ class Errors(object):
|
|||
E096 = ("Invalid object passed to displaCy: Can only visualize Doc or "
|
||||
"Span objects, or dicts if set to manual=True.")
|
||||
E097 = ("Can't merge non-disjoint spans. '{token}' is already part of tokens to merge")
|
||||
|
||||
E098 = ("Trying to set conflicting doc.ents: '{span1}' and '{span2}'. A token"
|
||||
" can only be part of one entity, so make sure the entities you're "
|
||||
"setting don't overlap.")
|
||||
|
||||
@add_codes
|
||||
class TempErrors(object):
|
||||
|
|
|
@ -3,6 +3,7 @@ from __future__ import unicode_literals
|
|||
|
||||
from ...pipeline import EntityRecognizer
|
||||
from ..util import get_doc
|
||||
from ...tokens import Span
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -22,3 +23,13 @@ def test_doc_add_entities_set_ents_iob(en_vocab):
|
|||
|
||||
doc.ents = [(doc.vocab.strings['WORD'], 0, 2)]
|
||||
assert [w.ent_iob_ for w in doc] == ['B', 'I', '', '']
|
||||
|
||||
def test_add_overlapping_entities(en_vocab):
|
||||
text = ["Louisiana", "Office", "of", "Conservation"]
|
||||
doc = get_doc(en_vocab, text)
|
||||
entity = Span(doc, 0, 4, label=391)
|
||||
doc.ents = [entity]
|
||||
|
||||
new_entity = Span(doc, 0, 1, label=392)
|
||||
with pytest.raises(ValueError):
|
||||
doc.ents = list(doc.ents) + [new_entity]
|
||||
|
|
|
@ -17,9 +17,10 @@ def test_issue242(en_tokenizer):
|
|||
matcher.add('FOOD', None, *patterns)
|
||||
|
||||
matches = [(ent_type, start, end) for ent_type, start, end in matcher(doc)]
|
||||
doc.ents += tuple(matches)
|
||||
match1, match2 = matches
|
||||
assert match1[1] == 3
|
||||
assert match1[2] == 5
|
||||
assert match2[1] == 4
|
||||
assert match2[2] == 6
|
||||
|
||||
doc.ents += tuple([match2])
|
||||
|
|
|
@ -458,6 +458,21 @@ cdef class Doc:
|
|||
# prediction
|
||||
# 3. Test basic data-driven ORTH gazetteer
|
||||
# 4. Test more nuanced date and currency regex
|
||||
|
||||
tokens_in_ents = {}
|
||||
cdef attr_t entity_type
|
||||
cdef int ent_start, ent_end
|
||||
for ent_info in ents:
|
||||
entity_type, ent_start, ent_end = get_entity_info(ent_info)
|
||||
for token_index in range(ent_start, ent_end):
|
||||
if token_index in tokens_in_ents.keys():
|
||||
raise ValueError(Errors.E098.format(
|
||||
span1=(tokens_in_ents[token_index][0],
|
||||
tokens_in_ents[token_index][1],
|
||||
self.vocab.strings[tokens_in_ents[token_index][2]]),
|
||||
span2=(ent_start, ent_end, self.vocab.strings[entity_type])))
|
||||
tokens_in_ents[token_index] = (ent_start, ent_end, entity_type)
|
||||
|
||||
cdef int i
|
||||
for i in range(self.length):
|
||||
self.c[i].ent_type = 0
|
||||
|
@ -465,15 +480,7 @@ cdef class Doc:
|
|||
cdef attr_t ent_type
|
||||
cdef int start, end
|
||||
for ent_info in ents:
|
||||
if isinstance(ent_info, Span):
|
||||
ent_id = ent_info.ent_id
|
||||
ent_type = ent_info.label
|
||||
start = ent_info.start
|
||||
end = ent_info.end
|
||||
elif len(ent_info) == 3:
|
||||
ent_type, start, end = ent_info
|
||||
else:
|
||||
ent_id, ent_type, start, end = ent_info
|
||||
ent_type, start, end = get_entity_info(ent_info)
|
||||
if ent_type is None or ent_type < 0:
|
||||
# Mark as O
|
||||
for i in range(start, end):
|
||||
|
@ -1062,3 +1069,14 @@ def fix_attributes(doc, attributes):
|
|||
attributes[ENT_TYPE] = doc.vocab.strings[attributes['label']]
|
||||
if 'ent_type' in attributes:
|
||||
attributes[ENT_TYPE] = attributes['ent_type']
|
||||
|
||||
def get_entity_info(ent_info):
|
||||
if isinstance(ent_info, Span):
|
||||
ent_type = ent_info.label
|
||||
start = ent_info.start
|
||||
end = ent_info.end
|
||||
elif len(ent_info) == 3:
|
||||
ent_type, start, end = ent_info
|
||||
else:
|
||||
ent_id, ent_type, start, end = ent_info
|
||||
return ent_type, start, end
|
Loading…
Reference in New Issue
Block a user