mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	make sure same device
This commit is contained in:
		
							parent
							
								
									9f9453865a
								
							
						
					
					
						commit
						1a782592c4
					
				| 
						 | 
					@ -2,13 +2,16 @@ from typing import List, Tuple
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from thinc.api import Model, chain
 | 
					from thinc.api import Model, chain
 | 
				
			||||||
from thinc.api import PyTorchWrapper, ArgsKwargs
 | 
					from thinc.api import PyTorchWrapper, ArgsKwargs
 | 
				
			||||||
from thinc.types import Floats2d, Ints2d, Ints1d
 | 
					from thinc.types import Floats2d
 | 
				
			||||||
from thinc.util import torch, xp2torch, torch2xp
 | 
					from thinc.util import torch, xp2torch, torch2xp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ...tokens import Doc
 | 
					from ...tokens import Doc
 | 
				
			||||||
from ...util import registry
 | 
					from ...util import registry
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					EPSILON = 1e-7
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@registry.architectures("spacy.Coref.v1")
 | 
					@registry.architectures("spacy.Coref.v1")
 | 
				
			||||||
def build_wl_coref_model(
 | 
					def build_wl_coref_model(
 | 
				
			||||||
    tok2vec: Model[List[Doc], List[Floats2d]],
 | 
					    tok2vec: Model[List[Doc], List[Floats2d]],
 | 
				
			||||||
| 
						 | 
					@ -42,7 +45,9 @@ def build_wl_coref_model(
 | 
				
			||||||
    return coref_model
 | 
					    return coref_model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bool):
 | 
					def convert_coref_clusterer_inputs(
 | 
				
			||||||
 | 
					        model: Model, X: List[Floats2d], is_train: bool
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
    # The input here is List[Floats2d], one for each doc
 | 
					    # The input here is List[Floats2d], one for each doc
 | 
				
			||||||
    # just use the first
 | 
					    # just use the first
 | 
				
			||||||
    # TODO real batching
 | 
					    # TODO real batching
 | 
				
			||||||
| 
						 | 
					@ -58,7 +63,9 @@ def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bo
 | 
				
			||||||
    return ArgsKwargs(args=(word_features,), kwargs={}), backprop
 | 
					    return ArgsKwargs(args=(word_features,), kwargs={}), backprop
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def convert_coref_clusterer_outputs(model: Model, inputs_outputs, is_train: bool):
 | 
					def convert_coref_clusterer_outputs(
 | 
				
			||||||
 | 
					        model: Model, inputs_outputs, is_train: bool
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
    _, outputs = inputs_outputs
 | 
					    _, outputs = inputs_outputs
 | 
				
			||||||
    scores, indices = outputs
 | 
					    scores, indices = outputs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -168,7 +175,6 @@ class CorefClusterer(torch.nn.Module):
 | 
				
			||||||
        return coref_scores, top_indices
 | 
					        return coref_scores, top_indices
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
EPSILON = 1e-7
 | 
					 | 
				
			||||||
# Note this function is kept here to keep a torch dep out of coref_util.
 | 
					# Note this function is kept here to keep a torch dep out of coref_util.
 | 
				
			||||||
def add_dummy(tensor: torch.Tensor, eps: bool = False):
 | 
					def add_dummy(tensor: torch.Tensor, eps: bool = False):
 | 
				
			||||||
    """Prepends zeros (or a very small value if eps is True)
 | 
					    """Prepends zeros (or a very small value if eps is True)
 | 
				
			||||||
| 
						 | 
					@ -294,7 +300,7 @@ class RoughScorer(torch.nn.Module):
 | 
				
			||||||
        self.k = antecedent_limit
 | 
					        self.k = antecedent_limit
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(
 | 
					    def forward(
 | 
				
			||||||
        self,  # type: ignore  # pylint: disable=arguments-differ  #35566 in pytorch
 | 
					        self,  # type: ignore
 | 
				
			||||||
        mentions: torch.Tensor,
 | 
					        mentions: torch.Tensor,
 | 
				
			||||||
    ) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
					    ) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
| 
						 | 
					@ -305,6 +311,7 @@ class RoughScorer(torch.nn.Module):
 | 
				
			||||||
        pair_mask = torch.arange(mentions.shape[0])
 | 
					        pair_mask = torch.arange(mentions.shape[0])
 | 
				
			||||||
        pair_mask = pair_mask.unsqueeze(1) - pair_mask.unsqueeze(0)
 | 
					        pair_mask = pair_mask.unsqueeze(1) - pair_mask.unsqueeze(0)
 | 
				
			||||||
        pair_mask = torch.log((pair_mask > 0).to(torch.float))
 | 
					        pair_mask = torch.log((pair_mask > 0).to(torch.float))
 | 
				
			||||||
 | 
					        pair_mask = pair_mask.to(mentions.device)
 | 
				
			||||||
        bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T)
 | 
					        bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T)
 | 
				
			||||||
        rough_scores = pair_mask + bilinear_scores
 | 
					        rough_scores = pair_mask + bilinear_scores
 | 
				
			||||||
        top_scores, indices = torch.topk(
 | 
					        top_scores, indices = torch.topk(
 | 
				
			||||||
| 
						 | 
					@ -340,5 +347,6 @@ class DistancePairwiseEncoder(torch.nn.Module):
 | 
				
			||||||
        log_distance = distance.to(torch.float).log2().floor_()
 | 
					        log_distance = distance.to(torch.float).log2().floor_()
 | 
				
			||||||
        log_distance = log_distance.clamp_max_(max=6).to(torch.long)
 | 
					        log_distance = log_distance.clamp_max_(max=6).to(torch.long)
 | 
				
			||||||
        distance = torch.where(distance < 5, distance - 1, log_distance + 2)
 | 
					        distance = torch.where(distance < 5, distance - 1, log_distance + 2)
 | 
				
			||||||
 | 
					        distance = distance.to(top_indices.device)
 | 
				
			||||||
        distance = self.distance_emb(distance)
 | 
					        distance = self.distance_emb(distance)
 | 
				
			||||||
        return self.dropout(distance)
 | 
					        return self.dropout(distance)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user