mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Add 'reapply' combinator, for iterated CNN
This commit is contained in:
parent
40a4873b70
commit
a186596307
22
spacy/_ml.py
22
spacy/_ml.py
|
@ -271,6 +271,28 @@ def Tok2Vec(width, embed_size, pretrained_dims=0, **kwargs):
|
|||
return tok2vec
|
||||
|
||||
|
||||
def reapply(layer, n_times):
|
||||
def reapply_fwd(X, drop=0.):
|
||||
backprops = []
|
||||
for i in range(n_times):
|
||||
Y, backprop = layer.begin_update(X, drop=drop)
|
||||
X = Y
|
||||
backprops.append(backprop)
|
||||
def reapply_bwd(dY, sgd=None):
|
||||
dX = None
|
||||
for backprop in reversed(backprops):
|
||||
dY = backprop(dY, sgd=sgd)
|
||||
if dX is None:
|
||||
dX = dY
|
||||
else:
|
||||
dX += dY
|
||||
return dX
|
||||
return Y, reapply_bwd
|
||||
return wrap(reapply_fwd, layer)
|
||||
|
||||
|
||||
|
||||
|
||||
def asarray(ops, dtype):
|
||||
def forward(X, drop=0.):
|
||||
return ops.asarray(X, dtype=dtype), None
|
||||
|
|
Loading…
Reference in New Issue
Block a user