mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
ensure Span.as_doc keeps the entity links + unit test
This commit is contained in:
parent
58a5b40ef6
commit
8608685543
|
@ -82,6 +82,7 @@ cdef enum attr_id_t:
|
||||||
DEP
|
DEP
|
||||||
ENT_IOB
|
ENT_IOB
|
||||||
ENT_TYPE
|
ENT_TYPE
|
||||||
|
ENT_KB_ID
|
||||||
HEAD
|
HEAD
|
||||||
SENT_START
|
SENT_START
|
||||||
SPACY
|
SPACY
|
||||||
|
|
|
@ -84,6 +84,7 @@ IDS = {
|
||||||
"DEP": DEP,
|
"DEP": DEP,
|
||||||
"ENT_IOB": ENT_IOB,
|
"ENT_IOB": ENT_IOB,
|
||||||
"ENT_TYPE": ENT_TYPE,
|
"ENT_TYPE": ENT_TYPE,
|
||||||
|
"ENT_KB_ID": ENT_KB_ID,
|
||||||
"HEAD": HEAD,
|
"HEAD": HEAD,
|
||||||
"SENT_START": SENT_START,
|
"SENT_START": SENT_START,
|
||||||
"SPACY": SPACY,
|
"SPACY": SPACY,
|
||||||
|
|
|
@ -81,6 +81,7 @@ cdef enum symbol_t:
|
||||||
DEP
|
DEP
|
||||||
ENT_IOB
|
ENT_IOB
|
||||||
ENT_TYPE
|
ENT_TYPE
|
||||||
|
ENT_KB_ID
|
||||||
HEAD
|
HEAD
|
||||||
SENT_START
|
SENT_START
|
||||||
SPACY
|
SPACY
|
||||||
|
|
|
@ -86,6 +86,7 @@ IDS = {
|
||||||
"DEP": DEP,
|
"DEP": DEP,
|
||||||
"ENT_IOB": ENT_IOB,
|
"ENT_IOB": ENT_IOB,
|
||||||
"ENT_TYPE": ENT_TYPE,
|
"ENT_TYPE": ENT_TYPE,
|
||||||
|
"ENT_KB_ID": ENT_KB_ID,
|
||||||
"HEAD": HEAD,
|
"HEAD": HEAD,
|
||||||
"SENT_START": SENT_START,
|
"SENT_START": SENT_START,
|
||||||
"SPACY": SPACY,
|
"SPACY": SPACY,
|
||||||
|
|
|
@ -5,6 +5,7 @@ import pytest
|
||||||
|
|
||||||
from spacy.kb import KnowledgeBase
|
from spacy.kb import KnowledgeBase
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
|
from spacy.pipeline import EntityRuler
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -101,3 +102,44 @@ def test_candidate_generation(nlp):
|
||||||
assert(len(mykb.get_candidates('douglas')) == 2)
|
assert(len(mykb.get_candidates('douglas')) == 2)
|
||||||
assert(len(mykb.get_candidates('adam')) == 1)
|
assert(len(mykb.get_candidates('adam')) == 1)
|
||||||
assert(len(mykb.get_candidates('shrubbery')) == 0)
|
assert(len(mykb.get_candidates('shrubbery')) == 0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_preserving_links_asdoc(nlp):
|
||||||
|
"""Test that Span.as_doc preserves the existing entity links"""
|
||||||
|
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||||
|
|
||||||
|
# adding entities
|
||||||
|
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
|
||||||
|
mykb.add_entity(entity='Q2', prob=0.8, entity_vector=[1])
|
||||||
|
|
||||||
|
# adding aliases
|
||||||
|
mykb.add_alias(alias='Boston', entities=['Q1'], probabilities=[0.7])
|
||||||
|
mykb.add_alias(alias='Denver', entities=['Q2'], probabilities=[0.6])
|
||||||
|
|
||||||
|
# set up pipeline with NER (Entity Ruler) and NEL (prior probability only, model not trained)
|
||||||
|
sentencizer = nlp.create_pipe("sentencizer")
|
||||||
|
nlp.add_pipe(sentencizer)
|
||||||
|
|
||||||
|
ruler = EntityRuler(nlp)
|
||||||
|
patterns = [{"label": "GPE", "pattern": "Boston"},
|
||||||
|
{"label": "GPE", "pattern": "Denver"}]
|
||||||
|
ruler.add_patterns(patterns)
|
||||||
|
nlp.add_pipe(ruler)
|
||||||
|
|
||||||
|
el_pipe = nlp.create_pipe(name='entity_linker', config={})
|
||||||
|
el_pipe.set_kb(mykb)
|
||||||
|
el_pipe.begin_training()
|
||||||
|
el_pipe.context_weight = 0
|
||||||
|
el_pipe.prior_weight = 1
|
||||||
|
nlp.add_pipe(el_pipe, last=True)
|
||||||
|
|
||||||
|
# test whether the entity links are preserved by the `as_doc()` function
|
||||||
|
text = "She lives in Boston. He lives in Denver."
|
||||||
|
doc = nlp(text)
|
||||||
|
for ent in doc.ents:
|
||||||
|
orig_text = ent.text
|
||||||
|
orig_kb_id = ent.kb_id_
|
||||||
|
sent_doc = ent.sent.as_doc()
|
||||||
|
for s_ent in sent_doc.ents:
|
||||||
|
if s_ent.text == orig_text:
|
||||||
|
assert s_ent.kb_id_ == orig_kb_id
|
||||||
|
|
|
@ -22,7 +22,7 @@ from ..lexeme cimport Lexeme, EMPTY_LEXEME
|
||||||
from ..typedefs cimport attr_t, flags_t
|
from ..typedefs cimport attr_t, flags_t
|
||||||
from ..attrs cimport ID, ORTH, NORM, LOWER, SHAPE, PREFIX, SUFFIX, CLUSTER
|
from ..attrs cimport ID, ORTH, NORM, LOWER, SHAPE, PREFIX, SUFFIX, CLUSTER
|
||||||
from ..attrs cimport LENGTH, POS, LEMMA, TAG, DEP, HEAD, SPACY, ENT_IOB
|
from ..attrs cimport LENGTH, POS, LEMMA, TAG, DEP, HEAD, SPACY, ENT_IOB
|
||||||
from ..attrs cimport ENT_TYPE, SENT_START, attr_id_t
|
from ..attrs cimport ENT_TYPE, ENT_KB_ID, SENT_START, attr_id_t
|
||||||
from ..parts_of_speech cimport CCONJ, PUNCT, NOUN, univ_pos_t
|
from ..parts_of_speech cimport CCONJ, PUNCT, NOUN, univ_pos_t
|
||||||
|
|
||||||
from ..attrs import intify_attrs, IDS
|
from ..attrs import intify_attrs, IDS
|
||||||
|
@ -64,6 +64,8 @@ cdef attr_t get_token_attr(const TokenC* token, attr_id_t feat_name) nogil:
|
||||||
return token.ent_iob
|
return token.ent_iob
|
||||||
elif feat_name == ENT_TYPE:
|
elif feat_name == ENT_TYPE:
|
||||||
return token.ent_type
|
return token.ent_type
|
||||||
|
elif feat_name == ENT_KB_ID:
|
||||||
|
return token.ent_kb_id
|
||||||
else:
|
else:
|
||||||
return Lexeme.get_struct_attr(token.lex, feat_name)
|
return Lexeme.get_struct_attr(token.lex, feat_name)
|
||||||
|
|
||||||
|
@ -850,7 +852,7 @@ cdef class Doc:
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/doc#to_bytes
|
DOCS: https://spacy.io/api/doc#to_bytes
|
||||||
"""
|
"""
|
||||||
array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE]
|
array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE] # TODO: ENT_KB_ID ?
|
||||||
if self.is_tagged:
|
if self.is_tagged:
|
||||||
array_head.append(TAG)
|
array_head.append(TAG)
|
||||||
# If doc parsed add head and dep attribute
|
# If doc parsed add head and dep attribute
|
||||||
|
@ -1004,6 +1006,7 @@ cdef class Doc:
|
||||||
"""
|
"""
|
||||||
cdef unicode tag, lemma, ent_type
|
cdef unicode tag, lemma, ent_type
|
||||||
deprecation_warning(Warnings.W013.format(obj="Doc"))
|
deprecation_warning(Warnings.W013.format(obj="Doc"))
|
||||||
|
# TODO: ENT_KB_ID ?
|
||||||
if len(args) == 3:
|
if len(args) == 3:
|
||||||
deprecation_warning(Warnings.W003)
|
deprecation_warning(Warnings.W003)
|
||||||
tag, lemma, ent_type = args
|
tag, lemma, ent_type = args
|
||||||
|
|
|
@ -210,7 +210,7 @@ cdef class Span:
|
||||||
words = [t.text for t in self]
|
words = [t.text for t in self]
|
||||||
spaces = [bool(t.whitespace_) for t in self]
|
spaces = [bool(t.whitespace_) for t in self]
|
||||||
cdef Doc doc = Doc(self.doc.vocab, words=words, spaces=spaces)
|
cdef Doc doc = Doc(self.doc.vocab, words=words, spaces=spaces)
|
||||||
array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE]
|
array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE, ENT_KB_ID]
|
||||||
if self.doc.is_tagged:
|
if self.doc.is_tagged:
|
||||||
array_head.append(TAG)
|
array_head.append(TAG)
|
||||||
# If doc parsed add head and dep attribute
|
# If doc parsed add head and dep attribute
|
||||||
|
|
|
@ -53,6 +53,8 @@ cdef class Token:
|
||||||
return token.ent_iob
|
return token.ent_iob
|
||||||
elif feat_name == ENT_TYPE:
|
elif feat_name == ENT_TYPE:
|
||||||
return token.ent_type
|
return token.ent_type
|
||||||
|
elif feat_name == ENT_KB_ID:
|
||||||
|
return token.ent_kb_id
|
||||||
elif feat_name == SENT_START:
|
elif feat_name == SENT_START:
|
||||||
return token.sent_start
|
return token.sent_start
|
||||||
else:
|
else:
|
||||||
|
@ -79,5 +81,7 @@ cdef class Token:
|
||||||
token.ent_iob = value
|
token.ent_iob = value
|
||||||
elif feat_name == ENT_TYPE:
|
elif feat_name == ENT_TYPE:
|
||||||
token.ent_type = value
|
token.ent_type = value
|
||||||
|
elif feat_name == ENT_KB_ID:
|
||||||
|
token.ent_kb_id = value
|
||||||
elif feat_name == SENT_START:
|
elif feat_name == SENT_START:
|
||||||
token.sent_start = value
|
token.sent_start = value
|
||||||
|
|
Loading…
Reference in New Issue
Block a user