mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 09:14:32 +03:00
Improve initialization for mutually textcat
This commit is contained in:
parent
5063d999e5
commit
d13b9373bf
|
@ -72,10 +72,10 @@ def _flatten_add_lengths(seqs, pad=0, drop=0.0):
|
|||
|
||||
|
||||
def _zero_init(model):
|
||||
def _zero_init_impl(self, X, y):
|
||||
def _zero_init_impl(self, *args, **kwargs):
|
||||
self.W.fill(0)
|
||||
|
||||
model.on_data_hooks.append(_zero_init_impl)
|
||||
model.on_init_hooks.append(_zero_init_impl)
|
||||
if model.W is not None:
|
||||
model.W.fill(0.0)
|
||||
return model
|
||||
|
@ -594,7 +594,7 @@ def build_simple_cnn_text_classifier(tok2vec, nr_class, exclusive_classes=False,
|
|||
if exclusive_classes:
|
||||
output_layer = Softmax(nr_class, tok2vec.nO)
|
||||
else:
|
||||
output_layer = zero_init(Affine(nr_class, tok2vec.nO)) >> logistic
|
||||
output_layer = zero_init(Affine(nr_class, tok2vec.nO, drop_factor=0.0)) >> logistic
|
||||
model = tok2vec >> flatten_add_lengths >> Pooling(mean_pool) >> output_layer
|
||||
model.tok2vec = chain(tok2vec, flatten)
|
||||
model.nO = nr_class
|
||||
|
|
Loading…
Reference in New Issue
Block a user