mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-05 14:10:34 +03:00
Fix dimensionality of textcat when no vectors available
This commit is contained in:
parent
af75b74208
commit
774f5732bd
25
spacy/_ml.py
25
spacy/_ml.py
|
@ -570,6 +570,7 @@ def foreach(layer, drop_factor=1.0):
|
||||||
|
|
||||||
def build_text_classifier(nr_class, width=64, **cfg):
|
def build_text_classifier(nr_class, width=64, **cfg):
|
||||||
nr_vector = cfg.get('nr_vector', 5000)
|
nr_vector = cfg.get('nr_vector', 5000)
|
||||||
|
pretrained_dims = cfg.get('pretrained_dims', 0)
|
||||||
with Model.define_operators({'>>': chain, '+': add, '|': concatenate,
|
with Model.define_operators({'>>': chain, '+': add, '|': concatenate,
|
||||||
'**': clone}):
|
'**': clone}):
|
||||||
if cfg.get('low_data'):
|
if cfg.get('low_data'):
|
||||||
|
@ -577,7 +578,7 @@ def build_text_classifier(nr_class, width=64, **cfg):
|
||||||
SpacyVectors
|
SpacyVectors
|
||||||
>> flatten_add_lengths
|
>> flatten_add_lengths
|
||||||
>> with_getitem(0,
|
>> with_getitem(0,
|
||||||
Affine(width, 300)
|
Affine(width, pretrained_dims)
|
||||||
)
|
)
|
||||||
>> ParametricAttention(width)
|
>> ParametricAttention(width)
|
||||||
>> Pooling(sum_pool)
|
>> Pooling(sum_pool)
|
||||||
|
@ -604,16 +605,22 @@ def build_text_classifier(nr_class, width=64, **cfg):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
static_vectors = (
|
if pretrained_dims:
|
||||||
SpacyVectors
|
static_vectors = (
|
||||||
>> with_flatten(Affine(width, 300))
|
SpacyVectors
|
||||||
)
|
>> with_flatten(Affine(width, pretrained_dims))
|
||||||
|
)
|
||||||
cnn_model = (
|
|
||||||
# TODO Make concatenate support lists
|
# TODO Make concatenate support lists
|
||||||
concatenate_lists(trained_vectors, static_vectors)
|
vectors = concatenate_lists(trained_vectors, static_vectors)
|
||||||
|
vectors_width = width*2
|
||||||
|
else:
|
||||||
|
vectors = trained_vectors
|
||||||
|
vectors_width = width
|
||||||
|
static_vectors = None
|
||||||
|
cnn_model = (
|
||||||
|
vectors
|
||||||
>> with_flatten(
|
>> with_flatten(
|
||||||
LN(Maxout(width, width*2))
|
LN(Maxout(width, vectors_width))
|
||||||
>> Residual(
|
>> Residual(
|
||||||
(ExtractWindow(nW=1) >> zero_init(Maxout(width, width*3)))
|
(ExtractWindow(nW=1) >> zero_init(Maxout(width, width*3)))
|
||||||
) ** 2, pad=2
|
) ** 2, pad=2
|
||||||
|
|
Loading…
Reference in New Issue
Block a user