Fix Alignment class for undersegmentation

This commit is contained in:
Matthew Honnibal 2018-04-02 23:39:26 +02:00
parent e6641a11b1
commit c8ba54e052
3 changed files with 42 additions and 16 deletions

View File

@ -108,25 +108,37 @@ class Alignment(object):
many remaining subtokens align to the same value. many remaining subtokens align to the same value.
''' '''
output = [] output = []
for i, alignment in enumerate(self._t2y): for i, alignment in enumerate(self._y2t):
if len(alignment) == 1 and alignment[0][1] == 0: if isinstance(alignment, int):
output.append(items[alignment[0][0]]) output.append(items[alignment])
elif isinstance(alignment, tuple):
output.append((items[alignment[0]], alignment[1]))
else: else:
output.append([]) output.append([])
for j1, j2 in alignment: for entry in alignment:
output[-1].append((items[j1], j2)) if isinstance(entry, int):
output[-1].append(items[entry])
else:
output[-1].append((items[entry[0]], entry[1]))
return output return output
def index_to_yours(self, index): def index_to_yours(self, index):
'''Translate an index that points into their tokens to point into yours''' '''Translate an index that points into their tokens to point into yours'''
alignment = self._t2y[index] alignment = self._t2y[index]
if len(alignment) == 1 and alignment[0][2] == 0: return alignment
return alignment[0][0] #if isinstance(alignment, int):
else: # return alignment
output = [] #elif len(alignment) == 1 and isinstance(alignment, int):
for i1, i2, n_to_go in alignment: # return alignment[0]
output.append((i1, i2, n_to_go)) #elif len(alignment) == 1:
return output # return alignment[0][0]
#if len(alignment) == 1 and alignment[0][2] == 0:
# return alignment[0][0]
#else:
# output = []
# for i1, i2, n_to_go in alignment:
# output.append((i1, i2, n_to_go))
# return output
def to_theirs(self, items): def to_theirs(self, items):
raise NotImplementedError raise NotImplementedError
@ -203,8 +215,10 @@ class Alignment(object):
# Apply the alignment to get the new values # Apply the alignment to get the new values
new = [] new = []
for head_vals in heads: for head_vals in heads:
if not isinstance(head_vals, list): if isinstance(head_vals, int):
head_vals = [(head_vals, 0)] head_vals = [(head_vals, 0)]
elif isinstance(head_vals, tuple):
head_vals = [None]
for head_val in head_vals: for head_val in head_vals:
if not isinstance(head_val, tuple): if not isinstance(head_val, tuple):
head_val = (head_val, 0) head_val = (head_val, 0)
@ -269,15 +283,15 @@ def _convert_multi_align(one2one, many2one, one2many, one2part):
seen_j = Counter() seen_j = Counter()
for i, j in enumerate(one2one): for i, j in enumerate(one2one):
if j != -1: if j != -1:
output.append(j) output.append(int(j))
elif i in many2one: elif i in many2one:
j = many2one[i] j = many2one[i]
output.append((j, seen_j[j])) output.append((int(j), seen_j[j]))
seen_j[j] += 1 seen_j[j] += 1
elif i in one2many: elif i in one2many:
output.append([]) output.append([])
for j in one2many[i]: for j in one2many[i]:
output[-1].append(j) output[-1].append(int(j))
elif i in one2part: elif i in one2part:
output.append(one2part[i]) output.append(one2part[i])
else: else:

View File

@ -32,6 +32,7 @@ cdef class GoldParse:
cdef public list ents cdef public list ents
cdef public dict brackets cdef public dict brackets
cdef public object cats cdef public object cats
cdef public object _alignment
cdef readonly list cand_to_gold cdef readonly list cand_to_gold
cdef readonly list gold_to_cand cdef readonly list gold_to_cand

View File

@ -465,6 +465,17 @@ cdef class GoldParse:
self.heads[i] = None self.heads[i] = None
self.labels[i] = None self.labels[i] = None
self.ner[i] = 'O' self.ner[i] = 'O'
elif isinstance(self.labels[i], tuple):
sub_i = self.labels[i][1]
# If we're at the end of the subtoken, use the head and label
if (i+1) == len(self.labels) \
or not isinstance(self.labels[i+1], tuple) \
or self.labels[i][1] < sub_i:
self.labels[i] = self.labels[i][0]
self.heads[i] = self.heads[i][0]
else:
self.labels[i] = 'subtok'
self.heads[i] = i+1
cycle = nonproj.contains_cycle(self._alignment.flatten(self.heads)) cycle = nonproj.contains_cycle(self._alignment.flatten(self.heads))
if cycle is not None: if cycle is not None: