Construct TextCatEnsemble.v2 using helper function

This commit is contained in:
Daniël de Kok 2024-01-24 14:59:01 +01:00
parent ce4ea5ffa7
commit e722284ff4

View File

@ -19,6 +19,7 @@ from thinc.api import (
clone,
concatenate,
list2ragged,
noop,
reduce_first,
reduce_last,
reduce_max,
@ -148,50 +149,26 @@ def build_text_classifier_v2(
linear_model: Model[List[Doc], Floats2d],
nO: Optional[int] = None,
) -> Model[List[Doc], Floats2d]:
# TODO: build the model with _build_parametric_attention_with_residual_nonlinear
# in spaCy v4. We don't do this in spaCy v3 to preserve model
# compatibility.
width = tok2vec.maybe_get_dim("nO")
exclusive_classes = not linear_model.attrs["multi_label"]
parametric_attention = _build_parametric_attention_with_residual_nonlinear(
tok2vec=tok2vec,
nonlinear_layer=Maxout(nI=width, nO=width),
key_transform=noop(),
)
with Model.define_operators({">>": chain, "|": concatenate}):
width = tok2vec.maybe_get_dim("nO")
attention_layer = ParametricAttention(width)
maxout_layer = Maxout(nO=width, nI=width)
norm_layer = LayerNorm(nI=width)
cnn_model = (
tok2vec
>> list2ragged()
>> attention_layer
>> reduce_sum()
>> residual(maxout_layer >> norm_layer >> Dropout(0.0))
)
nO_double = nO * 2 if nO else None
if exclusive_classes:
output_layer = Softmax(nO=nO, nI=nO_double)
else:
output_layer = Linear(nO=nO, nI=nO_double) >> Logistic()
model = (linear_model | cnn_model) >> output_layer
model = (linear_model | parametric_attention) >> output_layer
model.set_ref("tok2vec", tok2vec)
if model.has_dim("nO") is not False and nO is not None:
model.set_dim("nO", cast(int, nO))
model.set_ref("output_layer", linear_model.get_ref("output_layer"))
model.set_ref("attention_layer", attention_layer)
model.set_ref("maxout_layer", maxout_layer)
model.set_ref("norm_layer", norm_layer)
model.attrs["multi_label"] = not exclusive_classes
model.init = init_ensemble_textcat # type: ignore[assignment]
return model
def init_ensemble_textcat(model, X, Y) -> Model:
tok2vec_width = get_tok2vec_width(model)
model.get_ref("attention_layer").set_dim("nO", tok2vec_width)
model.get_ref("maxout_layer").set_dim("nO", tok2vec_width)
model.get_ref("maxout_layer").set_dim("nI", tok2vec_width)
model.get_ref("norm_layer").set_dim("nI", tok2vec_width)
model.get_ref("norm_layer").set_dim("nO", tok2vec_width)
init_chain(model, X, Y)
return model