mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 18:56:36 +03:00
Make PrecomputableAffines work
This commit is contained in:
parent
61bc203f3f
commit
03a215c5fd
48
spacy/_ml.py
48
spacy/_ml.py
|
@ -30,6 +30,8 @@ from . import util
|
||||||
import numpy
|
import numpy
|
||||||
import io
|
import io
|
||||||
|
|
||||||
|
from blis.py import einsum
|
||||||
|
|
||||||
# TODO: Unset this once we don't want to support models previous models.
|
# TODO: Unset this once we don't want to support models previous models.
|
||||||
import thinc.neural._classes.layernorm
|
import thinc.neural._classes.layernorm
|
||||||
thinc.neural._classes.layernorm.set_compat_six_eight(False)
|
thinc.neural._classes.layernorm.set_compat_six_eight(False)
|
||||||
|
@ -105,9 +107,7 @@ def _preprocess_doc(docs, drop=0.):
|
||||||
def _init_for_precomputed(W, ops):
|
def _init_for_precomputed(W, ops):
|
||||||
if (W**2).sum() != 0.:
|
if (W**2).sum() != 0.:
|
||||||
return
|
return
|
||||||
reshaped = W.reshape((W.shape[1], W.shape[0] * W.shape[2]))
|
ops.xavier_uniform_init(W, inplace=True)
|
||||||
ops.xavier_uniform_init(reshaped)
|
|
||||||
W[:] = reshaped.reshape(W.shape)
|
|
||||||
|
|
||||||
|
|
||||||
@describe.on_data(_set_dimensions_if_needed)
|
@describe.on_data(_set_dimensions_if_needed)
|
||||||
|
@ -116,7 +116,7 @@ def _init_for_precomputed(W, ops):
|
||||||
nF=Dimension("Number of features"),
|
nF=Dimension("Number of features"),
|
||||||
nO=Dimension("Output size"),
|
nO=Dimension("Output size"),
|
||||||
W=Synapses("Weights matrix",
|
W=Synapses("Weights matrix",
|
||||||
lambda obj: (obj.nF, obj.nO, obj.nI),
|
lambda obj: (obj.nI, obj.nF * obj.nO),
|
||||||
lambda W, ops: _init_for_precomputed(W, ops)),
|
lambda W, ops: _init_for_precomputed(W, ops)),
|
||||||
b=Biases("Bias vector",
|
b=Biases("Bias vector",
|
||||||
lambda obj: (obj.nO,)),
|
lambda obj: (obj.nO,)),
|
||||||
|
@ -130,31 +130,43 @@ class PrecomputableAffine(Model):
|
||||||
self.nI = nI
|
self.nI = nI
|
||||||
self.nF = nF
|
self.nF = nF
|
||||||
|
|
||||||
|
@property
|
||||||
|
def nIF(self):
|
||||||
|
return self.nI * self.nF
|
||||||
|
|
||||||
|
@property
|
||||||
|
def nFO(self):
|
||||||
|
return self.nF * self.nO
|
||||||
|
|
||||||
def begin_update(self, X, drop=0.):
|
def begin_update(self, X, drop=0.):
|
||||||
|
nN = X.shape[0]
|
||||||
# X: (b, i)
|
# X: (b, i)
|
||||||
# Yf: (b, f, i)
|
# Xf: (b, f, i)
|
||||||
|
# Yf: (b, f, o)
|
||||||
# dY: (b, o)
|
# dY: (b, o)
|
||||||
# dYf: (b, f, o)
|
# dYf: (b, f, o)
|
||||||
#Yf = numpy.einsum('bi,foi->bfo', X, self.W)
|
# W: (i, fo)
|
||||||
Yf = self.ops.xp.tensordot(
|
# Yf = numpy.einsum('bi,i_fo->b_fo', X, self.W)
|
||||||
X, self.W, axes=[[1], [2]])
|
Yf = einsum('ab,bc->ac', X, self.W).reshape((nN, self.nF, self.nO))
|
||||||
Yf += self.b
|
|
||||||
def backward(dY_ids, sgd=None):
|
def backward(dY_ids, sgd=None):
|
||||||
tensordot = self.ops.xp.tensordot
|
|
||||||
dY, ids = dY_ids
|
dY, ids = dY_ids
|
||||||
|
nB = ids.shape[0]
|
||||||
Xf = X[ids]
|
Xf = X[ids]
|
||||||
|
Xf = Xf.reshape((nB, self.nIF))
|
||||||
|
|
||||||
#dXf = numpy.einsum('bo,foi->bfi', dY, self.W)
|
dW_re = self.d_W.reshape((self.nIF, self.nO))
|
||||||
dXf = tensordot(dY, self.W, axes=[[1], [1]])
|
W_re = self.d_W.reshape((self.nIF, self.nO))
|
||||||
#dW = numpy.einsum('bo,bfi->ofi', dY, Xf)
|
# bo,if_o->bif
|
||||||
dW = tensordot(dY, Xf, axes=[[0], [0]])
|
dXf = einsum('ab,cb->ac', dY, W_re)
|
||||||
# ofi -> foi
|
# b_if,bo->if_o
|
||||||
self.d_W += dW.transpose((1, 0, 2))
|
einsum('ab,ac->bc', Xf, dY, out=dW_re)
|
||||||
self.d_b += dY.sum(axis=0)
|
# self.d_b += dY.sum(axis=0)
|
||||||
|
|
||||||
if sgd is not None:
|
if sgd is not None:
|
||||||
sgd(self._mem.weights, self._mem.gradient, key=self.id)
|
sgd(self._mem.weights, self._mem.gradient, key=self.id)
|
||||||
return dXf
|
dXf = dXf.reshape((nB, self.nI, self.nF))
|
||||||
|
dXf = dXf.transpose((0, 2, 1))
|
||||||
|
return self.ops.xp.ascontiguousarray(dXf)
|
||||||
return Yf, backward
|
return Yf, backward
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user