mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 16:07:41 +03:00 
			
		
		
		
	handle empty head_ids
This commit is contained in:
		
							parent
							
								
									7304604edd
								
							
						
					
					
						commit
						4fc40340f9
					
				|  | @ -133,6 +133,7 @@ def convert_coref_scorer_outputs( | |||
|     indices_xp = torch2xp(indices) | ||||
|     return (scores_xp, indices_xp), convert_for_torch_backward | ||||
| 
 | ||||
| 
 | ||||
| def convert_span_predictor_inputs( | ||||
|     model: Model, | ||||
|     X: Tuple[Ints1d, Floats2d, Ints1d], | ||||
|  | @ -141,13 +142,17 @@ def convert_span_predictor_inputs( | |||
|     tok2vec, (sent_ids, head_ids) = X | ||||
|     # Normally we shoudl use the input is_train, but for these two it's not relevant | ||||
|     sent_ids = xp2torch(sent_ids[0], requires_grad=False) | ||||
|     head_ids = xp2torch(head_ids[0], requires_grad=False) | ||||
|     if not head_ids[0].size: | ||||
|         head_ids = torch.empty(size=(0,)) | ||||
|     else: | ||||
|         head_ids = xp2torch(head_ids[0], requires_grad=False) | ||||
|     word_features = xp2torch(tok2vec[0], requires_grad=is_train) | ||||
| 
 | ||||
|     argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={}) | ||||
|     # TODO actually support backprop | ||||
|     return argskwargs, lambda dX: [[]] | ||||
| 
 | ||||
| 
 | ||||
| # TODO This probably belongs in the component, not the model. | ||||
| def predict_span_clusters(span_predictor: Model, | ||||
|                           sent_ids: Ints1d, | ||||
|  | @ -543,6 +548,9 @@ class SpanPredictor(torch.nn.Module): | |||
|         Returns: | ||||
|             torch.Tensor: span start/end scores, [n_heads, n_words, 2] | ||||
|         """ | ||||
|         # If we don't receive heads, return empty | ||||
|         if heads_ids.nelement() == 0: | ||||
|             return torch.empty(size=(0,)) | ||||
|         # Obtain distance embedding indices, [n_heads, n_words] | ||||
|         relative_positions = (heads_ids.unsqueeze(1) - torch.arange(words.shape[0]).unsqueeze(0)) | ||||
|         # make all valid distances positive | ||||
|  | @ -550,7 +558,6 @@ class SpanPredictor(torch.nn.Module): | |||
|         # "too_far" | ||||
|         emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127 | ||||
|         # Obtain "same sentence" boolean mask, [n_heads, n_words] | ||||
|         sent_id = torch.tensor(sent_id) | ||||
|         heads_ids = heads_ids.long() | ||||
|         same_sent = (sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0)) | ||||
|         # To save memory, only pass candidates from one sentence for each head | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user