diff --git a/spacy/ml/models/textcat.py b/spacy/ml/models/textcat.py index 3e5471ab3..a9aba27fb 100644 --- a/spacy/ml/models/textcat.py +++ b/spacy/ml/models/textcat.py @@ -264,6 +264,7 @@ def _build_parametric_attention_with_residual_nonlinear( parametric_attention.set_ref("tok2vec", tok2vec) parametric_attention.set_ref("attention_layer", attention_layer) + parametric_attention.set_ref("key_transform", key_transform) parametric_attention.set_ref("nonlinear_layer", nonlinear_layer) parametric_attention.set_ref("norm_layer", norm_layer) @@ -273,8 +274,10 @@ def _build_parametric_attention_with_residual_nonlinear( def _init_parametric_attention_with_residual_nonlinear(model, X, Y) -> Model: tok2vec_width = get_tok2vec_width(model) model.get_ref("attention_layer").set_dim("nO", tok2vec_width) - model.get_ref("nonlinear_layer").set_dim("nO", tok2vec_width) + model.get_ref("key_transform").set_dim("nI", tok2vec_width) + model.get_ref("key_transform").set_dim("nO", tok2vec_width) model.get_ref("nonlinear_layer").set_dim("nI", tok2vec_width) + model.get_ref("nonlinear_layer").set_dim("nO", tok2vec_width) model.get_ref("norm_layer").set_dim("nI", tok2vec_width) model.get_ref("norm_layer").set_dim("nO", tok2vec_width) init_chain(model, X, Y)