diff --git a/spacy/_ml.py b/spacy/_ml.py index ba9c3b634..dd4a86ac1 100644 --- a/spacy/_ml.py +++ b/spacy/_ml.py @@ -72,10 +72,10 @@ def _flatten_add_lengths(seqs, pad=0, drop=0.0): def _zero_init(model): - def _zero_init_impl(self, X, y): + def _zero_init_impl(self, *args, **kwargs): self.W.fill(0) - model.on_data_hooks.append(_zero_init_impl) + model.on_init_hooks.append(_zero_init_impl) if model.W is not None: model.W.fill(0.0) return model @@ -594,7 +594,7 @@ def build_simple_cnn_text_classifier(tok2vec, nr_class, exclusive_classes=False, if exclusive_classes: output_layer = Softmax(nr_class, tok2vec.nO) else: - output_layer = zero_init(Affine(nr_class, tok2vec.nO)) >> logistic + output_layer = zero_init(Affine(nr_class, tok2vec.nO, drop_factor=0.0)) >> logistic model = tok2vec >> flatten_add_lengths >> Pooling(mean_pool) >> output_layer model.tok2vec = chain(tok2vec, flatten) model.nO = nr_class