From 2a1ad4c5d294de02af668e07d19894491afc3204 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A1d=C3=A1r=20=C3=81kos?= Date: Fri, 8 Apr 2022 14:56:44 +0200 Subject: [PATCH] add backprop callback to spanpredictor --- spacy/ml/models/coref.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index 7972f9160..0b533daf0 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -104,13 +104,13 @@ def convert_coref_scorer_inputs( # just use the first # TODO real batching X = X[0] - - word_features = xp2torch(X, requires_grad=is_train) + def backprop(args: ArgsKwargs) -> List[Floats2d]: # convert to xp and wrap in list gradients = torch2xp(args.args[0]) return [gradients] + return ArgsKwargs(args=(word_features, ), kwargs={}), backprop @@ -141,16 +141,22 @@ def convert_span_predictor_inputs( ): tok2vec, (sent_ids, head_ids) = X # Normally we shoudl use the input is_train, but for these two it's not relevant + + def backprop(args: ArgsKwargs) -> List[Floats2d]: + # convert to xp and wrap in list + gradients = torch2xp(args.args[1]) + return [[gradients], None] + + word_features = xp2torch(tok2vec[0], requires_grad=is_train) sent_ids = xp2torch(sent_ids[0], requires_grad=False) if not head_ids[0].size: head_ids = torch.empty(size=(0,)) else: head_ids = xp2torch(head_ids[0], requires_grad=False) - word_features = xp2torch(tok2vec[0], requires_grad=is_train) argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={}) # TODO actually support backprop - return argskwargs, lambda dX: [[]] + return argskwargs, backprop # TODO This probably belongs in the component, not the model. @@ -247,7 +253,6 @@ def head_data_forward(model, docs, is_train): heads.append(span[0].i) heads = model.ops.asarray2i(heads) head_ids.append(heads) - # each of these is a list with one entry per doc # backprop is just a placeholder # TODO it would probably be better to have a list of tuples than two lists of arrays @@ -584,7 +589,6 @@ class SpanPredictor(torch.nn.Module): scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float('-inf')) scores[rows, cols] = res[padding_mask] - # Make sure that start <= head <= end during inference if not self.training: valid_starts = torch.log((relative_positions >= 0).to(torch.float))