Update vectors in textcat

This commit is contained in:
Matthew Honnibal 2020-07-29 14:35:36 +02:00
parent b5bbfec591
commit 9e1b11dd81

View File

@ -9,6 +9,7 @@ from ... import util
from ...attrs import ID, ORTH, PREFIX, SUFFIX, SHAPE, LOWER from ...attrs import ID, ORTH, PREFIX, SUFFIX, SHAPE, LOWER
from ...util import registry from ...util import registry
from ..extract_ngrams import extract_ngrams from ..extract_ngrams import extract_ngrams
from ..staticvectors import StaticVectors
@registry.architectures.register("spacy.TextCatCNN.v1") @registry.architectures.register("spacy.TextCatCNN.v1")
@ -101,13 +102,7 @@ def build_text_classifier(
) )
if pretrained_vectors: if pretrained_vectors:
nlp = util.load_model(pretrained_vectors) static_vectors = StaticVectors(width)
vectors = nlp.vocab.vectors
vector_dim = vectors.data.shape[1]
static_vectors = SpacyVectors(vectors) >> with_array(
Linear(width, vector_dim)
)
vector_layer = trained_vectors | static_vectors vector_layer = trained_vectors | static_vectors
vectors_width = width * 2 vectors_width = width * 2
else: else:
@ -158,14 +153,10 @@ def build_text_classifier(
@registry.architectures.register("spacy.TextCatLowData.v1") @registry.architectures.register("spacy.TextCatLowData.v1")
def build_text_classifier_lowdata(width, pretrained_vectors, dropout, nO=None): def build_text_classifier_lowdata(width, pretrained_vectors, dropout, nO=None):
nlp = util.load_model(pretrained_vectors)
vectors = nlp.vocab.vectors
vector_dim = vectors.data.shape[1]
# Note, before v.3, this was the default if setting "low_data" and "pretrained_dims" # Note, before v.3, this was the default if setting "low_data" and "pretrained_dims"
with Model.define_operators({">>": chain, "**": clone}): with Model.define_operators({">>": chain, "**": clone}):
model = ( model = (
SpacyVectors(vectors) StaticVectors(width)
>> list2ragged() >> list2ragged()
>> with_ragged(0, Linear(width, vector_dim)) >> with_ragged(0, Linear(width, vector_dim))
>> ParametricAttention(width) >> ParametricAttention(width)