Set pretrained_vectors in begin_training

This commit is contained in:
Matthew Honnibal 2018-03-28 16:32:41 +02:00
parent 95a9615221
commit 9bf6e93b3e
2 changed files with 5 additions and 3 deletions

View File

@ -516,6 +516,7 @@ class Tagger(Pipe):
vocab.morphology = Morphology(vocab.strings, new_tag_map,
vocab.morphology.lemmatizer,
exc=vocab.morphology.exc)
self.cfg['pretrained_vectors'] = kwargs.get('pretrained_vectors')
if self.model is True:
self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg)
link_vectors_to_models(self.vocab)
@ -910,12 +911,15 @@ class TextCategorizer(Pipe):
self.labels.append(label)
return 1
def begin_training(self, gold_tuples=tuple(), pipeline=None, sgd=None):
def begin_training(self, gold_tuples=tuple(), pipeline=None, sgd=None,
**kwargs):
if pipeline and getattr(pipeline[0], 'name', None) == 'tensorizer':
token_vector_width = pipeline[0].model.nO
else:
token_vector_width = 64
if self.model is True:
self.cfg['pretrained_vectors'] = kwargs.get('pretrained_vectors')
self.model = self.Model(len(self.labels), token_vector_width,
**self.cfg)
link_vectors_to_models(self.vocab)

View File

@ -896,7 +896,6 @@ cdef class Parser:
# TODO: Remove this once we don't have to handle previous models
if 'pretrained_dims' in self.cfg and 'pretrained_vectors' not in self.cfg:
self.cfg['pretrained_vectors'] = self.vocab.vectors.name
print("Create parser model", self.cfg)
path = util.ensure_path(path)
if self.model is True:
self.model, cfg = self.Model(**self.cfg)
@ -944,7 +943,6 @@ cdef class Parser:
# TODO: Remove this once we don't have to handle previous models
if 'pretrained_dims' in self.cfg and 'pretrained_vectors' not in self.cfg:
self.cfg['pretrained_vectors'] = self.vocab.vectors.name
print("Create parser model", self.cfg)
if self.model is True:
self.model, cfg = self.Model(**self.cfg)
else: