make sure same device

This commit is contained in:
kadarakos 2022-06-28 12:53:20 +00:00
parent 9f9453865a
commit 1a782592c4

View File

@ -2,13 +2,16 @@ from typing import List, Tuple
from thinc.api import Model, chain
from thinc.api import PyTorchWrapper, ArgsKwargs
from thinc.types import Floats2d, Ints2d, Ints1d
from thinc.types import Floats2d
from thinc.util import torch, xp2torch, torch2xp
from ...tokens import Doc
from ...util import registry
EPSILON = 1e-7
@registry.architectures("spacy.Coref.v1")
def build_wl_coref_model(
tok2vec: Model[List[Doc], List[Floats2d]],
@ -42,7 +45,9 @@ def build_wl_coref_model(
return coref_model
def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bool):
def convert_coref_clusterer_inputs(
model: Model, X: List[Floats2d], is_train: bool
):
# The input here is List[Floats2d], one for each doc
# just use the first
# TODO real batching
@ -50,7 +55,7 @@ def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bo
word_features = xp2torch(X, requires_grad=is_train)
# TODO fix or remove type annotations
def backprop(args: ArgsKwargs): #-> List[Floats2d]:
def backprop(args: ArgsKwargs): #-> List[Floats2d]:
# convert to xp and wrap in list
gradients = torch2xp(args.args[0])
return [gradients]
@ -58,7 +63,9 @@ def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bo
return ArgsKwargs(args=(word_features,), kwargs={}), backprop
def convert_coref_clusterer_outputs(model: Model, inputs_outputs, is_train: bool):
def convert_coref_clusterer_outputs(
model: Model, inputs_outputs, is_train: bool
):
_, outputs = inputs_outputs
scores, indices = outputs
@ -149,10 +156,10 @@ class CorefClusterer(torch.nn.Module):
a_scores_lst: List[torch.Tensor] = []
for i in range(0, len(words), batch_size):
pw_batch = pw[i : i + batch_size]
words_batch = words[i : i + batch_size]
top_indices_batch = top_indices[i : i + batch_size]
top_rough_scores_batch = top_rough_scores[i : i + batch_size]
pw_batch = pw[i:i + batch_size]
words_batch = words[i:i + batch_size]
top_indices_batch = top_indices[i:i + batch_size]
top_rough_scores_batch = top_rough_scores[i:i + batch_size]
# a_scores_batch [batch_size, n_ants]
a_scores_batch = self.a_scorer(
@ -168,7 +175,6 @@ class CorefClusterer(torch.nn.Module):
return coref_scores, top_indices
EPSILON = 1e-7
# Note this function is kept here to keep a torch dep out of coref_util.
def add_dummy(tensor: torch.Tensor, eps: bool = False):
"""Prepends zeros (or a very small value if eps is True)
@ -294,7 +300,7 @@ class RoughScorer(torch.nn.Module):
self.k = antecedent_limit
def forward(
self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
self, # type: ignore
mentions: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
@ -305,6 +311,7 @@ class RoughScorer(torch.nn.Module):
pair_mask = torch.arange(mentions.shape[0])
pair_mask = pair_mask.unsqueeze(1) - pair_mask.unsqueeze(0)
pair_mask = torch.log((pair_mask > 0).to(torch.float))
pair_mask = pair_mask.to(mentions.device)
bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T)
rough_scores = pair_mask + bilinear_scores
top_scores, indices = torch.topk(
@ -340,5 +347,6 @@ class DistancePairwiseEncoder(torch.nn.Module):
log_distance = distance.to(torch.float).log2().floor_()
log_distance = log_distance.clamp_max_(max=6).to(torch.long)
distance = torch.where(distance < 5, distance - 1, log_distance + 2)
distance = distance.to(top_indices.device)
distance = self.distance_emb(distance)
return self.dropout(distance)