mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Set pretrained_vectors in begin_training
This commit is contained in:
parent
95a9615221
commit
9bf6e93b3e
|
@ -516,6 +516,7 @@ class Tagger(Pipe):
|
|||
vocab.morphology = Morphology(vocab.strings, new_tag_map,
|
||||
vocab.morphology.lemmatizer,
|
||||
exc=vocab.morphology.exc)
|
||||
self.cfg['pretrained_vectors'] = kwargs.get('pretrained_vectors')
|
||||
if self.model is True:
|
||||
self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg)
|
||||
link_vectors_to_models(self.vocab)
|
||||
|
@ -910,12 +911,15 @@ class TextCategorizer(Pipe):
|
|||
self.labels.append(label)
|
||||
return 1
|
||||
|
||||
def begin_training(self, gold_tuples=tuple(), pipeline=None, sgd=None):
|
||||
def begin_training(self, gold_tuples=tuple(), pipeline=None, sgd=None,
|
||||
**kwargs):
|
||||
if pipeline and getattr(pipeline[0], 'name', None) == 'tensorizer':
|
||||
token_vector_width = pipeline[0].model.nO
|
||||
else:
|
||||
token_vector_width = 64
|
||||
|
||||
if self.model is True:
|
||||
self.cfg['pretrained_vectors'] = kwargs.get('pretrained_vectors')
|
||||
self.model = self.Model(len(self.labels), token_vector_width,
|
||||
**self.cfg)
|
||||
link_vectors_to_models(self.vocab)
|
||||
|
|
|
@ -896,7 +896,6 @@ cdef class Parser:
|
|||
# TODO: Remove this once we don't have to handle previous models
|
||||
if 'pretrained_dims' in self.cfg and 'pretrained_vectors' not in self.cfg:
|
||||
self.cfg['pretrained_vectors'] = self.vocab.vectors.name
|
||||
print("Create parser model", self.cfg)
|
||||
path = util.ensure_path(path)
|
||||
if self.model is True:
|
||||
self.model, cfg = self.Model(**self.cfg)
|
||||
|
@ -944,7 +943,6 @@ cdef class Parser:
|
|||
# TODO: Remove this once we don't have to handle previous models
|
||||
if 'pretrained_dims' in self.cfg and 'pretrained_vectors' not in self.cfg:
|
||||
self.cfg['pretrained_vectors'] = self.vocab.vectors.name
|
||||
print("Create parser model", self.cfg)
|
||||
if self.model is True:
|
||||
self.model, cfg = self.Model(**self.cfg)
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue
Block a user