span predictor device fix

This commit is contained in:
kadarakos 2022-06-29 06:58:47 +00:00
parent 1a782592c4
commit 0076f0f617

View File

@ -182,11 +182,12 @@ class SpanPredictor(torch.nn.Module):
torch.Tensor: span start/end scores, (n_heads x n_words x 2)
"""
# If we don't receive heads, return empty
device = heads_ids.device
if heads_ids.nelement() == 0:
return torch.empty(size=(0,))
# Obtain distance embedding indices, [n_heads, n_words]
relative_positions = heads_ids.unsqueeze(1) - torch.arange(
words.shape[0]
words.shape[0], device=device
).unsqueeze(0)
md = self.max_distance
# make all valid distances positive
@ -210,20 +211,26 @@ class SpanPredictor(torch.nn.Module):
dim=1,
)
lengths = same_sent.sum(dim=1)
padding_mask = torch.arange(0, lengths.max().item()).unsqueeze(0)
padding_mask = torch.arange(
0, lengths.max().item(), device=device
).unsqueeze(0)
# (n_heads x max_sent_len)
padding_mask = padding_mask < lengths.unsqueeze(1)
# (n_heads x max_sent_len x input_size * 2 + distance_emb_size)
# This is necessary to allow the convolution layer to look at several
# word scores
padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1])
padded_pairs = torch.zeros(
*padding_mask.shape, pair_matrix.shape[-1], device=device
)
padded_pairs[padding_mask] = pair_matrix
res = self.ffnn(padded_pairs) # (n_heads x n_candidates x last_layer_output)
res = self.conv(res.permute(0, 2, 1)).permute(
0, 2, 1
) # (n_heads x n_candidates, 2)
scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float("-inf"))
scores = torch.full(
(heads_ids.shape[0], words.shape[0], 2), float("-inf"), device=device
)
scores[rows, cols] = res[padding_mask]
# Make sure that start <= head <= end during inference
if not self.training: