mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-30 23:47:31 +03:00 
			
		
		
		
	handle misaligned tokenization
This commit is contained in:
		
							parent
							
								
									11a17976ec
								
							
						
					
					
						commit
						6e46ecfa2c
					
				|  | @ -7,7 +7,7 @@ from thinc.types import Floats2d, Ints1d, Ragged | |||
| from spacy.language import Language | ||||
| from spacy.pipeline.trainable_pipe import TrainablePipe | ||||
| from spacy.scorer import Scorer | ||||
| from spacy.tokens import Doc | ||||
| from spacy.tokens import Doc, Span | ||||
| from spacy.training import Example | ||||
| from spacy.errors import Errors | ||||
| 
 | ||||
|  | @ -160,6 +160,12 @@ class _MaxInt(int): | |||
|         return True | ||||
| 
 | ||||
| 
 | ||||
| def _char_indices(span: Span) -> Tuple[int, int]: | ||||
|     start = span[0].idx | ||||
|     end = span[-1].idx + len(span[-1]) | ||||
|     return start, end | ||||
| 
 | ||||
| 
 | ||||
| class SpanFinder(TrainablePipe): | ||||
|     """Pipeline that learns span boundaries""" | ||||
| 
 | ||||
|  | @ -282,38 +288,52 @@ class SpanFinder(TrainablePipe): | |||
|         scores: Scores representing the model's predictions. | ||||
|         RETURNS (Tuple[float, float]): The loss and the gradient. | ||||
|         """ | ||||
|         reference_truths = self._get_aligned_truth_scores(examples) | ||||
|         d_scores = scores - self.model.ops.asarray2f(reference_truths) | ||||
|         truths, masks = self._get_aligned_truth_scores(examples, self.model.ops) | ||||
|         d_scores = scores - self.model.ops.asarray2f(truths) | ||||
|         d_scores *= masks | ||||
|         loss = float((d_scores**2).sum()) | ||||
|         return loss, d_scores | ||||
| 
 | ||||
|     def _get_aligned_truth_scores(self, examples) -> List[Tuple[int, int]]: | ||||
|     def _get_aligned_truth_scores(self, examples, ops) -> Tuple[Floats2d, Floats2d]: | ||||
|         """Align scores of the predictions to the references for calculating the loss""" | ||||
|         # TODO: handle misaligned (None) alignments | ||||
|         reference_truths = [] | ||||
| 
 | ||||
|         truths = [] | ||||
|         masks = [] | ||||
|         for eg in examples: | ||||
|             if eg.x.text != eg.y.text: | ||||
|                 raise ValueError(Errors.E1053.format(component="span_finder")) | ||||
|             start_indices = set() | ||||
|             end_indices = set() | ||||
| 
 | ||||
|             n_tokens = len(eg.predicted) | ||||
|             truth = ops.xp.zeros((n_tokens, 2), dtype="float32") | ||||
|             mask = ops.xp.ones((n_tokens, 2), dtype="float32") | ||||
|             if self.training_key in eg.reference.spans: | ||||
|                 for span in eg.reference.spans[self.training_key]: | ||||
|                     start_indices.add(eg.reference[span.start].idx) | ||||
|                     end_indices.add( | ||||
|                         eg.reference[span.end - 1].idx + len(eg.reference[span.end - 1]) | ||||
|                     ref_start_char, ref_end_char = _char_indices(span) | ||||
|                     pred_span = eg.predicted.char_span( | ||||
|                         ref_start_char, ref_end_char, alignment_mode="expand" | ||||
|                     ) | ||||
| 
 | ||||
|             for token in eg.predicted: | ||||
|                 reference_truths.append( | ||||
|                     ( | ||||
|                         1 if token.idx in start_indices else 0, | ||||
|                         1 if token.idx + len(token) in end_indices else 0, | ||||
|                     ) | ||||
|                 ) | ||||
| 
 | ||||
|         return reference_truths | ||||
|                     pred_start_char, pred_end_char = _char_indices(pred_span) | ||||
|                     start_match = pred_start_char == ref_start_char | ||||
|                     end_match = pred_end_char == ref_end_char | ||||
|                     # Tokenizations line up. | ||||
|                     if start_match and end_match: | ||||
|                         truth[pred_span[0].i, 0] = 1 | ||||
|                         truth[pred_span[-1].i, 1] = 1 | ||||
|                     # Start character index lines up, but not the end. | ||||
|                     elif start_match and not end_match: | ||||
|                         truth[pred_span[0].i, 0] = 1 | ||||
|                         mask[pred_span[-1].i, 1] = 0 | ||||
|                     # End character index lines up, but not the start. | ||||
|                     elif not start_match and end_match: | ||||
|                         truth[pred_span[-1].i, 1] = 1 | ||||
|                         mask[pred_span[0].i, 0] = 0 | ||||
|                     # Neither of them match. | ||||
|                     else: | ||||
|                         mask[pred_span[0].i, 0] = 0 | ||||
|                         mask[pred_span[-1].i, 1] = 0 | ||||
|             truths.append(truth) | ||||
|             masks.append(mask) | ||||
|         truths = ops.xp.concatenate(truths, axis=0) | ||||
|         masks = ops.xp.concatenate(masks, axis=0) | ||||
|         return truths, mask | ||||
| 
 | ||||
|     def _get_reference(self, docs) -> List[Tuple[int, int]]: | ||||
|         """Create a reference list of token probabilities""" | ||||
|  |  | |||
|  | @ -63,15 +63,16 @@ def test_loss_alignment_example(tokens_predicted, tokens_reference, reference_tr | |||
|     example.reference.spans[TRAINING_KEY] = [example.reference.char_span(5, 9)] | ||||
|     span_finder = nlp.add_pipe("span_finder", config={"training_key": TRAINING_KEY}) | ||||
|     nlp.initialize() | ||||
|     ops = span_finder.model.ops | ||||
|     if predicted.text != reference.text: | ||||
|         with pytest.raises( | ||||
|             ValueError, match="must match between reference and predicted" | ||||
|         ): | ||||
|             span_finder._get_aligned_truth_scores([example]) | ||||
|             span_finder._get_aligned_truth_scores([example], ops) | ||||
|         return | ||||
|     truth_scores = span_finder._get_aligned_truth_scores([example]) | ||||
|     truth_scores, masks = span_finder._get_aligned_truth_scores([example], ops) | ||||
|     assert len(truth_scores) == len(tokens_predicted) | ||||
|     assert truth_scores == reference_truths | ||||
|     ops.xp.testing.assert_array_equal(truth_scores, ops.xp.asarray(reference_truths)) | ||||
| 
 | ||||
| 
 | ||||
| def test_span_finder_model(): | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user