ensure Span.as_doc keeps the entity links + unit test

This commit is contained in:
svlandeg 2019-06-25 15:28:51 +02:00
parent 58a5b40ef6
commit 8608685543
8 changed files with 56 additions and 3 deletions

View File

@ -82,6 +82,7 @@ cdef enum attr_id_t:
DEP
ENT_IOB
ENT_TYPE
ENT_KB_ID
HEAD
SENT_START
SPACY

View File

@ -84,6 +84,7 @@ IDS = {
"DEP": DEP,
"ENT_IOB": ENT_IOB,
"ENT_TYPE": ENT_TYPE,
"ENT_KB_ID": ENT_KB_ID,
"HEAD": HEAD,
"SENT_START": SENT_START,
"SPACY": SPACY,

View File

@ -81,6 +81,7 @@ cdef enum symbol_t:
DEP
ENT_IOB
ENT_TYPE
ENT_KB_ID
HEAD
SENT_START
SPACY

View File

@ -86,6 +86,7 @@ IDS = {
"DEP": DEP,
"ENT_IOB": ENT_IOB,
"ENT_TYPE": ENT_TYPE,
"ENT_KB_ID": ENT_KB_ID,
"HEAD": HEAD,
"SENT_START": SENT_START,
"SPACY": SPACY,

View File

@ -5,6 +5,7 @@ import pytest
from spacy.kb import KnowledgeBase
from spacy.lang.en import English
from spacy.pipeline import EntityRuler
@pytest.fixture
@ -101,3 +102,44 @@ def test_candidate_generation(nlp):
assert(len(mykb.get_candidates('douglas')) == 2)
assert(len(mykb.get_candidates('adam')) == 1)
assert(len(mykb.get_candidates('shrubbery')) == 0)
def test_preserving_links_asdoc(nlp):
"""Test that Span.as_doc preserves the existing entity links"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
mykb.add_entity(entity='Q2', prob=0.8, entity_vector=[1])
# adding aliases
mykb.add_alias(alias='Boston', entities=['Q1'], probabilities=[0.7])
mykb.add_alias(alias='Denver', entities=['Q2'], probabilities=[0.6])
# set up pipeline with NER (Entity Ruler) and NEL (prior probability only, model not trained)
sentencizer = nlp.create_pipe("sentencizer")
nlp.add_pipe(sentencizer)
ruler = EntityRuler(nlp)
patterns = [{"label": "GPE", "pattern": "Boston"},
{"label": "GPE", "pattern": "Denver"}]
ruler.add_patterns(patterns)
nlp.add_pipe(ruler)
el_pipe = nlp.create_pipe(name='entity_linker', config={})
el_pipe.set_kb(mykb)
el_pipe.begin_training()
el_pipe.context_weight = 0
el_pipe.prior_weight = 1
nlp.add_pipe(el_pipe, last=True)
# test whether the entity links are preserved by the `as_doc()` function
text = "She lives in Boston. He lives in Denver."
doc = nlp(text)
for ent in doc.ents:
orig_text = ent.text
orig_kb_id = ent.kb_id_
sent_doc = ent.sent.as_doc()
for s_ent in sent_doc.ents:
if s_ent.text == orig_text:
assert s_ent.kb_id_ == orig_kb_id

View File

@ -22,7 +22,7 @@ from ..lexeme cimport Lexeme, EMPTY_LEXEME
from ..typedefs cimport attr_t, flags_t
from ..attrs cimport ID, ORTH, NORM, LOWER, SHAPE, PREFIX, SUFFIX, CLUSTER
from ..attrs cimport LENGTH, POS, LEMMA, TAG, DEP, HEAD, SPACY, ENT_IOB
from ..attrs cimport ENT_TYPE, SENT_START, attr_id_t
from ..attrs cimport ENT_TYPE, ENT_KB_ID, SENT_START, attr_id_t
from ..parts_of_speech cimport CCONJ, PUNCT, NOUN, univ_pos_t
from ..attrs import intify_attrs, IDS
@ -64,6 +64,8 @@ cdef attr_t get_token_attr(const TokenC* token, attr_id_t feat_name) nogil:
return token.ent_iob
elif feat_name == ENT_TYPE:
return token.ent_type
elif feat_name == ENT_KB_ID:
return token.ent_kb_id
else:
return Lexeme.get_struct_attr(token.lex, feat_name)
@ -850,7 +852,7 @@ cdef class Doc:
DOCS: https://spacy.io/api/doc#to_bytes
"""
array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE]
array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE] # TODO: ENT_KB_ID ?
if self.is_tagged:
array_head.append(TAG)
# If doc parsed add head and dep attribute
@ -1004,6 +1006,7 @@ cdef class Doc:
"""
cdef unicode tag, lemma, ent_type
deprecation_warning(Warnings.W013.format(obj="Doc"))
# TODO: ENT_KB_ID ?
if len(args) == 3:
deprecation_warning(Warnings.W003)
tag, lemma, ent_type = args

View File

@ -210,7 +210,7 @@ cdef class Span:
words = [t.text for t in self]
spaces = [bool(t.whitespace_) for t in self]
cdef Doc doc = Doc(self.doc.vocab, words=words, spaces=spaces)
array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE]
array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE, ENT_KB_ID]
if self.doc.is_tagged:
array_head.append(TAG)
# If doc parsed add head and dep attribute

View File

@ -53,6 +53,8 @@ cdef class Token:
return token.ent_iob
elif feat_name == ENT_TYPE:
return token.ent_type
elif feat_name == ENT_KB_ID:
return token.ent_kb_id
elif feat_name == SENT_START:
return token.sent_start
else:
@ -79,5 +81,7 @@ cdef class Token:
token.ent_iob = value
elif feat_name == ENT_TYPE:
token.ent_type = value
elif feat_name == ENT_KB_ID:
token.ent_kb_id = value
elif feat_name == SENT_START:
token.sent_start = value