handle misaligned tokenization

This commit is contained in:
kadarakos 2023-05-31 16:56:01 +00:00
parent 11a17976ec
commit 6e46ecfa2c
2 changed files with 47 additions and 26 deletions

View File

@ -7,7 +7,7 @@ from thinc.types import Floats2d, Ints1d, Ragged
from spacy.language import Language
from spacy.pipeline.trainable_pipe import TrainablePipe
from spacy.scorer import Scorer
from spacy.tokens import Doc
from spacy.tokens import Doc, Span
from spacy.training import Example
from spacy.errors import Errors
@ -160,6 +160,12 @@ class _MaxInt(int):
return True
def _char_indices(span: Span) -> Tuple[int, int]:
start = span[0].idx
end = span[-1].idx + len(span[-1])
return start, end
class SpanFinder(TrainablePipe):
"""Pipeline that learns span boundaries"""
@ -282,38 +288,52 @@ class SpanFinder(TrainablePipe):
scores: Scores representing the model's predictions.
RETURNS (Tuple[float, float]): The loss and the gradient.
"""
reference_truths = self._get_aligned_truth_scores(examples)
d_scores = scores - self.model.ops.asarray2f(reference_truths)
truths, masks = self._get_aligned_truth_scores(examples, self.model.ops)
d_scores = scores - self.model.ops.asarray2f(truths)
d_scores *= masks
loss = float((d_scores**2).sum())
return loss, d_scores
def _get_aligned_truth_scores(self, examples) -> List[Tuple[int, int]]:
def _get_aligned_truth_scores(self, examples, ops) -> Tuple[Floats2d, Floats2d]:
"""Align scores of the predictions to the references for calculating the loss"""
# TODO: handle misaligned (None) alignments
reference_truths = []
truths = []
masks = []
for eg in examples:
if eg.x.text != eg.y.text:
raise ValueError(Errors.E1053.format(component="span_finder"))
start_indices = set()
end_indices = set()
n_tokens = len(eg.predicted)
truth = ops.xp.zeros((n_tokens, 2), dtype="float32")
mask = ops.xp.ones((n_tokens, 2), dtype="float32")
if self.training_key in eg.reference.spans:
for span in eg.reference.spans[self.training_key]:
start_indices.add(eg.reference[span.start].idx)
end_indices.add(
eg.reference[span.end - 1].idx + len(eg.reference[span.end - 1])
ref_start_char, ref_end_char = _char_indices(span)
pred_span = eg.predicted.char_span(
ref_start_char, ref_end_char, alignment_mode="expand"
)
for token in eg.predicted:
reference_truths.append(
(
1 if token.idx in start_indices else 0,
1 if token.idx + len(token) in end_indices else 0,
)
)
return reference_truths
pred_start_char, pred_end_char = _char_indices(pred_span)
start_match = pred_start_char == ref_start_char
end_match = pred_end_char == ref_end_char
# Tokenizations line up.
if start_match and end_match:
truth[pred_span[0].i, 0] = 1
truth[pred_span[-1].i, 1] = 1
# Start character index lines up, but not the end.
elif start_match and not end_match:
truth[pred_span[0].i, 0] = 1
mask[pred_span[-1].i, 1] = 0
# End character index lines up, but not the start.
elif not start_match and end_match:
truth[pred_span[-1].i, 1] = 1
mask[pred_span[0].i, 0] = 0
# Neither of them match.
else:
mask[pred_span[0].i, 0] = 0
mask[pred_span[-1].i, 1] = 0
truths.append(truth)
masks.append(mask)
truths = ops.xp.concatenate(truths, axis=0)
masks = ops.xp.concatenate(masks, axis=0)
return truths, mask
def _get_reference(self, docs) -> List[Tuple[int, int]]:
"""Create a reference list of token probabilities"""

View File

@ -63,15 +63,16 @@ def test_loss_alignment_example(tokens_predicted, tokens_reference, reference_tr
example.reference.spans[TRAINING_KEY] = [example.reference.char_span(5, 9)]
span_finder = nlp.add_pipe("span_finder", config={"training_key": TRAINING_KEY})
nlp.initialize()
ops = span_finder.model.ops
if predicted.text != reference.text:
with pytest.raises(
ValueError, match="must match between reference and predicted"
):
span_finder._get_aligned_truth_scores([example])
span_finder._get_aligned_truth_scores([example], ops)
return
truth_scores = span_finder._get_aligned_truth_scores([example])
truth_scores, masks = span_finder._get_aligned_truth_scores([example], ops)
assert len(truth_scores) == len(tokens_predicted)
assert truth_scores == reference_truths
ops.xp.testing.assert_array_equal(truth_scores, ops.xp.asarray(reference_truths))
def test_span_finder_model():