fix textcat init functionality

This commit is contained in:
svlandeg 2024-05-14 18:38:11 +02:00
parent c27679f210
commit 5992e927b9

View File

@ -169,23 +169,6 @@ def build_text_classifier_v2(
model.set_ref("output_layer", linear_model.get_ref("output_layer"))
model.attrs["multi_label"] = not exclusive_classes
model.init = init_ensemble_textcat # type: ignore[assignment]
return model
def init_ensemble_textcat(model, X, Y) -> Model:
# When tok2vec is lazily initialized, we need to initialize it before
# the rest of the chain to ensure that we can get its width.
tok2vec = model.get_ref("tok2vec")
tok2vec.initialize(X)
tok2vec_width = get_tok2vec_width(model)
model.get_ref("attention_layer").set_dim("nO", tok2vec_width)
model.get_ref("maxout_layer").set_dim("nO", tok2vec_width)
model.get_ref("maxout_layer").set_dim("nI", tok2vec_width)
model.get_ref("norm_layer").set_dim("nI", tok2vec_width)
model.get_ref("norm_layer").set_dim("nO", tok2vec_width)
init_chain(model, X, Y)
return model
@ -273,8 +256,10 @@ def _init_parametric_attention_with_residual_nonlinear(model, X, Y) -> Model:
tok2vec_width = get_tok2vec_width(model)
model.get_ref("attention_layer").set_dim("nO", tok2vec_width)
model.get_ref("key_transform").set_dim("nI", tok2vec_width)
model.get_ref("key_transform").set_dim("nO", tok2vec_width)
if model.get_ref("key_transform").has_dim("nI"):
model.get_ref("key_transform").set_dim("nI", tok2vec_width)
if model.get_ref("key_transform").has_dim("nO"):
model.get_ref("key_transform").set_dim("nO", tok2vec_width)
model.get_ref("nonlinear_layer").set_dim("nI", tok2vec_width)
model.get_ref("nonlinear_layer").set_dim("nO", tok2vec_width)
model.get_ref("norm_layer").set_dim("nI", tok2vec_width)