mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 20:22:25 +03:00
handle empty head_ids
This commit is contained in:
parent
7304604edd
commit
4fc40340f9
|
@ -133,6 +133,7 @@ def convert_coref_scorer_outputs(
|
||||||
indices_xp = torch2xp(indices)
|
indices_xp = torch2xp(indices)
|
||||||
return (scores_xp, indices_xp), convert_for_torch_backward
|
return (scores_xp, indices_xp), convert_for_torch_backward
|
||||||
|
|
||||||
|
|
||||||
def convert_span_predictor_inputs(
|
def convert_span_predictor_inputs(
|
||||||
model: Model,
|
model: Model,
|
||||||
X: Tuple[Ints1d, Floats2d, Ints1d],
|
X: Tuple[Ints1d, Floats2d, Ints1d],
|
||||||
|
@ -141,13 +142,17 @@ def convert_span_predictor_inputs(
|
||||||
tok2vec, (sent_ids, head_ids) = X
|
tok2vec, (sent_ids, head_ids) = X
|
||||||
# Normally we shoudl use the input is_train, but for these two it's not relevant
|
# Normally we shoudl use the input is_train, but for these two it's not relevant
|
||||||
sent_ids = xp2torch(sent_ids[0], requires_grad=False)
|
sent_ids = xp2torch(sent_ids[0], requires_grad=False)
|
||||||
head_ids = xp2torch(head_ids[0], requires_grad=False)
|
if not head_ids[0].size:
|
||||||
|
head_ids = torch.empty(size=(0,))
|
||||||
|
else:
|
||||||
|
head_ids = xp2torch(head_ids[0], requires_grad=False)
|
||||||
word_features = xp2torch(tok2vec[0], requires_grad=is_train)
|
word_features = xp2torch(tok2vec[0], requires_grad=is_train)
|
||||||
|
|
||||||
argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={})
|
argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={})
|
||||||
# TODO actually support backprop
|
# TODO actually support backprop
|
||||||
return argskwargs, lambda dX: [[]]
|
return argskwargs, lambda dX: [[]]
|
||||||
|
|
||||||
|
|
||||||
# TODO This probably belongs in the component, not the model.
|
# TODO This probably belongs in the component, not the model.
|
||||||
def predict_span_clusters(span_predictor: Model,
|
def predict_span_clusters(span_predictor: Model,
|
||||||
sent_ids: Ints1d,
|
sent_ids: Ints1d,
|
||||||
|
@ -543,6 +548,9 @@ class SpanPredictor(torch.nn.Module):
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: span start/end scores, [n_heads, n_words, 2]
|
torch.Tensor: span start/end scores, [n_heads, n_words, 2]
|
||||||
"""
|
"""
|
||||||
|
# If we don't receive heads, return empty
|
||||||
|
if heads_ids.nelement() == 0:
|
||||||
|
return torch.empty(size=(0,))
|
||||||
# Obtain distance embedding indices, [n_heads, n_words]
|
# Obtain distance embedding indices, [n_heads, n_words]
|
||||||
relative_positions = (heads_ids.unsqueeze(1) - torch.arange(words.shape[0]).unsqueeze(0))
|
relative_positions = (heads_ids.unsqueeze(1) - torch.arange(words.shape[0]).unsqueeze(0))
|
||||||
# make all valid distances positive
|
# make all valid distances positive
|
||||||
|
@ -550,7 +558,6 @@ class SpanPredictor(torch.nn.Module):
|
||||||
# "too_far"
|
# "too_far"
|
||||||
emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127
|
emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127
|
||||||
# Obtain "same sentence" boolean mask, [n_heads, n_words]
|
# Obtain "same sentence" boolean mask, [n_heads, n_words]
|
||||||
sent_id = torch.tensor(sent_id)
|
|
||||||
heads_ids = heads_ids.long()
|
heads_ids = heads_ids.long()
|
||||||
same_sent = (sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0))
|
same_sent = (sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0))
|
||||||
# To save memory, only pass candidates from one sentence for each head
|
# To save memory, only pass candidates from one sentence for each head
|
||||||
|
|
Loading…
Reference in New Issue
Block a user