From a0b211c582f25ae601d09f8e56bf713a8926b2c1 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Wed, 20 Jan 2021 10:46:32 +1100 Subject: [PATCH] Suggest refactor of entity linker --- spacy/pipeline/entity_linker.py | 170 +++++++++++++++++--------------- 1 file changed, 90 insertions(+), 80 deletions(-) diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index 5ccfe9940..c6cc45632 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -330,87 +330,97 @@ class EntityLinker(TrainablePipe): return final_kb_ids if isinstance(docs, Doc): docs = [docs] - for i, doc in enumerate(docs): - sentences = [s for s in doc.sents] - if len(doc) > 0: - # Looping through each sentence and each entity - # This may go wrong if there are entities across sentences - which shouldn't happen normally. - for sent_index, sent in enumerate(sentences): - if sent.ents: - # get n_neightbour sentences, clipped to the length of the document - start_sentence = max(0, sent_index - self.n_sents) - end_sentence = min( - len(sentences) - 1, sent_index + self.n_sents - ) - start_token = sentences[start_sentence].start - end_token = sentences[end_sentence].end - sent_doc = doc[start_token:end_token].as_doc() - # currently, the context is the same for each entity in a sentence (should be refined) - xp = self.model.ops.xp - if self.cfg.get("incl_context"): - sentence_encoding = self.model.predict([sent_doc])[0] - sentence_encoding_t = sentence_encoding.T - sentence_norm = xp.linalg.norm(sentence_encoding_t) - for ent in sent.ents: - entity_count += 1 - to_discard = self.cfg.get("labels_discard", []) - if to_discard and ent.label_ in to_discard: - # ignoring this entity - setting to NIL - final_kb_ids.append(self.NIL) - else: - candidates = self.get_candidates(self.kb, ent) - if not candidates: - # no prediction possible for this entity - setting to NIL - final_kb_ids.append(self.NIL) - elif len(candidates) == 1: - # shortcut for efficiency reasons: take the 1 candidate - # TODO: thresholding - final_kb_ids.append(candidates[0].entity_) - else: - random.shuffle(candidates) - # set all prior probabilities to 0 if incl_prior=False - prior_probs = xp.asarray( - [c.prior_prob for c in candidates] - ) - if not self.cfg.get("incl_prior"): - prior_probs = xp.asarray( - [0.0 for _ in candidates] - ) - scores = prior_probs - # add in similarity from the context - if self.cfg.get("incl_context"): - entity_encodings = xp.asarray( - [c.entity_vector for c in candidates] - ) - entity_norm = xp.linalg.norm( - entity_encodings, axis=1 - ) - if len(entity_encodings) != len(prior_probs): - raise RuntimeError( - Errors.E147.format( - method="predict", - msg="vectors not of equal length", - ) - ) - # cosine similarity - sims = xp.dot( - entity_encodings, sentence_encoding_t - ) / (sentence_norm * entity_norm) - if sims.shape != prior_probs.shape: - raise ValueError(Errors.E161) - scores = ( - prior_probs + sims - (prior_probs * sims) - ) - # TODO: thresholding - best_index = scores.argmax().item() - best_candidate = candidates[best_index] - final_kb_ids.append(best_candidate.entity_) - if not (len(final_kb_ids) == entity_count): - err = Errors.E147.format( - method="predict", msg="result variables not of equal length" - ) - raise RuntimeError(err) + sents = self._get_sents(docs) + windows = [self._get_window([sent.as_doc() for sent in sents]) + # currently, the context is the same for each entity in a sentence (should be refined) + if self.cfg.get("incl_context"): + sentence_encodings = self.model.predict(windows) + else: + sentence_encodings = None + final_kb_ids = self._encodings2predictions(sents, sentence_encodings) return final_kb_ids + + def _get_sents(self, docs): + """Get a flat list of sentences that have at least one entity.""" + sents = [] + for doc in docs: + for sent in doc.sents: + if sent.ents: + sents.append(sent) + return sents + + def _get_window(self, sent): + """Get a surrounding window of 3 sentences around a sentence.""" + start = sent[0].nbor(-1).sent.start + end = sent[-1].nbor(1).sent.end + return sent.doc[start:end] + + def _encoding2predictions(self, sents, sent_encodings): + if sent_encodings is not None: + se_T = [se.T for se in sent_encodings] + se_norms = [xp.linalg.norm(se_t) for se_t in sent_encodings] + else: + se_T = [None] * len(sents) + se_norms = [None] * len(sents) + final_kb_ids = [] + for sent in sents: + for ent in sent.ents: + final_kb_ids.append( + self._predict_entity(ent, se_T[i], se_norms[i]) + ) + return final_kb_ids + + def _predict_entity(self, ent, sent_encode, sent_norm): + if ent.label_ in self.cfg.get("labels_discard", []): + # ignoring this entity - setting to NIL + return self.NIL + candidates = self.get_candidates(self.kb, ent) + if not candidates: + # no prediction possible for this entity - setting to NIL + return self.NIL + elif len(candidates) == 1: + # shortcut for efficiency reasons: take the 1 candidate + # TODO: thresholding + return candidates[0].entity_ + else: + random.shuffle(candidates) + scores = self._score_candidates(candidates, sent_encode, sent_norm) + # TODO: thresholding + best_index = scores.argmax().item() + return candidates[best_index].entity_ + + def _score_candidates(self, candidates, se_encode_T, se_norm): + xp = self.model.ops.xp + # set all prior probabilities to 0 if incl_prior=False + if self.cfg.get("incl_prior"): + prior_probs = xp.asarray( + [c.prior_prob for c in candidates] + ) + else: + prior_probs = xp.asarray( + [0.0 for _ in candidates] + ) + if se_encode_T is None or not self.cfg.get("incl_context"): + return prior_probs + + # add in similarity from the context + entity_encodings = xp.asarray( + [c.entity_vector for c in candidates] + ) + entity_norm = xp.linalg.norm( + entity_encodings, axis=1 + ) + # cosine similarity + sims = xp.dot( + entity_encodings, se_encode_T + ) + sims /= (se_norm * entity_norm) + if sims.shape != prior_probs.shape: + raise ValueError(Errors.E161) + scores = ( + prior_probs + sims - (prior_probs * sims) + ) + return scores def set_annotations(self, docs: Iterable[Doc], kb_ids: List[str]) -> None: """Modify a batch of documents, using pre-computed scores.