diff --git a/spacy/tests/pipeline/test_entity_linker.py b/spacy/tests/pipeline/test_entity_linker.py index b10d55267..0c89a2e14 100644 --- a/spacy/tests/pipeline/test_entity_linker.py +++ b/spacy/tests/pipeline/test_entity_linker.py @@ -6,6 +6,7 @@ import pytest from spacy.kb import KnowledgeBase from spacy.lang.en import English from spacy.pipeline import EntityRuler +from spacy.tokens import Span @pytest.fixture @@ -171,3 +172,31 @@ def test_preserving_links_asdoc(nlp): for s_ent in sent_doc.ents: if s_ent.text == orig_text: assert s_ent.kb_id_ == orig_kb_id + + +def test_preserving_links_ents(nlp): + """Test that doc.ents preserves KB annotations""" + text = "She lives in Boston. He lives in Denver." + doc = nlp(text) + assert len(list(doc.ents)) == 0 + + boston_ent = Span(doc, 3, 4, label="LOC", kb_id="Q1") + doc.ents = [boston_ent] + assert len(list(doc.ents)) == 1 + assert list(doc.ents)[0].label_ == "LOC" + assert list(doc.ents)[0].kb_id_ == "Q1" + + +def test_preserving_links_ents_2(nlp): + """Test that doc.ents preserves KB annotations""" + text = "She lives in Boston. He lives in Denver." + doc = nlp(text) + assert len(list(doc.ents)) == 0 + + loc = doc.vocab.strings.add("LOC") + q1 = doc.vocab.strings.add("Q1") + + doc.ents = [(loc, q1, 3, 4)] + assert len(list(doc.ents)) == 1 + assert list(doc.ents)[0].label_ == "LOC" + assert list(doc.ents)[0].kb_id_ == "Q1" diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index e5c213383..e863b0807 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -534,7 +534,7 @@ cdef class Doc: cdef attr_t entity_type cdef int ent_start, ent_end for ent_info in ents: - entity_type, ent_start, ent_end = get_entity_info(ent_info) + entity_type, kb_id, ent_start, ent_end = get_entity_info(ent_info) for token_index in range(ent_start, ent_end): if token_index in tokens_in_ents.keys(): raise ValueError(Errors.E103.format( @@ -542,7 +542,7 @@ cdef class Doc: tokens_in_ents[token_index][1], self.vocab.strings[tokens_in_ents[token_index][2]]), span2=(ent_start, ent_end, self.vocab.strings[entity_type]))) - tokens_in_ents[token_index] = (ent_start, ent_end, entity_type) + tokens_in_ents[token_index] = (ent_start, ent_end, entity_type, kb_id) cdef int i for i in range(self.length): self.c[i].ent_type = 0 @@ -551,16 +551,18 @@ cdef class Doc: cdef attr_t ent_type cdef int start, end for ent_info in ents: - ent_type, start, end = get_entity_info(ent_info) + ent_type, ent_kb_id, start, end = get_entity_info(ent_info) if ent_type is None or ent_type < 0: # Mark as O for i in range(start, end): self.c[i].ent_type = 0 + self.c[i].ent_kb_id = 0 self.c[i].ent_iob = 2 else: # Mark (inside) as I for i in range(start, end): self.c[i].ent_type = ent_type + self.c[i].ent_kb_id = ent_kb_id self.c[i].ent_iob = 1 # Set start as B self.c[start].ent_iob = 3 @@ -1251,10 +1253,14 @@ def fix_attributes(doc, attributes): def get_entity_info(ent_info): if isinstance(ent_info, Span): ent_type = ent_info.label + ent_kb_id = ent_info.kb_id start = ent_info.start end = ent_info.end elif len(ent_info) == 3: ent_type, start, end = ent_info + ent_kb_id = 0 + elif len(ent_info) == 4: + ent_type, ent_kb_id, start, end = ent_info else: - ent_id, ent_type, start, end = ent_info - return ent_type, start, end + ent_id, ent_kb_id, ent_type, start, end = ent_info + return ent_type, ent_kb_id, start, end