mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
precompute_hiddens/Parser: look up CPU ops once (v4) (#11068)
* precompute_hiddens/Parser: look up CPU ops once * precompute_hiddens: make cpu_ops private
This commit is contained in:
parent
b2d05f9f66
commit
e581eeac34
|
@ -347,6 +347,7 @@ cdef class precompute_hiddens:
|
|||
cdef bint _is_synchronized
|
||||
cdef public object ops
|
||||
cdef public object numpy_ops
|
||||
cdef public object _cpu_ops
|
||||
cdef np.ndarray _features
|
||||
cdef np.ndarray _cached
|
||||
cdef np.ndarray bias
|
||||
|
@ -377,6 +378,7 @@ cdef class precompute_hiddens:
|
|||
self.nO = cached.shape[2]
|
||||
self.ops = lower_model.ops
|
||||
self.numpy_ops = NumpyOps()
|
||||
self._cpu_ops = get_ops("cpu") if isinstance(self.ops, CupyOps) else self.ops
|
||||
assert activation in (None, "relu", "maxout")
|
||||
self.activation = activation
|
||||
self._is_synchronized = False
|
||||
|
@ -439,11 +441,7 @@ cdef class precompute_hiddens:
|
|||
# - Output from backward on GPU
|
||||
bp_hiddens = self._bp_hiddens
|
||||
|
||||
cdef CBlas cblas
|
||||
if isinstance(self.ops, CupyOps):
|
||||
cblas = NUMPY_OPS.cblas()
|
||||
else:
|
||||
cblas = self.ops.cblas()
|
||||
cdef CBlas cblas = self._cpu_ops.cblas()
|
||||
|
||||
feat_weights = self.get_feat_weights()
|
||||
cdef int[:, ::1] ids = token_ids
|
||||
|
|
|
@ -12,6 +12,7 @@ cdef class Parser(TrainablePipe):
|
|||
cdef public object _rehearsal_model
|
||||
cdef readonly TransitionSystem moves
|
||||
cdef public object _multitasks
|
||||
cdef object _cpu_ops
|
||||
|
||||
cdef void _parseC(self, CBlas cblas, StateC** states,
|
||||
WeightsC weights, SizesC sizes) nogil
|
||||
|
|
|
@ -123,6 +123,7 @@ cdef class Parser(TrainablePipe):
|
|||
|
||||
self._rehearsal_model = None
|
||||
self.scorer = scorer
|
||||
self._cpu_ops = get_ops("cpu") if isinstance(self.model.ops, CupyOps) else self.model.ops
|
||||
|
||||
def __getnewargs_ex__(self):
|
||||
"""This allows pickling the Parser and its keyword-only init arguments"""
|
||||
|
@ -262,12 +263,7 @@ cdef class Parser(TrainablePipe):
|
|||
def greedy_parse(self, docs, drop=0.):
|
||||
cdef vector[StateC*] states
|
||||
cdef StateClass state
|
||||
ops = self.model.ops
|
||||
cdef CBlas cblas
|
||||
if isinstance(ops, CupyOps):
|
||||
cblas = NUMPY_OPS.cblas()
|
||||
else:
|
||||
cblas = ops.cblas()
|
||||
cdef CBlas cblas = self._cpu_ops.cblas()
|
||||
self._ensure_labels_are_added(docs)
|
||||
set_dropout_rate(self.model, drop)
|
||||
batch = self.moves.init_batch(docs)
|
||||
|
|
Loading…
Reference in New Issue
Block a user