fix for context encoder optimizer

This commit is contained in:
svlandeg 2019-07-03 13:35:36 +02:00
parent 3420cbe496
commit 8840d4b1b3
2 changed files with 23 additions and 21 deletions

View File

@ -73,11 +73,11 @@ def run_pipeline():
measure_performance = True
# test the EL pipe on a simple example
to_test_pipeline = False
to_test_pipeline = True
# write the NLP object, read back in and test again
to_write_nlp = False
to_read_nlp = False
to_write_nlp = True
to_read_nlp = True
test_from_file = False
# STEP 1 : create prior probabilities from WP (run only once)
@ -154,8 +154,8 @@ def run_pipeline():
optimizer.L2 = L2
# define the size (nr of entities) of training and dev set
train_limit = 50000
dev_limit = 50000
train_limit = 5
dev_limit = 5
train_data = training_set_creator.read_training(nlp=nlp_2,
training_dir=TRAINING_DIR,
@ -250,7 +250,8 @@ def run_pipeline():
print("STEP 9: testing NLP IO", datetime.datetime.now())
print()
print("writing to", NLP_2_DIR)
nlp_2.to_disk(NLP_2_DIR)
with el_pipe.model.use_params(optimizer.averages) and el_pipe.model.tok2vec.use_params(el_pipe.sgd_context.averages):
nlp_2.to_disk(NLP_2_DIR)
print()
# verify that the IO has gone correctly

View File

@ -1082,12 +1082,8 @@ class EntityLinker(Pipe):
def __init__(self, **cfg):
self.model = True
self.kb = None
self.sgd_context = None
self.cfg = dict(cfg)
self.context_weight = cfg.get("context_weight", 1)
self.prior_weight = cfg.get("prior_weight", 1)
self.context_width = cfg.get("context_width")
self.type_to_int = cfg.get("type_to_int", dict())
self.sgd_context = None
def set_kb(self, kb):
self.kb = kb
@ -1112,6 +1108,7 @@ class EntityLinker(Pipe):
if sgd is None:
sgd = self.create_optimizer()
return sgd
def update(self, docs, golds, state=None, drop=0.0, sgd=None, losses=None):
@ -1138,6 +1135,8 @@ class EntityLinker(Pipe):
priors = []
type_vectors = []
type_to_int = self.cfg.get("type_to_int", dict())
for doc, gold in zip(docs, golds):
ents_by_offset = dict()
for ent in doc.ents:
@ -1148,9 +1147,9 @@ class EntityLinker(Pipe):
gold_ent = ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)]
assert gold_ent is not None
type_vector = [0 for i in range(len(self.type_to_int))]
if len(self.type_to_int) > 0:
type_vector[self.type_to_int[gold_ent.label_]] = 1
type_vector = [0 for i in range(len(type_to_int))]
if len(type_to_int) > 0:
type_vector[type_to_int[gold_ent.label_]] = 1
candidates = self.kb.get_candidates(mention)
random.shuffle(candidates)
@ -1162,7 +1161,7 @@ class EntityLinker(Pipe):
context_docs.append(doc)
type_vectors.append(type_vector)
if self.prior_weight > 0:
if self.cfg.get("prior_weight", 1) > 0:
priors.append([c.prior_prob])
else:
priors.append([0])
@ -1187,7 +1186,7 @@ class EntityLinker(Pipe):
loss, d_scores = self.get_loss(prediction=pred, golds=cats, docs=None)
mention_gradient = bp_mention(d_scores, sgd=sgd)
context_gradients = [list(x[0:self.context_width]) for x in mention_gradient]
context_gradients = [list(x[0:self.cfg.get("context_width")]) for x in mention_gradient]
bp_context(self.model.ops.asarray(context_gradients, dtype="float32"), sgd=self.sgd_context)
if losses is not None:
@ -1235,13 +1234,15 @@ class EntityLinker(Pipe):
context_encodings = self.model.tok2vec(docs)
xp = get_array_module(context_encodings)
type_to_int = self.cfg.get("type_to_int", dict())
for i, doc in enumerate(docs):
if len(doc) > 0:
context_encoding = context_encodings[i]
for ent in doc.ents:
type_vector = [0 for i in range(len(self.type_to_int))]
if len(self.type_to_int) > 0:
type_vector[self.type_to_int[ent.label_]] = 1
type_vector = [0 for i in range(len(type_to_int))]
if len(type_to_int) > 0:
type_vector[type_to_int[ent.label_]] = 1
candidates = self.kb.get_candidates(ent.text)
if candidates:
@ -1249,10 +1250,10 @@ class EntityLinker(Pipe):
# this will set the prior probabilities to 0 (just like in training) if their weight is 0
prior_probs = xp.asarray([[c.prior_prob] for c in candidates])
prior_probs *= self.prior_weight
prior_probs *= self.cfg.get("prior_weight", 1)
scores = prior_probs
if self.context_weight > 0:
if self.cfg.get("context_weight", 1) > 0:
entity_encodings = xp.asarray([c.entity_vector for c in candidates])
assert len(entity_encodings) == len(prior_probs)
mention_encodings = [list(context_encoding) + list(entity_encodings[i])