introduce goldparse.links

This commit is contained in:
svlandeg 2019-06-07 13:54:45 +02:00
parent a5c061f506
commit 0486ccabfd
5 changed files with 82 additions and 53 deletions

View File

@ -303,8 +303,7 @@ def read_training(nlp, training_dir, id_to_descr, doc_cutoff, dev, limit, to_pri
collect_correct=True, collect_correct=True,
collect_incorrect=True) collect_incorrect=True)
docs = list() data = []
golds = list()
cnt = 0 cnt = 0
next_entity_nr = 1 next_entity_nr = 1
@ -323,7 +322,7 @@ def read_training(nlp, training_dir, id_to_descr, doc_cutoff, dev, limit, to_pri
article_doc = nlp(text) article_doc = nlp(text)
truncated_text = text[0:min(doc_cutoff, len(text))] truncated_text = text[0:min(doc_cutoff, len(text))]
gold_entities = dict() gold_entities = list()
# process all positive and negative entities, collect all relevant mentions in this article # process all positive and negative entities, collect all relevant mentions in this article
for mention, entity_pos in correct_entries[article_id].items(): for mention, entity_pos in correct_entries[article_id].items():
@ -337,11 +336,10 @@ def read_training(nlp, training_dir, id_to_descr, doc_cutoff, dev, limit, to_pri
# store gold entities # store gold entities
for match_id, start, end in matches: for match_id, start, end in matches:
gold_entities[(start, end, entity_pos)] = 1.0 gold_entities.append((start, end, entity_pos))
gold = GoldParse(doc=article_doc, cats=gold_entities) gold = GoldParse(doc=article_doc, links=gold_entities)
docs.append(article_doc) data.append((article_doc, gold))
golds.append(gold)
cnt += 1 cnt += 1
except Exception as e: except Exception as e:
@ -352,7 +350,7 @@ def read_training(nlp, training_dir, id_to_descr, doc_cutoff, dev, limit, to_pri
print() print()
print("Processed", cnt, "training articles, dev=" + str(dev)) print("Processed", cnt, "training articles, dev=" + str(dev))
print() print()
return docs, golds return data

View File

@ -1,6 +1,10 @@
# coding: utf-8 # coding: utf-8
from __future__ import unicode_literals from __future__ import unicode_literals
import random
from spacy.util import minibatch, compounding
from examples.pipeline.wiki_entity_linking import wikipedia_processor as wp, kb_creator, training_set_creator, run_el from examples.pipeline.wiki_entity_linking import wikipedia_processor as wp, kb_creator, training_set_creator, run_el
from examples.pipeline.wiki_entity_linking.train_el import EL_Model from examples.pipeline.wiki_entity_linking.train_el import EL_Model
@ -23,9 +27,11 @@ VOCAB_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/vocab'
TRAINING_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/training_data_nel/' TRAINING_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/training_data_nel/'
MAX_CANDIDATES=10 MAX_CANDIDATES = 10
MIN_PAIR_OCC=5 MIN_PAIR_OCC = 5
DOC_CHAR_CUTOFF=300 DOC_CHAR_CUTOFF = 300
EPOCHS = 5
DROPOUT = 0.1
if __name__ == "__main__": if __name__ == "__main__":
print("START", datetime.datetime.now()) print("START", datetime.datetime.now())
@ -115,7 +121,7 @@ if __name__ == "__main__":
if train_pipe: if train_pipe:
id_to_descr = kb_creator._get_id_to_description(ENTITY_DESCR) id_to_descr = kb_creator._get_id_to_description(ENTITY_DESCR)
docs, golds = training_set_creator.read_training(nlp=nlp, train_data = training_set_creator.read_training(nlp=nlp,
training_dir=TRAINING_DIR, training_dir=TRAINING_DIR,
id_to_descr=id_to_descr, id_to_descr=id_to_descr,
doc_cutoff=DOC_CHAR_CUTOFF, doc_cutoff=DOC_CHAR_CUTOFF,
@ -123,12 +129,6 @@ if __name__ == "__main__":
limit=10, limit=10,
to_print=False) to_print=False)
# for doc, gold in zip(docs, golds):
# print("doc", doc)
# for entity, label in gold.cats.items():
# print("entity", entity, label)
# print()
el_pipe = nlp.create_pipe(name='entity_linker', config={"kb": my_kb}) el_pipe = nlp.create_pipe(name='entity_linker', config={"kb": my_kb})
nlp.add_pipe(el_pipe, last=True) nlp.add_pipe(el_pipe, last=True)
@ -136,6 +136,20 @@ if __name__ == "__main__":
with nlp.disable_pipes(*other_pipes): # only train Entity Linking with nlp.disable_pipes(*other_pipes): # only train Entity Linking
nlp.begin_training() nlp.begin_training()
for itn in range(EPOCHS):
random.shuffle(train_data)
losses = {}
batches = minibatch(train_data, size=compounding(4.0, 32.0, 1.001))
for batch in batches:
docs, golds = zip(*batch)
nlp.update(
docs,
golds,
drop=DROPOUT,
losses=losses,
)
print("Losses", losses)
### BELOW CODE IS DEPRECATED ### ### BELOW CODE IS DEPRECATED ###
# STEP 6: apply the EL algorithm on the training dataset - TODO deprecated - code moved to pipes.pyx # STEP 6: apply the EL algorithm on the training dataset - TODO deprecated - code moved to pipes.pyx

View File

@ -31,6 +31,7 @@ cdef class GoldParse:
cdef public list ents cdef public list ents
cdef public dict brackets cdef public dict brackets
cdef public object cats cdef public object cats
cdef public list links
cdef readonly list cand_to_gold cdef readonly list cand_to_gold
cdef readonly list gold_to_cand cdef readonly list gold_to_cand

View File

@ -427,7 +427,7 @@ cdef class GoldParse:
def __init__(self, doc, annot_tuples=None, words=None, tags=None, def __init__(self, doc, annot_tuples=None, words=None, tags=None,
heads=None, deps=None, entities=None, make_projective=False, heads=None, deps=None, entities=None, make_projective=False,
cats=None, **_): cats=None, links=None, **_):
"""Create a GoldParse. """Create a GoldParse.
doc (Doc): The document the annotations refer to. doc (Doc): The document the annotations refer to.
@ -450,6 +450,8 @@ cdef class GoldParse:
examples of a label to have the value 0.0. Labels not in the examples of a label to have the value 0.0. Labels not in the
dictionary are treated as missing - the gradient for those labels dictionary are treated as missing - the gradient for those labels
will be zero. will be zero.
links (iterable): A sequence of `(start_char, end_char, kb_id)` tuples,
representing the external ID of an entity in a knowledge base.
RETURNS (GoldParse): The newly constructed object. RETURNS (GoldParse): The newly constructed object.
""" """
if words is None: if words is None:
@ -485,6 +487,7 @@ cdef class GoldParse:
self.c.ner = <Transition*>self.mem.alloc(len(doc), sizeof(Transition)) self.c.ner = <Transition*>self.mem.alloc(len(doc), sizeof(Transition))
self.cats = {} if cats is None else dict(cats) self.cats = {} if cats is None else dict(cats)
self.links = links
self.words = [None] * len(doc) self.words = [None] * len(doc)
self.tags = [None] * len(doc) self.tags = [None] * len(doc)
self.heads = [None] * len(doc) self.heads = [None] * len(doc)

View File

@ -1115,48 +1115,61 @@ class EntityLinker(Pipe):
self.sgd_mention = create_default_optimizer(self.mention_encoder.ops) self.sgd_mention = create_default_optimizer(self.mention_encoder.ops)
def update(self, docs, golds, state=None, drop=0.0, sgd=None, losses=None): def update(self, docs, golds, state=None, drop=0.0, sgd=None, losses=None):
""" docs should be a tuple of (entity_docs, article_docs, sentence_docs) TODO """
self.require_model() self.require_model()
if len(docs) != len(golds): if len(docs) != len(golds):
raise ValueError(Errors.E077.format(value="loss", n_docs=len(docs), raise ValueError(Errors.E077.format(value="EL training", n_docs=len(docs),
n_golds=len(golds))) n_golds=len(golds)))
entity_docs, article_docs, sentence_docs = docs if isinstance(docs, Doc):
assert len(entity_docs) == len(article_docs) == len(sentence_docs) docs = [docs]
golds = [golds]
if isinstance(entity_docs, Doc): for doc, gold in zip(docs, golds):
entity_docs = [entity_docs] print("doc", doc)
article_docs = [article_docs] for entity in gold.links:
sentence_docs = [sentence_docs] start, end, gold_kb = entity
print("entity", entity)
mention = doc[start:end].text
print("mention", mention)
candidates = self.kb.get_candidates(mention)
for c in candidates:
prior_prob = c.prior_prob
kb_id = c.entity_
print("candidate", kb_id, prior_prob)
entity_encoding = c.entity_vector
print()
entity_encodings = None #TODO print()
doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=drop)
sent_encodings, bp_sent = self.sent_encoder.begin_update(sentence_docs, drop=drop)
concat_encodings = [list(doc_encodings[i]) + list(sent_encodings[i]) for i in # entity_encodings = None #TODO
range(len(article_docs))] # doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=drop)
mention_encodings, bp_cont = self.mention_encoder.begin_update(np.asarray(concat_encodings), drop=self.DROP) # sent_encodings, bp_sent = self.sent_encoder.begin_update(sentence_docs, drop=drop)
#
loss, d_scores = self.get_loss(scores=mention_encodings, golds=entity_encodings, docs=None) # concat_encodings = [list(doc_encodings[i]) + list(sent_encodings[i]) for i in
# range(len(article_docs))]
mention_gradient = bp_cont(d_scores, sgd=self.sgd_cont) # mention_encodings, bp_cont = self.mention_encoder.begin_update(np.asarray(concat_encodings), drop=self.DROP)
#
# gradient : concat (doc+sent) vs. desc # loss, d_scores = self.get_loss(scores=mention_encodings, golds=entity_encodings, docs=None)
sent_start = self.article_encoder.nO #
sent_gradients = list() # mention_gradient = bp_cont(d_scores, sgd=self.sgd_cont)
doc_gradients = list() #
for x in mention_gradient: # # gradient : concat (doc+sent) vs. desc
doc_gradients.append(list(x[0:sent_start])) # sent_start = self.article_encoder.nO
sent_gradients.append(list(x[sent_start:])) # sent_gradients = list()
# doc_gradients = list()
bp_doc(doc_gradients, sgd=self.sgd_article) # for x in mention_gradient:
bp_sent(sent_gradients, sgd=self.sgd_sent) # doc_gradients.append(list(x[0:sent_start]))
# sent_gradients.append(list(x[sent_start:]))
if losses is not None: #
losses.setdefault(self.name, 0.0) # bp_doc(doc_gradients, sgd=self.sgd_article)
losses[self.name] += loss # bp_sent(sent_gradients, sgd=self.sgd_sent)
return loss #
# if losses is not None:
# losses.setdefault(self.name, 0.0)
# losses[self.name] += loss
# return loss
return None
def get_loss(self, docs, golds, scores): def get_loss(self, docs, golds, scores):
loss, gradients = get_cossim_loss(scores, golds) loss, gradients = get_cossim_loss(scores, golds)