handle empty head_ids

This commit is contained in:
Kádár Ákos 2022-03-28 11:28:21 +02:00
parent 7304604edd
commit 4fc40340f9

View File

@ -133,6 +133,7 @@ def convert_coref_scorer_outputs(
indices_xp = torch2xp(indices) indices_xp = torch2xp(indices)
return (scores_xp, indices_xp), convert_for_torch_backward return (scores_xp, indices_xp), convert_for_torch_backward
def convert_span_predictor_inputs( def convert_span_predictor_inputs(
model: Model, model: Model,
X: Tuple[Ints1d, Floats2d, Ints1d], X: Tuple[Ints1d, Floats2d, Ints1d],
@ -141,13 +142,17 @@ def convert_span_predictor_inputs(
tok2vec, (sent_ids, head_ids) = X tok2vec, (sent_ids, head_ids) = X
# Normally we shoudl use the input is_train, but for these two it's not relevant # Normally we shoudl use the input is_train, but for these two it's not relevant
sent_ids = xp2torch(sent_ids[0], requires_grad=False) sent_ids = xp2torch(sent_ids[0], requires_grad=False)
head_ids = xp2torch(head_ids[0], requires_grad=False) if not head_ids[0].size:
head_ids = torch.empty(size=(0,))
else:
head_ids = xp2torch(head_ids[0], requires_grad=False)
word_features = xp2torch(tok2vec[0], requires_grad=is_train) word_features = xp2torch(tok2vec[0], requires_grad=is_train)
argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={}) argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={})
# TODO actually support backprop # TODO actually support backprop
return argskwargs, lambda dX: [[]] return argskwargs, lambda dX: [[]]
# TODO This probably belongs in the component, not the model. # TODO This probably belongs in the component, not the model.
def predict_span_clusters(span_predictor: Model, def predict_span_clusters(span_predictor: Model,
sent_ids: Ints1d, sent_ids: Ints1d,
@ -543,6 +548,9 @@ class SpanPredictor(torch.nn.Module):
Returns: Returns:
torch.Tensor: span start/end scores, [n_heads, n_words, 2] torch.Tensor: span start/end scores, [n_heads, n_words, 2]
""" """
# If we don't receive heads, return empty
if heads_ids.nelement() == 0:
return torch.empty(size=(0,))
# Obtain distance embedding indices, [n_heads, n_words] # Obtain distance embedding indices, [n_heads, n_words]
relative_positions = (heads_ids.unsqueeze(1) - torch.arange(words.shape[0]).unsqueeze(0)) relative_positions = (heads_ids.unsqueeze(1) - torch.arange(words.shape[0]).unsqueeze(0))
# make all valid distances positive # make all valid distances positive
@ -550,7 +558,6 @@ class SpanPredictor(torch.nn.Module):
# "too_far" # "too_far"
emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127 emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127
# Obtain "same sentence" boolean mask, [n_heads, n_words] # Obtain "same sentence" boolean mask, [n_heads, n_words]
sent_id = torch.tensor(sent_id)
heads_ids = heads_ids.long() heads_ids = heads_ids.long()
same_sent = (sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0)) same_sent = (sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0))
# To save memory, only pass candidates from one sentence for each head # To save memory, only pass candidates from one sentence for each head