mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-25 00:34:20 +03:00
Implement AddHistory layer wrapper
This commit is contained in:
parent
d4ca6cef9e
commit
18347ab69c
31
spacy/_ml.py
31
spacy/_ml.py
|
@ -78,6 +78,37 @@ def add_tuples(X, drop=0.):
|
|||
return (vals1+vals2, length), add_tuples_bwd
|
||||
|
||||
|
||||
def AddHistory(layer, decay=0.0001):
|
||||
ops = layer.ops
|
||||
nonlocals = []
|
||||
if layer.nI:
|
||||
average_inputs = ops.allocate((layer.nO, layer.nI-layer.nO))
|
||||
nonlocals = []
|
||||
def history_fwd(X, drop=0.):
|
||||
if not nonlocals:
|
||||
nonlocals.append(ops.allocate((layer.nO, X.shape[1])))
|
||||
model.history = nonlocals[0]
|
||||
average_inputs = nonlocals[0]
|
||||
hist = ops.xp.tensordot(X, average_inputs, axes=[[1], [1]])
|
||||
X_hist = ops.xp.hstack((X, hist))
|
||||
Y, bp_Y = layer.begin_update(X_hist, drop=drop)
|
||||
for i in range(Y.shape[0]):
|
||||
amax = Y[i].argmax()
|
||||
average_inputs[amax] *= 1-decay
|
||||
average_inputs[amax] += decay * X[i]
|
||||
def history_bwd(dY, sgd=None):
|
||||
dX_hist = bp_Y(dY, sgd=sgd)
|
||||
dX = dX_hist[:, :X.shape[1]]
|
||||
return dX
|
||||
return Y, history_bwd
|
||||
model = wrap(history_fwd, layer)
|
||||
if layer.nI:
|
||||
model.history = average_inputs
|
||||
else:
|
||||
model.history = None
|
||||
return model
|
||||
|
||||
|
||||
def _zero_init(model):
|
||||
def _zero_init_impl(self, X, y):
|
||||
self.W.fill(0)
|
||||
|
|
Loading…
Reference in New Issue
Block a user