From 65f2270d597428386824c6d7be30e64ac33aeaa9 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Mon, 25 Jan 2021 11:22:43 +1100 Subject: [PATCH] Revert "Fix parser set_annotations during update" This reverts commit eb138c89edb306608826dca50619ea8a60de2b14. --- spacy/pipeline/transition_parser.pyx | 25 +++++++------------------ 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx index 422246164..b93565178 100644 --- a/spacy/pipeline/transition_parser.pyx +++ b/spacy/pipeline/transition_parser.pyx @@ -290,6 +290,9 @@ cdef class Parser(TrainablePipe): cdef void c_transition_batch(self, StateC** states, const float* scores, int nr_class, int batch_size) nogil: + # n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc + with gil: + assert self.moves.n_moves > 0, Errors.E924.format(name=self.name) is_valid = calloc(self.moves.n_moves, sizeof(int)) cdef int i, guess cdef Transition action @@ -307,7 +310,6 @@ cdef class Parser(TrainablePipe): def update(self, examples, *, drop=0., sgd=None, losses=None): cdef StateClass state - cdef Transition action if losses is None: losses = {} losses.setdefault(self.name, 0.) @@ -349,9 +351,6 @@ cdef class Parser(TrainablePipe): all_states = list(states) states_golds = list(zip(states, golds, state2doc)) n_moves = 0 - mem = Pool() - is_valid = mem.alloc(self.moves.n_moves, sizeof(int)) - cdef float[::1] scores_row while states_golds: states, golds, state2doc = zip(*states_golds) scores, backprop = model.begin_update(states) @@ -361,20 +360,10 @@ cdef class Parser(TrainablePipe): # can't normalize by the number of states either, as then we'd # be getting smaller gradients for states in long sequences. backprop(d_scores) - # Ugh, we need to get the actions for the histories, so we're - # duplicating work that's being done in transition_states. This - # should be refactored. - scores_view = scores - for i, state in enumerate(states): - self.moves.set_valid(is_valid, state.c) - scores_row = scores[i] - guess = arg_max_if_valid(&scores_row[0], is_valid, scores.shape[1]) - if guess == -1: - raise ValueError("Could not find valid transition") - histories[state2doc[i]].append(guess) - # Follow the predicted action - action = self.moves.c[guess] - action.do(state.c, action.label) + # Follow the predicted action + actions = self.transition_states(states, scores) + for i, action in enumerate(actions): + histories[i].append(action) states_golds = [ s for s in zip(states, golds, state2doc) if not s[0].is_final()