mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 02:06:31 +03:00
Suggest refactor of entity linker
This commit is contained in:
parent
88acbfc050
commit
a0b211c582
|
@ -330,87 +330,97 @@ class EntityLinker(TrainablePipe):
|
||||||
return final_kb_ids
|
return final_kb_ids
|
||||||
if isinstance(docs, Doc):
|
if isinstance(docs, Doc):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
for i, doc in enumerate(docs):
|
sents = self._get_sents(docs)
|
||||||
sentences = [s for s in doc.sents]
|
windows = [self._get_window([sent.as_doc() for sent in sents])
|
||||||
if len(doc) > 0:
|
|
||||||
# Looping through each sentence and each entity
|
|
||||||
# This may go wrong if there are entities across sentences - which shouldn't happen normally.
|
|
||||||
for sent_index, sent in enumerate(sentences):
|
|
||||||
if sent.ents:
|
|
||||||
# get n_neightbour sentences, clipped to the length of the document
|
|
||||||
start_sentence = max(0, sent_index - self.n_sents)
|
|
||||||
end_sentence = min(
|
|
||||||
len(sentences) - 1, sent_index + self.n_sents
|
|
||||||
)
|
|
||||||
start_token = sentences[start_sentence].start
|
|
||||||
end_token = sentences[end_sentence].end
|
|
||||||
sent_doc = doc[start_token:end_token].as_doc()
|
|
||||||
# currently, the context is the same for each entity in a sentence (should be refined)
|
# currently, the context is the same for each entity in a sentence (should be refined)
|
||||||
xp = self.model.ops.xp
|
|
||||||
if self.cfg.get("incl_context"):
|
if self.cfg.get("incl_context"):
|
||||||
sentence_encoding = self.model.predict([sent_doc])[0]
|
sentence_encodings = self.model.predict(windows)
|
||||||
sentence_encoding_t = sentence_encoding.T
|
|
||||||
sentence_norm = xp.linalg.norm(sentence_encoding_t)
|
|
||||||
for ent in sent.ents:
|
|
||||||
entity_count += 1
|
|
||||||
to_discard = self.cfg.get("labels_discard", [])
|
|
||||||
if to_discard and ent.label_ in to_discard:
|
|
||||||
# ignoring this entity - setting to NIL
|
|
||||||
final_kb_ids.append(self.NIL)
|
|
||||||
else:
|
else:
|
||||||
|
sentence_encodings = None
|
||||||
|
final_kb_ids = self._encodings2predictions(sents, sentence_encodings)
|
||||||
|
return final_kb_ids
|
||||||
|
|
||||||
|
def _get_sents(self, docs):
|
||||||
|
"""Get a flat list of sentences that have at least one entity."""
|
||||||
|
sents = []
|
||||||
|
for doc in docs:
|
||||||
|
for sent in doc.sents:
|
||||||
|
if sent.ents:
|
||||||
|
sents.append(sent)
|
||||||
|
return sents
|
||||||
|
|
||||||
|
def _get_window(self, sent):
|
||||||
|
"""Get a surrounding window of 3 sentences around a sentence."""
|
||||||
|
start = sent[0].nbor(-1).sent.start
|
||||||
|
end = sent[-1].nbor(1).sent.end
|
||||||
|
return sent.doc[start:end]
|
||||||
|
|
||||||
|
def _encoding2predictions(self, sents, sent_encodings):
|
||||||
|
if sent_encodings is not None:
|
||||||
|
se_T = [se.T for se in sent_encodings]
|
||||||
|
se_norms = [xp.linalg.norm(se_t) for se_t in sent_encodings]
|
||||||
|
else:
|
||||||
|
se_T = [None] * len(sents)
|
||||||
|
se_norms = [None] * len(sents)
|
||||||
|
final_kb_ids = []
|
||||||
|
for sent in sents:
|
||||||
|
for ent in sent.ents:
|
||||||
|
final_kb_ids.append(
|
||||||
|
self._predict_entity(ent, se_T[i], se_norms[i])
|
||||||
|
)
|
||||||
|
return final_kb_ids
|
||||||
|
|
||||||
|
def _predict_entity(self, ent, sent_encode, sent_norm):
|
||||||
|
if ent.label_ in self.cfg.get("labels_discard", []):
|
||||||
|
# ignoring this entity - setting to NIL
|
||||||
|
return self.NIL
|
||||||
candidates = self.get_candidates(self.kb, ent)
|
candidates = self.get_candidates(self.kb, ent)
|
||||||
if not candidates:
|
if not candidates:
|
||||||
# no prediction possible for this entity - setting to NIL
|
# no prediction possible for this entity - setting to NIL
|
||||||
final_kb_ids.append(self.NIL)
|
return self.NIL
|
||||||
elif len(candidates) == 1:
|
elif len(candidates) == 1:
|
||||||
# shortcut for efficiency reasons: take the 1 candidate
|
# shortcut for efficiency reasons: take the 1 candidate
|
||||||
# TODO: thresholding
|
# TODO: thresholding
|
||||||
final_kb_ids.append(candidates[0].entity_)
|
return candidates[0].entity_
|
||||||
else:
|
else:
|
||||||
random.shuffle(candidates)
|
random.shuffle(candidates)
|
||||||
|
scores = self._score_candidates(candidates, sent_encode, sent_norm)
|
||||||
|
# TODO: thresholding
|
||||||
|
best_index = scores.argmax().item()
|
||||||
|
return candidates[best_index].entity_
|
||||||
|
|
||||||
|
def _score_candidates(self, candidates, se_encode_T, se_norm):
|
||||||
|
xp = self.model.ops.xp
|
||||||
# set all prior probabilities to 0 if incl_prior=False
|
# set all prior probabilities to 0 if incl_prior=False
|
||||||
|
if self.cfg.get("incl_prior"):
|
||||||
prior_probs = xp.asarray(
|
prior_probs = xp.asarray(
|
||||||
[c.prior_prob for c in candidates]
|
[c.prior_prob for c in candidates]
|
||||||
)
|
)
|
||||||
if not self.cfg.get("incl_prior"):
|
else:
|
||||||
prior_probs = xp.asarray(
|
prior_probs = xp.asarray(
|
||||||
[0.0 for _ in candidates]
|
[0.0 for _ in candidates]
|
||||||
)
|
)
|
||||||
scores = prior_probs
|
if se_encode_T is None or not self.cfg.get("incl_context"):
|
||||||
|
return prior_probs
|
||||||
|
|
||||||
# add in similarity from the context
|
# add in similarity from the context
|
||||||
if self.cfg.get("incl_context"):
|
|
||||||
entity_encodings = xp.asarray(
|
entity_encodings = xp.asarray(
|
||||||
[c.entity_vector for c in candidates]
|
[c.entity_vector for c in candidates]
|
||||||
)
|
)
|
||||||
entity_norm = xp.linalg.norm(
|
entity_norm = xp.linalg.norm(
|
||||||
entity_encodings, axis=1
|
entity_encodings, axis=1
|
||||||
)
|
)
|
||||||
if len(entity_encodings) != len(prior_probs):
|
|
||||||
raise RuntimeError(
|
|
||||||
Errors.E147.format(
|
|
||||||
method="predict",
|
|
||||||
msg="vectors not of equal length",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# cosine similarity
|
# cosine similarity
|
||||||
sims = xp.dot(
|
sims = xp.dot(
|
||||||
entity_encodings, sentence_encoding_t
|
entity_encodings, se_encode_T
|
||||||
) / (sentence_norm * entity_norm)
|
)
|
||||||
|
sims /= (se_norm * entity_norm)
|
||||||
if sims.shape != prior_probs.shape:
|
if sims.shape != prior_probs.shape:
|
||||||
raise ValueError(Errors.E161)
|
raise ValueError(Errors.E161)
|
||||||
scores = (
|
scores = (
|
||||||
prior_probs + sims - (prior_probs * sims)
|
prior_probs + sims - (prior_probs * sims)
|
||||||
)
|
)
|
||||||
# TODO: thresholding
|
return scores
|
||||||
best_index = scores.argmax().item()
|
|
||||||
best_candidate = candidates[best_index]
|
|
||||||
final_kb_ids.append(best_candidate.entity_)
|
|
||||||
if not (len(final_kb_ids) == entity_count):
|
|
||||||
err = Errors.E147.format(
|
|
||||||
method="predict", msg="result variables not of equal length"
|
|
||||||
)
|
|
||||||
raise RuntimeError(err)
|
|
||||||
return final_kb_ids
|
|
||||||
|
|
||||||
def set_annotations(self, docs: Iterable[Doc], kb_ids: List[str]) -> None:
|
def set_annotations(self, docs: Iterable[Doc], kb_ids: List[str]) -> None:
|
||||||
"""Modify a batch of documents, using pre-computed scores.
|
"""Modify a batch of documents, using pre-computed scores.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user