diff --git a/spacy/errors.py b/spacy/errors.py index 7c0f0efd3..138de0f57 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -250,7 +250,9 @@ class Errors(object): E096 = ("Invalid object passed to displaCy: Can only visualize Doc or " "Span objects, or dicts if set to manual=True.") E097 = ("Can't merge non-disjoint spans. '{token}' is already part of tokens to merge") - + E098 = ("Trying to set conflicting doc.ents: '{span1}' and '{span2}'. A token" + " can only be part of one entity, so make sure the entities you're " + "setting don't overlap.") @add_codes class TempErrors(object): diff --git a/spacy/tests/doc/test_add_entities.py b/spacy/tests/doc/test_add_entities.py index 31d2b8420..f05da4fe8 100644 --- a/spacy/tests/doc/test_add_entities.py +++ b/spacy/tests/doc/test_add_entities.py @@ -3,6 +3,7 @@ from __future__ import unicode_literals from ...pipeline import EntityRecognizer from ..util import get_doc +from ...tokens import Span import pytest @@ -22,3 +23,13 @@ def test_doc_add_entities_set_ents_iob(en_vocab): doc.ents = [(doc.vocab.strings['WORD'], 0, 2)] assert [w.ent_iob_ for w in doc] == ['B', 'I', '', ''] + +def test_add_overlapping_entities(en_vocab): + text = ["Louisiana", "Office", "of", "Conservation"] + doc = get_doc(en_vocab, text) + entity = Span(doc, 0, 4, label=391) + doc.ents = [entity] + + new_entity = Span(doc, 0, 1, label=392) + with pytest.raises(ValueError): + doc.ents = list(doc.ents) + [new_entity] diff --git a/spacy/tests/regression/test_issue242.py b/spacy/tests/regression/test_issue242.py index b5909fe65..4fff090ee 100644 --- a/spacy/tests/regression/test_issue242.py +++ b/spacy/tests/regression/test_issue242.py @@ -17,9 +17,10 @@ def test_issue242(en_tokenizer): matcher.add('FOOD', None, *patterns) matches = [(ent_type, start, end) for ent_type, start, end in matcher(doc)] - doc.ents += tuple(matches) match1, match2 = matches assert match1[1] == 3 assert match1[2] == 5 assert match2[1] == 4 assert match2[2] == 6 + + doc.ents += tuple([match2]) diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index 73e7cf390..b29af1ff1 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -458,6 +458,21 @@ cdef class Doc: # prediction # 3. Test basic data-driven ORTH gazetteer # 4. Test more nuanced date and currency regex + + tokens_in_ents = {} + cdef attr_t entity_type + cdef int ent_start, ent_end + for ent_info in ents: + entity_type, ent_start, ent_end = get_entity_info(ent_info) + for token_index in range(ent_start, ent_end): + if token_index in tokens_in_ents.keys(): + raise ValueError(Errors.E098.format( + span1=(tokens_in_ents[token_index][0], + tokens_in_ents[token_index][1], + self.vocab.strings[tokens_in_ents[token_index][2]]), + span2=(ent_start, ent_end, self.vocab.strings[entity_type]))) + tokens_in_ents[token_index] = (ent_start, ent_end, entity_type) + cdef int i for i in range(self.length): self.c[i].ent_type = 0 @@ -465,15 +480,7 @@ cdef class Doc: cdef attr_t ent_type cdef int start, end for ent_info in ents: - if isinstance(ent_info, Span): - ent_id = ent_info.ent_id - ent_type = ent_info.label - start = ent_info.start - end = ent_info.end - elif len(ent_info) == 3: - ent_type, start, end = ent_info - else: - ent_id, ent_type, start, end = ent_info + ent_type, start, end = get_entity_info(ent_info) if ent_type is None or ent_type < 0: # Mark as O for i in range(start, end): @@ -1062,3 +1069,14 @@ def fix_attributes(doc, attributes): attributes[ENT_TYPE] = doc.vocab.strings[attributes['label']] if 'ent_type' in attributes: attributes[ENT_TYPE] = attributes['ent_type'] + +def get_entity_info(ent_info): + if isinstance(ent_info, Span): + ent_type = ent_info.label + start = ent_info.start + end = ent_info.end + elif len(ent_info) == 3: + ent_type, start, end = ent_info + else: + ent_id, ent_type, start, end = ent_info + return ent_type, start, end \ No newline at end of file