Vectorize update in AddHistory

This commit is contained in:
Matthew Honnibal 2017-09-14 20:34:40 +02:00
parent 18347ab69c
commit d84607f6bb

View File

@ -81,30 +81,27 @@ def add_tuples(X, drop=0.):
def AddHistory(layer, decay=0.0001): def AddHistory(layer, decay=0.0001):
ops = layer.ops ops = layer.ops
nonlocals = [] nonlocals = []
if layer.nI:
average_inputs = ops.allocate((layer.nO, layer.nI-layer.nO))
nonlocals = []
def history_fwd(X, drop=0.): def history_fwd(X, drop=0.):
if not nonlocals: if not nonlocals:
nonlocals.append(ops.allocate((layer.nO, X.shape[1]))) if hasattr(layer, 'nO'):
nO = layer.nO
else:
nO = layer._layers[-1].nO
nonlocals.append(ops.allocate((nO, X.shape[1])))
model.history = nonlocals[0] model.history = nonlocals[0]
average_inputs = nonlocals[0] average_inputs = nonlocals[0]
hist = ops.xp.tensordot(X, average_inputs, axes=[[1], [1]]) hist = ops.xp.tensordot(X, average_inputs, axes=[[1], [1]])
X_hist = ops.xp.hstack((X, hist)) X_hist = ops.xp.hstack((X, hist))
Y, bp_Y = layer.begin_update(X_hist, drop=drop) Y, bp_Y = layer.begin_update(X_hist, drop=drop)
for i in range(Y.shape[0]): amax = Y.argmax(axis=1)
amax = Y[i].argmax() average_inputs *= 1-decay
average_inputs[amax] *= 1-decay ops.scatter_add(average_inputs, amax, X * decay)
average_inputs[amax] += decay * X[i]
def history_bwd(dY, sgd=None): def history_bwd(dY, sgd=None):
dX_hist = bp_Y(dY, sgd=sgd) dX_hist = bp_Y(dY, sgd=sgd)
dX = dX_hist[:, :X.shape[1]] dX = dX_hist[:, :X.shape[1]]
return dX return ops.xp.ascontiguousarray(dX)
return Y, history_bwd return Y, history_bwd
model = wrap(history_fwd, layer) model = wrap(history_fwd, layer)
if layer.nI:
model.history = average_inputs
else:
model.history = None model.history = None
return model return model