mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 18:06:29 +03:00
Simplify PrecomputableAffine slightly
This commit is contained in:
parent
a4164f67ca
commit
433dc3c9c9
|
@ -49,17 +49,14 @@ def forward(model, X, is_train):
|
||||||
model.inc_grad("b", dY.sum(axis=0))
|
model.inc_grad("b", dY.sum(axis=0))
|
||||||
dY = dY.reshape((dY.shape[0], nO * nP))
|
dY = dY.reshape((dY.shape[0], nO * nP))
|
||||||
|
|
||||||
Wopfi = model.ops.as_contig(W.transpose((1, 2, 0, 3)))
|
Wopfi = W.transpose((1, 2, 0, 3))
|
||||||
Wopfi = Wopfi.reshape((nO * nP, nF * nI))
|
Wopfi = Wopfi.reshape((nO * nP, nF * nI))
|
||||||
dXf = model.ops.gemm(dY.reshape((dY.shape[0], nO * nP)), Wopfi)
|
dXf = model.ops.gemm(dY.reshape((dY.shape[0], nO * nP)), Wopfi)
|
||||||
|
|
||||||
# Reuse the buffer
|
dWopfi = model.ops.gemm(dY, Xf, trans1=True)
|
||||||
dWopfi = Wopfi
|
|
||||||
dWopfi.fill(0.0)
|
|
||||||
model.ops.gemm(dY, Xf, out=dWopfi, trans1=True)
|
|
||||||
dWopfi = dWopfi.reshape((nO, nP, nF, nI))
|
dWopfi = dWopfi.reshape((nO, nP, nF, nI))
|
||||||
# (o, p, f, i) --> (f, o, p, i)
|
# (o, p, f, i) --> (f, o, p, i)
|
||||||
dWopfi = model.ops.as_contig(dWopfi.transpose((2, 0, 1, 3)))
|
dWopfi = dWopfi.transpose((2, 0, 1, 3))
|
||||||
model.inc_grad("W", dWopfi)
|
model.inc_grad("W", dWopfi)
|
||||||
return dXf.reshape((dXf.shape[0], nF, nI))
|
return dXf.reshape((dXf.shape[0], nF, nI))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user