Fix head alignment in GoldParse

This commit is contained in:
Matthew Honnibal 2018-04-03 01:54:45 +02:00
parent 88de8fe323
commit 9c5c940441

View File

@ -451,37 +451,35 @@ cdef class GoldParse:
annot_tuples = (range(len(words)), words, tags, heads, deps, entities) annot_tuples = (range(len(words)), words, tags, heads, deps, entities)
self.orig_annot = list(zip(*annot_tuples)) self.orig_annot = list(zip(*annot_tuples))
if words is not None: self.words = self._alignment.to_yours(words)
self.words = self._alignment.to_yours(words) self.tags = self._alignment.to_yours(tags)
if tags is not None: self.labels = self._alignment.to_yours(deps)
self.tags = self._alignment.to_yours(tags) self.tags = self._alignment.to_yours(tags)
if deps is not None: self.ner = self._alignment.to_yours(entities)
self.labels = self._alignment.to_yours(deps) for gold_i, gold_head in enumerate(heads):
if tags is not None: if gold_head is None:
self.tags = self._alignment.to_yours(tags) continue
if entities is not None: cand_i = self._alignment._t2y[gold_i]
self.ner = self._alignment.to_yours(entities) cand_head = self._alignment._t2y[gold_head]
if heads is not None: if cand_i is None or cand_head is None:
for gold_i, gold_head in enumerate(heads): continue
if gold_head is None: elif isinstance(cand_i, int):
continue self.heads[cand_i] = cand_head
cand_i = self._alignment._t2y[gold_i] elif isinstance(cand_i, list):
cand_head = self._alignment._t2y[gold_head] for sub_i in cand_i[:-1]:
if cand_i is None or cand_head is None: self.heads[sub_i] = sub_i+1
continue if isinstance(cand_head, list):
elif isinstance(cand_i, int): self.heads[cand_i[-1]] = cand_head[-1]
self.heads[cand_i] = cand_head else:
elif isinstance(cand_i, list):
for sub_i in cand_i[:-1]:
self.heads[sub_i] = sub_i+1
self.heads[cand_i[-1]] = cand_head self.heads[cand_i[-1]] = cand_head
elif isinstance(cand_i, tuple): elif isinstance(cand_i, tuple) and isinstance(cand_head, int):
cand_i, sub_i = cand_i # We only handle one-to-many or many-to-one, not many-to-many
if not isinstance(self.heads[cand_i], list): cand_i, sub_i = cand_i
self.heads[cand_i] = [] if not isinstance(self.heads[cand_i], list):
while len(self.heads[cand_i]) <= sub_i: self.heads[cand_i] = []
self.heads[cand_i].append(None) while len(self.heads[cand_i]) <= sub_i:
self.heads[cand_i][sub_i] = cand_head self.heads[cand_i].append(None)
self.heads[cand_i][sub_i] = cand_head
for i in range(len(doc)): for i in range(len(doc)):
# Fix spaces # Fix spaces
@ -500,7 +498,7 @@ cdef class GoldParse:
self.labels[i] = self.labels[i][0] self.labels[i] = self.labels[i][0]
else: else:
self.labels[i] = 'subtok' self.labels[i] = 'subtok'
self.heads[i] = i+1 #self.heads[i] = i+1
cycle = nonproj.contains_cycle(self._alignment.flatten(self.heads)) cycle = nonproj.contains_cycle(self._alignment.flatten(self.heads))
if cycle is not None: if cycle is not None: