From 97241a3ed78d7fa41aaea3de30843ca49b0ae6d0 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Wed, 22 May 2019 23:40:10 +0200 Subject: [PATCH] upsampling and batch processing --- .../pipeline/wiki_entity_linking/run_el.py | 12 +- .../pipeline/wiki_entity_linking/train_el.py | 294 +++++++++--------- .../wiki_entity_linking/wiki_nel_pipeline.py | 2 +- 3 files changed, 157 insertions(+), 151 deletions(-) diff --git a/examples/pipeline/wiki_entity_linking/run_el.py b/examples/pipeline/wiki_entity_linking/run_el.py index 6ab7ea75f..273543306 100644 --- a/examples/pipeline/wiki_entity_linking/run_el.py +++ b/examples/pipeline/wiki_entity_linking/run_el.py @@ -78,8 +78,15 @@ def evaluate(predictions, golds, to_print=True): fp = 0 fn = 0 + corrects = 0 + incorrects = 0 + for pred, gold in zip(predictions, golds): is_correct = pred == gold + if is_correct: + corrects += 1 + else: + incorrects += 1 if not pred: if not is_correct: # we don't care about tn fn += 1 @@ -98,12 +105,15 @@ def evaluate(predictions, golds, to_print=True): recall = 100 * tp / (tp + fn + 0.0000001) fscore = 2 * recall * precision / (recall + precision + 0.0000001) + accuracy = corrects / (corrects + incorrects) + if to_print: print("precision", round(precision, 1), "%") print("recall", round(recall, 1), "%") print("Fscore", round(fscore, 1), "%") + print("Accuracy", round(accuracy, 1), "%") - return precision, recall, fscore + return precision, recall, fscore, accuracy def _prepare_pipeline(nlp, kb): diff --git a/examples/pipeline/wiki_entity_linking/train_el.py b/examples/pipeline/wiki_entity_linking/train_el.py index a383a3687..cd6e9de4d 100644 --- a/examples/pipeline/wiki_entity_linking/train_el.py +++ b/examples/pipeline/wiki_entity_linking/train_el.py @@ -6,6 +6,7 @@ import datetime from os import listdir import numpy as np import random +from random import shuffle from thinc.neural._classes.convolution import ExtractWindow from examples.pipeline.wiki_entity_linking import run_el, training_set_creator, kb_creator @@ -26,17 +27,17 @@ from spacy.tokens import Doc class EL_Model: - PRINT_LOSS = False - PRINT_F = True PRINT_TRAIN = False EPS = 0.0000000005 CUTOFF = 0.5 + BATCH_SIZE = 5 + INPUT_DIM = 300 - HIDDEN_1_WIDTH = 256 # 10 + HIDDEN_1_WIDTH = 32 # 10 HIDDEN_2_WIDTH = 32 # 6 - ENTITY_WIDTH = 64 # 4 - ARTICLE_WIDTH = 128 # 8 + DESC_WIDTH = 64 # 4 + ARTICLE_WIDTH = 64 # 8 DROP = 0.1 @@ -48,7 +49,7 @@ class EL_Model: self.kb = kb self._build_cnn(in_width=self.INPUT_DIM, - entity_width=self.ENTITY_WIDTH, + desc_width=self.DESC_WIDTH, article_width=self.ARTICLE_WIDTH, hidden_1_width=self.HIDDEN_1_WIDTH, hidden_2_width=self.HIDDEN_2_WIDTH) @@ -57,121 +58,118 @@ class EL_Model: # raise errors instead of runtime warnings in case of int/float overflow np.seterr(all='raise') - train_inst, train_pos, train_neg, train_texts = self._get_training_data(training_dir, - entity_descr_output, - False, - trainlimit, - balance=True, - to_print=False) + train_ent, train_gold, train_desc, train_article, train_texts = self._get_training_data(training_dir, + entity_descr_output, + False, + trainlimit, + to_print=False) + + train_pos_entities = [k for k,v in train_gold.items() if v] + train_neg_entities = [k for k,v in train_gold.items() if not v] + + train_pos_count = len(train_pos_entities) + train_neg_count = len(train_neg_entities) + + # upsample positives to 50-50 distribution + while train_pos_count < train_neg_count: + train_ent.append(random.choice(train_pos_entities)) + train_pos_count += 1 + + # upsample negatives to 50-50 distribution + while train_neg_count < train_pos_count: + train_ent.append(random.choice(train_neg_entities)) + train_neg_count += 1 + + shuffle(train_ent) + + dev_ent, dev_gold, dev_desc, dev_article, dev_texts = self._get_training_data(training_dir, + entity_descr_output, + True, + devlimit, + to_print=False) + shuffle(dev_ent) + + dev_pos_count = len([g for g in dev_gold.values() if g]) + dev_neg_count = len([g for g in dev_gold.values() if not g]) - dev_inst, dev_pos, dev_neg, dev_texts = self._get_training_data(training_dir, - entity_descr_output, - True, - devlimit, - balance=False, - to_print=False) self._begin_training() print() - self._test_dev(dev_inst, dev_pos, dev_neg, dev_texts, print_string="dev_random", calc_random=True) - self._test_dev(dev_inst, dev_pos, dev_neg, dev_texts, print_string="dev_pre", avg=False) - - instance_pos_count = 0 - instance_neg_count = 0 + self._test_dev(dev_ent, dev_gold, dev_desc, dev_article, dev_texts, print_string="dev_random", calc_random=True) + print() + self._test_dev(dev_ent, dev_gold, dev_desc, dev_article, dev_texts, print_string="dev_pre", avg=True) if to_print: print() - print("Training on", len(train_inst.values()), "articles") - print("Dev test on", len(dev_inst.values()), "articles") + print("Training on", len(train_ent), "entities in", len(train_texts), "articles") + print("Training instances pos/neg", train_pos_count, train_neg_count) + print() + print("Dev test on", len(dev_ent), "entities in", len(dev_texts), "articles") + print("Dev instances pos/neg", dev_pos_count, dev_neg_count) print() print(" CUTOFF", self.CUTOFF) print(" INPUT_DIM", self.INPUT_DIM) print(" HIDDEN_1_WIDTH", self.HIDDEN_1_WIDTH) - print(" ENTITY_WIDTH", self.ENTITY_WIDTH) + print(" DESC_WIDTH", self.DESC_WIDTH) print(" ARTICLE_WIDTH", self.ARTICLE_WIDTH) print(" HIDDEN_2_WIDTH", self.HIDDEN_2_WIDTH) print(" DROP", self.DROP) print() - # TODO: proper batches. Currently 1 article at the time - # TODO shuffle data (currently positive is always followed by several negatives) - article_count = 0 - for article_id, inst_cluster_set in train_inst.items(): - try: - # if to_print: - # print() - # print(article_count, "Training on article", article_id) - article_count += 1 - article_text = train_texts[article_id] - entities = list() - golds = list() - for inst_cluster in inst_cluster_set: - entities.append(train_pos.get(inst_cluster)) - golds.append(float(1.0)) - instance_pos_count += 1 - for neg_entity in train_neg.get(inst_cluster, []): - entities.append(neg_entity) - golds.append(float(0.0)) - instance_neg_count += 1 + start = 0 + stop = min(self.BATCH_SIZE, len(train_ent)) + processed = 0 - self.update(article_text=article_text, entities=entities, golds=golds) + while start < len(train_ent): + next_batch = train_ent[start:stop] - # dev eval - self._test_dev(dev_inst, dev_pos, dev_neg, dev_texts, print_string="dev_inter_avg", avg=True) - except ValueError as e: - print("Error in article id", article_id) + golds = [train_gold[e] for e in next_batch] + descs = [train_desc[e] for e in next_batch] + articles = [train_texts[train_article[e]] for e in next_batch] + + self.update(entities=next_batch, golds=golds, descs=descs, texts=articles) + self._test_dev(dev_ent, dev_gold, dev_desc, dev_article, dev_texts, print_string="dev_inter", avg=True) + + processed += len(next_batch) + + start = start + self.BATCH_SIZE + stop = min(stop + self.BATCH_SIZE, len(train_ent)) if to_print: print() - print("Trained on", instance_pos_count, "/", instance_neg_count, "instances pos/neg") + print("Trained on", processed, "entities in total") - def _test_dev(self, instances, pos, neg, texts_by_id, print_string, avg=False, calc_random=False): - predictions = list() - golds = list() + def _test_dev(self, entities, gold_by_entity, desc_by_entity, article_by_entity, texts_by_id, print_string, avg=True, calc_random=False): + golds = [gold_by_entity[e] for e in entities] - for article_id, inst_cluster_set in instances.items(): - for inst_cluster in inst_cluster_set: - pos_ex = pos.get(inst_cluster) - neg_exs = neg.get(inst_cluster, []) + if calc_random: + predictions = self._predict_random(entities=entities) - article = inst_cluster.split(sep="_")[0] - entity_id = inst_cluster.split(sep="_")[1] - article_doc = self.nlp(texts_by_id[article]) - entities = [self.nlp(pos_ex)] - golds.append(float(1.0)) - for neg_ex in neg_exs: - entities.append(self.nlp(neg_ex)) - golds.append(float(0.0)) - - if calc_random: - preds = self._predict_random(entities=entities) - else: - preds = self._predict(article_doc=article_doc, entities=entities, avg=avg) - predictions.extend(preds) + else: + desc_docs = self.nlp.pipe([desc_by_entity[e] for e in entities]) + article_docs = self.nlp.pipe([texts_by_id[article_by_entity[e]] for e in entities]) + predictions = self._predict(entities=entities, article_docs=article_docs, desc_docs=desc_docs, avg=avg) # TODO: combine with prior probability - p, r, f = run_el.evaluate(predictions, golds, to_print=False) - if self.PRINT_F: - print("p/r/F", print_string, round(p, 1), round(r, 1), round(f, 1)) - + p, r, f, acc = run_el.evaluate(predictions, golds, to_print=False) loss, gradient = self.get_loss(self.model.ops.asarray(predictions), self.model.ops.asarray(golds)) - if self.PRINT_LOSS: - print("loss", print_string, round(loss, 5)) + + print("p/r/F/acc/loss", print_string, round(p, 1), round(r, 1), round(f, 1), round(acc, 2), round(loss, 5)) return loss, p, r, f - def _predict(self, article_doc, entities, avg=False, apply_threshold=True): + def _predict(self, entities, article_docs, desc_docs, avg=True, apply_threshold=True): if avg: with self.article_encoder.use_params(self.sgd_article.averages) \ - and self.entity_encoder.use_params(self.sgd_entity.averages): - doc_encoding = self.article_encoder([article_doc])[0] - entity_encodings = self.entity_encoder(entities) + and self.desc_encoder.use_params(self.sgd_entity.averages): + doc_encodings = self.article_encoder(article_docs) + desc_encodings = self.desc_encoder(desc_docs) else: - doc_encoding = self.article_encoder([article_doc])[0] - entity_encodings = self.entity_encoder(entities) + doc_encodings = self.article_encoder(article_docs) + desc_encodings = self.desc_encoder(desc_docs) - concat_encodings = [list(entity_encodings[i]) + list(doc_encoding) for i in range(len(entities))] + concat_encodings = [list(desc_encodings[i]) + list(doc_encodings[i]) for i in range(len(entities))] np_array_list = np.asarray(concat_encodings) if avg: @@ -189,16 +187,16 @@ class EL_Model: def _predict_random(self, entities, apply_threshold=True): if not apply_threshold: - return [float(random.uniform(0,1)) for e in entities] + return [float(random.uniform(0, 1)) for e in entities] else: - return [float(1.0) if random.uniform(0,1) > self.CUTOFF else float(0.0) for e in entities] + return [float(1.0) if random.uniform(0, 1) > self.CUTOFF else float(0.0) for e in entities] - def _build_cnn(self, in_width, entity_width, article_width, hidden_1_width, hidden_2_width): + def _build_cnn(self, in_width, desc_width, article_width, hidden_1_width, hidden_2_width): with Model.define_operators({">>": chain, "|": concatenate, "**": clone}): - self.entity_encoder = self._encoder(in_width=in_width, hidden_with=hidden_1_width, end_width=entity_width) + self.desc_encoder = self._encoder(in_width=in_width, hidden_with=hidden_1_width, end_width=desc_width) self.article_encoder = self._encoder(in_width=in_width, hidden_with=hidden_1_width, end_width=article_width) - in_width = entity_width + article_width + in_width = desc_width + article_width out_width = hidden_2_width self.model = Affine(out_width, in_width) \ @@ -229,80 +227,78 @@ class EL_Model: def _begin_training(self): self.sgd_article = create_default_optimizer(self.article_encoder.ops) - self.sgd_entity = create_default_optimizer(self.entity_encoder.ops) + self.sgd_entity = create_default_optimizer(self.desc_encoder.ops) self.sgd = create_default_optimizer(self.model.ops) @staticmethod def get_loss(predictions, golds): d_scores = (predictions - golds) + gradient = d_scores.mean() loss = (d_scores ** 2).mean() - return loss, d_scores + return loss, gradient - # TODO: multiple docs/articles - def update(self, article_text, entities, golds, apply_threshold=True): - article_doc = self.nlp(article_text) - # entity_docs = list(self.nlp.pipe(entities)) + def update(self, entities, golds, descs, texts): + golds = self.model.ops.asarray(golds) - for entity, gold in zip(entities, golds): - doc_encodings, bp_doc = self.article_encoder.begin_update([article_doc], drop=self.DROP) - doc_encoding = doc_encodings[0] + desc_docs = self.nlp.pipe(descs) + article_docs = self.nlp.pipe(texts) - entity_doc = self.nlp(entity) - # print("entity_docs", type(entity_doc)) + doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=self.DROP) - entity_encodings, bp_entity = self.entity_encoder.begin_update([entity_doc], drop=self.DROP) - entity_encoding = entity_encodings[0] - # print("entity_encoding", len(entity_encoding), entity_encoding) + desc_encodings, bp_entity = self.desc_encoder.begin_update(desc_docs, drop=self.DROP) - concat_encodings = [list(entity_encoding) + list(doc_encoding)] # for i in range(len(entities)) - # print("concat_encodings", len(concat_encodings), concat_encodings) + concat_encodings = [list(desc_encodings[i]) + list(doc_encodings[i]) for i in range(len(entities))] - prediction, bp_model = self.model.begin_update(np.asarray(concat_encodings), drop=self.DROP) - # predictions = self.model.ops.flatten(predictions) + predictions, bp_model = self.model.begin_update(np.asarray(concat_encodings), drop=self.DROP) + predictions = self.model.ops.flatten(predictions) - # print("prediction", prediction) - # golds = self.model.ops.asarray(golds) - # print("gold", gold) + # print("entities", entities) + # print("predictions", predictions) + # print("golds", golds) - loss, gradient = self.get_loss(prediction, gold) + loss, gradient = self.get_loss(predictions, golds) - if self.PRINT_LOSS and self.PRINT_TRAIN: - print("loss train", round(loss, 5)) + if self.PRINT_TRAIN: + print("loss train", round(loss, 5)) - gradient = float(gradient) - # print("gradient", gradient) - # print("loss", loss) + gradient = float(gradient) + # print("gradient", gradient) + # print("loss", loss) - model_gradient = bp_model(gradient, sgd=self.sgd) - # print("model_gradient", model_gradient) + model_gradient = bp_model(gradient, sgd=self.sgd) + # print("model_gradient", model_gradient) - # concat = entity + doc, but doc is the same within this function (TODO: multiple docs/articles) - doc_gradient = model_gradient[0][self.ENTITY_WIDTH:] - entity_gradients = list() - for x in model_gradient: - entity_gradients.append(list(x[0:self.ENTITY_WIDTH])) + # concat = desc + doc, but doc is the same within this function (TODO: multiple docs/articles) + doc_gradient = model_gradient[0][self.DESC_WIDTH:] + entity_gradients = list() + for x in model_gradient: + entity_gradients.append(list(x[0:self.DESC_WIDTH])) - # print("doc_gradient", doc_gradient) - # print("entity_gradients", entity_gradients) + # print("doc_gradient", doc_gradient) + # print("entity_gradients", entity_gradients) - bp_doc([doc_gradient], sgd=self.sgd_article) - bp_entity(entity_gradients, sgd=self.sgd_entity) + bp_doc([doc_gradient], sgd=self.sgd_article) + bp_entity(entity_gradients, sgd=self.sgd_entity) - def _get_training_data(self, training_dir, entity_descr_output, dev, limit, balance, to_print): + def _get_training_data(self, training_dir, entity_descr_output, dev, limit, to_print): id_to_descr = kb_creator._get_id_to_description(entity_descr_output) correct_entries, incorrect_entries = training_set_creator.read_training_entities(training_output=training_dir, collect_correct=True, collect_incorrect=True) - instance_by_article = dict() local_vectors = list() # TODO: local vectors text_by_article = dict() - pos_entities = dict() - neg_entities = dict() + gold_by_entity = dict() + desc_by_entity = dict() + article_by_entity = dict() + entities = list() cnt = 0 - for f in listdir(training_dir): + next_entity_nr = 0 + files = listdir(training_dir) + shuffle(files) + for f in files: if not limit or cnt < limit: if dev == run_el.is_dev(f): article_id = f.replace(".txt", "") @@ -313,29 +309,29 @@ class EL_Model: with open(os.path.join(training_dir, f), mode="r", encoding='utf8') as file: text = file.read() text_by_article[article_id] = text - instance_by_article[article_id] = set() for mention, entity_pos in correct_entries[article_id].items(): descr = id_to_descr.get(entity_pos) if descr: - instance_by_article[article_id].add(article_id + "_" + mention) - pos_entities[article_id + "_" + mention] = descr + entities.append(next_entity_nr) + gold_by_entity[next_entity_nr] = 1 + desc_by_entity[next_entity_nr] = descr + article_by_entity[next_entity_nr] = article_id + next_entity_nr += 1 for mention, entity_negs in incorrect_entries[article_id].items(): - if not balance or pos_entities.get(article_id + "_" + mention): - neg_count = 0 - for entity_neg in entity_negs: - # if balance, keep only 1 negative instance for each positive instance - if neg_count < 1 or not balance: - descr = id_to_descr.get(entity_neg) - if descr: - descr_list = neg_entities.get(article_id + "_" + mention, []) - descr_list.append(descr) - neg_entities[article_id + "_" + mention] = descr_list - neg_count += 1 + for entity_neg in entity_negs: + descr = id_to_descr.get(entity_neg) + if descr: + entities.append(next_entity_nr) + gold_by_entity[next_entity_nr] = 0 + desc_by_entity[next_entity_nr] = descr + article_by_entity[next_entity_nr] = article_id + next_entity_nr += 1 if to_print: print() print("Processed", cnt, "training articles, dev=" + str(dev)) print() - return instance_by_article, pos_entities, neg_entities, text_by_article + return entities, gold_by_entity, desc_by_entity, article_by_entity, text_by_article + diff --git a/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py b/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py index 319b1e1c8..715282642 100644 --- a/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py +++ b/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py @@ -111,7 +111,7 @@ if __name__ == "__main__": print("STEP 6: training", datetime.datetime.now()) my_nlp = spacy.load('en_core_web_md') trainer = EL_Model(kb=my_kb, nlp=my_nlp) - trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=100, devlimit=20) + trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=400, devlimit=50) print() # STEP 7: apply the EL algorithm on the dev dataset