mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-27 09:44:36 +03:00
* Fix missing root labels bug identified in Issue #57
This commit is contained in:
parent
693c5a1558
commit
b3fd48c97b
|
@ -88,9 +88,15 @@ cdef class ArcEager(TransitionSystem):
|
||||||
t.get_cost = get_cost_funcs[move]
|
t.get_cost = get_cost_funcs[move]
|
||||||
return t
|
return t
|
||||||
|
|
||||||
cdef int first_state(self, State* state) except -1:
|
cdef int initialize_state(self, State* state) except -1:
|
||||||
push_stack(state)
|
push_stack(state)
|
||||||
|
|
||||||
|
cdef int finalize_state(self, State* state) except -1:
|
||||||
|
cdef int root_label = self.strings['ROOT']
|
||||||
|
for i in range(state.sent_len):
|
||||||
|
if state.sent[i].head == 0 and state.sent[i].dep == 0:
|
||||||
|
state.sent[i].dep = root_label
|
||||||
|
|
||||||
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
|
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
|
||||||
cdef bint[N_MOVES] is_valid
|
cdef bint[N_MOVES] is_valid
|
||||||
is_valid[SHIFT] = _can_shift(s)
|
is_valid[SHIFT] = _can_shift(s)
|
||||||
|
|
|
@ -124,9 +124,6 @@ cdef class BiluoPushDown(TransitionSystem):
|
||||||
t.get_cost = _get_cost
|
t.get_cost = _get_cost
|
||||||
return t
|
return t
|
||||||
|
|
||||||
cdef int first_state(self, State* state) except -1:
|
|
||||||
pass
|
|
||||||
|
|
||||||
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
|
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
|
||||||
cdef int best = -1
|
cdef int best = -1
|
||||||
cdef weight_t score = -90000
|
cdef weight_t score = -90000
|
||||||
|
|
|
@ -83,15 +83,14 @@ cdef class GreedyParser:
|
||||||
cdef int n_feats
|
cdef int n_feats
|
||||||
cdef Pool mem = Pool()
|
cdef Pool mem = Pool()
|
||||||
cdef State* state = new_state(mem, tokens.data, tokens.length)
|
cdef State* state = new_state(mem, tokens.data, tokens.length)
|
||||||
self.moves.first_state(state)
|
self.moves.initialize_state(state)
|
||||||
cdef Transition guess
|
cdef Transition guess
|
||||||
while not is_final(state):
|
while not is_final(state):
|
||||||
fill_context(context, state)
|
fill_context(context, state)
|
||||||
scores = self.model.score(context)
|
scores = self.model.score(context)
|
||||||
guess = self.moves.best_valid(scores, state)
|
guess = self.moves.best_valid(scores, state)
|
||||||
#print self.moves.move_name(guess.move, guess.label),
|
|
||||||
#print print_state(state, [w.orth_ for w in tokens])
|
|
||||||
guess.do(&guess, state)
|
guess.do(&guess, state)
|
||||||
|
self.moves.finalize_state(state)
|
||||||
tokens.set_parse(state.sent)
|
tokens.set_parse(state.sent)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
@ -99,7 +98,7 @@ cdef class GreedyParser:
|
||||||
self.moves.preprocess_gold(gold)
|
self.moves.preprocess_gold(gold)
|
||||||
cdef Pool mem = Pool()
|
cdef Pool mem = Pool()
|
||||||
cdef State* state = new_state(mem, tokens.data, tokens.length)
|
cdef State* state = new_state(mem, tokens.data, tokens.length)
|
||||||
self.moves.first_state(state)
|
self.moves.initialize_state(state)
|
||||||
|
|
||||||
cdef int cost
|
cdef int cost
|
||||||
cdef const Feature* feats
|
cdef const Feature* feats
|
||||||
|
@ -117,3 +116,4 @@ cdef class GreedyParser:
|
||||||
self.model.update(context, guess.clas, best.clas, cost)
|
self.model.update(context, guess.clas, best.clas, cost)
|
||||||
|
|
||||||
guess.do(&guess, state)
|
guess.do(&guess, state)
|
||||||
|
self.moves.finalize_state(state)
|
||||||
|
|
|
@ -30,7 +30,8 @@ cdef class TransitionSystem:
|
||||||
cdef const Transition* c
|
cdef const Transition* c
|
||||||
cdef readonly int n_moves
|
cdef readonly int n_moves
|
||||||
|
|
||||||
cdef int first_state(self, State* state) except -1
|
cdef int initialize_state(self, State* state) except -1
|
||||||
|
cdef int finalize_state(self, State* state) except -1
|
||||||
|
|
||||||
cdef int preprocess_gold(self, GoldParse gold) except -1
|
cdef int preprocess_gold(self, GoldParse gold) except -1
|
||||||
|
|
||||||
|
|
|
@ -26,8 +26,11 @@ cdef class TransitionSystem:
|
||||||
i += 1
|
i += 1
|
||||||
self.c = moves
|
self.c = moves
|
||||||
|
|
||||||
cdef int first_state(self, State* state) except -1:
|
cdef int initialize_state(self, State* state) except -1:
|
||||||
raise NotImplementedError
|
pass
|
||||||
|
|
||||||
|
cdef int finalize_state(self, State* state) except -1:
|
||||||
|
pass
|
||||||
|
|
||||||
cdef int preprocess_gold(self, GoldParse gold) except -1:
|
cdef int preprocess_gold(self, GoldParse gold) except -1:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
Loading…
Reference in New Issue
Block a user