mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-28 10:14:07 +03:00
Fix non-maxout parser
This commit is contained in:
parent
1036798155
commit
e7556ff048
|
@ -117,7 +117,7 @@ cdef class precompute_hiddens:
|
||||||
cached = gpu_cached
|
cached = gpu_cached
|
||||||
self.nF = cached.shape[1]
|
self.nF = cached.shape[1]
|
||||||
self.nP = getattr(lower_model, 'nP', 1)
|
self.nP = getattr(lower_model, 'nP', 1)
|
||||||
self.nO = cached.shape[2] // self.nP
|
self.nO = cached.shape[2]
|
||||||
self.ops = lower_model.ops
|
self.ops = lower_model.ops
|
||||||
self.bias = lower_model.b
|
self.bias = lower_model.b
|
||||||
self._is_synchronized = False
|
self._is_synchronized = False
|
||||||
|
@ -150,7 +150,7 @@ cdef class precompute_hiddens:
|
||||||
sum_state_features(<float*>state_vector.data,
|
sum_state_features(<float*>state_vector.data,
|
||||||
feat_weights, &ids[0,0],
|
feat_weights, &ids[0,0],
|
||||||
token_ids.shape[0], self.nF, self.nO*self.nP)
|
token_ids.shape[0], self.nF, self.nO*self.nP)
|
||||||
state_vector += self.bias.ravel()
|
state_vector += self.bias
|
||||||
state_vector, bp_nonlinearity = self._nonlinearity(state_vector)
|
state_vector, bp_nonlinearity = self._nonlinearity(state_vector)
|
||||||
|
|
||||||
def backward(d_state_vector, sgd=None):
|
def backward(d_state_vector, sgd=None):
|
||||||
|
@ -164,6 +164,7 @@ cdef class precompute_hiddens:
|
||||||
|
|
||||||
def _nonlinearity(self, state_vector):
|
def _nonlinearity(self, state_vector):
|
||||||
if self.nP == 1:
|
if self.nP == 1:
|
||||||
|
state_vector = state_vector.reshape(state_vector.shape[:-1])
|
||||||
mask = state_vector >= 0.
|
mask = state_vector >= 0.
|
||||||
state_vector *= mask
|
state_vector *= mask
|
||||||
else:
|
else:
|
||||||
|
@ -171,7 +172,9 @@ cdef class precompute_hiddens:
|
||||||
|
|
||||||
def backprop_nonlinearity(d_best, sgd=None):
|
def backprop_nonlinearity(d_best, sgd=None):
|
||||||
if self.nP == 1:
|
if self.nP == 1:
|
||||||
return d_best * mask
|
d_best *= mask
|
||||||
|
d_best = d_best.reshape((d_best.shape + (1,)))
|
||||||
|
return d_best
|
||||||
else:
|
else:
|
||||||
return self.ops.backprop_maxout(d_best, mask, self.nP)
|
return self.ops.backprop_maxout(d_best, mask, self.nP)
|
||||||
return state_vector, backprop_nonlinearity
|
return state_vector, backprop_nonlinearity
|
||||||
|
|
Loading…
Reference in New Issue
Block a user