Clean up pw_prod loss

This doesn't change the math but makes the transposes slightly easier to
understand (maybe?).
This commit is contained in:
Paul O'Leary McCann 2021-07-03 18:32:36 +09:00
parent b02df61eb9
commit 3f66e18592

View File

@ -458,14 +458,13 @@ def pairwise_product(bilinear, dropout, vecs: Floats2d, is_train):
# A neat side effect of this is that we don't have to pass the backprops # A neat side effect of this is that we don't have to pass the backprops
# around separately because the closure handles them. # around separately because the closure handles them.
source, source_b = bilinear(vecs, is_train) source, source_b = bilinear(vecs, is_train)
target, target_b = dropout(vecs, is_train) target, target_b = dropout(vecs.T, is_train)
pw_prod = source @ target.T pw_prod = source @ target
def backward(d_prod: Floats2d) -> Floats2d: def backward(d_prod: Floats2d) -> Floats2d:
dS = source_b(d_prod @ target) dS = source_b(d_prod @ target.T)
#dT = target_b(d_prod @ source) dT = target_b(source.T @ d_prod)
dT = target_b( (source.T @ d_prod).T ) dX = dS + dT.T
dX = dS + dT
return dX return dX
return pw_prod, backward return pw_prod, backward