diff --git a/spacy/ml/parser_model.pyx b/spacy/ml/parser_model.pyx index 961bf4d70..055fa0bad 100644 --- a/spacy/ml/parser_model.pyx +++ b/spacy/ml/parser_model.pyx @@ -347,6 +347,7 @@ cdef class precompute_hiddens: cdef bint _is_synchronized cdef public object ops cdef public object numpy_ops + cdef public object _cpu_ops cdef np.ndarray _features cdef np.ndarray _cached cdef np.ndarray bias @@ -377,6 +378,7 @@ cdef class precompute_hiddens: self.nO = cached.shape[2] self.ops = lower_model.ops self.numpy_ops = NumpyOps() + self._cpu_ops = get_ops("cpu") if isinstance(self.ops, CupyOps) else self.ops assert activation in (None, "relu", "maxout") self.activation = activation self._is_synchronized = False @@ -439,11 +441,7 @@ cdef class precompute_hiddens: # - Output from backward on GPU bp_hiddens = self._bp_hiddens - cdef CBlas cblas - if isinstance(self.ops, CupyOps): - cblas = NUMPY_OPS.cblas() - else: - cblas = self.ops.cblas() + cdef CBlas cblas = self._cpu_ops.cblas() feat_weights = self.get_feat_weights() cdef int[:, ::1] ids = token_ids diff --git a/spacy/pipeline/transition_parser.pxd b/spacy/pipeline/transition_parser.pxd index 1521fde60..f20e69a6e 100644 --- a/spacy/pipeline/transition_parser.pxd +++ b/spacy/pipeline/transition_parser.pxd @@ -12,6 +12,7 @@ cdef class Parser(TrainablePipe): cdef public object _rehearsal_model cdef readonly TransitionSystem moves cdef public object _multitasks + cdef object _cpu_ops cdef void _parseC(self, CBlas cblas, StateC** states, WeightsC weights, SizesC sizes) nogil diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx index 1327db2ce..340334b1a 100644 --- a/spacy/pipeline/transition_parser.pyx +++ b/spacy/pipeline/transition_parser.pyx @@ -123,6 +123,7 @@ cdef class Parser(TrainablePipe): self._rehearsal_model = None self.scorer = scorer + self._cpu_ops = get_ops("cpu") if isinstance(self.model.ops, CupyOps) else self.model.ops def __getnewargs_ex__(self): """This allows pickling the Parser and its keyword-only init arguments""" @@ -262,12 +263,7 @@ cdef class Parser(TrainablePipe): def greedy_parse(self, docs, drop=0.): cdef vector[StateC*] states cdef StateClass state - ops = self.model.ops - cdef CBlas cblas - if isinstance(ops, CupyOps): - cblas = NUMPY_OPS.cblas() - else: - cblas = ops.cblas() + cdef CBlas cblas = self._cpu_ops.cblas() self._ensure_labels_are_added(docs) set_dropout_rate(self.model, drop) batch = self.moves.init_batch(docs)