mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-27 17:54:39 +03:00
* Fix bint/int typing problem in TransitionSystem. In C++ bint* means bool*, but in C it means int*. So, type-casting to bint* is unsafe.
This commit is contained in:
parent
6cfa83157e
commit
e29daea85f
|
@ -383,7 +383,7 @@ cdef class ArcEager(TransitionSystem):
|
||||||
elif USE_ROOT_ARC_SEGMENT and st._sent[i].dep == self.root_label:
|
elif USE_ROOT_ARC_SEGMENT and st._sent[i].dep == self.root_label:
|
||||||
st._sent[i].head = 0
|
st._sent[i].head = 0
|
||||||
|
|
||||||
cdef int set_valid(self, bint* output, StateClass stcls) nogil:
|
cdef int set_valid(self, int* output, StateClass stcls) nogil:
|
||||||
cdef bint[N_MOVES] is_valid
|
cdef bint[N_MOVES] is_valid
|
||||||
is_valid[SHIFT] = Shift.is_valid(stcls, -1)
|
is_valid[SHIFT] = Shift.is_valid(stcls, -1)
|
||||||
is_valid[REDUCE] = Reduce.is_valid(stcls, -1)
|
is_valid[REDUCE] = Reduce.is_valid(stcls, -1)
|
||||||
|
@ -394,7 +394,7 @@ cdef class ArcEager(TransitionSystem):
|
||||||
for i in range(self.n_moves):
|
for i in range(self.n_moves):
|
||||||
output[i] = is_valid[self.c[i].move]
|
output[i] = is_valid[self.c[i].move]
|
||||||
|
|
||||||
cdef int set_costs(self, bint* is_valid, int* costs,
|
cdef int set_costs(self, int* is_valid, int* costs,
|
||||||
StateClass stcls, GoldParse gold) except -1:
|
StateClass stcls, GoldParse gold) except -1:
|
||||||
cdef int i, move, label
|
cdef int i, move, label
|
||||||
cdef label_cost_func_t[N_MOVES] label_cost_funcs
|
cdef label_cost_func_t[N_MOVES] label_cost_funcs
|
||||||
|
|
|
@ -93,7 +93,7 @@ cdef class Parser:
|
||||||
while not stcls.is_final():
|
while not stcls.is_final():
|
||||||
memset(eg.scores, 0, eg.nr_class * sizeof(weight_t))
|
memset(eg.scores, 0, eg.nr_class * sizeof(weight_t))
|
||||||
|
|
||||||
self.moves.set_valid(<bint*>eg.is_valid, stcls)
|
self.moves.set_valid(eg.is_valid, stcls)
|
||||||
fill_context(eg.atoms, stcls)
|
fill_context(eg.atoms, stcls)
|
||||||
self.model.set_scores(eg.scores, eg.atoms)
|
self.model.set_scores(eg.scores, eg.atoms)
|
||||||
eg.guess = arg_max_if_true(eg.scores, eg.is_valid, self.model.n_classes)
|
eg.guess = arg_max_if_true(eg.scores, eg.is_valid, self.model.n_classes)
|
||||||
|
@ -113,7 +113,7 @@ cdef class Parser:
|
||||||
while not stcls.is_final():
|
while not stcls.is_final():
|
||||||
memset(eg.c.scores, 0, eg.c.nr_class * sizeof(weight_t))
|
memset(eg.c.scores, 0, eg.c.nr_class * sizeof(weight_t))
|
||||||
|
|
||||||
self.moves.set_costs(<bint*>eg.c.is_valid, eg.c.costs, stcls, gold)
|
self.moves.set_costs(eg.c.is_valid, eg.c.costs, stcls, gold)
|
||||||
|
|
||||||
fill_context(eg.c.atoms, stcls)
|
fill_context(eg.c.atoms, stcls)
|
||||||
|
|
||||||
|
|
|
@ -45,7 +45,7 @@ cdef class TransitionSystem:
|
||||||
|
|
||||||
cdef Transition init_transition(self, int clas, int move, int label) except *
|
cdef Transition init_transition(self, int clas, int move, int label) except *
|
||||||
|
|
||||||
cdef int set_valid(self, bint* output, StateClass state) nogil
|
cdef int set_valid(self, int* output, StateClass state) nogil
|
||||||
|
|
||||||
cdef int set_costs(self, bint* is_valid, int* costs,
|
cdef int set_costs(self, int* is_valid, int* costs,
|
||||||
StateClass state, GoldParse gold) except -1
|
StateClass state, GoldParse gold) except -1
|
||||||
|
|
|
@ -44,12 +44,12 @@ cdef class TransitionSystem:
|
||||||
cdef Transition init_transition(self, int clas, int move, int label) except *:
|
cdef Transition init_transition(self, int clas, int move, int label) except *:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
cdef int set_valid(self, bint* is_valid, StateClass stcls) nogil:
|
cdef int set_valid(self, int* is_valid, StateClass stcls) nogil:
|
||||||
cdef int i
|
cdef int i
|
||||||
for i in range(self.n_moves):
|
for i in range(self.n_moves):
|
||||||
is_valid[i] = self.c[i].is_valid(stcls, self.c[i].label)
|
is_valid[i] = self.c[i].is_valid(stcls, self.c[i].label)
|
||||||
|
|
||||||
cdef int set_costs(self, bint* is_valid, int* costs,
|
cdef int set_costs(self, int* is_valid, int* costs,
|
||||||
StateClass stcls, GoldParse gold) except -1:
|
StateClass stcls, GoldParse gold) except -1:
|
||||||
cdef int i
|
cdef int i
|
||||||
self.set_valid(is_valid, stcls)
|
self.set_valid(is_valid, stcls)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user