precompute_hiddens/Parser: do not look up CPU ops (3.4) (#11069)

* precompute_hiddens/Parser: do not look up CPU ops

`get_ops("cpu")` is quite expensive. To avoid this, we want to cache the
result as in #11068. However, for 3.x we do not want to change the ABI.
So we avoid the expensive lookup by using NumpyOps. This should have a
minimal impact, since `get_ops("cpu")` was only used when the model ops
were `CupyOps`. If the ops are `AppleOps`, we are still passing through
the correct BLAS implementation.

* _NUMPY_OPS -> NUMPY_OPS
This commit is contained in:
Daniël de Kok 2022-07-05 10:53:42 +02:00 committed by GitHub
parent 78a84f0d78
commit a06cbae70d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 3 deletions

View File

@ -441,7 +441,7 @@ cdef class precompute_hiddens:
cdef CBlas cblas cdef CBlas cblas
if isinstance(self.ops, CupyOps): if isinstance(self.ops, CupyOps):
cblas = get_ops("cpu").cblas() cblas = NUMPY_OPS.cblas()
else: else:
cblas = self.ops.cblas() cblas = self.ops.cblas()

View File

@ -9,7 +9,7 @@ from libc.stdlib cimport calloc, free
import random import random
import srsly import srsly
from thinc.api import get_ops, set_dropout_rate, CupyOps from thinc.api import get_ops, set_dropout_rate, CupyOps, NumpyOps
from thinc.extra.search cimport Beam from thinc.extra.search cimport Beam
import numpy.random import numpy.random
import numpy import numpy
@ -30,6 +30,9 @@ from ..errors import Errors, Warnings
from .. import util from .. import util
NUMPY_OPS = NumpyOps()
cdef class Parser(TrainablePipe): cdef class Parser(TrainablePipe):
""" """
Base class of the DependencyParser and EntityRecognizer. Base class of the DependencyParser and EntityRecognizer.
@ -262,7 +265,7 @@ cdef class Parser(TrainablePipe):
ops = self.model.ops ops = self.model.ops
cdef CBlas cblas cdef CBlas cblas
if isinstance(ops, CupyOps): if isinstance(ops, CupyOps):
cblas = get_ops("cpu").cblas() cblas = NUMPY_OPS.cblas()
else: else:
cblas = ops.cblas() cblas = ops.cblas()
self._ensure_labels_are_added(docs) self._ensure_labels_are_added(docs)