mirror of
https://github.com/explosion/spaCy.git
synced 2025-04-26 11:53:40 +03:00
raise error when setting overlapping entities as doc.ents (#2880)
This commit is contained in:
parent
071789467e
commit
57f274b693
|
@ -250,7 +250,9 @@ class Errors(object):
|
||||||
E096 = ("Invalid object passed to displaCy: Can only visualize Doc or "
|
E096 = ("Invalid object passed to displaCy: Can only visualize Doc or "
|
||||||
"Span objects, or dicts if set to manual=True.")
|
"Span objects, or dicts if set to manual=True.")
|
||||||
E097 = ("Can't merge non-disjoint spans. '{token}' is already part of tokens to merge")
|
E097 = ("Can't merge non-disjoint spans. '{token}' is already part of tokens to merge")
|
||||||
|
E098 = ("Trying to set conflicting doc.ents: '{span1}' and '{span2}'. A token"
|
||||||
|
" can only be part of one entity, so make sure the entities you're "
|
||||||
|
"setting don't overlap.")
|
||||||
|
|
||||||
@add_codes
|
@add_codes
|
||||||
class TempErrors(object):
|
class TempErrors(object):
|
||||||
|
|
|
@ -3,6 +3,7 @@ from __future__ import unicode_literals
|
||||||
|
|
||||||
from ...pipeline import EntityRecognizer
|
from ...pipeline import EntityRecognizer
|
||||||
from ..util import get_doc
|
from ..util import get_doc
|
||||||
|
from ...tokens import Span
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
@ -22,3 +23,13 @@ def test_doc_add_entities_set_ents_iob(en_vocab):
|
||||||
|
|
||||||
doc.ents = [(doc.vocab.strings['WORD'], 0, 2)]
|
doc.ents = [(doc.vocab.strings['WORD'], 0, 2)]
|
||||||
assert [w.ent_iob_ for w in doc] == ['B', 'I', '', '']
|
assert [w.ent_iob_ for w in doc] == ['B', 'I', '', '']
|
||||||
|
|
||||||
|
def test_add_overlapping_entities(en_vocab):
|
||||||
|
text = ["Louisiana", "Office", "of", "Conservation"]
|
||||||
|
doc = get_doc(en_vocab, text)
|
||||||
|
entity = Span(doc, 0, 4, label=391)
|
||||||
|
doc.ents = [entity]
|
||||||
|
|
||||||
|
new_entity = Span(doc, 0, 1, label=392)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
doc.ents = list(doc.ents) + [new_entity]
|
||||||
|
|
|
@ -17,9 +17,10 @@ def test_issue242(en_tokenizer):
|
||||||
matcher.add('FOOD', None, *patterns)
|
matcher.add('FOOD', None, *patterns)
|
||||||
|
|
||||||
matches = [(ent_type, start, end) for ent_type, start, end in matcher(doc)]
|
matches = [(ent_type, start, end) for ent_type, start, end in matcher(doc)]
|
||||||
doc.ents += tuple(matches)
|
|
||||||
match1, match2 = matches
|
match1, match2 = matches
|
||||||
assert match1[1] == 3
|
assert match1[1] == 3
|
||||||
assert match1[2] == 5
|
assert match1[2] == 5
|
||||||
assert match2[1] == 4
|
assert match2[1] == 4
|
||||||
assert match2[2] == 6
|
assert match2[2] == 6
|
||||||
|
|
||||||
|
doc.ents += tuple([match2])
|
||||||
|
|
|
@ -458,6 +458,21 @@ cdef class Doc:
|
||||||
# prediction
|
# prediction
|
||||||
# 3. Test basic data-driven ORTH gazetteer
|
# 3. Test basic data-driven ORTH gazetteer
|
||||||
# 4. Test more nuanced date and currency regex
|
# 4. Test more nuanced date and currency regex
|
||||||
|
|
||||||
|
tokens_in_ents = {}
|
||||||
|
cdef attr_t entity_type
|
||||||
|
cdef int ent_start, ent_end
|
||||||
|
for ent_info in ents:
|
||||||
|
entity_type, ent_start, ent_end = get_entity_info(ent_info)
|
||||||
|
for token_index in range(ent_start, ent_end):
|
||||||
|
if token_index in tokens_in_ents.keys():
|
||||||
|
raise ValueError(Errors.E098.format(
|
||||||
|
span1=(tokens_in_ents[token_index][0],
|
||||||
|
tokens_in_ents[token_index][1],
|
||||||
|
self.vocab.strings[tokens_in_ents[token_index][2]]),
|
||||||
|
span2=(ent_start, ent_end, self.vocab.strings[entity_type])))
|
||||||
|
tokens_in_ents[token_index] = (ent_start, ent_end, entity_type)
|
||||||
|
|
||||||
cdef int i
|
cdef int i
|
||||||
for i in range(self.length):
|
for i in range(self.length):
|
||||||
self.c[i].ent_type = 0
|
self.c[i].ent_type = 0
|
||||||
|
@ -465,15 +480,7 @@ cdef class Doc:
|
||||||
cdef attr_t ent_type
|
cdef attr_t ent_type
|
||||||
cdef int start, end
|
cdef int start, end
|
||||||
for ent_info in ents:
|
for ent_info in ents:
|
||||||
if isinstance(ent_info, Span):
|
ent_type, start, end = get_entity_info(ent_info)
|
||||||
ent_id = ent_info.ent_id
|
|
||||||
ent_type = ent_info.label
|
|
||||||
start = ent_info.start
|
|
||||||
end = ent_info.end
|
|
||||||
elif len(ent_info) == 3:
|
|
||||||
ent_type, start, end = ent_info
|
|
||||||
else:
|
|
||||||
ent_id, ent_type, start, end = ent_info
|
|
||||||
if ent_type is None or ent_type < 0:
|
if ent_type is None or ent_type < 0:
|
||||||
# Mark as O
|
# Mark as O
|
||||||
for i in range(start, end):
|
for i in range(start, end):
|
||||||
|
@ -1062,3 +1069,14 @@ def fix_attributes(doc, attributes):
|
||||||
attributes[ENT_TYPE] = doc.vocab.strings[attributes['label']]
|
attributes[ENT_TYPE] = doc.vocab.strings[attributes['label']]
|
||||||
if 'ent_type' in attributes:
|
if 'ent_type' in attributes:
|
||||||
attributes[ENT_TYPE] = attributes['ent_type']
|
attributes[ENT_TYPE] = attributes['ent_type']
|
||||||
|
|
||||||
|
def get_entity_info(ent_info):
|
||||||
|
if isinstance(ent_info, Span):
|
||||||
|
ent_type = ent_info.label
|
||||||
|
start = ent_info.start
|
||||||
|
end = ent_info.end
|
||||||
|
elif len(ent_info) == 3:
|
||||||
|
ent_type, start, end = ent_info
|
||||||
|
else:
|
||||||
|
ent_id, ent_type, start, end = ent_info
|
||||||
|
return ent_type, start, end
|
Loading…
Reference in New Issue
Block a user