diff --git a/spacy/ml/models/textcat.py b/spacy/ml/models/textcat.py index e0c11ed99..0234530e6 100644 --- a/spacy/ml/models/textcat.py +++ b/spacy/ml/models/textcat.py @@ -107,6 +107,7 @@ def init_ensemble_textcat(model, X, Y) -> Model: model.get_ref("maxout_layer").set_dim("nO", tok2vec_width) model.get_ref("maxout_layer").set_dim("nI", tok2vec_width) model.get_ref("norm_layer").set_dim("nI", tok2vec_width) + model.get_ref("norm_layer").set_dim("nO", tok2vec_width) init_chain(model, X, Y) return model