mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 20:22:25 +03:00
add backprop callback to spanpredictor
This commit is contained in:
parent
3ba913109d
commit
2a1ad4c5d2
|
@ -104,13 +104,13 @@ def convert_coref_scorer_inputs(
|
||||||
# just use the first
|
# just use the first
|
||||||
# TODO real batching
|
# TODO real batching
|
||||||
X = X[0]
|
X = X[0]
|
||||||
|
|
||||||
|
|
||||||
word_features = xp2torch(X, requires_grad=is_train)
|
word_features = xp2torch(X, requires_grad=is_train)
|
||||||
|
|
||||||
def backprop(args: ArgsKwargs) -> List[Floats2d]:
|
def backprop(args: ArgsKwargs) -> List[Floats2d]:
|
||||||
# convert to xp and wrap in list
|
# convert to xp and wrap in list
|
||||||
gradients = torch2xp(args.args[0])
|
gradients = torch2xp(args.args[0])
|
||||||
return [gradients]
|
return [gradients]
|
||||||
|
|
||||||
return ArgsKwargs(args=(word_features, ), kwargs={}), backprop
|
return ArgsKwargs(args=(word_features, ), kwargs={}), backprop
|
||||||
|
|
||||||
|
|
||||||
|
@ -141,16 +141,22 @@ def convert_span_predictor_inputs(
|
||||||
):
|
):
|
||||||
tok2vec, (sent_ids, head_ids) = X
|
tok2vec, (sent_ids, head_ids) = X
|
||||||
# Normally we shoudl use the input is_train, but for these two it's not relevant
|
# Normally we shoudl use the input is_train, but for these two it's not relevant
|
||||||
|
|
||||||
|
def backprop(args: ArgsKwargs) -> List[Floats2d]:
|
||||||
|
# convert to xp and wrap in list
|
||||||
|
gradients = torch2xp(args.args[1])
|
||||||
|
return [[gradients], None]
|
||||||
|
|
||||||
|
word_features = xp2torch(tok2vec[0], requires_grad=is_train)
|
||||||
sent_ids = xp2torch(sent_ids[0], requires_grad=False)
|
sent_ids = xp2torch(sent_ids[0], requires_grad=False)
|
||||||
if not head_ids[0].size:
|
if not head_ids[0].size:
|
||||||
head_ids = torch.empty(size=(0,))
|
head_ids = torch.empty(size=(0,))
|
||||||
else:
|
else:
|
||||||
head_ids = xp2torch(head_ids[0], requires_grad=False)
|
head_ids = xp2torch(head_ids[0], requires_grad=False)
|
||||||
word_features = xp2torch(tok2vec[0], requires_grad=is_train)
|
|
||||||
|
|
||||||
argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={})
|
argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={})
|
||||||
# TODO actually support backprop
|
# TODO actually support backprop
|
||||||
return argskwargs, lambda dX: [[]]
|
return argskwargs, backprop
|
||||||
|
|
||||||
|
|
||||||
# TODO This probably belongs in the component, not the model.
|
# TODO This probably belongs in the component, not the model.
|
||||||
|
@ -247,7 +253,6 @@ def head_data_forward(model, docs, is_train):
|
||||||
heads.append(span[0].i)
|
heads.append(span[0].i)
|
||||||
heads = model.ops.asarray2i(heads)
|
heads = model.ops.asarray2i(heads)
|
||||||
head_ids.append(heads)
|
head_ids.append(heads)
|
||||||
|
|
||||||
# each of these is a list with one entry per doc
|
# each of these is a list with one entry per doc
|
||||||
# backprop is just a placeholder
|
# backprop is just a placeholder
|
||||||
# TODO it would probably be better to have a list of tuples than two lists of arrays
|
# TODO it would probably be better to have a list of tuples than two lists of arrays
|
||||||
|
@ -584,7 +589,6 @@ class SpanPredictor(torch.nn.Module):
|
||||||
|
|
||||||
scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float('-inf'))
|
scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float('-inf'))
|
||||||
scores[rows, cols] = res[padding_mask]
|
scores[rows, cols] = res[padding_mask]
|
||||||
|
|
||||||
# Make sure that start <= head <= end during inference
|
# Make sure that start <= head <= end during inference
|
||||||
if not self.training:
|
if not self.training:
|
||||||
valid_starts = torch.log((relative_positions >= 0).to(torch.float))
|
valid_starts = torch.log((relative_positions >= 0).to(torch.float))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user