mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Parser: use C saxpy/sgemm provided by the Ops implementation (#10773)
* Parser: use C saxpy/sgemm provided by the Ops implementation This is a backport of https://github.com/explosion/spaCy/pull/10747 from the parser refactor branch. It eliminates the explicit calls to BLIS, instead using the saxpy/sgemm provided by the Ops implementation. This allows us to use Accelerate in the parser on M1 Macs (with an updated thinc-apple-ops). Performance of the de_core_news_lg pipe: BLIS 0.7.0, no thinc-apple-ops: 6385 WPS BLIS 0.7.0, thinc-apple-ops: 36455 WPS BLIS 0.9.0, no thinc-apple-ops: 19188 WPS BLIS 0.9.0, thinc-apple-ops: 36682 WPS This PR, thinc-apple-ops: 38726 WPS Performance of the de_core_news_lg pipe (only tok2vec -> parser): BLIS 0.7.0, no thinc-apple-ops: 13907 WPS BLIS 0.7.0, thinc-apple-ops: 73172 WPS BLIS 0.9.0, no thinc-apple-ops: 41576 WPS BLIS 0.9.0, thinc-apple-ops: 72569 WPS This PR, thinc-apple-ops: 87061 WPS * Require thinc >=8.1.0,<8.2.0 * Lower thinc lowerbound to 8.1.0.dev0 * Use best CPU ops for CBLAS when the parser model is on the GPU * Fix another unguarded cblas() call * Fix: use ops as a shorthand for self.model.ops Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com> Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
This commit is contained in:
parent
6172af8158
commit
85dd2b6c04
|
@ -6,7 +6,6 @@ requires = [
|
|||
"preshed>=3.0.2,<3.1.0",
|
||||
"murmurhash>=0.28.0,<1.1.0",
|
||||
"thinc>=8.1.0.dev0,<8.2.0",
|
||||
"blis>=0.9.0,<0.10.0",
|
||||
"pathy",
|
||||
"numpy>=1.15.0",
|
||||
]
|
||||
|
|
|
@ -4,7 +4,6 @@ spacy-loggers>=1.0.0,<2.0.0
|
|||
cymem>=2.0.2,<2.1.0
|
||||
preshed>=3.0.2,<3.1.0
|
||||
thinc>=8.1.0.dev0,<8.2.0
|
||||
blis>=0.9.0,<0.10.0
|
||||
ml_datasets>=0.2.0,<0.3.0
|
||||
murmurhash>=0.28.0,<1.1.0
|
||||
wasabi>=0.9.1,<1.1.0
|
||||
|
|
|
@ -47,7 +47,6 @@ install_requires =
|
|||
cymem>=2.0.2,<2.1.0
|
||||
preshed>=3.0.2,<3.1.0
|
||||
thinc>=8.1.0.dev0,<8.2.0
|
||||
blis>=0.9.0,<0.10.0
|
||||
wasabi>=0.9.1,<1.1.0
|
||||
srsly>=2.4.3,<3.0.0
|
||||
catalogue>=2.0.6,<2.1.0
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from libc.string cimport memset, memcpy
|
||||
from thinc.backends.cblas cimport CBlas
|
||||
from ..typedefs cimport weight_t, hash_t
|
||||
from ..pipeline._parser_internals._state cimport StateC
|
||||
|
||||
|
@ -38,7 +39,7 @@ cdef ActivationsC alloc_activations(SizesC n) nogil
|
|||
|
||||
cdef void free_activations(const ActivationsC* A) nogil
|
||||
|
||||
cdef void predict_states(ActivationsC* A, StateC** states,
|
||||
cdef void predict_states(CBlas cblas, ActivationsC* A, StateC** states,
|
||||
const WeightsC* W, SizesC n) nogil
|
||||
|
||||
cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) nogil
|
||||
|
|
|
@ -4,11 +4,10 @@ from libc.math cimport exp
|
|||
from libc.string cimport memset, memcpy
|
||||
from libc.stdlib cimport calloc, free, realloc
|
||||
from thinc.backends.linalg cimport Vec, VecVec
|
||||
cimport blis.cy
|
||||
|
||||
import numpy
|
||||
import numpy.random
|
||||
from thinc.api import Model, CupyOps, NumpyOps
|
||||
from thinc.api import Model, CupyOps, NumpyOps, get_ops
|
||||
|
||||
from .. import util
|
||||
from ..errors import Errors
|
||||
|
@ -91,7 +90,7 @@ cdef void resize_activations(ActivationsC* A, SizesC n) nogil:
|
|||
A._curr_size = n.states
|
||||
|
||||
|
||||
cdef void predict_states(ActivationsC* A, StateC** states,
|
||||
cdef void predict_states(CBlas cblas, ActivationsC* A, StateC** states,
|
||||
const WeightsC* W, SizesC n) nogil:
|
||||
cdef double one = 1.0
|
||||
resize_activations(A, n)
|
||||
|
@ -99,7 +98,7 @@ cdef void predict_states(ActivationsC* A, StateC** states,
|
|||
states[i].set_context_tokens(&A.token_ids[i*n.feats], n.feats)
|
||||
memset(A.unmaxed, 0, n.states * n.hiddens * n.pieces * sizeof(float))
|
||||
memset(A.hiddens, 0, n.states * n.hiddens * sizeof(float))
|
||||
sum_state_features(A.unmaxed,
|
||||
sum_state_features(cblas, A.unmaxed,
|
||||
W.feat_weights, A.token_ids, n.states, n.feats, n.hiddens * n.pieces)
|
||||
for i in range(n.states):
|
||||
VecVec.add_i(&A.unmaxed[i*n.hiddens*n.pieces],
|
||||
|
@ -113,12 +112,10 @@ cdef void predict_states(ActivationsC* A, StateC** states,
|
|||
memcpy(A.scores, A.hiddens, n.states * n.classes * sizeof(float))
|
||||
else:
|
||||
# Compute hidden-to-output
|
||||
blis.cy.gemm(blis.cy.NO_TRANSPOSE, blis.cy.TRANSPOSE,
|
||||
n.states, n.classes, n.hiddens, one,
|
||||
<float*>A.hiddens, n.hiddens, 1,
|
||||
<float*>W.hidden_weights, n.hiddens, 1,
|
||||
one,
|
||||
<float*>A.scores, n.classes, 1)
|
||||
cblas.sgemm()(False, True, n.states, n.classes, n.hiddens,
|
||||
1.0, <const float *>A.hiddens, n.hiddens,
|
||||
<const float *>W.hidden_weights, n.hiddens,
|
||||
0.0, A.scores, n.classes)
|
||||
# Add bias
|
||||
for i in range(n.states):
|
||||
VecVec.add_i(&A.scores[i*n.classes],
|
||||
|
@ -135,7 +132,7 @@ cdef void predict_states(ActivationsC* A, StateC** states,
|
|||
A.scores[i*n.classes+j] = min_
|
||||
|
||||
|
||||
cdef void sum_state_features(float* output,
|
||||
cdef void sum_state_features(CBlas cblas, float* output,
|
||||
const float* cached, const int* token_ids, int B, int F, int O) nogil:
|
||||
cdef int idx, b, f, i
|
||||
cdef const float* feature
|
||||
|
@ -150,9 +147,7 @@ cdef void sum_state_features(float* output,
|
|||
else:
|
||||
idx = token_ids[f] * id_stride + f*O
|
||||
feature = &cached[idx]
|
||||
blis.cy.axpyv(blis.cy.NO_CONJUGATE, O, one,
|
||||
<float*>feature, 1,
|
||||
&output[b*O], 1)
|
||||
cblas.saxpy()(O, one, <const float*>feature, 1, &output[b*O], 1)
|
||||
token_ids += F
|
||||
|
||||
|
||||
|
@ -443,9 +438,15 @@ cdef class precompute_hiddens:
|
|||
# - Output from backward on GPU
|
||||
bp_hiddens = self._bp_hiddens
|
||||
|
||||
cdef CBlas cblas
|
||||
if isinstance(self.ops, CupyOps):
|
||||
cblas = get_ops("cpu").cblas()
|
||||
else:
|
||||
cblas = self.ops.cblas()
|
||||
|
||||
feat_weights = self.get_feat_weights()
|
||||
cdef int[:, ::1] ids = token_ids
|
||||
sum_state_features(<float*>state_vector.data,
|
||||
sum_state_features(cblas, <float*>state_vector.data,
|
||||
feat_weights, &ids[0,0],
|
||||
token_ids.shape[0], self.nF, self.nO*self.nP)
|
||||
state_vector += self.bias
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from cymem.cymem cimport Pool
|
||||
from thinc.backends.cblas cimport CBlas
|
||||
|
||||
from ..vocab cimport Vocab
|
||||
from .trainable_pipe cimport TrainablePipe
|
||||
|
@ -12,7 +13,7 @@ cdef class Parser(TrainablePipe):
|
|||
cdef readonly TransitionSystem moves
|
||||
cdef public object _multitasks
|
||||
|
||||
cdef void _parseC(self, StateC** states,
|
||||
cdef void _parseC(self, CBlas cblas, StateC** states,
|
||||
WeightsC weights, SizesC sizes) nogil
|
||||
|
||||
cdef void c_transition_batch(self, StateC** states, const float* scores,
|
||||
|
|
|
@ -9,7 +9,7 @@ from libc.stdlib cimport calloc, free
|
|||
import random
|
||||
|
||||
import srsly
|
||||
from thinc.api import set_dropout_rate, CupyOps
|
||||
from thinc.api import get_ops, set_dropout_rate, CupyOps
|
||||
from thinc.extra.search cimport Beam
|
||||
import numpy.random
|
||||
import numpy
|
||||
|
@ -259,6 +259,12 @@ cdef class Parser(TrainablePipe):
|
|||
def greedy_parse(self, docs, drop=0.):
|
||||
cdef vector[StateC*] states
|
||||
cdef StateClass state
|
||||
ops = self.model.ops
|
||||
cdef CBlas cblas
|
||||
if isinstance(ops, CupyOps):
|
||||
cblas = get_ops("cpu").cblas()
|
||||
else:
|
||||
cblas = ops.cblas()
|
||||
self._ensure_labels_are_added(docs)
|
||||
set_dropout_rate(self.model, drop)
|
||||
batch = self.moves.init_batch(docs)
|
||||
|
@ -269,8 +275,7 @@ cdef class Parser(TrainablePipe):
|
|||
states.push_back(state.c)
|
||||
sizes = get_c_sizes(model, states.size())
|
||||
with nogil:
|
||||
self._parseC(&states[0],
|
||||
weights, sizes)
|
||||
self._parseC(cblas, &states[0], weights, sizes)
|
||||
model.clear_memory()
|
||||
del model
|
||||
return batch
|
||||
|
@ -297,14 +302,13 @@ cdef class Parser(TrainablePipe):
|
|||
del model
|
||||
return list(batch)
|
||||
|
||||
cdef void _parseC(self, StateC** states,
|
||||
cdef void _parseC(self, CBlas cblas, StateC** states,
|
||||
WeightsC weights, SizesC sizes) nogil:
|
||||
cdef int i, j
|
||||
cdef vector[StateC*] unfinished
|
||||
cdef ActivationsC activations = alloc_activations(sizes)
|
||||
while sizes.states >= 1:
|
||||
predict_states(&activations,
|
||||
states, &weights, sizes)
|
||||
predict_states(cblas, &activations, states, &weights, sizes)
|
||||
# Validate actions, argmax, take action.
|
||||
self.c_transition_batch(states,
|
||||
activations.scores, sizes.classes, sizes.states)
|
||||
|
|
Loading…
Reference in New Issue
Block a user