diff --git a/spacy/syntax/arc_eager.pxd b/spacy/syntax/arc_eager.pxd index 1390d949c..5b7a6e3db 100644 --- a/spacy/syntax/arc_eager.pxd +++ b/spacy/syntax/arc_eager.pxd @@ -5,6 +5,13 @@ from thinc.typedefs cimport weight_t from .stateclass cimport StateClass from .transition_system cimport TransitionSystem, Transition +from ..gold cimport GoldParseC + cdef class ArcEager(TransitionSystem): pass + + +cdef int push_cost(StateClass stcls, const GoldParseC* gold, int target) nogil +cdef int arc_cost(StateClass stcls, const GoldParseC* gold, int head, int child) nogil + diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index 1bd7c00f5..e5257b18a 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -222,20 +222,20 @@ cdef class Break: return False elif st.at_break(): return False + elif st.B(0) == 0: + return False elif st.stack_depth() < 1: return False + elif (st.S(0) + 1) != st.B(0): + # Must break at the token boundary + return False else: return True @staticmethod cdef int transition(StateClass st, int label) nogil: st.set_break(st.B(0)) - while st.stack_depth() >= 2 and st.has_head(st.S(0)): - st.pop() - if st.stack_depth() == 1: - st.pop() - else: - st.unshift() + st.fast_forward() @staticmethod cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil: @@ -243,32 +243,37 @@ cdef class Break: @staticmethod cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil: - # When we break, we can't reach any arcs between stack and buffer - # So cost is number of deps between S0...Sn and B0...Nn - cdef int cost = 0 - cdef int i, j, B_i, S_i - for i in range(s.buffer_length()): - B_i = s.B(i) - for j in range(s.stack_depth()): - S_i = s.S(j) - cost += gold.heads[B_i] == S_i - cost += gold.heads[S_i] == B_i - return cost + # Check for sentence boundary --- if it's here, we can't have any deps + # between stack and buffer, so rest of action is irrelevant. + s0_root = _get_root(s.S(0), gold) + b0_root = _get_root(s.B(0), gold) + if s0_root == -1 or b0_root == -1 or s0_root != b0_root: + return 0 + else: + return 1 @staticmethod cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil: return 0 +cdef int _get_root(int word, const GoldParseC* gold) nogil: + while gold.heads[word] != word and gold.labels[word] != -1 and word >= 0: + word = gold.heads[word] + if gold.labels[word] == -1: + return -1 + else: + return word + cdef class ArcEager(TransitionSystem): @classmethod def get_labels(cls, gold_parses): - move_labels = {SHIFT: {'': True}, REDUCE: {'': True}, RIGHT: {}, - LEFT: {'ROOT': True}, BREAK: {'ROOT': True}} + move_labels = {SHIFT: {'': True}, REDUCE: {'': True}, RIGHT: {'root': True}, + LEFT: {'root': True}, BREAK: {'root': True}} for raw_text, sents in gold_parses: for (ids, words, tags, heads, labels, iob), ctnts in sents: for child, head, label in zip(ids, heads, labels): - if label != 'ROOT': + if label != 'root': if head < child: move_labels[RIGHT][label] = True elif head > child: @@ -341,6 +346,9 @@ cdef class ArcEager(TransitionSystem): return t cdef int initialize_state(self, StateClass st) except -1: + # Ensure sent_end is set to 0 throughout + for i in range(st.length): + st._sent[i].sent_end = False st.fast_forward() cdef int finalize_state(self, StateClass st) except -1: diff --git a/spacy/syntax/ner.pyx b/spacy/syntax/ner.pyx index c1005adae..c27bae1f2 100644 --- a/spacy/syntax/ner.pyx +++ b/spacy/syntax/ner.pyx @@ -322,7 +322,6 @@ cdef class Out: cdef int g_act = gold.ner[s.B(0)].move cdef int g_tag = gold.ner[s.B(0)].label - if g_act == MISSING: return 0 elif g_act == BEGIN: diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index 4be1046bc..30d5b5f92 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -36,6 +36,8 @@ from ..tokens cimport Tokens, TokenC from ..strings cimport StringStore from .arc_eager cimport TransitionSystem, Transition +from .arc_eager cimport push_cost, arc_cost + from .transition_system import OracleError from ..gold cimport GoldParse @@ -95,11 +97,12 @@ cdef class Parser: cdef Transition guess words = [w.orth_ for w in tokens] while not stcls.is_final(): - #print stcls.print_state(words) fill_context(context, stcls) scores = self.model.score(context) guess = self.moves.best_valid(scores, stcls) + #print self.moves.move_name(guess.move, guess.label), stcls.print_state(words) guess.do(stcls, guess.label) + assert stcls._s_i >= 0 self.moves.finalize_state(stcls) tokens.set_parse(stcls._sent) @@ -128,12 +131,27 @@ cdef class Parser: cdef atom_t[CONTEXT_SIZE] context loss = 0 words = [w.orth_ for w in tokens] + history = [] while not stcls.is_final(): + assert stcls._s_i >= 0 fill_context(context, stcls) scores = self.model.score(context) guess = self.moves.best_valid(scores, stcls) - best = self.moves.best_gold(scores, stcls, gold) + try: + best = self.moves.best_gold(scores, stcls, gold) + except: + history.append((self.moves.move_name(guess.move, guess.label), '!', stcls.print_state(words))) + for i, word in enumerate(words): + print gold.orig_annot[i] + print '\n'.join('\t'.join(s) for s in history) + print words[gold.c.heads[stcls.S(0)]] + print words[gold.c.heads[stcls.B(0)]] + print push_cost(stcls, &gold.c, stcls.B(0)) + print arc_cost(stcls, &gold.c, stcls.S(0), stcls.B(0)) + self.moves.set_valid(self.moves._is_valid, stcls) + raise cost = guess.get_cost(stcls, &gold.c, guess.label) + history.append((self.moves.move_name(guess.move, guess.label), str(cost), stcls.print_state(words))) self.model.update(context, guess.clas, best.clas, cost) guess.do(stcls, guess.label) loss += cost diff --git a/spacy/syntax/stateclass.pyx b/spacy/syntax/stateclass.pyx index 8b6abfdab..5c1895a1e 100644 --- a/spacy/syntax/stateclass.pyx +++ b/spacy/syntax/stateclass.pyx @@ -143,7 +143,7 @@ cdef class StateClass: self._stack[self._s_i] = self.B(0) self._s_i += 1 self._b_i += 1 - if self._b_i >= self._break: + if self._b_i > self._break: self._break = -1 cdef void pop(self) nogil: @@ -167,7 +167,7 @@ cdef class StateClass: self.pop() else: self.unshift() - elif self.buffer_length() >= 2 and self.stack_depth() == 0: + elif (self.length - self._b_i) >= 1 and self.stack_depth() == 0: self.push() else: break @@ -208,10 +208,9 @@ cdef class StateClass: self._sent[i].ent_iob = ent_iob self._sent[i].ent_type = ent_type - cdef void set_break(self, int i) nogil: - if 0 <= i < self.length: - self._sent[i].sent_end = True - self._break = i + cdef void set_break(self, int _) nogil: + self._sent[self.B(0)].sent_end = True + self._break = self._b_i cdef void clone(self, StateClass src) nogil: memcpy(self._sent, src._sent, self.length * sizeof(TokenC)) @@ -229,7 +228,7 @@ cdef class StateClass: third = words[self.S(2)] + '_%d' % self.S_(2).head n0 = words[self.B(0)] n1 = words[self.B(1)] - return ' '.join((str(self.buffer_length()), str(self.stack_depth()), third, second, top, '|', n0, n1)) + return ' '.join((str(self.buffer_length()), str(self.B_(0).sent_end), str(self._b_i), str(self._break), str(self.length), str(self.stack_depth()), third, second, top, '|', n0, n1)) # From https://en.wikipedia.org/wiki/Hamming_weight