mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-10 16:22:29 +03:00
small fixes
This commit is contained in:
parent
7edb2e1711
commit
89e322a637
|
@ -33,9 +33,9 @@ class EL_Model:
|
||||||
CUTOFF = 0.5
|
CUTOFF = 0.5
|
||||||
|
|
||||||
INPUT_DIM = 300
|
INPUT_DIM = 300
|
||||||
ENTITY_WIDTH = 4 # 64
|
ENTITY_WIDTH = 4 # 64
|
||||||
ARTICLE_WIDTH = 8 # 128
|
ARTICLE_WIDTH = 8 # 128
|
||||||
HIDDEN_WIDTH = 6 # 64
|
HIDDEN_WIDTH = 6 # 64
|
||||||
|
|
||||||
DROP = 0.00
|
DROP = 0.00
|
||||||
|
|
||||||
|
@ -71,7 +71,7 @@ class EL_Model:
|
||||||
self._test_dev(train_inst, train_pos, train_neg, train_doc, print_string="train_random", calc_random=True)
|
self._test_dev(train_inst, train_pos, train_neg, train_doc, print_string="train_random", calc_random=True)
|
||||||
self._test_dev(dev_inst, dev_pos, dev_neg, dev_doc, print_string="dev_random", calc_random=True)
|
self._test_dev(dev_inst, dev_pos, dev_neg, dev_doc, print_string="dev_random", calc_random=True)
|
||||||
print()
|
print()
|
||||||
self._test_dev(train_inst, train_pos, train_neg, train_doc, print_string="train_pre", calc_random=False)
|
self._test_dev(train_inst, train_pos, train_neg, train_doc, print_string="train_pre", avg=False)
|
||||||
self._test_dev(dev_inst, dev_pos, dev_neg, dev_doc, print_string="dev_pre", avg=False)
|
self._test_dev(dev_inst, dev_pos, dev_neg, dev_doc, print_string="dev_pre", avg=False)
|
||||||
|
|
||||||
instance_pos_count = 0
|
instance_pos_count = 0
|
||||||
|
@ -113,7 +113,7 @@ class EL_Model:
|
||||||
golds.append(float(0.0))
|
golds.append(float(0.0))
|
||||||
instance_neg_count += 1
|
instance_neg_count += 1
|
||||||
|
|
||||||
for k in range(5):
|
for k in range(10):
|
||||||
print()
|
print()
|
||||||
print("update", k)
|
print("update", k)
|
||||||
print()
|
print()
|
||||||
|
@ -182,7 +182,7 @@ class EL_Model:
|
||||||
def _predict(self, article_doc, entity, avg=False, apply_threshold=True):
|
def _predict(self, article_doc, entity, avg=False, apply_threshold=True):
|
||||||
if avg:
|
if avg:
|
||||||
with self.article_encoder.use_params(self.sgd_article.averages) \
|
with self.article_encoder.use_params(self.sgd_article.averages) \
|
||||||
and self.entity_encoder.use_params(self.sgd_article.averages):
|
and self.entity_encoder.use_params(self.sgd_entity.averages):
|
||||||
doc_encoding = self.article_encoder([article_doc])[0]
|
doc_encoding = self.article_encoder([article_doc])[0]
|
||||||
entity_encoding = self.entity_encoder([entity])[0]
|
entity_encoding = self.entity_encoder([entity])[0]
|
||||||
|
|
||||||
|
@ -228,7 +228,7 @@ class EL_Model:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _encoder(in_width, hidden_width):
|
def _encoder(in_width, hidden_width):
|
||||||
conv_depth = 1
|
conv_depth = 2
|
||||||
cnn_maxout_pieces = 3
|
cnn_maxout_pieces = 3
|
||||||
|
|
||||||
with Model.define_operators({">>": chain}):
|
with Model.define_operators({">>": chain}):
|
||||||
|
@ -261,16 +261,10 @@ class EL_Model:
|
||||||
return loss, d_scores
|
return loss, d_scores
|
||||||
|
|
||||||
def update(self, article_docs, entities, golds, apply_threshold=True):
|
def update(self, article_docs, entities, golds, apply_threshold=True):
|
||||||
print("article_docs", len(article_docs))
|
|
||||||
for a in article_docs:
|
|
||||||
print(a[0:10], a[-10:])
|
|
||||||
doc_encoding, bp_doc = self.article_encoder.begin_update([a], drop=self.DROP)
|
|
||||||
print(doc_encoding)
|
|
||||||
|
|
||||||
doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=self.DROP)
|
doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=self.DROP)
|
||||||
print("doc_encodings", len(doc_encodings), doc_encodings)
|
print("doc_encodings", len(doc_encodings), doc_encodings)
|
||||||
|
|
||||||
entity_encodings, bp_encoding = self.entity_encoder.begin_update(entities, drop=self.DROP)
|
entity_encodings, bp_entity = self.entity_encoder.begin_update(entities, drop=self.DROP)
|
||||||
print("entity_encodings", len(entity_encodings), entity_encodings)
|
print("entity_encodings", len(entity_encodings), entity_encodings)
|
||||||
|
|
||||||
concat_encodings = [list(entity_encodings[i]) + list(doc_encodings[i]) for i in range(len(entities))]
|
concat_encodings = [list(entity_encodings[i]) + list(doc_encodings[i]) for i in range(len(entities))]
|
||||||
|
@ -298,15 +292,19 @@ class EL_Model:
|
||||||
# print("d_scores", d_scores)
|
# print("d_scores", d_scores)
|
||||||
|
|
||||||
model_gradient = bp_model(d_scores, sgd=self.sgd)
|
model_gradient = bp_model(d_scores, sgd=self.sgd)
|
||||||
# print("model_gradient", model_gradient)
|
print("model_gradient", model_gradient)
|
||||||
|
|
||||||
doc_gradient = [x[0:self.ARTICLE_WIDTH] for x in model_gradient]
|
doc_gradient = list()
|
||||||
# print("doc_gradient", doc_gradient)
|
entity_gradient = list()
|
||||||
entity_gradient = [x[self.ARTICLE_WIDTH:] for x in model_gradient]
|
for x in model_gradient:
|
||||||
# print("entity_gradient", entity_gradient)
|
doc_gradient.append(list(x[0:self.ARTICLE_WIDTH]))
|
||||||
|
entity_gradient.append(list(x[self.ARTICLE_WIDTH:]))
|
||||||
|
|
||||||
|
print("doc_gradient", doc_gradient)
|
||||||
|
print("entity_gradient", entity_gradient)
|
||||||
|
|
||||||
bp_doc(doc_gradient)
|
bp_doc(doc_gradient)
|
||||||
bp_encoding(entity_gradient)
|
bp_entity(entity_gradient)
|
||||||
|
|
||||||
def _get_training_data(self, training_dir, entity_descr_output, dev, limit, to_print):
|
def _get_training_data(self, training_dir, entity_descr_output, dev, limit, to_print):
|
||||||
id_to_descr = kb_creator._get_id_to_description(entity_descr_output)
|
id_to_descr = kb_creator._get_id_to_description(entity_descr_output)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user