mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-10 00:20:35 +03:00
Improve initialization for mutually textcat
This commit is contained in:
parent
5063d999e5
commit
d13b9373bf
|
@ -72,10 +72,10 @@ def _flatten_add_lengths(seqs, pad=0, drop=0.0):
|
||||||
|
|
||||||
|
|
||||||
def _zero_init(model):
|
def _zero_init(model):
|
||||||
def _zero_init_impl(self, X, y):
|
def _zero_init_impl(self, *args, **kwargs):
|
||||||
self.W.fill(0)
|
self.W.fill(0)
|
||||||
|
|
||||||
model.on_data_hooks.append(_zero_init_impl)
|
model.on_init_hooks.append(_zero_init_impl)
|
||||||
if model.W is not None:
|
if model.W is not None:
|
||||||
model.W.fill(0.0)
|
model.W.fill(0.0)
|
||||||
return model
|
return model
|
||||||
|
@ -594,7 +594,7 @@ def build_simple_cnn_text_classifier(tok2vec, nr_class, exclusive_classes=False,
|
||||||
if exclusive_classes:
|
if exclusive_classes:
|
||||||
output_layer = Softmax(nr_class, tok2vec.nO)
|
output_layer = Softmax(nr_class, tok2vec.nO)
|
||||||
else:
|
else:
|
||||||
output_layer = zero_init(Affine(nr_class, tok2vec.nO)) >> logistic
|
output_layer = zero_init(Affine(nr_class, tok2vec.nO, drop_factor=0.0)) >> logistic
|
||||||
model = tok2vec >> flatten_add_lengths >> Pooling(mean_pool) >> output_layer
|
model = tok2vec >> flatten_add_lengths >> Pooling(mean_pool) >> output_layer
|
||||||
model.tok2vec = chain(tok2vec, flatten)
|
model.tok2vec = chain(tok2vec, flatten)
|
||||||
model.nO = nr_class
|
model.nO = nr_class
|
||||||
|
|
Loading…
Reference in New Issue
Block a user