mirror of
https://github.com/explosion/spaCy.git
synced 2025-04-20 00:51:58 +03:00
TextCatParametricAttention.v1: set key transform dimensions
This is necessary for tok2vec implementations that initialize lazily (e.g. curated transformers).
This commit is contained in:
parent
e2a3952de5
commit
ba15b8edd0
|
@ -264,6 +264,7 @@ def _build_parametric_attention_with_residual_nonlinear(
|
|||
|
||||
parametric_attention.set_ref("tok2vec", tok2vec)
|
||||
parametric_attention.set_ref("attention_layer", attention_layer)
|
||||
parametric_attention.set_ref("key_transform", key_transform)
|
||||
parametric_attention.set_ref("nonlinear_layer", nonlinear_layer)
|
||||
parametric_attention.set_ref("norm_layer", norm_layer)
|
||||
|
||||
|
@ -273,8 +274,10 @@ def _build_parametric_attention_with_residual_nonlinear(
|
|||
def _init_parametric_attention_with_residual_nonlinear(model, X, Y) -> Model:
|
||||
tok2vec_width = get_tok2vec_width(model)
|
||||
model.get_ref("attention_layer").set_dim("nO", tok2vec_width)
|
||||
model.get_ref("nonlinear_layer").set_dim("nO", tok2vec_width)
|
||||
model.get_ref("key_transform").set_dim("nI", tok2vec_width)
|
||||
model.get_ref("key_transform").set_dim("nO", tok2vec_width)
|
||||
model.get_ref("nonlinear_layer").set_dim("nI", tok2vec_width)
|
||||
model.get_ref("nonlinear_layer").set_dim("nO", tok2vec_width)
|
||||
model.get_ref("norm_layer").set_dim("nI", tok2vec_width)
|
||||
model.get_ref("norm_layer").set_dim("nO", tok2vec_width)
|
||||
init_chain(model, X, Y)
|
||||
|
|
Loading…
Reference in New Issue
Block a user