ensure Span.as_doc keeps the entity links + unit test

This commit is contained in:
svlandeg 2019-06-25 15:28:51 +02:00
parent 58a5b40ef6
commit 8608685543
8 changed files with 56 additions and 3 deletions

View File

@ -82,6 +82,7 @@ cdef enum attr_id_t:
DEP DEP
ENT_IOB ENT_IOB
ENT_TYPE ENT_TYPE
ENT_KB_ID
HEAD HEAD
SENT_START SENT_START
SPACY SPACY

View File

@ -84,6 +84,7 @@ IDS = {
"DEP": DEP, "DEP": DEP,
"ENT_IOB": ENT_IOB, "ENT_IOB": ENT_IOB,
"ENT_TYPE": ENT_TYPE, "ENT_TYPE": ENT_TYPE,
"ENT_KB_ID": ENT_KB_ID,
"HEAD": HEAD, "HEAD": HEAD,
"SENT_START": SENT_START, "SENT_START": SENT_START,
"SPACY": SPACY, "SPACY": SPACY,

View File

@ -81,6 +81,7 @@ cdef enum symbol_t:
DEP DEP
ENT_IOB ENT_IOB
ENT_TYPE ENT_TYPE
ENT_KB_ID
HEAD HEAD
SENT_START SENT_START
SPACY SPACY

View File

@ -86,6 +86,7 @@ IDS = {
"DEP": DEP, "DEP": DEP,
"ENT_IOB": ENT_IOB, "ENT_IOB": ENT_IOB,
"ENT_TYPE": ENT_TYPE, "ENT_TYPE": ENT_TYPE,
"ENT_KB_ID": ENT_KB_ID,
"HEAD": HEAD, "HEAD": HEAD,
"SENT_START": SENT_START, "SENT_START": SENT_START,
"SPACY": SPACY, "SPACY": SPACY,

View File

@ -5,6 +5,7 @@ import pytest
from spacy.kb import KnowledgeBase from spacy.kb import KnowledgeBase
from spacy.lang.en import English from spacy.lang.en import English
from spacy.pipeline import EntityRuler
@pytest.fixture @pytest.fixture
@ -101,3 +102,44 @@ def test_candidate_generation(nlp):
assert(len(mykb.get_candidates('douglas')) == 2) assert(len(mykb.get_candidates('douglas')) == 2)
assert(len(mykb.get_candidates('adam')) == 1) assert(len(mykb.get_candidates('adam')) == 1)
assert(len(mykb.get_candidates('shrubbery')) == 0) assert(len(mykb.get_candidates('shrubbery')) == 0)
def test_preserving_links_asdoc(nlp):
"""Test that Span.as_doc preserves the existing entity links"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
mykb.add_entity(entity='Q2', prob=0.8, entity_vector=[1])
# adding aliases
mykb.add_alias(alias='Boston', entities=['Q1'], probabilities=[0.7])
mykb.add_alias(alias='Denver', entities=['Q2'], probabilities=[0.6])
# set up pipeline with NER (Entity Ruler) and NEL (prior probability only, model not trained)
sentencizer = nlp.create_pipe("sentencizer")
nlp.add_pipe(sentencizer)
ruler = EntityRuler(nlp)
patterns = [{"label": "GPE", "pattern": "Boston"},
{"label": "GPE", "pattern": "Denver"}]
ruler.add_patterns(patterns)
nlp.add_pipe(ruler)
el_pipe = nlp.create_pipe(name='entity_linker', config={})
el_pipe.set_kb(mykb)
el_pipe.begin_training()
el_pipe.context_weight = 0
el_pipe.prior_weight = 1
nlp.add_pipe(el_pipe, last=True)
# test whether the entity links are preserved by the `as_doc()` function
text = "She lives in Boston. He lives in Denver."
doc = nlp(text)
for ent in doc.ents:
orig_text = ent.text
orig_kb_id = ent.kb_id_
sent_doc = ent.sent.as_doc()
for s_ent in sent_doc.ents:
if s_ent.text == orig_text:
assert s_ent.kb_id_ == orig_kb_id

View File

@ -22,7 +22,7 @@ from ..lexeme cimport Lexeme, EMPTY_LEXEME
from ..typedefs cimport attr_t, flags_t from ..typedefs cimport attr_t, flags_t
from ..attrs cimport ID, ORTH, NORM, LOWER, SHAPE, PREFIX, SUFFIX, CLUSTER from ..attrs cimport ID, ORTH, NORM, LOWER, SHAPE, PREFIX, SUFFIX, CLUSTER
from ..attrs cimport LENGTH, POS, LEMMA, TAG, DEP, HEAD, SPACY, ENT_IOB from ..attrs cimport LENGTH, POS, LEMMA, TAG, DEP, HEAD, SPACY, ENT_IOB
from ..attrs cimport ENT_TYPE, SENT_START, attr_id_t from ..attrs cimport ENT_TYPE, ENT_KB_ID, SENT_START, attr_id_t
from ..parts_of_speech cimport CCONJ, PUNCT, NOUN, univ_pos_t from ..parts_of_speech cimport CCONJ, PUNCT, NOUN, univ_pos_t
from ..attrs import intify_attrs, IDS from ..attrs import intify_attrs, IDS
@ -64,6 +64,8 @@ cdef attr_t get_token_attr(const TokenC* token, attr_id_t feat_name) nogil:
return token.ent_iob return token.ent_iob
elif feat_name == ENT_TYPE: elif feat_name == ENT_TYPE:
return token.ent_type return token.ent_type
elif feat_name == ENT_KB_ID:
return token.ent_kb_id
else: else:
return Lexeme.get_struct_attr(token.lex, feat_name) return Lexeme.get_struct_attr(token.lex, feat_name)
@ -850,7 +852,7 @@ cdef class Doc:
DOCS: https://spacy.io/api/doc#to_bytes DOCS: https://spacy.io/api/doc#to_bytes
""" """
array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE] array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE] # TODO: ENT_KB_ID ?
if self.is_tagged: if self.is_tagged:
array_head.append(TAG) array_head.append(TAG)
# If doc parsed add head and dep attribute # If doc parsed add head and dep attribute
@ -1004,6 +1006,7 @@ cdef class Doc:
""" """
cdef unicode tag, lemma, ent_type cdef unicode tag, lemma, ent_type
deprecation_warning(Warnings.W013.format(obj="Doc")) deprecation_warning(Warnings.W013.format(obj="Doc"))
# TODO: ENT_KB_ID ?
if len(args) == 3: if len(args) == 3:
deprecation_warning(Warnings.W003) deprecation_warning(Warnings.W003)
tag, lemma, ent_type = args tag, lemma, ent_type = args

View File

@ -210,7 +210,7 @@ cdef class Span:
words = [t.text for t in self] words = [t.text for t in self]
spaces = [bool(t.whitespace_) for t in self] spaces = [bool(t.whitespace_) for t in self]
cdef Doc doc = Doc(self.doc.vocab, words=words, spaces=spaces) cdef Doc doc = Doc(self.doc.vocab, words=words, spaces=spaces)
array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE] array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE, ENT_KB_ID]
if self.doc.is_tagged: if self.doc.is_tagged:
array_head.append(TAG) array_head.append(TAG)
# If doc parsed add head and dep attribute # If doc parsed add head and dep attribute

View File

@ -53,6 +53,8 @@ cdef class Token:
return token.ent_iob return token.ent_iob
elif feat_name == ENT_TYPE: elif feat_name == ENT_TYPE:
return token.ent_type return token.ent_type
elif feat_name == ENT_KB_ID:
return token.ent_kb_id
elif feat_name == SENT_START: elif feat_name == SENT_START:
return token.sent_start return token.sent_start
else: else:
@ -79,5 +81,7 @@ cdef class Token:
token.ent_iob = value token.ent_iob = value
elif feat_name == ENT_TYPE: elif feat_name == ENT_TYPE:
token.ent_type = value token.ent_type = value
elif feat_name == ENT_KB_ID:
token.ent_kb_id = value
elif feat_name == SENT_START: elif feat_name == SENT_START:
token.sent_start = value token.sent_start = value