mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Update vectors in textcat
This commit is contained in:
parent
b5bbfec591
commit
9e1b11dd81
|
@ -9,6 +9,7 @@ from ... import util
|
|||
from ...attrs import ID, ORTH, PREFIX, SUFFIX, SHAPE, LOWER
|
||||
from ...util import registry
|
||||
from ..extract_ngrams import extract_ngrams
|
||||
from ..staticvectors import StaticVectors
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.TextCatCNN.v1")
|
||||
|
@ -101,13 +102,7 @@ def build_text_classifier(
|
|||
)
|
||||
|
||||
if pretrained_vectors:
|
||||
nlp = util.load_model(pretrained_vectors)
|
||||
vectors = nlp.vocab.vectors
|
||||
vector_dim = vectors.data.shape[1]
|
||||
|
||||
static_vectors = SpacyVectors(vectors) >> with_array(
|
||||
Linear(width, vector_dim)
|
||||
)
|
||||
static_vectors = StaticVectors(width)
|
||||
vector_layer = trained_vectors | static_vectors
|
||||
vectors_width = width * 2
|
||||
else:
|
||||
|
@ -158,14 +153,10 @@ def build_text_classifier(
|
|||
|
||||
@registry.architectures.register("spacy.TextCatLowData.v1")
|
||||
def build_text_classifier_lowdata(width, pretrained_vectors, dropout, nO=None):
|
||||
nlp = util.load_model(pretrained_vectors)
|
||||
vectors = nlp.vocab.vectors
|
||||
vector_dim = vectors.data.shape[1]
|
||||
|
||||
# Note, before v.3, this was the default if setting "low_data" and "pretrained_dims"
|
||||
with Model.define_operators({">>": chain, "**": clone}):
|
||||
model = (
|
||||
SpacyVectors(vectors)
|
||||
StaticVectors(width)
|
||||
>> list2ragged()
|
||||
>> with_ragged(0, Linear(width, vector_dim))
|
||||
>> ParametricAttention(width)
|
||||
|
|
Loading…
Reference in New Issue
Block a user