Write moves costs directly into numpy array (#10163)

This avoids elementwise indexing and the allocation of an additional
array.

Gives a ~15% speed improvement when using batch_by_sequence with size
32.
This commit is contained in:
Daniël de Kok 2022-02-04 21:07:14 +01:00 committed by GitHub
parent b68d7a1ebf
commit 7b02f0fe5d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -23,6 +23,7 @@ from ._parser_internals cimport _beam_utils
from ._parser_internals import _beam_utils from ._parser_internals import _beam_utils
from ..vocab cimport Vocab from ..vocab cimport Vocab
from ._parser_internals.transition_system cimport TransitionSystem from ._parser_internals.transition_system cimport TransitionSystem
from ..typedefs cimport weight_t
from ..training import validate_examples, validate_get_examples from ..training import validate_examples, validate_get_examples
from ..errors import Errors, Warnings from ..errors import Errors, Warnings
@ -335,8 +336,8 @@ class Parser(TrainablePipe):
cdef int nO = moves.n_moves cdef int nO = moves.n_moves
cdef int nS = sum([len(history) for history in histories]) cdef int nS = sum([len(history) for history in histories])
cdef Pool mem = Pool() cdef Pool mem = Pool()
cdef np.ndarray costs_i
is_valid = <int*>mem.alloc(nO, sizeof(int)) is_valid = <int*>mem.alloc(nO, sizeof(int))
c_costs = <float*>mem.alloc(nO, sizeof(float))
states = moves.init_batch([eg.x for eg in examples]) states = moves.init_batch([eg.x for eg in examples])
batch = [] batch = []
for eg, s, h in zip(examples, states, histories): for eg, s, h in zip(examples, states, histories):
@ -347,13 +348,12 @@ class Parser(TrainablePipe):
while batch: while batch:
costs = numpy.zeros((len(batch), nO), dtype="f") costs = numpy.zeros((len(batch), nO), dtype="f")
for i, (eg, state, history, gold) in enumerate(batch): for i, (eg, state, history, gold) in enumerate(batch):
costs_i = costs[i]
clas = history.pop(0) clas = history.pop(0)
moves.set_costs(is_valid, c_costs, state.c, gold) moves.set_costs(is_valid, <weight_t*>costs_i.data, state.c, gold)
action = moves.c[clas] action = moves.c[clas]
action.do(state.c, action.label) action.do(state.c, action.label)
state.c.history.push_back(clas) state.c.history.push_back(clas)
for j in range(nO):
costs[i, j] = c_costs[j]
output.append(costs) output.append(costs)
batch = [(eg, s, h, g) for eg, s, h, g in batch if len(h) != 0] batch = [(eg, s, h, g) for eg, s, h, g in batch if len(h) != 0]
return self.model.ops.xp.vstack(output) return self.model.ops.xp.vstack(output)