mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
CPU/GPU compat
This commit is contained in:
parent
f99f5b75dc
commit
700979fb3c
|
@ -256,7 +256,7 @@ cdef class Parser:
|
|||
|
||||
self._cost_batch(costs, is_valid, states, golds)
|
||||
self._set_gradient(d_scores, scores, is_valid, costs)
|
||||
losses.append(numpy.abs(d_scores).sum())
|
||||
losses.append(self.model.ops.xp.abs(d_scores).sum())
|
||||
if force_gold:
|
||||
softmaxed *= costs <= 0
|
||||
return finish_update(d_scores, sgd=sgd)
|
||||
|
@ -312,9 +312,9 @@ cdef class Parser:
|
|||
n = gradients.shape[0]
|
||||
scores = scores * is_valid
|
||||
g_scores = scores * is_valid * (costs <= 0.)
|
||||
exps = numpy.exp(scores - scores.max(axis=1).reshape((n, 1)))
|
||||
exps = self.model.ops.xp.exp(scores - scores.max(axis=1).reshape((n, 1)))
|
||||
exps *= is_valid
|
||||
g_exps = numpy.exp(g_scores - g_scores.max(axis=1).reshape((n, 1)))
|
||||
g_exps = self.model.ops.xp.exp(g_scores - g_scores.max(axis=1).reshape((n, 1)))
|
||||
g_exps *= costs <= 0.
|
||||
g_exps *= is_valid
|
||||
gradients[:] = exps / exps.sum(axis=1).reshape((n, 1))
|
||||
|
|
Loading…
Reference in New Issue
Block a user