diff --git a/spacy/tests/training/test_training.py b/spacy/tests/training/test_training.py index ba485ab45..25c55e12c 100644 --- a/spacy/tests/training/test_training.py +++ b/spacy/tests/training/test_training.py @@ -514,6 +514,11 @@ def test_roundtrip_docs_to_docbin(doc): ([[0], [1], [2, 3]], [[0], [1], [2], [2]]), ), ([" ", "a"], ["a"], ([[], [0]], [[1]])), + ( + ["a", "''", "'", ","], + ["a'", "''", ","], + ([[0], [0, 1], [1], [2]], [[0, 1], [1, 2], [3]]), + ), ], ) def test_align(tokens_a, tokens_b, expected): # noqa @@ -698,7 +703,7 @@ def test_alignment_spaces(en_vocab): align = Alignment.from_strings(other_tokens, spacy_tokens) assert list(align.x2y.lengths) == [0, 3, 1, 1, 1, 1, 1] assert list(align.x2y.dataXd) == [0, 1, 2, 3, 4, 4, 5, 5] - assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 2,] + assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 2] assert list(align.y2x.dataXd) == [1, 1, 1, 2, 3, 4, 5, 6] # multiple leading whitespace tokens @@ -707,7 +712,7 @@ def test_alignment_spaces(en_vocab): align = Alignment.from_strings(other_tokens, spacy_tokens) assert list(align.x2y.lengths) == [0, 0, 3, 1, 1, 1, 1, 1] assert list(align.x2y.dataXd) == [0, 1, 2, 3, 4, 4, 5, 5] - assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 2,] + assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 2] assert list(align.y2x.dataXd) == [2, 2, 2, 3, 4, 5, 6, 7] # both with leading whitespace, not identical diff --git a/spacy/training/align.pyx b/spacy/training/align.pyx index b9d89f789..0ef1fd35d 100644 --- a/spacy/training/align.pyx +++ b/spacy/training/align.pyx @@ -7,8 +7,8 @@ from ..errors import Errors def get_alignments(A: List[str], B: List[str]) -> Tuple[List[List[int]], List[List[int]]]: # Create character-to-token mappings - char_to_token_a = tuple(chain(*((i,) * len(x) for i, x in enumerate(A)))) - char_to_token_b = tuple(chain(*((i,) * len(x) for i, x in enumerate(B)))) + char_to_token_a = tuple(chain(*((i,) * len(x.lower()) for i, x in enumerate(A)))) + char_to_token_b = tuple(chain(*((i,) * len(x.lower()) for i, x in enumerate(B)))) str_a = "".join(A).lower() str_b = "".join(B).lower() cdef int len_str_a = len(str_a) @@ -36,8 +36,14 @@ def get_alignments(A: List[str], B: List[str]) -> Tuple[List[List[int]], List[Li if prev_token_idx_b != token_idx_b: b2a.append(set()) # Process the alignment at the current position - if A[token_idx_a] == B[token_idx_b]: - # Current tokens are identical + if A[token_idx_a] == B[token_idx_b] and \ + (char_idx_a == 0 or \ + char_to_token_a[char_idx_a - 1] < token_idx_a) and \ + (char_idx_b == 0 or \ + char_to_token_b[char_idx_b - 1] < token_idx_b): + # Current tokens are identical and both character offsets are the + # start of a token (either at the beginning of the document or the + # previous character belongs to a different token) a2b[-1].add(token_idx_b) b2a[-1].add(token_idx_a) char_idx_a += len(A[token_idx_a])