add backprop callback to spanpredictor

This commit is contained in:
Kádár Ákos 2022-04-08 14:56:44 +02:00
parent 3ba913109d
commit 2a1ad4c5d2

View File

@ -104,13 +104,13 @@ def convert_coref_scorer_inputs(
# just use the first
# TODO real batching
X = X[0]
word_features = xp2torch(X, requires_grad=is_train)
def backprop(args: ArgsKwargs) -> List[Floats2d]:
# convert to xp and wrap in list
gradients = torch2xp(args.args[0])
return [gradients]
return ArgsKwargs(args=(word_features, ), kwargs={}), backprop
@ -141,16 +141,22 @@ def convert_span_predictor_inputs(
):
tok2vec, (sent_ids, head_ids) = X
# Normally we shoudl use the input is_train, but for these two it's not relevant
def backprop(args: ArgsKwargs) -> List[Floats2d]:
# convert to xp and wrap in list
gradients = torch2xp(args.args[1])
return [[gradients], None]
word_features = xp2torch(tok2vec[0], requires_grad=is_train)
sent_ids = xp2torch(sent_ids[0], requires_grad=False)
if not head_ids[0].size:
head_ids = torch.empty(size=(0,))
else:
head_ids = xp2torch(head_ids[0], requires_grad=False)
word_features = xp2torch(tok2vec[0], requires_grad=is_train)
argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={})
# TODO actually support backprop
return argskwargs, lambda dX: [[]]
return argskwargs, backprop
# TODO This probably belongs in the component, not the model.
@ -247,7 +253,6 @@ def head_data_forward(model, docs, is_train):
heads.append(span[0].i)
heads = model.ops.asarray2i(heads)
head_ids.append(heads)
# each of these is a list with one entry per doc
# backprop is just a placeholder
# TODO it would probably be better to have a list of tuples than two lists of arrays
@ -584,7 +589,6 @@ class SpanPredictor(torch.nn.Module):
scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float('-inf'))
scores[rows, cols] = res[padding_mask]
# Make sure that start <= head <= end during inference
if not self.training:
valid_starts = torch.log((relative_positions >= 0).to(torch.float))