mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 12:12:20 +03:00
span predictor device fix
This commit is contained in:
parent
1a782592c4
commit
0076f0f617
|
@ -182,11 +182,12 @@ class SpanPredictor(torch.nn.Module):
|
||||||
torch.Tensor: span start/end scores, (n_heads x n_words x 2)
|
torch.Tensor: span start/end scores, (n_heads x n_words x 2)
|
||||||
"""
|
"""
|
||||||
# If we don't receive heads, return empty
|
# If we don't receive heads, return empty
|
||||||
|
device = heads_ids.device
|
||||||
if heads_ids.nelement() == 0:
|
if heads_ids.nelement() == 0:
|
||||||
return torch.empty(size=(0,))
|
return torch.empty(size=(0,))
|
||||||
# Obtain distance embedding indices, [n_heads, n_words]
|
# Obtain distance embedding indices, [n_heads, n_words]
|
||||||
relative_positions = heads_ids.unsqueeze(1) - torch.arange(
|
relative_positions = heads_ids.unsqueeze(1) - torch.arange(
|
||||||
words.shape[0]
|
words.shape[0], device=device
|
||||||
).unsqueeze(0)
|
).unsqueeze(0)
|
||||||
md = self.max_distance
|
md = self.max_distance
|
||||||
# make all valid distances positive
|
# make all valid distances positive
|
||||||
|
@ -210,20 +211,26 @@ class SpanPredictor(torch.nn.Module):
|
||||||
dim=1,
|
dim=1,
|
||||||
)
|
)
|
||||||
lengths = same_sent.sum(dim=1)
|
lengths = same_sent.sum(dim=1)
|
||||||
padding_mask = torch.arange(0, lengths.max().item()).unsqueeze(0)
|
padding_mask = torch.arange(
|
||||||
|
0, lengths.max().item(), device=device
|
||||||
|
).unsqueeze(0)
|
||||||
# (n_heads x max_sent_len)
|
# (n_heads x max_sent_len)
|
||||||
padding_mask = padding_mask < lengths.unsqueeze(1)
|
padding_mask = padding_mask < lengths.unsqueeze(1)
|
||||||
# (n_heads x max_sent_len x input_size * 2 + distance_emb_size)
|
# (n_heads x max_sent_len x input_size * 2 + distance_emb_size)
|
||||||
# This is necessary to allow the convolution layer to look at several
|
# This is necessary to allow the convolution layer to look at several
|
||||||
# word scores
|
# word scores
|
||||||
padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1])
|
padded_pairs = torch.zeros(
|
||||||
|
*padding_mask.shape, pair_matrix.shape[-1], device=device
|
||||||
|
)
|
||||||
padded_pairs[padding_mask] = pair_matrix
|
padded_pairs[padding_mask] = pair_matrix
|
||||||
res = self.ffnn(padded_pairs) # (n_heads x n_candidates x last_layer_output)
|
res = self.ffnn(padded_pairs) # (n_heads x n_candidates x last_layer_output)
|
||||||
res = self.conv(res.permute(0, 2, 1)).permute(
|
res = self.conv(res.permute(0, 2, 1)).permute(
|
||||||
0, 2, 1
|
0, 2, 1
|
||||||
) # (n_heads x n_candidates, 2)
|
) # (n_heads x n_candidates, 2)
|
||||||
|
|
||||||
scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float("-inf"))
|
scores = torch.full(
|
||||||
|
(heads_ids.shape[0], words.shape[0], 2), float("-inf"), device=device
|
||||||
|
)
|
||||||
scores[rows, cols] = res[padding_mask]
|
scores[rows, cols] = res[padding_mask]
|
||||||
# Make sure that start <= head <= end during inference
|
# Make sure that start <= head <= end during inference
|
||||||
if not self.training:
|
if not self.training:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user