diff --git a/spacy/ml/models/textcat.py b/spacy/ml/models/textcat.py index 63dcb165a..8c7316f62 100644 --- a/spacy/ml/models/textcat.py +++ b/spacy/ml/models/textcat.py @@ -6,6 +6,7 @@ from thinc.api import chain, concatenate, clone, Dropout, ParametricAttention from thinc.api import SparseLinear, Softmax, softmax_activation, Maxout, reduce_sum from thinc.api import HashEmbed, with_array, with_cpu, uniqued from thinc.api import Relu, residual, expand_window +from thinc.layers.chain import init as init_chain from ...attrs import ID, ORTH, PREFIX, SUFFIX, SHAPE, LOWER from ...util import registry @@ -13,6 +14,7 @@ from ..extract_ngrams import extract_ngrams from ..staticvectors import StaticVectors from ..featureextractor import FeatureExtractor from ...tokens import Doc +from .tok2vec import get_tok2vec_width @registry.architectures.register("spacy.TextCatCNN.v1") @@ -69,16 +71,17 @@ def build_text_classifier_v2( exclusive_classes = not linear_model.attrs["multi_label"] with Model.define_operators({">>": chain, "|": concatenate}): width = tok2vec.maybe_get_dim("nO") + attention_layer = ParametricAttention(width) # TODO: benchmark performance difference of this layer + maxout_layer = Maxout(nO=width, nI=width) + linear_layer = Linear(nO=nO, nI=width) cnn_model = ( - tok2vec - >> list2ragged() - >> ParametricAttention( - width - ) # TODO: benchmark performance difference of this layer - >> reduce_sum() - >> residual(Maxout(nO=width, nI=width)) - >> Linear(nO=nO, nI=width) - >> Dropout(0.0) + tok2vec + >> list2ragged() + >> attention_layer + >> reduce_sum() + >> residual(maxout_layer) + >> linear_layer + >> Dropout(0.0) ) nO_double = nO * 2 if nO else None @@ -91,7 +94,22 @@ def build_text_classifier_v2( if model.has_dim("nO") is not False: model.set_dim("nO", nO) model.set_ref("output_layer", linear_model.get_ref("output_layer")) + model.set_ref("attention_layer", attention_layer) + model.set_ref("maxout_layer", maxout_layer) + model.set_ref("linear_layer", linear_layer) model.attrs["multi_label"] = not exclusive_classes + + model.init = init_ensemble_textcat + return model + + +def init_ensemble_textcat(model, X, Y) -> Model: + tok2vec_width = get_tok2vec_width(model) + model.get_ref("attention_layer").set_dim("nO", tok2vec_width) + model.get_ref("maxout_layer").set_dim("nO", tok2vec_width) + model.get_ref("maxout_layer").set_dim("nI", tok2vec_width) + model.get_ref("linear_layer").set_dim("nI", tok2vec_width) + init_chain(model, X, Y) return model diff --git a/spacy/ml/models/tok2vec.py b/spacy/ml/models/tok2vec.py index 8755d0d0d..0f727d85f 100644 --- a/spacy/ml/models/tok2vec.py +++ b/spacy/ml/models/tok2vec.py @@ -20,6 +20,17 @@ def tok2vec_listener_v1(width: int, upstream: str = "*"): return tok2vec +def get_tok2vec_width(model: Model): + nO = None + if model.has_ref("tok2vec"): + tok2vec = model.get_ref("tok2vec") + if tok2vec.has_dim("nO"): + nO = tok2vec.get_dim("nO") + elif tok2vec.has_ref("listener"): + nO = tok2vec.get_ref("listener").get_dim("nO") + return nO + + @registry.architectures.register("spacy.HashEmbedCNN.v1") def build_hash_embed_cnn_tok2vec( *,