Fix parser set_annotations during update

This commit is contained in:
Matthew Honnibal 2021-01-25 10:52:40 +11:00
parent c6df0eafd0
commit eb138c89ed

View File

@ -290,9 +290,6 @@ cdef class Parser(TrainablePipe):
cdef void c_transition_batch(self, StateC** states, const float* scores,
int nr_class, int batch_size) nogil:
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
with gil:
assert self.moves.n_moves > 0, Errors.E924.format(name=self.name)
is_valid = <int*>calloc(self.moves.n_moves, sizeof(int))
cdef int i, guess
cdef Transition action
@ -310,6 +307,7 @@ cdef class Parser(TrainablePipe):
def update(self, examples, *, drop=0., sgd=None, losses=None):
cdef StateClass state
cdef Transition action
if losses is None:
losses = {}
losses.setdefault(self.name, 0.)
@ -351,6 +349,9 @@ cdef class Parser(TrainablePipe):
all_states = list(states)
states_golds = list(zip(states, golds, state2doc))
n_moves = 0
mem = Pool()
is_valid = <int*>mem.alloc(self.moves.n_moves, sizeof(int))
cdef float[::1] scores_row
while states_golds:
states, golds, state2doc = zip(*states_golds)
scores, backprop = model.begin_update(states)
@ -360,10 +361,20 @@ cdef class Parser(TrainablePipe):
# can't normalize by the number of states either, as then we'd
# be getting smaller gradients for states in long sequences.
backprop(d_scores)
# Ugh, we need to get the actions for the histories, so we're
# duplicating work that's being done in transition_states. This
# should be refactored.
scores_view = scores
for i, state in enumerate(states):
self.moves.set_valid(is_valid, state.c)
scores_row = scores[i]
guess = arg_max_if_valid(&scores_row[0], is_valid, scores.shape[1])
if guess == -1:
raise ValueError("Could not find valid transition")
histories[state2doc[i]].append(guess)
# Follow the predicted action
actions = self.transition_states(states, scores)
for i, action in enumerate(actions):
histories[i].append(action)
action = self.moves.c[guess]
action.do(state.c, action.label)
states_golds = [
s for s in zip(states, golds, state2doc)
if not s[0].is_final()