mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 09:26:27 +03:00
Ensure that doc.ents preserves kb_id annotations (#4294)
* bugfix: ensure doc.ents preserves kb_id annotations * fix backward compatibility * additional test
This commit is contained in:
parent
57f4c088be
commit
03ac29f437
|
@ -6,6 +6,7 @@ import pytest
|
||||||
from spacy.kb import KnowledgeBase
|
from spacy.kb import KnowledgeBase
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.pipeline import EntityRuler
|
from spacy.pipeline import EntityRuler
|
||||||
|
from spacy.tokens import Span
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -171,3 +172,31 @@ def test_preserving_links_asdoc(nlp):
|
||||||
for s_ent in sent_doc.ents:
|
for s_ent in sent_doc.ents:
|
||||||
if s_ent.text == orig_text:
|
if s_ent.text == orig_text:
|
||||||
assert s_ent.kb_id_ == orig_kb_id
|
assert s_ent.kb_id_ == orig_kb_id
|
||||||
|
|
||||||
|
|
||||||
|
def test_preserving_links_ents(nlp):
|
||||||
|
"""Test that doc.ents preserves KB annotations"""
|
||||||
|
text = "She lives in Boston. He lives in Denver."
|
||||||
|
doc = nlp(text)
|
||||||
|
assert len(list(doc.ents)) == 0
|
||||||
|
|
||||||
|
boston_ent = Span(doc, 3, 4, label="LOC", kb_id="Q1")
|
||||||
|
doc.ents = [boston_ent]
|
||||||
|
assert len(list(doc.ents)) == 1
|
||||||
|
assert list(doc.ents)[0].label_ == "LOC"
|
||||||
|
assert list(doc.ents)[0].kb_id_ == "Q1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_preserving_links_ents_2(nlp):
|
||||||
|
"""Test that doc.ents preserves KB annotations"""
|
||||||
|
text = "She lives in Boston. He lives in Denver."
|
||||||
|
doc = nlp(text)
|
||||||
|
assert len(list(doc.ents)) == 0
|
||||||
|
|
||||||
|
loc = doc.vocab.strings.add("LOC")
|
||||||
|
q1 = doc.vocab.strings.add("Q1")
|
||||||
|
|
||||||
|
doc.ents = [(loc, q1, 3, 4)]
|
||||||
|
assert len(list(doc.ents)) == 1
|
||||||
|
assert list(doc.ents)[0].label_ == "LOC"
|
||||||
|
assert list(doc.ents)[0].kb_id_ == "Q1"
|
||||||
|
|
|
@ -534,7 +534,7 @@ cdef class Doc:
|
||||||
cdef attr_t entity_type
|
cdef attr_t entity_type
|
||||||
cdef int ent_start, ent_end
|
cdef int ent_start, ent_end
|
||||||
for ent_info in ents:
|
for ent_info in ents:
|
||||||
entity_type, ent_start, ent_end = get_entity_info(ent_info)
|
entity_type, kb_id, ent_start, ent_end = get_entity_info(ent_info)
|
||||||
for token_index in range(ent_start, ent_end):
|
for token_index in range(ent_start, ent_end):
|
||||||
if token_index in tokens_in_ents.keys():
|
if token_index in tokens_in_ents.keys():
|
||||||
raise ValueError(Errors.E103.format(
|
raise ValueError(Errors.E103.format(
|
||||||
|
@ -542,7 +542,7 @@ cdef class Doc:
|
||||||
tokens_in_ents[token_index][1],
|
tokens_in_ents[token_index][1],
|
||||||
self.vocab.strings[tokens_in_ents[token_index][2]]),
|
self.vocab.strings[tokens_in_ents[token_index][2]]),
|
||||||
span2=(ent_start, ent_end, self.vocab.strings[entity_type])))
|
span2=(ent_start, ent_end, self.vocab.strings[entity_type])))
|
||||||
tokens_in_ents[token_index] = (ent_start, ent_end, entity_type)
|
tokens_in_ents[token_index] = (ent_start, ent_end, entity_type, kb_id)
|
||||||
cdef int i
|
cdef int i
|
||||||
for i in range(self.length):
|
for i in range(self.length):
|
||||||
self.c[i].ent_type = 0
|
self.c[i].ent_type = 0
|
||||||
|
@ -551,16 +551,18 @@ cdef class Doc:
|
||||||
cdef attr_t ent_type
|
cdef attr_t ent_type
|
||||||
cdef int start, end
|
cdef int start, end
|
||||||
for ent_info in ents:
|
for ent_info in ents:
|
||||||
ent_type, start, end = get_entity_info(ent_info)
|
ent_type, ent_kb_id, start, end = get_entity_info(ent_info)
|
||||||
if ent_type is None or ent_type < 0:
|
if ent_type is None or ent_type < 0:
|
||||||
# Mark as O
|
# Mark as O
|
||||||
for i in range(start, end):
|
for i in range(start, end):
|
||||||
self.c[i].ent_type = 0
|
self.c[i].ent_type = 0
|
||||||
|
self.c[i].ent_kb_id = 0
|
||||||
self.c[i].ent_iob = 2
|
self.c[i].ent_iob = 2
|
||||||
else:
|
else:
|
||||||
# Mark (inside) as I
|
# Mark (inside) as I
|
||||||
for i in range(start, end):
|
for i in range(start, end):
|
||||||
self.c[i].ent_type = ent_type
|
self.c[i].ent_type = ent_type
|
||||||
|
self.c[i].ent_kb_id = ent_kb_id
|
||||||
self.c[i].ent_iob = 1
|
self.c[i].ent_iob = 1
|
||||||
# Set start as B
|
# Set start as B
|
||||||
self.c[start].ent_iob = 3
|
self.c[start].ent_iob = 3
|
||||||
|
@ -1251,10 +1253,14 @@ def fix_attributes(doc, attributes):
|
||||||
def get_entity_info(ent_info):
|
def get_entity_info(ent_info):
|
||||||
if isinstance(ent_info, Span):
|
if isinstance(ent_info, Span):
|
||||||
ent_type = ent_info.label
|
ent_type = ent_info.label
|
||||||
|
ent_kb_id = ent_info.kb_id
|
||||||
start = ent_info.start
|
start = ent_info.start
|
||||||
end = ent_info.end
|
end = ent_info.end
|
||||||
elif len(ent_info) == 3:
|
elif len(ent_info) == 3:
|
||||||
ent_type, start, end = ent_info
|
ent_type, start, end = ent_info
|
||||||
|
ent_kb_id = 0
|
||||||
|
elif len(ent_info) == 4:
|
||||||
|
ent_type, ent_kb_id, start, end = ent_info
|
||||||
else:
|
else:
|
||||||
ent_id, ent_type, start, end = ent_info
|
ent_id, ent_kb_id, ent_type, start, end = ent_info
|
||||||
return ent_type, start, end
|
return ent_type, ent_kb_id, start, end
|
||||||
|
|
Loading…
Reference in New Issue
Block a user