Update vectors in textcat

This commit is contained in:
Matthew Honnibal 2020-07-29 14:35:36 +02:00
parent b5bbfec591
commit 9e1b11dd81

View File

@ -9,6 +9,7 @@ from ... import util
from ...attrs import ID, ORTH, PREFIX, SUFFIX, SHAPE, LOWER
from ...util import registry
from ..extract_ngrams import extract_ngrams
from ..staticvectors import StaticVectors
@registry.architectures.register("spacy.TextCatCNN.v1")
@ -101,13 +102,7 @@ def build_text_classifier(
)
if pretrained_vectors:
nlp = util.load_model(pretrained_vectors)
vectors = nlp.vocab.vectors
vector_dim = vectors.data.shape[1]
static_vectors = SpacyVectors(vectors) >> with_array(
Linear(width, vector_dim)
)
static_vectors = StaticVectors(width)
vector_layer = trained_vectors | static_vectors
vectors_width = width * 2
else:
@ -158,14 +153,10 @@ def build_text_classifier(
@registry.architectures.register("spacy.TextCatLowData.v1")
def build_text_classifier_lowdata(width, pretrained_vectors, dropout, nO=None):
nlp = util.load_model(pretrained_vectors)
vectors = nlp.vocab.vectors
vector_dim = vectors.data.shape[1]
# Note, before v.3, this was the default if setting "low_data" and "pretrained_dims"
with Model.define_operators({">>": chain, "**": clone}):
model = (
SpacyVectors(vectors)
StaticVectors(width)
>> list2ragged()
>> with_ragged(0, Linear(width, vector_dim))
>> ParametricAttention(width)