From db422abf011fb9b0dabde5e22b9d7fa0b05424b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A1d=C3=A1r=20=C3=81kos?= Date: Fri, 18 Mar 2022 16:24:26 +0100 Subject: [PATCH] remove unnecessary .device --- spacy/ml/models/coref.py | 30 ++++++++---------------------- 1 file changed, 8 insertions(+), 22 deletions(-) diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index f40a4c110..fea4bc21a 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -38,14 +38,11 @@ def build_wl_coref_model( except ValueError: # happens with transformer listener dim = 768 - + with Model.define_operators({">>": chain}): # TODO chain tok2vec with these models - # TODO fix device - should be automatic - device = "cuda:0" coref_scorer = PyTorchWrapper( CorefScorer( - device, dim, embedding_size, hidden_size, @@ -65,7 +62,6 @@ def build_wl_coref_model( # TODO this was hardcoded to 1024, check hidden_size, sp_embedding_size, - device ), convert_inputs=convert_span_predictor_inputs @@ -205,7 +201,6 @@ class CorefScorer(torch.nn.Module): """ def __init__( self, - device: str, dim: int, # tok2vec size dist_emb_size: int, hidden_size: int, @@ -222,8 +217,7 @@ class CorefScorer(torch.nn.Module): epochs_trained (int): the number of epochs finished (useful for warm start) """ - # device, dist_emb_size, hidden_size, n_layers, dropout_rate - self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate).to(device) + self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate) #TODO clean this up bert_emb = dim pair_emb = bert_emb * 3 + self.pw.shape @@ -232,7 +226,7 @@ class CorefScorer(torch.nn.Module): hidden_size, n_layers, dropout_rate - ).to(device) + ) self.lstm = torch.nn.LSTM( input_size=bert_emb, hidden_size=bert_emb, @@ -243,7 +237,7 @@ class CorefScorer(torch.nn.Module): bert_emb, dropout_rate, roughk - ).to(device) + ) self.batch_size = batch_size def forward( @@ -392,7 +386,6 @@ class AnaphoricityScorer(torch.nn.Module): return out - class RoughScorer(torch.nn.Module): """ Is needed to give a roughly estimate of the anaphoricity of two candidates, @@ -423,7 +416,6 @@ class RoughScorer(torch.nn.Module): pair_mask = torch.arange(mentions.shape[0]) pair_mask = pair_mask.unsqueeze(1) - pair_mask.unsqueeze(0) pair_mask = torch.log((pair_mask > 0).to(torch.float)) - pair_mask = pair_mask.to(mentions.device) bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T) rough_scores = pair_mask + bilinear_scores @@ -450,7 +442,7 @@ class RoughScorer(torch.nn.Module): class SpanPredictor(torch.nn.Module): - def __init__(self, input_size: int, distance_emb_size: int, device): + def __init__(self, input_size: int, distance_emb_size: int): super().__init__() self.ffnn = torch.nn.Sequential( torch.nn.Linear(input_size * 2 + 64, input_size), @@ -461,7 +453,6 @@ class SpanPredictor(torch.nn.Module): torch.nn.Dropout(0.3), torch.nn.Linear(256, 64), ) - self.device = device self.conv = torch.nn.Sequential( torch.nn.Conv1d(64, 4, 3, 1, 1), torch.nn.Conv1d(4, 2, 3, 1, 1) @@ -529,6 +520,8 @@ class SpanPredictor(torch.nn.Module): valid_positions = torch.stack((valid_starts, valid_ends), dim=2) return scores + valid_positions return scores + + class DistancePairwiseEncoder(torch.nn.Module): def __init__(self, embedding_size, dropout_rate): @@ -538,17 +531,10 @@ class DistancePairwiseEncoder(torch.nn.Module): self.dropout = torch.nn.Dropout(dropout_rate) self.shape = emb_size - @property - def device(self) -> torch.device: - """ A workaround to get current device (which is assumed to be the - device of the first parameter of one of the submodules) """ - return next(self.distance_emb.parameters()).device - - def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch top_indices: torch.Tensor ) -> torch.Tensor: - word_ids = torch.arange(0, top_indices.size(0), device=self.device) + word_ids = torch.arange(0, top_indices.size(0)) distance = (word_ids.unsqueeze(1) - word_ids[top_indices] ).clamp_min_(min=1) log_distance = distance.to(torch.float).log2().floor_()