* Allow gold parse to cut non-projective arcs

This commit is contained in:
Matthew Honnibal 2015-05-31 01:11:56 +02:00
parent d512d20d81
commit 87d6551d19

View File

@ -163,7 +163,7 @@ def _consume_ent(tags):
cdef class GoldParse: cdef class GoldParse:
def __init__(self, tokens, annot_tuples, brackets=tuple()): def __init__(self, tokens, annot_tuples, brackets=tuple(), make_projective=False):
self.mem = Pool() self.mem = Pool()
self.loss = 0 self.loss = 0
self.length = len(tokens) self.length = len(tokens)
@ -197,6 +197,24 @@ cdef class GoldParse:
self.labels[i] = annot_tuples[4][gold_i] self.labels[i] = annot_tuples[4][gold_i]
self.ner[i] = annot_tuples[5][gold_i] self.ner[i] = annot_tuples[5][gold_i]
# If we have any non-projective arcs, i.e. crossing brackets, consider
# the heads for those words missing in the gold-standard.
# This way, we can train from these sentences
cdef int w1, w2, h1, h2
if make_projective:
heads = list(self.heads)
for w1 in range(self.length):
if heads[w1] is not None:
h1 = heads[w1]
for w2 in range(w1+1, self.length):
if heads[w2] is not None:
h2 = heads[w2]
if _arcs_cross(w1, h1, w2, h2):
self.heads[w1] = None
self.labels[w1] = ''
self.heads[w2] = None
self.labels[w2] = ''
self.brackets = {} self.brackets = {}
for (gold_start, gold_end, label_str) in brackets: for (gold_start, gold_end, label_str) in brackets:
start = self.gold_to_cand[gold_start] start = self.gold_to_cand[gold_start]
@ -210,17 +228,25 @@ cdef class GoldParse:
@property @property
def is_projective(self): def is_projective(self):
heads = [head for (id_, word, tag, head, dep, ner) in self.orig_annot] heads = list(self.heads)
deps = sorted([sorted(arc) for arc in enumerate(heads)]) for w1 in range(self.length):
for w1, h1 in deps: if heads[w1] is not None:
for w2, h2 in deps: h1 = heads[w1]
if w1 < w2 < h1 < h2: for w2 in range(self.length):
if heads[w2] is not None and _arcs_cross(w1, h1, w2, heads[w2]):
return False return False
elif w1 < w2 == h2 < h1:
return False
else:
return True return True
cdef int _arcs_cross(int w1, int h1, int w2, int h2) except -1:
if w1 > h1:
w1, h1 = h1, w1
if w2 > h2:
w2, h2 = h2, w2
if w1 > w2:
w1, h1, w2, h2 = w2, h2, w1, h1
return w1 < w2 < h1 < h2 or w1 < w2 == h2 < h1
def is_punct_label(label): def is_punct_label(label):
return label == 'P' or label.lower() == 'punct' return label == 'P' or label.lower() == 'punct'