improve speed of prediction loop

This commit is contained in:
svlandeg 2019-06-26 13:53:10 +02:00
parent bee23cd8af
commit 1de61f68d6
2 changed files with 42 additions and 30 deletions

View File

@ -76,7 +76,7 @@ def run_pipeline():
# write the NLP object, read back in and test again
to_write_nlp = False
to_read_nlp = True
to_read_nlp = False
test_from_file = True
# STEP 1 : create prior probabilities from WP (run only once)
@ -252,9 +252,17 @@ def run_pipeline():
print("reading from", NLP_2_DIR)
nlp_3 = spacy.load(NLP_2_DIR)
print("running toy example with NLP 3")
run_el_toy_example(nlp=nlp_3)
# testing performance with an NLP model from file
if test_from_file:
dev_limit = 5000
dev_data = training_set_creator.read_training(nlp=nlp_3,
nlp_2 = spacy.load(NLP_1_DIR)
nlp_3 = spacy.load(NLP_2_DIR)
el_pipe = nlp_3.get_pipe("entity_linker")
dev_limit = 10000
dev_data = training_set_creator.read_training(nlp=nlp_2,
training_dir=TRAINING_DIR,
dev=True,
limit=dev_limit)
@ -262,12 +270,9 @@ def run_pipeline():
print("Dev testing from file on", len(dev_data), "articles")
print()
dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data)
dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data, el_pipe=el_pipe)
print("dev acc combo avg:", round(dev_acc_combo, 3),
[(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()])
else:
print("running toy example with NLP 3")
run_el_toy_example(nlp=nlp_3)
print()
print("STOP", datetime.datetime.now())
@ -280,7 +285,9 @@ def _measure_accuracy(data, el_pipe=None):
docs = [d for d, g in data if len(d) > 0]
if el_pipe is not None:
docs = el_pipe.pipe(docs)
print("applying el_pipe", datetime.datetime.now())
docs = list(el_pipe.pipe(docs, batch_size=10000000000))
print("done applying el_pipe", datetime.datetime.now())
golds = [g for d, g in data if len(d) > 0]
for doc, gold in zip(docs, golds):

View File

@ -3,8 +3,6 @@
# coding: utf8
from __future__ import unicode_literals
import numpy as np
import numpy
import srsly
from collections import OrderedDict
@ -12,6 +10,7 @@ from thinc.api import chain
from thinc.v2v import Affine, Maxout, Softmax
from thinc.misc import LayerNorm
from thinc.neural.util import to_categorical
from thinc.neural.util import get_array_module
from ..cli.pretrain import get_cossim_loss
from .functions import merge_subtokens
@ -1151,7 +1150,7 @@ class EntityLinker(Pipe):
if len(entity_encodings) > 0:
context_encodings, bp_context = self.model.begin_update(context_docs, drop=drop)
entity_encodings = np.asarray(entity_encodings, dtype=np.float32)
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
loss, d_scores = self.get_loss(scores=context_encodings, golds=entity_encodings, docs=None)
bp_context(d_scores, sgd=sgd)
@ -1192,24 +1191,30 @@ class EntityLinker(Pipe):
if isinstance(docs, Doc):
docs = [docs]
context_encodings = self.model(docs)
xp = get_array_module(context_encodings)
for i, doc in enumerate(docs):
if len(doc) > 0:
context_encoding = self.model([doc])
context_enc_t = np.transpose(context_encoding)
context_encoding = context_encodings[i]
context_enc_t = context_encoding.T
norm_1 = xp.linalg.norm(context_enc_t)
for ent in doc.ents:
candidates = self.kb.get_candidates(ent.text)
if candidates:
scores = []
for c in candidates:
prior_prob = c.prior_prob * self.prior_weight
kb_id = c.entity_
entity_encoding = c.entity_vector
sim = float(cosine(np.asarray([entity_encoding]), context_enc_t)) * self.context_weight
score = prior_prob + sim - (prior_prob*sim)
scores.append(score)
prior_probs = xp.asarray([c.prior_prob for c in candidates])
prior_probs *= self.prior_weight
entity_encodings = xp.asarray([c.entity_vector for c in candidates])
norm_2 = xp.linalg.norm(entity_encodings, axis=1)
# cosine similarity
sims = xp.dot(entity_encodings, context_enc_t) / (norm_1 * norm_2)
sims *= self.context_weight
scores = prior_probs + sims - (prior_probs*sims)
best_index = scores.argmax()
# TODO: thresholding
best_index = scores.index(max(scores))
best_candidate = candidates[best_index]
final_entities.append(ent)
final_kb_ids.append(best_candidate.entity_)