diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index 840066bc7..6f2408df5 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -459,11 +459,12 @@ def pairwise_product(bilinear, dropout, vecs: Floats2d, is_train): # around separately because the closure handles them. source, source_b = bilinear(vecs, is_train) target, target_b = dropout(vecs, is_train) - pw_prod = bilinear.ops.xp.matmul(source, target.T) + pw_prod = source @ target.T def backward(d_prod: Floats2d) -> Floats2d: dS = source_b(d_prod @ target) - dT = target_b(d_prod @ source) + #dT = target_b(d_prod @ source) + dT = target_b( (source.T @ d_prod).T ) dX = dS + dT return dX