Fix loading models with pretrained vectors

This commit is contained in:
Matthew Honnibal 2018-04-03 23:11:48 +02:00
parent 96b612873b
commit 81f4005f3d

View File

@ -636,11 +636,11 @@ class Language(object):
""" """
path = util.ensure_path(path) path = util.ensure_path(path)
deserializers = OrderedDict(( deserializers = OrderedDict((
('vocab', lambda p: self.vocab.from_disk(p)), ('meta.json', lambda p: self.meta.update(util.read_json(p))),
('vocab', lambda p: (
self.vocab.from_disk(p) and _fix_pretrained_vectors_name(self))),
('tokenizer', lambda p: self.tokenizer.from_disk(p, vocab=False)), ('tokenizer', lambda p: self.tokenizer.from_disk(p, vocab=False)),
('meta.json', lambda p: self.meta.update(util.read_json(p)))
)) ))
_fix_pretrained_vectors_name(self)
for name, proc in self.pipeline: for name, proc in self.pipeline:
if name in disable: if name in disable:
continue continue
@ -682,11 +682,11 @@ class Language(object):
RETURNS (Language): The `Language` object. RETURNS (Language): The `Language` object.
""" """
deserializers = OrderedDict(( deserializers = OrderedDict((
('vocab', lambda b: self.vocab.from_bytes(b)), ('meta', lambda b: self.meta.update(ujson.loads(b))),
('vocab', lambda b: (
self.vocab.from_bytes(b) and _fix_pretrained_vectors_name(self))),
('tokenizer', lambda b: self.tokenizer.from_bytes(b, vocab=False)), ('tokenizer', lambda b: self.tokenizer.from_bytes(b, vocab=False)),
('meta', lambda b: self.meta.update(ujson.loads(b)))
)) ))
_fix_pretrained_vectors_name(self)
for i, (name, proc) in enumerate(self.pipeline): for i, (name, proc) in enumerate(self.pipeline):
if name in disable: if name in disable:
continue continue
@ -708,12 +708,12 @@ def _fix_pretrained_vectors_name(nlp):
nlp.vocab.vectors.name = vectors_name nlp.vocab.vectors.name = vectors_name
else: else:
raise ValueError(Errors.E092) raise ValueError(Errors.E092)
link_vectors_to_models(nlp.vocab)
for name, proc in nlp.pipeline: for name, proc in nlp.pipeline:
if not hasattr(proc, 'cfg'): if not hasattr(proc, 'cfg'):
continue continue
if proc.cfg.get('pretrained_dims'): proc.cfg.setdefault('deprecation_fixes', {})
assert nlp.vocab.vectors.name proc.cfg['deprecation_fixes']['vectors_name'] = nlp.vocab.vectors.name
proc.cfg['pretrained_vectors'] = nlp.vocab.vectors.name
class DisabledPipes(list): class DisabledPipes(list):