Probably fix pw prod backprop

I think this change is correct, but intuition doesn't really help
here...
This commit is contained in:
Paul O'Leary McCann 2021-06-17 21:23:00 +09:00
parent ccf561112a
commit 5c98c4c3b9

View File

@ -459,11 +459,12 @@ def pairwise_product(bilinear, dropout, vecs: Floats2d, is_train):
# around separately because the closure handles them. # around separately because the closure handles them.
source, source_b = bilinear(vecs, is_train) source, source_b = bilinear(vecs, is_train)
target, target_b = dropout(vecs, is_train) target, target_b = dropout(vecs, is_train)
pw_prod = bilinear.ops.xp.matmul(source, target.T) pw_prod = source @ target.T
def backward(d_prod: Floats2d) -> Floats2d: def backward(d_prod: Floats2d) -> Floats2d:
dS = source_b(d_prod @ target) dS = source_b(d_prod @ target)
dT = target_b(d_prod @ source) #dT = target_b(d_prod @ source)
dT = target_b( (source.T @ d_prod).T )
dX = dS + dT dX = dS + dT
return dX return dX