* Pass a StateC pointer into the transition and validation methods in the parser, so that the GIL can be released over a batch of documents

This commit is contained in:
Matthew Honnibal 2016-02-01 03:00:15 +01:00
parent a47f00901b
commit 28e5ad62bc
3 changed files with 6 additions and 6 deletions

View File

@ -114,7 +114,7 @@ cdef class Parser:
cdef void parseC(self, Doc tokens, StateClass stcls, Example eg) nogil:
while not stcls.is_final():
self.model.set_featuresC(&eg.c, stcls)
self.moves.set_valid(eg.c.is_valid, stcls)
self.moves.set_valid(eg.c.is_valid, stcls.c)
self.model.set_scoresC(eg.c.scores, eg.c.features, eg.c.nr_feat)
guess = VecVec.arg_max_if_true(eg.c.scores, eg.c.is_valid, eg.c.nr_class)
@ -210,7 +210,7 @@ cdef class StepwiseState:
def predict(self):
self.eg.reset()
self.parser.model.set_featuresC(&self.eg.c, self.stcls)
self.parser.moves.set_valid(self.eg.c.is_valid, self.stcls)
self.parser.moves.set_valid(self.eg.c.is_valid, self.stcls.c)
self.parser.model.set_scoresC(self.eg.c.scores,
self.eg.c.features, self.eg.c.nr_feat)

View File

@ -47,7 +47,7 @@ cdef class TransitionSystem:
cdef Transition init_transition(self, int clas, int move, int label) except *
cdef int set_valid(self, int* output, StateClass state) nogil
cdef int set_valid(self, int* output, const StateC* st) nogil
cdef int set_costs(self, int* is_valid, weight_t* costs,
StateClass state, GoldParse gold) except -1

View File

@ -66,15 +66,15 @@ cdef class TransitionSystem:
action = self.lookup_transition(move_name)
return action.is_valid(stcls.c, action.label)
cdef int set_valid(self, int* is_valid, StateClass stcls) nogil:
cdef int set_valid(self, int* is_valid, const StateC* st) nogil:
cdef int i
for i in range(self.n_moves):
is_valid[i] = self.c[i].is_valid(stcls.c, self.c[i].label)
is_valid[i] = self.c[i].is_valid(st, self.c[i].label)
cdef int set_costs(self, int* is_valid, weight_t* costs,
StateClass stcls, GoldParse gold) except -1:
cdef int i
self.set_valid(is_valid, stcls)
self.set_valid(is_valid, stcls.c)
for i in range(self.n_moves):
if is_valid[i]:
costs[i] = self.c[i].get_cost(stcls, &gold.c, self.c[i].label)