From 3f66e185927dcae90bbaaf51ce93e19acf0349cd Mon Sep 17 00:00:00 2001 From: Paul O'Leary McCann Date: Sat, 3 Jul 2021 18:32:36 +0900 Subject: [PATCH] Clean up pw_prod loss This doesn't change the math but makes the transposes slightly easier to understand (maybe?). --- spacy/ml/models/coref.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index 6f2408df5..2155d489c 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -458,14 +458,13 @@ def pairwise_product(bilinear, dropout, vecs: Floats2d, is_train): # A neat side effect of this is that we don't have to pass the backprops # around separately because the closure handles them. source, source_b = bilinear(vecs, is_train) - target, target_b = dropout(vecs, is_train) - pw_prod = source @ target.T + target, target_b = dropout(vecs.T, is_train) + pw_prod = source @ target def backward(d_prod: Floats2d) -> Floats2d: - dS = source_b(d_prod @ target) - #dT = target_b(d_prod @ source) - dT = target_b( (source.T @ d_prod).T ) - dX = dS + dT + dS = source_b(d_prod @ target.T) + dT = target_b(source.T @ d_prod) + dX = dS + dT.T return dX return pw_prod, backward