speeding up training

This commit is contained in:
svlandeg 2019-06-12 13:37:05 +02:00
parent 66813a1fdc
commit 6521cfa132
2 changed files with 57 additions and 43 deletions

View File

@ -115,6 +115,7 @@ def run_pipeline():
# STEP 6: create the entity linking pipe # STEP 6: create the entity linking pipe
if train_pipe: if train_pipe:
print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
train_limit = 100 train_limit = 100
dev_limit = 20 dev_limit = 20
print("Training on", train_limit, "articles") print("Training on", train_limit, "articles")
@ -131,7 +132,7 @@ def run_pipeline():
training_dir=TRAINING_DIR, training_dir=TRAINING_DIR,
dev=True, dev=True,
limit=dev_limit, limit=dev_limit,
to_print=False) to_print=False)
el_pipe = nlp.create_pipe(name='entity_linker', config={"kb": my_kb, "doc_cutoff": DOC_CHAR_CUTOFF}) el_pipe = nlp.create_pipe(name='entity_linker', config={"kb": my_kb, "doc_cutoff": DOC_CHAR_CUTOFF})
nlp.add_pipe(el_pipe, last=True) nlp.add_pipe(el_pipe, last=True)
@ -147,35 +148,40 @@ def run_pipeline():
with nlp.disable_pipes(*other_pipes): with nlp.disable_pipes(*other_pipes):
for batch in batches: for batch in batches:
docs, golds = zip(*batch) try:
nlp.update( docs, golds = zip(*batch)
docs, nlp.update(
golds, docs,
drop=DROPOUT, golds,
losses=losses, drop=DROPOUT,
) losses=losses,
)
except Exception as e:
print("Error updating batch", e)
# print(" measuring accuracy 1-1") print("Epoch, train loss", itn, round(losses['entity_linker'], 2))
el_pipe.context_weight = 1
el_pipe.prior_weight = 1
dev_acc_1_1 = _measure_accuracy(dev_data, el_pipe)
train_acc_1_1 = _measure_accuracy(train_data, el_pipe)
# print(" measuring accuracy 0-1") # baseline using only prior probabilities
el_pipe.context_weight = 0 el_pipe.context_weight = 0
el_pipe.prior_weight = 1 el_pipe.prior_weight = 1
dev_acc_0_1 = _measure_accuracy(dev_data, el_pipe) dev_acc_0_1 = _measure_accuracy(dev_data, el_pipe)
train_acc_0_1 = _measure_accuracy(train_data, el_pipe) train_acc_0_1 = _measure_accuracy(train_data, el_pipe)
# print(" measuring accuracy 1-0") # print(" measuring accuracy 1-1")
el_pipe.context_weight = 1 el_pipe.context_weight = 1
el_pipe.prior_weight = 0 el_pipe.prior_weight = 1
dev_acc_1_0 = _measure_accuracy(dev_data, el_pipe) dev_acc_1_1 = _measure_accuracy(dev_data, el_pipe)
train_acc_1_0 = _measure_accuracy(train_data, el_pipe) train_acc_1_1 = _measure_accuracy(train_data, el_pipe)
print("Epoch, train loss, train/dev acc, 1-1, 0-1, 1-0:", itn, round(losses['entity_linker'], 2), # print(" measuring accuracy 1-0")
round(train_acc_1_1, 2), round(train_acc_0_1, 2), round(train_acc_1_0, 2), "/", el_pipe.context_weight = 1
round(dev_acc_1_1, 2), round(dev_acc_0_1, 2), round(dev_acc_1_0, 2)) el_pipe.prior_weight = 0
dev_acc_1_0 = _measure_accuracy(dev_data, el_pipe)
train_acc_1_0 = _measure_accuracy(train_data, el_pipe)
print("train/dev acc, 1-1, 0-1, 1-0:" ,
round(train_acc_1_1, 2), round(train_acc_0_1, 2), round(train_acc_1_0, 2), "/",
round(dev_acc_1_1, 2), round(dev_acc_0_1, 2), round(dev_acc_1_0, 2))
# test Entity Linker # test Entity Linker
if to_test_pipeline: if to_test_pipeline:
@ -193,26 +199,29 @@ def _measure_accuracy(data, el_pipe):
docs = [d for d, g in data] docs = [d for d, g in data]
docs = el_pipe.pipe(docs) docs = el_pipe.pipe(docs)
golds = [g for d, g in data] golds = [g for d, g in data]
for doc, gold in zip(docs, golds): for doc, gold in zip(docs, golds):
correct_entries_per_article = dict() try:
for entity in gold.links: correct_entries_per_article = dict()
start, end, gold_kb = entity for entity in gold.links:
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb start, end, gold_kb = entity
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
for ent in doc.ents: for ent in doc.ents:
if ent.label_ == "PERSON": # TODO: expand to other types if ent.label_ == "PERSON": # TODO: expand to other types
pred_entity = ent.kb_id_ pred_entity = ent.kb_id_
start = ent.start start = ent.start
end = ent.end end = ent.end
gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None) gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None)
if gold_entity is not None: if gold_entity is not None:
if gold_entity == pred_entity: if gold_entity == pred_entity:
correct += 1 correct += 1
else: else:
incorrect += 1 incorrect += 1
except Exception as e:
print("Error assessing accuracy", e)
if correct == incorrect == 0: if correct == incorrect == 0:
return 0 return 0

View File

@ -1220,8 +1220,13 @@ class EntityLinker(Pipe):
def predict(self, docs): def predict(self, docs):
self.require_model() self.require_model()
if isinstance(docs, Doc):
docs = [docs]
final_entities = list() final_entities = list()
final_kb_ids = list() final_kb_ids = list()
for i, article_doc in enumerate(docs): for i, article_doc in enumerate(docs):
doc_encoding = self.article_encoder([article_doc]) doc_encoding = self.article_encoder([article_doc])
for ent in article_doc.ents: for ent in article_doc.ents: