From 0486ccabfdbfd6ee4531574ad18b5dde085b43be Mon Sep 17 00:00:00 2001 From: svlandeg Date: Fri, 7 Jun 2019 13:54:45 +0200 Subject: [PATCH] introduce goldparse.links --- .../training_set_creator.py | 14 ++-- .../wiki_entity_linking/wiki_nel_pipeline.py | 34 +++++--- spacy/gold.pxd | 1 + spacy/gold.pyx | 5 +- spacy/pipeline/pipes.pyx | 81 +++++++++++-------- 5 files changed, 82 insertions(+), 53 deletions(-) diff --git a/examples/pipeline/wiki_entity_linking/training_set_creator.py b/examples/pipeline/wiki_entity_linking/training_set_creator.py index c1879e2fb..156bce05f 100644 --- a/examples/pipeline/wiki_entity_linking/training_set_creator.py +++ b/examples/pipeline/wiki_entity_linking/training_set_creator.py @@ -303,8 +303,7 @@ def read_training(nlp, training_dir, id_to_descr, doc_cutoff, dev, limit, to_pri collect_correct=True, collect_incorrect=True) - docs = list() - golds = list() + data = [] cnt = 0 next_entity_nr = 1 @@ -323,7 +322,7 @@ def read_training(nlp, training_dir, id_to_descr, doc_cutoff, dev, limit, to_pri article_doc = nlp(text) truncated_text = text[0:min(doc_cutoff, len(text))] - gold_entities = dict() + gold_entities = list() # process all positive and negative entities, collect all relevant mentions in this article for mention, entity_pos in correct_entries[article_id].items(): @@ -337,11 +336,10 @@ def read_training(nlp, training_dir, id_to_descr, doc_cutoff, dev, limit, to_pri # store gold entities for match_id, start, end in matches: - gold_entities[(start, end, entity_pos)] = 1.0 + gold_entities.append((start, end, entity_pos)) - gold = GoldParse(doc=article_doc, cats=gold_entities) - docs.append(article_doc) - golds.append(gold) + gold = GoldParse(doc=article_doc, links=gold_entities) + data.append((article_doc, gold)) cnt += 1 except Exception as e: @@ -352,7 +350,7 @@ def read_training(nlp, training_dir, id_to_descr, doc_cutoff, dev, limit, to_pri print() print("Processed", cnt, "training articles, dev=" + str(dev)) print() - return docs, golds + return data diff --git a/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py b/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py index 08f4adda0..b66f8b316 100644 --- a/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py +++ b/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py @@ -1,6 +1,10 @@ # coding: utf-8 from __future__ import unicode_literals +import random + +from spacy.util import minibatch, compounding + from examples.pipeline.wiki_entity_linking import wikipedia_processor as wp, kb_creator, training_set_creator, run_el from examples.pipeline.wiki_entity_linking.train_el import EL_Model @@ -23,9 +27,11 @@ VOCAB_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/vocab' TRAINING_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/training_data_nel/' -MAX_CANDIDATES=10 -MIN_PAIR_OCC=5 -DOC_CHAR_CUTOFF=300 +MAX_CANDIDATES = 10 +MIN_PAIR_OCC = 5 +DOC_CHAR_CUTOFF = 300 +EPOCHS = 5 +DROPOUT = 0.1 if __name__ == "__main__": print("START", datetime.datetime.now()) @@ -115,7 +121,7 @@ if __name__ == "__main__": if train_pipe: id_to_descr = kb_creator._get_id_to_description(ENTITY_DESCR) - docs, golds = training_set_creator.read_training(nlp=nlp, + train_data = training_set_creator.read_training(nlp=nlp, training_dir=TRAINING_DIR, id_to_descr=id_to_descr, doc_cutoff=DOC_CHAR_CUTOFF, @@ -123,12 +129,6 @@ if __name__ == "__main__": limit=10, to_print=False) - # for doc, gold in zip(docs, golds): - # print("doc", doc) - # for entity, label in gold.cats.items(): - # print("entity", entity, label) - # print() - el_pipe = nlp.create_pipe(name='entity_linker', config={"kb": my_kb}) nlp.add_pipe(el_pipe, last=True) @@ -136,6 +136,20 @@ if __name__ == "__main__": with nlp.disable_pipes(*other_pipes): # only train Entity Linking nlp.begin_training() + for itn in range(EPOCHS): + random.shuffle(train_data) + losses = {} + batches = minibatch(train_data, size=compounding(4.0, 32.0, 1.001)) + for batch in batches: + docs, golds = zip(*batch) + nlp.update( + docs, + golds, + drop=DROPOUT, + losses=losses, + ) + print("Losses", losses) + ### BELOW CODE IS DEPRECATED ### # STEP 6: apply the EL algorithm on the training dataset - TODO deprecated - code moved to pipes.pyx diff --git a/spacy/gold.pxd b/spacy/gold.pxd index a1550b1ef..8943a155a 100644 --- a/spacy/gold.pxd +++ b/spacy/gold.pxd @@ -31,6 +31,7 @@ cdef class GoldParse: cdef public list ents cdef public dict brackets cdef public object cats + cdef public list links cdef readonly list cand_to_gold cdef readonly list gold_to_cand diff --git a/spacy/gold.pyx b/spacy/gold.pyx index 569979a5f..4fb22f3f0 100644 --- a/spacy/gold.pyx +++ b/spacy/gold.pyx @@ -427,7 +427,7 @@ cdef class GoldParse: def __init__(self, doc, annot_tuples=None, words=None, tags=None, heads=None, deps=None, entities=None, make_projective=False, - cats=None, **_): + cats=None, links=None, **_): """Create a GoldParse. doc (Doc): The document the annotations refer to. @@ -450,6 +450,8 @@ cdef class GoldParse: examples of a label to have the value 0.0. Labels not in the dictionary are treated as missing - the gradient for those labels will be zero. + links (iterable): A sequence of `(start_char, end_char, kb_id)` tuples, + representing the external ID of an entity in a knowledge base. RETURNS (GoldParse): The newly constructed object. """ if words is None: @@ -485,6 +487,7 @@ cdef class GoldParse: self.c.ner = self.mem.alloc(len(doc), sizeof(Transition)) self.cats = {} if cats is None else dict(cats) + self.links = links self.words = [None] * len(doc) self.tags = [None] * len(doc) self.heads = [None] * len(doc) diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx index a3caae455..f15ffd036 100644 --- a/spacy/pipeline/pipes.pyx +++ b/spacy/pipeline/pipes.pyx @@ -1115,48 +1115,61 @@ class EntityLinker(Pipe): self.sgd_mention = create_default_optimizer(self.mention_encoder.ops) def update(self, docs, golds, state=None, drop=0.0, sgd=None, losses=None): - """ docs should be a tuple of (entity_docs, article_docs, sentence_docs) TODO """ self.require_model() if len(docs) != len(golds): - raise ValueError(Errors.E077.format(value="loss", n_docs=len(docs), + raise ValueError(Errors.E077.format(value="EL training", n_docs=len(docs), n_golds=len(golds))) - entity_docs, article_docs, sentence_docs = docs - assert len(entity_docs) == len(article_docs) == len(sentence_docs) + if isinstance(docs, Doc): + docs = [docs] + golds = [golds] - if isinstance(entity_docs, Doc): - entity_docs = [entity_docs] - article_docs = [article_docs] - sentence_docs = [sentence_docs] + for doc, gold in zip(docs, golds): + print("doc", doc) + for entity in gold.links: + start, end, gold_kb = entity + print("entity", entity) + mention = doc[start:end].text + print("mention", mention) + candidates = self.kb.get_candidates(mention) + for c in candidates: + prior_prob = c.prior_prob + kb_id = c.entity_ + print("candidate", kb_id, prior_prob) + entity_encoding = c.entity_vector + print() - entity_encodings = None #TODO - doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=drop) - sent_encodings, bp_sent = self.sent_encoder.begin_update(sentence_docs, drop=drop) + print() - concat_encodings = [list(doc_encodings[i]) + list(sent_encodings[i]) for i in - range(len(article_docs))] - mention_encodings, bp_cont = self.mention_encoder.begin_update(np.asarray(concat_encodings), drop=self.DROP) - - loss, d_scores = self.get_loss(scores=mention_encodings, golds=entity_encodings, docs=None) - - mention_gradient = bp_cont(d_scores, sgd=self.sgd_cont) - - # gradient : concat (doc+sent) vs. desc - sent_start = self.article_encoder.nO - sent_gradients = list() - doc_gradients = list() - for x in mention_gradient: - doc_gradients.append(list(x[0:sent_start])) - sent_gradients.append(list(x[sent_start:])) - - bp_doc(doc_gradients, sgd=self.sgd_article) - bp_sent(sent_gradients, sgd=self.sgd_sent) - - if losses is not None: - losses.setdefault(self.name, 0.0) - losses[self.name] += loss - return loss + # entity_encodings = None #TODO + # doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=drop) + # sent_encodings, bp_sent = self.sent_encoder.begin_update(sentence_docs, drop=drop) + # + # concat_encodings = [list(doc_encodings[i]) + list(sent_encodings[i]) for i in + # range(len(article_docs))] + # mention_encodings, bp_cont = self.mention_encoder.begin_update(np.asarray(concat_encodings), drop=self.DROP) + # + # loss, d_scores = self.get_loss(scores=mention_encodings, golds=entity_encodings, docs=None) + # + # mention_gradient = bp_cont(d_scores, sgd=self.sgd_cont) + # + # # gradient : concat (doc+sent) vs. desc + # sent_start = self.article_encoder.nO + # sent_gradients = list() + # doc_gradients = list() + # for x in mention_gradient: + # doc_gradients.append(list(x[0:sent_start])) + # sent_gradients.append(list(x[sent_start:])) + # + # bp_doc(doc_gradients, sgd=self.sgd_article) + # bp_sent(sent_gradients, sgd=self.sgd_sent) + # + # if losses is not None: + # losses.setdefault(self.name, 0.0) + # losses[self.name] += loss + # return loss + return None def get_loss(self, docs, golds, scores): loss, gradients = get_cossim_loss(scores, golds)