mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 12:12:20 +03:00
span predictor device fix
This commit is contained in:
parent
1a782592c4
commit
0076f0f617
|
@ -182,11 +182,12 @@ class SpanPredictor(torch.nn.Module):
|
|||
torch.Tensor: span start/end scores, (n_heads x n_words x 2)
|
||||
"""
|
||||
# If we don't receive heads, return empty
|
||||
device = heads_ids.device
|
||||
if heads_ids.nelement() == 0:
|
||||
return torch.empty(size=(0,))
|
||||
# Obtain distance embedding indices, [n_heads, n_words]
|
||||
relative_positions = heads_ids.unsqueeze(1) - torch.arange(
|
||||
words.shape[0]
|
||||
words.shape[0], device=device
|
||||
).unsqueeze(0)
|
||||
md = self.max_distance
|
||||
# make all valid distances positive
|
||||
|
@ -210,20 +211,26 @@ class SpanPredictor(torch.nn.Module):
|
|||
dim=1,
|
||||
)
|
||||
lengths = same_sent.sum(dim=1)
|
||||
padding_mask = torch.arange(0, lengths.max().item()).unsqueeze(0)
|
||||
padding_mask = torch.arange(
|
||||
0, lengths.max().item(), device=device
|
||||
).unsqueeze(0)
|
||||
# (n_heads x max_sent_len)
|
||||
padding_mask = padding_mask < lengths.unsqueeze(1)
|
||||
# (n_heads x max_sent_len x input_size * 2 + distance_emb_size)
|
||||
# This is necessary to allow the convolution layer to look at several
|
||||
# word scores
|
||||
padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1])
|
||||
padded_pairs = torch.zeros(
|
||||
*padding_mask.shape, pair_matrix.shape[-1], device=device
|
||||
)
|
||||
padded_pairs[padding_mask] = pair_matrix
|
||||
res = self.ffnn(padded_pairs) # (n_heads x n_candidates x last_layer_output)
|
||||
res = self.conv(res.permute(0, 2, 1)).permute(
|
||||
0, 2, 1
|
||||
) # (n_heads x n_candidates, 2)
|
||||
|
||||
scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float("-inf"))
|
||||
scores = torch.full(
|
||||
(heads_ids.shape[0], words.shape[0], 2), float("-inf"), device=device
|
||||
)
|
||||
scores[rows, cols] = res[padding_mask]
|
||||
# Make sure that start <= head <= end during inference
|
||||
if not self.training:
|
||||
|
|
Loading…
Reference in New Issue
Block a user