mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 17:24:41 +03:00
* Fix missing root labels bug identified in Issue #57
This commit is contained in:
parent
693c5a1558
commit
b3fd48c97b
|
@ -88,9 +88,15 @@ cdef class ArcEager(TransitionSystem):
|
|||
t.get_cost = get_cost_funcs[move]
|
||||
return t
|
||||
|
||||
cdef int first_state(self, State* state) except -1:
|
||||
cdef int initialize_state(self, State* state) except -1:
|
||||
push_stack(state)
|
||||
|
||||
cdef int finalize_state(self, State* state) except -1:
|
||||
cdef int root_label = self.strings['ROOT']
|
||||
for i in range(state.sent_len):
|
||||
if state.sent[i].head == 0 and state.sent[i].dep == 0:
|
||||
state.sent[i].dep = root_label
|
||||
|
||||
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
|
||||
cdef bint[N_MOVES] is_valid
|
||||
is_valid[SHIFT] = _can_shift(s)
|
||||
|
|
|
@ -124,9 +124,6 @@ cdef class BiluoPushDown(TransitionSystem):
|
|||
t.get_cost = _get_cost
|
||||
return t
|
||||
|
||||
cdef int first_state(self, State* state) except -1:
|
||||
pass
|
||||
|
||||
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
|
||||
cdef int best = -1
|
||||
cdef weight_t score = -90000
|
||||
|
|
|
@ -83,15 +83,14 @@ cdef class GreedyParser:
|
|||
cdef int n_feats
|
||||
cdef Pool mem = Pool()
|
||||
cdef State* state = new_state(mem, tokens.data, tokens.length)
|
||||
self.moves.first_state(state)
|
||||
self.moves.initialize_state(state)
|
||||
cdef Transition guess
|
||||
while not is_final(state):
|
||||
fill_context(context, state)
|
||||
scores = self.model.score(context)
|
||||
guess = self.moves.best_valid(scores, state)
|
||||
#print self.moves.move_name(guess.move, guess.label),
|
||||
#print print_state(state, [w.orth_ for w in tokens])
|
||||
guess.do(&guess, state)
|
||||
self.moves.finalize_state(state)
|
||||
tokens.set_parse(state.sent)
|
||||
return 0
|
||||
|
||||
|
@ -99,7 +98,7 @@ cdef class GreedyParser:
|
|||
self.moves.preprocess_gold(gold)
|
||||
cdef Pool mem = Pool()
|
||||
cdef State* state = new_state(mem, tokens.data, tokens.length)
|
||||
self.moves.first_state(state)
|
||||
self.moves.initialize_state(state)
|
||||
|
||||
cdef int cost
|
||||
cdef const Feature* feats
|
||||
|
@ -117,3 +116,4 @@ cdef class GreedyParser:
|
|||
self.model.update(context, guess.clas, best.clas, cost)
|
||||
|
||||
guess.do(&guess, state)
|
||||
self.moves.finalize_state(state)
|
||||
|
|
|
@ -30,7 +30,8 @@ cdef class TransitionSystem:
|
|||
cdef const Transition* c
|
||||
cdef readonly int n_moves
|
||||
|
||||
cdef int first_state(self, State* state) except -1
|
||||
cdef int initialize_state(self, State* state) except -1
|
||||
cdef int finalize_state(self, State* state) except -1
|
||||
|
||||
cdef int preprocess_gold(self, GoldParse gold) except -1
|
||||
|
||||
|
|
|
@ -26,8 +26,11 @@ cdef class TransitionSystem:
|
|||
i += 1
|
||||
self.c = moves
|
||||
|
||||
cdef int first_state(self, State* state) except -1:
|
||||
raise NotImplementedError
|
||||
cdef int initialize_state(self, State* state) except -1:
|
||||
pass
|
||||
|
||||
cdef int finalize_state(self, State* state) except -1:
|
||||
pass
|
||||
|
||||
cdef int preprocess_gold(self, GoldParse gold) except -1:
|
||||
raise NotImplementedError
|
||||
|
|
Loading…
Reference in New Issue
Block a user