diff --git a/spacy/gold.pyx b/spacy/gold.pyx index 3a5f508b4..a2d9f2f05 100644 --- a/spacy/gold.pyx +++ b/spacy/gold.pyx @@ -73,9 +73,15 @@ def merge_sents(sents): return [(m_deps, m_brackets)] +_NORM_MAP = {"``": '"', "''": '"'} + + +def _normalize(tokens): + return [_NORM_MAP.get(word, word) for word in tokens] + + def align(tokens_a, tokens_b): - """Calculate alignment tables between two tokenizations, using the Levenshtein - algorithm. The alignment is case-insensitive. + """Calculate alignment tables between two tokenizations. tokens_a (List[str]): The candidate tokenization. tokens_b (List[str]): The reference tokenization. @@ -92,23 +98,52 @@ def align(tokens_a, tokens_b): * b2a_multi (Dict[int, int]): As with `a2b_multi`, but mapping the other direction. """ - if tokens_a == tokens_b: - alignment = numpy.arange(len(tokens_a)) - return 0, alignment, alignment, {}, {} - tokens_a = [w.replace(" ", "").lower() for w in tokens_a] - tokens_b = [w.replace(" ", "").lower() for w in tokens_b] - cost, i2j, j2i, matrix = _align.align(tokens_a, tokens_b) - i2j_multi, j2i_multi = _align.multi_align(i2j, j2i, [len(w) for w in tokens_a], - [len(w) for w in tokens_b]) - for i, j in list(i2j_multi.items()): - if i2j_multi.get(i+1) != j and i2j_multi.get(i-1) != j: - i2j[i] = j - i2j_multi.pop(i) - for j, i in list(j2i_multi.items()): - if j2i_multi.get(j+1) != i and j2i_multi.get(j-1) != i: - j2i[j] = i - j2i_multi.pop(j) - return cost, i2j, j2i, i2j_multi, j2i_multi + tokens_a = _normalize(tokens_a) + tokens_b = _normalize(tokens_b) + cost = 0 + a2b = numpy.empty(len(tokens_a), dtype="i") + b2a = numpy.empty(len(tokens_b), dtype="i") + a2b_multi = {} + b2a_multi = {} + i = 0 + j = 0 + offset_a = 0 + offset_b = 0 + while i < len(tokens_a) and j < len(tokens_b): + a = tokens_a[i][offset_a:] + b = tokens_b[j][offset_b:] + a2b[i] = b2a[j] = -1 + if a == b: + if offset_a == offset_b == 0: + a2b[i] = j + b2a[j] = i + elif offset_a == 0: + cost += 2 + a2b_multi[i] = j + elif offset_b == 0: + cost += 2 + b2a_multi[j] = i + offset_a = offset_b = 0 + i += 1 + j += 1 + elif b.startswith(a): + cost += 1 + if offset_a == 0: + a2b_multi[i] = j + i += 1 + offset_a = 0 + offset_b += len(a) + elif a.startswith(b): + cost += 1 + if offset_b == 0: + b2a_multi[j] = i + j += 1 + offset_b = 0 + offset_a += len(b) + else: + assert "".join(tokens_a) != "".join(tokens_b) + raise ValueError(f"{tokens_a} and {tokens_b} is different texts.") + return cost, a2b, b2a, a2b_multi, b2a_multi class GoldCorpus(object): diff --git a/spacy/tests/test_gold.py b/spacy/tests/test_gold.py index a7c29f8db..0f3112e07 100644 --- a/spacy/tests/test_gold.py +++ b/spacy/tests/test_gold.py @@ -3,7 +3,7 @@ from __future__ import unicode_literals from spacy.gold import biluo_tags_from_offsets, offsets_from_biluo_tags from spacy.gold import spans_from_biluo_tags, GoldParse, iob_to_biluo -from spacy.gold import GoldCorpus, docs_to_json +from spacy.gold import GoldCorpus, docs_to_json, align from spacy.lang.en import English from spacy.tokens import Doc from .util import make_tempdir @@ -175,3 +175,34 @@ def test_roundtrip_docs_to_json(): assert "BAKING" in goldparse.cats assert cats["TRAVEL"] == goldparse.cats["TRAVEL"] assert cats["BAKING"] == goldparse.cats["BAKING"] + + +@pytest.mark.parametrize( + "tokens_a,tokens_b,expected", + [ + (["a", "b", "c"], ["ab", "c"], (3, [-1, -1, 1], [-1, 2], {0: 0, 1: 0}, {})), + ( + ["a", "b", "``", "c"], + ['ab"', "c"], + (4, [-1, -1, -1, 1], [-1, 3], {0: 0, 1: 0, 2: 0}, {}), + ), + (["a", "bc"], ["ab", "c"], (4, [-1, -1], [-1, -1], {0: 0}, {1: 1})), + ( + ["ab", "c", "d"], + ["a", "b", "cd"], + (6, [-1, -1, -1], [-1, -1, -1], {1: 2, 2: 2}, {0: 0, 1: 0}), + ), + ( + ["a", "b", "cd"], + ["a", "b", "c", "d"], + (3, [0, 1, -1], [0, 1, -1, -1], {}, {2: 2, 3: 2}), + ), + ], +) +def test_align(tokens_a, tokens_b, expected): + cost, a2b, b2a, a2b_multi, b2a_multi = align(tokens_a, tokens_b) + assert (cost, list(a2b), list(b2a), a2b_multi, b2a_multi) == expected + # check symmetry + cost, a2b, b2a, a2b_multi, b2a_multi = align(tokens_b, tokens_a) + assert (cost, list(b2a), list(a2b), b2a_multi, a2b_multi) == expected +