diff --git a/spacy/_ml.py b/spacy/_ml.py index f589704a6..ac7849bbb 100644 --- a/spacy/_ml.py +++ b/spacy/_ml.py @@ -19,6 +19,8 @@ import numpy def _init_for_precomputed(W, ops): + if (W**2).sum() != 0.: + return reshaped = W.reshape((W.shape[1], W.shape[0] * W.shape[2])) ops.xavier_uniform_init(reshaped) W[:] = reshaped.reshape(W.shape) @@ -247,6 +249,7 @@ def doc2feats(cols=None): model.cols = cols return model + def print_shape(prefix): def forward(X, drop=0.): return X, lambda dX, **kwargs: dX