Fix deserialization

This commit is contained in:
Matthew Honnibal 2017-07-23 14:11:07 +02:00
parent 2df563ad24
commit c4a81a47a4

View File

@ -120,12 +120,19 @@ class BaseThincComponent(object):
deserialize = OrderedDict(( deserialize = OrderedDict((
('model', lambda p: self.model.from_bytes(p.open('rb').read())), ('model', lambda p: self.model.from_bytes(p.open('rb').read())),
('vocab', lambda p: self.vocab.from_disk(p)), ('vocab', lambda p: self.vocab.from_disk(p)),
('cfg', lambda p: self.cfg.update(ujson.load(p.open()))), ('cfg', lambda p: self.cfg.update(_load_cfg(p)))
)) ))
util.from_disk(path, deserialize, exclude) util.from_disk(path, deserialize, exclude)
return self return self
def _load_cfg(path):
if path.exists():
return ujson.load(path.open())
else:
return {}
class TokenVectorEncoder(BaseThincComponent): class TokenVectorEncoder(BaseThincComponent):
"""Assign position-sensitive vectors to tokens, using a CNN or RNN.""" """Assign position-sensitive vectors to tokens, using a CNN or RNN."""
name = 'tensorizer' name = 'tensorizer'