mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 20:22:25 +03:00
make sure same device
This commit is contained in:
parent
9f9453865a
commit
1a782592c4
|
@ -2,13 +2,16 @@ from typing import List, Tuple
|
|||
|
||||
from thinc.api import Model, chain
|
||||
from thinc.api import PyTorchWrapper, ArgsKwargs
|
||||
from thinc.types import Floats2d, Ints2d, Ints1d
|
||||
from thinc.types import Floats2d
|
||||
from thinc.util import torch, xp2torch, torch2xp
|
||||
|
||||
from ...tokens import Doc
|
||||
from ...util import registry
|
||||
|
||||
|
||||
EPSILON = 1e-7
|
||||
|
||||
|
||||
@registry.architectures("spacy.Coref.v1")
|
||||
def build_wl_coref_model(
|
||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||
|
@ -42,7 +45,9 @@ def build_wl_coref_model(
|
|||
return coref_model
|
||||
|
||||
|
||||
def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bool):
|
||||
def convert_coref_clusterer_inputs(
|
||||
model: Model, X: List[Floats2d], is_train: bool
|
||||
):
|
||||
# The input here is List[Floats2d], one for each doc
|
||||
# just use the first
|
||||
# TODO real batching
|
||||
|
@ -58,7 +63,9 @@ def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bo
|
|||
return ArgsKwargs(args=(word_features,), kwargs={}), backprop
|
||||
|
||||
|
||||
def convert_coref_clusterer_outputs(model: Model, inputs_outputs, is_train: bool):
|
||||
def convert_coref_clusterer_outputs(
|
||||
model: Model, inputs_outputs, is_train: bool
|
||||
):
|
||||
_, outputs = inputs_outputs
|
||||
scores, indices = outputs
|
||||
|
||||
|
@ -149,10 +156,10 @@ class CorefClusterer(torch.nn.Module):
|
|||
a_scores_lst: List[torch.Tensor] = []
|
||||
|
||||
for i in range(0, len(words), batch_size):
|
||||
pw_batch = pw[i : i + batch_size]
|
||||
words_batch = words[i : i + batch_size]
|
||||
top_indices_batch = top_indices[i : i + batch_size]
|
||||
top_rough_scores_batch = top_rough_scores[i : i + batch_size]
|
||||
pw_batch = pw[i:i + batch_size]
|
||||
words_batch = words[i:i + batch_size]
|
||||
top_indices_batch = top_indices[i:i + batch_size]
|
||||
top_rough_scores_batch = top_rough_scores[i:i + batch_size]
|
||||
|
||||
# a_scores_batch [batch_size, n_ants]
|
||||
a_scores_batch = self.a_scorer(
|
||||
|
@ -168,7 +175,6 @@ class CorefClusterer(torch.nn.Module):
|
|||
return coref_scores, top_indices
|
||||
|
||||
|
||||
EPSILON = 1e-7
|
||||
# Note this function is kept here to keep a torch dep out of coref_util.
|
||||
def add_dummy(tensor: torch.Tensor, eps: bool = False):
|
||||
"""Prepends zeros (or a very small value if eps is True)
|
||||
|
@ -294,7 +300,7 @@ class RoughScorer(torch.nn.Module):
|
|||
self.k = antecedent_limit
|
||||
|
||||
def forward(
|
||||
self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
|
||||
self, # type: ignore
|
||||
mentions: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
|
@ -305,6 +311,7 @@ class RoughScorer(torch.nn.Module):
|
|||
pair_mask = torch.arange(mentions.shape[0])
|
||||
pair_mask = pair_mask.unsqueeze(1) - pair_mask.unsqueeze(0)
|
||||
pair_mask = torch.log((pair_mask > 0).to(torch.float))
|
||||
pair_mask = pair_mask.to(mentions.device)
|
||||
bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T)
|
||||
rough_scores = pair_mask + bilinear_scores
|
||||
top_scores, indices = torch.topk(
|
||||
|
@ -340,5 +347,6 @@ class DistancePairwiseEncoder(torch.nn.Module):
|
|||
log_distance = distance.to(torch.float).log2().floor_()
|
||||
log_distance = log_distance.clamp_max_(max=6).to(torch.long)
|
||||
distance = torch.where(distance < 5, distance - 1, log_distance + 2)
|
||||
distance = distance.to(top_indices.device)
|
||||
distance = self.distance_emb(distance)
|
||||
return self.dropout(distance)
|
||||
|
|
Loading…
Reference in New Issue
Block a user