* Add flag to toggle handling of multi-root inputs without the Break transition. Clear up now unused best_valid stuff.

This commit is contained in:
Matthew Honnibal 2015-06-14 00:28:37 +02:00
parent 75289b4761
commit 399f15fbdf

View File

@ -20,6 +20,7 @@ from .stateclass cimport StateClass
DEF NON_MONOTONIC = True DEF NON_MONOTONIC = True
DEF USE_BREAK = False DEF USE_BREAK = False
DEF USE_ROOT_ARC_SEGMENT = True
cdef weight_t MIN_SCORE = -90000 cdef weight_t MIN_SCORE = -90000
@ -86,7 +87,7 @@ cdef int arc_cost(StateClass stcls, const GoldParseC* gold, int head, int child)
cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) nogil: cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) nogil:
if gold.labels[child] == -1: if gold.labels[child] == -1:
return True return True
elif _is_gold_root(gold, head) and _is_gold_root(gold, child): elif USE_ROOT_ARC_SEGMENT and _is_gold_root(gold, head) and _is_gold_root(gold, child):
return True return True
elif gold.heads[child] == head: elif gold.heads[child] == head:
return True return True
@ -352,10 +353,14 @@ cdef class ArcEager(TransitionSystem):
st.fast_forward() st.fast_forward()
cdef int finalize_state(self, StateClass st) except -1: cdef int finalize_state(self, StateClass st) except -1:
cdef int root_label = self.strings['ROOT'] cdef int root_label = self.strings['root']
for i in range(st.length): for i in range(st.length):
if st._sent[i].head == 0 and st._sent[i].dep == 0: if st._sent[i].head == 0 and st._sent[i].dep == 0:
st._sent[i].dep = root_label st._sent[i].dep = root_label
# If we're not using the Break transition, we segment via root-labelled
# arcs between the root words.
elif USE_ROOT_ARC_SEGMENT and st._sent[i].dep == root_label:
st._sent[i].head = 0
cdef int set_valid(self, bint* output, StateClass stcls) except -1: cdef int set_valid(self, bint* output, StateClass stcls) except -1:
cdef bint[N_MOVES] is_valid cdef bint[N_MOVES] is_valid
@ -422,12 +427,4 @@ cdef class ArcEager(TransitionSystem):
score = scores[i] score = scores[i]
assert best.clas < self.n_moves assert best.clas < self.n_moves
assert score > MIN_SCORE, (stcls.stack_depth(), stcls.buffer_length(), stcls.is_final(), stcls._b_i, stcls.length) assert score > MIN_SCORE, (stcls.stack_depth(), stcls.buffer_length(), stcls.is_final(), stcls._b_i, stcls.length)
# Label Shift moves with the best Right-Arc label, for non-monotonic
# actions
if best.move == SHIFT:
score = MIN_SCORE
for i in range(self.n_moves):
if self.c[i].move == RIGHT and scores[i] > score:
best.label = self.c[i].label
score = scores[i]
return best return best