* Prepare for break transition, with fast-forwarding. 86.5 on 1k nw gold preproc

This commit is contained in:
Matthew Honnibal 2015-06-10 14:08:30 +02:00
parent 495f528709
commit afd77a529b
3 changed files with 51 additions and 21 deletions

View File

@ -112,11 +112,12 @@ cdef bint _is_gold_root(const GoldParseC* gold, int word) nogil:
cdef class Shift: cdef class Shift:
@staticmethod @staticmethod
cdef bint is_valid(StateClass st, int label) nogil: cdef bint is_valid(StateClass st, int label) nogil:
return st.buffer_length() >= 2 and not st.shifted[st.B(0)] return st.buffer_length() >= 2 and not st.shifted[st.B(0)] and not st.B_(0).sent_end
@staticmethod @staticmethod
cdef int transition(StateClass state, int label) nogil: cdef int transition(StateClass st, int label) nogil:
state.push() st.push()
st.fast_forward()
@staticmethod @staticmethod
cdef int cost(StateClass st, const GoldParseC* gold, int label) nogil: cdef int cost(StateClass st, const GoldParseC* gold, int label) nogil:
@ -142,6 +143,7 @@ cdef class Reduce:
st.pop() st.pop()
else: else:
st.unshift() st.unshift()
st.fast_forward()
@staticmethod @staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil: cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
@ -159,17 +161,13 @@ cdef class Reduce:
cdef class LeftArc: cdef class LeftArc:
@staticmethod @staticmethod
cdef bint is_valid(StateClass st, int label) nogil: cdef bint is_valid(StateClass st, int label) nogil:
if NON_MONOTONIC: return not st.B_(0).sent_end
return st.stack_depth() >= 1 and st.buffer_length() >= 1 #and not missing_brackets(s)
else:
return st.stack_depth() >= 1 and st.buffer_length() >= 1 and not st.has_head(st.S(0))
@staticmethod @staticmethod
cdef int transition(StateClass st, int label) nogil: cdef int transition(StateClass st, int label) nogil:
st.add_arc(st.B(0), st.S(0), label) st.add_arc(st.B(0), st.S(0), label)
st.pop() st.pop()
if st.empty(): st.fast_forward()
st.push()
@staticmethod @staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil: cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
@ -190,12 +188,13 @@ cdef class LeftArc:
cdef class RightArc: cdef class RightArc:
@staticmethod @staticmethod
cdef bint is_valid(StateClass st, int label) nogil: cdef bint is_valid(StateClass st, int label) nogil:
return st.stack_depth() >= 1 and st.buffer_length() >= 1 return not st.B_(0).sent_end
@staticmethod @staticmethod
cdef int transition(StateClass st, int label) nogil: cdef int transition(StateClass st, int label) nogil:
st.add_arc(st.S(0), st.B(0), label) st.add_arc(st.S(0), st.B(0), label)
st.push() st.push()
st.fast_forward()
@staticmethod @staticmethod
cdef inline int cost(StateClass s, const GoldParseC* gold, int label) nogil: cdef inline int cost(StateClass s, const GoldParseC* gold, int label) nogil:
@ -221,7 +220,7 @@ cdef class Break:
cdef int i cdef int i
if not USE_BREAK: if not USE_BREAK:
return False return False
elif st.eol(): elif st.at_break():
return False return False
elif st.stack_depth() < 1: elif st.stack_depth() < 1:
return False return False
@ -230,9 +229,13 @@ cdef class Break:
@staticmethod @staticmethod
cdef int transition(StateClass st, int label) nogil: cdef int transition(StateClass st, int label) nogil:
#st.set_sent_start() st.set_break(st.B(0))
while st.stack_depth() >= 2 and st.buffer_length() == 0: while st.stack_depth() >= 2 and st.has_head(st.S(0)):
Reduce.transition(st, -1) st.pop()
if st.stack_depth() == 1:
st.pop()
else:
st.unshift()
@staticmethod @staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil: cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil:
@ -338,7 +341,7 @@ cdef class ArcEager(TransitionSystem):
return t return t
cdef int initialize_state(self, StateClass st) except -1: cdef int initialize_state(self, StateClass st) except -1:
st.push() st.fast_forward()
cdef int finalize_state(self, StateClass st) except -1: cdef int finalize_state(self, StateClass st) except -1:
cdef int root_label = self.strings['ROOT'] cdef int root_label = self.strings['ROOT']
@ -410,7 +413,7 @@ cdef class ArcEager(TransitionSystem):
best = self.c[i] best = self.c[i]
score = scores[i] score = scores[i]
assert best.clas < self.n_moves assert best.clas < self.n_moves
assert score > MIN_SCORE, (stcls.stack_depth(), stcls.buffer_length()) assert score > MIN_SCORE, (stcls.stack_depth(), stcls.buffer_length(), stcls.is_final(), stcls._b_i, stcls.length)
# Label Shift moves with the best Right-Arc label, for non-monotonic # Label Shift moves with the best Right-Arc label, for non-monotonic
# actions # actions
if best.move == SHIFT: if best.move == SHIFT:

View File

@ -62,6 +62,8 @@ cdef class StateClass:
cdef bint entity_is_open(self) nogil cdef bint entity_is_open(self) nogil
cdef bint eol(self) nogil cdef bint eol(self) nogil
cdef bint at_break(self) nogil
cdef bint is_final(self) nogil cdef bint is_final(self) nogil
@ -96,3 +98,5 @@ cdef class StateClass:
cdef void set_break(self, int i) nogil cdef void set_break(self, int i) nogil
cdef void clone(self, StateClass src) nogil cdef void clone(self, StateClass src) nogil
cdef void fast_forward(self) nogil

View File

@ -14,7 +14,7 @@ cdef class StateClass:
self._ents = <Entity*>mem.alloc(length, sizeof(Entity)) self._ents = <Entity*>mem.alloc(length, sizeof(Entity))
self.mem = mem self.mem = mem
self.length = length self.length = length
self._break = length self._break = -1
self._s_i = 0 self._s_i = 0
self._b_i = 0 self._b_i = 0
self._e_i = 0 self._e_i = 0
@ -105,10 +105,13 @@ cdef class StateClass:
return self._s_i <= 0 return self._s_i <= 0
cdef bint eol(self) nogil: cdef bint eol(self) nogil:
return self._b_i >= self._break return self.buffer_length() == 0
cdef bint at_break(self) nogil:
return self._break != -1
cdef bint is_final(self) nogil: cdef bint is_final(self) nogil:
return self.stack_depth() <= 1 and self._b_i >= self.length return self.stack_depth() <= 0 and self._b_i >= self.length
cdef bint has_head(self, int i) nogil: cdef bint has_head(self, int i) nogil:
return self.safe_get(i).head != 0 return self.safe_get(i).head != 0
@ -131,14 +134,17 @@ cdef class StateClass:
return self._s_i return self._s_i
cdef int buffer_length(self) nogil: cdef int buffer_length(self) nogil:
return self._break - self._b_i if self._break != -1:
return self._break - self._b_i
else:
return self.length - self._b_i
cdef void push(self) nogil: cdef void push(self) nogil:
self._stack[self._s_i] = self.B(0) self._stack[self._s_i] = self.B(0)
self._s_i += 1 self._s_i += 1
self._b_i += 1 self._b_i += 1
if self._b_i >= self._break: if self._b_i >= self._break:
self._break = self.length self._break = -1
cdef void pop(self) nogil: cdef void pop(self) nogil:
self._s_i -= 1 self._s_i -= 1
@ -149,6 +155,23 @@ cdef class StateClass:
self._s_i -= 1 self._s_i -= 1
self.shifted[self.B(0)] = True self.shifted[self.B(0)] = True
cdef void fast_forward(self) nogil:
while self.buffer_length() == 0 or self.stack_depth() == 0:
if self.buffer_length() == 1 and self.stack_depth() == 0:
self.push()
self.pop()
elif self.buffer_length() == 0 and self.stack_depth() == 1:
self.pop()
elif self.buffer_length() == 0 and self.stack_depth() >= 2:
if self.has_head(self.S(0)):
self.pop()
else:
self.unshift()
elif self.buffer_length() >= 2 and self.stack_depth() == 0:
self.push()
else:
break
cdef void add_arc(self, int head, int child, int label) nogil: cdef void add_arc(self, int head, int child, int label) nogil:
if self.has_head(child): if self.has_head(child):
self.del_arc(self.H(child), child) self.del_arc(self.H(child), child)