From 1a782592c4dde437e240fed095ed472c029927de Mon Sep 17 00:00:00 2001 From: kadarakos Date: Tue, 28 Jun 2022 12:53:20 +0000 Subject: [PATCH] make sure same device --- spacy/ml/models/coref.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index a8c880a39..660ef68c5 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -2,13 +2,16 @@ from typing import List, Tuple from thinc.api import Model, chain from thinc.api import PyTorchWrapper, ArgsKwargs -from thinc.types import Floats2d, Ints2d, Ints1d +from thinc.types import Floats2d from thinc.util import torch, xp2torch, torch2xp from ...tokens import Doc from ...util import registry +EPSILON = 1e-7 + + @registry.architectures("spacy.Coref.v1") def build_wl_coref_model( tok2vec: Model[List[Doc], List[Floats2d]], @@ -42,7 +45,9 @@ def build_wl_coref_model( return coref_model -def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bool): +def convert_coref_clusterer_inputs( + model: Model, X: List[Floats2d], is_train: bool +): # The input here is List[Floats2d], one for each doc # just use the first # TODO real batching @@ -50,7 +55,7 @@ def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bo word_features = xp2torch(X, requires_grad=is_train) # TODO fix or remove type annotations - def backprop(args: ArgsKwargs): #-> List[Floats2d]: + def backprop(args: ArgsKwargs): #-> List[Floats2d]: # convert to xp and wrap in list gradients = torch2xp(args.args[0]) return [gradients] @@ -58,7 +63,9 @@ def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bo return ArgsKwargs(args=(word_features,), kwargs={}), backprop -def convert_coref_clusterer_outputs(model: Model, inputs_outputs, is_train: bool): +def convert_coref_clusterer_outputs( + model: Model, inputs_outputs, is_train: bool +): _, outputs = inputs_outputs scores, indices = outputs @@ -149,10 +156,10 @@ class CorefClusterer(torch.nn.Module): a_scores_lst: List[torch.Tensor] = [] for i in range(0, len(words), batch_size): - pw_batch = pw[i : i + batch_size] - words_batch = words[i : i + batch_size] - top_indices_batch = top_indices[i : i + batch_size] - top_rough_scores_batch = top_rough_scores[i : i + batch_size] + pw_batch = pw[i:i + batch_size] + words_batch = words[i:i + batch_size] + top_indices_batch = top_indices[i:i + batch_size] + top_rough_scores_batch = top_rough_scores[i:i + batch_size] # a_scores_batch [batch_size, n_ants] a_scores_batch = self.a_scorer( @@ -168,7 +175,6 @@ class CorefClusterer(torch.nn.Module): return coref_scores, top_indices -EPSILON = 1e-7 # Note this function is kept here to keep a torch dep out of coref_util. def add_dummy(tensor: torch.Tensor, eps: bool = False): """Prepends zeros (or a very small value if eps is True) @@ -294,7 +300,7 @@ class RoughScorer(torch.nn.Module): self.k = antecedent_limit def forward( - self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch + self, # type: ignore mentions: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -305,6 +311,7 @@ class RoughScorer(torch.nn.Module): pair_mask = torch.arange(mentions.shape[0]) pair_mask = pair_mask.unsqueeze(1) - pair_mask.unsqueeze(0) pair_mask = torch.log((pair_mask > 0).to(torch.float)) + pair_mask = pair_mask.to(mentions.device) bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T) rough_scores = pair_mask + bilinear_scores top_scores, indices = torch.topk( @@ -340,5 +347,6 @@ class DistancePairwiseEncoder(torch.nn.Module): log_distance = distance.to(torch.float).log2().floor_() log_distance = log_distance.clamp_max_(max=6).to(torch.long) distance = torch.where(distance < 5, distance - 1, log_distance + 2) + distance = distance.to(top_indices.device) distance = self.distance_emb(distance) return self.dropout(distance)