diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx index c5591a9f3..8fad61538 100644 --- a/spacy/pipeline/transition_parser.pyx +++ b/spacy/pipeline/transition_parser.pyx @@ -23,6 +23,7 @@ from ._parser_internals cimport _beam_utils from ._parser_internals import _beam_utils from ..vocab cimport Vocab from ._parser_internals.transition_system cimport TransitionSystem +from ..typedefs cimport weight_t from ..training import validate_examples, validate_get_examples from ..errors import Errors, Warnings @@ -335,8 +336,8 @@ class Parser(TrainablePipe): cdef int nO = moves.n_moves cdef int nS = sum([len(history) for history in histories]) cdef Pool mem = Pool() + cdef np.ndarray costs_i is_valid = mem.alloc(nO, sizeof(int)) - c_costs = mem.alloc(nO, sizeof(float)) states = moves.init_batch([eg.x for eg in examples]) batch = [] for eg, s, h in zip(examples, states, histories): @@ -347,13 +348,12 @@ class Parser(TrainablePipe): while batch: costs = numpy.zeros((len(batch), nO), dtype="f") for i, (eg, state, history, gold) in enumerate(batch): + costs_i = costs[i] clas = history.pop(0) - moves.set_costs(is_valid, c_costs, state.c, gold) + moves.set_costs(is_valid, costs_i.data, state.c, gold) action = moves.c[clas] action.do(state.c, action.label) state.c.history.push_back(clas) - for j in range(nO): - costs[i, j] = c_costs[j] output.append(costs) batch = [(eg, s, h, g) for eg, s, h, g in batch if len(h) != 0] return self.model.ops.xp.vstack(output)