speed up predictions

This commit is contained in:
svlandeg 2019-06-11 14:18:20 +02:00
parent fe1ed432ef
commit 66813a1fdc

View File

@ -115,8 +115,8 @@ def run_pipeline():
# STEP 6: create the entity linking pipe
if train_pipe:
train_limit = 5
dev_limit = 2
train_limit = 100
dev_limit = 20
print("Training on", train_limit, "articles")
print("Dev testing on", dev_limit, "articles")
print()
@ -155,22 +155,25 @@ def run_pipeline():
losses=losses,
)
# print(" measuring accuracy 1-1")
el_pipe.context_weight = 1
el_pipe.prior_weight = 1
dev_acc_1_1 = _measure_accuracy(dev_data, nlp)
train_acc_1_1 = _measure_accuracy(train_data, nlp)
dev_acc_1_1 = _measure_accuracy(dev_data, el_pipe)
train_acc_1_1 = _measure_accuracy(train_data, el_pipe)
# print(" measuring accuracy 0-1")
el_pipe.context_weight = 0
el_pipe.prior_weight = 1
dev_acc_0_1 = _measure_accuracy(dev_data, nlp)
train_acc_0_1 = _measure_accuracy(train_data, nlp)
dev_acc_0_1 = _measure_accuracy(dev_data, el_pipe)
train_acc_0_1 = _measure_accuracy(train_data, el_pipe)
# print(" measuring accuracy 1-0")
el_pipe.context_weight = 1
el_pipe.prior_weight = 0
dev_acc_1_0 = _measure_accuracy(dev_data, nlp)
train_acc_1_0 = _measure_accuracy(train_data, nlp)
dev_acc_1_0 = _measure_accuracy(dev_data, el_pipe)
train_acc_1_0 = _measure_accuracy(train_data, el_pipe)
print("Epoch, train loss, train/dev acc, 1-1, 0-1, 1-0:", itn, losses['entity_linker'],
print("Epoch, train loss, train/dev acc, 1-1, 0-1, 1-0:", itn, round(losses['entity_linker'], 2),
round(train_acc_1_1, 2), round(train_acc_0_1, 2), round(train_acc_1_0, 2), "/",
round(dev_acc_1_1, 2), round(dev_acc_0_1, 2), round(dev_acc_1_0, 2))
@ -184,12 +187,13 @@ def run_pipeline():
print("STOP", datetime.datetime.now())
def _measure_accuracy(data, nlp):
def _measure_accuracy(data, el_pipe):
correct = 0
incorrect = 0
texts = [d.text for d, g in data]
docs = list(nlp.pipe(texts))
docs = [d for d, g in data]
docs = el_pipe.pipe(docs)
golds = [g for d, g in data]
for doc, gold in zip(docs, golds):