mirror of
https://github.com/explosion/spaCy.git
synced 2025-05-30 18:53:36 +03:00
define pretrained_dims which is used by build_text_classifier (#5004)
This commit is contained in:
parent
3b22eb651b
commit
72c964bcf4
|
@ -608,6 +608,7 @@ class Language(object):
|
||||||
link_vectors_to_models(self.vocab)
|
link_vectors_to_models(self.vocab)
|
||||||
if self.vocab.vectors.data.shape[1]:
|
if self.vocab.vectors.data.shape[1]:
|
||||||
cfg["pretrained_vectors"] = self.vocab.vectors.name
|
cfg["pretrained_vectors"] = self.vocab.vectors.name
|
||||||
|
cfg['pretrained_dims'] = self.vocab.vectors.data.shape[1]
|
||||||
if sgd is None:
|
if sgd is None:
|
||||||
sgd = create_default_optimizer(Model.ops)
|
sgd = create_default_optimizer(Model.ops)
|
||||||
self._optimizer = sgd
|
self._optimizer = sgd
|
||||||
|
|
|
@ -1044,6 +1044,7 @@ class TextCategorizer(Pipe):
|
||||||
self.add_label(cat)
|
self.add_label(cat)
|
||||||
if self.model is True:
|
if self.model is True:
|
||||||
self.cfg["pretrained_vectors"] = kwargs.get("pretrained_vectors")
|
self.cfg["pretrained_vectors"] = kwargs.get("pretrained_vectors")
|
||||||
|
self.cfg["pretrained_dims"] = kwargs.get("pretrained_dims")
|
||||||
self.require_labels()
|
self.require_labels()
|
||||||
self.model = self.Model(len(self.labels), **self.cfg)
|
self.model = self.Model(len(self.labels), **self.cfg)
|
||||||
link_vectors_to_models(self.vocab)
|
link_vectors_to_models(self.vocab)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user