mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 01:16:28 +03:00
Try to make sum_state_features faster
This commit is contained in:
parent
987e1533a4
commit
25280b7013
|
@ -165,16 +165,17 @@ cdef void sum_state_features(float* output,
|
||||||
cdef const float* feature
|
cdef const float* feature
|
||||||
padding = cached
|
padding = cached
|
||||||
cached += F * O
|
cached += F * O
|
||||||
|
cdef int id_stride = F*O
|
||||||
|
cdef float one = 1.
|
||||||
for b in range(B):
|
for b in range(B):
|
||||||
for f in range(F):
|
for f in range(F):
|
||||||
if token_ids[f] < 0:
|
if token_ids[f] < 0:
|
||||||
feature = &padding[f*O]
|
feature = &padding[f*O]
|
||||||
else:
|
else:
|
||||||
idx = token_ids[f] * F * O + f*O
|
idx = token_ids[f] * id_stride + f*O
|
||||||
feature = &cached[idx]
|
feature = &cached[idx]
|
||||||
VecVec.add_i(output,
|
openblas.simple_axpy(&output[b*O], O,
|
||||||
feature, 1., O)
|
feature, one)
|
||||||
output += O
|
|
||||||
token_ids += F
|
token_ids += F
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user