mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-30 23:47:31 +03:00 
			
		
		
		
	* Move `thinc.extra.search` to `spacy.pipeline._parser_internals` Backport of: https://github.com/explosion/spaCy/pull/11317 Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com> * Replace references to `thinc.backends.linalg` with `CBlas` Backport of: https://github.com/explosion/spaCy/pull/11292 Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com> * Use cross entropy from `thinc.legacy` * Require thinc>=9.0.0.dev0,<9.1.0 Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
		
			
				
	
	
		
			501 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			Cython
		
	
	
	
	
	
			
		
		
	
	
			501 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			Cython
		
	
	
	
	
	
| # cython: infer_types=True, cdivision=True, boundscheck=False
 | |
| cimport numpy as np
 | |
| from libc.math cimport exp
 | |
| from libc.string cimport memset, memcpy
 | |
| from libc.stdlib cimport calloc, free, realloc
 | |
| from thinc.backends.cblas cimport saxpy, sgemm
 | |
| 
 | |
| import numpy
 | |
| import numpy.random
 | |
| from thinc.api import Model, CupyOps, NumpyOps, get_ops
 | |
| 
 | |
| from .. import util
 | |
| from ..errors import Errors
 | |
| from ..typedefs cimport weight_t, class_t, hash_t
 | |
| from ..pipeline._parser_internals.stateclass cimport StateClass
 | |
| 
 | |
| 
 | |
| cdef WeightsC get_c_weights(model) except *:
 | |
|     cdef WeightsC output
 | |
|     cdef precompute_hiddens state2vec = model.state2vec
 | |
|     output.feat_weights = state2vec.get_feat_weights()
 | |
|     output.feat_bias = <const float*>state2vec.bias.data
 | |
|     cdef np.ndarray vec2scores_W
 | |
|     cdef np.ndarray vec2scores_b
 | |
|     if model.vec2scores is None:
 | |
|         output.hidden_weights = NULL
 | |
|         output.hidden_bias = NULL
 | |
|     else:
 | |
|         vec2scores_W = model.vec2scores.get_param("W")
 | |
|         vec2scores_b = model.vec2scores.get_param("b")
 | |
|         output.hidden_weights = <const float*>vec2scores_W.data
 | |
|         output.hidden_bias = <const float*>vec2scores_b.data
 | |
|     cdef np.ndarray class_mask = model._class_mask
 | |
|     output.seen_classes = <const float*>class_mask.data
 | |
|     return output
 | |
| 
 | |
| 
 | |
| cdef SizesC get_c_sizes(model, int batch_size) except *:
 | |
|     cdef SizesC output
 | |
|     output.states = batch_size
 | |
|     if model.vec2scores is None:
 | |
|         output.classes = model.state2vec.get_dim("nO")
 | |
|     else:
 | |
|         output.classes = model.vec2scores.get_dim("nO")
 | |
|     output.hiddens = model.state2vec.get_dim("nO")
 | |
|     output.pieces = model.state2vec.get_dim("nP")
 | |
|     output.feats = model.state2vec.get_dim("nF")
 | |
|     output.embed_width = model.tokvecs.shape[1]
 | |
|     return output
 | |
| 
 | |
| 
 | |
| cdef ActivationsC alloc_activations(SizesC n) nogil:
 | |
|     cdef ActivationsC A
 | |
|     memset(&A, 0, sizeof(A))
 | |
|     resize_activations(&A, n)
 | |
|     return A
 | |
| 
 | |
| 
 | |
| cdef void free_activations(const ActivationsC* A) nogil:
 | |
|     free(A.token_ids)
 | |
|     free(A.scores)
 | |
|     free(A.unmaxed)
 | |
|     free(A.hiddens)
 | |
|     free(A.is_valid)
 | |
| 
 | |
| 
 | |
| cdef void resize_activations(ActivationsC* A, SizesC n) nogil:
 | |
|     if n.states <= A._max_size:
 | |
|         A._curr_size = n.states
 | |
|         return
 | |
|     if A._max_size == 0:
 | |
|         A.token_ids = <int*>calloc(n.states * n.feats, sizeof(A.token_ids[0]))
 | |
|         A.scores = <float*>calloc(n.states * n.classes, sizeof(A.scores[0]))
 | |
|         A.unmaxed = <float*>calloc(n.states * n.hiddens * n.pieces, sizeof(A.unmaxed[0]))
 | |
|         A.hiddens = <float*>calloc(n.states * n.hiddens, sizeof(A.hiddens[0]))
 | |
|         A.is_valid = <int*>calloc(n.states * n.classes, sizeof(A.is_valid[0]))
 | |
|         A._max_size = n.states
 | |
|     else:
 | |
|         A.token_ids = <int*>realloc(A.token_ids,
 | |
|             n.states * n.feats * sizeof(A.token_ids[0]))
 | |
|         A.scores = <float*>realloc(A.scores,
 | |
|             n.states * n.classes * sizeof(A.scores[0]))
 | |
|         A.unmaxed = <float*>realloc(A.unmaxed,
 | |
|             n.states * n.hiddens * n.pieces * sizeof(A.unmaxed[0]))
 | |
|         A.hiddens = <float*>realloc(A.hiddens,
 | |
|             n.states * n.hiddens * sizeof(A.hiddens[0]))
 | |
|         A.is_valid = <int*>realloc(A.is_valid,
 | |
|             n.states * n.classes * sizeof(A.is_valid[0]))
 | |
|         A._max_size = n.states
 | |
|     A._curr_size = n.states
 | |
| 
 | |
| 
 | |
| cdef void predict_states(CBlas cblas, ActivationsC* A, StateC** states,
 | |
|         const WeightsC* W, SizesC n) nogil:
 | |
|     cdef double one = 1.0
 | |
|     resize_activations(A, n)
 | |
|     for i in range(n.states):
 | |
|         states[i].set_context_tokens(&A.token_ids[i*n.feats], n.feats)
 | |
|     memset(A.unmaxed, 0, n.states * n.hiddens * n.pieces * sizeof(float))
 | |
|     memset(A.hiddens, 0, n.states * n.hiddens * sizeof(float))
 | |
|     sum_state_features(cblas, A.unmaxed,
 | |
|         W.feat_weights, A.token_ids, n.states, n.feats, n.hiddens * n.pieces)
 | |
|     for i in range(n.states):
 | |
|         saxpy(cblas)(n.hiddens * n.pieces, 1., W.feat_bias, 1, &A.unmaxed[i*n.hiddens*n.pieces], 1)
 | |
|         for j in range(n.hiddens):
 | |
|             index = i * n.hiddens * n.pieces + j * n.pieces
 | |
|             which = _arg_max(&A.unmaxed[index], n.pieces)
 | |
|             A.hiddens[i*n.hiddens + j] = A.unmaxed[index + which]
 | |
|     memset(A.scores, 0, n.states * n.classes * sizeof(float))
 | |
|     if W.hidden_weights == NULL:
 | |
|         memcpy(A.scores, A.hiddens, n.states * n.classes * sizeof(float))
 | |
|     else:
 | |
|         # Compute hidden-to-output
 | |
|         sgemm(cblas)(False, True, n.states, n.classes, n.hiddens,
 | |
|             1.0, <const float *>A.hiddens, n.hiddens,
 | |
|             <const float *>W.hidden_weights, n.hiddens,
 | |
|             0.0, A.scores, n.classes)
 | |
|         # Add bias
 | |
|         for i in range(n.states):
 | |
|             saxpy(cblas)(n.classes, 1., W.hidden_bias, 1, &A.scores[i*n.classes], 1)
 | |
|     # Set unseen classes to minimum value
 | |
|     i = 0
 | |
|     min_ = A.scores[0]
 | |
|     for i in range(1, n.states * n.classes):
 | |
|         if A.scores[i] < min_:
 | |
|             min_ = A.scores[i]
 | |
|     for i in range(n.states):
 | |
|         for j in range(n.classes):
 | |
|             if not W.seen_classes[j]:
 | |
|                 A.scores[i*n.classes+j] = min_
 | |
| 
 | |
| 
 | |
| cdef void sum_state_features(CBlas cblas, float* output,
 | |
|         const float* cached, const int* token_ids, int B, int F, int O) nogil:
 | |
|     cdef int idx, b, f, i
 | |
|     cdef const float* feature
 | |
|     padding = cached
 | |
|     cached += F * O
 | |
|     cdef int id_stride = F*O
 | |
|     cdef float one = 1.
 | |
|     for b in range(B):
 | |
|         for f in range(F):
 | |
|             if token_ids[f] < 0:
 | |
|                 feature = &padding[f*O]
 | |
|             else:
 | |
|                 idx = token_ids[f] * id_stride + f*O
 | |
|                 feature = &cached[idx]
 | |
|             saxpy(cblas)(O, one, <const float*>feature, 1, &output[b*O], 1)
 | |
|         token_ids += F
 | |
| 
 | |
| 
 | |
| cdef void cpu_log_loss(float* d_scores,
 | |
|         const float* costs, const int* is_valid, const float* scores,
 | |
|         int O) nogil:
 | |
|     """Do multi-label log loss"""
 | |
|     cdef double max_, gmax, Z, gZ
 | |
|     best = arg_max_if_gold(scores, costs, is_valid, O)
 | |
|     guess = _arg_max(scores, O)
 | |
| 
 | |
|     if best == -1 or guess == -1:
 | |
|         # These shouldn't happen, but if they do, we want to make sure we don't
 | |
|         # cause an OOB access.
 | |
|         return
 | |
|     Z = 1e-10
 | |
|     gZ = 1e-10
 | |
|     max_ = scores[guess]
 | |
|     gmax = scores[best]
 | |
|     for i in range(O):
 | |
|         Z += exp(scores[i] - max_)
 | |
|         if costs[i] <= costs[best]:
 | |
|             gZ += exp(scores[i] - gmax)
 | |
|     for i in range(O):
 | |
|         if costs[i] <= costs[best]:
 | |
|             d_scores[i] = (exp(scores[i]-max_) / Z) - (exp(scores[i]-gmax)/gZ)
 | |
|         else:
 | |
|             d_scores[i] = exp(scores[i]-max_) / Z
 | |
| 
 | |
| 
 | |
| cdef int arg_max_if_gold(const weight_t* scores, const weight_t* costs,
 | |
|         const int* is_valid, int n) nogil:
 | |
|     # Find minimum cost
 | |
|     cdef float cost = 1
 | |
|     for i in range(n):
 | |
|         if is_valid[i] and costs[i] < cost:
 | |
|             cost = costs[i]
 | |
|     # Now find best-scoring with that cost
 | |
|     cdef int best = -1
 | |
|     for i in range(n):
 | |
|         if costs[i] <= cost and is_valid[i]:
 | |
|             if best == -1 or scores[i] > scores[best]:
 | |
|                 best = i
 | |
|     return best
 | |
| 
 | |
| 
 | |
| cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) nogil:
 | |
|     cdef int best = -1
 | |
|     for i in range(n):
 | |
|         if is_valid[i] >= 1:
 | |
|             if best == -1 or scores[i] > scores[best]:
 | |
|                 best = i
 | |
|     return best
 | |
| 
 | |
| 
 | |
| 
 | |
| class ParserStepModel(Model):
 | |
|     def __init__(self, docs, layers, *, has_upper, unseen_classes=None, train=True,
 | |
|             dropout=0.1):
 | |
|         Model.__init__(self, name="parser_step_model", forward=step_forward)
 | |
|         self.attrs["has_upper"] = has_upper
 | |
|         self.attrs["dropout_rate"] = dropout
 | |
|         self.tokvecs, self.bp_tokvecs = layers[0](docs, is_train=train)
 | |
|         if layers[1].get_dim("nP") >= 2:
 | |
|             activation = "maxout"
 | |
|         elif has_upper:
 | |
|             activation = None
 | |
|         else:
 | |
|             activation = "relu"
 | |
|         self.state2vec = precompute_hiddens(len(docs), self.tokvecs, layers[1],
 | |
|                                             activation=activation, train=train)
 | |
|         if has_upper:
 | |
|             self.vec2scores = layers[-1]
 | |
|         else:
 | |
|             self.vec2scores = None
 | |
|         self.cuda_stream = util.get_cuda_stream(non_blocking=True)
 | |
|         self.backprops = []
 | |
|         self._class_mask = numpy.zeros((self.nO,), dtype='f')
 | |
|         self._class_mask.fill(1)
 | |
|         if unseen_classes is not None:
 | |
|             for class_ in unseen_classes:
 | |
|                 self._class_mask[class_] = 0.
 | |
| 
 | |
|     def clear_memory(self):
 | |
|         del self.tokvecs
 | |
|         del self.bp_tokvecs
 | |
|         del self.state2vec
 | |
|         del self.backprops
 | |
|         del self._class_mask
 | |
| 
 | |
|     @property
 | |
|     def nO(self):
 | |
|         if self.attrs["has_upper"]:
 | |
|             return self.vec2scores.get_dim("nO")
 | |
|         else:
 | |
|             return self.state2vec.get_dim("nO")
 | |
| 
 | |
|     def class_is_unseen(self, class_):
 | |
|         return self._class_mask[class_]
 | |
| 
 | |
|     def mark_class_unseen(self, class_):
 | |
|         self._class_mask[class_] = 0
 | |
| 
 | |
|     def mark_class_seen(self, class_):
 | |
|         self._class_mask[class_] = 1
 | |
| 
 | |
|     def get_token_ids(self, states):
 | |
|         cdef StateClass state
 | |
|         states = [state for state in states if not state.is_final()]
 | |
|         cdef np.ndarray ids = numpy.zeros((len(states), self.state2vec.nF),
 | |
|                                           dtype='i', order='C')
 | |
|         ids.fill(-1)
 | |
|         c_ids = <int*>ids.data
 | |
|         for state in states:
 | |
|             state.c.set_context_tokens(c_ids, ids.shape[1])
 | |
|             c_ids += ids.shape[1]
 | |
|         return ids
 | |
| 
 | |
|     def backprop_step(self, token_ids, d_vector, get_d_tokvecs):
 | |
|         if isinstance(self.state2vec.ops, CupyOps) \
 | |
|         and not isinstance(token_ids, self.state2vec.ops.xp.ndarray):
 | |
|             # Move token_ids and d_vector to GPU, asynchronously
 | |
|             self.backprops.append((
 | |
|                 util.get_async(self.cuda_stream, token_ids),
 | |
|                 util.get_async(self.cuda_stream, d_vector),
 | |
|                 get_d_tokvecs
 | |
|             ))
 | |
|         else:
 | |
|             self.backprops.append((token_ids, d_vector, get_d_tokvecs))
 | |
| 
 | |
| 
 | |
|     def finish_steps(self, golds):
 | |
|         # Add a padding vector to the d_tokvecs gradient, so that missing
 | |
|         # values don't affect the real gradient.
 | |
|         d_tokvecs = self.ops.alloc((self.tokvecs.shape[0]+1, self.tokvecs.shape[1]))
 | |
|         # Tells CUDA to block, so our async copies complete.
 | |
|         if self.cuda_stream is not None:
 | |
|             self.cuda_stream.synchronize()
 | |
|         for ids, d_vector, bp_vector in self.backprops:
 | |
|             d_state_features = bp_vector((d_vector, ids))
 | |
|             ids = ids.flatten()
 | |
|             d_state_features = d_state_features.reshape(
 | |
|                 (ids.size, d_state_features.shape[2]))
 | |
|             self.ops.scatter_add(d_tokvecs, ids,
 | |
|                 d_state_features)
 | |
|         # Padded -- see update()
 | |
|         self.bp_tokvecs(d_tokvecs[:-1])
 | |
|         return d_tokvecs
 | |
| 
 | |
| NUMPY_OPS = NumpyOps()
 | |
| 
 | |
| def step_forward(model: ParserStepModel, states, is_train):
 | |
|     token_ids = model.get_token_ids(states)
 | |
|     vector, get_d_tokvecs = model.state2vec(token_ids, is_train)
 | |
|     mask = None
 | |
|     if model.attrs["has_upper"]:
 | |
|         dropout_rate = model.attrs["dropout_rate"]
 | |
|         if is_train and dropout_rate > 0:
 | |
|             mask = NUMPY_OPS.get_dropout_mask(vector.shape, 0.1)
 | |
|             vector *= mask
 | |
|         scores, get_d_vector = model.vec2scores(vector, is_train)
 | |
|     else:
 | |
|         scores = NumpyOps().asarray(vector)
 | |
|         get_d_vector = lambda d_scores: d_scores
 | |
|     # If the class is unseen, make sure its score is minimum
 | |
|     scores[:, model._class_mask == 0] = numpy.nanmin(scores)
 | |
| 
 | |
|     def backprop_parser_step(d_scores):
 | |
|         # Zero vectors for unseen classes
 | |
|         d_scores *= model._class_mask
 | |
|         d_vector = get_d_vector(d_scores)
 | |
|         if mask is not None:
 | |
|             d_vector *= mask
 | |
|         model.backprop_step(token_ids, d_vector, get_d_tokvecs)
 | |
|         return None
 | |
|     return scores, backprop_parser_step
 | |
| 
 | |
| 
 | |
| cdef class precompute_hiddens:
 | |
|     """Allow a model to be "primed" by pre-computing input features in bulk.
 | |
| 
 | |
|     This is used for the parser, where we want to take a batch of documents,
 | |
|     and compute vectors for each (token, position) pair. These vectors can then
 | |
|     be reused, especially for beam-search.
 | |
| 
 | |
|     Let's say we're using 12 features for each state, e.g. word at start of
 | |
|     buffer, three words on stack, their children, etc. In the normal arc-eager
 | |
|     system, a document of length N is processed in 2*N states. This means we'll
 | |
|     create 2*N*12 feature vectors --- but if we pre-compute, we only need
 | |
|     N*12 vector computations. The saving for beam-search is much better:
 | |
|     if we have a beam of k, we'll normally make 2*N*12*K computations --
 | |
|     so we can save the factor k. This also gives a nice CPU/GPU division:
 | |
|     we can do all our hard maths up front, packed into large multiplications,
 | |
|     and do the hard-to-program parsing on the CPU.
 | |
|     """
 | |
|     cdef readonly int nF, nO, nP
 | |
|     cdef bint _is_synchronized
 | |
|     cdef public object ops
 | |
|     cdef public object numpy_ops
 | |
|     cdef public object _cpu_ops
 | |
|     cdef np.ndarray _features
 | |
|     cdef np.ndarray _cached
 | |
|     cdef np.ndarray bias
 | |
|     cdef object _cuda_stream
 | |
|     cdef object _bp_hiddens
 | |
|     cdef object activation
 | |
| 
 | |
|     def __init__(self, batch_size, tokvecs, lower_model, cuda_stream=None,
 | |
|                  activation="maxout", train=False):
 | |
|         gpu_cached, bp_features = lower_model(tokvecs, train)
 | |
|         cdef np.ndarray cached
 | |
|         if not isinstance(gpu_cached, numpy.ndarray):
 | |
|             # Note the passing of cuda_stream here: it lets
 | |
|             # cupy make the copy asynchronously.
 | |
|             # We then have to block before first use.
 | |
|             cached = gpu_cached.get(stream=cuda_stream)
 | |
|         else:
 | |
|             cached = gpu_cached
 | |
|         if not isinstance(lower_model.get_param("b"), numpy.ndarray):
 | |
|             self.bias = lower_model.get_param("b").get(stream=cuda_stream)
 | |
|         else:
 | |
|             self.bias = lower_model.get_param("b")
 | |
|         self.nF = cached.shape[1]
 | |
|         if lower_model.has_dim("nP"):
 | |
|             self.nP = lower_model.get_dim("nP")
 | |
|         else:
 | |
|             self.nP = 1
 | |
|         self.nO = cached.shape[2]
 | |
|         self.ops = lower_model.ops
 | |
|         self.numpy_ops = NumpyOps()
 | |
|         self._cpu_ops = get_ops("cpu") if isinstance(self.ops, CupyOps) else self.ops
 | |
|         assert activation in (None, "relu", "maxout")
 | |
|         self.activation = activation
 | |
|         self._is_synchronized = False
 | |
|         self._cuda_stream = cuda_stream
 | |
|         self._cached = cached
 | |
|         self._bp_hiddens = bp_features
 | |
| 
 | |
|     cdef const float* get_feat_weights(self) except NULL:
 | |
|         if not self._is_synchronized and self._cuda_stream is not None:
 | |
|             self._cuda_stream.synchronize()
 | |
|             self._is_synchronized = True
 | |
|         return <float*>self._cached.data
 | |
| 
 | |
|     def has_dim(self, name):
 | |
|         if name == "nF":
 | |
|             return self.nF if self.nF is not None else True
 | |
|         elif name == "nP":
 | |
|             return self.nP if self.nP is not None else True
 | |
|         elif name == "nO":
 | |
|             return self.nO if self.nO is not None else True
 | |
|         else:
 | |
|             return False
 | |
| 
 | |
|     def get_dim(self, name):
 | |
|         if name == "nF":
 | |
|             return self.nF
 | |
|         elif name == "nP":
 | |
|             return self.nP
 | |
|         elif name == "nO":
 | |
|             return self.nO
 | |
|         else:
 | |
|             raise ValueError(Errors.E1033.format(name=name))
 | |
| 
 | |
|     def set_dim(self, name, value):
 | |
|         if name == "nF":
 | |
|             self.nF = value
 | |
|         elif name == "nP":
 | |
|             self.nP = value
 | |
|         elif name == "nO":
 | |
|             self.nO = value
 | |
|         else:
 | |
|             raise ValueError(Errors.E1033.format(name=name))
 | |
| 
 | |
|     def __call__(self, X, bint is_train):
 | |
|         if is_train:
 | |
|             return self.begin_update(X)
 | |
|         else:
 | |
|             return self.predict(X), lambda X: X
 | |
| 
 | |
|     def predict(self, X):
 | |
|         return self.begin_update(X)[0]
 | |
| 
 | |
|     def begin_update(self, token_ids):
 | |
|         cdef np.ndarray state_vector = numpy.zeros(
 | |
|             (token_ids.shape[0], self.nO, self.nP), dtype='f')
 | |
|         # This is tricky, but (assuming GPU available);
 | |
|         # - Input to forward on CPU
 | |
|         # - Output from forward on CPU
 | |
|         # - Input to backward on GPU!
 | |
|         # - Output from backward on GPU
 | |
|         bp_hiddens = self._bp_hiddens
 | |
| 
 | |
|         cdef CBlas cblas = self._cpu_ops.cblas()
 | |
| 
 | |
|         feat_weights = self.get_feat_weights()
 | |
|         cdef int[:, ::1] ids = token_ids
 | |
|         sum_state_features(cblas, <float*>state_vector.data,
 | |
|             feat_weights, &ids[0,0],
 | |
|             token_ids.shape[0], self.nF, self.nO*self.nP)
 | |
|         state_vector += self.bias
 | |
|         state_vector, bp_nonlinearity = self._nonlinearity(state_vector)
 | |
| 
 | |
|         def backward(d_state_vector_ids):
 | |
|             d_state_vector, token_ids = d_state_vector_ids
 | |
|             d_state_vector = bp_nonlinearity(d_state_vector)
 | |
|             d_tokens = bp_hiddens((d_state_vector, token_ids))
 | |
|             return d_tokens
 | |
|         return state_vector, backward
 | |
| 
 | |
|     def _nonlinearity(self, state_vector):
 | |
|         if self.activation == "maxout":
 | |
|             return self._maxout_nonlinearity(state_vector)
 | |
|         else:
 | |
|             return self._relu_nonlinearity(state_vector)
 | |
| 
 | |
|     def _maxout_nonlinearity(self, state_vector):
 | |
|         state_vector, mask = self.numpy_ops.maxout(state_vector)
 | |
|         # We're outputting to CPU, but we need this variable on GPU for the
 | |
|         # backward pass.
 | |
|         mask = self.ops.asarray(mask)
 | |
| 
 | |
|         def backprop_maxout(d_best):
 | |
|             return self.ops.backprop_maxout(d_best, mask, self.nP)
 | |
|         
 | |
|         return state_vector, backprop_maxout
 | |
| 
 | |
|     def _relu_nonlinearity(self, state_vector):
 | |
|         state_vector = state_vector.reshape((state_vector.shape[0], -1))
 | |
|         mask = state_vector >= 0.
 | |
|         state_vector *= mask
 | |
|         # We're outputting to CPU, but we need this variable on GPU for the
 | |
|         # backward pass.
 | |
|         mask = self.ops.asarray(mask)
 | |
| 
 | |
|         def backprop_relu(d_best):
 | |
|             d_best *= mask
 | |
|             return d_best.reshape((d_best.shape + (1,)))
 | |
|  
 | |
|         return state_vector, backprop_relu
 | |
| 
 | |
| cdef inline int _arg_max(const float* scores, const int n_classes) nogil:
 | |
|     if n_classes == 2:
 | |
|         return 0 if scores[0] > scores[1] else 1
 | |
|     cdef int i
 | |
|     cdef int best = 0
 | |
|     cdef float mode = scores[0]
 | |
|     for i in range(1, n_classes):
 | |
|         if scores[i] > mode:
 | |
|             mode = scores[i]
 | |
|             best = i
 | |
|     return best
 |