diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index fb9b001d5..55d5dcc3f 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -114,7 +114,7 @@ cdef class Parser: cdef void parseC(self, Doc tokens, StateClass stcls, Example eg) nogil: while not stcls.is_final(): self.model.set_featuresC(&eg.c, stcls) - self.moves.set_valid(eg.c.is_valid, stcls) + self.moves.set_valid(eg.c.is_valid, stcls.c) self.model.set_scoresC(eg.c.scores, eg.c.features, eg.c.nr_feat) guess = VecVec.arg_max_if_true(eg.c.scores, eg.c.is_valid, eg.c.nr_class) @@ -210,7 +210,7 @@ cdef class StepwiseState: def predict(self): self.eg.reset() self.parser.model.set_featuresC(&self.eg.c, self.stcls) - self.parser.moves.set_valid(self.eg.c.is_valid, self.stcls) + self.parser.moves.set_valid(self.eg.c.is_valid, self.stcls.c) self.parser.model.set_scoresC(self.eg.c.scores, self.eg.c.features, self.eg.c.nr_feat) diff --git a/spacy/syntax/transition_system.pxd b/spacy/syntax/transition_system.pxd index 23d3561c4..d9dbe454d 100644 --- a/spacy/syntax/transition_system.pxd +++ b/spacy/syntax/transition_system.pxd @@ -47,7 +47,7 @@ cdef class TransitionSystem: cdef Transition init_transition(self, int clas, int move, int label) except * - cdef int set_valid(self, int* output, StateClass state) nogil + cdef int set_valid(self, int* output, const StateC* st) nogil cdef int set_costs(self, int* is_valid, weight_t* costs, StateClass state, GoldParse gold) except -1 diff --git a/spacy/syntax/transition_system.pyx b/spacy/syntax/transition_system.pyx index 0b2a03202..4228e8e67 100644 --- a/spacy/syntax/transition_system.pyx +++ b/spacy/syntax/transition_system.pyx @@ -66,15 +66,15 @@ cdef class TransitionSystem: action = self.lookup_transition(move_name) return action.is_valid(stcls.c, action.label) - cdef int set_valid(self, int* is_valid, StateClass stcls) nogil: + cdef int set_valid(self, int* is_valid, const StateC* st) nogil: cdef int i for i in range(self.n_moves): - is_valid[i] = self.c[i].is_valid(stcls.c, self.c[i].label) + is_valid[i] = self.c[i].is_valid(st, self.c[i].label) cdef int set_costs(self, int* is_valid, weight_t* costs, StateClass stcls, GoldParse gold) except -1: cdef int i - self.set_valid(is_valid, stcls) + self.set_valid(is_valid, stcls.c) for i in range(self.n_moves): if is_valid[i]: costs[i] = self.c[i].get_cost(stcls, &gold.c, self.c[i].label)