mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-10 16:22:29 +03:00
handle misaligned tokenization
This commit is contained in:
parent
11a17976ec
commit
6e46ecfa2c
|
@ -7,7 +7,7 @@ from thinc.types import Floats2d, Ints1d, Ragged
|
||||||
from spacy.language import Language
|
from spacy.language import Language
|
||||||
from spacy.pipeline.trainable_pipe import TrainablePipe
|
from spacy.pipeline.trainable_pipe import TrainablePipe
|
||||||
from spacy.scorer import Scorer
|
from spacy.scorer import Scorer
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Doc, Span
|
||||||
from spacy.training import Example
|
from spacy.training import Example
|
||||||
from spacy.errors import Errors
|
from spacy.errors import Errors
|
||||||
|
|
||||||
|
@ -160,6 +160,12 @@ class _MaxInt(int):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _char_indices(span: Span) -> Tuple[int, int]:
|
||||||
|
start = span[0].idx
|
||||||
|
end = span[-1].idx + len(span[-1])
|
||||||
|
return start, end
|
||||||
|
|
||||||
|
|
||||||
class SpanFinder(TrainablePipe):
|
class SpanFinder(TrainablePipe):
|
||||||
"""Pipeline that learns span boundaries"""
|
"""Pipeline that learns span boundaries"""
|
||||||
|
|
||||||
|
@ -282,38 +288,52 @@ class SpanFinder(TrainablePipe):
|
||||||
scores: Scores representing the model's predictions.
|
scores: Scores representing the model's predictions.
|
||||||
RETURNS (Tuple[float, float]): The loss and the gradient.
|
RETURNS (Tuple[float, float]): The loss and the gradient.
|
||||||
"""
|
"""
|
||||||
reference_truths = self._get_aligned_truth_scores(examples)
|
truths, masks = self._get_aligned_truth_scores(examples, self.model.ops)
|
||||||
d_scores = scores - self.model.ops.asarray2f(reference_truths)
|
d_scores = scores - self.model.ops.asarray2f(truths)
|
||||||
|
d_scores *= masks
|
||||||
loss = float((d_scores**2).sum())
|
loss = float((d_scores**2).sum())
|
||||||
return loss, d_scores
|
return loss, d_scores
|
||||||
|
|
||||||
def _get_aligned_truth_scores(self, examples) -> List[Tuple[int, int]]:
|
def _get_aligned_truth_scores(self, examples, ops) -> Tuple[Floats2d, Floats2d]:
|
||||||
"""Align scores of the predictions to the references for calculating the loss"""
|
"""Align scores of the predictions to the references for calculating the loss"""
|
||||||
# TODO: handle misaligned (None) alignments
|
truths = []
|
||||||
reference_truths = []
|
masks = []
|
||||||
|
|
||||||
for eg in examples:
|
for eg in examples:
|
||||||
if eg.x.text != eg.y.text:
|
if eg.x.text != eg.y.text:
|
||||||
raise ValueError(Errors.E1053.format(component="span_finder"))
|
raise ValueError(Errors.E1053.format(component="span_finder"))
|
||||||
start_indices = set()
|
n_tokens = len(eg.predicted)
|
||||||
end_indices = set()
|
truth = ops.xp.zeros((n_tokens, 2), dtype="float32")
|
||||||
|
mask = ops.xp.ones((n_tokens, 2), dtype="float32")
|
||||||
if self.training_key in eg.reference.spans:
|
if self.training_key in eg.reference.spans:
|
||||||
for span in eg.reference.spans[self.training_key]:
|
for span in eg.reference.spans[self.training_key]:
|
||||||
start_indices.add(eg.reference[span.start].idx)
|
ref_start_char, ref_end_char = _char_indices(span)
|
||||||
end_indices.add(
|
pred_span = eg.predicted.char_span(
|
||||||
eg.reference[span.end - 1].idx + len(eg.reference[span.end - 1])
|
ref_start_char, ref_end_char, alignment_mode="expand"
|
||||||
)
|
)
|
||||||
|
pred_start_char, pred_end_char = _char_indices(pred_span)
|
||||||
for token in eg.predicted:
|
start_match = pred_start_char == ref_start_char
|
||||||
reference_truths.append(
|
end_match = pred_end_char == ref_end_char
|
||||||
(
|
# Tokenizations line up.
|
||||||
1 if token.idx in start_indices else 0,
|
if start_match and end_match:
|
||||||
1 if token.idx + len(token) in end_indices else 0,
|
truth[pred_span[0].i, 0] = 1
|
||||||
)
|
truth[pred_span[-1].i, 1] = 1
|
||||||
)
|
# Start character index lines up, but not the end.
|
||||||
|
elif start_match and not end_match:
|
||||||
return reference_truths
|
truth[pred_span[0].i, 0] = 1
|
||||||
|
mask[pred_span[-1].i, 1] = 0
|
||||||
|
# End character index lines up, but not the start.
|
||||||
|
elif not start_match and end_match:
|
||||||
|
truth[pred_span[-1].i, 1] = 1
|
||||||
|
mask[pred_span[0].i, 0] = 0
|
||||||
|
# Neither of them match.
|
||||||
|
else:
|
||||||
|
mask[pred_span[0].i, 0] = 0
|
||||||
|
mask[pred_span[-1].i, 1] = 0
|
||||||
|
truths.append(truth)
|
||||||
|
masks.append(mask)
|
||||||
|
truths = ops.xp.concatenate(truths, axis=0)
|
||||||
|
masks = ops.xp.concatenate(masks, axis=0)
|
||||||
|
return truths, mask
|
||||||
|
|
||||||
def _get_reference(self, docs) -> List[Tuple[int, int]]:
|
def _get_reference(self, docs) -> List[Tuple[int, int]]:
|
||||||
"""Create a reference list of token probabilities"""
|
"""Create a reference list of token probabilities"""
|
||||||
|
|
|
@ -63,15 +63,16 @@ def test_loss_alignment_example(tokens_predicted, tokens_reference, reference_tr
|
||||||
example.reference.spans[TRAINING_KEY] = [example.reference.char_span(5, 9)]
|
example.reference.spans[TRAINING_KEY] = [example.reference.char_span(5, 9)]
|
||||||
span_finder = nlp.add_pipe("span_finder", config={"training_key": TRAINING_KEY})
|
span_finder = nlp.add_pipe("span_finder", config={"training_key": TRAINING_KEY})
|
||||||
nlp.initialize()
|
nlp.initialize()
|
||||||
|
ops = span_finder.model.ops
|
||||||
if predicted.text != reference.text:
|
if predicted.text != reference.text:
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError, match="must match between reference and predicted"
|
ValueError, match="must match between reference and predicted"
|
||||||
):
|
):
|
||||||
span_finder._get_aligned_truth_scores([example])
|
span_finder._get_aligned_truth_scores([example], ops)
|
||||||
return
|
return
|
||||||
truth_scores = span_finder._get_aligned_truth_scores([example])
|
truth_scores, masks = span_finder._get_aligned_truth_scores([example], ops)
|
||||||
assert len(truth_scores) == len(tokens_predicted)
|
assert len(truth_scores) == len(tokens_predicted)
|
||||||
assert truth_scores == reference_truths
|
ops.xp.testing.assert_array_equal(truth_scores, ops.xp.asarray(reference_truths))
|
||||||
|
|
||||||
|
|
||||||
def test_span_finder_model():
|
def test_span_finder_model():
|
||||||
|
|
Loading…
Reference in New Issue
Block a user