Simplify PrecomputableAffine slightly

This commit is contained in:
Matthw Honnibal 2020-07-07 17:22:47 +02:00
parent a4164f67ca
commit 433dc3c9c9

View File

@ -49,17 +49,14 @@ def forward(model, X, is_train):
model.inc_grad("b", dY.sum(axis=0)) model.inc_grad("b", dY.sum(axis=0))
dY = dY.reshape((dY.shape[0], nO * nP)) dY = dY.reshape((dY.shape[0], nO * nP))
Wopfi = model.ops.as_contig(W.transpose((1, 2, 0, 3))) Wopfi = W.transpose((1, 2, 0, 3))
Wopfi = Wopfi.reshape((nO * nP, nF * nI)) Wopfi = Wopfi.reshape((nO * nP, nF * nI))
dXf = model.ops.gemm(dY.reshape((dY.shape[0], nO * nP)), Wopfi) dXf = model.ops.gemm(dY.reshape((dY.shape[0], nO * nP)), Wopfi)
# Reuse the buffer dWopfi = model.ops.gemm(dY, Xf, trans1=True)
dWopfi = Wopfi
dWopfi.fill(0.0)
model.ops.gemm(dY, Xf, out=dWopfi, trans1=True)
dWopfi = dWopfi.reshape((nO, nP, nF, nI)) dWopfi = dWopfi.reshape((nO, nP, nF, nI))
# (o, p, f, i) --> (f, o, p, i) # (o, p, f, i) --> (f, o, p, i)
dWopfi = model.ops.as_contig(dWopfi.transpose((2, 0, 1, 3))) dWopfi = dWopfi.transpose((2, 0, 1, 3))
model.inc_grad("W", dWopfi) model.inc_grad("W", dWopfi)
return dXf.reshape((dXf.shape[0], nF, nI)) return dXf.reshape((dXf.shape[0], nF, nI))