diff --git a/spacy/_align.pyx b/spacy/_align.pyx index 545cc68b9..3c6a215db 100644 --- a/spacy/_align.pyx +++ b/spacy/_align.pyx @@ -108,25 +108,37 @@ class Alignment(object): many remaining subtokens align to the same value. ''' output = [] - for i, alignment in enumerate(self._t2y): - if len(alignment) == 1 and alignment[0][1] == 0: - output.append(items[alignment[0][0]]) + for i, alignment in enumerate(self._y2t): + if isinstance(alignment, int): + output.append(items[alignment]) + elif isinstance(alignment, tuple): + output.append((items[alignment[0]], alignment[1])) else: output.append([]) - for j1, j2 in alignment: - output[-1].append((items[j1], j2)) + for entry in alignment: + if isinstance(entry, int): + output[-1].append(items[entry]) + else: + output[-1].append((items[entry[0]], entry[1])) return output def index_to_yours(self, index): '''Translate an index that points into their tokens to point into yours''' alignment = self._t2y[index] - if len(alignment) == 1 and alignment[0][2] == 0: - return alignment[0][0] - else: - output = [] - for i1, i2, n_to_go in alignment: - output.append((i1, i2, n_to_go)) - return output + return alignment + #if isinstance(alignment, int): + # return alignment + #elif len(alignment) == 1 and isinstance(alignment, int): + # return alignment[0] + #elif len(alignment) == 1: + # return alignment[0][0] + #if len(alignment) == 1 and alignment[0][2] == 0: + # return alignment[0][0] + #else: + # output = [] + # for i1, i2, n_to_go in alignment: + # output.append((i1, i2, n_to_go)) + # return output def to_theirs(self, items): raise NotImplementedError @@ -203,8 +215,10 @@ class Alignment(object): # Apply the alignment to get the new values new = [] for head_vals in heads: - if not isinstance(head_vals, list): + if isinstance(head_vals, int): head_vals = [(head_vals, 0)] + elif isinstance(head_vals, tuple): + head_vals = [None] for head_val in head_vals: if not isinstance(head_val, tuple): head_val = (head_val, 0) @@ -269,15 +283,15 @@ def _convert_multi_align(one2one, many2one, one2many, one2part): seen_j = Counter() for i, j in enumerate(one2one): if j != -1: - output.append(j) + output.append(int(j)) elif i in many2one: j = many2one[i] - output.append((j, seen_j[j])) + output.append((int(j), seen_j[j])) seen_j[j] += 1 elif i in one2many: output.append([]) for j in one2many[i]: - output[-1].append(j) + output[-1].append(int(j)) elif i in one2part: output.append(one2part[i]) else: diff --git a/spacy/gold.pxd b/spacy/gold.pxd index 0a689ac62..2be87b72a 100644 --- a/spacy/gold.pxd +++ b/spacy/gold.pxd @@ -32,6 +32,7 @@ cdef class GoldParse: cdef public list ents cdef public dict brackets cdef public object cats + cdef public object _alignment cdef readonly list cand_to_gold cdef readonly list gold_to_cand diff --git a/spacy/gold.pyx b/spacy/gold.pyx index d6b38bb98..279d218dc 100644 --- a/spacy/gold.pyx +++ b/spacy/gold.pyx @@ -465,6 +465,17 @@ cdef class GoldParse: self.heads[i] = None self.labels[i] = None self.ner[i] = 'O' + elif isinstance(self.labels[i], tuple): + sub_i = self.labels[i][1] + # If we're at the end of the subtoken, use the head and label + if (i+1) == len(self.labels) \ + or not isinstance(self.labels[i+1], tuple) \ + or self.labels[i][1] < sub_i: + self.labels[i] = self.labels[i][0] + self.heads[i] = self.heads[i][0] + else: + self.labels[i] = 'subtok' + self.heads[i] = i+1 cycle = nonproj.contains_cycle(self._alignment.flatten(self.heads)) if cycle is not None: