diff --git a/spacy/pipeline/span_finder.py b/spacy/pipeline/span_finder.py index af4eab636..50112d692 100644 --- a/spacy/pipeline/span_finder.py +++ b/spacy/pipeline/span_finder.py @@ -7,7 +7,7 @@ from thinc.types import Floats2d, Ints1d, Ragged from spacy.language import Language from spacy.pipeline.trainable_pipe import TrainablePipe from spacy.scorer import Scorer -from spacy.tokens import Doc +from spacy.tokens import Doc, Span from spacy.training import Example from spacy.errors import Errors @@ -160,6 +160,12 @@ class _MaxInt(int): return True +def _char_indices(span: Span) -> Tuple[int, int]: + start = span[0].idx + end = span[-1].idx + len(span[-1]) + return start, end + + class SpanFinder(TrainablePipe): """Pipeline that learns span boundaries""" @@ -282,38 +288,52 @@ class SpanFinder(TrainablePipe): scores: Scores representing the model's predictions. RETURNS (Tuple[float, float]): The loss and the gradient. """ - reference_truths = self._get_aligned_truth_scores(examples) - d_scores = scores - self.model.ops.asarray2f(reference_truths) + truths, masks = self._get_aligned_truth_scores(examples, self.model.ops) + d_scores = scores - self.model.ops.asarray2f(truths) + d_scores *= masks loss = float((d_scores**2).sum()) return loss, d_scores - def _get_aligned_truth_scores(self, examples) -> List[Tuple[int, int]]: + def _get_aligned_truth_scores(self, examples, ops) -> Tuple[Floats2d, Floats2d]: """Align scores of the predictions to the references for calculating the loss""" - # TODO: handle misaligned (None) alignments - reference_truths = [] - + truths = [] + masks = [] for eg in examples: if eg.x.text != eg.y.text: raise ValueError(Errors.E1053.format(component="span_finder")) - start_indices = set() - end_indices = set() - + n_tokens = len(eg.predicted) + truth = ops.xp.zeros((n_tokens, 2), dtype="float32") + mask = ops.xp.ones((n_tokens, 2), dtype="float32") if self.training_key in eg.reference.spans: for span in eg.reference.spans[self.training_key]: - start_indices.add(eg.reference[span.start].idx) - end_indices.add( - eg.reference[span.end - 1].idx + len(eg.reference[span.end - 1]) + ref_start_char, ref_end_char = _char_indices(span) + pred_span = eg.predicted.char_span( + ref_start_char, ref_end_char, alignment_mode="expand" ) - - for token in eg.predicted: - reference_truths.append( - ( - 1 if token.idx in start_indices else 0, - 1 if token.idx + len(token) in end_indices else 0, - ) - ) - - return reference_truths + pred_start_char, pred_end_char = _char_indices(pred_span) + start_match = pred_start_char == ref_start_char + end_match = pred_end_char == ref_end_char + # Tokenizations line up. + if start_match and end_match: + truth[pred_span[0].i, 0] = 1 + truth[pred_span[-1].i, 1] = 1 + # Start character index lines up, but not the end. + elif start_match and not end_match: + truth[pred_span[0].i, 0] = 1 + mask[pred_span[-1].i, 1] = 0 + # End character index lines up, but not the start. + elif not start_match and end_match: + truth[pred_span[-1].i, 1] = 1 + mask[pred_span[0].i, 0] = 0 + # Neither of them match. + else: + mask[pred_span[0].i, 0] = 0 + mask[pred_span[-1].i, 1] = 0 + truths.append(truth) + masks.append(mask) + truths = ops.xp.concatenate(truths, axis=0) + masks = ops.xp.concatenate(masks, axis=0) + return truths, mask def _get_reference(self, docs) -> List[Tuple[int, int]]: """Create a reference list of token probabilities""" diff --git a/spacy/tests/pipeline/test_span_finder.py b/spacy/tests/pipeline/test_span_finder.py index 921601283..55d126ecc 100644 --- a/spacy/tests/pipeline/test_span_finder.py +++ b/spacy/tests/pipeline/test_span_finder.py @@ -63,15 +63,16 @@ def test_loss_alignment_example(tokens_predicted, tokens_reference, reference_tr example.reference.spans[TRAINING_KEY] = [example.reference.char_span(5, 9)] span_finder = nlp.add_pipe("span_finder", config={"training_key": TRAINING_KEY}) nlp.initialize() + ops = span_finder.model.ops if predicted.text != reference.text: with pytest.raises( ValueError, match="must match between reference and predicted" ): - span_finder._get_aligned_truth_scores([example]) + span_finder._get_aligned_truth_scores([example], ops) return - truth_scores = span_finder._get_aligned_truth_scores([example]) + truth_scores, masks = span_finder._get_aligned_truth_scores([example], ops) assert len(truth_scores) == len(tokens_predicted) - assert truth_scores == reference_truths + ops.xp.testing.assert_array_equal(truth_scores, ops.xp.asarray(reference_truths)) def test_span_finder_model():