mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-27 09:44:36 +03:00
Vectorize update in AddHistory
This commit is contained in:
parent
18347ab69c
commit
d84607f6bb
21
spacy/_ml.py
21
spacy/_ml.py
|
@ -81,30 +81,27 @@ def add_tuples(X, drop=0.):
|
||||||
def AddHistory(layer, decay=0.0001):
|
def AddHistory(layer, decay=0.0001):
|
||||||
ops = layer.ops
|
ops = layer.ops
|
||||||
nonlocals = []
|
nonlocals = []
|
||||||
if layer.nI:
|
|
||||||
average_inputs = ops.allocate((layer.nO, layer.nI-layer.nO))
|
|
||||||
nonlocals = []
|
|
||||||
def history_fwd(X, drop=0.):
|
def history_fwd(X, drop=0.):
|
||||||
if not nonlocals:
|
if not nonlocals:
|
||||||
nonlocals.append(ops.allocate((layer.nO, X.shape[1])))
|
if hasattr(layer, 'nO'):
|
||||||
|
nO = layer.nO
|
||||||
|
else:
|
||||||
|
nO = layer._layers[-1].nO
|
||||||
|
nonlocals.append(ops.allocate((nO, X.shape[1])))
|
||||||
model.history = nonlocals[0]
|
model.history = nonlocals[0]
|
||||||
average_inputs = nonlocals[0]
|
average_inputs = nonlocals[0]
|
||||||
hist = ops.xp.tensordot(X, average_inputs, axes=[[1], [1]])
|
hist = ops.xp.tensordot(X, average_inputs, axes=[[1], [1]])
|
||||||
X_hist = ops.xp.hstack((X, hist))
|
X_hist = ops.xp.hstack((X, hist))
|
||||||
Y, bp_Y = layer.begin_update(X_hist, drop=drop)
|
Y, bp_Y = layer.begin_update(X_hist, drop=drop)
|
||||||
for i in range(Y.shape[0]):
|
amax = Y.argmax(axis=1)
|
||||||
amax = Y[i].argmax()
|
average_inputs *= 1-decay
|
||||||
average_inputs[amax] *= 1-decay
|
ops.scatter_add(average_inputs, amax, X * decay)
|
||||||
average_inputs[amax] += decay * X[i]
|
|
||||||
def history_bwd(dY, sgd=None):
|
def history_bwd(dY, sgd=None):
|
||||||
dX_hist = bp_Y(dY, sgd=sgd)
|
dX_hist = bp_Y(dY, sgd=sgd)
|
||||||
dX = dX_hist[:, :X.shape[1]]
|
dX = dX_hist[:, :X.shape[1]]
|
||||||
return dX
|
return ops.xp.ascontiguousarray(dX)
|
||||||
return Y, history_bwd
|
return Y, history_bwd
|
||||||
model = wrap(history_fwd, layer)
|
model = wrap(history_fwd, layer)
|
||||||
if layer.nI:
|
|
||||||
model.history = average_inputs
|
|
||||||
else:
|
|
||||||
model.history = None
|
model.history = None
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user