From 5992e927b9db218e27afeaba4f3bd5d76a64fca4 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Tue, 14 May 2024 18:38:11 +0200 Subject: [PATCH] fix textcat init functionality --- spacy/ml/models/textcat.py | 23 ++++------------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/spacy/ml/models/textcat.py b/spacy/ml/models/textcat.py index b0a6d78a6..61c8681a4 100644 --- a/spacy/ml/models/textcat.py +++ b/spacy/ml/models/textcat.py @@ -169,23 +169,6 @@ def build_text_classifier_v2( model.set_ref("output_layer", linear_model.get_ref("output_layer")) model.attrs["multi_label"] = not exclusive_classes - model.init = init_ensemble_textcat # type: ignore[assignment] - return model - - -def init_ensemble_textcat(model, X, Y) -> Model: - # When tok2vec is lazily initialized, we need to initialize it before - # the rest of the chain to ensure that we can get its width. - tok2vec = model.get_ref("tok2vec") - tok2vec.initialize(X) - - tok2vec_width = get_tok2vec_width(model) - model.get_ref("attention_layer").set_dim("nO", tok2vec_width) - model.get_ref("maxout_layer").set_dim("nO", tok2vec_width) - model.get_ref("maxout_layer").set_dim("nI", tok2vec_width) - model.get_ref("norm_layer").set_dim("nI", tok2vec_width) - model.get_ref("norm_layer").set_dim("nO", tok2vec_width) - init_chain(model, X, Y) return model @@ -273,8 +256,10 @@ def _init_parametric_attention_with_residual_nonlinear(model, X, Y) -> Model: tok2vec_width = get_tok2vec_width(model) model.get_ref("attention_layer").set_dim("nO", tok2vec_width) - model.get_ref("key_transform").set_dim("nI", tok2vec_width) - model.get_ref("key_transform").set_dim("nO", tok2vec_width) + if model.get_ref("key_transform").has_dim("nI"): + model.get_ref("key_transform").set_dim("nI", tok2vec_width) + if model.get_ref("key_transform").has_dim("nO"): + model.get_ref("key_transform").set_dim("nO", tok2vec_width) model.get_ref("nonlinear_layer").set_dim("nI", tok2vec_width) model.get_ref("nonlinear_layer").set_dim("nO", tok2vec_width) model.get_ref("norm_layer").set_dim("nI", tok2vec_width)