Suggest refactor of entity linker

This commit is contained in:
Matthew Honnibal 2021-01-20 10:46:32 +11:00
parent 88acbfc050
commit a0b211c582

View File

@ -330,87 +330,97 @@ class EntityLinker(TrainablePipe):
return final_kb_ids return final_kb_ids
if isinstance(docs, Doc): if isinstance(docs, Doc):
docs = [docs] docs = [docs]
for i, doc in enumerate(docs): sents = self._get_sents(docs)
sentences = [s for s in doc.sents] windows = [self._get_window([sent.as_doc() for sent in sents])
if len(doc) > 0:
# Looping through each sentence and each entity
# This may go wrong if there are entities across sentences - which shouldn't happen normally.
for sent_index, sent in enumerate(sentences):
if sent.ents:
# get n_neightbour sentences, clipped to the length of the document
start_sentence = max(0, sent_index - self.n_sents)
end_sentence = min(
len(sentences) - 1, sent_index + self.n_sents
)
start_token = sentences[start_sentence].start
end_token = sentences[end_sentence].end
sent_doc = doc[start_token:end_token].as_doc()
# currently, the context is the same for each entity in a sentence (should be refined) # currently, the context is the same for each entity in a sentence (should be refined)
xp = self.model.ops.xp
if self.cfg.get("incl_context"): if self.cfg.get("incl_context"):
sentence_encoding = self.model.predict([sent_doc])[0] sentence_encodings = self.model.predict(windows)
sentence_encoding_t = sentence_encoding.T
sentence_norm = xp.linalg.norm(sentence_encoding_t)
for ent in sent.ents:
entity_count += 1
to_discard = self.cfg.get("labels_discard", [])
if to_discard and ent.label_ in to_discard:
# ignoring this entity - setting to NIL
final_kb_ids.append(self.NIL)
else: else:
sentence_encodings = None
final_kb_ids = self._encodings2predictions(sents, sentence_encodings)
return final_kb_ids
def _get_sents(self, docs):
"""Get a flat list of sentences that have at least one entity."""
sents = []
for doc in docs:
for sent in doc.sents:
if sent.ents:
sents.append(sent)
return sents
def _get_window(self, sent):
"""Get a surrounding window of 3 sentences around a sentence."""
start = sent[0].nbor(-1).sent.start
end = sent[-1].nbor(1).sent.end
return sent.doc[start:end]
def _encoding2predictions(self, sents, sent_encodings):
if sent_encodings is not None:
se_T = [se.T for se in sent_encodings]
se_norms = [xp.linalg.norm(se_t) for se_t in sent_encodings]
else:
se_T = [None] * len(sents)
se_norms = [None] * len(sents)
final_kb_ids = []
for sent in sents:
for ent in sent.ents:
final_kb_ids.append(
self._predict_entity(ent, se_T[i], se_norms[i])
)
return final_kb_ids
def _predict_entity(self, ent, sent_encode, sent_norm):
if ent.label_ in self.cfg.get("labels_discard", []):
# ignoring this entity - setting to NIL
return self.NIL
candidates = self.get_candidates(self.kb, ent) candidates = self.get_candidates(self.kb, ent)
if not candidates: if not candidates:
# no prediction possible for this entity - setting to NIL # no prediction possible for this entity - setting to NIL
final_kb_ids.append(self.NIL) return self.NIL
elif len(candidates) == 1: elif len(candidates) == 1:
# shortcut for efficiency reasons: take the 1 candidate # shortcut for efficiency reasons: take the 1 candidate
# TODO: thresholding # TODO: thresholding
final_kb_ids.append(candidates[0].entity_) return candidates[0].entity_
else: else:
random.shuffle(candidates) random.shuffle(candidates)
scores = self._score_candidates(candidates, sent_encode, sent_norm)
# TODO: thresholding
best_index = scores.argmax().item()
return candidates[best_index].entity_
def _score_candidates(self, candidates, se_encode_T, se_norm):
xp = self.model.ops.xp
# set all prior probabilities to 0 if incl_prior=False # set all prior probabilities to 0 if incl_prior=False
if self.cfg.get("incl_prior"):
prior_probs = xp.asarray( prior_probs = xp.asarray(
[c.prior_prob for c in candidates] [c.prior_prob for c in candidates]
) )
if not self.cfg.get("incl_prior"): else:
prior_probs = xp.asarray( prior_probs = xp.asarray(
[0.0 for _ in candidates] [0.0 for _ in candidates]
) )
scores = prior_probs if se_encode_T is None or not self.cfg.get("incl_context"):
return prior_probs
# add in similarity from the context # add in similarity from the context
if self.cfg.get("incl_context"):
entity_encodings = xp.asarray( entity_encodings = xp.asarray(
[c.entity_vector for c in candidates] [c.entity_vector for c in candidates]
) )
entity_norm = xp.linalg.norm( entity_norm = xp.linalg.norm(
entity_encodings, axis=1 entity_encodings, axis=1
) )
if len(entity_encodings) != len(prior_probs):
raise RuntimeError(
Errors.E147.format(
method="predict",
msg="vectors not of equal length",
)
)
# cosine similarity # cosine similarity
sims = xp.dot( sims = xp.dot(
entity_encodings, sentence_encoding_t entity_encodings, se_encode_T
) / (sentence_norm * entity_norm) )
sims /= (se_norm * entity_norm)
if sims.shape != prior_probs.shape: if sims.shape != prior_probs.shape:
raise ValueError(Errors.E161) raise ValueError(Errors.E161)
scores = ( scores = (
prior_probs + sims - (prior_probs * sims) prior_probs + sims - (prior_probs * sims)
) )
# TODO: thresholding return scores
best_index = scores.argmax().item()
best_candidate = candidates[best_index]
final_kb_ids.append(best_candidate.entity_)
if not (len(final_kb_ids) == entity_count):
err = Errors.E147.format(
method="predict", msg="result variables not of equal length"
)
raise RuntimeError(err)
return final_kb_ids
def set_annotations(self, docs: Iterable[Doc], kb_ids: List[str]) -> None: def set_annotations(self, docs: Iterable[Doc], kb_ids: List[str]) -> None:
"""Modify a batch of documents, using pre-computed scores. """Modify a batch of documents, using pre-computed scores.