mirror of
https://github.com/explosion/spaCy.git
synced 2024-09-21 11:29:13 +03:00
Update vectors in textcat
This commit is contained in:
parent
b5bbfec591
commit
9e1b11dd81
|
@ -9,6 +9,7 @@ from ... import util
|
||||||
from ...attrs import ID, ORTH, PREFIX, SUFFIX, SHAPE, LOWER
|
from ...attrs import ID, ORTH, PREFIX, SUFFIX, SHAPE, LOWER
|
||||||
from ...util import registry
|
from ...util import registry
|
||||||
from ..extract_ngrams import extract_ngrams
|
from ..extract_ngrams import extract_ngrams
|
||||||
|
from ..staticvectors import StaticVectors
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.TextCatCNN.v1")
|
@registry.architectures.register("spacy.TextCatCNN.v1")
|
||||||
|
@ -101,13 +102,7 @@ def build_text_classifier(
|
||||||
)
|
)
|
||||||
|
|
||||||
if pretrained_vectors:
|
if pretrained_vectors:
|
||||||
nlp = util.load_model(pretrained_vectors)
|
static_vectors = StaticVectors(width)
|
||||||
vectors = nlp.vocab.vectors
|
|
||||||
vector_dim = vectors.data.shape[1]
|
|
||||||
|
|
||||||
static_vectors = SpacyVectors(vectors) >> with_array(
|
|
||||||
Linear(width, vector_dim)
|
|
||||||
)
|
|
||||||
vector_layer = trained_vectors | static_vectors
|
vector_layer = trained_vectors | static_vectors
|
||||||
vectors_width = width * 2
|
vectors_width = width * 2
|
||||||
else:
|
else:
|
||||||
|
@ -158,14 +153,10 @@ def build_text_classifier(
|
||||||
|
|
||||||
@registry.architectures.register("spacy.TextCatLowData.v1")
|
@registry.architectures.register("spacy.TextCatLowData.v1")
|
||||||
def build_text_classifier_lowdata(width, pretrained_vectors, dropout, nO=None):
|
def build_text_classifier_lowdata(width, pretrained_vectors, dropout, nO=None):
|
||||||
nlp = util.load_model(pretrained_vectors)
|
|
||||||
vectors = nlp.vocab.vectors
|
|
||||||
vector_dim = vectors.data.shape[1]
|
|
||||||
|
|
||||||
# Note, before v.3, this was the default if setting "low_data" and "pretrained_dims"
|
# Note, before v.3, this was the default if setting "low_data" and "pretrained_dims"
|
||||||
with Model.define_operators({">>": chain, "**": clone}):
|
with Model.define_operators({">>": chain, "**": clone}):
|
||||||
model = (
|
model = (
|
||||||
SpacyVectors(vectors)
|
StaticVectors(width)
|
||||||
>> list2ragged()
|
>> list2ragged()
|
||||||
>> with_ragged(0, Linear(width, vector_dim))
|
>> with_ragged(0, Linear(width, vector_dim))
|
||||||
>> ParametricAttention(width)
|
>> ParametricAttention(width)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user