Parser: use C saxpy/sgemm provided by the Ops implementation (#10773)

* Parser: use C saxpy/sgemm provided by the Ops implementation

This is a backport of https://github.com/explosion/spaCy/pull/10747
from the parser refactor branch. It eliminates the explicit calls
to BLIS, instead using the saxpy/sgemm provided by the Ops
implementation.

This allows us to use Accelerate in the parser on M1 Macs (with
an updated thinc-apple-ops).

Performance of the de_core_news_lg pipe:

BLIS 0.7.0, no thinc-apple-ops:  6385 WPS
BLIS 0.7.0, thinc-apple-ops:    36455 WPS
BLIS 0.9.0, no thinc-apple-ops: 19188 WPS
BLIS 0.9.0, thinc-apple-ops:    36682 WPS
This PR, thinc-apple-ops:       38726 WPS

Performance of the de_core_news_lg pipe (only tok2vec -> parser):

BLIS 0.7.0, no thinc-apple-ops: 13907 WPS
BLIS 0.7.0, thinc-apple-ops:    73172 WPS
BLIS 0.9.0, no thinc-apple-ops: 41576 WPS
BLIS 0.9.0, thinc-apple-ops:    72569 WPS
This PR, thinc-apple-ops:       87061 WPS

* Require thinc >=8.1.0,<8.2.0

* Lower thinc lowerbound to 8.1.0.dev0

* Use best CPU ops for CBLAS when the parser model is on the GPU

* Fix another unguarded cblas() call

* Fix: use ops as a shorthand for self.model.ops

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
This commit is contained in:
Daniël de Kok 2022-05-27 11:20:52 +02:00 committed by GitHub
parent 6172af8158
commit 85dd2b6c04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 30 additions and 26 deletions

View File

@ -6,7 +6,6 @@ requires = [
"preshed>=3.0.2,<3.1.0", "preshed>=3.0.2,<3.1.0",
"murmurhash>=0.28.0,<1.1.0", "murmurhash>=0.28.0,<1.1.0",
"thinc>=8.1.0.dev0,<8.2.0", "thinc>=8.1.0.dev0,<8.2.0",
"blis>=0.9.0,<0.10.0",
"pathy", "pathy",
"numpy>=1.15.0", "numpy>=1.15.0",
] ]

View File

@ -4,7 +4,6 @@ spacy-loggers>=1.0.0,<2.0.0
cymem>=2.0.2,<2.1.0 cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0 preshed>=3.0.2,<3.1.0
thinc>=8.1.0.dev0,<8.2.0 thinc>=8.1.0.dev0,<8.2.0
blis>=0.9.0,<0.10.0
ml_datasets>=0.2.0,<0.3.0 ml_datasets>=0.2.0,<0.3.0
murmurhash>=0.28.0,<1.1.0 murmurhash>=0.28.0,<1.1.0
wasabi>=0.9.1,<1.1.0 wasabi>=0.9.1,<1.1.0

View File

@ -47,7 +47,6 @@ install_requires =
cymem>=2.0.2,<2.1.0 cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0 preshed>=3.0.2,<3.1.0
thinc>=8.1.0.dev0,<8.2.0 thinc>=8.1.0.dev0,<8.2.0
blis>=0.9.0,<0.10.0
wasabi>=0.9.1,<1.1.0 wasabi>=0.9.1,<1.1.0
srsly>=2.4.3,<3.0.0 srsly>=2.4.3,<3.0.0
catalogue>=2.0.6,<2.1.0 catalogue>=2.0.6,<2.1.0

View File

@ -1,4 +1,5 @@
from libc.string cimport memset, memcpy from libc.string cimport memset, memcpy
from thinc.backends.cblas cimport CBlas
from ..typedefs cimport weight_t, hash_t from ..typedefs cimport weight_t, hash_t
from ..pipeline._parser_internals._state cimport StateC from ..pipeline._parser_internals._state cimport StateC
@ -38,7 +39,7 @@ cdef ActivationsC alloc_activations(SizesC n) nogil
cdef void free_activations(const ActivationsC* A) nogil cdef void free_activations(const ActivationsC* A) nogil
cdef void predict_states(ActivationsC* A, StateC** states, cdef void predict_states(CBlas cblas, ActivationsC* A, StateC** states,
const WeightsC* W, SizesC n) nogil const WeightsC* W, SizesC n) nogil
cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) nogil cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) nogil

View File

@ -4,11 +4,10 @@ from libc.math cimport exp
from libc.string cimport memset, memcpy from libc.string cimport memset, memcpy
from libc.stdlib cimport calloc, free, realloc from libc.stdlib cimport calloc, free, realloc
from thinc.backends.linalg cimport Vec, VecVec from thinc.backends.linalg cimport Vec, VecVec
cimport blis.cy
import numpy import numpy
import numpy.random import numpy.random
from thinc.api import Model, CupyOps, NumpyOps from thinc.api import Model, CupyOps, NumpyOps, get_ops
from .. import util from .. import util
from ..errors import Errors from ..errors import Errors
@ -91,7 +90,7 @@ cdef void resize_activations(ActivationsC* A, SizesC n) nogil:
A._curr_size = n.states A._curr_size = n.states
cdef void predict_states(ActivationsC* A, StateC** states, cdef void predict_states(CBlas cblas, ActivationsC* A, StateC** states,
const WeightsC* W, SizesC n) nogil: const WeightsC* W, SizesC n) nogil:
cdef double one = 1.0 cdef double one = 1.0
resize_activations(A, n) resize_activations(A, n)
@ -99,7 +98,7 @@ cdef void predict_states(ActivationsC* A, StateC** states,
states[i].set_context_tokens(&A.token_ids[i*n.feats], n.feats) states[i].set_context_tokens(&A.token_ids[i*n.feats], n.feats)
memset(A.unmaxed, 0, n.states * n.hiddens * n.pieces * sizeof(float)) memset(A.unmaxed, 0, n.states * n.hiddens * n.pieces * sizeof(float))
memset(A.hiddens, 0, n.states * n.hiddens * sizeof(float)) memset(A.hiddens, 0, n.states * n.hiddens * sizeof(float))
sum_state_features(A.unmaxed, sum_state_features(cblas, A.unmaxed,
W.feat_weights, A.token_ids, n.states, n.feats, n.hiddens * n.pieces) W.feat_weights, A.token_ids, n.states, n.feats, n.hiddens * n.pieces)
for i in range(n.states): for i in range(n.states):
VecVec.add_i(&A.unmaxed[i*n.hiddens*n.pieces], VecVec.add_i(&A.unmaxed[i*n.hiddens*n.pieces],
@ -113,12 +112,10 @@ cdef void predict_states(ActivationsC* A, StateC** states,
memcpy(A.scores, A.hiddens, n.states * n.classes * sizeof(float)) memcpy(A.scores, A.hiddens, n.states * n.classes * sizeof(float))
else: else:
# Compute hidden-to-output # Compute hidden-to-output
blis.cy.gemm(blis.cy.NO_TRANSPOSE, blis.cy.TRANSPOSE, cblas.sgemm()(False, True, n.states, n.classes, n.hiddens,
n.states, n.classes, n.hiddens, one, 1.0, <const float *>A.hiddens, n.hiddens,
<float*>A.hiddens, n.hiddens, 1, <const float *>W.hidden_weights, n.hiddens,
<float*>W.hidden_weights, n.hiddens, 1, 0.0, A.scores, n.classes)
one,
<float*>A.scores, n.classes, 1)
# Add bias # Add bias
for i in range(n.states): for i in range(n.states):
VecVec.add_i(&A.scores[i*n.classes], VecVec.add_i(&A.scores[i*n.classes],
@ -135,7 +132,7 @@ cdef void predict_states(ActivationsC* A, StateC** states,
A.scores[i*n.classes+j] = min_ A.scores[i*n.classes+j] = min_
cdef void sum_state_features(float* output, cdef void sum_state_features(CBlas cblas, float* output,
const float* cached, const int* token_ids, int B, int F, int O) nogil: const float* cached, const int* token_ids, int B, int F, int O) nogil:
cdef int idx, b, f, i cdef int idx, b, f, i
cdef const float* feature cdef const float* feature
@ -150,9 +147,7 @@ cdef void sum_state_features(float* output,
else: else:
idx = token_ids[f] * id_stride + f*O idx = token_ids[f] * id_stride + f*O
feature = &cached[idx] feature = &cached[idx]
blis.cy.axpyv(blis.cy.NO_CONJUGATE, O, one, cblas.saxpy()(O, one, <const float*>feature, 1, &output[b*O], 1)
<float*>feature, 1,
&output[b*O], 1)
token_ids += F token_ids += F
@ -443,9 +438,15 @@ cdef class precompute_hiddens:
# - Output from backward on GPU # - Output from backward on GPU
bp_hiddens = self._bp_hiddens bp_hiddens = self._bp_hiddens
cdef CBlas cblas
if isinstance(self.ops, CupyOps):
cblas = get_ops("cpu").cblas()
else:
cblas = self.ops.cblas()
feat_weights = self.get_feat_weights() feat_weights = self.get_feat_weights()
cdef int[:, ::1] ids = token_ids cdef int[:, ::1] ids = token_ids
sum_state_features(<float*>state_vector.data, sum_state_features(cblas, <float*>state_vector.data,
feat_weights, &ids[0,0], feat_weights, &ids[0,0],
token_ids.shape[0], self.nF, self.nO*self.nP) token_ids.shape[0], self.nF, self.nO*self.nP)
state_vector += self.bias state_vector += self.bias

View File

@ -1,4 +1,5 @@
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from thinc.backends.cblas cimport CBlas
from ..vocab cimport Vocab from ..vocab cimport Vocab
from .trainable_pipe cimport TrainablePipe from .trainable_pipe cimport TrainablePipe
@ -12,7 +13,7 @@ cdef class Parser(TrainablePipe):
cdef readonly TransitionSystem moves cdef readonly TransitionSystem moves
cdef public object _multitasks cdef public object _multitasks
cdef void _parseC(self, StateC** states, cdef void _parseC(self, CBlas cblas, StateC** states,
WeightsC weights, SizesC sizes) nogil WeightsC weights, SizesC sizes) nogil
cdef void c_transition_batch(self, StateC** states, const float* scores, cdef void c_transition_batch(self, StateC** states, const float* scores,

View File

@ -9,7 +9,7 @@ from libc.stdlib cimport calloc, free
import random import random
import srsly import srsly
from thinc.api import set_dropout_rate, CupyOps from thinc.api import get_ops, set_dropout_rate, CupyOps
from thinc.extra.search cimport Beam from thinc.extra.search cimport Beam
import numpy.random import numpy.random
import numpy import numpy
@ -259,6 +259,12 @@ cdef class Parser(TrainablePipe):
def greedy_parse(self, docs, drop=0.): def greedy_parse(self, docs, drop=0.):
cdef vector[StateC*] states cdef vector[StateC*] states
cdef StateClass state cdef StateClass state
ops = self.model.ops
cdef CBlas cblas
if isinstance(ops, CupyOps):
cblas = get_ops("cpu").cblas()
else:
cblas = ops.cblas()
self._ensure_labels_are_added(docs) self._ensure_labels_are_added(docs)
set_dropout_rate(self.model, drop) set_dropout_rate(self.model, drop)
batch = self.moves.init_batch(docs) batch = self.moves.init_batch(docs)
@ -269,8 +275,7 @@ cdef class Parser(TrainablePipe):
states.push_back(state.c) states.push_back(state.c)
sizes = get_c_sizes(model, states.size()) sizes = get_c_sizes(model, states.size())
with nogil: with nogil:
self._parseC(&states[0], self._parseC(cblas, &states[0], weights, sizes)
weights, sizes)
model.clear_memory() model.clear_memory()
del model del model
return batch return batch
@ -297,14 +302,13 @@ cdef class Parser(TrainablePipe):
del model del model
return list(batch) return list(batch)
cdef void _parseC(self, StateC** states, cdef void _parseC(self, CBlas cblas, StateC** states,
WeightsC weights, SizesC sizes) nogil: WeightsC weights, SizesC sizes) nogil:
cdef int i, j cdef int i, j
cdef vector[StateC*] unfinished cdef vector[StateC*] unfinished
cdef ActivationsC activations = alloc_activations(sizes) cdef ActivationsC activations = alloc_activations(sizes)
while sizes.states >= 1: while sizes.states >= 1:
predict_states(&activations, predict_states(cblas, &activations, states, &weights, sizes)
states, &weights, sizes)
# Validate actions, argmax, take action. # Validate actions, argmax, take action.
self.c_transition_batch(states, self.c_transition_batch(states,
activations.scores, sizes.classes, sizes.states) activations.scores, sizes.classes, sizes.states)