mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-27 10:26:35 +03:00
fix textcat init functionality
This commit is contained in:
parent
c27679f210
commit
5992e927b9
|
@ -169,23 +169,6 @@ def build_text_classifier_v2(
|
||||||
model.set_ref("output_layer", linear_model.get_ref("output_layer"))
|
model.set_ref("output_layer", linear_model.get_ref("output_layer"))
|
||||||
model.attrs["multi_label"] = not exclusive_classes
|
model.attrs["multi_label"] = not exclusive_classes
|
||||||
|
|
||||||
model.init = init_ensemble_textcat # type: ignore[assignment]
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def init_ensemble_textcat(model, X, Y) -> Model:
|
|
||||||
# When tok2vec is lazily initialized, we need to initialize it before
|
|
||||||
# the rest of the chain to ensure that we can get its width.
|
|
||||||
tok2vec = model.get_ref("tok2vec")
|
|
||||||
tok2vec.initialize(X)
|
|
||||||
|
|
||||||
tok2vec_width = get_tok2vec_width(model)
|
|
||||||
model.get_ref("attention_layer").set_dim("nO", tok2vec_width)
|
|
||||||
model.get_ref("maxout_layer").set_dim("nO", tok2vec_width)
|
|
||||||
model.get_ref("maxout_layer").set_dim("nI", tok2vec_width)
|
|
||||||
model.get_ref("norm_layer").set_dim("nI", tok2vec_width)
|
|
||||||
model.get_ref("norm_layer").set_dim("nO", tok2vec_width)
|
|
||||||
init_chain(model, X, Y)
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@ -273,7 +256,9 @@ def _init_parametric_attention_with_residual_nonlinear(model, X, Y) -> Model:
|
||||||
|
|
||||||
tok2vec_width = get_tok2vec_width(model)
|
tok2vec_width = get_tok2vec_width(model)
|
||||||
model.get_ref("attention_layer").set_dim("nO", tok2vec_width)
|
model.get_ref("attention_layer").set_dim("nO", tok2vec_width)
|
||||||
|
if model.get_ref("key_transform").has_dim("nI"):
|
||||||
model.get_ref("key_transform").set_dim("nI", tok2vec_width)
|
model.get_ref("key_transform").set_dim("nI", tok2vec_width)
|
||||||
|
if model.get_ref("key_transform").has_dim("nO"):
|
||||||
model.get_ref("key_transform").set_dim("nO", tok2vec_width)
|
model.get_ref("key_transform").set_dim("nO", tok2vec_width)
|
||||||
model.get_ref("nonlinear_layer").set_dim("nI", tok2vec_width)
|
model.get_ref("nonlinear_layer").set_dim("nI", tok2vec_width)
|
||||||
model.get_ref("nonlinear_layer").set_dim("nO", tok2vec_width)
|
model.get_ref("nonlinear_layer").set_dim("nO", tok2vec_width)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user