Try to make sum_state_features faster

This commit is contained in:
Matthew Honnibal 2018-03-27 10:08:38 +00:00
parent 987e1533a4
commit 25280b7013

View File

@ -165,16 +165,17 @@ cdef void sum_state_features(float* output,
cdef const float* feature
padding = cached
cached += F * O
cdef int id_stride = F*O
cdef float one = 1.
for b in range(B):
for f in range(F):
if token_ids[f] < 0:
feature = &padding[f*O]
else:
idx = token_ids[f] * F * O + f*O
idx = token_ids[f] * id_stride + f*O
feature = &cached[idx]
VecVec.add_i(output,
feature, 1., O)
output += O
openblas.simple_axpy(&output[b*O], O,
feature, one)
token_ids += F