mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-01 00:17:44 +03:00 
			
		
		
		
	Fix alignment issues
I believe this resolves issues with tokenization mismatches.
This commit is contained in:
		
							parent
							
								
									cf33b48fe0
								
							
						
					
					
						commit
						b09bbc7f5e
					
				|  | @ -231,16 +231,29 @@ class SpanPredictor(TrainablePipe): | |||
|         for eg in examples: | ||||
|             starts = [] | ||||
|             ends = [] | ||||
|             keeps = [] | ||||
|             sidx = 0 | ||||
|             for key, sg in eg.reference.spans.items(): | ||||
|                 if key.startswith(self.output_prefix): | ||||
|                     for mention in sg: | ||||
|                         starts.append(mention.start) | ||||
|                         ends.append(mention.end) | ||||
|                     for ii, mention in enumerate(sg): | ||||
|                         sidx += 1 | ||||
|                         # convert to span in pred | ||||
|                         sch, ech = (mention.start_char, mention.end_char) | ||||
|                         span = eg.predicted.char_span(sch, ech) | ||||
|                         # TODO add to errors.py | ||||
|                         if span is None: | ||||
|                             warnings.warn("Could not align gold span in span predictor, skipping") | ||||
|                             continue | ||||
|                         starts.append(span.start) | ||||
|                         ends.append(span.end) | ||||
|                         keeps.append(sidx - 1) | ||||
| 
 | ||||
|             starts = self.model.ops.xp.asarray(starts) | ||||
|             ends = self.model.ops.xp.asarray(ends) | ||||
|             start_scores = span_scores[:, :, 0] | ||||
|             end_scores = span_scores[:, :, 1] | ||||
|             start_scores = span_scores[:, :, 0][keeps] | ||||
|             end_scores = span_scores[:, :, 1][keeps] | ||||
| 
 | ||||
| 
 | ||||
|             n_classes = start_scores.shape[1] | ||||
|             start_probs = ops.softmax(start_scores, axis=1) | ||||
|             end_probs = ops.softmax(end_scores, axis=1) | ||||
|  | @ -248,7 +261,14 @@ class SpanPredictor(TrainablePipe): | |||
|             end_targets = to_categorical(ends, n_classes) | ||||
|             start_grads = start_probs - start_targets | ||||
|             end_grads = end_probs - end_targets | ||||
|             grads = ops.xp.stack((start_grads, end_grads), axis=2) | ||||
|             # now return to original shape, with 0s | ||||
|             final_start_grads = ops.alloc2f(*span_scores[:, :, 0].shape) | ||||
|             final_start_grads[keeps] = start_grads | ||||
|             final_end_grads = ops.alloc2f(*final_start_grads.shape) | ||||
|             final_end_grads[keeps] = end_grads | ||||
|             # XXX Note this only works with fake batching | ||||
|             grads = ops.xp.stack((final_start_grads, final_end_grads), axis=2) | ||||
| 
 | ||||
|             loss = float((grads**2).sum()) | ||||
|         return loss, grads | ||||
| 
 | ||||
|  | @ -267,6 +287,7 @@ class SpanPredictor(TrainablePipe): | |||
|             if not ex.predicted.spans: | ||||
|                 # set placeholder for shape inference | ||||
|                 doc = ex.predicted | ||||
|                 # TODO should be able to check if there are some valid docs in the batch | ||||
|                 assert len(doc) > 2, "Coreference requires at least two tokens" | ||||
|                 doc.spans[f"{self.input_prefix}_0"] = [doc[0:1], doc[1:2]] | ||||
|             X.append(ex.predicted) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user