Bugfix/nel crossing sentence (#7630)

* ensure each entity gets a KB ID, even when it's not within a sentence

* cleanup
This commit is contained in:
Sofie Van Landeghem 2021-04-12 10:08:01 +02:00 committed by GitHub
parent 673e2bc4c0
commit 27dbbb9903
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 127 additions and 70 deletions

View File

@ -300,77 +300,77 @@ class EntityLinker(TrainablePipe):
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
sentences = [s for s in doc.sents] sentences = [s for s in doc.sents]
if len(doc) > 0: if len(doc) > 0:
# Looping through each sentence and each entity # Looping through each entity (TODO: rewrite)
# This may go wrong if there are entities across sentences - which shouldn't happen normally. for ent in doc.ents:
for sent_index, sent in enumerate(sentences): sent = ent.sent
if sent.ents: sent_index = sentences.index(sent)
# get n_neightbour sentences, clipped to the length of the document assert sent_index >= 0
start_sentence = max(0, sent_index - self.n_sents) # get n_neightbour sentences, clipped to the length of the document
end_sentence = min( start_sentence = max(0, sent_index - self.n_sents)
len(sentences) - 1, sent_index + self.n_sents end_sentence = min(
) len(sentences) - 1, sent_index + self.n_sents
start_token = sentences[start_sentence].start )
end_token = sentences[end_sentence].end start_token = sentences[start_sentence].start
sent_doc = doc[start_token:end_token].as_doc() end_token = sentences[end_sentence].end
# currently, the context is the same for each entity in a sentence (should be refined) sent_doc = doc[start_token:end_token].as_doc()
xp = self.model.ops.xp # currently, the context is the same for each entity in a sentence (should be refined)
if self.incl_context: xp = self.model.ops.xp
sentence_encoding = self.model.predict([sent_doc])[0] if self.incl_context:
sentence_encoding_t = sentence_encoding.T sentence_encoding = self.model.predict([sent_doc])[0]
sentence_norm = xp.linalg.norm(sentence_encoding_t) sentence_encoding_t = sentence_encoding.T
for ent in sent.ents: sentence_norm = xp.linalg.norm(sentence_encoding_t)
entity_count += 1 entity_count += 1
if ent.label_ in self.labels_discard: if ent.label_ in self.labels_discard:
# ignoring this entity - setting to NIL # ignoring this entity - setting to NIL
final_kb_ids.append(self.NIL) final_kb_ids.append(self.NIL)
else: else:
candidates = self.get_candidates(self.kb, ent) candidates = self.get_candidates(self.kb, ent)
if not candidates: if not candidates:
# no prediction possible for this entity - setting to NIL # no prediction possible for this entity - setting to NIL
final_kb_ids.append(self.NIL) final_kb_ids.append(self.NIL)
elif len(candidates) == 1: elif len(candidates) == 1:
# shortcut for efficiency reasons: take the 1 candidate # shortcut for efficiency reasons: take the 1 candidate
# TODO: thresholding # TODO: thresholding
final_kb_ids.append(candidates[0].entity_) final_kb_ids.append(candidates[0].entity_)
else: else:
random.shuffle(candidates) random.shuffle(candidates)
# set all prior probabilities to 0 if incl_prior=False # set all prior probabilities to 0 if incl_prior=False
prior_probs = xp.asarray( prior_probs = xp.asarray(
[c.prior_prob for c in candidates] [c.prior_prob for c in candidates]
)
if not self.incl_prior:
prior_probs = xp.asarray(
[0.0 for _ in candidates]
)
scores = prior_probs
# add in similarity from the context
if self.incl_context:
entity_encodings = xp.asarray(
[c.entity_vector for c in candidates]
)
entity_norm = xp.linalg.norm(
entity_encodings, axis=1
)
if len(entity_encodings) != len(prior_probs):
raise RuntimeError(
Errors.E147.format(
method="predict",
msg="vectors not of equal length",
)
) )
if not self.incl_prior: # cosine similarity
prior_probs = xp.asarray( sims = xp.dot(
[0.0 for _ in candidates] entity_encodings, sentence_encoding_t
) ) / (sentence_norm * entity_norm)
scores = prior_probs if sims.shape != prior_probs.shape:
# add in similarity from the context raise ValueError(Errors.E161)
if self.incl_context: scores = (
entity_encodings = xp.asarray( prior_probs + sims - (prior_probs * sims)
[c.entity_vector for c in candidates] )
) # TODO: thresholding
entity_norm = xp.linalg.norm( best_index = scores.argmax().item()
entity_encodings, axis=1 best_candidate = candidates[best_index]
) final_kb_ids.append(best_candidate.entity_)
if len(entity_encodings) != len(prior_probs):
raise RuntimeError(
Errors.E147.format(
method="predict",
msg="vectors not of equal length",
)
)
# cosine similarity
sims = xp.dot(
entity_encodings, sentence_encoding_t
) / (sentence_norm * entity_norm)
if sims.shape != prior_probs.shape:
raise ValueError(Errors.E161)
scores = (
prior_probs + sims - (prior_probs * sims)
)
# TODO: thresholding
best_index = scores.argmax().item()
best_candidate = candidates[best_index]
final_kb_ids.append(best_candidate.entity_)
if not (len(final_kb_ids) == entity_count): if not (len(final_kb_ids) == entity_count):
err = Errors.E147.format( err = Errors.E147.format(
method="predict", msg="result variables not of equal length" method="predict", msg="result variables not of equal length"

View File

@ -1,4 +1,6 @@
from spacy.kb import KnowledgeBase
from spacy.lang.en import English from spacy.lang.en import English
from spacy.training import Example
def test_issue7065(): def test_issue7065():
@ -16,3 +18,58 @@ def test_issue7065():
ent = doc.ents[0] ent = doc.ents[0]
assert ent.start < sent0.end < ent.end assert ent.start < sent0.end < ent.end
assert sentences.index(ent.sent) == 0 assert sentences.index(ent.sent) == 0
def test_issue7065_b():
# Test that the NEL doesn't crash when an entity crosses a sentence boundary
nlp = English()
vector_length = 3
nlp.add_pipe("sentencizer")
text = "Mahler 's Symphony No. 8 was beautiful."
entities = [(0, 6, "PERSON"), (10, 24, "WORK")]
links = {(0, 6): {"Q7304": 1.0, "Q270853": 0.0},
(10, 24): {"Q7304": 0.0, "Q270853": 1.0}}
sent_starts = [1, -1, 0, 0, 0, 0, 0, 0, 0]
doc = nlp(text)
example = Example.from_dict(doc, {"entities": entities, "links": links, "sent_starts": sent_starts})
train_examples = [example]
def create_kb(vocab):
# create artificial KB
mykb = KnowledgeBase(vocab, entity_vector_length=vector_length)
mykb.add_entity(entity="Q270853", freq=12, entity_vector=[9, 1, -7])
mykb.add_alias(
alias="No. 8",
entities=["Q270853"],
probabilities=[1.0],
)
mykb.add_entity(entity="Q7304", freq=12, entity_vector=[6, -4, 3])
mykb.add_alias(
alias="Mahler",
entities=["Q7304"],
probabilities=[1.0],
)
return mykb
# Create the Entity Linker component and add it to the pipeline
entity_linker = nlp.add_pipe("entity_linker", last=True)
entity_linker.set_kb(create_kb)
# train the NEL pipe
optimizer = nlp.initialize(get_examples=lambda: train_examples)
for i in range(2):
losses = {}
nlp.update(train_examples, sgd=optimizer, losses=losses)
# Add a custom rule-based component to mimick NER
patterns = [
{"label": "PERSON", "pattern": [{"LOWER": "mahler"}]},
{"label": "WORK", "pattern": [{"LOWER": "symphony"}, {"LOWER": "no"}, {"LOWER": "."}, {"LOWER": "8"}]}
]
ruler = nlp.add_pipe("entity_ruler", before="entity_linker")
ruler.add_patterns(patterns)
# test the trained model - this should not throw E148
doc = nlp(text)
assert doc