From afd77a529baee6bee0ab61e5ffb405b7500fc43e Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Wed, 10 Jun 2015 14:08:30 +0200 Subject: [PATCH] * Prepare for break transition, with fast-forwarding. 86.5 on 1k nw gold preproc --- spacy/syntax/arc_eager.pyx | 35 +++++++++++++++++++---------------- spacy/syntax/stateclass.pxd | 4 ++++ spacy/syntax/stateclass.pyx | 33 ++++++++++++++++++++++++++++----- 3 files changed, 51 insertions(+), 21 deletions(-) diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index 222388b69..1bd7c00f5 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -112,11 +112,12 @@ cdef bint _is_gold_root(const GoldParseC* gold, int word) nogil: cdef class Shift: @staticmethod cdef bint is_valid(StateClass st, int label) nogil: - return st.buffer_length() >= 2 and not st.shifted[st.B(0)] + return st.buffer_length() >= 2 and not st.shifted[st.B(0)] and not st.B_(0).sent_end @staticmethod - cdef int transition(StateClass state, int label) nogil: - state.push() + cdef int transition(StateClass st, int label) nogil: + st.push() + st.fast_forward() @staticmethod cdef int cost(StateClass st, const GoldParseC* gold, int label) nogil: @@ -142,6 +143,7 @@ cdef class Reduce: st.pop() else: st.unshift() + st.fast_forward() @staticmethod cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil: @@ -159,17 +161,13 @@ cdef class Reduce: cdef class LeftArc: @staticmethod cdef bint is_valid(StateClass st, int label) nogil: - if NON_MONOTONIC: - return st.stack_depth() >= 1 and st.buffer_length() >= 1 #and not missing_brackets(s) - else: - return st.stack_depth() >= 1 and st.buffer_length() >= 1 and not st.has_head(st.S(0)) + return not st.B_(0).sent_end @staticmethod cdef int transition(StateClass st, int label) nogil: st.add_arc(st.B(0), st.S(0), label) st.pop() - if st.empty(): - st.push() + st.fast_forward() @staticmethod cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil: @@ -190,12 +188,13 @@ cdef class LeftArc: cdef class RightArc: @staticmethod cdef bint is_valid(StateClass st, int label) nogil: - return st.stack_depth() >= 1 and st.buffer_length() >= 1 + return not st.B_(0).sent_end @staticmethod cdef int transition(StateClass st, int label) nogil: st.add_arc(st.S(0), st.B(0), label) st.push() + st.fast_forward() @staticmethod cdef inline int cost(StateClass s, const GoldParseC* gold, int label) nogil: @@ -221,7 +220,7 @@ cdef class Break: cdef int i if not USE_BREAK: return False - elif st.eol(): + elif st.at_break(): return False elif st.stack_depth() < 1: return False @@ -230,9 +229,13 @@ cdef class Break: @staticmethod cdef int transition(StateClass st, int label) nogil: - #st.set_sent_start() - while st.stack_depth() >= 2 and st.buffer_length() == 0: - Reduce.transition(st, -1) + st.set_break(st.B(0)) + while st.stack_depth() >= 2 and st.has_head(st.S(0)): + st.pop() + if st.stack_depth() == 1: + st.pop() + else: + st.unshift() @staticmethod cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil: @@ -338,7 +341,7 @@ cdef class ArcEager(TransitionSystem): return t cdef int initialize_state(self, StateClass st) except -1: - st.push() + st.fast_forward() cdef int finalize_state(self, StateClass st) except -1: cdef int root_label = self.strings['ROOT'] @@ -410,7 +413,7 @@ cdef class ArcEager(TransitionSystem): best = self.c[i] score = scores[i] assert best.clas < self.n_moves - assert score > MIN_SCORE, (stcls.stack_depth(), stcls.buffer_length()) + assert score > MIN_SCORE, (stcls.stack_depth(), stcls.buffer_length(), stcls.is_final(), stcls._b_i, stcls.length) # Label Shift moves with the best Right-Arc label, for non-monotonic # actions if best.move == SHIFT: diff --git a/spacy/syntax/stateclass.pxd b/spacy/syntax/stateclass.pxd index dcc57474c..54b039208 100644 --- a/spacy/syntax/stateclass.pxd +++ b/spacy/syntax/stateclass.pxd @@ -62,6 +62,8 @@ cdef class StateClass: cdef bint entity_is_open(self) nogil cdef bint eol(self) nogil + + cdef bint at_break(self) nogil cdef bint is_final(self) nogil @@ -96,3 +98,5 @@ cdef class StateClass: cdef void set_break(self, int i) nogil cdef void clone(self, StateClass src) nogil + + cdef void fast_forward(self) nogil diff --git a/spacy/syntax/stateclass.pyx b/spacy/syntax/stateclass.pyx index 1e4f3b3f0..8b6abfdab 100644 --- a/spacy/syntax/stateclass.pyx +++ b/spacy/syntax/stateclass.pyx @@ -14,7 +14,7 @@ cdef class StateClass: self._ents = mem.alloc(length, sizeof(Entity)) self.mem = mem self.length = length - self._break = length + self._break = -1 self._s_i = 0 self._b_i = 0 self._e_i = 0 @@ -105,10 +105,13 @@ cdef class StateClass: return self._s_i <= 0 cdef bint eol(self) nogil: - return self._b_i >= self._break + return self.buffer_length() == 0 + + cdef bint at_break(self) nogil: + return self._break != -1 cdef bint is_final(self) nogil: - return self.stack_depth() <= 1 and self._b_i >= self.length + return self.stack_depth() <= 0 and self._b_i >= self.length cdef bint has_head(self, int i) nogil: return self.safe_get(i).head != 0 @@ -131,14 +134,17 @@ cdef class StateClass: return self._s_i cdef int buffer_length(self) nogil: - return self._break - self._b_i + if self._break != -1: + return self._break - self._b_i + else: + return self.length - self._b_i cdef void push(self) nogil: self._stack[self._s_i] = self.B(0) self._s_i += 1 self._b_i += 1 if self._b_i >= self._break: - self._break = self.length + self._break = -1 cdef void pop(self) nogil: self._s_i -= 1 @@ -149,6 +155,23 @@ cdef class StateClass: self._s_i -= 1 self.shifted[self.B(0)] = True + cdef void fast_forward(self) nogil: + while self.buffer_length() == 0 or self.stack_depth() == 0: + if self.buffer_length() == 1 and self.stack_depth() == 0: + self.push() + self.pop() + elif self.buffer_length() == 0 and self.stack_depth() == 1: + self.pop() + elif self.buffer_length() == 0 and self.stack_depth() >= 2: + if self.has_head(self.S(0)): + self.pop() + else: + self.unshift() + elif self.buffer_length() >= 2 and self.stack_depth() == 0: + self.push() + else: + break + cdef void add_arc(self, int head, int child, int label) nogil: if self.has_head(child): self.del_arc(self.H(child), child)