Further model deserialization fixes re #1727

This commit is contained in:
Matthew Honnibal 2018-01-23 19:16:05 +01:00
parent 91e916cb67
commit f3753c2453

View File

@ -174,7 +174,7 @@ class Pipe(object):
"""Load the pipe from a bytestring.""" """Load the pipe from a bytestring."""
def load_model(b): def load_model(b):
if self.model is True: if self.model is True:
self.cfg['pretrained_dims'] = self.vocab.vectors_length self.cfg.setdefault('pretrained_dims', self.vocab.vectors_length)
self.model = self.Model(**self.cfg) self.model = self.Model(**self.cfg)
self.model.from_bytes(b) self.model.from_bytes(b)
@ -199,7 +199,7 @@ class Pipe(object):
"""Load the pipe from disk.""" """Load the pipe from disk."""
def load_model(p): def load_model(p):
if self.model is True: if self.model is True:
self.cfg['pretrained_dims'] = self.vocab.vectors_length self.cfg.setdefault('pretrained_dims', self.vocab.vectors_length)
self.model = self.Model(**self.cfg) self.model = self.Model(**self.cfg)
self.model.from_bytes(p.open('rb').read()) self.model.from_bytes(p.open('rb').read())