* Prepare for break transition, with fast-forwarding. 86.5 on 1k nw gold preproc

This commit is contained in:
Matthew Honnibal 2015-06-10 14:08:30 +02:00
parent 495f528709
commit afd77a529b
3 changed files with 51 additions and 21 deletions

View File

@ -112,11 +112,12 @@ cdef bint _is_gold_root(const GoldParseC* gold, int word) nogil:
cdef class Shift:
@staticmethod
cdef bint is_valid(StateClass st, int label) nogil:
return st.buffer_length() >= 2 and not st.shifted[st.B(0)]
return st.buffer_length() >= 2 and not st.shifted[st.B(0)] and not st.B_(0).sent_end
@staticmethod
cdef int transition(StateClass state, int label) nogil:
state.push()
cdef int transition(StateClass st, int label) nogil:
st.push()
st.fast_forward()
@staticmethod
cdef int cost(StateClass st, const GoldParseC* gold, int label) nogil:
@ -142,6 +143,7 @@ cdef class Reduce:
st.pop()
else:
st.unshift()
st.fast_forward()
@staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
@ -159,17 +161,13 @@ cdef class Reduce:
cdef class LeftArc:
@staticmethod
cdef bint is_valid(StateClass st, int label) nogil:
if NON_MONOTONIC:
return st.stack_depth() >= 1 and st.buffer_length() >= 1 #and not missing_brackets(s)
else:
return st.stack_depth() >= 1 and st.buffer_length() >= 1 and not st.has_head(st.S(0))
return not st.B_(0).sent_end
@staticmethod
cdef int transition(StateClass st, int label) nogil:
st.add_arc(st.B(0), st.S(0), label)
st.pop()
if st.empty():
st.push()
st.fast_forward()
@staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
@ -190,12 +188,13 @@ cdef class LeftArc:
cdef class RightArc:
@staticmethod
cdef bint is_valid(StateClass st, int label) nogil:
return st.stack_depth() >= 1 and st.buffer_length() >= 1
return not st.B_(0).sent_end
@staticmethod
cdef int transition(StateClass st, int label) nogil:
st.add_arc(st.S(0), st.B(0), label)
st.push()
st.fast_forward()
@staticmethod
cdef inline int cost(StateClass s, const GoldParseC* gold, int label) nogil:
@ -221,7 +220,7 @@ cdef class Break:
cdef int i
if not USE_BREAK:
return False
elif st.eol():
elif st.at_break():
return False
elif st.stack_depth() < 1:
return False
@ -230,9 +229,13 @@ cdef class Break:
@staticmethod
cdef int transition(StateClass st, int label) nogil:
#st.set_sent_start()
while st.stack_depth() >= 2 and st.buffer_length() == 0:
Reduce.transition(st, -1)
st.set_break(st.B(0))
while st.stack_depth() >= 2 and st.has_head(st.S(0)):
st.pop()
if st.stack_depth() == 1:
st.pop()
else:
st.unshift()
@staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
@ -338,7 +341,7 @@ cdef class ArcEager(TransitionSystem):
return t
cdef int initialize_state(self, StateClass st) except -1:
st.push()
st.fast_forward()
cdef int finalize_state(self, StateClass st) except -1:
cdef int root_label = self.strings['ROOT']
@ -410,7 +413,7 @@ cdef class ArcEager(TransitionSystem):
best = self.c[i]
score = scores[i]
assert best.clas < self.n_moves
assert score > MIN_SCORE, (stcls.stack_depth(), stcls.buffer_length())
assert score > MIN_SCORE, (stcls.stack_depth(), stcls.buffer_length(), stcls.is_final(), stcls._b_i, stcls.length)
# Label Shift moves with the best Right-Arc label, for non-monotonic
# actions
if best.move == SHIFT:

View File

@ -62,6 +62,8 @@ cdef class StateClass:
cdef bint entity_is_open(self) nogil
cdef bint eol(self) nogil
cdef bint at_break(self) nogil
cdef bint is_final(self) nogil
@ -96,3 +98,5 @@ cdef class StateClass:
cdef void set_break(self, int i) nogil
cdef void clone(self, StateClass src) nogil
cdef void fast_forward(self) nogil

View File

@ -14,7 +14,7 @@ cdef class StateClass:
self._ents = <Entity*>mem.alloc(length, sizeof(Entity))
self.mem = mem
self.length = length
self._break = length
self._break = -1
self._s_i = 0
self._b_i = 0
self._e_i = 0
@ -105,10 +105,13 @@ cdef class StateClass:
return self._s_i <= 0
cdef bint eol(self) nogil:
return self._b_i >= self._break
return self.buffer_length() == 0
cdef bint at_break(self) nogil:
return self._break != -1
cdef bint is_final(self) nogil:
return self.stack_depth() <= 1 and self._b_i >= self.length
return self.stack_depth() <= 0 and self._b_i >= self.length
cdef bint has_head(self, int i) nogil:
return self.safe_get(i).head != 0
@ -131,14 +134,17 @@ cdef class StateClass:
return self._s_i
cdef int buffer_length(self) nogil:
return self._break - self._b_i
if self._break != -1:
return self._break - self._b_i
else:
return self.length - self._b_i
cdef void push(self) nogil:
self._stack[self._s_i] = self.B(0)
self._s_i += 1
self._b_i += 1
if self._b_i >= self._break:
self._break = self.length
self._break = -1
cdef void pop(self) nogil:
self._s_i -= 1
@ -149,6 +155,23 @@ cdef class StateClass:
self._s_i -= 1
self.shifted[self.B(0)] = True
cdef void fast_forward(self) nogil:
while self.buffer_length() == 0 or self.stack_depth() == 0:
if self.buffer_length() == 1 and self.stack_depth() == 0:
self.push()
self.pop()
elif self.buffer_length() == 0 and self.stack_depth() == 1:
self.pop()
elif self.buffer_length() == 0 and self.stack_depth() >= 2:
if self.has_head(self.S(0)):
self.pop()
else:
self.unshift()
elif self.buffer_length() >= 2 and self.stack_depth() == 0:
self.push()
else:
break
cdef void add_arc(self, int head, int child, int label) nogil:
if self.has_head(child):
self.del_arc(self.H(child), child)