diff --git a/spacy/ml/_precomputable_affine.py b/spacy/ml/_precomputable_affine.py index c7328bad9..ec95cdafd 100644 --- a/spacy/ml/_precomputable_affine.py +++ b/spacy/ml/_precomputable_affine.py @@ -110,7 +110,8 @@ def init(model, X=None, Y=None): pad = model.ops.alloc4f(1, nF, nO, nP) ops = model.ops - W = normal_init(ops, W.shape, fan_in=nF * nI) + scale = float(ops.xp.sqrt(1.0 / (nF * nI))) + W = normal_init(ops, W.shape, mean=scale) model.set_param("W", W) model.set_param("b", b) model.set_param("pad", pad)