diff --git a/examples/pipeline/wiki_entity_linking/train_el.py b/examples/pipeline/wiki_entity_linking/train_el.py index 7fd301e02..1e2c25ffc 100644 --- a/examples/pipeline/wiki_entity_linking/train_el.py +++ b/examples/pipeline/wiki_entity_linking/train_el.py @@ -46,11 +46,11 @@ class EL_Model(): dev_instances, dev_pos, dev_neg, dev_doc = self._get_training_data(training_dir, entity_descr_output, True, - limit, to_print) + limit / 10, to_print) if to_print: - print("Training on", len(train_instances), "instance clusters") - print("Dev test on", len(dev_instances), "instance clusters") + print("Training on", len(train_instances.values()), "articles") + print("Dev test on", len(dev_instances.values()), "articles") print() self.sgd_entity = self.begin_training(self.entity_encoder) @@ -60,49 +60,51 @@ class EL_Model(): losses = {} - for inst_cluster in train_instances: - pos_ex = train_pos.get(inst_cluster) - neg_exs = train_neg.get(inst_cluster, []) + instance_count = 0 + + for article_id, inst_cluster_set in train_instances.items(): + article_doc = train_doc[article_id] + pos_ex_list = list() + neg_exs_list = list() + for inst_cluster in inst_cluster_set: + instance_count += 1 + pos_ex_list.append(train_pos.get(inst_cluster)) + neg_exs_list.append(train_neg.get(inst_cluster, [])) + + self.update(article_doc, pos_ex_list, neg_exs_list, losses=losses) + p, r, fscore = self._test_dev(dev_instances, dev_pos, dev_neg, dev_doc) + print(round(fscore, 1)) + + if to_print: + print("Trained on", instance_count, "instance clusters") - if pos_ex and neg_exs: - article = inst_cluster.split(sep="_")[0] - entity_id = inst_cluster.split(sep="_")[1] - article_doc = train_doc[article] - self.update(article_doc, pos_ex, neg_exs, losses=losses) - p, r, fscore = self._test_dev(dev_instances, dev_pos, dev_neg, dev_doc) - print(round(fscore, 1)) - # TODO - # elif not pos_ex: - # print("Weird. Couldn't find pos example for", inst_cluster) - # elif not neg_exs: - # print("Weird. Couldn't find neg examples for", inst_cluster) def _test_dev(self, dev_instances, dev_pos, dev_neg, dev_doc): predictions = list() golds = list() - for inst_cluster in dev_instances: - pos_ex = dev_pos.get(inst_cluster) - neg_exs = dev_neg.get(inst_cluster, []) - ex_to_id = dict() + for article_id, inst_cluster_set in dev_instances.items(): + for inst_cluster in inst_cluster_set: + pos_ex = dev_pos.get(inst_cluster) + neg_exs = dev_neg.get(inst_cluster, []) + ex_to_id = dict() - if pos_ex and neg_exs: - ex_to_id[pos_ex] = pos_ex._.entity_id - for neg_ex in neg_exs: - ex_to_id[neg_ex] = neg_ex._.entity_id + if pos_ex and neg_exs: + ex_to_id[pos_ex] = pos_ex._.entity_id + for neg_ex in neg_exs: + ex_to_id[neg_ex] = neg_ex._.entity_id - article = inst_cluster.split(sep="_")[0] - entity_id = inst_cluster.split(sep="_")[1] - article_doc = dev_doc[article] + article = inst_cluster.split(sep="_")[0] + entity_id = inst_cluster.split(sep="_")[1] + article_doc = dev_doc[article] - examples = list(neg_exs) - examples.append(pos_ex) - shuffle(examples) - - best_entity, lowest_mse = self._predict(examples, article_doc) - predictions.append(ex_to_id[best_entity]) - golds.append(ex_to_id[pos_ex]) + examples = list(neg_exs) + examples.append(pos_ex) + shuffle(examples) + best_entity, lowest_mse = self._predict(examples, article_doc) + predictions.append(ex_to_id[best_entity]) + golds.append(ex_to_id[pos_ex]) # TODO: use lowest_mse and combine with prior probability p, r, F = run_el.evaluate(predictions, golds, to_print=False) @@ -161,60 +163,79 @@ class EL_Model(): sgd = create_default_optimizer(model.ops) return sgd - def update(self, article_doc, true_entity, false_entities, drop=0., losses=None): + def update(self, article_doc, true_entity_list, false_entities_list, drop=0., losses=None): + # TODO: one call only to begin_update ? + + entity_diffs = None + doc_diffs = None + doc_encoding, article_bp = self.article_encoder.begin_update([article_doc], drop=drop) - true_entity_encoding, true_entity_bp = self.entity_encoder.begin_update([true_entity], drop=drop) - # print("encoding dim", len(true_entity_encoding[0])) + for i, true_entity in enumerate(true_entity_list): + false_entities = false_entities_list[i] - consensus_encoding = self._calculate_consensus(doc_encoding, true_entity_encoding) - consensus_encoding_t = consensus_encoding.transpose() + true_entity_encoding, true_entity_bp = self.entity_encoder.begin_update([true_entity], drop=drop) + # print("encoding dim", len(true_entity_encoding[0])) - doc_mse, doc_diffs = self._calculate_similarity(doc_encoding, consensus_encoding) + consensus_encoding = self._calculate_consensus(doc_encoding, true_entity_encoding) + # consensus_encoding_t = consensus_encoding.transpose() - entity_mses = list() + doc_mse, doc_diff = self._calculate_similarity(doc_encoding, consensus_encoding) - true_mse, true_diffs = self._calculate_similarity(true_entity_encoding, consensus_encoding) - # print("true_mse", true_mse) - # print("true_diffs", true_diffs) - entity_mses.append(true_mse) - # true_exp = np.exp(true_entity_encoding.dot(consensus_encoding_t)) - # print("true_exp", true_exp) + entity_mses = list() - # false_exp_sum = 0 + true_mse, true_diffs = self._calculate_similarity(true_entity_encoding, consensus_encoding) + # print("true_mse", true_mse) + # print("true_diffs", true_diffs) + entity_mses.append(true_mse) + # true_exp = np.exp(true_entity_encoding.dot(consensus_encoding_t)) + # print("true_exp", true_exp) - for false_entity in false_entities: - false_entity_encoding, false_entity_bp = self.entity_encoder.begin_update([false_entity], drop=drop) - false_mse, false_diffs = self._calculate_similarity(false_entity_encoding, consensus_encoding) - # print("false_mse", false_mse) - # false_exp = np.exp(false_entity_encoding.dot(consensus_encoding_t)) - # print("false_exp", false_exp) - # print("false_diffs", false_diffs) - entity_mses.append(false_mse) - # if false_mse > true_mse: - # true_diffs = true_diffs - false_diffs ??? - # false_exp_sum += false_exp + # false_exp_sum = 0 - # prob = true_exp / false_exp_sum - # print("prob", prob) + if doc_diffs is not None: + doc_diffs += doc_diff + entity_diffs += true_diffs + else: + doc_diffs = doc_diff + entity_diffs = true_diffs - entity_mses = sorted(entity_mses) - # mse_sum = sum(entity_mses) - # entity_probs = [1 - x/mse_sum for x in entity_mses] - # print("entity_mses", entity_mses) - # print("entity_probs", entity_probs) - true_index = entity_mses.index(true_mse) - # print("true index", true_index) - # print("true prob", entity_probs[true_index]) + for false_entity in false_entities: + false_entity_encoding, false_entity_bp = self.entity_encoder.begin_update([false_entity], drop=drop) + false_mse, false_diffs = self._calculate_similarity(false_entity_encoding, consensus_encoding) + # print("false_mse", false_mse) + # false_exp = np.exp(false_entity_encoding.dot(consensus_encoding_t)) + # print("false_exp", false_exp) + # print("false_diffs", false_diffs) + entity_mses.append(false_mse) + # if false_mse > true_mse: + # true_diffs = true_diffs - false_diffs ??? + # false_exp_sum += false_exp - # print("training loss", true_mse) + # prob = true_exp / false_exp_sum + # print("prob", prob) - # print() + entity_mses = sorted(entity_mses) + # mse_sum = sum(entity_mses) + # entity_probs = [1 - x/mse_sum for x in entity_mses] + # print("entity_mses", entity_mses) + # print("entity_probs", entity_probs) + true_index = entity_mses.index(true_mse) + # print("true index", true_index) + # print("true prob", entity_probs[true_index]) + + # print("training loss", true_mse) + + # print() # TODO: proper backpropagation taking ranking of elements into account ? # TODO backpropagation also for negative examples - true_entity_bp(true_diffs, sgd=self.sgd_entity) - article_bp(doc_diffs, sgd=self.sgd_article) + + if doc_diffs is not None: + doc_diffs = doc_diffs / len(true_entity_list) + + true_entity_bp(entity_diffs, sgd=self.sgd_entity) + article_bp(doc_diffs, sgd=self.sgd_article) # TODO delete ? @@ -268,7 +289,7 @@ class EL_Model(): collect_incorrect=True) - instances = list() + instance_by_doc = dict() local_vectors = list() # TODO: local vectors doc_by_article = dict() pos_entities = dict() @@ -280,18 +301,19 @@ class EL_Model(): if dev == run_el.is_dev(f): article_id = f.replace(".txt", "") if cnt % 500 == 0 and to_print: - print(datetime.datetime.now(), "processed", cnt, "files in the dev dataset") + print(datetime.datetime.now(), "processed", cnt, "files in the training dataset") cnt += 1 if article_id not in doc_by_article: with open(os.path.join(training_dir, f), mode="r", encoding='utf8') as file: text = file.read() doc = self.nlp(text) doc_by_article[article_id] = doc + instance_by_doc[article_id] = set() for mention, entity_pos in correct_entries[article_id].items(): descr = id_to_descr.get(entity_pos) if descr: - instances.append(article_id + "_" + mention) + instance_by_doc[article_id].add(article_id + "_" + mention) doc_descr = self.nlp(descr) doc_descr._.entity_id = entity_pos pos_entities[article_id + "_" + mention] = doc_descr @@ -308,6 +330,6 @@ class EL_Model(): if to_print: print() - print("Processed", cnt, "dev articles") + print("Processed", cnt, "training articles, dev=" + str(dev)) print() - return instances, pos_entities, neg_entities, doc_by_article + return instance_by_doc, pos_entities, neg_entities, doc_by_article diff --git a/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py b/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py index 83650aa8d..581d38b1b 100644 --- a/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py +++ b/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py @@ -111,7 +111,7 @@ if __name__ == "__main__": print("STEP 6: training ", datetime.datetime.now()) my_nlp = spacy.load('en_core_web_md') trainer = EL_Model(kb=my_kb, nlp=my_nlp) - trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, limit=50) + trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, limit=500) print() # STEP 7: apply the EL algorithm on the dev dataset