evaluating on dev set during training

This commit is contained in:
svlandeg 2019-05-13 14:26:04 +02:00
parent b6d788064a
commit 3b81b00954
2 changed files with 90 additions and 22 deletions

View File

@ -70,12 +70,10 @@ def is_dev(file_name):
return file_name.endswith("3.txt") return file_name.endswith("3.txt")
def evaluate(predictions, golds): def evaluate(predictions, golds, to_print=True):
if len(predictions) != len(golds): if len(predictions) != len(golds):
raise ValueError("predictions and gold entities should have the same length") raise ValueError("predictions and gold entities should have the same length")
print("Evaluating", len(golds), "entities")
tp = 0 tp = 0
fp = 0 fp = 0
fn = 0 fn = 0
@ -89,17 +87,22 @@ def evaluate(predictions, golds):
else: else:
fp += 1 fp += 1
if to_print:
print("Evaluating", len(golds), "entities")
print("tp", tp) print("tp", tp)
print("fp", fp) print("fp", fp)
print("fn", fn) print("fn", fn)
precision = tp / (tp + fp + 0.0000001) precision = 100 * tp / (tp + fp + 0.0000001)
recall = tp / (tp + fn + 0.0000001) recall = 100 * tp / (tp + fn + 0.0000001)
fscore = 2 * recall * precision / (recall + precision + 0.0000001) fscore = 2 * recall * precision / (recall + precision + 0.0000001)
print("precision", round(100 * precision, 1), "%") if to_print:
print("recall", round(100 * recall, 1), "%") print("precision", round(precision, 1), "%")
print("Fscore", round(100 * fscore, 1), "%") print("recall", round(recall, 1), "%")
print("Fscore", round(fscore, 1), "%")
return precision, recall, fscore
def _prepare_pipeline(nlp, kb): def _prepare_pipeline(nlp, kb):

View File

@ -5,6 +5,7 @@ import os
import datetime import datetime
from os import listdir from os import listdir
import numpy as np import numpy as np
from random import shuffle
from examples.pipeline.wiki_entity_linking import run_el, training_set_creator, kb_creator from examples.pipeline.wiki_entity_linking import run_el, training_set_creator, kb_creator
@ -16,6 +17,8 @@ from thinc.t2v import Pooling, sum_pool, mean_pool
from thinc.t2t import ExtractWindow, ParametricAttention from thinc.t2t import ExtractWindow, ParametricAttention
from thinc.misc import Residual, LayerNorm as LN from thinc.misc import Residual, LayerNorm as LN
from spacy.tokens import Doc
""" TODO: this code needs to be implemented in pipes.pyx""" """ TODO: this code needs to be implemented in pipes.pyx"""
@ -33,34 +36,93 @@ class EL_Model():
self.article_encoder = self._simple_encoder(in_width=300, out_width=96) self.article_encoder = self._simple_encoder(in_width=300, out_width=96)
def train_model(self, training_dir, entity_descr_output, limit=None, to_print=True): def train_model(self, training_dir, entity_descr_output, limit=None, to_print=True):
instances, pos_entities, neg_entities, doc_by_article = self._get_training_data(training_dir, Doc.set_extension("entity_id", default=None)
train_instances, train_pos, train_neg, train_doc = self._get_training_data(training_dir,
entity_descr_output, entity_descr_output,
False,
limit, to_print)
dev_instances, dev_pos, dev_neg, dev_doc = self._get_training_data(training_dir,
entity_descr_output,
True,
limit, to_print) limit, to_print)
if to_print: if to_print:
print("Training on", len(instances), "instance clusters") print("Training on", len(train_instances), "instance clusters")
print("Dev test on", len(dev_instances), "instance clusters")
print() print()
self.sgd_entity = self.begin_training(self.entity_encoder) self.sgd_entity = self.begin_training(self.entity_encoder)
self.sgd_article = self.begin_training(self.article_encoder) self.sgd_article = self.begin_training(self.article_encoder)
self._test_dev(dev_instances, dev_pos, dev_neg, dev_doc)
losses = {} losses = {}
for inst_cluster in instances: for inst_cluster in train_instances:
pos_ex = pos_entities.get(inst_cluster) pos_ex = train_pos.get(inst_cluster)
neg_exs = neg_entities.get(inst_cluster, []) neg_exs = train_neg.get(inst_cluster, [])
if pos_ex and neg_exs: if pos_ex and neg_exs:
article = inst_cluster.split(sep="_")[0] article = inst_cluster.split(sep="_")[0]
entity_id = inst_cluster.split(sep="_")[1] entity_id = inst_cluster.split(sep="_")[1]
article_doc = doc_by_article[article] article_doc = train_doc[article]
self.update(article_doc, pos_ex, neg_exs, losses=losses) self.update(article_doc, pos_ex, neg_exs, losses=losses)
p, r, fscore = self._test_dev(dev_instances, dev_pos, dev_neg, dev_doc)
print(round(fscore, 1))
# TODO # TODO
# elif not pos_ex: # elif not pos_ex:
# print("Weird. Couldn't find pos example for", inst_cluster) # print("Weird. Couldn't find pos example for", inst_cluster)
# elif not neg_exs: # elif not neg_exs:
# print("Weird. Couldn't find neg examples for", inst_cluster) # print("Weird. Couldn't find neg examples for", inst_cluster)
def _test_dev(self, dev_instances, dev_pos, dev_neg, dev_doc):
predictions = list()
golds = list()
for inst_cluster in dev_instances:
pos_ex = dev_pos.get(inst_cluster)
neg_exs = dev_neg.get(inst_cluster, [])
ex_to_id = dict()
if pos_ex and neg_exs:
ex_to_id[pos_ex] = pos_ex._.entity_id
for neg_ex in neg_exs:
ex_to_id[neg_ex] = neg_ex._.entity_id
article = inst_cluster.split(sep="_")[0]
entity_id = inst_cluster.split(sep="_")[1]
article_doc = dev_doc[article]
examples = list(neg_exs)
examples.append(pos_ex)
shuffle(examples)
best_entity, lowest_mse = self._predict(examples, article_doc)
predictions.append(ex_to_id[best_entity])
golds.append(ex_to_id[pos_ex])
# TODO: use lowest_mse and combine with prior probability
p, r, F = run_el.evaluate(predictions, golds, to_print=False)
return p, r, F
def _predict(self, entities, article_doc):
doc_encoding = self.article_encoder([article_doc])
lowest_mse = None
best_entity = None
for entity in entities:
entity_encoding = self.entity_encoder([entity])
mse, _ = self._calculate_similarity(doc_encoding, entity_encoding)
if not best_entity or mse < lowest_mse:
lowest_mse = mse
best_entity = entity
return best_entity, lowest_mse
def _simple_encoder(self, in_width, out_width): def _simple_encoder(self, in_width, out_width):
conv_depth = 1 conv_depth = 1
cnn_maxout_pieces = 3 cnn_maxout_pieces = 3
@ -145,7 +207,7 @@ class EL_Model():
# print("true index", true_index) # print("true index", true_index)
# print("true prob", entity_probs[true_index]) # print("true prob", entity_probs[true_index])
print(true_mse) # print("training loss", true_mse)
# print() # print()
@ -198,13 +260,14 @@ class EL_Model():
def _get_labels(self): def _get_labels(self):
return tuple(self.labels) return tuple(self.labels)
def _get_training_data(self, training_dir, entity_descr_output, limit, to_print): def _get_training_data(self, training_dir, entity_descr_output, dev, limit, to_print):
id_to_descr = kb_creator._get_id_to_description(entity_descr_output) id_to_descr = kb_creator._get_id_to_description(entity_descr_output)
correct_entries, incorrect_entries = training_set_creator.read_training_entities(training_output=training_dir, correct_entries, incorrect_entries = training_set_creator.read_training_entities(training_output=training_dir,
collect_correct=True, collect_correct=True,
collect_incorrect=True) collect_incorrect=True)
instances = list() instances = list()
local_vectors = list() # TODO: local vectors local_vectors = list() # TODO: local vectors
doc_by_article = dict() doc_by_article = dict()
@ -214,7 +277,7 @@ class EL_Model():
cnt = 0 cnt = 0
for f in listdir(training_dir): for f in listdir(training_dir):
if not limit or cnt < limit: if not limit or cnt < limit:
if not run_el.is_dev(f): if dev == run_el.is_dev(f):
article_id = f.replace(".txt", "") article_id = f.replace(".txt", "")
if cnt % 500 == 0 and to_print: if cnt % 500 == 0 and to_print:
print(datetime.datetime.now(), "processed", cnt, "files in the dev dataset") print(datetime.datetime.now(), "processed", cnt, "files in the dev dataset")
@ -230,6 +293,7 @@ class EL_Model():
if descr: if descr:
instances.append(article_id + "_" + mention) instances.append(article_id + "_" + mention)
doc_descr = self.nlp(descr) doc_descr = self.nlp(descr)
doc_descr._.entity_id = entity_pos
pos_entities[article_id + "_" + mention] = doc_descr pos_entities[article_id + "_" + mention] = doc_descr
for mention, entity_negs in incorrect_entries[article_id].items(): for mention, entity_negs in incorrect_entries[article_id].items():
@ -237,6 +301,7 @@ class EL_Model():
descr = id_to_descr.get(entity_neg) descr = id_to_descr.get(entity_neg)
if descr: if descr:
doc_descr = self.nlp(descr) doc_descr = self.nlp(descr)
doc_descr._.entity_id = entity_neg
descr_list = neg_entities.get(article_id + "_" + mention, []) descr_list = neg_entities.get(article_id + "_" + mention, [])
descr_list.append(doc_descr) descr_list.append(doc_descr)
neg_entities[article_id + "_" + mention] = descr_list neg_entities[article_id + "_" + mention] = descr_list