mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 18:06:29 +03:00
* Add separate Param and AdadeltaParam classes. AdadeltaParam seems broken.
This commit is contained in:
parent
1dff04acb5
commit
df8179ca4f
|
@ -43,6 +43,39 @@ floatX = theano.config.floatX
|
||||||
def th_share(w, name=''):
|
def th_share(w, name=''):
|
||||||
return theano.shared(value=w, borrow=True, name=name)
|
return theano.shared(value=w, borrow=True, name=name)
|
||||||
|
|
||||||
|
class Param(object):
|
||||||
|
def __init__(self, numpy_data, name='?', wrapper=th_share):
|
||||||
|
self.curr = wrapper(numpy_data, name=name+'_curr')
|
||||||
|
self.step = wrapper(numpy.zeros(numpy_data.shape, numpy_data.dtype),
|
||||||
|
name=name+'_step')
|
||||||
|
|
||||||
|
def updates(self, cost, timestep, eta, mu):
|
||||||
|
step = (mu * self.step) - T.grad(cost, self.curr)
|
||||||
|
curr = self.curr + (eta * step)
|
||||||
|
return [(self.curr, curr), (self.step, step)]
|
||||||
|
|
||||||
|
|
||||||
|
class AdadeltaParam(object):
|
||||||
|
def __init__(self, numpy_data, name='?', wrapper=th_share):
|
||||||
|
self.curr = wrapper(numpy_data, name=name+'_curr')
|
||||||
|
# accu: accumulate gradient magnitudes
|
||||||
|
self.accu = wrapper(numpy.zeros(numpy_data.shape, dtype=numpy_data.dtype))
|
||||||
|
# delta_accu: accumulate update magnitudes (recursively!)
|
||||||
|
self.delta_accu = wrapper(numpy.zeros(numpy_data.shape, dtype=numpy_data.dtype))
|
||||||
|
|
||||||
|
def updates(self, cost, timestep, eps, rho):
|
||||||
|
# update accu (as in rmsprop)
|
||||||
|
grad = T.grad(cost, self.curr)
|
||||||
|
accu_new = rho * self.accu + (1 - rho) * grad ** 2
|
||||||
|
|
||||||
|
# compute parameter update, using the 'old' delta_accu
|
||||||
|
update = (grad * T.sqrt(self.delta_accu + eps) /
|
||||||
|
T.sqrt(accu_new + eps))
|
||||||
|
# update delta_accu (as accu, but accumulating updates)
|
||||||
|
delta_accu_new = rho * self.delta_accu + (1 - rho) * update ** 2
|
||||||
|
return [(self.curr, self.curr - update), (self.accu, accu_new),
|
||||||
|
(self.delta_accu, delta_accu_new)]
|
||||||
|
|
||||||
|
|
||||||
class AvgParam(object):
|
class AvgParam(object):
|
||||||
def __init__(self, numpy_data, name='?', wrapper=th_share):
|
def __init__(self, numpy_data, name='?', wrapper=th_share):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user