mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 20:22:25 +03:00
make sure same device
This commit is contained in:
parent
9f9453865a
commit
1a782592c4
|
@ -2,13 +2,16 @@ from typing import List, Tuple
|
||||||
|
|
||||||
from thinc.api import Model, chain
|
from thinc.api import Model, chain
|
||||||
from thinc.api import PyTorchWrapper, ArgsKwargs
|
from thinc.api import PyTorchWrapper, ArgsKwargs
|
||||||
from thinc.types import Floats2d, Ints2d, Ints1d
|
from thinc.types import Floats2d
|
||||||
from thinc.util import torch, xp2torch, torch2xp
|
from thinc.util import torch, xp2torch, torch2xp
|
||||||
|
|
||||||
from ...tokens import Doc
|
from ...tokens import Doc
|
||||||
from ...util import registry
|
from ...util import registry
|
||||||
|
|
||||||
|
|
||||||
|
EPSILON = 1e-7
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures("spacy.Coref.v1")
|
@registry.architectures("spacy.Coref.v1")
|
||||||
def build_wl_coref_model(
|
def build_wl_coref_model(
|
||||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||||
|
@ -42,7 +45,9 @@ def build_wl_coref_model(
|
||||||
return coref_model
|
return coref_model
|
||||||
|
|
||||||
|
|
||||||
def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bool):
|
def convert_coref_clusterer_inputs(
|
||||||
|
model: Model, X: List[Floats2d], is_train: bool
|
||||||
|
):
|
||||||
# The input here is List[Floats2d], one for each doc
|
# The input here is List[Floats2d], one for each doc
|
||||||
# just use the first
|
# just use the first
|
||||||
# TODO real batching
|
# TODO real batching
|
||||||
|
@ -50,7 +55,7 @@ def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bo
|
||||||
word_features = xp2torch(X, requires_grad=is_train)
|
word_features = xp2torch(X, requires_grad=is_train)
|
||||||
|
|
||||||
# TODO fix or remove type annotations
|
# TODO fix or remove type annotations
|
||||||
def backprop(args: ArgsKwargs): #-> List[Floats2d]:
|
def backprop(args: ArgsKwargs): #-> List[Floats2d]:
|
||||||
# convert to xp and wrap in list
|
# convert to xp and wrap in list
|
||||||
gradients = torch2xp(args.args[0])
|
gradients = torch2xp(args.args[0])
|
||||||
return [gradients]
|
return [gradients]
|
||||||
|
@ -58,7 +63,9 @@ def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bo
|
||||||
return ArgsKwargs(args=(word_features,), kwargs={}), backprop
|
return ArgsKwargs(args=(word_features,), kwargs={}), backprop
|
||||||
|
|
||||||
|
|
||||||
def convert_coref_clusterer_outputs(model: Model, inputs_outputs, is_train: bool):
|
def convert_coref_clusterer_outputs(
|
||||||
|
model: Model, inputs_outputs, is_train: bool
|
||||||
|
):
|
||||||
_, outputs = inputs_outputs
|
_, outputs = inputs_outputs
|
||||||
scores, indices = outputs
|
scores, indices = outputs
|
||||||
|
|
||||||
|
@ -149,10 +156,10 @@ class CorefClusterer(torch.nn.Module):
|
||||||
a_scores_lst: List[torch.Tensor] = []
|
a_scores_lst: List[torch.Tensor] = []
|
||||||
|
|
||||||
for i in range(0, len(words), batch_size):
|
for i in range(0, len(words), batch_size):
|
||||||
pw_batch = pw[i : i + batch_size]
|
pw_batch = pw[i:i + batch_size]
|
||||||
words_batch = words[i : i + batch_size]
|
words_batch = words[i:i + batch_size]
|
||||||
top_indices_batch = top_indices[i : i + batch_size]
|
top_indices_batch = top_indices[i:i + batch_size]
|
||||||
top_rough_scores_batch = top_rough_scores[i : i + batch_size]
|
top_rough_scores_batch = top_rough_scores[i:i + batch_size]
|
||||||
|
|
||||||
# a_scores_batch [batch_size, n_ants]
|
# a_scores_batch [batch_size, n_ants]
|
||||||
a_scores_batch = self.a_scorer(
|
a_scores_batch = self.a_scorer(
|
||||||
|
@ -168,7 +175,6 @@ class CorefClusterer(torch.nn.Module):
|
||||||
return coref_scores, top_indices
|
return coref_scores, top_indices
|
||||||
|
|
||||||
|
|
||||||
EPSILON = 1e-7
|
|
||||||
# Note this function is kept here to keep a torch dep out of coref_util.
|
# Note this function is kept here to keep a torch dep out of coref_util.
|
||||||
def add_dummy(tensor: torch.Tensor, eps: bool = False):
|
def add_dummy(tensor: torch.Tensor, eps: bool = False):
|
||||||
"""Prepends zeros (or a very small value if eps is True)
|
"""Prepends zeros (or a very small value if eps is True)
|
||||||
|
@ -294,7 +300,7 @@ class RoughScorer(torch.nn.Module):
|
||||||
self.k = antecedent_limit
|
self.k = antecedent_limit
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
|
self, # type: ignore
|
||||||
mentions: torch.Tensor,
|
mentions: torch.Tensor,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
|
@ -305,6 +311,7 @@ class RoughScorer(torch.nn.Module):
|
||||||
pair_mask = torch.arange(mentions.shape[0])
|
pair_mask = torch.arange(mentions.shape[0])
|
||||||
pair_mask = pair_mask.unsqueeze(1) - pair_mask.unsqueeze(0)
|
pair_mask = pair_mask.unsqueeze(1) - pair_mask.unsqueeze(0)
|
||||||
pair_mask = torch.log((pair_mask > 0).to(torch.float))
|
pair_mask = torch.log((pair_mask > 0).to(torch.float))
|
||||||
|
pair_mask = pair_mask.to(mentions.device)
|
||||||
bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T)
|
bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T)
|
||||||
rough_scores = pair_mask + bilinear_scores
|
rough_scores = pair_mask + bilinear_scores
|
||||||
top_scores, indices = torch.topk(
|
top_scores, indices = torch.topk(
|
||||||
|
@ -340,5 +347,6 @@ class DistancePairwiseEncoder(torch.nn.Module):
|
||||||
log_distance = distance.to(torch.float).log2().floor_()
|
log_distance = distance.to(torch.float).log2().floor_()
|
||||||
log_distance = log_distance.clamp_max_(max=6).to(torch.long)
|
log_distance = log_distance.clamp_max_(max=6).to(torch.long)
|
||||||
distance = torch.where(distance < 5, distance - 1, log_distance + 2)
|
distance = torch.where(distance < 5, distance - 1, log_distance + 2)
|
||||||
|
distance = distance.to(top_indices.device)
|
||||||
distance = self.distance_emb(distance)
|
distance = self.distance_emb(distance)
|
||||||
return self.dropout(distance)
|
return self.dropout(distance)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user