diff --git a/spacy/ml/models/textcat.py b/spacy/ml/models/textcat.py index 3e5471ab3..2d6baa1bc 100644 --- a/spacy/ml/models/textcat.py +++ b/spacy/ml/models/textcat.py @@ -19,6 +19,7 @@ from thinc.api import ( clone, concatenate, list2ragged, + noop, reduce_first, reduce_last, reduce_max, @@ -148,50 +149,26 @@ def build_text_classifier_v2( linear_model: Model[List[Doc], Floats2d], nO: Optional[int] = None, ) -> Model[List[Doc], Floats2d]: - # TODO: build the model with _build_parametric_attention_with_residual_nonlinear - # in spaCy v4. We don't do this in spaCy v3 to preserve model - # compatibility. + width = tok2vec.maybe_get_dim("nO") exclusive_classes = not linear_model.attrs["multi_label"] + parametric_attention = _build_parametric_attention_with_residual_nonlinear( + tok2vec=tok2vec, + nonlinear_layer=Maxout(nI=width, nO=width), + key_transform=noop(), + ) with Model.define_operators({">>": chain, "|": concatenate}): - width = tok2vec.maybe_get_dim("nO") - attention_layer = ParametricAttention(width) - maxout_layer = Maxout(nO=width, nI=width) - norm_layer = LayerNorm(nI=width) - cnn_model = ( - tok2vec - >> list2ragged() - >> attention_layer - >> reduce_sum() - >> residual(maxout_layer >> norm_layer >> Dropout(0.0)) - ) - nO_double = nO * 2 if nO else None if exclusive_classes: output_layer = Softmax(nO=nO, nI=nO_double) else: output_layer = Linear(nO=nO, nI=nO_double) >> Logistic() - model = (linear_model | cnn_model) >> output_layer + model = (linear_model | parametric_attention) >> output_layer model.set_ref("tok2vec", tok2vec) if model.has_dim("nO") is not False and nO is not None: model.set_dim("nO", cast(int, nO)) model.set_ref("output_layer", linear_model.get_ref("output_layer")) - model.set_ref("attention_layer", attention_layer) - model.set_ref("maxout_layer", maxout_layer) - model.set_ref("norm_layer", norm_layer) model.attrs["multi_label"] = not exclusive_classes - model.init = init_ensemble_textcat # type: ignore[assignment] - return model - - -def init_ensemble_textcat(model, X, Y) -> Model: - tok2vec_width = get_tok2vec_width(model) - model.get_ref("attention_layer").set_dim("nO", tok2vec_width) - model.get_ref("maxout_layer").set_dim("nO", tok2vec_width) - model.get_ref("maxout_layer").set_dim("nI", tok2vec_width) - model.get_ref("norm_layer").set_dim("nI", tok2vec_width) - model.get_ref("norm_layer").set_dim("nO", tok2vec_width) - init_chain(model, X, Y) return model