mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Fix loading of models when custom vectors are added
This commit is contained in:
parent
0ddb152be0
commit
3836199a83
|
@ -206,7 +206,7 @@ class Pipe(object):
|
|||
"""Load the pipe from a bytestring."""
|
||||
def load_model(b):
|
||||
# TODO: Remove this once we don't have to handle previous models
|
||||
if 'pretrained_dims' in self.cfg and 'pretrained_vectors' not in self.cfg:
|
||||
if self.cfg.get('pretrained_dims') and 'pretrained_vectors' not in self.cfg:
|
||||
self.cfg['pretrained_vectors'] = self.vocab.vectors.name
|
||||
if self.model is True:
|
||||
self.model = self.Model(**self.cfg)
|
||||
|
@ -233,7 +233,7 @@ class Pipe(object):
|
|||
"""Load the pipe from disk."""
|
||||
def load_model(p):
|
||||
# TODO: Remove this once we don't have to handle previous models
|
||||
if 'pretrained_dims' in self.cfg and 'pretrained_vectors' not in self.cfg:
|
||||
if self.cfg.get('pretrained_dims') and 'pretrained_vectors' not in self.cfg:
|
||||
self.cfg['pretrained_vectors'] = self.vocab.vectors.name
|
||||
if self.model is True:
|
||||
self.model = self.Model(**self.cfg)
|
||||
|
@ -578,7 +578,7 @@ class Tagger(Pipe):
|
|||
def from_bytes(self, bytes_data, **exclude):
|
||||
def load_model(b):
|
||||
# TODO: Remove this once we don't have to handle previous models
|
||||
if 'pretrained_dims' in self.cfg and 'pretrained_vectors' not in self.cfg:
|
||||
if self.cfg.get('pretrained_dims') and 'pretrained_vectors' not in self.cfg:
|
||||
self.cfg['pretrained_vectors'] = self.vocab.vectors.name
|
||||
|
||||
if self.model is True:
|
||||
|
@ -619,7 +619,7 @@ class Tagger(Pipe):
|
|||
def from_disk(self, path, **exclude):
|
||||
def load_model(p):
|
||||
# TODO: Remove this once we don't have to handle previous models
|
||||
if 'pretrained_dims' in self.cfg and 'pretrained_vectors' not in self.cfg:
|
||||
if self.cfg.get('pretrained_dims') and 'pretrained_vectors' not in self.cfg:
|
||||
self.cfg['pretrained_vectors'] = self.vocab.vectors.name
|
||||
if self.model is True:
|
||||
self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg)
|
||||
|
|
|
@ -901,7 +901,7 @@ cdef class Parser:
|
|||
util.from_disk(path, deserializers, exclude)
|
||||
if 'model' not in exclude:
|
||||
# TODO: Remove this once we don't have to handle previous models
|
||||
if 'pretrained_dims' in self.cfg and 'pretrained_vectors' not in self.cfg:
|
||||
if self.cfg.get('pretrained_dims') and 'pretrained_vectors' not in self.cfg:
|
||||
self.cfg['pretrained_vectors'] = self.vocab.vectors.name
|
||||
path = util.ensure_path(path)
|
||||
if self.model is True:
|
||||
|
@ -948,7 +948,7 @@ cdef class Parser:
|
|||
msg = util.from_bytes(bytes_data, deserializers, exclude)
|
||||
if 'model' not in exclude:
|
||||
# TODO: Remove this once we don't have to handle previous models
|
||||
if 'pretrained_dims' in self.cfg and 'pretrained_vectors' not in self.cfg:
|
||||
if self.cfg.get('pretrained_dims') and 'pretrained_vectors' not in self.cfg:
|
||||
self.cfg['pretrained_vectors'] = self.vocab.vectors.name
|
||||
if self.model is True:
|
||||
self.model, cfg = self.Model(**self.cfg)
|
||||
|
|
Loading…
Reference in New Issue
Block a user