mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-05 04:40:20 +03:00
Write moves costs directly into numpy array (#10163)
This avoids elementwise indexing and the allocation of an additional array. Gives a ~15% speed improvement when using batch_by_sequence with size 32.
This commit is contained in:
parent
b68d7a1ebf
commit
7b02f0fe5d
|
@ -23,6 +23,7 @@ from ._parser_internals cimport _beam_utils
|
||||||
from ._parser_internals import _beam_utils
|
from ._parser_internals import _beam_utils
|
||||||
from ..vocab cimport Vocab
|
from ..vocab cimport Vocab
|
||||||
from ._parser_internals.transition_system cimport TransitionSystem
|
from ._parser_internals.transition_system cimport TransitionSystem
|
||||||
|
from ..typedefs cimport weight_t
|
||||||
|
|
||||||
from ..training import validate_examples, validate_get_examples
|
from ..training import validate_examples, validate_get_examples
|
||||||
from ..errors import Errors, Warnings
|
from ..errors import Errors, Warnings
|
||||||
|
@ -335,8 +336,8 @@ class Parser(TrainablePipe):
|
||||||
cdef int nO = moves.n_moves
|
cdef int nO = moves.n_moves
|
||||||
cdef int nS = sum([len(history) for history in histories])
|
cdef int nS = sum([len(history) for history in histories])
|
||||||
cdef Pool mem = Pool()
|
cdef Pool mem = Pool()
|
||||||
|
cdef np.ndarray costs_i
|
||||||
is_valid = <int*>mem.alloc(nO, sizeof(int))
|
is_valid = <int*>mem.alloc(nO, sizeof(int))
|
||||||
c_costs = <float*>mem.alloc(nO, sizeof(float))
|
|
||||||
states = moves.init_batch([eg.x for eg in examples])
|
states = moves.init_batch([eg.x for eg in examples])
|
||||||
batch = []
|
batch = []
|
||||||
for eg, s, h in zip(examples, states, histories):
|
for eg, s, h in zip(examples, states, histories):
|
||||||
|
@ -347,13 +348,12 @@ class Parser(TrainablePipe):
|
||||||
while batch:
|
while batch:
|
||||||
costs = numpy.zeros((len(batch), nO), dtype="f")
|
costs = numpy.zeros((len(batch), nO), dtype="f")
|
||||||
for i, (eg, state, history, gold) in enumerate(batch):
|
for i, (eg, state, history, gold) in enumerate(batch):
|
||||||
|
costs_i = costs[i]
|
||||||
clas = history.pop(0)
|
clas = history.pop(0)
|
||||||
moves.set_costs(is_valid, c_costs, state.c, gold)
|
moves.set_costs(is_valid, <weight_t*>costs_i.data, state.c, gold)
|
||||||
action = moves.c[clas]
|
action = moves.c[clas]
|
||||||
action.do(state.c, action.label)
|
action.do(state.c, action.label)
|
||||||
state.c.history.push_back(clas)
|
state.c.history.push_back(clas)
|
||||||
for j in range(nO):
|
|
||||||
costs[i, j] = c_costs[j]
|
|
||||||
output.append(costs)
|
output.append(costs)
|
||||||
batch = [(eg, s, h, g) for eg, s, h, g in batch if len(h) != 0]
|
batch = [(eg, s, h, g) for eg, s, h, g in batch if len(h) != 0]
|
||||||
return self.model.ops.xp.vstack(output)
|
return self.model.ops.xp.vstack(output)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user