mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-25 00:34:20 +03:00
Fix many-to-one alignment
This commit is contained in:
parent
4890ee1732
commit
6138439469
|
@ -90,7 +90,7 @@ from .compat import unicode_
|
|||
from murmurhash.mrmr cimport hash32
|
||||
|
||||
|
||||
def align(S, T, many_to_one=False, one_to_many=False):
|
||||
def align(S, T):
|
||||
cdef int m = len(S)
|
||||
cdef int n = len(T)
|
||||
cdef np.ndarray matrix = numpy.zeros((m+1, n+1), dtype='int32')
|
||||
|
@ -126,39 +126,58 @@ def multi_align(np.ndarray i2j, np.ndarray j2i, i_lengths, j_lengths):
|
|||
i2j_multi: {1: 1, 2: 1}
|
||||
j2i_multi: {}
|
||||
'''
|
||||
i_starts = numpy.cumsum([0] + i_lengths[:-1])
|
||||
j_starts = numpy.cumsum([0] + j_lengths[:-1])
|
||||
i2j_miss = _get_regions(i2j, i_starts)
|
||||
j2i_miss = _get_regions(j2i, j_starts)
|
||||
i2j_miss = _get_regions(i2j, i_lengths)
|
||||
j2i_miss = _get_regions(j2i, j_lengths)
|
||||
|
||||
i2j_multi = _get_mapping(i2j_miss, j2i_miss, i_lengths, j_lengths)
|
||||
j2i_multi = _get_mapping(j2i_miss, i2j_miss, j_lengths, i_lengths)
|
||||
i2j_multi, j2i_multi = _get_mapping(i2j_miss, j2i_miss, i_lengths, j_lengths)
|
||||
return i2j_multi, j2i_multi
|
||||
|
||||
|
||||
def _get_regions(alignment, starts):
|
||||
def _get_regions(alignment, lengths):
|
||||
regions = {}
|
||||
start = None
|
||||
offset = 0
|
||||
for i in range(len(alignment)):
|
||||
if alignment[i] < 0:
|
||||
if start is None:
|
||||
start = starts[i]
|
||||
start = offset
|
||||
regions.setdefault(start, [])
|
||||
regions[start].append(i)
|
||||
else:
|
||||
start = None
|
||||
offset += lengths[i]
|
||||
return regions
|
||||
|
||||
|
||||
def _get_mapping(miss1, miss2, lengths1, lengths2):
|
||||
output = {}
|
||||
i2j = {}
|
||||
j2i = {}
|
||||
for start, region1 in miss1.items():
|
||||
region2 = miss2.get(start, [])
|
||||
if len(region2) == 1:
|
||||
if sum(lengths1[i] for i in region1):
|
||||
for i in region1:
|
||||
output[i] = region2[0]
|
||||
return output
|
||||
if not region1 or start not in miss2:
|
||||
continue
|
||||
region2 = miss2[start]
|
||||
if sum(lengths1[i] for i in region1) == sum(lengths2[i] for i in region2):
|
||||
j = region2.pop(0)
|
||||
buff = []
|
||||
# Consume tokens from region 1, until we meet the length of the
|
||||
# first token in region2. If we do, align the tokens. If
|
||||
# we exceed the length, break.
|
||||
while region1:
|
||||
buff.append(region1.pop(0))
|
||||
if sum(lengths1[i] for i in buff) == lengths2[j]:
|
||||
for i in buff:
|
||||
i2j[i] = j
|
||||
j2i[j] = buff[-1]
|
||||
j += 1
|
||||
buff = []
|
||||
elif sum(lengths1[i] for i in buff) > lengths2[j]:
|
||||
break
|
||||
else:
|
||||
if buff and sum(lengths1[i] for i in buff) == lengths2[j]:
|
||||
for i in buff:
|
||||
i2j[i] = j
|
||||
j2i[j] = buff[-1]
|
||||
return i2j, j2i
|
||||
|
||||
|
||||
def _convert_sequence(seq):
|
||||
|
|
|
@ -63,8 +63,6 @@ def merge_sents(sents):
|
|||
|
||||
punct_re = re.compile(r'\W')
|
||||
def align(cand_words, gold_words):
|
||||
cand_words = [punct_re.sub('', w).lower() for w in cand_words]
|
||||
gold_words = [punct_re.sub('', w).lower() for w in gold_words]
|
||||
if cand_words == gold_words:
|
||||
alignment = numpy.arange(len(cand_words))
|
||||
return 0, alignment, alignment, {}, {}
|
||||
|
@ -389,7 +387,7 @@ cdef class GoldParse:
|
|||
for i, gold_i in enumerate(self.cand_to_gold):
|
||||
if doc[i].text.isspace():
|
||||
self.words[i] = doc[i].text
|
||||
self.tags[i] = 'SP'
|
||||
self.tags[i] = '_SP'
|
||||
self.heads[i] = None
|
||||
self.labels[i] = None
|
||||
self.ner[i] = 'O'
|
||||
|
|
Loading…
Reference in New Issue
Block a user