mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 12:12:20 +03:00
handle empty head_ids
This commit is contained in:
parent
7304604edd
commit
4fc40340f9
|
@ -133,6 +133,7 @@ def convert_coref_scorer_outputs(
|
|||
indices_xp = torch2xp(indices)
|
||||
return (scores_xp, indices_xp), convert_for_torch_backward
|
||||
|
||||
|
||||
def convert_span_predictor_inputs(
|
||||
model: Model,
|
||||
X: Tuple[Ints1d, Floats2d, Ints1d],
|
||||
|
@ -141,13 +142,17 @@ def convert_span_predictor_inputs(
|
|||
tok2vec, (sent_ids, head_ids) = X
|
||||
# Normally we shoudl use the input is_train, but for these two it's not relevant
|
||||
sent_ids = xp2torch(sent_ids[0], requires_grad=False)
|
||||
head_ids = xp2torch(head_ids[0], requires_grad=False)
|
||||
if not head_ids[0].size:
|
||||
head_ids = torch.empty(size=(0,))
|
||||
else:
|
||||
head_ids = xp2torch(head_ids[0], requires_grad=False)
|
||||
word_features = xp2torch(tok2vec[0], requires_grad=is_train)
|
||||
|
||||
argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={})
|
||||
# TODO actually support backprop
|
||||
return argskwargs, lambda dX: [[]]
|
||||
|
||||
|
||||
# TODO This probably belongs in the component, not the model.
|
||||
def predict_span_clusters(span_predictor: Model,
|
||||
sent_ids: Ints1d,
|
||||
|
@ -543,6 +548,9 @@ class SpanPredictor(torch.nn.Module):
|
|||
Returns:
|
||||
torch.Tensor: span start/end scores, [n_heads, n_words, 2]
|
||||
"""
|
||||
# If we don't receive heads, return empty
|
||||
if heads_ids.nelement() == 0:
|
||||
return torch.empty(size=(0,))
|
||||
# Obtain distance embedding indices, [n_heads, n_words]
|
||||
relative_positions = (heads_ids.unsqueeze(1) - torch.arange(words.shape[0]).unsqueeze(0))
|
||||
# make all valid distances positive
|
||||
|
@ -550,7 +558,6 @@ class SpanPredictor(torch.nn.Module):
|
|||
# "too_far"
|
||||
emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127
|
||||
# Obtain "same sentence" boolean mask, [n_heads, n_words]
|
||||
sent_id = torch.tensor(sent_id)
|
||||
heads_ids = heads_ids.long()
|
||||
same_sent = (sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0))
|
||||
# To save memory, only pass candidates from one sentence for each head
|
||||
|
|
Loading…
Reference in New Issue
Block a user