mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 20:22:25 +03:00
add backprop callback to spanpredictor
This commit is contained in:
parent
3ba913109d
commit
2a1ad4c5d2
|
@ -104,13 +104,13 @@ def convert_coref_scorer_inputs(
|
|||
# just use the first
|
||||
# TODO real batching
|
||||
X = X[0]
|
||||
|
||||
|
||||
word_features = xp2torch(X, requires_grad=is_train)
|
||||
|
||||
def backprop(args: ArgsKwargs) -> List[Floats2d]:
|
||||
# convert to xp and wrap in list
|
||||
gradients = torch2xp(args.args[0])
|
||||
return [gradients]
|
||||
|
||||
return ArgsKwargs(args=(word_features, ), kwargs={}), backprop
|
||||
|
||||
|
||||
|
@ -141,16 +141,22 @@ def convert_span_predictor_inputs(
|
|||
):
|
||||
tok2vec, (sent_ids, head_ids) = X
|
||||
# Normally we shoudl use the input is_train, but for these two it's not relevant
|
||||
|
||||
def backprop(args: ArgsKwargs) -> List[Floats2d]:
|
||||
# convert to xp and wrap in list
|
||||
gradients = torch2xp(args.args[1])
|
||||
return [[gradients], None]
|
||||
|
||||
word_features = xp2torch(tok2vec[0], requires_grad=is_train)
|
||||
sent_ids = xp2torch(sent_ids[0], requires_grad=False)
|
||||
if not head_ids[0].size:
|
||||
head_ids = torch.empty(size=(0,))
|
||||
else:
|
||||
head_ids = xp2torch(head_ids[0], requires_grad=False)
|
||||
word_features = xp2torch(tok2vec[0], requires_grad=is_train)
|
||||
|
||||
argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={})
|
||||
# TODO actually support backprop
|
||||
return argskwargs, lambda dX: [[]]
|
||||
return argskwargs, backprop
|
||||
|
||||
|
||||
# TODO This probably belongs in the component, not the model.
|
||||
|
@ -247,7 +253,6 @@ def head_data_forward(model, docs, is_train):
|
|||
heads.append(span[0].i)
|
||||
heads = model.ops.asarray2i(heads)
|
||||
head_ids.append(heads)
|
||||
|
||||
# each of these is a list with one entry per doc
|
||||
# backprop is just a placeholder
|
||||
# TODO it would probably be better to have a list of tuples than two lists of arrays
|
||||
|
@ -584,7 +589,6 @@ class SpanPredictor(torch.nn.Module):
|
|||
|
||||
scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float('-inf'))
|
||||
scores[rows, cols] = res[padding_mask]
|
||||
|
||||
# Make sure that start <= head <= end during inference
|
||||
if not self.training:
|
||||
valid_starts = torch.log((relative_positions >= 0).to(torch.float))
|
||||
|
|
Loading…
Reference in New Issue
Block a user