From eb138c89edb306608826dca50619ea8a60de2b14 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Mon, 25 Jan 2021 10:52:40 +1100 Subject: [PATCH] Fix parser set_annotations during update --- spacy/pipeline/transition_parser.pyx | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx index b93565178..422246164 100644 --- a/spacy/pipeline/transition_parser.pyx +++ b/spacy/pipeline/transition_parser.pyx @@ -290,9 +290,6 @@ cdef class Parser(TrainablePipe): cdef void c_transition_batch(self, StateC** states, const float* scores, int nr_class, int batch_size) nogil: - # n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc - with gil: - assert self.moves.n_moves > 0, Errors.E924.format(name=self.name) is_valid = calloc(self.moves.n_moves, sizeof(int)) cdef int i, guess cdef Transition action @@ -310,6 +307,7 @@ cdef class Parser(TrainablePipe): def update(self, examples, *, drop=0., sgd=None, losses=None): cdef StateClass state + cdef Transition action if losses is None: losses = {} losses.setdefault(self.name, 0.) @@ -351,6 +349,9 @@ cdef class Parser(TrainablePipe): all_states = list(states) states_golds = list(zip(states, golds, state2doc)) n_moves = 0 + mem = Pool() + is_valid = mem.alloc(self.moves.n_moves, sizeof(int)) + cdef float[::1] scores_row while states_golds: states, golds, state2doc = zip(*states_golds) scores, backprop = model.begin_update(states) @@ -360,10 +361,20 @@ cdef class Parser(TrainablePipe): # can't normalize by the number of states either, as then we'd # be getting smaller gradients for states in long sequences. backprop(d_scores) - # Follow the predicted action - actions = self.transition_states(states, scores) - for i, action in enumerate(actions): - histories[i].append(action) + # Ugh, we need to get the actions for the histories, so we're + # duplicating work that's being done in transition_states. This + # should be refactored. + scores_view = scores + for i, state in enumerate(states): + self.moves.set_valid(is_valid, state.c) + scores_row = scores[i] + guess = arg_max_if_valid(&scores_row[0], is_valid, scores.shape[1]) + if guess == -1: + raise ValueError("Could not find valid transition") + histories[state2doc[i]].append(guess) + # Follow the predicted action + action = self.moves.c[guess] + action.do(state.c, action.label) states_golds = [ s for s in zip(states, golds, state2doc) if not s[0].is_final()