mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 07:57:35 +03:00 
			
		
		
		
	handle misaligned tokenization
This commit is contained in:
		
							parent
							
								
									11a17976ec
								
							
						
					
					
						commit
						6e46ecfa2c
					
				|  | @ -7,7 +7,7 @@ from thinc.types import Floats2d, Ints1d, Ragged | ||||||
| from spacy.language import Language | from spacy.language import Language | ||||||
| from spacy.pipeline.trainable_pipe import TrainablePipe | from spacy.pipeline.trainable_pipe import TrainablePipe | ||||||
| from spacy.scorer import Scorer | from spacy.scorer import Scorer | ||||||
| from spacy.tokens import Doc | from spacy.tokens import Doc, Span | ||||||
| from spacy.training import Example | from spacy.training import Example | ||||||
| from spacy.errors import Errors | from spacy.errors import Errors | ||||||
| 
 | 
 | ||||||
|  | @ -160,6 +160,12 @@ class _MaxInt(int): | ||||||
|         return True |         return True | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def _char_indices(span: Span) -> Tuple[int, int]: | ||||||
|  |     start = span[0].idx | ||||||
|  |     end = span[-1].idx + len(span[-1]) | ||||||
|  |     return start, end | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class SpanFinder(TrainablePipe): | class SpanFinder(TrainablePipe): | ||||||
|     """Pipeline that learns span boundaries""" |     """Pipeline that learns span boundaries""" | ||||||
| 
 | 
 | ||||||
|  | @ -282,38 +288,52 @@ class SpanFinder(TrainablePipe): | ||||||
|         scores: Scores representing the model's predictions. |         scores: Scores representing the model's predictions. | ||||||
|         RETURNS (Tuple[float, float]): The loss and the gradient. |         RETURNS (Tuple[float, float]): The loss and the gradient. | ||||||
|         """ |         """ | ||||||
|         reference_truths = self._get_aligned_truth_scores(examples) |         truths, masks = self._get_aligned_truth_scores(examples, self.model.ops) | ||||||
|         d_scores = scores - self.model.ops.asarray2f(reference_truths) |         d_scores = scores - self.model.ops.asarray2f(truths) | ||||||
|  |         d_scores *= masks | ||||||
|         loss = float((d_scores**2).sum()) |         loss = float((d_scores**2).sum()) | ||||||
|         return loss, d_scores |         return loss, d_scores | ||||||
| 
 | 
 | ||||||
|     def _get_aligned_truth_scores(self, examples) -> List[Tuple[int, int]]: |     def _get_aligned_truth_scores(self, examples, ops) -> Tuple[Floats2d, Floats2d]: | ||||||
|         """Align scores of the predictions to the references for calculating the loss""" |         """Align scores of the predictions to the references for calculating the loss""" | ||||||
|         # TODO: handle misaligned (None) alignments |         truths = [] | ||||||
|         reference_truths = [] |         masks = [] | ||||||
| 
 |  | ||||||
|         for eg in examples: |         for eg in examples: | ||||||
|             if eg.x.text != eg.y.text: |             if eg.x.text != eg.y.text: | ||||||
|                 raise ValueError(Errors.E1053.format(component="span_finder")) |                 raise ValueError(Errors.E1053.format(component="span_finder")) | ||||||
|             start_indices = set() |             n_tokens = len(eg.predicted) | ||||||
|             end_indices = set() |             truth = ops.xp.zeros((n_tokens, 2), dtype="float32") | ||||||
| 
 |             mask = ops.xp.ones((n_tokens, 2), dtype="float32") | ||||||
|             if self.training_key in eg.reference.spans: |             if self.training_key in eg.reference.spans: | ||||||
|                 for span in eg.reference.spans[self.training_key]: |                 for span in eg.reference.spans[self.training_key]: | ||||||
|                     start_indices.add(eg.reference[span.start].idx) |                     ref_start_char, ref_end_char = _char_indices(span) | ||||||
|                     end_indices.add( |                     pred_span = eg.predicted.char_span( | ||||||
|                         eg.reference[span.end - 1].idx + len(eg.reference[span.end - 1]) |                         ref_start_char, ref_end_char, alignment_mode="expand" | ||||||
|                     ) |                     ) | ||||||
| 
 |                     pred_start_char, pred_end_char = _char_indices(pred_span) | ||||||
|             for token in eg.predicted: |                     start_match = pred_start_char == ref_start_char | ||||||
|                 reference_truths.append( |                     end_match = pred_end_char == ref_end_char | ||||||
|                     ( |                     # Tokenizations line up. | ||||||
|                         1 if token.idx in start_indices else 0, |                     if start_match and end_match: | ||||||
|                         1 if token.idx + len(token) in end_indices else 0, |                         truth[pred_span[0].i, 0] = 1 | ||||||
|                     ) |                         truth[pred_span[-1].i, 1] = 1 | ||||||
|                 ) |                     # Start character index lines up, but not the end. | ||||||
| 
 |                     elif start_match and not end_match: | ||||||
|         return reference_truths |                         truth[pred_span[0].i, 0] = 1 | ||||||
|  |                         mask[pred_span[-1].i, 1] = 0 | ||||||
|  |                     # End character index lines up, but not the start. | ||||||
|  |                     elif not start_match and end_match: | ||||||
|  |                         truth[pred_span[-1].i, 1] = 1 | ||||||
|  |                         mask[pred_span[0].i, 0] = 0 | ||||||
|  |                     # Neither of them match. | ||||||
|  |                     else: | ||||||
|  |                         mask[pred_span[0].i, 0] = 0 | ||||||
|  |                         mask[pred_span[-1].i, 1] = 0 | ||||||
|  |             truths.append(truth) | ||||||
|  |             masks.append(mask) | ||||||
|  |         truths = ops.xp.concatenate(truths, axis=0) | ||||||
|  |         masks = ops.xp.concatenate(masks, axis=0) | ||||||
|  |         return truths, mask | ||||||
| 
 | 
 | ||||||
|     def _get_reference(self, docs) -> List[Tuple[int, int]]: |     def _get_reference(self, docs) -> List[Tuple[int, int]]: | ||||||
|         """Create a reference list of token probabilities""" |         """Create a reference list of token probabilities""" | ||||||
|  |  | ||||||
|  | @ -63,15 +63,16 @@ def test_loss_alignment_example(tokens_predicted, tokens_reference, reference_tr | ||||||
|     example.reference.spans[TRAINING_KEY] = [example.reference.char_span(5, 9)] |     example.reference.spans[TRAINING_KEY] = [example.reference.char_span(5, 9)] | ||||||
|     span_finder = nlp.add_pipe("span_finder", config={"training_key": TRAINING_KEY}) |     span_finder = nlp.add_pipe("span_finder", config={"training_key": TRAINING_KEY}) | ||||||
|     nlp.initialize() |     nlp.initialize() | ||||||
|  |     ops = span_finder.model.ops | ||||||
|     if predicted.text != reference.text: |     if predicted.text != reference.text: | ||||||
|         with pytest.raises( |         with pytest.raises( | ||||||
|             ValueError, match="must match between reference and predicted" |             ValueError, match="must match between reference and predicted" | ||||||
|         ): |         ): | ||||||
|             span_finder._get_aligned_truth_scores([example]) |             span_finder._get_aligned_truth_scores([example], ops) | ||||||
|         return |         return | ||||||
|     truth_scores = span_finder._get_aligned_truth_scores([example]) |     truth_scores, masks = span_finder._get_aligned_truth_scores([example], ops) | ||||||
|     assert len(truth_scores) == len(tokens_predicted) |     assert len(truth_scores) == len(tokens_predicted) | ||||||
|     assert truth_scores == reference_truths |     ops.xp.testing.assert_array_equal(truth_scores, ops.xp.asarray(reference_truths)) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_span_finder_model(): | def test_span_finder_model(): | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user