diff --git a/spacy/ml/tb_framework.py b/spacy/ml/tb_framework.py index 21779ddaa..a691ab4ae 100644 --- a/spacy/ml/tb_framework.py +++ b/spacy/ml/tb_framework.py @@ -39,7 +39,7 @@ def forward(model, X, is_train): def init(model, X=None, Y=None): model.get_ref("tok2vec").initialize(X=X) - model.get_ref("lower").initialize() + lower = model.get_ref("lower").initialize() if model.attrs["has_upper"]: statevecs = model.ops.alloc2f(2, lower.get_dim("nO")) model.get_ref("upper").initialize(X=statevecs)