From 807cb2e3703806c96c8e4c47db2fc0ef5380d39b Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Mon, 8 May 2017 14:24:43 +0200 Subject: [PATCH] Add PretrainableMaxouts --- spacy/_ml.py | 58 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/spacy/_ml.py b/spacy/_ml.py index e62971a69..d3bb903e7 100644 --- a/spacy/_ml.py +++ b/spacy/_ml.py @@ -61,6 +61,64 @@ class PrecomputableAffine(Model): return Yf, backward +@describe.on_data(_set_dimensions_if_needed) +@describe.attributes( + nI=Dimension("Input size"), + nF=Dimension("Number of features"), + nP=Dimension("Number of pieces"), + nO=Dimension("Output size"), + W=Synapses("Weights matrix", + lambda obj: (obj.nF, obj.nO, obj.nP, obj.nI), + lambda W, ops: ops.xavier_uniform_init(W)), + b=Biases("Bias vector", + lambda obj: (obj.nO, obj.nP)), + d_W=Gradient("W"), + d_b=Gradient("b") +) +class PrecomputableMaxouts(Model): + def __init__(self, nO=None, nI=None, nF=None, pieces=2, **kwargs): + Model.__init__(self, **kwargs) + self.nO = nO + self.nP = pieces + self.nI = nI + self.nF = nF + + def begin_update(self, X, drop=0.): + # X: (b, i) + # Yfp: (f, b, o, p) + # Yf: (f, b, o) + # Xf: (b, f, i) + # dY: (b, o) + # dYp: (b, o, p) + # W: (f, o, p, i) + # b: (o, p) + + Yfp = numpy.einsum('bi,fopi->fbop', X, self.W) + Yfp += self.b + Yf = self.ops.allocate((self.nF, X.shape[0], self.nO)) + which = self.ops.allocate((self.nF, X.shape[0], self.nO), dtype='i') + for i in range(self.nF): + Yf[i], which[i] = self.ops.maxout(Yfp[i]) + def backward(dY_ids, sgd=None): + dY, ids = dY_ids + Xf = X[ids] + dYp = self.ops.allocate((dY.shape[0], self.nO, self.nP)) + for i in range(self.nF): + dYp += self.ops.backprop_maxout(dY, which[i], self.nP) + + dXf = numpy.einsum('bop,fopi->bfi', dYp, self.W) + dW = numpy.einsum('bop,bfi->fopi', dYp, Xf) + db = dYp.sum(axis=0) + + self.d_W += dW + self.d_b += db + + if sgd is not None: + sgd(self._mem.weights, self._mem.gradient, key=self.id) + return dXf + return Yf, backward + + def get_col(idx): def forward(X, drop=0.): assert len(X.shape) <= 3