handle misaligned tokenization

This commit is contained in:
kadarakos 2023-05-31 16:56:01 +00:00
parent 11a17976ec
commit 6e46ecfa2c
2 changed files with 47 additions and 26 deletions

View File

@ -7,7 +7,7 @@ from thinc.types import Floats2d, Ints1d, Ragged
from spacy.language import Language from spacy.language import Language
from spacy.pipeline.trainable_pipe import TrainablePipe from spacy.pipeline.trainable_pipe import TrainablePipe
from spacy.scorer import Scorer from spacy.scorer import Scorer
from spacy.tokens import Doc from spacy.tokens import Doc, Span
from spacy.training import Example from spacy.training import Example
from spacy.errors import Errors from spacy.errors import Errors
@ -160,6 +160,12 @@ class _MaxInt(int):
return True return True
def _char_indices(span: Span) -> Tuple[int, int]:
start = span[0].idx
end = span[-1].idx + len(span[-1])
return start, end
class SpanFinder(TrainablePipe): class SpanFinder(TrainablePipe):
"""Pipeline that learns span boundaries""" """Pipeline that learns span boundaries"""
@ -282,38 +288,52 @@ class SpanFinder(TrainablePipe):
scores: Scores representing the model's predictions. scores: Scores representing the model's predictions.
RETURNS (Tuple[float, float]): The loss and the gradient. RETURNS (Tuple[float, float]): The loss and the gradient.
""" """
reference_truths = self._get_aligned_truth_scores(examples) truths, masks = self._get_aligned_truth_scores(examples, self.model.ops)
d_scores = scores - self.model.ops.asarray2f(reference_truths) d_scores = scores - self.model.ops.asarray2f(truths)
d_scores *= masks
loss = float((d_scores**2).sum()) loss = float((d_scores**2).sum())
return loss, d_scores return loss, d_scores
def _get_aligned_truth_scores(self, examples) -> List[Tuple[int, int]]: def _get_aligned_truth_scores(self, examples, ops) -> Tuple[Floats2d, Floats2d]:
"""Align scores of the predictions to the references for calculating the loss""" """Align scores of the predictions to the references for calculating the loss"""
# TODO: handle misaligned (None) alignments truths = []
reference_truths = [] masks = []
for eg in examples: for eg in examples:
if eg.x.text != eg.y.text: if eg.x.text != eg.y.text:
raise ValueError(Errors.E1053.format(component="span_finder")) raise ValueError(Errors.E1053.format(component="span_finder"))
start_indices = set() n_tokens = len(eg.predicted)
end_indices = set() truth = ops.xp.zeros((n_tokens, 2), dtype="float32")
mask = ops.xp.ones((n_tokens, 2), dtype="float32")
if self.training_key in eg.reference.spans: if self.training_key in eg.reference.spans:
for span in eg.reference.spans[self.training_key]: for span in eg.reference.spans[self.training_key]:
start_indices.add(eg.reference[span.start].idx) ref_start_char, ref_end_char = _char_indices(span)
end_indices.add( pred_span = eg.predicted.char_span(
eg.reference[span.end - 1].idx + len(eg.reference[span.end - 1]) ref_start_char, ref_end_char, alignment_mode="expand"
) )
pred_start_char, pred_end_char = _char_indices(pred_span)
for token in eg.predicted: start_match = pred_start_char == ref_start_char
reference_truths.append( end_match = pred_end_char == ref_end_char
( # Tokenizations line up.
1 if token.idx in start_indices else 0, if start_match and end_match:
1 if token.idx + len(token) in end_indices else 0, truth[pred_span[0].i, 0] = 1
) truth[pred_span[-1].i, 1] = 1
) # Start character index lines up, but not the end.
elif start_match and not end_match:
return reference_truths truth[pred_span[0].i, 0] = 1
mask[pred_span[-1].i, 1] = 0
# End character index lines up, but not the start.
elif not start_match and end_match:
truth[pred_span[-1].i, 1] = 1
mask[pred_span[0].i, 0] = 0
# Neither of them match.
else:
mask[pred_span[0].i, 0] = 0
mask[pred_span[-1].i, 1] = 0
truths.append(truth)
masks.append(mask)
truths = ops.xp.concatenate(truths, axis=0)
masks = ops.xp.concatenate(masks, axis=0)
return truths, mask
def _get_reference(self, docs) -> List[Tuple[int, int]]: def _get_reference(self, docs) -> List[Tuple[int, int]]:
"""Create a reference list of token probabilities""" """Create a reference list of token probabilities"""

View File

@ -63,15 +63,16 @@ def test_loss_alignment_example(tokens_predicted, tokens_reference, reference_tr
example.reference.spans[TRAINING_KEY] = [example.reference.char_span(5, 9)] example.reference.spans[TRAINING_KEY] = [example.reference.char_span(5, 9)]
span_finder = nlp.add_pipe("span_finder", config={"training_key": TRAINING_KEY}) span_finder = nlp.add_pipe("span_finder", config={"training_key": TRAINING_KEY})
nlp.initialize() nlp.initialize()
ops = span_finder.model.ops
if predicted.text != reference.text: if predicted.text != reference.text:
with pytest.raises( with pytest.raises(
ValueError, match="must match between reference and predicted" ValueError, match="must match between reference and predicted"
): ):
span_finder._get_aligned_truth_scores([example]) span_finder._get_aligned_truth_scores([example], ops)
return return
truth_scores = span_finder._get_aligned_truth_scores([example]) truth_scores, masks = span_finder._get_aligned_truth_scores([example], ops)
assert len(truth_scores) == len(tokens_predicted) assert len(truth_scores) == len(tokens_predicted)
assert truth_scores == reference_truths ops.xp.testing.assert_array_equal(truth_scores, ops.xp.asarray(reference_truths))
def test_span_finder_model(): def test_span_finder_model():