From 0076f0f617d73e8c1a4415a375c4e0eff9e7103d Mon Sep 17 00:00:00 2001 From: kadarakos Date: Wed, 29 Jun 2022 06:58:47 +0000 Subject: [PATCH] span predictor device fix --- spacy/ml/models/span_predictor.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/spacy/ml/models/span_predictor.py b/spacy/ml/models/span_predictor.py index 378b79e9b..d44e632bd 100644 --- a/spacy/ml/models/span_predictor.py +++ b/spacy/ml/models/span_predictor.py @@ -182,11 +182,12 @@ class SpanPredictor(torch.nn.Module): torch.Tensor: span start/end scores, (n_heads x n_words x 2) """ # If we don't receive heads, return empty + device = heads_ids.device if heads_ids.nelement() == 0: return torch.empty(size=(0,)) # Obtain distance embedding indices, [n_heads, n_words] relative_positions = heads_ids.unsqueeze(1) - torch.arange( - words.shape[0] + words.shape[0], device=device ).unsqueeze(0) md = self.max_distance # make all valid distances positive @@ -210,20 +211,26 @@ class SpanPredictor(torch.nn.Module): dim=1, ) lengths = same_sent.sum(dim=1) - padding_mask = torch.arange(0, lengths.max().item()).unsqueeze(0) + padding_mask = torch.arange( + 0, lengths.max().item(), device=device + ).unsqueeze(0) # (n_heads x max_sent_len) padding_mask = padding_mask < lengths.unsqueeze(1) # (n_heads x max_sent_len x input_size * 2 + distance_emb_size) # This is necessary to allow the convolution layer to look at several # word scores - padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1]) + padded_pairs = torch.zeros( + *padding_mask.shape, pair_matrix.shape[-1], device=device + ) padded_pairs[padding_mask] = pair_matrix res = self.ffnn(padded_pairs) # (n_heads x n_candidates x last_layer_output) res = self.conv(res.permute(0, 2, 1)).permute( 0, 2, 1 ) # (n_heads x n_candidates, 2) - scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float("-inf")) + scores = torch.full( + (heads_ids.shape[0], words.shape[0], 2), float("-inf"), device=device + ) scores[rows, cols] = res[padding_mask] # Make sure that start <= head <= end during inference if not self.training: