fix the fix for textcat init functionality

This commit is contained in:
svlandeg 2024-05-14 18:45:51 +02:00
parent 5992e927b9
commit e32a394ff0

View File

@ -256,9 +256,9 @@ def _init_parametric_attention_with_residual_nonlinear(model, X, Y) -> Model:
tok2vec_width = get_tok2vec_width(model) tok2vec_width = get_tok2vec_width(model)
model.get_ref("attention_layer").set_dim("nO", tok2vec_width) model.get_ref("attention_layer").set_dim("nO", tok2vec_width)
if model.get_ref("key_transform").has_dim("nI"): if model.get_ref("key_transform").has_dim("nI") is None:
model.get_ref("key_transform").set_dim("nI", tok2vec_width) model.get_ref("key_transform").set_dim("nI", tok2vec_width)
if model.get_ref("key_transform").has_dim("nO"): if model.get_ref("key_transform").has_dim("nO") is None:
model.get_ref("key_transform").set_dim("nO", tok2vec_width) model.get_ref("key_transform").set_dim("nO", tok2vec_width)
model.get_ref("nonlinear_layer").set_dim("nI", tok2vec_width) model.get_ref("nonlinear_layer").set_dim("nI", tok2vec_width)
model.get_ref("nonlinear_layer").set_dim("nO", tok2vec_width) model.get_ref("nonlinear_layer").set_dim("nO", tok2vec_width)