diff --git a/examples/pipeline/wiki_entity_linking/train_el.py b/examples/pipeline/wiki_entity_linking/train_el.py index 8dcea9256..c91058d5f 100644 --- a/examples/pipeline/wiki_entity_linking/train_el.py +++ b/examples/pipeline/wiki_entity_linking/train_el.py @@ -7,7 +7,7 @@ from os import listdir from examples.pipeline.wiki_entity_linking import run_el, training_set_creator, kb_creator -from spacy._ml import SpacyVectors, create_default_optimizer, zero_init +from spacy._ml import SpacyVectors, create_default_optimizer, zero_init, cosine from thinc.api import chain from thinc.v2v import Model, Maxout, Softmax, Affine, ReLu @@ -33,14 +33,12 @@ class EL_Model(): self.article_encoder = self._simple_encoder(width=300) def train_model(self, training_dir, entity_descr_output, limit=None, to_print=True): - instances, gold_vectors, entity_descriptions, doc_by_article = self._get_training_data(training_dir, + instances, pos_entities, neg_entities, doc_by_article = self._get_training_data(training_dir, entity_descr_output, limit, to_print) if to_print: - print("Training on", len(gold_vectors), "instances") - print(" - pos:", len([x for x in gold_vectors if x]), "instances") - print(" - pos:", len([x for x in gold_vectors if not x]), "instances") + print("Training on", len(instances), "instance clusters") print() self.sgd_entity = self.begin_training(self.entity_encoder) @@ -48,11 +46,20 @@ class EL_Model(): losses = {} - for inst, label, entity_descr in zip(instances, gold_vectors, entity_descriptions): - article = inst.split(sep="_")[0] - entity_id = inst.split(sep="_")[1] - article_doc = doc_by_article[article] - self.update(article_doc, entity_descr, label, losses=losses) + for inst_cluster in instances: + pos_ex = pos_entities.get(inst_cluster) + neg_exs = neg_entities.get(inst_cluster, []) + + if pos_ex and neg_exs: + article = inst_cluster.split(sep="_")[0] + entity_id = inst_cluster.split(sep="_")[1] + article_doc = doc_by_article[article] + self.update(article_doc, pos_ex, neg_exs, losses=losses) + # TODO + # elif not pos_ex: + # print("Weird. Couldn't find pos example for", inst_cluster) + # elif not neg_exs: + # print("Weird. Couldn't find neg examples for", inst_cluster) def _simple_encoder(self, width): with Model.define_operators({">>": chain}): @@ -69,22 +76,29 @@ class EL_Model(): sgd = create_default_optimizer(model.ops) return sgd - def update(self, article_doc, entity_descr, label, drop=0., losses=None): - entity_encoding, entity_bp = self.entity_encoder.begin_update([entity_descr], drop=drop) + def update(self, article_doc, true_entity, false_entities, drop=0., losses=None): doc_encoding, article_bp = self.article_encoder.begin_update([article_doc], drop=drop) + true_entity_encoding, true_entity_bp = self.entity_encoder.begin_update([true_entity], drop=drop) + # true_similarity = cosine(true_entity_encoding, doc_encoding) + # print("true_similarity", true_similarity) + + # for false_entity in false_entities: + # false_entity_encoding, false_entity_bp = self.entity_encoder.begin_update([false_entity], drop=drop) + # false_similarity = cosine(false_entity_encoding, doc_encoding) + # print("false_similarity", false_similarity) + # print("entity/article output dim", len(entity_encoding[0]), len(doc_encoding[0])) - mse, diffs = self._calculate_similarity(entity_encoding, doc_encoding) + mse, diffs = self._calculate_similarity(true_entity_encoding, doc_encoding) # print() # TODO: proper backpropagation taking ranking of elements into account ? # TODO backpropagation also for negative examples - if label: - entity_bp(diffs, sgd=self.sgd_entity) - article_bp(diffs, sgd=self.sgd_article) - print(mse) + true_entity_bp(diffs, sgd=self.sgd_entity) + article_bp(diffs, sgd=self.sgd_article) + print(mse) # TODO delete ? @@ -115,7 +129,7 @@ class EL_Model(): raise ValueError("To calculate similarity, both vectors should be of equal length") diffs = (vector2 - vector1) - error_sum = (diffs ** 2).sum(axis=1) + error_sum = (diffs ** 2).sum() mean_square_error = error_sum / len(vector1) return float(mean_square_error), diffs @@ -130,10 +144,10 @@ class EL_Model(): collect_incorrect=True) instances = list() - entity_descriptions = list() local_vectors = list() # TODO: local vectors - gold_vectors = list() doc_by_article = dict() + pos_entities = dict() + neg_entities = dict() cnt = 0 for f in listdir(training_dir): @@ -149,25 +163,24 @@ class EL_Model(): doc = self.nlp(text) doc_by_article[article_id] = doc - for mention_pos, entity_pos in correct_entries[article_id].items(): + for mention, entity_pos in correct_entries[article_id].items(): descr = id_to_descr.get(entity_pos) if descr: - instances.append(article_id + "_" + entity_pos) - doc = self.nlp(descr) - entity_descriptions.append(doc) - gold_vectors.append(True) + instances.append(article_id + "_" + mention) + doc_descr = self.nlp(descr) + pos_entities[article_id + "_" + mention] = doc_descr - for mention_neg, entity_negs in incorrect_entries[article_id].items(): + for mention, entity_negs in incorrect_entries[article_id].items(): for entity_neg in entity_negs: descr = id_to_descr.get(entity_neg) if descr: - instances.append(article_id + "_" + entity_neg) - doc = self.nlp(descr) - entity_descriptions.append(doc) - gold_vectors.append(False) + doc_descr = self.nlp(descr) + descr_list = neg_entities.get(article_id + "_" + mention, []) + descr_list.append(doc_descr) + neg_entities[article_id + "_" + mention] = descr_list if to_print: print() print("Processed", cnt, "dev articles") print() - return instances, gold_vectors, entity_descriptions, doc_by_article + return instances, pos_entities, neg_entities, doc_by_article