From 774f5732bdab210381fffce8c06adca9fc0152e5 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Wed, 4 Oct 2017 14:55:15 +0200 Subject: [PATCH] Fix dimensionality of textcat when no vectors available --- spacy/_ml.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/spacy/_ml.py b/spacy/_ml.py index dc458d6ac..b02bd27d9 100644 --- a/spacy/_ml.py +++ b/spacy/_ml.py @@ -570,6 +570,7 @@ def foreach(layer, drop_factor=1.0): def build_text_classifier(nr_class, width=64, **cfg): nr_vector = cfg.get('nr_vector', 5000) + pretrained_dims = cfg.get('pretrained_dims', 0) with Model.define_operators({'>>': chain, '+': add, '|': concatenate, '**': clone}): if cfg.get('low_data'): @@ -577,7 +578,7 @@ def build_text_classifier(nr_class, width=64, **cfg): SpacyVectors >> flatten_add_lengths >> with_getitem(0, - Affine(width, 300) + Affine(width, pretrained_dims) ) >> ParametricAttention(width) >> Pooling(sum_pool) @@ -604,16 +605,22 @@ def build_text_classifier(nr_class, width=64, **cfg): ) ) - static_vectors = ( - SpacyVectors - >> with_flatten(Affine(width, 300)) - ) - - cnn_model = ( + if pretrained_dims: + static_vectors = ( + SpacyVectors + >> with_flatten(Affine(width, pretrained_dims)) + ) # TODO Make concatenate support lists - concatenate_lists(trained_vectors, static_vectors) + vectors = concatenate_lists(trained_vectors, static_vectors) + vectors_width = width*2 + else: + vectors = trained_vectors + vectors_width = width + static_vectors = None + cnn_model = ( + vectors >> with_flatten( - LN(Maxout(width, width*2)) + LN(Maxout(width, vectors_width)) >> Residual( (ExtractWindow(nW=1) >> zero_init(Maxout(width, width*3))) ) ** 2, pad=2