TextCatParametricAttention.v1: set key transform dimensions

This is necessary for tok2vec implementations that initialize
lazily (e.g. curated transformers).
This commit is contained in:
Daniël de Kok 2024-01-19 10:28:54 +01:00
parent e2a3952de5
commit ba15b8edd0

View File

@ -264,6 +264,7 @@ def _build_parametric_attention_with_residual_nonlinear(
parametric_attention.set_ref("tok2vec", tok2vec)
parametric_attention.set_ref("attention_layer", attention_layer)
parametric_attention.set_ref("key_transform", key_transform)
parametric_attention.set_ref("nonlinear_layer", nonlinear_layer)
parametric_attention.set_ref("norm_layer", norm_layer)
@ -273,8 +274,10 @@ def _build_parametric_attention_with_residual_nonlinear(
def _init_parametric_attention_with_residual_nonlinear(model, X, Y) -> Model:
tok2vec_width = get_tok2vec_width(model)
model.get_ref("attention_layer").set_dim("nO", tok2vec_width)
model.get_ref("nonlinear_layer").set_dim("nO", tok2vec_width)
model.get_ref("key_transform").set_dim("nI", tok2vec_width)
model.get_ref("key_transform").set_dim("nO", tok2vec_width)
model.get_ref("nonlinear_layer").set_dim("nI", tok2vec_width)
model.get_ref("nonlinear_layer").set_dim("nO", tok2vec_width)
model.get_ref("norm_layer").set_dim("nI", tok2vec_width)
model.get_ref("norm_layer").set_dim("nO", tok2vec_width)
init_chain(model, X, Y)