Improve initialization for mutually textcat

This commit is contained in:
Matthew Honnibal 2019-02-23 12:27:45 +01:00
parent 5063d999e5
commit d13b9373bf

View File

@ -72,10 +72,10 @@ def _flatten_add_lengths(seqs, pad=0, drop=0.0):
def _zero_init(model):
def _zero_init_impl(self, X, y):
def _zero_init_impl(self, *args, **kwargs):
self.W.fill(0)
model.on_data_hooks.append(_zero_init_impl)
model.on_init_hooks.append(_zero_init_impl)
if model.W is not None:
model.W.fill(0.0)
return model
@ -594,7 +594,7 @@ def build_simple_cnn_text_classifier(tok2vec, nr_class, exclusive_classes=False,
if exclusive_classes:
output_layer = Softmax(nr_class, tok2vec.nO)
else:
output_layer = zero_init(Affine(nr_class, tok2vec.nO)) >> logistic
output_layer = zero_init(Affine(nr_class, tok2vec.nO, drop_factor=0.0)) >> logistic
model = tok2vec >> flatten_add_lengths >> Pooling(mean_pool) >> output_layer
model.tok2vec = chain(tok2vec, flatten)
model.nO = nr_class