train and predict per article (saving time for doc encoding)

This commit is contained in:
svlandeg 2019-05-13 17:02:34 +02:00
parent 3b81b00954
commit 4142e8dd1b
2 changed files with 103 additions and 81 deletions

View File

@ -46,11 +46,11 @@ class EL_Model():
dev_instances, dev_pos, dev_neg, dev_doc = self._get_training_data(training_dir, dev_instances, dev_pos, dev_neg, dev_doc = self._get_training_data(training_dir,
entity_descr_output, entity_descr_output,
True, True,
limit, to_print) limit / 10, to_print)
if to_print: if to_print:
print("Training on", len(train_instances), "instance clusters") print("Training on", len(train_instances.values()), "articles")
print("Dev test on", len(dev_instances), "instance clusters") print("Dev test on", len(dev_instances.values()), "articles")
print() print()
self.sgd_entity = self.begin_training(self.entity_encoder) self.sgd_entity = self.begin_training(self.entity_encoder)
@ -60,28 +60,31 @@ class EL_Model():
losses = {} losses = {}
for inst_cluster in train_instances: instance_count = 0
pos_ex = train_pos.get(inst_cluster)
neg_exs = train_neg.get(inst_cluster, [])
if pos_ex and neg_exs: for article_id, inst_cluster_set in train_instances.items():
article = inst_cluster.split(sep="_")[0] article_doc = train_doc[article_id]
entity_id = inst_cluster.split(sep="_")[1] pos_ex_list = list()
article_doc = train_doc[article] neg_exs_list = list()
self.update(article_doc, pos_ex, neg_exs, losses=losses) for inst_cluster in inst_cluster_set:
instance_count += 1
pos_ex_list.append(train_pos.get(inst_cluster))
neg_exs_list.append(train_neg.get(inst_cluster, []))
self.update(article_doc, pos_ex_list, neg_exs_list, losses=losses)
p, r, fscore = self._test_dev(dev_instances, dev_pos, dev_neg, dev_doc) p, r, fscore = self._test_dev(dev_instances, dev_pos, dev_neg, dev_doc)
print(round(fscore, 1)) print(round(fscore, 1))
# TODO
# elif not pos_ex: if to_print:
# print("Weird. Couldn't find pos example for", inst_cluster) print("Trained on", instance_count, "instance clusters")
# elif not neg_exs:
# print("Weird. Couldn't find neg examples for", inst_cluster)
def _test_dev(self, dev_instances, dev_pos, dev_neg, dev_doc): def _test_dev(self, dev_instances, dev_pos, dev_neg, dev_doc):
predictions = list() predictions = list()
golds = list() golds = list()
for inst_cluster in dev_instances: for article_id, inst_cluster_set in dev_instances.items():
for inst_cluster in inst_cluster_set:
pos_ex = dev_pos.get(inst_cluster) pos_ex = dev_pos.get(inst_cluster)
neg_exs = dev_neg.get(inst_cluster, []) neg_exs = dev_neg.get(inst_cluster, [])
ex_to_id = dict() ex_to_id = dict()
@ -103,7 +106,6 @@ class EL_Model():
predictions.append(ex_to_id[best_entity]) predictions.append(ex_to_id[best_entity])
golds.append(ex_to_id[pos_ex]) golds.append(ex_to_id[pos_ex])
# TODO: use lowest_mse and combine with prior probability # TODO: use lowest_mse and combine with prior probability
p, r, F = run_el.evaluate(predictions, golds, to_print=False) p, r, F = run_el.evaluate(predictions, golds, to_print=False)
return p, r, F return p, r, F
@ -161,16 +163,24 @@ class EL_Model():
sgd = create_default_optimizer(model.ops) sgd = create_default_optimizer(model.ops)
return sgd return sgd
def update(self, article_doc, true_entity, false_entities, drop=0., losses=None): def update(self, article_doc, true_entity_list, false_entities_list, drop=0., losses=None):
# TODO: one call only to begin_update ?
entity_diffs = None
doc_diffs = None
doc_encoding, article_bp = self.article_encoder.begin_update([article_doc], drop=drop) doc_encoding, article_bp = self.article_encoder.begin_update([article_doc], drop=drop)
for i, true_entity in enumerate(true_entity_list):
false_entities = false_entities_list[i]
true_entity_encoding, true_entity_bp = self.entity_encoder.begin_update([true_entity], drop=drop) true_entity_encoding, true_entity_bp = self.entity_encoder.begin_update([true_entity], drop=drop)
# print("encoding dim", len(true_entity_encoding[0])) # print("encoding dim", len(true_entity_encoding[0]))
consensus_encoding = self._calculate_consensus(doc_encoding, true_entity_encoding) consensus_encoding = self._calculate_consensus(doc_encoding, true_entity_encoding)
consensus_encoding_t = consensus_encoding.transpose() # consensus_encoding_t = consensus_encoding.transpose()
doc_mse, doc_diffs = self._calculate_similarity(doc_encoding, consensus_encoding) doc_mse, doc_diff = self._calculate_similarity(doc_encoding, consensus_encoding)
entity_mses = list() entity_mses = list()
@ -183,6 +193,13 @@ class EL_Model():
# false_exp_sum = 0 # false_exp_sum = 0
if doc_diffs is not None:
doc_diffs += doc_diff
entity_diffs += true_diffs
else:
doc_diffs = doc_diff
entity_diffs = true_diffs
for false_entity in false_entities: for false_entity in false_entities:
false_entity_encoding, false_entity_bp = self.entity_encoder.begin_update([false_entity], drop=drop) false_entity_encoding, false_entity_bp = self.entity_encoder.begin_update([false_entity], drop=drop)
false_mse, false_diffs = self._calculate_similarity(false_entity_encoding, consensus_encoding) false_mse, false_diffs = self._calculate_similarity(false_entity_encoding, consensus_encoding)
@ -213,7 +230,11 @@ class EL_Model():
# TODO: proper backpropagation taking ranking of elements into account ? # TODO: proper backpropagation taking ranking of elements into account ?
# TODO backpropagation also for negative examples # TODO backpropagation also for negative examples
true_entity_bp(true_diffs, sgd=self.sgd_entity)
if doc_diffs is not None:
doc_diffs = doc_diffs / len(true_entity_list)
true_entity_bp(entity_diffs, sgd=self.sgd_entity)
article_bp(doc_diffs, sgd=self.sgd_article) article_bp(doc_diffs, sgd=self.sgd_article)
@ -268,7 +289,7 @@ class EL_Model():
collect_incorrect=True) collect_incorrect=True)
instances = list() instance_by_doc = dict()
local_vectors = list() # TODO: local vectors local_vectors = list() # TODO: local vectors
doc_by_article = dict() doc_by_article = dict()
pos_entities = dict() pos_entities = dict()
@ -280,18 +301,19 @@ class EL_Model():
if dev == run_el.is_dev(f): if dev == run_el.is_dev(f):
article_id = f.replace(".txt", "") article_id = f.replace(".txt", "")
if cnt % 500 == 0 and to_print: if cnt % 500 == 0 and to_print:
print(datetime.datetime.now(), "processed", cnt, "files in the dev dataset") print(datetime.datetime.now(), "processed", cnt, "files in the training dataset")
cnt += 1 cnt += 1
if article_id not in doc_by_article: if article_id not in doc_by_article:
with open(os.path.join(training_dir, f), mode="r", encoding='utf8') as file: with open(os.path.join(training_dir, f), mode="r", encoding='utf8') as file:
text = file.read() text = file.read()
doc = self.nlp(text) doc = self.nlp(text)
doc_by_article[article_id] = doc doc_by_article[article_id] = doc
instance_by_doc[article_id] = set()
for mention, entity_pos in correct_entries[article_id].items(): for mention, entity_pos in correct_entries[article_id].items():
descr = id_to_descr.get(entity_pos) descr = id_to_descr.get(entity_pos)
if descr: if descr:
instances.append(article_id + "_" + mention) instance_by_doc[article_id].add(article_id + "_" + mention)
doc_descr = self.nlp(descr) doc_descr = self.nlp(descr)
doc_descr._.entity_id = entity_pos doc_descr._.entity_id = entity_pos
pos_entities[article_id + "_" + mention] = doc_descr pos_entities[article_id + "_" + mention] = doc_descr
@ -308,6 +330,6 @@ class EL_Model():
if to_print: if to_print:
print() print()
print("Processed", cnt, "dev articles") print("Processed", cnt, "training articles, dev=" + str(dev))
print() print()
return instances, pos_entities, neg_entities, doc_by_article return instance_by_doc, pos_entities, neg_entities, doc_by_article

View File

@ -111,7 +111,7 @@ if __name__ == "__main__":
print("STEP 6: training ", datetime.datetime.now()) print("STEP 6: training ", datetime.datetime.now())
my_nlp = spacy.load('en_core_web_md') my_nlp = spacy.load('en_core_web_md')
trainer = EL_Model(kb=my_kb, nlp=my_nlp) trainer = EL_Model(kb=my_kb, nlp=my_nlp)
trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, limit=50) trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, limit=500)
print() print()
# STEP 7: apply the EL algorithm on the dev dataset # STEP 7: apply the EL algorithm on the dev dataset