Add low_data mode in textcat

This commit is contained in:
Matthew Honnibal 2017-09-02 14:56:30 +02:00
parent ead78c7b9b
commit a3b69bcb3d
2 changed files with 26 additions and 9 deletions

View File

@ -510,9 +510,23 @@ def foreach(layer, drop_factor=1.0):
def build_text_classifier(nr_class, width=64, **cfg):
nr_vector = cfg.get('nr_vector', 200)
nr_vector = cfg.get('nr_vector', 5000)
with Model.define_operators({'>>': chain, '+': add, '|': concatenate,
'**': clone}):
if cfg.get('low_data'):
model = (
SpacyVectors
>> flatten_add_lengths
>> with_getitem(0, LN(Affine(width, 300)))
>> ParametricAttention(width)
>> Pooling(sum_pool)
>> Residual(ReLu(width, width)) ** 2
>> zero_init(Affine(nr_class, width, drop_factor=0.0))
>> logistic
)
return model
lower = HashEmbed(width, nr_vector, column=1)
prefix = HashEmbed(width//2, nr_vector, column=2)
suffix = HashEmbed(width//2, nr_vector, column=3)
@ -523,7 +537,7 @@ def build_text_classifier(nr_class, width=64, **cfg):
>> with_flatten(
uniqued(
(lower | prefix | suffix | shape)
>> LN(Maxout(width, 64+32+32+32)),
>> LN(Maxout(width, width+(width//2)*3)),
column=0
)
)
@ -537,14 +551,16 @@ def build_text_classifier(nr_class, width=64, **cfg):
cnn_model = (
# TODO Make concatenate support lists
concatenate_lists(trained_vectors, static_vectors)
>> flatten_add_lengths
>> with_getitem(0,
SELU(width, width*2)
>> (ExtractWindow(nW=1) >> SELU(width, width*3)) ** 2
>> with_flatten(
LN(Maxout(width, width*2))
>> Residual(
(ExtractWindow(nW=1) >> zero_init(Maxout(width, width*3)))
) ** 2, pad=2
)
>> flatten_add_lengths
>> ParametricAttention(width)
>> Pooling(sum_pool)
>> SELU(width, width) ** 2
>> Residual(zero_init(Maxout(width, width)))
>> zero_init(Affine(nr_class, width, drop_factor=0.0))
)

View File

@ -638,12 +638,13 @@ class TextCategorizer(BaseThincComponent):
return mean_square_error, d_scores
def begin_training(self, gold_tuples=tuple(), pipeline=None):
if pipeline:
if pipeline and getattr(pipeline[0], 'name', None) == 'tensorizer':
token_vector_width = pipeline[0].model.nO
else:
token_vector_width = 64
if self.model is True:
self.model = self.Model(len(self.labels), token_vector_width)
self.model = self.Model(len(self.labels), token_vector_width,
**self.cfg)
cdef class EntityRecognizer(LinearParser):