mirror of
https://github.com/explosion/spaCy.git
synced 2025-04-21 17:41:59 +03:00
handle misaligned tokenization
This commit is contained in:
parent
11a17976ec
commit
6e46ecfa2c
|
@ -7,7 +7,7 @@ from thinc.types import Floats2d, Ints1d, Ragged
|
|||
from spacy.language import Language
|
||||
from spacy.pipeline.trainable_pipe import TrainablePipe
|
||||
from spacy.scorer import Scorer
|
||||
from spacy.tokens import Doc
|
||||
from spacy.tokens import Doc, Span
|
||||
from spacy.training import Example
|
||||
from spacy.errors import Errors
|
||||
|
||||
|
@ -160,6 +160,12 @@ class _MaxInt(int):
|
|||
return True
|
||||
|
||||
|
||||
def _char_indices(span: Span) -> Tuple[int, int]:
|
||||
start = span[0].idx
|
||||
end = span[-1].idx + len(span[-1])
|
||||
return start, end
|
||||
|
||||
|
||||
class SpanFinder(TrainablePipe):
|
||||
"""Pipeline that learns span boundaries"""
|
||||
|
||||
|
@ -282,38 +288,52 @@ class SpanFinder(TrainablePipe):
|
|||
scores: Scores representing the model's predictions.
|
||||
RETURNS (Tuple[float, float]): The loss and the gradient.
|
||||
"""
|
||||
reference_truths = self._get_aligned_truth_scores(examples)
|
||||
d_scores = scores - self.model.ops.asarray2f(reference_truths)
|
||||
truths, masks = self._get_aligned_truth_scores(examples, self.model.ops)
|
||||
d_scores = scores - self.model.ops.asarray2f(truths)
|
||||
d_scores *= masks
|
||||
loss = float((d_scores**2).sum())
|
||||
return loss, d_scores
|
||||
|
||||
def _get_aligned_truth_scores(self, examples) -> List[Tuple[int, int]]:
|
||||
def _get_aligned_truth_scores(self, examples, ops) -> Tuple[Floats2d, Floats2d]:
|
||||
"""Align scores of the predictions to the references for calculating the loss"""
|
||||
# TODO: handle misaligned (None) alignments
|
||||
reference_truths = []
|
||||
|
||||
truths = []
|
||||
masks = []
|
||||
for eg in examples:
|
||||
if eg.x.text != eg.y.text:
|
||||
raise ValueError(Errors.E1053.format(component="span_finder"))
|
||||
start_indices = set()
|
||||
end_indices = set()
|
||||
|
||||
n_tokens = len(eg.predicted)
|
||||
truth = ops.xp.zeros((n_tokens, 2), dtype="float32")
|
||||
mask = ops.xp.ones((n_tokens, 2), dtype="float32")
|
||||
if self.training_key in eg.reference.spans:
|
||||
for span in eg.reference.spans[self.training_key]:
|
||||
start_indices.add(eg.reference[span.start].idx)
|
||||
end_indices.add(
|
||||
eg.reference[span.end - 1].idx + len(eg.reference[span.end - 1])
|
||||
ref_start_char, ref_end_char = _char_indices(span)
|
||||
pred_span = eg.predicted.char_span(
|
||||
ref_start_char, ref_end_char, alignment_mode="expand"
|
||||
)
|
||||
|
||||
for token in eg.predicted:
|
||||
reference_truths.append(
|
||||
(
|
||||
1 if token.idx in start_indices else 0,
|
||||
1 if token.idx + len(token) in end_indices else 0,
|
||||
)
|
||||
)
|
||||
|
||||
return reference_truths
|
||||
pred_start_char, pred_end_char = _char_indices(pred_span)
|
||||
start_match = pred_start_char == ref_start_char
|
||||
end_match = pred_end_char == ref_end_char
|
||||
# Tokenizations line up.
|
||||
if start_match and end_match:
|
||||
truth[pred_span[0].i, 0] = 1
|
||||
truth[pred_span[-1].i, 1] = 1
|
||||
# Start character index lines up, but not the end.
|
||||
elif start_match and not end_match:
|
||||
truth[pred_span[0].i, 0] = 1
|
||||
mask[pred_span[-1].i, 1] = 0
|
||||
# End character index lines up, but not the start.
|
||||
elif not start_match and end_match:
|
||||
truth[pred_span[-1].i, 1] = 1
|
||||
mask[pred_span[0].i, 0] = 0
|
||||
# Neither of them match.
|
||||
else:
|
||||
mask[pred_span[0].i, 0] = 0
|
||||
mask[pred_span[-1].i, 1] = 0
|
||||
truths.append(truth)
|
||||
masks.append(mask)
|
||||
truths = ops.xp.concatenate(truths, axis=0)
|
||||
masks = ops.xp.concatenate(masks, axis=0)
|
||||
return truths, mask
|
||||
|
||||
def _get_reference(self, docs) -> List[Tuple[int, int]]:
|
||||
"""Create a reference list of token probabilities"""
|
||||
|
|
|
@ -63,15 +63,16 @@ def test_loss_alignment_example(tokens_predicted, tokens_reference, reference_tr
|
|||
example.reference.spans[TRAINING_KEY] = [example.reference.char_span(5, 9)]
|
||||
span_finder = nlp.add_pipe("span_finder", config={"training_key": TRAINING_KEY})
|
||||
nlp.initialize()
|
||||
ops = span_finder.model.ops
|
||||
if predicted.text != reference.text:
|
||||
with pytest.raises(
|
||||
ValueError, match="must match between reference and predicted"
|
||||
):
|
||||
span_finder._get_aligned_truth_scores([example])
|
||||
span_finder._get_aligned_truth_scores([example], ops)
|
||||
return
|
||||
truth_scores = span_finder._get_aligned_truth_scores([example])
|
||||
truth_scores, masks = span_finder._get_aligned_truth_scores([example], ops)
|
||||
assert len(truth_scores) == len(tokens_predicted)
|
||||
assert truth_scores == reference_truths
|
||||
ops.xp.testing.assert_array_equal(truth_scores, ops.xp.asarray(reference_truths))
|
||||
|
||||
|
||||
def test_span_finder_model():
|
||||
|
|
Loading…
Reference in New Issue
Block a user