From 27dbbb99031c4859272cdd36688547b6b1ba0d0e Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Mon, 12 Apr 2021 10:08:01 +0200 Subject: [PATCH] Bugfix/nel crossing sentence (#7630) * ensure each entity gets a KB ID, even when it's not within a sentence * cleanup --- spacy/pipeline/entity_linker.py | 140 +++++++++++------------ spacy/tests/regression/test_issue7065.py | 57 +++++++++ 2 files changed, 127 insertions(+), 70 deletions(-) diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index 630057c3f..6ab52fb35 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -300,77 +300,77 @@ class EntityLinker(TrainablePipe): for i, doc in enumerate(docs): sentences = [s for s in doc.sents] if len(doc) > 0: - # Looping through each sentence and each entity - # This may go wrong if there are entities across sentences - which shouldn't happen normally. - for sent_index, sent in enumerate(sentences): - if sent.ents: - # get n_neightbour sentences, clipped to the length of the document - start_sentence = max(0, sent_index - self.n_sents) - end_sentence = min( - len(sentences) - 1, sent_index + self.n_sents - ) - start_token = sentences[start_sentence].start - end_token = sentences[end_sentence].end - sent_doc = doc[start_token:end_token].as_doc() - # currently, the context is the same for each entity in a sentence (should be refined) - xp = self.model.ops.xp - if self.incl_context: - sentence_encoding = self.model.predict([sent_doc])[0] - sentence_encoding_t = sentence_encoding.T - sentence_norm = xp.linalg.norm(sentence_encoding_t) - for ent in sent.ents: - entity_count += 1 - if ent.label_ in self.labels_discard: - # ignoring this entity - setting to NIL - final_kb_ids.append(self.NIL) - else: - candidates = self.get_candidates(self.kb, ent) - if not candidates: - # no prediction possible for this entity - setting to NIL - final_kb_ids.append(self.NIL) - elif len(candidates) == 1: - # shortcut for efficiency reasons: take the 1 candidate - # TODO: thresholding - final_kb_ids.append(candidates[0].entity_) - else: - random.shuffle(candidates) - # set all prior probabilities to 0 if incl_prior=False - prior_probs = xp.asarray( - [c.prior_prob for c in candidates] + # Looping through each entity (TODO: rewrite) + for ent in doc.ents: + sent = ent.sent + sent_index = sentences.index(sent) + assert sent_index >= 0 + # get n_neightbour sentences, clipped to the length of the document + start_sentence = max(0, sent_index - self.n_sents) + end_sentence = min( + len(sentences) - 1, sent_index + self.n_sents + ) + start_token = sentences[start_sentence].start + end_token = sentences[end_sentence].end + sent_doc = doc[start_token:end_token].as_doc() + # currently, the context is the same for each entity in a sentence (should be refined) + xp = self.model.ops.xp + if self.incl_context: + sentence_encoding = self.model.predict([sent_doc])[0] + sentence_encoding_t = sentence_encoding.T + sentence_norm = xp.linalg.norm(sentence_encoding_t) + entity_count += 1 + if ent.label_ in self.labels_discard: + # ignoring this entity - setting to NIL + final_kb_ids.append(self.NIL) + else: + candidates = self.get_candidates(self.kb, ent) + if not candidates: + # no prediction possible for this entity - setting to NIL + final_kb_ids.append(self.NIL) + elif len(candidates) == 1: + # shortcut for efficiency reasons: take the 1 candidate + # TODO: thresholding + final_kb_ids.append(candidates[0].entity_) + else: + random.shuffle(candidates) + # set all prior probabilities to 0 if incl_prior=False + prior_probs = xp.asarray( + [c.prior_prob for c in candidates] + ) + if not self.incl_prior: + prior_probs = xp.asarray( + [0.0 for _ in candidates] + ) + scores = prior_probs + # add in similarity from the context + if self.incl_context: + entity_encodings = xp.asarray( + [c.entity_vector for c in candidates] + ) + entity_norm = xp.linalg.norm( + entity_encodings, axis=1 + ) + if len(entity_encodings) != len(prior_probs): + raise RuntimeError( + Errors.E147.format( + method="predict", + msg="vectors not of equal length", + ) ) - if not self.incl_prior: - prior_probs = xp.asarray( - [0.0 for _ in candidates] - ) - scores = prior_probs - # add in similarity from the context - if self.incl_context: - entity_encodings = xp.asarray( - [c.entity_vector for c in candidates] - ) - entity_norm = xp.linalg.norm( - entity_encodings, axis=1 - ) - if len(entity_encodings) != len(prior_probs): - raise RuntimeError( - Errors.E147.format( - method="predict", - msg="vectors not of equal length", - ) - ) - # cosine similarity - sims = xp.dot( - entity_encodings, sentence_encoding_t - ) / (sentence_norm * entity_norm) - if sims.shape != prior_probs.shape: - raise ValueError(Errors.E161) - scores = ( - prior_probs + sims - (prior_probs * sims) - ) - # TODO: thresholding - best_index = scores.argmax().item() - best_candidate = candidates[best_index] - final_kb_ids.append(best_candidate.entity_) + # cosine similarity + sims = xp.dot( + entity_encodings, sentence_encoding_t + ) / (sentence_norm * entity_norm) + if sims.shape != prior_probs.shape: + raise ValueError(Errors.E161) + scores = ( + prior_probs + sims - (prior_probs * sims) + ) + # TODO: thresholding + best_index = scores.argmax().item() + best_candidate = candidates[best_index] + final_kb_ids.append(best_candidate.entity_) if not (len(final_kb_ids) == entity_count): err = Errors.E147.format( method="predict", msg="result variables not of equal length" diff --git a/spacy/tests/regression/test_issue7065.py b/spacy/tests/regression/test_issue7065.py index 897687d19..63d36552a 100644 --- a/spacy/tests/regression/test_issue7065.py +++ b/spacy/tests/regression/test_issue7065.py @@ -1,4 +1,6 @@ +from spacy.kb import KnowledgeBase from spacy.lang.en import English +from spacy.training import Example def test_issue7065(): @@ -16,3 +18,58 @@ def test_issue7065(): ent = doc.ents[0] assert ent.start < sent0.end < ent.end assert sentences.index(ent.sent) == 0 + + +def test_issue7065_b(): + # Test that the NEL doesn't crash when an entity crosses a sentence boundary + nlp = English() + vector_length = 3 + nlp.add_pipe("sentencizer") + + text = "Mahler 's Symphony No. 8 was beautiful." + entities = [(0, 6, "PERSON"), (10, 24, "WORK")] + links = {(0, 6): {"Q7304": 1.0, "Q270853": 0.0}, + (10, 24): {"Q7304": 0.0, "Q270853": 1.0}} + sent_starts = [1, -1, 0, 0, 0, 0, 0, 0, 0] + doc = nlp(text) + example = Example.from_dict(doc, {"entities": entities, "links": links, "sent_starts": sent_starts}) + train_examples = [example] + + def create_kb(vocab): + # create artificial KB + mykb = KnowledgeBase(vocab, entity_vector_length=vector_length) + mykb.add_entity(entity="Q270853", freq=12, entity_vector=[9, 1, -7]) + mykb.add_alias( + alias="No. 8", + entities=["Q270853"], + probabilities=[1.0], + ) + mykb.add_entity(entity="Q7304", freq=12, entity_vector=[6, -4, 3]) + mykb.add_alias( + alias="Mahler", + entities=["Q7304"], + probabilities=[1.0], + ) + return mykb + + # Create the Entity Linker component and add it to the pipeline + entity_linker = nlp.add_pipe("entity_linker", last=True) + entity_linker.set_kb(create_kb) + + # train the NEL pipe + optimizer = nlp.initialize(get_examples=lambda: train_examples) + for i in range(2): + losses = {} + nlp.update(train_examples, sgd=optimizer, losses=losses) + + # Add a custom rule-based component to mimick NER + patterns = [ + {"label": "PERSON", "pattern": [{"LOWER": "mahler"}]}, + {"label": "WORK", "pattern": [{"LOWER": "symphony"}, {"LOWER": "no"}, {"LOWER": "."}, {"LOWER": "8"}]} + ] + ruler = nlp.add_pipe("entity_ruler", before="entity_linker") + ruler.add_patterns(patterns) + + # test the trained model - this should not throw E148 + doc = nlp(text) + assert doc