fix entity linker (cf PR #5548)

This commit is contained in:
svlandeg 2020-06-20 21:47:23 +02:00
parent dc069e90b3
commit c9242e9bf4

View File

@ -1219,13 +1219,11 @@ class EntityLinker(Pipe):
sent_doc = doc[start_token:end_token].as_doc() sent_doc = doc[start_token:end_token].as_doc()
sentence_docs.append(sent_doc) sentence_docs.append(sent_doc)
sentence_encodings, bp_context = self.model.begin_update(sentence_docs, drop=drop)
loss, d_scores = self.get_similarity_loss(scores=sentence_encodings, golds=golds, docs=None)
bp_context(d_scores, sgd=sgd)
set_dropout_rate(self.model, drop) set_dropout_rate(self.model, drop)
sentence_encodings, bp_context = self.model.begin_update(sentence_docs) sentence_encodings, bp_context = self.model.begin_update(sentence_docs)
loss, d_scores = self.get_similarity_loss(scores=sentence_encodings, golds=golds) loss, d_scores = self.get_similarity_loss(scores=sentence_encodings, golds=golds)
bp_context(d_scores) bp_context(d_scores)
if sgd is not None: if sgd is not None:
self.model.finish_update(sgd) self.model.finish_update(sgd)
@ -1306,22 +1304,28 @@ class EntityLinker(Pipe):
if isinstance(docs, Doc): if isinstance(docs, Doc):
docs = [docs] docs = [docs]
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
sentences = [s for s in doc.sents] sentences = [s for s in doc.sents]
if len(doc) > 0: if len(doc) > 0:
# Looping through each sentence and each entity # Looping through each sentence and each entity
# This may go wrong if there are entities across sentences - which shouldn't happen normally. # This may go wrong if there are entities across sentences - which shouldn't happen normally.
for sent in doc.sents: for sent_index, sent in enumerate(sentences):
sent_doc = sent.as_doc() # get n_neightbour sentences, clipped to the length of the document
start_sentence = max(0, sent_index - self.n_sents)
end_sentence = min(len(sentences) -1, sent_index + self.n_sents)
start_token = sentences[start_sentence].start
end_token = sentences[end_sentence].end
sent_doc = doc[start_token:end_token].as_doc()
# currently, the context is the same for each entity in a sentence (should be refined) # currently, the context is the same for each entity in a sentence (should be refined)
sentence_encoding = self.model.predict([sent_doc])[0] sentence_encoding = self.model.predict([sent_doc])[0]
xp = get_array_module(sentence_encoding) xp = get_array_module(sentence_encoding)
sentence_encoding_t = sentence_encoding.T sentence_encoding_t = sentence_encoding.T
sentence_norm = xp.linalg.norm(sentence_encoding_t) sentence_norm = xp.linalg.norm(sentence_encoding_t)
for ent in sent_doc.ents: for ent in sent.ents:
entity_count += 1 entity_count += 1
to_discard = self.cfg.get("labels_discard", []) to_discard = self.cfg.get("labels_discard", [])
@ -1337,21 +1341,11 @@ class EntityLinker(Pipe):
final_kb_ids.append(self.NIL) final_kb_ids.append(self.NIL)
final_tensors.append(sentence_encoding) final_tensors.append(sentence_encoding)
sent_doc = doc[sent.start:sent.end].as_doc() elif len(candidates) == 1:
# shortcut for efficiency reasons: take the 1 candidate
# currently, the context is the same for each entity in a sentence (should be refined) # TODO: thresholding
sentence_encoding = self.model([sent_doc])[0] final_kb_ids.append(candidates[0].entity_)
xp = get_array_module(sentence_encoding)
sentence_encoding_t = sentence_encoding.T
sentence_norm = xp.linalg.norm(sentence_encoding_t)
for ent in sent.ents:
entity_count += 1
to_discard = self.cfg.get("labels_discard", [])
if to_discard and ent.label_ in to_discard:
# ignoring this entity - setting to NIL
final_kb_ids.append(self.NIL)
final_tensors.append(sentence_encoding) final_tensors.append(sentence_encoding)
else: else: