[#4525] fix gold.align (#4526)

* fix: gold.align

* fix align

* remove old align
This commit is contained in:
tamuhey 2019-10-27 21:38:04 +09:00 committed by Matthew Honnibal
parent a9c6104047
commit 554850206c
2 changed files with 86 additions and 20 deletions

View File

@ -73,9 +73,15 @@ def merge_sents(sents):
return [(m_deps, m_brackets)]
_NORM_MAP = {"``": '"', "''": '"'}
def _normalize(tokens):
return [_NORM_MAP.get(word, word) for word in tokens]
def align(tokens_a, tokens_b):
"""Calculate alignment tables between two tokenizations, using the Levenshtein
algorithm. The alignment is case-insensitive.
"""Calculate alignment tables between two tokenizations.
tokens_a (List[str]): The candidate tokenization.
tokens_b (List[str]): The reference tokenization.
@ -92,23 +98,52 @@ def align(tokens_a, tokens_b):
* b2a_multi (Dict[int, int]): As with `a2b_multi`, but mapping the other
direction.
"""
if tokens_a == tokens_b:
alignment = numpy.arange(len(tokens_a))
return 0, alignment, alignment, {}, {}
tokens_a = [w.replace(" ", "").lower() for w in tokens_a]
tokens_b = [w.replace(" ", "").lower() for w in tokens_b]
cost, i2j, j2i, matrix = _align.align(tokens_a, tokens_b)
i2j_multi, j2i_multi = _align.multi_align(i2j, j2i, [len(w) for w in tokens_a],
[len(w) for w in tokens_b])
for i, j in list(i2j_multi.items()):
if i2j_multi.get(i+1) != j and i2j_multi.get(i-1) != j:
i2j[i] = j
i2j_multi.pop(i)
for j, i in list(j2i_multi.items()):
if j2i_multi.get(j+1) != i and j2i_multi.get(j-1) != i:
j2i[j] = i
j2i_multi.pop(j)
return cost, i2j, j2i, i2j_multi, j2i_multi
tokens_a = _normalize(tokens_a)
tokens_b = _normalize(tokens_b)
cost = 0
a2b = numpy.empty(len(tokens_a), dtype="i")
b2a = numpy.empty(len(tokens_b), dtype="i")
a2b_multi = {}
b2a_multi = {}
i = 0
j = 0
offset_a = 0
offset_b = 0
while i < len(tokens_a) and j < len(tokens_b):
a = tokens_a[i][offset_a:]
b = tokens_b[j][offset_b:]
a2b[i] = b2a[j] = -1
if a == b:
if offset_a == offset_b == 0:
a2b[i] = j
b2a[j] = i
elif offset_a == 0:
cost += 2
a2b_multi[i] = j
elif offset_b == 0:
cost += 2
b2a_multi[j] = i
offset_a = offset_b = 0
i += 1
j += 1
elif b.startswith(a):
cost += 1
if offset_a == 0:
a2b_multi[i] = j
i += 1
offset_a = 0
offset_b += len(a)
elif a.startswith(b):
cost += 1
if offset_b == 0:
b2a_multi[j] = i
j += 1
offset_b = 0
offset_a += len(b)
else:
assert "".join(tokens_a) != "".join(tokens_b)
raise ValueError(f"{tokens_a} and {tokens_b} is different texts.")
return cost, a2b, b2a, a2b_multi, b2a_multi
class GoldCorpus(object):

View File

@ -3,7 +3,7 @@ from __future__ import unicode_literals
from spacy.gold import biluo_tags_from_offsets, offsets_from_biluo_tags
from spacy.gold import spans_from_biluo_tags, GoldParse, iob_to_biluo
from spacy.gold import GoldCorpus, docs_to_json
from spacy.gold import GoldCorpus, docs_to_json, align
from spacy.lang.en import English
from spacy.tokens import Doc
from .util import make_tempdir
@ -175,3 +175,34 @@ def test_roundtrip_docs_to_json():
assert "BAKING" in goldparse.cats
assert cats["TRAVEL"] == goldparse.cats["TRAVEL"]
assert cats["BAKING"] == goldparse.cats["BAKING"]
@pytest.mark.parametrize(
"tokens_a,tokens_b,expected",
[
(["a", "b", "c"], ["ab", "c"], (3, [-1, -1, 1], [-1, 2], {0: 0, 1: 0}, {})),
(
["a", "b", "``", "c"],
['ab"', "c"],
(4, [-1, -1, -1, 1], [-1, 3], {0: 0, 1: 0, 2: 0}, {}),
),
(["a", "bc"], ["ab", "c"], (4, [-1, -1], [-1, -1], {0: 0}, {1: 1})),
(
["ab", "c", "d"],
["a", "b", "cd"],
(6, [-1, -1, -1], [-1, -1, -1], {1: 2, 2: 2}, {0: 0, 1: 0}),
),
(
["a", "b", "cd"],
["a", "b", "c", "d"],
(3, [0, 1, -1], [0, 1, -1, -1], {}, {2: 2, 3: 2}),
),
],
)
def test_align(tokens_a, tokens_b, expected):
cost, a2b, b2a, a2b_multi, b2a_multi = align(tokens_a, tokens_b)
assert (cost, list(a2b), list(b2a), a2b_multi, b2a_multi) == expected
# check symmetry
cost, a2b, b2a, a2b_multi, b2a_multi = align(tokens_b, tokens_a)
assert (cost, list(b2a), list(a2b), b2a_multi, a2b_multi) == expected