diff --git a/spacy/ml/models/textcat.py b/spacy/ml/models/textcat.py index a64a2487a..139917581 100644 --- a/spacy/ml/models/textcat.py +++ b/spacy/ml/models/textcat.py @@ -9,6 +9,7 @@ from ... import util from ...attrs import ID, ORTH, PREFIX, SUFFIX, SHAPE, LOWER from ...util import registry from ..extract_ngrams import extract_ngrams +from ..staticvectors import StaticVectors @registry.architectures.register("spacy.TextCatCNN.v1") @@ -101,13 +102,7 @@ def build_text_classifier( ) if pretrained_vectors: - nlp = util.load_model(pretrained_vectors) - vectors = nlp.vocab.vectors - vector_dim = vectors.data.shape[1] - - static_vectors = SpacyVectors(vectors) >> with_array( - Linear(width, vector_dim) - ) + static_vectors = StaticVectors(width) vector_layer = trained_vectors | static_vectors vectors_width = width * 2 else: @@ -158,14 +153,10 @@ def build_text_classifier( @registry.architectures.register("spacy.TextCatLowData.v1") def build_text_classifier_lowdata(width, pretrained_vectors, dropout, nO=None): - nlp = util.load_model(pretrained_vectors) - vectors = nlp.vocab.vectors - vector_dim = vectors.data.shape[1] - # Note, before v.3, this was the default if setting "low_data" and "pretrained_dims" with Model.define_operators({">>": chain, "**": clone}): model = ( - SpacyVectors(vectors) + StaticVectors(width) >> list2ragged() >> with_ragged(0, Linear(width, vector_dim)) >> ParametricAttention(width)