mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 16:07:41 +03:00 
			
		
		
		
	conflict
This commit is contained in:
		
						commit
						150e7c46d7
					
				|  | @ -37,14 +37,11 @@ def build_wl_coref_model( | |||
|     except ValueError: | ||||
|         # happens with transformer listener | ||||
|         dim = 768 | ||||
|      | ||||
| 
 | ||||
|     with Model.define_operators({">>": chain}): | ||||
|         # TODO chain tok2vec with these models | ||||
|         # TODO fix device - should be automatic | ||||
|         device = "cuda:0" | ||||
|         coref_scorer = PyTorchWrapper( | ||||
|             CorefScorer( | ||||
|                 device, | ||||
|                 dim, | ||||
|                 embedding_size, | ||||
|                 hidden_size, | ||||
|  | @ -64,7 +61,6 @@ def build_wl_coref_model( | |||
|                 # TODO this was hardcoded to 1024, check | ||||
|                 hidden_size, | ||||
|                 sp_embedding_size, | ||||
|                 device | ||||
|             ), | ||||
|              | ||||
|             convert_inputs=convert_span_predictor_inputs | ||||
|  | @ -266,7 +262,6 @@ class CorefScorer(torch.nn.Module): | |||
|     """ | ||||
|     def __init__( | ||||
|         self, | ||||
|         device: str, | ||||
|         dim: int, # tok2vec size | ||||
|         dist_emb_size: int, | ||||
|         hidden_size: int, | ||||
|  | @ -283,8 +278,7 @@ class CorefScorer(torch.nn.Module): | |||
|             epochs_trained (int): the number of epochs finished | ||||
|                 (useful for warm start) | ||||
|         """ | ||||
|         # device, dist_emb_size, hidden_size, n_layers, dropout_rate | ||||
|         self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate).to(device) | ||||
|         self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate) | ||||
|         #TODO clean this up | ||||
|         bert_emb = dim | ||||
|         pair_emb = bert_emb * 3 + self.pw.shape | ||||
|  | @ -293,7 +287,7 @@ class CorefScorer(torch.nn.Module): | |||
|             hidden_size, | ||||
|             n_layers, | ||||
|             dropout_rate | ||||
|         ).to(device) | ||||
|         ) | ||||
|         self.lstm = torch.nn.LSTM( | ||||
|             input_size=bert_emb, | ||||
|             hidden_size=bert_emb, | ||||
|  | @ -304,7 +298,7 @@ class CorefScorer(torch.nn.Module): | |||
|             bert_emb, | ||||
|             dropout_rate, | ||||
|             roughk | ||||
|         ).to(device) | ||||
|         ) | ||||
|         self.batch_size = batch_size | ||||
| 
 | ||||
|     def forward( | ||||
|  | @ -453,7 +447,6 @@ class AnaphoricityScorer(torch.nn.Module): | |||
|         return out | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| class RoughScorer(torch.nn.Module): | ||||
|     """ | ||||
|     Is needed to give a roughly estimate of the anaphoricity of two candidates, | ||||
|  | @ -484,7 +477,6 @@ class RoughScorer(torch.nn.Module): | |||
|         pair_mask = torch.arange(mentions.shape[0]) | ||||
|         pair_mask = pair_mask.unsqueeze(1) - pair_mask.unsqueeze(0) | ||||
|         pair_mask = torch.log((pair_mask > 0).to(torch.float)) | ||||
|         pair_mask = pair_mask.to(mentions.device) | ||||
|         bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T) | ||||
|         rough_scores = pair_mask + bilinear_scores | ||||
| 
 | ||||
|  | @ -511,7 +503,7 @@ class RoughScorer(torch.nn.Module): | |||
| 
 | ||||
| 
 | ||||
| class SpanPredictor(torch.nn.Module): | ||||
|     def __init__(self, input_size: int, distance_emb_size: int, device): | ||||
|     def __init__(self, input_size: int, distance_emb_size: int): | ||||
|         super().__init__() | ||||
|         self.ffnn = torch.nn.Sequential( | ||||
|             torch.nn.Linear(input_size * 2 + 64, input_size), | ||||
|  | @ -522,7 +514,6 @@ class SpanPredictor(torch.nn.Module): | |||
|             torch.nn.Dropout(0.3), | ||||
|             torch.nn.Linear(256, 64), | ||||
|         ) | ||||
|         self.device = device | ||||
|         self.conv = torch.nn.Sequential( | ||||
|             torch.nn.Conv1d(64, 4, 3, 1, 1), | ||||
|             torch.nn.Conv1d(4, 2, 3, 1, 1) | ||||
|  | @ -600,17 +591,10 @@ class DistancePairwiseEncoder(torch.nn.Module): | |||
|         self.dropout = torch.nn.Dropout(dropout_rate) | ||||
|         self.shape = emb_size | ||||
| 
 | ||||
|     @property | ||||
|     def device(self) -> torch.device: | ||||
|         """ A workaround to get current device (which is assumed to be the | ||||
|         device of the first parameter of one of the submodules) """ | ||||
|         return next(self.distance_emb.parameters()).device | ||||
| 
 | ||||
| 
 | ||||
|     def forward(self,  # type: ignore  # pylint: disable=arguments-differ  #35566 in pytorch | ||||
|                 top_indices: torch.Tensor | ||||
|         ) -> torch.Tensor: | ||||
|         word_ids = torch.arange(0, top_indices.size(0), device=self.device) | ||||
|         word_ids = torch.arange(0, top_indices.size(0)) | ||||
|         distance = (word_ids.unsqueeze(1) - word_ids[top_indices] | ||||
|                     ).clamp_min_(min=1) | ||||
|         log_distance = distance.to(torch.float).log2().floor_() | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user