mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
fix convolution layer
This commit is contained in:
parent
dd691d0053
commit
7edb2e1711
|
@ -12,9 +12,9 @@ from examples.pipeline.wiki_entity_linking import run_el, training_set_creator,
|
|||
|
||||
from spacy._ml import SpacyVectors, create_default_optimizer, zero_init, logistic
|
||||
|
||||
from thinc.api import chain, concatenate, flatten_add_lengths, clone
|
||||
from thinc.api import chain, concatenate, flatten_add_lengths, clone, with_flatten
|
||||
from thinc.v2v import Model, Maxout, Affine
|
||||
from thinc.t2v import Pooling, mean_pool
|
||||
from thinc.t2v import Pooling, mean_pool, sum_pool
|
||||
from thinc.t2t import ParametricAttention
|
||||
from thinc.misc import Residual
|
||||
from thinc.misc import LayerNorm as LN
|
||||
|
@ -96,13 +96,13 @@ class EL_Model:
|
|||
try:
|
||||
# if to_print:
|
||||
# print()
|
||||
# print(article_count, "Training on article", article_id)
|
||||
print(article_count, "Training on article", article_id)
|
||||
article_count += 1
|
||||
article_docs = list()
|
||||
entities = list()
|
||||
golds = list()
|
||||
for inst_cluster in inst_cluster_set:
|
||||
if instance_pos_count < 2: # TODO remove
|
||||
if instance_pos_count < 2: # TODO del
|
||||
article_docs.append(train_doc[article_id])
|
||||
entities.append(train_pos.get(inst_cluster))
|
||||
golds.append(float(1.0))
|
||||
|
@ -228,16 +228,23 @@ class EL_Model:
|
|||
|
||||
@staticmethod
|
||||
def _encoder(in_width, hidden_width):
|
||||
conv_depth = 1
|
||||
cnn_maxout_pieces = 3
|
||||
|
||||
with Model.define_operators({">>": chain}):
|
||||
convolution = Residual((ExtractWindow(nW=1) >> LN(Maxout(in_width, in_width * 3, pieces=cnn_maxout_pieces))))
|
||||
|
||||
encoder = SpacyVectors \
|
||||
>> flatten_add_lengths \
|
||||
>> ParametricAttention(in_width)\
|
||||
>> Pooling(mean_pool) \
|
||||
>> (ExtractWindow(nW=1) >> LN(Maxout(in_width, in_width * 3))) \
|
||||
>> zero_init(Affine(hidden_width, in_width, drop_factor=0.0))
|
||||
>> with_flatten(LN(Maxout(in_width, in_width)) >> convolution ** conv_depth, pad=conv_depth) \
|
||||
>> flatten_add_lengths \
|
||||
>> ParametricAttention(in_width)\
|
||||
>> Pooling(mean_pool) \
|
||||
>> Residual(zero_init(Maxout(in_width, in_width))) \
|
||||
>> zero_init(Affine(hidden_width, in_width, drop_factor=0.0))
|
||||
|
||||
# TODO: ReLu instead of LN(Maxout) ?
|
||||
# TODO: more convolutions ?
|
||||
# sum_pool or mean_pool ?
|
||||
|
||||
return encoder
|
||||
|
||||
|
@ -261,16 +268,17 @@ class EL_Model:
|
|||
print(doc_encoding)
|
||||
|
||||
doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=self.DROP)
|
||||
entity_encodings, bp_encoding = self.entity_encoder.begin_update(entities, drop=self.DROP)
|
||||
concat_encodings = [list(entity_encodings[i]) + list(doc_encodings[i]) for i in range(len(entities))]
|
||||
|
||||
print("doc_encodings", len(doc_encodings), doc_encodings)
|
||||
|
||||
entity_encodings, bp_encoding = self.entity_encoder.begin_update(entities, drop=self.DROP)
|
||||
print("entity_encodings", len(entity_encodings), entity_encodings)
|
||||
print("concat_encodings", len(concat_encodings), concat_encodings)
|
||||
|
||||
concat_encodings = [list(entity_encodings[i]) + list(doc_encodings[i]) for i in range(len(entities))]
|
||||
# print("concat_encodings", len(concat_encodings), concat_encodings)
|
||||
|
||||
predictions, bp_model = self.model.begin_update(np.asarray(concat_encodings), drop=self.DROP)
|
||||
print("predictions", predictions)
|
||||
predictions = self.model.ops.flatten(predictions)
|
||||
print("predictions", predictions)
|
||||
golds = self.model.ops.asarray(golds)
|
||||
|
||||
loss, d_scores = self.get_loss(predictions, golds)
|
||||
|
@ -287,15 +295,15 @@ class EL_Model:
|
|||
|
||||
d_scores = d_scores.reshape((-1, 1))
|
||||
d_scores = d_scores.astype(np.float32)
|
||||
print("d_scores", d_scores)
|
||||
# print("d_scores", d_scores)
|
||||
|
||||
model_gradient = bp_model(d_scores, sgd=self.sgd)
|
||||
print("model_gradient", model_gradient)
|
||||
# print("model_gradient", model_gradient)
|
||||
|
||||
doc_gradient = [x[0:self.ARTICLE_WIDTH] for x in model_gradient]
|
||||
print("doc_gradient", doc_gradient)
|
||||
# print("doc_gradient", doc_gradient)
|
||||
entity_gradient = [x[self.ARTICLE_WIDTH:] for x in model_gradient]
|
||||
print("entity_gradient", entity_gradient)
|
||||
# print("entity_gradient", entity_gradient)
|
||||
|
||||
bp_doc(doc_gradient)
|
||||
bp_encoding(entity_gradient)
|
||||
|
|
Loading…
Reference in New Issue
Block a user