Update for CBlas changes in Thinc 8.1.0.dev2 (#10970)

This commit is contained in:
Daniël de Kok 2022-06-16 11:42:34 +02:00 committed by GitHub
parent 0d352c46ed
commit 3d3fbeda9f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 7 additions and 6 deletions

View File

@ -5,7 +5,7 @@ requires = [
"cymem>=2.0.2,<2.1.0",
"preshed>=3.0.2,<3.1.0",
"murmurhash>=0.28.0,<1.1.0",
"thinc>=8.1.0.dev0,<8.2.0",
"thinc>=8.1.0.dev2,<8.2.0",
"pathy",
"numpy>=1.15.0",
]

View File

@ -3,7 +3,7 @@ spacy-legacy>=3.0.9,<3.1.0
spacy-loggers>=1.0.0,<2.0.0
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
thinc>=8.1.0.dev0,<8.2.0
thinc>=8.1.0.dev2,<8.2.0
ml_datasets>=0.2.0,<0.3.0
murmurhash>=0.28.0,<1.1.0
wasabi>=0.9.1,<1.1.0

View File

@ -38,7 +38,7 @@ setup_requires =
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
murmurhash>=0.28.0,<1.1.0
thinc>=8.1.0.dev0,<8.2.0
thinc>=8.1.0.dev2,<8.2.0
install_requires =
# Our libraries
spacy-legacy>=3.0.9,<3.1.0
@ -46,7 +46,7 @@ install_requires =
murmurhash>=0.28.0,<1.1.0
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
thinc>=8.1.0.dev0,<8.2.0
thinc>=8.1.0.dev2,<8.2.0
wasabi>=0.9.1,<1.1.0
srsly>=2.4.3,<3.0.0
catalogue>=2.0.6,<2.1.0

View File

@ -4,6 +4,7 @@ from libc.math cimport exp
from libc.string cimport memset, memcpy
from libc.stdlib cimport calloc, free, realloc
from thinc.backends.linalg cimport Vec, VecVec
from thinc.backends.cblas cimport saxpy, sgemm
import numpy
import numpy.random
@ -112,7 +113,7 @@ cdef void predict_states(CBlas cblas, ActivationsC* A, StateC** states,
memcpy(A.scores, A.hiddens, n.states * n.classes * sizeof(float))
else:
# Compute hidden-to-output
cblas.sgemm()(False, True, n.states, n.classes, n.hiddens,
sgemm(cblas)(False, True, n.states, n.classes, n.hiddens,
1.0, <const float *>A.hiddens, n.hiddens,
<const float *>W.hidden_weights, n.hiddens,
0.0, A.scores, n.classes)
@ -147,7 +148,7 @@ cdef void sum_state_features(CBlas cblas, float* output,
else:
idx = token_ids[f] * id_stride + f*O
feature = &cached[idx]
cblas.saxpy()(O, one, <const float*>feature, 1, &output[b*O], 1)
saxpy(cblas)(O, one, <const float*>feature, 1, &output[b*O], 1)
token_ids += F