Fix loading of models when custom vectors are added

This commit is contained in:
Matthew Honnibal 2018-04-10 22:19:05 +02:00
parent 0ddb152be0
commit 3836199a83
2 changed files with 6 additions and 6 deletions

View File

@ -206,7 +206,7 @@ class Pipe(object):
"""Load the pipe from a bytestring."""
def load_model(b):
# TODO: Remove this once we don't have to handle previous models
if 'pretrained_dims' in self.cfg and 'pretrained_vectors' not in self.cfg:
if self.cfg.get('pretrained_dims') and 'pretrained_vectors' not in self.cfg:
self.cfg['pretrained_vectors'] = self.vocab.vectors.name
if self.model is True:
self.model = self.Model(**self.cfg)
@ -233,7 +233,7 @@ class Pipe(object):
"""Load the pipe from disk."""
def load_model(p):
# TODO: Remove this once we don't have to handle previous models
if 'pretrained_dims' in self.cfg and 'pretrained_vectors' not in self.cfg:
if self.cfg.get('pretrained_dims') and 'pretrained_vectors' not in self.cfg:
self.cfg['pretrained_vectors'] = self.vocab.vectors.name
if self.model is True:
self.model = self.Model(**self.cfg)
@ -578,7 +578,7 @@ class Tagger(Pipe):
def from_bytes(self, bytes_data, **exclude):
def load_model(b):
# TODO: Remove this once we don't have to handle previous models
if 'pretrained_dims' in self.cfg and 'pretrained_vectors' not in self.cfg:
if self.cfg.get('pretrained_dims') and 'pretrained_vectors' not in self.cfg:
self.cfg['pretrained_vectors'] = self.vocab.vectors.name
if self.model is True:
@ -619,7 +619,7 @@ class Tagger(Pipe):
def from_disk(self, path, **exclude):
def load_model(p):
# TODO: Remove this once we don't have to handle previous models
if 'pretrained_dims' in self.cfg and 'pretrained_vectors' not in self.cfg:
if self.cfg.get('pretrained_dims') and 'pretrained_vectors' not in self.cfg:
self.cfg['pretrained_vectors'] = self.vocab.vectors.name
if self.model is True:
self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg)

View File

@ -901,7 +901,7 @@ cdef class Parser:
util.from_disk(path, deserializers, exclude)
if 'model' not in exclude:
# TODO: Remove this once we don't have to handle previous models
if 'pretrained_dims' in self.cfg and 'pretrained_vectors' not in self.cfg:
if self.cfg.get('pretrained_dims') and 'pretrained_vectors' not in self.cfg:
self.cfg['pretrained_vectors'] = self.vocab.vectors.name
path = util.ensure_path(path)
if self.model is True:
@ -948,7 +948,7 @@ cdef class Parser:
msg = util.from_bytes(bytes_data, deserializers, exclude)
if 'model' not in exclude:
# TODO: Remove this once we don't have to handle previous models
if 'pretrained_dims' in self.cfg and 'pretrained_vectors' not in self.cfg:
if self.cfg.get('pretrained_dims') and 'pretrained_vectors' not in self.cfg:
self.cfg['pretrained_vectors'] = self.vocab.vectors.name
if self.model is True:
self.model, cfg = self.Model(**self.cfg)