diff --git a/examples/pipeline/wiki_entity_linking/train_descriptions.py b/examples/pipeline/wiki_entity_linking/train_descriptions.py new file mode 100644 index 000000000..63149b5f7 --- /dev/null +++ b/examples/pipeline/wiki_entity_linking/train_descriptions.py @@ -0,0 +1,113 @@ +from random import shuffle + +from examples.pipeline.wiki_entity_linking import kb_creator + +import numpy as np + +from spacy._ml import zero_init, create_default_optimizer +from spacy.cli.pretrain import get_cossim_loss + +from thinc.v2v import Model +from thinc.api import chain +from thinc.neural._classes.affine import Affine + + +class EntityEncoder: + + INPUT_DIM = 300 # dimension of pre-trained vectors + DESC_WIDTH = 64 + + DROP = 0 + EPOCHS = 5 + STOP_THRESHOLD = 0.05 + + BATCH_SIZE = 1000 + + def __init__(self, kb, nlp): + self.nlp = nlp + self.kb = kb + + def run(self, entity_descr_output): + id_to_descr = kb_creator._get_id_to_description(entity_descr_output) + + processed, loss = self._train_model(entity_descr_output, id_to_descr) + print("Trained on", processed, "entities across", self.EPOCHS, "epochs") + print("Final loss:", loss) + print() + + # TODO: apply and write to file afterwards ! + # self._apply_encoder(id_to_descr) + + def _train_model(self, entity_descr_output, id_to_descr): + # TODO: when loss gets too low, a 'mean of empty slice' warning is thrown by numpy + + self._build_network(self.INPUT_DIM, self.DESC_WIDTH) + + processed = 0 + loss = 1 + + for i in range(self.EPOCHS): + entity_keys = list(id_to_descr.keys()) + shuffle(entity_keys) + + batch_nr = 0 + start = 0 + stop = min(self.BATCH_SIZE, len(entity_keys)) + + while loss > self.STOP_THRESHOLD and start < len(entity_keys): + batch = [] + for e in entity_keys[start:stop]: + descr = id_to_descr[e] + doc = self.nlp(descr) + doc_vector = self._get_doc_embedding(doc) + batch.append(doc_vector) + + loss = self.update(batch) + print(i, batch_nr, loss) + processed += len(batch) + + batch_nr += 1 + start = start + self.BATCH_SIZE + stop = min(stop + self.BATCH_SIZE, len(entity_keys)) + + return processed, loss + + def _apply_encoder(self, id_to_descr): + for id, descr in id_to_descr.items(): + doc = self.nlp(descr) + doc_vector = self._get_doc_embedding(doc) + encoding = self.encoder(np.asarray([doc_vector])) + + @staticmethod + def _get_doc_embedding(doc): + indices = np.zeros((len(doc),), dtype="i") + for i, word in enumerate(doc): + if word.orth in doc.vocab.vectors.key2row: + indices[i] = doc.vocab.vectors.key2row[word.orth] + else: + indices[i] = 0 + word_vectors = doc.vocab.vectors.data[indices] + doc_vector = np.mean(word_vectors, axis=0) # TODO: min? max? + return doc_vector + + def _build_network(self, orig_width, hidden_with): + with Model.define_operators({">>": chain}): + self.encoder = ( + Affine(hidden_with, orig_width) + ) + self.model = self.encoder >> zero_init(Affine(orig_width, hidden_with, drop_factor=0.0)) + + self.sgd = create_default_optimizer(self.model.ops) + + def update(self, vectors): + predictions, bp_model = self.model.begin_update(np.asarray(vectors), drop=self.DROP) + + loss, d_scores = self.get_loss(scores=predictions, golds=np.asarray(vectors)) + bp_model(d_scores, sgd=self.sgd) + + return loss / len(vectors) + + @staticmethod + def get_loss(golds, scores): + loss, gradients = get_cossim_loss(scores, golds) + return loss, gradients diff --git a/examples/pipeline/wiki_entity_linking/train_el.py b/examples/pipeline/wiki_entity_linking/train_el.py index b9a0dc843..143e38d99 100644 --- a/examples/pipeline/wiki_entity_linking/train_el.py +++ b/examples/pipeline/wiki_entity_linking/train_el.py @@ -31,7 +31,7 @@ class EL_Model: PRINT_BATCH_LOSS = False EPS = 0.0000000005 - BATCH_SIZE = 5 + BATCH_SIZE = 100 DOC_CUTOFF = 300 # number of characters from the doc context INPUT_DIM = 300 # dimension of pre-trained vectors @@ -41,9 +41,9 @@ class EL_Model: ARTICLE_WIDTH = 128 SENT_WIDTH = 64 - DROP = 0.1 - LEARN_RATE = 0.001 - EPOCHS = 5 + DROP = 0.4 + LEARN_RATE = 0.005 + EPOCHS = 10 L2 = 1e-6 name = "entity_linker" @@ -62,12 +62,14 @@ class EL_Model: def train_model(self, training_dir, entity_descr_output, trainlimit=None, devlimit=None, to_print=True): np.seterr(divide="raise", over="warn", under="ignore", invalid="raise") + id_to_descr = kb_creator._get_id_to_description(entity_descr_output) + train_ent, train_gold, train_desc, train_art, train_art_texts, train_sent, train_sent_texts = \ - self._get_training_data(training_dir, entity_descr_output, False, trainlimit, to_print=False) + self._get_training_data(training_dir, id_to_descr, False, trainlimit, to_print=False) train_clusters = list(train_ent.keys()) dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, dev_sent, dev_sent_texts = \ - self._get_training_data(training_dir, entity_descr_output, True, devlimit, to_print=False) + self._get_training_data(training_dir, id_to_descr, True, devlimit, to_print=False) dev_clusters = list(dev_ent.keys()) dev_pos_count = len([g for g in dev_gold.values() if g]) @@ -386,9 +388,7 @@ class EL_Model: bp_doc(doc_gradients, sgd=self.sgd_article) bp_sent(sent_gradients, sgd=self.sgd_sent) - def _get_training_data(self, training_dir, entity_descr_output, dev, limit, to_print): - id_to_descr = kb_creator._get_id_to_description(entity_descr_output) - + def _get_training_data(self, training_dir, id_to_descr, dev, limit, to_print): correct_entries, incorrect_entries = training_set_creator.read_training_entities(training_output=training_dir, collect_correct=True, collect_incorrect=True) diff --git a/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py b/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py index 40d737a6f..1f4b4b67e 100644 --- a/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py +++ b/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py @@ -2,6 +2,7 @@ from __future__ import unicode_literals from examples.pipeline.wiki_entity_linking import wikipedia_processor as wp, kb_creator, training_set_creator, run_el +from examples.pipeline.wiki_entity_linking.train_descriptions import EntityEncoder from examples.pipeline.wiki_entity_linking.train_el import EL_Model import spacy @@ -38,11 +39,14 @@ if __name__ == "__main__": to_read_kb = True to_test_kb = False + # run entity description pre-training + run_desc_training = True + # create training dataset create_wp_training = False - # run training - run_training = True + # run EL training + run_el_training = False # apply named entity linking to the dev dataset apply_to_dev = False @@ -101,17 +105,25 @@ if __name__ == "__main__": run_el.run_el_toy_example(kb=my_kb, nlp=my_nlp) print() + # STEP 4b : read KB back in from file, create entity descriptions + # TODO: write back to file + if run_desc_training: + print("STEP 4b: training entity descriptions", datetime.datetime.now()) + my_nlp = spacy.load('en_core_web_md') + EntityEncoder(my_kb, my_nlp).run(entity_descr_output=ENTITY_DESCR) + print() + # STEP 5: create a training dataset from WP if create_wp_training: print("STEP 5: create training dataset", datetime.datetime.now()) training_set_creator.create_training(kb=my_kb, entity_def_input=ENTITY_DEFS, training_output=TRAINING_DIR) # STEP 6: apply the EL algorithm on the training dataset - if run_training: + if run_el_training: print("STEP 6: training", datetime.datetime.now()) my_nlp = spacy.load('en_core_web_md') trainer = EL_Model(kb=my_kb, nlp=my_nlp) - trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=50, devlimit=20) + trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=10000, devlimit=500) print() # STEP 7: apply the EL algorithm on the dev dataset diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx index c8afd431e..d0c83b56e 100644 --- a/spacy/pipeline/pipes.pyx +++ b/spacy/pipeline/pipes.pyx @@ -1177,6 +1177,8 @@ class EntityLinker(Pipe): def predict(self, docs): self.require_model() + final_entities = list() + final_kb_ids = list() for i, article_doc in enumerate(docs): doc_encoding = self.article_encoder([article_doc]) for ent in article_doc.ents: @@ -1188,23 +1190,27 @@ class EntityLinker(Pipe): candidates = self.kb.get_candidates(ent.text) if candidates: - highest_sim = -5 - best_i = -1 with self.use_avg_params: + scores = list() for c in candidates: + prior_prob = c.prior_prob kb_id = c.entity_ description = self.id_to_descr.get(kb_id) entity_encodings = self.entity_encoder([description]) # TODO: static entity vectors ? sim = cosine(entity_encodings, mention_enc_t) - if sim >= highest_sim: - best_i = i - highest_sim = sim + score = prior_prob + sim - (prior_prob*sim) # TODO: weights ? + scores.append(score) - # TODO best_candidate = max(candidates, key=lambda c: c.prior_prob) + best_index = scores.index(max(scores)) + best_candidate = candidates[best_index] + final_entities.append(ent) + final_kb_ids.append(best_candidate) + + return final_entities, final_kb_ids def set_annotations(self, docs, entities, kb_ids=None): - for token, kb_id in zip(entities, kb_ids): - token.ent_kb_id_ = kb_id + for entity, kb_id in zip(entities, kb_ids): + entity.ent_kb_id_ = kb_id class Sentencizer(object): """Segment the Doc into sentences using a rule-based strategy.