Dont overwrite pretrained_dims setting from cfg. Fixes #1727

This commit is contained in:
Matthew Honnibal 2018-01-23 19:10:49 +01:00
parent 7e6dc283db
commit 85c942a6e3
2 changed files with 3 additions and 3 deletions

View File

@ -532,7 +532,7 @@ class Tagger(Pipe):
else: else:
serialize['model'] = self.model.to_bytes serialize['model'] = self.model.to_bytes
serialize['vocab'] = self.vocab.to_bytes serialize['vocab'] = self.vocab.to_bytes
serialize['cfg'] = lambda: ujson.dumps(self.cfg)
tag_map = OrderedDict(sorted(self.vocab.morphology.tag_map.items())) tag_map = OrderedDict(sorted(self.vocab.morphology.tag_map.items()))
serialize['tag_map'] = lambda: msgpack.dumps( serialize['tag_map'] = lambda: msgpack.dumps(
tag_map, use_bin_type=True, encoding='utf8') tag_map, use_bin_type=True, encoding='utf8')
@ -565,7 +565,7 @@ class Tagger(Pipe):
return self return self
def to_disk(self, path, **exclude): def to_disk(self, path, **exclude):
self.cfg['pretrained_dims'] = self.vocab.vectors.data.shape[1] self.cfg.setdefault('pretrained_dims', self.vocab.vectors.data.shape[1])
tag_map = OrderedDict(sorted(self.vocab.morphology.tag_map.items())) tag_map = OrderedDict(sorted(self.vocab.morphology.tag_map.items()))
serialize = OrderedDict(( serialize = OrderedDict((
('vocab', lambda p: self.vocab.to_disk(p)), ('vocab', lambda p: self.vocab.to_disk(p)),

View File

@ -892,7 +892,7 @@ cdef class Parser:
if 'model' not in exclude: if 'model' not in exclude:
path = util.ensure_path(path) path = util.ensure_path(path)
if self.model is True: if self.model is True:
self.cfg['pretrained_dims'] = self.vocab.vectors_length self.cfg.setdefault('pretrained_dims', self.vocab.vectors_length)
self.model, cfg = self.Model(**self.cfg) self.model, cfg = self.Model(**self.cfg)
else: else:
cfg = {} cfg = {}