Fix dimensionality of textcat when no vectors available

This commit is contained in:
Matthew Honnibal 2017-10-04 14:55:15 +02:00
parent af75b74208
commit 774f5732bd

View File

@ -570,6 +570,7 @@ def foreach(layer, drop_factor=1.0):
def build_text_classifier(nr_class, width=64, **cfg):
nr_vector = cfg.get('nr_vector', 5000)
pretrained_dims = cfg.get('pretrained_dims', 0)
with Model.define_operators({'>>': chain, '+': add, '|': concatenate,
'**': clone}):
if cfg.get('low_data'):
@ -577,7 +578,7 @@ def build_text_classifier(nr_class, width=64, **cfg):
SpacyVectors
>> flatten_add_lengths
>> with_getitem(0,
Affine(width, 300)
Affine(width, pretrained_dims)
)
>> ParametricAttention(width)
>> Pooling(sum_pool)
@ -604,16 +605,22 @@ def build_text_classifier(nr_class, width=64, **cfg):
)
)
if pretrained_dims:
static_vectors = (
SpacyVectors
>> with_flatten(Affine(width, 300))
>> with_flatten(Affine(width, pretrained_dims))
)
cnn_model = (
# TODO Make concatenate support lists
concatenate_lists(trained_vectors, static_vectors)
vectors = concatenate_lists(trained_vectors, static_vectors)
vectors_width = width*2
else:
vectors = trained_vectors
vectors_width = width
static_vectors = None
cnn_model = (
vectors
>> with_flatten(
LN(Maxout(width, width*2))
LN(Maxout(width, vectors_width))
>> Residual(
(ExtractWindow(nW=1) >> zero_init(Maxout(width, width*3)))
) ** 2, pad=2