small fixes

This commit is contained in:
svlandeg 2019-05-20 17:20:39 +02:00
parent 7edb2e1711
commit 89e322a637

View File

@ -71,7 +71,7 @@ class EL_Model:
self._test_dev(train_inst, train_pos, train_neg, train_doc, print_string="train_random", calc_random=True)
self._test_dev(dev_inst, dev_pos, dev_neg, dev_doc, print_string="dev_random", calc_random=True)
print()
self._test_dev(train_inst, train_pos, train_neg, train_doc, print_string="train_pre", calc_random=False)
self._test_dev(train_inst, train_pos, train_neg, train_doc, print_string="train_pre", avg=False)
self._test_dev(dev_inst, dev_pos, dev_neg, dev_doc, print_string="dev_pre", avg=False)
instance_pos_count = 0
@ -113,7 +113,7 @@ class EL_Model:
golds.append(float(0.0))
instance_neg_count += 1
for k in range(5):
for k in range(10):
print()
print("update", k)
print()
@ -182,7 +182,7 @@ class EL_Model:
def _predict(self, article_doc, entity, avg=False, apply_threshold=True):
if avg:
with self.article_encoder.use_params(self.sgd_article.averages) \
and self.entity_encoder.use_params(self.sgd_article.averages):
and self.entity_encoder.use_params(self.sgd_entity.averages):
doc_encoding = self.article_encoder([article_doc])[0]
entity_encoding = self.entity_encoder([entity])[0]
@ -228,7 +228,7 @@ class EL_Model:
@staticmethod
def _encoder(in_width, hidden_width):
conv_depth = 1
conv_depth = 2
cnn_maxout_pieces = 3
with Model.define_operators({">>": chain}):
@ -261,16 +261,10 @@ class EL_Model:
return loss, d_scores
def update(self, article_docs, entities, golds, apply_threshold=True):
print("article_docs", len(article_docs))
for a in article_docs:
print(a[0:10], a[-10:])
doc_encoding, bp_doc = self.article_encoder.begin_update([a], drop=self.DROP)
print(doc_encoding)
doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=self.DROP)
print("doc_encodings", len(doc_encodings), doc_encodings)
entity_encodings, bp_encoding = self.entity_encoder.begin_update(entities, drop=self.DROP)
entity_encodings, bp_entity = self.entity_encoder.begin_update(entities, drop=self.DROP)
print("entity_encodings", len(entity_encodings), entity_encodings)
concat_encodings = [list(entity_encodings[i]) + list(doc_encodings[i]) for i in range(len(entities))]
@ -298,15 +292,19 @@ class EL_Model:
# print("d_scores", d_scores)
model_gradient = bp_model(d_scores, sgd=self.sgd)
# print("model_gradient", model_gradient)
print("model_gradient", model_gradient)
doc_gradient = [x[0:self.ARTICLE_WIDTH] for x in model_gradient]
# print("doc_gradient", doc_gradient)
entity_gradient = [x[self.ARTICLE_WIDTH:] for x in model_gradient]
# print("entity_gradient", entity_gradient)
doc_gradient = list()
entity_gradient = list()
for x in model_gradient:
doc_gradient.append(list(x[0:self.ARTICLE_WIDTH]))
entity_gradient.append(list(x[self.ARTICLE_WIDTH:]))
print("doc_gradient", doc_gradient)
print("entity_gradient", entity_gradient)
bp_doc(doc_gradient)
bp_encoding(entity_gradient)
bp_entity(entity_gradient)
def _get_training_data(self, training_dir, entity_descr_output, dev, limit, to_print):
id_to_descr = kb_creator._get_id_to_description(entity_descr_output)