mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Fix dimensionality of textcat when no vectors available
This commit is contained in:
parent
af75b74208
commit
774f5732bd
25
spacy/_ml.py
25
spacy/_ml.py
|
@ -570,6 +570,7 @@ def foreach(layer, drop_factor=1.0):
|
|||
|
||||
def build_text_classifier(nr_class, width=64, **cfg):
|
||||
nr_vector = cfg.get('nr_vector', 5000)
|
||||
pretrained_dims = cfg.get('pretrained_dims', 0)
|
||||
with Model.define_operators({'>>': chain, '+': add, '|': concatenate,
|
||||
'**': clone}):
|
||||
if cfg.get('low_data'):
|
||||
|
@ -577,7 +578,7 @@ def build_text_classifier(nr_class, width=64, **cfg):
|
|||
SpacyVectors
|
||||
>> flatten_add_lengths
|
||||
>> with_getitem(0,
|
||||
Affine(width, 300)
|
||||
Affine(width, pretrained_dims)
|
||||
)
|
||||
>> ParametricAttention(width)
|
||||
>> Pooling(sum_pool)
|
||||
|
@ -604,16 +605,22 @@ def build_text_classifier(nr_class, width=64, **cfg):
|
|||
)
|
||||
)
|
||||
|
||||
static_vectors = (
|
||||
SpacyVectors
|
||||
>> with_flatten(Affine(width, 300))
|
||||
)
|
||||
|
||||
cnn_model = (
|
||||
if pretrained_dims:
|
||||
static_vectors = (
|
||||
SpacyVectors
|
||||
>> with_flatten(Affine(width, pretrained_dims))
|
||||
)
|
||||
# TODO Make concatenate support lists
|
||||
concatenate_lists(trained_vectors, static_vectors)
|
||||
vectors = concatenate_lists(trained_vectors, static_vectors)
|
||||
vectors_width = width*2
|
||||
else:
|
||||
vectors = trained_vectors
|
||||
vectors_width = width
|
||||
static_vectors = None
|
||||
cnn_model = (
|
||||
vectors
|
||||
>> with_flatten(
|
||||
LN(Maxout(width, width*2))
|
||||
LN(Maxout(width, vectors_width))
|
||||
>> Residual(
|
||||
(ExtractWindow(nW=1) >> zero_init(Maxout(width, width*3)))
|
||||
) ** 2, pad=2
|
||||
|
|
Loading…
Reference in New Issue
Block a user