mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 09:56:28 +03:00
Fix parser set_annotations during update
This commit is contained in:
parent
c6df0eafd0
commit
eb138c89ed
|
@ -290,9 +290,6 @@ cdef class Parser(TrainablePipe):
|
|||
|
||||
cdef void c_transition_batch(self, StateC** states, const float* scores,
|
||||
int nr_class, int batch_size) nogil:
|
||||
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
|
||||
with gil:
|
||||
assert self.moves.n_moves > 0, Errors.E924.format(name=self.name)
|
||||
is_valid = <int*>calloc(self.moves.n_moves, sizeof(int))
|
||||
cdef int i, guess
|
||||
cdef Transition action
|
||||
|
@ -310,6 +307,7 @@ cdef class Parser(TrainablePipe):
|
|||
|
||||
def update(self, examples, *, drop=0., sgd=None, losses=None):
|
||||
cdef StateClass state
|
||||
cdef Transition action
|
||||
if losses is None:
|
||||
losses = {}
|
||||
losses.setdefault(self.name, 0.)
|
||||
|
@ -351,6 +349,9 @@ cdef class Parser(TrainablePipe):
|
|||
all_states = list(states)
|
||||
states_golds = list(zip(states, golds, state2doc))
|
||||
n_moves = 0
|
||||
mem = Pool()
|
||||
is_valid = <int*>mem.alloc(self.moves.n_moves, sizeof(int))
|
||||
cdef float[::1] scores_row
|
||||
while states_golds:
|
||||
states, golds, state2doc = zip(*states_golds)
|
||||
scores, backprop = model.begin_update(states)
|
||||
|
@ -360,10 +361,20 @@ cdef class Parser(TrainablePipe):
|
|||
# can't normalize by the number of states either, as then we'd
|
||||
# be getting smaller gradients for states in long sequences.
|
||||
backprop(d_scores)
|
||||
# Ugh, we need to get the actions for the histories, so we're
|
||||
# duplicating work that's being done in transition_states. This
|
||||
# should be refactored.
|
||||
scores_view = scores
|
||||
for i, state in enumerate(states):
|
||||
self.moves.set_valid(is_valid, state.c)
|
||||
scores_row = scores[i]
|
||||
guess = arg_max_if_valid(&scores_row[0], is_valid, scores.shape[1])
|
||||
if guess == -1:
|
||||
raise ValueError("Could not find valid transition")
|
||||
histories[state2doc[i]].append(guess)
|
||||
# Follow the predicted action
|
||||
actions = self.transition_states(states, scores)
|
||||
for i, action in enumerate(actions):
|
||||
histories[i].append(action)
|
||||
action = self.moves.c[guess]
|
||||
action.do(state.c, action.label)
|
||||
states_golds = [
|
||||
s for s in zip(states, golds, state2doc)
|
||||
if not s[0].is_final()
|
||||
|
|
Loading…
Reference in New Issue
Block a user