mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 18:06:29 +03:00
Fix parser set_annotations during update
This commit is contained in:
parent
c6df0eafd0
commit
eb138c89ed
|
@ -290,9 +290,6 @@ cdef class Parser(TrainablePipe):
|
||||||
|
|
||||||
cdef void c_transition_batch(self, StateC** states, const float* scores,
|
cdef void c_transition_batch(self, StateC** states, const float* scores,
|
||||||
int nr_class, int batch_size) nogil:
|
int nr_class, int batch_size) nogil:
|
||||||
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
|
|
||||||
with gil:
|
|
||||||
assert self.moves.n_moves > 0, Errors.E924.format(name=self.name)
|
|
||||||
is_valid = <int*>calloc(self.moves.n_moves, sizeof(int))
|
is_valid = <int*>calloc(self.moves.n_moves, sizeof(int))
|
||||||
cdef int i, guess
|
cdef int i, guess
|
||||||
cdef Transition action
|
cdef Transition action
|
||||||
|
@ -310,6 +307,7 @@ cdef class Parser(TrainablePipe):
|
||||||
|
|
||||||
def update(self, examples, *, drop=0., sgd=None, losses=None):
|
def update(self, examples, *, drop=0., sgd=None, losses=None):
|
||||||
cdef StateClass state
|
cdef StateClass state
|
||||||
|
cdef Transition action
|
||||||
if losses is None:
|
if losses is None:
|
||||||
losses = {}
|
losses = {}
|
||||||
losses.setdefault(self.name, 0.)
|
losses.setdefault(self.name, 0.)
|
||||||
|
@ -351,6 +349,9 @@ cdef class Parser(TrainablePipe):
|
||||||
all_states = list(states)
|
all_states = list(states)
|
||||||
states_golds = list(zip(states, golds, state2doc))
|
states_golds = list(zip(states, golds, state2doc))
|
||||||
n_moves = 0
|
n_moves = 0
|
||||||
|
mem = Pool()
|
||||||
|
is_valid = <int*>mem.alloc(self.moves.n_moves, sizeof(int))
|
||||||
|
cdef float[::1] scores_row
|
||||||
while states_golds:
|
while states_golds:
|
||||||
states, golds, state2doc = zip(*states_golds)
|
states, golds, state2doc = zip(*states_golds)
|
||||||
scores, backprop = model.begin_update(states)
|
scores, backprop = model.begin_update(states)
|
||||||
|
@ -360,10 +361,20 @@ cdef class Parser(TrainablePipe):
|
||||||
# can't normalize by the number of states either, as then we'd
|
# can't normalize by the number of states either, as then we'd
|
||||||
# be getting smaller gradients for states in long sequences.
|
# be getting smaller gradients for states in long sequences.
|
||||||
backprop(d_scores)
|
backprop(d_scores)
|
||||||
|
# Ugh, we need to get the actions for the histories, so we're
|
||||||
|
# duplicating work that's being done in transition_states. This
|
||||||
|
# should be refactored.
|
||||||
|
scores_view = scores
|
||||||
|
for i, state in enumerate(states):
|
||||||
|
self.moves.set_valid(is_valid, state.c)
|
||||||
|
scores_row = scores[i]
|
||||||
|
guess = arg_max_if_valid(&scores_row[0], is_valid, scores.shape[1])
|
||||||
|
if guess == -1:
|
||||||
|
raise ValueError("Could not find valid transition")
|
||||||
|
histories[state2doc[i]].append(guess)
|
||||||
# Follow the predicted action
|
# Follow the predicted action
|
||||||
actions = self.transition_states(states, scores)
|
action = self.moves.c[guess]
|
||||||
for i, action in enumerate(actions):
|
action.do(state.c, action.label)
|
||||||
histories[i].append(action)
|
|
||||||
states_golds = [
|
states_golds = [
|
||||||
s for s in zip(states, golds, state2doc)
|
s for s in zip(states, golds, state2doc)
|
||||||
if not s[0].is_final()
|
if not s[0].is_final()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user