mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 18:56:36 +03:00
small fixes
This commit is contained in:
parent
7edb2e1711
commit
89e322a637
|
@ -71,7 +71,7 @@ class EL_Model:
|
|||
self._test_dev(train_inst, train_pos, train_neg, train_doc, print_string="train_random", calc_random=True)
|
||||
self._test_dev(dev_inst, dev_pos, dev_neg, dev_doc, print_string="dev_random", calc_random=True)
|
||||
print()
|
||||
self._test_dev(train_inst, train_pos, train_neg, train_doc, print_string="train_pre", calc_random=False)
|
||||
self._test_dev(train_inst, train_pos, train_neg, train_doc, print_string="train_pre", avg=False)
|
||||
self._test_dev(dev_inst, dev_pos, dev_neg, dev_doc, print_string="dev_pre", avg=False)
|
||||
|
||||
instance_pos_count = 0
|
||||
|
@ -113,7 +113,7 @@ class EL_Model:
|
|||
golds.append(float(0.0))
|
||||
instance_neg_count += 1
|
||||
|
||||
for k in range(5):
|
||||
for k in range(10):
|
||||
print()
|
||||
print("update", k)
|
||||
print()
|
||||
|
@ -182,7 +182,7 @@ class EL_Model:
|
|||
def _predict(self, article_doc, entity, avg=False, apply_threshold=True):
|
||||
if avg:
|
||||
with self.article_encoder.use_params(self.sgd_article.averages) \
|
||||
and self.entity_encoder.use_params(self.sgd_article.averages):
|
||||
and self.entity_encoder.use_params(self.sgd_entity.averages):
|
||||
doc_encoding = self.article_encoder([article_doc])[0]
|
||||
entity_encoding = self.entity_encoder([entity])[0]
|
||||
|
||||
|
@ -228,7 +228,7 @@ class EL_Model:
|
|||
|
||||
@staticmethod
|
||||
def _encoder(in_width, hidden_width):
|
||||
conv_depth = 1
|
||||
conv_depth = 2
|
||||
cnn_maxout_pieces = 3
|
||||
|
||||
with Model.define_operators({">>": chain}):
|
||||
|
@ -261,16 +261,10 @@ class EL_Model:
|
|||
return loss, d_scores
|
||||
|
||||
def update(self, article_docs, entities, golds, apply_threshold=True):
|
||||
print("article_docs", len(article_docs))
|
||||
for a in article_docs:
|
||||
print(a[0:10], a[-10:])
|
||||
doc_encoding, bp_doc = self.article_encoder.begin_update([a], drop=self.DROP)
|
||||
print(doc_encoding)
|
||||
|
||||
doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=self.DROP)
|
||||
print("doc_encodings", len(doc_encodings), doc_encodings)
|
||||
|
||||
entity_encodings, bp_encoding = self.entity_encoder.begin_update(entities, drop=self.DROP)
|
||||
entity_encodings, bp_entity = self.entity_encoder.begin_update(entities, drop=self.DROP)
|
||||
print("entity_encodings", len(entity_encodings), entity_encodings)
|
||||
|
||||
concat_encodings = [list(entity_encodings[i]) + list(doc_encodings[i]) for i in range(len(entities))]
|
||||
|
@ -298,15 +292,19 @@ class EL_Model:
|
|||
# print("d_scores", d_scores)
|
||||
|
||||
model_gradient = bp_model(d_scores, sgd=self.sgd)
|
||||
# print("model_gradient", model_gradient)
|
||||
print("model_gradient", model_gradient)
|
||||
|
||||
doc_gradient = [x[0:self.ARTICLE_WIDTH] for x in model_gradient]
|
||||
# print("doc_gradient", doc_gradient)
|
||||
entity_gradient = [x[self.ARTICLE_WIDTH:] for x in model_gradient]
|
||||
# print("entity_gradient", entity_gradient)
|
||||
doc_gradient = list()
|
||||
entity_gradient = list()
|
||||
for x in model_gradient:
|
||||
doc_gradient.append(list(x[0:self.ARTICLE_WIDTH]))
|
||||
entity_gradient.append(list(x[self.ARTICLE_WIDTH:]))
|
||||
|
||||
print("doc_gradient", doc_gradient)
|
||||
print("entity_gradient", entity_gradient)
|
||||
|
||||
bp_doc(doc_gradient)
|
||||
bp_encoding(entity_gradient)
|
||||
bp_entity(entity_gradient)
|
||||
|
||||
def _get_training_data(self, training_dir, entity_descr_output, dev, limit, to_print):
|
||||
id_to_descr = kb_creator._get_id_to_description(entity_descr_output)
|
||||
|
|
Loading…
Reference in New Issue
Block a user