fix bug in updating tree structure when introducing additional roots

This commit is contained in:
Wolfgang Seeker 2016-04-25 12:01:19 +02:00
parent b6477fc4f4
commit f57f843e85
3 changed files with 106 additions and 19 deletions

View File

@ -399,31 +399,34 @@ cdef class ArcEager(TransitionSystem):
cdef TokenC* orig_head
cdef int new_edge
cdef int child_i
cdef TokenC* head_i
cdef int head_i
for i in range(st.length):
if st._sent[i].head == 0 and st._sent[i].dep == 0:
st._sent[i].dep = self.root_label
# If we're not using the Break transition, we segment via root-labelled
# arcs between the root words.
elif USE_ROOT_ARC_SEGMENT and st._sent[i].dep == self.root_label:
orig_head_id = st._sent[i].head
orig_head_id = i + st._sent[i].head
orig_head = &st._sent[orig_head_id]
if i < orig_head_id: # i is left dependent
orig_head.l_kids -= 1
if i == orig_head.l_edge: # i is left-most child
# find the second left-most child and make it the new l_edge
new_edge = orig_head_id
child_i = i
child_i = i+1
while child_i < orig_head_id:
if st._sent[child_i].head == orig_head_id:
if child_i + st._sent[child_i].head == orig_head_id:
new_edge = child_i
break
child_i += 1
# then walk up the path to root and update the l_edges of all ancestors
# the logic here works because the tree is guaranteed to be projective
head_i = &st._sent[orig_head.head]
while head_i.l_edge == orig_head.l_edge:
head_i.l_edge = new_edge
head_i = &st._sent[head_i.head]
head_i = orig_head_id + orig_head.head
while st._sent[head_i].l_edge == orig_head.l_edge:
st._sent[head_i].l_edge = new_edge
if st._sent[head_i].head == 0:
break
head_i += st._sent[head_i].head
orig_head.l_edge = new_edge
elif i > orig_head_id: # i is right dependent
@ -431,24 +434,28 @@ cdef class ArcEager(TransitionSystem):
if i == orig_head.r_edge:
# find the second right-most child and make it the new r_edge
new_edge = orig_head_id
child_i = i
child_i = i-1
while child_i > orig_head_id:
if st._sent[child_i].head == orig_head_id:
if child_i + st._sent[child_i].head == orig_head_id:
new_edge = child_i
break
child_i -= 1
# then walk up the path to root and update the l_edges of all ancestors
# then walk up the path to root and update the r_edges of all ancestors
# the logic here works because the tree is guaranteed to be projective
head_i = &st._sent[orig_head.head]
while head_i.r_edge == orig_head.r_edge:
head_i.r_edge = new_edge
head_i = &st._sent[head_i.head]
head_i = orig_head_id + orig_head.head
while st._sent[head_i].r_edge == orig_head.r_edge:
st._sent[head_i].r_edge = new_edge
if st._sent[head_i].head == 0:
break
head_i += st._sent[head_i].head
orig_head.r_edge = new_edge
# note that this can create non-projective trees if there are arcs
# note that this may create non-projective trees if there are arcs
# between nodes on both sides of the new root node
st._sent[i].head = 0
st._sent[st._sent[i].l_edge].sent_start = True
cdef int set_valid(self, int* output, const StateC* st) nogil:
cdef bint[N_MOVES] is_valid
is_valid[SHIFT] = Shift.is_valid(st, -1)

View File

@ -19,3 +19,82 @@ def test_one_word_sentence(EN):
with EN.parser.step_through(doc) as _:
pass
assert doc[0].dep != 0
def apply_transition_sequence(model, doc, sequence):
with model.parser.step_through(doc) as stepwise:
for transition in sequence:
stepwise.transition(transition)
@pytest.mark.models
def test_arc_eager_finalize_state(EN):
# right branching
example = EN.tokenizer.tokens_from_list(u"a b c d e".split(' '))
apply_transition_sequence(EN, example, ['R-nsubj','D','R-nsubj','R-nsubj','D','R-ROOT'])
print [ '%s/%s' % (t.dep_,t.head.i) for t in example ]
assert example[0].n_lefts == 0
assert example[0].n_rights == 2
assert example[0].left_edge.i == 0
assert example[0].right_edge.i == 3
assert example[0].head.i == 0
assert example[1].n_lefts == 0
assert example[1].n_rights == 0
assert example[1].left_edge.i == 1
assert example[1].right_edge.i == 1
assert example[1].head.i == 0
assert example[2].n_lefts == 0
assert example[2].n_rights == 1
assert example[2].left_edge.i == 2
assert example[2].right_edge.i == 3
assert example[2].head.i == 0
assert example[3].n_lefts == 0
assert example[3].n_rights == 0
assert example[3].left_edge.i == 3
assert example[3].right_edge.i == 3
assert example[3].head.i == 2
assert example[4].n_lefts == 0
assert example[4].n_rights == 0
assert example[4].left_edge.i == 4
assert example[4].right_edge.i == 4
assert example[4].head.i == 4
# left branching
example = EN.tokenizer.tokens_from_list(u"a b c d e".split(' '))
apply_transition_sequence(EN, example, ['S','L-nsubj','L-ROOT','S','L-nsubj','L-nsubj'])
print [ '%s/%s' % (t.dep_,t.head.i) for t in example ]
assert example[0].n_lefts == 0
assert example[0].n_rights == 0
assert example[0].left_edge.i == 0
assert example[0].right_edge.i == 0
assert example[0].head.i == 0
assert example[1].n_lefts == 0
assert example[1].n_rights == 0
assert example[1].left_edge.i == 1
assert example[1].right_edge.i == 1
assert example[1].head.i == 2
assert example[2].n_lefts == 1
assert example[2].n_rights == 0
assert example[2].left_edge.i == 1
assert example[2].right_edge.i == 2
assert example[2].head.i == 4
assert example[3].n_lefts == 0
assert example[3].n_rights == 0
assert example[3].left_edge.i == 3
assert example[3].right_edge.i == 3
assert example[3].head.i == 4
assert example[4].n_lefts == 2
assert example[4].n_rights == 0
assert example[4].left_edge.i == 1
assert example[4].right_edge.i == 4
assert example[4].head.i == 4

View File

@ -135,12 +135,13 @@ def test_sbd_for_root_label_dependents(EN):
make sure that the parser properly introduces a sentence boundary without
the break transition by checking for dependents with the root label
"""
example = EN.tokenizer.tokens_from_list(u"I bought a couch from IKEA. It was n't very comfortable .".split(' '))
example = EN.tokenizer.tokens_from_list(u"I saw a firefly It glowed".split(' '))
EN.tagger(example)
apply_transition_sequence(EN, example, ['L-nsubj','S','L-det','R-dobj','D','R-prep','R-pobj','D','D','S','L-nsubj','R-ROOT','R-neg','D','S','L-advmod','R-acomp','D','R-punct'])
apply_transition_sequence(EN, example, ['L-nsubj','S','L-det','R-dobj','D','S','L-nsubj','R-ROOT'])
print ['%s/%s' % (t.dep_,t.head.i) for t in example]
assert example[1].head.i == 1
assert example[7].head.i == 7
assert example[5].head.i == 5
sents = list(example.sents)
assert len(sents) == 2