mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-25 00:34:20 +03:00
* Add get_valid method
This commit is contained in:
parent
d82f9d958d
commit
c7876aa8b6
|
@ -120,6 +120,20 @@ cdef class ArcEager(TransitionSystem):
|
|||
if state.sent[i].head == 0 and state.sent[i].dep == 0:
|
||||
state.sent[i].dep = root_label
|
||||
|
||||
cdef bint* get_valid(self, const State* s) except NULL:
|
||||
cdef bint[N_MOVES] is_valid
|
||||
is_valid[SHIFT] = _can_shift(s)
|
||||
is_valid[REDUCE] = _can_reduce(s)
|
||||
is_valid[LEFT] = _can_left(s)
|
||||
is_valid[RIGHT] = _can_right(s)
|
||||
is_valid[BREAK] = _can_break(s)
|
||||
is_valid[CONSTITUENT] = _can_constituent(s)
|
||||
is_valid[ADJUST] = _can_adjust(s)
|
||||
cdef int i
|
||||
for i in range(self.n_moves):
|
||||
self._is_valid[i] = is_valid[self.c[i].move]
|
||||
return self._is_valid
|
||||
|
||||
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
|
||||
cdef bint[N_MOVES] is_valid
|
||||
is_valid[SHIFT] = _can_shift(s)
|
||||
|
@ -451,4 +465,3 @@ cdef inline bint _can_adjust(const State* s) nogil:
|
|||
# return False
|
||||
#elif b0 >= b1:
|
||||
# return False
|
||||
return True
|
||||
|
|
|
@ -140,6 +140,13 @@ cdef class BiluoPushDown(TransitionSystem):
|
|||
t.score = score
|
||||
return t
|
||||
|
||||
cdef bint* get_valid(self, const State* s) except NULL:
|
||||
cdef int i
|
||||
for i in range(self.n_moves):
|
||||
m = &self.c[i]
|
||||
self._is_valid[i] = _is_valid(m.move, m.label, s)
|
||||
return self._is_valid
|
||||
|
||||
|
||||
cdef int _get_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
||||
if not _is_valid(self.move, self.label, s):
|
||||
|
|
|
@ -28,6 +28,7 @@ cdef class TransitionSystem:
|
|||
cdef Pool mem
|
||||
cdef StringStore strings
|
||||
cdef const Transition* c
|
||||
cdef bint* _is_valid
|
||||
cdef readonly int n_moves
|
||||
|
||||
cdef int initialize_state(self, State* state) except -1
|
||||
|
@ -39,6 +40,8 @@ cdef class TransitionSystem:
|
|||
|
||||
cdef Transition init_transition(self, int clas, int move, int label) except *
|
||||
|
||||
cdef bint* get_valid(self, const State* state) except NULL
|
||||
|
||||
cdef Transition best_valid(self, const weight_t* scores, const State* state) except *
|
||||
|
||||
cdef Transition best_gold(self, const weight_t* scores, const State* state,
|
||||
|
|
|
@ -15,6 +15,7 @@ cdef class TransitionSystem:
|
|||
def __init__(self, StringStore string_table, dict labels_by_action):
|
||||
self.mem = Pool()
|
||||
self.n_moves = sum(len(labels) for labels in labels_by_action.values())
|
||||
self._is_valid = <bint*>self.mem.alloc(self.n_moves, sizeof(bint))
|
||||
moves = <Transition*>self.mem.alloc(self.n_moves, sizeof(Transition))
|
||||
cdef int i = 0
|
||||
cdef int label_id
|
||||
|
@ -43,6 +44,9 @@ cdef class TransitionSystem:
|
|||
|
||||
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
|
||||
raise NotImplementedError
|
||||
|
||||
cdef bint* get_valid(self, const State* state) except NULL:
|
||||
raise NotImplementedError
|
||||
|
||||
cdef Transition best_gold(self, const weight_t* scores, const State* s,
|
||||
GoldParse gold) except *:
|
||||
|
|
Loading…
Reference in New Issue
Block a user