mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-09 16:10:33 +03:00
* Add get_valid method
This commit is contained in:
parent
d82f9d958d
commit
c7876aa8b6
|
@ -120,6 +120,20 @@ cdef class ArcEager(TransitionSystem):
|
||||||
if state.sent[i].head == 0 and state.sent[i].dep == 0:
|
if state.sent[i].head == 0 and state.sent[i].dep == 0:
|
||||||
state.sent[i].dep = root_label
|
state.sent[i].dep = root_label
|
||||||
|
|
||||||
|
cdef bint* get_valid(self, const State* s) except NULL:
|
||||||
|
cdef bint[N_MOVES] is_valid
|
||||||
|
is_valid[SHIFT] = _can_shift(s)
|
||||||
|
is_valid[REDUCE] = _can_reduce(s)
|
||||||
|
is_valid[LEFT] = _can_left(s)
|
||||||
|
is_valid[RIGHT] = _can_right(s)
|
||||||
|
is_valid[BREAK] = _can_break(s)
|
||||||
|
is_valid[CONSTITUENT] = _can_constituent(s)
|
||||||
|
is_valid[ADJUST] = _can_adjust(s)
|
||||||
|
cdef int i
|
||||||
|
for i in range(self.n_moves):
|
||||||
|
self._is_valid[i] = is_valid[self.c[i].move]
|
||||||
|
return self._is_valid
|
||||||
|
|
||||||
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
|
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
|
||||||
cdef bint[N_MOVES] is_valid
|
cdef bint[N_MOVES] is_valid
|
||||||
is_valid[SHIFT] = _can_shift(s)
|
is_valid[SHIFT] = _can_shift(s)
|
||||||
|
@ -451,4 +465,3 @@ cdef inline bint _can_adjust(const State* s) nogil:
|
||||||
# return False
|
# return False
|
||||||
#elif b0 >= b1:
|
#elif b0 >= b1:
|
||||||
# return False
|
# return False
|
||||||
return True
|
|
||||||
|
|
|
@ -140,6 +140,13 @@ cdef class BiluoPushDown(TransitionSystem):
|
||||||
t.score = score
|
t.score = score
|
||||||
return t
|
return t
|
||||||
|
|
||||||
|
cdef bint* get_valid(self, const State* s) except NULL:
|
||||||
|
cdef int i
|
||||||
|
for i in range(self.n_moves):
|
||||||
|
m = &self.c[i]
|
||||||
|
self._is_valid[i] = _is_valid(m.move, m.label, s)
|
||||||
|
return self._is_valid
|
||||||
|
|
||||||
|
|
||||||
cdef int _get_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
cdef int _get_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
||||||
if not _is_valid(self.move, self.label, s):
|
if not _is_valid(self.move, self.label, s):
|
||||||
|
|
|
@ -28,6 +28,7 @@ cdef class TransitionSystem:
|
||||||
cdef Pool mem
|
cdef Pool mem
|
||||||
cdef StringStore strings
|
cdef StringStore strings
|
||||||
cdef const Transition* c
|
cdef const Transition* c
|
||||||
|
cdef bint* _is_valid
|
||||||
cdef readonly int n_moves
|
cdef readonly int n_moves
|
||||||
|
|
||||||
cdef int initialize_state(self, State* state) except -1
|
cdef int initialize_state(self, State* state) except -1
|
||||||
|
@ -39,6 +40,8 @@ cdef class TransitionSystem:
|
||||||
|
|
||||||
cdef Transition init_transition(self, int clas, int move, int label) except *
|
cdef Transition init_transition(self, int clas, int move, int label) except *
|
||||||
|
|
||||||
|
cdef bint* get_valid(self, const State* state) except NULL
|
||||||
|
|
||||||
cdef Transition best_valid(self, const weight_t* scores, const State* state) except *
|
cdef Transition best_valid(self, const weight_t* scores, const State* state) except *
|
||||||
|
|
||||||
cdef Transition best_gold(self, const weight_t* scores, const State* state,
|
cdef Transition best_gold(self, const weight_t* scores, const State* state,
|
||||||
|
|
|
@ -15,6 +15,7 @@ cdef class TransitionSystem:
|
||||||
def __init__(self, StringStore string_table, dict labels_by_action):
|
def __init__(self, StringStore string_table, dict labels_by_action):
|
||||||
self.mem = Pool()
|
self.mem = Pool()
|
||||||
self.n_moves = sum(len(labels) for labels in labels_by_action.values())
|
self.n_moves = sum(len(labels) for labels in labels_by_action.values())
|
||||||
|
self._is_valid = <bint*>self.mem.alloc(self.n_moves, sizeof(bint))
|
||||||
moves = <Transition*>self.mem.alloc(self.n_moves, sizeof(Transition))
|
moves = <Transition*>self.mem.alloc(self.n_moves, sizeof(Transition))
|
||||||
cdef int i = 0
|
cdef int i = 0
|
||||||
cdef int label_id
|
cdef int label_id
|
||||||
|
@ -44,6 +45,9 @@ cdef class TransitionSystem:
|
||||||
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
|
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
cdef bint* get_valid(self, const State* state) except NULL:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
cdef Transition best_gold(self, const weight_t* scores, const State* s,
|
cdef Transition best_gold(self, const weight_t* scores, const State* s,
|
||||||
GoldParse gold) except *:
|
GoldParse gold) except *:
|
||||||
cdef Transition best
|
cdef Transition best
|
||||||
|
|
Loading…
Reference in New Issue
Block a user