diff --git a/spacy/syntax/_state.pxd b/spacy/syntax/_state.pxd index 2362d925c..5882b37aa 100644 --- a/spacy/syntax/_state.pxd +++ b/spacy/syntax/_state.pxd @@ -157,6 +157,45 @@ cdef cppclass StateC: else: ids[i] = -1 + int can_push() nogil const: + if this.buffer_length == 0: + return 0 + else: + return 1 + + int can_pop() nogil const: + if this.stack_depth() < 1: + return 0 + else: + return 1 + + int can_arc() nogil const: + if this.at_break(): + return 0 + elif this.stack_depth() < 1: + return 0 + elif this.buffer_length == 0: + return 0 + else: + return 1 + + int can_break() nogil const: + if this.buffer_length == 0: + return False + elif this.B_(0).l_edge < 0: + return False + elif this._sent[this.B_(0).l_edge].sent_start < 0: + return False + elif this.stack_depth() < 1: # ?? I guess stops first action break? + return False + elif this.at_break(): + return False + else: + return True + + int can_split() nogil const: + return 0 + int S(int i) nogil const: if i >= this._s_i: return -1 @@ -265,7 +304,7 @@ cdef cppclass StateC: return this._n_until_break == 0 bint is_final() nogil const: - return this.stack_depth() <= 0 and this.buffer_length == 0 + return this.stack_depth() <= 1 and this.buffer_length == 0 bint has_head(int i) nogil const: return this.safe_get(i).head != 0 @@ -287,6 +326,12 @@ cdef cppclass StateC: int stack_depth() nogil const: return this._s_i + int segment_length() nogil const: + if this._n_until_break != -1: + return this._n_until_break + else: + return this.buffer_length + uint64_t hash() nogil const: cdef TokenC[11] sig sig[0] = this.S_(2)[0] @@ -460,69 +505,3 @@ cdef cppclass StateC: this._n_until_break = src._n_until_break this.offset = src.offset this._empty_token = src._empty_token - - void fast_forward() nogil: - # space token attachement policy: - # - attach space tokens always to the last preceding real token - # - except if it's the beginning of a sentence, then attach to the first following - # - boundary case: a document containing multiple space tokens but nothing else, - # then make the last space token the head of all others - - while is_space_token(this.B_(0)) \ - or this.eol() \ - or this.stack_depth() == 0: - if this.eol(): - # remove the last sentence's root from the stack - if this.stack_depth() == 1: - this.pop() - # parser got stuck: reduce stack or unshift - elif this.stack_depth() > 1: - if this.has_head(this.S(0)): - this.pop() - else: - this.unshift() - # stack is empty but there is another sentence on the buffer - elif this.buffer_length != 0: - this.push() - else: # stack empty and nothing else coming - break - - elif is_space_token(this.B_(0)): - # the normal case: we're somewhere inside a sentence - if this.stack_depth() > 0: - # assert not is_space_token(this.S_(0)) - # attach all coming space tokens to their last preceding - # real token (which should be on the top of the stack) - while is_space_token(this.B_(0)): - this.add_arc(this.S(0),this.B(0),0) - this.push() - this.pop() - # the rare case: we're at the beginning of a document: - # space tokens are attached to the first real token on the buffer - elif this.stack_depth() == 0: - # store all space tokens on the stack until a real token shows up - # or the last token on the buffer is reached - while is_space_token(this.B_(0)) and this.buffer_length > 1: - this.push() - # empty the stack by attaching all space tokens to the - # first token on the buffer - # boundary case: if all tokens are space tokens, the last one - # becomes the head of all others - while this.stack_depth() > 0: - this.add_arc(this.B(0),this.S(0),0) - this.pop() - # move the first token onto the stack - this.push() - - elif this.stack_depth() == 0: - # for one token sentences (?) - if this.buffer_length == 1: - this.push() - this.pop() - # with an empty stack and a non-empty buffer - # only shift is valid anyway - elif this.buffer_length != 0: - this.push() - - else: # can this even happen? - break diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index e2e7d5f34..82ac43cb1 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -66,7 +66,7 @@ cdef weight_t push_cost(StateClass stcls, const GoldParseC* gold, int target) no cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nogil: cdef weight_t cost = 0 cdef int i, B_i - for i in range(stcls.c.buffer_length): + for i in range(stcls.c.segment_length()): B_i = stcls.B(i) cost += gold.heads[B_i] == target cost += gold.heads[target] == B_i @@ -74,8 +74,8 @@ cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nog break if BINARY_COSTS and cost >= 1: return cost - if Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0: - cost += 1 + #if Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0: + # cost += 1 return cost @@ -117,15 +117,23 @@ cdef bint _is_gold_root(const GoldParseC* gold, int word) nogil: cdef class Shift: @staticmethod cdef bint is_valid(const StateC* st, attr_t label) nogil: - sent_start = st._sent[st.B_(0).l_edge].sent_start - return st.buffer_length >= 2 and not st.shifted[st.B(0)] and sent_start != 1 + if not st.can_push(): + return False + elif st.stack_depth() == 0: # If the stack is empty, we must push + return True + elif st.shifted[st.B(0)]: + return False + elif st.at_break(): + return False + else: + return True @staticmethod cdef int transition(StateC* st, attr_t label) nogil: - if label != 0: - st.split(st.B(1), label) + #if label != 0: + # st.split(st.B(1), label) + st.shifted[st.B(0)] = 1 st.push() - st.fast_forward() @staticmethod cdef weight_t cost(StateClass st, const GoldParseC* gold, attr_t label) nogil: @@ -138,7 +146,7 @@ cdef class Shift: @staticmethod cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, attr_t label) nogil: return 0 - #if gold.fused_tokens[s.B(1)] == label: + #if gold.fused_tokens[s.B(1)] == label: TODO # return 0 #else: # return 1 @@ -147,15 +155,21 @@ cdef class Shift: cdef class Reduce: @staticmethod cdef bint is_valid(const StateC* st, attr_t label) nogil: - return st.stack_depth() >= 2 + if st.stack_depth() >= 2: + return True + elif st.at_break() and st.stack_depth() == 1: + return True + else: + return False @staticmethod cdef int transition(StateC* st, attr_t label) nogil: if st.has_head(st.S(0)): st.pop() + elif st.stack_depth() == 1 and st.at_break(): + st.pop() else: st.unshift() - st.fast_forward() @staticmethod cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil: @@ -165,15 +179,15 @@ cdef class Reduce: cdef inline weight_t move_cost(StateClass st, const GoldParseC* gold) nogil: cost = pop_cost(st, gold, st.S(0)) if not st.has_head(st.S(0)): - # Decrement cost for the arcs e save + # Decrement cost for the arcs we save for i in range(1, st.stack_depth()): S_i = st.S(i) if gold.heads[st.S(0)] == S_i: cost -= 1 if gold.heads[S_i] == st.S(0): cost -= 1 - if Break.is_valid(st.c, 0) and Break.move_cost(st, gold) == 0: - cost -= 1 + #if Break.is_valid(st.c, 0) and Break.move_cost(st, gold) == 0: + # cost -= 1 return cost @staticmethod @@ -184,18 +198,18 @@ cdef class Reduce: cdef class LeftArc: @staticmethod cdef bint is_valid(const StateC* st, attr_t label) nogil: - sent_start = st._sent[st.B_(0).l_edge].sent_start - return sent_start != 1 + return st.can_arc() @staticmethod cdef int transition(StateC* st, attr_t label) nogil: st.add_arc(st.B(0), st.S(0), label) st.pop() - st.fast_forward() @staticmethod cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil: - return LeftArc.move_cost(s, gold) + LeftArc.label_cost(s, gold, label) + cdef weight_t move_cost = LeftArc.move_cost(s, gold) + cdef weight_t label_cost = LeftArc.label_cost(s, gold, label) + return move_cost + label_cost @staticmethod cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil: @@ -220,14 +234,17 @@ cdef class RightArc: @staticmethod cdef bint is_valid(const StateC* st, attr_t label) nogil: # If there's (perhaps partial) parse pre-set, don't allow cycle. - sent_start = st._sent[st.B_(0).l_edge].sent_start - return sent_start != 1 and st.H(st.S(0)) != st.B(0) + if not st.can_arc(): + return 0 + elif st.H(st.S(0)) == st.B(0): + return 0 + else: + return 1 @staticmethod cdef int transition(StateC* st, attr_t label) nogil: st.add_arc(st.S(0), st.B(0), label) st.push() - st.fast_forward() @staticmethod cdef inline weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil: @@ -253,21 +270,13 @@ cdef class Break: cdef int i if not USE_BREAK: return False - elif st.at_break(): - return False - elif st.stack_depth() < 1: - return False - elif st.B_(0).l_edge < 0: - return False - elif st._sent[st.B_(0).l_edge].sent_start < 0: - return False else: - return True + return st.can_break() @staticmethod cdef int transition(StateC* st, attr_t label) nogil: st.set_break(0) - st.fast_forward() + st.pop() @staticmethod cdef weight_t cost(StateClass s, const GoldParseC* gold, attr_t label) nogil: @@ -317,7 +326,6 @@ cdef void* _init_state(Pool mem, int length, void* tokens) except NULL: st._sent[i].dep = 0 st._sent[i].l_kids = 0 st._sent[i].r_kids = 0 - st.fast_forward() return st @@ -520,7 +528,6 @@ cdef class ArcEager(TransitionSystem): st._sent[i].dep = 0 st._sent[i].l_kids = 0 st._sent[i].r_kids = 0 - st.fast_forward() cdef int finalize_state(self, StateC* st) nogil: cdef int i diff --git a/spacy/syntax/stateclass.pxd b/spacy/syntax/stateclass.pxd index 8331e674d..b3cb28b68 100644 --- a/spacy/syntax/stateclass.pxd +++ b/spacy/syntax/stateclass.pxd @@ -137,6 +137,3 @@ cdef class StateClass: cdef inline void clone(self, StateClass src) nogil: self.c.clone(src.c) - - cdef inline void fast_forward(self) nogil: - self.c.fast_forward() diff --git a/spacy/syntax/stateclass.pyx b/spacy/syntax/stateclass.pyx index cd0cab2a6..a5349733a 100644 --- a/spacy/syntax/stateclass.pyx +++ b/spacy/syntax/stateclass.pyx @@ -30,28 +30,32 @@ cdef class StateClass: def get_S(self, int i): return self.c.S(i) - def push_stack(self, fast_forward=True): + def can_push(self): + return self.c.can_push() + + def can_pop(self): + return self.c.can_pop() + + def can_break(self): + return self.c.can_break() + + def can_arc(self): + return self.c.can_arc() + + def push_stack(self): self.c.push() - if fast_forward: - self.c.fast_forward() - def pop_stack(self, fast_forward=True): + def pop_stack(self): self.c.pop() - if fast_forward: - self.c.fast_forward() - def unshift(self, fast_forward=True): + def unshift(self): self.c.unshift() - if fast_forward: - self.c.fast_forward() def set_break(self, int i): self.c.set_break(i) - def split_token(self, int i, int n, fast_forward=True): + def split_token(self, int i, int n): self.c.split(i, n) - if fast_forward: - self.c.fast_forward() def get_doc(self, vocab): cdef Doc doc = Doc(vocab) diff --git a/spacy/tests/parser/test_add_label.py b/spacy/tests/parser/test_add_label.py index 9493452a1..7698e2aec 100644 --- a/spacy/tests/parser/test_add_label.py +++ b/spacy/tests/parser/test_add_label.py @@ -41,7 +41,6 @@ def test_init_parser(parser): pass # TODO: This is flakey, because it depends on what the parser first learns. -@pytest.mark.xfail def test_add_label(parser): doc = Doc(parser.vocab, words=['a', 'b', 'c', 'd']) doc = parser(doc) diff --git a/spacy/tests/parser/test_nn_beam.py b/spacy/tests/parser/test_nn_beam.py index ab8bf012b..b4c15c0d7 100644 --- a/spacy/tests/parser/test_nn_beam.py +++ b/spacy/tests/parser/test_nn_beam.py @@ -17,7 +17,10 @@ def vocab(): @pytest.fixture def moves(vocab): - aeager = ArcEager(vocab.strings, {}) + aeager = ArcEager(vocab.strings) + aeager.add_action(0, '') + aeager.add_action(1, '') + aeager.add_action(4, 'ROOT') aeager.add_action(2, 'nsubj') aeager.add_action(3, 'dobj') aeager.add_action(2, 'aux') diff --git a/spacy/tests/parser/test_split_word.py b/spacy/tests/parser/test_split_word.py index e44a27fdb..0b2b092b1 100644 --- a/spacy/tests/parser/test_split_word.py +++ b/spacy/tests/parser/test_split_word.py @@ -39,5 +39,5 @@ def test_split(): doc = get_doc('abcd') state = StateClass(doc, max_split=3) assert state.queue == [0, 1, 2, 3] - state.split_token(1, 2, fast_forward=False) + state.split_token(1, 2) assert state.queue == [0, 1, 1*4+1, 2*4+1, 2, 3]