mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
* fix: gold.align * fix align * remove old align
This commit is contained in:
parent
a9c6104047
commit
554850206c
|
@ -73,9 +73,15 @@ def merge_sents(sents):
|
|||
return [(m_deps, m_brackets)]
|
||||
|
||||
|
||||
_NORM_MAP = {"``": '"', "''": '"'}
|
||||
|
||||
|
||||
def _normalize(tokens):
|
||||
return [_NORM_MAP.get(word, word) for word in tokens]
|
||||
|
||||
|
||||
def align(tokens_a, tokens_b):
|
||||
"""Calculate alignment tables between two tokenizations, using the Levenshtein
|
||||
algorithm. The alignment is case-insensitive.
|
||||
"""Calculate alignment tables between two tokenizations.
|
||||
|
||||
tokens_a (List[str]): The candidate tokenization.
|
||||
tokens_b (List[str]): The reference tokenization.
|
||||
|
@ -92,23 +98,52 @@ def align(tokens_a, tokens_b):
|
|||
* b2a_multi (Dict[int, int]): As with `a2b_multi`, but mapping the other
|
||||
direction.
|
||||
"""
|
||||
if tokens_a == tokens_b:
|
||||
alignment = numpy.arange(len(tokens_a))
|
||||
return 0, alignment, alignment, {}, {}
|
||||
tokens_a = [w.replace(" ", "").lower() for w in tokens_a]
|
||||
tokens_b = [w.replace(" ", "").lower() for w in tokens_b]
|
||||
cost, i2j, j2i, matrix = _align.align(tokens_a, tokens_b)
|
||||
i2j_multi, j2i_multi = _align.multi_align(i2j, j2i, [len(w) for w in tokens_a],
|
||||
[len(w) for w in tokens_b])
|
||||
for i, j in list(i2j_multi.items()):
|
||||
if i2j_multi.get(i+1) != j and i2j_multi.get(i-1) != j:
|
||||
i2j[i] = j
|
||||
i2j_multi.pop(i)
|
||||
for j, i in list(j2i_multi.items()):
|
||||
if j2i_multi.get(j+1) != i and j2i_multi.get(j-1) != i:
|
||||
j2i[j] = i
|
||||
j2i_multi.pop(j)
|
||||
return cost, i2j, j2i, i2j_multi, j2i_multi
|
||||
tokens_a = _normalize(tokens_a)
|
||||
tokens_b = _normalize(tokens_b)
|
||||
cost = 0
|
||||
a2b = numpy.empty(len(tokens_a), dtype="i")
|
||||
b2a = numpy.empty(len(tokens_b), dtype="i")
|
||||
a2b_multi = {}
|
||||
b2a_multi = {}
|
||||
i = 0
|
||||
j = 0
|
||||
offset_a = 0
|
||||
offset_b = 0
|
||||
while i < len(tokens_a) and j < len(tokens_b):
|
||||
a = tokens_a[i][offset_a:]
|
||||
b = tokens_b[j][offset_b:]
|
||||
a2b[i] = b2a[j] = -1
|
||||
if a == b:
|
||||
if offset_a == offset_b == 0:
|
||||
a2b[i] = j
|
||||
b2a[j] = i
|
||||
elif offset_a == 0:
|
||||
cost += 2
|
||||
a2b_multi[i] = j
|
||||
elif offset_b == 0:
|
||||
cost += 2
|
||||
b2a_multi[j] = i
|
||||
offset_a = offset_b = 0
|
||||
i += 1
|
||||
j += 1
|
||||
elif b.startswith(a):
|
||||
cost += 1
|
||||
if offset_a == 0:
|
||||
a2b_multi[i] = j
|
||||
i += 1
|
||||
offset_a = 0
|
||||
offset_b += len(a)
|
||||
elif a.startswith(b):
|
||||
cost += 1
|
||||
if offset_b == 0:
|
||||
b2a_multi[j] = i
|
||||
j += 1
|
||||
offset_b = 0
|
||||
offset_a += len(b)
|
||||
else:
|
||||
assert "".join(tokens_a) != "".join(tokens_b)
|
||||
raise ValueError(f"{tokens_a} and {tokens_b} is different texts.")
|
||||
return cost, a2b, b2a, a2b_multi, b2a_multi
|
||||
|
||||
|
||||
class GoldCorpus(object):
|
||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import unicode_literals
|
|||
|
||||
from spacy.gold import biluo_tags_from_offsets, offsets_from_biluo_tags
|
||||
from spacy.gold import spans_from_biluo_tags, GoldParse, iob_to_biluo
|
||||
from spacy.gold import GoldCorpus, docs_to_json
|
||||
from spacy.gold import GoldCorpus, docs_to_json, align
|
||||
from spacy.lang.en import English
|
||||
from spacy.tokens import Doc
|
||||
from .util import make_tempdir
|
||||
|
@ -175,3 +175,34 @@ def test_roundtrip_docs_to_json():
|
|||
assert "BAKING" in goldparse.cats
|
||||
assert cats["TRAVEL"] == goldparse.cats["TRAVEL"]
|
||||
assert cats["BAKING"] == goldparse.cats["BAKING"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tokens_a,tokens_b,expected",
|
||||
[
|
||||
(["a", "b", "c"], ["ab", "c"], (3, [-1, -1, 1], [-1, 2], {0: 0, 1: 0}, {})),
|
||||
(
|
||||
["a", "b", "``", "c"],
|
||||
['ab"', "c"],
|
||||
(4, [-1, -1, -1, 1], [-1, 3], {0: 0, 1: 0, 2: 0}, {}),
|
||||
),
|
||||
(["a", "bc"], ["ab", "c"], (4, [-1, -1], [-1, -1], {0: 0}, {1: 1})),
|
||||
(
|
||||
["ab", "c", "d"],
|
||||
["a", "b", "cd"],
|
||||
(6, [-1, -1, -1], [-1, -1, -1], {1: 2, 2: 2}, {0: 0, 1: 0}),
|
||||
),
|
||||
(
|
||||
["a", "b", "cd"],
|
||||
["a", "b", "c", "d"],
|
||||
(3, [0, 1, -1], [0, 1, -1, -1], {}, {2: 2, 3: 2}),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_align(tokens_a, tokens_b, expected):
|
||||
cost, a2b, b2a, a2b_multi, b2a_multi = align(tokens_a, tokens_b)
|
||||
assert (cost, list(a2b), list(b2a), a2b_multi, b2a_multi) == expected
|
||||
# check symmetry
|
||||
cost, a2b, b2a, a2b_multi, b2a_multi = align(tokens_b, tokens_a)
|
||||
assert (cost, list(b2a), list(a2b), b2a_multi, a2b_multi) == expected
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user