mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-19 12:42:20 +03:00
Clean up pw_prod loss
This doesn't change the math but makes the transposes slightly easier to understand (maybe?).
This commit is contained in:
parent
b02df61eb9
commit
3f66e18592
|
@ -458,14 +458,13 @@ def pairwise_product(bilinear, dropout, vecs: Floats2d, is_train):
|
|||
# A neat side effect of this is that we don't have to pass the backprops
|
||||
# around separately because the closure handles them.
|
||||
source, source_b = bilinear(vecs, is_train)
|
||||
target, target_b = dropout(vecs, is_train)
|
||||
pw_prod = source @ target.T
|
||||
target, target_b = dropout(vecs.T, is_train)
|
||||
pw_prod = source @ target
|
||||
|
||||
def backward(d_prod: Floats2d) -> Floats2d:
|
||||
dS = source_b(d_prod @ target)
|
||||
#dT = target_b(d_prod @ source)
|
||||
dT = target_b( (source.T @ d_prod).T )
|
||||
dX = dS + dT
|
||||
dS = source_b(d_prod @ target.T)
|
||||
dT = target_b(source.T @ d_prod)
|
||||
dX = dS + dT.T
|
||||
return dX
|
||||
|
||||
return pw_prod, backward
|
||||
|
|
Loading…
Reference in New Issue
Block a user