diff --git a/spacy/ml/models/textcat.py b/spacy/ml/models/textcat.py index cbb20c7e3..e0c11ed99 100644 --- a/spacy/ml/models/textcat.py +++ b/spacy/ml/models/textcat.py @@ -4,7 +4,7 @@ from thinc.types import Floats2d from thinc.api import Model, reduce_mean, Linear, list2ragged, Logistic from thinc.api import chain, concatenate, clone, Dropout, ParametricAttention from thinc.api import SparseLinear, Softmax, softmax_activation, Maxout, reduce_sum -from thinc.api import with_cpu, Relu, residual +from thinc.api import with_cpu, Relu, residual, LayerNorm from thinc.layers.chain import init as init_chain from ...attrs import ORTH @@ -72,13 +72,14 @@ def build_text_classifier_v2( attention_layer = ParametricAttention( width ) # TODO: benchmark performance difference of this layer - maxout_layer = Maxout(nO=width, nI=width, dropout=0.0, normalize=True) + maxout_layer = Maxout(nO=width, nI=width) + norm_layer = LayerNorm(nI=width) cnn_model = ( tok2vec >> list2ragged() >> attention_layer >> reduce_sum() - >> residual(maxout_layer) + >> residual(maxout_layer >> norm_layer >> Dropout(0.0)) ) nO_double = nO * 2 if nO else None @@ -93,6 +94,7 @@ def build_text_classifier_v2( model.set_ref("output_layer", linear_model.get_ref("output_layer")) model.set_ref("attention_layer", attention_layer) model.set_ref("maxout_layer", maxout_layer) + model.set_ref("norm_layer", norm_layer) model.attrs["multi_label"] = not exclusive_classes model.init = init_ensemble_textcat @@ -104,6 +106,7 @@ def init_ensemble_textcat(model, X, Y) -> Model: model.get_ref("attention_layer").set_dim("nO", tok2vec_width) model.get_ref("maxout_layer").set_dim("nO", tok2vec_width) model.get_ref("maxout_layer").set_dim("nI", tok2vec_width) + model.get_ref("norm_layer").set_dim("nI", tok2vec_width) init_chain(model, X, Y) return model