Clarify parser model CPU/GPU code

The previous version worked with previous thinc, but only
because some thinc ops happened to have gpu/cpu compatible
implementations. It's better to call the right Ops instance.
This commit is contained in:
Matthw Honnibal 2019-10-20 17:15:17 +02:00
parent ee56c6a4e1
commit 3a67aa857e

View File

@ -19,7 +19,7 @@ from thinc.extra.search cimport Beam
from thinc.api import chain, clone
from thinc.v2v import Model, Maxout, Affine
from thinc.misc import LayerNorm
from thinc.neural.ops import CupyOps
from thinc.neural.ops import CupyOps, NumpyOps
from thinc.neural.util import get_array_module
from thinc.linalg cimport Vec, VecVec
cimport blis.cy
@ -425,28 +425,38 @@ cdef class precompute_hiddens:
def backward(d_state_vector_ids, sgd=None):
d_state_vector, token_ids = d_state_vector_ids
d_state_vector = bp_nonlinearity(d_state_vector, sgd)
# This will usually be on GPU
if not isinstance(d_state_vector, self.ops.xp.ndarray):
d_state_vector = self.ops.xp.array(d_state_vector)
d_tokens = bp_hiddens((d_state_vector, token_ids), sgd)
return d_tokens
return state_vector, backward
def _nonlinearity(self, state_vector):
if isinstance(state_vector, numpy.ndarray):
ops = NumpyOps()
else:
ops = CupyOps()
if self.nP == 1:
state_vector = state_vector.reshape(state_vector.shape[:-1])
mask = state_vector >= 0.
state_vector *= mask
else:
state_vector, mask = self.ops.maxout(state_vector)
state_vector, mask = ops.maxout(state_vector)
def backprop_nonlinearity(d_best, sgd=None):
if isinstance(d_best, numpy.ndarray):
ops = NumpyOps()
else:
ops = CupyOps()
mask_ = ops.asarray(mask)
# This will usually be on GPU
d_best = ops.asarray(d_best)
# Fix nans (which can occur from unseen classes.)
d_best[self.ops.xp.isnan(d_best)] = 0.
d_best[ops.xp.isnan(d_best)] = 0.
if self.nP == 1:
d_best *= mask
d_best *= mask_
d_best = d_best.reshape((d_best.shape + (1,)))
return d_best
else:
return self.ops.backprop_maxout(d_best, mask, self.nP)
return ops.backprop_maxout(d_best, mask_, self.nP)
return state_vector, backprop_nonlinearity