mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-19 20:52:23 +03:00
Probably fix pw prod backprop
I think this change is correct, but intuition doesn't really help here...
This commit is contained in:
parent
ccf561112a
commit
5c98c4c3b9
|
@ -459,11 +459,12 @@ def pairwise_product(bilinear, dropout, vecs: Floats2d, is_train):
|
|||
# around separately because the closure handles them.
|
||||
source, source_b = bilinear(vecs, is_train)
|
||||
target, target_b = dropout(vecs, is_train)
|
||||
pw_prod = bilinear.ops.xp.matmul(source, target.T)
|
||||
pw_prod = source @ target.T
|
||||
|
||||
def backward(d_prod: Floats2d) -> Floats2d:
|
||||
dS = source_b(d_prod @ target)
|
||||
dT = target_b(d_prod @ source)
|
||||
#dT = target_b(d_prod @ source)
|
||||
dT = target_b( (source.T @ d_prod).T )
|
||||
dX = dS + dT
|
||||
return dX
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user