mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
* Add separate Param and AdadeltaParam classes. AdadeltaParam seems broken.
This commit is contained in:
parent
1dff04acb5
commit
df8179ca4f
|
@ -43,6 +43,39 @@ floatX = theano.config.floatX
|
|||
def th_share(w, name=''):
|
||||
return theano.shared(value=w, borrow=True, name=name)
|
||||
|
||||
class Param(object):
|
||||
def __init__(self, numpy_data, name='?', wrapper=th_share):
|
||||
self.curr = wrapper(numpy_data, name=name+'_curr')
|
||||
self.step = wrapper(numpy.zeros(numpy_data.shape, numpy_data.dtype),
|
||||
name=name+'_step')
|
||||
|
||||
def updates(self, cost, timestep, eta, mu):
|
||||
step = (mu * self.step) - T.grad(cost, self.curr)
|
||||
curr = self.curr + (eta * step)
|
||||
return [(self.curr, curr), (self.step, step)]
|
||||
|
||||
|
||||
class AdadeltaParam(object):
|
||||
def __init__(self, numpy_data, name='?', wrapper=th_share):
|
||||
self.curr = wrapper(numpy_data, name=name+'_curr')
|
||||
# accu: accumulate gradient magnitudes
|
||||
self.accu = wrapper(numpy.zeros(numpy_data.shape, dtype=numpy_data.dtype))
|
||||
# delta_accu: accumulate update magnitudes (recursively!)
|
||||
self.delta_accu = wrapper(numpy.zeros(numpy_data.shape, dtype=numpy_data.dtype))
|
||||
|
||||
def updates(self, cost, timestep, eps, rho):
|
||||
# update accu (as in rmsprop)
|
||||
grad = T.grad(cost, self.curr)
|
||||
accu_new = rho * self.accu + (1 - rho) * grad ** 2
|
||||
|
||||
# compute parameter update, using the 'old' delta_accu
|
||||
update = (grad * T.sqrt(self.delta_accu + eps) /
|
||||
T.sqrt(accu_new + eps))
|
||||
# update delta_accu (as accu, but accumulating updates)
|
||||
delta_accu_new = rho * self.delta_accu + (1 - rho) * update ** 2
|
||||
return [(self.curr, self.curr - update), (self.accu, accu_new),
|
||||
(self.delta_accu, delta_accu_new)]
|
||||
|
||||
|
||||
class AvgParam(object):
|
||||
def __init__(self, numpy_data, name='?', wrapper=th_share):
|
||||
|
|
Loading…
Reference in New Issue
Block a user